diff --git a/.circleci/config.yml b/.circleci/config.yml index c6e2d2d0834..4b28d529a7b 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -87,14 +87,13 @@ commands: - run: name: Preparing environment - system command: | - choco install -y --no-progress miniconda3 + choco install -y --no-progress miniconda3 --params '"/AddToPath:1"' C:\tools\miniconda3\Scripts\conda.exe init powershell choco install -y --no-progress openssl javaruntime - run: name: Preparing environment - Hydra - # Using virtualenv==20.0.33 higher versions of virtualenv are not compatible with conda on windows. Relevant issue: https://github.com/ContinuumIO/anaconda-issues/issues/12094 command: | - conda create -n hydra python=<< parameters.py_version >> virtualenv==20.0.33 -qy + conda create -n hydra python=<< parameters.py_version >> -qy conda activate hydra pip install nox dataclasses --progress-bar off - save_cache: @@ -224,6 +223,23 @@ jobs: conda activate hydra nox -s lint_plugins test_plugins -ts exit $LASTEXITCODE + test_linux_omc_dev: + parameters: + py_version: + type: string + docker: + - image: cimg/base:stable-18.04 + steps: + - linux: + py_version: << parameters.py_version >> + - run: + name: Testing Hydra + command: | + export PATH="$HOME/miniconda3/envs/hydra/bin:$PATH" + export NOX_PYTHON_VERSIONS=<< parameters.py_version >> + export USE_OMEGACONF_DEV_VERSION=1 + pip install nox dataclasses --progress-bar off + nox -s test_core -ts # Misc coverage: docker: @@ -250,7 +266,11 @@ workflows: - test_win: matrix: parameters: - py_version: ["3.6", "3.7", "3.8"] + py_version: ["3.6", "3.7", "3.8", "3.9"] + - test_linux_omc_dev: + matrix: + parameters: + py_version: ["3.6", "3.7", "3.8", "3.9"] plugin_tests: @@ -269,7 +289,7 @@ workflows: - test_plugin_win: matrix: parameters: - py_version: ["3.6", "3.7", "3.8",] + py_version: ["3.6", "3.7", "3.8", "3.9"] test_plugin: [<< pipeline.parameters.test_plugins >>] diff --git a/.flake8 b/.flake8 index 93274d40d37..85f8204374b 100644 --- a/.flake8 +++ b/.flake8 @@ -7,6 +7,7 @@ exclude = ,tools/configen/example/gen ,tools/configen/tests/test_modules/expected ,temp + ,build # flake8-copyright does not support unicode, savoirfairelinux/flake8-copyright#15 ,examples/plugins/example_configsource_plugin/hydra_plugins/example_configsource_plugin/example_configsource_plugin.py diff --git a/.isort.cfg b/.isort.cfg index 90e7833adc2..dc0547edcd9 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -24,3 +24,4 @@ skip= ,tools/configen/example/gen ,tools/configen/tests/test_modules/expected ,temp + ,build diff --git a/.mypy.ini b/.mypy.ini index 384fcae9dbc..4c25b94361f 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -1,6 +1,10 @@ [mypy] python_version = 3.6 mypy_path=.stubs +exclude = (?x)( + build/ + | ^hydra/grammar/gen/ + ) [mypy-antlr4.*] ignore_missing_imports = True diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b5ce8f690f4..b5727ef0bd9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,7 +10,7 @@ repos: files: 'tools/.*' - repo: https://github.com/psf/black - rev: 20.8b1 + rev: 22.1.0 hooks: - id: black language_version: python3.6 @@ -25,4 +25,4 @@ repos: rev: v1.25.0 hooks: - id: yamllint - args: [-c=.yamllint] + args: [-c=.yamllint, --strict] diff --git a/README.md b/README.md index 118cef31f4b..be71cee5879 100644 --- a/README.md +++ b/README.md @@ -25,14 +25,17 @@ Language grade: Python - - zulip -

A framework for elegantly configuring complex applications.

- Check the website for more information. + Check the website for more information,
+ or click the thumbnail below for a one-minute video introduction to Hydra.
+

+

+ + 1 minute overview +

@@ -43,26 +46,18 @@ #### Stable -**Hydra 1.0** is the stable version of Hydra. +**Hydra 1.1** is the stable version of Hydra. - [Documentation](https://hydra.cc/docs/intro) - Installation : `pip install hydra-core --upgrade` -#### Release candidate -**Hydra 1.1** is now a release candidate! - -Please try it out and report any issues. - -- [Documentation](https://hydra.cc/docs/next/intro) -- [Release notes](https://github.com/facebookresearch/hydra/releases/tag/v1.1.0.rc1) -- Installation : `pip install hydra-core --upgrade --pre` -- Release candidates are more likely to have bugs, please report any issues. ### License Hydra is licensed under [MIT License](LICENSE). ## Community -Ask questions in the chat or StackOverflow (Use the tag #fb-hydra or #omegaconf): -* [Chat](https://hydra-framework.zulipchat.com) + +Ask questions in Github Discussions or StackOverflow (Use the tag #fb-hydra or #omegaconf): +* [Github Discussions](https://github.com/facebookresearch/hydra/discussions) * [StackOverflow](https://stackexchange.com/filters/391828/hydra-questions) * [Twitter](https://twitter.com/Hydra_Framework) diff --git a/build_helpers/build_helpers.py b/build_helpers/build_helpers.py index 157a0d60886..c6a511d576b 100644 --- a/build_helpers/build_helpers.py +++ b/build_helpers/build_helpers.py @@ -10,7 +10,7 @@ from typing import List, Optional from setuptools import Command -from setuptools.command import build_py, develop, sdist # type: ignore +from setuptools.command import build_py, develop, sdist def find_version(*file_paths: str) -> str: @@ -143,7 +143,7 @@ def run_antlr(cmd: Command) -> None: raise -class BuildPyCommand(build_py.build_py): # type: ignore +class BuildPyCommand(build_py.build_py): def run(self) -> None: if not self.dry_run: self.run_command("clean") @@ -151,16 +151,16 @@ def run(self) -> None: build_py.build_py.run(self) -class Develop(develop.develop): # type: ignore - def run(self) -> None: +class Develop(develop.develop): + def run(self) -> None: # type: ignore if not self.dry_run: run_antlr(self) develop.develop.run(self) -class SDistCommand(sdist.sdist): # type: ignore +class SDistCommand(sdist.sdist): def run(self) -> None: - if not self.dry_run: + if not self.dry_run: # type: ignore self.run_command("clean") run_antlr(self) sdist.sdist.run(self) diff --git a/examples/advanced/ad_hoc_composition/hydra_compose_example.py b/examples/advanced/ad_hoc_composition/hydra_compose_example.py index 814e2f84a41..54d85cb69e7 100644 --- a/examples/advanced/ad_hoc_composition/hydra_compose_example.py +++ b/examples/advanced/ad_hoc_composition/hydra_compose_example.py @@ -6,6 +6,6 @@ if __name__ == "__main__": # initialize the Hydra subsystem. # This is needed for apps that cannot have a standard @hydra.main() entry point - initialize(config_path="conf") + initialize(version_base=None, config_path="conf") cfg = compose("config.yaml", overrides=["db=mysql", "db.user=${oc.env:USER}"]) print(OmegaConf.to_yaml(cfg, resolve=True)) diff --git a/examples/advanced/config_search_path/my_app.py b/examples/advanced/config_search_path/my_app.py index 6b1303d511f..7585728f215 100644 --- a/examples/advanced/config_search_path/my_app.py +++ b/examples/advanced/config_search_path/my_app.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def my_app(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/examples/advanced/defaults_list_interpolation/my_app.py b/examples/advanced/defaults_list_interpolation/my_app.py index 6b1303d511f..7585728f215 100644 --- a/examples/advanced/defaults_list_interpolation/my_app.py +++ b/examples/advanced/defaults_list_interpolation/my_app.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def my_app(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/examples/advanced/hydra_app_example/hydra_app/main.py b/examples/advanced/hydra_app_example/hydra_app/main.py index e8b07ff979b..ec4c5b458f9 100644 --- a/examples/advanced/hydra_app_example/hydra_app/main.py +++ b/examples/advanced/hydra_app_example/hydra_app/main.py @@ -14,7 +14,7 @@ def add(app_cfg: DictConfig, key1: str, key2: str) -> Any: return ret -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def main(cfg: DictConfig) -> None: add(cfg.app, "num1", "num2") diff --git a/examples/advanced/hydra_app_example/tests/test_example.py b/examples/advanced/hydra_app_example/tests/test_example.py index b821ec6e6c5..afc8f0d5f3d 100644 --- a/examples/advanced/hydra_app_example/tests/test_example.py +++ b/examples/advanced/hydra_app_example/tests/test_example.py @@ -12,7 +12,7 @@ # 2. The module with your configs should be importable. it needs to have a __init__.py (can be empty). # 3. The config path is relative to the file calling initialize (this file) def test_with_initialize() -> None: - with initialize(config_path="../hydra_app/conf"): + with initialize(version_base=None, config_path="../hydra_app/conf"): # config is relative to a module cfg = compose(config_name="config", overrides=["app.user=test_user"]) assert cfg == { @@ -26,7 +26,7 @@ def test_with_initialize() -> None: # 3. The module should be absolute # 4. This approach is not sensitive to the location of this file, the test can be relocated freely. def test_with_initialize_config_module() -> None: - with initialize_config_module(config_module="hydra_app.conf"): + with initialize_config_module(version_base=None, config_module="hydra_app.conf"): # config is relative to a module cfg = compose(config_name="config", overrides=["app.user=test_user"]) assert cfg == { @@ -38,7 +38,9 @@ def test_with_initialize_config_module() -> None: # Usage in unittest style tests is similar. class TestWithUnittest(unittest.TestCase): def test_generated_config(self) -> None: - with initialize_config_module(config_module="hydra_app.conf"): + with initialize_config_module( + version_base=None, config_module="hydra_app.conf" + ): cfg = compose(config_name="config", overrides=["app.user=test_user"]) assert cfg == { "app": {"user": "test_user", "num1": 10, "num2": 20}, @@ -57,6 +59,6 @@ def test_generated_config(self) -> None: ], ) def test_user_logic(overrides: List[str], expected: int) -> None: - with initialize_config_module(config_module="hydra_app.conf"): + with initialize_config_module(version_base=None, config_module="hydra_app.conf"): cfg = compose(config_name="config", overrides=overrides) assert hydra_app.main.add(cfg.app, "num1", "num2") == expected diff --git a/examples/advanced/nested_defaults_list/my_app.py b/examples/advanced/nested_defaults_list/my_app.py index 6b1303d511f..7585728f215 100644 --- a/examples/advanced/nested_defaults_list/my_app.py +++ b/examples/advanced/nested_defaults_list/my_app.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def my_app(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/examples/advanced/package_overrides/simple.py b/examples/advanced/package_overrides/simple.py index 90590748f7e..76d2e4b73f7 100644 --- a/examples/advanced/package_overrides/simple.py +++ b/examples/advanced/package_overrides/simple.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path="conf", config_name="simple") +@hydra.main(version_base=None, config_path="conf", config_name="simple") def my_app(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/examples/advanced/package_overrides/two_packages.py b/examples/advanced/package_overrides/two_packages.py index 3e74134e93c..b0bdb60e852 100644 --- a/examples/advanced/package_overrides/two_packages.py +++ b/examples/advanced/package_overrides/two_packages.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path="conf", config_name="two_packages") +@hydra.main(version_base=None, config_path="conf", config_name="two_packages") def my_app(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/examples/advanced/ray_example/ray_compose_example.py b/examples/advanced/ray_example/ray_compose_example.py index 6002058e931..116cbd464cc 100644 --- a/examples/advanced/ray_example/ray_compose_example.py +++ b/examples/advanced/ray_example/ray_compose_example.py @@ -15,7 +15,7 @@ def train(overrides: List[str], cfg: DictConfig) -> Tuple[List[str], float]: return overrides, 0.9 -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def main(cfg: DictConfig) -> None: ray.init(**cfg.ray.init) diff --git a/examples/configure_hydra/custom_help/my_app.py b/examples/configure_hydra/custom_help/my_app.py index 6b1303d511f..7585728f215 100644 --- a/examples/configure_hydra/custom_help/my_app.py +++ b/examples/configure_hydra/custom_help/my_app.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def my_app(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/examples/configure_hydra/job_name/no_config_file_override.py b/examples/configure_hydra/job_name/no_config_file_override.py index d7215de2f3d..fbfa37824d1 100644 --- a/examples/configure_hydra/job_name/no_config_file_override.py +++ b/examples/configure_hydra/job_name/no_config_file_override.py @@ -5,7 +5,7 @@ from hydra.core.hydra_config import HydraConfig -@hydra.main(config_path=None) +@hydra.main(version_base=None) def experiment(_cfg: DictConfig) -> None: print(HydraConfig.get().job.name) diff --git a/examples/configure_hydra/job_name/with_config_file_override.py b/examples/configure_hydra/job_name/with_config_file_override.py index 01accce93c7..01f39dd5728 100644 --- a/examples/configure_hydra/job_name/with_config_file_override.py +++ b/examples/configure_hydra/job_name/with_config_file_override.py @@ -5,7 +5,7 @@ from hydra.core.hydra_config import HydraConfig -@hydra.main(config_path=".", config_name="config") +@hydra.main(version_base=None, config_path=".", config_name="config") def experiment(_cfg: DictConfig) -> None: print(HydraConfig.get().job.name) diff --git a/examples/configure_hydra/job_override_dirname/my_app.py b/examples/configure_hydra/job_override_dirname/my_app.py index 90143cce547..8db42720dc7 100644 --- a/examples/configure_hydra/job_override_dirname/my_app.py +++ b/examples/configure_hydra/job_override_dirname/my_app.py @@ -6,7 +6,7 @@ import hydra -@hydra.main(config_path=".", config_name="config") +@hydra.main(version_base=None, config_path=".", config_name="config") def my_app(_cfg: DictConfig) -> None: print(f"Working dir {os.getcwd()}") diff --git a/examples/configure_hydra/logging/my_app.py b/examples/configure_hydra/logging/my_app.py index 1b69b45bd51..d4ba242d959 100644 --- a/examples/configure_hydra/logging/my_app.py +++ b/examples/configure_hydra/logging/my_app.py @@ -8,7 +8,7 @@ log = logging.getLogger(__name__) -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def my_app(_cfg: DictConfig) -> None: log.info("Info level message") diff --git a/examples/configure_hydra/workdir/my_app.py b/examples/configure_hydra/workdir/my_app.py index b43c4510cf7..913d3208952 100644 --- a/examples/configure_hydra/workdir/my_app.py +++ b/examples/configure_hydra/workdir/my_app.py @@ -6,7 +6,7 @@ import hydra -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def experiment(_cfg: DictConfig) -> None: print(os.getcwd()) diff --git a/examples/experimental/rerun/config.yaml b/examples/experimental/rerun/config.yaml new file mode 100644 index 00000000000..ce895e3e227 --- /dev/null +++ b/examples/experimental/rerun/config.yaml @@ -0,0 +1,8 @@ +foo: bar + +hydra: + callbacks: + save_job_info: + _target_: hydra.experimental.callbacks.PickleJobInfoCallback + job: + chdir: false diff --git a/examples/experimental/rerun/my_app.py b/examples/experimental/rerun/my_app.py new file mode 100644 index 00000000000..2a7f056126a --- /dev/null +++ b/examples/experimental/rerun/my_app.py @@ -0,0 +1,20 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import logging + +from omegaconf import DictConfig + +import hydra +from hydra.core.hydra_config import HydraConfig + +log = logging.getLogger(__name__) + + +@hydra.main(version_base=None, config_path=".", config_name="config") +def my_app(cfg: DictConfig) -> None: + log.info(f"Output_dir={HydraConfig.get().runtime.output_dir}") + log.info(f"cfg.foo={cfg.foo}") + + +if __name__ == "__main__": + my_app() diff --git a/examples/instantiate/docs_example/my_app.py b/examples/instantiate/docs_example/my_app.py index 3b1bc412d97..d5b1a1a01a6 100644 --- a/examples/instantiate/docs_example/my_app.py +++ b/examples/instantiate/docs_example/my_app.py @@ -39,7 +39,7 @@ def __repr__(self) -> str: return f"Trainer(\n optimizer={self.optimizer},\n dataset={self.dataset}\n)" -@hydra.main(config_path=".", config_name="config") +@hydra.main(version_base=None, config_path=".", config_name="config") def my_app(cfg: DictConfig) -> None: optimizer = instantiate(cfg.trainer.optimizer) print(optimizer) diff --git a/examples/instantiate/object/my_app.py b/examples/instantiate/object/my_app.py index 8fb948f9638..593be34bed1 100644 --- a/examples/instantiate/object/my_app.py +++ b/examples/instantiate/object/my_app.py @@ -31,7 +31,7 @@ def connect(self) -> None: print(f"PostgreSQL connecting to {self.host}") -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def my_app(cfg: DictConfig) -> None: connection = instantiate(cfg.db) connection.connect() diff --git a/examples/instantiate/object_recursive/my_app.py b/examples/instantiate/object_recursive/my_app.py index ecd3e6b2958..c1bccdab553 100644 --- a/examples/instantiate/object_recursive/my_app.py +++ b/examples/instantiate/object_recursive/my_app.py @@ -28,7 +28,7 @@ def drive(self) -> None: print(f"Driver : {self.driver.name}, {len(self.wheels)} wheels") -@hydra.main(config_path=".", config_name="config") +@hydra.main(version_base=None, config_path=".", config_name="config") def my_app(cfg: DictConfig) -> None: car: Car = instantiate(cfg.car) car.drive() diff --git a/examples/instantiate/partial/config.yaml b/examples/instantiate/partial/config.yaml new file mode 100644 index 00000000000..46699d563bc --- /dev/null +++ b/examples/instantiate/partial/config.yaml @@ -0,0 +1,6 @@ +model: + _target_: my_app.Model + optim_partial: + _partial_: true + _target_: my_app.Optimizer + algo: SGD diff --git a/examples/instantiate/partial/my_app.py b/examples/instantiate/partial/my_app.py new file mode 100644 index 00000000000..26f0bec918f --- /dev/null +++ b/examples/instantiate/partial/my_app.py @@ -0,0 +1,38 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from typing import Any + +from omegaconf import DictConfig + +import hydra +from hydra.utils import instantiate + + +class Optimizer: + algo: str + lr: float + + def __init__(self, algo: str, lr: float) -> None: + self.algo = algo + self.lr = lr + + def __repr__(self) -> str: + return f"Optimizer(algo={self.algo},lr={self.lr})" + + +class Model: + def __init__(self, optim_partial: Any): + super().__init__() + self.optim = optim_partial(lr=0.1) + + def __repr__(self) -> str: + return f"Model(Optimizer={self.optim})" + + +@hydra.main(version_base=None, config_path=".", config_name="config") +def my_app(cfg: DictConfig) -> None: + model = instantiate(cfg.model) + print(model) + + +if __name__ == "__main__": + my_app() diff --git a/examples/instantiate/schema/my_app.py b/examples/instantiate/schema/my_app.py index bbeb02c60cd..1d4a5a5eb44 100644 --- a/examples/instantiate/schema/my_app.py +++ b/examples/instantiate/schema/my_app.py @@ -69,7 +69,7 @@ class Config: cs.store(group="db", name="postgresql", node=PostGreSQLConfig) -@hydra.main(config_path=None, config_name="config") +@hydra.main(version_base=None, config_name="config") def my_app(cfg: Config) -> None: connection = instantiate(cfg.db) connection.connect() diff --git a/examples/instantiate/schema_recursive/my_app.py b/examples/instantiate/schema_recursive/my_app.py index f2a02ad6551..072fbb59e2b 100644 --- a/examples/instantiate/schema_recursive/my_app.py +++ b/examples/instantiate/schema_recursive/my_app.py @@ -46,7 +46,7 @@ def pretty_print(tree: Tree, name: str = "root", depth: int = 0) -> None: pretty_print(tree.right, name="right", depth=depth + 1) -@hydra.main(config_path=".", config_name="config") +@hydra.main(version_base=None, config_path=".", config_name="config") def my_app(cfg: Config) -> None: tree: Tree = instantiate(cfg.tree) pretty_print(tree) diff --git a/examples/jupyter_notebooks/README.md b/examples/jupyter_notebooks/README.md index ef5ab81323a..5952e41f175 100644 --- a/examples/jupyter_notebooks/README.md +++ b/examples/jupyter_notebooks/README.md @@ -1,3 +1,3 @@ ## Jupyter notebooks -- **compose_configs_in_notebook.ipynb** [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/facebookresearch/hydra/master?filepath=examples%2Fjupyter_notebooks%2Fcompose_configs_in_notebook.ipynb) +- **compose_configs_in_notebook.ipynb** [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/facebookresearch/hydra/main?filepath=examples%2Fjupyter_notebooks%2Fcompose_configs_in_notebook.ipynb) diff --git a/examples/jupyter_notebooks/compose_configs_in_notebook.ipynb b/examples/jupyter_notebooks/compose_configs_in_notebook.ipynb index 2fbf3a5f951..e91eb7f68a5 100644 --- a/examples/jupyter_notebooks/compose_configs_in_notebook.ipynb +++ b/examples/jupyter_notebooks/compose_configs_in_notebook.ipynb @@ -51,7 +51,7 @@ } ], "source": [ - "with initialize(config_path=\"cloud_app/conf\"):\n", + "with initialize(version_base=None, config_path=\"cloud_app/conf\"):\n", " cfg = compose(overrides=[\"+db=mysql\"])\n", " print(cfg)" ] @@ -79,7 +79,7 @@ } ], "source": [ - "with initialize_config_module(config_module=\"cloud_app.conf\"):\n", + "with initialize_config_module(version_base=None, config_module=\"cloud_app.conf\"):\n", " cfg = compose(overrides=[\"+db=mysql\"])\n", " print(cfg)" ] @@ -108,7 +108,7 @@ ], "source": [ "abs_config_dir=os.path.abspath(\"cloud_app/conf\")\n", - "with initialize_config_dir(config_dir=abs_config_dir):\n", + "with initialize_config_dir(version_base=None, config_dir=abs_config_dir):\n", " cfg = compose(overrides=[\"+db=mysql\"])\n", " print(cfg)" ] @@ -138,7 +138,7 @@ } ], "source": [ - "initialize(config_path=\"cloud_app/conf\")\n", + "initialize(version_base=None, config_path=\"cloud_app/conf\")\n", "compose(overrides=[\"+db=mysql\"])" ] }, @@ -290,4 +290,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/examples/patterns/configuring_experiments/my_app.py b/examples/patterns/configuring_experiments/my_app.py index 6b1303d511f..7585728f215 100644 --- a/examples/patterns/configuring_experiments/my_app.py +++ b/examples/patterns/configuring_experiments/my_app.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def my_app(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/examples/patterns/extending_configs/my_app.py b/examples/patterns/extending_configs/my_app.py index 6b1303d511f..7585728f215 100644 --- a/examples/patterns/extending_configs/my_app.py +++ b/examples/patterns/extending_configs/my_app.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def my_app(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/examples/patterns/multi-select/my_app.py b/examples/patterns/multi-select/my_app.py index 6b1303d511f..7585728f215 100644 --- a/examples/patterns/multi-select/my_app.py +++ b/examples/patterns/multi-select/my_app.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def my_app(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/examples/patterns/specializing_config/example.py b/examples/patterns/specializing_config/example.py index 537adcff371..78d6ab11ed5 100644 --- a/examples/patterns/specializing_config/example.py +++ b/examples/patterns/specializing_config/example.py @@ -9,7 +9,7 @@ log = logging.getLogger(__name__) -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def experiment(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/examples/patterns/write_protect_config_node/frozen.py b/examples/patterns/write_protect_config_node/frozen.py index b3470157e9d..04556759013 100644 --- a/examples/patterns/write_protect_config_node/frozen.py +++ b/examples/patterns/write_protect_config_node/frozen.py @@ -16,7 +16,7 @@ class SerialPort: cs.store(name="config", node=SerialPort) -@hydra.main(config_path=None, config_name="config") +@hydra.main(version_base=None, config_name="config") def my_app(cfg: SerialPort) -> None: print(cfg) diff --git a/examples/plugins/example_launcher_plugin/example/my_app.py b/examples/plugins/example_launcher_plugin/example/my_app.py index 2915e42b728..0a1d9b27876 100644 --- a/examples/plugins/example_launcher_plugin/example/my_app.py +++ b/examples/plugins/example_launcher_plugin/example/my_app.py @@ -3,7 +3,7 @@ from omegaconf import OmegaConf, DictConfig -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def my_app(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/examples/plugins/example_registered_plugin/README.md b/examples/plugins/example_registered_plugin/README.md new file mode 100644 index 00000000000..e1ebd41ce9d --- /dev/null +++ b/examples/plugins/example_registered_plugin/README.md @@ -0,0 +1,2 @@ +# Hydra example plugin via `Plugins.register` +This plugin is not very useful, but it demonstrates how to register a plugin using the `Plugins.register` method. \ No newline at end of file diff --git a/examples/plugins/example_registered_plugin/example_registered_plugin/__init__.py b/examples/plugins/example_registered_plugin/example_registered_plugin/__init__.py new file mode 100644 index 00000000000..2b1aa90fa33 --- /dev/null +++ b/examples/plugins/example_registered_plugin/example_registered_plugin/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from hydra.core.plugins import Plugins +from hydra.plugins.plugin import Plugin + + +class ExampleRegisteredPlugin(Plugin): + def __init__(self, v: int) -> None: + self.v = v + + def add(self, x: int) -> int: + return self.v + x + + +def register_example_plugin() -> None: + """The Hydra user should call this function before invoking @hydra.main""" + Plugins.instance().register(ExampleRegisteredPlugin) diff --git a/examples/plugins/example_registered_plugin/setup.py b/examples/plugins/example_registered_plugin/setup.py new file mode 100644 index 00000000000..d55d1053b8f --- /dev/null +++ b/examples/plugins/example_registered_plugin/setup.py @@ -0,0 +1,34 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# type: ignore +from setuptools import setup + +with open("README.md", "r") as fh: + LONG_DESC = fh.read() +setup( + name="hydra-example-registered-plugin", + version="1.0.0", + author="Jasha Sommer-Simpson", + author_email="jasha10@fb.com", + description="Example of Hydra Plugin Registration", + long_description=LONG_DESC, + long_description_content_type="text/markdown", + url="https://github.com/facebookresearch/hydra/", + packages=["example_registered_plugin"], + classifiers=[ + # Feel free to use another license. + "License :: OSI Approved :: MIT License", + # Hydra uses Python version and Operating system to determine + # In which environments to test this plugin + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Operating System :: OS Independent", + ], + install_requires=[ + # consider pinning to a specific major version of Hydra to avoid unexpected problems + # if a new major version of Hydra introduces breaking changes for plugins. + # e.g: "hydra-core==1.1.*", + "hydra-core", + ], +) diff --git a/tests/test_apps/run_as_module/2/conf/__init__.py b/examples/plugins/example_registered_plugin/tests/__init__.py similarity index 100% rename from tests/test_apps/run_as_module/2/conf/__init__.py rename to examples/plugins/example_registered_plugin/tests/__init__.py diff --git a/examples/plugins/example_registered_plugin/tests/test_example_registered_plugin.py b/examples/plugins/example_registered_plugin/tests/test_example_registered_plugin.py new file mode 100644 index 00000000000..e372152e79c --- /dev/null +++ b/examples/plugins/example_registered_plugin/tests/test_example_registered_plugin.py @@ -0,0 +1,22 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from hydra.core.plugins import Plugins +from hydra.plugins.plugin import Plugin + +from example_registered_plugin import ExampleRegisteredPlugin, register_example_plugin + + +def test_discovery() -> None: + # Tests that this plugin can be discovered after Plugins.register is called + + plugin_name = ExampleRegisteredPlugin.__name__ + + assert plugin_name not in [x.__name__ for x in Plugins.instance().discover(Plugin)] + + register_example_plugin() + + assert plugin_name in [x.__name__ for x in Plugins.instance().discover(Plugin)] + + +def test_example_plugin() -> None: + a = ExampleRegisteredPlugin(10) + assert a.add(20) == 30 diff --git a/examples/plugins/example_searchpath_plugin/tests/test_example_search_path_plugin.py b/examples/plugins/example_searchpath_plugin/tests/test_example_search_path_plugin.py index 0be9f0c8fb5..872b23ddefd 100644 --- a/examples/plugins/example_searchpath_plugin/tests/test_example_search_path_plugin.py +++ b/examples/plugins/example_searchpath_plugin/tests/test_example_search_path_plugin.py @@ -17,7 +17,7 @@ def test_discovery() -> None: def test_config_installed() -> None: - with initialize(config_path=None): + with initialize(version_base=None): config_loader = GlobalHydra.instance().config_loader() assert "my_default_output_dir" in config_loader.get_group_options( "hydra/output" diff --git a/examples/plugins/example_sweeper_plugin/example/my_app.py b/examples/plugins/example_sweeper_plugin/example/my_app.py index 2915e42b728..0a1d9b27876 100644 --- a/examples/plugins/example_sweeper_plugin/example/my_app.py +++ b/examples/plugins/example_sweeper_plugin/example/my_app.py @@ -3,7 +3,7 @@ from omegaconf import OmegaConf, DictConfig -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def my_app(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/examples/tutorials/basic/running_your_hydra_app/3_working_directory/my_app.py b/examples/tutorials/basic/running_your_hydra_app/3_working_directory/my_app.py index 3766c7e4fd7..68fef3e4bfd 100644 --- a/examples/tutorials/basic/running_your_hydra_app/3_working_directory/my_app.py +++ b/examples/tutorials/basic/running_your_hydra_app/3_working_directory/my_app.py @@ -6,7 +6,7 @@ import hydra -@hydra.main(config_path=None) +@hydra.main(version_base=None) def my_app(_cfg: DictConfig) -> None: print(f"Working directory : {os.getcwd()}") diff --git a/examples/tutorials/basic/running_your_hydra_app/4_logging/my_app.py b/examples/tutorials/basic/running_your_hydra_app/4_logging/my_app.py index 2bd3fdca985..ed33b183d92 100644 --- a/examples/tutorials/basic/running_your_hydra_app/4_logging/my_app.py +++ b/examples/tutorials/basic/running_your_hydra_app/4_logging/my_app.py @@ -9,7 +9,7 @@ log = logging.getLogger(__name__) -@hydra.main(config_path=None) +@hydra.main(version_base=None) def my_app(_cfg: DictConfig) -> None: log.info("Info level message") log.debug("Debug level message") diff --git a/examples/tutorials/basic/running_your_hydra_app/5_basic_sweep/conf/config.yaml b/examples/tutorials/basic/running_your_hydra_app/5_basic_sweep/conf/config.yaml new file mode 100644 index 00000000000..70fba8cdfd7 --- /dev/null +++ b/examples/tutorials/basic/running_your_hydra_app/5_basic_sweep/conf/config.yaml @@ -0,0 +1,9 @@ +defaults: + - db: ??? + - _self_ + +hydra: + sweeper: + params: + db: glob(*) + db.timeout: 5,10 diff --git a/examples/tutorials/basic/running_your_hydra_app/5_basic_sweep/conf/db/mysql.yaml b/examples/tutorials/basic/running_your_hydra_app/5_basic_sweep/conf/db/mysql.yaml new file mode 100644 index 00000000000..e3a277358f2 --- /dev/null +++ b/examples/tutorials/basic/running_your_hydra_app/5_basic_sweep/conf/db/mysql.yaml @@ -0,0 +1,4 @@ +driver: mysql +user: omry +password: secret +timeout: 5 diff --git a/examples/tutorials/basic/running_your_hydra_app/5_basic_sweep/conf/db/postgresql.yaml b/examples/tutorials/basic/running_your_hydra_app/5_basic_sweep/conf/db/postgresql.yaml new file mode 100644 index 00000000000..66afdff5eb5 --- /dev/null +++ b/examples/tutorials/basic/running_your_hydra_app/5_basic_sweep/conf/db/postgresql.yaml @@ -0,0 +1,4 @@ +driver: postgresql +user: postgres_user +password: drowssap +timeout: 10 diff --git a/examples/tutorials/basic/running_your_hydra_app/5_basic_sweep/my_app.py b/examples/tutorials/basic/running_your_hydra_app/5_basic_sweep/my_app.py new file mode 100644 index 00000000000..49385e57c7f --- /dev/null +++ b/examples/tutorials/basic/running_your_hydra_app/5_basic_sweep/my_app.py @@ -0,0 +1,13 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from omegaconf import DictConfig + +import hydra + + +@hydra.main(version_base=None, config_path="conf", config_name="config") +def my_app(cfg: DictConfig) -> None: + print(f"driver={cfg.db.driver}, timeout={cfg.db.timeout}") + + +if __name__ == "__main__": + my_app() diff --git a/examples/tutorials/basic/your_first_hydra_app/1_simple_cli/my_app.py b/examples/tutorials/basic/your_first_hydra_app/1_simple_cli/my_app.py index 99a8af7fd56..beca9b58cc6 100644 --- a/examples/tutorials/basic/your_first_hydra_app/1_simple_cli/my_app.py +++ b/examples/tutorials/basic/your_first_hydra_app/1_simple_cli/my_app.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path=None) +@hydra.main(version_base=None) def my_app(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/examples/tutorials/basic/your_first_hydra_app/2_config_file/my_app.py b/examples/tutorials/basic/your_first_hydra_app/2_config_file/my_app.py index 07916d55b34..0839a3c526d 100644 --- a/examples/tutorials/basic/your_first_hydra_app/2_config_file/my_app.py +++ b/examples/tutorials/basic/your_first_hydra_app/2_config_file/my_app.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path=".", config_name="config") +@hydra.main(version_base=None, config_path=".", config_name="config") def my_app(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/examples/tutorials/basic/your_first_hydra_app/3_using_config/my_app.py b/examples/tutorials/basic/your_first_hydra_app/3_using_config/my_app.py index 641f3875435..b46c2d8cb33 100644 --- a/examples/tutorials/basic/your_first_hydra_app/3_using_config/my_app.py +++ b/examples/tutorials/basic/your_first_hydra_app/3_using_config/my_app.py @@ -5,7 +5,7 @@ import hydra -@hydra.main(config_path=".", config_name="config") +@hydra.main(version_base=None, config_path=".", config_name="config") def my_app(cfg: DictConfig) -> None: assert cfg.node.loompa == 10 # attribute style access assert cfg["node"]["loompa"] == 10 # dictionary style access diff --git a/examples/tutorials/basic/your_first_hydra_app/4_config_groups/my_app.py b/examples/tutorials/basic/your_first_hydra_app/4_config_groups/my_app.py index 9a19701ec9d..b6d515e1dc3 100644 --- a/examples/tutorials/basic/your_first_hydra_app/4_config_groups/my_app.py +++ b/examples/tutorials/basic/your_first_hydra_app/4_config_groups/my_app.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path="conf") +@hydra.main(version_base=None, config_path="conf") def my_app(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/examples/tutorials/basic/your_first_hydra_app/5_defaults/my_app.py b/examples/tutorials/basic/your_first_hydra_app/5_defaults/my_app.py index 6b1303d511f..7585728f215 100644 --- a/examples/tutorials/basic/your_first_hydra_app/5_defaults/my_app.py +++ b/examples/tutorials/basic/your_first_hydra_app/5_defaults/my_app.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def my_app(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/examples/tutorials/basic/your_first_hydra_app/6_composition/my_app.py b/examples/tutorials/basic/your_first_hydra_app/6_composition/my_app.py index 6b1303d511f..7585728f215 100644 --- a/examples/tutorials/basic/your_first_hydra_app/6_composition/my_app.py +++ b/examples/tutorials/basic/your_first_hydra_app/6_composition/my_app.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def my_app(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/examples/tutorials/structured_configs/1_minimal/my_app.py b/examples/tutorials/structured_configs/1_minimal/my_app.py index 421ea8ea6a4..32b0d9daaf3 100644 --- a/examples/tutorials/structured_configs/1_minimal/my_app.py +++ b/examples/tutorials/structured_configs/1_minimal/my_app.py @@ -16,7 +16,7 @@ class MySQLConfig: cs.store(name="config", node=MySQLConfig) -@hydra.main(config_path=None, config_name="config") +@hydra.main(version_base=None, config_name="config") def my_app(cfg: MySQLConfig) -> None: print(f"Host: {cfg.host}, port: {cfg.port}") diff --git a/examples/tutorials/structured_configs/1_minimal/my_app_type_error.py b/examples/tutorials/structured_configs/1_minimal/my_app_type_error.py index 74b19666e7d..54f3aacd457 100644 --- a/examples/tutorials/structured_configs/1_minimal/my_app_type_error.py +++ b/examples/tutorials/structured_configs/1_minimal/my_app_type_error.py @@ -16,7 +16,7 @@ class MySQLConfig: cs.store(name="config", node=MySQLConfig) -@hydra.main(config_path=None, config_name="config") +@hydra.main(version_base=None, config_name="config") def my_app(cfg: MySQLConfig) -> None: # pork should be port! if cfg.pork == 80: # type: ignore diff --git a/examples/tutorials/structured_configs/2_static_complex/my_app.py b/examples/tutorials/structured_configs/2_static_complex/my_app.py index ebcc95a5ee7..6efc59c0179 100644 --- a/examples/tutorials/structured_configs/2_static_complex/my_app.py +++ b/examples/tutorials/structured_configs/2_static_complex/my_app.py @@ -28,7 +28,7 @@ class MyConfig: cs.store(name="config", node=MyConfig) -@hydra.main(config_path=None, config_name="config") +@hydra.main(version_base=None, config_name="config") def my_app(cfg: MyConfig) -> None: print(f"Title={cfg.ui.title}, size={cfg.ui.width}x{cfg.ui.height} pixels") diff --git a/examples/tutorials/structured_configs/3_config_groups/my_app.py b/examples/tutorials/structured_configs/3_config_groups/my_app.py index c92d94600e1..82e3ca705ef 100644 --- a/examples/tutorials/structured_configs/3_config_groups/my_app.py +++ b/examples/tutorials/structured_configs/3_config_groups/my_app.py @@ -35,7 +35,7 @@ class Config: cs.store(group="db", name="postgresql", node=PostGreSQLConfig) -@hydra.main(config_path=None, config_name="config") +@hydra.main(version_base=None, config_name="config") def my_app(cfg: Config) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/examples/tutorials/structured_configs/3_config_groups/my_app_with_inheritance.py b/examples/tutorials/structured_configs/3_config_groups/my_app_with_inheritance.py index 477e7c71b86..0cd5263e852 100644 --- a/examples/tutorials/structured_configs/3_config_groups/my_app_with_inheritance.py +++ b/examples/tutorials/structured_configs/3_config_groups/my_app_with_inheritance.py @@ -40,7 +40,7 @@ class Config: cs.store(group="db", name="postgresql", node=PostGreSQLConfig) -@hydra.main(config_path=None, config_name="config") +@hydra.main(version_base=None, config_name="config") def my_app(cfg: Config) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/examples/tutorials/structured_configs/4_defaults/my_app.py b/examples/tutorials/structured_configs/4_defaults/my_app.py index cee73d14126..f953f2e99fa 100644 --- a/examples/tutorials/structured_configs/4_defaults/my_app.py +++ b/examples/tutorials/structured_configs/4_defaults/my_app.py @@ -48,7 +48,7 @@ class Config: cs.store(name="config", node=Config) -@hydra.main(config_path=None, config_name="config") +@hydra.main(version_base=None, config_name="config") def my_app(cfg: Config) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/examples/tutorials/structured_configs/5.1_structured_config_schema_same_config_group/my_app.py b/examples/tutorials/structured_configs/5.1_structured_config_schema_same_config_group/my_app.py index b423ed74c84..196d9b48a14 100644 --- a/examples/tutorials/structured_configs/5.1_structured_config_schema_same_config_group/my_app.py +++ b/examples/tutorials/structured_configs/5.1_structured_config_schema_same_config_group/my_app.py @@ -43,7 +43,7 @@ class Config: cs.store(group="db", name="base_postgresql", node=PostGreSQLConfig) -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def my_app(cfg: Config) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/examples/tutorials/structured_configs/5.2_structured_config_schema_different_config_group/my_app.py b/examples/tutorials/structured_configs/5.2_structured_config_schema_different_config_group/my_app.py index 1a8089b8e29..d3728c608a9 100644 --- a/examples/tutorials/structured_configs/5.2_structured_config_schema_different_config_group/my_app.py +++ b/examples/tutorials/structured_configs/5.2_structured_config_schema_different_config_group/my_app.py @@ -23,6 +23,7 @@ class Config: @hydra.main( + version_base=None, config_path="conf", config_name="config", ) diff --git a/hydra/__init__.py b/hydra/__init__.py index 16ec8df42d6..c636044a865 100644 --- a/hydra/__init__.py +++ b/hydra/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Source of truth for Hydra's version -__version__ = "1.2.0dev1" +__version__ = "1.2.0.dev6" from hydra import utils from hydra.errors import MissingConfigException from hydra.main import main diff --git a/hydra/_internal/config_loader_impl.py b/hydra/_internal/config_loader_impl.py index b283e84a11f..b5586d07319 100644 --- a/hydra/_internal/config_loader_impl.py +++ b/hydra/_internal/config_loader_impl.py @@ -135,6 +135,7 @@ def load_configuration( overrides: List[str], run_mode: RunMode, from_shell: bool = True, + validate_sweep_overrides: bool = True, ) -> DictConfig: try: return self._load_configuration_impl( @@ -142,6 +143,7 @@ def load_configuration( overrides=overrides, run_mode=run_mode, from_shell=from_shell, + validate_sweep_overrides=validate_sweep_overrides, ) except OmegaConfBaseException as e: raise ConfigCompositionException().with_traceback(sys.exc_info()[2]) from e @@ -220,8 +222,9 @@ def _load_configuration_impl( overrides: List[str], run_mode: RunMode, from_shell: bool = True, + validate_sweep_overrides: bool = True, ) -> DictConfig: - from hydra import __version__ + from hydra import __version__, version self.ensure_main_config_source_available() caching_repo = CachingConfigRepository(self.repository) @@ -231,9 +234,10 @@ def _load_configuration_impl( self._process_config_searchpath(config_name, parsed_overrides, caching_repo) - self.validate_sweep_overrides_legal( - overrides=parsed_overrides, run_mode=run_mode, from_shell=from_shell - ) + if validate_sweep_overrides: + self.validate_sweep_overrides_legal( + overrides=parsed_overrides, run_mode=run_mode, from_shell=from_shell + ) defaults_list = create_defaults_list( repo=caching_repo, @@ -273,6 +277,7 @@ def _load_configuration_impl( cfg.hydra.job.env_set[key] = os.environ[key] cfg.hydra.runtime.version = __version__ + cfg.hydra.runtime.version_base = version.getbase() cfg.hydra.runtime.cwd = os.getcwd() cfg.hydra.runtime.config_sources = [ @@ -306,9 +311,6 @@ def load_sweep_config( run_mode=RunMode.RUN, ) - with open_dict(sweep_config): - sweep_config.hydra.runtime.merge_with(master_config.hydra.runtime) - # Partial copy of master config cache, to ensure we get the same resolved values for timestamps cache: Dict[str, Any] = defaultdict(dict, {}) cache_master_config = OmegaConf.get_cache(master_config) @@ -527,9 +529,9 @@ def strip_defaults(cfg: Any) -> None: if cfg._is_missing() or cfg._is_none(): return with flag_override(cfg, ["readonly", "struct"], False): - if getattr(cfg, "__HYDRA_REMOVE_TOP_LEVEL_DEFAULTS__", False): + if cfg._get_flag("HYDRA_REMOVE_TOP_LEVEL_DEFAULTS"): + cfg._set_flag("HYDRA_REMOVE_TOP_LEVEL_DEFAULTS", None) cfg.pop("defaults", None) - cfg.pop("__HYDRA_REMOVE_TOP_LEVEL_DEFAULTS__") for _key, value in cfg.items_ex(resolve=False): strip_defaults(value) diff --git a/hydra/_internal/config_repository.py b/hydra/_internal/config_repository.py index 1e45930fc48..c9a4f6667c9 100644 --- a/hydra/_internal/config_repository.py +++ b/hydra/_internal/config_repository.py @@ -15,6 +15,7 @@ read_write, ) +from hydra import version from hydra.core.config_search_path import ConfigSearchPath from hydra.core.object_type import ObjectType from hydra.plugins.config_source import ConfigResult, ConfigSource @@ -190,10 +191,11 @@ def issue_deprecated_name_warning() -> None: for item in defaults._iter_ex(resolve=False): default: InputDefault if isinstance(item, DictConfig): - old_optional = None - if len(item) > 1: - if "optional" in item: - old_optional = item.pop("optional") + if not version.base_at_least("1.2"): + old_optional = None + if len(item) > 1: + if "optional" in item: + old_optional = item.pop("optional") keys = list(item.keys()) if len(keys) > 1: @@ -209,23 +211,24 @@ def issue_deprecated_name_warning() -> None: keywords = ConfigRepository.Keywords() self._extract_keywords_from_config_group(config_group, keywords) - if not keywords.optional and old_optional is not None: - keywords.optional = old_optional + if not version.base_at_least("1.2"): + if not keywords.optional and old_optional is not None: + keywords.optional = old_optional node = item._get_node(key) assert node is not None and isinstance(node, Node) config_value = node._value() - if old_optional is not None: - # DEPRECATED: remove in 1.2 - msg = dedent( - f""" - In {config_path}: 'optional: true' is deprecated. - Use 'optional {key}: {config_value}' instead. - Support for the old style will be removed in Hydra 1.2""" - ) + if not version.base_at_least("1.2"): + if old_optional is not None: + msg = dedent( + f""" + In {config_path}: 'optional: true' is deprecated. + Use 'optional {key}: {config_value}' instead. + Support for the old style is removed for Hydra version_base >= 1.2""" + ) - deprecation_warning(msg) + deprecation_warning(msg) if config_value is not None and not isinstance( config_value, (str, list) @@ -285,7 +288,7 @@ def _extract_defaults_list(self, config_path: str, cfg: Container) -> ListConfig # It will be removed later. # This is addressing an edge case where the defaults list re-appears once the dataclass is used # as a prototype during OmegaConf merge. - cfg["__HYDRA_REMOVE_TOP_LEVEL_DEFAULTS__"] = True + cfg._set_flag("HYDRA_REMOVE_TOP_LEVEL_DEFAULTS", True) defaults = cfg.get("defaults", empty) if not isinstance(defaults, ListConfig): if isinstance(defaults, DictConfig): diff --git a/hydra/_internal/config_search_path_impl.py b/hydra/_internal/config_search_path_impl.py index ab186e07b13..d0d7131f476 100644 --- a/hydra/_internal/config_search_path_impl.py +++ b/hydra/_internal/config_search_path_impl.py @@ -1,5 +1,5 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -from typing import List, MutableSequence, Optional +from typing import List, MutableSequence, Optional, Union from hydra.core.config_search_path import ( ConfigSearchPath, @@ -63,7 +63,10 @@ def append( self.append(provider, path, anchor=None) def prepend( - self, provider: str, path: str, anchor: Optional[SearchPathQuery] = None + self, + provider: str, + path: str, + anchor: Optional[Union[SearchPathQuery, str]] = None, ) -> None: """ Prepends to the search path. diff --git a/hydra/_internal/core_plugins/basic_sweeper.py b/hydra/_internal/core_plugins/basic_sweeper.py index f1b41264413..d92256f47c8 100644 --- a/hydra/_internal/core_plugins/basic_sweeper.py +++ b/hydra/_internal/core_plugins/basic_sweeper.py @@ -19,9 +19,10 @@ import itertools import logging import time +from collections import OrderedDict from dataclasses import dataclass from pathlib import Path -from typing import Any, Iterable, List, Optional, Sequence +from typing import Any, Dict, Iterable, List, Optional, Sequence from omegaconf import DictConfig, OmegaConf @@ -39,6 +40,7 @@ class BasicSweeperConf: _target_: str = "hydra._internal.core_plugins.basic_sweeper.BasicSweeper" max_batch_size: Optional[int] = None + params: Optional[Dict[str, str]] = None ConfigStore.instance().store( @@ -54,14 +56,20 @@ class BasicSweeper(Sweeper): Basic sweeper """ - def __init__(self, max_batch_size: Optional[int]) -> None: + def __init__( + self, max_batch_size: Optional[int], params: Optional[Dict[str, str]] = None + ) -> None: """ Instantiates """ super(BasicSweeper, self).__init__() + + if params is None: + params = {} self.overrides: Optional[Sequence[Sequence[Sequence[str]]]] = None self.batch_index = 0 self.max_batch_size = max_batch_size + self.params = params self.hydra_context: Optional[HydraContext] = None self.config: Optional[DictConfig] = None @@ -101,12 +109,13 @@ def split_arguments( ) -> List[List[List[str]]]: lists = [] + final_overrides = OrderedDict() for override in overrides: if override.is_sweep_override(): if override.is_discrete_sweep(): key = override.get_key_element() sweep = [f"{key}={val}" for val in override.sweep_string_iterator()] - lists.append(sweep) + final_overrides[key] = sweep else: assert override.value_type is not None raise HydraException( @@ -115,7 +124,10 @@ def split_arguments( else: key = override.get_key_element() value = override.get_value_element_as_str() - lists.append([f"{key}={value}"]) + final_overrides[key] = [f"{key}={value}"] + + for _, v in final_overrides.items(): + lists.append(v) all_batches = [list(x) for x in itertools.product(*lists)] assert max_batch_size is None or max_batch_size > 0 @@ -127,13 +139,22 @@ def split_arguments( ) return [x for x in chunks_iter] + def _parse_config(self) -> List[str]: + params_conf = [] + for k, v in self.params.items(): + params_conf.append(f"{k}={v}") + return params_conf + def sweep(self, arguments: List[str]) -> Any: assert self.config is not None assert self.launcher is not None assert self.hydra_context is not None + params_conf = self._parse_config() + params_conf.extend(arguments) + parser = OverridesParser.create(config_loader=self.hydra_context.config_loader) - overrides = parser.parse_overrides(arguments) + overrides = parser.parse_overrides(params_conf) self.overrides = self.split_arguments(overrides, self.max_batch_size) returns: List[Sequence[JobReturn]] = [] diff --git a/hydra/_internal/core_plugins/importlib_resources_config_source.py b/hydra/_internal/core_plugins/importlib_resources_config_source.py index 0090674300a..efd894bfc3b 100644 --- a/hydra/_internal/core_plugins/importlib_resources_config_source.py +++ b/hydra/_internal/core_plugins/importlib_resources_config_source.py @@ -2,19 +2,17 @@ import os import sys import zipfile -from typing import Any, List, Optional +from typing import TYPE_CHECKING, Any, List, Optional from omegaconf import OmegaConf from hydra.core.object_type import ObjectType from hydra.plugins.config_source import ConfigLoadError, ConfigResult, ConfigSource -if sys.version_info.major >= 4 or ( - sys.version_info.major >= 3 and sys.version_info.minor >= 9 -): - from importlib import resources +if TYPE_CHECKING or (sys.version_info < (3, 9)): + import importlib_resources as resources else: - import importlib_resources as resources # type:ignore + from importlib import resources # Relevant issue: https://github.com/python/mypy/issues/1153 # Use importlib backport for Python older than 3.9 @@ -55,7 +53,7 @@ def _read_config(self, res: Any) -> ConfigResult: def load_config(self, config_path: str) -> ConfigResult: normalized_config_path = self._normalize_file_name(config_path) - res = resources.files(self.path).joinpath(normalized_config_path) # type:ignore + res = resources.files(self.path).joinpath(normalized_config_path) if not res.exists(): raise ConfigLoadError(f"Config not found : {normalized_config_path}") @@ -63,18 +61,15 @@ def load_config(self, config_path: str) -> ConfigResult: def available(self) -> bool: try: - ret = resources.is_resource(self.path, "__init__.py") # type:ignore - assert isinstance(ret, bool) - return ret - except ValueError: - return False - except ModuleNotFoundError: + files = resources.files(self.path) + except (ValueError, ModuleNotFoundError, TypeError): return False + return any(f.name == "__init__.py" and f.is_file() for f in files.iterdir()) def is_group(self, config_path: str) -> bool: try: - files = resources.files(self.path) # type:ignore - except Exception: + files = resources.files(self.path) + except (ValueError, ModuleNotFoundError, TypeError): return False res = files.joinpath(config_path) @@ -85,8 +80,8 @@ def is_group(self, config_path: str) -> bool: def is_config(self, config_path: str) -> bool: config_path = self._normalize_file_name(config_path) try: - files = resources.files(self.path) # type:ignore - except Exception: + files = resources.files(self.path) + except (ValueError, ModuleNotFoundError, TypeError): return False res = files.joinpath(config_path) ret = res.exists() and res.is_file() @@ -95,9 +90,7 @@ def is_config(self, config_path: str) -> bool: def list(self, config_path: str, results_filter: Optional[ObjectType]) -> List[str]: files: List[str] = [] - for file in ( - resources.files(self.path).joinpath(config_path).iterdir() # type:ignore - ): + for file in resources.files(self.path).joinpath(config_path).iterdir(): fname = file.name fpath = os.path.join(config_path, fname) self._list_add_result( diff --git a/hydra/_internal/defaults_list.py b/hydra/_internal/defaults_list.py index 1e5a9a30cfd..d82bf64d8bd 100644 --- a/hydra/_internal/defaults_list.py +++ b/hydra/_internal/defaults_list.py @@ -435,10 +435,7 @@ def _has_config_content(cfg: DictConfig) -> bool: return False for key in cfg.keys(): - if not OmegaConf.is_missing(cfg, key) and key not in ( - "defaults", - "__HYDRA_REMOVE_TOP_LEVEL_DEFAULTS__", - ): + if not OmegaConf.is_missing(cfg, key) and key != "defaults": return True return False diff --git a/hydra/_internal/hydra.py b/hydra/_internal/hydra.py index 5cb461e851b..b9716d54787 100644 --- a/hydra/_internal/hydra.py +++ b/hydra/_internal/hydra.py @@ -78,6 +78,23 @@ def __init__(self, task_name: str, config_loader: ConfigLoader) -> None: self.config_loader = config_loader JobRuntime().set("name", task_name) + def get_mode( + self, + config_name: Optional[str], + overrides: List[str], + ) -> Any: + try: + cfg = self.compose_config( + config_name=config_name, + overrides=overrides, + with_log_configuration=False, + run_mode=RunMode.MULTIRUN, + validate_sweep_overrides=False, + ) + return cfg.hydra.mode + except Exception: + return None + def run( self, config_name: Optional[str], @@ -91,6 +108,10 @@ def run( with_log_configuration=with_log_configuration, run_mode=RunMode.RUN, ) + if cfg.hydra.mode is None: + cfg.hydra.mode = RunMode.RUN + else: + assert cfg.hydra.mode == RunMode.RUN callbacks = Callbacks(cfg) callbacks.on_run_start(config=cfg, config_name=config_name) @@ -125,6 +146,7 @@ def multirun( with_log_configuration=with_log_configuration, run_mode=RunMode.MULTIRUN, ) + callbacks = Callbacks(cfg) callbacks.on_multirun_start(config=cfg, config_name=config_name) @@ -384,7 +406,10 @@ def _print_plugins(self) -> None: log.debug("\t\t{}".format(plugin_name)) def _print_search_path( - self, config_name: Optional[str], overrides: List[str] + self, + config_name: Optional[str], + overrides: List[str], + run_mode: RunMode = RunMode.RUN, ) -> None: assert log is not None log.debug("") @@ -395,7 +420,7 @@ def _print_search_path( cfg = self.compose_config( config_name=config_name, overrides=overrides, - run_mode=RunMode.RUN, + run_mode=run_mode, with_log_configuration=False, ) HydraConfig.instance().set_config(cfg) @@ -458,10 +483,15 @@ def _print_plugins_profiling_info(self, top_n: int) -> None: self._log_footer(header=header, filler="-") def _print_config_info( - self, config_name: Optional[str], overrides: List[str] + self, + config_name: Optional[str], + overrides: List[str], + run_mode: RunMode = RunMode.RUN, ) -> None: assert log is not None - self._print_search_path(config_name=config_name, overrides=overrides) + self._print_search_path( + config_name=config_name, overrides=overrides, run_mode=run_mode + ) self._print_defaults_tree(config_name=config_name, overrides=overrides) self._print_defaults_list(config_name=config_name, overrides=overrides) @@ -469,7 +499,7 @@ def _print_config_info( lambda: self.compose_config( config_name=config_name, overrides=overrides, - run_mode=RunMode.RUN, + run_mode=run_mode, with_log_configuration=False, ) ) @@ -480,13 +510,16 @@ def _print_config_info( log.info(OmegaConf.to_yaml(cfg)) def _print_defaults_list( - self, config_name: Optional[str], overrides: List[str] + self, + config_name: Optional[str], + overrides: List[str], + run_mode: RunMode = RunMode.RUN, ) -> None: assert log is not None defaults = self.config_loader.compute_defaults_list( config_name=config_name, overrides=overrides, - run_mode=RunMode.RUN, + run_mode=run_mode, ) box: List[List[str]] = [ @@ -534,10 +567,11 @@ def _print_debug_info( self, config_name: Optional[str], overrides: List[str], + run_mode: RunMode = RunMode.RUN, ) -> None: assert log is not None if log.isEnabledFor(logging.DEBUG): - self._print_all_info(config_name, overrides) + self._print_all_info(config_name, overrides, run_mode) def compose_config( self, @@ -546,6 +580,7 @@ def compose_config( run_mode: RunMode, with_log_configuration: bool = False, from_shell: bool = True, + validate_sweep_overrides: bool = True, ) -> DictConfig: """ :param config_name: @@ -561,26 +596,36 @@ def compose_config( overrides=overrides, run_mode=run_mode, from_shell=from_shell, + validate_sweep_overrides=validate_sweep_overrides, ) if with_log_configuration: configure_log(cfg.hydra.hydra_logging, cfg.hydra.verbose) global log log = logging.getLogger(__name__) - self._print_debug_info(config_name, overrides) + self._print_debug_info(config_name, overrides, run_mode) return cfg def _print_plugins_info( - self, config_name: Optional[str], overrides: List[str] + self, + config_name: Optional[str], + overrides: List[str], + run_mode: RunMode = RunMode.RUN, ) -> None: self._print_plugins() self._print_plugins_profiling_info(top_n=10) - def _print_all_info(self, config_name: Optional[str], overrides: List[str]) -> None: + def _print_all_info( + self, + config_name: Optional[str], + overrides: List[str], + run_mode: RunMode = RunMode.RUN, + ) -> None: + from .. import __version__ self._log_header(f"Hydra {__version__}", filler="=") self._print_plugins() - self._print_config_info(config_name, overrides) + self._print_config_info(config_name, overrides, run_mode) def _print_defaults_tree_impl( self, @@ -616,20 +661,27 @@ def to_str(node: InputDefault) -> str: log.info(pad + to_str(tree)) def _print_defaults_tree( - self, config_name: Optional[str], overrides: List[str] + self, + config_name: Optional[str], + overrides: List[str], + run_mode: RunMode = RunMode.RUN, ) -> None: assert log is not None defaults = self.config_loader.compute_defaults_list( config_name=config_name, overrides=overrides, - run_mode=RunMode.RUN, + run_mode=run_mode, ) log.info("") self._log_header("Defaults Tree", filler="*") self._print_defaults_tree_impl(defaults.defaults_tree) def show_info( - self, info: str, config_name: Optional[str], overrides: List[str] + self, + info: str, + config_name: Optional[str], + overrides: List[str], + run_mode: RunMode = RunMode.RUN, ) -> None: options = { "all": self._print_all_info, @@ -647,4 +699,6 @@ def show_info( opts = sorted(options.keys()) log.error(f"Info usage: --info [{'|'.join(opts)}]") else: - options[info](config_name=config_name, overrides=overrides) + options[info]( + config_name=config_name, overrides=overrides, run_mode=run_mode + ) diff --git a/hydra/_internal/instantiate/_instantiate2.py b/hydra/_internal/instantiate/_instantiate2.py index da668ff493c..eff8781371d 100644 --- a/hydra/_internal/instantiate/_instantiate2.py +++ b/hydra/_internal/instantiate/_instantiate2.py @@ -1,9 +1,10 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import copy -import sys +import functools from enum import Enum -from typing import Any, Callable, Sequence, Tuple, Union +from textwrap import dedent +from typing import Any, Callable, Dict, List, Sequence, Tuple, Union from omegaconf import OmegaConf, SCMode from omegaconf._utils import is_structured_config @@ -20,6 +21,7 @@ class _Keys(str, Enum): CONVERT = "_convert_" RECURSIVE = "_recursive_" ARGS = "_args_" + PARTIAL = "_partial_" def _is_target(x: Any) -> bool: @@ -30,7 +32,7 @@ def _is_target(x: Any) -> bool: return False -def _extract_pos_args(*input_args: Any, **kwargs: Any) -> Tuple[Any, Any]: +def _extract_pos_args(input_args: Any, kwargs: Any) -> Tuple[Any, Any]: config_args = kwargs.pop(_Keys.ARGS, ()) output_args = config_args @@ -39,16 +41,22 @@ def _extract_pos_args(*input_args: Any, **kwargs: Any) -> Tuple[Any, Any]: output_args = input_args else: raise InstantiationException( - f"Unsupported _args_ type: {type(config_args).__name__}. value: {config_args}" + f"Unsupported _args_ type: '{type(config_args).__name__}'. value: '{config_args}'" ) return output_args, kwargs -def _call_target(_target_: Callable, *args, **kwargs) -> Any: # type: ignore +def _call_target( + _target_: Callable[..., Any], + _partial_: bool, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + full_key: str, +) -> Any: """Call target (type) with args and kwargs.""" try: - args, kwargs = _extract_pos_args(*args, **kwargs) + args, kwargs = _extract_pos_args(args, kwargs) # detaching configs from parent. # At this time, everything is resolved and the parent link can cause # issues when serializing objects in some scenarios. @@ -58,24 +66,45 @@ def _call_target(_target_: Callable, *args, **kwargs) -> Any: # type: ignore for v in kwargs.values(): if OmegaConf.is_config(v): v._set_parent(None) - - return _target_(*args, **kwargs) except Exception as e: - raise type(e)( - f"Error instantiating '{_convert_target_to_string(_target_)}' : {e}" - ).with_traceback(sys.exc_info()[2]) + msg = ( + f"Error in collecting args and kwargs for '{_convert_target_to_string(_target_)}':" + + f"\n{repr(e)}" + ) + if full_key: + msg += f"\nfull_key: {full_key}" + + raise InstantiationException(msg) from e + + if _partial_: + try: + return functools.partial(_target_, *args, **kwargs) + except Exception as e: + msg = ( + f"Error in creating partial({_convert_target_to_string(_target_)}, ...) object:" + + f"\n{repr(e)}" + ) + if full_key: + msg += f"\nfull_key: {full_key}" + raise InstantiationException(msg) from e + else: + try: + return _target_(*args, **kwargs) + except Exception as e: + msg = f"Error in call to target '{_convert_target_to_string(_target_)}':\n{repr(e)}" + if full_key: + msg += f"\nfull_key: {full_key}" + raise InstantiationException(msg) from e def _convert_target_to_string(t: Any) -> Any: - if isinstance(t, type): - return f"{t.__module__}.{t.__name__}" - elif callable(t): + if callable(t): return f"{t.__module__}.{t.__qualname__}" else: return t -def _prepare_input_dict(d: Any) -> Any: +def _prepare_input_dict_or_list(d: Union[Dict[Any, Any], List[Any]]) -> Any: res: Any if isinstance(d, dict): res = {} @@ -83,13 +112,13 @@ def _prepare_input_dict(d: Any) -> Any: if k == "_target_": v = _convert_target_to_string(d["_target_"]) elif isinstance(v, (dict, list)): - v = _prepare_input_dict(v) + v = _prepare_input_dict_or_list(v) res[k] = v elif isinstance(d, list): res = [] for v in d: if isinstance(v, (list, dict)): - v = _prepare_input_dict(v) + v = _prepare_input_dict_or_list(v) res.append(v) else: assert False @@ -97,18 +126,23 @@ def _prepare_input_dict(d: Any) -> Any: def _resolve_target( - target: Union[str, type, Callable[..., Any]] + target: Union[str, type, Callable[..., Any]], full_key: str ) -> Union[type, Callable[..., Any]]: """Resolve target string, type or callable into type or callable.""" if isinstance(target, str): - return _locate(target) - if isinstance(target, type): - return target - if callable(target): - return target - raise InstantiationException( - f"Unsupported target type: {type(target).__name__}. value: {target}" - ) + try: + target = _locate(target) + except Exception as e: + msg = f"Error locating target '{target}', see chained exception above." + if full_key: + msg += f"\nfull_key: {full_key}" + raise InstantiationException(msg) from e + if not callable(target): + msg = f"Expected a callable target, got '{target}' of type '{type(target).__name__}'" + if full_key: + msg += f"\nfull_key: {full_key}" + raise InstantiationException(msg) + return target def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any: @@ -128,7 +162,8 @@ def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any: the exception of Structured Configs (and their fields). all : Passed objects are dicts, lists and primitives without a trace of OmegaConf containers - _args_: List-like of positional arguments + _partial_: If True, return functools.partial wrapped method or object + False by default. Configure per target. :param args: Optional positional parameters pass-through :param kwargs: Optional named parameters to override parameters in the config object. Parameters not present @@ -147,17 +182,23 @@ def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any: if isinstance(config, TargetConf) and config._target_ == "???": # Specific check to give a good warning about failure to annotate _target_ as a string. raise InstantiationException( - f"Missing value for {type(config).__name__}._target_. Check that it's properly annotated and overridden." - f"\nA common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'" + dedent( + f"""\ + Config has missing value for key `_target_`, cannot instantiate. + Config type: {type(config).__name__} + Check that the `_target_` key in your dataclass is properly annotated and overridden. + A common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'""" + ) ) + # TODO: print full key - if isinstance(config, dict): - config = _prepare_input_dict(config) + if isinstance(config, (dict, list)): + config = _prepare_input_dict_or_list(config) - kwargs = _prepare_input_dict(kwargs) + kwargs = _prepare_input_dict_or_list(kwargs) # Structured Config always converted first to OmegaConf - if is_structured_config(config) or isinstance(config, dict): + if is_structured_config(config) or isinstance(config, (dict, list)): config = OmegaConf.structured(config, flags={"allow_objects": True}) if OmegaConf.is_dict(config): @@ -176,11 +217,42 @@ def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any: _recursive_ = config.pop(_Keys.RECURSIVE, True) _convert_ = config.pop(_Keys.CONVERT, ConvertMode.NONE) + _partial_ = config.pop(_Keys.PARTIAL, False) - return instantiate_node(config, *args, recursive=_recursive_, convert=_convert_) + return instantiate_node( + config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_ + ) + elif OmegaConf.is_list(config): + # Finalize config (convert targets to strings, merge with kwargs) + config_copy = copy.deepcopy(config) + config_copy._set_flag( + flags=["allow_objects", "struct", "readonly"], values=[True, False, False] + ) + config_copy._set_parent(config._get_parent()) + config = config_copy + + OmegaConf.resolve(config) + + _recursive_ = kwargs.pop(_Keys.RECURSIVE, True) + _convert_ = kwargs.pop(_Keys.CONVERT, ConvertMode.NONE) + _partial_ = kwargs.pop(_Keys.PARTIAL, False) + + if _partial_: + raise InstantiationException( + "The _partial_ keyword is not compatible with top-level list instantiation" + ) + + return instantiate_node( + config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_ + ) else: raise InstantiationException( - "Top level config has to be OmegaConf DictConfig, plain dict, or a Structured Config class or instance" + dedent( + f"""\ + Cannot instantiate config of type {type(config).__name__}. + Top level config must be an OmegaConf DictConfig/ListConfig object, + a plain dict/list, or a Structured Config class or instance.""" + ) ) @@ -200,6 +272,7 @@ def instantiate_node( *args: Any, convert: Union[str, ConvertMode] = ConvertMode.NONE, recursive: bool = True, + partial: bool = False, ) -> Any: # Return None if config is None if node is None or (OmegaConf.is_config(node) and node._is_none()): @@ -214,9 +287,21 @@ def instantiate_node( # if the key type is incompatible on get. convert = node[_Keys.CONVERT] if _Keys.CONVERT in node else convert recursive = node[_Keys.RECURSIVE] if _Keys.RECURSIVE in node else recursive + partial = node[_Keys.PARTIAL] if _Keys.PARTIAL in node else partial + + full_key = node._get_full_key(None) if not isinstance(recursive, bool): - raise TypeError(f"_recursive_ flag must be a bool, got {type(recursive)}") + msg = f"Instantiation: _recursive_ flag must be a bool, got {type(recursive)}" + if full_key: + msg += f"\nfull_key: {full_key}" + raise TypeError(msg) + + if not isinstance(partial, bool): + msg = f"Instantiation: _partial_ flag must be a bool, got {type( partial )}" + if node and full_key: + msg += f"\nfull_key: {full_key}" + raise TypeError(msg) # If OmegaConf list, create new list of instances if recursive if OmegaConf.is_list(node): @@ -235,9 +320,9 @@ def instantiate_node( return lst elif OmegaConf.is_dict(node): - exclude_keys = set({"_target_", "_convert_", "_recursive_"}) + exclude_keys = set({"_target_", "_convert_", "_recursive_", "_partial_"}) if _is_target(node): - _target_ = _resolve_target(node.get(_Keys.TARGET)) + _target_ = _resolve_target(node.get(_Keys.TARGET), full_key) kwargs = {} for key, value in node.items(): if key not in exclude_keys: @@ -246,11 +331,13 @@ def instantiate_node( value, convert=convert, recursive=recursive ) kwargs[key] = _convert_node(value, convert) - return _call_target(_target_, *args, **kwargs) + + return _call_target(_target_, partial, args, kwargs, full_key) else: # If ALL or PARTIAL non structured, instantiate in dict and resolve interpolations eagerly. if convert == ConvertMode.ALL or ( - convert == ConvertMode.PARTIAL and node._metadata.object_type is None + convert == ConvertMode.PARTIAL + and node._metadata.object_type in (None, dict) ): dict_items = {} for key, value in node.items(): diff --git a/hydra/_internal/utils.py b/hydra/_internal/utils.py index 18f556bacdb..2bf2872e51d 100644 --- a/hydra/_internal/utils.py +++ b/hydra/_internal/utils.py @@ -4,11 +4,12 @@ import logging.config import os import sys +import warnings from dataclasses import dataclass from os.path import dirname, join, normpath, realpath from traceback import print_exc, print_exception from types import FrameType, TracebackType -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple from omegaconf.errors import OmegaConfBaseException @@ -20,7 +21,7 @@ InstantiationException, SearchPathException, ) -from hydra.types import TaskFunction +from hydra.types import RunMode, TaskFunction log = logging.getLogger(__name__) @@ -35,7 +36,7 @@ def _get_module_name_override() -> Optional[str]: def detect_calling_file_or_module_from_task_function( task_function: Any, -) -> Tuple[Optional[str], Optional[str], str]: +) -> Tuple[Optional[str], Optional[str]]: mdl = task_function.__module__ override = _get_module_name_override() @@ -48,12 +49,13 @@ def detect_calling_file_or_module_from_task_function( calling_file = None calling_module = mdl else: - calling_file = task_function.__code__.co_filename + try: + calling_file = inspect.getfile(task_function) + except TypeError: + calling_file = None calling_module = None - task_name = detect_task_name(calling_file, mdl) - - return calling_file, calling_module, task_name + return calling_file, calling_module def detect_calling_file_or_module_from_stack_frame( @@ -296,17 +298,18 @@ class FakeTracebackType: def _run_hydra( + args: argparse.Namespace, args_parser: argparse.ArgumentParser, task_function: TaskFunction, config_path: Optional[str], config_name: Optional[str], + caller_stack_depth: int = 2, ) -> None: from hydra.core.global_hydra import GlobalHydra from .hydra import Hydra - args = args_parser.parse_args() if args.config_name is not None: config_name = args.config_name @@ -316,8 +319,13 @@ def _run_hydra( ( calling_file, calling_module, - task_name, ) = detect_calling_file_or_module_from_task_function(task_function) + if calling_file is None and calling_module is None: + ( + calling_file, + calling_module, + ) = detect_calling_file_or_module_from_stack_frame(caller_stack_depth + 1) + task_name = detect_task_name(calling_file, calling_module) validate_config_path(config_path) @@ -373,21 +381,19 @@ def add_conf_dir() -> None: ) if num_commands == 0: args.run = True - if args.run: - run_and_report( - lambda: hydra.run( - config_name=config_name, - task_function=task_function, - overrides=args.overrides, - ) - ) - elif args.multirun: - run_and_report( - lambda: hydra.multirun( - config_name=config_name, - task_function=task_function, - overrides=args.overrides, - ) + + overrides = args.overrides + + if args.run or args.multirun: + run_mode = hydra.get_mode(config_name=config_name, overrides=overrides) + _run_app( + run=args.run, + multirun=args.multirun, + mode=run_mode, + hydra=hydra, + config_name=config_name, + task_function=task_function, + overrides=overrides, ) elif args.cfg: run_and_report( @@ -416,6 +422,50 @@ def add_conf_dir() -> None: GlobalHydra.instance().clear() +def _run_app( + run: bool, + multirun: bool, + mode: Optional[RunMode], + hydra: Any, + config_name: Optional[str], + task_function: TaskFunction, + overrides: List[str], +) -> None: + if mode is None: + if run: + mode = RunMode.RUN + overrides.extend(["hydra.mode=RUN"]) + else: + mode = RunMode.MULTIRUN + overrides.extend(["hydra.mode=MULTIRUN"]) + else: + if multirun and mode == RunMode.RUN: + warnings.warn( + message="\n" + "\tRunning Hydra app with --multirun, overriding with `hydra.mode=MULTIRUN`.", + category=UserWarning, + ) + mode = RunMode.MULTIRUN + overrides.extend(["hydra.mode=MULTIRUN"]) + + if mode == RunMode.RUN: + run_and_report( + lambda: hydra.run( + config_name=config_name, + task_function=task_function, + overrides=overrides, + ) + ) + else: + run_and_report( + lambda: hydra.multirun( + config_name=config_name, + task_function=task_function, + overrides=overrides, + ) + ) + + def _get_exec_command() -> str: if sys.argv[0].endswith(".py"): return f"python {sys.argv[0]}" @@ -515,6 +565,11 @@ def __repr__(self) -> str: help="Adds an additional config dir to the config search path", ) + parser.add_argument( + "--experimental-rerun", + help="Rerun a job from a previous config pickle", + ) + info_choices = [ "all", "config", @@ -551,7 +606,7 @@ def get_column_widths(matrix: List[List[str]]) -> List[int]: return widths -def _locate(path: str) -> Union[type, Callable[..., Any]]: +def _locate(path: str) -> Any: """ Locate an object by name or dotted path, importing as necessary. This is similar to the pydoc function `locate`, except that it checks for @@ -559,44 +614,50 @@ def _locate(path: str) -> Union[type, Callable[..., Any]]: """ if path == "": raise ImportError("Empty path") - import builtins from importlib import import_module + from types import ModuleType - parts = [part for part in path.split(".") if part] - module = None - for n in reversed(range(len(parts))): + parts = [part for part in path.split(".")] + for part in parts: + if not len(part): + raise ValueError( + f"Error loading '{path}': invalid dotstring." + + "\nRelative imports are not supported." + ) + assert len(parts) > 0 + part0 = parts[0] + try: + obj = import_module(part0) + except Exception as exc_import: + raise ImportError( + f"Error loading '{path}':\n{repr(exc_import)}" + + f"\nAre you sure that module '{part0}' is installed?" + ) from exc_import + for m in range(1, len(parts)): + part = parts[m] try: - mod = ".".join(parts[:n]) - module = import_module(mod) - except Exception as e: - if n == 0: - raise ImportError(f"Error loading module '{path}'") from e - continue - if module: - break - if module: - obj = module - else: - obj = builtins - for part in parts[n:]: - mod = mod + "." + part - if not hasattr(obj, part): - try: - import_module(mod) - except Exception as e: - raise ImportError( - f"Encountered error: `{e}` when loading module '{path}'" - ) from e - obj = getattr(obj, part) - if isinstance(obj, type): - obj_type: type = obj - return obj_type - elif callable(obj): - obj_callable: Callable[..., Any] = obj - return obj_callable - else: - # dummy case - raise ValueError(f"Invalid type ({type(obj)}) found for {path}") + obj = getattr(obj, part) + except AttributeError as exc_attr: + parent_dotpath = ".".join(parts[:m]) + if isinstance(obj, ModuleType): + mod = ".".join(parts[: m + 1]) + try: + obj = import_module(mod) + continue + except ModuleNotFoundError as exc_import: + raise ImportError( + f"Error loading '{path}':\n{repr(exc_import)}" + + f"\nAre you sure that '{part}' is importable from module '{parent_dotpath}'?" + ) from exc_import + except Exception as exc_import: + raise ImportError( + f"Error loading '{path}':\n{repr(exc_import)}" + ) from exc_import + raise ImportError( + f"Error loading '{path}':\n{repr(exc_attr)}" + + f"\nAre you sure that '{part}' is an attribute of '{parent_dotpath}'?" + ) from exc_attr + return obj def _get_cls_name(config: Any, pop: bool = True) -> str: diff --git a/hydra/compose.py b/hydra/compose.py index 08c01b8d91c..fcc89228b71 100644 --- a/hydra/compose.py +++ b/hydra/compose.py @@ -4,6 +4,7 @@ from omegaconf import DictConfig, OmegaConf, open_dict +from hydra import version from hydra.core.global_hydra import GlobalHydra from hydra.types import RunMode @@ -45,15 +46,17 @@ def compose( del cfg["hydra"] if strict is not None: - # DEPRECATED: remove in 1.2 - deprecation_warning( - dedent( - """ - The strict flag in the compose API is deprecated and will be removed in the next version of Hydra. - See https://hydra.cc/docs/upgrades/0.11_to_1.0/strict_mode_flag_deprecated for more info. - """ + if version.base_at_least("1.2"): + raise TypeError("got an unexpected 'strict' argument") + else: + deprecation_warning( + dedent( + """ + The strict flag in the compose API is deprecated. + See https://hydra.cc/docs/upgrades/0.11_to_1.0/strict_mode_flag_deprecated for more info. + """ + ) ) - ) - OmegaConf.set_struct(cfg, strict) + OmegaConf.set_struct(cfg, strict) return cfg diff --git a/hydra/conf/__init__.py b/hydra/conf/__init__.py index 14f1559a970..47504c9b7c4 100644 --- a/hydra/conf/__init__.py +++ b/hydra/conf/__init__.py @@ -5,6 +5,7 @@ from omegaconf import MISSING from hydra.core.config_store import ConfigStore +from hydra.types import RunMode @dataclass @@ -46,6 +47,10 @@ class JobConf: # Job name, populated automatically unless specified by the user (in config or cli) name: str = MISSING + # Change current working dir to the output dir. + # Will be non-optional and default to False in Hydra 1.3 + chdir: Optional[bool] = None + # Populated automatically by Hydra. # Concatenation of job overrides that can be used as a part # of the directory name. @@ -91,8 +96,10 @@ class ConfigSourceInfo: @dataclass class RuntimeConf: version: str = MISSING + version_base: str = MISSING cwd: str = MISSING config_sources: List[ConfigSourceInfo] = MISSING + output_dir: str = MISSING # Composition choices dictionary # Ideally, the value type would be Union[str, List[str], None] @@ -116,6 +123,7 @@ class HydraConf: ] ) + mode: Optional[RunMode] = None # Elements to append to the config search path. # Note: This can only be configured in the primary config. searchpath: List[str] = field(default_factory=list) diff --git a/hydra/conf/hydra/job_logging/default.yaml b/hydra/conf/hydra/job_logging/default.yaml index 3be93fe95a4..5e04211ffd7 100644 --- a/hydra/conf/hydra/job_logging/default.yaml +++ b/hydra/conf/hydra/job_logging/default.yaml @@ -11,8 +11,8 @@ handlers: file: class: logging.FileHandler formatter: simple - # relative to the job log directory - filename: ${hydra.job.name}.log + # absolute file path + filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log root: level: INFO handlers: [console, file] diff --git a/hydra/core/config_loader.py b/hydra/core/config_loader.py index 0babbd14333..9aa0ab4e651 100644 --- a/hydra/core/config_loader.py +++ b/hydra/core/config_loader.py @@ -22,6 +22,7 @@ def load_configuration( overrides: List[str], run_mode: RunMode, from_shell: bool = True, + validate_sweep_overrides: bool = True, ) -> DictConfig: ... diff --git a/hydra/core/config_search_path.py b/hydra/core/config_search_path.py index b2e9772fbe2..dd43a08499a 100644 --- a/hydra/core/config_search_path.py +++ b/hydra/core/config_search_path.py @@ -1,7 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import MutableSequence, Optional +from typing import MutableSequence, Optional, Union class SearchPathElement: @@ -49,7 +49,10 @@ def append( @abstractmethod def prepend( - self, provider: str, path: str, anchor: Optional[SearchPathQuery] = None + self, + provider: str, + path: str, + anchor: Optional[Union[SearchPathQuery, str]] = None, ) -> None: """ Prepends to the search path. diff --git a/hydra/core/plugins.py b/hydra/core/plugins.py index b22d6820e12..ccb807da689 100644 --- a/hydra/core/plugins.py +++ b/hydra/core/plugins.py @@ -6,17 +6,12 @@ import warnings from collections import defaultdict from dataclasses import dataclass, field -from inspect import signature -from textwrap import dedent from timeit import default_timer as timer from typing import Any, Dict, List, Optional, Tuple, Type from omegaconf import DictConfig -from hydra._internal.callbacks import Callbacks -from hydra._internal.deprecation_warning import deprecation_warning from hydra._internal.sources_registry import SourcesRegistry -from hydra.core.config_loader import ConfigLoader from hydra.core.singleton import Singleton from hydra.plugins.completion_plugin import CompletionPlugin from hydra.plugins.config_source import ConfigSource @@ -27,6 +22,15 @@ from hydra.types import HydraContext, TaskFunction from hydra.utils import instantiate +PLUGIN_TYPES: List[Type[Plugin]] = [ + Plugin, + ConfigSource, + CompletionPlugin, + Launcher, + Sweeper, + SearchPathPlugin, +] + @dataclass class ScanStats: @@ -59,19 +63,32 @@ def _initialize(self) -> None: except ImportError: # If no plugins are installed the hydra_plugins package does not exist. pass - self.plugin_type_to_subclass_list, self.stats = self._scan_all_plugins( - modules=top_level - ) + + self.plugin_type_to_subclass_list = defaultdict(list) self.class_name_to_class = {} - for plugin_type, plugins in self.plugin_type_to_subclass_list.items(): - for clazz in plugins: - name = f"{clazz.__module__}.{clazz.__name__}" - self.class_name_to_class[name] = clazz - # Register config sources - for source in self.plugin_type_to_subclass_list[ConfigSource]: - assert issubclass(source, ConfigSource) - SourcesRegistry.instance().register(source) + scanned_plugins, self.stats = self._scan_all_plugins(modules=top_level) + for clazz in scanned_plugins: + self._register(clazz) + + def register(self, clazz: Type[Plugin]) -> None: + """ + Call Plugins.instance().register(MyPlugin) to manually register a plugin class. + """ + if not _is_concrete_plugin_type(clazz): + raise ValueError("Not a valid Hydra Plugin") + self._register(clazz) + + def _register(self, clazz: Type[Plugin]) -> None: + assert _is_concrete_plugin_type(clazz) + for plugin_type in PLUGIN_TYPES: + if issubclass(clazz, plugin_type): + if clazz not in self.plugin_type_to_subclass_list[plugin_type]: + self.plugin_type_to_subclass_list[plugin_type].append(clazz) + name = f"{clazz.__module__}.{clazz.__name__}" + self.class_name_to_class[name] = clazz + if issubclass(clazz, ConfigSource): + SourcesRegistry.instance().register(clazz) def _instantiate(self, config: DictConfig) -> Plugin: from hydra._internal import utils as internal_utils @@ -108,59 +125,6 @@ def is_in_toplevel_plugins_module(clazz: str) -> bool: "hydra._internal.core_plugins." ) - @staticmethod - def _setup_plugin( - plugin: Any, - task_function: TaskFunction, - config: DictConfig, - config_loader: Optional[ConfigLoader] = None, - hydra_context: Optional[HydraContext] = None, - ) -> Any: - """ - With HydraContext introduced in #1581, we need to set up the plugins in a way - that's compatible with both Hydra 1.0 and Hydra 1.1 syntax. - This method should be deleted in the next major release. - """ - assert isinstance(plugin, Sweeper) or isinstance(plugin, Launcher) - assert ( - config_loader is not None or hydra_context is not None - ), "config_loader and hydra_context cannot both be None" - - param_keys = signature(plugin.setup).parameters.keys() - - if "hydra_context" not in param_keys: - # DEPRECATED: remove in 1.2 - # hydra_context will be required in 1.2 - deprecation_warning( - message=dedent( - """ - Plugin's setup() signature has changed in Hydra 1.1. - Support for the old style will be removed in Hydra 1.2. - For more info, check https://github.com/facebookresearch/hydra/pull/1581.""" - ), - ) - config_loader = ( - config_loader - if config_loader is not None - else hydra_context.config_loader # type: ignore - ) - plugin.setup( # type: ignore - config=config, - config_loader=config_loader, - task_function=task_function, - ) - else: - if hydra_context is None: - # hydra_context could be None when an incompatible Sweeper instantiates a compatible Launcher - assert config_loader is not None - hydra_context = HydraContext( - config_loader=config_loader, callbacks=Callbacks() - ) - plugin.setup( - config=config, hydra_context=hydra_context, task_function=task_function - ) - return plugin - def instantiate_sweeper( self, *, @@ -172,56 +136,39 @@ def instantiate_sweeper( if config.hydra.sweeper is None: raise RuntimeError("Hydra sweeper is not configured") sweeper = self._instantiate(config.hydra.sweeper) - sweeper = self._setup_plugin( - plugin=sweeper, - task_function=task_function, - config=config, - config_loader=None, - hydra_context=hydra_context, - ) assert isinstance(sweeper, Sweeper) + sweeper.setup( + hydra_context=hydra_context, task_function=task_function, config=config + ) return sweeper def instantiate_launcher( self, + hydra_context: HydraContext, task_function: TaskFunction, config: DictConfig, - config_loader: Optional[ConfigLoader] = None, - hydra_context: Optional[HydraContext] = None, ) -> Launcher: Plugins.check_usage(self) if config.hydra.launcher is None: raise RuntimeError("Hydra launcher is not configured") launcher = self._instantiate(config.hydra.launcher) - launcher = self._setup_plugin( - plugin=launcher, - config=config, - task_function=task_function, - config_loader=config_loader, - hydra_context=hydra_context, - ) assert isinstance(launcher, Launcher) + launcher.setup( + hydra_context=hydra_context, task_function=task_function, config=config + ) return launcher @staticmethod def _scan_all_plugins( modules: List[Any], - ) -> Tuple[Dict[Type[Plugin], List[Type[Plugin]]], ScanStats]: + ) -> Tuple[List[Type[Plugin]], ScanStats]: stats = ScanStats() stats.total_time = timer() - ret: Dict[Type[Plugin], List[Type[Plugin]]] = defaultdict(list) + scanned_plugins: List[Type[Plugin]] = [] - plugin_types: List[Type[Plugin]] = [ - Plugin, - ConfigSource, - CompletionPlugin, - Launcher, - Sweeper, - SearchPathPlugin, - ] for mdl in modules: for importer, modname, ispkg in pkgutil.walk_packages( path=mdl.__path__, prefix=mdl.__name__ + ".", onerror=lambda x: None @@ -234,7 +181,8 @@ def _scan_all_plugins( if module_name.startswith("_") and not module_name.startswith("__"): continue import_time = timer() - m = importer.find_module(modname) + m = importer.find_module(modname) # type: ignore + assert m is not None with warnings.catch_warnings(record=True) as recorded_warnings: loaded_mod = m.load_module(modname) import_time = timer() - import_time @@ -259,14 +207,8 @@ def _scan_all_plugins( if loaded_mod is not None: for name, obj in inspect.getmembers(loaded_mod): - if ( - inspect.isclass(obj) - and issubclass(obj, Plugin) - and not inspect.isabstract(obj) - ): - for plugin_type in plugin_types: - if issubclass(obj, plugin_type): - ret[plugin_type].append(obj) + if _is_concrete_plugin_type(obj): + scanned_plugins.append(obj) except ImportError as e: warnings.warn( message=f"\n" @@ -278,7 +220,7 @@ def _scan_all_plugins( ) stats.total_time = timer() - stats.total_time - return ret, stats + return scanned_plugins, stats def get_stats(self) -> Optional[ScanStats]: return self.stats @@ -308,3 +250,9 @@ def check_usage(self_: Any) -> None: raise ValueError( f"Plugins is now a Singleton. usage: Plugins.instance().{inspect.stack()[1][3]}(...)" ) + + +def _is_concrete_plugin_type(obj: Any) -> bool: + return ( + inspect.isclass(obj) and issubclass(obj, Plugin) and not inspect.isabstract(obj) + ) diff --git a/hydra/core/utils.py b/hydra/core/utils.py index f7ccac040ba..10ad6e8b300 100644 --- a/hydra/core/utils.py +++ b/hydra/core/utils.py @@ -11,18 +11,16 @@ from os.path import splitext from pathlib import Path from textwrap import dedent -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union, cast +from typing import Any, Dict, Optional, Sequence, Union, cast from omegaconf import DictConfig, OmegaConf, open_dict, read_write +from hydra import version from hydra._internal.deprecation_warning import deprecation_warning from hydra.core.hydra_config import HydraConfig from hydra.core.singleton import Singleton from hydra.types import HydraContext, TaskFunction -if TYPE_CHECKING: - from hydra._internal.callbacks import Callbacks - log = logging.getLogger(__name__) @@ -85,25 +83,17 @@ def filter_overrides(overrides: Sequence[str]) -> Sequence[str]: return [x for x in overrides if not x.startswith("hydra.")] -def _get_callbacks_for_run_job(hydra_context: Optional[HydraContext]) -> "Callbacks": +def _check_hydra_context(hydra_context: Optional[HydraContext]) -> None: if hydra_context is None: - # DEPRECATED: remove in 1.2 - # hydra_context will be required in 1.2 - deprecation_warning( - message=dedent( + # hydra_context is required as of Hydra 1.2. + # We can remove this check in Hydra 1.3. + raise TypeError( + dedent( """ - run_job's signature has changed in Hydra 1.1. Please pass in hydra_context. - Support for the old style will be removed in Hydra 1.2. + run_job's signature has changed: the `hydra_context` arg is now required. For more info, check https://github.com/facebookresearch/hydra/pull/1581.""" ), ) - from hydra._internal.callbacks import Callbacks - - callbacks = Callbacks() - else: - callbacks = hydra_context.callbacks - - return callbacks def run_job( @@ -111,24 +101,35 @@ def run_job( config: DictConfig, job_dir_key: str, job_subdir_key: Optional[str], + hydra_context: HydraContext, configure_logging: bool = True, - hydra_context: Optional[HydraContext] = None, ) -> "JobReturn": - callbacks = _get_callbacks_for_run_job(hydra_context) + _check_hydra_context(hydra_context) + callbacks = hydra_context.callbacks old_cwd = os.getcwd() orig_hydra_cfg = HydraConfig.instance().cfg + + # init Hydra config for config evaluation HydraConfig.instance().set_config(config) - working_dir = str(OmegaConf.select(config, job_dir_key)) + + output_dir = str(OmegaConf.select(config, job_dir_key)) if job_subdir_key is not None: # evaluate job_subdir_key lazily. # this is running on the client side in sweep and contains things such as job:id which # are only available there. subdir = str(OmegaConf.select(config, job_subdir_key)) - working_dir = os.path.join(working_dir, subdir) + output_dir = os.path.join(output_dir, subdir) + + with read_write(config.hydra.runtime): + with open_dict(config.hydra.runtime): + config.hydra.runtime.output_dir = os.path.abspath(output_dir) + + # update Hydra config + HydraConfig.instance().set_config(config) + _chdir = None try: ret = JobReturn() - ret.working_dir = working_dir task_cfg = copy.deepcopy(config) with read_write(task_cfg): with open_dict(task_cfg): @@ -142,14 +143,39 @@ def run_job( assert isinstance(overrides, list) ret.overrides = overrides # handle output directories here - Path(str(working_dir)).mkdir(parents=True, exist_ok=True) - os.chdir(working_dir) + Path(str(output_dir)).mkdir(parents=True, exist_ok=True) + + _chdir = hydra_cfg.hydra.job.chdir + + if _chdir is None: + if version.base_at_least("1.2"): + _chdir = False + + if _chdir is None: + url = "https://hydra.cc/docs/upgrades/1.1_to_1.2/changes_to_job_working_dir" + deprecation_warning( + message=dedent( + f"""\ + Future Hydra versions will no longer change working directory at job runtime by default. + See {url} for more information.""" + ), + stacklevel=2, + ) + _chdir = True + + if _chdir: + os.chdir(output_dir) + ret.working_dir = output_dir + else: + ret.working_dir = os.getcwd() if configure_logging: configure_log(config.hydra.job_logging, config.hydra.verbose) if config.hydra.output_subdir is not None: - hydra_output = Path(config.hydra.output_subdir) + hydra_output = Path(config.hydra.runtime.output_dir) / Path( + config.hydra.output_subdir + ) _save_config(task_cfg, "config.yaml", hydra_output) _save_config(hydra_cfg, "hydra.yaml", hydra_output) _save_config(config.hydra.overrides.task, "overrides.yaml", hydra_output) @@ -172,7 +198,8 @@ def run_job( return ret finally: HydraConfig.instance().cfg = orig_hydra_cfg - os.chdir(old_cwd) + if _chdir: + os.chdir(old_cwd) def get_valid_filename(s: str) -> str: diff --git a/hydra/experimental/callbacks.py b/hydra/experimental/callbacks.py new file mode 100644 index 00000000000..42f60c9c8d6 --- /dev/null +++ b/hydra/experimental/callbacks.py @@ -0,0 +1,54 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import logging +import pickle +from pathlib import Path +from typing import Any + +from omegaconf import DictConfig + +from hydra.core.utils import JobReturn, JobStatus +from hydra.experimental.callback import Callback + + +class LogJobReturnCallback(Callback): + def __init__(self) -> None: + self.log = logging.getLogger(f"{__name__}.{self.__class__.__name__}") + + def on_job_end( + self, config: DictConfig, job_return: JobReturn, **kwargs: Any + ) -> None: + if job_return.status == JobStatus.COMPLETED: + self.log.info(f"Succeeded with return value: {job_return.return_value}") + elif job_return.status == JobStatus.FAILED: + self.log.error("", exc_info=job_return._return_value) + else: + self.log.error("Status unknown. This should never happen.") + + +class PickleJobInfoCallback(Callback): + output_dir: Path + + def __init__(self) -> None: + self.log = logging.getLogger(f"{__name__}.{self.__class__.__name__}") + + def on_job_start(self, config: DictConfig, **kwargs: Any) -> None: + self.output_dir = Path(config.hydra.runtime.output_dir) / Path( + config.hydra.output_subdir + ) + filename = "config.pickle" + self._save_pickle(obj=config, filename=filename, output_dir=self.output_dir) + self.log.info(f"Saving job configs in {self.output_dir / filename}") + + def on_job_end( + self, config: DictConfig, job_return: JobReturn, **kwargs: Any + ) -> None: + filename = "job_return.pickle" + self._save_pickle(obj=job_return, filename=filename, output_dir=self.output_dir) + self.log.info(f"Saving job_return in {self.output_dir / filename}") + + def _save_pickle(self, obj: Any, filename: str, output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + assert output_dir is not None + with open(str(output_dir / filename), "wb") as file: + pickle.dump(obj, file, protocol=4) diff --git a/hydra/extra/pytest_plugin.py b/hydra/extra/pytest_plugin.py index 522a6c76338..f01d5dbca93 100644 --- a/hydra/extra/pytest_plugin.py +++ b/hydra/extra/pytest_plugin.py @@ -1,7 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import copy from pathlib import Path -from typing import Callable, List, Optional +from typing import Callable, Generator, List, Optional from pytest import fixture @@ -10,8 +10,8 @@ from hydra.types import TaskFunction -@fixture(scope="function") # type: ignore -def hydra_restore_singletons() -> None: +@fixture(scope="function") +def hydra_restore_singletons() -> Generator[None, None, None]: """ Restore singletons state after the function returns """ @@ -20,7 +20,7 @@ def hydra_restore_singletons() -> None: Singleton.set_state(state) -@fixture(scope="function") # type: ignore +@fixture(scope="function") def hydra_sweep_runner() -> Callable[ [ Optional[str], @@ -29,7 +29,6 @@ def hydra_sweep_runner() -> Callable[ Optional[str], Optional[str], Optional[List[str]], - Optional[bool], Optional[Path], bool, ], @@ -42,7 +41,6 @@ def _( config_path: Optional[str], config_name: Optional[str], overrides: Optional[List[str]], - strict: Optional[bool] = None, temp_dir: Optional[Path] = None, configure_logging: bool = False, ) -> SweepTaskFunction: @@ -52,7 +50,6 @@ def _( sweep.task_function = task_function sweep.config_path = config_path sweep.config_name = config_name - sweep.strict = strict sweep.overrides = overrides or [] sweep.temp_dir = str(temp_dir) sweep.configure_logging = configure_logging @@ -61,7 +58,7 @@ def _( return _ -@fixture(scope="function") # type: ignore +@fixture(scope="function") def hydra_task_runner() -> Callable[ [ Optional[str], @@ -69,7 +66,6 @@ def hydra_task_runner() -> Callable[ Optional[str], Optional[str], Optional[List[str]], - Optional[bool], bool, ], TaskTestFunction, @@ -80,7 +76,6 @@ def _( config_path: Optional[str], config_name: Optional[str], overrides: Optional[List[str]] = None, - strict: Optional[bool] = None, configure_logging: bool = False, ) -> TaskTestFunction: task = TaskTestFunction() @@ -89,7 +84,6 @@ def _( task.config_name = config_name task.calling_module = calling_module task.config_path = config_path - task.strict = strict task.configure_logging = configure_logging return task diff --git a/hydra/grammar/OverrideLexer.g4 b/hydra/grammar/OverrideLexer.g4 index 3bec861db6b..d5475c72c58 100644 --- a/hydra/grammar/OverrideLexer.g4 +++ b/hydra/grammar/OverrideLexer.g4 @@ -59,7 +59,7 @@ BOOL: NULL: [Nn][Uu][Ll][Ll]; -UNQUOTED_CHAR: [/\-\\+.$%*@?]; // other characters allowed in unquoted strings +UNQUOTED_CHAR: [/\-\\+.$%*@?|]; // other characters allowed in unquoted strings ID: (CHAR|'_') (CHAR|DIGIT|'_')*; // Note: when adding more characters to the ESC rule below, also add them to // the `_ESC` string in `_internal/grammar/utils.py`. diff --git a/hydra/grammar/OverrideParser.g4 b/hydra/grammar/OverrideParser.g4 index 58d819df6ce..e14b776ce1f 100644 --- a/hydra/grammar/OverrideParser.g4 +++ b/hydra/grammar/OverrideParser.g4 @@ -59,7 +59,7 @@ primitive: | FLOAT // 3.14, -20.0, 1e-1, -10e3 | BOOL // true, TrUe, false, False | INTERPOLATION // ${foo.bar}, ${oc.env:USER,me} - | UNQUOTED_CHAR // /, -, \, +, ., $, %, *, @, ? + | UNQUOTED_CHAR // /, -, \, +, ., $, %, *, @, ?, | | COLON // : | ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \, | WS // whitespaces @@ -72,7 +72,7 @@ dictKey: | INT // 0, 10, -20, 1_000_000 | FLOAT // 3.14, -20.0, 1e-1, -10e3 | BOOL // true, TrUe, false, False - | UNQUOTED_CHAR // /, -, \, +, ., $, %, *, @, ? + | UNQUOTED_CHAR // /, -, \, +, ., $, %, *, @, ?, | | ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \, | WS // whitespaces )+; diff --git a/hydra/initialize.py b/hydra/initialize.py index 040fe05c3c7..ed94e736f2e 100644 --- a/hydra/initialize.py +++ b/hydra/initialize.py @@ -4,6 +4,7 @@ from textwrap import dedent from typing import Any, Optional +from hydra import version from hydra._internal.deprecation_warning import deprecation_warning from hydra._internal.hydra import Hydra from hydra._internal.utils import ( @@ -52,24 +53,30 @@ class initialize: def __init__( self, config_path: Optional[str] = _UNSPECIFIED_, + version_base: Optional[str] = _UNSPECIFIED_, job_name: Optional[str] = None, caller_stack_depth: int = 1, ) -> None: self._gh_backup = get_gh_backup() - # DEPRECATED: remove in 1.2 - # in 1.2, the default config_path should be changed to None + version.setbase(version_base) + if config_path is _UNSPECIFIED_: - url = "https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_hydra_main_config_path" - deprecation_warning( - message=dedent( - f"""\ - config_path is not specified in hydra.initialize(). - See {url} for more information.""" - ), - stacklevel=2, - ) - config_path = "." + if version.base_at_least("1.2"): + config_path = None + elif version_base is _UNSPECIFIED_: + url = "https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_hydra_main_config_path" + deprecation_warning( + message=dedent( + f"""\ + config_path is not specified in hydra.initialize(). + See {url} for more information.""" + ), + stacklevel=2, + ) + config_path = "." + else: + config_path = "." if config_path is not None and os.path.isabs(config_path): raise HydraException("config_path in initialize() must be relative") @@ -106,9 +113,16 @@ class initialize_config_module: :param job_name: the value for hydra.job.name (default is 'app') """ - def __init__(self, config_module: str, job_name: str = "app"): + def __init__( + self, + config_module: str, + version_base: Optional[str] = _UNSPECIFIED_, + job_name: str = "app", + ): self._gh_backup = get_gh_backup() + version.setbase(version_base) + Hydra.create_main_hydra_file_or_module( calling_file=None, calling_module=f"{config_module}.{job_name}", @@ -135,8 +149,16 @@ class initialize_config_dir: :param job_name: the value for hydra.job.name (default is 'app') """ - def __init__(self, config_dir: str, job_name: str = "app") -> None: + def __init__( + self, + config_dir: str, + version_base: Optional[str] = _UNSPECIFIED_, + job_name: str = "app", + ) -> None: self._gh_backup = get_gh_backup() + + version.setbase(version_base) + # Relative here would be interpreted as relative to cwd, which - depending on when it run # may have unexpected meaning. best to force an absolute path to avoid confusion. # Can consider using hydra.utils.to_absolute_path() to convert it at a future point if there is demand. diff --git a/hydra/main.py b/hydra/main.py index 8acf2280839..9f58271396f 100644 --- a/hydra/main.py +++ b/hydra/main.py @@ -1,20 +1,51 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import copy import functools +import pickle +import warnings +from pathlib import Path from textwrap import dedent -from typing import Any, Callable, Optional +from typing import Any, Callable, List, Optional -from omegaconf import DictConfig +from omegaconf import DictConfig, open_dict, read_write +from . import version from ._internal.deprecation_warning import deprecation_warning from ._internal.utils import _run_hydra, get_args_parser +from .core.hydra_config import HydraConfig +from .core.utils import _flush_loggers, configure_log from .types import TaskFunction _UNSPECIFIED_: Any = object() +def _get_rerun_conf(file_path: str, overrides: List[str]) -> DictConfig: + msg = "Experimental rerun CLI option, other command line args are ignored." + warnings.warn(msg, UserWarning) + file = Path(file_path) + if not file.exists(): + raise ValueError(f"File {file} does not exist!") + + if len(overrides) > 0: + msg = "Config overrides are not supported as of now." + warnings.warn(msg, UserWarning) + + with open(str(file), "rb") as input: + config = pickle.load(input) # nosec + configure_log(config.hydra.job_logging, config.hydra.verbose) + HydraConfig.instance().set_config(config) + task_cfg = copy.deepcopy(config) + with read_write(task_cfg): + with open_dict(task_cfg): + del task_cfg["hydra"] + assert isinstance(task_cfg, DictConfig) + return task_cfg + + def main( config_path: Optional[str] = _UNSPECIFIED_, config_name: Optional[str] = None, + version_base: Optional[str] = _UNSPECIFIED_, ) -> Callable[[TaskFunction], Any]: """ :param config_path: The config path, a directory relative to the declaring python file. @@ -22,19 +53,24 @@ def main( :param config_name: The name of the config (usually the file name without the .yaml extension) """ - # DEPRECATED: remove in 1.2 - # in 1.2, the default config_path should be changed to None + version.setbase(version_base) + if config_path is _UNSPECIFIED_: - url = "https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_hydra_main_config_path" - deprecation_warning( - message=dedent( - f""" - config_path is not specified in @hydra.main(). - See {url} for more information.""" - ), - stacklevel=2, - ) - config_path = "." + if version.base_at_least("1.2"): + config_path = None + elif version_base is _UNSPECIFIED_: + url = "https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_hydra_main_config_path" + deprecation_warning( + message=dedent( + f""" + config_path is not specified in @hydra.main(). + See {url} for more information.""" + ), + stacklevel=2, + ) + config_path = "." + else: + config_path = "." def main_decorator(task_function: TaskFunction) -> Callable[[], None]: @functools.wraps(task_function) @@ -42,15 +78,22 @@ def decorated_main(cfg_passthrough: Optional[DictConfig] = None) -> Any: if cfg_passthrough is not None: return task_function(cfg_passthrough) else: - args = get_args_parser() - # no return value from run_hydra() as it may sometime actually run the task_function - # multiple times (--multirun) - _run_hydra( - args_parser=args, - task_function=task_function, - config_path=config_path, - config_name=config_name, - ) + args_parser = get_args_parser() + args = args_parser.parse_args() + if args.experimental_rerun is not None: + cfg = _get_rerun_conf(args.experimental_rerun, args.overrides) + task_function(cfg) + _flush_loggers() + else: + # no return value from run_hydra() as it may sometime actually run the task_function + # multiple times (--multirun) + _run_hydra( + args=args, + args_parser=args_parser, + task_function=task_function, + config_path=config_path, + config_name=config_name, + ) return decorated_main diff --git a/hydra/plugins/completion_plugin.py b/hydra/plugins/completion_plugin.py index 0c57920174f..c77e30a286b 100644 --- a/hydra/plugins/completion_plugin.py +++ b/hydra/plugins/completion_plugin.py @@ -158,6 +158,12 @@ def str_rep(in_key: Any, in_value: Any) -> str: return matches def _query_config_groups(self, word: str) -> Tuple[List[str], bool]: + is_addition = word.startswith("+") + is_deletion = word.startswith("~") + if is_addition or is_deletion: + prefix, word = word[0], word[1:] + else: + prefix = "" last_eq_index = word.rfind("=") last_slash_index = word.rfind("/") exact_match: bool = False @@ -191,12 +197,13 @@ def _query_config_groups(self, word: str) -> Tuple[List[str], bool]: dirs = self.config_loader.get_group_options( group_name=name, results_filter=ObjectType.GROUP ) - if len(dirs) == 0 and len(files) > 0: + if len(dirs) == 0 and len(files) > 0 and not is_deletion: name = name + "=" elif len(dirs) > 0 and len(files) == 0: name = name + "/" matched_groups.append(name) + matched_groups = [f"{prefix}{group}" for group in matched_groups] return matched_groups, exact_match def _query(self, config_name: Optional[str], line: str) -> List[str]: diff --git a/hydra/plugins/config_source.py b/hydra/plugins/config_source.py index 2a1025bd549..44ca0e89f71 100644 --- a/hydra/plugins/config_source.py +++ b/hydra/plugins/config_source.py @@ -1,16 +1,16 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import re - from abc import abstractmethod from dataclasses import dataclass -from typing import List, Optional, Dict +from typing import Dict, List, Optional -from hydra.core.default_element import InputDefault -from hydra.errors import HydraException from omegaconf import Container -from hydra.core.object_type import ObjectType +from hydra import version from hydra._internal.deprecation_warning import deprecation_warning +from hydra.core.default_element import InputDefault +from hydra.core.object_type import ObjectType +from hydra.errors import HydraException from hydra.plugins.plugin import Plugin @@ -117,12 +117,14 @@ def full_path(self) -> str: @staticmethod def _normalize_file_name(filename: str) -> str: - if filename.endswith(".yml"): - # DEPRECATED: remove in 1.2 - deprecation_warning( - "Support for .yml files is deprecated. Use .yaml extension for Hydra config files" - ) - if not any(filename.endswith(ext) for ext in [".yaml", ".yml"]): + supported_extensions = [".yaml"] + if not version.base_at_least("1.2"): + supported_extensions.append(".yml") + if filename.endswith(".yml"): + deprecation_warning( + "Support for .yml files is deprecated. Use .yaml extension for Hydra config files" + ) + if not any(filename.endswith(ext) for ext in supported_extensions): filename += ".yaml" return filename diff --git a/hydra/test_utils/completion.py b/hydra/test_utils/completion.py index 2bdc901d66c..bcfd21fa062 100644 --- a/hydra/test_utils/completion.py +++ b/hydra/test_utils/completion.py @@ -4,7 +4,9 @@ import hydra -@hydra.main(config_path="configs/completion_test", config_name="config") +@hydra.main( + version_base=None, config_path="configs/completion_test", config_name="config" +) def run_cli(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/hydra/test_utils/example_app.py b/hydra/test_utils/example_app.py index 77ae9d5cd44..1cc5d9be73f 100644 --- a/hydra/test_utils/example_app.py +++ b/hydra/test_utils/example_app.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path="configs", config_name="db_conf") +@hydra.main(version_base=None, config_path="configs", config_name="db_conf") def run_cli(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/hydra/test_utils/launcher_common_tests.py b/hydra/test_utils/launcher_common_tests.py index 4116e604970..cb13cd988bc 100644 --- a/hydra/test_utils/launcher_common_tests.py +++ b/hydra/test_utils/launcher_common_tests.py @@ -620,7 +620,7 @@ def test_custom_sweeper_run_workdir( integration_test( tmpdir=self.get_test_scratch_dir(tmpdir), task_config=cfg, - overrides=overrides, + overrides=overrides + ["hydra.job.chdir=True"], prints="os.getcwd()", expected_outputs=expected_outputs, generate_custom_cmd=self.generate_custom_cmd(), @@ -649,6 +649,7 @@ def test_to_absolute_path_multirun( ) -> None: expected_dir = "cli_dir/cli_dir_0" overrides = extra_flags + [ + "hydra.job.chdir=True", "hydra.sweep.dir=cli_dir", "hydra.sweep.subdir=cli_dir_${hydra.job.num}", ] diff --git a/hydra/test_utils/test_utils.py b/hydra/test_utils/test_utils.py index 189df5a54b5..9e5a53efb67 100644 --- a/hydra/test_utils/test_utils.py +++ b/hydra/test_utils/test_utils.py @@ -290,7 +290,7 @@ def integration_test( $PROLOG -@hydra.main(config_path='.', config_name='config') +@hydra.main(version_base=None, config_path='.', config_name='config') def experiment(cfg): with open("$OUTPUT_FILE", "w") as f: $PRINTS @@ -342,9 +342,7 @@ def experiment(cfg): expected_outputs ), f"Unexpected number of output lines from {task_file}, output lines:\n\n{file_str}" for idx in range(len(output)): - assert ( - output[idx] == expected_outputs[idx] - ), f"Unexpected output for {prints[idx]} : expected {expected_outputs[idx]}, got {output[idx]}" + assert_regex_match(expected_outputs[idx], output[idx]) # some tests are parsing the file output for more specialized testing. return file_str finally: diff --git a/hydra/utils.py b/hydra/utils.py index 6583722f144..33c1ac1dc8d 100644 --- a/hydra/utils.py +++ b/hydra/utils.py @@ -22,18 +22,25 @@ def get_class(path: str) -> type: try: cls = _locate(path) if not isinstance(cls, type): - raise ValueError(f"Located non-class in {path} : {type(cls).__name__}") + raise ValueError( + f"Located non-class of type '{type(cls).__name__}'" + + f" while loading '{path}'" + ) return cls except Exception as e: - log.error(f"Error initializing class at {path} : {e}") + log.error(f"Error initializing class at {path}: {e}") raise e def get_method(path: str) -> Callable[..., Any]: try: - cl = _locate(path) - if not callable(cl): - raise ValueError(f"Non callable object located : {type(cl).__name__}") + obj = _locate(path) + if not callable(obj): + raise ValueError( + f"Located non-callable of type '{type(obj).__name__}'" + + f" while loading '{path}'" + ) + cl: Callable[..., Any] = obj return cl except Exception as e: log.error(f"Error getting callable at {path} : {e}") diff --git a/hydra/version.py b/hydra/version.py new file mode 100644 index 00000000000..ec6c329b0ea --- /dev/null +++ b/hydra/version.py @@ -0,0 +1,80 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +# Source of truth for Hydra's version + +from textwrap import dedent +from typing import Any, Optional + +from packaging.version import Version + +from . import __version__ +from ._internal.deprecation_warning import deprecation_warning +from .core.singleton import Singleton +from .errors import HydraException + +_UNSPECIFIED_: Any = object() + +__compat_version__: Version = Version("1.1") + + +class VersionBase(metaclass=Singleton): + def __init__(self) -> None: + self.version_base: Optional[Version] = _UNSPECIFIED_ + + def setbase(self, version: "Version") -> None: + assert isinstance( + version, Version + ), f"Unexpected Version type : {type(version)}" + self.version_base = version + + def getbase(self) -> Optional[Version]: + return self.version_base + + @staticmethod + def instance(*args: Any, **kwargs: Any) -> "VersionBase": + return Singleton.instance(VersionBase, *args, **kwargs) # type: ignore + + @staticmethod + def set_instance(instance: "VersionBase") -> None: + assert isinstance(instance, VersionBase) + Singleton._instances[VersionBase] = instance # type: ignore + + +def _get_version(ver: str) -> Version: + # Only consider major.minor as packaging will compare "1.2.0.dev2" < "1.2" + pver = Version(ver) + return Version(f"{pver.major}.{pver.minor}") + + +def base_at_least(ver: str) -> bool: + _version_base = VersionBase.instance().getbase() + if type(_version_base) is type(_UNSPECIFIED_): + VersionBase.instance().setbase(__compat_version__) + _version_base = __compat_version__ + assert isinstance(_version_base, Version) + return _version_base >= _get_version(ver) + + +def getbase() -> Optional[Version]: + return VersionBase.instance().getbase() + + +def setbase(ver: Any) -> None: + if type(ver) is type(_UNSPECIFIED_): + deprecation_warning( + message=dedent( + f""" + The version_base parameter is not specified. + Please specify a compatability version level, or None. + Will assume defaults for version {__compat_version__}""" + ), + stacklevel=3, + ) + _version_base = __compat_version__ + elif ver is None: + _version_base = _get_version(__version__) + else: + _version_base = _get_version(ver) + if _version_base < __compat_version__: + raise HydraException(f'version_base must be >= "{__compat_version__}"') + VersionBase.instance().setbase(_version_base) diff --git a/news/1283.feature b/news/1283.feature new file mode 100644 index 00000000000..64c8fa98139 --- /dev/null +++ b/news/1283.feature @@ -0,0 +1 @@ +Add support to Hydra's instantiation API for creation of `functools.partial` instances via a `_partial_` keyword. diff --git a/news/1376.feature b/news/1376.feature new file mode 100644 index 00000000000..82413c4db74 --- /dev/null +++ b/news/1376.feature @@ -0,0 +1,2 @@ +Support defining basic sweeping in input config. + \ No newline at end of file diff --git a/news/1805.feature b/news/1805.feature new file mode 100644 index 00000000000..b308a29db33 --- /dev/null +++ b/news/1805.feature @@ -0,0 +1 @@ +Add `--experimental-rerun` command-line option to reproduce pickled single runs diff --git a/news/1841.feature b/news/1841.feature new file mode 100644 index 00000000000..cf6dfe105c3 --- /dev/null +++ b/news/1841.feature @@ -0,0 +1 @@ +Implement tab completions for appending to the defaults list (+group=option) and deleting from the defaults list (~group). diff --git a/news/1850.feature b/news/1850.feature new file mode 100644 index 00000000000..69bbbf3063a --- /dev/null +++ b/news/1850.feature @@ -0,0 +1 @@ +Enable the use of the pipe symbol `|` in unquoted strings when parsing command-line overrides. diff --git a/news/1863.feature b/news/1863.feature new file mode 100644 index 00000000000..6a547924624 --- /dev/null +++ b/news/1863.feature @@ -0,0 +1 @@ +Improve clarity of error messages when `hydra.utils.instantiate` encounters a `_target_` that cannot be located diff --git a/news/1882.bugfix b/news/1882.bugfix new file mode 100644 index 00000000000..9e8847666f6 --- /dev/null +++ b/news/1882.bugfix @@ -0,0 +1 @@ +`hydra.runtime.choices` updated correctly during multi-run diff --git a/news/1897.bugfix b/news/1897.bugfix new file mode 100644 index 00000000000..b19bc598a5e --- /dev/null +++ b/news/1897.bugfix @@ -0,0 +1 @@ +hydra.verbose=True now works with multirun. diff --git a/news/1911.api_change b/news/1911.api_change new file mode 100644 index 00000000000..ce18b83d433 --- /dev/null +++ b/news/1911.api_change @@ -0,0 +1 @@ +If user code raises an exception when called by `instantiate`, raise an `InstantiateError` exception instead of an instance of the same exception class that was raised by the user code. diff --git a/news/1914.bugfix b/news/1914.bugfix new file mode 100644 index 00000000000..fe14076e8f9 --- /dev/null +++ b/news/1914.bugfix @@ -0,0 +1 @@ +Fix a resolution error occurring when a nested class is passed as a `_target_` keyword argument to `instantiate` diff --git a/news/1950.feature b/news/1950.feature new file mode 100644 index 00000000000..b57536960e9 --- /dev/null +++ b/news/1950.feature @@ -0,0 +1 @@ +The `instantiate` API now accepts `ListConfig`/`list`-type config as top-level input. diff --git a/news/1952.maintenance b/news/1952.maintenance new file mode 100644 index 00000000000..d0fd668ffe1 --- /dev/null +++ b/news/1952.maintenance @@ -0,0 +1 @@ +For version_base >= 1.2, remove deprecated "old optional" defaults list syntax diff --git a/news/1953.maintenance b/news/1953.maintenance new file mode 100644 index 00000000000..f8c33bd78f8 --- /dev/null +++ b/news/1953.maintenance @@ -0,0 +1 @@ +Remove support for deprecated arg `config_loader` to Plugin.setup, and update signature of `run_job` to require `hydra_context`. diff --git a/news/2042.bugfix b/news/2042.bugfix new file mode 100644 index 00000000000..ec17202c6ed --- /dev/null +++ b/news/2042.bugfix @@ -0,0 +1 @@ +It is now possible to pass other callable objects (besides functions) to `hydra.main`. diff --git a/news/2092.feature b/news/2092.feature new file mode 100644 index 00000000000..753b89c6c8f --- /dev/null +++ b/news/2092.feature @@ -0,0 +1 @@ +Add experimental Callback for pickling job info. diff --git a/news/2099.feature b/news/2099.feature new file mode 100644 index 00000000000..293e325303e --- /dev/null +++ b/news/2099.feature @@ -0,0 +1 @@ +Improve error messages raised in case of instantiation failure. diff --git a/news/2100.feature b/news/2100.feature new file mode 100644 index 00000000000..9964f627f01 --- /dev/null +++ b/news/2100.feature @@ -0,0 +1 @@ +Add callback for logging JobReturn. diff --git a/news/394.config b/news/394.config new file mode 100644 index 00000000000..92c68994154 --- /dev/null +++ b/news/394.config @@ -0,0 +1 @@ +Add hydra.mode config. diff --git a/news/910.feature b/news/910.feature new file mode 100644 index 00000000000..81b7dd3681c --- /dev/null +++ b/news/910.feature @@ -0,0 +1 @@ +Support disable changing working directory at runtime. diff --git a/noxfile.py b/noxfile.py index 65382b7fa2a..a8f3abdddbf 100644 --- a/noxfile.py +++ b/noxfile.py @@ -31,6 +31,7 @@ SKIP_CORE_TESTS = "0" SKIP_CORE_TESTS = os.environ.get("SKIP_CORE_TESTS", SKIP_CORE_TESTS) != "0" +USE_OMEGACONF_DEV_VERSION = os.environ.get("USE_OMEGACONF_DEV_VERSION", "0") != "0" FIX = os.environ.get("FIX", "0") == "1" VERBOSE = os.environ.get("VERBOSE", "0") SILENT = VERBOSE == "0" @@ -41,6 +42,7 @@ class Plugin: name: str path: str module: str + source_dir: str def get_current_os() -> str: @@ -57,6 +59,7 @@ def get_current_os() -> str: print(f"FIX\t\t\t:\t{FIX}") print(f"VERBOSE\t\t\t:\t{VERBOSE}") print(f"INSTALL_EDITABLE_MODE\t:\t{INSTALL_EDITABLE_MODE}") +print(f"USE_OMEGACONF_DEV_VERSION\t:\t{USE_OMEGACONF_DEV_VERSION}") def _upgrade_basic(session): @@ -79,19 +82,30 @@ def find_dirs(path: str): yield fullname +def _print_installed_omegaconf_version(session): + pip_list: str = session.run("pip", "list", silent=True) + for line in pip_list.split("\n"): + if "omegaconf" in line: + print(f"Installed omegaconf version: {line}") + + def install_hydra(session, cmd): # needed for build session.install("read-version", silent=SILENT) # clean install hydra session.chdir(BASE) + if USE_OMEGACONF_DEV_VERSION: + session.install("--pre", "omegaconf", silent=SILENT) session.run(*cmd, ".", silent=SILENT) + if USE_OMEGACONF_DEV_VERSION: + _print_installed_omegaconf_version(session) if not SILENT: session.install("pipdeptree", silent=SILENT) session.run("pipdeptree", "-p", "hydra-core") def pytest_args(*args): - ret = ["pytest", "-Werror"] + ret = ["pytest"] ret.extend(args) return ret @@ -167,11 +181,19 @@ def select_plugins(session, directory: str) -> List[Plugin]: ) continue + if "hydra_plugins" in os.listdir(os.path.join(BASE, directory, plugin["path"])): + module = "hydra_plugins." + plugin["dir_name"] + source_dir = "hydra_plugins" + else: + module = plugin["dir_name"] + source_dir = plugin["dir_name"] + ret.append( Plugin( name=plugin_name, path=plugin["path"], - module="hydra_plugins." + plugin["dir_name"], + source_dir=source_dir, + module=module, ) ) @@ -184,7 +206,6 @@ def select_plugins(session, directory: str) -> List[Plugin]: def install_dev_deps(session): - _upgrade_basic(session) session.run("pip", "install", "-r", "requirements/dev.txt", silent=SILENT) @@ -204,6 +225,7 @@ def _isort_cmd(): @nox.session(python=PYTHON_VERSIONS) def lint(session): + _upgrade_basic(session) install_dev_deps(session) install_hydra(session, ["pip", "install", "-e"]) @@ -227,28 +249,61 @@ def lint(session): "tools/configen/example/gen", "tools/configen/tests/test_modules/expected", "temp", + "build", ] isort = _isort_cmd() + [f"--skip={skip}" for skip in skiplist] session.run(*isort, silent=SILENT) - session.run("mypy", ".", "--strict", silent=SILENT) + session.run( + "mypy", + ".", + "--strict", + "--install-types", + "--non-interactive", + "--exclude=^examples/", + "--exclude=^tests/standalone_apps/", + "--exclude=^tests/test_apps/", + "--exclude=^tools/", + "--exclude=^plugins/", + silent=SILENT, + ) session.run("flake8", "--config", ".flake8") - session.run("yamllint", ".") + session.run("yamllint", "--strict", ".") - example_dirs = [ - "examples/advanced/", + mypy_check_subdirs = [ + "examples/advanced", "examples/configure_hydra", "examples/patterns", "examples/instantiate", "examples/tutorials/basic/your_first_hydra_app", "examples/tutorials/basic/running_your_hydra_app", - "examples/tutorials/structured_configs/", + "examples/tutorials/structured_configs", + "tests/standalone_apps", + "tests/test_apps", ] - for edir in example_dirs: - dirs = find_dirs(path=edir) + for sdir in mypy_check_subdirs: + dirs = find_dirs(path=sdir) for d in dirs: - session.run("mypy", d, "--strict", silent=SILENT) + session.run( + "mypy", + d, + "--strict", + "--install-types", + "--non-interactive", + silent=SILENT, + ) + + for sdir in ["tools"]: # no --strict flag for tools + dirs = find_dirs(path=sdir) + for d in dirs: + session.run( + "mypy", + d, + "--install-types", + "--non-interactive", + silent=SILENT, + ) # lint example plugins lint_plugins_in_dir(session=session, directory="examples/plugins") @@ -259,6 +314,7 @@ def lint(session): @nox.session(python=PYTHON_VERSIONS) def lint_plugins(session): + _upgrade_basic(session) lint_plugins_in_dir(session, "plugins") @@ -279,6 +335,7 @@ def lint_plugins_in_dir(session, directory: str) -> None: # Mypy for plugins for plugin in plugins: path = os.path.join(directory, plugin.path) + source_dir = plugin.source_dir session.chdir(path) session.run(*_black_cmd(), silent=SILENT) session.run(*_isort_cmd(), silent=SILENT) @@ -293,7 +350,9 @@ def lint_plugins_in_dir(session, directory: str) -> None: session.run( "mypy", "--strict", - f"{path}/hydra_plugins", + "--install-types", + "--non-interactive", + f"{path}/{source_dir}", "--config-file", f"{BASE}/.mypy.ini", silent=SILENT, @@ -301,6 +360,8 @@ def lint_plugins_in_dir(session, directory: str) -> None: session.run( "mypy", "--strict", + "--install-types", + "--non-interactive", "--namespace-packages", "--config-file", f"{BASE}/.mypy.ini", @@ -311,8 +372,8 @@ def lint_plugins_in_dir(session, directory: str) -> None: @nox.session(python=PYTHON_VERSIONS) def test_tools(session): - install_cmd = ["pip", "install"] _upgrade_basic(session) + install_cmd = ["pip", "install"] session.install("pytest") install_hydra(session, install_cmd) @@ -370,6 +431,7 @@ def test_core(session): @nox.session(python=PYTHON_VERSIONS) def test_plugins(session): + _upgrade_basic(session) test_plugins_in_directory( session=session, install_cmd=INSTALL_COMMAND, @@ -381,7 +443,6 @@ def test_plugins(session): def test_plugins_in_directory( session, install_cmd, directory: str, test_hydra_core: bool ): - _upgrade_basic(session) session.install("pytest") install_hydra(session, install_cmd) selected_plugin = select_plugins(session=session, directory=directory) @@ -418,13 +479,13 @@ def test_plugins_in_directory( @nox.session(python="3.8") def coverage(session): + _upgrade_basic(session) coverage_env = { "COVERAGE_HOME": BASE, "COVERAGE_FILE": f"{BASE}/.coverage", "COVERAGE_RCFILE": f"{BASE}/.coveragerc", } - _upgrade_basic(session) session.install("coverage", "pytest") install_hydra(session, ["pip", "install", "-e"]) session.run("coverage", "erase", env=coverage_env) @@ -466,13 +527,16 @@ def coverage(session): @nox.session(python=PYTHON_VERSIONS) def test_jupyter_notebooks(session): + _upgrade_basic(session) versions = copy.copy(DEFAULT_PYTHON_VERSIONS) if session.python not in versions: session.skip( f"Not testing Jupyter notebook on Python {session.python}, supports [{','.join(versions)}]" ) - session.install("jupyter", "nbval", "pyzmq") + session.install( + "jupyter", "nbval", "pyzmq", "pytest<7.0.0" + ) # pytest pinned due to https://github.com/computationalmodelling/nbval/issues/180 if platform.system() == "Windows": # Newer versions of pywin32 are causing CI issues on Windows. # see https://github.com/mhammond/pywin32/issues/1709 @@ -480,17 +544,17 @@ def test_jupyter_notebooks(session): install_hydra(session, ["pip", "install", "-e"]) args = pytest_args( - "--nbval", "examples/jupyter_notebooks/compose_configs_in_notebook.ipynb" + "--nbval", + "-W ignore::ResourceWarning", + "examples/jupyter_notebooks/compose_configs_in_notebook.ipynb", ) - # Jupyter notebook test on Windows yield warnings - args = [x for x in args if x != "-Werror"] session.run(*args, silent=SILENT) notebooks_dir = Path("tests/jupyter") for notebook in [ file for file in notebooks_dir.iterdir() if str(file).endswith(".ipynb") ]: - args = pytest_args("--nbval", str(notebook)) + args = pytest_args("--nbval", "-W ignore::ResourceWarning", str(notebook)) args = [x for x in args if x != "-Werror"] session.run(*args, silent=SILENT) diff --git a/plugins/hydra_ax_sweeper/example/banana.py b/plugins/hydra_ax_sweeper/example/banana.py index 05e0ad5c0f5..e8d10cbb5df 100644 --- a/plugins/hydra_ax_sweeper/example/banana.py +++ b/plugins/hydra_ax_sweeper/example/banana.py @@ -8,13 +8,13 @@ log = logging.getLogger(__name__) -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def banana(cfg: DictConfig) -> Any: x = cfg.banana.x y = cfg.banana.y a = 1 b = 100 - z = (a - x) ** 2 + b * ((y - x ** 2) ** 2) + z = (a - x) ** 2 + b * ((y - x**2) ** 2) log.info(f"Banana_Function(x={x}, y={y})={z}") return z diff --git a/plugins/hydra_ax_sweeper/hydra_plugins/hydra_ax_sweeper/__init__.py b/plugins/hydra_ax_sweeper/hydra_plugins/hydra_ax_sweeper/__init__.py index cab3b5a1524..461e882dd9f 100644 --- a/plugins/hydra_ax_sweeper/hydra_plugins/hydra_ax_sweeper/__init__.py +++ b/plugins/hydra_ax_sweeper/hydra_plugins/hydra_ax_sweeper/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -__version__ = "1.2.0dev1" +__version__ = "1.2.0.dev1" diff --git a/plugins/hydra_ax_sweeper/hydra_plugins/hydra_ax_sweeper/_core.py b/plugins/hydra_ax_sweeper/hydra_plugins/hydra_ax_sweeper/_core.py index e442da73411..46976ce6b55 100644 --- a/plugins/hydra_ax_sweeper/hydra_plugins/hydra_ax_sweeper/_core.py +++ b/plugins/hydra_ax_sweeper/hydra_plugins/hydra_ax_sweeper/_core.py @@ -312,6 +312,7 @@ def create_range_param_using_interval_override(override: Override) -> Dict[str, "name": key, "type": "range", "bounds": [value.start, value.end], + "log_scale": "log" in value.tags, } return param diff --git a/plugins/hydra_ax_sweeper/news/1870.feature b/plugins/hydra_ax_sweeper/news/1870.feature new file mode 100644 index 00000000000..102c990fa8e --- /dev/null +++ b/plugins/hydra_ax_sweeper/news/1870.feature @@ -0,0 +1 @@ +Support sampling in log scale. diff --git a/plugins/hydra_ax_sweeper/tests/apps/polynomial.py b/plugins/hydra_ax_sweeper/tests/apps/polynomial.py index 974c95d6cbd..3acc43898a3 100644 --- a/plugins/hydra_ax_sweeper/tests/apps/polynomial.py +++ b/plugins/hydra_ax_sweeper/tests/apps/polynomial.py @@ -5,7 +5,7 @@ from omegaconf import DictConfig -@hydra.main(config_path=".", config_name="polynomial") +@hydra.main(version_base=None, config_path=".", config_name="polynomial") def polynomial(cfg: DictConfig) -> Any: x = cfg.polynomial.x y = cfg.polynomial.y @@ -13,7 +13,7 @@ def polynomial(cfg: DictConfig) -> Any: a = 100 b = 10 c = 1 - result = a * (x ** 2) + b * y + c * z + result = a * (x**2) + b * y + c * z return result diff --git a/plugins/hydra_ax_sweeper/tests/apps/polynomial_with_dict_coefficients.py b/plugins/hydra_ax_sweeper/tests/apps/polynomial_with_dict_coefficients.py index b7e4441955d..b2a6f07017f 100644 --- a/plugins/hydra_ax_sweeper/tests/apps/polynomial_with_dict_coefficients.py +++ b/plugins/hydra_ax_sweeper/tests/apps/polynomial_with_dict_coefficients.py @@ -5,13 +5,15 @@ from omegaconf import DictConfig -@hydra.main(config_path=".", config_name="polynomial_with_dict_coefficients") +@hydra.main( + version_base=None, config_path=".", config_name="polynomial_with_dict_coefficients" +) def polynomial_with_dict_coefficients(cfg: DictConfig) -> Any: coeff = cfg.polynomial.coefficients a = 100 b = 10 c = 1 - return a * (coeff.x ** 2) + b * coeff.y + c * coeff.z + return a * (coeff.x**2) + b * coeff.y + c * coeff.z if __name__ == "__main__": diff --git a/plugins/hydra_ax_sweeper/tests/apps/polynomial_with_list_coefficients.py b/plugins/hydra_ax_sweeper/tests/apps/polynomial_with_list_coefficients.py index 63ae6fa6a98..657d331c42d 100644 --- a/plugins/hydra_ax_sweeper/tests/apps/polynomial_with_list_coefficients.py +++ b/plugins/hydra_ax_sweeper/tests/apps/polynomial_with_list_coefficients.py @@ -5,13 +5,15 @@ from omegaconf import DictConfig -@hydra.main(config_path=".", config_name="polynomial_with_coefficients") +@hydra.main( + version_base=None, config_path=".", config_name="polynomial_with_coefficients" +) def polynomial_with_list_coefficients(cfg: DictConfig) -> Any: x, y, z = cfg.polynomial.coefficients a = 100 b = 10 c = 1 - return a * (x ** 2) + b * y + c * z + return a * (x**2) + b * y + c * z if __name__ == "__main__": diff --git a/plugins/hydra_ax_sweeper/tests/config/params/logscale.yaml b/plugins/hydra_ax_sweeper/tests/config/params/logscale.yaml new file mode 100644 index 00000000000..72e037910a3 --- /dev/null +++ b/plugins/hydra_ax_sweeper/tests/config/params/logscale.yaml @@ -0,0 +1,9 @@ +# @package hydra.sweeper.ax_config.params +quadratic.x: + type: range + bounds: [1e-6, 1] + log_scale: true + +quadratic.y: + type: range + bounds: [-1, 1] diff --git a/plugins/hydra_ax_sweeper/tests/test_ax_sweeper_plugin.py b/plugins/hydra_ax_sweeper/tests/test_ax_sweeper_plugin.py index f4d21f53d2a..b65830b7694 100644 --- a/plugins/hydra_ax_sweeper/tests/test_ax_sweeper_plugin.py +++ b/plugins/hydra_ax_sweeper/tests/test_ax_sweeper_plugin.py @@ -30,7 +30,7 @@ def test_discovery() -> None: def quadratic(cfg: DictConfig) -> Any: - return 100 * (cfg.quadratic.x ** 2) + 1 * cfg.quadratic.y + return 100 * (cfg.quadratic.x**2) + 1 * cfg.quadratic.y @mark.parametrize( @@ -87,14 +87,17 @@ def test_jobs_dirs(hydra_sweep_runner: TSweepRunner) -> None: assert len(dirs) == 6 # and a total of 6 unique output directories -def test_jobs_configured_via_config(hydra_sweep_runner: TSweepRunner) -> None: +@mark.parametrize("test_conf", ["basic", "logscale"]) +def test_jobs_configured_via_config( + hydra_sweep_runner: TSweepRunner, test_conf: str +) -> None: sweep = hydra_sweep_runner( calling_file="tests/test_ax_sweeper_plugin.py", calling_module=None, task_function=quadratic, config_path="config", config_name="config.yaml", - overrides=["hydra/launcher=basic", "params=basic"], + overrides=["hydra/launcher=basic", f"params={test_conf}"], ) with sweep: assert sweep.returns is None @@ -107,7 +110,16 @@ def test_jobs_configured_via_config(hydra_sweep_runner: TSweepRunner) -> None: assert math.isclose(best_parameters["quadratic.y"], -1.0, abs_tol=1e-4) -def test_jobs_configured_via_cmd(hydra_sweep_runner: TSweepRunner) -> None: +@mark.parametrize( + "test_conf, override, expected_x", + [ + ("basic", "int(interval(1, 5))", 1.0), + ("logscale", "tag(log, interval(0.00001, 1))", 0.00001), + ], +) +def test_jobs_configured_via_cmd( + hydra_sweep_runner: TSweepRunner, test_conf: str, override: str, expected_x: float +) -> None: sweep = hydra_sweep_runner( calling_file="tests/test_ax_sweeper_plugin.py", calling_module=None, @@ -116,9 +128,9 @@ def test_jobs_configured_via_cmd(hydra_sweep_runner: TSweepRunner) -> None: config_name="config.yaml", overrides=[ "hydra/launcher=basic", - "quadratic.x=int(interval(-5, -2))", + f"quadratic.x={override}", "quadratic.y=int(interval(-2, 2))", - "params=basic", + f"params={test_conf}", ], ) with sweep: @@ -128,8 +140,8 @@ def test_jobs_configured_via_cmd(hydra_sweep_runner: TSweepRunner) -> None: assert returns["optimizer"] == "ax" assert len(returns) == 2 best_parameters = returns.ax - assert math.isclose(best_parameters["quadratic.x"], -2.0, abs_tol=1e-4) - assert math.isclose(best_parameters["quadratic.y"], 2.0, abs_tol=1e-4) + assert math.isclose(best_parameters["quadratic.x"], expected_x, abs_tol=1e-4) + assert math.isclose(best_parameters["quadratic.y"], -2.0, abs_tol=1e-4) def test_jobs_configured_via_cmd_and_config(hydra_sweep_runner: TSweepRunner) -> None: @@ -205,6 +217,7 @@ def test_ax_logging(tmpdir: Path, cmd_arg: str, expected_str: str) -> None: "tests/apps/polynomial.py", "-m", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", "polynomial.x=interval(-5, -2)", "polynomial.z=10", "hydra.sweeper.ax_config.max_trials=2", @@ -227,6 +240,7 @@ def test_search_space_exhausted_exception(tmpdir: Path, cmd_args: List[str]) -> "tests/apps/polynomial.py", "-m", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", "hydra.sweeper.ax_config.max_trials=2", ] + cmd_args run_python_script(cmd) @@ -260,6 +274,7 @@ def test_jobs_using_choice_between_lists( "tests/apps/polynomial_with_list_coefficients.py", "-m", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", "hydra.sweeper.ax_config.max_trials=3", ] + [cmd_arg] result, _ = run_python_script(cmd) @@ -296,6 +311,7 @@ def test_jobs_using_choice_between_dicts( "tests/apps/polynomial_with_dict_coefficients.py", "-m", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", "hydra.sweeper.ax_config.max_trials=3", ] + [cmd_arg] result, _ = run_python_script(cmd) @@ -309,6 +325,7 @@ def test_example_app(tmpdir: Path) -> None: "example/banana.py", "-m", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", "banana.x=int(interval(-5, 5))", "banana.y=interval(-5, 10.1)", "hydra.sweeper.ax_config.max_trials=2", diff --git a/plugins/hydra_colorlog/hydra_plugins/hydra_colorlog/__init__.py b/plugins/hydra_colorlog/hydra_plugins/hydra_colorlog/__init__.py index cab3b5a1524..461e882dd9f 100644 --- a/plugins/hydra_colorlog/hydra_plugins/hydra_colorlog/__init__.py +++ b/plugins/hydra_colorlog/hydra_plugins/hydra_colorlog/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -__version__ = "1.2.0dev1" +__version__ = "1.2.0.dev1" diff --git a/plugins/hydra_colorlog/tests/test_colorlog.py b/plugins/hydra_colorlog/tests/test_colorlog.py index e0592b91339..3b6bd42d653 100644 --- a/plugins/hydra_colorlog/tests/test_colorlog.py +++ b/plugins/hydra_colorlog/tests/test_colorlog.py @@ -12,7 +12,9 @@ def test_config_installed() -> None: Tests that color options are available for both hydra/hydra_logging and hydra/job_logging """ - with initialize(config_path="../hydra_plugins/hydra_colorlog/conf"): + with initialize( + version_base=None, config_path="../hydra_plugins/hydra_colorlog/conf" + ): config_loader = GlobalHydra.instance().config_loader() assert "colorlog" in config_loader.get_group_options("hydra/job_logging") assert "colorlog" in config_loader.get_group_options("hydra/hydra_logging") diff --git a/plugins/hydra_joblib_launcher/hydra_plugins/hydra_joblib_launcher/__init__.py b/plugins/hydra_joblib_launcher/hydra_plugins/hydra_joblib_launcher/__init__.py index cab3b5a1524..461e882dd9f 100644 --- a/plugins/hydra_joblib_launcher/hydra_plugins/hydra_joblib_launcher/__init__.py +++ b/plugins/hydra_joblib_launcher/hydra_plugins/hydra_joblib_launcher/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -__version__ = "1.2.0dev1" +__version__ = "1.2.0.dev1" diff --git a/plugins/hydra_nevergrad_sweeper/example/my_app.py b/plugins/hydra_nevergrad_sweeper/example/my_app.py index 0da0ac725c5..279fc198209 100644 --- a/plugins/hydra_nevergrad_sweeper/example/my_app.py +++ b/plugins/hydra_nevergrad_sweeper/example/my_app.py @@ -7,7 +7,7 @@ log = logging.getLogger(__name__) -@hydra.main(config_path=".", config_name="config") +@hydra.main(version_base=None, config_path=".", config_name="config") def dummy_training(cfg: DictConfig) -> float: """A dummy function to minimize Minimum is 0.0 at: diff --git a/plugins/hydra_nevergrad_sweeper/hydra_plugins/hydra_nevergrad_sweeper/__init__.py b/plugins/hydra_nevergrad_sweeper/hydra_plugins/hydra_nevergrad_sweeper/__init__.py index cab3b5a1524..461e882dd9f 100644 --- a/plugins/hydra_nevergrad_sweeper/hydra_plugins/hydra_nevergrad_sweeper/__init__.py +++ b/plugins/hydra_nevergrad_sweeper/hydra_plugins/hydra_nevergrad_sweeper/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -__version__ = "1.2.0dev1" +__version__ = "1.2.0.dev1" diff --git a/plugins/hydra_nevergrad_sweeper/setup.py b/plugins/hydra_nevergrad_sweeper/setup.py index 1e629c3b4fb..d0303350e35 100644 --- a/plugins/hydra_nevergrad_sweeper/setup.py +++ b/plugins/hydra_nevergrad_sweeper/setup.py @@ -26,7 +26,7 @@ ], install_requires=[ "hydra-core>=1.1.0.dev7", - "nevergrad>=0.4.3.post2,<0.4.3.post7", # https://github.com/facebookresearch/hydra/issues/1768 + "nevergrad>=0.4.3.post9", "cma==3.0.3", # https://github.com/facebookresearch/hydra/issues/1684 "numpy<1.20.0", # remove once nevergrad is upgraded to support numpy 1.20 ], diff --git a/plugins/hydra_nevergrad_sweeper/tests/test_nevergrad_sweeper_plugin.py b/plugins/hydra_nevergrad_sweeper/tests/test_nevergrad_sweeper_plugin.py index 7915cc1d8ea..d3747df5813 100644 --- a/plugins/hydra_nevergrad_sweeper/tests/test_nevergrad_sweeper_plugin.py +++ b/plugins/hydra_nevergrad_sweeper/tests/test_nevergrad_sweeper_plugin.py @@ -134,6 +134,7 @@ def test_nevergrad_example(with_commandline: bool, tmpdir: Path) -> None: "example/my_app.py", "-m", "hydra.sweep.dir=" + str(tmpdir), + "hydra.job.chdir=True", f"hydra.sweeper.optim.budget={budget}", # small budget to test fast f"hydra.sweeper.optim.num_workers={min(8, budget)}", "hydra.sweeper.optim.seed=12", # avoid random failures diff --git a/plugins/hydra_optuna_sweeper/NEWS.md b/plugins/hydra_optuna_sweeper/NEWS.md index 9b29b7e101b..aa9359dc2aa 100644 --- a/plugins/hydra_optuna_sweeper/NEWS.md +++ b/plugins/hydra_optuna_sweeper/NEWS.md @@ -1,3 +1,11 @@ +1.1.2 (2022-01-23) +======================= + +### Bug Fixes + +- Fix a bug where Optuna Sweeper parses the override value incorrectly ([#1811](https://github.com/facebookresearch/hydra/issues/1811)) + + 1.1.1 (2021-09-01) ======================= diff --git a/plugins/hydra_optuna_sweeper/example/conf/config.yaml b/plugins/hydra_optuna_sweeper/example/conf/config.yaml index 98281c7fd8e..bf1bd345212 100644 --- a/plugins/hydra_optuna_sweeper/example/conf/config.yaml +++ b/plugins/hydra_optuna_sweeper/example/conf/config.yaml @@ -11,16 +11,9 @@ hydra: storage: null n_trials: 20 n_jobs: 1 - - search_space: - x: - type: float - low: -5.5 - high: 5.5 - step: 0.5 - y: - type: categorical - choices: [-5, 0, 5] + params: + x: range(-5.5, 5.5, step=0.5) + y: choice(-5 ,0 ,5) x: 1 y: 1 diff --git a/plugins/hydra_optuna_sweeper/example/custom-search-space-objective.py b/plugins/hydra_optuna_sweeper/example/custom-search-space-objective.py new file mode 100644 index 00000000000..a389a63edbe --- /dev/null +++ b/plugins/hydra_optuna_sweeper/example/custom-search-space-objective.py @@ -0,0 +1,27 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import hydra +from omegaconf import DictConfig +from optuna.trial import Trial + + +@hydra.main(version_base=None, config_path="custom-search-space", config_name="config") +def multi_dimensional_sphere(cfg: DictConfig) -> float: + w: float = cfg.w + x: float = cfg.x + y: float = cfg.y + z: float = cfg.z + return w**2 + x**2 + y**2 + z**2 + + +def configure(cfg: DictConfig, trial: Trial) -> None: + x_value = trial.params["x"] + trial.suggest_float( + "z", + x_value - cfg.max_z_difference_from_x, + x_value + cfg.max_z_difference_from_x, + ) + trial.suggest_float("+w", 0.0, 1.0) # note +w here, not w as w is a new parameter + + +if __name__ == "__main__": + multi_dimensional_sphere() diff --git a/plugins/hydra_optuna_sweeper/example/custom-search-space/config.yaml b/plugins/hydra_optuna_sweeper/example/custom-search-space/config.yaml new file mode 100644 index 00000000000..f11a2aaed95 --- /dev/null +++ b/plugins/hydra_optuna_sweeper/example/custom-search-space/config.yaml @@ -0,0 +1,24 @@ +defaults: + - override hydra/sweeper: optuna + +hydra: + sweeper: + sampler: + seed: 123 + direction: minimize + study_name: custom-search-space + storage: null + n_trials: 20 + n_jobs: 1 + + params: + x: range(-5.5, 5.5, 0.5) + y: choice(-5, 0, 5) + # `custom_search_space` should be a dotpath pointing to a + # callable that provides search-space configuration logic: + custom_search_space: custom-search-space-objective.configure + +x: 1 +y: 1 +z: 100 +max_z_difference_from_x: 0.5 diff --git a/plugins/hydra_optuna_sweeper/example/multi-objective-conf/config.yaml b/plugins/hydra_optuna_sweeper/example/multi-objective-conf/config.yaml index dd86c282bc9..d4cc4f2d749 100644 --- a/plugins/hydra_optuna_sweeper/example/multi-objective-conf/config.yaml +++ b/plugins/hydra_optuna_sweeper/example/multi-objective-conf/config.yaml @@ -11,18 +11,9 @@ hydra: storage: null n_trials: 20 n_jobs: 1 - - search_space: - x: - type: float - low: 0 - high: 5 - step: 0.5 - y: - type: float - low: 0 - high: 3 - step: 0.5 + params: + x: range(0, 5, step=0.5) + y: range(0, 3, step=0.5) x: 1 y: 1 diff --git a/plugins/hydra_optuna_sweeper/example/multi-objective.py b/plugins/hydra_optuna_sweeper/example/multi-objective.py index bbf80d3877b..80d2666c4ec 100644 --- a/plugins/hydra_optuna_sweeper/example/multi-objective.py +++ b/plugins/hydra_optuna_sweeper/example/multi-objective.py @@ -5,12 +5,12 @@ from omegaconf import DictConfig -@hydra.main(config_path="multi-objective-conf", config_name="config") +@hydra.main(version_base=None, config_path="multi-objective-conf", config_name="config") def binh_and_korn(cfg: DictConfig) -> Tuple[float, float]: x: float = cfg.x y: float = cfg.y - v0 = 4 * x ** 2 + 4 * y ** 2 + v0 = 4 * x**2 + 4 * y**2 v1 = (x - 5) ** 2 + (y - 5) ** 2 return v0, v1 diff --git a/plugins/hydra_optuna_sweeper/example/sphere.py b/plugins/hydra_optuna_sweeper/example/sphere.py index dc00daa2672..e93921e937a 100644 --- a/plugins/hydra_optuna_sweeper/example/sphere.py +++ b/plugins/hydra_optuna_sweeper/example/sphere.py @@ -3,11 +3,11 @@ from omegaconf import DictConfig -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def sphere(cfg: DictConfig) -> float: x: float = cfg.x y: float = cfg.y - return x ** 2 + y ** 2 + return x**2 + y**2 if __name__ == "__main__": diff --git a/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/__init__.py b/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/__init__.py index cab3b5a1524..461e882dd9f 100644 --- a/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/__init__.py +++ b/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -__version__ = "1.2.0dev1" +__version__ = "1.2.0.dev1" diff --git a/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/_impl.py b/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/_impl.py index ac05835b43c..266f9b2f569 100644 --- a/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/_impl.py +++ b/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/_impl.py @@ -1,9 +1,22 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import logging import sys -from typing import Any, Dict, List, MutableMapping, MutableSequence, Optional +import warnings +from textwrap import dedent +from typing import ( + Any, + Callable, + Dict, + List, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, +) import optuna +from hydra._internal.deprecation_warning import deprecation_warning from hydra.core.override_parser.overrides_parser import OverridesParser from hydra.core.override_parser.types import ( ChoiceSweep, @@ -15,6 +28,7 @@ from hydra.core.plugins import Plugins from hydra.plugins.sweeper import Sweeper from hydra.types import HydraContext, TaskFunction +from hydra.utils import get_method from omegaconf import DictConfig, OmegaConf from optuna.distributions import ( BaseDistribution, @@ -26,6 +40,7 @@ LogUniformDistribution, UniformDistribution, ) +from optuna.trial import Trial from .config import Direction, DistributionConfig, DistributionType @@ -70,8 +85,8 @@ def create_optuna_distribution_from_override(override: Override) -> Any: assert isinstance(value, ChoiceSweep) for x in override.sweep_iterator(transformer=Transformer.encode): assert isinstance( - x, (str, int, float, bool) - ), f"A choice sweep expects str, int, float, or bool type. Got {type(x)}." + x, (str, int, float, bool, type(None)) + ), f"A choice sweep expects str, int, float, bool, or None type. Got {type(x)}." choices.append(x) return CategoricalDistribution(choices) @@ -82,10 +97,16 @@ def create_optuna_distribution_from_override(override: Override) -> Any: if value.shuffle: for x in override.sweep_iterator(transformer=Transformer.encode): assert isinstance( - x, (str, int, float, bool) - ), f"A choice sweep expects str, int, float, or bool type. Got {type(x)}." + x, (str, int, float, bool, type(None)) + ), f"A choice sweep expects str, int, float, bool, or None type. Got {type(x)}." choices.append(x) return CategoricalDistribution(choices) + if ( + isinstance(value.start, float) + or isinstance(value.stop, float) + or isinstance(value.step, float) + ): + return DiscreteUniformDistribution(value.start, value.stop, value.step) return IntUniformDistribution( int(value.start), int(value.stop), step=int(value.step) ) @@ -106,16 +127,35 @@ def create_optuna_distribution_from_override(override: Override) -> Any: raise NotImplementedError(f"{override} is not supported by Optuna sweeper.") +def create_params_from_overrides( + arguments: List[str], +) -> Tuple[Dict[str, BaseDistribution], Dict[str, Any]]: + parser = OverridesParser.create() + parsed = parser.parse_overrides(arguments) + search_space_distributions = dict() + fixed_params = dict() + for override in parsed: + param_name = override.get_key_element() + value = create_optuna_distribution_from_override(override) + if isinstance(value, BaseDistribution): + search_space_distributions[param_name] = value + else: + fixed_params[param_name] = value + return search_space_distributions, fixed_params + + class OptunaSweeperImpl(Sweeper): def __init__( self, sampler: Any, direction: Any, - storage: Optional[str], + storage: Optional[Any], study_name: Optional[str], n_trials: int, n_jobs: int, search_space: Optional[DictConfig], + custom_search_space: Optional[str], + params: Optional[DictConfig], ) -> None: self.sampler = sampler self.direction = direction @@ -123,15 +163,48 @@ def __init__( self.study_name = study_name self.n_trials = n_trials self.n_jobs = n_jobs - self.search_space = {} - if search_space: - assert isinstance(search_space, DictConfig) - self.search_space = { - str(x): create_optuna_distribution_from_config(y) - for x, y in search_space.items() - } + self.custom_search_space_extender: Optional[ + Callable[[DictConfig, Trial], None] + ] = None + if custom_search_space: + self.custom_search_space_extender = get_method(custom_search_space) + self.search_space = search_space + self.params = params self.job_idx: int = 0 + def _process_searchspace_config(self) -> None: + url = ( + "https://hydra.cc/docs/next/upgrades/1.1_to_1.2/changes_to_sweeper_config/" + ) + if self.params is None and self.search_space is None: + self.params = OmegaConf.create({}) + elif self.search_space is not None: + if self.params is not None: + warnings.warn( + "Both hydra.sweeper.params and hydra.sweeper.search_space are configured." + "\nHydra will use hydra.sweeper.params for defining search space." + f"\n{url}" + ) + self.search_space = None + else: + deprecation_warning( + message=dedent( + f"""\ + `hydra.sweeper.search_space` is deprecated and will be removed in the next major release. + Please configure with `hydra.sweeper.params`. + {url} + """ + ), + ) + self.params = OmegaConf.create( + { + str(x): create_optuna_distribution_from_config(y) + for x, y in self.search_space.items() + } + ) + self.search_space = None + assert self.search_space is None + def setup( self, *, @@ -147,38 +220,70 @@ def setup( ) self.sweep_dir = config.hydra.sweep.dir + def _get_directions(self) -> List[str]: + if isinstance(self.direction, MutableSequence): + return [d.name if isinstance(d, Direction) else d for d in self.direction] + elif isinstance(self.direction, str): + return [self.direction] + return [self.direction.name] + + def _configure_trials( + self, + trials: List[Trial], + search_space_distributions: Dict[str, BaseDistribution], + fixed_params: Dict[str, Any], + ) -> Sequence[Sequence[str]]: + overrides = [] + for trial in trials: + for param_name, distribution in search_space_distributions.items(): + assert type(param_name) is str + trial._suggest(param_name, distribution) + for param_name, value in fixed_params.items(): + trial.set_user_attr(param_name, value) + + if self.custom_search_space_extender: + assert self.config is not None + self.custom_search_space_extender(self.config, trial) + + overlap = trial.params.keys() & trial.user_attrs + if len(overlap): + raise ValueError( + "Overlapping fixed parameters and search space parameters found!" + f"Overlapping parameters: {list(overlap)}" + ) + params = dict(trial.params) + params.update(fixed_params) + + overrides.append(tuple(f"{name}={val}" for name, val in params.items())) + return overrides + + def _parse_sweeper_params_config(self) -> List[str]: + params_conf = [] + assert self.params is not None + for k, v in self.params.items(): + params_conf.append(f"{k}={v}") + return params_conf + def sweep(self, arguments: List[str]) -> None: assert self.config is not None assert self.launcher is not None assert self.hydra_context is not None assert self.job_idx is not None + assert self.search_space is None - parser = OverridesParser.create() - parsed = parser.parse_overrides(arguments) + self._process_searchspace_config() + params_conf = self._parse_sweeper_params_config() + params_conf.extend(arguments) + search_space_distributions, fixed_params = create_params_from_overrides( + params_conf + ) - search_space = dict(self.search_space) - fixed_params = dict() - for override in parsed: - value = create_optuna_distribution_from_override(override) - if isinstance(value, BaseDistribution): - search_space[override.get_key_element()] = value - else: - fixed_params[override.get_key_element()] = value # Remove fixed parameters from Optuna search space. for param_name in fixed_params: - if param_name in search_space: - del search_space[param_name] + if param_name in search_space_distributions: + del search_space_distributions[param_name] - directions: List[str] - if isinstance(self.direction, MutableSequence): - directions = [ - d.name if isinstance(d, Direction) else d for d in self.direction - ] - else: - if isinstance(self.direction, str): - directions = [self.direction] - else: - directions = [self.direction.name] + directions = self._get_directions() study = optuna.create_study( study_name=self.study_name, @@ -199,14 +304,9 @@ def sweep(self, arguments: List[str]) -> None: batch_size = min(n_trials_to_go, batch_size) trials = [study.ask() for _ in range(batch_size)] - overrides = [] - for trial in trials: - for param_name, distribution in search_space.items(): - trial._suggest(param_name, distribution) - - params = dict(trial.params) - params.update(fixed_params) - overrides.append(tuple(f"{name}={val}" for name, val in params.items())) + overrides = self._configure_trials( + trials, search_space_distributions, fixed_params + ) returns = self.launcher.launch(overrides, initial_job_idx=self.job_idx) self.job_idx += len(returns) diff --git a/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/config.py b/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/config.py index ba2e5b30588..b950e3fd234 100644 --- a/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/config.py +++ b/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/config.py @@ -147,7 +147,7 @@ class OptunaSweeperConf: # For example, you can use SQLite if you set 'sqlite:///example.db' # Please refer to the reference for further details # https://optuna.readthedocs.io/en/stable/reference/storages.html - storage: Optional[str] = None + storage: Optional[Any] = None # Name of study to persist optimization results study_name: Optional[str] = None @@ -158,7 +158,15 @@ class OptunaSweeperConf: # Number of parallel workers n_jobs: int = 2 - search_space: Dict[str, Any] = field(default_factory=dict) + search_space: Optional[Dict[str, Any]] = None + + params: Optional[Dict[str, str]] = None + + # Allow custom trial configuration via Python methods. + # If given, `custom_search_space` should be a an instantiate-style dotpath targeting + # a callable with signature Callable[[DictConfig, optuna.trial.Trial], None]. + # https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/002_configurations.html + custom_search_space: Optional[str] = None ConfigStore.instance().store( diff --git a/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/optuna_sweeper.py b/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/optuna_sweeper.py index a9b0072310d..c39723fb062 100644 --- a/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/optuna_sweeper.py +++ b/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/optuna_sweeper.py @@ -15,16 +15,26 @@ def __init__( self, sampler: SamplerConfig, direction: Any, - storage: Optional[str], + storage: Optional[Any], study_name: Optional[str], n_trials: int, n_jobs: int, search_space: Optional[DictConfig], + custom_search_space: Optional[str], + params: Optional[DictConfig], ) -> None: from ._impl import OptunaSweeperImpl self.sweeper = OptunaSweeperImpl( - sampler, direction, storage, study_name, n_trials, n_jobs, search_space + sampler, + direction, + storage, + study_name, + n_trials, + n_jobs, + search_space, + custom_search_space, + params, ) def setup( diff --git a/plugins/hydra_optuna_sweeper/news/1811.bugfix b/plugins/hydra_optuna_sweeper/news/1811.bugfix deleted file mode 100644 index 1dff1871ba4..00000000000 --- a/plugins/hydra_optuna_sweeper/news/1811.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug where Optuna Sweeper parses the override value incorrectly diff --git a/plugins/hydra_optuna_sweeper/news/1890.config b/plugins/hydra_optuna_sweeper/news/1890.config new file mode 100644 index 00000000000..3e802aa8c93 --- /dev/null +++ b/plugins/hydra_optuna_sweeper/news/1890.config @@ -0,0 +1 @@ +Add hydra.sweeper.params and deprecate hydra.sweeper.search_space diff --git a/plugins/hydra_optuna_sweeper/news/1906.feature b/plugins/hydra_optuna_sweeper/news/1906.feature new file mode 100644 index 00000000000..e9003963d15 --- /dev/null +++ b/plugins/hydra_optuna_sweeper/news/1906.feature @@ -0,0 +1 @@ +Add experimental 'custom_search_space' configuration node to allow extending trial objects programmatically. \ No newline at end of file diff --git a/plugins/hydra_optuna_sweeper/setup.py b/plugins/hydra_optuna_sweeper/setup.py index efc477c9457..bd65fca5b2c 100644 --- a/plugins/hydra_optuna_sweeper/setup.py +++ b/plugins/hydra_optuna_sweeper/setup.py @@ -27,8 +27,7 @@ ], install_requires=[ "hydra-core>=1.1.0.dev7", - "optuna>=2.5.0", - "alembic<1.7.0", # https://github.com/facebookresearch/hydra/issues/1806 + "optuna>=2.10.0", ], include_package_data=True, ) diff --git a/plugins/hydra_optuna_sweeper/tests/test_optuna_sweeper_plugin.py b/plugins/hydra_optuna_sweeper/tests/test_optuna_sweeper_plugin.py index f7a16466fe8..4ff19d5fe84 100644 --- a/plugins/hydra_optuna_sweeper/tests/test_optuna_sweeper_plugin.py +++ b/plugins/hydra_optuna_sweeper/tests/test_optuna_sweeper_plugin.py @@ -1,7 +1,9 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import os + +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from functools import partial from pathlib import Path -from typing import Any, List +from typing import Any, List, Optional import optuna from hydra.core.override_parser.overrides_parser import OverridesParser @@ -22,9 +24,12 @@ LogUniformDistribution, UniformDistribution, ) -from pytest import mark +from optuna.samplers import RandomSampler +from pytest import mark, warns from hydra_plugins.hydra_optuna_sweeper import _impl +from hydra_plugins.hydra_optuna_sweeper._impl import OptunaSweeperImpl +from hydra_plugins.hydra_optuna_sweeper.config import Direction from hydra_plugins.hydra_optuna_sweeper.optuna_sweeper import OptunaSweeper chdir_plugin_root() @@ -91,6 +96,7 @@ def test_create_optuna_distribution_from_config(input: Any, expected: Any) -> No ("key=int(interval(1, 5))", IntUniformDistribution(1, 5)), ("key=tag(log, interval(1, 5))", LogUniformDistribution(1, 5)), ("key=tag(log, int(interval(1, 5)))", IntLogUniformDistribution(1, 5)), + ("key=range(0.5, 5.5, step=1)", DiscreteUniformDistribution(0.5, 5.5, 1)), ], ) def test_create_optuna_distribution_from_override(input: Any, expected: Any) -> None: @@ -100,6 +106,32 @@ def test_create_optuna_distribution_from_override(input: Any, expected: Any) -> check_distribution(expected, actual) +@mark.parametrize( + "input, expected", + [ + (["key=choice(1,2)"], ({"key": CategoricalDistribution([1, 2])}, {})), + (["key=5"], ({}, {"key": "5"})), + ( + ["key1=choice(1,2)", "key2=5"], + ({"key1": CategoricalDistribution([1, 2])}, {"key2": "5"}), + ), + ( + ["key1=choice(1,2)", "key2=5", "key3=range(1,3)"], + ( + { + "key1": CategoricalDistribution([1, 2]), + "key3": IntUniformDistribution(1, 3), + }, + {"key2": "5"}, + ), + ), + ], +) +def test_create_params_from_overrides(input: Any, expected: Any) -> None: + actual = _impl.create_params_from_overrides(input) + assert actual == expected + + def test_launch_jobs(hydra_sweep_runner: TSweepRunner) -> None: sweep = hydra_sweep_runner( calling_file=None, @@ -126,6 +158,7 @@ def test_optuna_example(with_commandline: bool, tmpdir: Path) -> None: "example/sphere.py", "--multirun", "hydra.sweep.dir=" + str(tmpdir), + "hydra.job.chdir=True", "hydra.sweeper.n_trials=20", "hydra.sweeper.n_jobs=1", f"hydra.sweeper.storage={storage}", @@ -148,6 +181,7 @@ def test_optuna_example(with_commandline: bool, tmpdir: Path) -> None: assert returns["best_params"]["x"] == best_trial.params["x"] if with_commandline: assert "y" not in returns["best_params"] + assert "y" not in best_trial.params else: assert returns["best_params"]["y"] == best_trial.params["y"] assert returns["best_value"] == best_trial.value @@ -164,6 +198,7 @@ def test_optuna_multi_objective_example(with_commandline: bool, tmpdir: Path) -> "example/multi-objective.py", "--multirun", "hydra.sweep.dir=" + str(tmpdir), + "hydra.job.chdir=True", "hydra.sweeper.n_trials=20", "hydra.sweeper.n_jobs=1", "hydra/sweeper/sampler=random", @@ -198,3 +233,79 @@ def _dominates(values_x: List[float], values_y: List[float]) -> bool: return all(x <= y for x, y in zip(values_x, values_y)) and any( x < y for x, y in zip(values_x, values_y) ) + + +def test_optuna_custom_search_space_example(tmpdir: Path) -> None: + max_z_difference_from_x = 0.3 + cmd = [ + "example/custom-search-space-objective.py", + "--multirun", + "hydra.sweep.dir=" + str(tmpdir), + "hydra.job.chdir=True", + "hydra.sweeper.n_trials=20", + "hydra.sweeper.n_jobs=1", + "hydra/sweeper/sampler=random", + "hydra.sweeper.sampler.seed=123", + f"max_z_difference_from_x={max_z_difference_from_x}", + ] + run_python_script(cmd) + returns = OmegaConf.load(f"{tmpdir}/optimization_results.yaml") + assert isinstance(returns, DictConfig) + assert returns.name == "optuna" + assert ( + abs(returns["best_params"]["x"] - returns["best_params"]["z"]) + <= max_z_difference_from_x + ) + w = returns["best_params"]["+w"] + assert 0 <= w <= 1 + + +@mark.parametrize( + "search_space,params,raise_warning,msg", + [ + (None, None, False, None), + ( + {}, + {}, + True, + r"Both hydra.sweeper.params and hydra.sweeper.search_space are configured.*", + ), + ( + {}, + None, + True, + r"`hydra.sweeper.search_space` is deprecated and will be removed in the next major release.*", + ), + (None, {}, False, None), + ], +) +def test_warnings( + tmpdir: Path, + search_space: Optional[DictConfig], + params: Optional[DictConfig], + raise_warning: bool, + msg: Optional[str], +) -> None: + partial_sweeper = partial( + OptunaSweeperImpl, + sampler=RandomSampler(), + direction=Direction.minimize, + storage=None, + study_name="test", + n_trials=1, + n_jobs=1, + custom_search_space=None, + ) + if search_space is not None: + search_space = OmegaConf.create(search_space) + if params is not None: + params = OmegaConf.create(params) + sweeper = partial_sweeper(search_space=search_space, params=params) + if raise_warning: + with warns( + UserWarning, + match=msg, + ): + sweeper._process_searchspace_config() + else: + sweeper._process_searchspace_config() diff --git a/plugins/hydra_ray_launcher/examples/upload_download/train.py b/plugins/hydra_ray_launcher/examples/upload_download/train.py index 5b0d5598c73..f636e0f123a 100644 --- a/plugins/hydra_ray_launcher/examples/upload_download/train.py +++ b/plugins/hydra_ray_launcher/examples/upload_download/train.py @@ -8,7 +8,7 @@ log = logging.getLogger(__name__) -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def main(cfg: DictConfig) -> None: log.info("Start training...") model = MyModel(cfg.random_seed) diff --git a/plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/__init__.py b/plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/__init__.py index cab3b5a1524..461e882dd9f 100644 --- a/plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/__init__.py +++ b/plugins/hydra_ray_launcher/hydra_plugins/hydra_ray_launcher/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -__version__ = "1.2.0dev1" +__version__ = "1.2.0.dev1" diff --git a/plugins/hydra_ray_launcher/integration_test_tools/README.md b/plugins/hydra_ray_launcher/integration_test_tools/README.md new file mode 100644 index 00000000000..eb3a4a12b4a --- /dev/null +++ b/plugins/hydra_ray_launcher/integration_test_tools/README.md @@ -0,0 +1,21 @@ +## Setting up a new testing AMI for ray launcher. + +To run the tool: + +- Make sure the dependencies in `setup_integration_test_ami.py` matches exactly ray launcher's `setup.py`. + +- Before running the tool, set up your aws profile with admin access to the Hydra test AWS account. +``` +AWS_PROFILE=jieru python create_integration_test_ami.py +``` +You will see a new AMI created in the output +```commandline +ec2.Image(id='ami-0d65d5647e065a180') current state pending +... +``` +Sometimes it could take hours for a new AMI to be created. Proceed to the next step once the +AMI becomes available. + +- Update the `AWS_RAY_AMI` env variable in `tests/test_ray_aws_launcher.py` +- Run the test locally and debug if needed. +- Create a PR and make sure all CI pass! diff --git a/plugins/hydra_ray_launcher/integration_test_tools/create_integration_test_ami.py b/plugins/hydra_ray_launcher/integration_test_tools/create_integration_test_ami.py index e7d670d8747..70be9c76f62 100644 --- a/plugins/hydra_ray_launcher/integration_test_tools/create_integration_test_ami.py +++ b/plugins/hydra_ray_launcher/integration_test_tools/create_integration_test_ami.py @@ -29,7 +29,9 @@ def _run_command(command: str) -> str: return output -@hydra.main(config_name="create_integration_test_ami_config") +@hydra.main( + version_base=None, config_path=".", config_name="create_integration_test_ami_config" +) def set_up_machine(cfg: DictConfig) -> None: security_group_id = cfg.security_group_id assert security_group_id != "", "Security group cannot be empty!" diff --git a/plugins/hydra_ray_launcher/integration_test_tools/create_integration_test_ami_config.yaml b/plugins/hydra_ray_launcher/integration_test_tools/create_integration_test_ami_config.yaml index 47085e7c2be..bea49b03b15 100644 --- a/plugins/hydra_ray_launcher/integration_test_tools/create_integration_test_ami_config.yaml +++ b/plugins/hydra_ray_launcher/integration_test_tools/create_integration_test_ami_config.yaml @@ -2,6 +2,7 @@ security_group_id: sg-095ac7c26aa0d33bb python_versions: - 3.7 - 3.8 + - 3.9 ray_yaml: cluster_name: ray_test_base_AMI min_workers: 0 diff --git a/plugins/hydra_ray_launcher/integration_test_tools/setup_integration_test_ami.py b/plugins/hydra_ray_launcher/integration_test_tools/setup_integration_test_ami.py index 317296992a7..72d955a5b26 100644 --- a/plugins/hydra_ray_launcher/integration_test_tools/setup_integration_test_ami.py +++ b/plugins/hydra_ray_launcher/integration_test_tools/setup_integration_test_ami.py @@ -9,6 +9,7 @@ "ray[default]==1.6.0", "cloudpickle==1.6.0", "pickle5==0.0.11", + "aiohttp!=3.8", ] diff --git a/plugins/hydra_ray_launcher/news/1205.feature b/plugins/hydra_ray_launcher/news/1205.feature new file mode 100644 index 00000000000..20440f60a4f --- /dev/null +++ b/plugins/hydra_ray_launcher/news/1205.feature @@ -0,0 +1 @@ +Add support for python 3.9 diff --git a/plugins/hydra_ray_launcher/pytest.ini b/plugins/hydra_ray_launcher/pytest.ini index 212205028bc..8b1a7e71cbb 100644 --- a/plugins/hydra_ray_launcher/pytest.ini +++ b/plugins/hydra_ray_launcher/pytest.ini @@ -2,4 +2,3 @@ addopts = -p no:warnings log_cli = True log_cli_level = INFO - diff --git a/plugins/hydra_ray_launcher/setup.py b/plugins/hydra_ray_launcher/setup.py index 3fca9b7f7ec..4a64a977977 100644 --- a/plugins/hydra_ray_launcher/setup.py +++ b/plugins/hydra_ray_launcher/setup.py @@ -19,7 +19,7 @@ "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", - # "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.9", "Operating System :: MacOS", "Operating System :: POSIX :: Linux", ], @@ -27,6 +27,8 @@ "boto3==1.17.17", "hydra-core>=1.1.0.dev7", "ray[default]==1.6.0", + # https://github.com/aio-libs/aiohttp/issues/6203 + "aiohttp!=3.8.0", "cloudpickle==1.6.0", "pickle5==0.0.11", ], diff --git a/plugins/hydra_ray_launcher/tests/test_ray_aws_launcher.py b/plugins/hydra_ray_launcher/tests/test_ray_aws_launcher.py index e1c89def69a..2b9a2f557ca 100644 --- a/plugins/hydra_ray_launcher/tests/test_ray_aws_launcher.py +++ b/plugins/hydra_ray_launcher/tests/test_ray_aws_launcher.py @@ -19,6 +19,7 @@ LauncherTestSuite, ) from hydra.test_utils.test_utils import chdir_hydra_root, chdir_plugin_root +from omegaconf import OmegaConf from pytest import fixture, mark from hydra_plugins.hydra_ray_launcher.ray_aws_launcher import ( # type: ignore @@ -54,7 +55,7 @@ aws_not_configured = True -ami = os.environ.get("AWS_RAY_AMI", "ami-0d03f5ce1006a7ed5") +ami = os.environ.get("AWS_RAY_AMI", "ami-0436072b623a028fa") security_group_id = os.environ.get("AWS_RAY_SECURITY_GROUP", "sg-0a12b09a5ff961aee") subnet_id = os.environ.get("AWS_RAY_SUBNET", "subnet-acd2cfe7") instance_role = os.environ.get( @@ -116,32 +117,33 @@ chdir_plugin_root() +def run_command(commands: str) -> str: + log.info(f"running: {commands}") + output = subprocess.getoutput(commands) + log.info(f"outputs: {output}") + return output + + def build_ray_launcher_wheel(tmp_wheel_dir: str) -> str: chdir_hydra_root() plugin = "hydra_ray_launcher" os.chdir(Path("plugins") / plugin) log.info(f"Build wheel for {plugin}, save wheel to {tmp_wheel_dir}.") - subprocess.getoutput( - f"python setup.py sdist bdist_wheel && cp dist/*.whl {tmp_wheel_dir}" - ) + run_command(f"python setup.py sdist bdist_wheel && cp dist/*.whl {tmp_wheel_dir}") log.info("Download all plugin dependency wheels.") - subprocess.getoutput(f"pip download . -d {tmp_wheel_dir}") - plugin_wheel = subprocess.getoutput("ls dist/*.whl").split("/")[-1] + run_command(f"pip download . -d {tmp_wheel_dir}") + plugin_wheel = run_command("ls dist/*.whl").split("/")[-1] chdir_hydra_root() return plugin_wheel def build_core_wheel(tmp_wheel_dir: str) -> str: chdir_hydra_root() - subprocess.getoutput( - f"python setup.py sdist bdist_wheel && cp dist/*.whl {tmp_wheel_dir}" - ) + run_command(f"python setup.py sdist bdist_wheel && cp dist/*.whl {tmp_wheel_dir}") # download dependency wheel for hydra-core - subprocess.getoutput( - f"pip download -r requirements/requirements.txt -d {tmp_wheel_dir}" - ) - wheel = subprocess.getoutput("ls dist/*.whl").split("/")[-1] + run_command(f"pip download -r requirements/requirements.txt -d {tmp_wheel_dir}") + wheel = run_command("ls dist/*.whl").split("/")[-1] return wheel @@ -161,12 +163,16 @@ def upload_and_install_wheels( sdk.run_on_cluster( connect_config, cmd=f"pip install --no-index --find-links={temp_remote_wheel_dir} {temp_remote_wheel_dir}{core_wheel}", + with_output=True, + ) + log.info(f"Install plugin wheel {plugin_wheel} ") + log.info( + f"pip install --no-index --find-links={temp_remote_wheel_dir} {temp_remote_wheel_dir}{plugin_wheel}" ) - - log.info(f"Install plugin wheel {plugin_wheel}") sdk.run_on_cluster( connect_config, cmd=f"pip install --no-index --find-links={temp_remote_wheel_dir} {temp_remote_wheel_dir}{plugin_wheel}", + with_output=True, ) @@ -214,8 +220,8 @@ def manage_cluster() -> Generator[None, None, None]: # build all the wheels tmpdir = tempfile.mkdtemp() - plugin_wheel = build_ray_launcher_wheel(tmpdir) core_wheel = build_core_wheel(tmpdir) + plugin_wheel = build_ray_launcher_wheel(tmpdir) connect_config = { "cluster_name": cluster_name, "provider": { @@ -233,14 +239,28 @@ def manage_cluster() -> Generator[None, None, None]: "head_node": ray_nodes_conf, "worker_nodes": ray_nodes_conf, } + + # save connect_config as yaml, this could be useful for debugging + # you can run `ray attach .yaml` and log on to the AWS cluster for debugging. + conf = OmegaConf.create(connect_config) + with tempfile.NamedTemporaryFile(suffix=".yaml", delete=False) as fp: + OmegaConf.save(config=conf, f=fp.name, resolve=True) + log.info(f"Saving config to {fp.name}") + sdk.create_or_update_cluster( connect_config, ) sdk.run_on_cluster( - connect_config, run_env="auto", cmd=f"mkdir -p {temp_remote_dir}" + connect_config, + run_env="auto", + cmd=f"mkdir -p {temp_remote_dir}", + with_output=True, ) sdk.run_on_cluster( - connect_config, run_env="auto", cmd=f"mkdir -p {temp_remote_wheel_dir}" + connect_config, + run_env="auto", + cmd=f"mkdir -p {temp_remote_wheel_dir}", + with_output=True, ) upload_and_install_wheels(tmpdir, connect_config, core_wheel, plugin_wheel) validate_lib_version(connect_config) diff --git a/plugins/hydra_rq_launcher/hydra_plugins/hydra_rq_launcher/__init__.py b/plugins/hydra_rq_launcher/hydra_plugins/hydra_rq_launcher/__init__.py index cab3b5a1524..461e882dd9f 100644 --- a/plugins/hydra_rq_launcher/hydra_plugins/hydra_rq_launcher/__init__.py +++ b/plugins/hydra_rq_launcher/hydra_plugins/hydra_rq_launcher/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -__version__ = "1.2.0dev1" +__version__ = "1.2.0.dev1" diff --git a/plugins/hydra_submitit_launcher/example/my_app.py b/plugins/hydra_submitit_launcher/example/my_app.py index bc9a1dae551..8f366ff2d6a 100644 --- a/plugins/hydra_submitit_launcher/example/my_app.py +++ b/plugins/hydra_submitit_launcher/example/my_app.py @@ -10,7 +10,7 @@ log = logging.getLogger(__name__) -@hydra.main(config_path=".", config_name="config") +@hydra.main(version_base=None, config_path=".", config_name="config") def my_app(cfg: DictConfig) -> None: env = submitit.JobEnvironment() log.info(f"Process ID {os.getpid()} executing task {cfg.task}, with {env}") diff --git a/plugins/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/__init__.py b/plugins/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/__init__.py index cab3b5a1524..96ec9eecfa9 100644 --- a/plugins/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/__init__.py +++ b/plugins/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -__version__ = "1.2.0dev1" +__version__ = "1.2.0.dev2" diff --git a/plugins/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/config.py b/plugins/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/config.py index 366ae237e7f..01d35a0fc3d 100644 --- a/plugins/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/config.py +++ b/plugins/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/config.py @@ -25,6 +25,8 @@ class BaseQueueConf: nodes: int = 1 # name of the job name: str = "${hydra.job.name}" + # redirect stderr to stdout + stderr_to_stdout: bool = False @dataclass @@ -52,6 +54,7 @@ class SlurmQueueConf(BaseQueueConf): gpus_per_task: Optional[int] = None mem_per_gpu: Optional[str] = None mem_per_cpu: Optional[str] = None + account: Optional[str] = None # Following parameters are submitit specifics # diff --git a/plugins/hydra_submitit_launcher/news/1920.feature b/plugins/hydra_submitit_launcher/news/1920.feature new file mode 100644 index 00000000000..ebb0a441ee3 --- /dev/null +++ b/plugins/hydra_submitit_launcher/news/1920.feature @@ -0,0 +1 @@ +Add support for submitit parameter `account` diff --git a/plugins/hydra_submitit_launcher/news/1967.feature b/plugins/hydra_submitit_launcher/news/1967.feature new file mode 100644 index 00000000000..c2d937747d9 --- /dev/null +++ b/plugins/hydra_submitit_launcher/news/1967.feature @@ -0,0 +1 @@ +Add support for submitit parameter `stderr_to_stdout` diff --git a/plugins/hydra_submitit_launcher/setup.py b/plugins/hydra_submitit_launcher/setup.py index e02e410588b..28c4438c67f 100644 --- a/plugins/hydra_submitit_launcher/setup.py +++ b/plugins/hydra_submitit_launcher/setup.py @@ -26,7 +26,7 @@ ], install_requires=[ "hydra-core>=1.1.0.dev7", - "submitit>=1.0.0", + "submitit>=1.3.3", ], include_package_data=True, ) diff --git a/pyproject.toml b/pyproject.toml index a28ed91f34c..b49ce616bc3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ exclude = ''' | .venv | _build | dist + | build ) ) ''' diff --git a/pytest.ini b/pytest.ini index 6d816713c5e..b1b8c080f83 100644 --- a/pytest.ini +++ b/pytest.ini @@ -9,3 +9,9 @@ norecursedirs = tests/standalone_apps # tested separately under nox tools # tools are tested individually +filterwarnings = + error + ; Remove when default changes + ignore:.*Future Hydra versions will no longer change working directory at job runtime by default.*:UserWarning + ; Jupyter notebook test on Windows yield warnings + ignore:.*Proactor event loop does not implement add_reader family of methods required for zmq.*:RuntimeWarning \ No newline at end of file diff --git a/requirements/dev.txt b/requirements/dev.txt index 3bb2c822695..83a6c8dd25d 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -1,12 +1,12 @@ -r requirements.txt bandit -black==20.8b1 +black>=22.1.0 build coverage flake8 flake8-copyright isort==5.5.2 -mypy==0.790 +mypy nox packaging pre-commit diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 6a7cf19bfc1..79de29cf14c 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,3 +1,4 @@ -omegaconf==2.1.* +omegaconf~=2.1 antlr4-python3-runtime==4.8 importlib-resources;python_version<'3.9' +packaging diff --git a/tests/data.py b/tests/data.py new file mode 100644 index 00000000000..6aa20225508 --- /dev/null +++ b/tests/data.py @@ -0,0 +1,20 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +def foo() -> None: + ... + + +def foo_main_module() -> None: + ... + + +foo_main_module.__module__ = "__main__" + + +class Bar: + ... + + +bar_instance = Bar() + +bar_instance_main_module = Bar() +bar_instance_main_module.__module__ = "__main__" diff --git a/tests/defaults_list/test_defaults_list.py b/tests/defaults_list/test_defaults_list.py index 91374a5950f..b13dce5a990 100644 --- a/tests/defaults_list/test_defaults_list.py +++ b/tests/defaults_list/test_defaults_list.py @@ -5,6 +5,7 @@ from pytest import mark, param, raises, warns +from hydra import version from hydra._internal.defaults_list import create_defaults_list from hydra.core.default_element import ( ConfigDefault, @@ -80,23 +81,47 @@ def test_loaded_defaults_list( ), ], ) -def test_deprecated_optional( - config_path: str, expected_list: List[InputDefault] -) -> None: - repo = create_repo() - warning = dedent( - """ - In optional_deprecated: 'optional: true' is deprecated. - Use 'optional group1: file1' instead. - Support for the old style will be removed in Hydra 1.2""" - ) - with warns( - UserWarning, - match=re.escape(warning), - ): - result = repo.load_config(config_path=config_path) - assert result is not None - assert result.defaults_list == expected_list +class TestDeprecatedOptional: + def test_version_base_1_1( + self, + config_path: str, + expected_list: List[InputDefault], + ) -> None: + curr_base = version.getbase() + version.setbase("1.1") + repo = create_repo() + warning = dedent( + """ + In optional_deprecated: 'optional: true' is deprecated. + Use 'optional group1: file1' instead. + Support for the old style is removed for Hydra version_base >= 1.2""" + ) + with warns( + UserWarning, + match=re.escape(warning), + ): + result = repo.load_config(config_path=config_path) + assert result is not None + assert result.defaults_list == expected_list + version.setbase(str(curr_base)) + + @mark.parametrize("version_base", ["1.2", None]) + def test_version_base_1_2( + self, + config_path: str, + expected_list: List[InputDefault], + version_base: Optional[str], + ) -> None: + curr_base = version.getbase() + version.setbase(version_base) + repo = create_repo() + err = "In optional_deprecated: Too many keys in default item {'group1': 'file1', 'optional': True}" + with raises( + ValueError, + match=re.escape(err), + ): + repo.load_config(config_path=config_path) + version.setbase(str(curr_base)) def _test_defaults_list_impl( diff --git a/tests/instantiate/__init__.py b/tests/instantiate/__init__.py index 135e8ab9fee..f7283b2b432 100644 --- a/tests/instantiate/__init__.py +++ b/tests/instantiate/__init__.py @@ -2,11 +2,52 @@ import collections import collections.abc from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple +from functools import partial +from typing import Any, Dict, List, NoReturn, Optional, Tuple -from omegaconf import MISSING +from omegaconf import MISSING, DictConfig, ListConfig from hydra.types import TargetConf +from tests.instantiate.module_shadowed_by_function import a_function + +module_shadowed_by_function = a_function + + +def _convert_type(obj: Any) -> Any: + if isinstance(obj, DictConfig): + obj = dict(obj) + elif isinstance(obj, ListConfig): + obj = list(obj) + return obj + + +def partial_equal(obj1: Any, obj2: Any) -> bool: + if obj1 == obj2: + return True + + obj1, obj2 = _convert_type(obj1), _convert_type(obj2) + + if type(obj1) != type(obj2): + return False + if isinstance(obj1, dict): + if len(obj1) != len(obj2): + return False + for i in obj1.keys(): + if not partial_equal(obj1[i], obj2[i]): + return False + return True + if isinstance(obj1, list): + if len(obj1) != len(obj2): + return False + return all([partial_equal(obj1[i], obj2[i]) for i in range(len(obj1))]) + if not (isinstance(obj1, partial) and isinstance(obj2, partial)): + return False + return all( + [ + partial_equal(getattr(obj1, attr), getattr(obj2, attr)) + for attr in ["func", "args", "keywords"] + ] + ) class ArgsClass: @@ -26,6 +67,23 @@ def __eq__(self, other: Any) -> Any: return NotImplemented +class OuterClass: + def __init__(self) -> None: + pass + + @staticmethod + def method() -> str: + return "OuterClass.method return" + + class Nested: + def __init__(self) -> None: + pass + + @staticmethod + def method() -> str: + return "OuterClass.Nested.method return" + + def add_values(a: int, b: int) -> int: return a + b @@ -34,6 +92,20 @@ def module_function(x: int) -> int: return x +def module_function2() -> str: + return "fn return" + + +class ExceptionTakingNoArgument(Exception): + def __init__(self) -> None: + """Init method taking only one argument (self)""" + super().__init__("Err message") + + +def raise_exception_taking_no_argument() -> NoReturn: + raise ExceptionTakingNoArgument() + + @dataclass class AClass: a: Any @@ -55,8 +127,9 @@ class BClass: @dataclass -class TargetInParamsClass: +class KeywordsInParamsClass: target: Any + partial: Any @dataclass @@ -192,10 +265,7 @@ def __init__(self, transforms: List[Transform]): self.transforms = transforms def __eq__(self, other: Any) -> Any: - if isinstance(other, type(self)): - return self.transforms == other.transforms - else: - return False + return partial_equal(self.transforms, other.transforms) class Tree: @@ -212,9 +282,9 @@ def __init__(self, value: Any, left: Any = None, right: Any = None) -> None: def __eq__(self, other: Any) -> Any: if isinstance(other, type(self)): return ( - self.value == other.value - and self.left == other.left - and self.right == other.right + partial_equal(self.value, other.value) + and partial_equal(self.left, other.left) + and partial_equal(self.right, other.right) ) else: @@ -236,7 +306,9 @@ def __init__( def __eq__(self, other: Any) -> Any: if isinstance(other, type(self)): - return self.dictionary == other.dictionary and self.value == other.value + return partial_equal(self.dictionary, other.dictionary) and partial_equal( + self.value, other.value + ) else: return False @@ -253,6 +325,7 @@ class TransformConf: @dataclass class CenterCropConf(TransformConf): _target_: str = "tests.instantiate.CenterCrop" + _partial_: bool = False size: int = MISSING @@ -265,12 +338,14 @@ class RotationConf(TransformConf): @dataclass class ComposeConf: _target_: str = "tests.instantiate.Compose" + _partial_: bool = False transforms: List[TransformConf] = MISSING @dataclass class TreeConf: _target_: str = "tests.instantiate.Tree" + _partial_: bool = False left: Optional["TreeConf"] = None right: Optional["TreeConf"] = None value: Any = MISSING @@ -279,10 +354,16 @@ class TreeConf: @dataclass class MappingConf: _target_: str = "tests.instantiate.Mapping" + _partial_: bool = False dictionary: Optional[Dict[str, "MappingConf"]] = None - def __init__(self, dictionary: Optional[Dict[str, "MappingConf"]] = None): + def __init__( + self, + dictionary: Optional[Dict[str, "MappingConf"]] = None, + _partial_: bool = False, + ): self.dictionary = dictionary + self._partial_ = _partial_ @dataclass diff --git a/tests/instantiate/import_error.py b/tests/instantiate/import_error.py new file mode 100644 index 00000000000..816f9acb238 --- /dev/null +++ b/tests/instantiate/import_error.py @@ -0,0 +1,2 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +assert False diff --git a/tests/instantiate/module_shadowed_by_function.py b/tests/instantiate/module_shadowed_by_function.py new file mode 100644 index 00000000000..2930b2e0d0e --- /dev/null +++ b/tests/instantiate/module_shadowed_by_function.py @@ -0,0 +1,3 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +def a_function() -> None: + pass diff --git a/tests/instantiate/test_helpers.py b/tests/instantiate/test_helpers.py index e55e0066b52..7e4de9ef111 100644 --- a/tests/instantiate/test_helpers.py +++ b/tests/instantiate/test_helpers.py @@ -1,12 +1,14 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import datetime import re +from textwrap import dedent from typing import Any from _pytest.python_api import RaisesContext, raises -from pytest import mark +from pytest import mark, param from hydra._internal.utils import _locate -from hydra.utils import get_class +from hydra.utils import get_class, get_method from tests.instantiate import ( AClass, Adam, @@ -16,26 +18,118 @@ Parameters, ) +from .module_shadowed_by_function import a_function + @mark.parametrize( "name,expected", [ + param( + "int", + raises( + ImportError, + match=dedent( + r""" + Error loading 'int': + ModuleNotFoundError\("No module named 'int'",?\) + Are you sure that module 'int' is installed\? + """ + ).strip(), + ), + id="int", + ), + param("builtins.int", int, id="builtins_explicit"), + param("builtins.int.from_bytes", int.from_bytes, id="method_of_builtin"), + param( + "builtins.int.not_found", + raises( + ImportError, + match=dedent( + r""" + Error loading 'builtins\.int\.not_found': + AttributeError\("type object 'int' has no attribute 'not_found'",?\) + Are you sure that 'not_found' is an attribute of 'builtins\.int'\? + """ + ).strip(), + ), + id="builtin_attribute_error", + ), + param( + "datetime", + datetime, + id="top_level_module", + ), ("tests.instantiate.Adam", Adam), ("tests.instantiate.Parameters", Parameters), ("tests.instantiate.AClass", AClass), + param( + "tests.instantiate.AClass.static_method", + AClass.static_method, + id="staticmethod", + ), + param( + "tests.instantiate.AClass.not_found", + raises( + ImportError, + match=dedent( + r""" + Error loading 'tests\.instantiate\.AClass\.not_found': + AttributeError\("type object 'AClass' has no attribute 'not_found'",?\) + Are you sure that 'not_found' is an attribute of 'tests\.instantiate\.AClass'\? + """ + ).strip(), + ), + id="class_attribute_error", + ), ("tests.instantiate.ASubclass", ASubclass), ("tests.instantiate.NestingClass", NestingClass), ("tests.instantiate.AnotherClass", AnotherClass), - ("", raises(ImportError, match=re.escape("Empty path"))), - [ - "not_found", - raises(ImportError, match=re.escape("Error loading module 'not_found'")), - ], - ( + ("tests.instantiate.module_shadowed_by_function", a_function), + param( + "", + raises(ImportError, match=("Empty path")), + id="invalid-path-empty", + ), + param( + "toplevel_not_found", + raises( + ImportError, + match=dedent( + r""" + Error loading 'toplevel_not_found': + ModuleNotFoundError\("No module named 'toplevel_not_found'",?\) + Are you sure that module 'toplevel_not_found' is installed\? + """ + ).strip(), + ), + id="toplevel_not_found", + ), + param( "tests.instantiate.b.c.Door", raises( - ImportError, match=re.escape("No module named 'tests.instantiate.b'") + ImportError, + match=dedent( + r""" + Error loading 'tests\.instantiate\.b\.c\.Door': + ModuleNotFoundError\("No module named 'tests\.instantiate\.b'",?\) + Are you sure that 'b' is importable from module 'tests\.instantiate'\?""" + ).strip(), + ), + id="nested_not_found", + ), + param( + "tests.instantiate.import_error", + raises( + ImportError, + match=re.escape( + dedent( + """\ + Error loading 'tests.instantiate.import_error': + AssertionError()""" + ) + ), ), + id="import_assertion_error", ), ], ) @@ -47,6 +141,75 @@ def test_locate(name: str, expected: Any) -> None: assert _locate(name) == expected -@mark.parametrize("path,expected_type", [("tests.instantiate.AClass", AClass)]) -def test_get_class(path: str, expected_type: type) -> None: - assert get_class(path) == expected_type +@mark.parametrize( + "name", + [ + param(".", id="invalid-path-period"), + param("..", id="invalid-path-period2"), + param(".mod", id="invalid-path-relative"), + param("..mod", id="invalid-path-relative2"), + param("mod.", id="invalid-path-trailing-dot"), + param("mod..another", id="invalid-path-two-dots"), + ], +) +def test_locate_relative_import_fails(name: str) -> None: + with raises( + ValueError, + match=r"Error loading '.*': invalid dotstring\." + + re.escape("\nRelative imports are not supported."), + ): + _locate(name) + + +@mark.parametrize( + "path,expected", + [ + param("tests.instantiate.AClass", AClass, id="class"), + param("builtins.print", print, id="callable"), + param( + "datetime", + raises( + ValueError, + match="Located non-callable of type 'module' while loading 'datetime'", + ), + id="module-error", + ), + ], +) +def test_get_method(path: str, expected: Any) -> None: + if isinstance(expected, RaisesContext): + with expected: + get_method(path) + else: + assert get_method(path) == expected + + +@mark.parametrize( + "path,expected", + [ + param("tests.instantiate.AClass", AClass, id="class"), + param( + "builtins.print", + raises( + ValueError, + match="Located non-class of type 'builtin_function_or_method'" + + " while loading 'builtins.print'", + ), + id="callable-error", + ), + param( + "datetime", + raises( + ValueError, + match="Located non-class of type 'module' while loading 'datetime'", + ), + id="module-error", + ), + ], +) +def test_get_class(path: str, expected: Any) -> None: + if isinstance(expected, RaisesContext): + with expected: + get_class(path) + else: + assert get_class(path) == expected diff --git a/tests/instantiate/test_instantiate.py b/tests/instantiate/test_instantiate.py index b223fcb5e0a..cd6b8f05da8 100644 --- a/tests/instantiate/test_instantiate.py +++ b/tests/instantiate/test_instantiate.py @@ -3,13 +3,16 @@ import pickle import re from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple +from functools import partial +from textwrap import dedent +from typing import Any, Callable, Dict, List, Optional, Tuple from omegaconf import MISSING, DictConfig, ListConfig, OmegaConf from pytest import fixture, mark, param, raises, warns import hydra from hydra.errors import InstantiationException +from hydra.test_utils.test_utils import assert_multiline_regex_search from hydra.types import ConvertMode, TargetConf from tests.instantiate import ( AClass, @@ -25,10 +28,12 @@ Compose, ComposeConf, IllegalType, + KeywordsInParamsClass, Mapping, MappingConf, NestedConf, NestingClass, + OuterClass, Parameters, Rotation, RotationConf, @@ -37,12 +42,15 @@ SimpleClassNonPrimitiveConf, SimpleClassPrimitiveConf, SimpleDataClass, - TargetInParamsClass, Tree, TreeConf, UntypedPassthroughClass, UntypedPassthroughConf, User, + add_values, + module_function, + module_function2, + partial_equal, recisinstance, ) @@ -94,25 +102,102 @@ def config(request: Any, src: Any) -> Any: AClass(10, 20, 30, 40), id="class", ), + param( + { + "_target_": "tests.instantiate.AClass", + "_partial_": True, + "a": 10, + "b": 20, + "c": 30, + }, + {}, + partial(AClass, a=10, b=20, c=30), + id="class+partial", + ), + param( + [ + { + "_target_": "tests.instantiate.AClass", + "_partial_": True, + "a": 10, + "b": 20, + "c": 30, + }, + { + "_target_": "tests.instantiate.BClass", + "a": 50, + "b": 60, + "c": 70, + }, + ], + {}, + [partial(AClass, a=10, b=20, c=30), BClass(a=50, b=60, c=70)], + id="list_of_partial_class", + ), param( {"_target_": "tests.instantiate.AClass", "b": 20, "c": 30}, {"a": 10, "d": 40}, AClass(10, 20, 30, 40), id="class+override", ), + param( + {"_target_": "tests.instantiate.AClass", "b": 20, "c": 30}, + {"a": 10, "_partial_": True}, + partial(AClass, a=10, b=20, c=30), + id="class+override+partial1", + ), + param( + { + "_target_": "tests.instantiate.AClass", + "_partial_": True, + "c": 30, + }, + {"a": 10, "d": 40}, + partial(AClass, a=10, c=30, d=40), + id="class+override+partial2", + ), param( {"_target_": "tests.instantiate.AClass", "b": 200, "c": "${b}"}, {"a": 10, "b": 99, "d": 40}, AClass(10, 99, 99, 40), id="class+override+interpolation", ), + param( + {"_target_": "tests.instantiate.AClass", "b": 200, "c": "${b}"}, + {"a": 10, "b": 99, "_partial_": True}, + partial(AClass, a=10, b=99, c=99), + id="class+override+interpolation+partial1", + ), + param( + { + "_target_": "tests.instantiate.AClass", + "b": 200, + "_partial_": True, + "c": "${b}", + }, + {"a": 10, "b": 99}, + partial(AClass, a=10, b=99, c=99), + id="class+override+interpolation+partial2", + ), # Check class and static methods + param( + {"_target_": "tests.instantiate.ASubclass.class_method", "_partial_": True}, + {}, + partial(ASubclass.class_method), + id="class_method+partial", + ), param( {"_target_": "tests.instantiate.ASubclass.class_method", "y": 10}, {}, ASubclass(11), id="class_method", ), + param( + {"_target_": "tests.instantiate.AClass.static_method", "_partial_": True}, + {}, + partial(AClass.static_method), + id="static_method+partial", + ), param( {"_target_": "tests.instantiate.AClass.static_method", "z": 43}, {}, @@ -126,12 +211,27 @@ def config(request: Any, src: Any) -> Any: NestingClass(ASubclass(10)), id="class_with_nested_class", ), + param( + {"_target_": "tests.instantiate.nesting.a.class_method", "_partial_": True}, + {}, + partial(ASubclass.class_method), + id="class_method_on_an_object_nested_in_a_global+partial", + ), param( {"_target_": "tests.instantiate.nesting.a.class_method", "y": 10}, {}, ASubclass(11), id="class_method_on_an_object_nested_in_a_global", ), + param( + { + "_target_": "tests.instantiate.nesting.a.static_method", + "_partial_": True, + }, + {}, + partial(ASubclass.static_method), + id="static_method_on_an_object_nested_in_a_global+partial", + ), param( {"_target_": "tests.instantiate.nesting.a.static_method", "z": 43}, {}, @@ -139,6 +239,12 @@ def config(request: Any, src: Any) -> Any: id="static_method_on_an_object_nested_in_a_global", ), # Check that default value is respected + param( + {"_target_": "tests.instantiate.AClass"}, + {"a": 10, "b": 20, "_partial_": True, "d": "new_default"}, + partial(AClass, a=10, b=20, d="new_default"), + id="instantiate_respects_default_value+partial", + ), param( {"_target_": "tests.instantiate.AClass"}, {"a": 10, "b": 20, "c": 30}, @@ -146,6 +252,15 @@ def config(request: Any, src: Any) -> Any: id="instantiate_respects_default_value", ), # call a function from a module + param( + { + "_target_": "tests.instantiate.module_function", + "_partial_": True, + }, + {}, + partial(module_function), + id="call_function_in_module", + ), param( {"_target_": "tests.instantiate.module_function", "x": 43}, {}, @@ -153,6 +268,12 @@ def config(request: Any, src: Any) -> Any: id="call_function_in_module", ), # Check builtins + param( + {"_target_": "builtins.int", "base": 2, "_partial_": True}, + {}, + partial(int, base=2), + id="builtin_types+partial", + ), param( {"_target_": "builtins.str", "object": 43}, {}, @@ -166,12 +287,34 @@ def config(request: Any, src: Any) -> Any: AClass(a=10, b=20, c=30), id="passthrough", ), + param( + {"_target_": "tests.instantiate.AClass"}, + {"a": 10, "b": 20, "_partial_": True}, + partial(AClass, a=10, b=20), + id="passthrough+partial", + ), param( {"_target_": "tests.instantiate.AClass"}, {"a": 10, "b": 20, "c": 30, "d": {"x": IllegalType()}}, AClass(a=10, b=20, c=30, d={"x": IllegalType()}), id="oc_incompatible_passthrough", ), + param( + {"_target_": "tests.instantiate.AClass", "_partial_": True}, + {"a": 10, "b": 20, "d": {"x": IllegalType()}}, + partial(AClass, a=10, b=20, d={"x": IllegalType()}), + id="oc_incompatible_passthrough+partial", + ), + param( + {"_target_": "tests.instantiate.AClass", "_partial_": True}, + { + "a": 10, + "b": 20, + "d": {"x": [10, IllegalType()]}, + }, + partial(AClass, a=10, b=20, d={"x": [10, IllegalType()]}), + id="passthrough:list+partial", + ), param( {"_target_": "tests.instantiate.AClass"}, { @@ -190,10 +333,32 @@ def config(request: Any, src: Any) -> Any: id="untyped_passthrough", ), param( - TargetInParamsClass, - {"target": "test"}, - TargetInParamsClass(target="test"), - id="target_in_params", + KeywordsInParamsClass, + {"target": "foo", "partial": "bar"}, + KeywordsInParamsClass(target="foo", partial="bar"), + id="keywords_in_params", + ), + param([], {}, [], id="list_as_toplevel0"), + param( + [ + { + "_target_": "tests.instantiate.AClass", + "a": 10, + "b": 20, + "c": 30, + "d": 40, + }, + { + "_target_": "tests.instantiate.BClass", + "a": 50, + "b": 60, + "c": 70, + "d": 80, + }, + ], + {}, + [AClass(10, 20, 30, 40), BClass(50, 60, 70, 80)], + id="list_as_toplevel2", ), ], ) @@ -206,7 +371,7 @@ def test_class_instantiate( ) -> Any: passthrough["_recursive_"] = recursive obj = instantiate_func(config, **passthrough) - assert obj == expected + assert partial_equal(obj, expected) def test_none_cases( @@ -258,6 +423,33 @@ def test_none_cases( AClass(99, 20, 30, 40), id="interpolation_into_parent", ), + param( + { + "node": { + "_target_": "tests.instantiate.AClass", + "_partial_": True, + "a": "${value}", + "b": 20, + }, + "value": 99, + }, + {}, + partial(AClass, a=99, b=20), + id="interpolation_into_parent_partial", + ), + param( + { + "A": {"_target_": "tests.instantiate.add_values", "a": 1, "b": 2}, + "node": { + "_target_": "tests.instantiate.add_values", + "_partial_": True, + "a": "${A}", + }, + }, + {}, + partial(add_values, a=3), + id="interpolation_from_recursive_partial", + ), param( { "A": {"_target_": "tests.instantiate.add_values", "a": 1, "b": 2}, @@ -274,12 +466,18 @@ def test_none_cases( ], ) def test_interpolation_accessing_parent( - instantiate_func: Any, input_conf: Any, passthrough: Dict[str, Any], expected: Any + instantiate_func: Any, + input_conf: Any, + passthrough: Dict[str, Any], + expected: Any, ) -> Any: cfg_copy = OmegaConf.create(input_conf) input_conf = OmegaConf.create(input_conf) obj = instantiate_func(input_conf.node, **passthrough) - assert obj == expected + if isinstance(expected, partial): + assert partial_equal(obj, expected) + else: + assert obj == expected assert input_conf == cfg_copy @@ -303,7 +501,10 @@ def test_class_instantiate_omegaconf_node(instantiate_func: Any, config: Any) -> @mark.parametrize("src", [{"_target_": "tests.instantiate.Adam"}]) def test_instantiate_adam(instantiate_func: Any, config: Any) -> None: - with raises(TypeError): + with raises( + InstantiationException, + match=r"Error in call to target 'tests\.instantiate\.Adam':\nTypeError\(.*\)", + ): # can't instantiate without passing params instantiate_func(config) @@ -312,7 +513,8 @@ def test_instantiate_adam(instantiate_func: Any, config: Any) -> None: assert res == Adam(params=adam_params) -def test_regression_1483(instantiate_func: Any) -> None: +@mark.parametrize("is_partial", [True, False]) +def test_regression_1483(instantiate_func: Any, is_partial: bool) -> None: """ In 1483, pickle is failing because the parent node of lst node contains a generator, which is not picklable. @@ -325,21 +527,38 @@ def gen() -> Any: res: ArgsClass = instantiate_func( {"_target_": "tests.instantiate.ArgsClass"}, + _partial_=is_partial, gen=gen(), lst=[1, 2], ) - pickle.dumps(res.kwargs["lst"]) + if is_partial: + # res is of type functools.partial + pickle.dumps(res.keywords["lst"]) # type: ignore + else: + pickle.dumps(res.kwargs["lst"]) -def test_instantiate_adam_conf(instantiate_func: Any) -> None: - with raises(TypeError): +@mark.parametrize( + "is_partial,expected_params", + [(True, Parameters([1, 2, 3])), (False, partial(Parameters))], +) +def test_instantiate_adam_conf( + instantiate_func: Any, is_partial: bool, expected_params: Any +) -> None: + with raises( + InstantiationException, + match=r"Error in call to target 'tests\.instantiate\.Adam':\nTypeError\(.*\)", + ): # can't instantiate without passing params instantiate_func(AdamConf()) - adam_params = Parameters([1, 2, 3]) + adam_params = expected_params res = instantiate_func(AdamConf(lr=0.123), params=adam_params) expected = Adam(lr=0.123, params=adam_params) - assert res.params == expected.params + if is_partial: + partial_equal(res.params, expected.params) + else: + assert res.params == expected.params assert res.lr == expected.lr assert list(res.betas) == list(expected.betas) # OmegaConf converts tuples to lists assert res.eps == expected.eps @@ -371,29 +590,127 @@ def test_targetconf_deprecated() -> None: def test_instantiate_bad_adam_conf(instantiate_func: Any, recwarn: Any) -> None: - msg = ( - "Missing value for BadAdamConf._target_. Check that it's properly annotated and overridden." - "\nA common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'" + msg = re.escape( + dedent( + """\ + Config has missing value for key `_target_`, cannot instantiate. + Config type: BadAdamConf + Check that the `_target_` key in your dataclass is properly annotated and overridden. + A common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'""" + ) ) with raises( InstantiationException, - match=re.escape(msg), + match=msg, ): instantiate_func(BadAdamConf()) def test_instantiate_with_missing_module(instantiate_func: Any) -> None: + _target_ = "tests.instantiate.ClassWithMissingModule" with raises( - ModuleNotFoundError, match=re.escape("No module named 'some_missing_module'") + InstantiationException, + match=dedent( + rf""" + Error in call to target '{re.escape(_target_)}': + ModuleNotFoundError\("No module named 'some_missing_module'",?\)""" + ).strip(), ): # can't instantiate when importing a missing module - instantiate_func({"_target_": "tests.instantiate.ClassWithMissingModule"}) + instantiate_func({"_target_": _target_}) -def test_pass_extra_variables(instantiate_func: Any) -> None: - cfg = OmegaConf.create({"_target_": "tests.instantiate.AClass", "a": 10, "b": 20}) - assert instantiate_func(cfg, c=30) == AClass(a=10, b=20, c=30) +def test_instantiate_target_raising_exception_taking_no_arguments( + instantiate_func: Any, +) -> None: + _target_ = "tests.instantiate.raise_exception_taking_no_argument" + with raises( + InstantiationException, + match=( + dedent( + rf""" + Error in call to target '{re.escape(_target_)}': + ExceptionTakingNoArgument\('Err message',?\)""" + ).strip() + ), + ): + instantiate_func({}, _target_=_target_) + + +def test_instantiate_target_raising_exception_taking_no_arguments_nested( + instantiate_func: Any, +) -> None: + _target_ = "tests.instantiate.raise_exception_taking_no_argument" + with raises( + InstantiationException, + match=( + dedent( + rf""" + Error in call to target '{re.escape(_target_)}': + ExceptionTakingNoArgument\('Err message',?\) + full_key: foo + """ + ).strip() + ), + ): + instantiate_func({"foo": {"_target_": _target_}}) + + +def test_toplevel_list_partial_not_allowed(instantiate_func: Any) -> None: + config = [{"_target_": "tests.instantiate.ClassA", "a": 10, "b": 20, "c": 30}] + with raises( + InstantiationException, + match=re.escape( + "The _partial_ keyword is not compatible with top-level list instantiation" + ), + ): + instantiate_func(config, _partial_=True) + + +@mark.parametrize("is_partial", [True, False]) +def test_pass_extra_variables(instantiate_func: Any, is_partial: bool) -> None: + cfg = OmegaConf.create( + { + "_target_": "tests.instantiate.AClass", + "a": 10, + "b": 20, + "_partial_": is_partial, + } + ) + if is_partial: + assert partial_equal( + instantiate_func(cfg, c=30), partial(AClass, a=10, b=20, c=30) + ) + else: + assert instantiate_func(cfg, c=30) == AClass(a=10, b=20, c=30) + + +@mark.parametrize( + "target, expected", + [ + param(module_function2, lambda x: x == "fn return", id="fn"), + param(OuterClass, lambda x: isinstance(x, OuterClass), id="OuterClass"), + param( + OuterClass.method, + lambda x: x == "OuterClass.method return", + id="classmethod", + ), + param( + OuterClass.Nested, lambda x: isinstance(x, OuterClass.Nested), id="nested" + ), + param( + OuterClass.Nested.method, + lambda x: x == "OuterClass.Nested.method return", + id="nested_method", + ), + ], +) +def test_instantiate_with_callable_target_keyword( + instantiate_func: Any, target: Callable[[], None], expected: Callable[[Any], bool] +) -> None: + ret = instantiate_func({}, _target_=target) + assert expected(ret) @mark.parametrize( @@ -594,6 +911,195 @@ def test_recursive_instantiation( assert obj == expected +@mark.parametrize( + "src, passthrough, expected", + [ + # direct + param( + { + "_target_": "tests.instantiate.Tree", + "_partial_": True, + "left": { + "_target_": "tests.instantiate.Tree", + "value": 21, + }, + "right": { + "_target_": "tests.instantiate.Tree", + "value": 22, + }, + }, + {}, + partial(Tree, left=Tree(value=21), right=Tree(value=22)), + ), + param( + {"_target_": "tests.instantiate.Tree", "_partial_": True}, + {"value": 1}, + partial(Tree, value=1), + ), + param( + {"_target_": "tests.instantiate.Tree"}, + { + "value": 1, + "left": {"_target_": "tests.instantiate.Tree", "_partial_": True}, + }, + Tree(value=1, left=partial(Tree)), + ), + param( + {"_target_": "tests.instantiate.Tree"}, + { + "value": 1, + "left": {"_target_": "tests.instantiate.Tree", "_partial_": True}, + "right": {"_target_": "tests.instantiate.Tree", "value": 3}, + }, + Tree(value=1, left=partial(Tree), right=Tree(3)), + ), + param( + TreeConf( + value=1, + left=TreeConf(value=21, _partial_=True), + right=TreeConf(value=22), + ), + {}, + Tree( + value=1, + left=partial(Tree, value=21, left=None, right=None), + right=Tree(value=22), + ), + ), + param( + TreeConf( + _partial_=True, + value=1, + left=TreeConf(value=21, _partial_=True), + right=TreeConf(value=22, _partial_=True), + ), + {}, + partial( + Tree, + value=1, + left=partial(Tree, value=21, left=None, right=None), + right=partial(Tree, value=22, left=None, right=None), + ), + ), + param( + TreeConf( + _partial_=True, + value=1, + left=TreeConf( + value=21, + ), + right=TreeConf(value=22, left=TreeConf(_partial_=True, value=42)), + ), + {}, + partial( + Tree, + value=1, + left=Tree(value=21), + right=Tree( + value=22, left=partial(Tree, value=42, left=None, right=None) + ), + ), + ), + # list + # note that passthrough to a list element is not currently supported + param( + ComposeConf( + _partial_=True, + transforms=[ + CenterCropConf(size=10), + RotationConf(degrees=45), + ], + ), + {}, + partial( + Compose, + transforms=[ + CenterCrop(size=10), + Rotation(degrees=45), + ], + ), + ), + param( + ComposeConf( + transforms=[ + CenterCropConf(_partial_=True, size=10), + RotationConf(degrees=45), + ], + ), + {}, + Compose( + transforms=[ + partial(CenterCrop, size=10), # type: ignore + Rotation(degrees=45), + ], + ), + ), + param( + { + "_target_": "tests.instantiate.Compose", + "transforms": [ + {"_target_": "tests.instantiate.CenterCrop", "_partial_": True}, + {"_target_": "tests.instantiate.Rotation", "degrees": 45}, + ], + }, + {}, + Compose( + transforms=[ + partial(CenterCrop), # type: ignore + Rotation(degrees=45), + ] + ), + id="recursive:list:dict", + ), + # map + param( + MappingConf( + dictionary={ + "a": MappingConf(_partial_=True), + "b": MappingConf(), + } + ), + {}, + Mapping( + dictionary={ + "a": partial(Mapping, dictionary=None), # type: ignore + "b": Mapping(), + } + ), + ), + param( + { + "_target_": "tests.instantiate.Mapping", + "_partial_": True, + "dictionary": { + "a": {"_target_": "tests.instantiate.Mapping", "_partial_": True}, + }, + }, + { + "dictionary": { + "b": {"_target_": "tests.instantiate.Mapping", "_partial_": True}, + }, + }, + partial( + Mapping, + dictionary={ + "a": partial(Mapping), + "b": partial(Mapping), + }, + ), + ), + ], +) +def test_partial_instantiate( + instantiate_func: Any, + config: Any, + passthrough: Dict[str, Any], + expected: Any, +) -> None: + obj = instantiate_func(config, **passthrough) + assert obj == expected or partial_equal(obj, expected) + + @mark.parametrize( ("src", "passthrough", "expected"), [ @@ -848,26 +1354,65 @@ def test_instantiate_from_class_in_dict( @mark.parametrize( - "config, passthrough, expected", + "config, passthrough, err_msg", [ param( OmegaConf.create({"_target_": AClass}), {}, - AClass(10, 20, 30, 40), - id="class_in_config_dict", + re.escape( + "Expected a callable target, got" + + " '{'a': '???', 'b': '???', 'c': '???', 'd': 'default_value'}' of type 'DictConfig'" + ), + id="instantiate-from-dataclass-in-dict-fails", + ), + param( + OmegaConf.create({"foo": {"_target_": AClass}}), + {}, + re.escape( + "Expected a callable target, got" + + " '{'a': '???', 'b': '???', 'c': '???', 'd': 'default_value'}' of type 'DictConfig'" + + "\nfull_key: foo" + ), + id="instantiate-from-dataclass-in-dict-fails-nested", ), ], ) def test_instantiate_from_dataclass_in_dict_fails( - instantiate_func: Any, config: Any, passthrough: Any, expected: Any + instantiate_func: Any, config: Any, passthrough: Any, err_msg: str ) -> None: - # not the best error, but it will get the user to check their input config. - msg = "Unsupported target type: DictConfig. value: {'a': '???', 'b': '???', 'c': '???', 'd': 'default_value'}" with raises( InstantiationException, - match=re.escape(msg), + match=err_msg, ): - assert instantiate_func(config, **passthrough) == expected + instantiate_func(config, **passthrough) + + +def test_cannot_locate_target(instantiate_func: Any) -> None: + cfg = OmegaConf.create({"foo": {"_target_": "not_found"}}) + with raises( + InstantiationException, + match=re.escape( + dedent( + """\ + Error locating target 'not_found', see chained exception above. + full_key: foo""" + ) + ), + ) as exc_info: + instantiate_func(cfg) + err = exc_info.value + assert hasattr(err, "__cause__") + chained = err.__cause__ + assert isinstance(chained, ImportError) + assert_multiline_regex_search( + dedent( + """\ + Error loading 'not_found': + ModuleNotFoundError\\("No module named 'not_found'",?\\) + Are you sure that module 'not_found' is installed\\?""" + ), + chained.args[0], + ) @mark.parametrize( diff --git a/tests/instantiate/test_positional.py b/tests/instantiate/test_positional.py index e75112ec009..c0525a2bf59 100644 --- a/tests/instantiate/test_positional.py +++ b/tests/instantiate/test_positional.py @@ -1,8 +1,10 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from textwrap import dedent from typing import Any -from pytest import mark, param +from pytest import mark, param, raises +from hydra.errors import InstantiationException from hydra.utils import instantiate from tests.instantiate import ArgsClass @@ -47,6 +49,43 @@ def test_instantiate_args_kwargs(cfg: Any, expected: Any) -> None: assert instantiate(cfg) == expected +@mark.parametrize( + "cfg, msg", + [ + param( + {"_target_": "tests.instantiate.ArgsClass", "_args_": {"foo": "bar"}}, + dedent( + """\ + Error in collecting args and kwargs for 'tests\\.instantiate\\.ArgsClass': + InstantiationException\\("Unsupported _args_ type: 'DictConfig'\\. value: '{'foo': 'bar'}'",?\\)""" + ), + id="unsupported-args-type", + ), + param( + { + "foo": { + "_target_": "tests.instantiate.ArgsClass", + "_args_": {"foo": "bar"}, + } + }, + dedent( + """\ + Error in collecting args and kwargs for 'tests\\.instantiate\\.ArgsClass': + InstantiationException\\("Unsupported _args_ type: 'DictConfig'\\. value: '{'foo': 'bar'}'",?\\) + full_key: foo""" + ), + id="unsupported-args-type-nested", + ), + ], +) +def test_instantiate_unsupported_args_type(cfg: Any, msg: str) -> None: + with raises( + InstantiationException, + match=msg, + ): + instantiate(cfg) + + @mark.parametrize( ("cfg", "expected"), [ diff --git a/tests/instantiate/test_positional_only_arguments.py b/tests/instantiate/test_positional_only_arguments.py index 52ed65f2f99..49cc007ab7a 100644 --- a/tests/instantiate/test_positional_only_arguments.py +++ b/tests/instantiate/test_positional_only_arguments.py @@ -8,7 +8,7 @@ if sys.version_info < (3, 8): skip( - msg="Positional-only syntax is only supported in Python 3.8 or newer", + reason="Positional-only syntax is only supported in Python 3.8 or newer", allow_module_level=True, ) diff --git a/tests/jupyter/%run_test.ipynb b/tests/jupyter/%run_test.ipynb index 8dbd67a9de6..f32b6245cef 100644 --- a/tests/jupyter/%run_test.ipynb +++ b/tests/jupyter/%run_test.ipynb @@ -50,7 +50,7 @@ "source": [ "import tempfile\n", "tmpdir = tempfile.mkdtemp()\n", - "%run ../../examples/tutorials/basic/your_first_hydra_app/6_composition/my_app.py hydra.run.dir=$tmpdir" + "%run ../../examples/tutorials/basic/your_first_hydra_app/6_composition/my_app.py hydra.run.dir=$tmpdir hydra.job.chdir=True" ] } ], @@ -75,4 +75,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/tests/jupyter/a_module.py b/tests/jupyter/a_module.py index a59799bd798..bc940ff3f9d 100644 --- a/tests/jupyter/a_module.py +++ b/tests/jupyter/a_module.py @@ -5,13 +5,15 @@ def hydra_initialize() -> None: - initialize(config_path="../../hydra/test_utils/configs") + initialize(version_base=None, config_path="../../hydra/test_utils/configs") def hydra_initialize_config_dir() -> None: abs_conf_dir = Path.cwd() / "../../hydra/test_utils/configs" - initialize_config_dir(config_dir=str(abs_conf_dir)) + initialize_config_dir(version_base=None, config_dir=str(abs_conf_dir)) def hydra_initialize_config_module() -> None: - initialize_config_module(config_module="hydra.test_utils.configs") + initialize_config_module( + version_base=None, config_module="hydra.test_utils.configs" + ) diff --git a/tests/jupyter/test_initialize_in_module.ipynb b/tests/jupyter/test_initialize_in_module.ipynb index bf18fa5ec97..c6bcd28a252 100644 --- a/tests/jupyter/test_initialize_in_module.ipynb +++ b/tests/jupyter/test_initialize_in_module.ipynb @@ -98,7 +98,7 @@ ], "source": [ "GlobalHydra.instance().clear()\n", - "initialize(config_path=\"../../hydra/test_utils/configs\")\n", + "initialize(version_base=None, config_path=\"../../hydra/test_utils/configs\")\n", "compose(overrides=[\"+group1=file1\"])" ] }, @@ -121,7 +121,7 @@ "source": [ "GlobalHydra.instance().clear()\n", "abs_conf_dir = Path.cwd() / \"../../hydra/test_utils/configs\"\n", - "initialize_config_dir(config_dir=str(abs_conf_dir))\n", + "initialize_config_dir(version_base=None, config_dir=str(abs_conf_dir))\n", "compose(overrides=[\"+group1=file1\"])" ] }, @@ -143,7 +143,7 @@ ], "source": [ "GlobalHydra.instance().clear()\n", - "initialize_config_module(config_module=\"hydra.test_utils.configs\")\n", + "initialize_config_module(version_base=None, config_module=\"hydra.test_utils.configs\")\n", "compose(overrides=[\"+group1=file1\"])" ] } diff --git a/tests/standalone_apps/initialization_test_app/initialization_test_app/main.py b/tests/standalone_apps/initialization_test_app/initialization_test_app/main.py index b61dfddc987..9ad64dee82c 100644 --- a/tests/standalone_apps/initialization_test_app/initialization_test_app/main.py +++ b/tests/standalone_apps/initialization_test_app/initialization_test_app/main.py @@ -6,36 +6,42 @@ def main() -> None: - with initialize(config_path="conf"): + with initialize(version_base=None, config_path="conf"): cfg = compose(config_name="config", return_hydra_config=True) assert cfg.config == {"hello": "world"} assert cfg.hydra.job.name == "main" - with initialize(config_path="conf", job_name="test_job"): + with initialize(version_base=None, config_path="conf", job_name="test_job"): cfg = compose(config_name="config", return_hydra_config=True) assert cfg.config == {"hello": "world"} assert cfg.hydra.job.name == "test_job" abs_config_dir = os.path.abspath("initialization_test_app/conf") - with initialize_config_dir(config_dir=abs_config_dir): + with initialize_config_dir(version_base=None, config_dir=abs_config_dir): cfg = compose(config_name="config", return_hydra_config=True) assert cfg.config == {"hello": "world"} assert cfg.hydra.job.name == "app" - with initialize_config_dir(config_dir=abs_config_dir, job_name="test_job"): + with initialize_config_dir( + version_base=None, config_dir=abs_config_dir, job_name="test_job" + ): cfg = compose(config_name="config", return_hydra_config=True) assert cfg.config == {"hello": "world"} assert cfg.hydra.job.name == "test_job" # Those tests can only work if the module is installed if len(sys.argv) > 1 and sys.argv[1] == "module_installed": - with initialize_config_module(config_module="initialization_test_app.conf"): + with initialize_config_module( + version_base=None, config_module="initialization_test_app.conf" + ): cfg = compose(config_name="config", return_hydra_config=True) assert cfg.config == {"hello": "world"} assert cfg.hydra.job.name == "app" with initialize_config_module( - config_module="initialization_test_app.conf", job_name="test_job" + version_base=None, + config_module="initialization_test_app.conf", + job_name="test_job", ): cfg = compose(config_name="config", return_hydra_config=True) assert cfg.config == {"hello": "world"} diff --git a/tests/standalone_apps/namespace_pkg_config_source_test/namespace_test/test_namespace.py b/tests/standalone_apps/namespace_pkg_config_source_test/namespace_test/test_namespace.py index 2bdd02dd178..49905e4d01d 100644 --- a/tests/standalone_apps/namespace_pkg_config_source_test/namespace_test/test_namespace.py +++ b/tests/standalone_apps/namespace_pkg_config_source_test/namespace_test/test_namespace.py @@ -27,7 +27,9 @@ class TestCoreConfigSources(ConfigSourceTestSuite): def test_config_in_dir() -> None: - with initialize(config_path="../some_namespace/namespace_test/dir"): + with initialize( + version_base=None, config_path="../some_namespace/namespace_test/dir" + ): config_loader = GlobalHydra.instance().config_loader() assert "cifar10" in config_loader.get_group_options("dataset") assert "imagenet" in config_loader.get_group_options("dataset") diff --git a/tests/test_apps/app_can_fail/my_app.py b/tests/test_apps/app_can_fail/my_app.py index 758e958497e..de907cf641f 100644 --- a/tests/test_apps/app_can_fail/my_app.py +++ b/tests/test_apps/app_can_fail/my_app.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path=None) +@hydra.main(version_base=None) def my_app(cfg: DictConfig) -> None: val = 1 / cfg.divisor print(f"val={val}") diff --git a/tests/test_apps/app_change_dir/my_app.py b/tests/test_apps/app_change_dir/my_app.py new file mode 100644 index 00000000000..58cda268e97 --- /dev/null +++ b/tests/test_apps/app_change_dir/my_app.py @@ -0,0 +1,21 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import os +from pathlib import Path + +from omegaconf import DictConfig + +import hydra +from hydra.core.hydra_config import HydraConfig + + +@hydra.main(version_base=None) +def main(_: DictConfig) -> None: + subdir = Path(HydraConfig.get().run.dir) / Path("subdir") + subdir.mkdir(exist_ok=True, parents=True) + os.chdir(subdir) + + +if __name__ == "__main__": + main() + +print(f"current dir: {os.getcwd()}") diff --git a/tests/test_apps/app_exception/my_app.py b/tests/test_apps/app_exception/my_app.py index 71d5e1515f0..97aa8755081 100644 --- a/tests/test_apps/app_exception/my_app.py +++ b/tests/test_apps/app_exception/my_app.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path=None) +@hydra.main(version_base=None) def my_app(_: DictConfig) -> None: 1 / 0 diff --git a/tests/test_apps/run_as_module/3/module/__init__.py b/tests/test_apps/app_print_hydra_mode/__init__.py similarity index 100% rename from tests/test_apps/run_as_module/3/module/__init__.py rename to tests/test_apps/app_print_hydra_mode/__init__.py diff --git a/tests/test_apps/app_print_hydra_mode/conf/config.yaml b/tests/test_apps/app_print_hydra_mode/conf/config.yaml new file mode 100644 index 00000000000..d508cf75634 --- /dev/null +++ b/tests/test_apps/app_print_hydra_mode/conf/config.yaml @@ -0,0 +1 @@ +x: 1 diff --git a/tests/test_apps/app_print_hydra_mode/conf/hydra/sweeper/test.yaml b/tests/test_apps/app_print_hydra_mode/conf/hydra/sweeper/test.yaml new file mode 100644 index 00000000000..abc05e977ff --- /dev/null +++ b/tests/test_apps/app_print_hydra_mode/conf/hydra/sweeper/test.yaml @@ -0,0 +1,3 @@ +defaults: + - basic +max_batch_size: 1 diff --git a/tests/test_apps/app_print_hydra_mode/my_app.py b/tests/test_apps/app_print_hydra_mode/my_app.py new file mode 100644 index 00000000000..5d65ddc0fbc --- /dev/null +++ b/tests/test_apps/app_print_hydra_mode/my_app.py @@ -0,0 +1,14 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from omegaconf import DictConfig + +import hydra +from hydra.core.hydra_config import HydraConfig + + +@hydra.main(version_base=None, config_path="conf", config_name="config") +def my_app(_: DictConfig) -> None: + print(HydraConfig.get().mode) + + +if __name__ == "__main__": + my_app() diff --git a/tests/test_apps/app_with_callbacks/custom_callback/my_app.py b/tests/test_apps/app_with_callbacks/custom_callback/my_app.py index 6854379a3c2..3d17d71d364 100644 --- a/tests/test_apps/app_with_callbacks/custom_callback/my_app.py +++ b/tests/test_apps/app_with_callbacks/custom_callback/my_app.py @@ -1,6 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import logging +from typing import Any from omegaconf import DictConfig, OmegaConf @@ -12,30 +13,32 @@ class CustomCallback(Callback): - def __init__(self, callback_name): + def __init__(self, callback_name: str) -> None: self.name = callback_name log.info(f"Init {self.name}") - def on_job_start(self, config: DictConfig, **kwargs) -> None: + def on_job_start(self, config: DictConfig, **kwargs: Any) -> None: log.info(f"{self.name} on_job_start") - def on_job_end(self, config: DictConfig, job_return: JobReturn, **kwargs) -> None: + def on_job_end( + self, config: DictConfig, job_return: JobReturn, **kwargs: Any + ) -> None: log.info(f"{self.name} on_job_end") - def on_run_start(self, config: DictConfig, **kwargs) -> None: + def on_run_start(self, config: DictConfig, **kwargs: Any) -> None: log.info(f"{self.name} on_run_start") - def on_run_end(self, config: DictConfig, **kwargs) -> None: + def on_run_end(self, config: DictConfig, **kwargs: Any) -> None: log.info(f"{self.name} on_run_end") - def on_multirun_start(self, config: DictConfig, **kwargs) -> None: + def on_multirun_start(self, config: DictConfig, **kwargs: Any) -> None: log.info(f"{self.name} on_multirun_start") - def on_multirun_end(self, config: DictConfig, **kwargs) -> None: + def on_multirun_end(self, config: DictConfig, **kwargs: Any) -> None: log.info(f"{self.name} on_multirun_end") -@hydra.main(config_path=".", config_name="config") +@hydra.main(version_base=None, config_path=".", config_name="config") def my_app(cfg: DictConfig) -> None: log.info(OmegaConf.to_yaml(cfg)) diff --git a/tests/test_apps/app_with_cfg/my_app.py b/tests/test_apps/app_with_cfg/my_app.py index 9933c2ca1d9..15ea2577dd7 100644 --- a/tests/test_apps/app_with_cfg/my_app.py +++ b/tests/test_apps/app_with_cfg/my_app.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path=".", config_name="config") +@hydra.main(version_base=None, config_path=".", config_name="config") def my_app(_: DictConfig) -> None: pass diff --git a/tests/test_apps/app_with_cfg_groups/conf/config_with_runtime_option.yaml b/tests/test_apps/app_with_cfg_groups/conf/config_with_runtime_option.yaml new file mode 100644 index 00000000000..67a417c3290 --- /dev/null +++ b/tests/test_apps/app_with_cfg_groups/conf/config_with_runtime_option.yaml @@ -0,0 +1,5 @@ +defaults: + - optimizer: nesterov + - _self_ + +optimizer_option: ${hydra:runtime.choices.optimizer} diff --git a/tests/test_apps/app_with_cfg_groups/my_app.py b/tests/test_apps/app_with_cfg_groups/my_app.py index bb130df1d50..652ff3cc949 100644 --- a/tests/test_apps/app_with_cfg_groups/my_app.py +++ b/tests/test_apps/app_with_cfg_groups/my_app.py @@ -6,7 +6,7 @@ import hydra -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def my_app(cfg: DictConfig) -> Any: return cfg diff --git a/tests/test_apps/app_with_cfg_groups/my_app_with_runtime_choices_print.py b/tests/test_apps/app_with_cfg_groups/my_app_with_runtime_choices_print.py new file mode 100644 index 00000000000..9a6ec63d59d --- /dev/null +++ b/tests/test_apps/app_with_cfg_groups/my_app_with_runtime_choices_print.py @@ -0,0 +1,17 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from typing import Any + +from omegaconf import DictConfig + +import hydra + + +@hydra.main( + version_base=None, config_path="conf", config_name="config_with_runtime_option" +) +def my_app(cfg: DictConfig) -> Any: + print(cfg.optimizer_option) + + +if __name__ == "__main__": + my_app() diff --git a/tests/test_apps/app_with_cfg_groups_no_header/my_app.py b/tests/test_apps/app_with_cfg_groups_no_header/my_app.py index 6b1303d511f..7585728f215 100644 --- a/tests/test_apps/app_with_cfg_groups_no_header/my_app.py +++ b/tests/test_apps/app_with_cfg_groups_no_header/my_app.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def my_app(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/tests/test_apps/app_with_config_with_free_group/my_app.py b/tests/test_apps/app_with_config_with_free_group/my_app.py index 7f118245370..f69a111143c 100644 --- a/tests/test_apps/app_with_config_with_free_group/my_app.py +++ b/tests/test_apps/app_with_config_with_free_group/my_app.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def my_app(_: DictConfig) -> None: pass diff --git a/tests/test_apps/run_as_module/3/module/conf/__init__.py b/tests/test_apps/app_with_log_jobreturn_callback/__init__.py similarity index 100% rename from tests/test_apps/run_as_module/3/module/conf/__init__.py rename to tests/test_apps/app_with_log_jobreturn_callback/__init__.py diff --git a/tests/test_apps/app_with_log_jobreturn_callback/config.yaml b/tests/test_apps/app_with_log_jobreturn_callback/config.yaml new file mode 100644 index 00000000000..22dbd7e8ec6 --- /dev/null +++ b/tests/test_apps/app_with_log_jobreturn_callback/config.yaml @@ -0,0 +1,6 @@ +foo: bar + +hydra: + callbacks: + log_job_return: + _target_: hydra.experimental.callbacks.LogJobReturnCallback diff --git a/tests/test_apps/app_with_log_jobreturn_callback/my_app.py b/tests/test_apps/app_with_log_jobreturn_callback/my_app.py new file mode 100644 index 00000000000..d5802e0add6 --- /dev/null +++ b/tests/test_apps/app_with_log_jobreturn_callback/my_app.py @@ -0,0 +1,14 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from omegaconf import DictConfig + +import hydra + + +@hydra.main(config_path=".", config_name="config") +def my_app(cfg: DictConfig) -> None: + val = 1 / cfg.divisor + print(f"val={val}") + + +if __name__ == "__main__": + my_app() diff --git a/tests/test_apps/app_with_multiple_config_dirs/my_app.py b/tests/test_apps/app_with_multiple_config_dirs/my_app.py index ddb7946b9f1..faa693373de 100644 --- a/tests/test_apps/app_with_multiple_config_dirs/my_app.py +++ b/tests/test_apps/app_with_multiple_config_dirs/my_app.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path=".") +@hydra.main(version_base=None, config_path=".") def my_app(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/tests/test_apps/app_with_no_chdir_override/my_app.py b/tests/test_apps/app_with_no_chdir_override/my_app.py new file mode 100644 index 00000000000..8472b6cd635 --- /dev/null +++ b/tests/test_apps/app_with_no_chdir_override/my_app.py @@ -0,0 +1,13 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from omegaconf import DictConfig + +import hydra + + +@hydra.main(version_base="1.1", config_path=".") +def my_app(_: DictConfig) -> None: + pass + + +if __name__ == "__main__": + my_app() diff --git a/tests/test_apps/run_as_module/4/module/__init__.py b/tests/test_apps/app_with_pickle_job_info_callback/__init__.py similarity index 100% rename from tests/test_apps/run_as_module/4/module/__init__.py rename to tests/test_apps/app_with_pickle_job_info_callback/__init__.py diff --git a/tests/test_apps/app_with_pickle_job_info_callback/config.yaml b/tests/test_apps/app_with_pickle_job_info_callback/config.yaml new file mode 100644 index 00000000000..fdc7a961b64 --- /dev/null +++ b/tests/test_apps/app_with_pickle_job_info_callback/config.yaml @@ -0,0 +1,6 @@ +foo: bar + +hydra: + callbacks: + save_job_info: + _target_: hydra.experimental.callbacks.PickleJobInfoCallback diff --git a/tests/test_apps/app_with_pickle_job_info_callback/my_app.py b/tests/test_apps/app_with_pickle_job_info_callback/my_app.py new file mode 100644 index 00000000000..869d937aeeb --- /dev/null +++ b/tests/test_apps/app_with_pickle_job_info_callback/my_app.py @@ -0,0 +1,32 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import logging +import pickle +from pathlib import Path +from typing import Any + +from omegaconf import DictConfig + +import hydra +from hydra.core.hydra_config import HydraConfig + +log = logging.getLogger(__name__) + + +@hydra.main(version_base=None, config_path=".", config_name="config") +def my_app(cfg: DictConfig) -> str: + def pickle_cfg(path: Path, obj: Any) -> Any: + with open(str(path), "wb") as file: + pickle.dump(obj, file) + + hydra_cfg = HydraConfig.get() + output_dir = Path(hydra_cfg.runtime.output_dir) + pickle_cfg(Path(output_dir) / "task_cfg.pickle", cfg) + pickle_cfg(Path(output_dir) / "hydra_cfg.pickle", hydra_cfg) + log.info("Running my_app") + + return "hello world" + + +if __name__ == "__main__": + my_app() diff --git a/tests/test_apps/app_with_runtime_config_error/my_app.py b/tests/test_apps/app_with_runtime_config_error/my_app.py index 3b46379ab09..d16048bf294 100644 --- a/tests/test_apps/app_with_runtime_config_error/my_app.py +++ b/tests/test_apps/app_with_runtime_config_error/my_app.py @@ -8,7 +8,7 @@ def foo(cfg: DictConfig) -> None: cfg.foo = "bar" # does not exist in the config -@hydra.main(config_path=".", config_name="config") +@hydra.main(version_base=None, config_path=".", config_name="config") def my_app(cfg: DictConfig) -> None: foo(cfg) diff --git a/tests/test_apps/app_with_unicode_in_config/my_app.py b/tests/test_apps/app_with_unicode_in_config/my_app.py index 07916d55b34..0839a3c526d 100644 --- a/tests/test_apps/app_with_unicode_in_config/my_app.py +++ b/tests/test_apps/app_with_unicode_in_config/my_app.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path=".", config_name="config") +@hydra.main(version_base=None, config_path=".", config_name="config") def my_app(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/tests/test_apps/app_without_config/my_app.py b/tests/test_apps/app_without_config/my_app.py index 956198866d0..2117ee3c52b 100644 --- a/tests/test_apps/app_without_config/my_app.py +++ b/tests/test_apps/app_without_config/my_app.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path=None) +@hydra.main(version_base=None) def my_app(_: DictConfig) -> None: pass diff --git a/tests/test_apps/custom_env_defaults/my_app.py b/tests/test_apps/custom_env_defaults/my_app.py index 6e82825a64c..2b21c608c2a 100644 --- a/tests/test_apps/custom_env_defaults/my_app.py +++ b/tests/test_apps/custom_env_defaults/my_app.py @@ -9,7 +9,7 @@ log = logging.getLogger(__name__) -@hydra.main(config_path=None) +@hydra.main(version_base=None) def my_app(_: DictConfig) -> None: assert os.getenv("FOO") == "bar" diff --git a/tests/test_apps/defaults_in_schema_missing/my_app.py b/tests/test_apps/defaults_in_schema_missing/my_app.py index d7e5e4ae8c7..3fefe68d0f9 100644 --- a/tests/test_apps/defaults_in_schema_missing/my_app.py +++ b/tests/test_apps/defaults_in_schema_missing/my_app.py @@ -34,7 +34,7 @@ class Config: cs.store(name="config", node=Config, provider="main") -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def my_app(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/tests/test_apps/defaults_pkg_with_dot/my_app.py b/tests/test_apps/defaults_pkg_with_dot/my_app.py index b4a1eb58595..4ac205cb1bf 100644 --- a/tests/test_apps/defaults_pkg_with_dot/my_app.py +++ b/tests/test_apps/defaults_pkg_with_dot/my_app.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path=".", config_name="config") +@hydra.main(version_base=None, config_path=".", config_name="config") def my_app(cfg: DictConfig) -> None: print(cfg) diff --git a/tests/test_apps/deprecation_warning/my_app.py b/tests/test_apps/deprecation_warning/my_app.py index 85ba663329e..dbd811e3297 100644 --- a/tests/test_apps/deprecation_warning/my_app.py +++ b/tests/test_apps/deprecation_warning/my_app.py @@ -5,7 +5,7 @@ from hydra._internal.deprecation_warning import deprecation_warning -@hydra.main(config_path=None) +@hydra.main(version_base=None) def my_app(cfg: DictConfig) -> None: deprecation_warning("Feature FooBar is deprecated") diff --git a/tests/test_apps/hydra_resolver_in_output_dir/my_app.py b/tests/test_apps/hydra_resolver_in_output_dir/my_app.py new file mode 100644 index 00000000000..1daabd36d0e --- /dev/null +++ b/tests/test_apps/hydra_resolver_in_output_dir/my_app.py @@ -0,0 +1,14 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from omegaconf import DictConfig + +import hydra +from hydra.core.hydra_config import HydraConfig + + +@hydra.main(version_base=None) +def my_app(_: DictConfig) -> None: + print(HydraConfig.instance().get().runtime.output_dir) + + +if __name__ == "__main__": + my_app() diff --git a/tests/test_apps/hydra_to_cfg_interpolation/my_app.py b/tests/test_apps/hydra_to_cfg_interpolation/my_app.py index 90aa9789fa7..258bab59164 100644 --- a/tests/test_apps/hydra_to_cfg_interpolation/my_app.py +++ b/tests/test_apps/hydra_to_cfg_interpolation/my_app.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path=".", config_name="config") +@hydra.main(version_base=None, config_path=".", config_name="config") def my_app(cfg: DictConfig) -> None: print(cfg.c) diff --git a/tests/test_apps/run_as_module/4/module/conf/__init__.py b/tests/test_apps/hydra_verbose/__init__.py similarity index 100% rename from tests/test_apps/run_as_module/4/module/conf/__init__.py rename to tests/test_apps/hydra_verbose/__init__.py diff --git a/tests/test_apps/hydra_verbose/config.yaml b/tests/test_apps/hydra_verbose/config.yaml new file mode 100644 index 00000000000..29f89e65232 --- /dev/null +++ b/tests/test_apps/hydra_verbose/config.yaml @@ -0,0 +1,3 @@ +hydra: + verbose: + true diff --git a/tests/test_apps/hydra_verbose/my_app.py b/tests/test_apps/hydra_verbose/my_app.py new file mode 100644 index 00000000000..15ea2577dd7 --- /dev/null +++ b/tests/test_apps/hydra_verbose/my_app.py @@ -0,0 +1,13 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from omegaconf import DictConfig + +import hydra + + +@hydra.main(version_base=None, config_path=".", config_name="config") +def my_app(_: DictConfig) -> None: + pass + + +if __name__ == "__main__": + my_app() diff --git a/tests/test_apps/init_in_app_without_module/main.py b/tests/test_apps/init_in_app_without_module/main.py index 0a459f06164..2bd8be11b68 100644 --- a/tests/test_apps/init_in_app_without_module/main.py +++ b/tests/test_apps/init_in_app_without_module/main.py @@ -4,23 +4,25 @@ from hydra import compose, initialize, initialize_config_dir if __name__ == "__main__": - with initialize(config_path="."): + with initialize(version_base=None, config_path="."): cfg = compose(config_name="config", return_hydra_config=True) assert cfg.config == {"hello": "world"} assert cfg.hydra.job.name == "main" - with initialize(config_path=".", job_name="test_job"): + with initialize(version_base=None, config_path=".", job_name="test_job"): cfg = compose(config_name="config", return_hydra_config=True) assert cfg.config == {"hello": "world"} assert cfg.hydra.job.name == "test_job" abs_config__dir = os.path.abspath("") - with initialize_config_dir(config_dir=abs_config__dir): + with initialize_config_dir(version_base=None, config_dir=abs_config__dir): cfg = compose(config_name="config", return_hydra_config=True) assert cfg.config == {"hello": "world"} assert cfg.hydra.job.name == "app" - with initialize_config_dir(config_dir=abs_config__dir, job_name="test_job"): + with initialize_config_dir( + version_base=None, config_dir=abs_config__dir, job_name="test_job" + ): cfg = compose(config_name="config", return_hydra_config=True) assert cfg.config == {"hello": "world"} assert cfg.hydra.job.name == "test_job" diff --git a/tests/test_apps/multirun_structured_conflict/my_app.py b/tests/test_apps/multirun_structured_conflict/my_app.py index 008c3a2a26b..ab13c3664d0 100644 --- a/tests/test_apps/multirun_structured_conflict/my_app.py +++ b/tests/test_apps/multirun_structured_conflict/my_app.py @@ -16,8 +16,8 @@ class TestConfig: config_store.store(group="test", name="default", node=TestConfig) -@hydra.main(config_path=".", config_name="config") -def run(config: DictConfig): +@hydra.main(version_base=None, config_path=".", config_name="config") +def run(config: DictConfig) -> None: print(config.test.param) diff --git a/tests/test_apps/schema-overrides-hydra/__init__.py b/tests/test_apps/passes_callable_class_to_hydra_main/__init__.py similarity index 100% rename from tests/test_apps/schema-overrides-hydra/__init__.py rename to tests/test_apps/passes_callable_class_to_hydra_main/__init__.py diff --git a/tests/test_apps/passes_callable_class_to_hydra_main/my_app.py b/tests/test_apps/passes_callable_class_to_hydra_main/my_app.py new file mode 100644 index 00000000000..81c6a0eb184 --- /dev/null +++ b/tests/test_apps/passes_callable_class_to_hydra_main/my_app.py @@ -0,0 +1,21 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from omegaconf import DictConfig + +import hydra +from hydra.core.hydra_config import HydraConfig + + +class MyCallable: + def __init__(self, state: int = 123) -> None: + self._state = state + + def __call__(self, cfg: DictConfig) -> None: + print(self._state) + print(HydraConfig.get().job.name) + + +my_callable = MyCallable() +my_app = hydra.main(version_base=None)(my_callable) + +if __name__ == "__main__": + my_app() diff --git a/tests/test_apps/run_as_module/1/config.yaml b/tests/test_apps/run_as_module_1/config.yaml similarity index 100% rename from tests/test_apps/run_as_module/1/config.yaml rename to tests/test_apps/run_as_module_1/config.yaml diff --git a/tests/test_apps/run_as_module/4/module/my_app.py b/tests/test_apps/run_as_module_1/my_app.py similarity index 77% rename from tests/test_apps/run_as_module/4/module/my_app.py rename to tests/test_apps/run_as_module_1/my_app.py index 6b1303d511f..0839a3c526d 100644 --- a/tests/test_apps/run_as_module/4/module/my_app.py +++ b/tests/test_apps/run_as_module_1/my_app.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path=".", config_name="config") def my_app(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/tests/test_apps/run_as_module_2/conf/__init__.py b/tests/test_apps/run_as_module_2/conf/__init__.py new file mode 100644 index 00000000000..168f9979a46 --- /dev/null +++ b/tests/test_apps/run_as_module_2/conf/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved diff --git a/tests/test_apps/run_as_module/2/conf/config.yaml b/tests/test_apps/run_as_module_2/conf/config.yaml similarity index 100% rename from tests/test_apps/run_as_module/2/conf/config.yaml rename to tests/test_apps/run_as_module_2/conf/config.yaml diff --git a/tests/test_apps/run_as_module/2/my_app.py b/tests/test_apps/run_as_module_2/my_app.py similarity index 76% rename from tests/test_apps/run_as_module/2/my_app.py rename to tests/test_apps/run_as_module_2/my_app.py index 6b1303d511f..7585728f215 100644 --- a/tests/test_apps/run_as_module/2/my_app.py +++ b/tests/test_apps/run_as_module_2/my_app.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def my_app(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/tests/test_apps/run_as_module_3/module/__init__.py b/tests/test_apps/run_as_module_3/module/__init__.py new file mode 100644 index 00000000000..168f9979a46 --- /dev/null +++ b/tests/test_apps/run_as_module_3/module/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved diff --git a/tests/test_apps/run_as_module_3/module/conf/__init__.py b/tests/test_apps/run_as_module_3/module/conf/__init__.py new file mode 100644 index 00000000000..168f9979a46 --- /dev/null +++ b/tests/test_apps/run_as_module_3/module/conf/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved diff --git a/tests/test_apps/run_as_module/3/module/config.yaml b/tests/test_apps/run_as_module_3/module/config.yaml similarity index 100% rename from tests/test_apps/run_as_module/3/module/config.yaml rename to tests/test_apps/run_as_module_3/module/config.yaml diff --git a/tests/test_apps/run_as_module/1/my_app.py b/tests/test_apps/run_as_module_3/module/my_app.py similarity index 82% rename from tests/test_apps/run_as_module/1/my_app.py rename to tests/test_apps/run_as_module_3/module/my_app.py index 07916d55b34..1de2eb6bbb0 100644 --- a/tests/test_apps/run_as_module/1/my_app.py +++ b/tests/test_apps/run_as_module_3/module/my_app.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path=".", config_name="config") +@hydra.main(version_base=None, config_name="config") def my_app(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/tests/test_apps/run_as_module_4/module/__init__.py b/tests/test_apps/run_as_module_4/module/__init__.py new file mode 100644 index 00000000000..168f9979a46 --- /dev/null +++ b/tests/test_apps/run_as_module_4/module/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved diff --git a/tests/test_apps/run_as_module_4/module/conf/__init__.py b/tests/test_apps/run_as_module_4/module/conf/__init__.py new file mode 100644 index 00000000000..168f9979a46 --- /dev/null +++ b/tests/test_apps/run_as_module_4/module/conf/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved diff --git a/tests/test_apps/run_as_module/4/module/conf/config.yaml b/tests/test_apps/run_as_module_4/module/conf/config.yaml similarity index 100% rename from tests/test_apps/run_as_module/4/module/conf/config.yaml rename to tests/test_apps/run_as_module_4/module/conf/config.yaml diff --git a/tests/test_apps/run_as_module/3/module/my_app.py b/tests/test_apps/run_as_module_4/module/my_app.py similarity index 76% rename from tests/test_apps/run_as_module/3/module/my_app.py rename to tests/test_apps/run_as_module_4/module/my_app.py index 681400921da..7585728f215 100644 --- a/tests/test_apps/run_as_module/3/module/my_app.py +++ b/tests/test_apps/run_as_module_4/module/my_app.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path=None, config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def my_app(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/tests/test_apps/run_dir_test/my_app.py b/tests/test_apps/run_dir_test/my_app.py index 390005b708b..9c4ac4d64ec 100644 --- a/tests/test_apps/run_dir_test/my_app.py +++ b/tests/test_apps/run_dir_test/my_app.py @@ -9,7 +9,7 @@ from hydra.utils import get_original_cwd -@hydra.main(config_path=None) +@hydra.main(version_base=None) def my_app(_: DictConfig) -> None: run_dir = str(Path.cwd().relative_to(get_original_cwd())) time.sleep(2) diff --git a/tests/test_apps/schema_overrides_hydra/__init__.py b/tests/test_apps/schema_overrides_hydra/__init__.py new file mode 100644 index 00000000000..168f9979a46 --- /dev/null +++ b/tests/test_apps/schema_overrides_hydra/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved diff --git a/tests/test_apps/schema-overrides-hydra/config.yaml b/tests/test_apps/schema_overrides_hydra/config.yaml similarity index 100% rename from tests/test_apps/schema-overrides-hydra/config.yaml rename to tests/test_apps/schema_overrides_hydra/config.yaml diff --git a/tests/test_apps/schema-overrides-hydra/group/a.yaml b/tests/test_apps/schema_overrides_hydra/group/a.yaml similarity index 100% rename from tests/test_apps/schema-overrides-hydra/group/a.yaml rename to tests/test_apps/schema_overrides_hydra/group/a.yaml diff --git a/tests/test_apps/schema-overrides-hydra/my_app.py b/tests/test_apps/schema_overrides_hydra/my_app.py similarity index 90% rename from tests/test_apps/schema-overrides-hydra/my_app.py rename to tests/test_apps/schema_overrides_hydra/my_app.py index ec738adbea8..669b09be92b 100644 --- a/tests/test_apps/schema-overrides-hydra/my_app.py +++ b/tests/test_apps/schema_overrides_hydra/my_app.py @@ -19,7 +19,7 @@ class Config: ConfigStore.instance().store(name="config_schema", node=Config) -@hydra.main(config_path=".", config_name="config") +@hydra.main(version_base=None, config_path=".", config_name="config") def my_app(cfg: Config) -> None: print( f"job_name: {HydraConfig().get().job.name}, " diff --git a/tests/test_apps/simple_app/my_app.py b/tests/test_apps/simple_app/my_app.py index 378e48a178d..e76efd68962 100644 --- a/tests/test_apps/simple_app/my_app.py +++ b/tests/test_apps/simple_app/my_app.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path=None) +@hydra.main(version_base=None) def my_app(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg, resolve=True)) diff --git a/tests/test_apps/simple_interpolation/my_app.py b/tests/test_apps/simple_interpolation/my_app.py index 90aa9789fa7..258bab59164 100644 --- a/tests/test_apps/simple_interpolation/my_app.py +++ b/tests/test_apps/simple_interpolation/my_app.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path=".", config_name="config") +@hydra.main(version_base=None, config_path=".", config_name="config") def my_app(cfg: DictConfig) -> None: print(cfg.c) diff --git a/tests/test_apps/structured_with_none_list/my_app.py b/tests/test_apps/structured_with_none_list/my_app.py index 4b6188d129e..fe3d7a3cc41 100644 --- a/tests/test_apps/structured_with_none_list/my_app.py +++ b/tests/test_apps/structured_with_none_list/my_app.py @@ -2,6 +2,8 @@ from dataclasses import dataclass from typing import List, Optional +from omegaconf import DictConfig + import hydra from hydra.core.config_store import ConfigStore @@ -15,8 +17,8 @@ class Config: cs.store(name="config", node=Config) -@hydra.main(config_path=None, config_name="config") -def main(cfg): +@hydra.main(version_base=None, config_name="config") +def main(cfg: DictConfig) -> None: print(cfg) diff --git a/tests/test_apps/sweep_complex_defaults/my_app.py b/tests/test_apps/sweep_complex_defaults/my_app.py index 6b1303d511f..7585728f215 100644 --- a/tests/test_apps/sweep_complex_defaults/my_app.py +++ b/tests/test_apps/sweep_complex_defaults/my_app.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def my_app(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/tests/test_apps/sys_exit/my_app.py b/tests/test_apps/sys_exit/my_app.py index 8bcb45ef1bc..95acab0a2f4 100644 --- a/tests/test_apps/sys_exit/my_app.py +++ b/tests/test_apps/sys_exit/my_app.py @@ -6,7 +6,7 @@ import hydra -@hydra.main(config_path=None) +@hydra.main(version_base=None) def my_app(_: DictConfig) -> None: sys.exit(42) diff --git a/tests/test_apps/user-config-dir/my_app.py b/tests/test_apps/user-config-dir/my_app.py index 6b1303d511f..7585728f215 100644 --- a/tests/test_apps/user-config-dir/my_app.py +++ b/tests/test_apps/user-config-dir/my_app.py @@ -4,7 +4,7 @@ import hydra -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def my_app(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/tests/test_basic_sweeper.py b/tests/test_basic_sweeper.py index 849aedfdcb2..2db8c1e549a 100644 --- a/tests/test_basic_sweeper.py +++ b/tests/test_basic_sweeper.py @@ -70,6 +70,7 @@ def test_partial_failure( "--multirun", "+divisor=1,0", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", "hydra.hydra_logging.formatters.simple.format='[HYDRA] %(message)s'", ] out, err = run_process(cmd=cmd, print_error=False, raise_exception=False) diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index e64b503690a..4170137841b 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -1,13 +1,20 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import copy +import os +import pickle +import sys from pathlib import Path from textwrap import dedent -from typing import List +from typing import Any, List +from omegaconf import open_dict, read_write from pytest import mark +from hydra.core.utils import JobReturn, JobStatus from hydra.test_utils.test_utils import ( assert_regex_match, chdir_hydra_root, + run_process, run_python_script, ) @@ -81,6 +88,7 @@ def test_app_with_callbacks( cmd = [ app_path, "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", "hydra.hydra_logging.formatters.simple.format='[HYDRA] %(message)s'", "hydra.job_logging.formatters.simple.format='[JOB] %(message)s'", ] @@ -93,3 +101,126 @@ def test_app_with_callbacks( from_name="Expected output", to_name="Actual output", ) + + +@mark.parametrize("multirun", [True, False]) +def test_experimental_save_job_info_callback(tmpdir: Path, multirun: bool) -> None: + app_path = "tests/test_apps/app_with_pickle_job_info_callback/my_app.py" + + cmd = [ + app_path, + "hydra.run.dir=" + str(tmpdir), + "hydra.sweep.dir=" + str(tmpdir), + "hydra.job.chdir=True", + ] + if multirun: + cmd.append("-m") + _, _err = run_python_script(cmd) + + def load_pickle(path: Path) -> Any: + with open(str(path), "rb") as input: + obj = pickle.load(input) # nosec + return obj + + # load pickles from callbacks + callback_output = tmpdir / Path("0") / ".hydra" if multirun else tmpdir / ".hydra" + config_on_job_start = load_pickle(callback_output / "config.pickle") + job_return_on_job_end: JobReturn = load_pickle( + callback_output / "job_return.pickle" + ) + + task_cfg_from_callback = copy.deepcopy(config_on_job_start) + with read_write(task_cfg_from_callback): + with open_dict(task_cfg_from_callback): + del task_cfg_from_callback["hydra"] + + # load pickles generated from the application + app_output_dir = tmpdir / "0" if multirun else tmpdir + task_cfg_from_app = load_pickle(app_output_dir / "task_cfg.pickle") + hydra_cfg_from_app = load_pickle(app_output_dir / "hydra_cfg.pickle") + + # verify the cfg pickles are the same on_job_start + assert task_cfg_from_callback == task_cfg_from_app + assert config_on_job_start.hydra == hydra_cfg_from_app + + # verify pickled object are the same on_job_end + assert job_return_on_job_end.cfg == task_cfg_from_app + assert job_return_on_job_end.hydra_cfg.hydra == hydra_cfg_from_app # type: ignore + assert job_return_on_job_end.return_value == "hello world" + assert job_return_on_job_end.status == JobStatus.COMPLETED + + +@mark.parametrize("multirun", [True, False]) +def test_save_job_return_callback(tmpdir: Path, multirun: bool) -> None: + app_path = "tests/test_apps/app_with_log_jobreturn_callback/my_app.py" + cmd = [ + sys.executable, + app_path, + "hydra.run.dir=" + str(tmpdir), + "hydra.sweep.dir=" + str(tmpdir), + "hydra.job.chdir=True", + ] + if multirun: + extra = ["+x=0,1", "-m"] + cmd.extend(extra) + log_msg = "omegaconf.errors.ConfigAttributeError: Key 'divisor' is not in struct\n" + run_process(cmd=cmd, print_error=False, raise_exception=False) + + if multirun: + log_paths = [tmpdir / "0" / "my_app.log", tmpdir / "1" / "my_app.log"] + else: + log_paths = [tmpdir / "my_app.log"] + + for p in log_paths: + with open(p, "r") as file: + logs = file.readlines() + assert log_msg in logs + + +@mark.parametrize( + "warning_msg,overrides", + [ + ("Experimental rerun CLI option", []), + ("Config overrides are not supported as of now", ["+x=1"]), + ], +) +def test_experimental_rerun( + tmpdir: Path, warning_msg: str, overrides: List[str] +) -> None: + app_path = "tests/test_apps/app_with_pickle_job_info_callback/my_app.py" + + cmd = [ + app_path, + "hydra.run.dir=" + str(tmpdir), + "hydra.sweep.dir=" + str(tmpdir), + "hydra.job.chdir=False", + "hydra.hydra_logging.formatters.simple.format='[HYDRA] %(message)s'", + "hydra.job_logging.formatters.simple.format='[JOB] %(message)s'", + ] + run_python_script(cmd) + + config_file = tmpdir / ".hydra" / "config.pickle" + log_file = tmpdir / "my_app.log" + assert config_file.exists() + assert log_file.exists() + + with open(log_file, "r") as file: + logs = file.read().splitlines() + assert "[JOB] Running my_app" in logs + + os.remove(str(log_file)) + assert not log_file.exists() + + # then rerun the application and verify log file is created again + cmd = [ + app_path, + "--experimental-rerun", + str(config_file), + ] + cmd.extend(overrides) + result, err = run_python_script(cmd, allow_warnings=True) + assert warning_msg in err + + with open(log_file, "r") as file: + logs = file.read().splitlines() + assert "[JOB] Running my_app" in logs diff --git a/tests/test_completion.py b/tests/test_completion.py index fa1613893fe..f9aee44617d 100644 --- a/tests/test_completion.py +++ b/tests/test_completion.py @@ -8,7 +8,7 @@ from typing import List from packaging import version -from pytest import mark, param, skip +from pytest import mark, param, skip, xfail from hydra._internal.config_loader_impl import ConfigLoaderImpl from hydra._internal.core_plugins.bash_completion import BashCompletion @@ -157,6 +157,46 @@ def test_bash_completion_with_dot_in_path() -> None: param("group=dict group.dict=", 2, ["group.dict=true"], id="group"), param("group=dict group=", 2, ["group=dict", "group=list"], id="group"), param("group=dict group=", 2, ["group=dict", "group=list"], id="group"), + param("+", 2, ["+group=", "+hydra", "+test_hydra/"], id="bare_plus"), + param("+gro", 2, ["+group="], id="append_group_partial"), + param("+group=di", 2, ["+group=dict"], id="append_group_partial_option"), + param("+group=", 2, ["+group=dict", "+group=list"], id="group_append"), + param( + "group=dict +group=", 2, ["+group=dict", "+group=list"], id="group_append2" + ), + param("~gro", 2, ["~group"], id="delete_group_partial"), + param("~group=di", 2, ["~group=dict"], id="delete_group_partial_option"), + param("~group=", 2, ["~group=dict", "~group=list"], id="group_delete"), + param( + "group=dict ~group=", 2, ["~group=dict", "~group=list"], id="group_delete2" + ), + param("~", 2, ["~group", "~hydra", "~test_hydra/"], id="bare_tilde"), + param("+test_hydra/lau", 2, ["+test_hydra/launcher="], id="nested_plus"), + param( + "+test_hydra/launcher=", + 2, + ["+test_hydra/launcher=fairtask"], + id="nested_plus_equal", + ), + param( + "+test_hydra/launcher=fa", + 2, + ["+test_hydra/launcher=fairtask"], + id="nested_plus_equal_partial", + ), + param("~test_hydra/lau", 2, ["~test_hydra/launcher"], id="nested_tilde"), + param( + "~test_hydra/launcher=", + 2, + ["~test_hydra/launcher=fairtask"], + id="nested_tilde_equal", + ), + param( + "~test_hydra/launcher=fa", + 2, + ["~test_hydra/launcher=fairtask"], + id="nested_tilde_equal_partial", + ), ], ) class TestRunCompletion: @@ -197,6 +237,10 @@ def test_shell_integration( skip("fish is not installed or the version is too old") if shell == "zsh" and not is_zsh_supported(): skip("zsh is not installed or the version is too old") + if shell in ("zsh", "fish") and any( + word.startswith("~") for word in line.split(" ") + ): + xfail(f"{shell} treats words prefixed by the tilde symbol specially") # verify expect will be running the correct Python. # This preemptively detect a much harder to understand error from expect. diff --git a/tests/test_compose.py b/tests/test_compose.py index fe391a3136f..999e466d2db 100644 --- a/tests/test_compose.py +++ b/tests/test_compose.py @@ -3,6 +3,7 @@ import subprocess import sys from dataclasses import dataclass, field +from enum import Enum from pathlib import Path from textwrap import dedent from typing import Any, Dict, List, Optional @@ -10,7 +11,14 @@ from omegaconf import MISSING, OmegaConf from pytest import fixture, mark, param, raises, warns -from hydra import compose, initialize, initialize_config_dir, initialize_config_module +from hydra import ( + __version__, + compose, + initialize, + initialize_config_dir, + initialize_config_module, + version, +) from hydra._internal.config_search_path_impl import ConfigSearchPathImpl from hydra.core.config_search_path import SearchPathQuery from hydra.core.config_store import ConfigStore @@ -24,7 +32,7 @@ @fixture def initialize_hydra(config_path: Optional[str]) -> Any: try: - init = initialize(config_path=config_path) + init = initialize(version_base=None, config_path=config_path) init.__enter__() yield finally: @@ -34,7 +42,7 @@ def initialize_hydra(config_path: Optional[str]) -> Any: @fixture def initialize_hydra_no_path() -> Any: try: - init = initialize(config_path=None) + init = initialize(version_base=None) init.__enter__() yield finally: @@ -43,13 +51,54 @@ def initialize_hydra_no_path() -> Any: def test_initialize(hydra_restore_singletons: Any) -> None: assert not GlobalHydra().is_initialized() - initialize(config_path=None) + initialize(version_base=None) assert GlobalHydra().is_initialized() +def test_initialize_old_version_base(hydra_restore_singletons: Any) -> None: + assert not GlobalHydra().is_initialized() + with raises( + HydraException, + match=f'version_base must be >= "{version.__compat_version__}"', + ): + initialize(version_base="1.0") + + +def test_initialize_bad_version_base(hydra_restore_singletons: Any) -> None: + assert not GlobalHydra().is_initialized() + with raises( + TypeError, + match="expected string or bytes-like object", + ): + initialize(version_base=1.1) # type: ignore + + +def test_initialize_dev_version_base(hydra_restore_singletons: Any) -> None: + assert not GlobalHydra().is_initialized() + # packaging will compare "1.2.0.dev2" < "1.2", so need to ensure handled correctly + initialize(version_base="1.2.0.dev2") + assert version.base_at_least("1.2") + + +def test_initialize_cur_version_base(hydra_restore_singletons: Any) -> None: + assert not GlobalHydra().is_initialized() + initialize(version_base=None) + assert version.base_at_least(__version__) + + +def test_initialize_compat_version_base(hydra_restore_singletons: Any) -> None: + assert not GlobalHydra().is_initialized() + with raises( + UserWarning, + match=f"Will assume defaults for version {version.__compat_version__}", + ): + initialize() + assert version.base_at_least(str(version.__compat_version__)) + + def test_initialize_with_config_path(hydra_restore_singletons: Any) -> None: assert not GlobalHydra().is_initialized() - initialize(config_path="../hydra/test_utils/configs") + initialize(version_base=None, config_path="../hydra/test_utils/configs") assert GlobalHydra().is_initialized() gh = GlobalHydra.instance() @@ -181,7 +230,10 @@ class TestComposeInits: def test_initialize_ctx( self, config_file: str, overrides: List[str], expected: Any ) -> None: - with initialize(config_path="../examples/jupyter_notebooks/cloud_app/conf"): + with initialize( + version_base=None, + config_path="../examples/jupyter_notebooks/cloud_app/conf", + ): ret = compose(config_file, overrides) assert ret == expected @@ -196,6 +248,7 @@ def test_initialize_config_dir_ctx_with_relative_dir( ): with initialize_config_dir( config_dir="../examples/jupyter_notebooks/cloud_app/conf", + version_base=None, job_name="job_name", ): ret = compose(config_file, overrides) @@ -206,6 +259,7 @@ def test_initialize_config_module_ctx( ) -> None: with initialize_config_module( config_module="examples.jupyter_notebooks.cloud_app.conf", + version_base=None, job_name="job_name", ): ret = compose(config_file, overrides) @@ -218,7 +272,7 @@ def test_initialize_ctx_with_absolute_dir( with raises( HydraException, match=re.escape("config_path in initialize() must be relative") ): - with initialize(config_path=str(tmpdir)): + with initialize(version_base=None, config_path=str(tmpdir)): compose(overrides=["+test_group=test"]) @@ -233,7 +287,10 @@ def test_initialize_config_dir_ctx_with_absolute_dir( with open(str(cfg_file), "w") as f: OmegaConf.save(cfg, f) - with initialize_config_dir(config_dir=str(tmpdir)): + with initialize_config_dir( + config_dir=str(tmpdir), + version_base=None, + ): ret = compose(overrides=["+test_group=test"]) assert ret == {"test_group": cfg} @@ -245,7 +302,9 @@ def test_jobname_override_initialize_ctx( hydra_restore_singletons: Any, job_name: Optional[str], expected: str ) -> None: with initialize( - config_path="../examples/jupyter_notebooks/cloud_app/conf", job_name=job_name + version_base=None, + config_path="../examples/jupyter_notebooks/cloud_app/conf", + job_name=job_name, ): ret = compose(return_hydra_config=True) assert ret.hydra.job.name == expected @@ -254,26 +313,33 @@ def test_jobname_override_initialize_ctx( def test_jobname_override_initialize_config_dir_ctx( hydra_restore_singletons: Any, tmpdir: Any ) -> None: - with initialize_config_dir(config_dir=str(tmpdir), job_name="test_job"): + with initialize_config_dir( + config_dir=str(tmpdir), version_base=None, job_name="test_job" + ): ret = compose(return_hydra_config=True) assert ret.hydra.job.name == "test_job" def test_initialize_config_module_ctx(hydra_restore_singletons: Any) -> None: with initialize_config_module( - config_module="examples.jupyter_notebooks.cloud_app.conf" + config_module="examples.jupyter_notebooks.cloud_app.conf", + version_base=None, ): ret = compose(return_hydra_config=True) assert ret.hydra.job.name == "app" with initialize_config_module( - config_module="examples.jupyter_notebooks.cloud_app.conf", job_name="test_job" + config_module="examples.jupyter_notebooks.cloud_app.conf", + job_name="test_job", + version_base=None, ): ret = compose(return_hydra_config=True) assert ret.hydra.job.name == "test_job" with initialize_config_module( - config_module="examples.jupyter_notebooks.cloud_app.conf", job_name="test_job" + config_module="examples.jupyter_notebooks.cloud_app.conf", + job_name="test_job", + version_base=None, ): ret = compose(return_hydra_config=True) assert ret.hydra.job.name == "test_job" @@ -287,7 +353,8 @@ def test_missing_init_py_error(hydra_restore_singletons: Any) -> None: with raises(Exception, match=re.escape(expected)): with initialize_config_module( - config_module="hydra.test_utils.configs.missing_init_py" + config_module="hydra.test_utils.configs.missing_init_py", + version_base=None, ): hydra = GlobalHydra.instance().hydra assert hydra is not None @@ -301,7 +368,10 @@ def test_missing_bad_config_dir_error(hydra_restore_singletons: Any) -> None: ) with raises(Exception, match=re.escape(expected)): - with initialize_config_dir(config_dir="/no_way_in_hell_1234567890"): + with initialize_config_dir( + config_dir="/no_way_in_hell_1234567890", + version_base=None, + ): hydra = GlobalHydra.instance().hydra assert hydra is not None compose(config_name="test.yaml", overrides=[]) @@ -309,7 +379,9 @@ def test_missing_bad_config_dir_error(hydra_restore_singletons: Any) -> None: def test_initialize_with_module(hydra_restore_singletons: Any) -> None: with initialize_config_module( - config_module="tests.test_apps.app_with_cfg_groups.conf", job_name="my_pp" + config_module="tests.test_apps.app_with_cfg_groups.conf", + job_name="my_pp", + version_base=None, ): assert compose(config_name="config") == { "optimizer": {"type": "nesterov", "lr": 0.001} @@ -317,7 +389,9 @@ def test_initialize_with_module(hydra_restore_singletons: Any) -> None: def test_hydra_main_passthrough(hydra_restore_singletons: Any) -> None: - with initialize(config_path="test_apps/app_with_cfg_groups/conf"): + with initialize( + version_base=None, config_path="test_apps/app_with_cfg_groups/conf" + ): from tests.test_apps.app_with_cfg_groups.my_app import my_app # type: ignore cfg = compose(config_name="config", overrides=["optimizer.lr=1.0"]) @@ -568,7 +642,7 @@ def test_deprecated_compose() -> None: from hydra import initialize from hydra.experimental import compose as expr_compose - with initialize(config_path=None): + with initialize(version_base=None): with warns( expected_warning=UserWarning, match=re.escape( @@ -600,7 +674,9 @@ def test_deprecated_initialize_config_dir() -> None: "hydra.experimental.initialize_config_dir() is no longer experimental. Use hydra.initialize_config_dir()" ), ): - with expr_initialize_config_dir(config_dir=str(Path(".").absolute())): + with expr_initialize_config_dir( + config_dir=str(Path(".").absolute()), + ): assert compose() == {} @@ -617,7 +693,7 @@ def test_deprecated_initialize_config_module() -> None: ), ): with expr_initialize_config_module( - config_module="examples.jupyter_notebooks.cloud_app.conf" + config_module="examples.jupyter_notebooks.cloud_app.conf", ): assert compose() == {} @@ -670,16 +746,22 @@ def test_deprecated_compose_strict_flag(strict: bool) -> None: msg = dedent( """\ - The strict flag in the compose API is deprecated and will be removed in the next version of Hydra. + The strict flag in the compose API is deprecated. See https://hydra.cc/docs/upgrades/0.11_to_1.0/strict_mode_flag_deprecated for more info. """ ) + curr_base = version.getbase() + version.setbase("1.1") + with warns( expected_warning=UserWarning, match=re.escape(msg), ): cfg = compose(overrides=[], strict=strict) + + version.setbase(str(curr_base)) + assert cfg == {} assert OmegaConf.is_struct(cfg) is strict @@ -703,3 +785,23 @@ class Trainer: cfg = compose("trainer/base_trainer") assert cfg == {"trainer": {"reducer": {}}} + + +@mark.usefixtures("initialize_hydra_no_path") +def test_enum_with_removed_defaults_list(hydra_restore_singletons: Any) -> None: + class Category(Enum): + X = 0 + Y = 1 + Z = 2 + + @dataclass + class Conf: + enum_dict: Dict[Category, str] = field(default_factory=dict) + int_dict: Dict[int, str] = field(default_factory=dict) + str_dict: Dict[str, str] = field(default_factory=dict) + + cs = ConfigStore.instance() + cs.store(name="conf", node=Conf) + + cfg = compose("conf") + assert cfg == {"enum_dict": {}, "int_dict": {}, "str_dict": {}} diff --git a/tests/test_config_loader.py b/tests/test_config_loader.py index 38b273fb912..11aaf5b361d 100644 --- a/tests/test_config_loader.py +++ b/tests/test_config_loader.py @@ -7,6 +7,7 @@ from omegaconf import MISSING, OmegaConf, ValidationError, open_dict from pytest import mark, param, raises, warns +from hydra import version from hydra._internal.config_loader_impl import ConfigLoaderImpl from hydra._internal.utils import create_config_search_path from hydra.core.config_store import ConfigStore, ConfigStoreWithProvider @@ -189,6 +190,8 @@ def test_load_yml_file(self, path: str) -> None: config_loader = ConfigLoaderImpl( config_search_path=create_config_search_path(path) ) + curr_base = version.getbase() + version.setbase("1.1") with warns( UserWarning, match="Support for .yml files is deprecated. Use .yaml extension for Hydra config files", @@ -198,6 +201,7 @@ def test_load_yml_file(self, path: str) -> None: overrides=[], run_mode=RunMode.RUN, ) + version.setbase(str(curr_base)) with open_dict(cfg): del cfg["hydra"] @@ -326,11 +330,11 @@ def test_load_config_with_validation_error( msg = dedent( """\ In 'schema_validation_error': ValidationError raised while composing config: - Value 'not_an_int' could not be converted to Integer + Value 'not_an_int'( of type 'str')? could not be converted to Integer full_key: port object_type=MySQLConfig""" ) - with raises(ConfigCompositionException, match=re.escape(msg)): + with raises(ConfigCompositionException, match=msg): config_loader.load_configuration( config_name="schema_validation_error", overrides=[], diff --git a/tests/test_env_defaults.py b/tests/test_env_defaults.py index 5d1d20d9ade..c3b78680550 100644 --- a/tests/test_env_defaults.py +++ b/tests/test_env_defaults.py @@ -11,5 +11,6 @@ def test_env_defaults(tmpdir: Path) -> None: cmd = [ "tests/test_apps/custom_env_defaults/my_app.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", ] run_python_script(cmd) diff --git a/tests/test_examples/test_advanced_config_search_path.py b/tests/test_examples/test_advanced_config_search_path.py index aaad3636fb2..64d4508f329 100644 --- a/tests/test_examples/test_advanced_config_search_path.py +++ b/tests/test_examples/test_advanced_config_search_path.py @@ -37,6 +37,7 @@ def test_config_search_path( cmd = [ "examples/advanced/config_search_path/my_app.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", ] cmd.extend(args) if error is not None: diff --git a/tests/test_examples/test_advanced_package_overrides.py b/tests/test_examples/test_advanced_package_overrides.py index 787bb6ff8f8..5c393704eb1 100644 --- a/tests/test_examples/test_advanced_package_overrides.py +++ b/tests/test_examples/test_advanced_package_overrides.py @@ -12,6 +12,7 @@ def test_advanced_package_override_simple(tmpdir: Path) -> None: cmd = [ "examples/advanced/package_overrides/simple.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", ] result, _err = run_python_script(cmd) assert OmegaConf.create(result) == { @@ -23,6 +24,7 @@ def test_advanced_package_override_two_packages(tmpdir: Path) -> None: cmd = [ "examples/advanced/package_overrides/two_packages.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", ] result, _err = run_python_script(cmd) assert OmegaConf.create(result) == { diff --git a/tests/test_examples/test_basic_sweep.py b/tests/test_examples/test_basic_sweep.py new file mode 100644 index 00000000000..899412d477a --- /dev/null +++ b/tests/test_examples/test_basic_sweep.py @@ -0,0 +1,86 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from pathlib import Path +from textwrap import dedent +from typing import List + +from pytest import mark + +from hydra.test_utils.test_utils import ( + assert_regex_match, + chdir_hydra_root, + run_python_script, +) + +chdir_hydra_root() + + +@mark.parametrize( + "args,expected", + [ + ( + [], + dedent( + """\ + [HYDRA] Launching 4 jobs locally + [HYDRA] \t#0 : db=mysql db.timeout=5 + driver=mysql, timeout=5 + [HYDRA] \t#1 : db=mysql db.timeout=10 + driver=mysql, timeout=10 + [HYDRA] \t#2 : db=postgresql db.timeout=5 + driver=postgresql, timeout=5 + [HYDRA] \t#3 : db=postgresql db.timeout=10 + driver=postgresql, timeout=10""" + ), + ), + ( + ["db=glob([m*],exclude=postgresql)"], + dedent( + """\ + [HYDRA] Launching 2 jobs locally + [HYDRA] \t#0 : db=mysql db.timeout=5 + driver=mysql, timeout=5 + [HYDRA] \t#1 : db=mysql db.timeout=10 + driver=mysql, timeout=10""" + ), + ), + ( + ["db=mysql", "db.user=choice(one,two)"], + dedent( + """\ + [HYDRA] Launching 4 jobs locally + [HYDRA] \t#0 : db=mysql db.timeout=5 db.user=one + driver=mysql, timeout=5 + [HYDRA] \t#1 : db=mysql db.timeout=5 db.user=two + driver=mysql, timeout=5 + [HYDRA] \t#2 : db=mysql db.timeout=10 db.user=one + driver=mysql, timeout=10 + [HYDRA] \t#3 : db=mysql db.timeout=10 db.user=two + driver=mysql, timeout=10""" + ), + ), + ], +) +def test_basic_sweep_example( + tmpdir: Path, + args: List[str], + expected: str, +) -> None: + app_path = "examples/tutorials/basic/running_your_hydra_app/5_basic_sweep/my_app.py" + + cmd = [ + app_path, + "--multirun", + "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", + "hydra.hydra_logging.formatters.simple.format='[HYDRA] %(message)s'", + "hydra.job_logging.formatters.simple.format='[JOB] %(message)s'", + ] + cmd.extend(args) + result, _err = run_python_script(cmd) + + assert_regex_match( + from_line=expected, + to_line=result, + from_name="Expected output", + to_name="Actual output", + ) diff --git a/tests/test_examples/test_configure_hydra.py b/tests/test_examples/test_configure_hydra.py index ff357dc2c49..27a3e35fe4b 100644 --- a/tests/test_examples/test_configure_hydra.py +++ b/tests/test_examples/test_configure_hydra.py @@ -17,6 +17,7 @@ def test_custom_help(tmpdir: Path) -> None: [ "examples/configure_hydra/custom_help/my_app.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", "--help", ] ) @@ -55,6 +56,7 @@ def test_job_name_no_config_override(tmpdir: Path) -> None: cmd = [ "examples/configure_hydra/job_name/no_config_file_override.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", ] result, _err = run_python_script(cmd) assert result == "no_config_file_override" @@ -64,6 +66,7 @@ def test_job_name_with_config_override(tmpdir: Path) -> None: cmd = [ "examples/configure_hydra/job_name/with_config_file_override.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", ] result, _err = run_python_script(cmd) assert result == "name_from_config_file" @@ -73,6 +76,7 @@ def test_job_override_dirname(tmpdir: Path) -> None: cmd = [ "examples/configure_hydra/job_override_dirname/my_app.py", "hydra.sweep.dir=" + str(tmpdir), + "hydra.job.chdir=True", "learning_rate=0.1,0.01", "batch_size=32", "seed=999", @@ -87,6 +91,7 @@ def test_logging(tmpdir: Path) -> None: cmd = [ "examples/configure_hydra/logging/my_app.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", ] result, _err = run_python_script(cmd) assert result == "[INFO] - Info level message" @@ -96,6 +101,7 @@ def test_disabling_logging(tmpdir: Path) -> None: cmd = [ "examples/configure_hydra/logging/my_app.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", "hydra/job_logging=none", "hydra/hydra_logging=none", ] @@ -106,11 +112,11 @@ def test_disabling_logging(tmpdir: Path) -> None: def test_workdir_config(monkeypatch: Any, tmpdir: Path) -> None: script = str(Path("examples/configure_hydra/workdir/my_app.py").absolute()) monkeypatch.chdir(tmpdir) - result, _err = run_python_script([script]) + result, _err = run_python_script([script, "hydra.job.chdir=True"]) assert Path(result) == Path(tmpdir) / "run_dir" result, _err = run_python_script( - [script, "--multirun", "hydra/hydra_logging=disabled"] + [script, "--multirun", "hydra/hydra_logging=disabled", "hydra.job.chdir=True"] ) assert Path(result) == Path(tmpdir) / "sweep_dir" / "0" @@ -118,5 +124,11 @@ def test_workdir_config(monkeypatch: Any, tmpdir: Path) -> None: def test_workdir_override(monkeypatch: Any, tmpdir: Path) -> None: script = str(Path("examples/configure_hydra/workdir/my_app.py").absolute()) monkeypatch.chdir(tmpdir) - result, _err = run_python_script([script, "hydra.run.dir=blah"]) + result, _err = run_python_script( + [ + script, + "hydra.run.dir=blah", + "hydra.job.chdir=True", + ] + ) assert Path(result) == Path(tmpdir) / "blah" diff --git a/tests/test_examples/test_experimental.py b/tests/test_examples/test_experimental.py new file mode 100644 index 00000000000..f3f903d4556 --- /dev/null +++ b/tests/test_examples/test_experimental.py @@ -0,0 +1,18 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +from pathlib import Path + +from hydra.test_utils.test_utils import run_python_script + + +def test_rerun(tmpdir: Path) -> None: + cmd = [ + "examples/experimental/rerun/my_app.py", + "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", + "hydra.hydra_logging.formatters.simple.format='[HYDRA] %(message)s'", + "hydra.job_logging.formatters.simple.format='[JOB] %(message)s'", + ] + + result, _err = run_python_script(cmd) + assert "[JOB] cfg.foo=bar" in result diff --git a/tests/test_examples/test_instantiate_examples.py b/tests/test_examples/test_instantiate_examples.py index a6074209876..324bb0b425c 100644 --- a/tests/test_examples/test_instantiate_examples.py +++ b/tests/test_examples/test_instantiate_examples.py @@ -25,6 +25,7 @@ def test_instantiate_object(tmpdir: Path, overrides: List[str], output: str) -> cmd = [ "examples/instantiate/object/my_app.py", f"hydra.run.dir={tmpdir}", + "hydra.job.chdir=True", ] + overrides result, _err = run_python_script(cmd) assert result == output @@ -46,11 +47,22 @@ def test_instantiate_object_recursive( cmd = [ "examples/instantiate/object_recursive/my_app.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", ] + overrides result, _err = run_python_script(cmd) assert result == output +def test_instantiate_object_partial(tmpdir: Path) -> None: + cmd = [ + "examples/instantiate/partial/my_app.py", + "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", + ] + result, _err = run_python_script(cmd) + assert result == "Model(Optimizer=Optimizer(algo=SGD,lr=0.1))" + + @mark.parametrize( "overrides,output", [ @@ -62,6 +74,7 @@ def test_instantiate_schema(tmpdir: Path, overrides: List[str], output: str) -> cmd = [ "examples/instantiate/schema/my_app.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", ] + overrides result, _err = run_python_script(cmd) assert result == output @@ -90,6 +103,7 @@ def test_instantiate_schema_recursive( cmd = [ "examples/instantiate/schema_recursive/my_app.py", f"hydra.run.dir={tmpdir}", + "hydra.job.chdir=True", ] + overrides result, _err = run_python_script(cmd) assert_text_same(result, expected) @@ -128,6 +142,7 @@ def test_instantiate_docs_example( cmd = [ "examples/instantiate/docs_example/my_app.py", f"hydra.run.dir={tmpdir}", + "hydra.job.chdir=True", ] + overrides result, _err = run_python_script(cmd) assert_text_same(result, expected) diff --git a/tests/test_examples/test_patterns.py b/tests/test_examples/test_patterns.py index 3146444053b..5bb60f23af1 100644 --- a/tests/test_examples/test_patterns.py +++ b/tests/test_examples/test_patterns.py @@ -40,6 +40,7 @@ def test_write_protect_config_node(tmpdir: Any) -> None: cmd = [ "examples/patterns/write_protect_config_node/frozen.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", "data_bits=10", ] @@ -68,7 +69,11 @@ def test_extending_configs( monkeypatch: Any, tmpdir: Path, overrides: List[str] ) -> None: monkeypatch.chdir("examples/patterns/extending_configs") - cmd = ["my_app.py", "hydra.run.dir=" + str(tmpdir)] + overrides + cmd = [ + "my_app.py", + "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", + ] + overrides result, _err = run_python_script(cmd) assert OmegaConf.create(result) == { "db": { @@ -105,7 +110,11 @@ def test_configuring_experiments( monkeypatch: Any, tmpdir: Path, overrides: List[str], expected: Any ) -> None: monkeypatch.chdir("examples/patterns/configuring_experiments") - cmd = ["my_app.py", "hydra.run.dir=" + str(tmpdir)] + overrides + cmd = [ + "my_app.py", + "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", + ] + overrides result, _err = run_python_script(cmd) assert OmegaConf.create(result) == expected @@ -172,6 +181,10 @@ def test_multi_select( monkeypatch: Any, tmpdir: Path, overrides: List[str], expected: Any ) -> None: monkeypatch.chdir("examples/patterns/multi-select") - cmd = ["my_app.py", "hydra.run.dir=" + str(tmpdir)] + overrides + cmd = [ + "my_app.py", + "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", + ] + overrides result, _err = run_python_script(cmd) assert OmegaConf.create(result) == expected diff --git a/tests/test_examples/test_structured_configs_tutorial.py b/tests/test_examples/test_structured_configs_tutorial.py index 3cadc8b9d7d..19022d2d76b 100644 --- a/tests/test_examples/test_structured_configs_tutorial.py +++ b/tests/test_examples/test_structured_configs_tutorial.py @@ -21,6 +21,7 @@ def test_1_basic_run(tmpdir: Path) -> None: [ "examples/tutorials/structured_configs/1_minimal/my_app.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", ] ) assert result == "Host: localhost, port: 3306" @@ -37,6 +38,7 @@ def test_1_basic_run_with_override_error(tmpdir: Path) -> None: [ "examples/tutorials/structured_configs/1_minimal/my_app_type_error.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", ] ) assert re.search(re.escape(expected), err) is not None @@ -47,6 +49,7 @@ def test_1_basic_override(tmpdir: Path) -> None: [ "examples/tutorials/structured_configs/1_minimal/my_app.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", "port=9090", ] ) @@ -57,18 +60,19 @@ def test_1_basic_override_type_error(tmpdir: Path) -> None: cmd = [ "examples/tutorials/structured_configs/1_minimal/my_app.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", "port=foo", ] expected = dedent( """\ - Value 'foo' could not be converted to Integer + Value 'foo'( of type 'str')? could not be converted to Integer full_key: port object_type=MySQLConfig""" ) err = run_with_error(cmd) - assert re.search(re.escape(expected), err) is not None + assert re.search(expected, err) is not None def test_2_static_complex(tmpdir: Path) -> None: @@ -76,6 +80,7 @@ def test_2_static_complex(tmpdir: Path) -> None: [ "examples/tutorials/structured_configs/2_static_complex/my_app.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", ] ) assert result == "Title=My app, size=1024x768 pixels" @@ -95,6 +100,7 @@ def test_3_config_groups(tmpdir: Path, overrides: Any, expected: Any) -> None: cmd = [ "examples/tutorials/structured_configs/3_config_groups/my_app.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", ] cmd.extend(overrides) result, _err = run_python_script(cmd) @@ -129,6 +135,7 @@ def test_3_config_groups_with_inheritance( cmd = [ "examples/tutorials/structured_configs/3_config_groups/my_app_with_inheritance.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", ] + overrides result, _err = run_python_script(cmd) res = OmegaConf.create(result) @@ -139,6 +146,7 @@ def test_4_defaults(tmpdir: Path) -> None: cmd = [ "examples/tutorials/structured_configs/4_defaults/my_app.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", ] result, _err = run_python_script(cmd) assert OmegaConf.create(result) == { @@ -160,7 +168,7 @@ def test_4_defaults(tmpdir: Path) -> None: ], ) def test_5_structured_config_schema(tmpdir: Path, path: str) -> None: - cmd = [path, "hydra.run.dir=" + str(tmpdir)] + cmd = [path, "hydra.run.dir=" + str(tmpdir), "hydra.job.chdir=True"] result, _err = run_python_script(cmd) assert OmegaConf.create(result) == { "db": { diff --git a/tests/test_examples/test_tutorials_basic.py b/tests/test_examples/test_tutorials_basic.py index 1671602daaa..2d748bc5c4a 100644 --- a/tests/test_examples/test_tutorials_basic.py +++ b/tests/test_examples/test_tutorials_basic.py @@ -37,6 +37,7 @@ def test_tutorial_simple_cli_app( cmd = [ "examples/tutorials/basic/your_first_hydra_app/1_simple_cli/my_app.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", ] cmd.extend(args) result, _err = run_python_script(cmd) @@ -47,6 +48,7 @@ def test_tutorial_working_directory(tmpdir: Path) -> None: cmd = [ "examples/tutorials/basic/running_your_hydra_app/3_working_directory/my_app.py", f"hydra.run.dir={tmpdir}", + "hydra.job.chdir=True", ] result, _err = run_python_script(cmd) assert result == "Working directory : {}".format(tmpdir) @@ -63,6 +65,7 @@ def test_tutorial_logging(tmpdir: Path, args: List[str], expected: List[str]) -> cmd = [ "examples/tutorials/basic/running_your_hydra_app/4_logging/my_app.py", f"hydra.run.dir={tmpdir}", + "hydra.job.chdir=True", ] cmd.extend(args) result, _err = run_python_script(cmd) @@ -87,6 +90,7 @@ def test_tutorial_config_file(tmpdir: Path, args: List[str], output_conf: Any) - cmd = [ "examples/tutorials/basic/your_first_hydra_app/2_config_file/my_app.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", ] cmd.extend(args) result, _err = run_python_script(cmd) @@ -108,11 +112,12 @@ def test_tutorial_config_file(tmpdir: Path, args: List[str], output_conf: Any) - def test_tutorial_config_file_bad_key( tmpdir: Path, args: List[str], expected: Any ) -> None: - """ Similar to the previous test, but also tests exception values""" + """Similar to the previous test, but also tests exception values""" cmd = [ "examples/tutorials/basic/your_first_hydra_app/2_config_file/my_app.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", ] cmd.extend(args) if isinstance(expected, RaisesContext): @@ -148,6 +153,7 @@ def test_tutorial_config_groups( cmd = [ "examples/tutorials/basic/your_first_hydra_app/4_config_groups/my_app.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", ] cmd.extend(args) result, _err = run_python_script(cmd) @@ -186,6 +192,7 @@ def test_tutorial_defaults(tmpdir: Path, args: List[str], expected: DictConfig) cmd = [ "examples/tutorials/basic/your_first_hydra_app/5_defaults/my_app.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", ] cmd.extend(args) result, _err = run_python_script(cmd) @@ -288,6 +295,7 @@ def test_advanced_ad_hoc_composition( cmd = [ "examples/advanced/ad_hoc_composition/hydra_compose_example.py", f"hydra.run.dir={tmpdir}", + "hydra.job.chdir=True", ] result, _err = run_python_script(cmd) assert OmegaConf.create(result) == OmegaConf.create(expected) @@ -297,6 +305,7 @@ def test_examples_using_the_config_object(tmpdir: Path) -> None: cmd = [ "examples/tutorials/basic/your_first_hydra_app/3_using_config/my_app.py", f"hydra.run.dir={tmpdir}", + "hydra.job.chdir=True", ] run_python_script(cmd) diff --git a/tests/test_hydra.py b/tests/test_hydra.py index d4249153386..9b5b49eb769 100644 --- a/tests/test_hydra.py +++ b/tests/test_hydra.py @@ -11,7 +11,7 @@ from omegaconf import DictConfig, OmegaConf from pytest import mark, param, raises -from hydra import MissingConfigException +from hydra import MissingConfigException, version from hydra.test_utils.test_utils import ( TSweepRunner, TTaskRunner, @@ -70,7 +70,9 @@ def test_missing_conf_file( def test_run_dir() -> None: - run_python_script(["tests/test_apps/run_dir_test/my_app.py"]) + run_python_script( + ["tests/test_apps/run_dir_test/my_app.py", "hydra.job.chdir=True"] + ) @mark.parametrize( @@ -347,6 +349,7 @@ def test_short_module_name(tmpdir: Path) -> None: cmd = [ "examples/tutorials/basic/your_first_hydra_app/2_config_file/my_app.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", ] out, _err = run_python_script(cmd) assert OmegaConf.create(out) == { @@ -376,6 +379,7 @@ def test_module_env_override(tmpdir: Path, env_name: str) -> None: cmd = [ "examples/tutorials/basic/your_first_hydra_app/2_config_file/my_app.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", ] modified_env = os.environ.copy() modified_env[env_name] = "hydra.test_utils.configs.Foo" @@ -392,6 +396,7 @@ def test_cfg(tmpdir: Path, flag: str, resolve: bool, expected_keys: List[str]) - cmd = [ "examples/tutorials/basic/your_first_hydra_app/5_defaults/my_app.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", flag, ] if resolve: @@ -451,6 +456,7 @@ def test_cfg_with_package( cmd = [ "examples/tutorials/basic/your_first_hydra_app/5_defaults/my_app.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", ] + flags if resolve: cmd.append("--resolve") @@ -505,7 +511,12 @@ def test_cfg_with_package( def test_cfg_resolve_interpolation( tmpdir: Path, script: str, resolve: bool, flags: List[str], expected: str ) -> None: - cmd = [script, "hydra.run.dir=" + str(tmpdir), "--cfg=job"] + flags + cmd = [ + script, + "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", + "--cfg=job", + ] + flags if resolve: cmd.append("--resolve") @@ -513,6 +524,34 @@ def test_cfg_resolve_interpolation( assert_text_same(result, expected) +@mark.parametrize( + "script,expected", + [ + param( + "tests/test_apps/passes_callable_class_to_hydra_main/my_app.py", + dedent( + """\ + 123 + my_app + """ + ), + id="passes_callable_class_to_hydra_main", + ), + ], +) +def test_pass_callable_class_to_hydra_main( + tmpdir: Path, script: str, expected: str +) -> None: + cmd = [ + script, + "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", + ] + + result, _err = run_python_script(cmd) + assert_text_same(result, expected) + + @mark.parametrize( "other_flag", [None, "--run", "--multirun", "--info", "--shell-completion", "--hydra-help"], @@ -521,6 +560,7 @@ def test_resolve_flag_errmsg(tmpdir: Path, other_flag: Optional[str]) -> None: cmd = [ "examples/tutorials/basic/your_first_hydra_app/3_using_config/my_app.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", "--resolve", ] if other_flag is not None: @@ -698,6 +738,7 @@ def test_sweep_complex_defaults( The config_path is relative to the Python file declaring @hydra.main() --config-name,-cn : Overrides the config_name specified in hydra.main() --config-dir,-cd : Adds an additional config dir to the config search path + --experimental-rerun : Rerun a job from a previous config pickle --info,-i : Print Hydra information [all|config|defaults|defaults-tree|plugins|searchpath] Overrides : Any key=value arguments to override config values (use dots for.nested=overrides) """ @@ -755,6 +796,7 @@ def test_sweep_complex_defaults( The config_path is relative to the Python file declaring @hydra.main() --config-name,-cn : Overrides the config_name specified in hydra.main() --config-dir,-cd : Adds an additional config dir to the config search path + --experimental-rerun : Rerun a job from a previous config pickle --info,-i : Print Hydra information [all|config|defaults|defaults-tree|plugins|searchpath] Overrides : Any key=value arguments to override config values (use dots for.nested=overrides) """ @@ -766,7 +808,7 @@ def test_sweep_complex_defaults( def test_help( tmpdir: Path, script: str, flags: List[str], overrides: List[str], expected: Any ) -> None: - cmd = [script, "hydra.run.dir=" + str(tmpdir)] + cmd = [script, "hydra.run.dir=" + str(tmpdir), "hydra.job.chdir=True"] cmd.extend(overrides) cmd.extend(flags) result, _err = run_python_script(cmd) @@ -796,7 +838,7 @@ def test_help( def test_searchpath_config(tmpdir: Path, overrides: List[str], expected: str) -> None: cmd = ["examples/advanced/config_search_path/my_app.py"] cmd.extend(overrides) - cmd.extend(["hydra.run.dir=" + str(tmpdir)]) + cmd.extend(["hydra.run.dir=" + str(tmpdir), "hydra.job.chdir=True"]) result, _err = run_python_script(cmd) assert re.match(expected, result, re.DOTALL) @@ -833,6 +875,7 @@ def test_sys_exit(tmpdir: Path) -> None: "-Werror", "tests/test_apps/sys_exit/my_app.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", ] assert subprocess.run(cmd).returncode == 42 @@ -840,15 +883,19 @@ def test_sys_exit(tmpdir: Path) -> None: @mark.parametrize( "task_config, overrides, expected_dir", [ - ({"hydra": {"run": {"dir": "foo"}}}, [], "foo"), - ({}, ["hydra.run.dir=bar"], "bar"), - ({"hydra": {"run": {"dir": "foo"}}}, ["hydra.run.dir=boom"], "boom"), + ({"hydra": {"run": {"dir": "foo"}}}, ["hydra.job.chdir=True"], "foo"), + ({}, ["hydra.run.dir=bar", "hydra.job.chdir=True"], "bar"), + ( + {"hydra": {"run": {"dir": "foo"}}}, + ["hydra.run.dir=boom", "hydra.job.chdir=True"], + "boom", + ), ( { "hydra": {"run": {"dir": "foo-${hydra.job.override_dirname}"}}, "app": {"a": 1, "b": 2}, }, - ["app.a=20"], + ["app.a=20", "hydra.job.chdir=True"], "foo-app.a=20", ), ( @@ -856,7 +903,7 @@ def test_sys_exit(tmpdir: Path) -> None: "hydra": {"run": {"dir": "foo-${hydra.job.override_dirname}"}}, "app": {"a": 1, "b": 2}, }, - ["app.b=10", "app.a=20"], + ["app.b=10", "app.a=20", "hydra.job.chdir=True"], "foo-app.a=20,app.b=10", ), ], @@ -968,6 +1015,7 @@ def test_config_name_and_path_overrides( cmd = [ "tests/test_apps/app_with_multiple_config_dirs/my_app.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", f"--config-name={config_name}", f"--config-path={config_path}", ] @@ -1017,14 +1065,14 @@ def test_hydra_output_dir( "directory,file,module, error", [ ( - "tests/test_apps/run_as_module/1", + "tests/test_apps/run_as_module_1", "my_app.py", "my_app", "Primary config module is empty", ), - ("tests/test_apps/run_as_module/2", "my_app.py", "my_app", None), - ("tests/test_apps/run_as_module/3", "module/my_app.py", "module.my_app", None), - ("tests/test_apps/run_as_module/4", "module/my_app.py", "module.my_app", None), + ("tests/test_apps/run_as_module_2", "my_app.py", "my_app", None), + ("tests/test_apps/run_as_module_3", "module/my_app.py", "module.my_app", None), + ("tests/test_apps/run_as_module_4", "module/my_app.py", "module.my_app", None), ], ) def test_module_run( @@ -1033,6 +1081,7 @@ def test_module_run( cmd = [ directory + "/" + file, "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", ] modified_env = os.environ.copy() modified_env["PYTHONPATH"] = directory @@ -1053,28 +1102,28 @@ def test_module_run( ["test.param=1,2"], True, dedent( - """\ - Ambiguous value for argument 'test.param=1,2' - 1. To use it as a list, use key=[value1,value2] - 2. To use it as string, quote the value: key=\\'value1,value2\\' - 3. To sweep over it, add --multirun to your command line + r""" + Ambiguous value for argument 'test\.param=1,2' + 1\. To use it as a list, use key=\[value1,value2\] + 2\. To use it as string, quote the value: key=\\'value1,value2\\' + 3\. To sweep over it, add --multirun to your command line - Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.""" - ), + Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace\.""" + ).strip(), id="run:choice_sweep", ), param( ["test.param=[1,2]"], True, dedent( - """\ - Error merging override test.param=[1,2] - Value '[1, 2]' could not be converted to Integer - full_key: test.param + r""" + Error merging override test.param=\[1,2\] + Value '\[1, 2\]'( of type 'list')? could not be converted to Integer + full_key: test\.param object_type=TestConfig - Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.""" - ), + Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace\.""" + ).strip(), id="run:list_value", ), param(["test.param=1", "-m"], False, "1", id="multirun:value"), @@ -1087,14 +1136,15 @@ def test_multirun_structured_conflict( cmd = [ "tests/test_apps/multirun_structured_conflict/my_app.py", "hydra.sweep.dir=" + str(tmpdir), + "hydra.job.chdir=True", ] cmd.extend(overrides) if error: expected = normalize_newlines(expected) ret = run_with_error(cmd) - assert_regex_match( - from_line=expected, - to_line=ret, + assert_multiline_regex_search( + pattern=expected, + string=ret, from_name="Expected output", to_name="Actual output", ) @@ -1120,6 +1170,7 @@ def test_run_with_missing_default( ) -> None: cmd = cmd_base + [ "hydra.sweep.dir=" + str(tmpdir), + "hydra.job.chdir=True", "--config-name=unspecified_mandatory_default", "--config-path=../../../hydra/test_utils/configs", ] @@ -1138,6 +1189,7 @@ def test_command_line_interpolations_evaluated_lazily( ) -> None: cmd = cmd_base + [ "hydra.sweep.dir=" + str(tmpdir), + "hydra.job.chdir=True", "+foo=10,20", "+bar=${foo}", "--multirun", @@ -1155,6 +1207,7 @@ def test_multirun_config_overrides_evaluated_lazily( ) -> None: cmd = cmd_base + [ "hydra.sweep.dir=" + str(tmpdir), + "hydra.job.chdir=True", "+foo=10,20", "+bar=${foo}", "--multirun", @@ -1170,6 +1223,7 @@ def test_multirun_config_overrides_evaluated_lazily( def test_multirun_defaults_override(self, cmd_base: List[str], tmpdir: Any) -> None: cmd = cmd_base + [ "hydra.sweep.dir=" + str(tmpdir), + "hydra.job.chdir=True", "group1=file1,file2", "--multirun", "--config-path=../../../hydra/test_utils/configs", @@ -1186,6 +1240,7 @@ def test_multirun_defaults_override(self, cmd_base: List[str], tmpdir: Any) -> N def test_run_pass_list(self, cmd_base: List[str], tmpdir: Any) -> None: cmd = cmd_base + [ "hydra.sweep.dir=" + str(tmpdir), + "hydra.job.chdir=True", "+foo=[1,2,3]", ] expected = {"foo": [1, 2, 3]} @@ -1197,6 +1252,7 @@ def test_app_with_error_exception_sanitized(tmpdir: Any, monkeypatch: Any) -> No cmd = [ "tests/test_apps/app_with_runtime_config_error/my_app.py", f"hydra.sweep.dir={tmpdir}", + "hydra.job.chdir=True", ] expected = dedent( """\ @@ -1226,6 +1282,7 @@ def test_hydra_to_job_config_interpolation(tmpdir: Any) -> Any: cmd = [ "tests/test_apps/hydra_to_cfg_interpolation/my_app.py", "hydra.sweep.dir=" + str(tmpdir), + "hydra.job.chdir=True", "b=${a}", "a=foo", ] @@ -1254,6 +1311,7 @@ def test_config_dir_argument( cmd = [ "my_app.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", ] cmd.extend(overrides) result, _err = run_python_script(cmd) @@ -1261,10 +1319,11 @@ def test_config_dir_argument( def test_schema_overrides_hydra(monkeypatch: Any, tmpdir: Path) -> None: - monkeypatch.chdir("tests/test_apps/schema-overrides-hydra") + monkeypatch.chdir("tests/test_apps/schema_overrides_hydra") cmd = [ "my_app.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", ] result, _err = run_python_script(cmd) assert result == "job_name: test, name: James Bond, age: 7, group: a" @@ -1275,6 +1334,7 @@ def test_defaults_pkg_with_dot(monkeypatch: Any, tmpdir: Path) -> None: cmd = [ "my_app.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", ] result, _err = run_python_script(cmd) assert OmegaConf.create(result) == { @@ -1322,7 +1382,11 @@ def test_job_exception( expected: str, ) -> None: ret = run_with_error( - ["tests/test_apps/app_exception/my_app.py", f"hydra.run.dir={tmpdir}"] + [ + "tests/test_apps/app_exception/my_app.py", + f"hydra.run.dir={tmpdir}", + "hydra.job.chdir=True", + ] ) assert_regex_match( from_line=expected, @@ -1334,7 +1398,11 @@ def test_job_exception( def test_job_exception_full_error(tmpdir: Any) -> None: ret = run_with_error( - ["tests/test_apps/app_exception/my_app.py", f"hydra.run.dir={tmpdir}"], + [ + "tests/test_apps/app_exception/my_app.py", + f"hydra.run.dir={tmpdir}", + "hydra.job.chdir=True", + ], env={**os.environ, "HYDRA_FULL_ERROR": "1"}, ) @@ -1347,6 +1415,7 @@ def test_structured_with_none_list(monkeypatch: Any, tmpdir: Path) -> None: cmd = [ "my_app.py", "hydra.run.dir=" + str(tmpdir), + "hydra.job.chdir=True", ] result, _err = run_python_script(cmd) assert result == "{'list': None}" @@ -1359,7 +1428,7 @@ def test_self_hydra_config_interpolation_integration(tmpdir: Path) -> None: task_config=cfg, overrides=[], prints="HydraConfig.get().job_logging.handlers.file.filename", - expected_outputs="task.log", + expected_outputs=r".+task.log$", ) @@ -1378,11 +1447,17 @@ def test_hydra_main_without_config_path(tmpdir: Path) -> None: cmd = [ "tests/test_apps/hydra_main_without_config_path/my_app.py", f"hydra.run.dir={tmpdir}", + "hydra.job.chdir=True", ] _, err = run_python_script(cmd, allow_warnings=True) expected = dedent( - """ + f""" + .*my_app.py:7: UserWarning: + The version_base parameter is not specified. + Please specify a compatability version level, or None. + Will assume defaults for version {version.__compat_version__} + @hydra.main() .*my_app.py:7: UserWarning: config_path is not specified in @hydra.main(). See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_hydra_main_config_path for more information. @@ -1397,10 +1472,33 @@ def test_hydra_main_without_config_path(tmpdir: Path) -> None: ) +def test_job_chdir_not_specified(tmpdir: Path) -> None: + cmd = [ + "tests/test_apps/app_with_no_chdir_override/my_app.py", + f"hydra.run.dir={tmpdir}", + ] + out, err = run_python_script(cmd, allow_warnings=True) + + expected = dedent( + """ + .*UserWarning: Future Hydra versions will no longer change working directory at job runtime by default. + See https://hydra.cc/docs/upgrades/1.1_to_1.2/changes_to_job_working_dir for more information..* + .* + """ + ) + assert_regex_match( + from_line=expected, + to_line=err, + from_name="Expected error", + to_name="Actual error", + ) + + def test_app_with_unicode_config(tmpdir: Path) -> None: cmd = [ "tests/test_apps/app_with_unicode_in_config/my_app.py", f"hydra.run.dir={tmpdir}", + "hydra.job.chdir=True", ] out, _ = run_python_script(cmd) assert out == "config: 数据库" @@ -1424,6 +1522,7 @@ def test_frozen_primary_config( cmd = [ "examples/patterns/write_protect_config_node/frozen.py", f"hydra.run.dir={tmpdir}", + "hydra.job.chdir=True", ] cmd.extend(overrides) ret, _err = run_python_script(cmd) @@ -1435,7 +1534,7 @@ def test_frozen_primary_config( [ param( False, - r"^\S*/my_app\.py:10: UserWarning: Feature FooBar is deprecated$", + r"^\S*[/\\]my_app\.py:10: UserWarning: Feature FooBar is deprecated$", id="deprecation_warning", ), param( @@ -1444,7 +1543,7 @@ def test_frozen_primary_config( r""" ^Error executing job with overrides: \[\]\n? Traceback \(most recent call last\): - File "\S*/my_app.py", line 10, in my_app + File "\S*[/\\]my_app.py", line 10, in my_app deprecation_warning\("Feature FooBar is deprecated"\) File "\S*\.py", line 11, in deprecation_warning raise HydraDeprecationError\(.*\) @@ -1463,6 +1562,7 @@ def test_hydra_deprecation_warning( cmd = [ "tests/test_apps/deprecation_warning/my_app.py", f"hydra.run.dir={tmpdir}", + "hydra.job.chdir=True", ] env = os.environ.copy() if env_deprecation_err: @@ -1471,3 +1571,292 @@ def test_hydra_deprecation_warning( cmd, env=env, allow_warnings=True, print_error=False, raise_exception=False ) assert_multiline_regex_search(expected, err) + + +@mark.parametrize( + "multirun,expected", + [ + (False, ["my_app.log", ".hydra/config.yaml"]), + (True, ["0/my_app.log", "0/.hydra/config.yaml", "multirun.yaml"]), + ], +) +def test_disable_chdir(tmpdir: Path, multirun: bool, expected: List[str]) -> None: + cmd = [ + "examples/tutorials/basic/running_your_hydra_app/3_working_directory/my_app.py", + f"hydra.run.dir={tmpdir}", + f"hydra.sweep.dir={tmpdir}", + "hydra.job.chdir=False", + ] + if multirun: + cmd += ["-m"] + result, _err = run_python_script(cmd) + assert f"Working directory : {os.getcwd()}" in result + for p in expected: + path = os.path.join(str(tmpdir), p) + assert Path(path).exists() + + +@mark.parametrize( + "chdir", + [True, False], +) +def test_disable_chdir_with_app_chdir(tmpdir: Path, chdir: bool) -> None: + cmd = [ + "tests/test_apps/app_change_dir/my_app.py", + f"hydra.run.dir={tmpdir}", + f"hydra.job.chdir={chdir}", + ] + result, _err = run_python_script(cmd) + _path = os.getcwd() if chdir else Path(tmpdir) / "subdir" + assert f"current dir: {_path}" in result + + +@mark.parametrize( + "multirun", + [False, True], +) +def test_hydra_verbose_1897(tmpdir: Path, multirun: bool) -> None: + cmd = [ + "tests/test_apps/hydra_verbose/my_app.py", + f"hydra.run.dir={tmpdir}", + "hydra.job.chdir=False", + ] + if multirun: + cmd += ["+a=1,2", "-m"] + run_python_script(cmd) + + +@mark.parametrize( + "multirun", + [False, True], +) +def test_hydra_resolver_in_output_dir(tmpdir: Path, multirun: bool) -> None: + from hydra import __version__ + + subdir = "dir" + "${hydra:runtime.version}" + + output_dir = str(Path(tmpdir) / subdir) + + cmd = [ + "tests/test_apps/hydra_resolver_in_output_dir/my_app.py", + f"hydra.run.dir='{output_dir}'", + f"hydra.sweep.subdir='{subdir}'", + f"hydra.sweep.dir={str(Path(tmpdir))}", + "hydra.job.chdir=False", + ] + + if multirun: + cmd += ["-m"] + + out, _ = run_python_script(cmd) + + expected_subdir = f"dir{__version__}" + expected_output_dir = str(Path(tmpdir) / expected_subdir) + expected_log_file = Path(expected_output_dir) / "my_app.log" + assert expected_log_file.exists() + assert expected_output_dir in out + + +@mark.parametrize( + "overrides,expected_output,error,warning,warning_msg", + [ + param( + ["hydra.mode=RUN"], + "RunMode.RUN", + False, + False, + None, + id="single_run_config", + ), + param( + ["x=1", "hydra.mode=MULTIRUN"], + dedent( + """\ + [HYDRA] Launching 1 jobs locally + [HYDRA] \t#0 : x=1 + RunMode.MULTIRUN""" + ), + False, + False, + None, + id="multi_run_config", + ), + param( + ["--multirun", "x=1"], + dedent( + """\ + [HYDRA] Launching 1 jobs locally + [HYDRA] \t#0 : x=1 + RunMode.MULTIRUN""" + ), + False, + False, + None, + id="multi_run_commandline", + ), + param( + [], + dedent( + """\ + RunMode.RUN""" + ), + False, + False, + None, + id="run_with_no_config", + ), + param( + ["x=1,2", "hydra.mode=RUN"], + dedent( + """\ + Ambiguous value for argument 'x=1,2' + 1. To use it as a list, use key=[value1,value2] + 2. To use it as string, quote the value: key=\\'value1,value2\\' + 3. To sweep over it, add --multirun to your command line + + Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace. + """ + ), + True, + False, + None, + id="illegal_sweep_run", + ), + param( + ["x=1,2", "hydra.mode=MULTIRUN"], + dedent( + """\ + [HYDRA] Launching 2 jobs locally + [HYDRA] \t#0 : x=1 + RunMode.MULTIRUN + [HYDRA] \t#1 : x=2 + RunMode.MULTIRUN""" + ), + False, + False, + None, + id="sweep_from_config", + ), + param( + ["x=1,2", "hydra.mode=MULTIRUN", "hydra/sweeper=test"], + dedent( + """\ + [HYDRA] Launching 1 jobs locally + [HYDRA] \t#0 : x=1 + RunMode.MULTIRUN + [HYDRA] Launching 1 jobs locally + [HYDRA] \t#1 : x=2 + RunMode.MULTIRUN + """ + ), + False, + False, + None, + id="sweep_from_config_with_custom_sweeper", + ), + param( + ["--multirun", "x=1,2", "hydra.mode=RUN"], + dedent( + """ + [HYDRA] Launching 2 jobs locally + [HYDRA] \t#0 : x=1 + RunMode.MULTIRUN + [HYDRA] \t#1 : x=2 + RunMode.MULTIRUN""" + ), + False, + True, + dedent( + """ + .*: UserWarning: + \tRunning Hydra app with --multirun, overriding with `hydra.mode=MULTIRUN`. + .* + """ + ), + id="multirun_commandline_with_run_config_with_warning", + ), + ], +) +def test_hydra_mode( + tmpdir: Path, + overrides: List[str], + expected_output: str, + error: bool, + warning: bool, + warning_msg: Optional[str], +) -> None: + cmd = [ + "tests/test_apps/app_print_hydra_mode/my_app.py", + ] + cmd.extend(overrides) + cmd.extend( + [ + f"hydra.run.dir='{tmpdir}'", + f"hydra.sweep.dir={tmpdir}", + "hydra.job.chdir=False", + "hydra.hydra_logging.formatters.simple.format='[HYDRA] %(message)s'", + "hydra.job_logging.formatters.simple.format='[JOB] %(message)s'", + ] + ) + if error: + + expected = normalize_newlines(expected_output) + ret = run_with_error(cmd) + assert_regex_match( + from_line=expected, + to_line=ret, + from_name="Expected output", + to_name="Actual output", + ) + elif warning: + + out, err = run_python_script(cmd, allow_warnings=True) + assert_regex_match( + from_line=expected_output, + to_line=out, + from_name="Expected output", + to_name="Actual output", + ) + assert warning_msg is not None + assert_regex_match( + from_line=warning_msg, + to_line=err, + from_name="Expected error", + to_name="Actual error", + ) + else: + out, _ = run_python_script(cmd) + assert_regex_match( + from_line=expected_output, + to_line=out, + from_name="Expected output", + to_name="Actual output", + ) + + +def test_hydra_runtime_choice_1882(tmpdir: Path) -> None: + cmd = [ + "tests/test_apps/app_with_cfg_groups/my_app_with_runtime_choices_print.py", + "--multirun", + f"hydra.sweep.dir={tmpdir}", + "hydra.hydra_logging.formatters.simple.format='[HYDRA] %(message)s'", + "hydra.job_logging.formatters.simple.format='[JOB] %(message)s'", + "hydra.job.chdir=False", + "optimizer=adam,nesterov", + ] + expected_output = dedent( + """ + [HYDRA] Launching 2 jobs locally + [HYDRA] \t#0 : optimizer=adam + adam + [HYDRA] \t#1 : optimizer=nesterov + nesterov""" + ) + + out, _ = run_python_script(cmd) + assert_regex_match( + from_line=expected_output, + to_line=out, + from_name="Expected output", + to_name="Actual output", + ) diff --git a/tests/test_hydra_context_warnings.py b/tests/test_hydra_context_warnings.py index ead9b44e82b..6d5f0e97fdf 100644 --- a/tests/test_hydra_context_warnings.py +++ b/tests/test_hydra_context_warnings.py @@ -5,7 +5,7 @@ from unittest.mock import Mock from omegaconf import DictConfig, OmegaConf -from pytest import mark, warns +from pytest import mark, raises from hydra import TaskFunction from hydra._internal.callbacks import Callbacks @@ -13,7 +13,7 @@ from hydra._internal.utils import create_config_search_path from hydra.core.config_loader import ConfigLoader from hydra.core.plugins import Plugins -from hydra.core.utils import JobReturn, _get_callbacks_for_run_job +from hydra.core.utils import JobReturn, _check_hydra_context from hydra.plugins.launcher import Launcher from hydra.plugins.sweeper import Sweeper from hydra.test_utils.test_utils import chdir_hydra_root @@ -73,19 +73,13 @@ def test_setup_plugins( monkeypatch.setattr(Plugins, "check_usage", lambda _: None) monkeypatch.setattr(plugin_instance, "_instantiate", lambda _: plugin) - msg = dedent( - """ - Plugin's setup() signature has changed in Hydra 1.1. - Support for the old style will be removed in Hydra 1.2. - For more info, check https://github.com/facebookresearch/hydra/pull/1581.""" - ) - with warns(expected_warning=UserWarning, match=re.escape(msg)): + msg = "setup() got an unexpected keyword argument 'hydra_context'" + with raises(TypeError, match=re.escape(msg)): if isinstance(plugin, Launcher): Plugins.instance().instantiate_launcher( + hydra_context=hydra_context, task_function=task_function, config=config, - config_loader=config_loader, - hydra_context=hydra_context, ) else: Plugins.instance().instantiate_sweeper( @@ -99,9 +93,8 @@ def test_run_job() -> None: hydra_context = None msg = dedent( """ - run_job's signature has changed in Hydra 1.1. Please pass in hydra_context. - Support for the old style will be removed in Hydra 1.2. + run_job's signature has changed: the `hydra_context` arg is now required. For more info, check https://github.com/facebookresearch/hydra/pull/1581.""" ) - with warns(expected_warning=UserWarning, match=msg): - _get_callbacks_for_run_job(hydra_context) + with raises(TypeError, match=msg): + _check_hydra_context(hydra_context) diff --git a/tests/test_internal_utils.py b/tests/test_internal_utils.py index 0c372ed61af..cc8ff65580a 100644 --- a/tests/test_internal_utils.py +++ b/tests/test_internal_utils.py @@ -1,10 +1,11 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -from typing import Any +from typing import Any, Callable, Optional from omegaconf import DictConfig, OmegaConf from pytest import mark, param from hydra._internal import utils +from tests import data @mark.parametrize( @@ -32,3 +33,23 @@ def test_get_column_widths(matrix: Any, expected: Any) -> None: ) def test_get_class_name(config: DictConfig, expected: Any) -> None: assert utils._get_cls_name(config) == expected + + +@mark.parametrize( + "task_function, expected_file, expected_module", + [ + param(data.foo, None, "tests.data", id="function"), + param(data.foo_main_module, data.__file__, None, id="function-main-module"), + param(data.Bar, None, "tests.data", id="class"), + param(data.bar_instance, None, "tests.data", id="class_inst"), + param(data.bar_instance_main_module, None, None, id="class_inst-main-module"), + ], +) +def test_detect_calling_file_or_module_from_task_function( + task_function: Callable[..., None], + expected_file: Optional[str], + expected_module: Optional[str], +) -> None: + file, module = utils.detect_calling_file_or_module_from_task_function(task_function) + assert file == expected_file + assert module == expected_module diff --git a/tests/test_overrides_parser.py b/tests/test_overrides_parser.py index d4dfb01245a..cbfffb1acdc 100644 --- a/tests/test_overrides_parser.py +++ b/tests/test_overrides_parser.py @@ -31,7 +31,7 @@ ) from hydra.errors import HydraException -UNQUOTED_SPECIAL = r"/-\+.$%*@?" # special characters allowed in unquoted strings +UNQUOTED_SPECIAL = r"/-\+.$%*@?|" # special characters allowed in unquoted strings parser = OverridesParser(create_functions()) diff --git a/tests/test_plugin_interface.py b/tests/test_plugin_interface.py index 66ab7be9f84..8f1bb056f13 100644 --- a/tests/test_plugin_interface.py +++ b/tests/test_plugin_interface.py @@ -1,8 +1,9 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved from typing import List, Type -from pytest import mark +from pytest import mark, raises +from hydra.core.config_search_path import ConfigSearchPath from hydra.core.plugins import Plugins from hydra.plugins.launcher import Launcher from hydra.plugins.plugin import Plugin @@ -31,3 +32,23 @@ def test_discover(plugin_type: Type[Plugin], expected: List[str]) -> None: expected_classes = [get_class(c) for c in expected] for ex in expected_classes: assert ex in plugins + + +def test_register_plugin() -> None: + class MyPlugin(SearchPathPlugin): + def manipulate_search_path(self, search_path: ConfigSearchPath) -> None: + ... + + Plugins.instance().register(MyPlugin) + + assert MyPlugin in Plugins.instance().discover(Plugin) + assert MyPlugin in Plugins.instance().discover(SearchPathPlugin) + assert MyPlugin not in Plugins.instance().discover(Launcher) + + +def test_register_bad_plugin() -> None: + class NotAPlugin: + ... + + with raises(ValueError, match="Not a valid Hydra Plugin"): + Plugins.instance().register(NotAPlugin) # type: ignore diff --git a/tools/configen/configen/configen.py b/tools/configen/configen/configen.py index 21a3e7a2d9d..a7e8ad0f592 100644 --- a/tools/configen/configen/configen.py +++ b/tools/configen/configen/configen.py @@ -57,6 +57,7 @@ def init_config(conf_dir: str) -> None: sys.exit(1) sample_config = pkgutil.get_data(__name__, "templates/sample_config.yaml") + assert sample_config is not None file.write_bytes(sample_config) @@ -147,7 +148,7 @@ def get_default_flags(module: ModuleConf) -> List[Parameter]: Parameter( name="_recursive_", type_str="bool", - default=module.default_flags._recursive_, + default=str(module.default_flags._recursive_), ) ) @@ -156,7 +157,7 @@ def get_default_flags(module: ModuleConf) -> List[Parameter]: def generate_module(cfg: ConfigenConf, module: ModuleConf) -> str: classes_map: Dict[str, ClassInfo] = {} - imports = set() + imports: Set[Any] = set() string_imports: Set[str] = set() default_flags = get_default_flags(module) @@ -165,7 +166,7 @@ def generate_module(cfg: ConfigenConf, module: ModuleConf) -> str: full_name = f"{module.name}.{class_name}" cls = hydra.utils.get_class(full_name) sig = inspect.signature(cls) - resolved_hints = get_type_hints(cls.__init__) + resolved_hints = get_type_hints(cls.__init__) # type: ignore params: List[Parameter] = [] params = params + default_flags @@ -237,7 +238,7 @@ def generate_module(cfg: ConfigenConf, module: ModuleConf) -> str: ) -@hydra.main(config_path=None, config_name="configen_schema") +@hydra.main(version_base=None, config_name="configen_schema") def main(cfg: Config): if cfg.init_config_dir is not None: init_config(cfg.init_config_dir) diff --git a/tools/configen/configen/utils.py b/tools/configen/configen/utils.py index fc7e624e9b1..f102a42970b 100644 --- a/tools/configen/configen/utils.py +++ b/tools/configen/configen/utils.py @@ -1,7 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import sys from enum import Enum -from typing import Any, Dict, List, Optional, Set, Tuple, Type +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple from omegaconf._utils import ( _resolve_optional, @@ -65,7 +65,7 @@ def is_tuple_annotation(type_: Any) -> bool: return origin is tuple # pragma: no cover -def convert_imports(imports: Set[Type], string_imports: List[str]) -> List[str]: +def convert_imports(imports: Set[Any], string_imports: Iterable[str]) -> List[str]: tmp = set() for imp in string_imports: tmp.add(imp) @@ -94,7 +94,7 @@ def convert_imports(imports: Set[Type], string_imports: List[str]) -> List[str]: return sorted(list(tmp)) -def collect_imports(imports: Set[Type], type_: Type) -> None: +def collect_imports(imports: Set[Any], type_: Any) -> None: if is_list_annotation(type_): collect_imports(imports, get_list_element_type(type_)) type_ = List diff --git a/tools/configen/example/__init__.py b/tools/configen/example/__init__.py new file mode 100644 index 00000000000..168f9979a46 --- /dev/null +++ b/tools/configen/example/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved diff --git a/tools/configen/example/config/__init__.py b/tools/configen/example/config/__init__.py new file mode 100644 index 00000000000..168f9979a46 --- /dev/null +++ b/tools/configen/example/config/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved diff --git a/tools/configen/example/config/configen/__init__.py b/tools/configen/example/config/configen/__init__.py new file mode 100644 index 00000000000..168f9979a46 --- /dev/null +++ b/tools/configen/example/config/configen/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved diff --git a/tools/configen/example/config/configen/samples/__init__.py b/tools/configen/example/config/configen/samples/__init__.py new file mode 100644 index 00000000000..168f9979a46 --- /dev/null +++ b/tools/configen/example/config/configen/samples/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved diff --git a/tools/configen/example/my_app.py b/tools/configen/example/my_app.py index 2c831bebac1..f2e5eba8958 100644 --- a/tools/configen/example/my_app.py +++ b/tools/configen/example/my_app.py @@ -18,7 +18,7 @@ ) -@hydra.main(config_path=".", config_name="config") +@hydra.main(version_base=None, config_path=".", config_name="config") def my_app(cfg: DictConfig) -> None: user: User = hydra.utils.instantiate(cfg.user) admin: Admin = hydra.utils.instantiate(cfg.admin) diff --git a/tools/configen/setup.py b/tools/configen/setup.py index cb2c6815196..4daede298bf 100644 --- a/tools/configen/setup.py +++ b/tools/configen/setup.py @@ -5,7 +5,7 @@ setup( name="hydra-configen", - version="0.9.0dev8", + version="0.9.0.dev8", packages=find_packages(include=["configen"]), entry_points={"console_scripts": ["configen = configen.configen:main"]}, author="Omry Yadan, Rosario Scalise", diff --git a/tools/configen/tests/test_generate.py b/tools/configen/tests/test_generate.py index 2ab69375b7d..10f339bd947 100644 --- a/tools/configen/tests/test_generate.py +++ b/tools/configen/tests/test_generate.py @@ -194,7 +194,7 @@ def test_generated_code_with_default_flags( {"param": "str"}, [], {}, - UnionArg(param="str"), + UnionArg(param="str"), # type: ignore id="UnionArg:illegal_but_ok_arg", ), param( @@ -280,7 +280,8 @@ def test_instantiate_classes( full_class = f"{MODULE_NAME}.generated.{classname}Conf" schema = OmegaConf.structured(get_class(full_class)) cfg = OmegaConf.merge(schema, params) - obj = instantiate(config=cfg, *args, **kwargs) + kwargs["config"] = cfg + obj = instantiate(*args, **kwargs) assert obj == expected @@ -289,6 +290,7 @@ def test_example_application(monkeypatch: Any, tmpdir: Path): cmd = [ "my_app.py", f"hydra.run.dir={tmpdir}", + "hydra.job.chdir=True", "user.name=Batman", ] result, _err = run_python_script(cmd) diff --git a/tools/configen/tests/test_modules/future_annotations.py b/tools/configen/tests/test_modules/future_annotations.py index c0b3d0c2b1c..79591658b4b 100644 --- a/tools/configen/tests/test_modules/future_annotations.py +++ b/tools/configen/tests/test_modules/future_annotations.py @@ -1,5 +1,5 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -from __future__ import annotations # noqa: F407 +from __future__ import annotations # type: ignore # noqa: F407 from dataclasses import dataclass from typing import List, Optional diff --git a/tools/release/conf/config.yaml b/tools/release/conf/config.yaml index 99d861e3841..8413c391a34 100644 --- a/tools/release/conf/config.yaml +++ b/tools/release/conf/config.yaml @@ -1,3 +1,8 @@ defaults: - config_schema - set: all + - _self_ + +hydra: + job: + chdir: true diff --git a/tools/release/release.py b/tools/release/release.py index 5a5468c46db..97695e9b80c 100644 --- a/tools/release/release.py +++ b/tools/release/release.py @@ -72,7 +72,7 @@ def get_releases(metadata: DictConfig) -> List[Version]: @dataclass -class Package: +class PackageInfo: name: str local_version: Version latest_version: Version @@ -84,7 +84,7 @@ def parse_version(ver: str) -> Version: return v -def get_package_info(path: str) -> Package: +def get_package_info(path: str) -> PackageInfo: try: prev = os.getcwd() path = os.path.abspath(path) @@ -101,7 +101,7 @@ def get_package_info(path: str) -> Package: remote_metadata = get_metadata(package_name) latest = get_releases(remote_metadata)[-1] - return Package( + return PackageInfo( name=package_name, local_version=local_version, latest_version=latest ) @@ -123,18 +123,21 @@ def build_package(cfg: Config, pkg_path: str, build_dir: str) -> None: def _next_version(version: str) -> str: cur = parse(version) + assert isinstance(cur, Version) if cur.is_devrelease: prefix = "dev" + assert cur.dev is not None num = cur.dev + 1 new_version = f"{cur.major}.{cur.minor}.{cur.micro}.{prefix}{num}" elif cur.is_prerelease: + assert cur.pre is not None prefix = cur.pre[0] num = cur.pre[1] + 1 new_version = f"{cur.major}.{cur.minor}.{cur.micro}.{prefix}{num}" elif cur.is_postrelease: - prefix = cur.post[0] - num = cur.post[1] + 1 - new_version = f"{cur.major}.{cur.minor}.{cur.micro}.{prefix}{num}" + assert cur.post is not None + num = cur.post + 1 + new_version = f"{cur.major}.{cur.minor}.{cur.micro}.{num}" else: micro = cur.micro + 1 new_version = f"{cur.major}.{cur.minor}.{micro}" @@ -174,7 +177,7 @@ def bump_version(cfg: Config, package: Package, hydra_root: str) -> None: OmegaConf.register_new_resolver("parent_key", lambda _parent_: _parent_._key()) -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def main(cfg: Config) -> None: hydra_root = find_parent_dir_containing(target="ATTRIBUTION") build_dir = f"{os.getcwd()}/{cfg.build_dir}" diff --git a/website/docs/advanced/compose_api.md b/website/docs/advanced/compose_api.md index 7521f0d23d3..140e50823da 100644 --- a/website/docs/advanced/compose_api.md +++ b/website/docs/advanced/compose_api.md @@ -35,6 +35,8 @@ There are 3 initialization methods: All 3 can be used as methods or contexts. When used as methods, they are initializing Hydra globally and should only be called once. When used as contexts, they are initializing Hydra within the context can be used multiple times. +Like @hydra.main() all three support the [version_base](../upgrades/version_base.md) parameter +to define the compatability level to use. ### Code example ```python @@ -43,12 +45,12 @@ from omegaconf import OmegaConf if __name__ == "__main__": # context initialization - with initialize(config_path="conf", job_name="test_app"): + with initialize(version_base=None, config_path="conf", job_name="test_app"): cfg = compose(config_name="config", overrides=["db=mysql", "db.user=me"]) print(OmegaConf.to_yaml(cfg)) # global initialization - initialize(config_path="conf", job_name="test_app") + initialize(version_base=None, config_path="conf", job_name="test_app") cfg = compose(config_name="config", overrides=["db=mysql", "db.user=me"]) print(OmegaConf.to_yaml(cfg)) ``` @@ -71,6 +73,7 @@ def compose( ```python title="Relative initialization" def initialize( + version_base: Optional[str], config_path: Optional[str] = None, job_name: Optional[str] = "app", caller_stack_depth: int = 1, @@ -85,6 +88,7 @@ def initialize( - Python modules - Unit tests - Jupyter notebooks. + :param version_base: compatability level to use. :param config_path: path relative to the parent of the caller :param job_name: the value for hydra.job.name (By default it is automatically detected based on the caller) :param caller_stack_depth: stack depth of the caller, defaults to 1 (direct caller). @@ -92,21 +96,31 @@ def initialize( ``` ```python title="Initialzing with config module" -def initialize_config_module(config_module: str, job_name: str = "app") -> None: +def initialize_config_module( + config_module: str, + version_base: Optional[str], + job_name: str = "app" +) -> None: """ Initializes Hydra and add the config_module to the config search path. The config module must be importable (an __init__.py must exist at its top level) :param config_module: absolute module name, for example "foo.bar.conf". + :param version_base: compatability level to use. :param job_name: the value for hydra.job.name (default is 'app') """ ``` ```python title="Initialzing with config directory" -def initialize_config_dir(config_dir: str, job_name: str = "app") -> None: +def initialize_config_dir( + config_dir: str, + version_base: Optional[str], + job_name: str = "app" +) -> None: """ Initializes Hydra and add an absolute config dir to the to the config search path. The config_dir is always a path on the file system and is must be an absolute path. Relative paths will result in an error. :param config_dir: absolute file system path + :param version_base: compatability level to use. :param job_name: the value for hydra.job.name (default is 'app') """ ``` diff --git a/website/docs/advanced/hydra-command-line-flags.md b/website/docs/advanced/hydra-command-line-flags.md index 0226bb9bc6a..7e31bd71c7d 100644 --- a/website/docs/advanced/hydra-command-line-flags.md +++ b/website/docs/advanced/hydra-command-line-flags.md @@ -3,8 +3,6 @@ id: hydra-command-line-flags title: Hydra's command line flags --- -Hydra is using the command line for two things: - Hydra is using the command line for two things: - Controlling Hydra - Configuring your application (See [Override Grammar](override_grammar/basic.md)) diff --git a/website/docs/advanced/instantiate_objects/config_files.md b/website/docs/advanced/instantiate_objects/config_files.md index 7880d06c049..bcc654f87f2 100644 --- a/website/docs/advanced/instantiate_objects/config_files.md +++ b/website/docs/advanced/instantiate_objects/config_files.md @@ -83,7 +83,7 @@ defaults: With this, you can instantiate the object from the configuration with a single line of code: ```python -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def my_app(cfg): connection = hydra.utils.instantiate(cfg.db) connection.connect() diff --git a/website/docs/advanced/instantiate_objects/overview.md b/website/docs/advanced/instantiate_objects/overview.md index b56b3d47f52..a23c2b4b6da 100644 --- a/website/docs/advanced/instantiate_objects/overview.md +++ b/website/docs/advanced/instantiate_objects/overview.md @@ -38,6 +38,8 @@ def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any: the exception of Structured Configs (and their fields). all : Passed objects are dicts, lists and primitives without a trace of OmegaConf containers + _partial_: If True, return functools.partial wrapped method or object + False by default. Configure per target. :param args: Optional positional parameters pass-through :param kwargs: Optional named parameters to override parameters in the config object. Parameters not present @@ -198,17 +200,184 @@ Trainer( ) ``` -## Parameter conversion strategies -By default, the parameters passed to the target are either primitives (int, float, bool etc) or -OmegaConf containers (DictConfig, ListConfig). -OmegaConf containers have many advantages over primitive dicts and lists but in some cases -it's desired to pass a real dicts and lists (for example, for performance reasons). +### Parameter conversion strategies +By default, the parameters passed to the target are either primitives (int, +float, bool etc) or OmegaConf containers (`DictConfig`, `ListConfig`). +OmegaConf containers have many advantages over primitive dicts and lists, +including convenient attribute access for keys, +[duck-typing as instances of dataclasses or attrs classes](https://omegaconf.readthedocs.io/en/latest/structured_config.html), and +support for [variable interpolation](https://omegaconf.readthedocs.io/en/latest/usage.html#variable-interpolation) +and [custom resolvers](https://omegaconf.readthedocs.io/en/latest/custom_resolvers.html). +If the callable targeted by `instantiate` leverages OmegaConf's features, it +will make sense to pass `DictConfig` and `ListConfig` instances directly to +that callable. + +That being said, in many cases it's desired to pass normal Python dicts and +lists, rather than `DictConfig` or `ListConfig` instances, as arguments to your +callable. You can change instantiate's argument conversion strategy using the +`_convert_` parameter. Supported values are: + +- `"none"` : Default behavior, Use OmegaConf containers +- `"partial"` : Convert OmegaConf containers to dict and list, except Structured Configs. +- `"all"` : Convert everything to primitive containers + +The conversion strategy applies recursively to all subconfigs of the instantiation target. +Here is an example demonstrating the various conversion strategies: -You can change the parameter conversion strategy using the `_convert_` parameter (in your config or the call-site). -Supported values are: +```python +from dataclasses import dataclass +from omegaconf import DictConfig, OmegaConf +from hydra.utils import instantiate + +@dataclass +class Foo: + a: int = 123 + +class MyTarget: + def __init__(self, foo, bar): + self.foo = foo + self.bar = bar + +cfg = OmegaConf.create( + { + "_target_": "__main__.MyTarget", + "foo": Foo(), + "bar": {"b": 456}, + } +) + +obj_none = instantiate(cfg, _convert_="none") +assert isinstance(obj_none, MyTarget) +assert isinstance(obj_none.foo, DictConfig) +assert isinstance(obj_none.bar, DictConfig) + +obj_partial = instantiate(cfg, _convert_="partial") +assert isinstance(obj_partial, MyTarget) +assert isinstance(obj_partial.foo, DictConfig) +assert isinstance(obj_partial.bar, dict) + +obj_all = instantiate(cfg, _convert_="all") +assert isinstance(obj_none, MyTarget) +assert isinstance(obj_all.foo, dict) +assert isinstance(obj_all.bar, dict) +``` + +Passing the `_convert_` keyword argument to `instantiate` has the same effect as defining +a `_convert_` attribute on your config object. Here is an example creating +instances of `MyTarget` that are equivalent to the above: + +```python +cfg_none = OmegaConf.create({..., "_convert_": "none"}) +obj_none = instantiate(cfg_none) + +cfg_partial = OmegaConf.create({..., "_convert_": "partial"}) +obj_partial = instantiate(cfg_partial) + +cfg_all = OmegaConf.create({..., "_convert_": "all"}) +obj_all = instantiate(cfg_all) +``` + +If performance is a concern, note that the `_convert_="none"` strategy does the +least work -- no conversion (from `DictConfig`/`ListConfig` to native python +containers) is taking place. The `_convert_="partial"` strategy does more work, +and `_convert_="all"` does more work yet. + +### Partial Instantiation +Sometimes you may not set all parameters needed to instantiate an object from the configuration, in this case you can set +`_partial_` to be `True` to get a `functools.partial` wrapped object or method, then complete initializing the object in +the application code. Here is an example: + +```python title="Example classes" +class Optimizer: + algo: str + lr: float + + def __init__(self, algo: str, lr: float) -> None: + self.algo = algo + self.lr = lr + + def __repr__(self) -> str: + return f"Optimizer(algo={self.algo},lr={self.lr})" + + +class Model: + def __init__(self, optim_partial: Any, lr: float): + super().__init__() + self.optim = optim_partial(lr=lr) + self.lr = lr + + def __repr__(self) -> str: + return f"Model(Optimizer={self.optim},lr={self.lr})" +``` + +
-- `none` : Default behavior, Use OmegaConf containers -- `partial` : Convert OmegaConf containers to dict and list, except Structured Configs. -- `all` : Convert everything to primitive containers +
-Note that the conversion strategy applies to all the parameters passed to the target. \ No newline at end of file +```yaml title="Config" +model: + _target_: my_app.Model + optim_partial: + _partial_: true + _target_: my_app.Optimizer + algo: SGD + lr: 0.01 +``` + + +
+ +
+ +```python title="Instantiation" +model = instantiate(cfg.model) +print(model) +# "Model(Optimizer=Optimizer(algo=SGD,lr=0.01),lr=0.01) +``` + +
+
+ +If you are repeatedly instantiating the same config, +using `_partial_=True` may provide a significant speedup as compared with regular (non-partial) instantiation. +```python +factory = instantiate(config, _partial_=True) +obj = factory() +``` +In the above example, repeatedly calling `factory` would be faster than repeatedly calling `instantiate(config)`. +A caveat of this approach is that the same keyword arguments would be re-used in each call to `factory`. +```python +class Foo: + ... + +class Bar: + def __init__(self, foo): + self.foo = foo + +bar_conf = { + "_target_": "__main__.Bar", + "foo": {"_target_": "__main__.Foo"}, +} + +bar_factory = instantiate(bar_conf, _partial_=True) +bar1 = bar_factory() +bar2 = bar_factory() + +assert bar1 is not bar2 +assert bar1.foo is bar2.foo # the `Foo` instance is re-used here +``` +This does not apply if `_partial_=False`, +in which case a new `Foo` instance would be created with each call to `instantiate`. + + +### Instantiation of builtins + +The value of `_target_` passed to `instantiate` should be a "dotpath" pointing +to some callable that can be looked up via a combination of `import` and `getattr`. +If you want to target one of Python's [built-in functions](https://docs.python.org/3/library/functions.html) (such as `len` or `print` or `divmod`), +you will need to provide a dotpath looking up that function in Python's [`builtins`](https://docs.python.org/3/library/builtins.html) module. +```python +from hydra.utils import instantiate +# instantiate({"_target_": "len"}, [1,2,3]) # this gives an InstantiationException +instantiate({"_target_": "builtins.len"}, [1,2,3]) # this works, returns the number 3 +``` diff --git a/website/docs/advanced/override_grammar/basic.md b/website/docs/advanced/override_grammar/basic.md index 2894bd96cbe..66574961a2e 100644 --- a/website/docs/advanced/override_grammar/basic.md +++ b/website/docs/advanced/override_grammar/basic.md @@ -22,9 +22,9 @@ The rest are manipulating the config object. - Removing a config value : `~foo.bar`, `~foo.bar=value` ### Modifying the Defaults List -- Overriding selected Option: `db=mysql` -- Appending to Defaults List: `+db=mysql` -- Deleting from Defaults List: `~db`, `~db=mysql` +- Overriding selected Option: `db=mysql`, `server/db=mysql` +- Appending to Defaults List: `+db=mysql`, `+server/db=mysql` +- Deleting from Defaults List: `~db`, `~db=mysql`, `~server/db`, `~server/db=mysql` ## Grammar Hydra supports a rich [DSL](https://en.wikipedia.org/wiki/Domain-specific_language) in the command line. @@ -84,7 +84,7 @@ primitive: | FLOAT // 3.14, -20.0, 1e-1, -10e3 | BOOL // true, TrUe, false, False | INTERPOLATION // ${foo.bar}, ${oc.env:USER,me} - | UNQUOTED_CHAR // /, -, \, +, ., $, %, *, @, ? + | UNQUOTED_CHAR // /, -, \, +, ., $, %, *, @, ?, | | COLON // : | ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \, | WS // whitespaces @@ -97,7 +97,7 @@ dictKey: | INT // 0, 10, -20, 1_000_000 | FLOAT // 3.14, -20.0, 1e-1, -10e3 | BOOL // true, TrUe, false, False - | UNQUOTED_CHAR // /, -, \, +, ., $, %, *, @, ? + | UNQUOTED_CHAR // /, -, \, +, ., $, %, *, @, ?, | | ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \, | WS // whitespaces )+; @@ -147,6 +147,29 @@ Quoted strings can accept any value between the quotes, but some characters need +It may be necessary to use multiple pairs of quotes to prevent your +shell from consuming quotation marks before they are passed to hydra. + +```shell +$ python my_app.py '+foo="{a: 10}"' +foo: '{a: 10}' + +$ python my_app.py '+foo={a: 10}' +foo: + a: 10 + +``` + +Here are some best practices around quoting in CLI overrides: +- Quote the whole key=value pair with single quotes, as in the first two + examples above. These quotes are for the benefit of the shell. +- Do not quote keys. +- Only quote values if they contain a space. It will work if you always quote + values, but it will turn numbers/dicts/lists into strings (as in the first + example above). +- When you are quoting values, use double quotes to avoid collision with the + outer single quoted consumed by the shell. + ### Whitespaces in unquoted values Unquoted Override values can contain non leading or trailing whitespaces. For example, `msg=hello world` is a legal override (key is `msg` and value is the string `hello world`). @@ -164,7 +187,8 @@ $ python my_app.py 'msg= hello world ' ``` ### Escaped characters in unquoted values -Some otherwise special characters may be included in unquoted values by escaping them with a `\`. +Hydra's parser considers some characters to be illegal in unquoted strings. +These otherwise special characters may be included in unquoted values by escaping them with a `\`. These characters are: `\()[]{}:=, \t` (the last two ones being the whitespace and tab characters). As an example, in the following `dir` is set to the string `job{a=1,b=2,c=3}`: @@ -173,6 +197,14 @@ As an example, in the following `dir` is set to the string `job{a=1,b=2,c=3}`: $ python my_app.py 'dir=job\{a\=1\,b\=2\,c\=3\}' ``` +As an alternative to escaping special characters with a backslash, the value containing the special character may be quoted: + +```shell +$ python my_app.py 'dir=A[B' # parser error +$ python my_app.py 'dir="A[B"' # ok +$ python my_app.py 'dir=A\[B' # ok +``` + ### Primitives - `id` : oompa10, loompa_12 - `null`: null diff --git a/website/docs/advanced/plugins/develop.md b/website/docs/advanced/plugins/develop.md index 0f2b6d4a615..d4b9805d350 100644 --- a/website/docs/advanced/plugins/develop.md +++ b/website/docs/advanced/plugins/develop.md @@ -11,17 +11,37 @@ If you develop plugins, please join the None: + """Hydra users should call this function before invoking @hydra.main""" + Plugins.instance().register(MyPlugin) +``` + ## Getting started The best way to get started developing a Hydra plugin is to base your new plugin on one of the example plugins: diff --git a/website/docs/advanced/search_path.md b/website/docs/advanced/search_path.md index 208f30ab7ca..7a0587092dc 100644 --- a/website/docs/advanced/search_path.md +++ b/website/docs/advanced/search_path.md @@ -34,8 +34,15 @@ Using the `config_path` parameter `@hydra.main()`. The `config_path` is relati In some cases you may want to add multiple locations to the search path. For example, an app may want to read the configs from an additional Python module or -an additional directory on the file system. -Configure this using `hydra.searchpath` in your primary config or your command line. +an additional directory on the file system. Another example is in unit testing, +where the defaults list in a config loaded from the `tests/configs` folder may +make reference to another config from the `app/configs` folder. If the +`config_path` or `config_dir` argument passed to `@hydra.main` or to one of the +[initialization methods](compose_api.md#initialization-methods) points to +`tests/configs`, the configs located in `app/configs` will not be discoverable +unless Hydra's search path is modified. + +You can configure `hydra.searchpath` in your primary config or from the command line. :::info hydra.searchpath can **only** be configured in the primary config. Attempting to configure it in other configs will result in an error. ::: @@ -62,7 +69,7 @@ In this example, we add a second config directory - `additional_conf`, next to t ```python title="my_app.py" -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def my_app(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/website/docs/advanced/terminology.md b/website/docs/advanced/terminology.md index 8a3e7c1c719..905231bbec7 100644 --- a/website/docs/advanced/terminology.md +++ b/website/docs/advanced/terminology.md @@ -72,7 +72,7 @@ defaults: ``` ## Config Group -A Config Group is directory in the [Config Search Path](#config-search-path) that contains [Input Configs](#input-configs). +A Config Group is a directory in the [Config Search Path](#config-search-path) that contains [Input Configs](#input-configs). Config Groups can be nested, and in that case the path elements are separated by a forward slash ('/') regardless of the operating system. ## Config Group Option @@ -128,4 +128,4 @@ The [Config Search Path](search_path.md) is a list of paths that are searched in the Python [PYTHONPATH](https://docs.python.org/3/using/cmdline.html#envvar-PYTHONPATH). ## Plugins -[Plugins](plugins/intro.md) extend Hydra's capabilities. Hydra has several plugin types, for examples Launcher and Sweeper. \ No newline at end of file +[Plugins](plugins/intro.md) extend Hydra's capabilities. Hydra has several plugin types, for example Launcher and Sweeper. diff --git a/website/docs/advanced/unit_testing.md b/website/docs/advanced/unit_testing.md index 83974cb0423..bc0325a4c58 100644 --- a/website/docs/advanced/unit_testing.md +++ b/website/docs/advanced/unit_testing.md @@ -18,11 +18,15 @@ from hydra import initialize, compose # it needs to have a __init__.py (can be empty). # 3. THe config path is relative to the file calling initialize (this file) def test_with_initialize() -> None: - with initialize(config_path="../hydra_app/conf"): + with initialize(version_base=None, config_path="../hydra_app/conf"): # config is relative to a module cfg = compose(config_name="config", overrides=["app.user=test_user"]) assert cfg == { "app": {"user": "test_user", "num1": 10, "num2": 20}, "db": {"host": "localhost", "port": 3306}, } -``` \ No newline at end of file +``` + +For an idea about how to modify Hydra's search path when using `compose` in +unit tests, see the page on +[overriding the `hydra.searchpath` config](search_path.md#overriding-hydrasearchpath-config). diff --git a/website/docs/configure_hydra/Intro.md b/website/docs/configure_hydra/Intro.md index 272a3a12278..8792894053d 100644 --- a/website/docs/configure_hydra/Intro.md +++ b/website/docs/configure_hydra/Intro.md @@ -83,6 +83,7 @@ You can find more details in the [Job Configuration](job.md) page. Fields under **hydra.job**: - **name** : Job name, defaults to the Python file name without the suffix. can be overridden. - **override_dirname** : Pathname derived from the overrides for this job +- **chdir**: If `True`, Hydra calls `os.chdir(output_dir)` before calling back to the user's main function. - **id** : Job ID in the underlying jobs system (SLURM etc) - **num** : job serial number in sweep - **config_name** : The name of the config used by the job (Output only) @@ -90,13 +91,53 @@ Fields under **hydra.job**: - **env_copy**: Environment variable to copy from the launching machine - **config**: fine-grained configuration for job +### hydra.run: +Used in single-run mode (i.e. when the `--multirun` command-line flag is omitted). +See [configuration for run](workdir.md#configuration-for-run). +- **dir**: used to specify the output directory. + +### hydra.sweep: +Used in multi-run mode (i.e. when the `--multirun` command-line flag is given) +See [configuration for multirun](workdir.md#configuration-for-multirun). +- **dir**: used to specify the output directory common to all jobs in the multirun sweep +- **subdir**: used to specify the a pattern for creation of job-specific subdirectory + ### hydra.runtime: Fields under **hydra.runtime** are populated automatically and should not be overridden. - **version**: Hydra's version - **cwd**: Original working directory the app was executed from +- **output_dir**: This is the directory created by Hydra for saving logs and + yaml config files, as configured by [customizing the working directory pattern](workdir.md). - **choices**: A dictionary containing the final config group choices. - **config_sources**: The final list of config sources used to compose the config. +### hydra.overrides +Fields under **hydra.overrides** are populated automatically and should not be overridden. +- **task**: Contains a list of the command-line overrides used, except `hydra` config overrides. + Contains the same information as the `.hydra/overrides.yaml` file. + See [Output/Working directory](/tutorials/basic/running_your_app/3_working_directory.md). +- **hydra**: Contains a list of the command-line `hydra` config overrides used. + +### hydra.mode +See [multirun](/tutorials/basic/running_your_app/2_multirun.md) for more info. + +### Other Hydra settings +The following fields are present at the top level of the Hydra Config. +- **searchpath**: A list of paths that Hydra searches in order to find configs. + See [overriding `hydra.searchpath`](advanced/search_path.md#overriding-hydrasearchpath-config) +- **job_logging** and **hydra_logging**: Configure logging settings. + See [logging](/tutorials/basic/running_your_app/4_logging.md) and [customizing logging](logging.md). +- **sweeper**: [Sweeper](/tutorials/basic/running_your_app/2_multirun.md#sweeper) plugin settings. Defaults to basic sweeper. +- **launcher**: [Launcher](/tutorials/basic/running_your_app/2_multirun.md#launcher) plugin settings. Defaults to basic launcher. +- **callbacks**: [Experimental callback support](/experimental/callbacks.md). +- **help**: Configures your app's `--help` CLI flag. See [customizing application's help](app_help.md). +- **hydra_help**: Configures the `--hydra-help` CLI flag. +- **output_subdir**: Configures the `.hydra` subdirectory name. + See [changing or disabling the output subdir](/tutorials/basic/running_your_app/3_working_directory.md#changing-or-disabling-hydras-output-subdir). +- **verbose**: Configures per-file DEBUG-level logging. + See [logging](/tutorials/basic/running_your_app/4_logging.md). + + ### Resolvers provided by Hydra Hydra provides the following [OmegaConf resolvers](https://omegaconf.readthedocs.io/en/latest/usage.html#resolvers) by default. diff --git a/website/docs/configure_hydra/job.md b/website/docs/configure_hydra/job.md index 8e3f1541dab..2b5be2b0e5c 100644 --- a/website/docs/configure_hydra/job.md +++ b/website/docs/configure_hydra/job.md @@ -17,6 +17,9 @@ class JobConf: # Job name, populated automatically unless specified by the user (in config or cli) name: str = MISSING + # Change current working dir to the output dir. + chdir: bool = True + # Concatenation of job overrides that can be used as a part # of the directory name. # This can be configured in hydra.job.config.override_dirname @@ -59,6 +62,12 @@ The job name is used by different things in Hydra, such as the log file name (`$ It is normally derived from the Python file name (The job name of the file `train.py` is `train`). You can override it via the command line, or your config file. +### hydra.job.chdir + +Decides whether Hydra changes the current working directory to the output directory for each job. +Learn more at the [Output/Working directory](/tutorials/basic/running_your_app/3_working_directory.md#disable-changing-current-working-dir-to-jobs-output-dir) page. + + ### hydra.job.override_dirname Enables the creation of an output directory which is based on command line overrides. Learn more at the [Customizing Working Directory](/configure_hydra/workdir.md) page. diff --git a/website/docs/configure_hydra/workdir.md b/website/docs/configure_hydra/workdir.md index 02883e53ffd..e082b373fe5 100644 --- a/website/docs/configure_hydra/workdir.md +++ b/website/docs/configure_hydra/workdir.md @@ -8,6 +8,11 @@ import {ExampleGithubLink} from "@site/src/components/GithubLink" +Hydra automatically creates an output directory used to store log files and +save yaml configs. This directory can be configured by setting `hydra.run.dir` +(for single hydra runs) or `hydra.sweep.dir`/`hydra.sweep.subdir` (for multirun +sweeps). At runtime, the path of the output directory can be +[accessed](Intro.md#accessing-the-hydra-config) via the `hydra.runtime.output_dir` variable. Below are a few examples of customizing output directory patterns. ### Configuration for run diff --git a/website/docs/development/release.md b/website/docs/development/release.md index df52c506f16..57b96be5b6c 100644 --- a/website/docs/development/release.md +++ b/website/docs/development/release.md @@ -11,4 +11,7 @@ The release process may be automated in the future. - Update NEWS.md with towncrier - Create a wheel and source dist for hydra-core: `python -m build` - Upload pip package: `python -m twine upload dist/*` - +- Update the link to the latest stable release in `website/docs/intro.md` +- If you are creating a new release branch: + - [tag a new versioned copy of the docs using docusaurus](https://docusaurus.io/docs/versioning#tagging-a-new-version) + - update `website/docusaurus.config.js` with a pointer to the new release branch on github diff --git a/website/docs/experimental/callbacks.md b/website/docs/experimental/callbacks.md index 49ad5339216..a270468f727 100644 --- a/website/docs/experimental/callbacks.md +++ b/website/docs/experimental/callbacks.md @@ -5,7 +5,9 @@ sidebar_label: Callbacks --- import GithubLink from "@site/src/components/GithubLink" +import {ExampleGithubLink} from "@site/src/components/GithubLink" + The Callback interface enables custom code to be triggered by various Hydra events. @@ -88,7 +90,7 @@ class MyCallback(Callback): print(f"Job ended,uploading...") # uploading... -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def my_app(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/website/docs/experimental/rerun.md b/website/docs/experimental/rerun.md new file mode 100644 index 00000000000..11b0e2ee9c1 --- /dev/null +++ b/website/docs/experimental/rerun.md @@ -0,0 +1,91 @@ +--- +id: rerun +title: Re-run a job from previous config +sidebar_label: Re-run +--- + +import {ExampleGithubLink} from "@site/src/components/GithubLink" + + + +:::caution +This is an experimental feature. Please read through this page to understand what is supported. +::: + +We use the example app linked above for demonstration. To save the configs for re-run, first use the experimental +Hydra Callback for saving the job info: + + +```yaml title="config.yaml" +hydra: + callbacks: + save_job_info: + _target_: hydra.experimental.pickle_job_info_callback.PickleJobInfoCallback +``` + + + + +```python title="Example function" +@hydra.main(config_path=".", config_name="config") +def my_app(cfg: DictConfig) -> None: + log.info(f"output_dir={HydraConfig.get().runtime.output_dir}") + log.info(f"cfg.foo={cfg.foo}") +``` + + +Run the example app: +```commandline +$ python my_app.py +[2022-03-16 14:51:30,905][hydra.experimental.pickle_job_info_callback][INFO] - Saving job configs in /Users/jieru/workspace/hydra/examples/experimental/outputs/2022-03-16/14-51-30/.hydra/config.pickle +[2022-03-16 14:51:30,906][__main__][INFO] - Output_dir=/Users/jieru/workspace/hydra/examples/experimental/outputs/2022-03-16/14-51-30 +[2022-03-16 14:51:30,906][__main__][INFO] - cfg.foo=bar +[2022-03-16 14:51:30,906][hydra.experimental.pickle_job_info_callback][INFO] - Saving job_return in /Users/jieru/workspace/hydra/examples/experimental/outputs/2022-03-16/14-51-30/.hydra/job_return.pickle +``` +The Callback saves `config.pickle` in `.hydra` sub dir, this is what we will use for rerun. + +Now rerun the app +```commandline +$ OUTPUT_DIR=/Users/jieru/workspace/hydra/examples/experimental/outputs/2022-03-16/14-51-30/.hydra/ +$ python my_app.py --experimental-rerun $OUTPUT_DIR/config.pickle +/Users/jieru/workspace/hydra/hydra/main.py:23: UserWarning: Experimental rerun CLI option. + warnings.warn(msg, UserWarning) +[2022-03-16 14:59:21,666][__main__][INFO] - Output_dir=/Users/jieru/workspace/hydra/examples/experimental/outputs/2022-03-16/14-51-30 +[2022-03-16 14:59:21,666][__main__][INFO] - cfg.foo=bar +``` +You will notice `my_app.log` is updated with the logging from the second run, but Callbacks are not called this time. Read on to learn more. + + +### Important Notes +This is an experimental feature. Please reach out if you have any question. +- Only single run is supported. +- `--experimental-rerun` cannot be used with other command-line options or overrides. They will simply be ignored. +- Rerun passes in a cfg_passthrough directly to your application, this means except for logging, no other `hydra.main` +functions are called (such as change working dir, or calling callbacks.) +- The configs are preserved and reconstructed to the best efforts. Meaning we can only guarantee that the `cfg` object +itself passed in by `hydra.main` stays the same across runs. However, configs are resolved lazily. Meaning we cannot +guarantee your application will behave the same if your application resolves configs during run time. In the following example, +`cfg.time_now` will resolve to different value every run. + +
+
+ +```yaml title="config.yaml" +time_now: ${now:%H-%M-%S} + + + +``` + +
+ +
+ +```python title="Example function" +@hydra.main(config_path=".", config_name="config") +def my_app(cfg: DictConfig) -> None: + val = cfg.time_now + # the rest of the application +``` +
+
diff --git a/website/docs/intro.md b/website/docs/intro.md index 38a9b80e6fe..c5cd64b5d50 100644 --- a/website/docs/intro.md +++ b/website/docs/intro.md @@ -27,14 +27,15 @@ Use the version switcher in the top bar to switch between documentation versions | | Version | Release notes | Python Versions | | -------|---------------------------|-------------------------------------------------------------------------------------| -------------------| -| ►| 1.1 (Stable) | [Release notes](https://github.com/facebookresearch/hydra/releases/tag/v1.1.0) | **3.6 - 3.9** | -| | 1.0 | [Release notes](https://github.com/facebookresearch/hydra/releases/tag/v1.0.0rc1) | **3.6 - 3.8** | -| | 0.11 | [Release notes](https://github.com/facebookresearch/hydra/releases/tag/0.11.0) | **2.7, 3.5 - 3.8** | +| ►| 1.1 (Stable) | [Release notes](https://github.com/facebookresearch/hydra/releases/tag/v1.1.1) | **3.6 - 3.9** | +| | 1.0 | [Release notes](https://github.com/facebookresearch/hydra/releases/tag/v1.0.7) | **3.6 - 3.8** | +| | 0.11 | [Release notes](https://github.com/facebookresearch/hydra/releases/tag/v0.11.3) | **2.7, 3.5 - 3.8** | ## Quick start guide -This guide will show you some of the most important features of Hydra. -Read the [tutorial](tutorials/basic/your_first_app/1_simple_cli.md) to gain a deeper understanding. +This guide will show you some of the most important features you get by writing your application as a Hydra app. +If you only want to use Hydra for config composition, check out Hydra's [compose API](advanced/compose_api.md) for an alternative. +Please also read the full [tutorial](tutorials/basic/your_first_app/1_simple_cli.md) to gain a deeper understanding. ### Installation ```commandline @@ -54,7 +55,7 @@ Application: import hydra from omegaconf import DictConfig, OmegaConf -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def my_app(cfg : DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) @@ -141,8 +142,8 @@ There is a whole lot more to Hydra. Read the [tutorial](tutorials/basic/your_fir ## Other stuff ### Community -Ask questions in the chat or StackOverflow (Use the tag #fb-hydra): -* [Zulip Chat](https://hydra-framework.zulipchat.com) +Ask questions on github or StackOverflow (Use the tag #fb-hydra): +* [github](https://github.com/facebookresearch/hydra/discussions) * [StackOverflow](https://stackoverflow.com/questions/tagged/fb-hydra) Follow Hydra on Twitter and Facebook: diff --git a/website/docs/patterns/specializing_config.md b/website/docs/patterns/specializing_config.md index ca3e996e6d7..c13c8ba1d29 100644 --- a/website/docs/patterns/specializing_config.md +++ b/website/docs/patterns/specializing_config.md @@ -5,7 +5,7 @@ title: Specializing configuration import {ExampleGithubLink} from "@site/src/components/GithubLink" - + In some cases the desired configuration should depend on other configuration choices. For example, You may want to use only 5 layers in your Alexnet model if the dataset of choice is cifar10, and the dafault 7 otherwise. @@ -54,6 +54,8 @@ We want the model for alexnet, when trained on cifar - to have 5 layers. ### dataset_model/cifar10_alexnet.yaml ```yaml +# @package _global_ + model: num_layers: 5 ``` diff --git a/website/docs/plugins/ax_sweeper.md b/website/docs/plugins/ax_sweeper.md index 5ceaa88416b..0e534ad26f7 100644 --- a/website/docs/plugins/ax_sweeper.md +++ b/website/docs/plugins/ax_sweeper.md @@ -62,7 +62,6 @@ ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: In this example, we set the range of `x` parameter as an integer in the interval `[-5, 5]` and the range of `y` parameter as a float in the interval `[-5, 10.1]`. Note that in the case of `x`, we used `int(interval(...))` and hence only integers are sampled. In the case of `y`, we used `interval(...)` which refers to a floating-point interval. Other supported formats are fixed parameters (e.g.` banana.x=5.0`), choice parameters (eg `banana.x=choice(1,2,3)`) and range (eg `banana.x=range(1, 10)`). Note that `interval`, `choice` etc. are functions provided by Hydra, and you can read more about them [here](https://hydra.cc/docs/next/advanced/override_grammar/extended/). An important thing to remember is, use [`interval`](https://hydra.cc/docs/next/advanced/override_grammar/extended/#interval-sweep) when we want Ax to sample values from an interval. [`RangeParameter`](https://ax.dev/api/ax.html#ax.RangeParameter) in Ax is equivalent to `interval` in Hydra. Remember to use `int(interval(...))` if you want to sample only integer points from the interval. [`range`](https://hydra.cc/docs/next/advanced/override_grammar/extended/#range-sweep) can be used as an alternate way of specifying choice parameters. For example `python example/banana.py -m banana.x=choice(1, 2, 3, 4)` is equivalent to `python example/banana.py -m banana.x=range(1, 5)`. - The values of the `x` and `y` parameters can also be set using the config file `plugins/hydra_ax_sweeper/example/conf/config.yaml`. For instance, the configuration corresponding to the commandline arguments is as follows: ``` @@ -75,6 +74,14 @@ banana.y: bounds: [-5, 10.1] ``` +To sample in log space, you can tag the commandline override with `log`. E.g. `python example/banana.py -m banana.x=tag(log, interval(1, 1000))`. You can set `log_scale: true` in the input config to achieve the same. +``` +banana.z: + type: range + bounds: [1, 100] + log_scale: true +``` + In general, the plugin supports setting all the Ax supported [Parameters](https://ax.dev/api/core.html?highlight=range#module-ax.core.parameter) in the config. According to the [Ax documentation](https://ax.dev/api/service.html#ax.service.ax_client.AxClient.create_experiment), the required elements in the config are: * `name` - Name of the parameter. It is of type string. diff --git a/website/docs/plugins/nevergrad_sweeper.md b/website/docs/plugins/nevergrad_sweeper.md index ae8d81281ca..f203a1e0376 100644 --- a/website/docs/plugins/nevergrad_sweeper.md +++ b/website/docs/plugins/nevergrad_sweeper.md @@ -27,7 +27,7 @@ defaults: ``` The default configuration is defined and documented here. -There are several standard approaches for configuring plugins. Check [this page](../patterns/configuring_plugins) for more information. +There are several standard approaches for configuring plugins. Check [this page](../patterns/configuring_plugins.md) for more information. ## Example of training using Nevergrad hyperparameter search diff --git a/website/docs/plugins/optuna_sweeper.md b/website/docs/plugins/optuna_sweeper.md index 445cb821f4a..3ddb723fc5c 100644 --- a/website/docs/plugins/optuna_sweeper.md +++ b/website/docs/plugins/optuna_sweeper.md @@ -70,18 +70,9 @@ storage: null study_name: sphere n_trials: 20 n_jobs: 1 -search_space: - x: - type: float - low: -5.5 - high: 5.5 - step: 0.5 - 'y': - type: categorical - choices: - - -5 - - 0 - - 5 +params: + x: range(-5.5,5.5,step=0.5) + y: choice(-5,0,5) ``` The function decorated with `@hydra.main()` returns a float which we want to minimize, the minimum is 0 and reached for: @@ -156,6 +147,7 @@ The output is as follows: #### Range override `range` is converted to [`IntUniformDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.IntUniformDistribution.html). If you apply `shuffle` to `range`, [`CategoricalDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.CategoricalDistribution.html) is used instead. +If any of `range`'s start, stop or step is of type float, it will be converted to [`DiscreteUniformDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.DiscreteUniformDistribution.html)
Example for range override @@ -211,39 +203,8 @@ The output is as follows: ### Configuring through config file -#### Int parameters - -`int` parameters can be defined with the following fields: - -- `type`: `int` -- `low`: lower bound -- `high`: upper bound -- `step`: discretization step (optional) -- `log`: if `true`, space is converted to the log domain - -If `log` is `false`, the parameter is mapped to [`IntUniformDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.IntUniformDistribution.html). Otherwise, the parameter is mapped to [`IntLogUniformDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.IntLogUniformDistribution.html). Please note that `step` can not be set if `log=true`. - -#### Float parameters - -`float` parameters can be defined with the following fields: - -- `type`: `float` -- `low`: lower bound -- `high`: upper bound -- `step`: discretization step -- `log`: if `true`, space is converted to the log domain - -If `log` is `false`, the parameter is mapped to [`UniformDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.UniformDistribution.html) or [`DiscreteUniformDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.DiscreteUniformDistribution.html) depending on the presence or absence of the `step` field, respectively. Otherwise, the parameter is mapped to [`LogUniformDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.LogUniformDistribution.html). Please note that `step` can not be set if `log=true`. - -#### Categorical parameters - -`categorical` parameters can be defined with the following fields: - - - `type`: `categorical` - - `choices`: a list of parameter value candidates - -The parameters are mapped to [`CategoricalDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.CategoricalDistribution.html). - +The syntax in config file is consistent with the above commandline override. For example, a commandline override +`x=range(1,4)` can be expressed in config file as `x: range(1,4)` under the `hydra.sweeper.params` node. ## Example 2: Multi-Objective Optimization @@ -275,17 +236,9 @@ storage: null study_name: multi-objective n_trials: 20 n_jobs: 1 -search_space: - x: - type: float - low: 0 - high: 5 - step: 0.5 - 'y': - type: float - low: 0 - high: 3 - step: 0.5 +params: + x: range(0, 5, step=0.5) + y: range(0, 3, step=0.5) ```
@@ -299,3 +252,72 @@ python example/multi-objective.py --multirun For problems with trade-offs between two different objectives, there may be no single solution that simultaneously minimizes both objectives. Instead, we obtained a set of solutions, namely [Pareto optimal solutions](https://en.wikipedia.org/wiki/Pareto_efficiency), that show the best trade-offs possible between the objectives. In the following figure, the blue dots show the Pareto optimal solutions in the optimization results. ![Pareto-optimal solutions](/plugins/optuna_sweeper/multi_objective_result.png) + +## EXPERIMENTAL: Custom-Search-Space Optimization + +Hydra's Optuna Sweeper allows users to provide a hook for custom search space configuration. +This means you can work directly with the `optuna.trial.Trial` object to suggest parameters. +To use this feature, define a python function with signature `Callable[[DictConfig, optuna.trial.Trial], None]` +and set the `hydra.sweeper.custom_search_space` key in your config to target that function. + +You can find a full example in the same directory as before, where `example/custom-search-space-objective.py` implements a benchmark function to be minimized. +The example shows the use of Optuna's [pythonic search spaces](https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/002_configurations.html) in combination with Hydra. +Part of the search space configuration is defined in config files, and part of it is written in Python. + +
Example: Custom search space configuration + +```yaml +defaults: + - override hydra/sweeper: optuna + +hydra: + sweeper: + sampler: + seed: 123 + direction: minimize + study_name: custom-search-space + storage: null + n_trials: 20 + n_jobs: 1 + + params: + x: range(-5.5, 5.5, 0.5) + y: choice(-5, 0, 5) + # `custom_search_space` should be a dotpath pointing to a + # callable that provides search-space configuration logic: + custom_search_space: .custom-search-space-objective.configure + +x: 1 +y: 1 +z: 100 +max_z_difference_from_x: 0.5 +``` +```python +# example/custom-search-space-objective.py + +... + +def configure(cfg: DictConfig, trial: Trial) -> None: + x_value = trial.params["x"] + trial.suggest_float( + "z", + x_value - cfg.max_z_difference_from_x, + x_value + cfg.max_z_difference_from_x, + ) + trial.suggest_float("+w", 0.0, 1.0) # note +w here, not w as w is a new parameter + +... +``` + +
+ +The method that `custom_search_space` points to must accepts both a DictConfig with already set options and a trial object which needs further configuration. In this example we limit `z` the difference between `x` and `z` to be no more than 0.5. +Note that this `custom_search_space` API should be considered experimental and is subject to change. + +### Order of trial configuration +Configuring a trial object is done in the following sequence: + - search space parameters are set from the `hydra.sweeper.params` key in the config + - Command line overrides are set + - `custom_search_space` parameters are set + +It is not allowed to set search space parameters in the `custom_search_space` method for parameters which have a fixed value from command line overrides. [Trial.user_attrs](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna.trial.Trial.user_attrs) can be inspected to find any of such fixed parameters. \ No newline at end of file diff --git a/website/docs/plugins/submitit_launcher.md b/website/docs/plugins/submitit_launcher.md index 623bb6fbed4..37dd080a050 100644 --- a/website/docs/plugins/submitit_launcher.md +++ b/website/docs/plugins/submitit_launcher.md @@ -55,6 +55,7 @@ cpus_per_gpu: null gpus_per_task: null mem_per_gpu: null mem_per_cpu: null +account: null signal_delay_s: 120 max_num_timeout: 0 additional_parameters: {} diff --git a/website/docs/tutorials/basic/running_your_app/2_multirun.md b/website/docs/tutorials/basic/running_your_app/2_multirun.md index cd898324b45..8a6b875cb86 100644 --- a/website/docs/tutorials/basic/running_your_app/2_multirun.md +++ b/website/docs/tutorials/basic/running_your_app/2_multirun.md @@ -7,10 +7,14 @@ sidebar_label: Multi-run Sometimes you want to run the same application with multiple different configurations. E.g. running a performance test on each of the databases with each of the schemas. -Use the `--multirun` (`-m`) flag and pass a comma separated list specifying the values for each dimension you want to sweep. +You can multirun a Hydra application via either commandline or configuration: -The following sweeps over all combinations of the dbs and schemas. -```text title="$ python my_app.py -m db=mysql,postgresql schema=warehouse,support,school" +### Configure `hydra.mode` (new in Hydra 1.2) +You can configure `hydra.mode` in any supported way. The legal values are `RUN` and `MULTIRUN`. +The following shows how to override from the command-line and sweep over all combinations of the dbs and schemas. +Setting `hydra.mode=MULTIRUN` in your input config would make your application multi-run by default. + +```text title="$ python my_app.py hydra.mode=MULTIRUN db=mysql,postgresql schema=warehouse,support,school" [2021-01-20 17:25:03,317][HYDRA] Launching 6 jobs locally [2021-01-20 17:25:03,318][HYDRA] #0 : db=mysql schema=warehouse [2021-01-20 17:25:03,458][HYDRA] #1 : db=mysql schema=support @@ -21,6 +25,69 @@ The following sweeps over all combinations of the dbs and schemas. ``` The printed configurations have been omitted for brevity. +### `--multirun (-m)` from the command-line +You can achieve the above from command-line as well: +```commandline +python my_app.py --multirun db=mysql,postgresql schema=warehouse,support,school +``` +or +```commandline +python my_app.py -m db=mysql,postgresql schema=warehouse,support,school +``` + +You can access `hydra.mode` at runtime to determine whether the application is in RUN or MULTIRUN mode. Check [here](/configure_hydra/Intro.md) +on how to access Hydra config at run time. + +If conflicts arise (eg, `hydra.mode=RUN` and the application was run with `--multirun`), Hydra will determine the value of `hydra.mode` +at run time. The following table shows what runtime `hydra.mode` value you'd get with different input configs and commandline combinations. + +[//]: # (Conversion matrix) + +| | No multirun commandline flag | --multirun ( -m) | +|-------------------- |-------------------------------------|-------------------------------------| +|hydra.mode=RUN | RunMode.RUN | RunMode.MULTIRUN (with UserWarning) | +|hydra.mode=MULTIRUN | RunMode.MULTIRUN | RunMode.MULTIRUN | +|hydra.mode=None (default) | RunMode.RUN | RunMode.MULTIRUN | + + +:::important +Hydra composes configs lazily at job launching time. If you change code or configs after launching a job/sweep, the final +composed configs might be impacted. +::: + +### Sweeping via `hydra.sweeper.params` + +import {ExampleGithubLink} from "@site/src/components/GithubLink" + + + +You can also define sweeping in the input configs by overriding +`hydra.sweeper.params`. Using the above example, the same multirun could be achieved via the following config. + +```yaml +hydra: + sweeper: + params: + db: mysql,postgresql + schema: warehouse,support,school +``` + +The syntax are consistent for both input configs and commandline overrides. +If a sweep is specified in both an input config and at the command line, +then the commandline sweep will take precedence over the sweep defined +in the input config. If we run the same application with the above input config and a new commandline override: + +```text title="$ python my_app.py -m db=mysql" +[2021-01-20 17:25:03,317][HYDRA] Launching 3 jobs locally +[2021-01-20 17:25:03,318][HYDRA] #0 : db=mysql schema=warehouse +[2021-01-20 17:25:03,458][HYDRA] #1 : db=mysql schema=support +[2021-01-20 17:25:03,602][HYDRA] #2 : db=mysql schema=school +``` +:::info +The above configuration methods only apply to Hydra's default `BasicSweeper` for now. For other sweepers, please checkout the +corresponding documentations. +::: + ### Additional sweep types Hydra supports other kinds of sweeps, e.g: ```python diff --git a/website/docs/tutorials/basic/running_your_app/3_working_directory.md b/website/docs/tutorials/basic/running_your_app/3_working_directory.md index a39621016f8..43c91d7733c 100644 --- a/website/docs/tutorials/basic/running_your_app/3_working_directory.md +++ b/website/docs/tutorials/basic/running_your_app/3_working_directory.md @@ -53,14 +53,38 @@ Inside the Hydra output directory we have: And in the main output directory: * `my_app.log`: A log file created for this run -### Changing or disabling the output subdir +### Disable changing current working dir to job's output dir + +By default, Hydra's `@hydra.main` decorator makes a call to `os.chdir` before passing control to the user's decorated main function. +Set `hydra.job.chdir=False` to disable this behavior. +```bash +# check current working dir +$ pwd +/home/omry/dev/hydra + +# working dir remains the same +$ python my_app.py hydra.job.chdir=False +Working directory : /home/omry/dev/hydra + +# output dir and files are still created, even if `chdir` is disabled: +$ tree -a outputs/2021-10-25/09-46-26/ +outputs/2021-10-25/09-46-26/ +├── .hydra +│   ├── config.yaml +│   ├── hydra.yaml +│   └── overrides.yaml +└── my_app.log +``` + + +### Changing or disabling Hydra's output subdir You can change the `.hydra` subdirectory name by overriding `hydra.output_subdir`. You can disable its creation by overriding `hydra.output_subdir` to `null`. -### Original working directory +### Accessing the original working directory in your application -You can still access the original working directory via `get_original_cwd()` and `to_absolute_path()` in `hydra.utils`: +With `hydra.job.chdir=True`, you can still access the original working directory by importing `get_original_cwd()` and `to_absolute_path()` in `hydra.utils`: ```python from hydra.utils import get_original_cwd, to_absolute_path @@ -71,6 +95,9 @@ def my_app(_cfg: DictConfig) -> None: print(f"Orig working directory : {get_original_cwd()}") print(f"to_absolute_path('foo') : {to_absolute_path('foo')}") print(f"to_absolute_path('/foo') : {to_absolute_path('/foo')}") + +if __name__ == "__main__": + my_app() ``` ```text title="$ python examples/tutorial/8_working_directory/original_cwd.py" diff --git a/website/docs/tutorials/basic/running_your_app/4_logging.md b/website/docs/tutorials/basic/running_your_app/4_logging.md index 5c33c4fae2c..5ad115909fc 100644 --- a/website/docs/tutorials/basic/running_your_app/4_logging.md +++ b/website/docs/tutorials/basic/running_your_app/4_logging.md @@ -42,7 +42,8 @@ You can enable DEBUG level logging from the command line by overriding `hydra.v Examples: * `hydra.verbose=true` : Sets the log level of **all** loggers to `DEBUG` -* `hydra.verbose=NAME` : Sets the log level of the logger `NAME` to `DEBUG` +* `hydra.verbose=NAME` : Sets the log level of the logger `NAME` to `DEBUG`. + Equivalent to `import logging; logging.getLogger(NAME).setLevel(logging.DEBUG)`. * `hydra.verbose=[NAME1,NAME2]`: Sets the log level of the loggers `NAME1` and `NAME2` to `DEBUG` Example output: diff --git a/website/docs/tutorials/basic/running_your_app/6_tab_completion.md b/website/docs/tutorials/basic/running_your_app/6_tab_completion.md index 4de0f8909b6..9ee7f680dfa 100644 --- a/website/docs/tutorials/basic/running_your_app/6_tab_completion.md +++ b/website/docs/tutorials/basic/running_your_app/6_tab_completion.md @@ -16,15 +16,24 @@ import Script from '@site/src/components/Script.jsx'; ### Install tab completion Get the exact command to install the completion from `--hydra-help`. -Currently, Bash, zsh and Fish are supported. +Currently, Bash, zsh and Fish are supported. We are relying on the community to implement tab completion plugins for additional shells. +#### Fish instructions Fish support requires version >= 3.1.2. Previous versions will work but add an extra space after `.`. +Because the fish shell implements special behavior for expanding words prefixed +with a tilde character '~', command-line completion does not work for +[tilde deletions](/advanced/override_grammar/basic.md#modifying-the-defaults-list). + #### Zsh instructions Zsh is compatible with the existing Bash shell completion by appending ``` autoload -Uz bashcompinit && bashcompinit ``` -to the `.zshrc` file after `compinit`, restarting the shell and then using the commands provided for Bash. \ No newline at end of file +to the `.zshrc` file after `compinit`, restarting the shell and then using the commands provided for Bash. + +Because the zsh shell implements special behavior for expanding words prefixed +with a tilde character '~', command-line completion does not work for +[tilde deletions](/advanced/override_grammar/basic.md#modifying-the-defaults-list). \ No newline at end of file diff --git a/website/docs/tutorials/basic/your_first_app/1_simple_cli.md b/website/docs/tutorials/basic/your_first_app/1_simple_cli.md index ff2b71a193b..d81b2967caa 100644 --- a/website/docs/tutorials/basic/your_first_app/1_simple_cli.md +++ b/website/docs/tutorials/basic/your_first_app/1_simple_cli.md @@ -17,7 +17,7 @@ The examples in this tutorial are available None: print(OmegaConf.to_yaml(cfg)) + +if __name__ == "__main__": + my_app() ``` Running `my_app.py` without requesting a configuration will print an empty config. diff --git a/website/docs/tutorials/basic/your_first_app/5_defaults.md b/website/docs/tutorials/basic/your_first_app/5_defaults.md index 18cab3fca11..21f7827b697 100644 --- a/website/docs/tutorials/basic/your_first_app/5_defaults.md +++ b/website/docs/tutorials/basic/your_first_app/5_defaults.md @@ -24,9 +24,15 @@ defaults: Remember to specify the `config_name`: ```python -@hydra.main(config_path="conf", config_name="config") +from omegaconf import DictConfig, OmegaConf +import hydra + +@hydra.main(version_base=None, config_path="conf", config_name="config") def my_app(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) + +if __name__ == "__main__": + my_app() ``` When you run the updated application, MySQL is loaded by default. diff --git a/website/docs/tutorials/structured_config/10_config_store.md b/website/docs/tutorials/structured_config/10_config_store.md index 948524bafb0..a7de100805b 100644 --- a/website/docs/tutorials/structured_config/10_config_store.md +++ b/website/docs/tutorials/structured_config/10_config_store.md @@ -2,10 +2,10 @@ id: config_store title: Config Store API --- -`ConfigStore` is a singleton storing configs in memory. -The primary API for interacting with the `ConfigStore` is the store method described below. - +Throughout the rest of tutorials, we will be using `ConfigStore` to register dataclasses as input configs in Hydra. +`ConfigStore` is a singleton storing configs in memory. +The primary API for interacting with the `ConfigStore` is the store method described below. ### API ```python @@ -33,14 +33,123 @@ class ConfigStore(metaclass=Singleton): ... ``` +### ConfigStore and YAML input configs + +`ConfigStore` has feature parity with YAML input configs. On top of that, it also provides typing validation. +`ConfigStore` can be used alone or together with YAML. We will see more examples later in this series of tutorials. +For now, let's see how the `ConfigStore` API translates into the YAML input configs, which we've become more familiar +with after the basic tutorials. + +Say we have a simple application and a `db` config group with a `mysql` option: + +
+ +
+ +```python title="my_app.py" +@hydra.main(version_base=None, config_path="conf") +def my_app(cfg: DictConfig) -> None: + print(OmegaConf.to_yaml(cfg)) + + +if __name__ == "__main__": + my_app() +``` +
+
+ +```text title="Directory layout" +├─ conf +│ └─ db +│ └─ mysql.yaml +└── my_app.py + + + +``` +
+
+ +```yaml title="db/mysql.yaml" +driver: mysql +user: omry +password: secret + + + + +``` +
+
+ +What if we want to add an `postgresql` option now? Yes, we can easily add a `db/postgresql.yaml` config group option. But +that is not the only way! We can also use `ConfigStore` to make another config group option for `db` available to Hydra. + +To achieve this, we add a few lines (highlighted) in the above `my_app.py` file: + + +```python title="my_app.py" {1-9} +@dataclass +class PostgresSQLConfig: + driver: str = "postgresql" + user: str = "jieru" + password: str = "secret" + +cs = ConfigStore.instance() +# Registering the Config class with the name `postgresql` with the config group `db` +cs.store(name="postgresql", group="db", node=PostgresSQLConfig) + +@hydra.main(version_base=None, config_path="conf") +def my_app(cfg: DictConfig) -> None: + print(OmegaConf.to_yaml(cfg)) + + +if __name__ == "__main__": + my_app() +``` + + +Now that our application has access to both `db` config group options, let's run the application to verify: + +
+ +
+ +```commandline title="python my_app.py +db=mysql" +db: + driver: mysql + user: omry + password: secret + +``` +
+
+ +```commandline title="python my_app.py +db=postgresql" +db: + driver: postgresql + user: jieru + password: secret + +``` +
+
+ + ### Example node values A few examples of supported node values parameters: ```python +from dataclasses import dataclass + +from hydra.core.config_store import ConfigStore + @dataclass class MySQLConfig: host: str = "localhost" port: int = 3306 +cs = ConfigStore.instance() + # Using the type cs.store(name="config1", node=MySQLConfig) # Using an instance, overriding some default values diff --git a/website/docs/tutorials/structured_config/1_minimal_example.md b/website/docs/tutorials/structured_config/1_minimal_example.md index 8739155bbd6..668cae8d672 100644 --- a/website/docs/tutorials/structured_config/1_minimal_example.md +++ b/website/docs/tutorials/structured_config/1_minimal_example.md @@ -30,7 +30,7 @@ cs = ConfigStore.instance() # Registering the Config class with the name 'config'. cs.store(name="config", node=MySQLConfig) -@hydra.main(config_path=None, config_name="config") +@hydra.main(version_base=None, config_name="config") def my_app(cfg: MySQLConfig) -> None: # pork should be port! if cfg.pork == 80: diff --git a/website/docs/tutorials/structured_config/2_hierarchical_static_config.md b/website/docs/tutorials/structured_config/2_hierarchical_static_config.md index f7aae5177c4..66dc5e373c0 100644 --- a/website/docs/tutorials/structured_config/2_hierarchical_static_config.md +++ b/website/docs/tutorials/structured_config/2_hierarchical_static_config.md @@ -11,6 +11,11 @@ import {ExampleGithubLink} from "@site/src/components/GithubLink" Dataclasses can be nested and then accessed via a common root. The entire tree is type checked. ```python +from dataclasses import dataclass + +import hydra +from hydra.core.config_store import ConfigStore + @dataclass class MySQLConfig: host: str = "localhost" @@ -30,7 +35,7 @@ class MyConfig: cs = ConfigStore.instance() cs.store(name="config", node=MyConfig) -@hydra.main(config_path=None, config_name="config") +@hydra.main(version_base=None, config_name="config") def my_app(cfg: MyConfig) -> None: print(f"Title={cfg.ui.title}, size={cfg.ui.width}x{cfg.ui.height} pixels") diff --git a/website/docs/tutorials/structured_config/3_config_groups.md b/website/docs/tutorials/structured_config/3_config_groups.md index 64464d1d587..fe4b8d39310 100644 --- a/website/docs/tutorials/structured_config/3_config_groups.md +++ b/website/docs/tutorials/structured_config/3_config_groups.md @@ -11,6 +11,11 @@ Structured Configs can be used to implement config groups. Special care needs to default value for fields populated by a config group. We will look at why below. ```python title="Defining a config group for database" {16-17,22-23} +from dataclasses import dataclass + +import hydra +from hydra.core.config_store import ConfigStore + @dataclass class MySQLConfig: driver: str = "mysql" @@ -35,9 +40,12 @@ cs.store(name="config", node=Config) cs.store(group="db", name="mysql", node=MySQLConfig) cs.store(group="db", name="postgresql", node=PostGreSQLConfig) -@hydra.main(config_path=None, config_name="config") +@hydra.main(version_base=None, config_name="config") def my_app(cfg: Config) -> None: print(OmegaConf.to_yaml(cfg)) + +if __name__ == "__main__": + my_app() ``` :::caution diff --git a/website/docs/tutorials/structured_config/4_defaults.md b/website/docs/tutorials/structured_config/4_defaults.md index 5923c2392fe..3f47f86f379 100644 --- a/website/docs/tutorials/structured_config/4_defaults.md +++ b/website/docs/tutorials/structured_config/4_defaults.md @@ -15,6 +15,10 @@ NOTE: You can still place your defaults list in your primary (YAML) config file
```python {11-14,19,25} +from dataclasses import dataclass + +import hydra +from hydra.core.config_store import ConfigStore from omegaconf import MISSING, OmegaConf @dataclass @@ -44,7 +48,7 @@ cs.store(group="db", name="postgresql", node=PostGreSQLConfig) cs.store(name="config", node=Config) -@hydra.main(config_path=None, config_name="config") +@hydra.main(version_base=None, config_name="config") def my_app(cfg: Config) -> None: print(OmegaConf.to_yaml(cfg)) diff --git a/website/docs/tutorials/structured_config/5_schema.md b/website/docs/tutorials/structured_config/5_schema.md index ff062bdb222..bce5762ac12 100644 --- a/website/docs/tutorials/structured_config/5_schema.md +++ b/website/docs/tutorials/structured_config/5_schema.md @@ -75,6 +75,11 @@ The primary Defaults List will come from `config.yaml`.
my_app.py (Click to expand) ```python {28-30} +from dataclasses import dataclass + +import hydra +from hydra.core.config_store import ConfigStore + @dataclass class DBConfig: driver: str = MISSING @@ -106,7 +111,7 @@ cs.store(name="base_config", node=Config) cs.store(group="db", name="base_mysql", node=MySQLConfig) cs.store(group="db", name="base_postgresql", node=PostGreSQLConfig) -@hydra.main(config_path="conf", config_name="config") +@hydra.main(version_base=None, config_path="conf", config_name="config") def my_app(cfg: Config) -> None: print(OmegaConf.to_yaml(cfg)) @@ -190,6 +195,11 @@ we want to validate against.
```python title="my_app.py" +from dataclasses import dataclass + +import hydra +from hydra.core.config_store import ConfigStore + import database_lib @@ -221,6 +231,10 @@ if __name__ == "__main__":
```python title="database_lib.py" {17,22} +from dataclasses import dataclass + +from hydra.core.config_store import ConfigStore + @dataclass class DBConfig: ... diff --git a/website/docs/upgrades/1.0_to_1.1/hydra_main_config_path.md b/website/docs/upgrades/1.0_to_1.1/hydra_main_config_path.md index 08801aa5e35..d2d3d0bd6cb 100644 --- a/website/docs/upgrades/1.0_to_1.1/hydra_main_config_path.md +++ b/website/docs/upgrades/1.0_to_1.1/hydra_main_config_path.md @@ -23,7 +23,7 @@ hydra.initialize(config_path="conf") ### No config directory For applications that do not define config files next to the Python script (typically applications using only Structured Configs), it is recommended that you pass `None` as the config_path, indicating that no directory should be added to the config search path. -This will become the default in Hydra 1.2. +This will become the default with [version_base](../version_base.md) >= "1.2" ```python @hydra.main(config_path=None) # or: diff --git a/website/docs/upgrades/1.1_to_1.2/changes_to_job_working_dir.md b/website/docs/upgrades/1.1_to_1.2/changes_to_job_working_dir.md new file mode 100644 index 00000000000..f41573e3ae3 --- /dev/null +++ b/website/docs/upgrades/1.1_to_1.2/changes_to_job_working_dir.md @@ -0,0 +1,16 @@ +--- +id: changes_to_job_working_dir +title: Changes to job's runtime working directory +hide_title: true +--- + +Hydra 1.2 introduces `hydra.job.chdir`. This config allows users to specify whether Hydra should change the runtime working +directory to the job's output directory. +`hydra.job.chdir` will default to `False` if version_base is set to >= "1.2" (or None), +or otherwise will use the old behavior and default to `True`, with a warning being issued if `hydra.job.chdir` is not set. + +If you want to keep the old Hydra behavior, please set `hydra.job.chdir=True` explicitly for your application. + +For more information about `hydra.job.chdir`, +see [Output/Working directory](/tutorials/basic/running_your_app/3_working_directory.md#disable-changing-current-working-dir-to-jobs-output-dir) +and [Job Configuration - hydra.job.chdir](/configure_hydra/job.md#hydrajobchdir). diff --git a/website/docs/upgrades/1.1_to_1.2/changes_to_sweeper_config.md b/website/docs/upgrades/1.1_to_1.2/changes_to_sweeper_config.md new file mode 100644 index 00000000000..99128cae295 --- /dev/null +++ b/website/docs/upgrades/1.1_to_1.2/changes_to_sweeper_config.md @@ -0,0 +1,58 @@ +--- +id: changes_to_sweeper_config +title: Changes to configuring sweeper's search space +hide_title: true +--- + +Hydra 1.2 introduces `hydra.sweeper.params`. All Hydra Sweepers (BasicSweeper and HPOs) search +space will be defined under this config node. + + +### Optuna +For migration, move search space definition from `hydra.sweeper.search_space` to `hydra.sweeper.params`. Change the search space +definition to be consistent with how you'd override a value from commandline. For example: + +
+
+ +```yaml title="Hydra 1.1" +hydra: + sweeper: + search_space: + search_space: + x: + type: float + low: -5.5 + high: 5.5 + step: 0.5 + 'y': + type: categorical + choices: + - -5 + - 0 + - 5 +``` +
+
+ +```bash title="Hydra 1.2" +hydra: + sweeper: + params: + x: range(-5.5, 5.5, step=0.5) + y: choice(-5, 0, 5) + + + + + + + + + + +``` +
+
+ +Check out [Optuna Sweeper](/plugins/optuna_sweeper.md) for more info. \ No newline at end of file diff --git a/website/docs/upgrades/1.1_to_1.2/hydra_main_config_path.md b/website/docs/upgrades/1.1_to_1.2/hydra_main_config_path.md new file mode 100644 index 00000000000..6f61bb8b91c --- /dev/null +++ b/website/docs/upgrades/1.1_to_1.2/hydra_main_config_path.md @@ -0,0 +1,8 @@ +--- +id: changes_to_hydra_main_config_path +title: Changes to @hydra.main() and hydra.initialize() +--- + +Prior to Hydra 1.2, **@hydra.main()** and **hydra.initialize()** default `config path` was the directory containing the Python app (calling **@hydra.main()** or **hydra.initialize()**). +Starting with Hydra 1.1 we give [control over the default config path](../1.0_to_1.1/hydra_main_config_path.md), +and starting with Hydra 1.2, with [version_base](../version_base.md) >= "1.2", we choose a default config_path=None, indicating that no directory should be added to the config search path. diff --git a/website/docs/upgrades/intro.md b/website/docs/upgrades/intro.md index e65a2201d85..139d4a3b473 100644 --- a/website/docs/upgrades/intro.md +++ b/website/docs/upgrades/intro.md @@ -5,7 +5,9 @@ sidebar_label: Introduction --- Upgrading to a new Hydra version is usually an easy process. - +Also since Hydra version 1.2, backwards compatibility is improved +by giving the user more control over appropriate defaults +through the use of the [version_base parameter](version_base.md). :::info NOTE Hydra versioning has only major versions and patch versions. A bump of the first two version digits is considered a major release. diff --git a/website/docs/upgrades/version_base.md b/website/docs/upgrades/version_base.md new file mode 100644 index 00000000000..246e9889c7a --- /dev/null +++ b/website/docs/upgrades/version_base.md @@ -0,0 +1,19 @@ +--- +id: version_base +title: version_base +--- + +Hydra since version 1.2 supports backwards compatible upgrades by default +through the use of the `version_base` parameter to **@hydra.main()** and **hydra.initialize()**. + +There are three classes of values that the `version_base` parameter supports, +given new and existing users greater control of the default behaviors to use. + +1. If the `version_base` parameter is **not specified**, Hydra 1.x will use defaults compatible with version 1.1. +Also in this case, a warning is issued to indicate an explicit `version_base` is preferred. + +2. If the `version_base` parameter is **None**, then the defaults are chosen for the current minor Hydra version. +For example for Hydra 1.2, then would imply `config_path=None` and `hydra.job.chdir=False`. + +3. If the `version_base` parameter is an **explicit version string** like "1.1", +then the defaults appropriate to that version are used. diff --git a/website/docusaurus.config.js b/website/docusaurus.config.js index bb96b5f20be..b72a2143aeb 100755 --- a/website/docusaurus.config.js +++ b/website/docusaurus.config.js @@ -10,18 +10,30 @@ module.exports = { tagline: 'A framework for elegantly configuring complex applications', url: 'https://hydra.cc', baseUrl: '/', + onBrokenLinks: 'throw', + onBrokenMarkdownLinks: 'warn', + trailingSlash: true, favicon: 'img/Hydra-head.svg', organizationName: 'facebookresearch', // Usually your GitHub org/user name. projectName: 'hydra', // Usually your repo name. customFields: { githubLinkVersionToBaseUrl: { - // TODO: Update once a branch is cut for 1.1 - "1.1": "https://github.com/facebookresearch/hydra/blob/main/", + // TODO: Update once a branch is cut for 1.2 + "1.2": "https://github.com/facebookresearch/hydra/blob/main/", + "1.1": "https://github.com/facebookresearch/hydra/blob/1.1_branch/", "1.0": "https://github.com/facebookresearch/hydra/blob/1.0_branch/", current: "https://github.com/facebookresearch/hydra/blob/main/", }, }, themeConfig: { + announcementBar: { + id: 'support_ukraine', + content: + 'Support Ukraine 🇺🇦
Help Provide Humanitarian Aid to Ukraine.', + backgroundColor: '#20232a', + textColor: '#fff', + isCloseable: false, + }, googleAnalytics: { trackingID: 'UA-149862507-1', }, @@ -30,11 +42,11 @@ module.exports = { indexName: 'hydra', algoliaOptions: {}, }, - announcementBar: { - id: 'supportus', - content: - '⭐️ If you like Hydra, give it a star on GitHub! ⭐️', - }, + // announcementBar: { + // id: 'supportus', + // content: + // '⭐️ If you like Hydra, give it a star on GitHub! ⭐️', + // }, prism: { additionalLanguages: ['antlr4'], }, @@ -106,7 +118,7 @@ module.exports = { alt: 'Facebook Open Source Logo', src: 'https://docusaurus.io/img/oss_logo.png', }, - copyright: `Copyright © ${new Date().getFullYear()} Facebook, Inc.`, + copyright: `Copyright © ${new Date().getFullYear()} Meta Platforms, Inc`, }, }, presets: [ diff --git a/website/package.json b/website/package.json old mode 100755 new mode 100644 index 0b61bbd54fd..137f07bdc7e --- a/website/package.json +++ b/website/package.json @@ -10,12 +10,12 @@ "deploy": "docusaurus deploy" }, "dependencies": { - "@docusaurus/core": "^2.0.0-beta.6", - "@docusaurus/preset-classic": "^2.0.0-beta.6", + "@docusaurus/core": "^2.0.0-beta.14", + "@docusaurus/preset-classic": "^2.0.0-beta.14", "classnames": "^2.2.6", - "docusaurus-plugin-internaldocs-fb": "^0.8.5", + "docusaurus-plugin-internaldocs-fb": "0.10.4", "is-svg": "4.3.1", - "node-fetch": "^2.6.1", + "node-fetch": "^2.6.7", "prism-react-renderer": "1.2.1", "react": "^17.0.2", "react-dom": "^17.0.2" diff --git a/website/sidebars.js b/website/sidebars.js index 51e8b2f2a50..44251cb06f5 100755 --- a/website/sidebars.js +++ b/website/sidebars.js @@ -47,12 +47,12 @@ module.exports = { label: 'Structured Configs Tutorial', items: [ 'tutorials/structured_config/intro', + 'tutorials/structured_config/config_store', 'tutorials/structured_config/minimal_example', 'tutorials/structured_config/hierarchical_static_config', 'tutorials/structured_config/config_groups', 'tutorials/structured_config/defaults', 'tutorials/structured_config/schema', - 'tutorials/structured_config/config_store', ], }, ], @@ -133,6 +133,7 @@ module.exports = { "Experimental": [ "experimental/intro", "experimental/callbacks", + "experimental/rerun", ], 'Developer Guide': [ @@ -145,6 +146,16 @@ module.exports = { 'Upgrade Guide': [ 'upgrades/intro', + 'upgrades/version_base', + { + type: 'category', + label: '1.1 to 1.2', + items: [ + 'upgrades/1.1_to_1.2/changes_to_hydra_main_config_path', + 'upgrades/1.1_to_1.2/changes_to_job_working_dir', + 'upgrades/1.1_to_1.2/changes_to_sweeper_config', + ], + }, { type: 'category', label: '1.0 to 1.1', @@ -167,6 +178,7 @@ module.exports = { 'upgrades/0.11_to_1.0/object_instantiation_changes', ], }, + ], 'FB Only': FBInternalOnly([ diff --git a/website/src/components/GithubLink.jsx b/website/src/components/GithubLink.jsx index fc5b76647a8..fddacd2cc9e 100644 --- a/website/src/components/GithubLink.jsx +++ b/website/src/components/GithubLink.jsx @@ -30,12 +30,12 @@ export default function GithubLink(props) { } export function ExampleGithubLink(props) { - const text = props.text ?? "Example" + const text = props.text ?? "Example (Click Here)" return (  Example ); diff --git a/website/src/css/custom.css b/website/src/css/custom.css index aecbf6d4656..234a4490a25 100755 --- a/website/src/css/custom.css +++ b/website/src/css/custom.css @@ -52,3 +52,39 @@ html[data-theme='dark'] .docusaurus-highlight-code-line { background-color: /* Color which works with dark mode syntax highlighting theme */ } +/* Announcement banner */ + +:root { + --docusaurus-announcement-bar-height: auto !important; +} + +div[class^="announcementBarContent"] { + line-height: 40px; + font-size: 20px; + font-weight: bold; + padding: 8px 30px; +} + +div[class^="announcementBarContent"] a { + text-decoration: underline; + display: inline-block; + color: var(--ifm-color-primary-lightest) !important; +} + +div[class^="announcementBarContent"] a:hover { + color: var(--brand) !important; +} + +@media only screen and (max-width: 768px) { + .announcement { + font-size: 18px; + } +} + +@media only screen and (max-width: 500px) { + .announcement { + font-size: 15px; + line-height: 22px; + padding: 6px 30px; + } +} diff --git a/website/src/pages/index.js b/website/src/pages/index.js index 4486c473a45..31bda2852be 100755 --- a/website/src/pages/index.js +++ b/website/src/pages/index.js @@ -46,6 +46,39 @@ const features = [ }, ]; +function VideoContainer() { + return ( +
+
+
+

Hydra overview; 1 minute video

+