Thanks for your excellent works and sharing your code !
while I have questions about the code:
log_probs = torch.cat(
[
log_add_exp(log_x_start[:,:-1,:]+log_cumprod_at, log_cumprod_bt),
log_add_exp(log_x_start[:,-1:,:]+log_1_min_cumprod_ct, log_cumprod_ct)
],
dim=1
)
why log_add_exp is used , looking forward your reply!