Skip to content

Commit

Permalink
Fix async filter matching strategy (#192)
Browse files Browse the repository at this point in the history
The RateLimitConditionMatchingStrategy.FIRST check was done after the ratelimit checks are all executed. This caused tokens to be consumed from all ratelimitcheck buckets, while only the first ratelimitcheck should have been consumed. This should now be fixed.

Co-authored-by: Edwin Heuver <[email protected]>
  • Loading branch information
Edwin9292 and Edwin-192 authored Dec 1, 2023
1 parent 627d3ea commit e5555c1
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 209 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

import static java.nio.charset.StandardCharsets.UTF_8;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

import com.giffing.bucket4j.spring.boot.starter.context.ConsumptionProbeHolder;
import com.giffing.bucket4j.spring.boot.starter.context.RateLimitCheck;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
Expand All @@ -22,47 +25,38 @@
@Data
@Slf4j
public class AbstractReactiveFilter {

private final FilterConfiguration<ServerHttpRequest> filterConfig;

public AbstractReactiveFilter(FilterConfiguration<ServerHttpRequest> filterConfig) {
this.filterConfig = filterConfig;
}

protected boolean urlMatches(ServerHttpRequest request) {
return request.getURI().getPath().matches(filterConfig.getUrl());
}

protected Mono<Void> chainWithRateLimitCheck(ServerWebExchange exchange, ReactiveFilterChain chain) {
log.debug("reate-limit-check;method:{};uri:{}", exchange.getRequest().getMethod(), exchange.getRequest().getURI());
ServerHttpRequest request = exchange.getRequest();
ServerHttpResponse response = exchange.getResponse();
List<Mono<ConsumptionProbe>> asyncConsumptionProbes = filterConfig.getRateLimitChecks()
.stream()
.map(rl -> rl.rateLimit(request))
.filter(cph -> cph != null && cph.getConsumptionProbeCompletableFuture() != null)
.map(cph -> Mono.fromFuture(cph.getConsumptionProbeCompletableFuture()))
.toList();
List<Mono<ConsumptionProbe>> asyncConsumptionProbes = new ArrayList<>();
for (RateLimitCheck<ServerHttpRequest> rlc : filterConfig.getRateLimitChecks()) {
ConsumptionProbeHolder cph = rlc.rateLimit(request);
if(cph != null && cph.getConsumptionProbeCompletableFuture() != null){
asyncConsumptionProbes.add(Mono.fromFuture(cph.getConsumptionProbeCompletableFuture()));
if(filterConfig.getStrategy() == RateLimitConditionMatchingStrategy.FIRST){
break;
}
}
}
if(asyncConsumptionProbes.isEmpty()) {
return chain.apply(exchange);
}
AtomicInteger consumptionProbeCounter = new AtomicInteger(0);
return Flux
.concat(asyncConsumptionProbes)
//.takeWhile(Objects::nonNull)
.doOnNext(cp -> consumptionProbeCounter.incrementAndGet())
.takeWhile(cp -> shouldTakeMoreConsumptionProbe(consumptionProbeCounter))
.reduce(this::reduceConsumptionProbe)
.flatMap(consumptionProbe -> handleConsumptionProbe(exchange, chain, response, consumptionProbe));

}

protected boolean shouldTakeMoreConsumptionProbe(AtomicInteger consumptionProbeCounter) {
boolean shouldTakeMore = filterConfig.getStrategy().equals(RateLimitConditionMatchingStrategy.ALL)
||
(filterConfig.getStrategy().equals(RateLimitConditionMatchingStrategy.FIRST) && consumptionProbeCounter.get() == 1);
log.debug("take-more-probes:{};probe-index:{};matching-strategy:{}", shouldTakeMore, consumptionProbeCounter.get(), filterConfig.getStrategy());
return shouldTakeMore;
.concat(asyncConsumptionProbes)
.reduce(this::reduceConsumptionProbe)
.flatMap(consumptionProbe -> handleConsumptionProbe(exchange, chain, response, consumptionProbe));
}

protected ConsumptionProbe reduceConsumptionProbe(ConsumptionProbe x, ConsumptionProbe y) {
Expand All @@ -72,23 +66,23 @@ protected ConsumptionProbe reduceConsumptionProbe(ConsumptionProbe x, Consumptio
} else if(!y.isConsumed()) {
result = y;
} else {
result = x.getRemainingTokens() < y.getRemainingTokens() ? x : y;
result = x.getRemainingTokens() < y.getRemainingTokens() ? x : y;
}
log.debug("reduce-probes;result-isConsumed:{};result-getremainingTokens:{};x-isConsumed:{};x-getremainingTokens{};y-isConsumed:{};y-getremainingTokens{}",
result.isConsumed(), result.getRemainingTokens(),
x.isConsumed(), x.getRemainingTokens(),
y.isConsumed(), y.getRemainingTokens());
return result;
}

protected Mono<Void> handleConsumptionProbe(ServerWebExchange exchange, ReactiveFilterChain chain,
ServerHttpResponse response, ConsumptionProbe cp) {
log.debug("probe-results;isConsumed:{};remainingTokens:{};nanosToWaitForRefill:{};nanosToWaitForReset:{}",
cp.isConsumed(),
cp.getRemainingTokens(),
cp.getNanosToWaitForRefill(),
log.debug("probe-results;isConsumed:{};remainingTokens:{};nanosToWaitForRefill:{};nanosToWaitForReset:{}",
cp.isConsumed(),
cp.getRemainingTokens(),
cp.getNanosToWaitForRefill(),
cp.getNanosToWaitForReset());

if(!cp.isConsumed()) {
if(Boolean.FALSE.equals(filterConfig.getHideHttpResponseHeaders())) {
filterConfig.getHttpResponseHeaders().forEach(response.getHeaders()::addIfAbsent);
Expand All @@ -97,7 +91,7 @@ protected Mono<Void> handleConsumptionProbe(ServerWebExchange exchange, Reactive
response.setStatusCode(filterConfig.getHttpStatusCode());
response.getHeaders().set("Content-Type", filterConfig.getHttpContentType());
DataBuffer buffer = exchange.getResponse().bufferFactory().wrap(filterConfig.getHttpResponseBody().getBytes(UTF_8));
return response.writeWith(Flux.just(buffer));
return response.writeWith(Flux.just(buffer));
} else {
return Mono.error(new ReactiveRateLimitException("HTTP ResponseBody is null"));
}
Expand All @@ -108,6 +102,4 @@ protected Mono<Void> handleConsumptionProbe(ServerWebExchange exchange, Reactive
}
return chain.apply(exchange);
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -39,52 +39,52 @@
class SpringCloudGatewayRateLimitFilterTest {

private GlobalFilter filter;
private FilterConfiguration configuration;
private RateLimitCheck rateLimitCheck1;
private RateLimitCheck rateLimitCheck2;
private RateLimitCheck rateLimitCheck3;
private FilterConfiguration<ServerHttpRequest> configuration;
private RateLimitCheck<ServerHttpRequest> rateLimitCheck1;
private RateLimitCheck<ServerHttpRequest> rateLimitCheck2;
private RateLimitCheck<ServerHttpRequest> rateLimitCheck3;

private ServerWebExchange exchange;
private GatewayFilterChain chain;


private ServerHttpResponse serverHttpResponse;

@BeforeEach
public void setup() throws URISyntaxException {
rateLimitCheck1 = mock(RateLimitCheck.class);
rateLimitCheck2 = mock(RateLimitCheck.class);
rateLimitCheck3 = mock(RateLimitCheck.class);

exchange = Mockito.mock(ServerWebExchange.class);
ServerHttpRequest serverHttpRequest = Mockito.mock(ServerHttpRequest.class);
URI uri = new URI("url");
when(serverHttpRequest.getURI()).thenReturn(uri);
public void setup() throws URISyntaxException {
rateLimitCheck1 = mock(RateLimitCheck.class);
rateLimitCheck2 = mock(RateLimitCheck.class);
rateLimitCheck3 = mock(RateLimitCheck.class);

exchange = Mockito.mock(ServerWebExchange.class);

ServerHttpRequest serverHttpRequest = Mockito.mock(ServerHttpRequest.class);
URI uri = new URI("url");
when(serverHttpRequest.getURI()).thenReturn(uri);
when(exchange.getRequest()).thenReturn(serverHttpRequest);

serverHttpResponse = Mockito.mock(ServerHttpResponse.class);
when(exchange.getResponse()).thenReturn(serverHttpResponse);
when(exchange.getResponse()).thenReturn(serverHttpResponse);

chain = Mockito.mock(GatewayFilterChain.class);
when(chain.filter(exchange)).thenReturn(Mono.empty());
configuration = new FilterConfiguration();
configuration.setRateLimitChecks(Arrays.asList(rateLimitCheck1, rateLimitCheck2, rateLimitCheck3));
configuration.setUrl(".*");
filter = new SpringCloudGatewayRateLimitFilter(configuration);
}

@Test

configuration = new FilterConfiguration<>();
configuration.setRateLimitChecks(Arrays.asList(rateLimitCheck1, rateLimitCheck2, rateLimitCheck3));
configuration.setUrl(".*");
filter = new SpringCloudGatewayRateLimitFilter(configuration);
}

@Test
void should_throw_rate_limit_exception_with_no_remaining_tokens() {

configuration.setStrategy(RateLimitConditionMatchingStrategy.FIRST);

rateLimitConfig(0L, rateLimitCheck1);
HttpHeaders httpHeaders = Mockito.mock(HttpHeaders.class);
when(serverHttpResponse.getHeaders()).thenReturn(httpHeaders);
AtomicBoolean hasRateLimitError = new AtomicBoolean(false);
rateLimitConfig(0L, rateLimitCheck1);
HttpHeaders httpHeaders = Mockito.mock(HttpHeaders.class);
when(serverHttpResponse.getHeaders()).thenReturn(httpHeaders);

AtomicBoolean hasRateLimitError = new AtomicBoolean(false);
Mono<Void> result = filter.filter(exchange, chain)
.onErrorResume(ReactiveRateLimitException.class, (e) -> {
hasRateLimitError.set(true);
Expand All @@ -93,63 +93,62 @@ void should_throw_rate_limit_exception_with_no_remaining_tokens() {
result.subscribe();
Assertions.assertTrue(hasRateLimitError.get());
}

@Test
void should_execute_all_checks_when_using_RateLimitConditionMatchingStrategy_All() throws URISyntaxException {
configuration.setStrategy(RateLimitConditionMatchingStrategy.ALL);

rateLimitConfig(30L, rateLimitCheck1);
rateLimitConfig(0L, rateLimitCheck2);
rateLimitConfig(0L, rateLimitCheck3);

HttpHeaders httpHeaders = Mockito.mock(HttpHeaders.class);
when(serverHttpResponse.getHeaders()).thenReturn(httpHeaders);
final ArgumentCaptor<String> captor = ArgumentCaptor.forClass(String.class);
Mono<Void> result = filter.filter(exchange, chain);
assertThrows(ReactiveRateLimitException.class, () -> {
result.block();
});

configuration.setStrategy(RateLimitConditionMatchingStrategy.ALL);

rateLimitConfig(30L, rateLimitCheck1);
rateLimitConfig(0L, rateLimitCheck2);
rateLimitConfig(0L, rateLimitCheck3);

HttpHeaders httpHeaders = Mockito.mock(HttpHeaders.class);
when(serverHttpResponse.getHeaders()).thenReturn(httpHeaders);
final ArgumentCaptor<String> captor = ArgumentCaptor.forClass(String.class);

Mono<Void> result = filter.filter(exchange, chain);
assertThrows(ReactiveRateLimitException.class, () -> {
result.block();
});

verify(rateLimitCheck1, times(1)).rateLimit(any());
verify(rateLimitCheck2, times(1)).rateLimit(any());
verify(rateLimitCheck3, times(1)).rateLimit(any());
verify(rateLimitCheck2, times(1)).rateLimit(any());
verify(rateLimitCheck3, times(1)).rateLimit(any());
}

@Test
void should_execute_only_one_check_when_using_RateLimitConditionMatchingStrategy_FIRST() {
configuration.setStrategy(RateLimitConditionMatchingStrategy.FIRST);

rateLimitConfig(30L, rateLimitCheck1);
rateLimitConfig(0L, rateLimitCheck2);
rateLimitConfig(10L, rateLimitCheck3);
HttpHeaders httpHeaders = Mockito.mock(HttpHeaders.class);
when(serverHttpResponse.getHeaders()).thenReturn(httpHeaders);
final ArgumentCaptor<String> captor = ArgumentCaptor.forClass(String.class);
configuration.setStrategy(RateLimitConditionMatchingStrategy.FIRST);

rateLimitConfig(30L, rateLimitCheck1);
rateLimitConfig(0L, rateLimitCheck2);
rateLimitConfig(10L, rateLimitCheck3);

HttpHeaders httpHeaders = Mockito.mock(HttpHeaders.class);
when(serverHttpResponse.getHeaders()).thenReturn(httpHeaders);
final ArgumentCaptor<String> captor = ArgumentCaptor.forClass(String.class);

Mono<Void> result = filter.filter(exchange, chain);
result.block();
verify(httpHeaders, times(1)).set(any(), captor.capture());

List<String> values = captor.getAllValues();
Assertions.assertEquals("30", values.stream().findFirst().get());
verify(rateLimitCheck1, times(1)).rateLimit(any());
verify(rateLimitCheck2, times(1)).rateLimit(any());
verify(rateLimitCheck3, times(1)).rateLimit(any());

verify(httpHeaders, times(1)).set(any(), captor.capture());

List<String> values = captor.getAllValues();
Assertions.assertEquals("30", values.stream().findFirst().get());

verify(rateLimitCheck1, times(1)).rateLimit(any());
verify(rateLimitCheck2, times(0)).rateLimit(any());
verify(rateLimitCheck3, times(0)).rateLimit(any());
}

private void rateLimitConfig(Long remainingTokens, RateLimitCheck rateLimitCheck) {
private void rateLimitConfig(Long remainingTokens, RateLimitCheck<ServerHttpRequest> rateLimitCheck) {
ConsumptionProbeHolder consumptionHolder = Mockito.mock(ConsumptionProbeHolder.class);
ConsumptionProbe probe = Mockito.mock(ConsumptionProbe.class);
when(probe.isConsumed()).thenReturn(remainingTokens > 0 ? true : false);
ConsumptionProbe probe = Mockito.mock(ConsumptionProbe.class);
when(probe.isConsumed()).thenReturn(remainingTokens > 0);
when(probe.getRemainingTokens()).thenReturn(remainingTokens);
when(consumptionHolder.getConsumptionProbeCompletableFuture())
.thenReturn(CompletableFuture.completedFuture(probe));
when(rateLimitCheck.rateLimit(any())).thenReturn(consumptionHolder);
.thenReturn(CompletableFuture.completedFuture(probe));
when(rateLimitCheck.rateLimit(any())).thenReturn(consumptionHolder);
}

}
Loading

0 comments on commit e5555c1

Please sign in to comment.