Skip to content

Commit

Permalink
Allow dot in parameter/metric names (#3195)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3195

SymPy variables must be valid Python identifiers, but some legacy Ax users may include dots in their parameter names. Until we can guarantee this is no longer happening we need this sanitization.

Reviewed By: lena-kashtelyan

Differential Revision: D67414859

fbshipit-source-id: 9ff915c98c61ba04cf5fd74c2871c9878005a729
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Dec 19, 2024
1 parent 8bea777 commit 8e3b456
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 8 deletions.
41 changes: 33 additions & 8 deletions ax/preview/api/utils/instantiation/from_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

# pyre-strict

import re
from typing import Sequence

from ax.core.metric import Metric
Expand All @@ -28,6 +29,8 @@
from sympy.core.symbol import Symbol
from sympy.core.sympify import sympify

DOT_PLACEHOLDER = "__dot__"


def optimization_config_from_string(
objective_str: str, outcome_constraint_strs: Sequence[str] | None = None
Expand Down Expand Up @@ -108,7 +111,7 @@ def parse_parameter_constraint(constraint_str: str) -> ParameterConstraint:
bound = 0
for term, coefficient in coefficient_dict.items():
if term.is_symbol:
constraint_dict[term.name] = coefficient
constraint_dict[_unsanitize_dot(term.name)] = coefficient
elif term.is_number:
# Invert because we are "moving" the bound to the right hand side
bound = -1 * coefficient
Expand All @@ -129,7 +132,7 @@ def parse_objective(objective_str: str) -> Objective:
linear objectives.
"""
# Parse the objective string into a SymPy expression
expression = sympify(objective_str)
expression = sympify(_sanitize_dot(objective_str))

if isinstance(expression, tuple): # Multi-objective
return MultiObjective(
Expand Down Expand Up @@ -181,15 +184,15 @@ def parse_outcome_constraint(constraint_str: str) -> OutcomeConstraint:
term, coefficient = next(iter(constraint_dict.items()))

return OutcomeConstraint(
metric=Metric(name=term),
metric=Metric(name=_unsanitize_dot(term)),
op=ComparisonOp.LEQ if coefficient > 0 else ComparisonOp.GEQ,
bound=bound / coefficient,
relative=is_relative,
)

names, coefficients = zip(*constraint_dict.items())
return ScalarizedOutcomeConstraint(
metrics=[Metric(name=name) for name in names],
metrics=[Metric(name=_unsanitize_dot(name)) for name in names],
op=ComparisonOp.LEQ,
weights=[*coefficients],
bound=bound,
Expand All @@ -206,7 +209,9 @@ def _create_single_objective(expression: Expr) -> Objective:

# If the expression is a just a Symbol it represents a single metric objective
if isinstance(expression, Symbol):
return Objective(metric=Metric(name=str(expression.name)), minimize=False)
return Objective(
metric=Metric(name=_unsanitize_dot(str(expression.name))), minimize=False
)

# If the expression is a Mul it likely represents a single metric objective but
# some additional validation is required
Expand All @@ -221,13 +226,15 @@ def _create_single_objective(expression: Expr) -> Objective:
# the sign from the coefficient rather than its value
minimize = bool(expression.as_coefficient(symbol) < 0)

return Objective(metric=Metric(name=str(symbol)), minimize=minimize)
return Objective(
metric=Metric(name=_unsanitize_dot(str(symbol))), minimize=minimize
)

# If the expression is an Add it represents a scalarized objective
elif isinstance(expression, Add):
names, coefficients = zip(*expression.as_coefficients_dict().items())
return ScalarizedObjective(
metrics=[Metric(name=str(name)) for name in names],
metrics=[Metric(name=_unsanitize_dot(str(name))) for name in names],
weights=[float(coefficient) for coefficient in coefficients],
minimize=False,
)
Expand All @@ -245,7 +252,7 @@ def _extract_coefficient_dict_from_inequality(
constraints.
"""
# Parse the constraint string into a SymPy inequality
inequality = sympify(inequality_str)
inequality = sympify(_sanitize_dot(inequality_str))

# Check the SymPy object is a valid inequality
if not isinstance(inequality, GreaterThan | LessThan):
Expand All @@ -261,3 +268,21 @@ def _extract_coefficient_dict_from_inequality(
return {
key: float(value) for key, value in expression.as_coefficients_dict().items()
}


def _sanitize_dot(s: str) -> str:
"""
Converts a string with normal dots to a string with sanitized dots. This is
temporarily necessary because SymPy symbol names must be valid Python identifiers,
but some legacy Ax users may include dots in their parameter names.
"""
return re.sub(r"([a-zA-Z])\.([a-zA-Z])", r"\1__dot__\2", s)


def _unsanitize_dot(s: str) -> str:
"""
Converts a string with sanitized dots back to a string with normal dots. This is
temporarily necessary because SymPy symbol names must be valid Python identifiers,
but some legacy Ax users may include dots in their parameter names.
"""
return re.sub(r"__dot__", ".", s)
12 changes: 12 additions & 0 deletions ax/preview/api/utils/instantiation/tests/test_from_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ax.core.parameter_constraint import ParameterConstraint
from ax.exceptions.core import UserInputError
from ax.preview.api.utils.instantiation.from_string import (
_sanitize_dot,
optimization_config_from_string,
parse_objective,
parse_outcome_constraint,
Expand Down Expand Up @@ -211,3 +212,14 @@ def test_parse_outcome_constraint(self) -> None:

with self.assertRaisesRegex(UserInputError, "Only linear"):
parse_outcome_constraint(constraint_str="flops * flops <= 1000000")

def test_sanitize_dot(self) -> None:
self.assertEqual(_sanitize_dot("foo.bar.baz"), "foo__dot__bar__dot__baz")

constraint = parse_parameter_constraint(constraint_str="foo.bar + foo.baz <= 1")
self.assertEqual(
constraint,
ParameterConstraint(
constraint_dict={"foo.bar": 1, "foo.baz": 1}, bound=1.0
),
)

0 comments on commit 8e3b456

Please sign in to comment.