-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core] LTX Video #10021
base: main
Are you sure you want to change the base?
[core] LTX Video #10021
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
scheduler = FlowMatchEulerDiscreteScheduler( | ||
use_dynamic_shifting=True, | ||
base_shift=0.95, | ||
max_shift=2.05, | ||
base_image_seq_len=1024, | ||
max_image_seq_len=4096, | ||
shift_terminal=0.1, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @yiyixux for the shift_terminal change
elif qk_norm == "rms_norm_across_heads": | ||
# LTX applies qk norm across all heads | ||
self.norm_q = RMSNorm(dim_head * heads, eps=eps) | ||
self.norm_k = RMSNorm(dim_head * kv_heads, eps=eps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@DN6 Should I follow your approach with Mochi and create a separate attention class for LTX?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok but we want to be more careful, ideally, we do that as part of carefully planned-out refactor
but maybe it would be safe to just inherit form Attention
for now? e.g. we wrote code like this with the assumption in mind we only have one attention class
https://github.com/huggingface/diffusers/blob/e47cc1fc1a89a5375c322d296cd122fe71ab859f/src/diffusers/pipelines/pag/pag_utils.py#L57C39-L57C48
cc @DN6 here too
@@ -169,6 +170,12 @@ def _sigma_to_t(self, sigma): | |||
def time_shift(self, mu: float, sigma: float, t: torch.Tensor): | |||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) | |||
|
|||
def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @yiyixuxu
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding!
Co-authored-by: Steven Liu <[email protected]>
hidden_states = hidden_states.reshape( | ||
batch_size, -1, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p | ||
) | ||
hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there any reason we reshape from 5d -> 3d and then 3d -> 5d on every iteration?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
think it will be nice to:
- have a
_pack_latent
and_unpack_latent
like we did for flux - maybe we can move the rotary embedding in the pipeline so we only pack/unpack once (I know we currently have some discrepancy here, so open to discussion); or we can just configure the rotary pos embeds class in pipeline, so we do not need to give the shape info each time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I'll add the pack and unpack latent methods.
Regarding the RoPE, I think a separate layer approach is okay despite it requiring recomputation at every step. This is because we're planning to work on caching hooks that would enable the outputs of any layer to be cached and re-used. Since RoPE is an integral part of many models, there can be some opt-out code we could default to for enabling caching by default on these kinds of model specific RoPE layers. WDYT?
# ============= TODO(aryan): needs a look by YiYi | ||
latents = latents.float() | ||
|
||
noise_pred = self._unpack_latents( | ||
noise_pred, | ||
latent_num_frames, | ||
latent_height, | ||
latent_width, | ||
self.transformer_spatial_patch_size, | ||
self.transformer_temporal_patch_size, | ||
) | ||
latents = self._unpack_latents( | ||
latents, | ||
latent_num_frames, | ||
latent_height, | ||
latent_width, | ||
self.transformer_spatial_patch_size, | ||
self.transformer_temporal_patch_size, | ||
) | ||
|
||
noise_pred = noise_pred[:, :, 1:] | ||
noise_latents = latents[:, :, 1:] | ||
pred_latents = self.scheduler.step(noise_pred, t, noise_latents, return_dict=False)[0] | ||
|
||
latents = torch.cat([latents[:, :, :1], pred_latents], dim=2) | ||
latents = self._pack_latents( | ||
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size | ||
) | ||
latents = latents.to(dtype=latents_dtype) | ||
# ============= |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yiyixuxu They use per latent frame timesteps (actually, it's per-token timesteps but all tokens corresponding to the same frame have the same timesteps), but since we don't support it in our schedulers, we can't really do the normal scheduler.step()
. These changes were required to make the pipeline atleast generate reasonable results. The quality of generations looks similar to me but will try and numerically match.
T2V:
I2V: