Skip to content

Commit

Permalink
Fix bugs (#1311)
Browse files Browse the repository at this point in the history
  • Loading branch information
tastelikefeet authored Jul 7, 2024
1 parent 65ea69d commit 4a96f35
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 16 deletions.
24 changes: 19 additions & 5 deletions swift/llm/utils/media.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
import shutil
from typing import Any, Dict, List, Literal, Optional, Union
from typing import Any, Dict, Literal, Optional, Union

import numpy as np
from modelscope.hub.utils.utils import get_cache_dir

from swift.hub.utils.utils import get_cache_dir
from swift.utils import get_logger

logger = get_logger()
Expand Down Expand Up @@ -125,10 +125,24 @@ def get_url(media_type):
return f'{MediaCache.URL_PREFIX}{media_type}.{extension}'

@staticmethod
def download(media_type, media_name=None):
from swift.utils import safe_ddp_context
def download(media_type_or_url: str, local_alias: Optional[str] = None):
"""Download and extract a resource from a http link.
Args:
media_type_or_url: `str`, Either belongs to the `media_type_urls` listed in the class field, or a
remote url to download and extract. Be aware that, this media type or url
needs to contain a zip or tar file.
local_alias: `Options[str]`, The local alias name for the `media_type_or_url`. If the first arg is a
media_type listed in this class, local_alias can leave None. else please pass in a name for the url.
The local dir contains the extracted files will be: {cache_dir}/{local_alias}
Returns:
The local dir contains the extracted files.
"""
from swift.utils import safe_ddp_context, FileLockContext
with safe_ddp_context():
return MediaCache._safe_download(media_type=media_type, media_name=media_name)
with FileLockContext(media_type_or_url):
return MediaCache._safe_download(media_type=media_type_or_url, media_name=local_alias)

@staticmethod
def _safe_download(media_type, media_name=None):
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import transformers
from modelscope import (AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,
GenerationConfig, GPTQConfig, snapshot_download)
from modelscope.hub.utils.utils import get_cache_dir
from packaging import version
from torch import Tensor
from torch import dtype as Dtype
Expand All @@ -25,7 +26,6 @@
from transformers.utils.versions import require_version

from swift import get_logger
from swift.hub.utils.utils import get_cache_dir
from swift.utils import get_dist_setting, safe_ddp_context, subprocess_run, use_torchacc
from .template import TemplateType
from .utils import get_max_model_len, is_unsloth_available
Expand Down
9 changes: 3 additions & 6 deletions swift/llm/utils/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,7 @@ def preprocess(self, d: Dict[str, Any]) -> Dict[str, Any]:
medias = self.parse_medias(d)
self.media_replacer(row, medias)
if self.media_type:
if not isinstance(self.media_key, str):
row[self.media_name] = medias
row[self.media_name] = medias
return row

def __call__(self, dataset: HfDataset) -> HfDataset:
Expand Down Expand Up @@ -248,8 +247,7 @@ def preprocess(self, d: Dict[str, Any]) -> Dict[str, Any]:
medias = self.parse_medias(d)
self.media_replacer(row, medias)
if self.media_type:
if not isinstance(self.media_key, str):
row[self.media_name] = medias
row[self.media_name] = medias
return row
except (AssertionError, SyntaxError):
if self.error_strategy == 'raise':
Expand Down Expand Up @@ -303,8 +301,7 @@ def preprocess(self, d: Dict[str, Any]) -> Dict[str, Any]:
medias = self.parse_medias(d)
self.media_replacer(row, medias)
if self.media_type:
if not isinstance(self.media_key, str):
row[self.media_name] = medias
row[self.media_name] = medias
except Exception:
if self.error_strategy == 'raise':
raise ValueError(f'conversations: {conversations}')
Expand Down
2 changes: 1 addition & 1 deletion swift/tuners/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
import numpy as np
import torch
from modelscope import snapshot_download
from modelscope.hub.utils.utils import get_cache_dir
from packaging import version
from peft.utils import CONFIG_NAME
from peft.utils import ModulesToSaveWrapper as _ModulesToSaveWrapper
from peft.utils import _get_submodules

from swift.hub.utils.utils import get_cache_dir
from swift.tuners.module_mapping import ModelKeys
from swift.utils.constants import BIN_EXTENSIONS
from swift.utils.logger import get_logger
Expand Down
17 changes: 17 additions & 0 deletions swift/ui/llm_train/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,14 @@ def update_log(cls, task):
ret.append(gr.update(visible=True, label=p['name']))
return ret

@classmethod
def get_initial(cls, line):
tqdm_starts = ['Train:', 'Map:', 'Val:', 'Filter:']
for start in tqdm_starts:
if line.startswith(start):
return start
return None

@classmethod
def wait(cls, logging_dir, task):
if not logging_dir:
Expand Down Expand Up @@ -334,6 +342,15 @@ def wait(cls, logging_dir, task):
else:
latest_data = ''
lines.extend(latest_lines)
start = cls.get_initial(lines[-1])
if start:
i = len(lines) - 2
while i >= 0:
if lines[i].startswith(start):
del lines[i]
i -= 1
else:
break
yield ['\n'.join(lines)] + Runtime.plot(task)
except IOError:
pass
Expand Down
6 changes: 3 additions & 3 deletions swift/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@
from .torch_utils import (activate_model_parameters, broadcast_string, freeze_model_parameters, get_dist_setting,
get_model_info, is_ddp_plus_mp, is_dist, is_local_master, is_master, is_mp, is_on_same_device,
show_layers, time_synchronize, torchacc_trim_graph, use_torchacc)
from .utils import (add_version_to_work_dir, check_json_format, get_pai_tensorboard_dir, is_pai_training_job,
lower_bound, parse_args, read_multi_line, safe_ddp_context, seed_everything, subprocess_run,
test_time, upper_bound)
from .utils import (FileLockContext, add_version_to_work_dir, check_json_format, get_pai_tensorboard_dir,
is_pai_training_job, lower_bound, parse_args, read_multi_line, safe_ddp_context, seed_everything,
subprocess_run, test_time, upper_bound)
48 changes: 48 additions & 0 deletions swift/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import datetime as dt
import fcntl
import hashlib
import os
import random
import re
Expand All @@ -11,6 +13,7 @@

import numpy as np
import torch.distributed as dist
from modelscope.hub.utils.utils import get_cache_dir
from transformers import HfArgumentParser, enable_full_determinism, set_seed

from .logger import get_logger
Expand All @@ -20,6 +23,51 @@
logger = get_logger()


class FileLockContext:

cache_dir = os.path.join(get_cache_dir(), 'lockers')

def __init__(self, origin_symbol: str, timeout: int = 60 * 30):
self.origin_symbol = origin_symbol
self.file_path = hashlib.md5(origin_symbol.encode('utf-8')).hexdigest() + '.lock'
self.file_path = os.path.join(FileLockContext.cache_dir, self.file_path)
self.file_handle = None
self.timeout = timeout

def acquire(self):
"""Acquire the lock, optionally waiting until it is available."""
start_time = time.time()
while True:
try:
os.makedirs(FileLockContext.cache_dir, exist_ok=True)
open(self.file_path, 'a').close()
self.file_handle = open(self.file_path, 'w')
fcntl.flock(self.file_handle, fcntl.LOCK_EX)
return True
except IOError as e:
if self.file_handle:
self.file_handle.close()
self.file_handle = None
if self.timeout and (time.time() - start_time) >= self.timeout:
raise IOError(f'Cannot acquire the file lock from {self.origin_symbol} '
f'as the timeout reaches: {self.timeout} seconds') from e
time.sleep(1)

def release(self):
"""Release the lock."""
if self.file_handle:
fcntl.flock(self.file_handle, fcntl.LOCK_UN)
self.file_handle.close()
self.file_handle = None

def __enter__(self):
self.acquire()
return self

def __exit__(self, exc_type, exc_value, traceback):
self.release()


@contextmanager
def safe_ddp_context():
if is_dist() and not is_local_master():
Expand Down

0 comments on commit 4a96f35

Please sign in to comment.