Skip to content

Commit

Permalink
Extend list function (facebookresearch#2950)
Browse files Browse the repository at this point in the history
* Add list_extend function to override functions

* Add tests for list_extend

* Add documentation for extend_list
  • Loading branch information
jesszzzz authored Oct 15, 2024
1 parent 253e72a commit 4d9d6e7
Show file tree
Hide file tree
Showing 11 changed files with 210 additions and 14 deletions.
9 changes: 9 additions & 0 deletions hydra/_internal/config_loader_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,15 @@ def _apply_overrides_to_config(overrides: List[Override], cfg: DictConfig) -> No
)
elif override.is_force_add():
OmegaConf.update(cfg, key, value, merge=True, force_add=True)
elif override.is_list_extend():
config_val = OmegaConf.select(cfg, key, throw_on_missing=True)
if not OmegaConf.is_list(config_val):
raise ConfigCompositionException(
"Could not append to config list. The existing value of"
f" '{override.key_or_group}' is {config_val} which is not"
f" a list."
)
config_val.extend(value)
else:
try:
OmegaConf.update(cfg, key, value, merge=True)
Expand Down
8 changes: 8 additions & 0 deletions hydra/_internal/grammar/grammar_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
ChoiceSweep,
Glob,
IntervalSweep,
ListExtensionOverrideValue,
ParsedElementType,
QuotedString,
RangeSweep,
Expand Down Expand Up @@ -399,3 +400,10 @@ def glob(
exclude = [exclude]

return Glob(include=include, exclude=exclude)


def extend_list(*args: Any) -> ListExtensionOverrideValue:
"""
Extends an existing list in the config with the given values.
"""
return ListExtensionOverrideValue(values=list(args))
1 change: 1 addition & 0 deletions hydra/core/override_parser/overrides_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,5 @@ def create_functions() -> Functions:
functions.register(name="sort", func=grammar_functions.sort)
functions.register(name="shuffle", func=grammar_functions.shuffle)
functions.register(name="glob", func=grammar_functions.glob)
functions.register(name="extend_list", func=grammar_functions.extend_list)
return functions
8 changes: 8 additions & 0 deletions hydra/core/override_parser/overrides_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Glob,
IntervalSweep,
Key,
ListExtensionOverrideValue,
Override,
OverrideType,
ParsedElementType,
Expand Down Expand Up @@ -190,6 +191,13 @@ def visitOverride(self, ctx: OverrideParser.OverrideContext) -> Override:
value_type = ValueType.RANGE_SWEEP
else:
value_type = ValueType.ELEMENT
if isinstance(value, ListExtensionOverrideValue):
if not override_type == OverrideType.CHANGE:
raise HydraException(
"Trying to use override symbols when extending a list"
)
override_type = OverrideType.EXTEND_LIST
value = value.values

return Override(
type=override_type,
Expand Down
12 changes: 12 additions & 0 deletions hydra/core/override_parser/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ class OverrideType(Enum):
ADD = 2
FORCE_ADD = 3
DEL = 4
EXTEND_LIST = 5


class ValueType(Enum):
Expand Down Expand Up @@ -227,6 +228,11 @@ def match(s: str, globs: List[str]) -> bool:
return res


@dataclass
class ListExtensionOverrideValue:
values: List["ParsedElementType"]


class Transformer:
@staticmethod
def identity(x: ParsedElementType) -> ParsedElementType:
Expand Down Expand Up @@ -286,6 +292,12 @@ def is_force_add(self) -> bool:
"""
return self.type == OverrideType.FORCE_ADD

def is_list_extend(self) -> bool:
"""
:return: True if this override represents appending to a list config value
"""
return self.type == OverrideType.EXTEND_LIST

@staticmethod
def _convert_value(value: ParsedElementType) -> Optional[ElementType]:
if isinstance(value, list):
Expand Down
1 change: 1 addition & 0 deletions news/1547.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add extend_list function to override syntax
49 changes: 48 additions & 1 deletion tests/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
from hydra.core.config_search_path import SearchPathQuery
from hydra.core.config_store import ConfigStore
from hydra.core.global_hydra import GlobalHydra
from hydra.errors import ConfigCompositionException, HydraException
from hydra.errors import (
ConfigCompositionException,
HydraException,
OverrideParseException,
)
from hydra.test_utils.test_utils import chdir_hydra_root

chdir_hydra_root()
Expand Down Expand Up @@ -429,6 +433,49 @@ class Config:
compose(config_name="config", overrides=overrides)


@mark.usefixtures("initialize_hydra_no_path")
@mark.parametrize(
("overrides", "expected"),
[
param(
["list_key=extend_list(d, e)"],
{"list_key": ["a", "b", "c", "d", "e"]},
id="extend_list_with_str",
),
param(
["list_key=extend_list([d1, d2])"],
{"list_key": ["a", "b", "c", ["d1", "d2"]]},
id="extend_list_with_list",
),
param(
["list_key=extend_list(d, [e1])", "list_key=extend_list(f)"],
{"list_key": ["a", "b", "c", "d", ["e1"], "f"]},
id="extend_list_twice",
),
param(
["+list_key=extend_list([d1, d2])"],
raises(OverrideParseException),
id="extend_list_with_append_key",
),
],
)
def test_extending_list(
hydra_restore_singletons: Any, overrides: List[str], expected: Any
) -> None:
@dataclass
class Config:
list_key: Any = field(default_factory=lambda: ["a", "b", "c"])

ConfigStore.instance().store(name="config", node=Config)

if isinstance(expected, dict):
cfg = compose(config_name="config", overrides=overrides)
assert cfg == expected
else:
with expected:
compose(config_name="config", overrides=overrides)


@mark.parametrize("override", ["hydra.foo=bar", "hydra.job_logging.foo=bar"])
def test_hydra_node_validated(initialize_hydra_no_path: Any, override: str) -> None:
with raises(ConfigCompositionException):
Expand Down
11 changes: 11 additions & 0 deletions tests/test_hydra_cli_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@
ValueError while evaluating 'choice()': empty choice is not legal""",
id="empty choice",
),
param(
"+key=extend_list(1, 2, 3)",
"""Error parsing override '+key=extend_list(1, 2, 3)'
Trying to use override symbols when extending a list""",
id="plus key extend_list",
),
param(
"key={inner_key=extend_list(1, 2, 3)}",
"no viable alternative at input '{inner_key='",
id="embedded extend_list",
),
param(
["+key=choice(choice(a,b))", "-m"],
"""Error parsing override '+key=choice(choice(a,b))'
Expand Down
57 changes: 57 additions & 0 deletions tests/test_overrides_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Glob,
IntervalSweep,
Key,
ListExtensionOverrideValue,
Override,
OverrideType,
Quote,
Expand Down Expand Up @@ -174,6 +175,21 @@ def test_element(value: str, expected: Any) -> None:
ChoiceSweep(list=[1, 2, 3], shuffle=True),
id="shuffle(choice(1,2,3))",
),
param(
"extend_list(1,2,three)",
ListExtensionOverrideValue(values=[1, 2, "three"]),
id="extend_list(1,2,three)",
),
param(
"extend_list('5')",
ListExtensionOverrideValue(values=["5"]),
id="extend_list('5')",
),
param(
"extend_list([1,2,3], {a:1, b:2})",
ListExtensionOverrideValue(values=[[1, 2, 3], {"a": 1, "b": 2}]),
id="extend_list([1,2,3], {a:1, b:2})",
),
],
)
def test_value(value: str, expected: Any) -> None:
Expand Down Expand Up @@ -523,6 +539,15 @@ def test_interval_sweep(value: str, expected: Any) -> None:
raises(HydraException, match=re.escape("mismatched input '/'")),
id="error:dollar_in_group",
),
param(
"override",
"+key=extend_list(foobar)",
raises(
HydraException,
match=re.escape("Trying to use override symbols when extending a list"),
),
id="error:plus_in_extend_list_key",
),
],
)
def test_parse_errors(rule: str, value: str, expected: Any) -> None:
Expand Down Expand Up @@ -997,6 +1022,38 @@ def test_override(
assert ret == expected


@mark.parametrize(
"value,expected_key,expected_value",
[
param(
"key=extend_list([1,2])",
"key",
[[1, 2]],
id="extend_list_of_list",
),
param(
"key=extend_list(1,2,3)",
"key",
[1, 2, 3],
id="extend_list_with_multiple_vals",
),
],
)
def test_list_extend_override(
value: str,
expected_key: str,
expected_value: Any,
) -> None:
test_override(
"",
value,
OverrideType.EXTEND_LIST,
expected_key,
expected_value,
ValueType.ELEMENT,
)


def test_deprecated_name_package(hydra_restore_singletons: Any) -> None:
msg = (
"In override key@_name_=value: _name_ keyword is deprecated in packages, "
Expand Down
2 changes: 2 additions & 0 deletions website/docs/advanced/override_grammar/basic.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,8 @@ foo=[1,2,3]
nested=[a,[b,[c]]]
```

Lists are assigned, not merged. To extend an existing list, use the [`extend_list` function](extended.md#extending-lists).

### Dictionaries
```python
foo={a:10,b:20}
Expand Down
Loading

0 comments on commit 4d9d6e7

Please sign in to comment.