From 6e5cd5238afbee02ba24101aef27272b30cd4f8a Mon Sep 17 00:00:00 2001 From: David Gardner Date: Fri, 17 Nov 2023 10:27:52 -0800 Subject: [PATCH] Formatting fixes --- morpheus/pipeline/stage_decorator.py | 67 +++++++------------------- tests/pipeline/test_stage_decorator.py | 4 +- 2 files changed, 19 insertions(+), 52 deletions(-) diff --git a/morpheus/pipeline/stage_decorator.py b/morpheus/pipeline/stage_decorator.py index 64c4db83e1..c69c8d9bab 100644 --- a/morpheus/pipeline/stage_decorator.py +++ b/morpheus/pipeline/stage_decorator.py @@ -103,17 +103,11 @@ class WrappedFunctionSourceStage(SingleOutputSource): Additional keyword arguments to bind to `gen_fn` via `functools.partial`. """ - def __init__(self, - config: Config, - gen_fn: GeneratorType, - *gen_args, - return_type: type = None, - **gen_fn_kwargs): + def __init__(self, config: Config, gen_fn: GeneratorType, *gen_args, return_type: type = None, **gen_fn_kwargs): super().__init__(config) # collections.abc.Generator is a subclass of collections.abc.Iterator if not inspect.isgeneratorfunction(gen_fn): - raise ValueError( - "Wrapped source functions must be generator functions") + raise ValueError("Wrapped source functions must be generator functions") self._gen_fn = functools.partial(gen_fn, *gen_args, **gen_fn_kwargs) self._gen_fn_name = _get_name_from_fn(gen_fn) @@ -133,8 +127,7 @@ def _build_source(self, builder: mrc.Builder) -> mrc.SegmentObject: return builder.make_source(self.unique_name, self._gen_fn) -class PreAllocatedWrappedFunctionStage(PreallocatorMixin, - WrappedFunctionSourceStage): +class PreAllocatedWrappedFunctionStage(PreallocatorMixin, WrappedFunctionSourceStage): """ Source stage that wraps a generator function as the method for generating messages. @@ -159,21 +152,10 @@ class PreAllocatedWrappedFunctionStage(PreallocatorMixin, Additional keyword arguments to bind to `gen_fn` via `functools.partial`. """ - def __init__(self, - config: Config, - gen_fn: GeneratorType, - *gen_args, - return_type: type = None, - **gen_fn_kwargs): - super().__init__(*gen_args, - config=config, - gen_fn=gen_fn, - return_type=return_type, - **gen_fn_kwargs) + def __init__(self, config: Config, gen_fn: GeneratorType, *gen_args, return_type: type = None, **gen_fn_kwargs): + super().__init__(*gen_args, config=config, gen_fn=gen_fn, return_type=return_type, **gen_fn_kwargs) if not _is_dataframe_containing_type(self._return_type): - raise ValueError( - "PreAllocatedWrappedFunctionStage can only be used with DataFrame containing types" - ) + raise ValueError("PreAllocatedWrappedFunctionStage can only be used with DataFrame containing types") def source(gen_fn: GeneratorType): @@ -201,9 +183,7 @@ def source(gen_fn: GeneratorType): # Use wraps to ensure user's don't lose their function name and docstrinsgs, however we do want to override the # annotations to reflect that the returned function requires a config and returns a stage - @functools.wraps(gen_fn, - assigned=('__module__', '__name__', '__qualname__', - '__doc__')) + @functools.wraps(gen_fn, assigned=('__module__', '__name__', '__qualname__', '__doc__')) def wrapper(config: Config, *args, **kwargs) -> WrappedFunctionSourceStage: return_type = _determine_return_type(gen_fn) @@ -215,11 +195,7 @@ def wrapper(config: Config, *args, **kwargs) -> WrappedFunctionSourceStage: return_type=return_type, **kwargs) - return WrappedFunctionSourceStage(*args, - config=config, - gen_fn=gen_fn, - return_type=return_type, - **kwargs) + return WrappedFunctionSourceStage(*args, config=config, gen_fn=gen_fn, return_type=return_type, **kwargs) return wrapper @@ -263,8 +239,7 @@ def __init__(self, return_type: type = None, **on_data_kwargs): super().__init__(config) - self._on_data_fn = functools.partial(on_data_fn, *on_data_args, - **on_data_kwargs) + self._on_data_fn = functools.partial(on_data_fn, *on_data_args, **on_data_kwargs) self._on_data_fn_name = _get_name_from_fn(on_data_fn) # Even if both accept_type and return_type are provided, we should still need to inspect the function signature @@ -277,18 +252,16 @@ def __init__(self, if self._accept_type is signature.empty: logger.warning( "%s argument of %s has no type annotation, defaulting to typing.Any for the stage accept type", - first_param.name, self._on_data_fn_name) + first_param.name, + self._on_data_fn_name) self._accept_type = typing.Any except StopIteration as e: - raise ValueError( - f"Wrapped stage functions {self._on_data_fn_name} must have at least one parameter" - ) from e + raise ValueError(f"Wrapped stage functions {self._on_data_fn_name} must have at least one parameter") from e self._return_type = return_type or signature.return_annotation if self._return_type is signature.empty: - logger.warning( - "Return type of %s has no type annotation, defaulting to the stage's accept type", - self._on_data_fn_name) + logger.warning("Return type of %s has no type annotation, defaulting to the stage's accept type", + self._on_data_fn_name) self._return_type = self._accept_type @property @@ -309,8 +282,7 @@ def compute_schema(self, schema: StageSchema): schema.output_schema.set_type(return_type) - def _build_single(self, builder: mrc.Builder, - input_node: mrc.SegmentObject) -> mrc.SegmentObject: + def _build_single(self, builder: mrc.Builder, input_node: mrc.SegmentObject) -> mrc.SegmentObject: node = builder.make_node(self.unique_name, ops.map(self._on_data_fn)) builder.make_edge(input_node, node) @@ -347,13 +319,8 @@ def stage(on_data_fn: typing.Callable): # Use wraps to ensure user's don't lose their function name and docstrinsgs, however we do want to override the # annotations to reflect that the returned function requires a config and returns a stage - @functools.wraps(on_data_fn, - assigned=('__module__', '__name__', '__qualname__', - '__doc__')) + @functools.wraps(on_data_fn, assigned=('__module__', '__name__', '__qualname__', '__doc__')) def wrapper(config: Config, *args, **kwargs) -> WrappedFunctionStage: - return WrappedFunctionStage(*args, - config=config, - on_data_fn=on_data_fn, - **kwargs) + return WrappedFunctionStage(*args, config=config, on_data_fn=on_data_fn, **kwargs) return wrapper diff --git a/tests/pipeline/test_stage_decorator.py b/tests/pipeline/test_stage_decorator.py index 37eaaec231..2284a3c553 100644 --- a/tests/pipeline/test_stage_decorator.py +++ b/tests/pipeline/test_stage_decorator.py @@ -29,11 +29,11 @@ from morpheus.messages import MessageMeta from morpheus.messages import MultiMessage from morpheus.pipeline import LinearPipeline -from morpheus.pipeline.stage_decorator import source -from morpheus.pipeline.stage_decorator import stage from morpheus.pipeline.stage_decorator import PreAllocatedWrappedFunctionStage from morpheus.pipeline.stage_decorator import WrappedFunctionSourceStage from morpheus.pipeline.stage_decorator import WrappedFunctionStage +from morpheus.pipeline.stage_decorator import source +from morpheus.pipeline.stage_decorator import stage from morpheus.pipeline.stage_schema import StageSchema from morpheus.stages.output.compare_dataframe_stage import CompareDataFrameStage