Skip to content

Commit

Permalink
Feature: Collection Aggregations
Browse files Browse the repository at this point in the history
  • Loading branch information
gwaramadze committed Dec 16, 2024
1 parent b121c79 commit 4aa6fb3
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 14 deletions.
15 changes: 13 additions & 2 deletions quixstreams/dataframe/windows/sliding.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def process_window(
grace = self._grace_ms
aggregate = self._aggregate_func
default = self._aggregate_default
collect = self._aggregate_collection

# Sliding windows are inclusive on both ends, so values with
# timestamps equal to latest_timestamp - duration - grace
Expand Down Expand Up @@ -206,17 +207,27 @@ def process_window(
)
)

if collect:
state.collect_value(value=value, timestamp_ms=timestamp_ms)

expired_windows = [
{"start": start, "end": end, "value": self._merge_func(aggregation)}
for (start, end), (max_timestamp, aggregation) in state.expire_windows(
max_start_time=max_expired_window_start,
delete=False,
collect=collect,
end_inclusive=True,
)
if end == max_timestamp # Emit only left windows
]

state.delete_windows(max_start_time=max_deleted_window_start)
return reversed(updated_windows), expired_windows
state.delete_windows(
max_start_time=max_deleted_window_start,
delete_values=collect,
)

updated_windows = [] if collect else reversed(updated_windows)
return updated_windows, expired_windows

def _update_window(
self,
Expand Down
14 changes: 13 additions & 1 deletion quixstreams/dataframe/windows/time_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ def process_window(
)
continue

if self._aggregate_collection:
state.update_window(start, end, value=None, timestamp_ms=timestamp_ms)
continue

current_value = state.get_window(start, end, default=default)
aggregated = self._aggregate_func(current_value, value)
state.update_window(start, end, value=aggregated, timestamp_ms=timestamp_ms)
Expand All @@ -108,9 +112,13 @@ def process_window(
}
)

if self._aggregate_collection:
state.collect_value(value=value, timestamp_ms=timestamp_ms)

