From 266776d79868e138b7dcdf70f3dda410d7f7643e Mon Sep 17 00:00:00 2001 From: Daniil Gusev <133032822+daniil-quix@users.noreply.github.com> Date: Wed, 15 Nov 2023 10:37:41 +0100 Subject: [PATCH] Allow StreamingDataFrame.apply() to be assigned to keys and used as a filter (#238) --- README.md | 37 +- src/StreamingDataFrames/docs/serialization.md | 28 +- .../docs/stateful-processing.md | 35 +- .../docs/streamingdataframe.md | 313 +++++--- src/StreamingDataFrames/examples/README.md | 6 +- .../examples/bank_example/README.md | 4 +- .../bank_example/json_version/consumer.py | 29 +- .../quix_platform_version/consumer.py | 33 +- .../quix_platform_version/producer.py | 5 +- .../quixstreams/__init__.py | 1 + src/StreamingDataFrames/quixstreams/app.py | 17 +- .../quixstreams/context.py | 49 ++ .../quixstreams/core/__init__.py | 0 .../quixstreams/core/stream/__init__.py | 2 + .../quixstreams/core/stream/functions.py | 125 +++ .../quixstreams/core/stream/stream.py | 239 ++++++ .../quixstreams/dataframe/__init__.py | 2 +- .../quixstreams/dataframe/base.py | 20 + .../quixstreams/dataframe/column.py | 140 ---- .../quixstreams/dataframe/dataframe.py | 419 +++++----- .../quixstreams/dataframe/exceptions.py | 7 - .../quixstreams/dataframe/pipeline.py | 82 -- .../quixstreams/dataframe/series.py | 234 ++++++ .../quixstreams/models/__init__.py | 1 + .../models/{context.py => messagecontext.py} | 0 .../quixstreams/models/rows.py | 2 +- .../quixstreams/models/topics.py | 2 +- .../tests/test_quixstreams/test_app.py | 96 +-- .../tests/test_quixstreams/test_context.py | 52 ++ .../test_quixstreams/test_core/__init__.py | 0 .../test_core/test_functions.py | 42 + .../test_quixstreams/test_core/test_stream.py | 107 +++ .../test_dataframe/fixtures.py | 41 +- .../test_dataframe/test_column.py | 328 -------- .../test_dataframe/test_dataframe.py | 724 +++++++++++------- .../test_dataframe/test_pipeline.py | 32 - .../test_dataframe/test_series.py | 313 ++++++++ 37 files changed, 2210 insertions(+), 1357 deletions(-) create mode 100644 src/StreamingDataFrames/quixstreams/context.py create mode 100644 src/StreamingDataFrames/quixstreams/core/__init__.py create mode 100644 src/StreamingDataFrames/quixstreams/core/stream/__init__.py create mode 100644 src/StreamingDataFrames/quixstreams/core/stream/functions.py create mode 100644 src/StreamingDataFrames/quixstreams/core/stream/stream.py create mode 100644 src/StreamingDataFrames/quixstreams/dataframe/base.py delete mode 100644 src/StreamingDataFrames/quixstreams/dataframe/column.py delete mode 100644 src/StreamingDataFrames/quixstreams/dataframe/exceptions.py delete mode 100644 src/StreamingDataFrames/quixstreams/dataframe/pipeline.py create mode 100644 src/StreamingDataFrames/quixstreams/dataframe/series.py rename src/StreamingDataFrames/quixstreams/models/{context.py => messagecontext.py} (100%) create mode 100644 src/StreamingDataFrames/tests/test_quixstreams/test_context.py create mode 100644 src/StreamingDataFrames/tests/test_quixstreams/test_core/__init__.py create mode 100644 src/StreamingDataFrames/tests/test_quixstreams/test_core/test_functions.py create mode 100644 src/StreamingDataFrames/tests/test_quixstreams/test_core/test_stream.py delete mode 100644 src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_column.py delete mode 100644 src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_pipeline.py create mode 100644 src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_series.py diff --git a/README.md b/README.md index cb55e845d..5afc20166 100644 --- a/README.md +++ b/README.md @@ -63,12 +63,12 @@ See [requirements.txt](./src/StreamingDataFrames/requirements.txt) for the full Here's an example of how to process data from a Kafka Topic with Quix Streams: ```python -from quixstreams import Application, MessageContext, State +from quixstreams import Application, State # Define an application app = Application( - broker_address="localhost:9092", # Kafka broker address - consumer_group="consumer-group-name", # Kafka consumer group + broker_address="localhost:9092", # Kafka broker address + consumer_group="consumer-group-name", # Kafka consumer group ) # Define the input and output topics. By default, "json" serialization will be used @@ -76,13 +76,7 @@ input_topic = app.topic("my_input_topic") output_topic = app.topic("my_output_topic") -def add_one(data: dict, ctx: MessageContext): - for field, value in data.items(): - if isinstance(value, int): - data[field] += 1 - - -def count(data: dict, ctx: MessageContext, state: State): +def count(data: dict, state: State): # Get a value from state for the current Kafka message key total = state.get('total', default=0) total += 1 @@ -91,27 +85,34 @@ def count(data: dict, ctx: MessageContext, state: State): # Update your message data with a value from the state data['total'] = total + # Create a StreamingDataFrame instance # StreamingDataFrame is a primary interface to define the message processing pipeline sdf = app.dataframe(topic=input_topic) # Print the incoming messages -sdf = sdf.apply(lambda value, ctx: print('Received a message:', value)) +sdf = sdf.update(lambda value: print('Received a message:', value)) # Select fields from incoming messages -sdf = sdf[["field_0", "field_2", "field_8"]] +sdf = sdf[["field_1", "field_2", "field_3"]] # Filter only messages with "field_0" > 10 and "field_2" != "test" -sdf = sdf[(sdf["field_0"] > 10) & (sdf["field_2"] != "test")] +sdf = sdf[(sdf["field_1"] > 10) & (sdf["field_2"] != "test")] + +# Filter messages using custom functions +sdf = sdf[sdf.apply(lambda value: 0 < (value['field_1'] + value['field_3']) < 1000)] + +# Generate a new value based on the current one +sdf = sdf.apply(lambda value: {**value, 'new_field': 'new_value'}) -# Apply custom function to transform the message -sdf = sdf.apply(add_one) +# Update a value based on the entire message content +sdf['field_4'] = sdf.apply(lambda value: value['field_1'] + value['field_3']) -# Apply a stateful function to persist data to the state store -sdf = sdf.apply(count, stateful=True) +# Use a stateful function to persist data to the state store and update the value in place +sdf = sdf.update(count, stateful=True) # Print the result before producing it -sdf = sdf.apply(lambda value, ctx: print('Producing a message:', value)) +sdf = sdf.update(lambda value, ctx: print('Producing a message:', value)) # Produce the result to the output topic sdf = sdf.to_topic(output_topic) diff --git a/src/StreamingDataFrames/docs/serialization.md b/src/StreamingDataFrames/docs/serialization.md index 63d8825a5..2883c6285 100644 --- a/src/StreamingDataFrames/docs/serialization.md +++ b/src/StreamingDataFrames/docs/serialization.md @@ -21,30 +21,31 @@ By default, message values are serialized with `JSON`, message keys are seriali - `quix_events` & `quix_timeseries` - for serializers only. ## Using SerDes -To set a serializer, you may either pass a string shorthand for it, or an instance of `streamingdataframes.models.serializers.Serializer` and `streamingdataframes.models.serializers.Deserializer` directly +To set a serializer, you may either pass a string shorthand for it, or an instance of `quixstreams.models.serializers.Serializer` and `quixstreams.models.serializers.Deserializer` directly to the `Application.topic()`. Example with format shorthands: ```python -from streamingdataframes.models.serializers import JSONDeserializer -app = Application(...) +from quixstreams import Application +app = Application(broker_address='localhost:9092', consumer_group='consumer') # Deserializing message values from JSON to objects and message keys as strings -input_topic = app.topic(value_deserializer='json', key_deserializer='string') +input_topic = app.topic('input', value_deserializer='json', key_deserializer='string') # Serializing message values to JSON and message keys to bytes -output_topic = app.topic(value_serializer='json', key_deserializer='bytes') +output_topic = app.topic('output', value_serializer='json', key_deserializer='bytes') ``` Passing `Serializer` and `Deserializer` instances directly: ```python -from streamingdataframes.models.serializers import JSONDeserializer, JSONSerializer -app = Application(...) -input_topic = app.topic(value_deserializer=JSONDeserializer()) -output_topic = app.topic(value_deserializer=JSONSerializer()) +from quixstreams import Application +from quixstreams.models.serializers import JSONDeserializer, JSONSerializer +app = Application(broker_address='localhost:9092', consumer_group='consumer') +input_topic = app.topic('input', value_deserializer=JSONDeserializer()) +output_topic = app.topic('output', value_serializer=JSONSerializer()) ``` -You can find all available serializers in `streamingdataframes.models.serializers` module. +You can find all available serializers in `quixstreams.models.serializers` module. We also plan on including other popular ones like Avro and Protobuf in the near future. @@ -56,8 +57,9 @@ The Deserializer object will wrap the received value to the dictionary with `col Example: ```python -from streamingdataframes.models.serializers import IntegerDeserializer -app = Application(...) -input_topic = app.topic(value_deserializer=IntegerDeserializer(column_name='number')) +from quixstreams import Application +from quixstreams.models.serializers import IntegerDeserializer +app = Application(broker_address='localhost:9092', consumer_group='consumer') +input_topic = app.topic('input', value_deserializer=IntegerDeserializer(column_name='number')) # Will deserialize message with value "123" to "{'number': 123}" ... ``` diff --git a/src/StreamingDataFrames/docs/stateful-processing.md b/src/StreamingDataFrames/docs/stateful-processing.md index 9138d5cc6..9ffcff882 100644 --- a/src/StreamingDataFrames/docs/stateful-processing.md +++ b/src/StreamingDataFrames/docs/stateful-processing.md @@ -45,28 +45,31 @@ When another consumer reads the message with `KEY_B`, it will not be able to rea ## Using State -The state is available in functions passed to `StreamingDataFrame.apply()` with parameter `stateful=True`: +The state is available in functions passed to `StreamingDataFrame.apply()`, `StreamingDataFrame.update()` and `StreamingDataFrame.filter()` with parameter `stateful=True`: ```python -from quixstreams import Application, MessageContext, State -app = Application() +from quixstreams import Application, State +app = Application( + broker_address='localhost:9092', + consumer_group='consumer', +) topic = app.topic('topic') sdf = app.dataframe(topic) -def count_messages(value: dict, ctx: MessageContext, state: State): +def count_messages(value: dict, state: State): total = state.get('total', default=0) total += 1 state.set('total', total) - value['total'] = total + return {**value, 'total': total} -# Apply a custom function and inform StreamingDataFrame to provide a State instance to it -# by passing "stateful=True" -sdf.apply(count_messages, stateful=True) + +# Apply a custom function and inform StreamingDataFrame to provide a State instance to it via passing "stateful=True" +sdf = sdf.apply(count_messages, stateful=True) ``` -Currently, only functions passed to `StreamingDataFrame.apply()` may use State. +Currently, only functions passed to `StreamingDataFrame.apply()`, `StreamingDataFrame.update()` and `StreamingDataFrame.filter()` may use State.
@@ -75,11 +78,19 @@ Currently, only functions passed to `StreamingDataFrame.apply()` may use State. By default, an `Application` keeps the state in `state` directory relative to the current working directory. To change it, pass `state_dir="your-path"` to `Application` or `Application.Quix` calls: ```python -Application(state_dir="folder/path/here") +from quixstreams import Application +app = Application( + broker_address='localhost:9092', + consumer_group='consumer', + state_dir="folder/path/here", +) # or -Application.Quix(state_dir="folder/path/here") +app = Application.Quix( + consumer_group='consumer', + state_dir="folder/path/here", +) ``` ## State Guarantees @@ -105,4 +116,4 @@ We plan to add a proper recovery process in the future. #### Shared state directory In the current version, it's assumed that the state directory is shared between consumers (e.g. using Kubernetes PVC) -If consumers live on different nodes and don't have access to the same state directory, they will not be able to pickup state on rebalancing. \ No newline at end of file +If consumers live on different nodes and don't have access to the same state directory, they will not be able to pick up state on rebalancing. diff --git a/src/StreamingDataFrames/docs/streamingdataframe.md b/src/StreamingDataFrames/docs/streamingdataframe.md index 09ac8515b..817a670d9 100644 --- a/src/StreamingDataFrames/docs/streamingdataframe.md +++ b/src/StreamingDataFrames/docs/streamingdataframe.md @@ -1,18 +1,19 @@ # `StreamingDataFrame`: Detailed Overview -`StreamingDataFrame` and `Column` are the primary interface to define the stream processing pipelines. -Changes to instances of `StreamingDataFrame` and `Column` update the processing pipeline, but the actual +`StreamingDataFrame` and `StreamingSeries` are the primary objects to define the stream processing pipelines. + +Changes to instances of `StreamingDataFrame` and `StreamingSeries` update the processing pipeline, but the actual data changes happen only when it's executed via `Application.run()` Example: ```python -from quixstreams import Application, MessageContext, State +from quixstreams import Application, State # Define an application app = Application( - broker_address="localhost:9092", # Kafka broker address - consumer_group="consumer-group-name", # Kafka consumer group + broker_address="localhost:9092", # Kafka broker address + consumer_group="consumer", # Kafka consumer group ) # Define the input and output topics. By default, the "json" serialization will be used @@ -20,27 +21,28 @@ input_topic = app.topic("my_input_topic") output_topic = app.topic("my_output_topic") -def add_one(data: dict, ctx: MessageContext): +def add_one(data: dict): for field, value in data.items(): if isinstance(value, int): data[field] += 1 - -def count(data: dict, ctx: MessageContext, state: State): + +def count(data: dict, state: State): # Get a value from state for the current Kafka message key - total = state.get('total', default=0) + total = state.get("total", default=0) total += 1 # Set a value back to the state - state.set('total') - # Update your message data with a value from the state - data['total'] = total + state.set("total", total) + # Return result + return total + # Create a StreamingDataFrame instance # StreamingDataFrame is a primary interface to define the message processing pipeline sdf = app.dataframe(topic=input_topic) # Print the incoming messages -sdf = sdf.apply(lambda value, ctx: print('Received a message:', value)) +sdf = sdf.update(lambda value: print("Received a message:", value)) # Select fields from incoming message sdf = sdf[["field_0", "field_2", "field_8"]] @@ -48,51 +50,57 @@ sdf = sdf[["field_0", "field_2", "field_8"]] # Filter only messages with "field_0" > 10 and "field_2" != "test" sdf = sdf[(sdf["field_0"] > 10) & (sdf["field_2"] != "test")] +# You may also use a custom function to filter data +sdf = sdf.filter(lambda v: v["field_0"] > 10 and v["field_2"] != "test") + # Apply custom function to transform the message sdf = sdf.apply(add_one) -# Apply a stateful function in persist data into the state store -sdf = sdf.apply(count, stateful=True) +# Use a stateful function in persist data into the state store +# and update the message value +sdf["total"] = sdf.apply(count, stateful=True) # Print the result before producing it -sdf = sdf.apply(lambda value, ctx: print('Producing a message:', value)) +sdf = sdf.update(lambda value: print("Producing a message:", value)) # Produce the result to the output topic sdf = sdf.to_topic(output_topic) ``` -## Interacting with `Rows` +## Data Types + +`StreamingDataFrame` is agnostic of data types passed to it during processing. -Under the hood, `StreamingDataFrame` is manipulating Kafka messages via `Row` objects. +All functions passed to `StreamingDataFrame` will receive data in the same format as it's deserialized +by the `Topic` object. -Simplified, a `Row` is effectively a dictionary of the Kafka message -value, with each key equivalent to a dataframe column name. +It can also produce any types back to Kafka as long as the value can be serialized +to bytes by `value_serializer` passed to the output `Topic` object. -`StreamingDataFrame` interacts with `Row` objects via the Pandas-like -interface and user-defined functions passed to `StreamingDataFrame.apply()` +The column access like `dataframe["column"]` is supported only for dictionaries.
-## Accessing Fields/Columns +## Accessing Fields via StreamingSeries In typical Pandas dataframe fashion, you can access a column: ```python -sdf["field_a"] # "my_str" +sdf["field_a"] # returns a StreamingSeries with value from field "field_a" ``` Typically, this is done in combination with other operations. -You can also access nested objects (dicts, lists, etc): +You can also access nested objects (dicts, lists, etc.): ```python -sdf["field_c"][2] # 3 +sdf["field_c"][2] # returns a StreamingSeries with value of "field_c[2]" if "field_c" is a collection ```
-## Performing Operations with Columns +## Performing Operations with StreamingSeries You can do almost any basic operations or comparisons with columns, assuming validity: @@ -103,167 +111,270 @@ sdf["field_a"] / sdf["field_b"] sdf["field_a"] | sdf["field_b"] sdf["field_a"] & sdf["field_b"] sdf["field_a"].isnull() -sdf["field_a"].contains('string') +sdf["field_a"].contains("string") sdf["field_a"] != "woo" ```
-## Assigning New Columns +## Assigning New Fields -You may add new columns from the results of numerous other +You may add new fields from the results of numerous other operations: ```python -sdf["a_new_int_field"] = 5 +# Set dictionary key "a_new_int_field" to 5 +sdf["a_new_int_field"] = 5 + +# Set key "a_new_str_field" to a sum of "field_a" and "field_b" sdf["a_new_str_field"] = sdf["field_a"] + sdf["field_b"] -sdf["another_new_field"] = sdf["a_new_str_field"].apply(lambda value, ctx: value + "another") -``` -See [the `.apply()` section](#user-defined-functions-apply) for more information on how that works. +# Do the same but with a custom function applied to a whole message value +sdf["another_new_field"] = sdf.apply(lambda value: value['field_a'] + value['field_b']) +# Use a custom function on StreamingSeries to update key "another_new_field" +sdf["another_new_field"] = sdf["a_new_str_field"].apply(lambda value: value + "another") +```
-## Selecting only certain Columns +## Selecting Columns -In typical Pandas fashion, you can take a subset of columns: +In typical `pandas` fashion, you can take a subset of columns: ```python -# remove "field_d" +# Select only fields "field_a", "field_b", "field_c" sdf = sdf[["field_a", "field_b", "field_c"]] ```
-## Filtering Rows (messages) - -"Filtering" is a very specific concept and operation with `StreamingDataFrames`. +## Filtering -In practice, it functions similarly to how you might filter rows with Pandas DataFrames -with conditionals. +`StreamingDataFrame` provides a similar `pandas`-like API to filter data. -When a "column" reference is actually another operation, it will be treated -as a "filter". If that result is empty or None, the row is now "filtered". +To filter data you may use: +- Conditional expressions with `StreamingSeries` (if underlying message value is deserialized as a dictionary) +- Custom functions like `sdf[sdf.apply(lambda v: v['field'] < 0)]` +- Custom functions like `sdf = sdf.filter(lambda v: v['field'] < 0)` -When filtered, ALL downstream functions for that row are now skipped, +When the value is filtered from the stream, ALL downstream functions for that value are now skipped, _including Kafka-related operations like producing_. +Example: + ```python -# This would continue onward -sdf = sdf[sdf["field_a"] == "my_str"] +# Filter values using `StreamingSeries` expressions +sdf = sdf[(sdf["field_a"] == 'my_string') | (sdf['field_b'] > 0)] -# This would filter the row, skipping further functions -sdf = sdf[(sdf["field_a"] != "woo") & (sdf["field_c"][0] > 100)] +# Filter values using `StreamingDataFrame.apply()` +sdf = sdf[sdf.apply(lambda value: value > 0)] + +# Filter values using `StreamingDataFrame.filter()` +sdf = sdf.filter(lambda value: value >0) ``` + +## Using Custom Functions: `.apply()`, `.update()` and `.filter()` + +`StreamingDataFrame` provides a flexible mechanism to transform and filter data using +simple python functions via `.apply()`, `.update()` and `.filter()` methods. + +All three methods accept 2 arguments: +- A function to apply. +A stateless function should accept only one argument - value. +A stateful function should accept only two argument - value and `State`. + +- A `stateful` flag which can be `True` or `False` (default - `False`).
+By passing `stateful=True`, you inform a `StreamingDataFrame` to pass an extra argument of type `State` to your function +to perform stateful operations. -## User Defined Functions: `.apply()` +Read on for more details about each method. -Should you need more advanced transformations, `.apply()` allows you -to use any python function to operate on your row. -When used on a `StreamingDataFrame`, your function must accept 2 ordered arguments: -- a current Row value (as a dictionary) -- an instance of `MessageContext` that allows you to access other message metadata (key, partition, timestamap, etc). +### `StreamingDataFrame.apply()` +Use `.apply()` when you need to generate a new value based on the input. +
+When using `.apply()`, the result of the function will always be propagated downstream and will become an input for the next functions. +
+Although `.apply()` can mutate the input, it's discouraged, and `.update()` method should be used instead. + +Example: +```python +# Return a new value based on input +sdf = sdf.apply(lambda value: value + 1) +``` + +There are 2 other use cases for `.apply()`: +1. `StreamingDataFrame.apply()` can be used to assign new keys to the value if the value is a dictionary: +```python +# Set a key "field_a" to a sum of "field_b" and "field_c" +sdf['field_a'] = sdf.apply(lambda value: value['field_b'] + value['field_c']) +``` + +2. `StreamingDataFrame.apply()` can be used to filter values. +
+In this case, the result of the passed function is interpreted as `bool`: +```python +# Filter values where sum of "field_b" and "field_c" is greater than 0 +sdf = sdf[sdf.apply(lambda value: (value['field_b'] + value['field_c']) > 0)] +``` -Consequently, your function **MUST either** _alter this dict in-place_ -**OR** _return a dictionary_ to directly replace the current data with. +### `StreamingDataFrame.update()` +Use `.update()` when you need to mutate the input value in place or to perform a side effect without generating a new value. +For example, use to print data to the console or to simply update the counter in the State. -For example: +The result of a function passed to `.update()` is always ignored, and its input will be propagated downstream instead. +Examples: ```python -# in place example -def in_place(value: dict, ctx: MessageContext): - value['key'] = 'value' - -sdf = sdf.apply(in_place) - -# replacement example -def new_data(value: dict, ctx: MessageContext): - new_value = {'key': value['key']} - return new_value - -sdf = sdf.apply(new_data) +# Mutate a list by appending a new item to it +# The updated list will be passed downstream +sdf = sdf.update(lambda value: value.append(1)) + +# Use .update() to print a value to the console +sdf = sdf.update(lambda value: print("Received value: ", value)) ``` + +### `StreamingDataFrame.filter()` +Use `.filter()` to filter values based on entire message content.
+The result of a function passed to `.filter()` is interpreted as boolean. +```python +# Filter out values with "field_a" <= 0 +sdf = sdf.filter(lambda value: value['field_a'] > 0) -The `.apply()` function is also valid for columns, but rather than providing a -dictionary, it instead uses the column value, and the function must return a value. +# Filter out values where "field_a" is False +sdf = sdf.filter(lambda value: value['field_a']) +``` +You may also achieve the same result with `sdf[sdf.apply()]` syntax: ```python -sdf["new_field"] = sdf["field_a"].apply(lambda value, ctx: value + "-add_me") +# Filter out values with "field_a" <= 0 using .apply() syntax +sdf = sdf[sdf.apply(lambda value: value['field_a'] > 0)] ``` -NOTE: Every `.apply()` is a _temporary_ state change, but the result can be assigned. -So, in the above example, `field_a` remains `my_str`, but `new_field == my_str-add_me` -as desired. +
+ +### Using custom functions with StreamingSeries +The `.apply()` function is also valid for `StreamingSeries`. +But instead of receiving an entire message value, it will receive only a value of the particular key: + +```python +# Generate a new value based on "field_b" and assign it back to "field_a" +sdf['field_a'] = sdf['field_b'].apply(lambda field_b: field_b.strip()) +``` + +It follows the same rules as `StreamingDataFrame.apply()`, and the result of the function +will be returned as is. + +`StreamingSeries` supports only `.apply()` method.
-### Stateful Processing with `.apply()` +## Stateful Processing with Custom Functions + +If you want to use persistent state during processing, you can access the state for a given _message key_ via +passing `stateful=True` to `StreamingDataFrame.apply()`, `StreamingDataFrame.update()` or `StreamingDataFrame.filter()`. -If you want to use persistent state during processing, you can access the state for a given row via -a keyword argument `stateful=True`, and your function should accept a third `State` object as -an argument (you can just call it something like `state`). +In this case, your custom function should accept a second argument of type `State`. -When your function has access to state, it will receive a `State` object, which can do: +The `State` object provides a minimal API to worked with persistent state sore: - `.get(key, default=None)` - `.set(key, value)` - `.delete(key)` - `.exists(key)` -`Key` and value can be anything, and you can have any number of keys. -NOTE: `key` is unrelated to the Kafka message key, which is handled behind the scenes. +You may treat `State` as a dictionary-like structure. +
+`Key` and `value` can be of any type as long as they are serializable to JSON (a default serialization format for the State). +
+You may easily store strings, numbers, lists, tuples and dictionaries. + + + +Under the hood, the `key` is always prefixed by the actual Kafka message key to ensure +that messages with different keys don't have access to the same state. + ```python -from quixstreams import MessageContext, State +from quixstreams import State -def edit_data(row, ctx: MessageContext, state: State): - msg_max = len(row["field_c"]) +# Update current value using stateful operations + +def edit_data(value, state: State): + msg_max = len(value["field_c"]) current_max = state.get("current_len_max") if current_max < msg_max: state.set("current_len_max", msg_max) current_max = msg_max - row["len_max"] = current_max - return row + value["len_max"] = current_max -sdf = sdf.apply(edit_data, stateful=True) +sdf = sdf.update(edit_data, stateful=True) ``` For more information about stateful processing in general, see [**Stateful Applications**](./stateful_processing.md). -
+## Accessing the Kafka Message Keys and Metadata +`quixstreams` provides access to the metadata of the current Kafka message via `quixstreams.context` module. -## Producing to Topics: `.to_topic()` +Information like message key, topic, partition, offset, timestamp and more is stored globally in `MessageContext` object, +and it's updated on each incoming message. -To send the current state of the `StreamingDataFrame` to a topic, simply call -`to_topic` with a `Topic` instance generated from `Application.topic()` +To get the current message key, use `quixstreams.message_key` function: + +```python +from quixstreams import message_key +sdf = sdf.apply(lambda value: 1 if message_key() == b'1' else 0) +``` + +To get the whole `MessageContext` object with all attributes including keys, use `quixstreams.message_context` +```python +from quixstreams import message_context + +# Get current message timestamp and set it to a "timestamp" key +sdf['timestamp'] = sdf.apply(lambda value: message_context().timestamp.milliseconds) +``` + +Both `quixstreams.message_key()` and `quixstreams.message_context()` should be called +only from the custom functions during processing. + + +## Producing to Topics: `StreamingDataFrame.to_topic()` + +To send the current value of the `StreamingDataFrame` to a topic, simply call +`.to_topic()` with a `Topic` instance generated from `Application.topic()` as an argument. To change the outgoing message key (which defaults to the current consumed key), -you can optionally provide a key function, which operates similarly to the `.apply()` -function with a `row` (dict) and `ctx` argument, and returns a desired -(serializable) key. +you can optionally provide a key function, which operates similarly to the `.apply()`. +
+It should accept a message value and return a new key. + +The returned key must be compatible with `key_serializer` provided to the `Topic` object. ```python from quixstreams import Application -app = Application(...) -output_topic = app.topic("my_output_topic") +app = Application(broker_address='localhost:9092', consumer_group='consumer') + +# Incoming key is deserialized to string +input_topic = app.topic("input", key_deserializer='str') +# Outgoing key will be serialized as a string too +output_topic = app.topic("my_output_topic", key_serializer='str') # Producing a new message to a topic with the same key -sdf = sdf.to_topic(other_output_topic) +sdf = sdf.to_topic(output_topic) -# Producing a new message to a topic with a new key -sdf = sdf.to_topic(output_topic, key=lambda value, ctx: ctx.key + value['field']) +# Generate a new message key based on "value['field']" assuming it is a string +sdf = sdf.to_topic(output_topic, key=lambda value: str(value["field"])) ``` diff --git a/src/StreamingDataFrames/examples/README.md b/src/StreamingDataFrames/examples/README.md index 61aa12243..4e0f43681 100644 --- a/src/StreamingDataFrames/examples/README.md +++ b/src/StreamingDataFrames/examples/README.md @@ -1,7 +1,7 @@ -# `StreamingDataFrames` Examples +# Quix Streams Examples -This folder contains a few examples/boiler-plate applications to get you started with -the `StreamingDataFrames` library. +This folder contains a few boilerplate applications to get you started with +the Quix Streams library. ## Running an Example diff --git a/src/StreamingDataFrames/examples/bank_example/README.md b/src/StreamingDataFrames/examples/bank_example/README.md index 6c23d555d..5e5ab284a 100644 --- a/src/StreamingDataFrames/examples/bank_example/README.md +++ b/src/StreamingDataFrames/examples/bank_example/README.md @@ -18,8 +18,8 @@ account and the purchase attempt was above a certain cost (we don't want to spam This example showcases: - How to use multiple Quix kafka applications together - Producer - - Consumer via `streamingdataframes.Application` (consumes and produces) - - Basic usage of the `dataframe` object + - Consumer via `quixstreams.Application` (consumes and produces) + - Basic usage of the `StreamingDataFrame` object - Using different serializations/data structures - json - Quix serializers (more intended for Quix platform use) diff --git a/src/StreamingDataFrames/examples/bank_example/json_version/consumer.py b/src/StreamingDataFrames/examples/bank_example/json_version/consumer.py index 86ea20642..b423a0ccc 100644 --- a/src/StreamingDataFrames/examples/bank_example/json_version/consumer.py +++ b/src/StreamingDataFrames/examples/bank_example/json_version/consumer.py @@ -9,36 +9,34 @@ from dotenv import load_dotenv -from quixstreams import Application, MessageContext, State +from quixstreams import Application, State, message_key load_dotenv("./env_vars.env") -def count_transactions(value: dict, ctx: MessageContext, state: State): +def count_transactions(value: dict, state: State): """ Track the number of transactions using persistent state :param value: message value - :param ctx: message context with key, timestamp and other Kafka message metadata :param state: instance of State store """ total = state.get("total_transactions", 0) total += 1 state.set("total_transactions", total) - value["total_transactions"] = total + return total -def uppercase_source(value: dict, ctx: MessageContext): +def uppercase_source(value: dict): """ Upper-case field "transaction_source" for each processed message :param value: message value, a dictionary with all deserialized message data - :param ctx: message context, it contains message metadata like key, topic, timestamp - etc. + :return: this function must either return None or a new dictionary """ - print(f'Processing message with key "{ctx.key}"') + print(f'Processing message with key "{message_key()}"') value["transaction_source"] = value["transaction_source"].upper() return value @@ -62,28 +60,25 @@ def uppercase_source(value: dict, ctx: MessageContext): sdf = app.dataframe(input_topic) # Filter only messages with "account_class" == "Gold" and "transaction_amount" >= 1000 -sdf = sdf[ - (sdf["account_class"] == "Gold") - & (sdf["transaction_amount"].apply(lambda x, ctx: abs(x)) >= 1000) -] +sdf = sdf[(sdf["account_class"] == "Gold") & (sdf["transaction_amount"].abs() >= 1000)] # Drop all fields except the ones we need sdf = sdf[["account_id", "transaction_amount", "transaction_source"]] -# Update the total number of transactions in state -sdf = sdf.apply(count_transactions, stateful=True) +# Update the total number of transactions in state and save result to the message +sdf["total_transactions"] = sdf.apply(count_transactions, stateful=True) # Transform field "transaction_source" to upper-case using a custom function -sdf = sdf.apply(uppercase_source) +sdf["transaction_source"] = sdf["transaction_source"].apply(lambda v: v.upper()) # Add a new field with a notification text sdf["customer_notification"] = "A high cost purchase was attempted" # Print the transformed message to the console -sdf = sdf.apply(lambda val, ctx: print(f"Sending update: {val}")) +sdf = sdf.update(lambda val: print(f"Sending update: {val}")) # Send the message to the output topic -sdf.to_topic(output_topic) +sdf = sdf.to_topic(output_topic) if __name__ == "__main__": # Start message processing diff --git a/src/StreamingDataFrames/examples/bank_example/quix_platform_version/consumer.py b/src/StreamingDataFrames/examples/bank_example/quix_platform_version/consumer.py index 4a6ea30bf..004fdba7d 100644 --- a/src/StreamingDataFrames/examples/bank_example/quix_platform_version/consumer.py +++ b/src/StreamingDataFrames/examples/bank_example/quix_platform_version/consumer.py @@ -8,40 +8,37 @@ from dotenv import load_dotenv -from quixstreams import Application, MessageContext, State +from quixstreams import Application, State, message_key + # Reminder: the platform will have these values available by default so loading the # environment would be unnecessary there. load_dotenv("./bank_example/quix_platform_version/quix_vars.env") -def uppercase_source(value: dict, ctx: MessageContext): +def uppercase_source(value: dict): """ Upper-case field "transaction_source" for each processed message :param value: message value, a dictionary with all deserialized message data - :param ctx: message context, it contains message metadata like key, topic, timestamp - etc. :return: this function must either return None or a new dictionary """ - print(f'Processing message with key "{ctx.key}"') + print(f'Processing message with key "{message_key()}"') value["transaction_source"] = value["transaction_source"].upper() - return value -def count_transactions(value: dict, ctx: MessageContext, state: State): +def count_transactions(value: dict, state: State): """ Track the number of transactions using persistent state :param value: message value - :param ctx: message context with key, timestamp and other Kafka message metadata :param state: instance of State store """ total = state.get("total_transactions", 0) total += 1 state.set("total_transactions", total) - value["total_transactions"] = total + return total # Define your application and settings @@ -61,30 +58,26 @@ def count_transactions(value: dict, ctx: MessageContext, state: State): # Create a StreamingDataFrame and start building your processing pipeline sdf = app.dataframe(input_topic) - # Filter only messages with "account_class" == "Gold" and "transaction_amount" >= 1000 -sdf = sdf[ - (sdf["account_class"] == "Gold") - & (sdf["transaction_amount"].apply(lambda x, ctx: abs(x)) >= 1000) -] +sdf = sdf[(sdf["account_class"] == "Gold") & (sdf["transaction_amount"].abs() >= 1000)] # Drop all fields except the ones we need -sdf = sdf[["account_id", "transaction_amount", "transaction_source"]] +sdf = sdf[["Timestamp", "account_id", "transaction_amount", "transaction_source"]] -# Update the total number of transactions in state -sdf = sdf.apply(count_transactions, stateful=True) +# Update the total number of transactions in state and save result to the message +sdf["total_transactions"] = sdf.apply(count_transactions, stateful=True) # Transform field "transaction_source" to upper-case using a custom function -sdf = sdf.apply(uppercase_source) +sdf["transaction_source"] = sdf["transaction_source"].apply(lambda v: v.upper()) # Add a new field with a notification text sdf["customer_notification"] = "A high cost purchase was attempted" # Print the transformed message to the console -sdf = sdf.apply(lambda val, ctx: print(f"Sending update: {val}")) +sdf = sdf.update(lambda val: print(f"Sending update: {val}")) # Send the message to the output topic -sdf.to_topic(output_topic) +sdf = sdf.to_topic(output_topic) if __name__ == "__main__": # Start message processing diff --git a/src/StreamingDataFrames/examples/bank_example/quix_platform_version/producer.py b/src/StreamingDataFrames/examples/bank_example/quix_platform_version/producer.py index be93dacde..c8c96da8c 100644 --- a/src/StreamingDataFrames/examples/bank_example/quix_platform_version/producer.py +++ b/src/StreamingDataFrames/examples/bank_example/quix_platform_version/producer.py @@ -1,3 +1,4 @@ +import time import uuid from random import randint, random, choice from time import sleep @@ -9,8 +10,7 @@ QuixTimeseriesSerializer, SerializationContext, ) -from quixstreams.models.topics import Topic, TopicCreationConfigs -from quixstreams.platforms.quix import QuixKafkaConfigsBuilder +from quixstreams.platforms.quix import QuixKafkaConfigsBuilder, TopicCreationConfigs load_dotenv("./bank_example/quix_platform_version/quix_vars.env") @@ -48,6 +48,7 @@ "account_class": "Gold" if account >= 8 else "Silver", "transaction_amount": randint(-2500, -1), "transaction_source": choice(retailers), + "Timestamp": time.time_ns(), } print(f"Producing value {value}") producer.produce( diff --git a/src/StreamingDataFrames/quixstreams/__init__.py b/src/StreamingDataFrames/quixstreams/__init__.py index 201a8d840..ff37d95b4 100644 --- a/src/StreamingDataFrames/quixstreams/__init__.py +++ b/src/StreamingDataFrames/quixstreams/__init__.py @@ -1,5 +1,6 @@ from .app import Application from .models import MessageContext from .state import State +from .context import message_context, message_key __version__ = "2.0alpha2" diff --git a/src/StreamingDataFrames/quixstreams/app.py b/src/StreamingDataFrames/quixstreams/app.py index 54259cdb6..ebb5034ef 100644 --- a/src/StreamingDataFrames/quixstreams/app.py +++ b/src/StreamingDataFrames/quixstreams/app.py @@ -6,6 +6,8 @@ from confluent_kafka import TopicPartition from typing_extensions import Self +from .context import set_message_context, copy_context +from .core.stream import Filtered from .dataframe import StreamingDataFrame from .error_callbacks import ( ConsumerErrorCallback, @@ -335,7 +337,6 @@ def dataframe( :return: `StreamingDataFrame` object """ sdf = StreamingDataFrame(topic=topic, state_manager=self._state_manager) - sdf.consumer = self._consumer sdf.producer = self._producer return sdf @@ -404,7 +405,7 @@ def run( with exit_stack: # Subscribe to topics in Kafka and start polling self._consumer.subscribe( - list(dataframe.topics_in.values()), + [dataframe.topic], on_assign=self._on_assign, on_revoke=self._on_revoke, on_lost=self._on_lost, @@ -412,6 +413,8 @@ def run( logger.info("Waiting for incoming messages") # Start polling Kafka for messages and callbacks self._running = True + + dataframe_compiled = dataframe.compile() while self._running: # Serve producer callbacks self._producer.poll(self._producer_poll_timeout) @@ -431,13 +434,21 @@ def run( first_row.partition, first_row.offset, ) + # Create a new contextvars.Context and set the current MessageContext + # (it's the same across multiple rows) + context = copy_context() + context.run(set_message_context, first_row.context) with start_state_transaction( topic=topic_name, partition=partition, offset=offset ): for row in rows: try: - dataframe.process(row=row) + # Execute StreamingDataFrame in a context + context.run(dataframe_compiled, row.value) + except Filtered: + # The message was filtered by StreamingDataFrame + continue except Exception as exc: # TODO: This callback might be triggered because of Producer # errors too because they happen within ".process()" diff --git a/src/StreamingDataFrames/quixstreams/context.py b/src/StreamingDataFrames/quixstreams/context.py new file mode 100644 index 000000000..2e4de4dbe --- /dev/null +++ b/src/StreamingDataFrames/quixstreams/context.py @@ -0,0 +1,49 @@ +from contextvars import ContextVar, copy_context +from typing import Optional, Any + +from quixstreams.exceptions import QuixException +from quixstreams.models.messagecontext import MessageContext + +__all__ = ( + "MessageContextNotSetError", + "set_message_context", + "message_key", + "message_context", + "copy_context", +) + +_current_message_context = ContextVar("current_message_context") + + +class MessageContextNotSetError(QuixException): + ... + + +def set_message_context(context: Optional[MessageContext]): + """ + Set a MessageContext for the current message in the given `contextvars.Context` + + :param context: instance of `MessageContext` + """ + _current_message_context.set(context) + + +def message_context() -> MessageContext: + """ + Get a MessageContext for the current message + :return: instance of `MessageContext` + """ + try: + return _current_message_context.get() + + except LookupError: + raise MessageContextNotSetError("Message context is not set") + + +def message_key() -> Any: + """ + Get current a message key. + + :return: a deserialized message key + """ + return message_context().key diff --git a/src/StreamingDataFrames/quixstreams/core/__init__.py b/src/StreamingDataFrames/quixstreams/core/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/StreamingDataFrames/quixstreams/core/stream/__init__.py b/src/StreamingDataFrames/quixstreams/core/stream/__init__.py new file mode 100644 index 000000000..0c64aa38d --- /dev/null +++ b/src/StreamingDataFrames/quixstreams/core/stream/__init__.py @@ -0,0 +1,2 @@ +from .stream import * +from .functions import * diff --git a/src/StreamingDataFrames/quixstreams/core/stream/functions.py b/src/StreamingDataFrames/quixstreams/core/stream/functions.py new file mode 100644 index 000000000..bcc67be33 --- /dev/null +++ b/src/StreamingDataFrames/quixstreams/core/stream/functions.py @@ -0,0 +1,125 @@ +import enum +import functools +from typing import TypeVar, Callable, Protocol, Optional + +__all__ = ( + "StreamCallable", + "SupportsBool", + "Filtered", + "Apply", + "Update", + "Filter", + "is_filter_function", + "is_update_function", + "is_apply_function", + "get_stream_function_type", +) + +R = TypeVar("R") +T = TypeVar("T") + +StreamCallable = Callable[[T], R] + + +class SupportsBool(Protocol): + def __bool__(self) -> bool: + ... + + +class Filtered(Exception): + ... + + +class StreamFunctionType(enum.IntEnum): + FILTER = 1 + UPDATE = 2 + APPLY = 3 + + +_STREAM_FUNC_TYPE_ATTR = "__stream_function_type__" + + +def get_stream_function_type(func: StreamCallable) -> Optional[StreamFunctionType]: + return getattr(func, _STREAM_FUNC_TYPE_ATTR, None) + + +def set_stream_function_type(func: StreamCallable, type_: StreamFunctionType): + setattr(func, _STREAM_FUNC_TYPE_ATTR, type_) + + +def is_filter_function(func: StreamCallable) -> bool: + func_type = get_stream_function_type(func) + return func_type is not None and func_type == StreamFunctionType.FILTER + + +def is_update_function(func: StreamCallable) -> bool: + func_type = get_stream_function_type(func) + return func_type is not None and func_type == StreamFunctionType.UPDATE + + +def is_apply_function(func: StreamCallable) -> bool: + func_type = get_stream_function_type(func) + return func_type is not None and func_type == StreamFunctionType.APPLY + + +def Filter(func: Callable[[T], SupportsBool]) -> StreamCallable: + """ + Wraps function into a "Filter" function. + The result of a Filter function is interpreted as boolean. + If it's `True`, the input will be return downstream. + If it's `False`, the `Filtered` exception will be raised to signal that the + value is filtered out. + + :param func: a function to filter value + :return: a Filter function + """ + + def wrapper(value: T) -> T: + result = func(value) + if not result: + raise Filtered() + return value + + wrapper = functools.update_wrapper(wrapper=wrapper, wrapped=func) + set_stream_function_type(wrapper, StreamFunctionType.FILTER) + return wrapper + + +def Update(func: StreamCallable) -> StreamCallable: + """ + Wrap a function into "Update" function. + + The provided function is expected to mutate the value. + Its result will always be ignored, and its input is passed + downstream. + + :param func: a function to mutate values + :return: an Update function + """ + + def wrapper(value: T) -> T: + func(value) + return value + + wrapper = functools.update_wrapper(wrapper=wrapper, wrapped=func) + set_stream_function_type(wrapper, StreamFunctionType.UPDATE) + return wrapper + + +def Apply(func: StreamCallable) -> StreamCallable: + """ + Wrap a function into "Apply" function. + + The provided function is expected to return a new value based on input, + and its result will always be passed downstream. + + :param func: a function to generate a new value + :return: an Apply function + """ + + def wrapper(value: T) -> R: + return func(value) + + wrapper = functools.update_wrapper(wrapper=wrapper, wrapped=func) + set_stream_function_type(wrapper, StreamFunctionType.APPLY) + return wrapper diff --git a/src/StreamingDataFrames/quixstreams/core/stream/stream.py b/src/StreamingDataFrames/quixstreams/core/stream/stream.py new file mode 100644 index 000000000..87280fe10 --- /dev/null +++ b/src/StreamingDataFrames/quixstreams/core/stream/stream.py @@ -0,0 +1,239 @@ +import copy +import functools +import itertools +from typing import Any, List, Callable, Optional, TypeVar + +from typing_extensions import Self + +from .functions import * + +__all__ = ("Stream",) + +R = TypeVar("R") +T = TypeVar("T") + + +class Stream: + def __init__( + self, + func: Optional[Callable[[Any], Any]] = None, + parent: Optional[Self] = None, + ): + """ + A base class for all streaming operations. + + `Stream` is an abstraction of a function pipeline. + Each Stream has a function and a parent (None by default). + When adding new function to the stream, it creates a new `Stream` object and + sets "parent" to the previous `Stream` to maintain an order of execution. + + Streams supports 3 types of functions: + - "Apply" - generate new values based on a previous one. + The result of an Apply function is passed downstream to the next functions. + - "Update" - update values in-place. + The result of an Update function is always ignored, and its input is passed + downstream. + - "Filter" - to filter values from the Stream. + The result of a Filter function is interpreted as boolean. + If it's `True`, the input will be passed downstream. + If it's `False`, the `Filtered` exception will be raised to signal that the + value is filtered out. + + To execute functions accumulated on the `Stream` instance, it must be compiled + using `.compile()` method, which returns a function closure executing all the + functions in the tree. + + :param func: a function to be called on the stream. + It is expected to be wrapped into one of "Apply", "Filter" or "Update" from + `quixstreams.core.stream.functions` package. + Default - "Apply(lambda v: v)". + :param parent: a parent `Stream` + """ + if func is not None and not any( + ( + is_apply_function(func), + is_update_function(func), + is_filter_function(func), + ) + ): + raise ValueError( + "Provided function must be either Apply, Filter or Update function" + ) + + self.func = func if func is not None else Apply(lambda x: x) + self.parent = parent + + def __repr__(self) -> str: + """ + Generate a nice repr with all functions in the stream and its parents. + + :return: a string of format + "]: | ... >" + """ + tree_funcs = [s.func for s in self.tree()] + funcs_repr = " | ".join( + ( + f"<{get_stream_function_type(f).name}: {f.__qualname__}>" + for f in tree_funcs + ) + ) + return f"<{self.__class__.__name__} [{len(tree_funcs)}]: {funcs_repr}>" + + def add_filter(self, func: Callable[[T], R]) -> Self: + """ + Add a function to filter values from the Stream. + + The return value of the function will be interpreted as `bool`. + If the function returns `False`-like result, the Stream will raise `Filtered` + exception during execution. + + :param func: a function to filter values from the stream + :return: a new `Stream` derived from the current one + """ + return self._add(Filter(func)) + + def add_apply(self, func: Callable[[T], R]) -> Self: + """ + Add an "apply" function to the Stream. + + The function is supposed to return a new value, which will be passed + further during execution. + + :param func: a function to generate a new value + :return: a new `Stream` derived from the current one + """ + return self._add(Apply(func)) + + def add_update(self, func: Callable[[T], object]) -> Self: + """ + Add an "update" function to the Stream, that will mutate the input value. + + The return of this function will be ignored and its input + will be passed downstream. + + :param func: a function to mutate the value + :return: a new Stream derived from the current one + """ + return self._add((Update(func))) + + def diff( + self, + other: "Stream", + ) -> Self: + """ + Takes the difference between Streams `self` and `other` based on their last + common parent, and returns a new `Stream` that includes only this difference. + + It's impossible to calculate a diff when: + - Streams don't have a common parent. + - When the `self` Stream already includes all the nodes from + the `other` Stream, and the resulting diff is empty. + + :param other: a `Stream` to take a diff from. + :raises ValueError: if Streams don't have a common parent + or if the diff is empty. + :return: new `Stream` instance including all the Streams from the diff + """ + diff = self._diff_from_last_common_parent(other) + parent = None + head = None + for node in diff: + # Copy the node to ensure we don't alter the previously created Nodes + node = copy.deepcopy(node) + node.parent = parent + parent = node + head = node + return head + + def tree(self) -> List[Self]: + """ + Return a list of all parent Streams including the node itself. + + The tree is ordered from child to parent (current node comes first). + :return: a list of `Stream` objects + """ + tree_ = [self] + node = self + while node.parent: + tree_.insert(0, node.parent) + node = node.parent + return tree_ + + def compile( + self, + allow_filters: bool = True, + allow_updates: bool = True, + ) -> Callable[[T], R]: + """ + Compile a list of functions from this `Stream` and its parents into a single + big closure using a "compiler" function. + + Closures are more performant than calling all the functions in the + `Stream.tree()` one-by-one. + + :param allow_filters: If False, this function will fail with ValueError if + the stream has filter functions in the tree. Default - True. + :param allow_updates: If False, this function will fail with ValueError if + the stream has update functions in the tree. Default - True. + + :raises ValueError: if disallowed functions are present in the stream tree. + """ + compiled = None + tree = self.tree() + for node in tree: + func = node.func + if not allow_updates and is_update_function(func): + raise ValueError("Update functions are not allowed in this stream") + elif not allow_filters and is_filter_function(func): + raise ValueError("Filter functions are not allowed in this stream") + + if compiled is None: + compiled = func + else: + compiled = compiler(func)(compiled) + + return compiled + + def _diff_from_last_common_parent(self, other: Self) -> List[Self]: + nodes_self = self.tree() + nodes_other = other.tree() + + diff = [] + last_common_parent = None + for node_self, node_other in itertools.zip_longest(nodes_self, nodes_other): + if node_self is node_other: + last_common_parent = node_other + elif node_other is not None: + diff.append(node_other) + + if last_common_parent is None: + raise ValueError("Common parent not found") + if not diff: + raise ValueError("The diff is empty") + return diff + + def _add(self, func: Callable[[T], R]) -> Self: + return self.__class__(func=func, parent=self) + + +def compiler(outer_func: Callable[[T], R]) -> Callable[[T], R]: + """ + A function that wraps two other functions into a closure. + It passes the result of thje inner function as an input to the outer function. + + It is used to transform (aka "compile") a list of functions into one large closure + like this: + ``` + [func, func, func] -> func(func(func())) + ``` + + :return: a function with one argument (value) + """ + + def wrapper(inner_func: Callable[[T], R]): + def _wrapper(v: T) -> R: + return outer_func(inner_func(v)) + + return functools.update_wrapper(_wrapper, inner_func) + + return functools.update_wrapper(wrapper, outer_func) diff --git a/src/StreamingDataFrames/quixstreams/dataframe/__init__.py b/src/StreamingDataFrames/quixstreams/dataframe/__init__.py index 61bb62ada..b6f5296e8 100644 --- a/src/StreamingDataFrames/quixstreams/dataframe/__init__.py +++ b/src/StreamingDataFrames/quixstreams/dataframe/__init__.py @@ -1,2 +1,2 @@ from .dataframe import StreamingDataFrame -from .exceptions import * +from .series import * diff --git a/src/StreamingDataFrames/quixstreams/dataframe/base.py b/src/StreamingDataFrames/quixstreams/dataframe/base.py new file mode 100644 index 000000000..de133fab9 --- /dev/null +++ b/src/StreamingDataFrames/quixstreams/dataframe/base.py @@ -0,0 +1,20 @@ +import abc +from typing import Optional, Any + +from quixstreams.core.stream import Stream, StreamCallable +from quixstreams.models.messagecontext import MessageContext + + +class BaseStreaming: + @property + @abc.abstractmethod + def stream(self) -> Stream: + ... + + @abc.abstractmethod + def compile(self, *args, **kwargs) -> StreamCallable: + ... + + @abc.abstractmethod + def test(self, value: Any, ctx: Optional[MessageContext] = None) -> Any: + ... diff --git a/src/StreamingDataFrames/quixstreams/dataframe/column.py b/src/StreamingDataFrames/quixstreams/dataframe/column.py deleted file mode 100644 index 8794c592d..000000000 --- a/src/StreamingDataFrames/quixstreams/dataframe/column.py +++ /dev/null @@ -1,140 +0,0 @@ -import operator -from typing import Optional, Any, Callable, Container - -from typing_extensions import Self, TypeAlias, Union - -from ..models import Row, MessageContext - -ColumnApplier: TypeAlias = Callable[[Any, MessageContext], Any] - -__all__ = ("Column", "ColumnApplier") - - -def invert(value): - if isinstance(value, bool): - return operator.not_(value) - else: - return operator.invert(value) - - -class Column: - def __init__( - self, - col_name: Optional[str] = None, - _eval_func: Optional[ColumnApplier] = None, - ): - self.col_name = col_name - self._eval_func = _eval_func if _eval_func else lambda row: row[self.col_name] - - def __getitem__(self, item: Union[str, int]) -> Self: - return self.__class__(_eval_func=lambda x: self.eval(x)[item]) - - def _operation(self, other: Any, op: Callable[[Any, Any], Any]) -> Self: - return self.__class__( - _eval_func=lambda x: op( - self.eval(x), other.eval(x) if isinstance(other, Column) else other - ), - ) - - def eval(self, row: Row) -> Any: - """ - Execute all the functions accumulated on this Column. - - :param row: A Quixstreams Row - :return: A primitive type - """ - return self._eval_func(row) - - def apply(self, func: ColumnApplier) -> Self: - """ - Add a callable to the execution list for this column. - - The provided callable should accept a single argument, which will be its input. - The provided callable should similarly return one output, or None - - :param func: a callable with one argument and one output - :return: a new Column with the new callable added - """ - return Column(_eval_func=lambda row: func(self.eval(row), row.context)) - - def isin(self, other: Container) -> Self: - return self._operation(other, lambda a, b: operator.contains(b, a)) - - def contains(self, other: Any) -> Self: - return self._operation(other, operator.contains) - - def is_(self, other: Any) -> Self: - """ - Check if column value refers to the same object as `other` - :param other: object to check for "is" - :return: - """ - return self._operation(other, operator.is_) - - def isnot(self, other: Any) -> Self: - """ - Check if column value refers to the same object as `other` - :param other: object to check for "is" - :return: - """ - return self._operation(other, operator.is_not) - - def isnull(self) -> Self: - """ - Check if column value is None - """ - return self._operation(None, operator.is_) - - def notnull(self) -> Self: - """ - Check if column value is not None - """ - return self._operation(None, operator.is_not) - - def abs(self) -> Self: - """ - Get absolute value of the Column value - """ - return self.apply(lambda v, ctx: abs(v)) - - def __and__(self, other: Any) -> Self: - return self._operation(other, operator.and_) - - def __or__(self, other: Any) -> Self: - return self._operation(other, operator.or_) - - def __mod__(self, other: Any) -> Self: - return self._operation(other, operator.mod) - - def __add__(self, other: Any) -> Self: - return self._operation(other, operator.add) - - def __sub__(self, other: Any) -> Self: - return self._operation(other, operator.sub) - - def __mul__(self, other: Any) -> Self: - return self._operation(other, operator.mul) - - def __truediv__(self, other: Any) -> Self: - return self._operation(other, operator.truediv) - - def __eq__(self, other: Any) -> Self: - return self._operation(other, operator.eq) - - def __ne__(self, other: Any) -> Self: - return self._operation(other, operator.ne) - - def __lt__(self, other: Any) -> Self: - return self._operation(other, operator.lt) - - def __le__(self, other: Any) -> Self: - return self._operation(other, operator.le) - - def __gt__(self, other: Any) -> Self: - return self._operation(other, operator.gt) - - def __ge__(self, other: Any) -> Self: - return self._operation(other, operator.ge) - - def __invert__(self) -> Self: - return self.__class__(_eval_func=lambda x: invert(self.eval(x))) diff --git a/src/StreamingDataFrames/quixstreams/dataframe/dataframe.py b/src/StreamingDataFrames/quixstreams/dataframe/dataframe.py index 79383dd53..6ce43e9d4 100644 --- a/src/StreamingDataFrames/quixstreams/dataframe/dataframe.py +++ b/src/StreamingDataFrames/quixstreams/dataframe/dataframe.py @@ -1,229 +1,109 @@ -import uuid -from typing import Optional, Callable, Union, List, Mapping, Any - -from typing_extensions import Self, TypeAlias - -from .column import Column -from .exceptions import InvalidApplyResultType -from .pipeline import Pipeline -from ..models import Row, Topic, MessageContext -from ..rowconsumer import RowConsumerProto -from ..rowproducer import RowProducerProto -from ..state import State, StateStoreManager - -ApplyFunc: TypeAlias = Callable[ - [dict, MessageContext], Optional[Union[dict, List[dict]]] -] -StatefulApplyFunc: TypeAlias = Callable[ - [dict, MessageContext, State], Optional[Union[dict, List[dict]]] -] -KeyFunc: TypeAlias = Callable[[dict, MessageContext], Any] - -__all__ = ("StreamingDataFrame",) - - -def subset(keys: List[str], row: Row) -> Row: - row.value = row[keys] - return row - - -def setitem(k: str, v: Any, row: Row) -> Row: - row[k] = v.eval(row) if isinstance(v, Column) else v - return row - - -def apply( - row: Row, - func: Union[ApplyFunc, StatefulApplyFunc], - state_manager: Optional[StateStoreManager] = None, -) -> Union[Row, List[Row]]: - # Providing state to the function if state_manager is passed - if state_manager is not None: - transaction = state_manager.get_store_transaction() - # Prefix all the state keys by the message key - with transaction.with_prefix(prefix=row.key): - # Pass a State object with an interface limited to the key updates only - result = func(row.value, row.context, transaction.state) - else: - result = func(row.value, row.context) - - if result is None and isinstance(row.value, dict): - # Function returned None, assume it changed the incoming dict in-place - return row - if isinstance(result, dict): - # Function returned dict, assume it's a new value for the Row - row.value = result - return row - raise InvalidApplyResultType( - f"Only 'dict' or 'NoneType' (in-place modification) allowed, not {type(result)}" - ) - - -class StreamingDataFrame: - """ - Allows you to define transformations on a kafka message as if it were a Pandas - DataFrame. - Currently, it implements a small subset of the Pandas interface, along with - some differences/accommodations for kafka-specific functionality. - - A `StreamingDataFrame` expects to interact with a QuixStreams `Row`, which is - interacted with like a dictionary. - - Unlike pandas, you will not get an immediate output from any given operation; - instead, the command is permanently added to the `StreamingDataFrame`'s - "pipeline". You can then execute this pipeline indefinitely on a `Row` like so: - - ``` - df = StreamingDataframe(topic) - df = df.apply(lambda row: row) # do stuff; must return the row back! - for row_obj in [row_0, row_1]: - print(df.process(row_obj)) - ``` - - Note that just like Pandas, you can "filter" out rows with your operations, like: - ``` - df = df[df['column_b'] >= 5] - ``` - If a processing step nulls the Row in some way, all further processing on that - row (including kafka operations, besides committing) will be skipped. - - There is a :class:`quixstreams.app.Application` class that can manage - the kafka-specific dependencies of `StreamingDataFrame`; - it is recommended you hand your `StreamingDataFrame` instance to an `Application` - instance when interacting with Kafka. - - Below is a larger example of a `StreamingDataFrame` - (that you'd hand to an `Application`): - - ``` - # Define your processing steps - # Remove column_a, add 1 to columns b and c, skip row if b+1 >= 5, else publish row - df = StreamingDataframe() - df = df[['column_b', 'column_c']] - df = df.apply(lambda row: row[key] + 1 if key in ['column_b', 'column_c']) - df = df[df['column_b'] + 1] >= 5] - df.to_topic('my_output_topic') - - # Incomplete Rows to showcase what data you are actually interacting with - record_0 = Row(value={'column_a': 'a_string', 'column_b': 3, 'column_c': 5}) - record_1 = Row(value={'column_a': 'a_string', 'column_b': 1, 'column_c': 10}) - - # process records - df.process(record_0) # produces {'column_b': 4, 'column_c': 6} to "my_output_topic" - df.process(record_1) # filters row, does NOT produce to "my_output_topic" - ``` - """ - +import contextvars +import functools +import operator +from typing import Optional, Callable, Union, List, TypeVar, Any + +from typing_extensions import Self + +from quixstreams.context import ( + message_context, + set_message_context, + message_key, +) +from quixstreams.core.stream import StreamCallable, Stream +from quixstreams.models import Topic, Row, MessageContext +from quixstreams.rowproducer import RowProducerProto +from quixstreams.state import StateStoreManager, State +from .base import BaseStreaming +from .series import StreamingSeries + +T = TypeVar("T") +R = TypeVar("R") +DataFrameFunc = Callable[[T], R] +DataFrameStatefulFunc = Callable[[T, State], R] + + +class StreamingDataFrame(BaseStreaming): def __init__( self, topic: Topic, state_manager: StateStoreManager, + stream: Optional[Stream] = None, ): - self._id = str(uuid.uuid4()) - self._pipeline = Pipeline(_id=self.id) - self._real_consumer: Optional[RowConsumerProto] = None + self._stream: Stream = stream or Stream() + self._topic = topic self._real_producer: Optional[RowProducerProto] = None - self._topics_in = {topic.name: topic} - self._topics_out = {} self._state_manager = state_manager + @property + def stream(self) -> Stream: + return self._stream + + @property + def topic(self) -> Topic: + return self._topic + def apply( - self, - func: Union[ApplyFunc, StatefulApplyFunc], - stateful: bool = False, + self, func: Union[DataFrameFunc, DataFrameStatefulFunc], stateful: bool = False ) -> Self: """ - Apply a custom function with current value - and :py:class:`quixstreams.models.context.MessageContext` - as the expected input. - - :param func: a callable which accepts 2 arguments: - - value - a dict with fields and values for the current Row - - a context - an instance of :py:class:`quixstreams.models.context.MessageContext` - which contains message metadata like key, timestamp, partition, - and more. - - .. note:: if `stateful=True` is passed, a third argument of type `State` - will be provided to the function. + Apply a function to transform the value and return a new value. - The custom function may return: - - a new dict to replace the current Row value - - `None` to modify the current Row value in-place + The result will be passed downstream as an input value. - :param stateful: if `True`, the function will be provided with 3rd argument + :param func: a function to apply + :param stateful: if `True`, the function will be provided with a second argument of type `State` to perform stateful operations. - - :return: current instance of `StreamingDataFrame` """ - if stateful: - # Register the default store for each input topic - for topic in self._topics_in.values(): - self._state_manager.register_store(topic_name=topic.name) - return self._apply( - lambda row: apply(row, func, state_manager=self._state_manager) - ) - return self._apply(lambda row: apply(row, func)) + self._register_store() + func = _as_stateful(func=func, state_manager=self._state_manager) - def process(self, row: Row) -> Optional[Union[Row, List[Row]]]: - """ - Execute the previously defined StreamingDataframe operations on a provided Row. - :param row: a QuixStreams Row object - :return: Row, list of Rows, or None (if filtered) - """ - return self._pipeline.process(row) + stream = self.stream.add_apply(func) + return self._clone(stream=stream) - def to_topic(self, topic: Topic, key: Optional[KeyFunc] = None) -> Self: + def update( + self, func: Union[DataFrameFunc, DataFrameStatefulFunc], stateful: bool = False + ) -> Self: """ - Produce a row to a desired topic. - Note that a producer must be assigned on the `StreamingDataFrame` if not using - :class:`quixstreams.app.Application` class to facilitate - the execution of StreamingDataFrame. + Apply a function to mutate value in-place or to perform a side effect + that doesn't update the value (e.g. print a value to the console). - :param topic: A QuixStreams `Topic` - :param key: a callable to generate a new message key, optional. - If passed, the return type of this callable must be serializable - by `key_serializer` defined for this Topic object. - By default, the current message key will be used. + The result of the function will be ignored, and the original value will be + passed downstream. - :return: self (StreamingDataFrame) + :param func: function to update value + :param stateful: if `True`, the function will be provided with a second argument + of type `State` to perform stateful operations. """ - self._topics_out[topic.name] = topic - return self._apply( - lambda row: self._produce( - topic, row, key=key(row.value, row.context) if key else None - ) - ) + if stateful: + self._register_store() + func = _as_stateful(func=func, state_manager=self._state_manager) - @property - def id(self) -> str: - return self._id + stream = self.stream.add_update(func) + return self._clone(stream=stream) - @property - def topics_in(self) -> Mapping[str, Topic]: - """ - Get a mapping with Topics for the StreamingDataFrame input topics - :return: dict of {: } + def filter( + self, func: Union[DataFrameFunc, DataFrameStatefulFunc], stateful: bool = False + ) -> Self: """ - return self._topics_in + Filter value using provided function. - @property - def topics_out(self) -> Mapping[str, Topic]: - """ - Get a mapping with Topics for the StreamingDataFrame output topics - :return: dict of {: } + If the function returns True-like value, the original value will be + passed downstream. + Otherwise, the `Filtered` exception will be raised. + + :param func: function to filter value + :param stateful: if `True`, the function will be provided with second argument + of type `State` to perform stateful operations. """ - return self._topics_out - @property - def consumer(self) -> RowConsumerProto: - if self._real_consumer is None: - raise RuntimeError("Consumer instance has not been provided") - return self._real_consumer + if stateful: + self._register_store() + func = _as_stateful(func=func, state_manager=self._state_manager) - @consumer.setter - def consumer(self, consumer: RowConsumerProto): - self._real_consumer = consumer + stream = self.stream.add_filter(func) + return self._clone(stream=stream) @property def producer(self) -> RowProducerProto: @@ -236,7 +116,7 @@ def producer(self, producer: RowProducerProto): self._real_producer = producer @staticmethod - def contains(key: str) -> Column: + def contains(key: str) -> StreamingSeries: """ Check if the key is present in the Row value. @@ -249,38 +129,127 @@ def contains(key: str) -> Column: # This would add a new column 'has_column' which contains boolean values # indicating the presence of 'column_x' in each row. """ - return Column(_eval_func=lambda row: key in row.keys()) + return StreamingSeries.from_func(lambda value: key in value) + + def to_topic( + self, topic: Topic, key: Optional[Callable[[object], object]] = None + ) -> Self: + """ + Produce value to the topic. + + .. note:: A `RowProducer` instance must be assigned to + `StreamingDataFrame.producer` if not using :class:`quixstreams.app.Application` + to facilitate the execution of StreamingDataFrame. + + :param topic: instance of `Topic` + :param key: a callable to generate a new message key, optional. + If passed, the return type of this callable must be serializable + by `key_serializer` defined for this Topic object. + By default, the current message key will be used. + + """ + return self.update( + lambda value: self._produce(topic, value, key=key(value) if key else None) + ) + + def compile(self) -> StreamCallable: + """ + Compile all functions of this StreamingDataFrame into one big closure. + + Closures are more performant than calling all the functions in the + `StreamingDataFrame` one-by-one. + + :return: a function that accepts "value" + and returns a result of StreamingDataFrame + """ + return self.stream.compile() + + def test(self, value: object, ctx: Optional[MessageContext] = None) -> Any: + """ + A shorthand to test `StreamingDataFrame` with provided value + and `MessageContext`. + + :param value: value to pass through `StreamingDataFrame` + :param ctx: instance of `MessageContext`, optional. + Provide it if the StreamingDataFrame instance calls `to_topic()`, + has stateful functions or functions calling `get_current_key()`. + Default - `None`. + + :return: result of `StreamingDataFrame` + """ + context = contextvars.copy_context() + context.run(set_message_context, ctx) + compiled = self.compile() + return context.run(compiled, value) + + def _clone(self, stream: Stream) -> Self: + clone = self.__class__( + stream=stream, topic=self._topic, state_manager=self._state_manager + ) + if self._real_producer is not None: + clone.producer = self._real_producer + return clone + + def _produce(self, topic: Topic, value: object, key: Optional[object] = None): + ctx = message_context() + key = key or ctx.key + row = Row(value=value, context=ctx) # noqa + self.producer.produce_row(row, topic, key=key) - def __setitem__(self, key: str, value: Any): - self._apply(lambda row: setitem(key, value, row)) + def _register_store(self): + """ + Register the default store for input topic in StateStoreManager + """ + self._state_manager.register_store(topic_name=self._topic.name) + + def __setitem__(self, key, value: Union[Self, object]): + if isinstance(value, self.__class__): + diff = self.stream.diff(value.stream) + diff_compiled = diff.compile(allow_filters=False, allow_updates=False) + stream = self.stream.add_update( + lambda v: operator.setitem(v, key, diff_compiled(v)) + ) + elif isinstance(value, StreamingSeries): + value_compiled = value.compile(allow_filters=False, allow_updates=False) + stream = self.stream.add_update( + lambda v: operator.setitem(v, key, value_compiled(v)) + ) + else: + stream = self.stream.add_update(lambda v: operator.setitem(v, key, value)) + self._stream = stream def __getitem__( - self, item: Union[str, List[str], Column, Self] - ) -> Union[Column, Self]: - if isinstance(item, Column): - return self._apply(lambda row: row if item.eval(row) else None) + self, item: Union[str, List[str], StreamingSeries, Self] + ) -> Union[Self, StreamingSeries]: + if isinstance(item, StreamingSeries): + # Filter SDF based on StreamingSeries + item_compiled = item.compile(allow_filters=False, allow_updates=False) + return self.filter(lambda v: item_compiled(v)) + elif isinstance(item, self.__class__): + # Filter SDF based on another SDF + diff = self.stream.diff(item.stream) + diff_compiled = diff.compile(allow_filters=False, allow_updates=False) + return self.filter(lambda v: diff_compiled(v)) elif isinstance(item, list): - return self._apply(lambda row: subset(item, row)) - elif isinstance(item, StreamingDataFrame): - # TODO: Implement filtering based on another SDF - raise ValueError( - "Filtering based on StreamingDataFrame is not supported yet." - ) + # Take only certain keys from the dict and return a new dict + return self.apply(lambda v: {k: v[k] for k in item}) + elif isinstance(item, str): + # Create a StreamingSeries based on key + return StreamingSeries(name=item) else: - return Column(col_name=item) + raise TypeError(f'Unsupported key type "{type(item)}"') - def _produce(self, topic: Topic, row: Row, key: Optional[Any] = None) -> Row: - self.producer.produce_row(row, topic, key=key) - return row - def _apply(self, func: Callable[[Row], Optional[Union[Row, List[Row]]]]) -> Self: - """ - Add a callable to the StreamingDataframe execution list. - The provided callable should accept and return a Quixstreams Row; exceptions - to this include a user's .apply() function, or a "filter" that returns None. +def _as_stateful( + func: DataFrameStatefulFunc, state_manager: StateStoreManager +) -> DataFrameFunc: + @functools.wraps(func) + def wrapper(value: object) -> object: + transaction = state_manager.get_store_transaction() + key = message_key() + # Prefix all the state keys by the message key + with transaction.with_prefix(prefix=key): + # Pass a State object with an interface limited to the key updates only + return func(value, transaction.state) - :param func: callable that accepts and (usually) returns a QuixStreams Row - :return: self (StreamingDataFrame) - """ - self._pipeline.apply(func) - return self + return wrapper diff --git a/src/StreamingDataFrames/quixstreams/dataframe/exceptions.py b/src/StreamingDataFrames/quixstreams/dataframe/exceptions.py deleted file mode 100644 index c05a016b4..000000000 --- a/src/StreamingDataFrames/quixstreams/dataframe/exceptions.py +++ /dev/null @@ -1,7 +0,0 @@ -from quixstreams import exceptions - -__all__ = ("InvalidApplyResultType",) - - -class InvalidApplyResultType(exceptions.QuixException): - ... diff --git a/src/StreamingDataFrames/quixstreams/dataframe/pipeline.py b/src/StreamingDataFrames/quixstreams/dataframe/pipeline.py deleted file mode 100644 index 125acfe96..000000000 --- a/src/StreamingDataFrames/quixstreams/dataframe/pipeline.py +++ /dev/null @@ -1,82 +0,0 @@ -import logging -import uuid -from typing import Optional, Callable, Any, List -from typing_extensions import Self -from ..models import Row - -logger = logging.getLogger(__name__) - -__all__ = ("PipelineFunction", "Pipeline") - - -class PipelineFunction: - def __init__(self, func: Callable): - self._id = str(uuid.uuid4()) - self._func = func - - @property - def id(self) -> str: - return self._id - - def __repr__(self): - return f'<{self.__class__.__name__} "{repr(self._func)}">' - - def __call__(self, row: Row) -> Row: - return self._func(row) - - -class Pipeline: - def __init__(self, functions: List[PipelineFunction] = None, _id: str = None): - self._id = _id or str(uuid.uuid4()) - self._functions = functions or [] - - @property - def functions(self) -> List[PipelineFunction]: - return self._functions - - @property - def id(self) -> str: - return self._id - - def apply(self, func: Callable) -> Self: - """ - Add a callable to the Pipeline execution list. - The provided callable should accept a single argument, which will be its input. - The provided callable should similarly return one output, or None - - Note that if the expected input is a list, this function will be called with - each element in that list rather than the list itself. - - :param func: callable that accepts and (usually) returns an object - :return: self (Pipeline) - """ - self.functions.append(PipelineFunction(func=func)) - return self - - def process(self, data: Any) -> Optional[Any]: - """ - Execute the previously defined Pipeline functions on a provided input, `data`. - - Note that if `data` is a list, each function will be called with each element - in that list rather than the list itself. - - :param data: any object, but usually a QuixStreams Row - :return: an object OR list of objects OR None (if filtered) - """ - # TODO: maybe have an arg that allows passing "blacklisted" result types - # or SDF inspects each result somehow? - result = data - for func in self.functions: - if isinstance(result, list): - result = [ - fd for fd in (func(d) for d in result) if fd is not None - ] or None - else: - result = func(result) - if result is None: - logger.debug( - "Pipeline {pid} processing step returned a None; " - "terminating processing".format(pid=self.id) - ) - break - return result diff --git a/src/StreamingDataFrames/quixstreams/dataframe/series.py b/src/StreamingDataFrames/quixstreams/dataframe/series.py new file mode 100644 index 000000000..8c5a479c8 --- /dev/null +++ b/src/StreamingDataFrames/quixstreams/dataframe/series.py @@ -0,0 +1,234 @@ +import contextvars +import operator +from typing import Optional, Union, Callable, Container, Any + +from typing_extensions import Self + +from quixstreams.context import set_message_context +from quixstreams.core.stream.functions import StreamCallable, Apply +from quixstreams.core.stream.stream import Stream +from quixstreams.models.messagecontext import MessageContext +from .base import BaseStreaming + +__all__ = ("StreamingSeries",) + + +class StreamingSeries(BaseStreaming): + def __init__( + self, + name: Optional[str] = None, + stream: Optional[Stream] = None, + ): + if not (name or stream): + raise ValueError('Either "name" or "stream" must be passed') + self._stream = stream or Stream(func=Apply(lambda v: v[name])) + + @classmethod + def from_func(cls, func: StreamCallable) -> Self: + """ + Createa StreamingSeries from a function. + + The provided function will be wrapped into `Apply` + :param func: a function to apply + :return: instance of `StreamingSeries` + """ + return cls(stream=Stream(Apply(func))) + + @property + def stream(self) -> Stream: + return self._stream + + def apply(self, func: StreamCallable) -> Self: + """ + Add a callable to the execution list for this series. + + The provided callable should accept a single argument, which will be its input. + The provided callable should similarly return one output, or None + + :param func: a callable with one argument and one output + :return: a new `StreamingSeries` with the new callable added + """ + child = self._stream.add_apply(func) + return self.__class__(stream=child) + + def compile( + self, + allow_filters: bool = True, + allow_updates: bool = True, + ) -> StreamCallable: + """ + Compile all functions of this StreamingSeries into one big closure. + + Closures are more performant than calling all the functions in the + `StreamingDataFrame` one-by-one. + + :param allow_filters: If False, this function will fail with ValueError if + the stream has filter functions in the tree. Default - True. + :param allow_updates: If False, this function will fail with ValueError if + the stream has update functions in the tree. Default - True. + + :raises ValueError: if disallowed functions are present in the tree of + underlying `Stream`. + + :return: a function that accepts "value" + and returns a result of StreamingDataFrame + """ + + return self._stream.compile( + allow_filters=allow_filters, allow_updates=allow_updates + ) + + def test(self, value: Any, ctx: Optional[MessageContext] = None) -> Any: + """ + A shorthand to test `StreamingSeries` with provided value + and `MessageContext`. + + :param value: value to pass through `StreamingSeries` + :param ctx: instance of `MessageContext`, optional. + Provide it if the StreamingSeries instance has + functions calling `get_current_key()`. + Default - `None`. + :return: result of `StreamingSeries` + """ + context = contextvars.copy_context() + context.run(set_message_context, ctx) + compiled = self.compile() + return context.run(compiled, value) + + def _operation( + self, other: Union[Self, object], operator_: Callable[[object, object], object] + ) -> Self: + self_compiled = self.compile() + if isinstance(other, self.__class__): + other_compiled = other.compile() + return self.from_func( + func=lambda v, op=operator_: op(self_compiled(v), other_compiled(v)) + ) + else: + return self.from_func( + func=lambda v, op=operator_: op(self_compiled(v), other) + ) + + def isin(self, other: Container) -> Self: + """ + Check if series value is in "other". + Same as "StreamingSeries in other". + + :param other: a container to check + :return: new StreamingSeries + """ + return self._operation( + other, lambda a, b, contains=operator.contains: contains(b, a) + ) + + def contains(self, other: object) -> Self: + """ + Check if series value contains "other" + Same as "other in StreamingSeries". + + :param other: object to check + :return: new StreamingSeries + """ + return self._operation(other, operator.contains) + + def is_(self, other: object) -> Self: + """ + Check if series value refers to the same object as `other` + :param other: object to check for "is" + :return: new StreamingSeries + """ + return self._operation(other, operator.is_) + + def isnot(self, other: object) -> Self: + """ + Check if series value refers to the same object as `other` + :param other: object to check for "is" + :return: new StreamingSeries + """ + return self._operation(other, operator.is_not) + + def isnull(self) -> Self: + """ + Check if series value is None + + :return: new StreamingSeries + """ + return self._operation(None, operator.is_) + + def notnull(self) -> Self: + """ + Check if series value is not None + """ + return self._operation(None, operator.is_not) + + def abs(self) -> Self: + """ + Get absolute value of the series value + """ + return self.apply(func=lambda v: abs(v)) + + def __getitem__(self, item: Union[str, int]) -> Self: + return self._operation(item, operator.getitem) + + def __mod__(self, other: object) -> Self: + return self._operation(other, operator.mod) + + def __add__(self, other: object) -> Self: + return self._operation(other, operator.add) + + def __sub__(self, other: object) -> Self: + return self._operation(other, operator.sub) + + def __mul__(self, other: object) -> Self: + return self._operation(other, operator.mul) + + def __truediv__(self, other: object) -> Self: + return self._operation(other, operator.truediv) + + def __eq__(self, other: object) -> Self: + return self._operation(other, operator.eq) + + def __ne__(self, other: object) -> Self: + return self._operation(other, operator.ne) + + def __lt__(self, other: object) -> Self: + return self._operation(other, operator.lt) + + def __le__(self, other: object) -> Self: + return self._operation(other, operator.le) + + def __gt__(self, other: object) -> Self: + return self._operation(other, operator.gt) + + def __ge__(self, other: object) -> Self: + return self._operation(other, operator.ge) + + def __and__(self, other: object) -> Self: + """ + Do a logical "and" comparison. + + .. note:: It behaves differently than `pandas`. `pandas` performs + a bitwise "and" if one of the arguments is a number. + This function always does a logical "and" instead. + """ + return self._operation(other, lambda x, y: x and y) + + def __or__(self, other: object) -> Self: + """ + Do a logical "or" comparison. + + .. note:: It behaves differently than `pandas`. `pandas` performs + a bitwise "or" if one of the arguments is a number. + This function always does a logical "or" instead. + """ + return self._operation(other, lambda x, y: x or y) + + def __invert__(self) -> Self: + """ + Do a logical "not". + + .. note:: It behaves differently than `pandas`. `pandas` performs + a bitwise "not" if argument is a number. + This function always does a logical "not" instead. + """ + return self.apply(lambda v: not v) diff --git a/src/StreamingDataFrames/quixstreams/models/__init__.py b/src/StreamingDataFrames/quixstreams/models/__init__.py index fe4e55c81..ed2fb1237 100644 --- a/src/StreamingDataFrames/quixstreams/models/__init__.py +++ b/src/StreamingDataFrames/quixstreams/models/__init__.py @@ -3,3 +3,4 @@ from .timestamps import * from .topics import * from .types import * +from .messagecontext import * diff --git a/src/StreamingDataFrames/quixstreams/models/context.py b/src/StreamingDataFrames/quixstreams/models/messagecontext.py similarity index 100% rename from src/StreamingDataFrames/quixstreams/models/context.py rename to src/StreamingDataFrames/quixstreams/models/messagecontext.py diff --git a/src/StreamingDataFrames/quixstreams/models/rows.py b/src/StreamingDataFrames/quixstreams/models/rows.py index f0f319a00..3d9826233 100644 --- a/src/StreamingDataFrames/quixstreams/models/rows.py +++ b/src/StreamingDataFrames/quixstreams/models/rows.py @@ -5,7 +5,7 @@ from .messages import MessageHeadersTuples from .timestamps import MessageTimestamp -from .context import MessageContext +from .messagecontext import MessageContext class Row: diff --git a/src/StreamingDataFrames/quixstreams/models/topics.py b/src/StreamingDataFrames/quixstreams/models/topics.py index b30d69aec..9d99ede77 100644 --- a/src/StreamingDataFrames/quixstreams/models/topics.py +++ b/src/StreamingDataFrames/quixstreams/models/topics.py @@ -1,7 +1,7 @@ import logging from typing import Union, List, Mapping, Optional, Any -from .context import MessageContext +from .messagecontext import MessageContext from .messages import KafkaMessage from .rows import Row from .serializers import ( diff --git a/src/StreamingDataFrames/tests/test_quixstreams/test_app.py b/src/StreamingDataFrames/tests/test_quixstreams/test_app.py index 901858bd0..5c6ed8978 100644 --- a/src/StreamingDataFrames/tests/test_quixstreams/test_app.py +++ b/src/StreamingDataFrames/tests/test_quixstreams/test_app.py @@ -7,7 +7,6 @@ import pytest from confluent_kafka import KafkaException, TopicPartition -from tests.utils import TopicPartitionStub from quixstreams.app import Application from quixstreams.models import ( @@ -16,7 +15,6 @@ JSONDeserializer, SerializationError, JSONSerializer, - MessageContext, ) from quixstreams.platforms.quix import ( QuixKafkaConfigsBuilder, @@ -28,6 +26,7 @@ RowConsumer, ) from quixstreams.state import State +from tests.utils import TopicPartitionStub def _stop_app_on_future(app: Application, future: Future, timeout: float): @@ -92,8 +91,8 @@ def on_message_processed(topic_, partition, offset): value_deserializer=JSONDeserializer(), ) - df = app.dataframe(topic_in) - df.to_topic(topic_out) + sdf = app.dataframe(topic_in) + sdf = sdf.to_topic(topic_out) processed_count = 0 total_messages = 3 @@ -107,7 +106,7 @@ def on_message_processed(topic_, partition, offset): # Stop app when the future is resolved executor.submit(_stop_app_on_future, app, done, 10.0) - app.run(df) + app.run(sdf) # Check that all messages have been processed assert processed_count == total_messages @@ -130,21 +129,19 @@ def on_message_processed(topic_, partition, offset): assert row.key == data["key"] assert row.value == {column_name: loads(data["value"].decode())} - def test_run_consumer_error_raised( - self, app_factory, producer, topic_factory, consumer, executor - ): + def test_run_consumer_error_raised(self, app_factory, topic_factory, executor): # Set "auto_offset_reset" to "error" to simulate errors in Consumer app = app_factory(auto_offset_reset="error") topic_name, _ = topic_factory() topic = app.topic( topic_name, value_deserializer=JSONDeserializer(column_name="root") ) - df = app.dataframe(topic) + sdf = app.dataframe(topic) # Stop app after 10s if nothing failed executor.submit(_stop_app_on_timeout, app, 10.0) with pytest.raises(KafkaMessageError): - app.run(df) + app.run(sdf) def test_run_deserialization_error_raised( self, app_factory, producer, topic_factory, consumer, executor @@ -157,12 +154,12 @@ def test_run_deserialization_error_raised( with producer: producer.produce(topic=topic_name, value=b"abc") - df = app.dataframe(topic) + sdf = app.dataframe(topic) with pytest.raises(SerializationError): # Stop app after 10s if nothing failed executor.submit(_stop_app_on_timeout, app, 10.0) - app.run(df) + app.run(sdf) def test_run_consumer_error_suppressed( self, app_factory, producer, topic_factory, consumer, executor @@ -181,14 +178,14 @@ def on_consumer_error(exc, *args): app = app_factory(on_consumer_error=on_consumer_error) topic_name, _ = topic_factory() topic = app.topic(topic_name) - df = app.dataframe(topic) + sdf = app.dataframe(topic) with patch.object(RowConsumer, "poll") as mocked: # Patch RowConsumer.poll to simulate failures mocked.side_effect = ValueError("test") # Stop app when the future is resolved executor.submit(_stop_app_on_future, app, done, 10.0) - app.run(df) + app.run(sdf) assert polled > 1 def test_run_processing_error_raised( @@ -198,19 +195,19 @@ def test_run_processing_error_raised( topic_name, _ = topic_factory() topic = app.topic(topic_name, value_deserializer=JSONDeserializer()) - df = app.dataframe(topic) + sdf = app.dataframe(topic) def fail(*args): raise ValueError("test") - df = df.apply(fail) + sdf = sdf.apply(fail) with producer: producer.produce(topic=topic.name, value=b'{"field":"value"}') with pytest.raises(ValueError): executor.submit(_stop_app_on_timeout, app, 10.0) - app.run(df) + app.run(sdf) def test_run_processing_error_suppressed( self, app_factory, topic_factory, producer, executor @@ -232,12 +229,12 @@ def on_processing_error(exc, *args): ) topic_name, _ = topic_factory() topic = app.topic(topic_name, value_deserializer=JSONDeserializer()) - df = app.dataframe(topic) + sdf = app.dataframe(topic) def fail(*args): raise ValueError("test") - df = df.apply(fail) + sdf = sdf.apply(fail) with producer: for i in range(produced): @@ -245,7 +242,7 @@ def fail(*args): # Stop app from the background thread when the future is resolved executor.submit(_stop_app_on_future, app, done, 10.0) - app.run(df) + app.run(sdf) assert produced == consumed def test_run_producer_error_raised( @@ -261,15 +258,15 @@ def test_run_producer_error_raised( topic_in = app.topic(topic_in_name, value_deserializer=JSONDeserializer()) topic_out = app.topic(topic_out_name, value_serializer=JSONSerializer()) - df = app.dataframe(topic_in) - df.to_topic(topic_out) + sdf = app.dataframe(topic_in) + sdf = sdf.to_topic(topic_out) with producer: producer.produce(topic_in.name, dumps({"field": 1001 * "a"})) with pytest.raises(KafkaException): executor.submit(_stop_app_on_timeout, app, 10.0) - app.run(df) + app.run(sdf) def test_run_serialization_error_raised( self, app_factory, producer, topic_factory, executor @@ -281,15 +278,15 @@ def test_run_serialization_error_raised( topic_out_name, _ = topic_factory() topic_out = app.topic(topic_out_name, value_serializer=DoubleSerializer()) - df = app.dataframe(topic_in) - df.to_topic(topic_out) + sdf = app.dataframe(topic_in) + sdf = sdf.to_topic(topic_out) with producer: producer.produce(topic_in.name, b'{"field":"value"}') with pytest.raises(SerializationError): executor.submit(_stop_app_on_timeout, app, 10.0) - app.run(df) + app.run(sdf) def test_run_producer_error_suppressed( self, app_factory, producer, topic_factory, consumer, executor @@ -314,15 +311,15 @@ def on_producer_error(exc, *args): topic_out_name, _ = topic_factory() topic_out = app.topic(topic_out_name, value_serializer=DoubleSerializer()) - df = app.dataframe(topic_in) - df.to_topic(topic_out) + sdf = app.dataframe(topic_in) + sdf = sdf.to_topic(topic_out) with producer: for _ in range(produce_input): producer.produce(topic_in.name, b'{"field":"value"}') executor.submit(_stop_app_on_future, app, done, 10.0) - app.run(df) + app.run(sdf) assert produce_output_attempts == produce_input @@ -336,7 +333,6 @@ def test_streamingdataframe_init(self): app = Application(broker_address="localhost", consumer_group="test") topic = app.topic(name="test-topic") sdf = app.dataframe(topic) - assert sdf @@ -401,9 +397,7 @@ def test_topic_config(self, quix_app_factory): assert builder.create_topic_configs[expected_name].name == expected_name assert builder.create_topic_configs[expected_name].num_partitions == 5 - def test_topic_auto_create_false_topic_confirmation( - self, dataframe_factory, quix_app_factory - ): + def test_topic_auto_create_false_topic_confirmation(self, quix_app_factory): """ Topics are confirmed when auto_create_topics=False """ @@ -443,7 +437,7 @@ def test_quix_app_stateful_quix_deployment_no_state_management_warning( app = quix_app_factory(workspace_id="") topic = app.topic(topic_name) sdf = app.dataframe(topic) - sdf.apply(lambda x, state: x, stateful=True) + sdf = sdf.apply(lambda x, state: x, stateful=True) monkeypatch.setenv( QuixEnvironment.DEPLOYMENT_ID, @@ -507,15 +501,15 @@ def test_run_stateful_success( topic_in = app.topic(topic_in_name, value_deserializer=JSONDeserializer()) # Define a function that counts incoming Rows using state - def count(_, ctx: MessageContext, state: State): + def count(_, state: State): total = state.get("total", 0) total += 1 state.set("total", total) if total == total_messages: total_consumed.set_result(total) - df = app.dataframe(topic_in) - df.apply(count, stateful=True) + sdf = app.dataframe(topic_in) + sdf = sdf.update(count, stateful=True) total_messages = 3 # Produce messages to the topic and flush @@ -529,7 +523,7 @@ def count(_, ctx: MessageContext, state: State): # Stop app when the future is resolved executor.submit(_stop_app_on_future, app, total_consumed, 10.0) - app.run(df) + app.run(sdf) # Check that the values are actually in the DB state_manager = state_manager_factory( @@ -567,7 +561,7 @@ def test_run_stateful_processing_fails( topic_in = app.topic(topic_in_name, value_deserializer=JSONDeserializer()) # Define a function that counts incoming Rows using state - def count(_, ctx, state: State): + def count(_, state: State): total = state.get("total", 0) total += 1 state.set("total", total) @@ -578,9 +572,7 @@ def fail(*_): failed.set_result(True) raise ValueError("test") - df = app.dataframe(topic_in) - df.apply(count, stateful=True) - df.apply(fail) + sdf = app.dataframe(topic_in).update(count, stateful=True).update(fail) total_messages = 3 # Produce messages to the topic and flush @@ -592,7 +584,7 @@ def fail(*_): # Stop app when the future is resolved executor.submit(_stop_app_on_future, app, failed, 10.0) with pytest.raises(ValueError): - app.run(df) + app.run(sdf) # Ensure that nothing was committed to the DB state_manager = state_manager_factory( @@ -629,7 +621,7 @@ def test_run_stateful_suppress_processing_errors( topic_in = app.topic(topic_in_name, value_deserializer=JSONDeserializer()) # Define a function that counts incoming Rows using state - def count(_, ctx, state: State): + def count(_, state: State): total = state.get("total", 0) total += 1 state.set("total", total) @@ -639,9 +631,7 @@ def count(_, ctx, state: State): def fail(_): raise ValueError("test") - df = app.dataframe(topic_in) - df.apply(count, stateful=True) - df.apply(fail) + sdf = app.dataframe(topic_in).update(count, stateful=True).apply(fail) total_messages = 3 message_key = b"key" @@ -656,7 +646,7 @@ def fail(_): # Stop app when the future is resolved executor.submit(_stop_app_on_future, app, total_consumed, 10.0) # Run the application - app.run(df) + app.run(sdf) # Ensure that data is committed to the DB state_manager = state_manager_factory( @@ -710,11 +700,9 @@ def test_on_assign_topic_offset_behind_warning( # Define some stateful function so the App assigns store partitions done = Future() - def count(_, ctx, state: State): - done.set_result(True) - - df = app.dataframe(topic_in) - df.apply(count, stateful=True) + sdf = app.dataframe(topic_in).update( + lambda *_: done.set_result(True), stateful=True + ) # Produce a message to the topic and flush data = {"key": b"key", "value": dumps({"key": "value"})} @@ -725,7 +713,7 @@ def count(_, ctx, state: State): executor.submit(_stop_app_on_future, app, done, 10.0) # Run the application with patch.object(logging.getLoggerClass(), "warning") as mock: - app.run(df) + app.run(sdf) assert mock.called assert "is behind the stored offset" in mock.call_args[0][0] diff --git a/src/StreamingDataFrames/tests/test_quixstreams/test_context.py b/src/StreamingDataFrames/tests/test_quixstreams/test_context.py new file mode 100644 index 000000000..f6d85479e --- /dev/null +++ b/src/StreamingDataFrames/tests/test_quixstreams/test_context.py @@ -0,0 +1,52 @@ +import contextvars + +import pytest + +from quixstreams.models import MessageTimestamp, MessageContext +from quixstreams.context import ( + message_context, + set_message_context, + message_key, + MessageContextNotSetError, +) + + +@pytest.fixture() +def message_context_factory(): + def factory(key: object = "test") -> MessageContext: + return MessageContext( + key=key, + topic="test", + partition=0, + offset=0, + size=0, + timestamp=MessageTimestamp.create(0, 0), + ) + + return factory + + +class TestContext: + def test_get_current_context_not_set_fails(self): + ctx = contextvars.copy_context() + with pytest.raises(MessageContextNotSetError): + ctx.run(message_context) + + def test_set_current_context_and_run(self, message_context_factory): + ctx = contextvars.copy_context() + message_ctx1 = message_context_factory(key="test") + message_ctx2 = message_context_factory(key="test2") + for message_ctx in [message_ctx1, message_ctx2]: + ctx.run(set_message_context, message_ctx) + assert ctx.run(lambda: message_context()) == message_ctx + + def test_get_current_key_success(self, message_context_factory): + ctx = contextvars.copy_context() + message_ctx = message_context_factory(key="test") + ctx.run(set_message_context, message_ctx) + assert ctx.run(message_key) == message_ctx.key + + def test_get_current_key_not_set_fails(self): + ctx = contextvars.copy_context() + with pytest.raises(MessageContextNotSetError): + ctx.run(message_key) diff --git a/src/StreamingDataFrames/tests/test_quixstreams/test_core/__init__.py b/src/StreamingDataFrames/tests/test_quixstreams/test_core/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/StreamingDataFrames/tests/test_quixstreams/test_core/test_functions.py b/src/StreamingDataFrames/tests/test_quixstreams/test_core/test_functions.py new file mode 100644 index 000000000..7b50c4845 --- /dev/null +++ b/src/StreamingDataFrames/tests/test_quixstreams/test_core/test_functions.py @@ -0,0 +1,42 @@ +import pytest + +from quixstreams.core.stream.functions import ( + Apply, + Filter, + Update, + Filtered, + is_apply_function, + is_filter_function, + is_update_function, +) + + +class TestFunctions: + def test_apply_function(self): + func = Apply(lambda v: v) + assert func(1) == 1 + assert is_apply_function(func) + + def test_update_function(self): + value = [0] + expected = [0, 1] + func = Update(lambda v: v.append(1)) + assert func(value) == expected + assert is_update_function(func) + + @pytest.mark.parametrize( + "value, filtered", + [ + (1, True), + (0, False), + ], + ) + def test_filter_function(self, value, filtered): + func = Filter(lambda v: v == 0) + assert is_filter_function(func) + + if filtered: + with pytest.raises(Filtered): + func(value) + else: + assert func(value) == value diff --git a/src/StreamingDataFrames/tests/test_quixstreams/test_core/test_stream.py b/src/StreamingDataFrames/tests/test_quixstreams/test_core/test_stream.py new file mode 100644 index 000000000..776364e96 --- /dev/null +++ b/src/StreamingDataFrames/tests/test_quixstreams/test_core/test_stream.py @@ -0,0 +1,107 @@ +import pytest + +from quixstreams.core.stream import ( + Stream, + Filtered, + is_filter_function, + is_update_function, + is_apply_function, +) + + +class TestStream: + def test_add_apply(self): + stream = Stream().add_apply(lambda v: v + 1) + assert stream.compile()(1) == 2 + + def test_add_update(self): + stream = Stream().add_update(lambda v: v.append(1)) + assert stream.compile()([0]) == [0, 1] + + @pytest.mark.parametrize( + "value, filtered", + [ + (1, True), + (0, False), + ], + ) + def test_add_filter(self, value, filtered): + stream = Stream().add_filter(lambda v: v == 0) + + if filtered: + with pytest.raises(Filtered): + stream.compile()(value) + else: + assert stream.compile()(value) == value + + def test_tree(self): + stream = ( + Stream() + .add_apply(lambda v: v) + .add_filter(lambda v: v) + .add_update(lambda v: v) + ) + tree = stream.tree() + assert len(tree) == 4 + assert is_apply_function(tree[0].func) + assert is_apply_function(tree[1].func) + assert is_filter_function(tree[2].func) + assert is_update_function(tree[3].func) + + def test_diff_success(self): + stream = Stream() + stream = stream.add_apply(lambda v: v) + stream2 = ( + stream.add_apply(lambda v: v) + .add_update(lambda v: v) + .add_filter(lambda v: v) + ) + + stream = stream.add_apply(lambda v: v) + + diff = stream.diff(stream2) + + diff_tree = diff.tree() + assert len(diff_tree) == 3 + assert is_apply_function(diff_tree[0].func) + assert is_update_function(diff_tree[1].func) + assert is_filter_function(diff_tree[2].func) + + def test_diff_empty_same_stream_fails(self): + stream = Stream() + with pytest.raises(ValueError, match="The diff is empty"): + stream.diff(stream) + + def test_diff_empty_stream_full_child_fails(self): + stream = Stream() + stream2 = stream.add_apply(lambda v: v) + with pytest.raises(ValueError, match="The diff is empty"): + stream2.diff(stream) + + def test_diff_no_common_parent_fails(self): + stream = Stream() + stream2 = Stream() + with pytest.raises(ValueError, match="Common parent not found"): + stream.diff(stream2) + + def test_compile_allow_filters_false(self): + stream = Stream().add_filter(lambda v: v) + with pytest.raises(ValueError, match="Filter functions are not allowed"): + stream.compile(allow_filters=False) + + def test_compile_allow_updates_false(self): + stream = Stream().add_update(lambda v: v) + with pytest.raises(ValueError, match="Update functions are not allowed"): + stream.compile(allow_updates=False) + + def test_repr(self): + stream = ( + Stream() + .add_apply(lambda v: v) + .add_update(lambda v: v) + .add_filter(lambda v: v) + ) + repr(stream) + + def test_init_with_unwrapped_function(self): + ... diff --git a/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/fixtures.py b/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/fixtures.py index 3f9729eb7..4831752bb 100644 --- a/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/fixtures.py +++ b/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/fixtures.py @@ -1,55 +1,22 @@ -from functools import partial from typing import Optional from unittest.mock import MagicMock import pytest from quixstreams.dataframe.dataframe import StreamingDataFrame -from quixstreams.dataframe.pipeline import Pipeline, PipelineFunction from quixstreams.models.topics import Topic from quixstreams.state import StateStoreManager -@pytest.fixture() -def pipeline_function(): - def test_func(data): - return {k: v + 1 for k, v in data.items()} - - return PipelineFunction(func=test_func) - - -@pytest.fixture() -def pipeline(pipeline_function): - return Pipeline(functions=[pipeline_function]) - - @pytest.fixture() def dataframe_factory(): - def _dataframe_factory( + def factory( topic: Optional[Topic] = None, state_manager: Optional[StateStoreManager] = None, - ): + ) -> StreamingDataFrame: return StreamingDataFrame( - topic=topic or Topic(name="test_in"), + topic=topic or Topic(name="test"), state_manager=state_manager or MagicMock(spec=StateStoreManager), ) - return _dataframe_factory - - -def row_values_plus_n(n, row, ctx): - for k, v in row.items(): - row[k] = v + n - return row - - -@pytest.fixture() -def row_plus_n_func(): - """ - This generally will be used alongside "row_plus_n" - """ - - def _row_values_plus_n(n=None): - return partial(row_values_plus_n, n) - - return _row_values_plus_n + return factory diff --git a/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_column.py b/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_column.py deleted file mode 100644 index 1eac1ca47..000000000 --- a/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_column.py +++ /dev/null @@ -1,328 +0,0 @@ -import pytest - -from quixstreams.dataframe.column import Column - - -class TestColumn: - def test_apply(self, row_factory): - msg_value = row_factory({"x": 5, "y": 20, "z": 110}, key=123) - result = Column("x").apply(lambda v, context: v + context.key) - assert isinstance(result, Column) - assert result.eval(msg_value) == 128 - - def test_addition(self, row_factory): - msg_value = row_factory({"x": 5, "y": 20, "z": 110}) - result = Column("x") + Column("y") - assert isinstance(result, Column) - assert result.eval(msg_value) == 25 - - def test_multi_op(self, row_factory): - msg_value = row_factory({"x": 5, "y": 20, "z": 110}) - result = Column("x") + Column("y") + Column("z") - assert isinstance(result, Column) - assert result.eval(msg_value) == 135 - - def test_scaler_op(self, row_factory): - msg_value = row_factory({"x": 5, "y": 20, "z": 110}) - result = Column("x") + 2 - assert isinstance(result, Column) - assert result.eval(msg_value) == 7 - - def test_subtraction(self, row_factory): - msg_value = row_factory({"x": 5, "y": 20, "z": 110}) - result = Column("y") - Column("x") - assert isinstance(result, Column) - assert result.eval(msg_value) == 15 - - def test_multiplication(self, row_factory): - msg_value = row_factory({"x": 5, "y": 20, "z": 110}) - result = Column("x") * Column("y") - assert isinstance(result, Column) - assert result.eval(msg_value) == 100 - - def test_div(self, row_factory): - msg_value = row_factory({"x": 5, "y": 20, "z": 110}) - result = Column("z") / Column("y") - assert isinstance(result, Column) - assert result.eval(msg_value) == 5.5 - - def test_mod(self, row_factory): - msg_value = row_factory({"x": 5, "y": 20, "z": 110}) - result = Column("y") % Column("x") - assert isinstance(result, Column) - assert result.eval(msg_value) == 0 - - def test_equality_true(self, row_factory): - msg_value = row_factory({"x": 5, "x2": 5}) - result = Column("x") == Column("x2") - assert isinstance(result, Column) - assert result.eval(msg_value) is True - - def test_equality_false(self, row_factory): - msg_value = row_factory({"x": 5, "y": 20, "z": 110}) - result = Column("x") == Column("y") - assert isinstance(result, Column) - assert result.eval(msg_value) is False - - def test_inequality_true(self, row_factory): - msg_value = row_factory({"x": 5, "y": 20, "z": 110}) - result = Column("x") != Column("y") - assert isinstance(result, Column) - assert result.eval(msg_value) is True - - def test_inequality_false(self, row_factory): - msg_value = row_factory({"x": 5, "x2": 5}) - result = Column("x") != Column("x2") - assert isinstance(result, Column) - assert result.eval(msg_value) is False - - def test_less_than_true(self, row_factory): - msg_value = row_factory({"x": 5, "y": 20, "z": 110}) - result = Column("x") < Column("y") - assert isinstance(result, Column) - assert result.eval(msg_value) is True - - def test_less_than_false(self, row_factory): - msg_value = row_factory({"x": 5, "y": 20, "z": 110}) - result = Column("y") < Column("x") - assert isinstance(result, Column) - assert result.eval(msg_value) is False - - def test_less_than_or_equal_equal_true(self, row_factory): - msg_value = row_factory({"x": 5, "x2": 5}) - result = Column("x") <= Column("x2") - assert isinstance(result, Column) - assert result.eval(msg_value) is True - - def test_less_than_or_equal_less_true(self, row_factory): - msg_value = row_factory({"x": 5, "y": 20, "z": 110}) - result = Column("x") <= Column("y") - assert isinstance(result, Column) - assert result.eval(msg_value) is True - - def test_less_than_or_equal_false(self, row_factory): - msg_value = row_factory({"x": 5, "y": 20, "z": 110}) - result = Column("y") <= Column("x") - assert isinstance(result, Column) - assert result.eval(msg_value) is False - - def test_greater_than_true(self, row_factory): - msg_value = row_factory({"x": 5, "y": 20, "z": 110}) - result = Column("y") > Column("x") - assert isinstance(result, Column) - assert result.eval(msg_value) is True - - def test_greater_than_false(self, row_factory): - msg_value = row_factory({"x": 5, "y": 20, "z": 110}) - result = Column("x") > Column("y") - assert isinstance(result, Column) - assert result.eval(msg_value) is False - - def test_greater_than_or_equal_equal_true(self, row_factory): - msg_value = row_factory({"x": 5, "x2": 5}) - result = Column("x") >= Column("x2") - assert isinstance(result, Column) - assert result.eval(msg_value) is True - - def test_greater_than_or_equal_greater_true(self, row_factory): - msg_value = row_factory({"x": 5, "y": 20, "z": 110}) - result = Column("y") >= Column("x") - assert isinstance(result, Column) - assert result.eval(msg_value) is True - - def test_greater_than_or_equal_greater_false(self, row_factory): - msg_value = row_factory({"x": 5, "y": 20, "z": 110}) - result = Column("x") >= Column("y") - assert isinstance(result, Column) - assert result.eval(msg_value) is False - - def test_and_true(self, row_factory): - msg_value = row_factory({"x": True, "y": True, "z": False}) - result = Column("x") & Column("y") - assert isinstance(result, Column) - assert result.eval(msg_value) is True - - def test_and_true_multiple(self, row_factory): - msg_value = row_factory({"x": True, "y": True, "z": False}) - result = Column("x") & Column("y") & Column("x") - assert isinstance(result, Column) - assert result.eval(msg_value) is True - - def test_and_false_multiple(self, row_factory): - msg_value = row_factory({"x": True, "y": True, "z": False}) - result = Column("x") & Column("y") & Column("z") - assert isinstance(result, Column) - assert result.eval(msg_value) is False - - def test_and_false(self, row_factory): - msg_value = row_factory({"x": True, "y": True, "z": False}) - result = Column("x") & Column("z") - assert isinstance(result, Column) - assert result.eval(msg_value) is False - - def test_and_true_bool(self, row_factory): - msg_value = row_factory({"x": True, "y": True, "z": False}) - result = Column("x") & True - assert isinstance(result, Column) - assert result.eval(msg_value) is True - - def test_and_false_bool(self, row_factory): - msg_value = row_factory({"x": True, "y": True, "z": False}) - result = Column("x") & False - assert isinstance(result, Column) - assert result.eval(msg_value) is False - - def test_or_true(self, row_factory): - msg_value = row_factory({"x": True, "y": True, "z": False}) - result = Column("x") | Column("z") - assert isinstance(result, Column) - assert result.eval(msg_value) is True - - def test_or_false(self, row_factory): - msg_value = row_factory({"x": True, "y": True, "z": False}) - result = Column("z") | Column("z") - assert isinstance(result, Column) - assert result.eval(msg_value) is False - - def test_and_inequalities_true(self, row_factory): - msg_value = row_factory({"x": 5, "y": 20, "z": 110}) - result = (Column("x") <= Column("y")) & (Column("x") <= Column("z")) - assert isinstance(result, Column) - assert result.eval(msg_value) is True - - def test_and_inequalities_false(self, row_factory): - msg_value = row_factory({"x": 5, "y": 20, "z": 110}) - result = (Column("x") <= Column("y")) & (Column("x") > Column("z")) - assert isinstance(result, Column) - assert result.eval(msg_value) is False - - def test_invert_int(self, row_factory): - msg_value = row_factory({"x": 1}) - result = ~Column("x") - assert isinstance(result, Column) - assert result.eval(msg_value) == -2 - - def test_invert_bool(self, row_factory): - msg_value = row_factory({"x": True, "y": True, "z": False}) - result = ~Column("x") - assert isinstance(result, Column) - assert result.eval(msg_value) is False - - def test_invert_bool_from_inequalities(self, row_factory): - msg_value = row_factory({"x": 5, "y": 20, "z": 110}) - result = ~(Column("x") <= Column("y")) - assert isinstance(result, Column) - assert result.eval(msg_value) is False - - @pytest.mark.parametrize( - "value, other, expected", - [ - ({"x": 1}, [1, 2, 3], True), - ({"x": 1}, [], False), - ({"x": 1}, {1: 456}, True), - ], - ) - def test_isin(self, row_factory, value, other, expected): - row = row_factory(value) - assert Column("x").isin(other).eval(row) == expected - - @pytest.mark.parametrize( - "value, other, expected", - [ - ({"x": [1, 2, 3]}, 1, True), - ({"x": [1, 2, 3]}, 5, False), - ({"x": "abc"}, "a", True), - ({"x": {"y": "z"}}, "y", True), - ], - ) - def test_contains(self, row_factory, value, other, expected): - row = row_factory(value) - assert Column("x").contains(other).eval(row) == expected - - @pytest.mark.parametrize( - "value, expected", - [ - ({"x": None}, True), - ({"x": [1, 2, 3]}, False), - ], - ) - def test_isnull(self, row_factory, value, expected): - row = row_factory(value) - assert Column("x").isnull().eval(row) == expected - - @pytest.mark.parametrize( - "value, expected", - [ - ({"x": None}, False), - ({"x": [1, 2, 3]}, True), - ], - ) - def test_notnull(self, row_factory, value, expected): - row = row_factory(value) - assert Column("x").notnull().eval(row) == expected - - @pytest.mark.parametrize( - "value, other, expected", - [ - ({"x": [1, 2, 3]}, None, False), - ({"x": None}, None, True), - ({"x": 1}, 1, True), - ], - ) - def test_is_(self, row_factory, value, other, expected): - row = row_factory(value) - assert Column("x").is_(other).eval(row) == expected - - @pytest.mark.parametrize( - "value, other, expected", - [ - ({"x": [1, 2, 3]}, None, True), - ({"x": None}, None, False), - ({"x": 1}, 1, False), - ], - ) - def test_isnot(self, row_factory, value, other, expected): - row = row_factory(value) - assert Column("x").isnot(other).eval(row) == expected - - @pytest.mark.parametrize( - "nested_item", - [ - {2: 110}, - {2: "a_string"}, - {2: {"another": "dict"}}, - {2: ["a", "list"]}, - ["item", "in", "this", "list"], - ], - ) - def test__get_item__(self, row_factory, nested_item): - msg_value = row_factory({"x": {"y": nested_item}, "k": 0}) - result = Column("x")["y"][2] - assert isinstance(result, Column) - assert result.eval(msg_value) == nested_item[2] - - def test_get_item_with_apply(self, row_factory): - msg_value = row_factory({"x": {"y": {"z": 110}}, "k": 0}) - result = Column("x")["y"]["z"].apply(lambda v, ctx: v + 10) - assert isinstance(result, Column) - assert result.eval(msg_value) == 120 - - def test_get_item_with_op(self, row_factory): - msg_value = row_factory({"x": {"y": 10}, "j": {"k": 5}}) - result = Column("x")["y"] + Column("j")["k"] - assert isinstance(result, Column) - assert result.eval(msg_value) == 15 - - @pytest.mark.parametrize("value, expected", [(10, 10), (-10, 10), (10.0, 10.0)]) - def test_abs_success(self, value, expected, row_factory): - row = row_factory({"x": value}) - result = Column("x").abs() - assert isinstance(result, Column) - assert result.eval(row) == expected - - def test_abs_not_a_number_fails(self, row_factory): - row = row_factory({"x": "string"}) - result = Column("x").abs() - assert isinstance(result, Column) - with pytest.raises(TypeError, match="bad operand type for abs()"): - assert result.eval(row) diff --git a/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_dataframe.py b/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_dataframe.py index 9d2c33155..e3285a162 100644 --- a/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_dataframe.py +++ b/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_dataframe.py @@ -1,224 +1,262 @@ +import operator + import pytest -from tests.utils import TopicPartitionStub -from quixstreams.dataframe.exceptions import InvalidApplyResultType -from quixstreams.dataframe.pipeline import Pipeline -from quixstreams.models import MessageContext +from quixstreams import MessageContext, State +from quixstreams.core.stream import Filtered +from quixstreams.models import MessageTimestamp from quixstreams.models.topics import Topic -from quixstreams.state import State - - -class TestDataframe: - def test_dataframe(self, dataframe_factory): - dataframe = dataframe_factory() - assert isinstance(dataframe._pipeline, Pipeline) - assert dataframe._pipeline.id == dataframe.id - - -class TestDataframeProcess: - def test_apply(self, dataframe_factory, row_factory): - dataframe = dataframe_factory() - row = row_factory({"x": 1, "y": 2}, key="key") - - def _apply(value: dict, ctx: MessageContext): - assert ctx.key == "key" - assert value == {"x": 1, "y": 2} - return { - "x": 3, - "y": 4, - } - - dataframe.apply(_apply) - assert dataframe.process(row).value == {"x": 3, "y": 4} - - def test_apply_no_return_value(self, dataframe_factory, row_factory): - dataframe = dataframe_factory() - dataframe = dataframe.apply(lambda row, ctx: row.update({"y": 2})) - row = row_factory({"x": 1}) - assert dataframe.process(row).value == row_factory({"x": 1, "y": 2}).value - - def test_apply_invalid_return_type(self, dataframe_factory, row_factory): - dataframe = dataframe_factory() - dataframe = dataframe.apply(lambda row, ctx: False) - row = row_factory({"x": 1, "y": 2}) - with pytest.raises(InvalidApplyResultType): - dataframe.process(row) - - def test_apply_fluent(self, dataframe_factory, row_factory, row_plus_n_func): - dataframe = dataframe_factory() - dataframe = dataframe.apply(row_plus_n_func(n=1)).apply(row_plus_n_func(n=2)) - row = row_factory({"x": 1, "y": 2}) - assert dataframe.process(row).value == row_factory({"x": 4, "y": 5}).value - - def test_apply_sequential(self, dataframe_factory, row_factory, row_plus_n_func): - dataframe = dataframe_factory() - dataframe = dataframe.apply(row_plus_n_func(n=1)) - dataframe = dataframe.apply(row_plus_n_func(n=2)) - row = row_factory({"x": 1, "y": 2}) - assert dataframe.process(row).value == row_factory({"x": 4, "y": 5}).value - - def test_setitem_primitive(self, dataframe_factory, row_factory): - dataframe = dataframe_factory() - dataframe["new"] = 1 - row = row_factory({"x": 1}) - assert dataframe.process(row).value == row_factory({"x": 1, "new": 1}).value - - def test_setitem_column_only(self, dataframe_factory, row_factory): - dataframe = dataframe_factory() - dataframe["new"] = dataframe["x"] - row = row_factory({"x": 1}) - assert dataframe.process(row).value == row_factory({"x": 1, "new": 1}).value - - def test_setitem_column_with_function(self, dataframe_factory, row_factory): - dataframe = dataframe_factory() - dataframe["new"] = dataframe["x"].apply(lambda v, ctx: v + 5) - row = row_factory({"x": 1}) - assert dataframe.process(row).value == row_factory({"x": 1, "new": 6}).value - - def test_setitem_column_with_operations(self, dataframe_factory, row_factory): - dataframe = dataframe_factory() - dataframe["new"] = ( - dataframe["x"] + dataframe["y"].apply(lambda v, ctx: v + 5) + 1 - ) - row = row_factory({"x": 1, "y": 2}) - expected = row_factory({"x": 1, "y": 2, "new": 9}) - assert dataframe.process(row).value == expected.value +from tests.utils import TopicPartitionStub - def test_setitem_from_a_nested_column( - self, dataframe_factory, row_factory, row_plus_n_func - ): - dataframe = dataframe_factory() - dataframe["a"] = dataframe["x"]["y"] - row = row_factory({"x": {"y": 1, "z": "banana"}}) - expected = row_factory({"x": {"y": 1, "z": "banana"}, "a": 1}) - assert dataframe.process(row).value == expected.value - - def test_column_subset(self, dataframe_factory, row_factory): - dataframe = dataframe_factory() - dataframe = dataframe[["x", "y"]] - row = row_factory({"x": 1, "y": 2, "z": 3}) - expected = row_factory({"x": 1, "y": 2}) - assert dataframe.process(row).value == expected.value - - def test_column_subset_with_funcs( - self, dataframe_factory, row_factory, row_plus_n_func - ): - dataframe = dataframe_factory() - dataframe = dataframe[["x", "y"]].apply(row_plus_n_func(n=5)) - row = row_factory({"x": 1, "y": 2, "z": 3}) - expected = row_factory({"x": 6, "y": 7}) - assert dataframe.process(row).value == expected.value - - def test_inequality_filter(self, dataframe_factory, row_factory): - dataframe = dataframe_factory() - dataframe = dataframe[dataframe["x"] > 0] - row = row_factory({"x": 1, "y": 2}) - assert dataframe.process(row).value == row.value - - def test_inequality_filter_is_filtered(self, dataframe_factory, row_factory): - dataframe = dataframe_factory() - dataframe = dataframe[dataframe["x"] >= 1000] - row = row_factory({"x": 1, "y": 2}) - assert dataframe.process(row) is None - - def test_inequality_filter_with_operation(self, dataframe_factory, row_factory): - dataframe = dataframe_factory() - dataframe = dataframe[(dataframe["x"] - 0 + dataframe["y"]) > 0] - row = row_factory({"x": 1, "y": 2}) - assert dataframe.process(row).value == row.value - - def test_inequality_filter_with_operation_is_filtered( - self, dataframe_factory, row_factory - ): - dataframe = dataframe_factory() - dataframe = dataframe[(dataframe["x"] - dataframe["y"]) > 0] - row = row_factory({"x": 1, "y": 2}) - assert dataframe.process(row) is None - - def test_inequality_filtering_with_apply(self, dataframe_factory, row_factory): - dataframe = dataframe_factory() - dataframe = dataframe[dataframe["x"].apply(lambda v, ctx: v - 1) >= 0] - row = row_factory({"x": 1, "y": 2}) - assert dataframe.process(row).value == row.value - - def test_inequality_filtering_with_apply_is_filtered( - self, dataframe_factory, row_factory - ): - dataframe = dataframe_factory() - dataframe = dataframe[dataframe["x"].apply(lambda v, ctx: v - 10) >= 0] - row = row_factory({"x": 1, "y": 2}) - assert dataframe.process(row) is None - - def test_compound_inequality_filter(self, dataframe_factory, row_factory): - dataframe = dataframe_factory() - dataframe = dataframe[(dataframe["x"] >= 0) & (dataframe["y"] < 10)] - row = row_factory({"x": 1, "y": 2}) - assert dataframe.process(row).value == row.value - - def test_compound_inequality_filter_is_filtered( - self, dataframe_factory, row_factory - ): - dataframe = dataframe_factory() - dataframe = dataframe[(dataframe["x"] >= 0) & (dataframe["y"] < 0)] - row = row_factory({"x": 1, "y": 2}) - assert dataframe.process(row) is None - - def test_contains_on_existing_column(self, dataframe_factory, row_factory): - dataframe = dataframe_factory() - dataframe["has_column"] = dataframe.contains("x") - row = row_factory({"x": 1}) - assert ( - dataframe.process(row).value - == row_factory({"x": 1, "has_column": True}).value - ) - def test_contains_on_missing_column(self, dataframe_factory, row_factory): - dataframe = dataframe_factory() - dataframe["has_column"] = dataframe.contains("wrong_column") - row = row_factory({"x": 1}) - assert ( - dataframe.process(row).value - == row_factory({"x": 1, "has_column": False}).value +class TestStreamingDataFrame: + @pytest.mark.parametrize( + "value, expected", + [(1, 2), ("input", "return"), ([0, 1, 2], "return"), ({"key": "value"}, None)], + ) + def test_apply(self, dataframe_factory, value, expected): + sdf = dataframe_factory() + + def _apply(value_: dict): + assert value_ == value + return expected + + sdf = sdf.apply(_apply) + assert sdf.test(value) == expected + + @pytest.mark.parametrize( + "value, mutation, expected", + [ + ([0, 1, 2], lambda v: v.append(3), [0, 1, 2, 3]), + ({"a": "b"}, lambda v: operator.setitem(v, "x", "y"), {"a": "b", "x": "y"}), + ], + ) + def test_update(self, dataframe_factory, value, mutation, expected): + sdf = dataframe_factory() + sdf = sdf.update(mutation) + assert sdf.test(value) == expected + + def test_apply_multiple(self, dataframe_factory): + sdf = dataframe_factory() + value = 1 + expected = 4 + sdf = sdf.apply(lambda v: v + 1).apply(lambda v: v + 2) + assert sdf.test(value) == expected + + def test_apply_update_multiple(self, dataframe_factory): + sdf = dataframe_factory() + value = {"x": 1} + expected = {"x": 3, "y": 3} + sdf = ( + sdf.apply(lambda v: {"x": v["x"] + 1}) + .update(lambda v: operator.setitem(v, "y", 3)) + .apply(lambda v: {**v, "x": v["x"] + 1}) ) - - def test_contains_as_filter(self, dataframe_factory, row_factory): - dataframe = dataframe_factory() - dataframe = dataframe[dataframe.contains("x")] - - valid_row = row_factory({"x": 1, "y": 2}) - valid_result = dataframe.process(valid_row) - assert valid_result is not None and valid_result.value == valid_row.value - - invalid_row = row_factory({"y": 2}) - assert dataframe.process(invalid_row) is None - - -class TestDataframeKafka: + assert sdf.test(value) == expected + + def test_setitem_primitive(self, dataframe_factory): + value = {"x": 1} + expected = {"x": 2} + sdf = dataframe_factory() + sdf["x"] = 2 + assert sdf.test(value) == expected + + def test_setitem_series(self, dataframe_factory): + value = {"x": 1, "y": 2} + expected = {"x": 2, "y": 2} + sdf = dataframe_factory() + sdf["x"] = sdf["y"] + assert sdf.test(value) == expected + + def test_setitem_series_apply(self, dataframe_factory): + value = {"x": 1} + expected = {"x": 1, "y": 2} + sdf = dataframe_factory() + sdf["y"] = sdf["x"].apply(lambda v: v + 1) + assert sdf.test(value) == expected + + def test_setitem_series_with_operations(self, dataframe_factory): + value = {"x": 1, "y": 2} + expected = {"x": 1, "y": 2, "z": 5} + sdf = dataframe_factory() + sdf["z"] = (sdf["x"] + sdf["y"]).apply(lambda v: v + 1) + 1 + assert sdf.test(value) == expected + + def test_setitem_another_dataframe_apply(self, dataframe_factory): + value = {"x": 1} + expected = {"x": 1, "y": 2} + sdf = dataframe_factory() + sdf["y"] = sdf.apply(lambda v: v["x"] + 1) + assert sdf.test(value) == expected + + def test_column_subset(self, dataframe_factory): + value = {"x": 1, "y": 2, "z": 3} + expected = {"x": 1, "y": 2} + sdf = dataframe_factory() + sdf = sdf[["x", "y"]] + assert sdf.test(value) == expected + + def test_column_subset_and_apply(self, dataframe_factory): + value = {"x": 1, "y": 2, "z": 3} + expected = 2 + sdf = dataframe_factory() + sdf = sdf[["x", "y"]] + sdf = sdf.apply(lambda v: v["y"]) + assert sdf.test(value) == expected + + @pytest.mark.parametrize( + "value, filtered", + [ + ({"x": 1, "y": 2}, False), + ({"x": 0, "y": 2}, True), + ], + ) + def test_filter_with_series(self, dataframe_factory, value, filtered): + sdf = dataframe_factory() + sdf = sdf[sdf["x"] > 0] + + if filtered: + with pytest.raises(Filtered): + assert sdf.test(value) + else: + assert sdf.test(value) == value + + @pytest.mark.parametrize( + "value, filtered", + [ + ({"x": 1, "y": 2}, False), + ({"x": 0, "y": 2}, True), + ], + ) + def test_filter_with_series_apply(self, dataframe_factory, value, filtered): + sdf = dataframe_factory() + sdf = sdf[sdf["x"].apply(lambda v: v > 0)] + + if filtered: + with pytest.raises(Filtered): + assert sdf.test(value) + else: + assert sdf.test(value) == value + + @pytest.mark.parametrize( + "value, filtered", + [ + ({"x": 1, "y": 2}, False), + ({"x": 0, "y": 2}, True), + ], + ) + def test_filter_with_multiple_series(self, dataframe_factory, value, filtered): + sdf = dataframe_factory() + sdf = sdf[(sdf["x"] > 0) & (sdf["y"] > 0)] + + if filtered: + with pytest.raises(Filtered): + assert sdf.test(value) + else: + assert sdf.test(value) == value + + @pytest.mark.parametrize( + "value, filtered", + [ + ({"x": 1, "y": 2}, False), + ({"x": 0, "y": 2}, True), + ], + ) + def test_filter_with_another_sdf_apply(self, dataframe_factory, value, filtered): + sdf = dataframe_factory() + sdf = sdf[sdf.apply(lambda v: v["x"] > 0)] + + if filtered: + with pytest.raises(Filtered): + assert sdf.test(value) + else: + assert sdf.test(value) == value + + def test_filter_with_another_sdf_with_filters_fails(self, dataframe_factory): + sdf = dataframe_factory() + sdf2 = sdf[sdf["x"] > 1].apply(lambda v: v["x"] > 0) + with pytest.raises(ValueError, match="Filter functions are not allowed"): + sdf = sdf[sdf2] + + def test_filter_with_another_sdf_with_update_fails(self, dataframe_factory): + sdf = dataframe_factory() + sdf2 = sdf.apply(lambda v: v).update(lambda v: operator.setitem(v, "x", 2)) + with pytest.raises(ValueError, match="Update functions are not allowed"): + sdf = sdf[sdf2] + + @pytest.mark.parametrize( + "value, filtered", + [ + ({"x": 1, "y": 2}, False), + ({"x": 0, "y": 2}, True), + ], + ) + def test_filter_with_function(self, dataframe_factory, value, filtered): + sdf = dataframe_factory() + sdf = sdf.filter(lambda v: v["x"] > 0) + + if filtered: + with pytest.raises(Filtered): + assert sdf.test(value) + else: + assert sdf.test(value) == value + + def test_contains_on_existing_column(self, dataframe_factory): + sdf = dataframe_factory() + sdf["has_column"] = sdf.contains("x") + assert sdf.test({"x": 1}) == {"x": 1, "has_column": True} + + def test_contains_on_missing_column(self, dataframe_factory): + sdf = dataframe_factory() + sdf["has_column"] = sdf.contains("wrong_column") + + assert sdf.test({"x": 1}) == {"x": 1, "has_column": False} + + def test_contains_as_filter(self, dataframe_factory): + sdf = dataframe_factory() + sdf = sdf[sdf.contains("x")] + + valid_value = {"x": 1, "y": 2} + assert sdf.test(valid_value) == valid_value + + invalid_value = {"y": 2} + with pytest.raises(Filtered): + sdf.test(invalid_value) + + +class TestDatafTestStreamingDataFrameToTopic: def test_to_topic( self, dataframe_factory, row_consumer_factory, row_producer_factory, - row_factory, - topic_json_serdes_factory, + topic_factory, ): - topic = topic_json_serdes_factory() + topic_name, _ = topic_factory() + topic = Topic( + topic_name, + key_deserializer="str", + value_serializer="json", + value_deserializer="json", + ) producer = row_producer_factory() - dataframe = dataframe_factory() - dataframe.producer = producer - dataframe.to_topic(topic) - - assert dataframe.topics_out[topic.name] == topic - - row_to_produce = row_factory( - topic="ignore_me", - key=b"test_key", - value={"x": "1", "y": "2"}, + sdf = dataframe_factory() + sdf.producer = producer + sdf = sdf.to_topic(topic) + + value = {"x": 1, "y": 2} + ctx = MessageContext( + key="test", + topic="test", + partition=0, + offset=0, + size=0, + timestamp=MessageTimestamp.create(0, 0), ) with producer: - dataframe.process(row_to_produce) + sdf.test(value, ctx=ctx) with row_consumer_factory(auto_offset_reset="earliest") as consumer: consumer.subscribe([topic]) @@ -226,33 +264,43 @@ def test_to_topic( assert consumed_row assert consumed_row.topic == topic.name - assert row_to_produce.key == consumed_row.key - assert row_to_produce.value == consumed_row.value + assert consumed_row.key == ctx.key + assert consumed_row.value == value def test_to_topic_custom_key( self, dataframe_factory, row_consumer_factory, row_producer_factory, - row_factory, - topic_json_serdes_factory, + topic_factory, ): - topic = topic_json_serdes_factory() + topic_name, _ = topic_factory() + topic = Topic( + topic_name, + value_serializer="json", + value_deserializer="json", + key_serializer="int", + key_deserializer="int", + ) producer = row_producer_factory() - dataframe = dataframe_factory() - dataframe.producer = producer - # Using value of "x" column as a new key - dataframe.to_topic(topic, key=lambda value, ctx: value["x"]) + sdf = dataframe_factory() + sdf.producer = producer + + # Use value["x"] as a new key + sdf = sdf.to_topic(topic, key=lambda v: v["x"]) - row_to_produce = row_factory( - topic=topic.name, - key=b"test_key", - value={"x": "1", "y": "2"}, + value = {"x": 1, "y": 2} + ctx = MessageContext( + topic="test", + partition=0, + offset=0, + size=0, + timestamp=MessageTimestamp.create(0, 0), ) with producer: - dataframe.process(row_to_produce) + sdf.test(value, ctx=ctx) with row_consumer_factory(auto_offset_reset="earliest") as consumer: consumer.subscribe([topic]) @@ -260,34 +308,48 @@ def test_to_topic_custom_key( assert consumed_row assert consumed_row.topic == topic.name - assert consumed_row.value == row_to_produce.value - assert consumed_row.key == row_to_produce.value["x"].encode() + assert consumed_row.value == value + assert consumed_row.key == value["x"] def test_to_topic_multiple_topics_out( self, dataframe_factory, row_consumer_factory, row_producer_factory, - row_factory, - topic_json_serdes_factory, + topic_factory, ): - topic_0 = topic_json_serdes_factory() - topic_1 = topic_json_serdes_factory() + topic_0_name, _ = topic_factory() + topic_1_name, _ = topic_factory() + + topic_0 = Topic( + topic_0_name, + value_serializer="json", + value_deserializer="json", + ) + topic_1 = Topic( + topic_1_name, + value_serializer="json", + value_deserializer="json", + ) producer = row_producer_factory() - dataframe = dataframe_factory() - dataframe.producer = producer + sdf = dataframe_factory() + sdf.producer = producer - dataframe.to_topic(topic_0) - dataframe.to_topic(topic_1) + sdf = sdf.to_topic(topic_0).to_topic(topic_1) - row_to_produce = row_factory( - key=b"test_key", - value={"x": "1", "y": "2"}, + value = {"x": 1, "y": 2} + ctx = MessageContext( + key=b"test", + topic="test", + partition=0, + offset=0, + size=0, + timestamp=MessageTimestamp.create(0, 0), ) with producer: - dataframe.process(row_to_produce) + sdf.test(value, ctx=ctx) consumed_rows = [] with row_consumer_factory(auto_offset_reset="earliest") as consumer: @@ -300,53 +362,199 @@ def test_to_topic_multiple_topics_out( t.name for t in [topic_0, topic_1] } for consumed_row in consumed_rows: - assert row_to_produce.key == consumed_row.key - assert row_to_produce.value == consumed_row.value + assert consumed_row.key == ctx.key + assert consumed_row.value == value def test_to_topic_no_producer_assigned(self, dataframe_factory, row_factory): - topic = Topic("whatever") + topic = Topic("test") - dataframe = dataframe_factory() - dataframe.to_topic(topic) + sdf = dataframe_factory() + sdf = sdf.to_topic(topic) + + value = {"x": "1", "y": "2"} + ctx = MessageContext( + key=b"test", + topic="test", + partition=0, + offset=0, + size=0, + timestamp=MessageTimestamp.create(0, 0), + ) - with pytest.raises(RuntimeError): - dataframe.process( - row_factory( - topic=topic.name, key=b"test_key", value={"x": "1", "y": "2"} - ) - ) + with pytest.raises( + RuntimeError, match="Producer instance has not been provided" + ): + sdf.test(value, ctx=ctx) class TestDataframeStateful: - def test_apply_stateful(self, dataframe_factory, state_manager, row_factory): + def test_apply_stateful(self, dataframe_factory, state_manager): topic = Topic("test") - def stateful_func(value, ctx, state: State): + def stateful_func(value_: dict, state: State) -> int: current_max = state.get("max") if current_max is None: - current_max = value["number"] + current_max = value_["number"] else: - current_max = max(current_max, value["number"]) + current_max = max(current_max, value_["number"]) state.set("max", current_max) - value["max"] = current_max + return current_max sdf = dataframe_factory(topic, state_manager=state_manager) - sdf.apply(stateful_func, stateful=True) + sdf = sdf.apply(stateful_func, stateful=True) state_manager.on_partition_assign( tp=TopicPartitionStub(topic=topic.name, partition=0) ) - rows = [ - row_factory(topic=topic.name, value={"number": 1}), - row_factory(topic=topic.name, value={"number": 10}), - row_factory(topic=topic.name, value={"number": 3}), + values = [ + {"number": 1}, + {"number": 10}, + {"number": 3}, ] result = None - for row in rows: + ctx = MessageContext( + key=b"test", + topic="test", + partition=0, + offset=0, + size=0, + timestamp=MessageTimestamp.create(0, 0), + ) + for value in values: with state_manager.start_store_transaction( - topic=row.topic, partition=row.partition, offset=row.offset + topic=ctx.topic, partition=ctx.partition, offset=ctx.offset ): - result = sdf.process(row) + result = sdf.test(value, ctx) + + assert result == 10 + + def test_update_stateful(self, dataframe_factory, state_manager): + topic = Topic("test") + + def stateful_func(value_: dict, state: State): + current_max = state.get("max") + if current_max is None: + current_max = value_["number"] + else: + current_max = max(current_max, value_["number"]) + state.set("max", current_max) + value_["max"] = current_max + + sdf = dataframe_factory(topic, state_manager=state_manager) + sdf = sdf.update(stateful_func, stateful=True) - assert result - assert result.value["max"] == 10 + state_manager.on_partition_assign( + tp=TopicPartitionStub(topic=topic.name, partition=0) + ) + result = None + values = [ + {"number": 1}, + {"number": 10}, + {"number": 3}, + ] + ctx = MessageContext( + key=b"test", + topic="test", + partition=0, + offset=0, + size=0, + timestamp=MessageTimestamp.create(0, 0), + ) + for value in values: + with state_manager.start_store_transaction( + topic=ctx.topic, partition=ctx.partition, offset=ctx.offset + ): + result = sdf.test(value, ctx) + + assert result is not None + assert result["max"] == 10 + + def test_filter_stateful(self, dataframe_factory, state_manager): + topic = Topic("test") + + def stateful_func(value_: dict, state: State): + current_max = state.get("max") + if current_max is None: + current_max = value_["number"] + else: + current_max = max(current_max, value_["number"]) + state.set("max", current_max) + value_["max"] = current_max + + sdf = dataframe_factory(topic, state_manager=state_manager) + sdf = sdf.update(stateful_func, stateful=True) + sdf = sdf.filter(lambda v, state: state.get("max") >= 3, stateful=True) + + state_manager.on_partition_assign( + tp=TopicPartitionStub(topic=topic.name, partition=0) + ) + values = [ + {"number": 1}, + {"number": 1}, + {"number": 3}, + ] + ctx = MessageContext( + key=b"test", + topic="test", + partition=0, + offset=0, + size=0, + timestamp=MessageTimestamp.create(0, 0), + ) + results = [] + for value in values: + with state_manager.start_store_transaction( + topic=ctx.topic, partition=ctx.partition, offset=ctx.offset + ): + try: + results.append(sdf.test(value, ctx)) + except Filtered: + pass + assert len(results) == 1 + assert results[0]["max"] == 3 + + def test_filter_with_another_sdf_apply_stateful( + self, dataframe_factory, state_manager + ): + topic = Topic("test") + + def stateful_func(value_: dict, state: State): + current_max = state.get("max") + if current_max is None: + current_max = value_["number"] + else: + current_max = max(current_max, value_["number"]) + state.set("max", current_max) + value_["max"] = current_max + + sdf = dataframe_factory(topic, state_manager=state_manager) + sdf = sdf.update(stateful_func, stateful=True) + sdf = sdf[sdf.apply(lambda v, state: state.get("max") >= 3, stateful=True)] + + state_manager.on_partition_assign( + tp=TopicPartitionStub(topic=topic.name, partition=0) + ) + values = [ + {"number": 1}, + {"number": 1}, + {"number": 3}, + ] + ctx = MessageContext( + key=b"test", + topic="test", + partition=0, + offset=0, + size=0, + timestamp=MessageTimestamp.create(0, 0), + ) + results = [] + for value in values: + with state_manager.start_store_transaction( + topic=ctx.topic, partition=ctx.partition, offset=ctx.offset + ): + try: + results.append(sdf.test(value, ctx)) + except Filtered: + pass + assert len(results) == 1 + assert results[0]["max"] == 3 diff --git a/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_pipeline.py b/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_pipeline.py deleted file mode 100644 index 163b602f6..000000000 --- a/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_pipeline.py +++ /dev/null @@ -1,32 +0,0 @@ -from quixstreams.dataframe.pipeline import PipelineFunction - - -class TestPipeline: - def test_pipeline(self, pipeline_function, pipeline): - assert pipeline._functions == [pipeline_function] - - def test_apply(self, pipeline): - def throwaway_func(data): - return data - - pipeline = pipeline.apply(throwaway_func) - assert isinstance(pipeline.functions[-1], PipelineFunction) - assert pipeline.functions[-1]._func == throwaway_func - - def test_process(self, pipeline): - assert pipeline.process({"a": 1, "b": 2}) == {"a": 2, "b": 3} - - def test_process_list(self, pipeline): - actual = pipeline.process([{"a": 1, "b": 2}, {"a": 10, "b": 11}]) - expected = [{"a": 2, "b": 3}, {"a": 11, "b": 12}] - assert actual == expected - - def test_process_break_and_return_none(self, pipeline): - """ - Add a new function that returns None, then reverse the _functions order - so that an exception would get thrown if the pipeline doesn't (as expected) - stop processing when attempting to execute a function with a NoneType input. - """ - pipeline = pipeline.apply(lambda d: None) - pipeline._functions.reverse() - assert pipeline.process({"a": 1, "b": 2}) is None diff --git a/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_series.py b/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_series.py new file mode 100644 index 000000000..45d0ac73f --- /dev/null +++ b/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_series.py @@ -0,0 +1,313 @@ +import pytest + +from quixstreams.dataframe.series import StreamingSeries + + +class TestStreamingSeries: + def test_apply(self): + value = {"x": 5, "y": 20, "z": 110} + expected = {"x": 6} + result = StreamingSeries("x").apply(lambda v: {"x": v + 1}) + assert isinstance(result, StreamingSeries) + assert result.test(value) == expected + + @pytest.mark.parametrize( + "value, series, other, expected", + [ + ({"x": 5, "y": 20}, StreamingSeries("x"), StreamingSeries("y"), 25), + ({"x": 5, "y": 20}, StreamingSeries("x"), 10, 15), + ], + ) + def test_add(self, value, series, other, expected): + result = series + other + assert result.test(value) == expected + + @pytest.mark.parametrize( + "value, series, other, expected", + [ + ({"x": 5, "y": 20}, StreamingSeries("y"), StreamingSeries("x"), 15), + ({"x": 5, "y": 20}, StreamingSeries("x"), 10, -5), + ], + ) + def test_subtract(self, value, series, other, expected): + result = series - other + assert result.test(value) == expected + + @pytest.mark.parametrize( + "value, series, other, expected", + [ + ({"x": 5, "y": 20}, StreamingSeries("y"), StreamingSeries("x"), 100), + ({"x": 5, "y": 20}, StreamingSeries("x"), 10, 50), + ], + ) + def test_multiply(self, value, series, other, expected): + result = series * other + assert result.test(value) == expected + + @pytest.mark.parametrize( + "value, series, other, expected", + [ + ({"x": 5, "y": 20}, StreamingSeries("x"), StreamingSeries("x"), 1), + ({"x": 5, "y": 20}, StreamingSeries("x"), 2, 2.5), + ], + ) + def test_div(self, value, series, other, expected): + result = series / other + assert result.test(value) == expected + + @pytest.mark.parametrize( + "value, series, other, expected", + [ + ({"x": 5, "y": 2}, StreamingSeries("x"), StreamingSeries("y"), 1), + ({"x": 5, "y": 20}, StreamingSeries("x"), 3, 2), + ], + ) + def test_mod(self, value, series, other, expected): + result = series % other + assert result.test(value) == expected + + @pytest.mark.parametrize( + "value, series, other, expected", + [ + ({"x": 5, "y": 2}, StreamingSeries("x"), StreamingSeries("x"), True), + ({"x": 5, "y": 2}, StreamingSeries("x"), StreamingSeries("y"), False), + ({"x": 5, "y": 20}, StreamingSeries("x"), 5, True), + ({"x": 5, "y": 20}, StreamingSeries("x"), 6, False), + ], + ) + def test_equal(self, value, series, other, expected): + result = series == other + assert result.test(value) is expected + + @pytest.mark.parametrize( + "value, series, other, expected", + [ + ({"x": 5, "y": 2}, StreamingSeries("x"), StreamingSeries("x"), False), + ({"x": 5, "y": 2}, StreamingSeries("x"), StreamingSeries("y"), True), + ({"x": 5, "y": 20}, StreamingSeries("x"), 5, False), + ({"x": 5, "y": 20}, StreamingSeries("x"), 6, True), + ], + ) + def test_not_equal(self, value, series, other, expected): + result = series != other + assert result.test(value) is expected + + @pytest.mark.parametrize( + "value, series, other, expected", + [ + ({"x": 5, "y": 20}, StreamingSeries("x"), StreamingSeries("x"), False), + ({"x": 5, "y": 20}, StreamingSeries("x"), StreamingSeries("y"), True), + ({"x": 5, "y": 20}, StreamingSeries("x"), 5, False), + ({"x": 5, "y": 20}, StreamingSeries("x"), 6, True), + ], + ) + def test_less_than(self, value, series, other, expected): + result = series < other + assert result.test(value) is expected + + @pytest.mark.parametrize( + "value, series, other, expected", + [ + ({"x": 5, "y": 20}, StreamingSeries("x"), StreamingSeries("x"), True), + ({"x": 5, "y": 20}, StreamingSeries("x"), StreamingSeries("y"), True), + ({"x": 5, "y": 20}, StreamingSeries("x"), 4, False), + ({"x": 5, "y": 20}, StreamingSeries("x"), 5, True), + ({"x": 5, "y": 20}, StreamingSeries("x"), 6, True), + ], + ) + def test_less_than_equal(self, value, series, other, expected): + result = series <= other + assert result.test(value) is expected + + @pytest.mark.parametrize( + "value, series, other, expected", + [ + ({"x": 5, "y": 20}, StreamingSeries("x"), StreamingSeries("x"), False), + ({"x": 5, "y": 4}, StreamingSeries("x"), StreamingSeries("y"), True), + ({"x": 5, "y": 20}, StreamingSeries("x"), 4, True), + ({"x": 5, "y": 20}, StreamingSeries("x"), 5, False), + ({"x": 5, "y": 20}, StreamingSeries("x"), 6, False), + ], + ) + def test_greater_than(self, value, series, other, expected): + result = series > other + assert result.test(value) is expected + + @pytest.mark.parametrize( + "value, series, other, expected", + [ + ({"x": 5, "y": 20}, StreamingSeries("x"), StreamingSeries("x"), True), + ({"x": 5, "y": 4}, StreamingSeries("x"), StreamingSeries("y"), True), + ({"x": 5, "y": 6}, StreamingSeries("x"), StreamingSeries("y"), False), + ({"x": 5, "y": 20}, StreamingSeries("x"), 4, True), + ({"x": 5, "y": 20}, StreamingSeries("x"), 5, True), + ({"x": 5, "y": 20}, StreamingSeries("x"), 6, False), + ], + ) + def test_greater_than_equal(self, value, series, other, expected): + result = series >= other + assert result.test(value) is expected + + @pytest.mark.parametrize( + "value, series, other, expected", + [ + ({"x": True, "y": False}, StreamingSeries("x"), StreamingSeries("x"), True), + ( + {"x": True, "y": False}, + StreamingSeries("x"), + StreamingSeries("y"), + False, + ), + ({"x": True, "y": False}, StreamingSeries("x"), True, True), + ({"x": True, "y": False}, StreamingSeries("x"), False, False), + ({"x": True, "y": False}, StreamingSeries("x"), 0, 0), + ], + ) + def test_and(self, value, series, other, expected): + result = series & other + assert result.test(value) is expected + + @pytest.mark.parametrize( + "value, series, other, expected", + [ + ({"x": True, "y": False}, StreamingSeries("x"), StreamingSeries("y"), True), + ( + {"x": False}, + StreamingSeries("x"), + StreamingSeries("x"), + False, + ), + ( + { + "x": True, + }, + StreamingSeries("x"), + 0, + True, + ), + ({"x": False}, StreamingSeries("x"), 0, 0), + ({"x": False}, StreamingSeries("x"), True, True), + ], + ) + def test_or(self, value, series, other, expected): + result = series | other + assert result.test(value) is expected + + def test_multiple_conditions(self): + value = {"x": 5, "y": 20, "z": 110} + expected = True + result = (StreamingSeries("x") <= StreamingSeries("y")) & ( + StreamingSeries("x") <= StreamingSeries("z") + ) + + assert result.test(value) is expected + + @pytest.mark.parametrize( + "value, series, expected", + [ + ({"x": True, "y": False}, StreamingSeries("x"), False), + ({"x": 1, "y": False}, StreamingSeries("x"), False), + ], + ) + def test_invert(self, value, series, expected): + result = ~series + + assert result.test(value) == expected + + @pytest.mark.parametrize( + "value, series, other, expected", + [ + ({"x": 1}, StreamingSeries("x"), [1, 2, 3], True), + ({"x": 1}, StreamingSeries("x"), [], False), + ({"x": 1}, StreamingSeries("x"), {1: 456}, True), + ], + ) + def test_isin(self, value, series, other, expected): + assert series.isin(other).test(value) == expected + + @pytest.mark.parametrize( + "value, series, other, expected", + [ + ({"x": [1, 2, 3]}, StreamingSeries("x"), 1, True), + ({"x": [1, 2, 3]}, StreamingSeries("x"), 5, False), + ({"x": "abc"}, StreamingSeries("x"), "a", True), + ({"x": {"y": "z"}}, StreamingSeries("x"), "y", True), + ], + ) + def test_contains(self, series, value, other, expected): + assert series.contains(other).test(value) == expected + + @pytest.mark.parametrize( + "value, series, expected", + [ + ({"x": None}, StreamingSeries("x"), True), + ({"x": [1, 2, 3]}, StreamingSeries("x"), False), + ], + ) + def test_isnull(self, value, series, expected): + assert series.isnull().test(value) == expected + + @pytest.mark.parametrize( + "value, series, expected", + [ + ({"x": None}, StreamingSeries("x"), False), + ({"x": [1, 2, 3]}, StreamingSeries("x"), True), + ], + ) + def test_notnull(self, value, series, expected): + assert series.notnull().test(value) == expected + + @pytest.mark.parametrize( + "value, series, other, expected", + [ + ({"x": [1, 2, 3]}, StreamingSeries("x"), None, False), + ({"x": None}, StreamingSeries("x"), None, True), + ({"x": 1}, StreamingSeries("x"), 1, True), + ], + ) + def test_is_(self, value, series, other, expected): + assert series.is_(other).test(value) == expected + + @pytest.mark.parametrize( + "value, series, other, expected", + [ + ({"x": [1, 2, 3]}, StreamingSeries("x"), None, True), + ({"x": None}, StreamingSeries("x"), None, False), + ({"x": 1}, StreamingSeries("x"), 1, False), + ], + ) + def test_isnot(self, value, series, other, expected): + assert series.isnot(other).test(value) == expected + + @pytest.mark.parametrize( + "value, item, expected", + [ + ({"x": {"y": 1}}, "y", 1), + ({"x": [0, 1, 2, 3]}, 1, 1), + ], + ) + def test_getitem(self, value, item, expected): + result = StreamingSeries("x")[item] + assert result.test(value) == expected + + def test_getitem_with_apply(self): + value = {"x": {"y": {"z": 110}}, "k": 0} + result = StreamingSeries("x")["y"]["z"].apply(lambda v: v + 10) + + assert result.test(value) == 120 + + @pytest.mark.parametrize("value, expected", [(10, 10), (-10, 10), (10.0, 10.0)]) + def test_abs_success( + self, + value, + expected, + ): + result = StreamingSeries("x").abs() + + assert result.test({"x": value}) == expected + + def test_abs_not_a_number_fails(self): + result = StreamingSeries("x").abs() + + with pytest.raises(TypeError, match="bad operand type for abs()"): + assert result.test({"x": "string"})