Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added CVQA related eval files #20

Open
wants to merge 80 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
e3ad3d9
pretrain LLaVA v1_5 with Aya
nahidalam Jul 8, 2024
4216e56
unk null
nahidalam Jul 8, 2024
c1efd19
siglip in pretrain script
nahidalam Jul 12, 2024
390524f
siglip_encoder.py from Snehanshu
nahidalam Jul 12, 2024
0c921d1
call siglipvisiontower
nahidalam Jul 12, 2024
8eafa41
readme update and debug print
nahidalam Jul 12, 2024
7fa8273
builder update for siglip
nahidalam Jul 12, 2024
3396dd7
siglip works
nahidalam Jul 13, 2024
3a2326d
cleanup
nahidalam Jul 14, 2024
67ab312
Merge pull request #2 from nahidalam/aya_siglip
nahidalam Jul 14, 2024
a7dade9
init llava_cohere
nahidalam Jul 14, 2024
de9fab0
update README
nahidalam Jul 14, 2024
ab31418
Added new cache_position argument
rsk2327 Jul 19, 2024
cffcdb8
Added eval notebook to run pretrained Aya models
rsk2327 Jul 19, 2024
4e6f9df
Update LLaVA_Loading_Pretrained_Model.ipynb
rsk2327 Jul 19, 2024
6188732
Create eval_utils.py
rsk2327 Jul 19, 2024
c3765f2
Added new functions to eval_utils and updated eval notebook
rsk2327 Jul 19, 2024
61d0e37
Adding finetuning code
rsk2327 Jul 20, 2024
d010966
pretraining unk_token change + changes to support multilingual json f…
Satyajitv Jul 22, 2024
c8bd707
pretraining unk_token change + changes to support multilingual json f…
Satyajitv Jul 22, 2024
33f78e3
Merge pull request #5 from Satyajitv/maya_pretrain_multilingual
nahidalam Jul 22, 2024
9912903
updated script for multilingual
nahidalam Jul 22, 2024
3029339
Added cohere finetuning notebook
rsk2327 Jul 22, 2024
9f929c7
update to batch size 4
nahidalam Jul 22, 2024
57c918d
Merge branch 'nahidalam:maya_pretrain' into maya_pretrain
rsk2327 Jul 23, 2024
18c1569
Updated builder.py to read aya pretrained model
rsk2327 Jul 24, 2024
c4f8a9b
Merge branch 'maya_pretrain' of https://github.com/rsk2327/LLaVA into…
rsk2327 Jul 24, 2024
6ad488a
Update builder.py
rsk2327 Jul 24, 2024
66af865
Updated train.py and finetune_lora.sh script for finetuning
rsk2327 Jul 24, 2024
6f50c1c
Update Finetuning_with_Pretrained_Cohere.ipynb
rsk2327 Jul 24, 2024
01bc79c
Add eval scripts and README for PALO multilingual-llava-bench-in-the-…
iuddin Jul 25, 2024
9820ea8
be able to also support loading of single json file
Satyajitv Jul 25, 2024
0138bfe
Merge pull request #8 from Satyajitv/maya_pretrain_path
nahidalam Jul 27, 2024
d751a4e
Merge pull request #6 from iuddin/palo_eval
nahidalam Jul 27, 2024
8e881e0
Merge pull request #4 from rsk2327/maya_pretrain
nahidalam Jul 27, 2024
f57a9db
Update eval_all_languages.sh to work with Maya
iuddin Jul 28, 2024
4e09840
Update llavabench_palo.sh to work with Maya
iuddin Jul 28, 2024
46dde3b
Add model_vqa_maya to support eval with Maya
iuddin Jul 28, 2024
e56b17f
Update model to gpt-4-turbo
iuddin Jul 28, 2024
ff51611
Commenting out unused 2 import lines asthed thros errrsr
iuddin Jul 28, 2024
ec97f28
Update .gitignore to allow two eval files
iuddin Jul 28, 2024
f47f69e
Add eval review and response file for Maya-8B English
iuddin Jul 28, 2024
68a0de3
Add PALO eval instructions to README
iuddin Jul 29, 2024
9789107
Merge pull request #10 from iuddin/integrate_pretrained_eval
nahidalam Jul 29, 2024
e1b2a01
Resolved token mismatch issue
rsk2327 Aug 5, 2024
4c3ed6e
Update train.py
rsk2327 Aug 6, 2024
ec5e69c
updated finetuning scripts
rsk2327 Aug 7, 2024
1282784
instruction tune
nahidalam Aug 9, 2024
94a27cd
Merge pull request #12 from rsk2327/maya_pretrain
nahidalam Aug 9, 2024
c34f12a
script path update
nahidalam Aug 9, 2024
a4cfdcb
update batch size
nahidalam Aug 9, 2024
37f66de
avoiding issue with crop_size missing in siglip, so using default val…
Satyajitv Aug 14, 2024
eb420f6
Merge pull request #15 from Satyajitv/maya_pretrain_crop_size
nahidalam Aug 21, 2024
f9b04af
skip if image not found and fix image size
nahidalam Aug 22, 2024
9489bca
update
nahidalam Aug 22, 2024
f2fe74f
Update finetune testing code
rsk2327 Aug 28, 2024
3e5b15e
Update finetune_args.py
rsk2327 Aug 31, 2024
240d506
Adding 256 for siglip crop_size
Satyajitv Aug 31, 2024
94c1171
Updated siglip embeddings usage
rsk2327 Sep 1, 2024
a05aba3
Update README.md
rsk2327 Sep 1, 2024
cb1f8a1
Merge pull request #16 from Satyajitv/maya_pretrain_visionm_cropsize
nahidalam Sep 1, 2024
352644d
Merge pull request #17 from rsk2327/maya_pretrain
nahidalam Sep 1, 2024
b68e1ea
fix wrong function call
nahidalam Sep 2, 2024
27f7ea9
effective batch size 32
nahidalam Sep 2, 2024
0bf1323
avoid the index error
Satyajitv Sep 3, 2024
79eb4a4
Merge branch 'maya_pretrain' into maya_pretrain_visionm_cropsize
Satyajitv Sep 3, 2024
cb88fe3
avoid the index error
Satyajitv Sep 3, 2024
7dec391
Merge pull request #18 from Satyajitv/maya_pretrain_visionm_cropsize
nahidalam Sep 3, 2024
625b2a8
avoid the index error
Satyajitv Sep 3, 2024
c94dc33
Merge pull request #19 from Satyajitv/maya_pretrain_visionm_cropsize
nahidalam Sep 3, 2024
b5c05e0
update finetune script
nahidalam Sep 5, 2024
8ba6a93
batch size 16 for instruction tune
nahidalam Sep 5, 2024
3275042
update non-lora finetune script
nahidalam Sep 7, 2024
04bf235
updated eval script
nahidalam Sep 23, 2024
d8b2d49
script path fix
nahidalam Sep 24, 2024
d85d9ed
updated script
nahidalam Sep 24, 2024
6a87e63
updated script
nahidalam Sep 24, 2024
0429f62
updated eval model to gpt-4o-mini
nahidalam Sep 25, 2024
00bac15
logging message
nahidalam Sep 25, 2024
1ecc141
Added CVQA eval files
rsk2327 Oct 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ dist
*.json
*.jsonl

# Exceptions
!evaluation/Maya-8B_English.jsonl
!evaluation/reviews/Maya-8B_English.jsonl

# Data
!**/alpaca-data-conversation.json

Expand All @@ -26,7 +30,7 @@ checkpoints
ckpts*

.ipynb_checkpoints
*.ipynb


# DevContainer
!.devcontainer/*
Expand Down
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,25 @@
# Maya: Multimodal Multilingual LLM

## Install

Follow the installation process of LLaVA below. Then do this

```
pip install chardet==5.2.0
pip install datasets==2.15.0
pip install deepspeed==0.14.2
pip install fastapi==0.111.0
pip install transformers==4.42.3
pip install accelerate==0.27.2
```
## Contributors
- Satya https://github.com/Satyajitv
- Ryan Chan https://github.com/rchan26
- Sangyeon Kim https://github.com/KimSangYeon-DGU
- Snehanshu https://github.com/pilot-j
- Drishti Sushma https://github.com/DrishtiShrrrma
- Roshan Santhosh https://github.com/rsk2327

# 🌋 LLaVA: Large Language and Vision Assistant

*Visual instruction tuning towards large language and vision models with GPT-4 level capabilities.*
Expand Down
60 changes: 60 additions & 0 deletions evaluation/Maya-8B_English.jsonl

Large diffs are not rendered by default.

60 changes: 60 additions & 0 deletions evaluation/reviews/Maya-8B_English.jsonl

Large diffs are not rendered by default.

13 changes: 13 additions & 0 deletions llava/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,18 @@ def dict(self):
sep="<|im_end|>",
)

conv_aya = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="<|END_OF_TURN_TOKEN|>",
)

default_conversation = conv_vicuna_v1
conv_templates = {
"default": conv_vicuna_v0,
Expand All @@ -389,6 +401,7 @@ def dict(self):
"llava_llama_2": conv_llava_llama_2,

"mpt": conv_mpt,
"aya": conv_aya
}


Expand Down
3 changes: 2 additions & 1 deletion llava/eval/eval_gpt_review.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def get_eval(content: str, max_tokens: int):
while True:
try:
response = openai.ChatCompletion.create(
model='gpt-4',
model='gpt-4o-mini',
messages=[{
'role': 'system',
'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
Expand All @@ -27,6 +27,7 @@ def get_eval(content: str, max_tokens: int):
)
break
except openai.error.RateLimitError:
print('rate limit error!')
pass
except Exception as e:
print(e)
Expand Down
3 changes: 2 additions & 1 deletion llava/eval/eval_gpt_review_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def get_eval(content: str, max_tokens: int):
while True:
try:
response = openai.ChatCompletion.create(
model='gpt-4-0314',
model='gpt-4o-mini', #gpt-4-0314
messages=[{
'role': 'system',
'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
Expand All @@ -25,6 +25,7 @@ def get_eval(content: str, max_tokens: int):
)
break
except openai.error.RateLimitError:
print('rate limit error!')
pass
except Exception as e:
print(e)
Expand Down
3 changes: 2 additions & 1 deletion llava/eval/eval_gpt_review_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def get_eval(content: str, max_tokens: int):
while True:
try:
response = openai.ChatCompletion.create(
model='gpt-4-0314',
model='gpt-4o-mini',
messages=[{
'role': 'system',
'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
Expand All @@ -25,6 +25,7 @@ def get_eval(content: str, max_tokens: int):
)
break
except openai.error.RateLimitError:
print('rate limit error!')
pass
except Exception as e:
print(e)
Expand Down
206 changes: 206 additions & 0 deletions llava/eval/maya/eval_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import os
import sys
import torch
import requests
from io import BytesIO
from PIL import Image


from transformers import AutoTokenizer, AutoConfig, TextStreamer
from transformers.models.cohere.tokenization_cohere_fast import CohereTokenizerFast
from llava.model.language_model.llava_cohere import LlavaCohereForCausalLM, LlavaCohereConfig
from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path

from typing import Optional, Literal


def load_maya_model(model_base: str, model_path : str, projector_path : Optional[str] = None, mode = Literal['pretrained','finetuned']):

""" Function that helps load a trained Maya model

Trained Maya model can be of two flavors :
1. Pretrained : The model has only gone through pretraining and the changes are restricted to the projector layer
2. Finetuned : Model has gone through instruction finetuning post pretraining stage. This affects the whole model

This is a replication of the load_pretrained_model function from llava.model.builder thats specific to Cohere/Maya

Args:
model_base : Path of the base LLM model in HF. Eg: 'CohereForAI/aya-23-8B', 'meta-llama/Meta-Llama-3-8B-Instruct'.
This is used to instantiate the tokenizer and the model (in case of loading the pretrained model)
model_path : Path of the trained model repo in HF. Eg : 'nahidalam/Maya'
This is used to load the config file. So this path/directory should have the config.json file
For the finetuned model, this is used to load the final model weights as well
projector_path : For the pretrained model, this represents the path to the local directory which holds the mm_projector.bin file
model : Helps specify if this is loading a pretrained only model or a finetuned model

Returns:
model: LlavaCohereForCausalLM object
tokenizer: CohereTokenizerFast object
image_processor:
content_len:
"""

device_map = 'auto'
kwargs = {"device_map": device_map}
kwargs['torch_dtype'] = torch.float16
kwargs['attn_implementation'] = 'flash_attention_2'

## Instantiating tokenizer and model base
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
cfg_pretrained = LlavaCohereConfig.from_pretrained(model_path)

if mode == 'pretrained':
model = LlavaCohereForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)

## Loading Projector layer weights
mm_projector_weights = torch.load(projector_path, map_location='cpu')
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
model.load_state_dict(mm_projector_weights, strict=False)
else:
model = LlavaCohereForCausalLM.from_pretrained('/home/user/Documents/GitHub/maya_full_ft', config=cfg_pretrained, **kwargs)




## Loading image processor
image_processor = None

mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
if mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
model.resize_token_embeddings(len(tokenizer))

vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
vision_tower.load_model(device_map=device_map)
if device_map != 'auto':
vision_tower.to(device=device_map, dtype=torch.float16)
image_processor = vision_tower.image_processor

if hasattr(model.config, "max_sequence_length"):
context_len = model.config.max_sequence_length
else:
context_len = 2048

maya = MayaModel(model, tokenizer, image_processor, context_len)

return maya


class MayaModel(object):

def __init__(self, model : LlavaCohereForCausalLM, tokenizer : CohereTokenizerFast, image_processor, context_length):
self.model = model
self.tokenizer = tokenizer
self.image_processor = image_processor
self.context_length = context_length

def validate_inputs():
pass




def load_image(image_input):
"""
Convert various image inputs to a PIL Image object.

:param image_input: Can be a URL string, a file path string, or image bytes
:return: PIL Image object
"""
try:
if isinstance(image_input, str):
if image_input.startswith(('http://', 'https://')):
# Input is a URL
response = requests.get(image_input)
response.raise_for_status() # Raise an exception for bad responses
return Image.open(io.BytesIO(response.content))
elif os.path.isfile(image_input):
# Input is a file path
return Image.open(image_input)
else:
raise ValueError("Invalid input: string is neither a valid URL nor a file path")
elif isinstance(image_input, bytes):
# Input is bytes
return Image.open(io.BytesIO(image_input))
else:
raise ValueError("Invalid input type. Expected URL string, file path string, or bytes.")
except requests.RequestException as e:
raise ValueError(f"Error fetching image from URL: {e}")
except IOError as e:
raise ValueError(f"Error opening image file: {e}")
except Exception as e:
raise ValueError(f"An unexpected error occurred: {e}")




def get_single_sample_prediction(maya_model, image_file, user_question, temperature = 0.0, max_new_tokens = 100, conv_mode = 'aya'):
"""Generates the prediction for a single image-user question pair.

Args:
model (MayaModel): Trained Maya model
image_file : One of the following: Online image url, local image path, or image bytes
user_question (str): Question to be shared with LLM
temperature (float, optional): Temperature param for LLMs. Defaults to 0.0.
max_new_tokens (int, optional): Max new number of tokens generated. Defaults to 100
conv_model (str, optional): Conversation model to be used. Defaults to 'aya'.

Returns:
output (str): Model's response to user question
"""


conv = conv_templates[conv_mode].copy()
roles = conv.roles
model = maya_model.model
tokenizer = maya_model.tokenizer
image_processor = maya_model.image_processor

image = load_image(image_file)
image_size = image.size

image_tensor = process_images([image], image_processor, model.config)
if type(image_tensor) is list:
image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
else:
image_tensor = image_tensor.to(model.device, dtype=torch.float16)

inp = user_question

if image is not None:
# first message
if model.config.mm_use_im_start_end:
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
else:
inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
# image = None

conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
image_sizes=[image_size],
do_sample=True if temperature > 0 else False,
temperature=temperature,
max_new_tokens=max_new_tokens,
streamer=streamer,
use_cache=True)

outputs = tokenizer.decode(output_ids[0]).strip()

return outputs
Loading