diff --git a/src/world_engine.py b/src/world_engine.py index cbc5ab2..6b6182c 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -41,13 +41,15 @@ def __init__( model_config_overrides: Optional[Dict] = None, device=None, dtype=torch.bfloat16, - load_weights: bool = True + load_weights: bool = True, + cpu_offload: bool = False, ): """ model_uri: HF URI or local folder containing model.safetensors and config.yaml quant: None | intw8a8 | fp8w8a8 | nvfp4 model_config_overrides: Dict to override model config values - auto_aspect_ratio: set to False to work in ae raw space, otherwise in/out are 720p or 360p + cpu_offload: build model on CPU before moving to GPU (reduces peak VRAM) """ self.device = torch.get_default_device() if device is None else device self.dtype = torch.get_default_dtype() if dtype is None else dtype @@ -76,10 +78,18 @@ def __init__( if self.model_cfg.prompt_conditioning is not None: self.prompt_encoder = PromptEncoder(self.model_cfg.prompt_encoder_uri, dtype=dtype).eval() - self.model = WorldModel.from_pretrained( - model_uri, cfg=self.model_cfg, device=self.device, dtype=dtype, load_weights=load_weights - ).eval() - apply_inference_patches(self.model) + if cpu_offload and str(self.device) != "cpu": + with torch.device("cpu"): + self.model = WorldModel.from_pretrained( + model_uri, cfg=self.model_cfg, device="cpu", dtype=dtype, load_weights=load_weights + ).eval() + apply_inference_patches(self.model) + self.model = self.model.to(self.device) + else: + self.model = WorldModel.from_pretrained( + model_uri, cfg=self.model_cfg, device=self.device, dtype=dtype, load_weights=load_weights + ).eval() + apply_inference_patches(self.model) if quant is not None: quantize_model(self.model, quant)