Skip to content

Commit

Permalink
fix qwen2vl
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Dec 26, 2024
1 parent 629dca3 commit 9f7fd8d
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 14 deletions.
4 changes: 2 additions & 2 deletions swift/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
RLHFArguments, WebUIArguments, BaseArguments)
from .template import (TEMPLATE_MAPPING, Template, Word, get_template, TemplateType, register_template,
TemplateInputs, Messages, TemplateMeta, get_template_meta, InferRequest, load_image,
MaxLengthError)
MaxLengthError, load_file)
from .model import (register_model, MODEL_MAPPING, ModelType, get_model_tokenizer, safe_snapshot_download,
HfConfigFactory, ModelInfo, ModelMeta, ModelKeys, register_model_arch, MultiModelKeys,
ModelArch, get_model_arch, MODEL_ARCH_MAPPING, get_model_info_meta, get_model_name, ModelGroup,
Expand Down Expand Up @@ -45,7 +45,7 @@
'template': [
'TEMPLATE_MAPPING', 'Template', 'Word', 'get_template', 'TemplateType', 'register_template',
'TemplateInputs', 'Messages', 'TemplateMeta', 'get_template_meta', 'InferRequest', 'load_image',
'MaxLengthError'
'MaxLengthError', 'load_file'
],
'model': [
'MODEL_MAPPING', 'ModelType', 'get_model_tokenizer', 'safe_snapshot_download', 'HfConfigFactory',
Expand Down
12 changes: 12 additions & 0 deletions swift/llm/model/model/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,12 +474,24 @@ def _get_cast_dtype(self) -> torch.dtype:

def patch_qwen_vl_utils():
from qwen_vl_utils import vision_process
if hasattr(vision_process, '_patch'):
return
for key in [
'image_factor', 'min_pixels', 'max_pixels', 'max_ratio', 'video_min_pixels', 'video_max_pixels',
'video_total_pixels', 'frame_factor', 'fps', 'fps_min_frames', 'fps_max_frames'
]:
type_func = float if key == 'fps' else int
setattr(vision_process, key.upper(), get_env_args(key, type_func, getattr(vision_process, key.upper())))
from qwen_vl_utils import vision_process
_read_video_decord = vision_process._read_video_decord

def _new_read_video_decord(ele: dict):
from swift.llm import load_file
ele['video'] = load_file(ele['video'])
return _read_video_decord(ele)

vision_process.VIDEO_READER_BACKENDS['decord'] = _new_read_video_decord
vision_process._patch = True


def get_model_tokenizer_qwen2_vl(model_dir: str,
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
from .template_inputs import InferRequest, Messages, TemplateInputs, Tool
from .template_meta import TemplateMeta
from .utils import Word, split_action_action_input, split_parts_by_regex, split_str_parts_by
from .vision_utils import load_image
from .vision_utils import load_file, load_image
15 changes: 4 additions & 11 deletions swift/llm/template/template/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ..template_inputs import StdTemplateInputs
from ..template_meta import TemplateMeta
from ..utils import Context, Word, findall
from ..vision_utils import load_audio_qwen, load_batch
from ..vision_utils import load_audio_qwen, load_batch, load_file
from .utils import DEFAULT_SYSTEM, ChatmlTemplateMeta


Expand Down Expand Up @@ -145,11 +145,6 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in
register_template(QwenTemplateMeta(MLLMTemplateType.qwen2_audio, template_cls=Qwen2AudioTemplate))


def _process_image_qwen(image):

return image


class Qwen2VLTemplate(Template):
image_token_id = 151655
video_token_id = 151656
Expand All @@ -162,7 +157,7 @@ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int
inputs.images[index] = fetch_image({'image': inputs.images[index]})
return ['<|vision_start|><|image_pad|><|vision_end|>']
else:
inputs.videos[index] = fetch_video({'video': inputs.videos[index]})
inputs.videos[index] = fetch_video({'video': inputs.videos[index]}).to(torch.uint8)
return ['<|vision_start|><|video_pad|><|vision_end|>']

def replace_object(self, object_: Dict[str, Any], index: int, inputs: StdTemplateInputs) -> List[Context]:
Expand Down Expand Up @@ -198,12 +193,10 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
if locals()[media_type]:
if media_type == 'images':
media_token = self.image_token_id
media_inputs = processor.image_processor(
images=images, videos=None, return_tensors='pt', do_rescale=False)
media_inputs = processor.image_processor(images=images, videos=None, return_tensors='pt')
media_grid_thw = media_inputs['image_grid_thw']
else:
media_inputs = processor.image_processor(
images=None, videos=videos, return_tensors='pt', do_rescale=False)
media_inputs = processor.image_processor(images=None, videos=videos, return_tensors='pt')
media_grid_thw = media_inputs['video_grid_thw']
media_token = self.video_token_id
idx_list = findall(input_ids, media_token)
Expand Down

0 comments on commit 9f7fd8d

Please sign in to comment.