-
Notifications
You must be signed in to change notification settings - Fork 0
/
export.py
40 lines (27 loc) · 957 Bytes
/
export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from utils.data_loader import data_loader
from utils.models import Captioner
from torchvision.transforms import transforms
# Apply some transformation to our data
transform = transforms.Compose(
[
transforms.Resize((356, 356)),
transforms.RandomCrop((299, 299)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]
)
# Get the data
data, dataset = data_loader(root_dir="Data/Images/train",
caption_file="Data/caption_train.csv",
transform=transform, num_workers=6)
# Hyperparameters
embed_size = 256
hidden_size = 256
vocabulary_size = len(dataset.vocabulary)
num_layer = 1
images, captions = next(iter(data))
model = Captioner(embed_size, hidden_size, vocabulary_size, num_layer)
model.eval()
traced_model = torch.jit.trace(model, (images, captions))
traced_model.save("deployment/checkpoints/caption.pt")