Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
import threading
import random

# Disabling parallelism to avoid deadlocks.
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Global model cache
model_cache = {}

Expand All @@ -30,7 +33,7 @@
height_high = 768
width_low = 640
height_low = 384
cpu_offloading = True # enable cpu_offloading by default
cpu_offloading = torch.cuda.is_available() # enable cpu_offloading by default

# Get the current working directory and create a folder to store the model
current_directory = os.getcwd()
Expand Down Expand Up @@ -115,6 +118,10 @@ def initialize_model(variant):
model.vae.to("cuda")
model.dit.to("cuda")
model.text_encoder.to("cuda")
elif torch.mps.is_available():
model.vae.to("mps")
model.dit.to("mps")
model.text_encoder.to("mps")
else:
print("[WARNING] CUDA is not available. Proceeding without GPU.")

Expand Down Expand Up @@ -182,7 +189,7 @@ def progress_callback(i, m):

try:
print("[INFO] Starting text-to-video generation...")
with torch.no_grad(), torch.autocast('cuda', dtype=torch_dtype_selected):
with torch.no_grad(), torch.autocast(model.device.type, dtype=torch_dtype_selected):
frames = model.generate(
prompt=prompt,
num_inference_steps=[20, 20, 20],
Expand Down Expand Up @@ -238,7 +245,7 @@ def progress_callback(i, m):

try:
print("[INFO] Starting image-to-video generation...")
with torch.no_grad(), torch.autocast('cuda', dtype=torch_dtype_selected):
with torch.no_grad(), torch.autocast(model.device.type, dtype=torch_dtype_selected):
frames = model.generate_i2v(
prompt=prompt,
input_image=image,
Expand Down
6 changes: 3 additions & 3 deletions diffusion_schedulers/scheduling_flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def set_begin_index(self, begin_index: int = 0):
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps

def set_timesteps(self, num_inference_steps: int, stage_index: int, device: Union[str, torch.device] = None):
def set_timesteps(self, num_inference_steps: int, stage_index: int, device: Union[str, torch.device] = None, dtype: torch.dtype = None):
"""
Setting the timesteps and sigmas for each stage
"""
Expand All @@ -191,7 +191,7 @@ def set_timesteps(self, num_inference_steps: int, stage_index: int, device: Unio
timesteps = np.linspace(
timestep_max, timestep_min, num_inference_steps,
)
self.timesteps = torch.from_numpy(timesteps).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=dtype)

stage_sigmas = self.sigmas_per_stage[stage_index]
sigma_max = stage_sigmas[0].item()
Expand All @@ -200,7 +200,7 @@ def set_timesteps(self, num_inference_steps: int, stage_index: int, device: Unio
ratios = np.linspace(
sigma_max, sigma_min, num_inference_steps
)
sigmas = torch.from_numpy(ratios).to(device=device)
sigmas = torch.from_numpy(ratios).to(device=device, dtype=dtype)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])

self._step_index = None
Expand Down
2 changes: 1 addition & 1 deletion pyramid_dit/flux_modules/modeling_pyramid_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even."

scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
omega = 1.0 / (theta**scale)

batch_size, seq_length = pos.shape
Expand Down
2 changes: 1 addition & 1 deletion pyramid_dit/mmdit_modules/modeling_pyramid_mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even."

scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
omega = 1.0 / (theta**scale)

batch_size, seq_length = pos.shape
Expand Down
29 changes: 19 additions & 10 deletions pyramid_dit/pyramid_dit_for_video_gen_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,9 @@ def _enable_sequential_cpu_offload(self, model):
cpu_offload(model, device, offload_buffers=offload_buffers)

def enable_sequential_cpu_offload(self):
self._enable_sequential_cpu_offload(self.text_encoder)
self._enable_sequential_cpu_offload(self.dit)
if torch.cuda.is_available():
self._enable_sequential_cpu_offload(self.text_encoder)
self._enable_sequential_cpu_offload(self.dit)

def load_checkpoint(self, checkpoint_path, model_key='model', **kwargs):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
Expand Down Expand Up @@ -696,7 +697,9 @@ def prepare_latents(

def sample_block_noise(self, bs, ch, temp, height, width):
gamma = self.scheduler.config.gamma
dist = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(4), torch.eye(4) * (1 + gamma) - torch.ones(4, 4) * gamma)
# Add a small epsilon to avoid positive-semidefinite.
epsilon = 1e-6
dist = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(4), torch.eye(4) * (1 + gamma) - torch.ones(4, 4) * gamma + epsilon)
block_number = bs * ch * temp * (height // 2) * (width // 2)
noise = torch.stack([dist.sample() for _ in range(block_number)]) # [block number, 4]
noise = rearrange(noise, '(b c t h w) (p q) -> b c t (h p) (w q)',b=bs,c=ch,t=temp,h=height//2,w=width//2,p=2,q=2)
Expand All @@ -723,7 +726,7 @@ def generate_one_unit(
intermed_latents = []

for i_s in range(len(stages)):
self.scheduler.set_timesteps(num_inference_steps[i_s], i_s, device=device)
self.scheduler.set_timesteps(num_inference_steps[i_s], i_s, device=device, dtype=dtype)
timesteps = self.scheduler.timesteps

if i_s > 0:
Expand Down Expand Up @@ -811,7 +814,7 @@ def generate_i2v(
if self.sequential_offload_enabled and not cpu_offloading:
print("Warning: overriding cpu_offloading set to false, as it's needed for sequential cpu offload")
cpu_offloading=True
device = self.device if not cpu_offloading else torch.device("cuda")
device = self.device
dtype = self.dtype
if cpu_offloading:
# skip caring about the text encoder here as its about to be used anyways.
Expand Down Expand Up @@ -927,8 +930,11 @@ def generate_i2v(

for unit_index in tqdm(range(1, num_units)):
gc.collect()
torch.cuda.empty_cache()

if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.mps.is_available():
torch.mps.empty_cache()

if callback:
callback(unit_index, num_units)

Expand Down Expand Up @@ -1028,7 +1034,7 @@ def generate(
if self.sequential_offload_enabled and not cpu_offloading:
print("Warning: overriding cpu_offloading set to false, as it's needed for sequential cpu offload")
cpu_offloading=True
device = self.device if not cpu_offloading else torch.device("cuda")
device = self.device
dtype = self.dtype
if cpu_offloading:
# skip caring about the text encoder here as its about to be used anyways.
Expand Down Expand Up @@ -1125,8 +1131,11 @@ def generate(

for unit_index in tqdm(range(num_units)):
gc.collect()
torch.cuda.empty_cache()

if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.mps.is_available():
torch.mps.empty_cache()

if callback:
callback(unit_index, num_units)

Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
wheel
torch==2.1.2
torchvision==0.16.2
torch>=2.5.0
torchvision>=0.20.0
transformers==4.39.3
accelerate==0.30.0
diffusers>=0.30.1
numpy==1.24.4
numpy==1.26.4
einops
ftfy
ipython
Expand Down
2 changes: 2 additions & 0 deletions video_vae/modeling_causal_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,8 @@ def chunk_decode(self, z: torch.FloatTensor, window_size=2):

dec_list = []
for idx, frames in enumerate(frame_list):
if torch.mps.is_available():
torch.mps.empty_cache()
if idx == 0:
z_h = self.post_quant_conv(frames, is_init_image=True, temporal_chunk=True)
dec = self.decoder(z_h, is_init_image=True, temporal_chunk=True)
Expand Down
1 change: 1 addition & 0 deletions video_vae/modeling_enc_dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def custom_forward(*inputs):

# post-process
sample = self.conv_norm_out(sample)
sample = sample.contiguous() # MPS problem workaround.
sample = self.conv_act(sample)
sample = self.conv_out(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)

Expand Down
2 changes: 2 additions & 0 deletions video_vae/modeling_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def forward(
else:
hidden_states = self.norm1(hidden_states)

hidden_states = hidden_states.contiguous() # MPS problem workaround.
hidden_states = self.nonlinearity(hidden_states)

hidden_states = self.conv1(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
Expand All @@ -138,6 +139,7 @@ def forward(
else:
hidden_states = self.norm2(hidden_states)

hidden_states = hidden_states.contiguous() # MPS problem workaround.
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
Expand Down