diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 1e17ad97..90f5f596 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -676,6 +676,7 @@ def run(args: argparse.Namespace) -> None: entity=args.wandb_entity, name=args.wandb_name, config=wandb_config, + directory=args.wandb_dir, ) wandb.run.summary["params"] = args_dict_json diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 96f1e185..6a5d2b0e 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -568,6 +568,12 @@ def build_default_arg_parser() -> argparse.ArgumentParser: action="store_true", default=False, ) + parser.add_argument( + "--wandb_dir", + help="An absolute path to a directory where Weights and Biases metadata will be stored", + type=str, + default=None, + ) parser.add_argument( "--wandb_project", help="Weights and Biases project name", diff --git a/mace/tools/torch_tools.py b/mace/tools/torch_tools.py index f341cf0b..95ecd7cc 100644 --- a/mace/tools/torch_tools.py +++ b/mace/tools/torch_tools.py @@ -119,10 +119,10 @@ def voigt_to_matrix(t: torch.Tensor): ) -def init_wandb(project: str, entity: str, name: str, config: dict): +def init_wandb(project: str, entity: str, name: str, config: dict, directory:str): import wandb - wandb.init(project=project, entity=entity, name=name, config=config) + wandb.init(project=project, entity=entity, name=name, config=config, dir=directory) @contextmanager