Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Dec 24, 2024
1 parent 98b5bed commit f8727b1
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 6 deletions.
2 changes: 1 addition & 1 deletion swift/llm/model/model/qwen.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Dict, Optional, Tuple, Type

import torch
from transformers import AutoConfig, BitsAndBytesConfig, PreTrainedTokenizerBase
Expand Down
1 change: 0 additions & 1 deletion swift/llm/template/vision_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import numpy as np
import requests
import torch
from packaging import version
from PIL import Image, ImageDraw

from swift.utils import get_env_args
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/train/rlhf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import List, Optional, Union
from typing import List, Union

from swift.utils import patch_getattr
from ..argument import RLHFArguments
Expand Down
3 changes: 1 addition & 2 deletions swift/llm/train/sft.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from functools import partial
from typing import List, Optional, Union
from typing import List, Union

from datasets import Dataset as HfDataset

Expand All @@ -14,7 +14,6 @@
from ..dataset import EncodePreprocessor, GetLengthPreprocessor, LazyLLMDataset, PackingPreprocessor, load_dataset
from ..infer import prepare_generation_config
from ..model import get_model_arch
from ..template import get_template
from ..utils import deep_getattr, dynamic_gradient_checkpointing
from .tuner import TunerMixin

Expand Down
3 changes: 2 additions & 1 deletion swift/trainers/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,9 @@ def train(self, *args, **kwargs):
logger.info(f'Successfully registered post_encode hook: {[model.__class__.__name__ for model in models]}')
self._save_initial_model(self.args.output_dir)
with self.hub.patch_hub(), self._patch_loss_function():
return super().train(*args, **kwargs)
res = super().train(*args, **kwargs)
self.template.remove_post_encode_hook()
return res

def push_to_hub(self, *args, **kwargs):
with self.hub.patch_hub():
Expand Down

0 comments on commit f8727b1

Please sign in to comment.