Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert on transient redis error #319

Merged
merged 1 commit into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 69 additions & 21 deletions limitador/src/storage/redis/counters_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,22 @@ impl CachedCounterValue {
value - start == 0
}

fn revert_writes(&self, writes: u64) -> Result<(), ()> {
let newer = self.initial_value.load(Ordering::SeqCst);
if newer > writes {
return match self.initial_value.compare_exchange(
newer,
newer - writes,
Ordering::SeqCst,
Ordering::SeqCst,
) {
Err(expiry) if expiry != 0 => Err(()),
_ => Ok(()),
};
}
Ok(())
}

pub fn hits(&self, _: &Counter) -> u64 {
self.value.value_at(SystemTime::now())
}
Expand Down Expand Up @@ -165,10 +181,10 @@ impl Batcher {
self.notifier.notify_one();
}

pub async fn consume<F, Fut, O>(&self, max: usize, consumer: F) -> O
pub async fn consume<F, Fut, T, E>(&self, max: usize, consumer: F) -> Result<T, E>
where
F: FnOnce(HashMap<Counter, Arc<CachedCounterValue>>) -> Fut,
Fut: Future<Output = O>,
Fut: Future<Output = Result<T, E>>,
Comment on lines +184 to +187
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This signature is getting somewhat … ugly. But async closures aren't a stable thing yet, so I can't have anything call await in the closure passed to consume… that's how I went about it instead…

{
let mut ready = self.batch_ready(max);
loop {
Expand All @@ -191,15 +207,17 @@ impl Batcher {
}
histogram!("batcher_flush_size").record(result.len() as f64);
let result = consumer(result).await;
batch.iter().for_each(|counter| {
let prev = self
.updates
.remove_if(counter, |_, v| v.no_pending_writes());
if prev.is_some() {
self.limiter.add_permits(1);
gauge!("batcher_size").decrement(1);
}
});
if result.is_ok() {
batch.iter().for_each(|counter| {
let prev = self
.updates
.remove_if(counter, |_, v| v.no_pending_writes());
if prev.is_some() {
self.limiter.add_permits(1);
gauge!("batcher_size").decrement(1);
}
});
}
return result;
} else {
ready = select! {
Expand Down Expand Up @@ -252,6 +270,31 @@ impl CountersCache {
&self.batcher
}

pub fn return_pending_writes(
&self,
counter: &Counter,
value: u64,
writes: u64,
) -> Result<(), ()> {
if writes != 0 {
let mut miss = false;
let value = self.cache.get_with_by_ref(counter, || {
if let Some(entry) = self.batcher.updates.get(counter) {
entry.value().clone()
} else {
miss = true;
let value = Arc::new(CachedCounterValue::from_authority(counter, value));
value.delta(counter, writes);
value
}
});
if miss.not() {
return value.revert_writes(writes);
}
}
Ok(())
}

pub fn apply_remote_delta(
&self,
counter: Counter,
Expand Down Expand Up @@ -456,9 +499,10 @@ mod tests {
.consume(2, |items| {
assert!(items.is_empty());
assert!(SystemTime::now().duration_since(start).unwrap() >= duration);
async {}
async { Ok::<(), ()>(()) }
})
.await;
.await
.expect("Always Ok!");
}

#[tokio::test]
Expand All @@ -482,9 +526,10 @@ mod tests {
SystemTime::now().duration_since(start).unwrap()
>= Duration::from_millis(100)
);
async {}
async { Ok::<(), ()>(()) }
})
.await;
.await
.expect("Always Ok!");
}

#[tokio::test]
Expand All @@ -507,9 +552,10 @@ mod tests {
let wait_period = SystemTime::now().duration_since(start).unwrap();
assert!(wait_period >= Duration::from_millis(40));
assert!(wait_period < Duration::from_millis(50));
async {}
async { Ok::<(), ()>(()) }
})
.await;
.await
.expect("Always Ok!");
}

#[tokio::test]
Expand All @@ -528,9 +574,10 @@ mod tests {
assert!(
SystemTime::now().duration_since(start).unwrap() < Duration::from_millis(5)
);
async {}
async { Ok::<(), ()>(()) }
})
.await;
.await
.expect("Always Ok!");
}

#[tokio::test]
Expand All @@ -553,9 +600,10 @@ mod tests {
let wait_period = SystemTime::now().duration_since(start).unwrap();
assert!(wait_period >= Duration::from_millis(40));
assert!(wait_period < Duration::from_millis(50));
async {}
async { Ok::<(), ()>(()) }
})
.await;
.await
.expect("Always Ok!");
}
}

