Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] LTX Video #10021

Open
wants to merge 38 commits into
base: main
Choose a base branch
from
Open

[core] LTX Video #10021

wants to merge 38 commits into from

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Nov 26, 2024

T2V:

import torch
from diffusers import LTXPipeline
from diffusers.utils import export_to_video

pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-diffusers", torch_dtype=torch.bfloat16)
pipe.to("cuda")

prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"

video = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=704,
    height=480,
    num_frames=161,
    num_inference_steps=50,
).frames[0]
export_to_video(video, "output.mp4", fps=24)

I2V:

import torch
from diffusers import LTXImageToVideoPipeline
from diffusers.utils import export_to_video, load_image

pipe = LTXImageToVideoPipeline.from_pretrained("a-r-r-o-w/LTX-Video-diffusers", torch_dtype=torch.bfloat16)
pipe.to("cuda")

image = load_image(
    "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png"
)
prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background. Flames engulf the structure, with smoke billowing into the air. Firefighters in protective gear rush to the scene, a fire truck labeled '38' visible behind them. The girl's neutral expression contrasts sharply with the chaos of the fire, creating a poignant and emotionally charged scene."
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"

video = pipe(
    image=image,
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=704,
    height=480,
    num_frames=161,
    num_inference_steps=50,
).frames[0]
export_to_video(video, "output.mp4", fps=24)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@a-r-r-o-w a-r-r-o-w marked this pull request as ready for review November 27, 2024 14:22
@a-r-r-o-w a-r-r-o-w requested review from yiyixuxu, stevhliu and DN6 and removed request for yiyixuxu November 27, 2024 14:22
Comment on lines +190 to +197
scheduler = FlowMatchEulerDiscreteScheduler(
use_dynamic_shifting=True,
base_shift=0.95,
max_shift=2.05,
base_image_seq_len=1024,
max_image_seq_len=4096,
shift_terminal=0.1,
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @yiyixux for the shift_terminal change

Comment on lines +199 to +202
elif qk_norm == "rms_norm_across_heads":
# LTX applies qk norm across all heads
self.norm_q = RMSNorm(dim_head * heads, eps=eps)
self.norm_k = RMSNorm(dim_head * kv_heads, eps=eps)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DN6 Should I follow your approach with Mochi and create a separate attention class for LTX?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok but we want to be more careful, ideally, we do that as part of carefully planned-out refactor
but maybe it would be safe to just inherit form Attention for now? e.g. we wrote code like this with the assumption in mind we only have one attention class
https://github.com/huggingface/diffusers/blob/e47cc1fc1a89a5375c322d296cd122fe71ab859f/src/diffusers/pipelines/pag/pag_utils.py#L57C39-L57C48

cc @DN6 here too

src/diffusers/models/autoencoders/autoencoder_kl_ltx.py Outdated Show resolved Hide resolved
@@ -169,6 +170,12 @@ def _sigma_to_t(self, sigma):
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)

def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding!

docs/source/en/api/models/autoencoderkl_ltx.md Outdated Show resolved Hide resolved
docs/source/en/api/pipelines/ltx.md Outdated Show resolved Hide resolved
src/diffusers/models/autoencoders/autoencoder_kl_ltx.py Outdated Show resolved Hide resolved
src/diffusers/models/autoencoders/autoencoder_kl_ltx.py Outdated Show resolved Hide resolved
src/diffusers/models/autoencoders/autoencoder_kl_ltx.py Outdated Show resolved Hide resolved
src/diffusers/models/transformers/transformer_ltx.py Outdated Show resolved Hide resolved
src/diffusers/models/transformers/transformer_ltx.py Outdated Show resolved Hide resolved
src/diffusers/pipelines/ltx/pipeline_ltx.py Outdated Show resolved Hide resolved
src/diffusers/pipelines/ltx/pipeline_ltx.py Outdated Show resolved Hide resolved
src/diffusers/pipelines/ltx/pipeline_ltx.py Outdated Show resolved Hide resolved
Comment on lines 395 to 398
hidden_states = hidden_states.reshape(
batch_size, -1, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p
)
hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there any reason we reshape from 5d -> 3d and then 3d -> 5d on every iteration?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

think it will be nice to:

  1. have a _pack_latent and _unpack_latent like we did for flux
  2. maybe we can move the rotary embedding in the pipeline so we only pack/unpack once (I know we currently have some discrepancy here, so open to discussion); or we can just configure the rotary pos embeds class in pipeline, so we do not need to give the shape info each time

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll add the pack and unpack latent methods.

Regarding the RoPE, I think a separate layer approach is okay despite it requiring recomputation at every step. This is because we're planning to work on caching hooks that would enable the outputs of any layer to be cached and re-used. Since RoPE is an integral part of many models, there can be some opt-out code we could default to for enabling caching by default on these kinds of model specific RoPE layers. WDYT?

Comment on lines +777 to +806
# ============= TODO(aryan): needs a look by YiYi
latents = latents.float()

noise_pred = self._unpack_latents(
noise_pred,
latent_num_frames,
latent_height,
latent_width,
self.transformer_spatial_patch_size,
self.transformer_temporal_patch_size,
)
latents = self._unpack_latents(
latents,
latent_num_frames,
latent_height,
latent_width,
self.transformer_spatial_patch_size,
self.transformer_temporal_patch_size,
)

noise_pred = noise_pred[:, :, 1:]
noise_latents = latents[:, :, 1:]
pred_latents = self.scheduler.step(noise_pred, t, noise_latents, return_dict=False)[0]

latents = torch.cat([latents[:, :, :1], pred_latents], dim=2)
latents = self._pack_latents(
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
)
latents = latents.to(dtype=latents_dtype)
# =============
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiyixuxu They use per latent frame timesteps (actually, it's per-token timesteps but all tokens corresponding to the same frame have the same timesteps), but since we don't support it in our schedulers, we can't really do the normal scheduler.step(). These changes were required to make the pipeline atleast generate reasonable results. The quality of generations looks similar to me but will try and numerically match.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants