-
Notifications
You must be signed in to change notification settings - Fork 36
/
caption.py
189 lines (158 loc) · 7.17 KB
/
caption.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import sys
if "LLaVA" not in sys.path:
sys.path.append("LLaVA")
import os
import subprocess
import time
from pathlib import Path
import requests
import torch
from PIL import Image
from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
from llava.conversation import conv_templates
from llava.mm_utils import tokenizer_image_token
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
# url for the weights mirror
REPLICATE_WEIGHTS_URL = "https://weights.replicate.delivery/default"
# files to download from the weights mirrors
weights = [
{
"dest": "liuhaotian/llava-v1.5-13b",
# git commit hash from huggingface
"src": "llava-v1.5-13b/006818fc465ebda4c003c0998674d9141d8d95f8",
"files": [
"config.json",
"generation_config.json",
"pytorch_model-00001-of-00003.bin",
"pytorch_model-00002-of-00003.bin",
"pytorch_model-00003-of-00003.bin",
"pytorch_model.bin.index.json",
"special_tokens_map.json",
"tokenizer.model",
"tokenizer_config.json",
],
},
{
"dest": "openai/clip-vit-large-patch14-336",
"src": "clip-vit-large-patch14-336/ce19dc912ca5cd21c8a653c79e251e808ccabcd1",
"files": ["config.json", "preprocessor_config.json", "pytorch_model.bin"],
},
]
PROMPT = """
Write a four sentence caption for this image. In the first sentence describe the style and type (painting, photo, etc) of the image. Describe in the remaining sentences the contents and composition of the image. Only use language that would be used to prompt a text to image model. Do not include usage. Comma separate keywords rather than using "or". Precise composition is important. Avoid phrases like "conveys a sense of" and "capturing the", just use the terms themselves.
Good examples are:
"Photo of an alien woman with a glowing halo standing on top of a mountain, wearing a white robe and silver mask in the futuristic style with futuristic design, sky background, soft lighting, dynamic pose, a sense of future technology, a science fiction movie scene rendered in the Unreal Engine."
"A scene from the cartoon series Masters of the Universe depicts Man-At-Arms wearing a gray helmet and gray armor with red gloves. He is holding an iron bar above his head while looking down on Orko, a pink blob character. Orko is sitting behind Man-At-Arms facing left on a chair. Both characters are standing near each other, with Orko inside a yellow chestplate over a blue shirt and black pants. The scene is drawn in the style of the Masters of the Universe cartoon series."
"An emoji, digital illustration, playful, whimsical. A cartoon zombie character with green skin and tattered clothes reaches forward with two hands, they have green skin, messy hair, an open mouth and gaping teeth, one eye is half closed."
""".strip()
def download_json(url: str, dest: Path):
res = requests.get(url, allow_redirects=True)
if res.status_code == 200 and res.content:
with dest.open("wb") as f:
f.write(res.content)
else:
print(f"Failed to download {url}. Status code: {res.status_code}")
def download_weights(baseurl: str, basedest: str, files: list[str]):
base_dir = Path(basedest)
start = time.time()
print("downloading to: ", base_dir)
base_dir.mkdir(parents=True, exist_ok=True)
for f in files:
dest = base_dir / f
url = f"{REPLICATE_WEIGHTS_URL}/{baseurl}/{f}"
if not dest.exists():
print("downloading url: ", url)
if dest.suffix == ".json":
download_json(url, dest)
else:
subprocess.check_call(["pget", url, str(dest)], close_fds=False)
print("downloading took: ", time.time() - start)
class Captioner:
def load_models(self):
for weight in weights:
download_weights(weight["src"], weight["dest"], weight["files"])
disable_torch_init()
self.tokenizer, self.model, self.image_processor, self.context_len = (
load_pretrained_model(
"liuhaotian/llava-v1.5-13b",
model_name="llava-v1.5-13b",
model_base=None,
load_8bit=False,
load_4bit=False,
)
)
def iter_images_captions(self, image_folder: Path):
for root, _, files in os.walk(image_folder):
for filename in files:
if filename.lower().endswith(
(".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp")
):
image_path = Path(root) / filename
caption_filename = image_path.stem + ".txt"
caption_path = image_path.parent / caption_filename
yield image_path, caption_path
def all_images_are_captioned(self, image_folder: Path):
for _, caption_path in self.iter_images_captions(image_folder):
if not caption_path.exists():
return False
return True
def caption_images(
self, image_folder: Path, autocaption_prefix: str, autocaption_suffix: str
):
for image_path, caption_path in self.iter_images_captions(image_folder):
if caption_path.exists():
print(f"{image_path.name} is already captioned")
else:
self.caption_image(
image_path, caption_path, autocaption_prefix, autocaption_suffix
)
def caption_image(
self,
image_path: Path,
caption_path: Path,
autocaption_prefix: str,
autocaption_suffix: str,
):
conv_mode = "llava_v1"
conv = conv_templates[conv_mode].copy()
image_data = Image.open(image_path).convert("RGB")
image_tensor = (
self.image_processor.preprocess(image_data, return_tensors="pt")[
"pixel_values"
]
.half()
.cuda()
)
# just one turn, always prepend image token
inp = DEFAULT_IMAGE_TOKEN + "\n"
inp += PROMPT
if autocaption_prefix:
inp += f"\n\nYou must start the caption with '{autocaption_prefix}'. "
if autocaption_suffix:
inp += f"\n\nYou must end the caption with '{autocaption_suffix}'."
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = (
tokenizer_image_token(
prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
)
.unsqueeze(0) # pyright: ignore
.cuda()
)
with torch.inference_mode():
output_ids = self.model.generate(
inputs=input_ids,
images=image_tensor,
do_sample=True,
temperature=0.2,
top_p=1.0,
max_new_tokens=512,
use_cache=True,
)
output = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[
0
].strip()
print(f"Caption for {image_path}: {output}")
caption_path.write_text(output)