diff --git a/src/lmnr_flow/context.py b/src/lmnr_flow/context.py index 9b6e170..e111211 100644 --- a/src/lmnr_flow/context.py +++ b/src/lmnr_flow/context.py @@ -10,8 +10,10 @@ class Context: def __init__(self): self.states = {} # str -> State - def get(self, key: str) -> Any: + def get(self, key: str, default: Any = None) -> Any: if key not in self.states: + if default is not None: + return default raise Exception(f"Key {key} not found in context") state = self.states[key] diff --git a/tests/test_flow.py b/tests/test_flow.py index bece04b..cd3d827 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -206,6 +206,28 @@ def task2(ctx): assert result == {"task2": "result2"} +def test_context_get_with_fallback(flow): + # Tests context handling of existing and missing keys (with/without default). + def existing_key_task(ctx): + ctx.set("existing_key", "existing_value") + return TaskOutput("output_with_key") + + flow.add_task("existing_key_task", existing_key_task) + + flow.run("existing_key_task") + + # Test retrieving an existing key. + existing_value = flow.context.get("existing_key") + assert existing_value == "existing_value" + + with pytest.raises(Exception) as exc_info: + flow.context.get("non_existent_key") + assert "Key non_existent_key not found in context" in str(exc_info.value) + + fallback_value = flow.context.get("non_existent_key_with_default", default="fallback_value") + assert fallback_value == "fallback_value" + + def test_invalid_task_reference(flow): # Test referencing non-existent task def task1(ctx):