Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved type-hints for stage and source decorators #1831

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions morpheus/pipeline/stage_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@
from morpheus.messages import MultiMessage

logger = logging.getLogger(__name__)
GeneratorType = typing.Callable[..., collections.abc.Iterator[typing.Any]]

_InputT = typing.TypeVar('_InputT')
_OutputT = typing.TypeVar('_OutputT')
_P = typing.ParamSpec('_P')

GeneratorType = typing.Callable[_P, collections.abc.Iterator[_OutputT]]
ComputeSchemaType = typing.Callable[[_pipeline.StageSchema], None]


Expand Down Expand Up @@ -134,7 +139,12 @@ class PreAllocatedWrappedFunctionStage(_pipeline.PreallocatorMixin, WrappedFunct
"""


def source(gen_fn: GeneratorType = None, *, name: str = None, compute_schema_fn: ComputeSchemaType = None):
def source(
gen_fn: GeneratorType = None,
*,
name: str = None,
compute_schema_fn: ComputeSchemaType = None
) -> typing.Callable[typing.Concatenate[Config, _P], WrappedFunctionSourceStage]:
"""
Decorator for wrapping a function as a source stage. The function must be a generator method, and provide a
provide a return type annotation.
Expand Down Expand Up @@ -162,7 +172,7 @@ def source(gen_fn: GeneratorType = None, *, name: str = None, compute_schema_fn:
# Use wraps to ensure user's don't lose their function name and docstrinsgs, however we do want to override the
# annotations to reflect that the returned function requires a config and returns a stage
@functools.wraps(gen_fn, assigned=('__module__', '__name__', '__qualname__', '__doc__'))
def wrapper(config: Config, **kwargs) -> WrappedFunctionSourceStage:
def wrapper(config: Config, **kwargs: _P.kwargs) -> WrappedFunctionSourceStage:
nonlocal name
nonlocal compute_schema_fn

Expand Down Expand Up @@ -271,12 +281,15 @@ def _build_single(self, builder: mrc.Builder, input_node: mrc.SegmentObject) ->
return node


def stage(on_data_fn: typing.Callable = None,
DecoratedStageType = typing.Callable[typing.Concatenate[Config, _P], WrappedFunctionStage]


def stage(on_data_fn: typing.Callable[typing.Concatenate[_InputT, _P], _OutputT] = None,
*,
name: str = None,
accept_type: type = None,
compute_schema_fn: ComputeSchemaType = None,
needed_columns: dict[str, TypeId] = None):
needed_columns: dict[str, TypeId] = None) -> DecoratedStageType:
"""
Decorator for wrapping a function as a stage. The function must receive at least one argument, the first argument
must be the incoming message, and must return a value.
Expand Down Expand Up @@ -317,7 +330,7 @@ def stage(on_data_fn: typing.Callable = None,
# Use wraps to ensure user's don't lose their function name and docstrinsgs, however we do want to override the
# annotations to reflect that the returned function requires a config and returns a stage
@functools.wraps(on_data_fn, assigned=('__module__', '__name__', '__qualname__', '__doc__'))
def wrapper(config: Config, **kwargs) -> WrappedFunctionStage:
def wrapper(config: Config, **kwargs: _P.kwargs) -> WrappedFunctionStage:
nonlocal name
nonlocal accept_type
nonlocal compute_schema_fn
Expand Down
Loading