Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support snr_gamma for rebalancing loss #279

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions opensora/models/diffusion/diffusion/gaussian_diffusion_t2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ def mean_flat(tensor):
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))

def mean_flat_reweight(tensor, weights):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape)))) * weights


class ModelMeanType(enum.Enum):
"""
Expand Down Expand Up @@ -726,7 +732,7 @@ def _vb_terms_bpd(
output = th.where((t == 0), decoder_nll, kl)
return {"output": output, "pred_xstart": out["pred_xstart"]}

def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, mse_loss_weights=None):
"""
Compute training losses for a single timestep.
:param model: the model to evaluate loss on.
Expand Down Expand Up @@ -801,7 +807,10 @@ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
ModelMeanType.EPSILON: noise,
}[self.model_mean_type]
assert model_output.shape == target.shape == x_start.shape
terms["mse"] = mean_flat((target - model_output) ** 2)
if mse_loss_weights is not None:
terms["mse"] = mean_flat_reweight((target - model_output) ** 2, mse_loss_weights)
else:
terms["mse"] = mean_flat((target - model_output) ** 2)
if "vb" in terms:
terms["loss"] = terms["mse"] + terms["vb"]
else:
Expand Down
37 changes: 35 additions & 2 deletions opensora/train/train_t2v_t5_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,30 @@ def generate_timestep_weights(args, num_timesteps):

return weights

def compute_snr(timesteps, alphas_cumprod):
"""
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
"""

sqrt_alphas_cumprod = alphas_cumprod**0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5

# Expand the tensors.
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)

sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)

# Compute SNR.
snr = (alpha / sigma) ** 2
return snr


#################################################################################
# Training Loop #
Expand Down Expand Up @@ -456,7 +480,16 @@ def load_model_hook(models, input_dir):
model_kwargs = dict(encoder_hidden_states=cond, attention_mask=attn_mask,
encoder_attention_mask=cond_mask, use_image_num=args.use_image_num)
t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=accelerator.device)
loss_dict = diffusion.training_losses(model, x, t, model_kwargs)
if args.snr_gamma is not None:
snr = compute_snr(t, diffusion.alphas_cumprod)
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(t)], dim=1).min(dim=1)[0] / snr
)

loss_dict = diffusion.training_losses(model, x, t, model_kwargs, mse_loss_weights=mse_loss_weights)
else:
loss_dict = diffusion.training_losses(model, x, t, model_kwargs)

loss = loss_dict["loss"].mean()

# Gather the losses across all processes for logging (if we use distributed training).
Expand Down Expand Up @@ -746,7 +779,7 @@ def load_model_hook(models, input_dir):
parser.add_argument(
"--snr_gamma",
type=float,
default=None,
default=5.0,
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
"More details here: https://arxiv.org/abs/2303.09556.",
)
Expand Down