From 7abf210bcd2dbf8d97e9b8fdeb7a9f4ebe80c4d8 Mon Sep 17 00:00:00 2001 From: Yoshimasa Niwa Date: Wed, 16 Oct 2024 20:52:24 +0900 Subject: [PATCH 1/2] Use mps if it's available on Apple Silicon devices. - Use pytorch 2.5.0 instead of nightly. - FIX: activation error on MPS MPS can't silu activation and creates randomly broken results if tensor memory format is not contiguous. This is not happening on macOS 15 and later because it's using native stride but macOS 14 is affected. --- app.py | 13 +++++++--- .../scheduling_flow_matching.py | 6 ++--- .../flux_modules/modeling_pyramid_flux.py | 2 +- .../mmdit_modules/modeling_pyramid_mmdit.py | 2 +- .../pyramid_dit_for_video_gen_pipeline.py | 25 ++++++++++++------- requirements.txt | 6 ++--- video_vae/modeling_causal_vae.py | 2 ++ video_vae/modeling_enc_dec.py | 1 + video_vae/modeling_resnet.py | 2 ++ 9 files changed, 39 insertions(+), 20 deletions(-) diff --git a/app.py b/app.py index 6f910d4..7038734 100644 --- a/app.py +++ b/app.py @@ -10,6 +10,9 @@ import threading import random +# Disabling parallelism to avoid deadlocks. +os.environ["TOKENIZERS_PARALLELISM"] = "false" + # Global model cache model_cache = {} @@ -30,7 +33,7 @@ height_high = 768 width_low = 640 height_low = 384 -cpu_offloading = True # enable cpu_offloading by default +cpu_offloading = torch.cuda.is_available() # enable cpu_offloading by default # Get the current working directory and create a folder to store the model current_directory = os.getcwd() @@ -115,6 +118,10 @@ def initialize_model(variant): model.vae.to("cuda") model.dit.to("cuda") model.text_encoder.to("cuda") + elif torch.mps.is_available(): + model.vae.to("mps") + model.dit.to("mps") + model.text_encoder.to("mps") else: print("[WARNING] CUDA is not available. Proceeding without GPU.") @@ -182,7 +189,7 @@ def progress_callback(i, m): try: print("[INFO] Starting text-to-video generation...") - with torch.no_grad(), torch.autocast('cuda', dtype=torch_dtype_selected): + with torch.no_grad(), torch.autocast(model.device.type, dtype=torch_dtype_selected): frames = model.generate( prompt=prompt, num_inference_steps=[20, 20, 20], @@ -238,7 +245,7 @@ def progress_callback(i, m): try: print("[INFO] Starting image-to-video generation...") - with torch.no_grad(), torch.autocast('cuda', dtype=torch_dtype_selected): + with torch.no_grad(), torch.autocast(model.device.type, dtype=torch_dtype_selected): frames = model.generate_i2v( prompt=prompt, input_image=image, diff --git a/diffusion_schedulers/scheduling_flow_matching.py b/diffusion_schedulers/scheduling_flow_matching.py index 0a20f68..58f31d2 100644 --- a/diffusion_schedulers/scheduling_flow_matching.py +++ b/diffusion_schedulers/scheduling_flow_matching.py @@ -176,7 +176,7 @@ def set_begin_index(self, begin_index: int = 0): def _sigma_to_t(self, sigma): return sigma * self.config.num_train_timesteps - def set_timesteps(self, num_inference_steps: int, stage_index: int, device: Union[str, torch.device] = None): + def set_timesteps(self, num_inference_steps: int, stage_index: int, device: Union[str, torch.device] = None, dtype: torch.dtype = None): """ Setting the timesteps and sigmas for each stage """ @@ -191,7 +191,7 @@ def set_timesteps(self, num_inference_steps: int, stage_index: int, device: Unio timesteps = np.linspace( timestep_max, timestep_min, num_inference_steps, ) - self.timesteps = torch.from_numpy(timesteps).to(device=device) + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=dtype) stage_sigmas = self.sigmas_per_stage[stage_index] sigma_max = stage_sigmas[0].item() @@ -200,7 +200,7 @@ def set_timesteps(self, num_inference_steps: int, stage_index: int, device: Unio ratios = np.linspace( sigma_max, sigma_min, num_inference_steps ) - sigmas = torch.from_numpy(ratios).to(device=device) + sigmas = torch.from_numpy(ratios).to(device=device, dtype=dtype) self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) self._step_index = None diff --git a/pyramid_dit/flux_modules/modeling_pyramid_flux.py b/pyramid_dit/flux_modules/modeling_pyramid_flux.py index 0021c31..271f6a6 100644 --- a/pyramid_dit/flux_modules/modeling_pyramid_flux.py +++ b/pyramid_dit/flux_modules/modeling_pyramid_flux.py @@ -28,7 +28,7 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: assert dim % 2 == 0, "The dimension must be even." - scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim omega = 1.0 / (theta**scale) batch_size, seq_length = pos.shape diff --git a/pyramid_dit/mmdit_modules/modeling_pyramid_mmdit.py b/pyramid_dit/mmdit_modules/modeling_pyramid_mmdit.py index 1cb50b5..b6a675c 100644 --- a/pyramid_dit/mmdit_modules/modeling_pyramid_mmdit.py +++ b/pyramid_dit/mmdit_modules/modeling_pyramid_mmdit.py @@ -28,7 +28,7 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: assert dim % 2 == 0, "The dimension must be even." - scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim omega = 1.0 / (theta**scale) batch_size, seq_length = pos.shape diff --git a/pyramid_dit/pyramid_dit_for_video_gen_pipeline.py b/pyramid_dit/pyramid_dit_for_video_gen_pipeline.py index 75ed51f..eb95a28 100644 --- a/pyramid_dit/pyramid_dit_for_video_gen_pipeline.py +++ b/pyramid_dit/pyramid_dit_for_video_gen_pipeline.py @@ -207,8 +207,9 @@ def _enable_sequential_cpu_offload(self, model): cpu_offload(model, device, offload_buffers=offload_buffers) def enable_sequential_cpu_offload(self): - self._enable_sequential_cpu_offload(self.text_encoder) - self._enable_sequential_cpu_offload(self.dit) + if torch.cuda.is_available(): + self._enable_sequential_cpu_offload(self.text_encoder) + self._enable_sequential_cpu_offload(self.dit) def load_checkpoint(self, checkpoint_path, model_key='model', **kwargs): checkpoint = torch.load(checkpoint_path, map_location='cpu') @@ -723,7 +724,7 @@ def generate_one_unit( intermed_latents = [] for i_s in range(len(stages)): - self.scheduler.set_timesteps(num_inference_steps[i_s], i_s, device=device) + self.scheduler.set_timesteps(num_inference_steps[i_s], i_s, device=device, dtype=dtype) timesteps = self.scheduler.timesteps if i_s > 0: @@ -811,7 +812,7 @@ def generate_i2v( if self.sequential_offload_enabled and not cpu_offloading: print("Warning: overriding cpu_offloading set to false, as it's needed for sequential cpu offload") cpu_offloading=True - device = self.device if not cpu_offloading else torch.device("cuda") + device = self.device dtype = self.dtype if cpu_offloading: # skip caring about the text encoder here as its about to be used anyways. @@ -927,8 +928,11 @@ def generate_i2v( for unit_index in tqdm(range(1, num_units)): gc.collect() - torch.cuda.empty_cache() - + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif torch.mps.is_available(): + torch.mps.empty_cache() + if callback: callback(unit_index, num_units) @@ -1028,7 +1032,7 @@ def generate( if self.sequential_offload_enabled and not cpu_offloading: print("Warning: overriding cpu_offloading set to false, as it's needed for sequential cpu offload") cpu_offloading=True - device = self.device if not cpu_offloading else torch.device("cuda") + device = self.device dtype = self.dtype if cpu_offloading: # skip caring about the text encoder here as its about to be used anyways. @@ -1125,8 +1129,11 @@ def generate( for unit_index in tqdm(range(num_units)): gc.collect() - torch.cuda.empty_cache() - + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif torch.mps.is_available(): + torch.mps.empty_cache() + if callback: callback(unit_index, num_units) diff --git a/requirements.txt b/requirements.txt index 297b460..69fb646 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ wheel -torch==2.1.2 -torchvision==0.16.2 +torch>=2.5.0 +torchvision>=0.20.0 transformers==4.39.3 accelerate==0.30.0 diffusers>=0.30.1 -numpy==1.24.4 +numpy==1.26.4 einops ftfy ipython diff --git a/video_vae/modeling_causal_vae.py b/video_vae/modeling_causal_vae.py index 2dc5c81..43e01f3 100644 --- a/video_vae/modeling_causal_vae.py +++ b/video_vae/modeling_causal_vae.py @@ -361,6 +361,8 @@ def chunk_decode(self, z: torch.FloatTensor, window_size=2): dec_list = [] for idx, frames in enumerate(frame_list): + if torch.mps.is_available(): + torch.mps.empty_cache() if idx == 0: z_h = self.post_quant_conv(frames, is_init_image=True, temporal_chunk=True) dec = self.decoder(z_h, is_init_image=True, temporal_chunk=True) diff --git a/video_vae/modeling_enc_dec.py b/video_vae/modeling_enc_dec.py index a8cd464..0c2601e 100644 --- a/video_vae/modeling_enc_dec.py +++ b/video_vae/modeling_enc_dec.py @@ -360,6 +360,7 @@ def custom_forward(*inputs): # post-process sample = self.conv_norm_out(sample) + sample = sample.contiguous() # MPS problem workaround. sample = self.conv_act(sample) sample = self.conv_out(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk) diff --git a/video_vae/modeling_resnet.py b/video_vae/modeling_resnet.py index a1ddda0..644bf52 100644 --- a/video_vae/modeling_resnet.py +++ b/video_vae/modeling_resnet.py @@ -126,6 +126,7 @@ def forward( else: hidden_states = self.norm1(hidden_states) + hidden_states = hidden_states.contiguous() # MPS problem workaround. hidden_states = self.nonlinearity(hidden_states) hidden_states = self.conv1(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk) @@ -138,6 +139,7 @@ def forward( else: hidden_states = self.norm2(hidden_states) + hidden_states = hidden_states.contiguous() # MPS problem workaround. hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk) From 1389532ca8c8a4b0cc764d1783df68a3ddacc343 Mon Sep 17 00:00:00 2001 From: Yoshimasa Niwa Date: Mon, 18 Nov 2024 22:37:43 -0800 Subject: [PATCH 2/2] Address for pytorch 2.5 behavior change. --- pyramid_dit/pyramid_dit_for_video_gen_pipeline.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyramid_dit/pyramid_dit_for_video_gen_pipeline.py b/pyramid_dit/pyramid_dit_for_video_gen_pipeline.py index eb95a28..782da67 100644 --- a/pyramid_dit/pyramid_dit_for_video_gen_pipeline.py +++ b/pyramid_dit/pyramid_dit_for_video_gen_pipeline.py @@ -697,7 +697,9 @@ def prepare_latents( def sample_block_noise(self, bs, ch, temp, height, width): gamma = self.scheduler.config.gamma - dist = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(4), torch.eye(4) * (1 + gamma) - torch.ones(4, 4) * gamma) + # Add a small epsilon to avoid positive-semidefinite. + epsilon = 1e-6 + dist = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(4), torch.eye(4) * (1 + gamma) - torch.ones(4, 4) * gamma + epsilon) block_number = bs * ch * temp * (height // 2) * (width // 2) noise = torch.stack([dist.sample() for _ in range(block_number)]) # [block number, 4] noise = rearrange(noise, '(b c t h w) (p q) -> b c t (h p) (w q)',b=bs,c=ch,t=temp,h=height//2,w=width//2,p=2,q=2)