diff --git a/app.py b/app.py index 6f910d4..7038734 100644 --- a/app.py +++ b/app.py @@ -10,6 +10,9 @@ import threading import random +# Disabling parallelism to avoid deadlocks. +os.environ["TOKENIZERS_PARALLELISM"] = "false" + # Global model cache model_cache = {} @@ -30,7 +33,7 @@ height_high = 768 width_low = 640 height_low = 384 -cpu_offloading = True # enable cpu_offloading by default +cpu_offloading = torch.cuda.is_available() # enable cpu_offloading by default # Get the current working directory and create a folder to store the model current_directory = os.getcwd() @@ -115,6 +118,10 @@ def initialize_model(variant): model.vae.to("cuda") model.dit.to("cuda") model.text_encoder.to("cuda") + elif torch.mps.is_available(): + model.vae.to("mps") + model.dit.to("mps") + model.text_encoder.to("mps") else: print("[WARNING] CUDA is not available. Proceeding without GPU.") @@ -182,7 +189,7 @@ def progress_callback(i, m): try: print("[INFO] Starting text-to-video generation...") - with torch.no_grad(), torch.autocast('cuda', dtype=torch_dtype_selected): + with torch.no_grad(), torch.autocast(model.device.type, dtype=torch_dtype_selected): frames = model.generate( prompt=prompt, num_inference_steps=[20, 20, 20], @@ -238,7 +245,7 @@ def progress_callback(i, m): try: print("[INFO] Starting image-to-video generation...") - with torch.no_grad(), torch.autocast('cuda', dtype=torch_dtype_selected): + with torch.no_grad(), torch.autocast(model.device.type, dtype=torch_dtype_selected): frames = model.generate_i2v( prompt=prompt, input_image=image, diff --git a/diffusion_schedulers/scheduling_flow_matching.py b/diffusion_schedulers/scheduling_flow_matching.py index 0a20f68..58f31d2 100644 --- a/diffusion_schedulers/scheduling_flow_matching.py +++ b/diffusion_schedulers/scheduling_flow_matching.py @@ -176,7 +176,7 @@ def set_begin_index(self, begin_index: int = 0): def _sigma_to_t(self, sigma): return sigma * self.config.num_train_timesteps - def set_timesteps(self, num_inference_steps: int, stage_index: int, device: Union[str, torch.device] = None): + def set_timesteps(self, num_inference_steps: int, stage_index: int, device: Union[str, torch.device] = None, dtype: torch.dtype = None): """ Setting the timesteps and sigmas for each stage """ @@ -191,7 +191,7 @@ def set_timesteps(self, num_inference_steps: int, stage_index: int, device: Unio timesteps = np.linspace( timestep_max, timestep_min, num_inference_steps, ) - self.timesteps = torch.from_numpy(timesteps).to(device=device) + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=dtype) stage_sigmas = self.sigmas_per_stage[stage_index] sigma_max = stage_sigmas[0].item() @@ -200,7 +200,7 @@ def set_timesteps(self, num_inference_steps: int, stage_index: int, device: Unio ratios = np.linspace( sigma_max, sigma_min, num_inference_steps ) - sigmas = torch.from_numpy(ratios).to(device=device) + sigmas = torch.from_numpy(ratios).to(device=device, dtype=dtype) self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) self._step_index = None diff --git a/pyramid_dit/flux_modules/modeling_pyramid_flux.py b/pyramid_dit/flux_modules/modeling_pyramid_flux.py index 0021c31..271f6a6 100644 --- a/pyramid_dit/flux_modules/modeling_pyramid_flux.py +++ b/pyramid_dit/flux_modules/modeling_pyramid_flux.py @@ -28,7 +28,7 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: assert dim % 2 == 0, "The dimension must be even." - scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim omega = 1.0 / (theta**scale) batch_size, seq_length = pos.shape diff --git a/pyramid_dit/mmdit_modules/modeling_pyramid_mmdit.py b/pyramid_dit/mmdit_modules/modeling_pyramid_mmdit.py index 1cb50b5..b6a675c 100644 --- a/pyramid_dit/mmdit_modules/modeling_pyramid_mmdit.py +++ b/pyramid_dit/mmdit_modules/modeling_pyramid_mmdit.py @@ -28,7 +28,7 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: assert dim % 2 == 0, "The dimension must be even." - scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim omega = 1.0 / (theta**scale) batch_size, seq_length = pos.shape diff --git a/pyramid_dit/pyramid_dit_for_video_gen_pipeline.py b/pyramid_dit/pyramid_dit_for_video_gen_pipeline.py index 75ed51f..782da67 100644 --- a/pyramid_dit/pyramid_dit_for_video_gen_pipeline.py +++ b/pyramid_dit/pyramid_dit_for_video_gen_pipeline.py @@ -207,8 +207,9 @@ def _enable_sequential_cpu_offload(self, model): cpu_offload(model, device, offload_buffers=offload_buffers) def enable_sequential_cpu_offload(self): - self._enable_sequential_cpu_offload(self.text_encoder) - self._enable_sequential_cpu_offload(self.dit) + if torch.cuda.is_available(): + self._enable_sequential_cpu_offload(self.text_encoder) + self._enable_sequential_cpu_offload(self.dit) def load_checkpoint(self, checkpoint_path, model_key='model', **kwargs): checkpoint = torch.load(checkpoint_path, map_location='cpu') @@ -696,7 +697,9 @@ def prepare_latents( def sample_block_noise(self, bs, ch, temp, height, width): gamma = self.scheduler.config.gamma - dist = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(4), torch.eye(4) * (1 + gamma) - torch.ones(4, 4) * gamma) + # Add a small epsilon to avoid positive-semidefinite. + epsilon = 1e-6 + dist = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(4), torch.eye(4) * (1 + gamma) - torch.ones(4, 4) * gamma + epsilon) block_number = bs * ch * temp * (height // 2) * (width // 2) noise = torch.stack([dist.sample() for _ in range(block_number)]) # [block number, 4] noise = rearrange(noise, '(b c t h w) (p q) -> b c t (h p) (w q)',b=bs,c=ch,t=temp,h=height//2,w=width//2,p=2,q=2) @@ -723,7 +726,7 @@ def generate_one_unit( intermed_latents = [] for i_s in range(len(stages)): - self.scheduler.set_timesteps(num_inference_steps[i_s], i_s, device=device) + self.scheduler.set_timesteps(num_inference_steps[i_s], i_s, device=device, dtype=dtype) timesteps = self.scheduler.timesteps if i_s > 0: @@ -811,7 +814,7 @@ def generate_i2v( if self.sequential_offload_enabled and not cpu_offloading: print("Warning: overriding cpu_offloading set to false, as it's needed for sequential cpu offload") cpu_offloading=True - device = self.device if not cpu_offloading else torch.device("cuda") + device = self.device dtype = self.dtype if cpu_offloading: # skip caring about the text encoder here as its about to be used anyways. @@ -927,8 +930,11 @@ def generate_i2v( for unit_index in tqdm(range(1, num_units)): gc.collect() - torch.cuda.empty_cache() - + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif torch.mps.is_available(): + torch.mps.empty_cache() + if callback: callback(unit_index, num_units) @@ -1028,7 +1034,7 @@ def generate( if self.sequential_offload_enabled and not cpu_offloading: print("Warning: overriding cpu_offloading set to false, as it's needed for sequential cpu offload") cpu_offloading=True - device = self.device if not cpu_offloading else torch.device("cuda") + device = self.device dtype = self.dtype if cpu_offloading: # skip caring about the text encoder here as its about to be used anyways. @@ -1125,8 +1131,11 @@ def generate( for unit_index in tqdm(range(num_units)): gc.collect() - torch.cuda.empty_cache() - + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif torch.mps.is_available(): + torch.mps.empty_cache() + if callback: callback(unit_index, num_units) diff --git a/requirements.txt b/requirements.txt index 297b460..69fb646 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ wheel -torch==2.1.2 -torchvision==0.16.2 +torch>=2.5.0 +torchvision>=0.20.0 transformers==4.39.3 accelerate==0.30.0 diffusers>=0.30.1 -numpy==1.24.4 +numpy==1.26.4 einops ftfy ipython diff --git a/video_vae/modeling_causal_vae.py b/video_vae/modeling_causal_vae.py index 2dc5c81..43e01f3 100644 --- a/video_vae/modeling_causal_vae.py +++ b/video_vae/modeling_causal_vae.py @@ -361,6 +361,8 @@ def chunk_decode(self, z: torch.FloatTensor, window_size=2): dec_list = [] for idx, frames in enumerate(frame_list): + if torch.mps.is_available(): + torch.mps.empty_cache() if idx == 0: z_h = self.post_quant_conv(frames, is_init_image=True, temporal_chunk=True) dec = self.decoder(z_h, is_init_image=True, temporal_chunk=True) diff --git a/video_vae/modeling_enc_dec.py b/video_vae/modeling_enc_dec.py index a8cd464..0c2601e 100644 --- a/video_vae/modeling_enc_dec.py +++ b/video_vae/modeling_enc_dec.py @@ -360,6 +360,7 @@ def custom_forward(*inputs): # post-process sample = self.conv_norm_out(sample) + sample = sample.contiguous() # MPS problem workaround. sample = self.conv_act(sample) sample = self.conv_out(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk) diff --git a/video_vae/modeling_resnet.py b/video_vae/modeling_resnet.py index a1ddda0..644bf52 100644 --- a/video_vae/modeling_resnet.py +++ b/video_vae/modeling_resnet.py @@ -126,6 +126,7 @@ def forward( else: hidden_states = self.norm1(hidden_states) + hidden_states = hidden_states.contiguous() # MPS problem workaround. hidden_states = self.nonlinearity(hidden_states) hidden_states = self.conv1(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk) @@ -138,6 +139,7 @@ def forward( else: hidden_states = self.norm2(hidden_states) + hidden_states = hidden_states.contiguous() # MPS problem workaround. hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)