Skip to content

Commit

Permalink
Merge pull request #58 from Lightricks/feature/improve-inference
Browse files Browse the repository at this point in the history
Inference: Support noise augmentation in image-to-video
  • Loading branch information
yoavhacohen authored Dec 11, 2024
2 parents a01a171 + fe8ba4e commit 7c3e1b0
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 31 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,13 @@ To use our model, please follow the inference code in [inference.py](./inference
#### For text-to-video generation:

```bash
python inference.py --ckpt_dir 'PATH' --prompt "PROMPT" --height HEIGHT --width WIDTH --num_frames NUM_FRAMES --seed SEED
python inference.py --ckpt_path 'PATH' --prompt "PROMPT" --height HEIGHT --width WIDTH --num_frames NUM_FRAMES --seed SEED
```

#### For image-to-video generation:

```bash
python inference.py --ckpt_dir 'PATH' --prompt "PROMPT" --input_image_path IMAGE_PATH --height HEIGHT --width WIDTH --num_frames NUM_FRAMES --seed SEED
python inference.py --ckpt_path 'PATH' --prompt "PROMPT" --input_image_path IMAGE_PATH --height HEIGHT --width WIDTH --num_frames NUM_FRAMES --seed SEED
```

## ComfyUI Integration
Expand Down
69 changes: 41 additions & 28 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import imageio
import numpy as np
import safetensors.torch
from safetensors import safe_open
import torch
import torch.nn.functional as F
from PIL import Image
Expand All @@ -29,34 +29,33 @@
MAX_NUM_FRAMES = 257


def load_vae(vae_dir):
vae_ckpt_path = vae_dir / "vae_diffusion_pytorch_model.safetensors"
vae_config_path = vae_dir / "config.json"
with open(vae_config_path, "r") as f:
vae_config = json.load(f)
def load_vae(vae_config, ckpt):
vae = CausalVideoAutoencoder.from_config(vae_config)
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
vae_state_dict = {
key.replace("vae.", ""): value
for key, value in ckpt.items()
if key.startswith("vae.")
}
vae.load_state_dict(vae_state_dict)
if torch.cuda.is_available():
vae = vae.cuda()
return vae.to(torch.bfloat16)


def load_unet(unet_dir):
unet_ckpt_path = unet_dir / "unet_diffusion_pytorch_model.safetensors"
unet_config_path = unet_dir / "config.json"
transformer_config = Transformer3DModel.load_config(unet_config_path)
def load_transformer(transformer_config, ckpt):
transformer = Transformer3DModel.from_config(transformer_config)
unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
transformer.load_state_dict(unet_state_dict, strict=True)
transformer_state_dict = {
key.replace("model.diffusion_model.", ""): value
for key, value in ckpt.items()
if key.startswith("model.diffusion_model.")
}
transformer.load_state_dict(transformer_state_dict, strict=True)
if torch.cuda.is_available():
transformer = transformer.cuda()
return transformer


def load_scheduler(scheduler_dir):
scheduler_config_path = scheduler_dir / "scheduler_config.json"
scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
def load_scheduler(scheduler_config):
return RectifiedFlowScheduler.from_config(scheduler_config)


Expand Down Expand Up @@ -168,10 +167,10 @@ def main():

# Directories
parser.add_argument(
"--ckpt_dir",
"--ckpt_path",
type=str,
required=True,
help="Path to the directory containing unet, vae, and scheduler subdirectories",
help="Path to a safetensors file that contains all model parts.",
)
parser.add_argument(
"--input_video_path",
Expand Down Expand Up @@ -205,6 +204,12 @@ def main():
default=3,
help="Guidance scale for the pipeline",
)
parser.add_argument(
"--image_cond_noise_scale",
type=float,
default=0.15,
help="Amount of noise to add to the conditioned image",
)
parser.add_argument(
"--height",
type=int,
Expand Down Expand Up @@ -297,15 +302,22 @@ def main():
media_items = None

# Paths for the separate mode directories
ckpt_dir = Path(args.ckpt_dir)
unet_dir = ckpt_dir / "unet"
vae_dir = ckpt_dir / "vae"
scheduler_dir = ckpt_dir / "scheduler"
ckpt_path = Path(args.ckpt_path)
ckpt = {}
with safe_open(ckpt_path, framework="pt", device="cpu") as f:
metadata = f.metadata()
for k in f.keys():
ckpt[k] = f.get_tensor(k)

configs = json.loads(metadata["config"])
vae_config = configs["vae"]
transformer_config = configs["transformer"]
scheduler_config = configs["scheduler"]

# Load models
vae = load_vae(vae_dir)
unet = load_unet(unet_dir)
scheduler = load_scheduler(scheduler_dir)
vae = load_vae(vae_config, ckpt)
transformer = load_transformer(transformer_config, ckpt)
scheduler = load_scheduler(scheduler_config)
patchifier = SymmetricPatchifier(patch_size=1)
text_encoder = T5EncoderModel.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
Expand All @@ -316,12 +328,12 @@ def main():
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
)

if args.bfloat16 and unet.dtype != torch.bfloat16:
unet = unet.to(torch.bfloat16)
if args.bfloat16 and transformer.dtype != torch.bfloat16:
transformer = transformer.to(torch.bfloat16)

# Use submodels for the pipeline
submodel_dict = {
"transformer": unet,
"transformer": transformer,
"patchifier": patchifier,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
Expand Down Expand Up @@ -365,6 +377,7 @@ def main():
if media_items is not None
else ConditioningMethod.UNCONDITIONAL
),
image_cond_noise_scale=args.image_cond_noise_scale,
mixed_precision=not args.bfloat16,
).images

Expand Down
31 changes: 31 additions & 0 deletions ltx_video/pipelines/pipeline_ltx_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,26 @@ def _clean_caption(self, caption):

return caption.strip()

def image_cond_noise_update(
self,
t,
init_latents,
latents,
noise_scale,
conditiong_mask,
generator,
):
noise = randn_tensor(
latents.shape,
generator=generator,
device=latents.device,
dtype=latents.dtype,
)
latents = (init_latents + noise_scale * noise * (t**2)) * conditiong_mask[
..., None
] + latents * (1 - conditiong_mask[..., None])
return latents

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(
self,
Expand Down Expand Up @@ -897,6 +917,7 @@ def __call__(
self.video_scale_factor = self.video_scale_factor if is_video else 1
conditioning_method = kwargs.get("conditioning_method", None)
vae_per_channel_normalize = kwargs.get("vae_per_channel_normalize", False)
image_cond_noise_scale = kwargs.get("image_cond_noise_scale", 0.0)
init_latents, conditioning_mask = self.prepare_conditioning(
media_items,
num_frames,
Expand Down Expand Up @@ -924,6 +945,7 @@ def __call__(
latents=init_latents,
latents_mask=conditioning_mask,
)
orig_conditiong_mask = conditioning_mask
if conditioning_mask is not None and is_video:
assert num_images_per_prompt == 1
conditioning_mask = (
Expand Down Expand Up @@ -954,6 +976,15 @@ def __call__(

with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if conditioning_method == ConditioningMethod.FIRST_FRAME:
latents = self.image_cond_noise_update(
t,
init_latents,
latents,
image_cond_noise_scale,
orig_conditiong_mask,
generator,
)
latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ requires-python = ">=3.10"
readme = "README.md"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent"
]
dependencies = [
Expand Down

0 comments on commit 7c3e1b0

Please sign in to comment.