Skip to content

Commit

Permalink
mypy: make quixstreams.core.* pass type checks
Browse files Browse the repository at this point in the history
  • Loading branch information
quentin-quix committed Dec 17, 2024
1 parent b3321ef commit 7a2c526
Show file tree
Hide file tree
Showing 12 changed files with 338 additions and 96 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ ignore_errors = true

[[tool.mypy.overrides]]
module = [
"quixstreams.core.*",
"quixstreams.dataframe.*",
"quixstreams.dataframe.series.*",
"quixstreams.dataframe.windows.*",
"quixstreams.rowproducer.*"
]
ignore_errors = true
9 changes: 7 additions & 2 deletions quixstreams/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def alter_context(value):
_current_message_context.set(context)


def message_context() -> Optional[MessageContext]:
def message_context() -> MessageContext:
"""
Get a MessageContext for the current message, which houses most of the message
metadata, like:
Expand All @@ -75,6 +75,11 @@ def message_context() -> Optional[MessageContext]:
:return: instance of `MessageContext`
"""
try:
return _current_message_context.get()
ctx = _current_message_context.get()
except LookupError:
raise MessageContextNotSetError("Message context is not set")

if ctx is None:
raise MessageContextNotSetError("Message context is not set")

return ctx
61 changes: 51 additions & 10 deletions quixstreams/core/stream/functions/apply.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from typing import Any
from typing import Any, Literal, Union, overload

from .base import StreamFunction
from .types import ApplyCallback, ApplyWithMetadataCallback, VoidExecutor
from .types import (
ApplyCallback,
ApplyExpandedCallback,
ApplyWithMetadataCallback,
ApplyWithMetadataExpandedCallback,
VoidExecutor,
)

__all__ = ("ApplyFunction", "ApplyWithMetadataFunction")

Expand All @@ -14,22 +20,34 @@ class ApplyFunction(StreamFunction):
and its result will always be passed downstream.
"""

@overload
def __init__(self, func: ApplyCallback, expand: Literal[False] = False) -> None: ...

@overload
def __init__(self, func: ApplyExpandedCallback, expand: Literal[True]) -> None: ...

def __init__(
self,
func: ApplyCallback,
func: Union[ApplyCallback, ApplyExpandedCallback],
expand: bool = False,
):
super().__init__(func)

self.func: Union[ApplyCallback, ApplyExpandedCallback]
self.expand = expand

def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor:
child_executor = self._resolve_branching(*child_executors)
func = self.func

if self.expand:

def wrapper(
value: Any, key: Any, timestamp: int, headers: Any, func=self.func
):
value: Any,
key: Any,
timestamp: int,
headers: Any,
) -> None:
# Execute a function on a single value and wrap results into a list
# to expand them downstream
result = func(value)
Expand All @@ -39,8 +57,11 @@ def wrapper(
else:

def wrapper(
value: Any, key: Any, timestamp: int, headers: Any, func=self.func
):
value: Any,
key: Any,
timestamp: int,
headers: Any,
) -> None:
# Execute a function on a single value and return its result
result = func(value)
child_executor(result, key, timestamp, headers)
Expand All @@ -57,20 +78,37 @@ class ApplyWithMetadataFunction(StreamFunction):
and its result will always be passed downstream.
"""

@overload
def __init__(
self, func: ApplyWithMetadataCallback, expand: Literal[False] = False
) -> None: ...

@overload
def __init__(
self, func: ApplyWithMetadataExpandedCallback, expand: Literal[True]
) -> None: ...

def __init__(
self,
func: ApplyWithMetadataCallback,
func: Union[ApplyWithMetadataCallback, ApplyWithMetadataExpandedCallback],
expand: bool = False,
):
super().__init__(func)

self.func: Union[ApplyWithMetadataCallback, ApplyWithMetadataExpandedCallback]
self.expand = expand

def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor:
child_executor = self._resolve_branching(*child_executors)
func = self.func

if self.expand:

