forked from DRAGNLabs/301r_retnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tokenize_data.py
73 lines (53 loc) · 2.2 KB
/
tokenize_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import argparse
import dask
dask.config.set({'dataframe.query-planning': True})
import dask.dataframe as dd
from dask.diagnostics import ProgressBar
import pyarrow as pa
import yaml
from pathlib import Path
from transformers import PreTrainedTokenizerFast
from utils import Struct
ProgressBar().register()
def tokenize_data(config, split):
# Dataset path
dataset_path = Path(config.raw_dataset_path) / split / '*.parquet'
# Load the dataset from disk into dask
dataset = dd.read_parquet(path=dataset_path,
columns=[config.dataset_feature])
tokenizer = PreTrainedTokenizerFast.from_pretrained(config.tokenizer_path)
def tokenization_partition(partition):
tokenization_dataframe = lambda series: \
tokenizer(
series,
padding=False,
truncation=False,
return_token_type_ids=False,
return_attention_mask=False)["input_ids"]
tokenized_data = partition[config.dataset_feature] \
.map(tokenization_dataframe, na_action='ignore').to_frame()
return tokenized_data
dataset = dataset.map_partitions(tokenization_partition)
# Make sure directory for tokenized dataset exists
tokenized_dataset_dir = Path(config.tokenized_dataset_path)
tokenized_dataset_dir.mkdir(parents=True, exist_ok=True)
print(f"Saving tokenized data to {config.tokenized_dataset_path}")
dataset.to_parquet(tokenized_dataset_dir / split,
schema={"text": pa.list_(pa.int64())})
print('Done!')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Tokenize data')
parser.add_argument('config_path',
type=str,
help='Path to the config file')
parser.add_argument('split',
type=str,
choices=['train', 'test', 'validation'],
help='Dataset split to use')
args = parser.parse_args()
config_path = args.config_path
split = args.split
with open(config_path, "r") as f:
config = yaml.safe_load(f)
config = Struct(**config)
tokenize_data(config, split)