-
Notifications
You must be signed in to change notification settings - Fork 1
/
prac.py
52 lines (46 loc) · 1.77 KB
/
prac.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import os
import sys
import argparse
import time
from _datetime import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from conf import settings
from utils import get_network, get_training_dataloader, get_test_dataloader, WarmUpLR
parser = argparse.ArgumentParser() # 하이퍼파라미트 정의
parser.add_argument('-net', type=str, required=True, help='net type')
parser.add_argument('-gpu', action='store_true', default=False, help='use gpu or not')
parser.add_argument('-b', type=int, default=128, help='batch size for dataloader')
parser.add_argument('-warm', type=int, default=1, help='warm up training phase')
# warm이 뭐지
parser.add_argument('-lr', type=float, default=0.1, help='initial learning rate')
parser.add_argument('-resume', action='store_true', default=False, help='resume training')
# resume 뭐지
args = parser.parse_args() # 인자 파싱하기
net = get_network(args)
# 데이터 전처리
cifar100_training_loader = get_training_dataloader(
settings.CIFAR100_TRAIN_MEAN,
settings.CIFAR100_TRAIN_STD,
num_workers=4,
batch_szie=args.b,
shuffle=True
)
cifar100_test_loader = get_test_dataloader(
settings.CIFAR100_TEST_MEAN,
settings.CIFAR100_TEST_STD,
num_workers=4,
batch_size=args.b,
shuffle=True
)
loss_function = nn.CrossEntropyLoss()
# 여기에 들어갈 파라미터는 무엇?
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
# learning rate decay
train_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=settings.MILESTONES, gamma=0.2)
iter_per_epoch = len(cifar100_training_loader) # 391 왜 391개 일까?
wramup_scheduler = WarmUpLR(optimizer, iter_per_epoch * args.warm)