Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add minigpt4 gradio demo and training script #1758

Merged
merged 12 commits into from
Oct 12, 2023
185 changes: 185 additions & 0 deletions configs/minigpt4/minigpt-4_baichuan-7b_caption.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
_base_ = [
'../_base_/default_runtime.py',
]

data_preprocessor = dict(
type='MultiModalDataPreprocessor',
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)

# dataset settings
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(224, 224),
interpolation='bicubic',
backend='pillow'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='CleanCaption',
keys='chat_content',
remove_chars='',
lowercase=False),
dict(
type='PackInputs',
algorithm_keys=['chat_content', 'lang'],
meta_keys=['image_id']),
]

train_dataloader = dict(
batch_size=2,
num_workers=4,
dataset=dict(
type='MiniGPT4Dataset',
data_root='YOUR_DATA_DIRECTORY',
ann_file='YOUR_DATA_FILE',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
drop_last=False,
)

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(224, 224),
interpolation='bicubic',
backend='pillow'),
dict(type='PackInputs', meta_keys=['image_id']),
]

test_evaluator = dict(
type='COCOCaption',
ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
)

test_dataloader = dict(
batch_size=1,
dataset=dict(
type='COCOCaption',
data_root='data/coco',
ann_file='annotations/coco_karpathy_val.json',
pipeline=test_pipeline))

# model settings
model = dict(
type='MiniGPT4',
vision_encoder=dict(
type='BEiTViT',
# eva-g without the final layer
arch=dict(
embed_dims=1408,
num_layers=39,
num_heads=16,
feedforward_channels=6144,
),
img_size=224,
patch_size=14,
layer_scale_init_value=0.0,
frozen_stages=39,
use_abs_pos_emb=True,
use_rel_pos_bias=False,
final_norm=False,
use_shared_rel_pos_bias=False,
out_type='raw',
pretrained= # noqa
'https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_eva-g-p14_20230615-e908c021.pth' # noqa
),
q_former_model=dict(
type='Qformer',
model_style='bert-base-uncased',
vision_model_width=1408,
add_cross_attention=True,
cross_attention_freq=2,
num_query_token=32,
pretrained= # noqa
'https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_qformer_20230615-1dfa889c.pth' # noqa
),
lang_encoder=dict(
type='AutoModelForCausalLM',
name_or_path='YOUR_PATH_TO_BAICHUAN',
trust_remote_code=True),
tokenizer=dict(
type='AutoTokenizer',
name_or_path='YOUR_PATH_TO_BAICHUAN',
trust_remote_code=True),
task='caption',
en_prompt_template='###Ask: {} ###Answer: ',
zh_prompt_template='###问:{} ###答:',
raw_prompts=[
[
'<Img><ImageHere></Img> Describe this image in detail.',
'<Img><ImageHere></Img> Take a look at this image and describe what you notice.', # noqa
'<Img><ImageHere></Img> Please provide a detailed description of the picture.', # noqa
'<Img><ImageHere></Img> Could you describe the contents of this image for me?' # noqa
],
[
'<Img><ImageHere></Img> 详细描述这张图片。',
'<Img><ImageHere></Img> 浏览这张图片并描述你注意到什么。',
'<Img><ImageHere></Img> 请对这张图片进行详细的描述。',
'<Img><ImageHere></Img> 你能为我描述这张图片的内容吗?'
]
],
max_txt_len=160,
end_sym='###')

strategy = dict(
type='DeepSpeedStrategy',
fp16=dict(
enabled=True,
auto_cast=False,
fp16_master_weights_and_grads=False,
loss_scale=0,
loss_scale_window=1000,
hysteresis=1,
min_loss_scale=1,
initial_scale_power=16,
),
inputs_to_half=[0],
zero_optimization=dict(
stage=2,
allgather_partitions=True,
allgather_bucket_size=2e8,
reduce_scatter=True,
reduce_bucket_size='auto',
overlap_comm=True,
contiguous_gradients=True,
),
)

# schedule settings
optim_wrapper = dict(
type='DeepSpeedOptimWrapper',
optimizer=dict(type='AdamW', lr=1e-3, weight_decay=0.05))

param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-3 / 500,
by_epoch=False,
begin=0,
end=500,
),
dict(
type='CosineAnnealingLR',
eta_min=2e-4,
by_epoch=False,
begin=500,
),
]

train_cfg = dict(by_epoch=True, max_epochs=6)
test_cfg = dict()

