Skip to content

Commit

Permalink
Fix rate limiter, add tests for token credit
Browse files Browse the repository at this point in the history
  • Loading branch information
dcadenas committed Sep 26, 2024
1 parent a00b1ce commit ad1920a
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 49 deletions.
24 changes: 6 additions & 18 deletions src/domain/followee_notification_factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use tokio::time::Instant;
type Follower = PublicKey;
type Followee = PublicKey;

static ONE_DAY: Duration = Duration::from_secs(24 * 60 * 60);
static TWELVE_HOURS: Duration = Duration::from_secs(12 * 60 * 60);
/// Accumulates messages for a followee and flushes them in batches
pub struct FolloweeNotificationFactory {
pub follow_changes: OrderMap<Follower, Box<FollowChange>>,
Expand All @@ -22,7 +24,7 @@ impl FolloweeNotificationFactory {
pub fn new(min_time_between_messages: Duration) -> Self {
// Rate limiter for 1 message every 12 hours, bursts of 10
let capacity = 10.0;
let rate_limiter = RateLimiter::new(capacity, Duration::from_secs(12 * 60 * 60));
let rate_limiter = RateLimiter::new(capacity, TWELVE_HOURS);

Self {
follow_changes: OrderMap::with_capacity(100),
Expand Down Expand Up @@ -78,7 +80,7 @@ impl FolloweeNotificationFactory {
}

let one_day_elapsed = match self.emptied_at {
Some(emptied_at) => now.duration_since(emptied_at) >= Duration::from_secs(24 * 60 * 60),
Some(emptied_at) => now.duration_since(emptied_at) >= ONE_DAY,
None => true,
};

Expand Down Expand Up @@ -113,15 +115,7 @@ impl FolloweeNotificationFactory {
}

if self.should_flush() {
let now = Instant::now();
let one_day_elapsed = match self.emptied_at {
Some(emptied_at) => {
now.duration_since(emptied_at) >= Duration::from_secs(24 * 60 * 60)
}
None => true,
};

self.emptied_at = Some(now);
self.emptied_at = Some(Instant::now());

let followers = self
.follow_changes
Expand All @@ -136,13 +130,7 @@ impl FolloweeNotificationFactory {
.collect();

let tokens_needed = messages.len() as f64;

if one_day_elapsed {
// Overcharge the rate limiter to consume tokens regardless of availability
self.rate_limiter.overcharge(tokens_needed);
} else if !self.rate_limiter.consume(tokens_needed) {
return vec![];
}
self.rate_limiter.overcharge(tokens_needed);

return messages;
}
Expand Down
92 changes: 77 additions & 15 deletions src/domain/notification_factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ mod tests {
}

#[test]
fn test_single_item_batch_before_rate_limit_is_hit() {
fn test_first_single_item_batch() {
let min_seconds_between_messages = nonzero!(3600usize);

let mut notification_factory = NotificationFactory::new(min_seconds_between_messages);
Expand All @@ -265,8 +265,8 @@ mod tests {
}

#[tokio::test(start_paused = true)]
async fn test_no_message_after_rate_limit_is_hit_but_retention_not_elapsed() {
// After one single follow change the rate limit will be hit
async fn test_no_second_message_before_min_seconds_elapse() {
// After one single follow change, we need to wait min period
let min_seconds_between_messages = nonzero!(3600usize);

let mut notification_factory = NotificationFactory::new(min_seconds_between_messages);
Expand All @@ -283,23 +283,23 @@ mod tests {

advance(Duration::from_secs(1)).await;

// We hit the limit so the rest of the messages are retained
// We didn't wait min seconds
let messages = notification_factory.flush();
assert_batches_eq(&messages, &[]);
assert_eq!(notification_factory.follow_changes_len(), 2);

// We pass the number of minimum seconds between messages so all are sent
advance(Duration::from_secs(
min_seconds_between_messages.get() as u64 * 60,
min_seconds_between_messages.get() as u64
))
.await;
let messages = notification_factory.flush();
assert_batches_eq(&messages, &[(followee, &[change2, change3])]);
}

#[tokio::test(start_paused = true)]
async fn test_batch_sizes_after_rate_limit_and_retention_period() {
// After one single follow change, the rate limit will be hit
async fn test_batch_sizes_after_min_seconds() {
// After one single follow change, we need to wait min period
let min_seconds_between_messages = nonzero!(3600usize);
const MAX_FOLLOWERS_TRIPLED: usize = 3 * MAX_FOLLOWERS_PER_BATCH;

Expand All @@ -309,15 +309,15 @@ mod tests {
insert_new_follower(&mut notification_factory, followee);
let messages = notification_factory.flush();
advance(Duration::from_secs(1)).await;
// Sent and hit the rate limit
// One is accepted
assert_eq!(messages.len(), 1,);
assert!(messages[0].is_single());

repeat(()).take(MAX_FOLLOWERS_TRIPLED).for_each(|_| {
insert_new_follower(&mut notification_factory, followee);
});

// All inserted MAX_FOLLOWERS_TRIPLED changes wait for next available flush time
// All inserted MAX_FOLLOWERS_TRIPLED changes wait for min period
let messages = notification_factory.flush();
assert_eq!(messages.len(), 0);
assert_eq!(
Expand All @@ -327,7 +327,7 @@ mod tests {

// Before the max_retention time elapses..
advance(Duration::from_secs(
(min_seconds_between_messages.get() - 5) as u64 * 60,
(min_seconds_between_messages.get() - 5) as u64,
))
.await;

Expand All @@ -339,9 +339,14 @@ mod tests {
MAX_FOLLOWERS_TRIPLED + 1,
);

let messages = notification_factory.flush();
assert_eq!(messages.len(), 0);

// And we finally hit the min period
advance(Duration::from_secs(6u64)).await;
let messages = notification_factory.flush();

// But it will be included in the next batch anyways, regardless of the rate limit
// But it will be included in the next batch anyways, regardless of the min period
assert_eq!(messages.len(), 4);
assert_eq!(messages[0].len(), MAX_FOLLOWERS_PER_BATCH);
assert_eq!(messages[1].len(), MAX_FOLLOWERS_PER_BATCH);
Expand All @@ -360,7 +365,7 @@ mod tests {
assert_eq!(notification_factory.follow_changes_len(), 1);

advance(Duration::from_secs(
min_seconds_between_messages.get() as u64 * 60 + 1,
min_seconds_between_messages.get() as u64 + 1,
))
.await;
let messages = notification_factory.flush();
Expand All @@ -375,7 +380,7 @@ mod tests {
assert_eq!(notification_factory.follow_changes_len(), 0);

advance(Duration::from_secs(
min_seconds_between_messages.get() as u64 * 60 + 1,
min_seconds_between_messages.get() as u64 + 1,
))
.await;
let messages = notification_factory.flush();
Expand All @@ -386,6 +391,63 @@ mod tests {
assert_eq!(messages.len(), 0);
}

#[tokio::test(start_paused = true)]
async fn test_after_burst_of_10_messages_wait_12_hours() {
let min_seconds_between_messages = nonzero!(15 * 60usize);
let mut notification_factory = NotificationFactory::new(min_seconds_between_messages);

let followee = Keys::generate().public_key();

// After 10 messages
for _ in 0..10 {
insert_new_follower(&mut notification_factory, followee);
advance(Duration::from_secs(
min_seconds_between_messages.get() as u64
))
.await;
let messages = notification_factory.flush();
assert_eq!(messages.len(), 1);
assert_eq!(notification_factory.follow_changes_len(), 0);
assert_eq!(notification_factory.followees_len(), 1);
}

// We used all initial credit from the rate limiter
insert_new_follower(&mut notification_factory, followee);

let messages = notification_factory.flush();
assert_eq!(messages.len(), 0);
assert_eq!(notification_factory.follow_changes_len(), 1);
assert_eq!(notification_factory.followees_len(), 1);

// We wait 12 hours and a new token is available
advance(Duration::from_secs(12 * 3600)).await;
let messages = notification_factory.flush();
assert_eq!(messages.len(), 1);
assert_eq!(notification_factory.follow_changes_len(), 0);
assert_eq!(notification_factory.followees_len(), 1);

// After 24 hours with no messages the follower entry is removed with its rated limiter
advance(Duration::from_secs(24 * 3600)).await;
let messages = notification_factory.flush();
assert_eq!(messages.len(), 0);
assert_eq!(notification_factory.follow_changes_len(), 0);
assert_eq!(notification_factory.followees_len(), 0);

// And again, we have the initial credit of 10 messages
advance(Duration::from_secs(24 * 3600)).await;
for _ in 0..10 {
insert_new_follower(&mut notification_factory, followee);
advance(Duration::from_secs(
min_seconds_between_messages.get() as u64
))
.await;
let messages = notification_factory.flush();
assert_eq!(messages.len(), 1);
assert_eq!(notification_factory.follow_changes_len(), 0);
assert_eq!(notification_factory.followees_len(), 1);
}
}

#[test]
fn test_is_empty_and_len() {
let mut notification_factory = NotificationFactory::new(nonzero!(60usize));
Expand Down Expand Up @@ -416,7 +478,7 @@ mod tests {
insert_new_follower(&mut notification_factory, followee);

advance(Duration::from_secs(
min_seconds_between_messages.get() as u64 * 60 + 1,
min_seconds_between_messages.get() as u64 + 1,
))
.await;

Expand All @@ -426,7 +488,7 @@ mod tests {
insert_new_unfollower(&mut notification_factory, followee);

advance(Duration::from_secs(
min_seconds_between_messages.get() as u64 * 60 + 1,
min_seconds_between_messages.get() as u64 + 1,
))
.await;

Expand Down
32 changes: 16 additions & 16 deletions src/rate_limiter.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use tokio::time::{Duration, Instant};
use tracing::debug;

/// Token bucket rate limiter.
pub struct RateLimiter {
burst_capacity: f64, // Maximum number of tokens
capacity: f64, // Initial capacity, credit in the bucket
tokens: f64, // Current number of tokens (can be negative)
refill_rate_per_sec: f64, // Tokens added per second
last_refill: Instant, // Last time tokens were refilled
Expand All @@ -11,13 +12,13 @@ pub struct RateLimiter {

impl RateLimiter {
/// Creates a new RateLimiter.
pub fn new(burst_capacity: f64, refill_duration: Duration) -> Self {
let refill_rate_per_sec = burst_capacity / refill_duration.as_secs_f64();
let tokens = burst_capacity;
let max_negative_tokens = burst_capacity * 1000.0;
pub fn new(initial_capacity: f64, refill_duration: Duration) -> Self {
let refill_rate_per_sec = refill_duration.as_secs_f64();
let tokens = initial_capacity;
let max_negative_tokens = initial_capacity * 1000.0;

Self {
burst_capacity,
capacity: initial_capacity,
tokens,
refill_rate_per_sec,
last_refill: Instant::now(),
Expand All @@ -31,8 +32,9 @@ impl RateLimiter {
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
self.last_refill = now;

let tokens_to_add = elapsed * self.refill_rate_per_sec;
self.tokens = (self.tokens + tokens_to_add).min(self.burst_capacity);
let tokens_to_add = elapsed / self.refill_rate_per_sec;
debug!("Tokens to add: {}", tokens_to_add);
self.tokens = (self.tokens + tokens_to_add).min(self.capacity);
}

/// Checks if the specified number of tokens are available without consuming them.
Expand Down Expand Up @@ -127,7 +129,7 @@ mod tests {
#[tokio::test(start_paused = true)]
async fn test_refill_tokens() {
let capacity = 10.0;
let refill_duration = Duration::from_secs(10); // Short duration for testing
let refill_duration = Duration::from_secs(10);
let mut rate_limiter = RateLimiter::new(capacity, refill_duration);

// Consume all tokens
Expand All @@ -138,30 +140,28 @@ mod tests {
// Advance time by half of the refill duration
time::advance(Duration::from_secs(5)).await;
rate_limiter.refill_tokens();
// Should have refilled half the tokens
assert_eq!(rate_limiter.tokens, 5.0);
assert_eq!(rate_limiter.tokens, 0.5);

// Advance time to complete the refill duration
time::advance(Duration::from_secs(5)).await;
rate_limiter.refill_tokens();
assert_eq!(rate_limiter.tokens, capacity);
assert_eq!(rate_limiter.tokens, 1.0);
}

#[tokio::test(start_paused = true)]
async fn test_overcharge_and_refill() {
let capacity = 10.0;
let refill_duration = Duration::from_secs(10); // Short duration for testing
let refill_duration = Duration::from_secs(10);
let mut rate_limiter = RateLimiter::new(capacity, refill_duration);

// Overcharge by 15 tokens
rate_limiter.overcharge(15.0);
assert_eq!(rate_limiter.tokens, -5.0);

// Advance time to refill tokens
time::advance(Duration::from_secs(20)).await; // Wait enough to refill capacity
time::advance(Duration::from_secs(6 * 10)).await; // Wait enough to refill capacity
rate_limiter.refill_tokens();
// Tokens should be at capacity, but deficit should be reduced
assert_eq!(rate_limiter.tokens, capacity);
assert_eq!(rate_limiter.tokens, 1.0);
}

#[tokio::test]
Expand Down

0 comments on commit ad1920a

Please sign in to comment.