Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
salvaRC committed Feb 13, 2024
1 parent e4523b3 commit d075af7
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/utilities/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ def update_dict_with_other(d1: Dict[str, Any], other: Dict[str, Any]): # _and_r
"""
diff = []
for k, v in other.items():
if isinstance(v, dict):
if isinstance(v, dict) and d1.get(k) is not None:
d1[k], diff_sub = update_dict_with_other(d1.get(k, {}), v)
diff += [f"{k}.{x}" for x in diff_sub]
else:
Expand Down
35 changes: 25 additions & 10 deletions src/utilities/wandb_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,15 +262,30 @@ def load_hydra_config_from_wandb(
raise ValueError(f"override_key_value must be a list of strings, but has type {type(override_key_value)}")
# copy overrides to new list
overrides = list(override_key_value.copy())
rank = os.environ.get("RANK", None) or os.environ.get("LOCAL_RANK", 0)

# Find latest hydra_config-v{VERSION}.yaml file in wandb cloud
hydra_config_files = [f.name for f in run.files() if "hydra_config" in f.name]
if len(hydra_config_files) == 0:
raise ValueError(f"Could not find any hydra_config file in wandb run {run_path}")
elif len(hydra_config_files) == 1:
assert hydra_config_files[0] == "hydra_config.yaml", f"Only one hydra_config file found: {hydra_config_files}"
else:
hydra_config_files = [f for f in hydra_config_files if "hydra_config-v" in f]
assert len(hydra_config_files) > 0, f"Could not find any hydra_config-v file in wandb run {run_path}"
# Sort by version number (largest is last, earliest are hydra_config.yaml and hydra_config-v1.yaml),
hydra_config_files = sorted(hydra_config_files, key=lambda x: int(x.split("-v")[-1].split(".")[0]))

hydra_config_file = hydra_config_files[-1]
if hydra_config_file != "hydra_config.yaml":
log.info(f" Reloading from hydra config file: {hydra_config_file}")

# Download from wandb cloud
wandb_restore_kwargs = dict(run_path=run_path, replace=True, root=os.getcwd())
try:
wandb.restore("hydra_config.yaml", **wandb_restore_kwargs)
except ValueError: # hydra_config has not been saved to wandb :(
overrides += json.load(wandb.restore("wandb-metadata.json", **wandb_restore_kwargs))["args"]
if len(overrides) == 0:
raise ValueError("wandb-metadata.json had no args, are you sure this is correct?")
# also wandb-metadata.json is unexpected (was likely overwritten)
if os.path.exists(hydra_config_file) and rank not in ["0", 0]:
pass
else:
wandb.restore(hydra_config_file, **wandb_restore_kwargs)

# remove overrides of the form k=v, where k has no dot in it. We don't support this.
overrides = [o for o in overrides if "=" in o and "." in o.split("=")[0]]
Expand All @@ -285,7 +300,7 @@ def load_hydra_config_from_wandb(
f"logger.wandb.tags={run.tags}",
f"logger.wandb.group={run.group}",
]
config = OmegaConf.load("hydra_config.yaml")
config = OmegaConf.load(hydra_config_file)
overrides = OmegaConf.from_dotlist(overrides)
config = OmegaConf.unsafe_merge(config, overrides)

Expand All @@ -297,8 +312,8 @@ def load_hydra_config_from_wandb(
# override config with override_config (which needs to be the second argument of OmegaConf.merge)
config = OmegaConf.unsafe_merge(config, override_config) # unsafe_merge since override_config is not needed

os.remove("hydra_config.yaml") if os.path.exists("hydra_config.yaml") else None
os.remove("../../hydra_config.yaml") if os.path.exists("../../hydra_config.yaml") else None
os.remove(hydra_config_file) if os.path.exists(hydra_config_file) else None
os.remove(f"../../{hydra_config_file}") if os.path.exists(f"../../{hydra_config_file}") else None

if run.id != config.logger.wandb.id and run.id in config.logger.wandb.name:
config.logger.wandb.id = run.id
Expand Down

0 comments on commit d075af7

Please sign in to comment.