From 932d2c0f7853ff08bc891ec8066f998d72e0b1f9 Mon Sep 17 00:00:00 2001 From: mzr1996 Date: Thu, 21 Dec 2023 18:36:23 +0800 Subject: [PATCH] Fix lint --- configs/llava/README.md | 8 +++--- configs/llava/llava-7b-v1_caption.py | 1 - mmpretrain/models/multimodal/llava/llava.py | 3 ++- mmpretrain/models/multimodal/llava/modules.py | 25 +++++++++++++------ 4 files changed, 23 insertions(+), 14 deletions(-) diff --git a/configs/llava/README.md b/configs/llava/README.md index b129518e96..581abfe5a6 100644 --- a/configs/llava/README.md +++ b/configs/llava/README.md @@ -33,11 +33,11 @@ print(out) ### Image Caption on COCO -| Model | Params (M) | Config | Download | -| :-------------------- | :--------: | :------------------------------: | :--------------------: | -| `llava-7b-v1_caption` | 7045.82 | [config](llava-7b-v1_caption.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1_liuhaotian_20231025-c9e119b6.pth) | +| Model | Params (M) | Config | Download | +| :---------------------- | :--------: | :--------------------------------: | :-------------------------------------------------------------------------------------------------------------: | +| `llava-7b-v1_caption` | 7045.82 | [config](llava-7b-v1_caption.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1_liuhaotian_20231025-c9e119b6.pth) | | `llava-7b-v1.5_caption` | 7062.90 | [config](llava-7b-v1.5_caption.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth) | -| `llava-7b-v1.5_vqa` | 7062.90 | [config](llava-7b-v1.5_vqa.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth) | +| `llava-7b-v1.5_vqa` | 7062.90 | [config](llava-7b-v1.5_vqa.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth) | ## Citation diff --git a/configs/llava/llava-7b-v1_caption.py b/configs/llava/llava-7b-v1_caption.py index 5759720680..92e2d1fb65 100644 --- a/configs/llava/llava-7b-v1_caption.py +++ b/configs/llava/llava-7b-v1_caption.py @@ -35,7 +35,6 @@ ), task='caption', prompt_tmpl=prompt_tmpl, - # generation_cfg=dict(num_beams=3, max_new_tokens=50, length_penalty=-1.0), generation_cfg=dict(max_new_tokens=50), ) diff --git a/mmpretrain/models/multimodal/llava/llava.py b/mmpretrain/models/multimodal/llava/llava.py index 6cc72e4606..f829b09214 100644 --- a/mmpretrain/models/multimodal/llava/llava.py +++ b/mmpretrain/models/multimodal/llava/llava.py @@ -77,7 +77,8 @@ def __init__(self, self.tokenizer = TOKENIZER.build(tokenizer) # add Llava special tokens to the tokenizer if use_im_patch: - self.tokenizer.add_tokens([self.im_patch_token], special_tokens=True) + self.tokenizer.add_tokens([self.im_patch_token], + special_tokens=True) if use_im_start_end: self.tokenizer.add_tokens([self.im_start_token, self.im_end_token], special_tokens=True) diff --git a/mmpretrain/models/multimodal/llava/modules.py b/mmpretrain/models/multimodal/llava/modules.py index 48d5710ebf..fa3c6bbbcc 100644 --- a/mmpretrain/models/multimodal/llava/modules.py +++ b/mmpretrain/models/multimodal/llava/modules.py @@ -58,7 +58,8 @@ def __init__(self, modules = [nn.Linear(self.mm_hidden_size, self.lang_hidden_size)] for _ in range(1, mm_proj_depth): modules.append(nn.GELU()) - modules.append(nn.Linear(self.lang_hidden_size, self.lang_hidden_size)) + modules.append( + nn.Linear(self.lang_hidden_size, self.lang_hidden_size)) mm_projector = nn.Sequential(*modules) self.lang_encoder.model.add_module('mm_projector', mm_projector) elif mm_proj_depth == 0: @@ -137,9 +138,15 @@ def forward_vision_tower( labels: torch.LongTensor, images: Union[torch.FloatTensor, None] = None, ): - if self.vision_tower is None or images is None or input_ids.shape[1] == 1: - if past_key_values is not None and self.vision_tower is not None and images is not None and input_ids.shape[1] == 1: - attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device) + if self.vision_tower is None or images is None or input_ids.shape[ + 1] == 1: + if (past_key_values is not None and self.vision_tower is not None + and images is not None and input_ids.shape[1] == 1): + attention_mask = torch.ones( + (attention_mask.shape[0], + past_key_values[-1][-1].shape[-2] + 1), + dtype=attention_mask.dtype, + device=attention_mask.device) return input_ids, attention_mask, past_key_values, None, labels with torch.no_grad(): @@ -192,16 +199,18 @@ def forward_vision_tower( cur_new_labels = torch.cat([ labels[batch_idx, :img_idx], labels.new_full((cur_img.size(0), ), -100), - labels[batch_idx, img_idx+1:], - ], dim=0) + labels[batch_idx, img_idx + 1:], + ], + dim=0) new_labels.append(cur_new_labels) if attention_mask is not None: cur_attn_mask = torch.cat([ attention_mask[batch_idx, :img_idx], attention_mask.new_full((cur_img.size(0), ), True), - attention_mask[batch_idx, img_idx+1:], - ], dim=0) + attention_mask[batch_idx, img_idx + 1:], + ], + dim=0) new_attn_mask.append(cur_attn_mask) inputs_embeds = torch.stack(new_input_embeds, dim=0)