From 9f7fd8dd6069f4ccde2d90c8ea79b05bb40505f6 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 26 Dec 2024 20:51:58 +0800 Subject: [PATCH] fix qwen2vl --- swift/llm/__init__.py | 4 ++-- swift/llm/model/model/qwen.py | 12 ++++++++++++ swift/llm/template/__init__.py | 2 +- swift/llm/template/template/qwen.py | 15 ++++----------- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/swift/llm/__init__.py b/swift/llm/__init__.py index 2173ce93b..8e8145b6a 100644 --- a/swift/llm/__init__.py +++ b/swift/llm/__init__.py @@ -14,7 +14,7 @@ RLHFArguments, WebUIArguments, BaseArguments) from .template import (TEMPLATE_MAPPING, Template, Word, get_template, TemplateType, register_template, TemplateInputs, Messages, TemplateMeta, get_template_meta, InferRequest, load_image, - MaxLengthError) + MaxLengthError, load_file) from .model import (register_model, MODEL_MAPPING, ModelType, get_model_tokenizer, safe_snapshot_download, HfConfigFactory, ModelInfo, ModelMeta, ModelKeys, register_model_arch, MultiModelKeys, ModelArch, get_model_arch, MODEL_ARCH_MAPPING, get_model_info_meta, get_model_name, ModelGroup, @@ -45,7 +45,7 @@ 'template': [ 'TEMPLATE_MAPPING', 'Template', 'Word', 'get_template', 'TemplateType', 'register_template', 'TemplateInputs', 'Messages', 'TemplateMeta', 'get_template_meta', 'InferRequest', 'load_image', - 'MaxLengthError' + 'MaxLengthError', 'load_file' ], 'model': [ 'MODEL_MAPPING', 'ModelType', 'get_model_tokenizer', 'safe_snapshot_download', 'HfConfigFactory', diff --git a/swift/llm/model/model/qwen.py b/swift/llm/model/model/qwen.py index 66069df3c..7a9b4255a 100644 --- a/swift/llm/model/model/qwen.py +++ b/swift/llm/model/model/qwen.py @@ -474,12 +474,24 @@ def _get_cast_dtype(self) -> torch.dtype: def patch_qwen_vl_utils(): from qwen_vl_utils import vision_process + if hasattr(vision_process, '_patch'): + return for key in [ 'image_factor', 'min_pixels', 'max_pixels', 'max_ratio', 'video_min_pixels', 'video_max_pixels', 'video_total_pixels', 'frame_factor', 'fps', 'fps_min_frames', 'fps_max_frames' ]: type_func = float if key == 'fps' else int setattr(vision_process, key.upper(), get_env_args(key, type_func, getattr(vision_process, key.upper()))) + from qwen_vl_utils import vision_process + _read_video_decord = vision_process._read_video_decord + + def _new_read_video_decord(ele: dict): + from swift.llm import load_file + ele['video'] = load_file(ele['video']) + return _read_video_decord(ele) + + vision_process.VIDEO_READER_BACKENDS['decord'] = _new_read_video_decord + vision_process._patch = True def get_model_tokenizer_qwen2_vl(model_dir: str, diff --git a/swift/llm/template/__init__.py b/swift/llm/template/__init__.py index 52f99f0f1..f650558b1 100644 --- a/swift/llm/template/__init__.py +++ b/swift/llm/template/__init__.py @@ -6,4 +6,4 @@ from .template_inputs import InferRequest, Messages, TemplateInputs, Tool from .template_meta import TemplateMeta from .utils import Word, split_action_action_input, split_parts_by_regex, split_str_parts_by -from .vision_utils import load_image +from .vision_utils import load_file, load_image diff --git a/swift/llm/template/template/qwen.py b/swift/llm/template/template/qwen.py index 9c9e3becb..0ec5469e1 100644 --- a/swift/llm/template/template/qwen.py +++ b/swift/llm/template/template/qwen.py @@ -12,7 +12,7 @@ from ..template_inputs import StdTemplateInputs from ..template_meta import TemplateMeta from ..utils import Context, Word, findall -from ..vision_utils import load_audio_qwen, load_batch +from ..vision_utils import load_audio_qwen, load_batch, load_file from .utils import DEFAULT_SYSTEM, ChatmlTemplateMeta @@ -145,11 +145,6 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in register_template(QwenTemplateMeta(MLLMTemplateType.qwen2_audio, template_cls=Qwen2AudioTemplate)) -def _process_image_qwen(image): - - return image - - class Qwen2VLTemplate(Template): image_token_id = 151655 video_token_id = 151656 @@ -162,7 +157,7 @@ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int inputs.images[index] = fetch_image({'image': inputs.images[index]}) return ['<|vision_start|><|image_pad|><|vision_end|>'] else: - inputs.videos[index] = fetch_video({'video': inputs.videos[index]}) + inputs.videos[index] = fetch_video({'video': inputs.videos[index]}).to(torch.uint8) return ['<|vision_start|><|video_pad|><|vision_end|>'] def replace_object(self, object_: Dict[str, Any], index: int, inputs: StdTemplateInputs) -> List[Context]: @@ -198,12 +193,10 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: if locals()[media_type]: if media_type == 'images': media_token = self.image_token_id - media_inputs = processor.image_processor( - images=images, videos=None, return_tensors='pt', do_rescale=False) + media_inputs = processor.image_processor(images=images, videos=None, return_tensors='pt') media_grid_thw = media_inputs['image_grid_thw'] else: - media_inputs = processor.image_processor( - images=None, videos=videos, return_tensors='pt', do_rescale=False) + media_inputs = processor.image_processor(images=None, videos=videos, return_tensors='pt') media_grid_thw = media_inputs['video_grid_thw'] media_token = self.video_token_id idx_list = findall(input_ids, media_token)