expired_windows = []
for (start, end), aggregated in state.expire_windows(
max_start_time=max_expired_window_start
max_start_time=max_expired_window_start,
collect=self._aggregate_collection,
):
expired_windows.append(
{"start": start, "end": end, "value": self._merge_func(aggregated)}
Expand Down Expand Up @@ -183,6 +191,10 @@ def current(self) -> "StreamingDataFrame":
This method processes streaming data and returns results as they come,
regardless of whether the window is closed or not.
"""
if self._aggregate_collection:
raise ValueError(
"`current` is not supported in combination with `collect`."
)

def window_callback(
value: Any, key: Any, timestamp_ms: int, _headers: Any, state: WindowedState
Expand Down
5 changes: 5 additions & 0 deletions quixstreams/state/rocksdb/windowed/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,8 @@

LATEST_TIMESTAMPS_CF_NAME = "__latest-timestamps__"
LATEST_TIMESTAMP_KEY = b"__latest_timestamp__"

GLOBAL_COUNTER_CF_NAME = "__global-counter__"
GLOBAL_COUNTER_KEY = b"__global_counter__"

VALUES_CF_NAME = "__values__"
4 changes: 4 additions & 0 deletions quixstreams/state/rocksdb/windowed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from ..partition import RocksDBStorePartition
from ..types import RocksDBOptionsType
from .metadata import (
GLOBAL_COUNTER_CF_NAME,
LATEST_DELETED_WINDOW_CF_NAME,
LATEST_EXPIRED_WINDOW_CF_NAME,
LATEST_TIMESTAMPS_CF_NAME,
VALUES_CF_NAME,
)
from .transaction import WindowedRocksDBPartitionTransaction

Expand Down Expand Up @@ -42,6 +44,8 @@ def __init__(
self._ensure_column_family(LATEST_EXPIRED_WINDOW_CF_NAME)
self._ensure_column_family(LATEST_DELETED_WINDOW_CF_NAME)
self._ensure_column_family(LATEST_TIMESTAMPS_CF_NAME)
self._ensure_column_family(GLOBAL_COUNTER_CF_NAME)
self._ensure_column_family(VALUES_CF_NAME)

def iter_items(
self, from_key: bytes, read_opt: ReadOptions, cf_name: str = "default"
Expand Down
34 changes: 30 additions & 4 deletions quixstreams/state/rocksdb/windowed/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,19 @@ def update_window(
prefix=self._prefix,
)

def collect_value(self, value: Any, timestamp_ms: int) -> None:
"""
Collect a value for the window.
:param value: value of the window
:param timestamp_ms: current message timestamp in milliseconds
"""
return self._transaction.collect_value(
value=value,
timestamp_ms=timestamp_ms,
prefix=self._prefix,
)

def get_latest_timestamp(self) -> Optional[int]:
"""
Get the latest observed timestamp for the current message key.
Expand All @@ -76,7 +89,11 @@ def get_latest_timestamp(self) -> Optional[int]:
return self._transaction.get_latest_timestamp(prefix=self._prefix)

def expire_windows(
self, max_start_time: int, delete: bool = True
self,
max_start_time: int,
delete: bool = True,
collect: bool = False,
end_inclusive: bool = False,
) -> list[tuple[tuple[int, int], Any]]:
"""
Get all expired windows from RocksDB up to the specified `max_start_time` timestamp.
Expand All @@ -86,10 +103,16 @@ def expire_windows(
:param max_start_time: The timestamp up to which windows are considered expired, inclusive.
:param delete: If True, expired windows will be deleted.
:param collect: If True, scattered values will be collected into single window.
:param end_inclusive: If True, the end of the window will be inclusive.
:return: A sorted list of tuples in the format `((start, end), value)`.
"""
return self._transaction.expire_windows(
max_start_time=max_start_time, prefix=self._prefix, delete=delete
max_start_time=max_start_time,
prefix=self._prefix,
delete=delete,
collect=collect,
end_inclusive=end_inclusive,
)

def get_windows(
Expand All @@ -110,7 +133,7 @@ def get_windows(
backwards=backwards,
)

def delete_windows(self, max_start_time: int) -> None:
def delete_windows(self, max_start_time: int, delete_values: bool) -> None:
"""
Delete windows from RocksDB up to the specified `max_start_time` timestamp.
Expand All @@ -119,7 +142,10 @@ def delete_windows(self, max_start_time: int) -> None:
unexpired windows.
:param max_start_time: The timestamp up to which windows should be deleted, inclusive.
:param delete_values: If True, values will be deleted.
"""
return self._transaction.delete_windows(
max_start_time=max_start_time, prefix=self._prefix
max_start_time=max_start_time,
delete_values=delete_values,
prefix=self._prefix,
)
88 changes: 86 additions & 2 deletions quixstreams/state/rocksdb/windowed/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@
from quixstreams.state.serialization import DumpsFunc, LoadsFunc, serialize

from .metadata import (
GLOBAL_COUNTER_CF_NAME,
GLOBAL_COUNTER_KEY,
LATEST_DELETED_WINDOW_CF_NAME,
LATEST_DELETED_WINDOW_TIMESTAMP_KEY,
LATEST_EXPIRED_WINDOW_CF_NAME,
LATEST_EXPIRED_WINDOW_TIMESTAMP_KEY,
LATEST_TIMESTAMP_KEY,
LATEST_TIMESTAMPS_CF_NAME,
VALUES_CF_NAME,
)
from .serialization import append_integer, encode_integer_pair, parse_window_key
from .state import WindowedTransactionState
Expand All @@ -31,6 +34,13 @@ class TimestampsCache:
timestamps: dict[bytes, Optional[int]] = field(default_factory=dict)


@dataclass
class CounterCache:
key: bytes
cf_name: str
counter: Optional[int] = None


class WindowedRocksDBPartitionTransaction(PartitionTransaction):
def __init__(
self,
Expand Down Expand Up @@ -63,6 +73,10 @@ def __init__(
key=LATEST_DELETED_WINDOW_TIMESTAMP_KEY,
cf_name=LATEST_DELETED_WINDOW_CF_NAME,
)
self._global_counter: CounterCache = CounterCache(
key=GLOBAL_COUNTER_KEY,
cf_name=GLOBAL_COUNTER_CF_NAME,
)

def as_state(self, prefix: Any = DEFAULT_PREFIX) -> WindowedTransactionState: # type: ignore [override]
return WindowedTransactionState(
Expand Down Expand Up @@ -117,13 +131,27 @@ def update_window(
timestamp_ms=updated_timestamp_ms,
)

def collect_value(
self,
timestamp_ms: int,
value: Any,
prefix: bytes,
) -> None:
key = encode_integer_pair(timestamp_ms, self._get_next_count())
self.set(key=key, value=value, prefix=prefix, cf_name=VALUES_CF_NAME)

def delete_window(self, start_ms: int, end_ms: int, prefix: bytes):
self._validate_duration(start_ms=start_ms, end_ms=end_ms)
key = encode_integer_pair(start_ms, end_ms)
self.delete(key=key, prefix=prefix)

def expire_windows(
self, max_start_time: int, prefix: bytes, delete: bool = True
self,
max_start_time: int,
prefix: bytes,
delete: bool = True,
collect: bool = False,
end_inclusive: bool = False,
) -> list[tuple[tuple[int, int], Any]]:
"""
Get all expired windows from RocksDB up to the specified `max_start_time` timestamp.
Expand All @@ -143,6 +171,8 @@ def expire_windows(
:param max_start_time: The timestamp up to which windows are considered expired, inclusive.
:param prefix: The key prefix for filtering windows.
:param delete: If True, expired windows will be deleted.
:param collect: If True, scattered values will be collected into single window.
:param end_inclusive: If True, the end timestamp will be inclusive.
:return: A sorted list of tuples in the format `((start, end), value)`.
"""
start_from = -1
Expand Down Expand Up @@ -174,14 +204,41 @@ def expire_windows(
timestamp_ms=last_expired__gt,
)

# Collect scattered values into windows
if collect:
collected_expired_windows = []
for (start, end), value in expired_windows:
collection = self._get_values(
start=start,
# Sliding windows are inclusive on both ends
# (including timestamps of messages equal to `end`).
# Since RocksDB range queries are exclusive on the
# `end` boundary, we add +1 to include it.
end=end + 1 if end_inclusive else end,
prefix=prefix,
)
if value is None:
value = collection
else:
# Sliding windows are timestamped:
# value is [max_timestamp, value] where max_timestamp
# is the timestamp of the latest message in the window
value[1] = collection
collected_expired_windows.append(((start, end), value))
expired_windows = collected_expired_windows

# Delete expired windows from the state
if delete:
for (start, end), _ in expired_windows:
self.delete_window(start, end, prefix=prefix)
if collect:
self._delete_values(max_timestamp=start, prefix=prefix)

return expired_windows

def delete_windows(self, max_start_time: int, prefix: bytes) -> None:
def delete_windows(
self, max_start_time: int, delete_values: bool, prefix: bytes
) -> None:
"""
Delete windows from RocksDB up to the specified `max_start_time` timestamp.
Expand Down Expand Up @@ -228,6 +285,15 @@ def delete_windows(self, max_start_time: int, prefix: bytes) -> None:
timestamp_ms=last_deleted__gt,
)

if delete_values:
self._delete_values(max_timestamp=max_start_time, prefix=prefix)

def _delete_values(self, max_timestamp: int, prefix: bytes) -> None:
for key, _ in self._get_items(
start=0, end=max_timestamp, prefix=prefix, cf_name=VALUES_CF_NAME
):
self.delete(key=key, prefix=prefix, cf_name=VALUES_CF_NAME)

def get_windows(
self,
start_from_ms: int,
Expand Down Expand Up @@ -269,6 +335,12 @@ def get_windows(

return result

def _get_values(self, start: int, end: int, prefix: bytes) -> list[Any]:
items = self._get_items(
start=start, end=end, prefix=prefix, cf_name=VALUES_CF_NAME
)
return [self._deserialize_value(value) for _, value in items]

def _get_items(
self,
start: int,
Expand Down Expand Up @@ -362,3 +434,15 @@ def _serialize_key(self, key: Any, prefix: bytes) -> bytes:
# Allow bytes keys in WindowedStore
key_bytes = key if isinstance(key, bytes) else serialize(key, dumps=self._dumps)
return prefix + SEPARATOR + key_bytes

def _get_next_count(self) -> int:
cache = self._global_counter
kwargs = {"key": cache.key, "prefix": b"", "cf_name": cache.cf_name}

if cache.counter is None:
cache.counter = self.get(default=-1, **kwargs)

cache.counter += 1

self.set(value=cache.counter, **kwargs)
return cache.counter
Loading

0 comments on commit 4aa6fb3

Please sign in to comment.