Skip to content

Commit

Permalink
fix ui, and support DATASET_ENABLE_CACHE variable (#1319)
Browse files Browse the repository at this point in the history
  • Loading branch information
tastelikefeet authored Jul 8, 2024
1 parent fd1dd26 commit f5b4585
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 44 deletions.
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,19 @@ The complete list of supported models and datasets can be found at [Supported Mo
| Computing cards A10/A100, etc. | Support BF16 and FlashAttn |
| Huawei Ascend NPU | |

### Environment variables

- DATASET_ENABLE_CACHE: Enable cache when preprocess dataset, you can use `1/True` or `0/False`, default `False`
- WEBUI_SHARE: Share your web-ui, you can use `1/True` or `0/False`, default `False`
- SWIFT_UI_LANG: web-ui language, you can use `en` or `zh`, default `zh`
- WEBUI_SERVER: web-ui host ip,`0.0.0.0` for all routes,`127.0.0.1` for local network only. Default `127.0.0.1`
- WEBUI_PORT: web-ui port
- USE_HF: Use huggingface endpoint or ModelScope endpoint to download models and datasets. you can use `1/True` or `0/False`, default `False`
- FORCE_REDOWNLOAD: Force to re-download the dataset

Other variables like `CUDA_VISIBLE_DEVICES` are also supported, which are not listed here.


## 📃 Documentation

### Documentation Compiling
Expand Down
13 changes: 13 additions & 0 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,19 @@ CUDA_VISIBLE_DEVICES=0 swift deploy \
| 华为昇腾NPU | |


### 环境变量

- DATASET_ENABLE_CACHE:在预处理数据集时启用缓存,您可以使用`1/True``0/False`,默认值为`False`
- WEBUI_SHARE:共享web-ui,可以使用`1/True``0/False`,默认值为`False`
- SWIFT_UI_LANG:web-ui语言,您可以使用`en``zh`,默认值为`zh`
- WEBUI_SERVER:web-ui可访问的IP`0.0.0.0`表示所有路由,`127.0.0.1`仅用于本地网络。默认值为`127.0.0.1`
- WEBUI_PORT:web-ui端口
- USE_HF:使用huggingface endpoint或ModelScope endpoint下载模型和数据集。您可以使用`1/True``0/False`,默认值为`False`
- FORCE_REDOWNLOAD:强制重新下载数据集

其他变量如`CUDA_VISIBLE_DEVICES`也支持,但未在此列出。


## 📃文档

### 文档编译
Expand Down
81 changes: 43 additions & 38 deletions swift/llm/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@
from tqdm.auto import tqdm
from transformers.utils import strtobool

from swift.utils import (get_logger, get_seed, is_dist, is_local_master, read_from_jsonl, safe_ddp_context,
transform_jsonl_to_df)
from swift.utils import get_logger, get_seed, is_dist, is_local_master, read_from_jsonl, transform_jsonl_to_df
from swift.utils.torch_utils import _find_local_mac
from .media import MediaCache, MediaTag
from .preprocess import (AlpacaPreprocessor, ClsPreprocessor, ComposePreprocessor, ConversationsPreprocessor,
ListPreprocessor, PreprocessFunc, RenameColumnsPreprocessor, SmartPreprocessor,
TextGenerationPreprocessor, preprocess_sharegpt)
from .utils import download_dataset

dataset_enable_cache = strtobool(os.environ.get('DATASET_ENABLE_CACHE', 'False'))


def _update_fingerprint_mac(*args, **kwargs):
mac = _find_local_mac().replace(':', '')
Expand Down Expand Up @@ -378,7 +379,7 @@ def _post_preprocess(
train_sample = dataset_sample - val_sample
assert isinstance(val_sample, int)
train_dataset, val_dataset = train_dataset.train_test_split(
test_size=val_sample, seed=get_seed(random_state), load_from_cache_file=False).values()
test_size=val_sample, seed=get_seed(random_state), load_from_cache_file=dataset_enable_cache).values()

assert train_sample > 0
train_dataset = sample_dataset(train_dataset, train_sample, random_state)
Expand Down Expand Up @@ -445,7 +446,8 @@ def preprocess_row(row):
return {'image': [], 'conversations': []}
return {'image': [image]}

dataset = dataset.map(preprocess_row, load_from_cache_file=False).filter(lambda row: row['conversations'])
dataset = dataset.map(
preprocess_row, load_from_cache_file=dataset_enable_cache).filter(lambda row: row['conversations'])
return ConversationsPreprocessor(
user_role='human', assistant_role='gpt', media_type='image', error_strategy='delete')(
dataset)
Expand Down Expand Up @@ -490,7 +492,7 @@ def preprocess_row(row):
else:
return {'images': []}

return dataset.map(preprocess_row, load_from_cache_file=False).filter(lambda row: row['images'])
return dataset.map(preprocess_row, load_from_cache_file=dataset_enable_cache).filter(lambda row: row['images'])


def get_mantis_dataset(dataset_id: str,
Expand Down Expand Up @@ -575,7 +577,7 @@ def preprocess_image(example):
example['images'] = []
return example

dataset = dataset.map(preprocess_image, load_from_cache_file=False).filter(lambda row: row['images'])
dataset = dataset.map(preprocess_image, load_from_cache_file=dataset_enable_cache).filter(lambda row: row['images'])
return ConversationsPreprocessor(
user_role='user',
assistant_role='assistant',
Expand Down Expand Up @@ -666,7 +668,7 @@ def preprocess(row):
'query': np.random.choice(caption_prompt),
}

return dataset.map(preprocess, load_from_cache_file=False)
return dataset.map(preprocess, load_from_cache_file=dataset_enable_cache)


register_dataset(
Expand Down Expand Up @@ -717,11 +719,9 @@ def _preprocess_aishell1_dataset(dataset: HfDataset) -> HfDataset:


def _preprocess_video_chatgpt(dataset: HfDataset) -> HfDataset:
from datasets.download.download_manager import DownloadManager
url = 'https://modelscope.cn/datasets/huangjintao/VideoChatGPT/resolve/master/videos.zip'
with safe_ddp_context():
local_dir = DownloadManager().download_and_extract(url)
local_dir = os.path.join(str(local_dir), 'Test_Videos')
local_dir = MediaCache.download(url, 'video_chatgpt')
local_dir = os.path.join(local_dir, 'Test_Videos')
# only `.mp4`
mp4_set = [file[:-4] for file in os.listdir(local_dir) if file.endswith('mp4')]
query = []
Expand Down Expand Up @@ -794,7 +794,7 @@ def map_row(row):
return response

dataset = AlpacaPreprocessor()(dataset)
return dataset.map(map_row, load_from_cache_file=False)
return dataset.map(map_row, load_from_cache_file=dataset_enable_cache)


register_dataset(
Expand All @@ -821,7 +821,7 @@ def map_row(row):
title = match.group(1)
return {'response': title}

return dataset.map(map_row, load_from_cache_file=False).filter(lambda row: row['response'])
return dataset.map(map_row, load_from_cache_file=dataset_enable_cache).filter(lambda row: row['response'])


register_dataset(
Expand Down Expand Up @@ -1002,7 +1002,8 @@ def reorganize_row(row):
'history': history,
}

return dataset.map(reorganize_row, load_from_cache_file=False).filter(lambda row: row['query'] is not None)
return dataset.map(
reorganize_row, load_from_cache_file=dataset_enable_cache).filter(lambda row: row['query'] is not None)


register_dataset(
Expand Down Expand Up @@ -1067,7 +1068,7 @@ def row_can_be_parsed(row):
return False

return dataset.filter(row_can_be_parsed).map(
reorganize_row, load_from_cache_file=False).filter(lambda row: row['query'])
reorganize_row, load_from_cache_file=dataset_enable_cache).filter(lambda row: row['query'])


register_dataset(
Expand Down Expand Up @@ -1137,7 +1138,8 @@ def preprocess_image(example):
return example

dataset = dataset.map(
preprocess_image, load_from_cache_file=False).filter(lambda example: example['images'] is not None)
preprocess_image,
load_from_cache_file=dataset_enable_cache).filter(lambda example: example['images'] is not None)
processer = ConversationsPreprocessor(
user_role='human', assistant_role='gpt', media_type='image', media_key='images', error_strategy='delete')
return processer(dataset)
Expand Down Expand Up @@ -1182,8 +1184,8 @@ def preprocess(row):
return {'response': '', 'image': None}

return dataset.map(
preprocess,
load_from_cache_file=False).filter(lambda row: row.get('response')).rename_columns({'image': 'images'})
preprocess, load_from_cache_file=dataset_enable_cache).filter(lambda row: row.get('response')).rename_columns(
{'image': 'images'})


def preprocess_refcoco_unofficial_caption(dataset):
Expand All @@ -1209,7 +1211,7 @@ def preprocess(row):
res['response'] = ''
return res

return dataset.map(preprocess, load_from_cache_file=False).filter(lambda row: row.get('response'))
return dataset.map(preprocess, load_from_cache_file=dataset_enable_cache).filter(lambda row: row.get('response'))


register_dataset(
Expand Down Expand Up @@ -1254,7 +1256,7 @@ def preprocess(row):
res['response'] = ''
return res

return dataset.map(preprocess, load_from_cache_file=False).filter(lambda row: row.get('response'))
return dataset.map(preprocess, load_from_cache_file=dataset_enable_cache).filter(lambda row: row.get('response'))


register_dataset(
Expand Down Expand Up @@ -1323,7 +1325,8 @@ def preprocess_image(example):
return example

dataset = dataset.map(
preprocess_image, load_from_cache_file=False).filter(lambda example: example['images'] is not None)
preprocess_image,
load_from_cache_file=dataset_enable_cache).filter(lambda example: example['images'] is not None)
processer = ConversationsPreprocessor(
user_role='human', assistant_role='gpt', media_type='image', media_key='images', error_strategy='delete')
return processer(dataset)
Expand Down Expand Up @@ -1386,7 +1389,7 @@ def preprocess(row):
else:
return {'image': ''}

dataset = dataset.map(preprocess, load_from_cache_file=False).filter(lambda row: row['image'])
dataset = dataset.map(preprocess, load_from_cache_file=dataset_enable_cache).filter(lambda row: row['image'])
return ConversationsPreprocessor(
user_role='human', assistant_role='gpt', media_type='image', error_strategy='delete')(
dataset)
Expand All @@ -1412,7 +1415,7 @@ def reorganize_row(row):
'rejected_response': row['answer_en'],
}

return dataset.map(reorganize_row, load_from_cache_file=False)
return dataset.map(reorganize_row, load_from_cache_file=dataset_enable_cache)


def process_ultrafeedback_kto(dataset: HfDataset):
Expand All @@ -1424,7 +1427,7 @@ def reorganize_row(row):
'label': row['label'],
}

return dataset.map(reorganize_row, load_from_cache_file=False)
return dataset.map(reorganize_row, load_from_cache_file=dataset_enable_cache)


register_dataset(
Expand Down Expand Up @@ -1466,7 +1469,8 @@ def preprocess_row(row):
'response': output,
}

return dataset.map(preprocess_row, load_from_cache_file=False).filter(lambda row: row['query'] and row['response'])
return dataset.map(
preprocess_row, load_from_cache_file=dataset_enable_cache).filter(lambda row: row['query'] and row['response'])


register_dataset(
Expand Down Expand Up @@ -1495,7 +1499,7 @@ def preprocess_row(row):
'response': response,
}

return dataset.map(preprocess_row, load_from_cache_file=False)
return dataset.map(preprocess_row, load_from_cache_file=dataset_enable_cache)


register_dataset(
Expand Down Expand Up @@ -1537,7 +1541,7 @@ def preprocess(row):
'query': query,
}

return dataset.map(preprocess, load_from_cache_file=False).rename_column('image', 'images')
return dataset.map(preprocess, load_from_cache_file=dataset_enable_cache).rename_column('image', 'images')


register_dataset(
Expand All @@ -1560,7 +1564,7 @@ def preprocess(row):
'query': query,
}

return dataset.map(preprocess, load_from_cache_file=False).rename_column('image', 'images')
return dataset.map(preprocess, load_from_cache_file=dataset_enable_cache).rename_column('image', 'images')


register_dataset(
Expand All @@ -1584,7 +1588,7 @@ def preprocess(row):
'query': query,
}

return dataset.map(preprocess, load_from_cache_file=False).rename_column('image', 'images')
return dataset.map(preprocess, load_from_cache_file=dataset_enable_cache).rename_column('image', 'images')


register_dataset(
Expand All @@ -1606,7 +1610,8 @@ def preprocess_row(row):
return {'query': query, 'response': f'{solution}\nSo the final answer is:{response}'}

return dataset.map(
preprocess_row, load_from_cache_file=False).filter(lambda row: row['image']).rename_columns({'image': 'images'})
preprocess_row,
load_from_cache_file=dataset_enable_cache).filter(lambda row: row['image']).rename_columns({'image': 'images'})


register_dataset(
Expand Down Expand Up @@ -1660,7 +1665,7 @@ def preprocess_row(row):

return {'images': images, 'response': response, 'objects': json.dumps(objects or [], ensure_ascii=False)}

return dataset.map(preprocess_row, load_from_cache_file=False).filter(lambda row: row['objects'])
return dataset.map(preprocess_row, load_from_cache_file=dataset_enable_cache).filter(lambda row: row['objects'])


register_dataset(
Expand All @@ -1687,7 +1692,7 @@ def preprocess_row(row):
else:
return {'query': '', 'response': '', 'images': ''}

return dataset.map(preprocess_row, load_from_cache_file=False).filter(lambda row: row['query'])
return dataset.map(preprocess_row, load_from_cache_file=dataset_enable_cache).filter(lambda row: row['query'])


register_dataset(
Expand Down Expand Up @@ -1720,7 +1725,7 @@ def preprocess_row(row):
return {'messages': rounds}

dataset = dataset.map(
preprocess_row, load_from_cache_file=False).map(
preprocess_row, load_from_cache_file=dataset_enable_cache).map(
ConversationsPreprocessor(
user_role='user',
assistant_role='assistant',
Expand All @@ -1730,7 +1735,7 @@ def preprocess_row(row):
media_key='images',
media_type='image',
).preprocess,
load_from_cache_file=False)
load_from_cache_file=dataset_enable_cache)
return dataset


Expand Down Expand Up @@ -1787,8 +1792,8 @@ def preprocess(row):
}

return dataset.map(
preprocess,
load_from_cache_file=False).filter(lambda r: r['source'] != 'toxic-dpo-v0.2' and r['query'] is not None)
preprocess, load_from_cache_file=dataset_enable_cache).filter(
lambda r: r['source'] != 'toxic-dpo-v0.2' and r['query'] is not None)


register_dataset(
Expand All @@ -1814,7 +1819,7 @@ def preprocess(row):
'response': response,
}

return dataset.map(preprocess, load_from_cache_file=False)
return dataset.map(preprocess, load_from_cache_file=dataset_enable_cache)


register_dataset(
Expand Down Expand Up @@ -2116,7 +2121,7 @@ def reorganize_row(row):
'response': convs[-1]['value']
}

return dataset.map(reorganize_row, load_from_cache_file=False)
return dataset.map(reorganize_row, load_from_cache_file=dataset_enable_cache)


register_dataset(
Expand Down
Loading

0 comments on commit f5b4585

Please sign in to comment.