Skip to content

Commit

Permalink
Merge pull request #44 from boostcampaitech5/feat-#32/UltraGCN_CV
Browse files Browse the repository at this point in the history
Feat #32/ultra gcn cv
  • Loading branch information
asdftyui authored May 31, 2023
2 parents dc24ed0 + 89a53fe commit 388736d
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 54 deletions.
21 changes: 18 additions & 3 deletions DKT/base/base_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data.sampler import SubsetRandomSampler
from sklearn.model_selection import KFold


class BaseDataLoader(DataLoader):
"""
Base class for all data loaders
"""
def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate):
def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, fold, collate_fn=default_collate):
self.validation_split = validation_split
self.shuffle = shuffle
self.fold = fold

self.batch_idx = 0
self.n_samples = len(dataset)
Expand Down Expand Up @@ -42,8 +44,21 @@ def _split_sampler(self, split):
else:
len_valid = int(self.n_samples * split)

valid_idx = idx_full[0:len_valid]
train_idx = np.delete(idx_full, np.arange(0, len_valid))
if self.fold == 0:
valid_idx = idx_full[0:len_valid]
train_idx = np.delete(idx_full, np.arange(0, len_valid))
elif self.fold == 1:
valid_idx = idx_full[len_valid:2*len_valid]
train_idx = np.delete(idx_full, np.arange(0, len_valid))
elif self.fold == 2:
valid_idx = idx_full[2*len_valid:3*len_valid]
train_idx = np.delete(idx_full, np.arange(0, len_valid))
elif self.fold == 3:
valid_idx = idx_full[3*len_valid:4*len_valid]
train_idx = np.delete(idx_full, np.arange(0, len_valid))
else:
valid_idx = idx_full[4*len_valid:]
train_idx = np.delete(idx_full, np.arange(0, len_valid))

train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
Expand Down
7 changes: 4 additions & 3 deletions DKT/base/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ class BaseTrainer:
"""
Base class for all trainers
"""
def __init__(self, model, criterion, metric_ftns, optimizer, config):
def __init__(self, model, criterion, metric_ftns, optimizer, config, fold):
self.config = config
self.logger = config.get_logger('trainer', config['trainer']['verbosity'])

self.model = model
self.criterion = criterion
self.metric_ftns = metric_ftns
self.optimizer = optimizer
self.fold = fold

cfg_trainer = config['trainer']
self.epochs = cfg_trainer['epochs']
Expand Down Expand Up @@ -119,9 +120,9 @@ def _save_checkpoint(self, epoch, save_best=False):
torch.save(state, filename)
self.logger.info("Saving checkpoint: {} ...".format(filename))
if save_best:
best_path = str(self.checkpoint_dir / 'model_best.pth')
best_path = str(self.checkpoint_dir / 'model_best{}.pth'.format(self.fold))
torch.save(state, best_path)
self.logger.info("Saving current best: model_best.pth ...")
self.logger.info("Saving current best: model_best{}.pth ...".format(self.fold))

