-
Notifications
You must be signed in to change notification settings - Fork 19
/
learn_synthetic.py
executable file
·69 lines (52 loc) · 2.2 KB
/
learn_synthetic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from model import *
if __name__ == "__main__":
N = 300 # num teams
# only used for matrix factorization algorithms
D = 10 # output latent vector dimension
Hp = 5 # hidden units for pacing component
# only used for matrix factorization algorithms
D0 = 30 # base latent vector dimension
H = 20 # num hidden units for transformation networks
reg_param1 = .001
reg_param2 = .001
NUM_ITS = 500
MODEL = "pmf_with_pace"
if MODEL == "simplest":
BASE_LEARNING_RATE = .005
make_model_fn = make_simplest_learning_functions
elif MODEL == "pmf": # probabilistic matrix factorization
BASE_LEARNING_RATE = .005
make_model_fn = make_vanilla_pmf_functions
elif MODEL == "pmf_with_pace":
BASE_LEARNING_RATE = .001
make_model_fn = make_pmf_plus_pace_functions
elif MODEL == "full":
BASE_LEARNING_RATE = .0001
make_model_fn = make_learning_functions
else:
assert False # unsupported model
out_fn, train_fn, params = \
make_model_fn(N, D0, H, D, Hp, reg_param1, reg_param2)
# make synthetic data
G = 500 # num games
team1_ids = np.random.randint(N,size=(G))
team2_ids = np.random.randint(N,size=(G))
team1_locs = np.random.randint(3, size=(G))
team2_locs = np.random.randint(3, size=(G))
team1_scores = np.random.randint(50, high=70, size=(G))
team2_scores = np.random.randint(50, high=70, size=(G))
print team1_ids
print team1_locs
for t in range(NUM_ITS):
obj = 0
learning_rate = BASE_LEARNING_RATE / (1.0 + np.sqrt(t))
for g in range(G):
obj_g = train_fn(team1_ids[g], team1_locs[g], team2_ids[g], team2_locs[g],
team1_scores[g], team2_scores[g], learning_rate)
obj += obj_g
pred_team1, pred_team2 = out_fn(team1_ids[g], team1_locs[g],
team2_ids[g], team2_locs[g])
if t % 10 == 0 and g % 50 == 0:
print "%s-%s vs %s-%s" % (pred_team1, pred_team2,
team1_scores[g], team2_scores[g])
print "%s\t%s" % (t, obj)