Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from tqdm import tqdm

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -138,9 +138,9 @@ def __len__(self):
parser = argparse.ArgumentParser()
parser.add_argument("--experiment_name", type=str, default="base")
parser.add_argument("--dataset", type=str, default="dino", choices=["circle", "dino", "line", "moons"])
parser.add_argument("--train_batch_size", type=int, default=32)
parser.add_argument("--train_batch_size", type=int, default=256)
parser.add_argument("--eval_batch_size", type=int, default=1000)
parser.add_argument("--num_epochs", type=int, default=200)
parser.add_argument("--num_epochs", type=int, default=150)
parser.add_argument("--learning_rate", type=float, default=1e-3)
parser.add_argument("--num_timesteps", type=int, default=50)
parser.add_argument("--beta_schedule", type=str, default="linear", choices=["linear", "quadratic"])
Expand All @@ -150,7 +150,15 @@ def __len__(self):
parser.add_argument("--time_embedding", type=str, default="sinusoidal", choices=["sinusoidal", "learnable", "linear", "zero"])
parser.add_argument("--input_embedding", type=str, default="sinusoidal", choices=["sinusoidal", "learnable", "linear", "identity"])
parser.add_argument("--save_images_step", type=int, default=1)
parser.add_argument("--device", type=str, default="cpu")
config = parser.parse_args()

if torch.cuda.is_available() and config.device[:4] =="cuda":
print(f"INFO: '{config.device}' is available")
device = config.device
else:
print(f"INFO: using cpu")
device = config.device = "cpu"

dataset = datasets.get_dataset(config.dataset)
dataloader = DataLoader(
Expand All @@ -161,7 +169,7 @@ def __len__(self):
hidden_layers=config.hidden_layers,
emb_size=config.embedding_size,
time_emb=config.time_embedding,
input_emb=config.input_embedding)
input_emb=config.input_embedding).to(device)

noise_scheduler = NoiseScheduler(
num_timesteps=config.num_timesteps,
Expand All @@ -176,6 +184,7 @@ def __len__(self):
frames = []
losses = []
print("Training model...")

for epoch in range(config.num_epochs):
model.train()
progress_bar = tqdm(total=len(dataloader))
Expand All @@ -188,8 +197,8 @@ def __len__(self):
).long()

noisy = noise_scheduler.add_noise(batch, noise, timesteps)
noise_pred = model(noisy, timesteps)
loss = F.mse_loss(noise_pred, noise)
noise_pred = model(noisy.to(device), timesteps.to(device))
loss = F.mse_loss(noise_pred, noise.to(device))
loss.backward(loss)

nn.utils.clip_grad_norm_(model.parameters(), 1.0)
Expand All @@ -202,6 +211,7 @@ def __len__(self):
progress_bar.set_postfix(**logs)
global_step += 1
progress_bar.close()


if epoch % config.save_images_step == 0 or epoch == config.num_epochs - 1:
# generate data with the model to later visualize the learning process
Expand All @@ -211,9 +221,10 @@ def __len__(self):
for i, t in enumerate(tqdm(timesteps)):
t = torch.from_numpy(np.repeat(t, config.eval_batch_size)).long()
with torch.no_grad():
residual = model(sample, t)
residual = model(sample.to(device), t.to(device)).to("cpu")
sample = noise_scheduler.step(residual, t[0], sample)
frames.append(sample.numpy())


print("Saving model...")
outdir = f"exps/{config.experiment_name}"
Expand Down
11 changes: 6 additions & 5 deletions positional_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@


class SinusoidalEmbedding(nn.Module):
def __init__(self, size: int, scale: float = 1.0):
def __init__(self, size: int, scale: torch.float = 1.0):
super().__init__()
self.size = size
self.scale = scale

def forward(self, x: torch.Tensor):
x = x * self.scale
half_size = self.size // 2
emb = torch.log(torch.Tensor([10000.0])) / (half_size - 1)
emb = torch.exp(-emb * torch.arange(half_size))
emb = x.unsqueeze(-1) * emb.unsqueeze(0)
self.emb = nn.Parameter(emb, requires_grad=False)

def forward(self, x: torch.Tensor):
x = x * self.scale
emb = x.unsqueeze(-1) * self.emb.unsqueeze(0)
emb = torch.cat((torch.sin(emb), torch.cos(emb)), dim=-1)
return emb

Expand Down