Skip to content

Commit

Permalink
Merge pull request #340 from chirino/distributed-bytes-keys
Browse files Browse the repository at this point in the history
[distributed store] use a single map Vec<u8> -> Counters map
  • Loading branch information
chirino authored May 23, 2024
2 parents 386ce89 + 1571819 commit 8ce12c8
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 262 deletions.
53 changes: 32 additions & 21 deletions limitador/src/storage/distributed/cr_counter_value.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,36 @@
use crate::storage::atomic_expiring_value::AtomicExpiryTime;
use std::collections::btree_map::Entry;
use std::collections::BTreeMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::RwLock;
use std::time::{Duration, SystemTime};

use crate::storage::atomic_expiring_value::AtomicExpiryTime;

#[derive(Debug)]
pub struct CrCounterValue<A: Ord> {
ourselves: A,
max_value: u64,
value: AtomicU64,
others: RwLock<BTreeMap<A, u64>>,
expiry: AtomicExpiryTime,
}

#[allow(dead_code)]
impl<A: Ord> CrCounterValue<A> {
pub fn new(actor: A, time_window: Duration) -> Self {
pub fn new(actor: A, max_value: u64, time_window: Duration) -> Self {
Self {
ourselves: actor,
max_value,
value: Default::default(),
others: RwLock::default(),
expiry: AtomicExpiryTime::new(SystemTime::now() + time_window),
}
}

pub fn max_value(&self) -> u64 {
self.max_value
}

pub fn read(&self) -> u64 {
self.read_at(SystemTime::now())
}
Expand Down Expand Up @@ -116,6 +123,7 @@ impl<A: Ord> CrCounterValue<A> {
pub fn into_inner(self) -> (SystemTime, BTreeMap<A, u64>) {
let Self {
ourselves,
max_value: _,
value,
others,
expiry,
Expand All @@ -137,6 +145,7 @@ impl<A: Clone + Ord> Clone for CrCounterValue<A> {
fn clone(&self) -> Self {
Self {
ourselves: self.ourselves.clone(),
max_value: self.max_value,
value: AtomicU64::new(self.value.load(Ordering::SeqCst)),
others: RwLock::new(self.others.read().unwrap().clone()),
expiry: self.expiry.clone(),
Expand All @@ -148,6 +157,7 @@ impl<A: Clone + Ord + Default> From<(SystemTime, BTreeMap<A, u64>)> for CrCounte
fn from(value: (SystemTime, BTreeMap<A, u64>)) -> Self {
Self {
ourselves: A::default(),
max_value: 0,
value: Default::default(),
others: RwLock::new(value.1),
expiry: value.0.into(),
Expand All @@ -157,13 +167,14 @@ impl<A: Clone + Ord + Default> From<(SystemTime, BTreeMap<A, u64>)> for CrCounte

#[cfg(test)]
mod tests {
use crate::storage::distributed::cr_counter_value::CrCounterValue;
use std::time::{Duration, SystemTime};

use crate::storage::distributed::cr_counter_value::CrCounterValue;

#[test]
fn local_increments_are_readable() {
let window = Duration::from_secs(1);
let a = CrCounterValue::new('A', window);
let a = CrCounterValue::new('A', u64::MAX, window);
a.inc(3, window);
assert_eq!(3, a.read());
a.inc(2, window);
Expand All @@ -173,7 +184,7 @@ mod tests {
#[test]
fn local_increments_expire() {
let window = Duration::from_secs(1);
let a = CrCounterValue::new('A', window);
let a = CrCounterValue::new('A', u64::MAX, window);
let now = SystemTime::now();
a.inc_at(3, window, now);
assert_eq!(3, a.read());
Expand All @@ -184,7 +195,7 @@ mod tests {
#[test]
fn other_increments_are_readable() {
let window = Duration::from_secs(1);
let a = CrCounterValue::new('A', window);
let a = CrCounterValue::new('A', u64::MAX, window);
a.inc_actor('B', 3, window);
assert_eq!(3, a.read());
a.inc_actor('B', 2, window);
Expand All @@ -194,7 +205,7 @@ mod tests {
#[test]
fn other_increments_expire() {
let window = Duration::from_secs(1);
let a = CrCounterValue::new('A', window);
let a = CrCounterValue::new('A', u64::MAX, window);
let now = SystemTime::now();
a.inc_actor_at('B', 3, window, now);
assert_eq!(3, a.read());
Expand All @@ -205,8 +216,8 @@ mod tests {
#[test]
fn merges() {
let window = Duration::from_secs(1);
let a = CrCounterValue::new('A', window);
let b = CrCounterValue::new('B', window);
let a = CrCounterValue::new('A', u64::MAX, window);
let b = CrCounterValue::new('B', u64::MAX, window);
a.inc(3, window);
b.inc(2, window);
a.merge(b);
Expand All @@ -216,8 +227,8 @@ mod tests {
#[test]
fn merges_symetric() {
let window = Duration::from_secs(1);
let a = CrCounterValue::new('A', window);
let b = CrCounterValue::new('B', window);
let a = CrCounterValue::new('A', u64::MAX, window);
let b = CrCounterValue::new('B', u64::MAX, window);
a.inc(3, window);
b.inc(2, window);
b.merge(a);
Expand All @@ -227,8 +238,8 @@ mod tests {
#[test]
fn merges_overrides_with_larger_value() {
let window = Duration::from_secs(1);
let a = CrCounterValue::new('A', window);
let b = CrCounterValue::new('B', window);
let a = CrCounterValue::new('A', u64::MAX, window);
let b = CrCounterValue::new('B', u64::MAX, window);
a.inc(3, window);
b.inc(2, window);
b.inc_actor('A', 2, window); // older value!
Expand All @@ -239,8 +250,8 @@ mod tests {
#[test]
fn merges_ignore_lesser_values() {
let window = Duration::from_secs(1);
let a = CrCounterValue::new('A', window);
let b = CrCounterValue::new('B', window);
let a = CrCounterValue::new('A', u64::MAX, window);
let b = CrCounterValue::new('B', u64::MAX, window);
a.inc(3, window);
b.inc(2, window);
b.inc_actor('A', 5, window); // newer value!
Expand All @@ -251,9 +262,9 @@ mod tests {
#[test]
fn merge_ignores_expired_sets() {
let window = Duration::from_secs(1);
let a = CrCounterValue::new('A', Duration::ZERO);
let a = CrCounterValue::new('A', u64::MAX, Duration::ZERO);
a.inc(3, Duration::ZERO);
let b = CrCounterValue::new('B', window);
let b = CrCounterValue::new('B', u64::MAX, window);
b.inc(2, window);
b.merge(a);
assert_eq!(b.read(), 2);
Expand All @@ -262,9 +273,9 @@ mod tests {
#[test]
fn merge_ignores_expired_sets_symmetric() {
let window = Duration::from_secs(1);
let a = CrCounterValue::new('A', Duration::ZERO);
let a = CrCounterValue::new('A', u64::MAX, Duration::ZERO);
a.inc(3, Duration::ZERO);
let b = CrCounterValue::new('B', window);
let b = CrCounterValue::new('B', u64::MAX, window);
b.inc(2, window);
a.merge(b);
assert_eq!(a.read(), 2);
Expand All @@ -273,9 +284,9 @@ mod tests {
#[test]
fn merge_uses_earliest_expiry() {
let later = Duration::from_secs(1);
let a = CrCounterValue::new('A', later);
let a = CrCounterValue::new('A', u64::MAX, later);
let sooner = Duration::from_millis(200);
let b = CrCounterValue::new('B', sooner);
let b = CrCounterValue::new('B', u64::MAX, sooner);
a.inc(3, later);
b.inc(2, later);
a.merge(b);
Expand Down
46 changes: 7 additions & 39 deletions limitador/src/storage/distributed/grpc/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use std::collections::{BTreeMap, HashMap, HashSet};
use std::collections::{HashMap, HashSet};
use std::net::SocketAddr;
use std::ops::Add;
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use std::{error::Error, io::ErrorKind, pin::Pin};

use moka::sync::Cache;
use tokio::sync::mpsc::Sender;
use tokio::sync::{broadcast, mpsc, RwLock};
use tokio::time::sleep;
Expand All @@ -14,15 +13,12 @@ use tokio_stream::{wrappers::ReceiverStream, Stream, StreamExt};
use tonic::{Code, Request, Response, Status, Streaming};
use tracing::debug;

use crate::counter::Counter;
use crate::storage::distributed::cr_counter_value::CrCounterValue;
use crate::storage::distributed::grpc::v1::packet::Message;
use crate::storage::distributed::grpc::v1::replication_client::ReplicationClient;
use crate::storage::distributed::grpc::v1::replication_server::{Replication, ReplicationServer};
use crate::storage::distributed::grpc::v1::{
CounterUpdate, Hello, MembershipUpdate, Packet, Peer, Pong,
};
use crate::storage::distributed::CounterKey;

// clippy will barf on protobuff generated code for enum variants in
// v3::socket_option::SocketState, so allow this lint
Expand Down Expand Up @@ -187,34 +183,7 @@ impl Session {
}
Some(Message::CounterUpdate(update)) => {
debug!("peer: '{}': CounterUpdate", self.peer_id);

let counter_key = postcard::from_bytes::<CounterKey>(update.key.as_slice())
.map_err(|err| {
Status::internal(format!("failed to decode counter key: {:?}", err))
})?;

let values = BTreeMap::from_iter(
update
.values
.iter()
.map(|(k, v)| (k.to_owned(), v.to_owned())),
);

let counter = <CounterKey as Into<Counter>>::into(counter_key);
if counter.is_qualified() {
if let Some(counter) = self.broker_state.qualified_counters.get(&counter) {
counter.merge(
(UNIX_EPOCH + Duration::from_secs(update.expires_at), values).into(),
);
}
} else {
let counters = self.broker_state.limits_for_namespace.read().unwrap();
let limits = counters.get(counter.namespace()).unwrap();
let value = limits.get(counter.limit()).unwrap();
value.merge(
(UNIX_EPOCH + Duration::from_secs(update.expires_at), values).into(),
);
};
(self.broker_state.on_counter_update)(update);
}
_ => {
debug!("peer: '{}': unsupported packet: {:?}", self.peer_id, packet);
Expand Down Expand Up @@ -348,12 +317,13 @@ impl MessageSender {
}
}

type CounterUpdateFn = Pin<Box<dyn Fn(CounterUpdate) + Sync + Send>>;

#[derive(Clone)]
struct BrokerState {
id: String,
limits_for_namespace: Arc<std::sync::RwLock<super::LimitsMap>>,
qualified_counters: Arc<Cache<Counter, Arc<CrCounterValue<String>>>>,
publisher: broadcast::Sender<CounterUpdate>,
on_counter_update: Arc<CounterUpdateFn>,
}

#[derive(Clone)]
Expand All @@ -369,8 +339,7 @@ impl Broker {
id: String,
listen_address: SocketAddr,
peer_urls: Vec<String>,
limits_for_namespace: Arc<std::sync::RwLock<super::LimitsMap>>,
qualified_counters: Arc<Cache<Counter, Arc<CrCounterValue<String>>>>,
on_counter_update: CounterUpdateFn,
) -> Broker {
let (tx, _) = broadcast::channel(16);
let publisher: broadcast::Sender<CounterUpdate> = tx;
Expand All @@ -381,8 +350,7 @@ impl Broker {
broker_state: BrokerState {
id,
publisher,
limits_for_namespace,
qualified_counters,
on_counter_update: Arc::new(on_counter_update),
},
replication_state: Arc::new(RwLock::new(ReplicationState {
discovered_urls: HashSet::new(),
Expand Down
Loading

0 comments on commit 8ce12c8

Please sign in to comment.