Skip to content

Commit

Permalink
Fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
mzr1996 committed Dec 21, 2023
1 parent 29f88a7 commit 932d2c0
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 14 deletions.
8 changes: 4 additions & 4 deletions configs/llava/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ print(out)

### Image Caption on COCO

| Model | Params (M) | Config | Download |
| :-------------------- | :--------: | :------------------------------: | :--------------------: |
| `llava-7b-v1_caption` | 7045.82 | [config](llava-7b-v1_caption.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1_liuhaotian_20231025-c9e119b6.pth) |
| Model | Params (M) | Config | Download |
| :---------------------- | :--------: | :--------------------------------: | :-------------------------------------------------------------------------------------------------------------: |
| `llava-7b-v1_caption` | 7045.82 | [config](llava-7b-v1_caption.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1_liuhaotian_20231025-c9e119b6.pth) |
| `llava-7b-v1.5_caption` | 7062.90 | [config](llava-7b-v1.5_caption.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth) |
| `llava-7b-v1.5_vqa` | 7062.90 | [config](llava-7b-v1.5_vqa.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth) |
| `llava-7b-v1.5_vqa` | 7062.90 | [config](llava-7b-v1.5_vqa.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth) |

## Citation

Expand Down
1 change: 0 additions & 1 deletion configs/llava/llava-7b-v1_caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
),
task='caption',
prompt_tmpl=prompt_tmpl,
# generation_cfg=dict(num_beams=3, max_new_tokens=50, length_penalty=-1.0),
generation_cfg=dict(max_new_tokens=50),
)

Expand Down
3 changes: 2 additions & 1 deletion mmpretrain/models/multimodal/llava/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def __init__(self,
self.tokenizer = TOKENIZER.build(tokenizer)
# add Llava special tokens to the tokenizer
if use_im_patch:
self.tokenizer.add_tokens([self.im_patch_token], special_tokens=True)
self.tokenizer.add_tokens([self.im_patch_token],
special_tokens=True)
if use_im_start_end:
self.tokenizer.add_tokens([self.im_start_token, self.im_end_token],
special_tokens=True)
Expand Down
25 changes: 17 additions & 8 deletions mmpretrain/models/multimodal/llava/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def __init__(self,
modules = [nn.Linear(self.mm_hidden_size, self.lang_hidden_size)]
for _ in range(1, mm_proj_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(self.lang_hidden_size, self.lang_hidden_size))
modules.append(
nn.Linear(self.lang_hidden_size, self.lang_hidden_size))
mm_projector = nn.Sequential(*modules)
self.lang_encoder.model.add_module('mm_projector', mm_projector)
elif mm_proj_depth == 0:
Expand Down Expand Up @@ -137,9 +138,15 @@ def forward_vision_tower(
labels: torch.LongTensor,
images: Union[torch.FloatTensor, None] = None,
):
if self.vision_tower is None or images is None or input_ids.shape[1] == 1:
if past_key_values is not None and self.vision_tower is not None and images is not None and input_ids.shape[1] == 1:
attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device)
if self.vision_tower is None or images is None or input_ids.shape[
1] == 1:
if (past_key_values is not None and self.vision_tower is not None
and images is not None and input_ids.shape[1] == 1):
attention_mask = torch.ones(
(attention_mask.shape[0],
past_key_values[-1][-1].shape[-2] + 1),
dtype=attention_mask.dtype,
device=attention_mask.device)
return input_ids, attention_mask, past_key_values, None, labels

with torch.no_grad():
Expand Down Expand Up @@ -192,16 +199,18 @@ def forward_vision_tower(
cur_new_labels = torch.cat([
labels[batch_idx, :img_idx],
labels.new_full((cur_img.size(0), ), -100),
labels[batch_idx, img_idx+1:],
], dim=0)
labels[batch_idx, img_idx + 1:],
],
dim=0)
new_labels.append(cur_new_labels)

if attention_mask is not None:
cur_attn_mask = torch.cat([
attention_mask[batch_idx, :img_idx],
attention_mask.new_full((cur_img.size(0), ), True),
attention_mask[batch_idx, img_idx+1:],
], dim=0)
attention_mask[batch_idx, img_idx + 1:],
],
dim=0)
new_attn_mask.append(cur_attn_mask)

inputs_embeds = torch.stack(new_input_embeds, dim=0)
Expand Down

0 comments on commit 932d2c0

Please sign in to comment.