From 506476bfa2da72b046442241db0488be22a6de0a Mon Sep 17 00:00:00 2001 From: shonenkov Date: Mon, 27 Jun 2022 01:31:57 +0400 Subject: [PATCH] up version --- README.md | 6 +++--- requirements.txt | 2 +- rudalle_aspect_ratio/aspect_ratio.py | 12 ++++++------ rudalle_aspect_ratio/image_prompts.py | 11 ++++++----- 4 files changed, 16 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 9848cd4..2c7bc61 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ Generate any arbitrary aspect ratio images using the ruDALLE models ### Installing ``` -pip install rudalle==1.0.0 +pip install rudalle==1.1.1 git clone https://github.com/shonenkov-AI/rudalle-aspect-ratio ``` @@ -31,7 +31,7 @@ rudalle_ar = RuDalleAspectRatio( dalle=dalle, vae=vae, tokenizer=tokenizer, aspect_ratio=32/9, bs=4, device=device ) -_, result_pil_images = rudalle_ar.generate_images('готический квартал', 1024, 0.975, 4) +_, result_pil_images = rudalle_ar.generate_images('готический квартал', 768, 0.99, 4) show(result_pil_images, 1) ``` ![](./pics/h_example.jpg) @@ -42,7 +42,7 @@ rudalle_ar = RuDalleAspectRatio( dalle=dalle, vae=vae, tokenizer=tokenizer, aspect_ratio=9/32, bs=4, device=device ) -_, result_pil_images = rudalle_ar.generate_images('голубой цветок', 512, 0.975, 4) +_, result_pil_images = rudalle_ar.generate_images('голубой цветок', 768, 0.99, 4) show(result_pil_images, 4) ``` diff --git a/requirements.txt b/requirements.txt index 199b651..0259e51 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -rudalle==1.0.0 \ No newline at end of file +rudalle==1.1.1 \ No newline at end of file diff --git a/rudalle_aspect_ratio/aspect_ratio.py b/rudalle_aspect_ratio/aspect_ratio.py index d4feda4..b947e20 100644 --- a/rudalle_aspect_ratio/aspect_ratio.py +++ b/rudalle_aspect_ratio/aspect_ratio.py @@ -86,7 +86,7 @@ def generate_w_codebooks(self, text, top_k, top_p, images_num, image_prompts=Non torch.ones((chunk_bs, 1, self.total_seq_length, self.total_seq_length), device=self.device) ) out = input_ids.unsqueeze(0).repeat(chunk_bs, 1).to(self.device) - has_cache = False + cache = {} if image_prompts is not None: prompts_idx, prompts = image_prompts.image_prompts_idx, image_prompts.image_prompts range_out = range(out.shape[1], self.total_seq_length) @@ -97,8 +97,8 @@ def generate_w_codebooks(self, text, top_k, top_p, images_num, image_prompts=Non if image_prompts is not None and idx in prompts_idx: out = torch.cat((out, prompts[:, idx].unsqueeze(1)), dim=-1) else: - logits, has_cache = self.dalle(out, attention_mask, - has_cache=has_cache, use_cache=use_cache, return_loss=False) + logits, cache = self.dalle(out, attention_mask, + cache=cache, use_cache=use_cache, return_loss=False) logits = logits[:, -1, self.vocab_size:] logits /= temperature filtered_logits = transformers.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) @@ -146,10 +146,10 @@ def generate_h_codebooks(self, text, top_k, top_p, images_num, temperature=1.0, full_context[:, self.text_seq_length:][:, -j * self.image_tokens_per_dim:] ), dim=-1) - has_cache = False + cache = {} for _ in range(self.image_tokens_per_dim): - logits, has_cache = self.dalle(out, attention_mask, - has_cache=has_cache, use_cache=use_cache, return_loss=False) + logits, cache = self.dalle(out, attention_mask, + cache=cache, use_cache=use_cache, return_loss=False) logits = logits[:, -1, self.vocab_size:] logits /= temperature filtered_logits = transformers.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) diff --git a/rudalle_aspect_ratio/image_prompts.py b/rudalle_aspect_ratio/image_prompts.py index b4a9ff7..7cbda20 100644 --- a/rudalle_aspect_ratio/image_prompts.py +++ b/rudalle_aspect_ratio/image_prompts.py @@ -24,22 +24,23 @@ def _get_image_prompts(self, img, borders, vae, crop_first): vqg_img = torch.zeros((bs, vqg_img_h, vqg_img_w), dtype=torch.int32, device=img.device) if borders['down'] != 0: down_border = borders['down'] * 8 - _, _, [_, _, down_vqg_img] = vae.model.encode(img[:, :, -down_border:, :]) + _, _, [_, _, down_vqg_img] = vae.model.encode(img[:, :, -down_border:, :], disable_gumbel_softmax=True) vqg_img[:, -borders['down']:, :] = down_vqg_img if borders['right'] != 0: right_border = borders['right'] * 8 - _, _, [_, _, right_vqg_img] = vae.model.encode(img[:, :, :, -right_border:]) + _, _, [_, _, right_vqg_img] = vae.model.encode( + img[:, :, :, -right_border:], disable_gumbel_softmax=True) vqg_img[:, :, -borders['right']:] = right_vqg_img if borders['left'] != 0: left_border = borders['left'] * 8 - _, _, [_, _, left_vqg_img] = vae.model.encode(img[:, :, :, :left_border]) + _, _, [_, _, left_vqg_img] = vae.model.encode(img[:, :, :, :left_border], disable_gumbel_softmax=True) vqg_img[:, :, :borders['left']] = left_vqg_img if borders['up'] != 0: up_border = borders['up'] * 8 - _, _, [_, _, up_vqg_img] = vae.model.encode(img[:, :, :up_border, :]) + _, _, [_, _, up_vqg_img] = vae.model.encode(img[:, :, :up_border, :], disable_gumbel_softmax=True) vqg_img[:, :borders['up'], :] = up_vqg_img else: - _, _, [_, _, vqg_img] = vae.model.encode(img) + _, _, [_, _, vqg_img] = vae.model.encode(img, disable_gumbel_softmax=True) bs, vqg_img_h, vqg_img_w = vqg_img.shape mask = torch.zeros(vqg_img_h, vqg_img_w)