Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions src/world_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,15 @@ def __init__(
model_config_overrides: Optional[Dict] = None,
device=None,
dtype=torch.bfloat16,
load_weights: bool = True
load_weights: bool = True,
cpu_offload: bool = False,
):
"""
model_uri: HF URI or local folder containing model.safetensors and config.yaml
quant: None | intw8a8 | fp8w8a8 | nvfp4
model_config_overrides: Dict to override model config values
- auto_aspect_ratio: set to False to work in ae raw space, otherwise in/out are 720p or 360p
cpu_offload: build model on CPU before moving to GPU (reduces peak VRAM)
"""
self.device = torch.get_default_device() if device is None else device
self.dtype = torch.get_default_dtype() if dtype is None else dtype
Expand Down Expand Up @@ -76,10 +78,18 @@ def __init__(
if self.model_cfg.prompt_conditioning is not None:
self.prompt_encoder = PromptEncoder(self.model_cfg.prompt_encoder_uri, dtype=dtype).eval()

self.model = WorldModel.from_pretrained(
model_uri, cfg=self.model_cfg, device=self.device, dtype=dtype, load_weights=load_weights
).eval()
apply_inference_patches(self.model)
if cpu_offload and str(self.device) != "cpu":
with torch.device("cpu"):
self.model = WorldModel.from_pretrained(
model_uri, cfg=self.model_cfg, device="cpu", dtype=dtype, load_weights=load_weights
).eval()
apply_inference_patches(self.model)
self.model = self.model.to(self.device)
else:
self.model = WorldModel.from_pretrained(
model_uri, cfg=self.model_cfg, device=self.device, dtype=dtype, load_weights=load_weights
).eval()
apply_inference_patches(self.model)
if quant is not None:
quantize_model(self.model, quant)

Expand Down