-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
59 lines (47 loc) · 1.82 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""Entry point to the application."""
import logging
import os
import hydra
from hydra.core.hydra_config import HydraConfig
import torch
import torch.multiprocessing as mp
from omegaconf import DictConfig, OmegaConf
from trainers import graphitee
from data.dataset import load_data
# logging.basicConfig(level = logging.INFO)
@hydra.main(config_path="config", config_name="news_recommendation", version_base=None)
def main(cfg: DictConfig) -> None:
"""Run the specified application"""
print(OmegaConf.to_yaml(cfg))
hydra_output_dir = HydraConfig.get().runtime.output_dir
if cfg.app == "load_data":
# downloads the dataset if not done already and returns the processed data.
# useful for debugging purposes.
graph, dataset = load_data(cfg=cfg)
return
elif cfg.app == "partition_data":
""" TODO """
elif cfg.app == "train":
trainer = graphitee
if cfg.backend == "gloo":
n_devices = torch.cuda.device_count()
devices = [f"{i}" for i in range(n_devices)]
if "CUDA_VISIBLE_DEVICES" in os.environ:
devices = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
n_devices = len(devices)
torch.multiprocessing.set_start_method('spawn')
os.environ["CUDA_VISIBLE_DEVICES"] = devices[0]
p = mp.Process(target=trainer.init_process,
args=(0, cfg, hydra_output_dir))
p.start()
p.join()
else:
raise ValueError(
f"Backend {cfg.backend} is not supported."
)
else:
raise ValueError(
f"Backend {cfg.app} is not supported."
)
if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter