Skip to content

Commit

Permalink
Fix issue where StepInfo config could be parsed into a Step (#344)
Browse files Browse the repository at this point in the history
* fix issue where StepInfo config could be parsed into a Step

* Update tests/step_info_test.py

Co-authored-by: Pete <[email protected]>

* Update tests/step_info_test.py

Co-authored-by: Pete <[email protected]>

* Update tests/step_info_test.py

Co-authored-by: Pete <[email protected]>

* fix tests

Co-authored-by: Pete <[email protected]>
  • Loading branch information
lgatys and epwalsh authored Jul 26, 2022
1 parent 57096b2 commit 2498318
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed issue where the StepInfo config argument could be parsed into a Step.
- Restored capability to run tests out-of-tree.

## [v0.10.0](https://github.com/allenai/tango/releases/tag/v0.10.0) - 2022-07-07
Expand Down
14 changes: 9 additions & 5 deletions tango/common/from_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,11 +387,15 @@ def construct_arg(
args = getattr(annotation, "__args__", [])

# Try to guess if `popped_params` might be a step, come from a step, or contain a step.
could_be_step = try_from_step and (
origin == Step
or isinstance(popped_params, Step)
or _params_contain_step(popped_params)
or (isinstance(popped_params, (dict, Params)) and popped_params.get("type") == "ref")
could_be_step = (
try_from_step
and (
origin == Step
or isinstance(popped_params, Step)
or _params_contain_step(popped_params)
or (isinstance(popped_params, (dict, Params)) and popped_params.get("type") == "ref")
)
and not (class_name == "StepInfo" and argument_name == "config")
)
if could_be_step:
# If we think it might be a step, we try parsing as a step _first_.
Expand Down
35 changes: 35 additions & 0 deletions tests/step_info_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import json
from pathlib import Path
from typing import Any

from tango.common.testing.steps import FloatStep
from tango.step import Step
from tango.step_graph import StepGraph
from tango.step_info import StepInfo


Expand All @@ -23,3 +26,35 @@ def test_step_info():
serialized = json.dumps(step_info.to_json_dict())
deserialized = StepInfo.from_json_dict(json.loads(serialized))
assert deserialized == step_info


def test_step_info_with_step_dependency():
"""Checks that the StepInfo config is not parsed to a Step if it has dependencies on upstream steps"""

@Step.register("foo", exist_ok=True)
class FooStep(Step):
def run(self, bar: Any) -> str: # type: ignore
return "foo" + bar

@Step.register("bar", exist_ok=True)
class BarStep(Step):
def run(self) -> str: # type: ignore
return "Hey!"

graph = StepGraph.from_params(
{
"foo": {
"type": "foo",
"bar": {"type": "ref", "ref": "bar"},
},
"bar": {
"type": "bar",
},
}
)
step = graph["foo"]
step_info = StepInfo.new_from_step(step)

step_info_json = json.dumps(step_info.to_json_dict())
step_info = StepInfo.from_json_dict(json.loads(step_info_json))
assert isinstance(step_info.config, dict)
2 changes: 1 addition & 1 deletion tests/step_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class Bar(FromParams):
def __init__(self, x: int):
self.x = x

@Step.register("foo")
@Step.register("foo", exist_ok=True)
class FooStep(Step):
def run(self, bar: Bar) -> Bar: # type: ignore
return bar
Expand Down

0 comments on commit 2498318

Please sign in to comment.