diff --git a/swift/llm/model/model/qwen.py b/swift/llm/model/model/qwen.py index d1f1c9dca..66069df3c 100644 --- a/swift/llm/model/model/qwen.py +++ b/swift/llm/model/model/qwen.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, Optional, Tuple, Type import torch from transformers import AutoConfig, BitsAndBytesConfig, PreTrainedTokenizerBase diff --git a/swift/llm/template/vision_utils.py b/swift/llm/template/vision_utils.py index 48cbda30f..fb1e68d9a 100644 --- a/swift/llm/template/vision_utils.py +++ b/swift/llm/template/vision_utils.py @@ -9,7 +9,6 @@ import numpy as np import requests import torch -from packaging import version from PIL import Image, ImageDraw from swift.utils import get_env_args diff --git a/swift/llm/train/rlhf.py b/swift/llm/train/rlhf.py index e172dd6a8..906a3e116 100644 --- a/swift/llm/train/rlhf.py +++ b/swift/llm/train/rlhf.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from typing import List, Optional, Union +from typing import List, Union from swift.utils import patch_getattr from ..argument import RLHFArguments diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index 24084598d..56ce3d87c 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -1,7 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os from functools import partial -from typing import List, Optional, Union +from typing import List, Union from datasets import Dataset as HfDataset @@ -14,7 +14,6 @@ from ..dataset import EncodePreprocessor, GetLengthPreprocessor, LazyLLMDataset, PackingPreprocessor, load_dataset from ..infer import prepare_generation_config from ..model import get_model_arch -from ..template import get_template from ..utils import deep_getattr, dynamic_gradient_checkpointing from .tuner import TunerMixin diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py index f9b401bb1..aa3458862 100644 --- a/swift/trainers/mixin.py +++ b/swift/trainers/mixin.py @@ -283,8 +283,9 @@ def train(self, *args, **kwargs): logger.info(f'Successfully registered post_encode hook: {[model.__class__.__name__ for model in models]}') self._save_initial_model(self.args.output_dir) with self.hub.patch_hub(), self._patch_loss_function(): - return super().train(*args, **kwargs) + res = super().train(*args, **kwargs) self.template.remove_post_encode_hook() + return res def push_to_hub(self, *args, **kwargs): with self.hub.patch_hub():