From 9ba941c479b18edb32d279aba9e7a328975080cb Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Wed, 1 May 2024 18:51:58 -0700 Subject: [PATCH] Update GenerationNode.model_to_gen_from_name (#2407) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2407 Updates the behavior to match the docstring. Previously, this would call `model_spec_to_gen_from`, which could lead to fitting multiple models if the node had multiple model specs. The return type was always `str` rather than `Optional[str]` suggested by the type hints. Reviewed By: mgarrard Differential Revision: D56735619 fbshipit-source-id: a6b39b3bc921f720bcb43daece55dc9da1ee96c6 --- ax/modelbridge/generation_node.py | 5 ++++- ax/modelbridge/tests/test_generation_node.py | 2 ++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/ax/modelbridge/generation_node.py b/ax/modelbridge/generation_node.py index 289884908df..3f5b493249e 100644 --- a/ax/modelbridge/generation_node.py +++ b/ax/modelbridge/generation_node.py @@ -143,7 +143,10 @@ def model_to_gen_from_name(self) -> Optional[str]: """Returns the name of the model that will be used for gen, if there is one. Otherwise, returns None. """ - return self.model_spec_to_gen_from.model_key + if self._model_spec_to_gen_from is not None: + return self._model_spec_to_gen_from.model_key + else: + return None @property def generation_strategy(self) -> modelbridge.generation_strategy.GenerationStrategy: diff --git a/ax/modelbridge/tests/test_generation_node.py b/ax/modelbridge/tests/test_generation_node.py index 829ee72e7ea..115d38be506 100644 --- a/ax/modelbridge/tests/test_generation_node.py +++ b/ax/modelbridge/tests/test_generation_node.py @@ -113,6 +113,7 @@ def test_properties(self) -> None: ), ], ) + self.assertIsNone(node.model_to_gen_from_name) dat = self.branin_experiment.lookup_data() node.fit( experiment=self.branin_experiment, @@ -124,6 +125,7 @@ def test_properties(self) -> None: self.assertEqual( node.model_spec_to_gen_from.model_kwargs, node.model_specs[0].model_kwargs ) + self.assertEqual(node.model_to_gen_from_name, "GPEI") self.assertEqual( node.model_spec_to_gen_from.model_gen_kwargs, node.model_specs[0].model_gen_kwargs,