diff --git a/quixstreams/dataframe/dataframe.py b/quixstreams/dataframe/dataframe.py index 115570c07..65e7f9b32 100644 --- a/quixstreams/dataframe/dataframe.py +++ b/quixstreams/dataframe/dataframe.py @@ -13,6 +13,7 @@ Literal, Optional, Tuple, + TypeVar, Union, cast, overload, @@ -1324,31 +1325,12 @@ def _drop(value: Dict, columns: List[str], ignore_missing: bool = False): raise -@overload -def _as_metadata_func( - func: ApplyCallbackStateful, -) -> ApplyWithMetadataCallbackStateful: ... - - -@overload -def _as_metadata_func( - func: FilterCallbackStateful, -) -> FilterWithMetadataCallbackStateful: ... - - -@overload -def _as_metadata_func( - func: UpdateCallbackStateful, -) -> UpdateWithMetadataCallbackStateful: ... +T = TypeVar("T") def _as_metadata_func( - func: Union[ApplyCallbackStateful, FilterCallbackStateful, UpdateCallbackStateful], -) -> Union[ - ApplyWithMetadataCallbackStateful, - FilterWithMetadataCallbackStateful, - UpdateWithMetadataCallbackStateful, -]: + func: Callable[[Any, State], T], +) -> Callable[[Any, Any, int, Any, State], T]: @functools.wraps(func) def wrapper( value: Any, _key: Any, _timestamp: int, _headers: Any, state: State @@ -1359,9 +1341,9 @@ def wrapper( def _as_stateful( - func: ApplyWithMetadataCallbackStateful, + func: Callable[[Any, Any, int, Any, State], T], processing_context: ProcessingContext, -) -> ApplyWithMetadataCallback: +) -> Callable[[Any, Any, int, Any], T]: @functools.wraps(func) def wrapper(value: Any, key: Any, timestamp: int, headers: Any) -> Any: ctx = message_context()