default_hooks = dict(
checkpoint=dict(
type='CheckpointHook',
interval=1,
by_epoch=True,
save_last=True,
max_keep_ckpts=1,
))
16 changes: 12 additions & 4 deletions configs/minigpt4/minigpt-4_vicuna-7b_caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,18 @@
task='caption',
prompt_template='###Human: {} ###Assistant: ',
hmtbgc marked this conversation as resolved.
Show resolved Hide resolved
raw_prompts=[
'<Img><ImageHere></Img> Describe this image in detail.',
'<Img><ImageHere></Img> Take a look at this image and describe what you notice.', # noqa
'<Img><ImageHere></Img> Please provide a detailed description of the picture.', # noqa
'<Img><ImageHere></Img> Could you describe the contents of this image for me?', # noqa
[
'<Img><ImageHere></Img> Describe this image in detail.',
'<Img><ImageHere></Img> Take a look at this image and describe what you notice.', # noqa
'<Img><ImageHere></Img> Please provide a detailed description of the picture.', # noqa
'<Img><ImageHere></Img> Could you describe the contents of this image for me?' # noqa
],
[
'<Img><ImageHere></Img> 详细描述这张图片。',
'<Img><ImageHere></Img> 浏览这张图片并描述你注意到什么。',
'<Img><ImageHere></Img> 请对这张图片进行详细的描述。',
'<Img><ImageHere></Img> 你能为我描述这张图片的内容吗?'
]
hmtbgc marked this conversation as resolved.
Show resolved Hide resolved
],
max_txt_len=160,
end_sym='###')
Expand Down
4 changes: 3 additions & 1 deletion mmpretrain/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from .gqa_dataset import GQA
from .iconqa import IconQA
from .infographic_vqa import InfographicVQA
from .minigpt4_dataset import MiniGPT4Dataset
from .nocaps import NoCaps
from .ocr_vqa import OCRVQA
from .refcoco import RefCOCO
Expand All @@ -56,5 +57,6 @@
'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption',
'FlamingoEvalCOCOVQA', 'Flickr30kCaption', 'Flickr30kRetrieval',
'RefCOCO', 'VisualGenomeQA', 'ScienceQA', 'NoCaps', 'GQA', 'TextVQA',
'VSR', 'VizWiz', 'OCRVQA', 'InfographicVQA', 'IconQA'
'VSR', 'VizWiz', 'OCRVQA', 'InfographicVQA', 'IconQA',
'MiniGPT4Dataset'
])
79 changes: 79 additions & 0 deletions mmpretrain/datasets/minigpt4_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

import mmengine
from mmengine.dataset import BaseDataset
from mmengine.fileio import get_file_backend

from mmpretrain.registry import DATASETS


@DATASETS.register_module()
class MiniGPT4Dataset(BaseDataset):
"""Dataset for training MiniGPT4.

MiniGPT4 dataset directory:

minigpt4_dataset
├── image
│ ├── id0.jpg
│ │── id1.jpg
│ │── id2.jpg
│ └── ...
└── conversation_data.json

The structure of conversation_data.json:

[
// English data
{
"id": str(id0),
"conversation": "###Ask: <Img><ImageHere></Img> [Ask content]
###Answer: [Answer content]"
},

// Chinese data
{
"id": str(id1),
"conversation": "###问:<Img><ImageHere></Img> [Ask content]
###答:[Answer content]"
},

...
]

Args:
data_root (str): The root directory for ``ann_file`` and ``image``.
ann_file (str): Conversation file path.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""

def load_data_list(self) -> List[dict]:
file_backend = get_file_backend(self.data_root)
conversation_path = file_backend.join_path(self.data_root,
self.ann_file)
conversation = mmengine.load(conversation_path)
img_ids = {}
n = 0
for conv in conversation:
img_id = conv['id']
if img_id not in img_ids.keys():
img_ids[img_id] = n
n += 1

img_root = file_backend.join_path(self.data_root, 'image')
data_list = []
for conv in conversation:
img_file = '{}.jpg'.format(conv['id'])
chat_content = conv['conversation']
lang = 'en' if chat_content.startswith('###Ask: ') else 'zh'
data_info = {
'image_id': img_ids[conv['id']],
'img_path': file_backend.join_path(img_root, img_file),
'chat_content': chat_content,
'lang': lang,
}

data_list.append(data_info)

return data_list
Loading