-
Notifications
You must be signed in to change notification settings - Fork 75
/
autoencode.py
55 lines (43 loc) · 1.38 KB
/
autoencode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# -*- coding: utf-8 -*-
import os
import os.path as osp
from PIL import Image
import torch
from transformers import AutoModel, AutoImageProcessor
MODEL_HUB = "BAAI/Emu3-VisionTokenizer"
model = AutoModel.from_pretrained(MODEL_HUB, trust_remote_code=True).eval().cuda()
processor = AutoImageProcessor.from_pretrained(MODEL_HUB, trust_remote_code=True)
# TODO: you need to modify the path here
VIDEO_FRAMES_PATH = "YOUR_VIDEO_FRAMES_PATH"
video = os.listdir(VIDEO_FRAMES_PATH)
video.sort()
video = [Image.open(osp.join(VIDEO_FRAMES_PATH, v)) for v in video]
images = processor(video, return_tensors="pt")["pixel_values"]
images = images.unsqueeze(0).cuda()
# image autoencode
image = images[:, 0]
print(image.shape)
with torch.no_grad():
# encode
codes = model.encode(image)
# decode
recon = model.decode(codes)
recon = recon.view(-1, *recon.shape[2:])
recon_image = processor.postprocess(recon)["pixel_values"][0]
recon_image.save("recon_image.png")
# video autoencode
images = images.view(
-1,
model.config.temporal_downsample_factor,
*images.shape[2:],
)
print(images.shape)
with torch.no_grad():
# encode
codes = model.encode(images)
# decode
recon = model.decode(codes)
recon = recon.view(-1, *recon.shape[2:])
recon_images = processor.postprocess(recon)["pixel_values"]
for idx, im in enumerate(recon_images):
im.save(f"recon_video_{idx}.png")