Skip to content

Commit

Permalink
Formatting fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dagardner-nv committed Nov 17, 2023
1 parent 97f7dcd commit 6e5cd52
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 52 deletions.
67 changes: 17 additions & 50 deletions morpheus/pipeline/stage_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,11 @@ class WrappedFunctionSourceStage(SingleOutputSource):
Additional keyword arguments to bind to `gen_fn` via `functools.partial`.
"""

def __init__(self,
config: Config,
gen_fn: GeneratorType,
*gen_args,
return_type: type = None,
**gen_fn_kwargs):
def __init__(self, config: Config, gen_fn: GeneratorType, *gen_args, return_type: type = None, **gen_fn_kwargs):
super().__init__(config)
# collections.abc.Generator is a subclass of collections.abc.Iterator
if not inspect.isgeneratorfunction(gen_fn):
raise ValueError(
"Wrapped source functions must be generator functions")
raise ValueError("Wrapped source functions must be generator functions")

self._gen_fn = functools.partial(gen_fn, *gen_args, **gen_fn_kwargs)
self._gen_fn_name = _get_name_from_fn(gen_fn)
Expand All @@ -133,8 +127,7 @@ def _build_source(self, builder: mrc.Builder) -> mrc.SegmentObject:
return builder.make_source(self.unique_name, self._gen_fn)


class PreAllocatedWrappedFunctionStage(PreallocatorMixin,
WrappedFunctionSourceStage):
class PreAllocatedWrappedFunctionStage(PreallocatorMixin, WrappedFunctionSourceStage):
"""
Source stage that wraps a generator function as the method for generating messages.
Expand All @@ -159,21 +152,10 @@ class PreAllocatedWrappedFunctionStage(PreallocatorMixin,
Additional keyword arguments to bind to `gen_fn` via `functools.partial`.
"""

def __init__(self,
config: Config,
gen_fn: GeneratorType,
*gen_args,
return_type: type = None,
**gen_fn_kwargs):
super().__init__(*gen_args,
config=config,
gen_fn=gen_fn,
return_type=return_type,
**gen_fn_kwargs)
def __init__(self, config: Config, gen_fn: GeneratorType, *gen_args, return_type: type = None, **gen_fn_kwargs):
super().__init__(*gen_args, config=config, gen_fn=gen_fn, return_type=return_type, **gen_fn_kwargs)
if not _is_dataframe_containing_type(self._return_type):
raise ValueError(
"PreAllocatedWrappedFunctionStage can only be used with DataFrame containing types"
)
raise ValueError("PreAllocatedWrappedFunctionStage can only be used with DataFrame containing types")


def source(gen_fn: GeneratorType):
Expand Down Expand Up @@ -201,9 +183,7 @@ def source(gen_fn: GeneratorType):

# Use wraps to ensure user's don't lose their function name and docstrinsgs, however we do want to override the
# annotations to reflect that the returned function requires a config and returns a stage
@functools.wraps(gen_fn,
assigned=('__module__', '__name__', '__qualname__',
'__doc__'))
@functools.wraps(gen_fn, assigned=('__module__', '__name__', '__qualname__', '__doc__'))
def wrapper(config: Config, *args, **kwargs) -> WrappedFunctionSourceStage:
return_type = _determine_return_type(gen_fn)

Expand All @@ -215,11 +195,7 @@ def wrapper(config: Config, *args, **kwargs) -> WrappedFunctionSourceStage:
return_type=return_type,
**kwargs)

return WrappedFunctionSourceStage(*args,
config=config,
gen_fn=gen_fn,
return_type=return_type,
**kwargs)
return WrappedFunctionSourceStage(*args, config=config, gen_fn=gen_fn, return_type=return_type, **kwargs)

return wrapper

Expand Down Expand Up @@ -263,8 +239,7 @@ def __init__(self,
return_type: type = None,
**on_data_kwargs):
super().__init__(config)
self._on_data_fn = functools.partial(on_data_fn, *on_data_args,
**on_data_kwargs)
self._on_data_fn = functools.partial(on_data_fn, *on_data_args, **on_data_kwargs)
self._on_data_fn_name = _get_name_from_fn(on_data_fn)

# Even if both accept_type and return_type are provided, we should still need to inspect the function signature
Expand All @@ -277,18 +252,16 @@ def __init__(self,
if self._accept_type is signature.empty:
logger.warning(
"%s argument of %s has no type annotation, defaulting to typing.Any for the stage accept type",
first_param.name, self._on_data_fn_name)
first_param.name,
self._on_data_fn_name)
self._accept_type = typing.Any
except StopIteration as e:
raise ValueError(
f"Wrapped stage functions {self._on_data_fn_name} must have at least one parameter"
) from e
raise ValueError(f"Wrapped stage functions {self._on_data_fn_name} must have at least one parameter") from e

self._return_type = return_type or signature.return_annotation
if self._return_type is signature.empty:
logger.warning(
"Return type of %s has no type annotation, defaulting to the stage's accept type",
self._on_data_fn_name)
logger.warning("Return type of %s has no type annotation, defaulting to the stage's accept type",
self._on_data_fn_name)
self._return_type = self._accept_type

@property
Expand All @@ -309,8 +282,7 @@ def compute_schema(self, schema: StageSchema):

schema.output_schema.set_type(return_type)

def _build_single(self, builder: mrc.Builder,
input_node: mrc.SegmentObject) -> mrc.SegmentObject:
def _build_single(self, builder: mrc.Builder, input_node: mrc.SegmentObject) -> mrc.SegmentObject:
node = builder.make_node(self.unique_name, ops.map(self._on_data_fn))
builder.make_edge(input_node, node)

Expand Down Expand Up @@ -347,13 +319,8 @@ def stage(on_data_fn: typing.Callable):

# Use wraps to ensure user's don't lose their function name and docstrinsgs, however we do want to override the
# annotations to reflect that the returned function requires a config and returns a stage
@functools.wraps(on_data_fn,
assigned=('__module__', '__name__', '__qualname__',
'__doc__'))
@functools.wraps(on_data_fn, assigned=('__module__', '__name__', '__qualname__', '__doc__'))
def wrapper(config: Config, *args, **kwargs) -> WrappedFunctionStage:
return WrappedFunctionStage(*args,
config=config,
on_data_fn=on_data_fn,
**kwargs)
return WrappedFunctionStage(*args, config=config, on_data_fn=on_data_fn, **kwargs)

return wrapper
4 changes: 2 additions & 2 deletions tests/pipeline/test_stage_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@
from morpheus.messages import MessageMeta
from morpheus.messages import MultiMessage
from morpheus.pipeline import LinearPipeline
from morpheus.pipeline.stage_decorator import source
from morpheus.pipeline.stage_decorator import stage
from morpheus.pipeline.stage_decorator import PreAllocatedWrappedFunctionStage
from morpheus.pipeline.stage_decorator import WrappedFunctionSourceStage
from morpheus.pipeline.stage_decorator import WrappedFunctionStage
from morpheus.pipeline.stage_decorator import source
from morpheus.pipeline.stage_decorator import stage
from morpheus.pipeline.stage_schema import StageSchema
from morpheus.stages.output.compare_dataframe_stage import CompareDataFrameStage

Expand Down

0 comments on commit 6e5cd52

Please sign in to comment.