-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
22 changed files
with
1,089 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
env/ | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
.hypothesis/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# celery beat schedule file | ||
celerybeat-schedule | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# dotenv | ||
.env | ||
|
||
# virtualenv | ||
.venv | ||
venv/ | ||
ENV/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
|
||
# input data, saved log, checkpoints | ||
data/ | ||
input/ | ||
saved/ | ||
datasets/ | ||
|
||
# editor, os cache directory | ||
.vscode/ | ||
.idea/ | ||
__MACOSX/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .base_data_loader import * | ||
from .base_model import * | ||
from .base_trainer import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import numpy as np | ||
from torch.utils.data import DataLoader | ||
from torch.utils.data.dataloader import default_collate | ||
from torch.utils.data.sampler import SubsetRandomSampler | ||
|
||
|
||
class BaseDataLoader(DataLoader): | ||
""" | ||
Base class for all data loaders | ||
""" | ||
def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate): | ||
self.validation_split = validation_split | ||
self.shuffle = shuffle | ||
|
||
self.batch_idx = 0 | ||
self.n_samples = len(dataset) | ||
|
||
self.sampler, self.valid_sampler = self._split_sampler(self.validation_split) | ||
|
||
self.init_kwargs = { | ||
'dataset': dataset, | ||
'batch_size': batch_size, | ||
'shuffle': self.shuffle, | ||
'collate_fn': collate_fn, | ||
'num_workers': num_workers | ||
} | ||
super().__init__(sampler=self.sampler, **self.init_kwargs) | ||
|
||
def _split_sampler(self, split): | ||
if split == 0.0: | ||
return None, None | ||
|
||
idx_full = np.arange(self.n_samples) | ||
|
||
np.random.seed(0) | ||
np.random.shuffle(idx_full) | ||
|
||
if isinstance(split, int): | ||
assert split > 0 | ||
assert split < self.n_samples, "validation set size is configured to be larger than entire dataset." | ||
len_valid = split | ||
else: | ||
len_valid = int(self.n_samples * split) | ||
|
||
valid_idx = idx_full[0:len_valid] | ||
train_idx = np.delete(idx_full, np.arange(0, len_valid)) | ||
|
||
train_sampler = SubsetRandomSampler(train_idx) | ||
valid_sampler = SubsetRandomSampler(valid_idx) | ||
|
||
# turn off shuffle option which is mutually exclusive with sampler | ||
self.shuffle = False | ||
self.n_samples = len(train_idx) | ||
|
||
return train_sampler, valid_sampler | ||
|
||
def split_validation(self): | ||
if self.valid_sampler is None: | ||
return None | ||
else: | ||
return DataLoader(sampler=self.valid_sampler, **self.init_kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import torch.nn as nn | ||
import numpy as np | ||
from abc import abstractmethod | ||
|
||
|
||
class BaseModel(nn.Module): | ||
""" | ||
Base class for all models | ||
""" | ||
@abstractmethod | ||
def forward(self, *inputs): | ||
""" | ||
Forward pass logic | ||
:return: Model output | ||
""" | ||
raise NotImplementedError | ||
|
||
def __str__(self): | ||
""" | ||
Model prints with number of trainable parameters | ||
""" | ||
model_parameters = filter(lambda p: p.requires_grad, self.parameters()) | ||
params = sum([np.prod(p.size()) for p in model_parameters]) | ||
return super().__str__() + '\nTrainable parameters: {}'.format(params) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import torch | ||
from abc import abstractmethod | ||
from numpy import inf | ||
from logger import TensorboardWriter | ||
|
||
|
||
class BaseTrainer: | ||
""" | ||
Base class for all trainers | ||
""" | ||
def __init__(self, model, criterion, metric_ftns, optimizer, config): | ||
self.config = config | ||
self.logger = config.get_logger('trainer', config['trainer']['verbosity']) | ||
|
||
self.model = model | ||
self.criterion = criterion | ||
self.metric_ftns = metric_ftns | ||
self.optimizer = optimizer | ||
|
||
cfg_trainer = config['trainer'] | ||
self.epochs = cfg_trainer['epochs'] | ||
self.save_period = cfg_trainer['save_period'] | ||
self.monitor = cfg_trainer.get('monitor', 'off') | ||
|
||
# configuration to monitor model performance and save best | ||
if self.monitor == 'off': | ||
self.mnt_mode = 'off' | ||
self.mnt_best = 0 | ||
else: | ||
self.mnt_mode, self.mnt_metric = self.monitor.split() | ||
assert self.mnt_mode in ['min', 'max'] | ||
|
||
self.mnt_best = inf if self.mnt_mode == 'min' else -inf | ||
self.early_stop = cfg_trainer.get('early_stop', inf) | ||
if self.early_stop <= 0: | ||
self.early_stop = inf | ||
|
||
self.start_epoch = 1 | ||
|
||
self.checkpoint_dir = config.save_dir | ||
|
||
# setup visualization writer instance | ||
self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard']) | ||
|
||
if config.resume is not None: | ||
self._resume_checkpoint(config.resume) | ||
|
||
@abstractmethod | ||
def _train_epoch(self, epoch): | ||
""" | ||
Training logic for an epoch | ||
:param epoch: Current epoch number | ||
""" | ||
raise NotImplementedError | ||
|
||
def train(self): | ||
""" | ||
Full training logic | ||
""" | ||
not_improved_count = 0 | ||
for epoch in range(self.start_epoch, self.epochs + 1): | ||
result = self._train_epoch(epoch) | ||
|
||
# save logged informations into log dict | ||
log = {'epoch': epoch} | ||
log.update(result) | ||
|
||
# print logged informations to the screen | ||
for key, value in log.items(): | ||
self.logger.info(' {:15s}: {}'.format(str(key), value)) | ||
|
||
# evaluate model performance according to configured metric, save best checkpoint as model_best | ||
best = False | ||
if self.mnt_mode != 'off': | ||
try: | ||
# check whether model performance improved or not, according to specified metric(mnt_metric) | ||
improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \ | ||
(self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best) | ||
except KeyError: | ||
self.logger.warning("Warning: Metric '{}' is not found. " | ||
"Model performance monitoring is disabled.".format(self.mnt_metric)) | ||
self.mnt_mode = 'off' | ||
improved = False | ||
|
||
if improved: | ||
self.mnt_best = log[self.mnt_metric] | ||
not_improved_count = 0 | ||
best = True | ||
else: | ||
not_improved_count += 1 | ||
|
||
if not_improved_count > self.early_stop: | ||
self.logger.info("Validation performance didn\'t improve for {} epochs. " | ||
"Training stops.".format(self.early_stop)) | ||
break | ||
|
||
if epoch % self.save_period == 0: | ||
self._save_checkpoint(epoch, save_best=best) | ||
|
||
def _save_checkpoint(self, epoch, save_best=False): | ||
""" | ||
Saving checkpoints | ||
:param epoch: current epoch number | ||
:param log: logging information of the epoch | ||
:param save_best: if True, rename the saved checkpoint to 'model_best.pth' | ||
""" | ||
arch = type(self.model).__name__ | ||
state = { | ||
'arch': arch, | ||
'epoch': epoch, | ||
'state_dict': self.model.state_dict(), | ||
'optimizer': self.optimizer.state_dict(), | ||
'monitor_best': self.mnt_best, | ||
'config': self.config | ||
} | ||
filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch)) | ||
torch.save(state, filename) | ||
self.logger.info("Saving checkpoint: {} ...".format(filename)) | ||
if save_best: | ||
best_path = str(self.checkpoint_dir / 'model_best.pth') | ||
torch.save(state, best_path) | ||
self.logger.info("Saving current best: model_best.pth ...") | ||
|
||
def _resume_checkpoint(self, resume_path): | ||
""" | ||
Resume from saved checkpoints | ||
:param resume_path: Checkpoint path to be resumed | ||
""" | ||
resume_path = str(resume_path) | ||
self.logger.info("Loading checkpoint: {} ...".format(resume_path)) | ||
checkpoint = torch.load(resume_path) | ||
self.start_epoch = checkpoint['epoch'] + 1 | ||
self.mnt_best = checkpoint['monitor_best'] | ||
|
||
# load architecture params from checkpoint. | ||
if checkpoint['config']['arch'] != self.config['arch']: | ||
self.logger.warning("Warning: Architecture configuration given in config file is different from that of " | ||
"checkpoint. This may yield an exception while state_dict is being loaded.") | ||
self.model.load_state_dict(checkpoint['state_dict']) | ||
|
||
# load optimizer state from checkpoint only when optimizer type is not changed. | ||
if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']: | ||
self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. " | ||
"Optimizer parameters not being resumed.") | ||
else: | ||
self.optimizer.load_state_dict(checkpoint['optimizer']) | ||
|
||
self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)) |
Oops, something went wrong.