Skip to content

Commit

Permalink
fix TimeAsFeature transform for MOO (#3178)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3178

see title. For MOO, fixed features can be empty in L931 in ax/modelbridge/torch.py. This led to an issue where `start_time` was not in the parameters dict, causing issues when untransforming. This fixes the issue

Reviewed By: Balandat

Differential Revision: D67216689

fbshipit-source-id: 4f7985d9cbd9e20c23c4c94a47185eb0bdc72c7b
  • Loading branch information
sdaulton authored and facebook-github-bot committed Dec 16, 2024
1 parent 1b9a8e3 commit c9303da
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ def test_TransformObservationFeatures(self) -> None:
obsf_trans[0],
ObservationFeatures({"x": 2.5, "duration": 0.5, "start_time": 5.0}),
)
# test untransforming observation features that do not have
# start/end time (important for fixed features in MOO when un-
# transforming objective thresholds)
obsf_trans = [ObservationFeatures({"x": 2.5})]
obsf_untrans = self.t.untransform_observation_features(obsf_trans)
self.assertEqual(obsf_untrans, obsf_trans)

def test_TransformSearchSpace(self) -> None:
ss2 = deepcopy(self.search_space)
Expand Down
21 changes: 11 additions & 10 deletions ax/modelbridge/transforms/time_as_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
from ax.models.types import TConfig
from ax.utils.common.logger import get_logger
from ax.utils.common.timeutils import unixtime_to_pandas_ts
from ax.utils.common.typeutils import checked_cast
from pyre_extensions import none_throws
from pyre_extensions import assert_is_instance, none_throws

if TYPE_CHECKING:
# import as module to make sphinx-autodoc-typehints happy
Expand Down Expand Up @@ -139,12 +138,14 @@ def untransform_observation_features(
self, observation_features: list[ObservationFeatures]
) -> list[ObservationFeatures]:
for obsf in observation_features:
start_time = checked_cast(float, obsf.parameters.pop("start_time"))
obsf.start_time = unixtime_to_pandas_ts(start_time)
obsf.end_time = unixtime_to_pandas_ts(
checked_cast(float, obsf.parameters.pop("duration"))
* self.duration_range
+ self.min_duration
+ start_time
)
start_time = obsf.parameters.pop("start_time", None)
duration = obsf.parameters.pop("duration", None)
if start_time is not None:
start_time = assert_is_instance(start_time, float)
obsf.start_time = unixtime_to_pandas_ts(start_time)
if duration is not None:
duration = assert_is_instance(duration, float)
obsf.end_time = unixtime_to_pandas_ts(
duration * self.duration_range + self.min_duration + start_time
)
return observation_features

0 comments on commit c9303da

Please sign in to comment.