forked from Seunghyeon-Cho/AdamNODE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
odelstm_data.py
117 lines (95 loc) · 3.69 KB
/
odelstm_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import numpy as np
import os
from einops import rearrange
import torch
class Walker2dImitationData:
def __init__(self, seq_len, device='cpu'):
self.seq_len = seq_len
self.device = device
all_files = sorted(
[
os.path.join("data/walker", d)
for d in os.listdir("data/walker")
if d.endswith(".npy")
]
)
self.rng = np.random.RandomState(891374)
np.random.RandomState(125487).shuffle(all_files)
# 15% test set, 10% validation set, the rest is for training
test_n = int(0.15 * len(all_files))
valid_n = int((0.15 + 0.1) * len(all_files))
test_files = all_files[:test_n]
valid_files = all_files[test_n:valid_n]
train_files = all_files[valid_n:]
train_x, train_t, train_y = self._load_files(train_files)
valid_x, valid_t, valid_y = self._load_files(valid_files)
test_x, test_t, test_y = self._load_files(test_files)
train_x, train_t, train_y = self.perturb_sequences(train_x, train_t, train_y)
valid_x, valid_t, valid_y = self.perturb_sequences(valid_x, valid_t, valid_y)
test_x, test_t, test_y = self.perturb_sequences(test_x, test_t, test_y)
self.train_x, self.train_times, self.train_y = self.align_sequences(
train_x, train_t, train_y
)
self.valid_x, self.valid_times, self.valid_y = self.align_sequences(
valid_x, valid_t, valid_y
)
self.test_x, self.test_times, self.test_y = self.align_sequences(
test_x, test_t, test_y
)
self.input_size = self.train_x.shape[-1]
# print("train_times: ", str(self.train_times.shape))
# print("train_x: ", str(self.train_x.shape))
# print("train_y: ", str(self.train_y.shape))
def align_sequences(self, set_x, set_t, set_y):
times = []
x = []
y = []
for i in range(len(set_y)):
seq_x = set_x[i]
seq_t = set_t[i]
seq_y = set_y[i]
for t in range(0, seq_y.shape[0] - self.seq_len, self.seq_len // 4):
x.append(seq_x[t: t + self.seq_len])
times.append(seq_t[t: t + self.seq_len])
y.append(seq_y[t: t + self.seq_len])
L = (
np.stack(x, axis=0),
np.stack(times, axis=0),
np.stack(y, axis=0),
)
return [rearrange(torch.Tensor(i), 'b t ... -> t b ...').to(self.device) for i in L]
def perturb_sequences(self, set_x, set_t, set_y):
x = []
times = []
y = []
for i in range(len(set_y)):
seq_x = set_x[i]
seq_y = set_y[i]
new_x, new_times = [], []
new_y = []
skip = 0
for t in range(seq_y.shape[0]):
skip += 1
if self.rng.rand() < 0.9:
new_x.append(seq_x[t])
new_times.append(skip)
new_y.append(seq_y[t])
skip = 0
x.append(np.stack(new_x, axis=0))
times.append(np.stack(new_times, axis=0))
y.append(np.stack(new_y, axis=0))
return x, times, y
def _load_files(self, files):
all_x = []
all_t = []
all_y = []
for f in files:
arr = np.load(f)
x_state = arr[:-1, :].astype(np.float32)
y = arr[1:, :].astype(np.float32)
x_times = np.ones(x_state.shape[0])
all_x.append(x_state)
all_t.append(x_times)
all_y.append(y)
# print("Loaded file '{}' of length {:d}".format(f, x_state.shape[0]))
return all_x, all_t, all_y