forked from DRAGNLabs/301r_retnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
split_data.py
59 lines (45 loc) · 1.59 KB
/
split_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import datasets
import sys
import time
import yaml
from datasets import DatasetDict
from pathlib import Path
from utils import Struct
import dask
dask.config.set({'dataframe.query-planning': True})
import dask.dataframe as dd
import dask_ml
def split_data(config):
"""
Filter and split the dataset into training, validation, and testing datasets.
"""
# Create folder to save this dataset's files in
dataset_dir = Path(config.raw_dataset_path)
dataset_dir.mkdir(parents=True, exist_ok=True)
print("Loading data...")
# Read the dataset from disk
dataset = dd.read_parquet(dataset_dir / "*.parquet")
# Filter out rows with only whitespace
dataset = dataset[dataset[config.dataset_feature].str.strip() != '']
# Split into training, validation, and testing datasets
train, test_valid = dask_ml.model_selection.train_test_split(
dataset,
shuffle=False, # Very expensive for large datasets
train_size=config.splits[0],
random_state=config.rand_seed)
test, validation = dask_ml.model_selection.train_test_split(
test_valid,
shuffle=False,
train_size=config.splits[1] / (config.splits[1] + config.splits[2]),
random_state=config.rand_seed)
train.to_parquet(dataset_dir / 'train')
validation.to_parquet(dataset_dir / 'validation')
test.to_parquet(dataset_dir / 'test')
print("Finished")
if __name__ == "__main__":
args = sys.argv
config_path = args[1]
with open(config_path, "r") as f:
config = yaml.safe_load(f)
config = Struct(**config)
split_data(config)