diff --git a/DKT/.gitignore b/DKT/.gitignore new file mode 100644 index 0000000..7bcc0d9 --- /dev/null +++ b/DKT/.gitignore @@ -0,0 +1,112 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# dotenv +.env + +# virtualenv +.venv +venv/ +ENV/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +# input data, saved log, checkpoints +data/ +input/ +saved/ +datasets/ + +# editor, os cache directory +.vscode/ +.idea/ +__MACOSX/ diff --git a/DKT/base/__init__.py b/DKT/base/__init__.py new file mode 100644 index 0000000..19c2224 --- /dev/null +++ b/DKT/base/__init__.py @@ -0,0 +1,3 @@ +from .base_data_loader import * +from .base_model import * +from .base_trainer import * diff --git a/DKT/base/base_data_loader.py b/DKT/base/base_data_loader.py new file mode 100644 index 0000000..91a0d98 --- /dev/null +++ b/DKT/base/base_data_loader.py @@ -0,0 +1,61 @@ +import numpy as np +from torch.utils.data import DataLoader +from torch.utils.data.dataloader import default_collate +from torch.utils.data.sampler import SubsetRandomSampler + + +class BaseDataLoader(DataLoader): + """ + Base class for all data loaders + """ + def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate): + self.validation_split = validation_split + self.shuffle = shuffle + + self.batch_idx = 0 + self.n_samples = len(dataset) + + self.sampler, self.valid_sampler = self._split_sampler(self.validation_split) + + self.init_kwargs = { + 'dataset': dataset, + 'batch_size': batch_size, + 'shuffle': self.shuffle, + 'collate_fn': collate_fn, + 'num_workers': num_workers + } + super().__init__(sampler=self.sampler, **self.init_kwargs) + + def _split_sampler(self, split): + if split == 0.0: + return None, None + + idx_full = np.arange(self.n_samples) + + np.random.seed(0) + np.random.shuffle(idx_full) + + if isinstance(split, int): + assert split > 0 + assert split < self.n_samples, "validation set size is configured to be larger than entire dataset." + len_valid = split + else: + len_valid = int(self.n_samples * split) + + valid_idx = idx_full[0:len_valid] + train_idx = np.delete(idx_full, np.arange(0, len_valid)) + + train_sampler = SubsetRandomSampler(train_idx) + valid_sampler = SubsetRandomSampler(valid_idx) + + # turn off shuffle option which is mutually exclusive with sampler + self.shuffle = False + self.n_samples = len(train_idx) + + return train_sampler, valid_sampler + + def split_validation(self): + if self.valid_sampler is None: + return None + else: + return DataLoader(sampler=self.valid_sampler, **self.init_kwargs) diff --git a/DKT/base/base_model.py b/DKT/base/base_model.py new file mode 100644 index 0000000..ad73507 --- /dev/null +++ b/DKT/base/base_model.py @@ -0,0 +1,25 @@ +import torch.nn as nn +import numpy as np +from abc import abstractmethod + + +class BaseModel(nn.Module): + """ + Base class for all models + """ + @abstractmethod + def forward(self, *inputs): + """ + Forward pass logic + + :return: Model output + """ + raise NotImplementedError + + def __str__(self): + """ + Model prints with number of trainable parameters + """ + model_parameters = filter(lambda p: p.requires_grad, self.parameters()) + params = sum([np.prod(p.size()) for p in model_parameters]) + return super().__str__() + '\nTrainable parameters: {}'.format(params) diff --git a/DKT/base/base_trainer.py b/DKT/base/base_trainer.py new file mode 100644 index 0000000..e43e33b --- /dev/null +++ b/DKT/base/base_trainer.py @@ -0,0 +1,151 @@ +import torch +from abc import abstractmethod +from numpy import inf +from logger import TensorboardWriter + + +class BaseTrainer: + """ + Base class for all trainers + """ + def __init__(self, model, criterion, metric_ftns, optimizer, config): + self.config = config + self.logger = config.get_logger('trainer', config['trainer']['verbosity']) + + self.model = model + self.criterion = criterion + self.metric_ftns = metric_ftns + self.optimizer = optimizer + + cfg_trainer = config['trainer'] + self.epochs = cfg_trainer['epochs'] + self.save_period = cfg_trainer['save_period'] + self.monitor = cfg_trainer.get('monitor', 'off') + + # configuration to monitor model performance and save best + if self.monitor == 'off': + self.mnt_mode = 'off' + self.mnt_best = 0 + else: + self.mnt_mode, self.mnt_metric = self.monitor.split() + assert self.mnt_mode in ['min', 'max'] + + self.mnt_best = inf if self.mnt_mode == 'min' else -inf + self.early_stop = cfg_trainer.get('early_stop', inf) + if self.early_stop <= 0: + self.early_stop = inf + + self.start_epoch = 1 + + self.checkpoint_dir = config.save_dir + + # setup visualization writer instance + self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard']) + + if config.resume is not None: + self._resume_checkpoint(config.resume) + + @abstractmethod + def _train_epoch(self, epoch): + """ + Training logic for an epoch + + :param epoch: Current epoch number + """ + raise NotImplementedError + + def train(self): + """ + Full training logic + """ + not_improved_count = 0 + for epoch in range(self.start_epoch, self.epochs + 1): + result = self._train_epoch(epoch) + + # save logged informations into log dict + log = {'epoch': epoch} + log.update(result) + + # print logged informations to the screen + for key, value in log.items(): + self.logger.info(' {:15s}: {}'.format(str(key), value)) + + # evaluate model performance according to configured metric, save best checkpoint as model_best + best = False + if self.mnt_mode != 'off': + try: + # check whether model performance improved or not, according to specified metric(mnt_metric) + improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \ + (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best) + except KeyError: + self.logger.warning("Warning: Metric '{}' is not found. " + "Model performance monitoring is disabled.".format(self.mnt_metric)) + self.mnt_mode = 'off' + improved = False + + if improved: + self.mnt_best = log[self.mnt_metric] + not_improved_count = 0 + best = True + else: + not_improved_count += 1 + + if not_improved_count > self.early_stop: + self.logger.info("Validation performance didn\'t improve for {} epochs. " + "Training stops.".format(self.early_stop)) + break + + if epoch % self.save_period == 0: + self._save_checkpoint(epoch, save_best=best) + + def _save_checkpoint(self, epoch, save_best=False): + """ + Saving checkpoints + + :param epoch: current epoch number + :param log: logging information of the epoch + :param save_best: if True, rename the saved checkpoint to 'model_best.pth' + """ + arch = type(self.model).__name__ + state = { + 'arch': arch, + 'epoch': epoch, + 'state_dict': self.model.state_dict(), + 'optimizer': self.optimizer.state_dict(), + 'monitor_best': self.mnt_best, + 'config': self.config + } + filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch)) + torch.save(state, filename) + self.logger.info("Saving checkpoint: {} ...".format(filename)) + if save_best: + best_path = str(self.checkpoint_dir / 'model_best.pth') + torch.save(state, best_path) + self.logger.info("Saving current best: model_best.pth ...") + + def _resume_checkpoint(self, resume_path): + """ + Resume from saved checkpoints + + :param resume_path: Checkpoint path to be resumed + """ + resume_path = str(resume_path) + self.logger.info("Loading checkpoint: {} ...".format(resume_path)) + checkpoint = torch.load(resume_path) + self.start_epoch = checkpoint['epoch'] + 1 + self.mnt_best = checkpoint['monitor_best'] + + # load architecture params from checkpoint. + if checkpoint['config']['arch'] != self.config['arch']: + self.logger.warning("Warning: Architecture configuration given in config file is different from that of " + "checkpoint. This may yield an exception while state_dict is being loaded.") + self.model.load_state_dict(checkpoint['state_dict']) + + # load optimizer state from checkpoint only when optimizer type is not changed. + if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']: + self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. " + "Optimizer parameters not being resumed.") + else: + self.optimizer.load_state_dict(checkpoint['optimizer']) + + self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)) diff --git a/DKT/config.json b/DKT/config.json new file mode 100644 index 0000000..0339e6a --- /dev/null +++ b/DKT/config.json @@ -0,0 +1,50 @@ +{ + "name": "Mnist_LeNet", + "n_gpu": 1, + + "arch": { + "type": "MnistModel", + "args": {} + }, + "data_loader": { + "type": "MnistDataLoader", + "args":{ + "data_dir": "data/", + "batch_size": 128, + "shuffle": true, + "validation_split": 0.1, + "num_workers": 2 + } + }, + "optimizer": { + "type": "Adam", + "args":{ + "lr": 0.001, + "weight_decay": 0, + "amsgrad": true + } + }, + "loss": "nll_loss", + "metrics": [ + "accuracy", "top_k_acc" + ], + "lr_scheduler": { + "type": "StepLR", + "args": { + "step_size": 50, + "gamma": 0.1 + } + }, + "trainer": { + "epochs": 100, + + "save_dir": "saved/", + "save_period": 1, + "verbosity": 2, + + "monitor": "min val_loss", + "early_stop": 10, + + "tensorboard": true + } +} diff --git a/DKT/data_loader/data_loaders.py b/DKT/data_loader/data_loaders.py new file mode 100644 index 0000000..f44b129 --- /dev/null +++ b/DKT/data_loader/data_loaders.py @@ -0,0 +1,16 @@ +from torchvision import datasets, transforms +from base import BaseDataLoader + + +class MnistDataLoader(BaseDataLoader): + """ + MNIST data loading demo using BaseDataLoader + """ + def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_workers=1, training=True): + trsfm = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + self.data_dir = data_dir + self.dataset = datasets.MNIST(self.data_dir, train=training, download=True, transform=trsfm) + super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers) diff --git a/DKT/logger/__init__.py b/DKT/logger/__init__.py new file mode 100644 index 0000000..5f3763b --- /dev/null +++ b/DKT/logger/__init__.py @@ -0,0 +1,2 @@ +from .logger import * +from .visualization import * \ No newline at end of file diff --git a/DKT/logger/logger.py b/DKT/logger/logger.py new file mode 100644 index 0000000..4599fb0 --- /dev/null +++ b/DKT/logger/logger.py @@ -0,0 +1,22 @@ +import logging +import logging.config +from pathlib import Path +from utils import read_json + + +def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO): + """ + Setup logging configuration + """ + log_config = Path(log_config) + if log_config.is_file(): + config = read_json(log_config) + # modify logging paths based on run config + for _, handler in config['handlers'].items(): + if 'filename' in handler: + handler['filename'] = str(save_dir / handler['filename']) + + logging.config.dictConfig(config) + else: + print("Warning: logging configuration file is not found in {}.".format(log_config)) + logging.basicConfig(level=default_level) diff --git a/DKT/logger/logger_config.json b/DKT/logger/logger_config.json new file mode 100644 index 0000000..c3e7e02 --- /dev/null +++ b/DKT/logger/logger_config.json @@ -0,0 +1,32 @@ + +{ + "version": 1, + "disable_existing_loggers": false, + "formatters": { + "simple": {"format": "%(message)s"}, + "datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"} + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "level": "DEBUG", + "formatter": "simple", + "stream": "ext://sys.stdout" + }, + "info_file_handler": { + "class": "logging.handlers.RotatingFileHandler", + "level": "INFO", + "formatter": "datetime", + "filename": "info.log", + "maxBytes": 10485760, + "backupCount": 20, "encoding": "utf8" + } + }, + "root": { + "level": "INFO", + "handlers": [ + "console", + "info_file_handler" + ] + } +} \ No newline at end of file diff --git a/DKT/logger/visualization.py b/DKT/logger/visualization.py new file mode 100644 index 0000000..34ef64f --- /dev/null +++ b/DKT/logger/visualization.py @@ -0,0 +1,73 @@ +import importlib +from datetime import datetime + + +class TensorboardWriter(): + def __init__(self, log_dir, logger, enabled): + self.writer = None + self.selected_module = "" + + if enabled: + log_dir = str(log_dir) + + # Retrieve vizualization writer. + succeeded = False + for module in ["torch.utils.tensorboard", "tensorboardX"]: + try: + self.writer = importlib.import_module(module).SummaryWriter(log_dir) + succeeded = True + break + except ImportError: + succeeded = False + self.selected_module = module + + if not succeeded: + message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \ + "this machine. Please install TensorboardX with 'pip install tensorboardx', upgrade PyTorch to " \ + "version >= 1.1 to use 'torch.utils.tensorboard' or turn off the option in the 'config.json' file." + logger.warning(message) + + self.step = 0 + self.mode = '' + + self.tb_writer_ftns = { + 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio', + 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding' + } + self.tag_mode_exceptions = {'add_histogram', 'add_embedding'} + self.timer = datetime.now() + + def set_step(self, step, mode='train'): + self.mode = mode + self.step = step + if step == 0: + self.timer = datetime.now() + else: + duration = datetime.now() - self.timer + self.add_scalar('steps_per_sec', 1 / duration.total_seconds()) + self.timer = datetime.now() + + def __getattr__(self, name): + """ + If visualization is configured to use: + return add_data() methods of tensorboard with additional information (step, tag) added. + Otherwise: + return a blank function handle that does nothing + """ + if name in self.tb_writer_ftns: + add_data = getattr(self.writer, name, None) + + def wrapper(tag, data, *args, **kwargs): + if add_data is not None: + # add mode(train/valid) tag + if name not in self.tag_mode_exceptions: + tag = '{}/{}'.format(tag, self.mode) + add_data(tag, data, self.step, *args, **kwargs) + return wrapper + else: + # default action for returning methods defined in this class, set_step() for instance. + try: + attr = object.__getattr__(name) + except AttributeError: + raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name)) + return attr diff --git a/DKT/model/loss.py b/DKT/model/loss.py new file mode 100644 index 0000000..1cd24cb --- /dev/null +++ b/DKT/model/loss.py @@ -0,0 +1,5 @@ +import torch.nn.functional as F + + +def nll_loss(output, target): + return F.nll_loss(output, target) diff --git a/DKT/model/metric.py b/DKT/model/metric.py new file mode 100644 index 0000000..df08e03 --- /dev/null +++ b/DKT/model/metric.py @@ -0,0 +1,20 @@ +import torch + + +def accuracy(output, target): + with torch.no_grad(): + pred = torch.argmax(output, dim=1) + assert pred.shape[0] == len(target) + correct = 0 + correct += torch.sum(pred == target).item() + return correct / len(target) + + +def top_k_acc(output, target, k=3): + with torch.no_grad(): + pred = torch.topk(output, k, dim=1)[1] + assert pred.shape[0] == len(target) + correct = 0 + for i in range(k): + correct += torch.sum(pred[:, i] == target).item() + return correct / len(target) diff --git a/DKT/model/model.py b/DKT/model/model.py new file mode 100644 index 0000000..7fa23c5 --- /dev/null +++ b/DKT/model/model.py @@ -0,0 +1,22 @@ +import torch.nn as nn +import torch.nn.functional as F +from base import BaseModel + + +class MnistModel(BaseModel): + def __init__(self, num_classes=10): + super().__init__() + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.conv2_drop = nn.Dropout2d() + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, num_classes) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) + x = x.view(-1, 320) + x = F.relu(self.fc1(x)) + x = F.dropout(x, training=self.training) + x = self.fc2(x) + return F.log_softmax(x, dim=1) diff --git a/DKT/parse_config.py b/DKT/parse_config.py new file mode 100644 index 0000000..309f153 --- /dev/null +++ b/DKT/parse_config.py @@ -0,0 +1,157 @@ +import os +import logging +from pathlib import Path +from functools import reduce, partial +from operator import getitem +from datetime import datetime +from logger import setup_logging +from utils import read_json, write_json + + +class ConfigParser: + def __init__(self, config, resume=None, modification=None, run_id=None): + """ + class to parse configuration json file. Handles hyperparameters for training, initializations of modules, checkpoint saving + and logging module. + :param config: Dict containing configurations, hyperparameters for training. contents of `config.json` file for example. + :param resume: String, path to the checkpoint being loaded. + :param modification: Dict keychain:value, specifying position values to be replaced from config dict. + :param run_id: Unique Identifier for training processes. Used to save checkpoints and training log. Timestamp is being used as default + """ + # load config file and apply modification + self._config = _update_config(config, modification) + self.resume = resume + + # set save_dir where trained model and log will be saved. + save_dir = Path(self.config['trainer']['save_dir']) + + exper_name = self.config['name'] + if run_id is None: # use timestamp as default run-id + run_id = datetime.now().strftime(r'%m%d_%H%M%S') + self._save_dir = save_dir / 'models' / exper_name / run_id + self._log_dir = save_dir / 'log' / exper_name / run_id + + # make directory for saving checkpoints and log. + exist_ok = run_id == '' + self.save_dir.mkdir(parents=True, exist_ok=exist_ok) + self.log_dir.mkdir(parents=True, exist_ok=exist_ok) + + # save updated config file to the checkpoint dir + write_json(self.config, self.save_dir / 'config.json') + + # configure logging module + setup_logging(self.log_dir) + self.log_levels = { + 0: logging.WARNING, + 1: logging.INFO, + 2: logging.DEBUG + } + + @classmethod + def from_args(cls, args, options=''): + """ + Initialize this class from some cli arguments. Used in train, test. + """ + for opt in options: + args.add_argument(*opt.flags, default=None, type=opt.type) + if not isinstance(args, tuple): + args = args.parse_args() + + if args.device is not None: + os.environ["CUDA_VISIBLE_DEVICES"] = args.device + if args.resume is not None: + resume = Path(args.resume) + cfg_fname = resume.parent / 'config.json' + else: + msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example." + assert args.config is not None, msg_no_cfg + resume = None + cfg_fname = Path(args.config) + + config = read_json(cfg_fname) + if args.config and resume: + # update new config for fine-tuning + config.update(read_json(args.config)) + + # parse custom cli options into dictionary + modification = {opt.target : getattr(args, _get_opt_name(opt.flags)) for opt in options} + return cls(config, resume, modification) + + def init_obj(self, name, module, *args, **kwargs): + """ + Finds a function handle with the name given as 'type' in config, and returns the + instance initialized with corresponding arguments given. + + `object = config.init_obj('name', module, a, b=1)` + is equivalent to + `object = module.name(a, b=1)` + """ + module_name = self[name]['type'] + module_args = dict(self[name]['args']) + assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' + module_args.update(kwargs) + return getattr(module, module_name)(*args, **module_args) + + def init_ftn(self, name, module, *args, **kwargs): + """ + Finds a function handle with the name given as 'type' in config, and returns the + function with given arguments fixed with functools.partial. + + `function = config.init_ftn('name', module, a, b=1)` + is equivalent to + `function = lambda *args, **kwargs: module.name(a, *args, b=1, **kwargs)`. + """ + module_name = self[name]['type'] + module_args = dict(self[name]['args']) + assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' + module_args.update(kwargs) + return partial(getattr(module, module_name), *args, **module_args) + + def __getitem__(self, name): + """Access items like ordinary dict.""" + return self.config[name] + + def get_logger(self, name, verbosity=2): + msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, self.log_levels.keys()) + assert verbosity in self.log_levels, msg_verbosity + logger = logging.getLogger(name) + logger.setLevel(self.log_levels[verbosity]) + return logger + + # setting read-only attributes + @property + def config(self): + return self._config + + @property + def save_dir(self): + return self._save_dir + + @property + def log_dir(self): + return self._log_dir + +# helper functions to update config dict with custom cli options +def _update_config(config, modification): + if modification is None: + return config + + for k, v in modification.items(): + if v is not None: + _set_by_path(config, k, v) + return config + +def _get_opt_name(flags): + for flg in flags: + if flg.startswith('--'): + return flg.replace('--', '') + return flags[0].replace('--', '') + +def _set_by_path(tree, keys, value): + """Set a value in a nested object in tree by sequence of keys.""" + keys = keys.split(';') + _get_by_path(tree, keys[:-1])[keys[-1]] = value + +def _get_by_path(tree, keys): + """Access a nested object in tree by sequence of keys.""" + return reduce(getitem, keys, tree) diff --git a/DKT/requirements.txt b/DKT/requirements.txt new file mode 100644 index 0000000..baa85b8 --- /dev/null +++ b/DKT/requirements.txt @@ -0,0 +1,5 @@ +torch>=1.1 +torchvision +numpy +tqdm +tensorboard>=1.14 diff --git a/DKT/test.py b/DKT/test.py new file mode 100644 index 0000000..fc084fa --- /dev/null +++ b/DKT/test.py @@ -0,0 +1,81 @@ +import argparse +import torch +from tqdm import tqdm +import data_loader.data_loaders as module_data +import model.loss as module_loss +import model.metric as module_metric +import model.model as module_arch +from parse_config import ConfigParser + + +def main(config): + logger = config.get_logger('test') + + # setup data_loader instances + data_loader = getattr(module_data, config['data_loader']['type'])( + config['data_loader']['args']['data_dir'], + batch_size=512, + shuffle=False, + validation_split=0.0, + training=False, + num_workers=2 + ) + + # build model architecture + model = config.init_obj('arch', module_arch) + logger.info(model) + + # get function handles of loss and metrics + loss_fn = getattr(module_loss, config['loss']) + metric_fns = [getattr(module_metric, met) for met in config['metrics']] + + logger.info('Loading checkpoint: {} ...'.format(config.resume)) + checkpoint = torch.load(config.resume) + state_dict = checkpoint['state_dict'] + if config['n_gpu'] > 1: + model = torch.nn.DataParallel(model) + model.load_state_dict(state_dict) + + # prepare model for testing + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = model.to(device) + model.eval() + + total_loss = 0.0 + total_metrics = torch.zeros(len(metric_fns)) + + with torch.no_grad(): + for i, (data, target) in enumerate(tqdm(data_loader)): + data, target = data.to(device), target.to(device) + output = model(data) + + # + # save sample images, or do something with output here + # + + # computing loss, metrics on test set + loss = loss_fn(output, target) + batch_size = data.shape[0] + total_loss += loss.item() * batch_size + for i, metric in enumerate(metric_fns): + total_metrics[i] += metric(output, target) * batch_size + + n_samples = len(data_loader.sampler) + log = {'loss': total_loss / n_samples} + log.update({ + met.__name__: total_metrics[i].item() / n_samples for i, met in enumerate(metric_fns) + }) + logger.info(log) + + +if __name__ == '__main__': + args = argparse.ArgumentParser(description='PyTorch Template') + args.add_argument('-c', '--config', default=None, type=str, + help='config file path (default: None)') + args.add_argument('-r', '--resume', default=None, type=str, + help='path to latest checkpoint (default: None)') + args.add_argument('-d', '--device', default=None, type=str, + help='indices of GPUs to enable (default: all)') + + config = ConfigParser.from_args(args) + main(config) diff --git a/DKT/train.py b/DKT/train.py new file mode 100644 index 0000000..a43f6c4 --- /dev/null +++ b/DKT/train.py @@ -0,0 +1,73 @@ +import argparse +import collections +import torch +import numpy as np +import data_loader.data_loaders as module_data +import model.loss as module_loss +import model.metric as module_metric +import model.model as module_arch +from parse_config import ConfigParser +from trainer import Trainer +from utils import prepare_device + + +# fix random seeds for reproducibility +SEED = 123 +torch.manual_seed(SEED) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False +np.random.seed(SEED) + +def main(config): + logger = config.get_logger('train') + + # setup data_loader instances + data_loader = config.init_obj('data_loader', module_data) + valid_data_loader = data_loader.split_validation() + + # build model architecture, then print to console + model = config.init_obj('arch', module_arch) + logger.info(model) + + # prepare for (multi-device) GPU training + device, device_ids = prepare_device(config['n_gpu']) + model = model.to(device) + if len(device_ids) > 1: + model = torch.nn.DataParallel(model, device_ids=device_ids) + + # get function handles of loss and metrics + criterion = getattr(module_loss, config['loss']) + metrics = [getattr(module_metric, met) for met in config['metrics']] + + # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler + trainable_params = filter(lambda p: p.requires_grad, model.parameters()) + optimizer = config.init_obj('optimizer', torch.optim, trainable_params) + lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer) + + trainer = Trainer(model, criterion, metrics, optimizer, + config=config, + device=device, + data_loader=data_loader, + valid_data_loader=valid_data_loader, + lr_scheduler=lr_scheduler) + + trainer.train() + + +if __name__ == '__main__': + args = argparse.ArgumentParser(description='PyTorch Template') + args.add_argument('-c', '--config', default=None, type=str, + help='config file path (default: None)') + args.add_argument('-r', '--resume', default=None, type=str, + help='path to latest checkpoint (default: None)') + args.add_argument('-d', '--device', default=None, type=str, + help='indices of GPUs to enable (default: all)') + + # custom cli options to modify configuration from default values given in json file. + CustomArgs = collections.namedtuple('CustomArgs', 'flags type target') + options = [ + CustomArgs(['--lr', '--learning_rate'], type=float, target='optimizer;args;lr'), + CustomArgs(['--bs', '--batch_size'], type=int, target='data_loader;args;batch_size') + ] + config = ConfigParser.from_args(args, options) + main(config) diff --git a/DKT/trainer/__init__.py b/DKT/trainer/__init__.py new file mode 100644 index 0000000..5c0a8a4 --- /dev/null +++ b/DKT/trainer/__init__.py @@ -0,0 +1 @@ +from .trainer import * diff --git a/DKT/trainer/trainer.py b/DKT/trainer/trainer.py new file mode 100644 index 0000000..ae71d4b --- /dev/null +++ b/DKT/trainer/trainer.py @@ -0,0 +1,110 @@ +import numpy as np +import torch +from torchvision.utils import make_grid +from base import BaseTrainer +from utils import inf_loop, MetricTracker + + +class Trainer(BaseTrainer): + """ + Trainer class + """ + def __init__(self, model, criterion, metric_ftns, optimizer, config, device, + data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None): + super().__init__(model, criterion, metric_ftns, optimizer, config) + self.config = config + self.device = device + self.data_loader = data_loader + if len_epoch is None: + # epoch-based training + self.len_epoch = len(self.data_loader) + else: + # iteration-based training + self.data_loader = inf_loop(data_loader) + self.len_epoch = len_epoch + self.valid_data_loader = valid_data_loader + self.do_validation = self.valid_data_loader is not None + self.lr_scheduler = lr_scheduler + self.log_step = int(np.sqrt(data_loader.batch_size)) + + self.train_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) + self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) + + def _train_epoch(self, epoch): + """ + Training logic for an epoch + + :param epoch: Integer, current training epoch. + :return: A log that contains average loss and metric in this epoch. + """ + self.model.train() + self.train_metrics.reset() + for batch_idx, (data, target) in enumerate(self.data_loader): + data, target = data.to(self.device), target.to(self.device) + + self.optimizer.zero_grad() + output = self.model(data) + loss = self.criterion(output, target) + loss.backward() + self.optimizer.step() + + self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) + self.train_metrics.update('loss', loss.item()) + for met in self.metric_ftns: + self.train_metrics.update(met.__name__, met(output, target)) + + if batch_idx % self.log_step == 0: + self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( + epoch, + self._progress(batch_idx), + loss.item())) + self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) + + if batch_idx == self.len_epoch: + break + log = self.train_metrics.result() + + if self.do_validation: + val_log = self._valid_epoch(epoch) + log.update(**{'val_'+k : v for k, v in val_log.items()}) + + if self.lr_scheduler is not None: + self.lr_scheduler.step() + return log + + def _valid_epoch(self, epoch): + """ + Validate after training an epoch + + :param epoch: Integer, current training epoch. + :return: A log that contains information about validation + """ + self.model.eval() + self.valid_metrics.reset() + with torch.no_grad(): + for batch_idx, (data, target) in enumerate(self.valid_data_loader): + data, target = data.to(self.device), target.to(self.device) + + output = self.model(data) + loss = self.criterion(output, target) + + self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') + self.valid_metrics.update('loss', loss.item()) + for met in self.metric_ftns: + self.valid_metrics.update(met.__name__, met(output, target)) + self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) + + # add histogram of model parameters to the tensorboard + for name, p in self.model.named_parameters(): + self.writer.add_histogram(name, p, bins='auto') + return self.valid_metrics.result() + + def _progress(self, batch_idx): + base = '[{}/{} ({:.0f}%)]' + if hasattr(self.data_loader, 'n_samples'): + current = batch_idx * self.data_loader.batch_size + total = self.data_loader.n_samples + else: + current = batch_idx + total = self.len_epoch + return base.format(current, total, 100.0 * current / total) diff --git a/DKT/utils/__init__.py b/DKT/utils/__init__.py new file mode 100644 index 0000000..46d3a15 --- /dev/null +++ b/DKT/utils/__init__.py @@ -0,0 +1 @@ +from .util import * diff --git a/DKT/utils/util.py b/DKT/utils/util.py new file mode 100644 index 0000000..d8894bf --- /dev/null +++ b/DKT/utils/util.py @@ -0,0 +1,67 @@ +import json +import torch +import pandas as pd +from pathlib import Path +from itertools import repeat +from collections import OrderedDict + + +def ensure_dir(dirname): + dirname = Path(dirname) + if not dirname.is_dir(): + dirname.mkdir(parents=True, exist_ok=False) + +def read_json(fname): + fname = Path(fname) + with fname.open('rt') as handle: + return json.load(handle, object_hook=OrderedDict) + +def write_json(content, fname): + fname = Path(fname) + with fname.open('wt') as handle: + json.dump(content, handle, indent=4, sort_keys=False) + +def inf_loop(data_loader): + ''' wrapper function for endless data loader. ''' + for loader in repeat(data_loader): + yield from loader + +def prepare_device(n_gpu_use): + """ + setup GPU device if available. get gpu device indices which are used for DataParallel + """ + n_gpu = torch.cuda.device_count() + if n_gpu_use > 0 and n_gpu == 0: + print("Warning: There\'s no GPU available on this machine," + "training will be performed on CPU.") + n_gpu_use = 0 + if n_gpu_use > n_gpu: + print(f"Warning: The number of GPU\'s configured to use is {n_gpu_use}, but only {n_gpu} are " + "available on this machine.") + n_gpu_use = n_gpu + device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu') + list_ids = list(range(n_gpu_use)) + return device, list_ids + +class MetricTracker: + def __init__(self, *keys, writer=None): + self.writer = writer + self._data = pd.DataFrame(index=keys, columns=['total', 'counts', 'average']) + self.reset() + + def reset(self): + for col in self._data.columns: + self._data[col].values[:] = 0 + + def update(self, key, value, n=1): + if self.writer is not None: + self.writer.add_scalar(key, value) + self._data.total[key] += value * n + self._data.counts[key] += n + self._data.average[key] = self._data.total[key] / self._data.counts[key] + + def avg(self, key): + return self._data.average[key] + + def result(self): + return dict(self._data.average)