forked from cxtjjcz/785-visual-story-telling
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_model_v2.py
66 lines (54 loc) · 2.21 KB
/
run_model_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os.path as osp
import time
import pickle
from torch.utils.data import Dataset, DataLoader
from beam_search import *
from model_v2 import ModelV2
from vist_api.vist import Story_in_Sequence
from dataset import StoryDataset, collate_story
from vocab import Vocabulary
from train_test import train, test
vocab_save_path = "vocab.pt"
vist_annotations_dir = './vist_api/'
images_dir = './vist_api/images/'
sis_train = Story_in_Sequence(images_dir + "train", vist_annotations_dir)
sis_val = Story_in_Sequence(images_dir+"val", vist_annotations_dir)
# sis_test = Story_in_Sequence(images_dir+"test", vist_annotations_dir)
if True:
corpus = []
for story in sis_train.Stories:
sent_ids = sis_train.Stories[story]['sent_ids']
for sent_id in sent_ids:
corpus.append(sis_train.Sents[sent_id]['text'])
vocab = Vocabulary(corpus, freq_cutoff=2) # reads and builds
# Verifying vocabulary is the same
for word in vocab.w2i.keys():
index = vocab.w2i[word]
if (word != vocab.i2w[index]):
print('Words mismatched...')
# Saving vocabulary
with open(vocab_save_path, 'wb') as file:
pickle.dump(vocab, file)
else:
vocab = pickle.load(open(vocab_save_path, 'rb'))
def main():
train_story_set = StoryDataset(sis_train, vocab)
val_story_set = StoryDataset(sis_val, vocab)
# test_story_set = StoryDataset(sis_test, vocab)
train_loader = DataLoader(train_story_set, shuffle=False, batch_size=BATCH_SIZE, collate_fn=collate_story,
pin_memory=False)
# imgs of shape [BS, 5, 3, 224, 224]
# sents BS * 5 * MAX_LEN
model_v2 = ModelV2(vocab)
# Learning rate is the most sensitive value to set,
# will need to test what works well past 400 instances
optimizer = torch.optim.Adam(model_v2.parameters(), lr=0.05, weight_decay=1e-4) # .001 for 400
isTraining = True
if isTraining:
train(10, model_v2, train_loader, optimizer)
else:
model_v2.load_state_dict(torch.load('./Training/7'))
test_loader = DataLoader(train_story_set, shuffle=False, batch_size=BATCH_SIZE, collate_fn=collate_story)
test(model_v2, val_story_set, vocab)
if __name__ == "__main__":
main()