From 266776d79868e138b7dcdf70f3dda410d7f7643e Mon Sep 17 00:00:00 2001
From: Daniil Gusev <133032822+daniil-quix@users.noreply.github.com>
Date: Wed, 15 Nov 2023 10:37:41 +0100
Subject: [PATCH] Allow StreamingDataFrame.apply() to be assigned to keys and
used as a filter (#238)
---
README.md | 37 +-
src/StreamingDataFrames/docs/serialization.md | 28 +-
.../docs/stateful-processing.md | 35 +-
.../docs/streamingdataframe.md | 313 +++++---
src/StreamingDataFrames/examples/README.md | 6 +-
.../examples/bank_example/README.md | 4 +-
.../bank_example/json_version/consumer.py | 29 +-
.../quix_platform_version/consumer.py | 33 +-
.../quix_platform_version/producer.py | 5 +-
.../quixstreams/__init__.py | 1 +
src/StreamingDataFrames/quixstreams/app.py | 17 +-
.../quixstreams/context.py | 49 ++
.../quixstreams/core/__init__.py | 0
.../quixstreams/core/stream/__init__.py | 2 +
.../quixstreams/core/stream/functions.py | 125 +++
.../quixstreams/core/stream/stream.py | 239 ++++++
.../quixstreams/dataframe/__init__.py | 2 +-
.../quixstreams/dataframe/base.py | 20 +
.../quixstreams/dataframe/column.py | 140 ----
.../quixstreams/dataframe/dataframe.py | 419 +++++-----
.../quixstreams/dataframe/exceptions.py | 7 -
.../quixstreams/dataframe/pipeline.py | 82 --
.../quixstreams/dataframe/series.py | 234 ++++++
.../quixstreams/models/__init__.py | 1 +
.../models/{context.py => messagecontext.py} | 0
.../quixstreams/models/rows.py | 2 +-
.../quixstreams/models/topics.py | 2 +-
.../tests/test_quixstreams/test_app.py | 96 +--
.../tests/test_quixstreams/test_context.py | 52 ++
.../test_quixstreams/test_core/__init__.py | 0
.../test_core/test_functions.py | 42 +
.../test_quixstreams/test_core/test_stream.py | 107 +++
.../test_dataframe/fixtures.py | 41 +-
.../test_dataframe/test_column.py | 328 --------
.../test_dataframe/test_dataframe.py | 724 +++++++++++-------
.../test_dataframe/test_pipeline.py | 32 -
.../test_dataframe/test_series.py | 313 ++++++++
37 files changed, 2210 insertions(+), 1357 deletions(-)
create mode 100644 src/StreamingDataFrames/quixstreams/context.py
create mode 100644 src/StreamingDataFrames/quixstreams/core/__init__.py
create mode 100644 src/StreamingDataFrames/quixstreams/core/stream/__init__.py
create mode 100644 src/StreamingDataFrames/quixstreams/core/stream/functions.py
create mode 100644 src/StreamingDataFrames/quixstreams/core/stream/stream.py
create mode 100644 src/StreamingDataFrames/quixstreams/dataframe/base.py
delete mode 100644 src/StreamingDataFrames/quixstreams/dataframe/column.py
delete mode 100644 src/StreamingDataFrames/quixstreams/dataframe/exceptions.py
delete mode 100644 src/StreamingDataFrames/quixstreams/dataframe/pipeline.py
create mode 100644 src/StreamingDataFrames/quixstreams/dataframe/series.py
rename src/StreamingDataFrames/quixstreams/models/{context.py => messagecontext.py} (100%)
create mode 100644 src/StreamingDataFrames/tests/test_quixstreams/test_context.py
create mode 100644 src/StreamingDataFrames/tests/test_quixstreams/test_core/__init__.py
create mode 100644 src/StreamingDataFrames/tests/test_quixstreams/test_core/test_functions.py
create mode 100644 src/StreamingDataFrames/tests/test_quixstreams/test_core/test_stream.py
delete mode 100644 src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_column.py
delete mode 100644 src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_pipeline.py
create mode 100644 src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_series.py
diff --git a/README.md b/README.md
index cb55e845d..5afc20166 100644
--- a/README.md
+++ b/README.md
@@ -63,12 +63,12 @@ See [requirements.txt](./src/StreamingDataFrames/requirements.txt) for the full
Here's an example of how to process data from a Kafka Topic with Quix Streams:
```python
-from quixstreams import Application, MessageContext, State
+from quixstreams import Application, State
# Define an application
app = Application(
- broker_address="localhost:9092", # Kafka broker address
- consumer_group="consumer-group-name", # Kafka consumer group
+ broker_address="localhost:9092", # Kafka broker address
+ consumer_group="consumer-group-name", # Kafka consumer group
)
# Define the input and output topics. By default, "json" serialization will be used
@@ -76,13 +76,7 @@ input_topic = app.topic("my_input_topic")
output_topic = app.topic("my_output_topic")
-def add_one(data: dict, ctx: MessageContext):
- for field, value in data.items():
- if isinstance(value, int):
- data[field] += 1
-
-
-def count(data: dict, ctx: MessageContext, state: State):
+def count(data: dict, state: State):
# Get a value from state for the current Kafka message key
total = state.get('total', default=0)
total += 1
@@ -91,27 +85,34 @@ def count(data: dict, ctx: MessageContext, state: State):
# Update your message data with a value from the state
data['total'] = total
+
# Create a StreamingDataFrame instance
# StreamingDataFrame is a primary interface to define the message processing pipeline
sdf = app.dataframe(topic=input_topic)
# Print the incoming messages
-sdf = sdf.apply(lambda value, ctx: print('Received a message:', value))
+sdf = sdf.update(lambda value: print('Received a message:', value))
# Select fields from incoming messages
-sdf = sdf[["field_0", "field_2", "field_8"]]
+sdf = sdf[["field_1", "field_2", "field_3"]]
# Filter only messages with "field_0" > 10 and "field_2" != "test"
-sdf = sdf[(sdf["field_0"] > 10) & (sdf["field_2"] != "test")]
+sdf = sdf[(sdf["field_1"] > 10) & (sdf["field_2"] != "test")]
+
+# Filter messages using custom functions
+sdf = sdf[sdf.apply(lambda value: 0 < (value['field_1'] + value['field_3']) < 1000)]
+
+# Generate a new value based on the current one
+sdf = sdf.apply(lambda value: {**value, 'new_field': 'new_value'})
-# Apply custom function to transform the message
-sdf = sdf.apply(add_one)
+# Update a value based on the entire message content
+sdf['field_4'] = sdf.apply(lambda value: value['field_1'] + value['field_3'])
-# Apply a stateful function to persist data to the state store
-sdf = sdf.apply(count, stateful=True)
+# Use a stateful function to persist data to the state store and update the value in place
+sdf = sdf.update(count, stateful=True)
# Print the result before producing it
-sdf = sdf.apply(lambda value, ctx: print('Producing a message:', value))
+sdf = sdf.update(lambda value, ctx: print('Producing a message:', value))
# Produce the result to the output topic
sdf = sdf.to_topic(output_topic)
diff --git a/src/StreamingDataFrames/docs/serialization.md b/src/StreamingDataFrames/docs/serialization.md
index 63d8825a5..2883c6285 100644
--- a/src/StreamingDataFrames/docs/serialization.md
+++ b/src/StreamingDataFrames/docs/serialization.md
@@ -21,30 +21,31 @@ By default, message values are serialized with `JSON`, message keys are seriali
- `quix_events` & `quix_timeseries` - for serializers only.
## Using SerDes
-To set a serializer, you may either pass a string shorthand for it, or an instance of `streamingdataframes.models.serializers.Serializer` and `streamingdataframes.models.serializers.Deserializer` directly
+To set a serializer, you may either pass a string shorthand for it, or an instance of `quixstreams.models.serializers.Serializer` and `quixstreams.models.serializers.Deserializer` directly
to the `Application.topic()`.
Example with format shorthands:
```python
-from streamingdataframes.models.serializers import JSONDeserializer
-app = Application(...)
+from quixstreams import Application
+app = Application(broker_address='localhost:9092', consumer_group='consumer')
# Deserializing message values from JSON to objects and message keys as strings
-input_topic = app.topic(value_deserializer='json', key_deserializer='string')
+input_topic = app.topic('input', value_deserializer='json', key_deserializer='string')
# Serializing message values to JSON and message keys to bytes
-output_topic = app.topic(value_serializer='json', key_deserializer='bytes')
+output_topic = app.topic('output', value_serializer='json', key_deserializer='bytes')
```
Passing `Serializer` and `Deserializer` instances directly:
```python
-from streamingdataframes.models.serializers import JSONDeserializer, JSONSerializer
-app = Application(...)
-input_topic = app.topic(value_deserializer=JSONDeserializer())
-output_topic = app.topic(value_deserializer=JSONSerializer())
+from quixstreams import Application
+from quixstreams.models.serializers import JSONDeserializer, JSONSerializer
+app = Application(broker_address='localhost:9092', consumer_group='consumer')
+input_topic = app.topic('input', value_deserializer=JSONDeserializer())
+output_topic = app.topic('output', value_serializer=JSONSerializer())
```
-You can find all available serializers in `streamingdataframes.models.serializers` module.
+You can find all available serializers in `quixstreams.models.serializers` module.
We also plan on including other popular ones like Avro and Protobuf in the near future.
@@ -56,8 +57,9 @@ The Deserializer object will wrap the received value to the dictionary with `col
Example:
```python
-from streamingdataframes.models.serializers import IntegerDeserializer
-app = Application(...)
-input_topic = app.topic(value_deserializer=IntegerDeserializer(column_name='number'))
+from quixstreams import Application
+from quixstreams.models.serializers import IntegerDeserializer
+app = Application(broker_address='localhost:9092', consumer_group='consumer')
+input_topic = app.topic('input', value_deserializer=IntegerDeserializer(column_name='number'))
# Will deserialize message with value "123" to "{'number': 123}" ...
```
diff --git a/src/StreamingDataFrames/docs/stateful-processing.md b/src/StreamingDataFrames/docs/stateful-processing.md
index 9138d5cc6..9ffcff882 100644
--- a/src/StreamingDataFrames/docs/stateful-processing.md
+++ b/src/StreamingDataFrames/docs/stateful-processing.md
@@ -45,28 +45,31 @@ When another consumer reads the message with `KEY_B`, it will not be able to rea
## Using State
-The state is available in functions passed to `StreamingDataFrame.apply()` with parameter `stateful=True`:
+The state is available in functions passed to `StreamingDataFrame.apply()`, `StreamingDataFrame.update()` and `StreamingDataFrame.filter()` with parameter `stateful=True`:
```python
-from quixstreams import Application, MessageContext, State
-app = Application()
+from quixstreams import Application, State
+app = Application(
+ broker_address='localhost:9092',
+ consumer_group='consumer',
+)
topic = app.topic('topic')
sdf = app.dataframe(topic)
-def count_messages(value: dict, ctx: MessageContext, state: State):
+def count_messages(value: dict, state: State):
total = state.get('total', default=0)
total += 1
state.set('total', total)
- value['total'] = total
+ return {**value, 'total': total}
-# Apply a custom function and inform StreamingDataFrame to provide a State instance to it
-# by passing "stateful=True"
-sdf.apply(count_messages, stateful=True)
+
+# Apply a custom function and inform StreamingDataFrame to provide a State instance to it via passing "stateful=True"
+sdf = sdf.apply(count_messages, stateful=True)
```
-Currently, only functions passed to `StreamingDataFrame.apply()` may use State.
+Currently, only functions passed to `StreamingDataFrame.apply()`, `StreamingDataFrame.update()` and `StreamingDataFrame.filter()` may use State.
@@ -75,11 +78,19 @@ Currently, only functions passed to `StreamingDataFrame.apply()` may use State.
By default, an `Application` keeps the state in `state` directory relative to the current working directory.
To change it, pass `state_dir="your-path"` to `Application` or `Application.Quix` calls:
```python
-Application(state_dir="folder/path/here")
+from quixstreams import Application
+app = Application(
+ broker_address='localhost:9092',
+ consumer_group='consumer',
+ state_dir="folder/path/here",
+)
# or
-Application.Quix(state_dir="folder/path/here")
+app = Application.Quix(
+ consumer_group='consumer',
+ state_dir="folder/path/here",
+)
```
## State Guarantees
@@ -105,4 +116,4 @@ We plan to add a proper recovery process in the future.
#### Shared state directory
In the current version, it's assumed that the state directory is shared between consumers (e.g. using Kubernetes PVC)
-If consumers live on different nodes and don't have access to the same state directory, they will not be able to pickup state on rebalancing.
\ No newline at end of file
+If consumers live on different nodes and don't have access to the same state directory, they will not be able to pick up state on rebalancing.
diff --git a/src/StreamingDataFrames/docs/streamingdataframe.md b/src/StreamingDataFrames/docs/streamingdataframe.md
index 09ac8515b..817a670d9 100644
--- a/src/StreamingDataFrames/docs/streamingdataframe.md
+++ b/src/StreamingDataFrames/docs/streamingdataframe.md
@@ -1,18 +1,19 @@
# `StreamingDataFrame`: Detailed Overview
-`StreamingDataFrame` and `Column` are the primary interface to define the stream processing pipelines.
-Changes to instances of `StreamingDataFrame` and `Column` update the processing pipeline, but the actual
+`StreamingDataFrame` and `StreamingSeries` are the primary objects to define the stream processing pipelines.
+
+Changes to instances of `StreamingDataFrame` and `StreamingSeries` update the processing pipeline, but the actual
data changes happen only when it's executed via `Application.run()`
Example:
```python
-from quixstreams import Application, MessageContext, State
+from quixstreams import Application, State
# Define an application
app = Application(
- broker_address="localhost:9092", # Kafka broker address
- consumer_group="consumer-group-name", # Kafka consumer group
+ broker_address="localhost:9092", # Kafka broker address
+ consumer_group="consumer", # Kafka consumer group
)
# Define the input and output topics. By default, the "json" serialization will be used
@@ -20,27 +21,28 @@ input_topic = app.topic("my_input_topic")
output_topic = app.topic("my_output_topic")
-def add_one(data: dict, ctx: MessageContext):
+def add_one(data: dict):
for field, value in data.items():
if isinstance(value, int):
data[field] += 1
-
-def count(data: dict, ctx: MessageContext, state: State):
+
+def count(data: dict, state: State):
# Get a value from state for the current Kafka message key
- total = state.get('total', default=0)
+ total = state.get("total", default=0)
total += 1
# Set a value back to the state
- state.set('total')
- # Update your message data with a value from the state
- data['total'] = total
+ state.set("total", total)
+ # Return result
+ return total
+
# Create a StreamingDataFrame instance
# StreamingDataFrame is a primary interface to define the message processing pipeline
sdf = app.dataframe(topic=input_topic)
# Print the incoming messages
-sdf = sdf.apply(lambda value, ctx: print('Received a message:', value))
+sdf = sdf.update(lambda value: print("Received a message:", value))
# Select fields from incoming message
sdf = sdf[["field_0", "field_2", "field_8"]]
@@ -48,51 +50,57 @@ sdf = sdf[["field_0", "field_2", "field_8"]]
# Filter only messages with "field_0" > 10 and "field_2" != "test"
sdf = sdf[(sdf["field_0"] > 10) & (sdf["field_2"] != "test")]
+# You may also use a custom function to filter data
+sdf = sdf.filter(lambda v: v["field_0"] > 10 and v["field_2"] != "test")
+
# Apply custom function to transform the message
sdf = sdf.apply(add_one)
-# Apply a stateful function in persist data into the state store
-sdf = sdf.apply(count, stateful=True)
+# Use a stateful function in persist data into the state store
+# and update the message value
+sdf["total"] = sdf.apply(count, stateful=True)
# Print the result before producing it
-sdf = sdf.apply(lambda value, ctx: print('Producing a message:', value))
+sdf = sdf.update(lambda value: print("Producing a message:", value))
# Produce the result to the output topic
sdf = sdf.to_topic(output_topic)
```
-## Interacting with `Rows`
+## Data Types
+
+`StreamingDataFrame` is agnostic of data types passed to it during processing.
-Under the hood, `StreamingDataFrame` is manipulating Kafka messages via `Row` objects.
+All functions passed to `StreamingDataFrame` will receive data in the same format as it's deserialized
+by the `Topic` object.
-Simplified, a `Row` is effectively a dictionary of the Kafka message
-value, with each key equivalent to a dataframe column name.
+It can also produce any types back to Kafka as long as the value can be serialized
+to bytes by `value_serializer` passed to the output `Topic` object.
-`StreamingDataFrame` interacts with `Row` objects via the Pandas-like
-interface and user-defined functions passed to `StreamingDataFrame.apply()`
+The column access like `dataframe["column"]` is supported only for dictionaries.
-## Accessing Fields/Columns
+## Accessing Fields via StreamingSeries
In typical Pandas dataframe fashion, you can access a column:
```python
-sdf["field_a"] # "my_str"
+sdf["field_a"] # returns a StreamingSeries with value from field "field_a"
```
Typically, this is done in combination with other operations.
-You can also access nested objects (dicts, lists, etc):
+You can also access nested objects (dicts, lists, etc.):
```python
-sdf["field_c"][2] # 3
+sdf["field_c"][2] # returns a StreamingSeries with value of "field_c[2]" if "field_c" is a collection
```
-## Performing Operations with Columns
+## Performing Operations with StreamingSeries
You can do almost any basic operations or
comparisons with columns, assuming validity:
@@ -103,167 +111,270 @@ sdf["field_a"] / sdf["field_b"]
sdf["field_a"] | sdf["field_b"]
sdf["field_a"] & sdf["field_b"]
sdf["field_a"].isnull()
-sdf["field_a"].contains('string')
+sdf["field_a"].contains("string")
sdf["field_a"] != "woo"
```
-## Assigning New Columns
+## Assigning New Fields
-You may add new columns from the results of numerous other
+You may add new fields from the results of numerous other
operations:
```python
-sdf["a_new_int_field"] = 5
+# Set dictionary key "a_new_int_field" to 5
+sdf["a_new_int_field"] = 5
+
+# Set key "a_new_str_field" to a sum of "field_a" and "field_b"
sdf["a_new_str_field"] = sdf["field_a"] + sdf["field_b"]
-sdf["another_new_field"] = sdf["a_new_str_field"].apply(lambda value, ctx: value + "another")
-```
-See [the `.apply()` section](#user-defined-functions-apply) for more information on how that works.
+# Do the same but with a custom function applied to a whole message value
+sdf["another_new_field"] = sdf.apply(lambda value: value['field_a'] + value['field_b'])
+# Use a custom function on StreamingSeries to update key "another_new_field"
+sdf["another_new_field"] = sdf["a_new_str_field"].apply(lambda value: value + "another")
+```
-## Selecting only certain Columns
+## Selecting Columns
-In typical Pandas fashion, you can take a subset of columns:
+In typical `pandas` fashion, you can take a subset of columns:
```python
-# remove "field_d"
+# Select only fields "field_a", "field_b", "field_c"
sdf = sdf[["field_a", "field_b", "field_c"]]
```
-## Filtering Rows (messages)
-
-"Filtering" is a very specific concept and operation with `StreamingDataFrames`.
+## Filtering
-In practice, it functions similarly to how you might filter rows with Pandas DataFrames
-with conditionals.
+`StreamingDataFrame` provides a similar `pandas`-like API to filter data.
-When a "column" reference is actually another operation, it will be treated
-as a "filter". If that result is empty or None, the row is now "filtered".
+To filter data you may use:
+- Conditional expressions with `StreamingSeries` (if underlying message value is deserialized as a dictionary)
+- Custom functions like `sdf[sdf.apply(lambda v: v['field'] < 0)]`
+- Custom functions like `sdf = sdf.filter(lambda v: v['field'] < 0)`
-When filtered, ALL downstream functions for that row are now skipped,
+When the value is filtered from the stream, ALL downstream functions for that value are now skipped,
_including Kafka-related operations like producing_.
+Example:
+
```python
-# This would continue onward
-sdf = sdf[sdf["field_a"] == "my_str"]
+# Filter values using `StreamingSeries` expressions
+sdf = sdf[(sdf["field_a"] == 'my_string') | (sdf['field_b'] > 0)]
-# This would filter the row, skipping further functions
-sdf = sdf[(sdf["field_a"] != "woo") & (sdf["field_c"][0] > 100)]
+# Filter values using `StreamingDataFrame.apply()`
+sdf = sdf[sdf.apply(lambda value: value > 0)]
+
+# Filter values using `StreamingDataFrame.filter()`
+sdf = sdf.filter(lambda value: value >0)
```
+
+## Using Custom Functions: `.apply()`, `.update()` and `.filter()`
+
+`StreamingDataFrame` provides a flexible mechanism to transform and filter data using
+simple python functions via `.apply()`, `.update()` and `.filter()` methods.
+
+All three methods accept 2 arguments:
+- A function to apply.
+A stateless function should accept only one argument - value.
+A stateful function should accept only two argument - value and `State`.
+
+- A `stateful` flag which can be `True` or `False` (default - `False`).
+By passing `stateful=True`, you inform a `StreamingDataFrame` to pass an extra argument of type `State` to your function
+to perform stateful operations.
-## User Defined Functions: `.apply()`
+Read on for more details about each method.
-Should you need more advanced transformations, `.apply()` allows you
-to use any python function to operate on your row.
-When used on a `StreamingDataFrame`, your function must accept 2 ordered arguments:
-- a current Row value (as a dictionary)
-- an instance of `MessageContext` that allows you to access other message metadata (key, partition, timestamap, etc).
+### `StreamingDataFrame.apply()`
+Use `.apply()` when you need to generate a new value based on the input.
+
+When using `.apply()`, the result of the function will always be propagated downstream and will become an input for the next functions.
+
+Although `.apply()` can mutate the input, it's discouraged, and `.update()` method should be used instead.
+
+Example:
+```python
+# Return a new value based on input
+sdf = sdf.apply(lambda value: value + 1)
+```
+
+There are 2 other use cases for `.apply()`:
+1. `StreamingDataFrame.apply()` can be used to assign new keys to the value if the value is a dictionary:
+```python
+# Set a key "field_a" to a sum of "field_b" and "field_c"
+sdf['field_a'] = sdf.apply(lambda value: value['field_b'] + value['field_c'])
+```
+
+2. `StreamingDataFrame.apply()` can be used to filter values.
+
+In this case, the result of the passed function is interpreted as `bool`:
+```python
+# Filter values where sum of "field_b" and "field_c" is greater than 0
+sdf = sdf[sdf.apply(lambda value: (value['field_b'] + value['field_c']) > 0)]
+```
-Consequently, your function **MUST either** _alter this dict in-place_
-**OR** _return a dictionary_ to directly replace the current data with.
+### `StreamingDataFrame.update()`
+Use `.update()` when you need to mutate the input value in place or to perform a side effect without generating a new value.
+For example, use to print data to the console or to simply update the counter in the State.
-For example:
+The result of a function passed to `.update()` is always ignored, and its input will be propagated downstream instead.
+Examples:
```python
-# in place example
-def in_place(value: dict, ctx: MessageContext):
- value['key'] = 'value'
-
-sdf = sdf.apply(in_place)
-
-# replacement example
-def new_data(value: dict, ctx: MessageContext):
- new_value = {'key': value['key']}
- return new_value
-
-sdf = sdf.apply(new_data)
+# Mutate a list by appending a new item to it
+# The updated list will be passed downstream
+sdf = sdf.update(lambda value: value.append(1))
+
+# Use .update() to print a value to the console
+sdf = sdf.update(lambda value: print("Received value: ", value))
```
+
+### `StreamingDataFrame.filter()`
+Use `.filter()` to filter values based on entire message content.
+The result of a function passed to `.filter()` is interpreted as boolean.
+```python
+# Filter out values with "field_a" <= 0
+sdf = sdf.filter(lambda value: value['field_a'] > 0)
-The `.apply()` function is also valid for columns, but rather than providing a
-dictionary, it instead uses the column value, and the function must return a value.
+# Filter out values where "field_a" is False
+sdf = sdf.filter(lambda value: value['field_a'])
+```
+You may also achieve the same result with `sdf[sdf.apply()]` syntax:
```python
-sdf["new_field"] = sdf["field_a"].apply(lambda value, ctx: value + "-add_me")
+# Filter out values with "field_a" <= 0 using .apply() syntax
+sdf = sdf[sdf.apply(lambda value: value['field_a'] > 0)]
```
-NOTE: Every `.apply()` is a _temporary_ state change, but the result can be assigned.
-So, in the above example, `field_a` remains `my_str`, but `new_field == my_str-add_me`
-as desired.
+
+
+### Using custom functions with StreamingSeries
+The `.apply()` function is also valid for `StreamingSeries`.
+But instead of receiving an entire message value, it will receive only a value of the particular key:
+
+```python
+# Generate a new value based on "field_b" and assign it back to "field_a"
+sdf['field_a'] = sdf['field_b'].apply(lambda field_b: field_b.strip())
+```
+
+It follows the same rules as `StreamingDataFrame.apply()`, and the result of the function
+will be returned as is.
+
+`StreamingSeries` supports only `.apply()` method.
-### Stateful Processing with `.apply()`
+## Stateful Processing with Custom Functions
+
+If you want to use persistent state during processing, you can access the state for a given _message key_ via
+passing `stateful=True` to `StreamingDataFrame.apply()`, `StreamingDataFrame.update()` or `StreamingDataFrame.filter()`.
-If you want to use persistent state during processing, you can access the state for a given row via
-a keyword argument `stateful=True`, and your function should accept a third `State` object as
-an argument (you can just call it something like `state`).
+In this case, your custom function should accept a second argument of type `State`.
-When your function has access to state, it will receive a `State` object, which can do:
+The `State` object provides a minimal API to worked with persistent state sore:
- `.get(key, default=None)`
- `.set(key, value)`
- `.delete(key)`
- `.exists(key)`
-`Key` and value can be anything, and you can have any number of keys.
-NOTE: `key` is unrelated to the Kafka message key, which is handled behind the scenes.
+You may treat `State` as a dictionary-like structure.
+
+`Key` and `value` can be of any type as long as they are serializable to JSON (a default serialization format for the State).
+
+You may easily store strings, numbers, lists, tuples and dictionaries.
+
+
+
+Under the hood, the `key` is always prefixed by the actual Kafka message key to ensure
+that messages with different keys don't have access to the same state.
+
```python
-from quixstreams import MessageContext, State
+from quixstreams import State
-def edit_data(row, ctx: MessageContext, state: State):
- msg_max = len(row["field_c"])
+# Update current value using stateful operations
+
+def edit_data(value, state: State):
+ msg_max = len(value["field_c"])
current_max = state.get("current_len_max")
if current_max < msg_max:
state.set("current_len_max", msg_max)
current_max = msg_max
- row["len_max"] = current_max
- return row
+ value["len_max"] = current_max
-sdf = sdf.apply(edit_data, stateful=True)
+sdf = sdf.update(edit_data, stateful=True)
```
For more information about stateful processing in general, see
[**Stateful Applications**](./stateful_processing.md).
-
+## Accessing the Kafka Message Keys and Metadata
+`quixstreams` provides access to the metadata of the current Kafka message via `quixstreams.context` module.
-## Producing to Topics: `.to_topic()`
+Information like message key, topic, partition, offset, timestamp and more is stored globally in `MessageContext` object,
+and it's updated on each incoming message.
-To send the current state of the `StreamingDataFrame` to a topic, simply call
-`to_topic` with a `Topic` instance generated from `Application.topic()`
+To get the current message key, use `quixstreams.message_key` function:
+
+```python
+from quixstreams import message_key
+sdf = sdf.apply(lambda value: 1 if message_key() == b'1' else 0)
+```
+
+To get the whole `MessageContext` object with all attributes including keys, use `quixstreams.message_context`
+```python
+from quixstreams import message_context
+
+# Get current message timestamp and set it to a "timestamp" key
+sdf['timestamp'] = sdf.apply(lambda value: message_context().timestamp.milliseconds)
+```
+
+Both `quixstreams.message_key()` and `quixstreams.message_context()` should be called
+only from the custom functions during processing.
+
+
+## Producing to Topics: `StreamingDataFrame.to_topic()`
+
+To send the current value of the `StreamingDataFrame` to a topic, simply call
+`.to_topic()` with a `Topic` instance generated from `Application.topic()`
as an argument.
To change the outgoing message key (which defaults to the current consumed key),
-you can optionally provide a key function, which operates similarly to the `.apply()`
-function with a `row` (dict) and `ctx` argument, and returns a desired
-(serializable) key.
+you can optionally provide a key function, which operates similarly to the `.apply()`.
+
+It should accept a message value and return a new key.
+
+The returned key must be compatible with `key_serializer` provided to the `Topic` object.
```python
from quixstreams import Application
-app = Application(...)
-output_topic = app.topic("my_output_topic")
+app = Application(broker_address='localhost:9092', consumer_group='consumer')
+
+# Incoming key is deserialized to string
+input_topic = app.topic("input", key_deserializer='str')
+# Outgoing key will be serialized as a string too
+output_topic = app.topic("my_output_topic", key_serializer='str')
# Producing a new message to a topic with the same key
-sdf = sdf.to_topic(other_output_topic)
+sdf = sdf.to_topic(output_topic)
-# Producing a new message to a topic with a new key
-sdf = sdf.to_topic(output_topic, key=lambda value, ctx: ctx.key + value['field'])
+# Generate a new message key based on "value['field']" assuming it is a string
+sdf = sdf.to_topic(output_topic, key=lambda value: str(value["field"]))
```
diff --git a/src/StreamingDataFrames/examples/README.md b/src/StreamingDataFrames/examples/README.md
index 61aa12243..4e0f43681 100644
--- a/src/StreamingDataFrames/examples/README.md
+++ b/src/StreamingDataFrames/examples/README.md
@@ -1,7 +1,7 @@
-# `StreamingDataFrames` Examples
+# Quix Streams Examples
-This folder contains a few examples/boiler-plate applications to get you started with
-the `StreamingDataFrames` library.
+This folder contains a few boilerplate applications to get you started with
+the Quix Streams library.
## Running an Example
diff --git a/src/StreamingDataFrames/examples/bank_example/README.md b/src/StreamingDataFrames/examples/bank_example/README.md
index 6c23d555d..5e5ab284a 100644
--- a/src/StreamingDataFrames/examples/bank_example/README.md
+++ b/src/StreamingDataFrames/examples/bank_example/README.md
@@ -18,8 +18,8 @@ account and the purchase attempt was above a certain cost (we don't want to spam
This example showcases:
- How to use multiple Quix kafka applications together
- Producer
- - Consumer via `streamingdataframes.Application` (consumes and produces)
- - Basic usage of the `dataframe` object
+ - Consumer via `quixstreams.Application` (consumes and produces)
+ - Basic usage of the `StreamingDataFrame` object
- Using different serializations/data structures
- json
- Quix serializers (more intended for Quix platform use)
diff --git a/src/StreamingDataFrames/examples/bank_example/json_version/consumer.py b/src/StreamingDataFrames/examples/bank_example/json_version/consumer.py
index 86ea20642..b423a0ccc 100644
--- a/src/StreamingDataFrames/examples/bank_example/json_version/consumer.py
+++ b/src/StreamingDataFrames/examples/bank_example/json_version/consumer.py
@@ -9,36 +9,34 @@
from dotenv import load_dotenv
-from quixstreams import Application, MessageContext, State
+from quixstreams import Application, State, message_key
load_dotenv("./env_vars.env")
-def count_transactions(value: dict, ctx: MessageContext, state: State):
+def count_transactions(value: dict, state: State):
"""
Track the number of transactions using persistent state
:param value: message value
- :param ctx: message context with key, timestamp and other Kafka message metadata
:param state: instance of State store
"""
total = state.get("total_transactions", 0)
total += 1
state.set("total_transactions", total)
- value["total_transactions"] = total
+ return total
-def uppercase_source(value: dict, ctx: MessageContext):
+def uppercase_source(value: dict):
"""
Upper-case field "transaction_source" for each processed message
:param value: message value, a dictionary with all deserialized message data
- :param ctx: message context, it contains message metadata like key, topic, timestamp
- etc.
+
:return: this function must either return None or a new dictionary
"""
- print(f'Processing message with key "{ctx.key}"')
+ print(f'Processing message with key "{message_key()}"')
value["transaction_source"] = value["transaction_source"].upper()
return value
@@ -62,28 +60,25 @@ def uppercase_source(value: dict, ctx: MessageContext):
sdf = app.dataframe(input_topic)
# Filter only messages with "account_class" == "Gold" and "transaction_amount" >= 1000
-sdf = sdf[
- (sdf["account_class"] == "Gold")
- & (sdf["transaction_amount"].apply(lambda x, ctx: abs(x)) >= 1000)
-]
+sdf = sdf[(sdf["account_class"] == "Gold") & (sdf["transaction_amount"].abs() >= 1000)]
# Drop all fields except the ones we need
sdf = sdf[["account_id", "transaction_amount", "transaction_source"]]
-# Update the total number of transactions in state
-sdf = sdf.apply(count_transactions, stateful=True)
+# Update the total number of transactions in state and save result to the message
+sdf["total_transactions"] = sdf.apply(count_transactions, stateful=True)
# Transform field "transaction_source" to upper-case using a custom function
-sdf = sdf.apply(uppercase_source)
+sdf["transaction_source"] = sdf["transaction_source"].apply(lambda v: v.upper())
# Add a new field with a notification text
sdf["customer_notification"] = "A high cost purchase was attempted"
# Print the transformed message to the console
-sdf = sdf.apply(lambda val, ctx: print(f"Sending update: {val}"))
+sdf = sdf.update(lambda val: print(f"Sending update: {val}"))
# Send the message to the output topic
-sdf.to_topic(output_topic)
+sdf = sdf.to_topic(output_topic)
if __name__ == "__main__":
# Start message processing
diff --git a/src/StreamingDataFrames/examples/bank_example/quix_platform_version/consumer.py b/src/StreamingDataFrames/examples/bank_example/quix_platform_version/consumer.py
index 4a6ea30bf..004fdba7d 100644
--- a/src/StreamingDataFrames/examples/bank_example/quix_platform_version/consumer.py
+++ b/src/StreamingDataFrames/examples/bank_example/quix_platform_version/consumer.py
@@ -8,40 +8,37 @@
from dotenv import load_dotenv
-from quixstreams import Application, MessageContext, State
+from quixstreams import Application, State, message_key
+
# Reminder: the platform will have these values available by default so loading the
# environment would be unnecessary there.
load_dotenv("./bank_example/quix_platform_version/quix_vars.env")
-def uppercase_source(value: dict, ctx: MessageContext):
+def uppercase_source(value: dict):
"""
Upper-case field "transaction_source" for each processed message
:param value: message value, a dictionary with all deserialized message data
- :param ctx: message context, it contains message metadata like key, topic, timestamp
- etc.
:return: this function must either return None or a new dictionary
"""
- print(f'Processing message with key "{ctx.key}"')
+ print(f'Processing message with key "{message_key()}"')
value["transaction_source"] = value["transaction_source"].upper()
- return value
-def count_transactions(value: dict, ctx: MessageContext, state: State):
+def count_transactions(value: dict, state: State):
"""
Track the number of transactions using persistent state
:param value: message value
- :param ctx: message context with key, timestamp and other Kafka message metadata
:param state: instance of State store
"""
total = state.get("total_transactions", 0)
total += 1
state.set("total_transactions", total)
- value["total_transactions"] = total
+ return total
# Define your application and settings
@@ -61,30 +58,26 @@ def count_transactions(value: dict, ctx: MessageContext, state: State):
# Create a StreamingDataFrame and start building your processing pipeline
sdf = app.dataframe(input_topic)
-
# Filter only messages with "account_class" == "Gold" and "transaction_amount" >= 1000
-sdf = sdf[
- (sdf["account_class"] == "Gold")
- & (sdf["transaction_amount"].apply(lambda x, ctx: abs(x)) >= 1000)
-]
+sdf = sdf[(sdf["account_class"] == "Gold") & (sdf["transaction_amount"].abs() >= 1000)]
# Drop all fields except the ones we need
-sdf = sdf[["account_id", "transaction_amount", "transaction_source"]]
+sdf = sdf[["Timestamp", "account_id", "transaction_amount", "transaction_source"]]
-# Update the total number of transactions in state
-sdf = sdf.apply(count_transactions, stateful=True)
+# Update the total number of transactions in state and save result to the message
+sdf["total_transactions"] = sdf.apply(count_transactions, stateful=True)
# Transform field "transaction_source" to upper-case using a custom function
-sdf = sdf.apply(uppercase_source)
+sdf["transaction_source"] = sdf["transaction_source"].apply(lambda v: v.upper())
# Add a new field with a notification text
sdf["customer_notification"] = "A high cost purchase was attempted"
# Print the transformed message to the console
-sdf = sdf.apply(lambda val, ctx: print(f"Sending update: {val}"))
+sdf = sdf.update(lambda val: print(f"Sending update: {val}"))
# Send the message to the output topic
-sdf.to_topic(output_topic)
+sdf = sdf.to_topic(output_topic)
if __name__ == "__main__":
# Start message processing
diff --git a/src/StreamingDataFrames/examples/bank_example/quix_platform_version/producer.py b/src/StreamingDataFrames/examples/bank_example/quix_platform_version/producer.py
index be93dacde..c8c96da8c 100644
--- a/src/StreamingDataFrames/examples/bank_example/quix_platform_version/producer.py
+++ b/src/StreamingDataFrames/examples/bank_example/quix_platform_version/producer.py
@@ -1,3 +1,4 @@
+import time
import uuid
from random import randint, random, choice
from time import sleep
@@ -9,8 +10,7 @@
QuixTimeseriesSerializer,
SerializationContext,
)
-from quixstreams.models.topics import Topic, TopicCreationConfigs
-from quixstreams.platforms.quix import QuixKafkaConfigsBuilder
+from quixstreams.platforms.quix import QuixKafkaConfigsBuilder, TopicCreationConfigs
load_dotenv("./bank_example/quix_platform_version/quix_vars.env")
@@ -48,6 +48,7 @@
"account_class": "Gold" if account >= 8 else "Silver",
"transaction_amount": randint(-2500, -1),
"transaction_source": choice(retailers),
+ "Timestamp": time.time_ns(),
}
print(f"Producing value {value}")
producer.produce(
diff --git a/src/StreamingDataFrames/quixstreams/__init__.py b/src/StreamingDataFrames/quixstreams/__init__.py
index 201a8d840..ff37d95b4 100644
--- a/src/StreamingDataFrames/quixstreams/__init__.py
+++ b/src/StreamingDataFrames/quixstreams/__init__.py
@@ -1,5 +1,6 @@
from .app import Application
from .models import MessageContext
from .state import State
+from .context import message_context, message_key
__version__ = "2.0alpha2"
diff --git a/src/StreamingDataFrames/quixstreams/app.py b/src/StreamingDataFrames/quixstreams/app.py
index 54259cdb6..ebb5034ef 100644
--- a/src/StreamingDataFrames/quixstreams/app.py
+++ b/src/StreamingDataFrames/quixstreams/app.py
@@ -6,6 +6,8 @@
from confluent_kafka import TopicPartition
from typing_extensions import Self
+from .context import set_message_context, copy_context
+from .core.stream import Filtered
from .dataframe import StreamingDataFrame
from .error_callbacks import (
ConsumerErrorCallback,
@@ -335,7 +337,6 @@ def dataframe(
:return: `StreamingDataFrame` object
"""
sdf = StreamingDataFrame(topic=topic, state_manager=self._state_manager)
- sdf.consumer = self._consumer
sdf.producer = self._producer
return sdf
@@ -404,7 +405,7 @@ def run(
with exit_stack:
# Subscribe to topics in Kafka and start polling
self._consumer.subscribe(
- list(dataframe.topics_in.values()),
+ [dataframe.topic],
on_assign=self._on_assign,
on_revoke=self._on_revoke,
on_lost=self._on_lost,
@@ -412,6 +413,8 @@ def run(
logger.info("Waiting for incoming messages")
# Start polling Kafka for messages and callbacks
self._running = True
+
+ dataframe_compiled = dataframe.compile()
while self._running:
# Serve producer callbacks
self._producer.poll(self._producer_poll_timeout)
@@ -431,13 +434,21 @@ def run(
first_row.partition,
first_row.offset,
)
+ # Create a new contextvars.Context and set the current MessageContext
+ # (it's the same across multiple rows)
+ context = copy_context()
+ context.run(set_message_context, first_row.context)
with start_state_transaction(
topic=topic_name, partition=partition, offset=offset
):
for row in rows:
try:
- dataframe.process(row=row)
+ # Execute StreamingDataFrame in a context
+ context.run(dataframe_compiled, row.value)
+ except Filtered:
+ # The message was filtered by StreamingDataFrame
+ continue
except Exception as exc:
# TODO: This callback might be triggered because of Producer
# errors too because they happen within ".process()"
diff --git a/src/StreamingDataFrames/quixstreams/context.py b/src/StreamingDataFrames/quixstreams/context.py
new file mode 100644
index 000000000..2e4de4dbe
--- /dev/null
+++ b/src/StreamingDataFrames/quixstreams/context.py
@@ -0,0 +1,49 @@
+from contextvars import ContextVar, copy_context
+from typing import Optional, Any
+
+from quixstreams.exceptions import QuixException
+from quixstreams.models.messagecontext import MessageContext
+
+__all__ = (
+ "MessageContextNotSetError",
+ "set_message_context",
+ "message_key",
+ "message_context",
+ "copy_context",
+)
+
+_current_message_context = ContextVar("current_message_context")
+
+
+class MessageContextNotSetError(QuixException):
+ ...
+
+
+def set_message_context(context: Optional[MessageContext]):
+ """
+ Set a MessageContext for the current message in the given `contextvars.Context`
+
+ :param context: instance of `MessageContext`
+ """
+ _current_message_context.set(context)
+
+
+def message_context() -> MessageContext:
+ """
+ Get a MessageContext for the current message
+ :return: instance of `MessageContext`
+ """
+ try:
+ return _current_message_context.get()
+
+ except LookupError:
+ raise MessageContextNotSetError("Message context is not set")
+
+
+def message_key() -> Any:
+ """
+ Get current a message key.
+
+ :return: a deserialized message key
+ """
+ return message_context().key
diff --git a/src/StreamingDataFrames/quixstreams/core/__init__.py b/src/StreamingDataFrames/quixstreams/core/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/StreamingDataFrames/quixstreams/core/stream/__init__.py b/src/StreamingDataFrames/quixstreams/core/stream/__init__.py
new file mode 100644
index 000000000..0c64aa38d
--- /dev/null
+++ b/src/StreamingDataFrames/quixstreams/core/stream/__init__.py
@@ -0,0 +1,2 @@
+from .stream import *
+from .functions import *
diff --git a/src/StreamingDataFrames/quixstreams/core/stream/functions.py b/src/StreamingDataFrames/quixstreams/core/stream/functions.py
new file mode 100644
index 000000000..bcc67be33
--- /dev/null
+++ b/src/StreamingDataFrames/quixstreams/core/stream/functions.py
@@ -0,0 +1,125 @@
+import enum
+import functools
+from typing import TypeVar, Callable, Protocol, Optional
+
+__all__ = (
+ "StreamCallable",
+ "SupportsBool",
+ "Filtered",
+ "Apply",
+ "Update",
+ "Filter",
+ "is_filter_function",
+ "is_update_function",
+ "is_apply_function",
+ "get_stream_function_type",
+)
+
+R = TypeVar("R")
+T = TypeVar("T")
+
+StreamCallable = Callable[[T], R]
+
+
+class SupportsBool(Protocol):
+ def __bool__(self) -> bool:
+ ...
+
+
+class Filtered(Exception):
+ ...
+
+
+class StreamFunctionType(enum.IntEnum):
+ FILTER = 1
+ UPDATE = 2
+ APPLY = 3
+
+
+_STREAM_FUNC_TYPE_ATTR = "__stream_function_type__"
+
+
+def get_stream_function_type(func: StreamCallable) -> Optional[StreamFunctionType]:
+ return getattr(func, _STREAM_FUNC_TYPE_ATTR, None)
+
+
+def set_stream_function_type(func: StreamCallable, type_: StreamFunctionType):
+ setattr(func, _STREAM_FUNC_TYPE_ATTR, type_)
+
+
+def is_filter_function(func: StreamCallable) -> bool:
+ func_type = get_stream_function_type(func)
+ return func_type is not None and func_type == StreamFunctionType.FILTER
+
+
+def is_update_function(func: StreamCallable) -> bool:
+ func_type = get_stream_function_type(func)
+ return func_type is not None and func_type == StreamFunctionType.UPDATE
+
+
+def is_apply_function(func: StreamCallable) -> bool:
+ func_type = get_stream_function_type(func)
+ return func_type is not None and func_type == StreamFunctionType.APPLY
+
+
+def Filter(func: Callable[[T], SupportsBool]) -> StreamCallable:
+ """
+ Wraps function into a "Filter" function.
+ The result of a Filter function is interpreted as boolean.
+ If it's `True`, the input will be return downstream.
+ If it's `False`, the `Filtered` exception will be raised to signal that the
+ value is filtered out.
+
+ :param func: a function to filter value
+ :return: a Filter function
+ """
+
+ def wrapper(value: T) -> T:
+ result = func(value)
+ if not result:
+ raise Filtered()
+ return value
+
+ wrapper = functools.update_wrapper(wrapper=wrapper, wrapped=func)
+ set_stream_function_type(wrapper, StreamFunctionType.FILTER)
+ return wrapper
+
+
+def Update(func: StreamCallable) -> StreamCallable:
+ """
+ Wrap a function into "Update" function.
+
+ The provided function is expected to mutate the value.
+ Its result will always be ignored, and its input is passed
+ downstream.
+
+ :param func: a function to mutate values
+ :return: an Update function
+ """
+
+ def wrapper(value: T) -> T:
+ func(value)
+ return value
+
+ wrapper = functools.update_wrapper(wrapper=wrapper, wrapped=func)
+ set_stream_function_type(wrapper, StreamFunctionType.UPDATE)
+ return wrapper
+
+
+def Apply(func: StreamCallable) -> StreamCallable:
+ """
+ Wrap a function into "Apply" function.
+
+ The provided function is expected to return a new value based on input,
+ and its result will always be passed downstream.
+
+ :param func: a function to generate a new value
+ :return: an Apply function
+ """
+
+ def wrapper(value: T) -> R:
+ return func(value)
+
+ wrapper = functools.update_wrapper(wrapper=wrapper, wrapped=func)
+ set_stream_function_type(wrapper, StreamFunctionType.APPLY)
+ return wrapper
diff --git a/src/StreamingDataFrames/quixstreams/core/stream/stream.py b/src/StreamingDataFrames/quixstreams/core/stream/stream.py
new file mode 100644
index 000000000..87280fe10
--- /dev/null
+++ b/src/StreamingDataFrames/quixstreams/core/stream/stream.py
@@ -0,0 +1,239 @@
+import copy
+import functools
+import itertools
+from typing import Any, List, Callable, Optional, TypeVar
+
+from typing_extensions import Self
+
+from .functions import *
+
+__all__ = ("Stream",)
+
+R = TypeVar("R")
+T = TypeVar("T")
+
+
+class Stream:
+ def __init__(
+ self,
+ func: Optional[Callable[[Any], Any]] = None,
+ parent: Optional[Self] = None,
+ ):
+ """
+ A base class for all streaming operations.
+
+ `Stream` is an abstraction of a function pipeline.
+ Each Stream has a function and a parent (None by default).
+ When adding new function to the stream, it creates a new `Stream` object and
+ sets "parent" to the previous `Stream` to maintain an order of execution.
+
+ Streams supports 3 types of functions:
+ - "Apply" - generate new values based on a previous one.
+ The result of an Apply function is passed downstream to the next functions.
+ - "Update" - update values in-place.
+ The result of an Update function is always ignored, and its input is passed
+ downstream.
+ - "Filter" - to filter values from the Stream.
+ The result of a Filter function is interpreted as boolean.
+ If it's `True`, the input will be passed downstream.
+ If it's `False`, the `Filtered` exception will be raised to signal that the
+ value is filtered out.
+
+ To execute functions accumulated on the `Stream` instance, it must be compiled
+ using `.compile()` method, which returns a function closure executing all the
+ functions in the tree.
+
+ :param func: a function to be called on the stream.
+ It is expected to be wrapped into one of "Apply", "Filter" or "Update" from
+ `quixstreams.core.stream.functions` package.
+ Default - "Apply(lambda v: v)".
+ :param parent: a parent `Stream`
+ """
+ if func is not None and not any(
+ (
+ is_apply_function(func),
+ is_update_function(func),
+ is_filter_function(func),
+ )
+ ):
+ raise ValueError(
+ "Provided function must be either Apply, Filter or Update function"
+ )
+
+ self.func = func if func is not None else Apply(lambda x: x)
+ self.parent = parent
+
+ def __repr__(self) -> str:
+ """
+ Generate a nice repr with all functions in the stream and its parents.
+
+ :return: a string of format
+ "]: | ... >"
+ """
+ tree_funcs = [s.func for s in self.tree()]
+ funcs_repr = " | ".join(
+ (
+ f"<{get_stream_function_type(f).name}: {f.__qualname__}>"
+ for f in tree_funcs
+ )
+ )
+ return f"<{self.__class__.__name__} [{len(tree_funcs)}]: {funcs_repr}>"
+
+ def add_filter(self, func: Callable[[T], R]) -> Self:
+ """
+ Add a function to filter values from the Stream.
+
+ The return value of the function will be interpreted as `bool`.
+ If the function returns `False`-like result, the Stream will raise `Filtered`
+ exception during execution.
+
+ :param func: a function to filter values from the stream
+ :return: a new `Stream` derived from the current one
+ """
+ return self._add(Filter(func))
+
+ def add_apply(self, func: Callable[[T], R]) -> Self:
+ """
+ Add an "apply" function to the Stream.
+
+ The function is supposed to return a new value, which will be passed
+ further during execution.
+
+ :param func: a function to generate a new value
+ :return: a new `Stream` derived from the current one
+ """
+ return self._add(Apply(func))
+
+ def add_update(self, func: Callable[[T], object]) -> Self:
+ """
+ Add an "update" function to the Stream, that will mutate the input value.
+
+ The return of this function will be ignored and its input
+ will be passed downstream.
+
+ :param func: a function to mutate the value
+ :return: a new Stream derived from the current one
+ """
+ return self._add((Update(func)))
+
+ def diff(
+ self,
+ other: "Stream",
+ ) -> Self:
+ """
+ Takes the difference between Streams `self` and `other` based on their last
+ common parent, and returns a new `Stream` that includes only this difference.
+
+ It's impossible to calculate a diff when:
+ - Streams don't have a common parent.
+ - When the `self` Stream already includes all the nodes from
+ the `other` Stream, and the resulting diff is empty.
+
+ :param other: a `Stream` to take a diff from.
+ :raises ValueError: if Streams don't have a common parent
+ or if the diff is empty.
+ :return: new `Stream` instance including all the Streams from the diff
+ """
+ diff = self._diff_from_last_common_parent(other)
+ parent = None
+ head = None
+ for node in diff:
+ # Copy the node to ensure we don't alter the previously created Nodes
+ node = copy.deepcopy(node)
+ node.parent = parent
+ parent = node
+ head = node
+ return head
+
+ def tree(self) -> List[Self]:
+ """
+ Return a list of all parent Streams including the node itself.
+
+ The tree is ordered from child to parent (current node comes first).
+ :return: a list of `Stream` objects
+ """
+ tree_ = [self]
+ node = self
+ while node.parent:
+ tree_.insert(0, node.parent)
+ node = node.parent
+ return tree_
+
+ def compile(
+ self,
+ allow_filters: bool = True,
+ allow_updates: bool = True,
+ ) -> Callable[[T], R]:
+ """
+ Compile a list of functions from this `Stream` and its parents into a single
+ big closure using a "compiler" function.
+
+ Closures are more performant than calling all the functions in the
+ `Stream.tree()` one-by-one.
+
+ :param allow_filters: If False, this function will fail with ValueError if
+ the stream has filter functions in the tree. Default - True.
+ :param allow_updates: If False, this function will fail with ValueError if
+ the stream has update functions in the tree. Default - True.
+
+ :raises ValueError: if disallowed functions are present in the stream tree.
+ """
+ compiled = None
+ tree = self.tree()
+ for node in tree:
+ func = node.func
+ if not allow_updates and is_update_function(func):
+ raise ValueError("Update functions are not allowed in this stream")
+ elif not allow_filters and is_filter_function(func):
+ raise ValueError("Filter functions are not allowed in this stream")
+
+ if compiled is None:
+ compiled = func
+ else:
+ compiled = compiler(func)(compiled)
+
+ return compiled
+
+ def _diff_from_last_common_parent(self, other: Self) -> List[Self]:
+ nodes_self = self.tree()
+ nodes_other = other.tree()
+
+ diff = []
+ last_common_parent = None
+ for node_self, node_other in itertools.zip_longest(nodes_self, nodes_other):
+ if node_self is node_other:
+ last_common_parent = node_other
+ elif node_other is not None:
+ diff.append(node_other)
+
+ if last_common_parent is None:
+ raise ValueError("Common parent not found")
+ if not diff:
+ raise ValueError("The diff is empty")
+ return diff
+
+ def _add(self, func: Callable[[T], R]) -> Self:
+ return self.__class__(func=func, parent=self)
+
+
+def compiler(outer_func: Callable[[T], R]) -> Callable[[T], R]:
+ """
+ A function that wraps two other functions into a closure.
+ It passes the result of thje inner function as an input to the outer function.
+
+ It is used to transform (aka "compile") a list of functions into one large closure
+ like this:
+ ```
+ [func, func, func] -> func(func(func()))
+ ```
+
+ :return: a function with one argument (value)
+ """
+
+ def wrapper(inner_func: Callable[[T], R]):
+ def _wrapper(v: T) -> R:
+ return outer_func(inner_func(v))
+
+ return functools.update_wrapper(_wrapper, inner_func)
+
+ return functools.update_wrapper(wrapper, outer_func)
diff --git a/src/StreamingDataFrames/quixstreams/dataframe/__init__.py b/src/StreamingDataFrames/quixstreams/dataframe/__init__.py
index 61bb62ada..b6f5296e8 100644
--- a/src/StreamingDataFrames/quixstreams/dataframe/__init__.py
+++ b/src/StreamingDataFrames/quixstreams/dataframe/__init__.py
@@ -1,2 +1,2 @@
from .dataframe import StreamingDataFrame
-from .exceptions import *
+from .series import *
diff --git a/src/StreamingDataFrames/quixstreams/dataframe/base.py b/src/StreamingDataFrames/quixstreams/dataframe/base.py
new file mode 100644
index 000000000..de133fab9
--- /dev/null
+++ b/src/StreamingDataFrames/quixstreams/dataframe/base.py
@@ -0,0 +1,20 @@
+import abc
+from typing import Optional, Any
+
+from quixstreams.core.stream import Stream, StreamCallable
+from quixstreams.models.messagecontext import MessageContext
+
+
+class BaseStreaming:
+ @property
+ @abc.abstractmethod
+ def stream(self) -> Stream:
+ ...
+
+ @abc.abstractmethod
+ def compile(self, *args, **kwargs) -> StreamCallable:
+ ...
+
+ @abc.abstractmethod
+ def test(self, value: Any, ctx: Optional[MessageContext] = None) -> Any:
+ ...
diff --git a/src/StreamingDataFrames/quixstreams/dataframe/column.py b/src/StreamingDataFrames/quixstreams/dataframe/column.py
deleted file mode 100644
index 8794c592d..000000000
--- a/src/StreamingDataFrames/quixstreams/dataframe/column.py
+++ /dev/null
@@ -1,140 +0,0 @@
-import operator
-from typing import Optional, Any, Callable, Container
-
-from typing_extensions import Self, TypeAlias, Union
-
-from ..models import Row, MessageContext
-
-ColumnApplier: TypeAlias = Callable[[Any, MessageContext], Any]
-
-__all__ = ("Column", "ColumnApplier")
-
-
-def invert(value):
- if isinstance(value, bool):
- return operator.not_(value)
- else:
- return operator.invert(value)
-
-
-class Column:
- def __init__(
- self,
- col_name: Optional[str] = None,
- _eval_func: Optional[ColumnApplier] = None,
- ):
- self.col_name = col_name
- self._eval_func = _eval_func if _eval_func else lambda row: row[self.col_name]
-
- def __getitem__(self, item: Union[str, int]) -> Self:
- return self.__class__(_eval_func=lambda x: self.eval(x)[item])
-
- def _operation(self, other: Any, op: Callable[[Any, Any], Any]) -> Self:
- return self.__class__(
- _eval_func=lambda x: op(
- self.eval(x), other.eval(x) if isinstance(other, Column) else other
- ),
- )
-
- def eval(self, row: Row) -> Any:
- """
- Execute all the functions accumulated on this Column.
-
- :param row: A Quixstreams Row
- :return: A primitive type
- """
- return self._eval_func(row)
-
- def apply(self, func: ColumnApplier) -> Self:
- """
- Add a callable to the execution list for this column.
-
- The provided callable should accept a single argument, which will be its input.
- The provided callable should similarly return one output, or None
-
- :param func: a callable with one argument and one output
- :return: a new Column with the new callable added
- """
- return Column(_eval_func=lambda row: func(self.eval(row), row.context))
-
- def isin(self, other: Container) -> Self:
- return self._operation(other, lambda a, b: operator.contains(b, a))
-
- def contains(self, other: Any) -> Self:
- return self._operation(other, operator.contains)
-
- def is_(self, other: Any) -> Self:
- """
- Check if column value refers to the same object as `other`
- :param other: object to check for "is"
- :return:
- """
- return self._operation(other, operator.is_)
-
- def isnot(self, other: Any) -> Self:
- """
- Check if column value refers to the same object as `other`
- :param other: object to check for "is"
- :return:
- """
- return self._operation(other, operator.is_not)
-
- def isnull(self) -> Self:
- """
- Check if column value is None
- """
- return self._operation(None, operator.is_)
-
- def notnull(self) -> Self:
- """
- Check if column value is not None
- """
- return self._operation(None, operator.is_not)
-
- def abs(self) -> Self:
- """
- Get absolute value of the Column value
- """
- return self.apply(lambda v, ctx: abs(v))
-
- def __and__(self, other: Any) -> Self:
- return self._operation(other, operator.and_)
-
- def __or__(self, other: Any) -> Self:
- return self._operation(other, operator.or_)
-
- def __mod__(self, other: Any) -> Self:
- return self._operation(other, operator.mod)
-
- def __add__(self, other: Any) -> Self:
- return self._operation(other, operator.add)
-
- def __sub__(self, other: Any) -> Self:
- return self._operation(other, operator.sub)
-
- def __mul__(self, other: Any) -> Self:
- return self._operation(other, operator.mul)
-
- def __truediv__(self, other: Any) -> Self:
- return self._operation(other, operator.truediv)
-
- def __eq__(self, other: Any) -> Self:
- return self._operation(other, operator.eq)
-
- def __ne__(self, other: Any) -> Self:
- return self._operation(other, operator.ne)
-
- def __lt__(self, other: Any) -> Self:
- return self._operation(other, operator.lt)
-
- def __le__(self, other: Any) -> Self:
- return self._operation(other, operator.le)
-
- def __gt__(self, other: Any) -> Self:
- return self._operation(other, operator.gt)
-
- def __ge__(self, other: Any) -> Self:
- return self._operation(other, operator.ge)
-
- def __invert__(self) -> Self:
- return self.__class__(_eval_func=lambda x: invert(self.eval(x)))
diff --git a/src/StreamingDataFrames/quixstreams/dataframe/dataframe.py b/src/StreamingDataFrames/quixstreams/dataframe/dataframe.py
index 79383dd53..6ce43e9d4 100644
--- a/src/StreamingDataFrames/quixstreams/dataframe/dataframe.py
+++ b/src/StreamingDataFrames/quixstreams/dataframe/dataframe.py
@@ -1,229 +1,109 @@
-import uuid
-from typing import Optional, Callable, Union, List, Mapping, Any
-
-from typing_extensions import Self, TypeAlias
-
-from .column import Column
-from .exceptions import InvalidApplyResultType
-from .pipeline import Pipeline
-from ..models import Row, Topic, MessageContext
-from ..rowconsumer import RowConsumerProto
-from ..rowproducer import RowProducerProto
-from ..state import State, StateStoreManager
-
-ApplyFunc: TypeAlias = Callable[
- [dict, MessageContext], Optional[Union[dict, List[dict]]]
-]
-StatefulApplyFunc: TypeAlias = Callable[
- [dict, MessageContext, State], Optional[Union[dict, List[dict]]]
-]
-KeyFunc: TypeAlias = Callable[[dict, MessageContext], Any]
-
-__all__ = ("StreamingDataFrame",)
-
-
-def subset(keys: List[str], row: Row) -> Row:
- row.value = row[keys]
- return row
-
-
-def setitem(k: str, v: Any, row: Row) -> Row:
- row[k] = v.eval(row) if isinstance(v, Column) else v
- return row
-
-
-def apply(
- row: Row,
- func: Union[ApplyFunc, StatefulApplyFunc],
- state_manager: Optional[StateStoreManager] = None,
-) -> Union[Row, List[Row]]:
- # Providing state to the function if state_manager is passed
- if state_manager is not None:
- transaction = state_manager.get_store_transaction()
- # Prefix all the state keys by the message key
- with transaction.with_prefix(prefix=row.key):
- # Pass a State object with an interface limited to the key updates only
- result = func(row.value, row.context, transaction.state)
- else:
- result = func(row.value, row.context)
-
- if result is None and isinstance(row.value, dict):
- # Function returned None, assume it changed the incoming dict in-place
- return row
- if isinstance(result, dict):
- # Function returned dict, assume it's a new value for the Row
- row.value = result
- return row
- raise InvalidApplyResultType(
- f"Only 'dict' or 'NoneType' (in-place modification) allowed, not {type(result)}"
- )
-
-
-class StreamingDataFrame:
- """
- Allows you to define transformations on a kafka message as if it were a Pandas
- DataFrame.
- Currently, it implements a small subset of the Pandas interface, along with
- some differences/accommodations for kafka-specific functionality.
-
- A `StreamingDataFrame` expects to interact with a QuixStreams `Row`, which is
- interacted with like a dictionary.
-
- Unlike pandas, you will not get an immediate output from any given operation;
- instead, the command is permanently added to the `StreamingDataFrame`'s
- "pipeline". You can then execute this pipeline indefinitely on a `Row` like so:
-
- ```
- df = StreamingDataframe(topic)
- df = df.apply(lambda row: row) # do stuff; must return the row back!
- for row_obj in [row_0, row_1]:
- print(df.process(row_obj))
- ```
-
- Note that just like Pandas, you can "filter" out rows with your operations, like:
- ```
- df = df[df['column_b'] >= 5]
- ```
- If a processing step nulls the Row in some way, all further processing on that
- row (including kafka operations, besides committing) will be skipped.
-
- There is a :class:`quixstreams.app.Application` class that can manage
- the kafka-specific dependencies of `StreamingDataFrame`;
- it is recommended you hand your `StreamingDataFrame` instance to an `Application`
- instance when interacting with Kafka.
-
- Below is a larger example of a `StreamingDataFrame`
- (that you'd hand to an `Application`):
-
- ```
- # Define your processing steps
- # Remove column_a, add 1 to columns b and c, skip row if b+1 >= 5, else publish row
- df = StreamingDataframe()
- df = df[['column_b', 'column_c']]
- df = df.apply(lambda row: row[key] + 1 if key in ['column_b', 'column_c'])
- df = df[df['column_b'] + 1] >= 5]
- df.to_topic('my_output_topic')
-
- # Incomplete Rows to showcase what data you are actually interacting with
- record_0 = Row(value={'column_a': 'a_string', 'column_b': 3, 'column_c': 5})
- record_1 = Row(value={'column_a': 'a_string', 'column_b': 1, 'column_c': 10})
-
- # process records
- df.process(record_0) # produces {'column_b': 4, 'column_c': 6} to "my_output_topic"
- df.process(record_1) # filters row, does NOT produce to "my_output_topic"
- ```
- """
-
+import contextvars
+import functools
+import operator
+from typing import Optional, Callable, Union, List, TypeVar, Any
+
+from typing_extensions import Self
+
+from quixstreams.context import (
+ message_context,
+ set_message_context,
+ message_key,
+)
+from quixstreams.core.stream import StreamCallable, Stream
+from quixstreams.models import Topic, Row, MessageContext
+from quixstreams.rowproducer import RowProducerProto
+from quixstreams.state import StateStoreManager, State
+from .base import BaseStreaming
+from .series import StreamingSeries
+
+T = TypeVar("T")
+R = TypeVar("R")
+DataFrameFunc = Callable[[T], R]
+DataFrameStatefulFunc = Callable[[T, State], R]
+
+
+class StreamingDataFrame(BaseStreaming):
def __init__(
self,
topic: Topic,
state_manager: StateStoreManager,
+ stream: Optional[Stream] = None,
):
- self._id = str(uuid.uuid4())
- self._pipeline = Pipeline(_id=self.id)
- self._real_consumer: Optional[RowConsumerProto] = None
+ self._stream: Stream = stream or Stream()
+ self._topic = topic
self._real_producer: Optional[RowProducerProto] = None
- self._topics_in = {topic.name: topic}
- self._topics_out = {}
self._state_manager = state_manager
+ @property
+ def stream(self) -> Stream:
+ return self._stream
+
+ @property
+ def topic(self) -> Topic:
+ return self._topic
+
def apply(
- self,
- func: Union[ApplyFunc, StatefulApplyFunc],
- stateful: bool = False,
+ self, func: Union[DataFrameFunc, DataFrameStatefulFunc], stateful: bool = False
) -> Self:
"""
- Apply a custom function with current value
- and :py:class:`quixstreams.models.context.MessageContext`
- as the expected input.
-
- :param func: a callable which accepts 2 arguments:
- - value - a dict with fields and values for the current Row
- - a context - an instance of :py:class:`quixstreams.models.context.MessageContext`
- which contains message metadata like key, timestamp, partition,
- and more.
-
- .. note:: if `stateful=True` is passed, a third argument of type `State`
- will be provided to the function.
+ Apply a function to transform the value and return a new value.
- The custom function may return:
- - a new dict to replace the current Row value
- - `None` to modify the current Row value in-place
+ The result will be passed downstream as an input value.
- :param stateful: if `True`, the function will be provided with 3rd argument
+ :param func: a function to apply
+ :param stateful: if `True`, the function will be provided with a second argument
of type `State` to perform stateful operations.
-
- :return: current instance of `StreamingDataFrame`
"""
-
if stateful:
- # Register the default store for each input topic
- for topic in self._topics_in.values():
- self._state_manager.register_store(topic_name=topic.name)
- return self._apply(
- lambda row: apply(row, func, state_manager=self._state_manager)
- )
- return self._apply(lambda row: apply(row, func))
+ self._register_store()
+ func = _as_stateful(func=func, state_manager=self._state_manager)
- def process(self, row: Row) -> Optional[Union[Row, List[Row]]]:
- """
- Execute the previously defined StreamingDataframe operations on a provided Row.
- :param row: a QuixStreams Row object
- :return: Row, list of Rows, or None (if filtered)
- """
- return self._pipeline.process(row)
+ stream = self.stream.add_apply(func)
+ return self._clone(stream=stream)
- def to_topic(self, topic: Topic, key: Optional[KeyFunc] = None) -> Self:
+ def update(
+ self, func: Union[DataFrameFunc, DataFrameStatefulFunc], stateful: bool = False
+ ) -> Self:
"""
- Produce a row to a desired topic.
- Note that a producer must be assigned on the `StreamingDataFrame` if not using
- :class:`quixstreams.app.Application` class to facilitate
- the execution of StreamingDataFrame.
+ Apply a function to mutate value in-place or to perform a side effect
+ that doesn't update the value (e.g. print a value to the console).
- :param topic: A QuixStreams `Topic`
- :param key: a callable to generate a new message key, optional.
- If passed, the return type of this callable must be serializable
- by `key_serializer` defined for this Topic object.
- By default, the current message key will be used.
+ The result of the function will be ignored, and the original value will be
+ passed downstream.
- :return: self (StreamingDataFrame)
+ :param func: function to update value
+ :param stateful: if `True`, the function will be provided with a second argument
+ of type `State` to perform stateful operations.
"""
- self._topics_out[topic.name] = topic
- return self._apply(
- lambda row: self._produce(
- topic, row, key=key(row.value, row.context) if key else None
- )
- )
+ if stateful:
+ self._register_store()
+ func = _as_stateful(func=func, state_manager=self._state_manager)
- @property
- def id(self) -> str:
- return self._id
+ stream = self.stream.add_update(func)
+ return self._clone(stream=stream)
- @property
- def topics_in(self) -> Mapping[str, Topic]:
- """
- Get a mapping with Topics for the StreamingDataFrame input topics
- :return: dict of {: }
+ def filter(
+ self, func: Union[DataFrameFunc, DataFrameStatefulFunc], stateful: bool = False
+ ) -> Self:
"""
- return self._topics_in
+ Filter value using provided function.
- @property
- def topics_out(self) -> Mapping[str, Topic]:
- """
- Get a mapping with Topics for the StreamingDataFrame output topics
- :return: dict of {: }
+ If the function returns True-like value, the original value will be
+ passed downstream.
+ Otherwise, the `Filtered` exception will be raised.
+
+ :param func: function to filter value
+ :param stateful: if `True`, the function will be provided with second argument
+ of type `State` to perform stateful operations.
"""
- return self._topics_out
- @property
- def consumer(self) -> RowConsumerProto:
- if self._real_consumer is None:
- raise RuntimeError("Consumer instance has not been provided")
- return self._real_consumer
+ if stateful:
+ self._register_store()
+ func = _as_stateful(func=func, state_manager=self._state_manager)
- @consumer.setter
- def consumer(self, consumer: RowConsumerProto):
- self._real_consumer = consumer
+ stream = self.stream.add_filter(func)
+ return self._clone(stream=stream)
@property
def producer(self) -> RowProducerProto:
@@ -236,7 +116,7 @@ def producer(self, producer: RowProducerProto):
self._real_producer = producer
@staticmethod
- def contains(key: str) -> Column:
+ def contains(key: str) -> StreamingSeries:
"""
Check if the key is present in the Row value.
@@ -249,38 +129,127 @@ def contains(key: str) -> Column:
# This would add a new column 'has_column' which contains boolean values
# indicating the presence of 'column_x' in each row.
"""
- return Column(_eval_func=lambda row: key in row.keys())
+ return StreamingSeries.from_func(lambda value: key in value)
+
+ def to_topic(
+ self, topic: Topic, key: Optional[Callable[[object], object]] = None
+ ) -> Self:
+ """
+ Produce value to the topic.
+
+ .. note:: A `RowProducer` instance must be assigned to
+ `StreamingDataFrame.producer` if not using :class:`quixstreams.app.Application`
+ to facilitate the execution of StreamingDataFrame.
+
+ :param topic: instance of `Topic`
+ :param key: a callable to generate a new message key, optional.
+ If passed, the return type of this callable must be serializable
+ by `key_serializer` defined for this Topic object.
+ By default, the current message key will be used.
+
+ """
+ return self.update(
+ lambda value: self._produce(topic, value, key=key(value) if key else None)
+ )
+
+ def compile(self) -> StreamCallable:
+ """
+ Compile all functions of this StreamingDataFrame into one big closure.
+
+ Closures are more performant than calling all the functions in the
+ `StreamingDataFrame` one-by-one.
+
+ :return: a function that accepts "value"
+ and returns a result of StreamingDataFrame
+ """
+ return self.stream.compile()
+
+ def test(self, value: object, ctx: Optional[MessageContext] = None) -> Any:
+ """
+ A shorthand to test `StreamingDataFrame` with provided value
+ and `MessageContext`.
+
+ :param value: value to pass through `StreamingDataFrame`
+ :param ctx: instance of `MessageContext`, optional.
+ Provide it if the StreamingDataFrame instance calls `to_topic()`,
+ has stateful functions or functions calling `get_current_key()`.
+ Default - `None`.
+
+ :return: result of `StreamingDataFrame`
+ """
+ context = contextvars.copy_context()
+ context.run(set_message_context, ctx)
+ compiled = self.compile()
+ return context.run(compiled, value)
+
+ def _clone(self, stream: Stream) -> Self:
+ clone = self.__class__(
+ stream=stream, topic=self._topic, state_manager=self._state_manager
+ )
+ if self._real_producer is not None:
+ clone.producer = self._real_producer
+ return clone
+
+ def _produce(self, topic: Topic, value: object, key: Optional[object] = None):
+ ctx = message_context()
+ key = key or ctx.key
+ row = Row(value=value, context=ctx) # noqa
+ self.producer.produce_row(row, topic, key=key)
- def __setitem__(self, key: str, value: Any):
- self._apply(lambda row: setitem(key, value, row))
+ def _register_store(self):
+ """
+ Register the default store for input topic in StateStoreManager
+ """
+ self._state_manager.register_store(topic_name=self._topic.name)
+
+ def __setitem__(self, key, value: Union[Self, object]):
+ if isinstance(value, self.__class__):
+ diff = self.stream.diff(value.stream)
+ diff_compiled = diff.compile(allow_filters=False, allow_updates=False)
+ stream = self.stream.add_update(
+ lambda v: operator.setitem(v, key, diff_compiled(v))
+ )
+ elif isinstance(value, StreamingSeries):
+ value_compiled = value.compile(allow_filters=False, allow_updates=False)
+ stream = self.stream.add_update(
+ lambda v: operator.setitem(v, key, value_compiled(v))
+ )
+ else:
+ stream = self.stream.add_update(lambda v: operator.setitem(v, key, value))
+ self._stream = stream
def __getitem__(
- self, item: Union[str, List[str], Column, Self]
- ) -> Union[Column, Self]:
- if isinstance(item, Column):
- return self._apply(lambda row: row if item.eval(row) else None)
+ self, item: Union[str, List[str], StreamingSeries, Self]
+ ) -> Union[Self, StreamingSeries]:
+ if isinstance(item, StreamingSeries):
+ # Filter SDF based on StreamingSeries
+ item_compiled = item.compile(allow_filters=False, allow_updates=False)
+ return self.filter(lambda v: item_compiled(v))
+ elif isinstance(item, self.__class__):
+ # Filter SDF based on another SDF
+ diff = self.stream.diff(item.stream)
+ diff_compiled = diff.compile(allow_filters=False, allow_updates=False)
+ return self.filter(lambda v: diff_compiled(v))
elif isinstance(item, list):
- return self._apply(lambda row: subset(item, row))
- elif isinstance(item, StreamingDataFrame):
- # TODO: Implement filtering based on another SDF
- raise ValueError(
- "Filtering based on StreamingDataFrame is not supported yet."
- )
+ # Take only certain keys from the dict and return a new dict
+ return self.apply(lambda v: {k: v[k] for k in item})
+ elif isinstance(item, str):
+ # Create a StreamingSeries based on key
+ return StreamingSeries(name=item)
else:
- return Column(col_name=item)
+ raise TypeError(f'Unsupported key type "{type(item)}"')
- def _produce(self, topic: Topic, row: Row, key: Optional[Any] = None) -> Row:
- self.producer.produce_row(row, topic, key=key)
- return row
- def _apply(self, func: Callable[[Row], Optional[Union[Row, List[Row]]]]) -> Self:
- """
- Add a callable to the StreamingDataframe execution list.
- The provided callable should accept and return a Quixstreams Row; exceptions
- to this include a user's .apply() function, or a "filter" that returns None.
+def _as_stateful(
+ func: DataFrameStatefulFunc, state_manager: StateStoreManager
+) -> DataFrameFunc:
+ @functools.wraps(func)
+ def wrapper(value: object) -> object:
+ transaction = state_manager.get_store_transaction()
+ key = message_key()
+ # Prefix all the state keys by the message key
+ with transaction.with_prefix(prefix=key):
+ # Pass a State object with an interface limited to the key updates only
+ return func(value, transaction.state)
- :param func: callable that accepts and (usually) returns a QuixStreams Row
- :return: self (StreamingDataFrame)
- """
- self._pipeline.apply(func)
- return self
+ return wrapper
diff --git a/src/StreamingDataFrames/quixstreams/dataframe/exceptions.py b/src/StreamingDataFrames/quixstreams/dataframe/exceptions.py
deleted file mode 100644
index c05a016b4..000000000
--- a/src/StreamingDataFrames/quixstreams/dataframe/exceptions.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from quixstreams import exceptions
-
-__all__ = ("InvalidApplyResultType",)
-
-
-class InvalidApplyResultType(exceptions.QuixException):
- ...
diff --git a/src/StreamingDataFrames/quixstreams/dataframe/pipeline.py b/src/StreamingDataFrames/quixstreams/dataframe/pipeline.py
deleted file mode 100644
index 125acfe96..000000000
--- a/src/StreamingDataFrames/quixstreams/dataframe/pipeline.py
+++ /dev/null
@@ -1,82 +0,0 @@
-import logging
-import uuid
-from typing import Optional, Callable, Any, List
-from typing_extensions import Self
-from ..models import Row
-
-logger = logging.getLogger(__name__)
-
-__all__ = ("PipelineFunction", "Pipeline")
-
-
-class PipelineFunction:
- def __init__(self, func: Callable):
- self._id = str(uuid.uuid4())
- self._func = func
-
- @property
- def id(self) -> str:
- return self._id
-
- def __repr__(self):
- return f'<{self.__class__.__name__} "{repr(self._func)}">'
-
- def __call__(self, row: Row) -> Row:
- return self._func(row)
-
-
-class Pipeline:
- def __init__(self, functions: List[PipelineFunction] = None, _id: str = None):
- self._id = _id or str(uuid.uuid4())
- self._functions = functions or []
-
- @property
- def functions(self) -> List[PipelineFunction]:
- return self._functions
-
- @property
- def id(self) -> str:
- return self._id
-
- def apply(self, func: Callable) -> Self:
- """
- Add a callable to the Pipeline execution list.
- The provided callable should accept a single argument, which will be its input.
- The provided callable should similarly return one output, or None
-
- Note that if the expected input is a list, this function will be called with
- each element in that list rather than the list itself.
-
- :param func: callable that accepts and (usually) returns an object
- :return: self (Pipeline)
- """
- self.functions.append(PipelineFunction(func=func))
- return self
-
- def process(self, data: Any) -> Optional[Any]:
- """
- Execute the previously defined Pipeline functions on a provided input, `data`.
-
- Note that if `data` is a list, each function will be called with each element
- in that list rather than the list itself.
-
- :param data: any object, but usually a QuixStreams Row
- :return: an object OR list of objects OR None (if filtered)
- """
- # TODO: maybe have an arg that allows passing "blacklisted" result types
- # or SDF inspects each result somehow?
- result = data
- for func in self.functions:
- if isinstance(result, list):
- result = [
- fd for fd in (func(d) for d in result) if fd is not None
- ] or None
- else:
- result = func(result)
- if result is None:
- logger.debug(
- "Pipeline {pid} processing step returned a None; "
- "terminating processing".format(pid=self.id)
- )
- break
- return result
diff --git a/src/StreamingDataFrames/quixstreams/dataframe/series.py b/src/StreamingDataFrames/quixstreams/dataframe/series.py
new file mode 100644
index 000000000..8c5a479c8
--- /dev/null
+++ b/src/StreamingDataFrames/quixstreams/dataframe/series.py
@@ -0,0 +1,234 @@
+import contextvars
+import operator
+from typing import Optional, Union, Callable, Container, Any
+
+from typing_extensions import Self
+
+from quixstreams.context import set_message_context
+from quixstreams.core.stream.functions import StreamCallable, Apply
+from quixstreams.core.stream.stream import Stream
+from quixstreams.models.messagecontext import MessageContext
+from .base import BaseStreaming
+
+__all__ = ("StreamingSeries",)
+
+
+class StreamingSeries(BaseStreaming):
+ def __init__(
+ self,
+ name: Optional[str] = None,
+ stream: Optional[Stream] = None,
+ ):
+ if not (name or stream):
+ raise ValueError('Either "name" or "stream" must be passed')
+ self._stream = stream or Stream(func=Apply(lambda v: v[name]))
+
+ @classmethod
+ def from_func(cls, func: StreamCallable) -> Self:
+ """
+ Createa StreamingSeries from a function.
+
+ The provided function will be wrapped into `Apply`
+ :param func: a function to apply
+ :return: instance of `StreamingSeries`
+ """
+ return cls(stream=Stream(Apply(func)))
+
+ @property
+ def stream(self) -> Stream:
+ return self._stream
+
+ def apply(self, func: StreamCallable) -> Self:
+ """
+ Add a callable to the execution list for this series.
+
+ The provided callable should accept a single argument, which will be its input.
+ The provided callable should similarly return one output, or None
+
+ :param func: a callable with one argument and one output
+ :return: a new `StreamingSeries` with the new callable added
+ """
+ child = self._stream.add_apply(func)
+ return self.__class__(stream=child)
+
+ def compile(
+ self,
+ allow_filters: bool = True,
+ allow_updates: bool = True,
+ ) -> StreamCallable:
+ """
+ Compile all functions of this StreamingSeries into one big closure.
+
+ Closures are more performant than calling all the functions in the
+ `StreamingDataFrame` one-by-one.
+
+ :param allow_filters: If False, this function will fail with ValueError if
+ the stream has filter functions in the tree. Default - True.
+ :param allow_updates: If False, this function will fail with ValueError if
+ the stream has update functions in the tree. Default - True.
+
+ :raises ValueError: if disallowed functions are present in the tree of
+ underlying `Stream`.
+
+ :return: a function that accepts "value"
+ and returns a result of StreamingDataFrame
+ """
+
+ return self._stream.compile(
+ allow_filters=allow_filters, allow_updates=allow_updates
+ )
+
+ def test(self, value: Any, ctx: Optional[MessageContext] = None) -> Any:
+ """
+ A shorthand to test `StreamingSeries` with provided value
+ and `MessageContext`.
+
+ :param value: value to pass through `StreamingSeries`
+ :param ctx: instance of `MessageContext`, optional.
+ Provide it if the StreamingSeries instance has
+ functions calling `get_current_key()`.
+ Default - `None`.
+ :return: result of `StreamingSeries`
+ """
+ context = contextvars.copy_context()
+ context.run(set_message_context, ctx)
+ compiled = self.compile()
+ return context.run(compiled, value)
+
+ def _operation(
+ self, other: Union[Self, object], operator_: Callable[[object, object], object]
+ ) -> Self:
+ self_compiled = self.compile()
+ if isinstance(other, self.__class__):
+ other_compiled = other.compile()
+ return self.from_func(
+ func=lambda v, op=operator_: op(self_compiled(v), other_compiled(v))
+ )
+ else:
+ return self.from_func(
+ func=lambda v, op=operator_: op(self_compiled(v), other)
+ )
+
+ def isin(self, other: Container) -> Self:
+ """
+ Check if series value is in "other".
+ Same as "StreamingSeries in other".
+
+ :param other: a container to check
+ :return: new StreamingSeries
+ """
+ return self._operation(
+ other, lambda a, b, contains=operator.contains: contains(b, a)
+ )
+
+ def contains(self, other: object) -> Self:
+ """
+ Check if series value contains "other"
+ Same as "other in StreamingSeries".
+
+ :param other: object to check
+ :return: new StreamingSeries
+ """
+ return self._operation(other, operator.contains)
+
+ def is_(self, other: object) -> Self:
+ """
+ Check if series value refers to the same object as `other`
+ :param other: object to check for "is"
+ :return: new StreamingSeries
+ """
+ return self._operation(other, operator.is_)
+
+ def isnot(self, other: object) -> Self:
+ """
+ Check if series value refers to the same object as `other`
+ :param other: object to check for "is"
+ :return: new StreamingSeries
+ """
+ return self._operation(other, operator.is_not)
+
+ def isnull(self) -> Self:
+ """
+ Check if series value is None
+
+ :return: new StreamingSeries
+ """
+ return self._operation(None, operator.is_)
+
+ def notnull(self) -> Self:
+ """
+ Check if series value is not None
+ """
+ return self._operation(None, operator.is_not)
+
+ def abs(self) -> Self:
+ """
+ Get absolute value of the series value
+ """
+ return self.apply(func=lambda v: abs(v))
+
+ def __getitem__(self, item: Union[str, int]) -> Self:
+ return self._operation(item, operator.getitem)
+
+ def __mod__(self, other: object) -> Self:
+ return self._operation(other, operator.mod)
+
+ def __add__(self, other: object) -> Self:
+ return self._operation(other, operator.add)
+
+ def __sub__(self, other: object) -> Self:
+ return self._operation(other, operator.sub)
+
+ def __mul__(self, other: object) -> Self:
+ return self._operation(other, operator.mul)
+
+ def __truediv__(self, other: object) -> Self:
+ return self._operation(other, operator.truediv)
+
+ def __eq__(self, other: object) -> Self:
+ return self._operation(other, operator.eq)
+
+ def __ne__(self, other: object) -> Self:
+ return self._operation(other, operator.ne)
+
+ def __lt__(self, other: object) -> Self:
+ return self._operation(other, operator.lt)
+
+ def __le__(self, other: object) -> Self:
+ return self._operation(other, operator.le)
+
+ def __gt__(self, other: object) -> Self:
+ return self._operation(other, operator.gt)
+
+ def __ge__(self, other: object) -> Self:
+ return self._operation(other, operator.ge)
+
+ def __and__(self, other: object) -> Self:
+ """
+ Do a logical "and" comparison.
+
+ .. note:: It behaves differently than `pandas`. `pandas` performs
+ a bitwise "and" if one of the arguments is a number.
+ This function always does a logical "and" instead.
+ """
+ return self._operation(other, lambda x, y: x and y)
+
+ def __or__(self, other: object) -> Self:
+ """
+ Do a logical "or" comparison.
+
+ .. note:: It behaves differently than `pandas`. `pandas` performs
+ a bitwise "or" if one of the arguments is a number.
+ This function always does a logical "or" instead.
+ """
+ return self._operation(other, lambda x, y: x or y)
+
+ def __invert__(self) -> Self:
+ """
+ Do a logical "not".
+
+ .. note:: It behaves differently than `pandas`. `pandas` performs
+ a bitwise "not" if argument is a number.
+ This function always does a logical "not" instead.
+ """
+ return self.apply(lambda v: not v)
diff --git a/src/StreamingDataFrames/quixstreams/models/__init__.py b/src/StreamingDataFrames/quixstreams/models/__init__.py
index fe4e55c81..ed2fb1237 100644
--- a/src/StreamingDataFrames/quixstreams/models/__init__.py
+++ b/src/StreamingDataFrames/quixstreams/models/__init__.py
@@ -3,3 +3,4 @@
from .timestamps import *
from .topics import *
from .types import *
+from .messagecontext import *
diff --git a/src/StreamingDataFrames/quixstreams/models/context.py b/src/StreamingDataFrames/quixstreams/models/messagecontext.py
similarity index 100%
rename from src/StreamingDataFrames/quixstreams/models/context.py
rename to src/StreamingDataFrames/quixstreams/models/messagecontext.py
diff --git a/src/StreamingDataFrames/quixstreams/models/rows.py b/src/StreamingDataFrames/quixstreams/models/rows.py
index f0f319a00..3d9826233 100644
--- a/src/StreamingDataFrames/quixstreams/models/rows.py
+++ b/src/StreamingDataFrames/quixstreams/models/rows.py
@@ -5,7 +5,7 @@
from .messages import MessageHeadersTuples
from .timestamps import MessageTimestamp
-from .context import MessageContext
+from .messagecontext import MessageContext
class Row:
diff --git a/src/StreamingDataFrames/quixstreams/models/topics.py b/src/StreamingDataFrames/quixstreams/models/topics.py
index b30d69aec..9d99ede77 100644
--- a/src/StreamingDataFrames/quixstreams/models/topics.py
+++ b/src/StreamingDataFrames/quixstreams/models/topics.py
@@ -1,7 +1,7 @@
import logging
from typing import Union, List, Mapping, Optional, Any
-from .context import MessageContext
+from .messagecontext import MessageContext
from .messages import KafkaMessage
from .rows import Row
from .serializers import (
diff --git a/src/StreamingDataFrames/tests/test_quixstreams/test_app.py b/src/StreamingDataFrames/tests/test_quixstreams/test_app.py
index 901858bd0..5c6ed8978 100644
--- a/src/StreamingDataFrames/tests/test_quixstreams/test_app.py
+++ b/src/StreamingDataFrames/tests/test_quixstreams/test_app.py
@@ -7,7 +7,6 @@
import pytest
from confluent_kafka import KafkaException, TopicPartition
-from tests.utils import TopicPartitionStub
from quixstreams.app import Application
from quixstreams.models import (
@@ -16,7 +15,6 @@
JSONDeserializer,
SerializationError,
JSONSerializer,
- MessageContext,
)
from quixstreams.platforms.quix import (
QuixKafkaConfigsBuilder,
@@ -28,6 +26,7 @@
RowConsumer,
)
from quixstreams.state import State
+from tests.utils import TopicPartitionStub
def _stop_app_on_future(app: Application, future: Future, timeout: float):
@@ -92,8 +91,8 @@ def on_message_processed(topic_, partition, offset):
value_deserializer=JSONDeserializer(),
)
- df = app.dataframe(topic_in)
- df.to_topic(topic_out)
+ sdf = app.dataframe(topic_in)
+ sdf = sdf.to_topic(topic_out)
processed_count = 0
total_messages = 3
@@ -107,7 +106,7 @@ def on_message_processed(topic_, partition, offset):
# Stop app when the future is resolved
executor.submit(_stop_app_on_future, app, done, 10.0)
- app.run(df)
+ app.run(sdf)
# Check that all messages have been processed
assert processed_count == total_messages
@@ -130,21 +129,19 @@ def on_message_processed(topic_, partition, offset):
assert row.key == data["key"]
assert row.value == {column_name: loads(data["value"].decode())}
- def test_run_consumer_error_raised(
- self, app_factory, producer, topic_factory, consumer, executor
- ):
+ def test_run_consumer_error_raised(self, app_factory, topic_factory, executor):
# Set "auto_offset_reset" to "error" to simulate errors in Consumer
app = app_factory(auto_offset_reset="error")
topic_name, _ = topic_factory()
topic = app.topic(
topic_name, value_deserializer=JSONDeserializer(column_name="root")
)
- df = app.dataframe(topic)
+ sdf = app.dataframe(topic)
# Stop app after 10s if nothing failed
executor.submit(_stop_app_on_timeout, app, 10.0)
with pytest.raises(KafkaMessageError):
- app.run(df)
+ app.run(sdf)
def test_run_deserialization_error_raised(
self, app_factory, producer, topic_factory, consumer, executor
@@ -157,12 +154,12 @@ def test_run_deserialization_error_raised(
with producer:
producer.produce(topic=topic_name, value=b"abc")
- df = app.dataframe(topic)
+ sdf = app.dataframe(topic)
with pytest.raises(SerializationError):
# Stop app after 10s if nothing failed
executor.submit(_stop_app_on_timeout, app, 10.0)
- app.run(df)
+ app.run(sdf)
def test_run_consumer_error_suppressed(
self, app_factory, producer, topic_factory, consumer, executor
@@ -181,14 +178,14 @@ def on_consumer_error(exc, *args):
app = app_factory(on_consumer_error=on_consumer_error)
topic_name, _ = topic_factory()
topic = app.topic(topic_name)
- df = app.dataframe(topic)
+ sdf = app.dataframe(topic)
with patch.object(RowConsumer, "poll") as mocked:
# Patch RowConsumer.poll to simulate failures
mocked.side_effect = ValueError("test")
# Stop app when the future is resolved
executor.submit(_stop_app_on_future, app, done, 10.0)
- app.run(df)
+ app.run(sdf)
assert polled > 1
def test_run_processing_error_raised(
@@ -198,19 +195,19 @@ def test_run_processing_error_raised(
topic_name, _ = topic_factory()
topic = app.topic(topic_name, value_deserializer=JSONDeserializer())
- df = app.dataframe(topic)
+ sdf = app.dataframe(topic)
def fail(*args):
raise ValueError("test")
- df = df.apply(fail)
+ sdf = sdf.apply(fail)
with producer:
producer.produce(topic=topic.name, value=b'{"field":"value"}')
with pytest.raises(ValueError):
executor.submit(_stop_app_on_timeout, app, 10.0)
- app.run(df)
+ app.run(sdf)
def test_run_processing_error_suppressed(
self, app_factory, topic_factory, producer, executor
@@ -232,12 +229,12 @@ def on_processing_error(exc, *args):
)
topic_name, _ = topic_factory()
topic = app.topic(topic_name, value_deserializer=JSONDeserializer())
- df = app.dataframe(topic)
+ sdf = app.dataframe(topic)
def fail(*args):
raise ValueError("test")
- df = df.apply(fail)
+ sdf = sdf.apply(fail)
with producer:
for i in range(produced):
@@ -245,7 +242,7 @@ def fail(*args):
# Stop app from the background thread when the future is resolved
executor.submit(_stop_app_on_future, app, done, 10.0)
- app.run(df)
+ app.run(sdf)
assert produced == consumed
def test_run_producer_error_raised(
@@ -261,15 +258,15 @@ def test_run_producer_error_raised(
topic_in = app.topic(topic_in_name, value_deserializer=JSONDeserializer())
topic_out = app.topic(topic_out_name, value_serializer=JSONSerializer())
- df = app.dataframe(topic_in)
- df.to_topic(topic_out)
+ sdf = app.dataframe(topic_in)
+ sdf = sdf.to_topic(topic_out)
with producer:
producer.produce(topic_in.name, dumps({"field": 1001 * "a"}))
with pytest.raises(KafkaException):
executor.submit(_stop_app_on_timeout, app, 10.0)
- app.run(df)
+ app.run(sdf)
def test_run_serialization_error_raised(
self, app_factory, producer, topic_factory, executor
@@ -281,15 +278,15 @@ def test_run_serialization_error_raised(
topic_out_name, _ = topic_factory()
topic_out = app.topic(topic_out_name, value_serializer=DoubleSerializer())
- df = app.dataframe(topic_in)
- df.to_topic(topic_out)
+ sdf = app.dataframe(topic_in)
+ sdf = sdf.to_topic(topic_out)
with producer:
producer.produce(topic_in.name, b'{"field":"value"}')
with pytest.raises(SerializationError):
executor.submit(_stop_app_on_timeout, app, 10.0)
- app.run(df)
+ app.run(sdf)
def test_run_producer_error_suppressed(
self, app_factory, producer, topic_factory, consumer, executor
@@ -314,15 +311,15 @@ def on_producer_error(exc, *args):
topic_out_name, _ = topic_factory()
topic_out = app.topic(topic_out_name, value_serializer=DoubleSerializer())
- df = app.dataframe(topic_in)
- df.to_topic(topic_out)
+ sdf = app.dataframe(topic_in)
+ sdf = sdf.to_topic(topic_out)
with producer:
for _ in range(produce_input):
producer.produce(topic_in.name, b'{"field":"value"}')
executor.submit(_stop_app_on_future, app, done, 10.0)
- app.run(df)
+ app.run(sdf)
assert produce_output_attempts == produce_input
@@ -336,7 +333,6 @@ def test_streamingdataframe_init(self):
app = Application(broker_address="localhost", consumer_group="test")
topic = app.topic(name="test-topic")
sdf = app.dataframe(topic)
-
assert sdf
@@ -401,9 +397,7 @@ def test_topic_config(self, quix_app_factory):
assert builder.create_topic_configs[expected_name].name == expected_name
assert builder.create_topic_configs[expected_name].num_partitions == 5
- def test_topic_auto_create_false_topic_confirmation(
- self, dataframe_factory, quix_app_factory
- ):
+ def test_topic_auto_create_false_topic_confirmation(self, quix_app_factory):
"""
Topics are confirmed when auto_create_topics=False
"""
@@ -443,7 +437,7 @@ def test_quix_app_stateful_quix_deployment_no_state_management_warning(
app = quix_app_factory(workspace_id="")
topic = app.topic(topic_name)
sdf = app.dataframe(topic)
- sdf.apply(lambda x, state: x, stateful=True)
+ sdf = sdf.apply(lambda x, state: x, stateful=True)
monkeypatch.setenv(
QuixEnvironment.DEPLOYMENT_ID,
@@ -507,15 +501,15 @@ def test_run_stateful_success(
topic_in = app.topic(topic_in_name, value_deserializer=JSONDeserializer())
# Define a function that counts incoming Rows using state
- def count(_, ctx: MessageContext, state: State):
+ def count(_, state: State):
total = state.get("total", 0)
total += 1
state.set("total", total)
if total == total_messages:
total_consumed.set_result(total)
- df = app.dataframe(topic_in)
- df.apply(count, stateful=True)
+ sdf = app.dataframe(topic_in)
+ sdf = sdf.update(count, stateful=True)
total_messages = 3
# Produce messages to the topic and flush
@@ -529,7 +523,7 @@ def count(_, ctx: MessageContext, state: State):
# Stop app when the future is resolved
executor.submit(_stop_app_on_future, app, total_consumed, 10.0)
- app.run(df)
+ app.run(sdf)
# Check that the values are actually in the DB
state_manager = state_manager_factory(
@@ -567,7 +561,7 @@ def test_run_stateful_processing_fails(
topic_in = app.topic(topic_in_name, value_deserializer=JSONDeserializer())
# Define a function that counts incoming Rows using state
- def count(_, ctx, state: State):
+ def count(_, state: State):
total = state.get("total", 0)
total += 1
state.set("total", total)
@@ -578,9 +572,7 @@ def fail(*_):
failed.set_result(True)
raise ValueError("test")
- df = app.dataframe(topic_in)
- df.apply(count, stateful=True)
- df.apply(fail)
+ sdf = app.dataframe(topic_in).update(count, stateful=True).update(fail)
total_messages = 3
# Produce messages to the topic and flush
@@ -592,7 +584,7 @@ def fail(*_):
# Stop app when the future is resolved
executor.submit(_stop_app_on_future, app, failed, 10.0)
with pytest.raises(ValueError):
- app.run(df)
+ app.run(sdf)
# Ensure that nothing was committed to the DB
state_manager = state_manager_factory(
@@ -629,7 +621,7 @@ def test_run_stateful_suppress_processing_errors(
topic_in = app.topic(topic_in_name, value_deserializer=JSONDeserializer())
# Define a function that counts incoming Rows using state
- def count(_, ctx, state: State):
+ def count(_, state: State):
total = state.get("total", 0)
total += 1
state.set("total", total)
@@ -639,9 +631,7 @@ def count(_, ctx, state: State):
def fail(_):
raise ValueError("test")
- df = app.dataframe(topic_in)
- df.apply(count, stateful=True)
- df.apply(fail)
+ sdf = app.dataframe(topic_in).update(count, stateful=True).apply(fail)
total_messages = 3
message_key = b"key"
@@ -656,7 +646,7 @@ def fail(_):
# Stop app when the future is resolved
executor.submit(_stop_app_on_future, app, total_consumed, 10.0)
# Run the application
- app.run(df)
+ app.run(sdf)
# Ensure that data is committed to the DB
state_manager = state_manager_factory(
@@ -710,11 +700,9 @@ def test_on_assign_topic_offset_behind_warning(
# Define some stateful function so the App assigns store partitions
done = Future()
- def count(_, ctx, state: State):
- done.set_result(True)
-
- df = app.dataframe(topic_in)
- df.apply(count, stateful=True)
+ sdf = app.dataframe(topic_in).update(
+ lambda *_: done.set_result(True), stateful=True
+ )
# Produce a message to the topic and flush
data = {"key": b"key", "value": dumps({"key": "value"})}
@@ -725,7 +713,7 @@ def count(_, ctx, state: State):
executor.submit(_stop_app_on_future, app, done, 10.0)
# Run the application
with patch.object(logging.getLoggerClass(), "warning") as mock:
- app.run(df)
+ app.run(sdf)
assert mock.called
assert "is behind the stored offset" in mock.call_args[0][0]
diff --git a/src/StreamingDataFrames/tests/test_quixstreams/test_context.py b/src/StreamingDataFrames/tests/test_quixstreams/test_context.py
new file mode 100644
index 000000000..f6d85479e
--- /dev/null
+++ b/src/StreamingDataFrames/tests/test_quixstreams/test_context.py
@@ -0,0 +1,52 @@
+import contextvars
+
+import pytest
+
+from quixstreams.models import MessageTimestamp, MessageContext
+from quixstreams.context import (
+ message_context,
+ set_message_context,
+ message_key,
+ MessageContextNotSetError,
+)
+
+
+@pytest.fixture()
+def message_context_factory():
+ def factory(key: object = "test") -> MessageContext:
+ return MessageContext(
+ key=key,
+ topic="test",
+ partition=0,
+ offset=0,
+ size=0,
+ timestamp=MessageTimestamp.create(0, 0),
+ )
+
+ return factory
+
+
+class TestContext:
+ def test_get_current_context_not_set_fails(self):
+ ctx = contextvars.copy_context()
+ with pytest.raises(MessageContextNotSetError):
+ ctx.run(message_context)
+
+ def test_set_current_context_and_run(self, message_context_factory):
+ ctx = contextvars.copy_context()
+ message_ctx1 = message_context_factory(key="test")
+ message_ctx2 = message_context_factory(key="test2")
+ for message_ctx in [message_ctx1, message_ctx2]:
+ ctx.run(set_message_context, message_ctx)
+ assert ctx.run(lambda: message_context()) == message_ctx
+
+ def test_get_current_key_success(self, message_context_factory):
+ ctx = contextvars.copy_context()
+ message_ctx = message_context_factory(key="test")
+ ctx.run(set_message_context, message_ctx)
+ assert ctx.run(message_key) == message_ctx.key
+
+ def test_get_current_key_not_set_fails(self):
+ ctx = contextvars.copy_context()
+ with pytest.raises(MessageContextNotSetError):
+ ctx.run(message_key)
diff --git a/src/StreamingDataFrames/tests/test_quixstreams/test_core/__init__.py b/src/StreamingDataFrames/tests/test_quixstreams/test_core/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/StreamingDataFrames/tests/test_quixstreams/test_core/test_functions.py b/src/StreamingDataFrames/tests/test_quixstreams/test_core/test_functions.py
new file mode 100644
index 000000000..7b50c4845
--- /dev/null
+++ b/src/StreamingDataFrames/tests/test_quixstreams/test_core/test_functions.py
@@ -0,0 +1,42 @@
+import pytest
+
+from quixstreams.core.stream.functions import (
+ Apply,
+ Filter,
+ Update,
+ Filtered,
+ is_apply_function,
+ is_filter_function,
+ is_update_function,
+)
+
+
+class TestFunctions:
+ def test_apply_function(self):
+ func = Apply(lambda v: v)
+ assert func(1) == 1
+ assert is_apply_function(func)
+
+ def test_update_function(self):
+ value = [0]
+ expected = [0, 1]
+ func = Update(lambda v: v.append(1))
+ assert func(value) == expected
+ assert is_update_function(func)
+
+ @pytest.mark.parametrize(
+ "value, filtered",
+ [
+ (1, True),
+ (0, False),
+ ],
+ )
+ def test_filter_function(self, value, filtered):
+ func = Filter(lambda v: v == 0)
+ assert is_filter_function(func)
+
+ if filtered:
+ with pytest.raises(Filtered):
+ func(value)
+ else:
+ assert func(value) == value
diff --git a/src/StreamingDataFrames/tests/test_quixstreams/test_core/test_stream.py b/src/StreamingDataFrames/tests/test_quixstreams/test_core/test_stream.py
new file mode 100644
index 000000000..776364e96
--- /dev/null
+++ b/src/StreamingDataFrames/tests/test_quixstreams/test_core/test_stream.py
@@ -0,0 +1,107 @@
+import pytest
+
+from quixstreams.core.stream import (
+ Stream,
+ Filtered,
+ is_filter_function,
+ is_update_function,
+ is_apply_function,
+)
+
+
+class TestStream:
+ def test_add_apply(self):
+ stream = Stream().add_apply(lambda v: v + 1)
+ assert stream.compile()(1) == 2
+
+ def test_add_update(self):
+ stream = Stream().add_update(lambda v: v.append(1))
+ assert stream.compile()([0]) == [0, 1]
+
+ @pytest.mark.parametrize(
+ "value, filtered",
+ [
+ (1, True),
+ (0, False),
+ ],
+ )
+ def test_add_filter(self, value, filtered):
+ stream = Stream().add_filter(lambda v: v == 0)
+
+ if filtered:
+ with pytest.raises(Filtered):
+ stream.compile()(value)
+ else:
+ assert stream.compile()(value) == value
+
+ def test_tree(self):
+ stream = (
+ Stream()
+ .add_apply(lambda v: v)
+ .add_filter(lambda v: v)
+ .add_update(lambda v: v)
+ )
+ tree = stream.tree()
+ assert len(tree) == 4
+ assert is_apply_function(tree[0].func)
+ assert is_apply_function(tree[1].func)
+ assert is_filter_function(tree[2].func)
+ assert is_update_function(tree[3].func)
+
+ def test_diff_success(self):
+ stream = Stream()
+ stream = stream.add_apply(lambda v: v)
+ stream2 = (
+ stream.add_apply(lambda v: v)
+ .add_update(lambda v: v)
+ .add_filter(lambda v: v)
+ )
+
+ stream = stream.add_apply(lambda v: v)
+
+ diff = stream.diff(stream2)
+
+ diff_tree = diff.tree()
+ assert len(diff_tree) == 3
+ assert is_apply_function(diff_tree[0].func)
+ assert is_update_function(diff_tree[1].func)
+ assert is_filter_function(diff_tree[2].func)
+
+ def test_diff_empty_same_stream_fails(self):
+ stream = Stream()
+ with pytest.raises(ValueError, match="The diff is empty"):
+ stream.diff(stream)
+
+ def test_diff_empty_stream_full_child_fails(self):
+ stream = Stream()
+ stream2 = stream.add_apply(lambda v: v)
+ with pytest.raises(ValueError, match="The diff is empty"):
+ stream2.diff(stream)
+
+ def test_diff_no_common_parent_fails(self):
+ stream = Stream()
+ stream2 = Stream()
+ with pytest.raises(ValueError, match="Common parent not found"):
+ stream.diff(stream2)
+
+ def test_compile_allow_filters_false(self):
+ stream = Stream().add_filter(lambda v: v)
+ with pytest.raises(ValueError, match="Filter functions are not allowed"):
+ stream.compile(allow_filters=False)
+
+ def test_compile_allow_updates_false(self):
+ stream = Stream().add_update(lambda v: v)
+ with pytest.raises(ValueError, match="Update functions are not allowed"):
+ stream.compile(allow_updates=False)
+
+ def test_repr(self):
+ stream = (
+ Stream()
+ .add_apply(lambda v: v)
+ .add_update(lambda v: v)
+ .add_filter(lambda v: v)
+ )
+ repr(stream)
+
+ def test_init_with_unwrapped_function(self):
+ ...
diff --git a/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/fixtures.py b/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/fixtures.py
index 3f9729eb7..4831752bb 100644
--- a/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/fixtures.py
+++ b/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/fixtures.py
@@ -1,55 +1,22 @@
-from functools import partial
from typing import Optional
from unittest.mock import MagicMock
import pytest
from quixstreams.dataframe.dataframe import StreamingDataFrame
-from quixstreams.dataframe.pipeline import Pipeline, PipelineFunction
from quixstreams.models.topics import Topic
from quixstreams.state import StateStoreManager
-@pytest.fixture()
-def pipeline_function():
- def test_func(data):
- return {k: v + 1 for k, v in data.items()}
-
- return PipelineFunction(func=test_func)
-
-
-@pytest.fixture()
-def pipeline(pipeline_function):
- return Pipeline(functions=[pipeline_function])
-
-
@pytest.fixture()
def dataframe_factory():
- def _dataframe_factory(
+ def factory(
topic: Optional[Topic] = None,
state_manager: Optional[StateStoreManager] = None,
- ):
+ ) -> StreamingDataFrame:
return StreamingDataFrame(
- topic=topic or Topic(name="test_in"),
+ topic=topic or Topic(name="test"),
state_manager=state_manager or MagicMock(spec=StateStoreManager),
)
- return _dataframe_factory
-
-
-def row_values_plus_n(n, row, ctx):
- for k, v in row.items():
- row[k] = v + n
- return row
-
-
-@pytest.fixture()
-def row_plus_n_func():
- """
- This generally will be used alongside "row_plus_n"
- """
-
- def _row_values_plus_n(n=None):
- return partial(row_values_plus_n, n)
-
- return _row_values_plus_n
+ return factory
diff --git a/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_column.py b/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_column.py
deleted file mode 100644
index 1eac1ca47..000000000
--- a/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_column.py
+++ /dev/null
@@ -1,328 +0,0 @@
-import pytest
-
-from quixstreams.dataframe.column import Column
-
-
-class TestColumn:
- def test_apply(self, row_factory):
- msg_value = row_factory({"x": 5, "y": 20, "z": 110}, key=123)
- result = Column("x").apply(lambda v, context: v + context.key)
- assert isinstance(result, Column)
- assert result.eval(msg_value) == 128
-
- def test_addition(self, row_factory):
- msg_value = row_factory({"x": 5, "y": 20, "z": 110})
- result = Column("x") + Column("y")
- assert isinstance(result, Column)
- assert result.eval(msg_value) == 25
-
- def test_multi_op(self, row_factory):
- msg_value = row_factory({"x": 5, "y": 20, "z": 110})
- result = Column("x") + Column("y") + Column("z")
- assert isinstance(result, Column)
- assert result.eval(msg_value) == 135
-
- def test_scaler_op(self, row_factory):
- msg_value = row_factory({"x": 5, "y": 20, "z": 110})
- result = Column("x") + 2
- assert isinstance(result, Column)
- assert result.eval(msg_value) == 7
-
- def test_subtraction(self, row_factory):
- msg_value = row_factory({"x": 5, "y": 20, "z": 110})
- result = Column("y") - Column("x")
- assert isinstance(result, Column)
- assert result.eval(msg_value) == 15
-
- def test_multiplication(self, row_factory):
- msg_value = row_factory({"x": 5, "y": 20, "z": 110})
- result = Column("x") * Column("y")
- assert isinstance(result, Column)
- assert result.eval(msg_value) == 100
-
- def test_div(self, row_factory):
- msg_value = row_factory({"x": 5, "y": 20, "z": 110})
- result = Column("z") / Column("y")
- assert isinstance(result, Column)
- assert result.eval(msg_value) == 5.5
-
- def test_mod(self, row_factory):
- msg_value = row_factory({"x": 5, "y": 20, "z": 110})
- result = Column("y") % Column("x")
- assert isinstance(result, Column)
- assert result.eval(msg_value) == 0
-
- def test_equality_true(self, row_factory):
- msg_value = row_factory({"x": 5, "x2": 5})
- result = Column("x") == Column("x2")
- assert isinstance(result, Column)
- assert result.eval(msg_value) is True
-
- def test_equality_false(self, row_factory):
- msg_value = row_factory({"x": 5, "y": 20, "z": 110})
- result = Column("x") == Column("y")
- assert isinstance(result, Column)
- assert result.eval(msg_value) is False
-
- def test_inequality_true(self, row_factory):
- msg_value = row_factory({"x": 5, "y": 20, "z": 110})
- result = Column("x") != Column("y")
- assert isinstance(result, Column)
- assert result.eval(msg_value) is True
-
- def test_inequality_false(self, row_factory):
- msg_value = row_factory({"x": 5, "x2": 5})
- result = Column("x") != Column("x2")
- assert isinstance(result, Column)
- assert result.eval(msg_value) is False
-
- def test_less_than_true(self, row_factory):
- msg_value = row_factory({"x": 5, "y": 20, "z": 110})
- result = Column("x") < Column("y")
- assert isinstance(result, Column)
- assert result.eval(msg_value) is True
-
- def test_less_than_false(self, row_factory):
- msg_value = row_factory({"x": 5, "y": 20, "z": 110})
- result = Column("y") < Column("x")
- assert isinstance(result, Column)
- assert result.eval(msg_value) is False
-
- def test_less_than_or_equal_equal_true(self, row_factory):
- msg_value = row_factory({"x": 5, "x2": 5})
- result = Column("x") <= Column("x2")
- assert isinstance(result, Column)
- assert result.eval(msg_value) is True
-
- def test_less_than_or_equal_less_true(self, row_factory):
- msg_value = row_factory({"x": 5, "y": 20, "z": 110})
- result = Column("x") <= Column("y")
- assert isinstance(result, Column)
- assert result.eval(msg_value) is True
-
- def test_less_than_or_equal_false(self, row_factory):
- msg_value = row_factory({"x": 5, "y": 20, "z": 110})
- result = Column("y") <= Column("x")
- assert isinstance(result, Column)
- assert result.eval(msg_value) is False
-
- def test_greater_than_true(self, row_factory):
- msg_value = row_factory({"x": 5, "y": 20, "z": 110})
- result = Column("y") > Column("x")
- assert isinstance(result, Column)
- assert result.eval(msg_value) is True
-
- def test_greater_than_false(self, row_factory):
- msg_value = row_factory({"x": 5, "y": 20, "z": 110})
- result = Column("x") > Column("y")
- assert isinstance(result, Column)
- assert result.eval(msg_value) is False
-
- def test_greater_than_or_equal_equal_true(self, row_factory):
- msg_value = row_factory({"x": 5, "x2": 5})
- result = Column("x") >= Column("x2")
- assert isinstance(result, Column)
- assert result.eval(msg_value) is True
-
- def test_greater_than_or_equal_greater_true(self, row_factory):
- msg_value = row_factory({"x": 5, "y": 20, "z": 110})
- result = Column("y") >= Column("x")
- assert isinstance(result, Column)
- assert result.eval(msg_value) is True
-
- def test_greater_than_or_equal_greater_false(self, row_factory):
- msg_value = row_factory({"x": 5, "y": 20, "z": 110})
- result = Column("x") >= Column("y")
- assert isinstance(result, Column)
- assert result.eval(msg_value) is False
-
- def test_and_true(self, row_factory):
- msg_value = row_factory({"x": True, "y": True, "z": False})
- result = Column("x") & Column("y")
- assert isinstance(result, Column)
- assert result.eval(msg_value) is True
-
- def test_and_true_multiple(self, row_factory):
- msg_value = row_factory({"x": True, "y": True, "z": False})
- result = Column("x") & Column("y") & Column("x")
- assert isinstance(result, Column)
- assert result.eval(msg_value) is True
-
- def test_and_false_multiple(self, row_factory):
- msg_value = row_factory({"x": True, "y": True, "z": False})
- result = Column("x") & Column("y") & Column("z")
- assert isinstance(result, Column)
- assert result.eval(msg_value) is False
-
- def test_and_false(self, row_factory):
- msg_value = row_factory({"x": True, "y": True, "z": False})
- result = Column("x") & Column("z")
- assert isinstance(result, Column)
- assert result.eval(msg_value) is False
-
- def test_and_true_bool(self, row_factory):
- msg_value = row_factory({"x": True, "y": True, "z": False})
- result = Column("x") & True
- assert isinstance(result, Column)
- assert result.eval(msg_value) is True
-
- def test_and_false_bool(self, row_factory):
- msg_value = row_factory({"x": True, "y": True, "z": False})
- result = Column("x") & False
- assert isinstance(result, Column)
- assert result.eval(msg_value) is False
-
- def test_or_true(self, row_factory):
- msg_value = row_factory({"x": True, "y": True, "z": False})
- result = Column("x") | Column("z")
- assert isinstance(result, Column)
- assert result.eval(msg_value) is True
-
- def test_or_false(self, row_factory):
- msg_value = row_factory({"x": True, "y": True, "z": False})
- result = Column("z") | Column("z")
- assert isinstance(result, Column)
- assert result.eval(msg_value) is False
-
- def test_and_inequalities_true(self, row_factory):
- msg_value = row_factory({"x": 5, "y": 20, "z": 110})
- result = (Column("x") <= Column("y")) & (Column("x") <= Column("z"))
- assert isinstance(result, Column)
- assert result.eval(msg_value) is True
-
- def test_and_inequalities_false(self, row_factory):
- msg_value = row_factory({"x": 5, "y": 20, "z": 110})
- result = (Column("x") <= Column("y")) & (Column("x") > Column("z"))
- assert isinstance(result, Column)
- assert result.eval(msg_value) is False
-
- def test_invert_int(self, row_factory):
- msg_value = row_factory({"x": 1})
- result = ~Column("x")
- assert isinstance(result, Column)
- assert result.eval(msg_value) == -2
-
- def test_invert_bool(self, row_factory):
- msg_value = row_factory({"x": True, "y": True, "z": False})
- result = ~Column("x")
- assert isinstance(result, Column)
- assert result.eval(msg_value) is False
-
- def test_invert_bool_from_inequalities(self, row_factory):
- msg_value = row_factory({"x": 5, "y": 20, "z": 110})
- result = ~(Column("x") <= Column("y"))
- assert isinstance(result, Column)
- assert result.eval(msg_value) is False
-
- @pytest.mark.parametrize(
- "value, other, expected",
- [
- ({"x": 1}, [1, 2, 3], True),
- ({"x": 1}, [], False),
- ({"x": 1}, {1: 456}, True),
- ],
- )
- def test_isin(self, row_factory, value, other, expected):
- row = row_factory(value)
- assert Column("x").isin(other).eval(row) == expected
-
- @pytest.mark.parametrize(
- "value, other, expected",
- [
- ({"x": [1, 2, 3]}, 1, True),
- ({"x": [1, 2, 3]}, 5, False),
- ({"x": "abc"}, "a", True),
- ({"x": {"y": "z"}}, "y", True),
- ],
- )
- def test_contains(self, row_factory, value, other, expected):
- row = row_factory(value)
- assert Column("x").contains(other).eval(row) == expected
-
- @pytest.mark.parametrize(
- "value, expected",
- [
- ({"x": None}, True),
- ({"x": [1, 2, 3]}, False),
- ],
- )
- def test_isnull(self, row_factory, value, expected):
- row = row_factory(value)
- assert Column("x").isnull().eval(row) == expected
-
- @pytest.mark.parametrize(
- "value, expected",
- [
- ({"x": None}, False),
- ({"x": [1, 2, 3]}, True),
- ],
- )
- def test_notnull(self, row_factory, value, expected):
- row = row_factory(value)
- assert Column("x").notnull().eval(row) == expected
-
- @pytest.mark.parametrize(
- "value, other, expected",
- [
- ({"x": [1, 2, 3]}, None, False),
- ({"x": None}, None, True),
- ({"x": 1}, 1, True),
- ],
- )
- def test_is_(self, row_factory, value, other, expected):
- row = row_factory(value)
- assert Column("x").is_(other).eval(row) == expected
-
- @pytest.mark.parametrize(
- "value, other, expected",
- [
- ({"x": [1, 2, 3]}, None, True),
- ({"x": None}, None, False),
- ({"x": 1}, 1, False),
- ],
- )
- def test_isnot(self, row_factory, value, other, expected):
- row = row_factory(value)
- assert Column("x").isnot(other).eval(row) == expected
-
- @pytest.mark.parametrize(
- "nested_item",
- [
- {2: 110},
- {2: "a_string"},
- {2: {"another": "dict"}},
- {2: ["a", "list"]},
- ["item", "in", "this", "list"],
- ],
- )
- def test__get_item__(self, row_factory, nested_item):
- msg_value = row_factory({"x": {"y": nested_item}, "k": 0})
- result = Column("x")["y"][2]
- assert isinstance(result, Column)
- assert result.eval(msg_value) == nested_item[2]
-
- def test_get_item_with_apply(self, row_factory):
- msg_value = row_factory({"x": {"y": {"z": 110}}, "k": 0})
- result = Column("x")["y"]["z"].apply(lambda v, ctx: v + 10)
- assert isinstance(result, Column)
- assert result.eval(msg_value) == 120
-
- def test_get_item_with_op(self, row_factory):
- msg_value = row_factory({"x": {"y": 10}, "j": {"k": 5}})
- result = Column("x")["y"] + Column("j")["k"]
- assert isinstance(result, Column)
- assert result.eval(msg_value) == 15
-
- @pytest.mark.parametrize("value, expected", [(10, 10), (-10, 10), (10.0, 10.0)])
- def test_abs_success(self, value, expected, row_factory):
- row = row_factory({"x": value})
- result = Column("x").abs()
- assert isinstance(result, Column)
- assert result.eval(row) == expected
-
- def test_abs_not_a_number_fails(self, row_factory):
- row = row_factory({"x": "string"})
- result = Column("x").abs()
- assert isinstance(result, Column)
- with pytest.raises(TypeError, match="bad operand type for abs()"):
- assert result.eval(row)
diff --git a/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_dataframe.py b/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_dataframe.py
index 9d2c33155..e3285a162 100644
--- a/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_dataframe.py
+++ b/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_dataframe.py
@@ -1,224 +1,262 @@
+import operator
+
import pytest
-from tests.utils import TopicPartitionStub
-from quixstreams.dataframe.exceptions import InvalidApplyResultType
-from quixstreams.dataframe.pipeline import Pipeline
-from quixstreams.models import MessageContext
+from quixstreams import MessageContext, State
+from quixstreams.core.stream import Filtered
+from quixstreams.models import MessageTimestamp
from quixstreams.models.topics import Topic
-from quixstreams.state import State
-
-
-class TestDataframe:
- def test_dataframe(self, dataframe_factory):
- dataframe = dataframe_factory()
- assert isinstance(dataframe._pipeline, Pipeline)
- assert dataframe._pipeline.id == dataframe.id
-
-
-class TestDataframeProcess:
- def test_apply(self, dataframe_factory, row_factory):
- dataframe = dataframe_factory()
- row = row_factory({"x": 1, "y": 2}, key="key")
-
- def _apply(value: dict, ctx: MessageContext):
- assert ctx.key == "key"
- assert value == {"x": 1, "y": 2}
- return {
- "x": 3,
- "y": 4,
- }
-
- dataframe.apply(_apply)
- assert dataframe.process(row).value == {"x": 3, "y": 4}
-
- def test_apply_no_return_value(self, dataframe_factory, row_factory):
- dataframe = dataframe_factory()
- dataframe = dataframe.apply(lambda row, ctx: row.update({"y": 2}))
- row = row_factory({"x": 1})
- assert dataframe.process(row).value == row_factory({"x": 1, "y": 2}).value
-
- def test_apply_invalid_return_type(self, dataframe_factory, row_factory):
- dataframe = dataframe_factory()
- dataframe = dataframe.apply(lambda row, ctx: False)
- row = row_factory({"x": 1, "y": 2})
- with pytest.raises(InvalidApplyResultType):
- dataframe.process(row)
-
- def test_apply_fluent(self, dataframe_factory, row_factory, row_plus_n_func):
- dataframe = dataframe_factory()
- dataframe = dataframe.apply(row_plus_n_func(n=1)).apply(row_plus_n_func(n=2))
- row = row_factory({"x": 1, "y": 2})
- assert dataframe.process(row).value == row_factory({"x": 4, "y": 5}).value
-
- def test_apply_sequential(self, dataframe_factory, row_factory, row_plus_n_func):
- dataframe = dataframe_factory()
- dataframe = dataframe.apply(row_plus_n_func(n=1))
- dataframe = dataframe.apply(row_plus_n_func(n=2))
- row = row_factory({"x": 1, "y": 2})
- assert dataframe.process(row).value == row_factory({"x": 4, "y": 5}).value
-
- def test_setitem_primitive(self, dataframe_factory, row_factory):
- dataframe = dataframe_factory()
- dataframe["new"] = 1
- row = row_factory({"x": 1})
- assert dataframe.process(row).value == row_factory({"x": 1, "new": 1}).value
-
- def test_setitem_column_only(self, dataframe_factory, row_factory):
- dataframe = dataframe_factory()
- dataframe["new"] = dataframe["x"]
- row = row_factory({"x": 1})
- assert dataframe.process(row).value == row_factory({"x": 1, "new": 1}).value
-
- def test_setitem_column_with_function(self, dataframe_factory, row_factory):
- dataframe = dataframe_factory()
- dataframe["new"] = dataframe["x"].apply(lambda v, ctx: v + 5)
- row = row_factory({"x": 1})
- assert dataframe.process(row).value == row_factory({"x": 1, "new": 6}).value
-
- def test_setitem_column_with_operations(self, dataframe_factory, row_factory):
- dataframe = dataframe_factory()
- dataframe["new"] = (
- dataframe["x"] + dataframe["y"].apply(lambda v, ctx: v + 5) + 1
- )
- row = row_factory({"x": 1, "y": 2})
- expected = row_factory({"x": 1, "y": 2, "new": 9})
- assert dataframe.process(row).value == expected.value
+from tests.utils import TopicPartitionStub
- def test_setitem_from_a_nested_column(
- self, dataframe_factory, row_factory, row_plus_n_func
- ):
- dataframe = dataframe_factory()
- dataframe["a"] = dataframe["x"]["y"]
- row = row_factory({"x": {"y": 1, "z": "banana"}})
- expected = row_factory({"x": {"y": 1, "z": "banana"}, "a": 1})
- assert dataframe.process(row).value == expected.value
-
- def test_column_subset(self, dataframe_factory, row_factory):
- dataframe = dataframe_factory()
- dataframe = dataframe[["x", "y"]]
- row = row_factory({"x": 1, "y": 2, "z": 3})
- expected = row_factory({"x": 1, "y": 2})
- assert dataframe.process(row).value == expected.value
-
- def test_column_subset_with_funcs(
- self, dataframe_factory, row_factory, row_plus_n_func
- ):
- dataframe = dataframe_factory()
- dataframe = dataframe[["x", "y"]].apply(row_plus_n_func(n=5))
- row = row_factory({"x": 1, "y": 2, "z": 3})
- expected = row_factory({"x": 6, "y": 7})
- assert dataframe.process(row).value == expected.value
-
- def test_inequality_filter(self, dataframe_factory, row_factory):
- dataframe = dataframe_factory()
- dataframe = dataframe[dataframe["x"] > 0]
- row = row_factory({"x": 1, "y": 2})
- assert dataframe.process(row).value == row.value
-
- def test_inequality_filter_is_filtered(self, dataframe_factory, row_factory):
- dataframe = dataframe_factory()
- dataframe = dataframe[dataframe["x"] >= 1000]
- row = row_factory({"x": 1, "y": 2})
- assert dataframe.process(row) is None
-
- def test_inequality_filter_with_operation(self, dataframe_factory, row_factory):
- dataframe = dataframe_factory()
- dataframe = dataframe[(dataframe["x"] - 0 + dataframe["y"]) > 0]
- row = row_factory({"x": 1, "y": 2})
- assert dataframe.process(row).value == row.value
-
- def test_inequality_filter_with_operation_is_filtered(
- self, dataframe_factory, row_factory
- ):
- dataframe = dataframe_factory()
- dataframe = dataframe[(dataframe["x"] - dataframe["y"]) > 0]
- row = row_factory({"x": 1, "y": 2})
- assert dataframe.process(row) is None
-
- def test_inequality_filtering_with_apply(self, dataframe_factory, row_factory):
- dataframe = dataframe_factory()
- dataframe = dataframe[dataframe["x"].apply(lambda v, ctx: v - 1) >= 0]
- row = row_factory({"x": 1, "y": 2})
- assert dataframe.process(row).value == row.value
-
- def test_inequality_filtering_with_apply_is_filtered(
- self, dataframe_factory, row_factory
- ):
- dataframe = dataframe_factory()
- dataframe = dataframe[dataframe["x"].apply(lambda v, ctx: v - 10) >= 0]
- row = row_factory({"x": 1, "y": 2})
- assert dataframe.process(row) is None
-
- def test_compound_inequality_filter(self, dataframe_factory, row_factory):
- dataframe = dataframe_factory()
- dataframe = dataframe[(dataframe["x"] >= 0) & (dataframe["y"] < 10)]
- row = row_factory({"x": 1, "y": 2})
- assert dataframe.process(row).value == row.value
-
- def test_compound_inequality_filter_is_filtered(
- self, dataframe_factory, row_factory
- ):
- dataframe = dataframe_factory()
- dataframe = dataframe[(dataframe["x"] >= 0) & (dataframe["y"] < 0)]
- row = row_factory({"x": 1, "y": 2})
- assert dataframe.process(row) is None
-
- def test_contains_on_existing_column(self, dataframe_factory, row_factory):
- dataframe = dataframe_factory()
- dataframe["has_column"] = dataframe.contains("x")
- row = row_factory({"x": 1})
- assert (
- dataframe.process(row).value
- == row_factory({"x": 1, "has_column": True}).value
- )
- def test_contains_on_missing_column(self, dataframe_factory, row_factory):
- dataframe = dataframe_factory()
- dataframe["has_column"] = dataframe.contains("wrong_column")
- row = row_factory({"x": 1})
- assert (
- dataframe.process(row).value
- == row_factory({"x": 1, "has_column": False}).value
+class TestStreamingDataFrame:
+ @pytest.mark.parametrize(
+ "value, expected",
+ [(1, 2), ("input", "return"), ([0, 1, 2], "return"), ({"key": "value"}, None)],
+ )
+ def test_apply(self, dataframe_factory, value, expected):
+ sdf = dataframe_factory()
+
+ def _apply(value_: dict):
+ assert value_ == value
+ return expected
+
+ sdf = sdf.apply(_apply)
+ assert sdf.test(value) == expected
+
+ @pytest.mark.parametrize(
+ "value, mutation, expected",
+ [
+ ([0, 1, 2], lambda v: v.append(3), [0, 1, 2, 3]),
+ ({"a": "b"}, lambda v: operator.setitem(v, "x", "y"), {"a": "b", "x": "y"}),
+ ],
+ )
+ def test_update(self, dataframe_factory, value, mutation, expected):
+ sdf = dataframe_factory()
+ sdf = sdf.update(mutation)
+ assert sdf.test(value) == expected
+
+ def test_apply_multiple(self, dataframe_factory):
+ sdf = dataframe_factory()
+ value = 1
+ expected = 4
+ sdf = sdf.apply(lambda v: v + 1).apply(lambda v: v + 2)
+ assert sdf.test(value) == expected
+
+ def test_apply_update_multiple(self, dataframe_factory):
+ sdf = dataframe_factory()
+ value = {"x": 1}
+ expected = {"x": 3, "y": 3}
+ sdf = (
+ sdf.apply(lambda v: {"x": v["x"] + 1})
+ .update(lambda v: operator.setitem(v, "y", 3))
+ .apply(lambda v: {**v, "x": v["x"] + 1})
)
-
- def test_contains_as_filter(self, dataframe_factory, row_factory):
- dataframe = dataframe_factory()
- dataframe = dataframe[dataframe.contains("x")]
-
- valid_row = row_factory({"x": 1, "y": 2})
- valid_result = dataframe.process(valid_row)
- assert valid_result is not None and valid_result.value == valid_row.value
-
- invalid_row = row_factory({"y": 2})
- assert dataframe.process(invalid_row) is None
-
-
-class TestDataframeKafka:
+ assert sdf.test(value) == expected
+
+ def test_setitem_primitive(self, dataframe_factory):
+ value = {"x": 1}
+ expected = {"x": 2}
+ sdf = dataframe_factory()
+ sdf["x"] = 2
+ assert sdf.test(value) == expected
+
+ def test_setitem_series(self, dataframe_factory):
+ value = {"x": 1, "y": 2}
+ expected = {"x": 2, "y": 2}
+ sdf = dataframe_factory()
+ sdf["x"] = sdf["y"]
+ assert sdf.test(value) == expected
+
+ def test_setitem_series_apply(self, dataframe_factory):
+ value = {"x": 1}
+ expected = {"x": 1, "y": 2}
+ sdf = dataframe_factory()
+ sdf["y"] = sdf["x"].apply(lambda v: v + 1)
+ assert sdf.test(value) == expected
+
+ def test_setitem_series_with_operations(self, dataframe_factory):
+ value = {"x": 1, "y": 2}
+ expected = {"x": 1, "y": 2, "z": 5}
+ sdf = dataframe_factory()
+ sdf["z"] = (sdf["x"] + sdf["y"]).apply(lambda v: v + 1) + 1
+ assert sdf.test(value) == expected
+
+ def test_setitem_another_dataframe_apply(self, dataframe_factory):
+ value = {"x": 1}
+ expected = {"x": 1, "y": 2}
+ sdf = dataframe_factory()
+ sdf["y"] = sdf.apply(lambda v: v["x"] + 1)
+ assert sdf.test(value) == expected
+
+ def test_column_subset(self, dataframe_factory):
+ value = {"x": 1, "y": 2, "z": 3}
+ expected = {"x": 1, "y": 2}
+ sdf = dataframe_factory()
+ sdf = sdf[["x", "y"]]
+ assert sdf.test(value) == expected
+
+ def test_column_subset_and_apply(self, dataframe_factory):
+ value = {"x": 1, "y": 2, "z": 3}
+ expected = 2
+ sdf = dataframe_factory()
+ sdf = sdf[["x", "y"]]
+ sdf = sdf.apply(lambda v: v["y"])
+ assert sdf.test(value) == expected
+
+ @pytest.mark.parametrize(
+ "value, filtered",
+ [
+ ({"x": 1, "y": 2}, False),
+ ({"x": 0, "y": 2}, True),
+ ],
+ )
+ def test_filter_with_series(self, dataframe_factory, value, filtered):
+ sdf = dataframe_factory()
+ sdf = sdf[sdf["x"] > 0]
+
+ if filtered:
+ with pytest.raises(Filtered):
+ assert sdf.test(value)
+ else:
+ assert sdf.test(value) == value
+
+ @pytest.mark.parametrize(
+ "value, filtered",
+ [
+ ({"x": 1, "y": 2}, False),
+ ({"x": 0, "y": 2}, True),
+ ],
+ )
+ def test_filter_with_series_apply(self, dataframe_factory, value, filtered):
+ sdf = dataframe_factory()
+ sdf = sdf[sdf["x"].apply(lambda v: v > 0)]
+
+ if filtered:
+ with pytest.raises(Filtered):
+ assert sdf.test(value)
+ else:
+ assert sdf.test(value) == value
+
+ @pytest.mark.parametrize(
+ "value, filtered",
+ [
+ ({"x": 1, "y": 2}, False),
+ ({"x": 0, "y": 2}, True),
+ ],
+ )
+ def test_filter_with_multiple_series(self, dataframe_factory, value, filtered):
+ sdf = dataframe_factory()
+ sdf = sdf[(sdf["x"] > 0) & (sdf["y"] > 0)]
+
+ if filtered:
+ with pytest.raises(Filtered):
+ assert sdf.test(value)
+ else:
+ assert sdf.test(value) == value
+
+ @pytest.mark.parametrize(
+ "value, filtered",
+ [
+ ({"x": 1, "y": 2}, False),
+ ({"x": 0, "y": 2}, True),
+ ],
+ )
+ def test_filter_with_another_sdf_apply(self, dataframe_factory, value, filtered):
+ sdf = dataframe_factory()
+ sdf = sdf[sdf.apply(lambda v: v["x"] > 0)]
+
+ if filtered:
+ with pytest.raises(Filtered):
+ assert sdf.test(value)
+ else:
+ assert sdf.test(value) == value
+
+ def test_filter_with_another_sdf_with_filters_fails(self, dataframe_factory):
+ sdf = dataframe_factory()
+ sdf2 = sdf[sdf["x"] > 1].apply(lambda v: v["x"] > 0)
+ with pytest.raises(ValueError, match="Filter functions are not allowed"):
+ sdf = sdf[sdf2]
+
+ def test_filter_with_another_sdf_with_update_fails(self, dataframe_factory):
+ sdf = dataframe_factory()
+ sdf2 = sdf.apply(lambda v: v).update(lambda v: operator.setitem(v, "x", 2))
+ with pytest.raises(ValueError, match="Update functions are not allowed"):
+ sdf = sdf[sdf2]
+
+ @pytest.mark.parametrize(
+ "value, filtered",
+ [
+ ({"x": 1, "y": 2}, False),
+ ({"x": 0, "y": 2}, True),
+ ],
+ )
+ def test_filter_with_function(self, dataframe_factory, value, filtered):
+ sdf = dataframe_factory()
+ sdf = sdf.filter(lambda v: v["x"] > 0)
+
+ if filtered:
+ with pytest.raises(Filtered):
+ assert sdf.test(value)
+ else:
+ assert sdf.test(value) == value
+
+ def test_contains_on_existing_column(self, dataframe_factory):
+ sdf = dataframe_factory()
+ sdf["has_column"] = sdf.contains("x")
+ assert sdf.test({"x": 1}) == {"x": 1, "has_column": True}
+
+ def test_contains_on_missing_column(self, dataframe_factory):
+ sdf = dataframe_factory()
+ sdf["has_column"] = sdf.contains("wrong_column")
+
+ assert sdf.test({"x": 1}) == {"x": 1, "has_column": False}
+
+ def test_contains_as_filter(self, dataframe_factory):
+ sdf = dataframe_factory()
+ sdf = sdf[sdf.contains("x")]
+
+ valid_value = {"x": 1, "y": 2}
+ assert sdf.test(valid_value) == valid_value
+
+ invalid_value = {"y": 2}
+ with pytest.raises(Filtered):
+ sdf.test(invalid_value)
+
+
+class TestDatafTestStreamingDataFrameToTopic:
def test_to_topic(
self,
dataframe_factory,
row_consumer_factory,
row_producer_factory,
- row_factory,
- topic_json_serdes_factory,
+ topic_factory,
):
- topic = topic_json_serdes_factory()
+ topic_name, _ = topic_factory()
+ topic = Topic(
+ topic_name,
+ key_deserializer="str",
+ value_serializer="json",
+ value_deserializer="json",
+ )
producer = row_producer_factory()
- dataframe = dataframe_factory()
- dataframe.producer = producer
- dataframe.to_topic(topic)
-
- assert dataframe.topics_out[topic.name] == topic
-
- row_to_produce = row_factory(
- topic="ignore_me",
- key=b"test_key",
- value={"x": "1", "y": "2"},
+ sdf = dataframe_factory()
+ sdf.producer = producer
+ sdf = sdf.to_topic(topic)
+
+ value = {"x": 1, "y": 2}
+ ctx = MessageContext(
+ key="test",
+ topic="test",
+ partition=0,
+ offset=0,
+ size=0,
+ timestamp=MessageTimestamp.create(0, 0),
)
with producer:
- dataframe.process(row_to_produce)
+ sdf.test(value, ctx=ctx)
with row_consumer_factory(auto_offset_reset="earliest") as consumer:
consumer.subscribe([topic])
@@ -226,33 +264,43 @@ def test_to_topic(
assert consumed_row
assert consumed_row.topic == topic.name
- assert row_to_produce.key == consumed_row.key
- assert row_to_produce.value == consumed_row.value
+ assert consumed_row.key == ctx.key
+ assert consumed_row.value == value
def test_to_topic_custom_key(
self,
dataframe_factory,
row_consumer_factory,
row_producer_factory,
- row_factory,
- topic_json_serdes_factory,
+ topic_factory,
):
- topic = topic_json_serdes_factory()
+ topic_name, _ = topic_factory()
+ topic = Topic(
+ topic_name,
+ value_serializer="json",
+ value_deserializer="json",
+ key_serializer="int",
+ key_deserializer="int",
+ )
producer = row_producer_factory()
- dataframe = dataframe_factory()
- dataframe.producer = producer
- # Using value of "x" column as a new key
- dataframe.to_topic(topic, key=lambda value, ctx: value["x"])
+ sdf = dataframe_factory()
+ sdf.producer = producer
+
+ # Use value["x"] as a new key
+ sdf = sdf.to_topic(topic, key=lambda v: v["x"])
- row_to_produce = row_factory(
- topic=topic.name,
- key=b"test_key",
- value={"x": "1", "y": "2"},
+ value = {"x": 1, "y": 2}
+ ctx = MessageContext(
+ topic="test",
+ partition=0,
+ offset=0,
+ size=0,
+ timestamp=MessageTimestamp.create(0, 0),
)
with producer:
- dataframe.process(row_to_produce)
+ sdf.test(value, ctx=ctx)
with row_consumer_factory(auto_offset_reset="earliest") as consumer:
consumer.subscribe([topic])
@@ -260,34 +308,48 @@ def test_to_topic_custom_key(
assert consumed_row
assert consumed_row.topic == topic.name
- assert consumed_row.value == row_to_produce.value
- assert consumed_row.key == row_to_produce.value["x"].encode()
+ assert consumed_row.value == value
+ assert consumed_row.key == value["x"]
def test_to_topic_multiple_topics_out(
self,
dataframe_factory,
row_consumer_factory,
row_producer_factory,
- row_factory,
- topic_json_serdes_factory,
+ topic_factory,
):
- topic_0 = topic_json_serdes_factory()
- topic_1 = topic_json_serdes_factory()
+ topic_0_name, _ = topic_factory()
+ topic_1_name, _ = topic_factory()
+
+ topic_0 = Topic(
+ topic_0_name,
+ value_serializer="json",
+ value_deserializer="json",
+ )
+ topic_1 = Topic(
+ topic_1_name,
+ value_serializer="json",
+ value_deserializer="json",
+ )
producer = row_producer_factory()
- dataframe = dataframe_factory()
- dataframe.producer = producer
+ sdf = dataframe_factory()
+ sdf.producer = producer
- dataframe.to_topic(topic_0)
- dataframe.to_topic(topic_1)
+ sdf = sdf.to_topic(topic_0).to_topic(topic_1)
- row_to_produce = row_factory(
- key=b"test_key",
- value={"x": "1", "y": "2"},
+ value = {"x": 1, "y": 2}
+ ctx = MessageContext(
+ key=b"test",
+ topic="test",
+ partition=0,
+ offset=0,
+ size=0,
+ timestamp=MessageTimestamp.create(0, 0),
)
with producer:
- dataframe.process(row_to_produce)
+ sdf.test(value, ctx=ctx)
consumed_rows = []
with row_consumer_factory(auto_offset_reset="earliest") as consumer:
@@ -300,53 +362,199 @@ def test_to_topic_multiple_topics_out(
t.name for t in [topic_0, topic_1]
}
for consumed_row in consumed_rows:
- assert row_to_produce.key == consumed_row.key
- assert row_to_produce.value == consumed_row.value
+ assert consumed_row.key == ctx.key
+ assert consumed_row.value == value
def test_to_topic_no_producer_assigned(self, dataframe_factory, row_factory):
- topic = Topic("whatever")
+ topic = Topic("test")
- dataframe = dataframe_factory()
- dataframe.to_topic(topic)
+ sdf = dataframe_factory()
+ sdf = sdf.to_topic(topic)
+
+ value = {"x": "1", "y": "2"}
+ ctx = MessageContext(
+ key=b"test",
+ topic="test",
+ partition=0,
+ offset=0,
+ size=0,
+ timestamp=MessageTimestamp.create(0, 0),
+ )
- with pytest.raises(RuntimeError):
- dataframe.process(
- row_factory(
- topic=topic.name, key=b"test_key", value={"x": "1", "y": "2"}
- )
- )
+ with pytest.raises(
+ RuntimeError, match="Producer instance has not been provided"
+ ):
+ sdf.test(value, ctx=ctx)
class TestDataframeStateful:
- def test_apply_stateful(self, dataframe_factory, state_manager, row_factory):
+ def test_apply_stateful(self, dataframe_factory, state_manager):
topic = Topic("test")
- def stateful_func(value, ctx, state: State):
+ def stateful_func(value_: dict, state: State) -> int:
current_max = state.get("max")
if current_max is None:
- current_max = value["number"]
+ current_max = value_["number"]
else:
- current_max = max(current_max, value["number"])
+ current_max = max(current_max, value_["number"])
state.set("max", current_max)
- value["max"] = current_max
+ return current_max
sdf = dataframe_factory(topic, state_manager=state_manager)
- sdf.apply(stateful_func, stateful=True)
+ sdf = sdf.apply(stateful_func, stateful=True)
state_manager.on_partition_assign(
tp=TopicPartitionStub(topic=topic.name, partition=0)
)
- rows = [
- row_factory(topic=topic.name, value={"number": 1}),
- row_factory(topic=topic.name, value={"number": 10}),
- row_factory(topic=topic.name, value={"number": 3}),
+ values = [
+ {"number": 1},
+ {"number": 10},
+ {"number": 3},
]
result = None
- for row in rows:
+ ctx = MessageContext(
+ key=b"test",
+ topic="test",
+ partition=0,
+ offset=0,
+ size=0,
+ timestamp=MessageTimestamp.create(0, 0),
+ )
+ for value in values:
with state_manager.start_store_transaction(
- topic=row.topic, partition=row.partition, offset=row.offset
+ topic=ctx.topic, partition=ctx.partition, offset=ctx.offset
):
- result = sdf.process(row)
+ result = sdf.test(value, ctx)
+
+ assert result == 10
+
+ def test_update_stateful(self, dataframe_factory, state_manager):
+ topic = Topic("test")
+
+ def stateful_func(value_: dict, state: State):
+ current_max = state.get("max")
+ if current_max is None:
+ current_max = value_["number"]
+ else:
+ current_max = max(current_max, value_["number"])
+ state.set("max", current_max)
+ value_["max"] = current_max
+
+ sdf = dataframe_factory(topic, state_manager=state_manager)
+ sdf = sdf.update(stateful_func, stateful=True)
- assert result
- assert result.value["max"] == 10
+ state_manager.on_partition_assign(
+ tp=TopicPartitionStub(topic=topic.name, partition=0)
+ )
+ result = None
+ values = [
+ {"number": 1},
+ {"number": 10},
+ {"number": 3},
+ ]
+ ctx = MessageContext(
+ key=b"test",
+ topic="test",
+ partition=0,
+ offset=0,
+ size=0,
+ timestamp=MessageTimestamp.create(0, 0),
+ )
+ for value in values:
+ with state_manager.start_store_transaction(
+ topic=ctx.topic, partition=ctx.partition, offset=ctx.offset
+ ):
+ result = sdf.test(value, ctx)
+
+ assert result is not None
+ assert result["max"] == 10
+
+ def test_filter_stateful(self, dataframe_factory, state_manager):
+ topic = Topic("test")
+
+ def stateful_func(value_: dict, state: State):
+ current_max = state.get("max")
+ if current_max is None:
+ current_max = value_["number"]
+ else:
+ current_max = max(current_max, value_["number"])
+ state.set("max", current_max)
+ value_["max"] = current_max
+
+ sdf = dataframe_factory(topic, state_manager=state_manager)
+ sdf = sdf.update(stateful_func, stateful=True)
+ sdf = sdf.filter(lambda v, state: state.get("max") >= 3, stateful=True)
+
+ state_manager.on_partition_assign(
+ tp=TopicPartitionStub(topic=topic.name, partition=0)
+ )
+ values = [
+ {"number": 1},
+ {"number": 1},
+ {"number": 3},
+ ]
+ ctx = MessageContext(
+ key=b"test",
+ topic="test",
+ partition=0,
+ offset=0,
+ size=0,
+ timestamp=MessageTimestamp.create(0, 0),
+ )
+ results = []
+ for value in values:
+ with state_manager.start_store_transaction(
+ topic=ctx.topic, partition=ctx.partition, offset=ctx.offset
+ ):
+ try:
+ results.append(sdf.test(value, ctx))
+ except Filtered:
+ pass
+ assert len(results) == 1
+ assert results[0]["max"] == 3
+
+ def test_filter_with_another_sdf_apply_stateful(
+ self, dataframe_factory, state_manager
+ ):
+ topic = Topic("test")
+
+ def stateful_func(value_: dict, state: State):
+ current_max = state.get("max")
+ if current_max is None:
+ current_max = value_["number"]
+ else:
+ current_max = max(current_max, value_["number"])
+ state.set("max", current_max)
+ value_["max"] = current_max
+
+ sdf = dataframe_factory(topic, state_manager=state_manager)
+ sdf = sdf.update(stateful_func, stateful=True)
+ sdf = sdf[sdf.apply(lambda v, state: state.get("max") >= 3, stateful=True)]
+
+ state_manager.on_partition_assign(
+ tp=TopicPartitionStub(topic=topic.name, partition=0)
+ )
+ values = [
+ {"number": 1},
+ {"number": 1},
+ {"number": 3},
+ ]
+ ctx = MessageContext(
+ key=b"test",
+ topic="test",
+ partition=0,
+ offset=0,
+ size=0,
+ timestamp=MessageTimestamp.create(0, 0),
+ )
+ results = []
+ for value in values:
+ with state_manager.start_store_transaction(
+ topic=ctx.topic, partition=ctx.partition, offset=ctx.offset
+ ):
+ try:
+ results.append(sdf.test(value, ctx))
+ except Filtered:
+ pass
+ assert len(results) == 1
+ assert results[0]["max"] == 3
diff --git a/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_pipeline.py b/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_pipeline.py
deleted file mode 100644
index 163b602f6..000000000
--- a/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_pipeline.py
+++ /dev/null
@@ -1,32 +0,0 @@
-from quixstreams.dataframe.pipeline import PipelineFunction
-
-
-class TestPipeline:
- def test_pipeline(self, pipeline_function, pipeline):
- assert pipeline._functions == [pipeline_function]
-
- def test_apply(self, pipeline):
- def throwaway_func(data):
- return data
-
- pipeline = pipeline.apply(throwaway_func)
- assert isinstance(pipeline.functions[-1], PipelineFunction)
- assert pipeline.functions[-1]._func == throwaway_func
-
- def test_process(self, pipeline):
- assert pipeline.process({"a": 1, "b": 2}) == {"a": 2, "b": 3}
-
- def test_process_list(self, pipeline):
- actual = pipeline.process([{"a": 1, "b": 2}, {"a": 10, "b": 11}])
- expected = [{"a": 2, "b": 3}, {"a": 11, "b": 12}]
- assert actual == expected
-
- def test_process_break_and_return_none(self, pipeline):
- """
- Add a new function that returns None, then reverse the _functions order
- so that an exception would get thrown if the pipeline doesn't (as expected)
- stop processing when attempting to execute a function with a NoneType input.
- """
- pipeline = pipeline.apply(lambda d: None)
- pipeline._functions.reverse()
- assert pipeline.process({"a": 1, "b": 2}) is None
diff --git a/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_series.py b/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_series.py
new file mode 100644
index 000000000..45d0ac73f
--- /dev/null
+++ b/src/StreamingDataFrames/tests/test_quixstreams/test_dataframe/test_series.py
@@ -0,0 +1,313 @@
+import pytest
+
+from quixstreams.dataframe.series import StreamingSeries
+
+
+class TestStreamingSeries:
+ def test_apply(self):
+ value = {"x": 5, "y": 20, "z": 110}
+ expected = {"x": 6}
+ result = StreamingSeries("x").apply(lambda v: {"x": v + 1})
+ assert isinstance(result, StreamingSeries)
+ assert result.test(value) == expected
+
+ @pytest.mark.parametrize(
+ "value, series, other, expected",
+ [
+ ({"x": 5, "y": 20}, StreamingSeries("x"), StreamingSeries("y"), 25),
+ ({"x": 5, "y": 20}, StreamingSeries("x"), 10, 15),
+ ],
+ )
+ def test_add(self, value, series, other, expected):
+ result = series + other
+ assert result.test(value) == expected
+
+ @pytest.mark.parametrize(
+ "value, series, other, expected",
+ [
+ ({"x": 5, "y": 20}, StreamingSeries("y"), StreamingSeries("x"), 15),
+ ({"x": 5, "y": 20}, StreamingSeries("x"), 10, -5),
+ ],
+ )
+ def test_subtract(self, value, series, other, expected):
+ result = series - other
+ assert result.test(value) == expected
+
+ @pytest.mark.parametrize(
+ "value, series, other, expected",
+ [
+ ({"x": 5, "y": 20}, StreamingSeries("y"), StreamingSeries("x"), 100),
+ ({"x": 5, "y": 20}, StreamingSeries("x"), 10, 50),
+ ],
+ )
+ def test_multiply(self, value, series, other, expected):
+ result = series * other
+ assert result.test(value) == expected
+
+ @pytest.mark.parametrize(
+ "value, series, other, expected",
+ [
+ ({"x": 5, "y": 20}, StreamingSeries("x"), StreamingSeries("x"), 1),
+ ({"x": 5, "y": 20}, StreamingSeries("x"), 2, 2.5),
+ ],
+ )
+ def test_div(self, value, series, other, expected):
+ result = series / other
+ assert result.test(value) == expected
+
+ @pytest.mark.parametrize(
+ "value, series, other, expected",
+ [
+ ({"x": 5, "y": 2}, StreamingSeries("x"), StreamingSeries("y"), 1),
+ ({"x": 5, "y": 20}, StreamingSeries("x"), 3, 2),
+ ],
+ )
+ def test_mod(self, value, series, other, expected):
+ result = series % other
+ assert result.test(value) == expected
+
+ @pytest.mark.parametrize(
+ "value, series, other, expected",
+ [
+ ({"x": 5, "y": 2}, StreamingSeries("x"), StreamingSeries("x"), True),
+ ({"x": 5, "y": 2}, StreamingSeries("x"), StreamingSeries("y"), False),
+ ({"x": 5, "y": 20}, StreamingSeries("x"), 5, True),
+ ({"x": 5, "y": 20}, StreamingSeries("x"), 6, False),
+ ],
+ )
+ def test_equal(self, value, series, other, expected):
+ result = series == other
+ assert result.test(value) is expected
+
+ @pytest.mark.parametrize(
+ "value, series, other, expected",
+ [
+ ({"x": 5, "y": 2}, StreamingSeries("x"), StreamingSeries("x"), False),
+ ({"x": 5, "y": 2}, StreamingSeries("x"), StreamingSeries("y"), True),
+ ({"x": 5, "y": 20}, StreamingSeries("x"), 5, False),
+ ({"x": 5, "y": 20}, StreamingSeries("x"), 6, True),
+ ],
+ )
+ def test_not_equal(self, value, series, other, expected):
+ result = series != other
+ assert result.test(value) is expected
+
+ @pytest.mark.parametrize(
+ "value, series, other, expected",
+ [
+ ({"x": 5, "y": 20}, StreamingSeries("x"), StreamingSeries("x"), False),
+ ({"x": 5, "y": 20}, StreamingSeries("x"), StreamingSeries("y"), True),
+ ({"x": 5, "y": 20}, StreamingSeries("x"), 5, False),
+ ({"x": 5, "y": 20}, StreamingSeries("x"), 6, True),
+ ],
+ )
+ def test_less_than(self, value, series, other, expected):
+ result = series < other
+ assert result.test(value) is expected
+
+ @pytest.mark.parametrize(
+ "value, series, other, expected",
+ [
+ ({"x": 5, "y": 20}, StreamingSeries("x"), StreamingSeries("x"), True),
+ ({"x": 5, "y": 20}, StreamingSeries("x"), StreamingSeries("y"), True),
+ ({"x": 5, "y": 20}, StreamingSeries("x"), 4, False),
+ ({"x": 5, "y": 20}, StreamingSeries("x"), 5, True),
+ ({"x": 5, "y": 20}, StreamingSeries("x"), 6, True),
+ ],
+ )
+ def test_less_than_equal(self, value, series, other, expected):
+ result = series <= other
+ assert result.test(value) is expected
+
+ @pytest.mark.parametrize(
+ "value, series, other, expected",
+ [
+ ({"x": 5, "y": 20}, StreamingSeries("x"), StreamingSeries("x"), False),
+ ({"x": 5, "y": 4}, StreamingSeries("x"), StreamingSeries("y"), True),
+ ({"x": 5, "y": 20}, StreamingSeries("x"), 4, True),
+ ({"x": 5, "y": 20}, StreamingSeries("x"), 5, False),
+ ({"x": 5, "y": 20}, StreamingSeries("x"), 6, False),
+ ],
+ )
+ def test_greater_than(self, value, series, other, expected):
+ result = series > other
+ assert result.test(value) is expected
+
+ @pytest.mark.parametrize(
+ "value, series, other, expected",
+ [
+ ({"x": 5, "y": 20}, StreamingSeries("x"), StreamingSeries("x"), True),
+ ({"x": 5, "y": 4}, StreamingSeries("x"), StreamingSeries("y"), True),
+ ({"x": 5, "y": 6}, StreamingSeries("x"), StreamingSeries("y"), False),
+ ({"x": 5, "y": 20}, StreamingSeries("x"), 4, True),
+ ({"x": 5, "y": 20}, StreamingSeries("x"), 5, True),
+ ({"x": 5, "y": 20}, StreamingSeries("x"), 6, False),
+ ],
+ )
+ def test_greater_than_equal(self, value, series, other, expected):
+ result = series >= other
+ assert result.test(value) is expected
+
+ @pytest.mark.parametrize(
+ "value, series, other, expected",
+ [
+ ({"x": True, "y": False}, StreamingSeries("x"), StreamingSeries("x"), True),
+ (
+ {"x": True, "y": False},
+ StreamingSeries("x"),
+ StreamingSeries("y"),
+ False,
+ ),
+ ({"x": True, "y": False}, StreamingSeries("x"), True, True),
+ ({"x": True, "y": False}, StreamingSeries("x"), False, False),
+ ({"x": True, "y": False}, StreamingSeries("x"), 0, 0),
+ ],
+ )
+ def test_and(self, value, series, other, expected):
+ result = series & other
+ assert result.test(value) is expected
+
+ @pytest.mark.parametrize(
+ "value, series, other, expected",
+ [
+ ({"x": True, "y": False}, StreamingSeries("x"), StreamingSeries("y"), True),
+ (
+ {"x": False},
+ StreamingSeries("x"),
+ StreamingSeries("x"),
+ False,
+ ),
+ (
+ {
+ "x": True,
+ },
+ StreamingSeries("x"),
+ 0,
+ True,
+ ),
+ ({"x": False}, StreamingSeries("x"), 0, 0),
+ ({"x": False}, StreamingSeries("x"), True, True),
+ ],
+ )
+ def test_or(self, value, series, other, expected):
+ result = series | other
+ assert result.test(value) is expected
+
+ def test_multiple_conditions(self):
+ value = {"x": 5, "y": 20, "z": 110}
+ expected = True
+ result = (StreamingSeries("x") <= StreamingSeries("y")) & (
+ StreamingSeries("x") <= StreamingSeries("z")
+ )
+
+ assert result.test(value) is expected
+
+ @pytest.mark.parametrize(
+ "value, series, expected",
+ [
+ ({"x": True, "y": False}, StreamingSeries("x"), False),
+ ({"x": 1, "y": False}, StreamingSeries("x"), False),
+ ],
+ )
+ def test_invert(self, value, series, expected):
+ result = ~series
+
+ assert result.test(value) == expected
+
+ @pytest.mark.parametrize(
+ "value, series, other, expected",
+ [
+ ({"x": 1}, StreamingSeries("x"), [1, 2, 3], True),
+ ({"x": 1}, StreamingSeries("x"), [], False),
+ ({"x": 1}, StreamingSeries("x"), {1: 456}, True),
+ ],
+ )
+ def test_isin(self, value, series, other, expected):
+ assert series.isin(other).test(value) == expected
+
+ @pytest.mark.parametrize(
+ "value, series, other, expected",
+ [
+ ({"x": [1, 2, 3]}, StreamingSeries("x"), 1, True),
+ ({"x": [1, 2, 3]}, StreamingSeries("x"), 5, False),
+ ({"x": "abc"}, StreamingSeries("x"), "a", True),
+ ({"x": {"y": "z"}}, StreamingSeries("x"), "y", True),
+ ],
+ )
+ def test_contains(self, series, value, other, expected):
+ assert series.contains(other).test(value) == expected
+
+ @pytest.mark.parametrize(
+ "value, series, expected",
+ [
+ ({"x": None}, StreamingSeries("x"), True),
+ ({"x": [1, 2, 3]}, StreamingSeries("x"), False),
+ ],
+ )
+ def test_isnull(self, value, series, expected):
+ assert series.isnull().test(value) == expected
+
+ @pytest.mark.parametrize(
+ "value, series, expected",
+ [
+ ({"x": None}, StreamingSeries("x"), False),
+ ({"x": [1, 2, 3]}, StreamingSeries("x"), True),
+ ],
+ )
+ def test_notnull(self, value, series, expected):
+ assert series.notnull().test(value) == expected
+
+ @pytest.mark.parametrize(
+ "value, series, other, expected",
+ [
+ ({"x": [1, 2, 3]}, StreamingSeries("x"), None, False),
+ ({"x": None}, StreamingSeries("x"), None, True),
+ ({"x": 1}, StreamingSeries("x"), 1, True),
+ ],
+ )
+ def test_is_(self, value, series, other, expected):
+ assert series.is_(other).test(value) == expected
+
+ @pytest.mark.parametrize(
+ "value, series, other, expected",
+ [
+ ({"x": [1, 2, 3]}, StreamingSeries("x"), None, True),
+ ({"x": None}, StreamingSeries("x"), None, False),
+ ({"x": 1}, StreamingSeries("x"), 1, False),
+ ],
+ )
+ def test_isnot(self, value, series, other, expected):
+ assert series.isnot(other).test(value) == expected
+
+ @pytest.mark.parametrize(
+ "value, item, expected",
+ [
+ ({"x": {"y": 1}}, "y", 1),
+ ({"x": [0, 1, 2, 3]}, 1, 1),
+ ],
+ )
+ def test_getitem(self, value, item, expected):
+ result = StreamingSeries("x")[item]
+ assert result.test(value) == expected
+
+ def test_getitem_with_apply(self):
+ value = {"x": {"y": {"z": 110}}, "k": 0}
+ result = StreamingSeries("x")["y"]["z"].apply(lambda v: v + 10)
+
+ assert result.test(value) == 120
+
+ @pytest.mark.parametrize("value, expected", [(10, 10), (-10, 10), (10.0, 10.0)])
+ def test_abs_success(
+ self,
+ value,
+ expected,
+ ):
+ result = StreamingSeries("x").abs()
+
+ assert result.test({"x": value}) == expected
+
+ def test_abs_not_a_number_fails(self):
+ result = StreamingSeries("x").abs()
+
+ with pytest.raises(TypeError, match="bad operand type for abs()"):
+ assert result.test({"x": "string"})