From 59a9f85918591f907e11e62e9f358385a62705b6 Mon Sep 17 00:00:00 2001 From: Alex Snaps Date: Thu, 9 May 2024 13:10:38 -0400 Subject: [PATCH] Revert on transient redis error --- limitador/src/storage/redis/counters_cache.rs | 90 ++++++++++++++----- limitador/src/storage/redis/redis_cached.rs | 90 +++++++++++++++++-- 2 files changed, 150 insertions(+), 30 deletions(-) diff --git a/limitador/src/storage/redis/counters_cache.rs b/limitador/src/storage/redis/counters_cache.rs index d527bd1f..67591f44 100644 --- a/limitador/src/storage/redis/counters_cache.rs +++ b/limitador/src/storage/redis/counters_cache.rs @@ -104,6 +104,22 @@ impl CachedCounterValue { value - start == 0 } + fn revert_writes(&self, writes: u64) -> Result<(), ()> { + let newer = self.initial_value.load(Ordering::SeqCst); + if newer > writes { + return match self.initial_value.compare_exchange( + newer, + newer - writes, + Ordering::SeqCst, + Ordering::SeqCst, + ) { + Err(expiry) if expiry != 0 => Err(()), + _ => Ok(()), + }; + } + Ok(()) + } + pub fn hits(&self, _: &Counter) -> u64 { self.value.value_at(SystemTime::now()) } @@ -165,10 +181,10 @@ impl Batcher { self.notifier.notify_one(); } - pub async fn consume(&self, max: usize, consumer: F) -> O + pub async fn consume(&self, max: usize, consumer: F) -> Result where F: FnOnce(HashMap>) -> Fut, - Fut: Future, + Fut: Future>, { let mut ready = self.batch_ready(max); loop { @@ -191,15 +207,17 @@ impl Batcher { } histogram!("batcher_flush_size").record(result.len() as f64); let result = consumer(result).await; - batch.iter().for_each(|counter| { - let prev = self - .updates - .remove_if(counter, |_, v| v.no_pending_writes()); - if prev.is_some() { - self.limiter.add_permits(1); - gauge!("batcher_size").decrement(1); - } - }); + if result.is_ok() { + batch.iter().for_each(|counter| { + let prev = self + .updates + .remove_if(counter, |_, v| v.no_pending_writes()); + if prev.is_some() { + self.limiter.add_permits(1); + gauge!("batcher_size").decrement(1); + } + }); + } return result; } else { ready = select! { @@ -252,6 +270,31 @@ impl CountersCache { &self.batcher } + pub fn return_pending_writes( + &self, + counter: &Counter, + value: u64, + writes: u64, + ) -> Result<(), ()> { + if writes != 0 { + let mut miss = false; + let value = self.cache.get_with_by_ref(counter, || { + if let Some(entry) = self.batcher.updates.get(counter) { + entry.value().clone() + } else { + miss = true; + let value = Arc::new(CachedCounterValue::from_authority(counter, value)); + value.delta(counter, writes); + value + } + }); + if miss.not() { + return value.revert_writes(writes); + } + } + Ok(()) + } + pub fn apply_remote_delta( &self, counter: Counter, @@ -456,9 +499,10 @@ mod tests { .consume(2, |items| { assert!(items.is_empty()); assert!(SystemTime::now().duration_since(start).unwrap() >= duration); - async {} + async { Ok::<(), ()>(()) } }) - .await; + .await + .expect("Always Ok!"); } #[tokio::test] @@ -482,9 +526,10 @@ mod tests { SystemTime::now().duration_since(start).unwrap() >= Duration::from_millis(100) ); - async {} + async { Ok::<(), ()>(()) } }) - .await; + .await + .expect("Always Ok!"); } #[tokio::test] @@ -507,9 +552,10 @@ mod tests { let wait_period = SystemTime::now().duration_since(start).unwrap(); assert!(wait_period >= Duration::from_millis(40)); assert!(wait_period < Duration::from_millis(50)); - async {} + async { Ok::<(), ()>(()) } }) - .await; + .await + .expect("Always Ok!"); } #[tokio::test] @@ -528,9 +574,10 @@ mod tests { assert!( SystemTime::now().duration_since(start).unwrap() < Duration::from_millis(5) ); - async {} + async { Ok::<(), ()>(()) } }) - .await; + .await + .expect("Always Ok!"); } #[tokio::test] @@ -553,9 +600,10 @@ mod tests { let wait_period = SystemTime::now().duration_since(start).unwrap(); assert!(wait_period >= Duration::from_millis(40)); assert!(wait_period < Duration::from_millis(50)); - async {} + async { Ok::<(), ()>(()) } }) - .await; + .await + .expect("Always Ok!"); } } diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index 992f2a81..1a5e036b 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -284,7 +284,7 @@ impl CachedRedisStorageBuilder { async fn update_counters( redis_conn: &mut C, counters_and_deltas: HashMap>, -) -> Result, StorageErr> { +) -> Result, (Vec<(Counter, u64, u64, i64)>, StorageErr)> { let redis_script = redis::Script::new(BATCH_UPDATE_COUNTERS); let mut script_invocation = redis_script.prepare_invoke(); @@ -303,15 +303,21 @@ async fn update_counters( script_invocation.arg(counter.seconds()); script_invocation.arg(delta); // We need to store the counter in the actual order we are sending it to the script - res.push((counter, 0, last_value_from_redis, 0)); + res.push((counter, last_value_from_redis, delta, 0)); } } // The redis crate is not working with tables, thus the response will be a Vec of counter values - let script_res: Vec = script_invocation + let script_res: Vec = match script_invocation .invoke_async(redis_conn) .instrument(debug_span!("datastore")) - .await?; + .await + { + Ok(res) => res, + Err(err) => { + return Err((res, err.into())); + } + }; // We need to update the values and ttls returned by redis let counters_range = 0..res.len(); @@ -319,10 +325,10 @@ async fn update_counters( for (i, j) in counters_range.zip(script_res_range) { let (_, val, delta, expires_at) = &mut res[i]; - *val = u64::try_from(script_res[j]).unwrap_or(0); *delta = u64::try_from(script_res[j]) .unwrap_or(0) - .saturating_sub(*delta); + .saturating_sub(*val); // new value - previous one = remote writes + *val = u64::try_from(script_res[j]).unwrap_or(0); // update to value to newest *expires_at = script_res[j + 1]; } res @@ -349,9 +355,22 @@ async fn flush_batcher_and_update_counters( update_counters(&mut redis_conn, counters) }) .await - .or_else(|err| { + .or_else(|(data, err)| { if err.is_transient() { flip_partitioned(&partitioned, true); + let counters = data.len(); + let mut reverted = 0; + for (counter, old_value, pending_writes, _) in data { + if cached_counters + .return_pending_writes(&counter, old_value, pending_writes) + .is_err() + { + error!("Couldn't revert writes back to {:?}", &counter); + } else { + reverted += 1; + } + } + warn!("Reverted {} of {} counter increments", reverted, counters); Ok(Vec::new()) } else { Err(err) @@ -375,9 +394,10 @@ mod tests { }; use crate::storage::redis::redis_cached::{flush_batcher_and_update_counters, update_counters}; use crate::storage::redis::CachedRedisStorage; - use redis::{ErrorKind, Value}; + use redis::{Cmd, ErrorKind, RedisError, Value}; use redis_test::{MockCmd, MockRedisConnection}; use std::collections::HashMap; + use std::io; use std::ops::Add; use std::sync::atomic::AtomicBool; use std::sync::Arc; @@ -518,8 +538,60 @@ mod tests { ) .await; + let c = cached_counters.get(&counter).unwrap(); + assert_eq!(c.hits(&counter), 8); + assert_eq!(c.pending_writes(), Ok(0)); + } + + #[tokio::test] + async fn flush_batcher_reverts_on_err() { + let counter = Counter::new( + Limit::new( + "test_namespace", + 10, + 60, + vec!["req.method == 'POST'"], + vec!["app_id"], + ), + Default::default(), + ); + + let error: RedisError = io::Error::new(io::ErrorKind::TimedOut, "That was long!").into(); + assert!(error.is_timeout()); + let mock_client = MockRedisConnection::new(vec![MockCmd::new::<&mut Cmd, Value>( + redis::cmd("EVALSHA") + .arg("95a717e821d8fbdd667b5e4c6fede4c9cad16006") + .arg("2") + .arg(key_for_counter(&counter)) + .arg(key_for_counters_of_limit(counter.limit())) + .arg(60) + .arg(3), + Err(error), + )]); + + let cache = CountersCacheBuilder::new().build(Duration::from_millis(10)); + let value = Arc::new(CachedCounterValue::from_authority(&counter, 2)); + value.delta(&counter, 3); + cache.batcher().add(counter.clone(), value).await; + + let cached_counters: Arc = Arc::new(cache); + let partitioned = Arc::new(AtomicBool::new(false)); + if let Some(c) = cached_counters.get(&counter) { - assert_eq!(c.hits(&counter), 8); + assert_eq!(c.hits(&counter), 5); } + + flush_batcher_and_update_counters( + mock_client, + true, + cached_counters.clone(), + partitioned, + 100, + ) + .await; + + let c = cached_counters.get(&counter).unwrap(); + assert_eq!(c.hits(&counter), 5); + assert_eq!(c.pending_writes(), Ok(3)); } }