def _resume_checkpoint(self, resume_path):
"""
Expand Down
8 changes: 5 additions & 3 deletions DKT/config/config_ultraGCN.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"name": "UltraGCN",
"n_gpu": 1,
"fold":false,

"arch": {
"type": "UltraGCN",
Expand All @@ -19,7 +20,8 @@
"batch_size": 512,
"shuffle": true,
"num_workers": 2,
"validation_split": 0.2
"validation_split": 0.2,
"random_seed": 42
}
},
"optimizer": {
Expand Down Expand Up @@ -56,8 +58,8 @@
},
"test": {
"data_dir": "~/input/data/test_data_modify.csv",
"model_dir": "./saved/models/UltraGCN/0518_033541/model_best.pth",
"submission_dir": "~/level2_dkt-recsys-09/DKT/submission/UltraGCN_submission.csv",
"model_dir": "./saved/models/UltraGCN/0531_020223/model_best",
"submission_dir": "./submission/UltraGCN_submission.csv",
"sample_submission_dir": "~/input/data/sample_submission.csv",
"batch_size": 512
}
Expand Down
8 changes: 4 additions & 4 deletions DKT/data_loader/data_loaders_GCN.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ def __len__(self):


class UltraGCNDataLoader(BaseDataLoader):
def __init__(self, data_dir, batch_size, shuffle=False, num_workers=1, validation_split=0.0):
def __init__(self, data_dir, batch_size, shuffle=False, num_workers=1, validation_split=0.0, random_seed=42, fold=0):

self.data_dir = data_dir
self.random_seed = random_seed
self.dataset = UltraGCNDataset(data_dir)

super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers)
super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers, fold)


class HMDataset(Dataset):
def __init__(self, data, max_seq_len):
self.data = data
Expand Down Expand Up @@ -120,4 +120,4 @@ def collate(self, batch):
for i, _ in enumerate(col_list):
col_list[i] = torch.stack(col_list[i])

return tuple(col_list)
return tuple(col_list)
35 changes: 28 additions & 7 deletions DKT/test_GCN.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import model.model_GCN as module_arch
from parse_config import ConfigParser
import pandas as pd
import numpy as np
import os
from torch.utils.data import DataLoader, TensorDataset


Expand All @@ -11,15 +13,34 @@ def main(config):
test_dataset = TensorDataset(torch.LongTensor(data.values))
test_dataloader = DataLoader(test_dataset, batch_size=config['test']['batch_size'], shuffle=False)

# build model architecture
model = config.init_obj('arch', module_arch)
model.load_state_dict(torch.load(config['test']['model_dir'])['state_dict'])
model.eval()
if config['fold']:
predicts_list = list()
for fold in range(5):
# build model architecture
model = config.init_obj('arch', module_arch)
model_path = config['test']['model_dir']+"{}.pth".format(fold)
model.load_state_dict(torch.load(model_path)['state_dict'])
model.eval()

predict = list()
for idx, data in enumerate(test_dataloader):
predict.extend(model(data[0]).tolist())
predicts_list.append(predict)
predicts = np.mean(predicts_list, axis=0)
else:
# build model architecture
model = config.init_obj('arch', module_arch)
model_path = config['test']['model_dir']+"0.pth"
model.load_state_dict(torch.load(model_path)['state_dict'])
model.eval()

predicts = list()
for idx, data in enumerate(test_dataloader):
predicts.extend(model(data[0]).tolist())
predicts = list()
for idx, data in enumerate(test_dataloader):
predicts.extend(model(data[0]).tolist())

dir_path = "./submission"
if not os.path.exists(dir_path):
os.makedirs(dir_path)
write_path = config['test']['submission_dir']
submission = pd.read_csv(config['test']['sample_submission_dir'])
submission['prediction'] = predicts
Expand Down
117 changes: 85 additions & 32 deletions DKT/train_GCN.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import collections
import torch
import numpy as np
import data_loader.data_loaders_GCN as module_data
from data_loader.data_loaders_GCN import UltraGCNDataLoader
import model.loss_GCN as module_loss
import model.metric_GCN as module_metric
import model.model_GCN as module_arch
Expand All @@ -12,6 +12,7 @@
import wandb
import os

import data_loader.data_loaders_GCN as module_data
os.environ['wandb mode'] = 'offline'

# fix random seeds for reproducibility
Expand All @@ -25,38 +26,90 @@ def main(config):
wandb.login()
logger = config.get_logger('train')

# setup data_loader instances
data_loader = config.init_obj('data_loader', module_data)
valid_data_loader = data_loader.split_validation()

wandb.init(project=config['name'], config=config, entity="ffm")
# build model architecture, then print to console
model = config.init_obj('arch', module_arch)
logger.info(model)

# prepare for (multi-device) GPU training
device, device_ids = prepare_device(config['n_gpu'])
model = model.to(device)
if len(device_ids) > 1:
model = torch.nn.DataParallel(model, device_ids=device_ids)

# get function handles of loss and metrics
criterion = getattr(module_loss, config['loss'])
metrics = [getattr(module_metric, met) for met in config['metrics']]

# build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = config.init_obj('optimizer', torch.optim, trainable_params)
lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer)

trainer = Trainer(model, criterion, metrics, optimizer,
config=config,
device=device,
data_loader=data_loader,
valid_data_loader=valid_data_loader,
lr_scheduler=lr_scheduler)

trainer.train()

if config['fold']:
for fold in range(5):
print(
f"-------------------------START FOLD {fold + 1} TRAINING---------------------------"
)
print(
f"-------------------------START FOLD {fold + 1} MODEL LOADING----------------------"
)

data_loader = UltraGCNDataLoader(data_dir=config['data_loader']['args']['data_dir'], batch_size=config['data_loader']['args']['batch_size'],
shuffle=config['data_loader']['args']['shuffle'], num_workers=config['data_loader']['args']['num_workers'],
validation_split=config['data_loader']['args']['validation_split'], random_seed=config['data_loader']['args']['random_seed'],
fold=fold)
valid_data_loader = data_loader.split_validation()

model = config.init_obj('arch', module_arch)
logger.info(model)

# prepare for (multi-device) GPU training
device, device_ids = prepare_device(config['n_gpu'])
model = model.to(device)
if len(device_ids) > 1:
model = torch.nn.DataParallel(model, device_ids=device_ids)

# get function handles of loss and metrics
criterion = getattr(module_loss, config['loss'])
metrics = [getattr(module_metric, met) for met in config['metrics']]

# build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = config.init_obj('optimizer', torch.optim, trainable_params)
lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer)

print(
f"-------------------------DONE FOLD {fold + 1} MODEL LOADING-----------------------"
)

trainer = Trainer(model, criterion, metrics, optimizer,
config=config,
device=device,
data_loader=data_loader,
valid_data_loader=valid_data_loader,
lr_scheduler=lr_scheduler,
fold=fold)

trainer.train()
print(
f"---------------------------DONE FOLD {fold + 1} TRAINING--------------------------"
)
else:
# setup data_loader instances
data_loader = config.init_obj('data_loader', module_data)
valid_data_loader = data_loader.split_validation()

wandb.init(project=config['name'], config=config, entity="ffm")
# build model architecture, then print to console
model = config.init_obj('arch', module_arch)
logger.info(model)

# prepare for (multi-device) GPU training
device, device_ids = prepare_device(config['n_gpu'])
model = model.to(device)
if len(device_ids) > 1:
model = torch.nn.DataParallel(model, device_ids=device_ids)

# get function handles of loss and metrics
criterion = getattr(module_loss, config['loss'])
metrics = [getattr(module_metric, met) for met in config['metrics']]

# build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = config.init_obj('optimizer', torch.optim, trainable_params)
lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer)

trainer = Trainer(model, criterion, metrics, optimizer,
config=config,
device=device,
data_loader=data_loader,
valid_data_loader=valid_data_loader,
lr_scheduler=lr_scheduler)

trainer.train()


if __name__ == '__main__':
Expand Down
5 changes: 3 additions & 2 deletions DKT/trainer/trainer_GCN.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ class Trainer(BaseTrainer):
Trainer class
"""
def __init__(self, model, criterion, metric_ftns, optimizer, config, device,
data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None):
super().__init__(model, criterion, metric_ftns, optimizer, config)
data_loader, fold=0, valid_data_loader=None, lr_scheduler=None, len_epoch=None):
super().__init__(model, criterion, metric_ftns, optimizer, config, fold)
self.config = config
self.device = device
self.data_loader = data_loader
self.fold = fold
if len_epoch is None:
# epoch-based training
self.len_epoch = len(self.data_loader)
Expand Down

0 comments on commit 388736d

Please sign in to comment.