-
Notifications
You must be signed in to change notification settings - Fork 4
/
dataset.py
132 lines (107 loc) · 3.99 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import pprint
import numpy as np
import json
from tqdm import tqdm
from utils import File
class Dataset(list):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __add__(self, rhs):
return self.__class__(super().__add__(rhs))
def __getitem__(self, item):
result = super().__getitem__(item)
if isinstance(result, list):
return self.__class__(result)
else:
return result
@classmethod
def annotate(cls, tokenizer, din, dout, limit=None):
raise NotImplementedError()
@classmethod
def download(cls, dout):
raise NotImplementedError()
@classmethod
def pull(cls, tokenizer, limit=None):
dout = File.new_dir(os.path.join('dataset', cls.__name__.lower()), ensure_dir=True)
draw = os.path.join(dout, 'raw')
dann = os.path.join(dout, 'ann')
if not os.path.isdir(draw):
os.makedirs(draw)
cls.download(draw)
if not os.path.isdir(dann):
os.makedirs(dann)
anns = cls.annotate(tokenizer, draw, dann, limit=limit)
return anns
def compute_metrics(self, preds):
raise NotImplementedError()
def accumulate_preds(self, preds, batch_preds):
if preds is None:
preds = batch_preds.copy()
else:
preds.update(batch_preds)
return preds
@classmethod
def serialize_one(cls, ex):
return json.dumps(ex)
@classmethod
def deserialize_one(cls, line):
return json.loads(line)
def save(self, fname, verbose=False):
with open(fname, 'wt') as f:
iterator = tqdm(self, desc='save') if verbose else self
for ex in iterator:
f.write(json.dumps(self.serialize_one(ex)) + '\n')
@classmethod
def load(cls, fname, limit=None):
with open(fname) as f:
data = [cls.deserialize_one(line) for i, line in enumerate(f) if limit is None or i < limit]
return cls(data)
def keep(self, keep):
return self.__class__([e for e in self if keep(e)])
def batch(self, batch_size, shuffle=False, verbose=False, desc='batch'):
items = self[:]
if shuffle:
np.random.shuffle(items)
iterator = range(0, len(items), batch_size)
if verbose:
iterator = tqdm(iterator, desc=desc)
for i in iterator:
yield items[i:i+batch_size]
def reset(self):
return self
class AugDataset(Dataset):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.AugModel = self.train_data = self.generator = self.size = self.p_gen = None
def set_augmenter(self, AugModel, generator):
self.AugModel = AugModel
self.generator = generator
return self
def set_data(self, size, p_gen, train_data):
self.size = size
self.p_gen = p_gen
self.train_data = train_data
return self
def reset(self):
print('Generating data')
self.clear()
gen = self.generator.run_gen(self.size, self.AugModel)
self.extend(gen)
print('Generated data size {}'.format(len(self)))
return self
def batch(self, batch_size, shuffle=False, verbose=False, desc='batch'):
gen_batch_size = int(self.p_gen * batch_size)
train_batch_size = batch_size - gen_batch_size
for train in self.train_data.batch(train_batch_size, shuffle=shuffle, verbose=verbose, desc=desc):
sample_inds = np.random.choice(list(range(len(self))), size=gen_batch_size)
gen = train.__class__([self[i] for i in sample_inds])
yield gen + train
def __getitem__(self, item):
result = super().__getitem__(item)
if isinstance(result, list):
new = self.__class__(result)
new.set_augmenter(self.AugModel, self.generator).set_data(self.size, self.p_gen, self.train_data)
return new
else:
return result