def wrapper(
value: Any, key: Any, timestamp: int, headers: Any, func=self.func
value: Any,
key: Any,
timestamp: int,
headers: Any,
):
# Execute a function on a single value and wrap results into a list
# to expand them downstream
Expand All @@ -81,7 +119,10 @@ def wrapper(
else:

def wrapper(
value: Any, key: Any, timestamp: int, headers: Any, func=self.func
value: Any,
key: Any,
timestamp: int,
headers: Any,
):
# Execute a function on a single value and return its result
result = func(value, key, timestamp, headers)
Expand Down
3 changes: 3 additions & 0 deletions quixstreams/core/stream/functions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def _resolve_branching(self, *child_executors: VoidExecutor) -> VoidExecutor:
If there's only one executor - copying is not neccessary, and the executor
is returned as is.
"""
if not child_executors:
raise TypeError("At least one executor is required")

if len(child_executors) > 1:

def wrapper(
Expand Down
22 changes: 18 additions & 4 deletions quixstreams/core/stream/functions/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,18 @@ class FilterFunction(StreamFunction):

def __init__(self, func: FilterCallback):
super().__init__(func)
self.func: FilterCallback

def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor:
child_executor = self._resolve_branching(*child_executors)

def wrapper(value: Any, key: Any, timestamp: int, headers: Any, func=self.func):
func = self.func

def wrapper(
value: Any,
key: Any,
timestamp: int,
headers: Any,
):
# Filter a single value
if func(value):
child_executor(value, key, timestamp, headers)
Expand All @@ -42,11 +49,18 @@ class FilterWithMetadataFunction(StreamFunction):

def __init__(self, func: FilterWithMetadataCallback):
super().__init__(func)
self.func: FilterWithMetadataCallback

def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor:
child_executor = self._resolve_branching(*child_executors)

def wrapper(value: Any, key: Any, timestamp: int, headers: Any, func=self.func):
func = self.func

def wrapper(
value: Any,
key: Any,
timestamp: int,
headers: Any,
):
# Filter a single value
if func(value, key, timestamp, headers):
child_executor(value, key, timestamp, headers)
Expand Down
20 changes: 16 additions & 4 deletions quixstreams/core/stream/functions/transform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Union
from typing import Any, Literal, Union, cast, overload

from .base import StreamFunction
from .types import TransformCallback, TransformExpandedCallback, VoidExecutor
Expand All @@ -21,38 +21,50 @@ class TransformFunction(StreamFunction):
The result of the callback will always be passed downstream.
"""

@overload
def __init__(
self, func: TransformCallback, expand: Literal[False] = False
) -> None: ...

@overload
def __init__(
self, func: TransformExpandedCallback, expand: Literal[True]
) -> None: ...

def __init__(
self,
func: Union[TransformCallback, TransformExpandedCallback],
expand: bool = False,
):
super().__init__(func)

self.func: Union[TransformCallback, TransformExpandedCallback]
self.expand = expand

def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor:
child_executor = self._resolve_branching(*child_executors)

if self.expand:
expanded_func = cast(TransformExpandedCallback, self.func)

def wrapper(
value: Any,
key: Any,
timestamp: int,
headers: Any,
func: TransformExpandedCallback = self.func,
):
result = func(value, key, timestamp, headers)
result = expanded_func(value, key, timestamp, headers)
for new_value, new_key, new_timestamp, new_headers in result:
child_executor(new_value, new_key, new_timestamp, new_headers)

else:
func = cast(TransformCallback, self.func)

def wrapper(
value: Any,
key: Any,
timestamp: int,
headers: Any,
func: TransformCallback = self.func,
):
# Execute a function on a single value and return its result
new_value, new_key, new_timestamp, new_headers = func(
Expand Down
10 changes: 8 additions & 2 deletions quixstreams/core/stream/functions/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ class UpdateFunction(StreamFunction):
def __init__(self, func: UpdateCallback):
super().__init__(func)

self.func: UpdateCallback

def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor:
child_executor = self._resolve_branching(*child_executors)
func = self.func

def wrapper(value: Any, key: Any, timestamp: int, headers: Any, func=self.func):
def wrapper(value: Any, key: Any, timestamp: int, headers: Any):
# Update a single value and forward it
func(value)
child_executor(value, key, timestamp, headers)
Expand All @@ -45,10 +48,13 @@ class UpdateWithMetadataFunction(StreamFunction):
def __init__(self, func: UpdateWithMetadataCallback):
super().__init__(func)

self.func: UpdateWithMetadataCallback

def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor:
child_executor = self._resolve_branching(*child_executors)
func = self.func

def wrapper(value: Any, key: Any, timestamp: int, headers: Any, func=self.func):
def wrapper(value: Any, key: Any, timestamp: int, headers: Any):
# Update a single value and forward it
func(value, key, timestamp, headers)
child_executor(value, key, timestamp, headers)
Expand Down
Loading

0 comments on commit 7a2c526

Please sign in to comment.