Skip to content

Commit

Permalink
2.x generic alias handling in validate_parameters (#16117)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz authored Nov 26, 2024
1 parent fb919c6 commit a6bdfdf
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 7 deletions.
31 changes: 24 additions & 7 deletions src/prefect/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import importlib.util
import inspect
import os
import sys
import tempfile
import warnings
from functools import partial, update_wrapper
Expand Down Expand Up @@ -137,6 +138,16 @@
if TYPE_CHECKING:
from prefect.deployments.runner import FlexibleScheduleList, RunnerDeployment

# Handle Python 3.8 compatibility for GenericAlias
if sys.version_info >= (3, 9):
from types import GenericAlias # novermin

GENERIC_ALIAS = (GenericAlias,)
else:
from typing import _GenericAlias

GENERIC_ALIAS = (_GenericAlias,)


@PrefectObjectRegistry.register_instances
class Flow(Generic[P, R]):
Expand Down Expand Up @@ -530,18 +541,22 @@ def validate_parameters(self, parameters: Dict[str, Any]) -> Dict[str, Any]:
is_v1_type(param.annotation) for param in sig.parameters.values()
)
has_v1_models = any(
issubclass(param.annotation, V1BaseModel)
if isinstance(param.annotation, type)
else False
(
isinstance(param.annotation, type)
and not isinstance(param.annotation, GENERIC_ALIAS)
and issubclass(param.annotation, V1BaseModel)
)
for param in sig.parameters.values()
)
has_v2_types = any(
is_v2_type(param.annotation) for param in sig.parameters.values()
)
has_v2_models = any(
issubclass(param.annotation, V2BaseModel)
if isinstance(param.annotation, type)
else False
(
isinstance(param.annotation, type)
and not isinstance(param.annotation, GENERIC_ALIAS)
and issubclass(param.annotation, V2BaseModel)
)
for param in sig.parameters.values()
)

Expand Down Expand Up @@ -1601,7 +1616,9 @@ def flow(


def select_flow(
flows: Iterable[Flow], flow_name: str = None, from_message: str = None
flows: Iterable[Flow],
flow_name: Optional[str] = None,
from_message: Optional[str] = None,
) -> Flow:
"""
Select the only flow in an iterable or a flow specified by name.
Expand Down
17 changes: 17 additions & 0 deletions tests/test_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -1517,6 +1517,23 @@ def my_flow(secret: SecretStr):
"secret": SecretStr("my secret")
}

@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Python 3.9+ required for GenericAlias"
)
def test_flow_signature_can_contain_generic_type_hints(self):
"""Test that generic type hints like dict[str, str] work correctly
this is a regression test for https://github.com/PrefectHQ/prefect/issues/16105
"""

@flow
def my_flow(param: dict[str, str]): # novermin
return param

test_data = {"foo": "bar"}
assert my_flow(test_data) == test_data
assert my_flow.validate_parameters({"param": test_data}) == {"param": test_data}


class TestSubflowTaskInputs:
async def test_subflow_with_one_upstream_task_future(self, prefect_client):
Expand Down

0 comments on commit a6bdfdf

Please sign in to comment.