-
Notifications
You must be signed in to change notification settings - Fork 0
/
thumbgenie.py
168 lines (135 loc) · 6.28 KB
/
thumbgenie.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import os
import torch
import numpy as np
import requests
from bs4 import BeautifulSoup
from transformers import DistilBertTokenizer, DistilBertModel
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from peft import get_peft_model, LoraConfig
from torchvision.transforms import ToTensor, Resize
from torchvision.utils import save_image
# Constants
TITLE_MAX_LENGTH = 128
CATEGORY_EMBEDDING_WEIGHT = 2.0
MEDIAFIRE_URL = "https://www.mediafire.com/file/hxj2h3gn0y6ibtz/unet_final-001.pt/file"
CACHE_DIR = "./.cache"
GENERATED_IMAGE_DIR = "./generated_images"
CATEGORIES = ["science", "news", "food", "blog", "tech", "informative", "comedy", "entertainment", "automobile", "videogames"]
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Configurable constants
IMAGE_RESOLUTION = (720, 1280)
FINAL_MODEL_PATH = None
os.makedirs(CACHE_DIR, exist_ok=True)
os.makedirs(GENERATED_IMAGE_DIR, exist_ok=True)
def get_mediafire_direct_link(mediafire_url: str) -> str:
response = requests.get(mediafire_url)
soup = BeautifulSoup(response.content, 'html.parser')
download_link = soup.find('a', {'id': 'downloadButton'})
if download_link:
return download_link['href']
else:
raise Exception("Could not find the download link on the page")
def download_file(url: str, save_path: str) -> None:
response = requests.get(url, stream=True)
with open(save_path, 'wb') as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
if FINAL_MODEL_PATH is None:
FINAL_MODEL_PATH = "./final_model/unet_final-001.pt"
if not os.path.exists(FINAL_MODEL_PATH):
print("Final Model Path set as 'None'... Downloading default model...")
os.makedirs(os.path.dirname(FINAL_MODEL_PATH), exist_ok=True)
direct_link = get_mediafire_direct_link(MEDIAFIRE_URL)
download_file(direct_link, FINAL_MODEL_PATH)
print(f"Default model has been downloaded and saved to {FINAL_MODEL_PATH}")
else:
if not os.path.exists(FINAL_MODEL_PATH):
raise FileNotFoundError(f"Model file not found at {FINAL_MODEL_PATH}")
def preprocess_text(title, category):
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', cache_dir=CACHE_DIR)
text_encoder = DistilBertModel.from_pretrained('distilbert-base-uncased', cache_dir=CACHE_DIR).to(DEVICE)
category_emphasis = ' '.join([category] * int(CATEGORY_EMBEDDING_WEIGHT))
combined_text = f"{category_emphasis} {title}"
inputs = tokenizer(combined_text, return_tensors="pt", max_length=TITLE_MAX_LENGTH, padding="max_length", truncation=True)
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
with torch.no_grad():
outputs = text_encoder(**inputs)
return outputs.last_hidden_state.squeeze(0)
def load_model() -> StableDiffusionPipeline:
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(DEVICE)
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["to_q", "to_v"],
lora_dropout=0.05,
bias="none",
)
pipe.unet = get_peft_model(pipe.unet, lora_config)
pipe.unet.load_state_dict(torch.load(FINAL_MODEL_PATH, map_location=DEVICE))
pipe.unet.eval()
pipe.safety_checker = None
return pipe
def generate_images(pipe: StableDiffusionPipeline, prompt, num_images: int = 1, generator_seed: int = 42) -> np.ndarray:
generator = torch.Generator(device=DEVICE).manual_seed(generator_seed)
with torch.no_grad():
images = pipe(
prompt,
num_images_per_prompt=num_images,
generator=generator,
guidance_scale=7.5,
height=IMAGE_RESOLUTION[0],
width=IMAGE_RESOLUTION[1],
).images
return images
def save_generated_images(images: np.ndarray, save_dir: str) -> None:
for i, image in enumerate(images):
image_tensor = ToTensor()(image).unsqueeze(0)
image_resized = Resize((IMAGE_RESOLUTION[0], IMAGE_RESOLUTION[1]))(image_tensor).squeeze(0)
image_path = os.path.join(save_dir, f"generated_image_{i+1}.png")
save_image(image_resized, image_path)
print(f"Saved generated image: {image_path}")
def get_user_input():
while True:
print("Enter 'q' to quit at any time.")
title = input(f"Enter title (max {TITLE_MAX_LENGTH} characters): ").strip()
if title == "q":
return None, None, None
if len(title) > TITLE_MAX_LENGTH:
print(f"Title is too long. Please enter a title with at most {TITLE_MAX_LENGTH} characters.")
continue
print("Categories: " + ", ".join(CATEGORIES))
category1 = input("Select first category: ").strip().lower()
if category1 == "q":
return None, None, None
category2 = input("Select second category (or press Enter to use the same category): ").strip().lower()
if category2 == "q":
return None, None, None
if category1 not in CATEGORIES or (category2 and category2 not in CATEGORIES):
print("Invalid category selection. Please choose from the provided list.")
continue
if not category2:
category2 = category1
categories = f"{category1} {category2}" if category1 != category2 else category1
batch_size = input("Enter number of images to be generated (default is 1): ").strip()
if batch_size == "q":
return None, None, None
elif batch_size == 0:
batch_size = 1
if not batch_size.isdigit():
print("Invalid input. Please enter a valid number.")
continue
batch_size = int(batch_size)
return title, categories, batch_size
if __name__ == "__main__":
while True:
title, categories, batch_size = get_user_input()
if title is None and categories is None:
print("Exiting...")
break
prompt = f"{categories} {title}"
pipe = load_model()
images = generate_images(pipe, prompt, num_images=batch_size)
save_generated_images(images, GENERATED_IMAGE_DIR)