-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
29 lines (21 loc) · 868 Bytes
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from omegaconf import OmegaConf
import src.data as data_module
from src.data.DataInterface import DataInterface
from src.model import Model
def app():
cfg = OmegaConf.load("./config/config.yaml")
model_type: str = cfg["model"]["type"]
recbole_cfg_path = f"./config/model_configs/{model_type.lower()}.yaml"
model_cfg = OmegaConf.load(recbole_cfg_path)
## Data Pre-Processing & Export data(for recbole)
print("---------Start Data Pre-Processing---------")
data_type: str = cfg["data"]["type"]
data_path = cfg["data"]["base_path"]
_: DataInterface = getattr(data_module, data_type)(model_cfg["dataset"], data_path, model_cfg["data_path"])
## (Train or Train/Valid) & Inference
mode = cfg["model"]["mode"]
print(f"-------Start Train{mode}------------")
Model(cfg)
return
if __name__ == "__main__":
app()