Skip to content

Commit

Permalink
add paper
Browse files Browse the repository at this point in the history
  • Loading branch information
heheyas committed Mar 10, 2024
1 parent b2bd10b commit cb3369d
Show file tree
Hide file tree
Showing 4 changed files with 249 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ spirals
paper/
spirals_co3d/
scene_spirals/
assets/
blenders/
colmap_results/
depth_spirals/
recon/SIBR_viewers/
recon/assets/
247 changes: 247 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
# TODO
import numpy as np
import argparse
import torch
from torchvision.utils import make_grid
import tempfile
import gradio as gr
from omegaconf import OmegaConf
from einops import rearrange
from scripts.pub.V3D_512 import (
sample_one,
get_batch,
get_unique_embedder_keys_from_conditioner,
load_model,
)
from sgm.util import default, instantiate_from_config
from safetensors.torch import load_file as load_safetensors
from PIL import Image
from kiui.op import recenter
from torchvision.transforms import ToTensor
from einops import rearrange, repeat
import rembg
import os
from glob import glob


def do_sample(
image,
model,
clip_model,
ae_model,
device,
num_frames,
num_steps,
decoding_t,
border_ratio,
ignore_alpha,
rembg_session,
output_folder,
):
# if image.mode == "RGBA":
# image = image.convert("RGB")
w, h = image.size
image = Image.fromarray(image)

if border_ratio > 0:
if image.mode != "RGBA" or ignore_alpha:
image = image.convert("RGB")
image = np.asarray(image)
carved_image = rembg.remove(image, session=rembg_session) # [H, W, 4]
else:
image = np.asarray(image)
carved_image = image
mask = carved_image[..., -1] > 0
image = recenter(carved_image, mask, border_ratio=border_ratio)
image = image.astype(np.float32) / 255.0
if image.shape[-1] == 4:
image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
image = Image.fromarray((image * 255).astype(np.uint8))
else:
print("Ignore border ratio")
image = image.resize((512, 512))

image = ToTensor()(image)
image = image * 2.0 - 1.0

image = image.unsqueeze(0).to(device)
H, W = image.shape[2:]
assert image.shape[1] == 3
F = 8
C = 4
shape = (num_frames, C, H // F, W // F)

value_dict = {}
value_dict["motion_bucket_id"] = 0
value_dict["fps_id"] = 0
value_dict["cond_aug"] = 0.05
value_dict["cond_frames_without_noise"] = clip_model(image)
value_dict["cond_frames"] = ae_model.encode(image)
value_dict["cond_frames"] += 0.05 * torch.randn_like(value_dict["cond_frames"])
value_dict["cond_aug"] = 0.05

with torch.no_grad():
with torch.autocast(device):
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
[1, num_frames],
T=num_frames,
device=device,
)
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=[
"cond_frames",
"cond_frames_without_noise",
],
)

for k in ["crossattn", "concat"]:
uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)

randn = torch.randn(shape, device=device)
randn = randn.to(device)

additional_model_inputs = {}
additional_model_inputs["image_only_indicator"] = torch.zeros(
2, num_frames
).to(device)
additional_model_inputs["num_video_frames"] = batch["num_video_frames"]

def denoiser(input, sigma, c):
return model.denoiser(
model.model, input, sigma, c, **additional_model_inputs
)

samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
model.en_and_decode_n_samples_a_time = decoding_t
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)

os.makedirs(output_folder, exist_ok=True)
base_count = len(glob(os.path.join(output_folder, "*.mp4")))
video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")

frames = (
(rearrange(samples, "t c h w -> t h w c") * 255)
.cpu()
.numpy()
.astype(np.uint8)
)

return frames


def change_model_params(model, min_cfg, max_cfg):
model.params.sampler.guider.max_scale = max_cfg
model.params.sampler.guider.min_scale = min_cfg


def launch(device="cuda", port=4321, share=False):
model_config = "scripts/pub/configs/V3D_512.yaml"
num_frames = OmegaConf.load(
model_config
).model.params.sampler_config.params.guider_config.params.num_frames
print("Detected num_frames:", num_frames)
num_steps = default(num_steps, 25)
output_folder = default(output_folder, f"outputs/V3D_512")
decoding_t = min(decoding_t, num_frames)

sd = load_safetensors("./ckpts/svd_xt.safetensors")
clip_model_config = OmegaConf.load("configs/embedder/clip_image.yaml")
clip_model = instantiate_from_config(clip_model_config).eval()
clip_sd = dict()
for k, v in sd.items():
if "conditioner.embedders.0" in k:
clip_sd[k.replace("conditioner.embedders.0.", "")] = v
clip_model.load_state_dict(clip_sd)
clip_model = clip_model.to(device)

ae_model_config = OmegaConf.load("configs/ae/video.yaml")
ae_model = instantiate_from_config(ae_model_config).eval()
encoder_sd = dict()
for k, v in sd.items():
if "first_stage_model" in k:
encoder_sd[k.replace("first_stage_model.", "")] = v
ae_model.load_state_dict(encoder_sd)
ae_model = ae_model.to(device)
rembg_session = rembg.new_session()

model = load_model(model_config, device, num_frames, num_steps, 3.5, 3.5)

with gr.Blocks(title="V3D", theme=gr.themes.Monochrome()) as demo:
with gr.Row(equal_height=True):
with gr.Column():
input_image = gr.Image(value=None, label="Input Image")

border_ratio_slider = gr.Slider(
value=0.05,
label="Border Ratio",
min=0.05,
max=0.5,
step=0.05,
)
min_guidance_slider = gr.Slider(
value=0.05,
label="Min CFG Value",
min=0.05,
max=0.5,
step=0.05,
)
max_guidance_slider = gr.Slider(
value=0.05,
label="Max CFG Value",
min=0.05,
max=0.5,
step=0.05,
)
run_button = gr.Button(value="Run V3D")

with gr.Column():
output_video = gr.Video(value=None, label="Output Orbit Video")

@run_button.click(
inputs=[
input_image,
border_ratio_slider,
min_guidance_slider,
max_guidance_slider,
],
outputs=[output_video],
)
def _(image, border_ratio, min_guidance, max_guidance):
change_model_params(model, min_guidance, max_guidance)
return do_sample(
image,
model,
clip_model,
ae_model,
device,
num_frames,
num_steps,
decoding_t,
border_ratio,
False,
rembg_session,
output_folder,
)

demo.launch(
inbrowser=True, inline=False, server_port=port, share=share, show_error=True
)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=4321)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--share", action="store_true")

opt = parser.parse_args()

launch(opt.device, opt.port, opt.share)
Binary file added assets/pdf/V3D.pdf
Binary file not shown.
2 changes: 1 addition & 1 deletion readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Tsinghua University, ShengShu
This repository contains the official implementation of [V3D: Video Diffusion Models are Effective 3D Generators](404).


### [Paper](TBD) | [Project Page](TBD) | [HF Demo](TBD)
### [Paper](assets/pdf/V3D.pdf) | [Project Page](TBD) | [HF Demo](TBD)

### Video results
Single Image to 3D
Expand Down

0 comments on commit cb3369d

Please sign in to comment.