Skip to content

Commit

Permalink
Merge pull request #4 from shonenkov-AI/feature/up_version
Browse files Browse the repository at this point in the history
up version
  • Loading branch information
shonenkov authored Jun 27, 2022
2 parents c59820c + 506476b commit 10fb15d
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 15 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Generate any arbitrary aspect ratio images using the ruDALLE models
### Installing

```
pip install rudalle==1.0.0
pip install rudalle==1.1.1
git clone https://github.com/shonenkov-AI/rudalle-aspect-ratio
```

Expand All @@ -31,7 +31,7 @@ rudalle_ar = RuDalleAspectRatio(
dalle=dalle, vae=vae, tokenizer=tokenizer,
aspect_ratio=32/9, bs=4, device=device
)
_, result_pil_images = rudalle_ar.generate_images('готический квартал', 1024, 0.975, 4)
_, result_pil_images = rudalle_ar.generate_images('готический квартал', 768, 0.99, 4)
show(result_pil_images, 1)
```
![](./pics/h_example.jpg)
Expand All @@ -42,7 +42,7 @@ rudalle_ar = RuDalleAspectRatio(
dalle=dalle, vae=vae, tokenizer=tokenizer,
aspect_ratio=9/32, bs=4, device=device
)
_, result_pil_images = rudalle_ar.generate_images('голубой цветок', 512, 0.975, 4)
_, result_pil_images = rudalle_ar.generate_images('голубой цветок', 768, 0.99, 4)
show(result_pil_images, 4)
```

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
rudalle==1.0.0
rudalle==1.1.1
12 changes: 6 additions & 6 deletions rudalle_aspect_ratio/aspect_ratio.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def generate_w_codebooks(self, text, top_k, top_p, images_num, image_prompts=Non
torch.ones((chunk_bs, 1, self.total_seq_length, self.total_seq_length), device=self.device)
)
out = input_ids.unsqueeze(0).repeat(chunk_bs, 1).to(self.device)
has_cache = False
cache = {}
if image_prompts is not None:
prompts_idx, prompts = image_prompts.image_prompts_idx, image_prompts.image_prompts
range_out = range(out.shape[1], self.total_seq_length)
Expand All @@ -97,8 +97,8 @@ def generate_w_codebooks(self, text, top_k, top_p, images_num, image_prompts=Non
if image_prompts is not None and idx in prompts_idx:
out = torch.cat((out, prompts[:, idx].unsqueeze(1)), dim=-1)
else:
logits, has_cache = self.dalle(out, attention_mask,
has_cache=has_cache, use_cache=use_cache, return_loss=False)
logits, cache = self.dalle(out, attention_mask,
cache=cache, use_cache=use_cache, return_loss=False)
logits = logits[:, -1, self.vocab_size:]
logits /= temperature
filtered_logits = transformers.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
Expand Down Expand Up @@ -146,10 +146,10 @@ def generate_h_codebooks(self, text, top_k, top_p, images_num, temperature=1.0,
full_context[:, self.text_seq_length:][:, -j * self.image_tokens_per_dim:]
), dim=-1)

has_cache = False
cache = {}
for _ in range(self.image_tokens_per_dim):
logits, has_cache = self.dalle(out, attention_mask,
has_cache=has_cache, use_cache=use_cache, return_loss=False)
logits, cache = self.dalle(out, attention_mask,
cache=cache, use_cache=use_cache, return_loss=False)
logits = logits[:, -1, self.vocab_size:]
logits /= temperature
filtered_logits = transformers.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
Expand Down
11 changes: 6 additions & 5 deletions rudalle_aspect_ratio/image_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,23 @@ def _get_image_prompts(self, img, borders, vae, crop_first):
vqg_img = torch.zeros((bs, vqg_img_h, vqg_img_w), dtype=torch.int32, device=img.device)
if borders['down'] != 0:
down_border = borders['down'] * 8
_, _, [_, _, down_vqg_img] = vae.model.encode(img[:, :, -down_border:, :])
_, _, [_, _, down_vqg_img] = vae.model.encode(img[:, :, -down_border:, :], disable_gumbel_softmax=True)
vqg_img[:, -borders['down']:, :] = down_vqg_img
if borders['right'] != 0:
right_border = borders['right'] * 8
_, _, [_, _, right_vqg_img] = vae.model.encode(img[:, :, :, -right_border:])
_, _, [_, _, right_vqg_img] = vae.model.encode(
img[:, :, :, -right_border:], disable_gumbel_softmax=True)
vqg_img[:, :, -borders['right']:] = right_vqg_img
if borders['left'] != 0:
left_border = borders['left'] * 8
_, _, [_, _, left_vqg_img] = vae.model.encode(img[:, :, :, :left_border])
_, _, [_, _, left_vqg_img] = vae.model.encode(img[:, :, :, :left_border], disable_gumbel_softmax=True)
vqg_img[:, :, :borders['left']] = left_vqg_img
if borders['up'] != 0:
up_border = borders['up'] * 8
_, _, [_, _, up_vqg_img] = vae.model.encode(img[:, :, :up_border, :])
_, _, [_, _, up_vqg_img] = vae.model.encode(img[:, :, :up_border, :], disable_gumbel_softmax=True)
vqg_img[:, :borders['up'], :] = up_vqg_img
else:
_, _, [_, _, vqg_img] = vae.model.encode(img)
_, _, [_, _, vqg_img] = vae.model.encode(img, disable_gumbel_softmax=True)

bs, vqg_img_h, vqg_img_w = vqg_img.shape
mask = torch.zeros(vqg_img_h, vqg_img_w)
Expand Down

0 comments on commit 10fb15d

Please sign in to comment.