-
Notifications
You must be signed in to change notification settings - Fork 2
/
data_loader.py
83 lines (64 loc) · 1.99 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
class DataLoader:
"""
Superclass for data loaders responsible for generating data.
...
Methods
-------
_sample_batch(*args, **kwargs)
Sample a flat batch of data (no explicit task structure)
_sample_episode(*args, **kwargs)
Sample an episode consisting of a support and query set
generator(*args, **kwargs)
Data generator that yields batches/tasks
"""
def __init__(self, *args, **kwargs):
"""
Initliaze the attributes of the DataLoader.
"""
pass
def _sample_batch(self, *args, **kwargs):
"""Sample flat data
Sample a flat batch of data as often done in regular machine learning contexts
Parameters
----------
*args : iterable
Positional arguments
**kwargs : dict
Keyword arguments
Returns
----------
data
A flat batch of data (x, y)
"""
raise NotImplementedError()
def _sample_episode(self, *args, **kwargs):
"""Sample a task
Samples data in the form of a task consisting of a support and query set.
Parameters
----------
*args : iterable
Positional arguments
**kwargs : dict
Keyword arguments
Returns
----------
data
A task (train_x, train_y, test_x, test_y)
"""
raise NotImplementedError()
def generator(self, *args, **kwargs):
"""Generator object that yields batches or episodes
Iteratively yield a sampled batch or task
Parameters
----------
*args : iterable
Positional arguments
**kwargs : dict
Keyword arguments
Returns
----------
data
A task (train_x, train_y, test_x, test_y). When the generator samples
batches of data, test_x, and test_y are None
"""
raise NotImplementedError()