Expand Down
90 changes: 81 additions & 9 deletions limitador/src/storage/redis/redis_cached.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ impl CachedRedisStorageBuilder {
async fn update_counters<C: ConnectionLike>(
redis_conn: &mut C,
counters_and_deltas: HashMap<Counter, Arc<CachedCounterValue>>,
) -> Result<Vec<(Counter, u64, u64, i64)>, StorageErr> {
) -> Result<Vec<(Counter, u64, u64, i64)>, (Vec<(Counter, u64, u64, i64)>, StorageErr)> {
let redis_script = redis::Script::new(BATCH_UPDATE_COUNTERS);
let mut script_invocation = redis_script.prepare_invoke();

Expand All @@ -303,26 +303,32 @@ async fn update_counters<C: ConnectionLike>(
script_invocation.arg(counter.seconds());
script_invocation.arg(delta);
// We need to store the counter in the actual order we are sending it to the script
res.push((counter, 0, last_value_from_redis, 0));
res.push((counter, last_value_from_redis, delta, 0));
}
}

// The redis crate is not working with tables, thus the response will be a Vec of counter values
let script_res: Vec<i64> = script_invocation
let script_res: Vec<i64> = match script_invocation
.invoke_async(redis_conn)
.instrument(debug_span!("datastore"))
.await?;
.await
{
Ok(res) => res,
Err(err) => {
return Err((res, err.into()));
}
};

// We need to update the values and ttls returned by redis
let counters_range = 0..res.len();
let script_res_range = (0..script_res.len()).step_by(2);

for (i, j) in counters_range.zip(script_res_range) {
let (_, val, delta, expires_at) = &mut res[i];
*val = u64::try_from(script_res[j]).unwrap_or(0);
*delta = u64::try_from(script_res[j])
.unwrap_or(0)
.saturating_sub(*delta);
.saturating_sub(*val); // new value - previous one = remote writes
*val = u64::try_from(script_res[j]).unwrap_or(0); // update to value to newest
*expires_at = script_res[j + 1];
}
res
Expand All @@ -349,9 +355,22 @@ async fn flush_batcher_and_update_counters<C: ConnectionLike>(
update_counters(&mut redis_conn, counters)
})
.await
.or_else(|err| {
.or_else(|(data, err)| {
if err.is_transient() {
flip_partitioned(&partitioned, true);
let counters = data.len();
let mut reverted = 0;
for (counter, old_value, pending_writes, _) in data {
if cached_counters
.return_pending_writes(&counter, old_value, pending_writes)
.is_err()
{
error!("Couldn't revert writes back to {:?}", &counter);
} else {
reverted += 1;
}
}
warn!("Reverted {} of {} counter increments", reverted, counters);
Ok(Vec::new())
} else {
Err(err)
Expand All @@ -375,9 +394,10 @@ mod tests {
};
use crate::storage::redis::redis_cached::{flush_batcher_and_update_counters, update_counters};
use crate::storage::redis::CachedRedisStorage;
use redis::{ErrorKind, Value};
use redis::{Cmd, ErrorKind, RedisError, Value};
use redis_test::{MockCmd, MockRedisConnection};
use std::collections::HashMap;
use std::io;
use std::ops::Add;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
Expand Down Expand Up @@ -518,8 +538,60 @@ mod tests {
)
.await;

let c = cached_counters.get(&counter).unwrap();
assert_eq!(c.hits(&counter), 8);
assert_eq!(c.pending_writes(), Ok(0));
}

#[tokio::test]
async fn flush_batcher_reverts_on_err() {
let counter = Counter::new(
Limit::new(
"test_namespace",
10,
60,
vec!["req.method == 'POST'"],
vec!["app_id"],
),
Default::default(),
);

let error: RedisError = io::Error::new(io::ErrorKind::TimedOut, "That was long!").into();
assert!(error.is_timeout());
let mock_client = MockRedisConnection::new(vec![MockCmd::new::<&mut Cmd, Value>(
redis::cmd("EVALSHA")
.arg("95a717e821d8fbdd667b5e4c6fede4c9cad16006")
.arg("2")
.arg(key_for_counter(&counter))
.arg(key_for_counters_of_limit(counter.limit()))
.arg(60)
.arg(3),
Err(error),
)]);

let cache = CountersCacheBuilder::new().build(Duration::from_millis(10));
let value = Arc::new(CachedCounterValue::from_authority(&counter, 2));
value.delta(&counter, 3);
cache.batcher().add(counter.clone(), value).await;

let cached_counters: Arc<CountersCache> = Arc::new(cache);
let partitioned = Arc::new(AtomicBool::new(false));

if let Some(c) = cached_counters.get(&counter) {
assert_eq!(c.hits(&counter), 8);
assert_eq!(c.hits(&counter), 5);
}

flush_batcher_and_update_counters(
mock_client,
true,
cached_counters.clone(),
partitioned,
100,
)
.await;

let c = cached_counters.get(&counter).unwrap();
assert_eq!(c.hits(&counter), 5);
assert_eq!(c.pending_writes(), Ok(3));
}
}
Loading