-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into feat-#32/UltraGCN_CV
- Loading branch information
Showing
50 changed files
with
5,749 additions
and
114 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import argparse | ||
|
||
|
||
def parse_args(mode="train"): | ||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument("--seed", default=42, type=int, help="seed") | ||
parser.add_argument("--device", default="cpu", type=str, help="cpu or gpu") | ||
|
||
# -- 데이터 경로 및 파일 이름 설정 | ||
parser.add_argument( | ||
"--data_dir", | ||
default="/opt/ml/input/data/", | ||
type=str, | ||
help="data directory", | ||
) | ||
parser.add_argument( | ||
"--asset_dir", default="asset/", type=str, help="data directory" | ||
) | ||
parser.add_argument( | ||
"--file_name", default="train_data.csv", type=str, help="train file name" | ||
) | ||
|
||
# -- 모델의 경로 및 이름, 결과 저장 | ||
parser.add_argument( | ||
"--model_dir", default="models/", type=str, help="model directory" | ||
) | ||
parser.add_argument( | ||
"--model_name", default="model.pt", type=str, help="model file name" | ||
) | ||
parser.add_argument( | ||
"--output_dir", default="output/", type=str, help="output directory" | ||
) | ||
parser.add_argument( | ||
"--test_file_name", default="test_data.csv", type=str, help="test file name" | ||
) | ||
|
||
parser.add_argument( | ||
"--max_seq_len", default=30, type=int, help="max sequence length" | ||
) | ||
parser.add_argument("--num_workers", default=4, type=int, help="number of workers") | ||
|
||
# 모델 | ||
parser.add_argument( | ||
"--hidden_dim", default=300, type=int, help="hidden dimension size" | ||
) | ||
parser.add_argument("--n_layers", default=2, type=int, help="number of layers") | ||
parser.add_argument("--n_heads", default=4, type=int, help="number of heads") | ||
parser.add_argument("--drop_out", default=0.2, type=float, help="drop out rate") | ||
|
||
# 훈련 | ||
parser.add_argument("--n_epochs", default=30, type=int, help="number of epochs") | ||
parser.add_argument("--batch_size", default=64, type=int, help="batch size") | ||
parser.add_argument("--lr", default=0.009668, type=float, help="learning rate") | ||
parser.add_argument("--clip_grad", default=10, type=int, help="clip grad") | ||
parser.add_argument("--patience", default=10, type=int, help="for early stopping") | ||
|
||
parser.add_argument( | ||
"--log_steps", default=50, type=int, help="print log per n steps" | ||
) | ||
|
||
### 중요 ### | ||
parser.add_argument("--model", default="LastQuery", type=str, help="model type") | ||
parser.add_argument("--optimizer", default="adam", type=str, help="optimizer type") | ||
parser.add_argument( | ||
"--scheduler", default="plateau", type=str, help="scheduler type" | ||
) | ||
|
||
# -- Data split methods : default(user), k-fold, ... | ||
parser.add_argument( | ||
"--split_method", default="k-fold", type=str, help="data split strategy" | ||
) | ||
parser.add_argument( | ||
"--n_splits", default=5, type=str, help="number of k-fold splits" | ||
) | ||
|
||
### Argumentation 관련 ### | ||
|
||
parser.add_argument( | ||
"--window", default=True, type=bool, help="Arumentation with stridde window" | ||
) | ||
parser.add_argument( | ||
"--shuffle", default=False, type=bool, help="data shuffle option" | ||
) | ||
parser.add_argument("--stride", default=80, type=int) | ||
parser.add_argument("--shuffle_n", default=2, type=int) | ||
|
||
### Tfixup 관련 ### | ||
parser.add_argument("--Tfixup", default=False, type=bool, help="Tfixup") | ||
|
||
args = parser.parse_args() | ||
|
||
# args.stride = args.max_seq_len | ||
|
||
return args |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
{ | ||
"name": "HybridModel", | ||
"n_gpu": 1, | ||
|
||
"arch": { | ||
"type": "HMModel_lstm", | ||
"args": { | ||
"n_test": 1537, | ||
"n_tag": 912, | ||
"gamma": 1e-4, | ||
"lambda": 0.8, | ||
"hidden_dim": 256, | ||
"n_layers": 3, | ||
"n_heads": 4, | ||
"drop_out": 0.4, | ||
"model_dir": "/opt/ml/level2_dkt-recsys-09/DKT/saved/models/UltraGCN/0524_043901/model_best.pth", | ||
"ultragcn": { | ||
"user_num": 7442, | ||
"item_num": 9454, | ||
"embedding_dim": 64, | ||
"gamma": 1e-4, | ||
"lambda": 0.8 | ||
} | ||
} | ||
}, | ||
"data_loader": { | ||
"type": "HMDataLoader", | ||
"args":{ | ||
"data_dir": "/opt/ml/input/data", | ||
"asset_dir": "./asset", | ||
"batch_size": 512, | ||
"shuffle": true, | ||
"num_workers": 2, | ||
"max_seq_len": 200, | ||
"validation_split": 0.2, | ||
"stride": 10, | ||
"shuffle_n": 2, | ||
"shuffle_aug": false | ||
} | ||
}, | ||
"optimizer": { | ||
"type": "Adam", | ||
"args":{ | ||
"lr": 0.001, | ||
"weight_decay": 0, | ||
"amsgrad": true | ||
} | ||
}, | ||
"loss": "BCE_loss", | ||
"metrics": [ | ||
"accuracy", "auc" | ||
], | ||
"lr_scheduler": { | ||
"type": "StepLR", | ||
"args": { | ||
"step_size": 50, | ||
"gamma": 0.1 | ||
} | ||
}, | ||
"trainer": { | ||
"epochs": 100, | ||
|
||
"save_dir": "saved/", | ||
"save_period": 1, | ||
"verbosity": 2, | ||
|
||
"monitor": "min val_loss", | ||
"early_stop": 10, | ||
|
||
"tensorboard": true | ||
}, | ||
"test": { | ||
"data_dir": "~/input/data/test_data_modify.csv", | ||
"model_dir": "/opt/ml/level2_dkt-recsys-09/DKT/saved/models/HybridModel/0524_162035/model_best.pth", | ||
"submission_dir": "~/level2_dkt-recsys-09/DKT/submission/UltraGCN_HM_aug_lstm.csv", | ||
"sample_submission_dir": "~/input/data/sample_submission.csv", | ||
"batch_size": 128 | ||
} | ||
} |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
{ | ||
"name": "lgcnLSTMattn", | ||
"n_gpu": 1, | ||
|
||
"arch": { | ||
"type": "lgcnLSTMattn", | ||
"args": { | ||
"user_num": 7442, | ||
"item_num": 9454, | ||
"embedding_dim": 64, | ||
"gamma": 1e-4, | ||
"lambda": 0.8 | ||
} | ||
}, | ||
"data_loader": { | ||
"type": "lgcnLSTMattnDataLoader", | ||
"args":{ | ||
"data_dir": "/opt/ml/input/data/", | ||
"batch_size": 512, | ||
"shuffle": true, | ||
"num_workers": 2, | ||
"validation_split": 0.2 | ||
} | ||
}, | ||
"optimizer": { | ||
"type": "Adam", | ||
"args":{ | ||
"lr": 0.001, | ||
"weight_decay": 0, | ||
"amsgrad": true | ||
} | ||
}, | ||
"loss": "lgcnLSTMattn_loss", | ||
"metrics": [ | ||
"accuracy", | ||
"auc" | ||
], | ||
"lr_scheduler": { | ||
"type": "StepLR", | ||
"args": { | ||
"step_size": 50, | ||
"gamma": 0.1 | ||
} | ||
}, | ||
"model": { | ||
"max_seq_len": 200, | ||
"hidden_dim": 256, | ||
"n_layers": 2, | ||
"n_heads": 4, | ||
"drop_out": 0.4, | ||
"gcn_n_layes": 2, | ||
"alpha": 1.0, | ||
"beta": 1.0 | ||
}, | ||
"trainer": { | ||
"n_epochs": 60, | ||
"batch_size": 70, | ||
"lr": 0.000001, | ||
"clip_grad" : 10, | ||
"patience": 100, | ||
"log_step": 50, | ||
|
||
"save_dir": "saved/", | ||
"save_period": 1, | ||
"verbosity": 2, | ||
|
||
"monitor": "min val_loss", | ||
"early_stop": 10, | ||
|
||
"tensorboard": false | ||
}, | ||
"test": { | ||
"data_dir": "~/input/data/test_data_modify.csv", | ||
"model_dir": "./saved/models/LGCNtrans/0518_033541/model_best.pth", | ||
"submission_dir": "~/level2_dkt-recsys-09/DKT/submission/lgcnLSTMattn_submission.csv", | ||
"sample_submission_dir": "~/input/data/sample_submission.csv", | ||
"batch_size": 512 | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .data_preprocess_HM import * | ||
from .data_loaders_GCN import * | ||
from .data_preprocess_LQ import * | ||
from dataloader_lgcnlstmattn import * | ||
|
Oops, something went wrong.