From b74cbda0da26a3c33da3665c1b51d2f7a1d73b3e Mon Sep 17 00:00:00 2001 From: imDema Date: Thu, 6 Jun 2024 12:17:44 +0200 Subject: [PATCH] Add keyed fold scan --- src/operator/iteration/iterate.rs | 2 +- src/operator/mod.rs | 87 +++++++++++++++++++++++++++++++ src/operator/sink/avro.rs | 27 ---------- tests/fold_scan.rs | 52 ++++++++++++++++++ 4 files changed, 140 insertions(+), 28 deletions(-) diff --git a/src/operator/iteration/iterate.rs b/src/operator/iteration/iterate.rs index 969e906..b3712e3 100644 --- a/src/operator/iteration/iterate.rs +++ b/src/operator/iteration/iterate.rs @@ -297,7 +297,7 @@ impl Operator for Iterate Display for Iterate { +impl Display for Iterate { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Iterate<{}>", std::any::type_name::()) } diff --git a/src/operator/mod.rs b/src/operator/mod.rs index 52c16dd..ac24334 100644 --- a/src/operator/mod.rs +++ b/src/operator/mod.rs @@ -3,6 +3,7 @@ //! The actual operator list can be found from the implemented methods of [`Stream`], //! [`KeyedStream`], [`crate::WindowedStream`] +use std::collections::HashMap; use std::fmt::Display; use std::hash::Hash; use std::ops::{AddAssign, Div}; @@ -2685,6 +2686,92 @@ where self.0.split_block(End::new, NextStrategy::random()) } + pub fn fold_scan( + self, + init: S, + fold: L, + map: F, + ) -> KeyedStream> + where + Op::Out: ExchangeData, + I: Send, + K: ExchangeDataKey + Sync, + L: Fn(&K, &mut S, I) + Send + Clone + 'static, + F: Fn(&K, &S, I) -> O + Send + Clone + 'static, + S: ExchangeData + Sync, + O: ExchangeData, + { + #[derive(Serialize, Deserialize, Clone)] + enum TwoPass { + First(I), + Second(I), + Output(O), + } + + let (state, s) = self.map(|el| TwoPass::First(el.1)).unkey().iterate( + 2, + HashMap::::default(), + |s, state| { + s.to_keyed() + .map(move |(k, el)| match el { + TwoPass::First(el) => TwoPass::Second(el), + TwoPass::Second(el) => { + TwoPass::Output((map)(k, state.get().get(k).unwrap(), el)) + } + TwoPass::Output(_) => unreachable!(), + }) + .unkey() + }, + move |local: &mut HashMap, (k, el)| match el { + TwoPass::First(_) => {} + TwoPass::Second(el) => fold( + &k, + local.entry(k.clone()).or_insert_with(|| init.clone()), + el, + ), + TwoPass::Output(_) => {} + }, + move |global: &mut HashMap, mut local| { + global.extend(local.drain()); + }, + |_| true, + ); + + state.for_each(std::mem::drop); + s.to_keyed().map(|(_, t)| match t { + TwoPass::First(_) | TwoPass::Second(_) => unreachable!(), + TwoPass::Output(o) => o, + }) + } + + pub fn reduce_scan( + self, + first_map: F1, + reduce: R, + second_map: F2, + ) -> KeyedStream> + where + Op::Out: ExchangeData, + F1: Fn(&K, I) -> S + Send + Clone + 'static, + F2: Fn(&K, &S, I) -> O + Send + Clone + 'static, + R: Fn(&K, S, S) -> S + Send + Clone + 'static, + K: Sync, + S: ExchangeData + Sync, + O: ExchangeData, + { + self.fold_scan( + None, + move |k, acc: &mut Option, x| { + let map = (first_map)(k, x); + *acc = Some(match acc.take() { + Some(v) => (reduce)(k, v, map), + None => map, + }); + }, + move |k, state, x| (second_map)(k, state.as_ref().unwrap(), x), + ) + } + /// Close the stream and send resulting items to a channel on a single host. /// /// If the stream is distributed among multiple replicas, parallelism will diff --git a/src/operator/sink/avro.rs b/src/operator/sink/avro.rs index 489cabe..edb68f7 100644 --- a/src/operator/sink/avro.rs +++ b/src/operator/sink/avro.rs @@ -134,30 +134,3 @@ where .finalize_block(); } } - -// #[cfg(test)] -// mod qtests { -// use std::AvroSinkions::HashSet; - -// use crate::config::RuntimeConfig; -// use crate::environment::StreamContext; -// use crate::operator::source; - -// #[test] -// fn AvroSink_vec() { -// let env = StreamContext::new(RuntimeConfig::local(4).unwrap()); -// let source = source::IteratorSource::new(0..10u8); -// let res = env.stream(source).AvroSink::>(); -// env.execute_blocking(); -// assert_eq!(res.get().unwrap(), (0..10).AvroSink::>()); -// } - -// #[test] -// fn AvroSink_set() { -// let env = StreamContext::new(RuntimeConfig::local(4).unwrap()); -// let source = source::IteratorSource::new(0..10u8); -// let res = env.stream(source).AvroSink::>(); -// env.execute_blocking(); -// assert_eq!(res.get().unwrap(), (0..10).AvroSink::>()); -// } -// } diff --git a/tests/fold_scan.rs b/tests/fold_scan.rs index c9c7c2d..0e936a5 100644 --- a/tests/fold_scan.rs +++ b/tests/fold_scan.rs @@ -56,3 +56,55 @@ fn reduce_scan() { } }); } + +#[test] +fn keyed_fold_scan() { + TestHelper::local_remote_env(|ctx| { + let res = ctx + .stream_par_iter(0..100i32) + .group_by(|e| e % 5) + .fold_scan( + 0, + |_k, acc: &mut i32, x| { + *acc += x; + }, + |_k, acc, x| (x, *acc), + ) + .unkey() + .map(|t| (t.0, t.1 .0, t.1 .1)) + .collect_vec(); + + ctx.execute_blocking(); + if let Some(mut res) = res.get() { + let mut expected = (0..100) + .map(|x| (x % 5, x, (0..100).filter(|m| m % 5 == x % 5).sum::())) + .collect::>(); + res.sort(); + expected.sort(); + assert_eq!(expected, res); + } + }); +} + +#[test] +fn keyed_reduce_scan() { + TestHelper::local_remote_env(|ctx| { + let res = ctx + .stream_par_iter(0..100i32) + .group_by(|e| e % 5) + .reduce_scan(|_k, x| x, |_k, a, b| a + b, |_k, acc, x| (x, *acc)) + .unkey() + .map(|t| (t.0, t.1 .0, t.1 .1)) + .collect_vec(); + + ctx.execute_blocking(); + if let Some(mut res) = res.get() { + let mut expected = (0..100) + .map(|x| (x % 5, x, (0..100).filter(|m| m % 5 == x % 5).sum::())) + .collect::>(); + res.sort(); + expected.sort(); + assert_eq!(expected, res); + } + }); +}