From 1ff618b0e39ee174a80f4c04db8c95ca68316f85 Mon Sep 17 00:00:00 2001 From: Tharsanan1 Date: Tue, 10 Sep 2024 20:35:50 +0530 Subject: [PATCH] Add sub based ai ratelimit feature related changes --- .../oasparser/envoyconf/routes_configs.go | 5 +++- .../envoyconf/routes_with_clusters.go | 1 + .../operator/controllers/dp/api_controller.go | 14 ++++++--- adapter/internal/operator/operator.go | 2 +- .../operator/synchronizer/data_store.go | 5 ++++ .../internal/xds/ratelimiter_cache.go | 3 +- .../apk/enforcer/grpc/ExtAuthService.java | 1 + .../grpc/ExternalProcessorService.java | 29 ++++++++++++++++--- .../enforcer/grpc/client/RatelimitClient.java | 24 ++++++++++++--- .../security/jwt/Oauth2Authenticator.java | 2 ++ 10 files changed, 71 insertions(+), 15 deletions(-) diff --git a/adapter/internal/oasparser/envoyconf/routes_configs.go b/adapter/internal/oasparser/envoyconf/routes_configs.go index 8913a1e5d..4798a48cd 100644 --- a/adapter/internal/oasparser/envoyconf/routes_configs.go +++ b/adapter/internal/oasparser/envoyconf/routes_configs.go @@ -152,6 +152,9 @@ func generateRouteAction(apiType string, routeConfig *model.EndpointConfig, rate if isBackendBasedAIRatelimitEnabled { action.Route.RateLimits = append(action.Route.RateLimits, generateBackendBasedAIRatelimit(descriptorValueForBackendBasedAIRatelimit)...) } + if isSubscriptionBasedAIRatelimitEnabled { + action.Route.RateLimits = append(action.Route.RateLimits, generateSubscriptionBasedAIRatelimit()...) + } // Add request mirroring configurations if mirrorClusterNames != nil && len(mirrorClusterNames) > 0 { @@ -260,7 +263,7 @@ func generateBackendBasedAIRatelimit(descValue string) []*routev3.RateLimit { } -func generateSubscriptionBasedAIRatelimit(descValue string) []*routev3.RateLimit { +func generateSubscriptionBasedAIRatelimit() []*routev3.RateLimit { rateLimitForRequestTokenCount := routev3.RateLimit{ Actions: []*routev3.RateLimit_Action{ { diff --git a/adapter/internal/oasparser/envoyconf/routes_with_clusters.go b/adapter/internal/oasparser/envoyconf/routes_with_clusters.go index 674f1c808..4d0928bb7 100644 --- a/adapter/internal/oasparser/envoyconf/routes_with_clusters.go +++ b/adapter/internal/oasparser/envoyconf/routes_with_clusters.go @@ -930,6 +930,7 @@ func createRoutes(params *routeCreateParams) (routes []*routev3.Route, err error } routeConfig := resource.GetEndpoints().Config metaData := &corev3.Metadata{} + logger.LoggerAPI.Infof("Is backend based rl enabled: %+v, Is subs based rl enabled: %+v", resource.GetEnableBackendBasedAIRatelimit(), resource.GetEnableSubscriptionBasedAIRatelimit()) if resource.GetEnableBackendBasedAIRatelimit() || resource.GetEnableSubscriptionBasedAIRatelimit() { metaData = &corev3.Metadata{ FilterMetadata: map[string]*structpb.Struct{ diff --git a/adapter/internal/operator/controllers/dp/api_controller.go b/adapter/internal/operator/controllers/dp/api_controller.go index 121265ef8..19ba7db98 100644 --- a/adapter/internal/operator/controllers/dp/api_controller.go +++ b/adapter/internal/operator/controllers/dp/api_controller.go @@ -865,12 +865,18 @@ func (apiReconciler *APIReconciler) resolveAiSubscriptionRatelimitPolicies(ctx c } for _, subscription := range subscriptionList.Items { aiRatelimitPolicy := &dpv1alpha3.AIRateLimitPolicy{} - if err := apiReconciler.client.Get(ctx, utils.NamespacedName(&subscription), aiRatelimitPolicy, ); err != nil { + nn:= types.NamespacedName{ + Name: subscription.Spec.RatelimitRef.Name, + Namespace: subscription.GetNamespace(), + } + if err := apiReconciler.client.Get(ctx, nn, aiRatelimitPolicy, ); err != nil { loggers.LoggerAPKOperator.Infof("No associated aiRatelimitPolicy found for Subscription: %s", utils.NamespacedName(&subscription)) - return + continue + } else { + loggers.LoggerAPKOperator.Infof("API state set as AI subscription enabled") + apiState.IsAiSubscriptionRatelimitEnabled = true + break } - apiState.IsAiSubscriptionRatelimitEnabled = true - return } } diff --git a/adapter/internal/operator/operator.go b/adapter/internal/operator/operator.go index b7236ce5e..ece07386a 100644 --- a/adapter/internal/operator/operator.go +++ b/adapter/internal/operator/operator.go @@ -51,7 +51,7 @@ import ( dpv1alpha2 "github.com/wso2/apk/common-go-libs/apis/dp/v1alpha2" dpv1alpha3 "github.com/wso2/apk/common-go-libs/apis/dp/v1alpha3" cpv1alpha2 "github.com/wso2/apk/common-go-libs/apis/cp/v1alpha2" - cpv1alpha3 "github.com/wso2/apk/common-go-libs/apis/cp/v1alpha2" + cpv1alpha3 "github.com/wso2/apk/common-go-libs/apis/cp/v1alpha3" //+kubebuilder:scaffold:imports ) diff --git a/adapter/internal/operator/synchronizer/data_store.go b/adapter/internal/operator/synchronizer/data_store.go index 1ecb67f1c..23183f949 100644 --- a/adapter/internal/operator/synchronizer/data_store.go +++ b/adapter/internal/operator/synchronizer/data_store.go @@ -75,6 +75,11 @@ func (ods *OperatorDataStore) processAPIState(apiNamespacedName types.Namespaced var updated bool events := []string{} cachedAPI := ods.apiStore[apiNamespacedName] + if cachedAPI.IsAiSubscriptionRatelimitEnabled != apiState.IsAiSubscriptionRatelimitEnabled { + cachedAPI.IsAiSubscriptionRatelimitEnabled = apiState.IsAiSubscriptionRatelimitEnabled + updated = true + events = append(events, "Subscription based AI RatelimitPolicy") + } if apiState.APIDefinition.Generation > cachedAPI.APIDefinition.Generation { cachedAPI.APIDefinition = apiState.APIDefinition diff --git a/common-controller/internal/xds/ratelimiter_cache.go b/common-controller/internal/xds/ratelimiter_cache.go index 7f7d918a1..53827e723 100644 --- a/common-controller/internal/xds/ratelimiter_cache.go +++ b/common-controller/internal/xds/ratelimiter_cache.go @@ -448,7 +448,8 @@ func (r *rateLimitPolicyCache) ProcessAIRatelimitPolicySpecsAndUpdateCache(aiRat } func prepareSubscriptionBasedAIRatelimitIdentifier(org string, namespacedName types.NamespacedName) string { - return fmt.Sprintf("%s-%s-%s", org, string(namespacedName.Namespace), string(namespacedName.Name)) + // return fmt.Sprintf("%s-%s-%s", org, string(namespacedName.Namespace), string(namespacedName.Name)) + return fmt.Sprintf("%s-%s", org, string(namespacedName.Name)) } func prepareAIRatelimitIdentifier(org string, namespacedName types.NamespacedName, spec *dpv1alpha3.AIRateLimitPolicySpec) string { diff --git a/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/grpc/ExtAuthService.java b/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/grpc/ExtAuthService.java index eb2b59a97..267a0b730 100644 --- a/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/grpc/ExtAuthService.java +++ b/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/grpc/ExtAuthService.java @@ -264,6 +264,7 @@ private String constructQueryParamString(boolean removeAllQueryParams, String re * @param value */ private void addMetadata(Struct.Builder structBuilder, String key, String value) { + System.out.println("Key: " + key + " value: " + value); structBuilder.putFields(key, Value.newBuilder().setStringValue(value).build()); } diff --git a/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/grpc/ExternalProcessorService.java b/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/grpc/ExternalProcessorService.java index bda06b8e8..1678eea94 100644 --- a/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/grpc/ExternalProcessorService.java +++ b/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/grpc/ExternalProcessorService.java @@ -18,10 +18,13 @@ package org.wso2.apk.enforcer.grpc; +import com.fasterxml.jackson.core.JsonParseException; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.protobuf.Struct; import com.google.protobuf.Value; +import io.envoyproxy.envoy.config.core.v3.Metadata; import io.envoyproxy.envoy.service.ext_proc.v3.BodyMutation; import io.envoyproxy.envoy.service.ext_proc.v3.BodyResponse; import io.envoyproxy.envoy.service.ext_proc.v3.CommonResponse; @@ -47,6 +50,12 @@ public class ExternalProcessorService extends ExternalProcessorGrpc.ExternalProc private static final String DESCRIPTOR_KEY_FOR_AI_REQUEST_TOKEN_COUNT = "airequesttokencount"; private static final String DESCRIPTOR_KEY_FOR_AI_RESPONSE_TOKEN_COUNT = "airesponsetokencount"; private static final String DESCRIPTOR_KEY_FOR_AI_TOTAL_TOKEN_COUNT = "aitotaltokencount"; + private static final String DESCRIPTOR_KEY_FOR_SUBSCRIPTION_BASED_AI_REQUEST_TOKEN_COUNT = "airequesttokencountsubs"; + private static final String DESCRIPTOR_KEY_FOR_SUBSCRIPTION_BASED_AI_RESPONSE_TOKEN_COUNT = "airesponsetokencountsubs"; + private static final String DESCRIPTOR_KEY_FOR_SUBSCRIPTION_BASED_AI_TOTAL_TOKEN_COUNT = "aitotaltokencountsubs"; + private static final String DESCRIPTOR_KEY_FOR_AI_SUBSCRIPTION = "subscription"; + private static final String DYNAMIC_METADATA_KEY_FOR_ORGANIZATION_AND_AIRL_POLICY = "ratelimit:organization-and-rlpolicy"; + private static final String DYNAMIC_METADATA_KEY_FOR_SUBSCRIPTION = "ratelimit:subscription"; RatelimitClient ratelimitClient = new RatelimitClient(); @Override public StreamObserver process( @@ -92,13 +101,25 @@ public void onNext(ProcessingRequest request) { return; } System.out.println("body: " +request.getResponseBody().getBody().toStringUtf8()); + List configs = new ArrayList<>(); if (filterMetadata.enableBackendBasedAIRatelimit) { - List configs = new ArrayList<>(); configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_REQUEST_TOKEN_COUNT, filterMetadata.backendBasedAIRatelimitDescriptorValue, usage.getPrompt_tokens())); configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_RESPONSE_TOKEN_COUNT, filterMetadata.backendBasedAIRatelimitDescriptorValue, usage.getCompletion_tokens())); configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_TOTAL_TOKEN_COUNT, filterMetadata.backendBasedAIRatelimitDescriptorValue, usage.getTotal_tokens())); - ratelimitClient.shouldRatelimit(configs); } + if (filterMetadata.enableSubscriptionBasedAIRatelimit) { + if (request.hasMetadataContext()) { + Struct filterMetadataFromAuthZ = request.getMetadataContext().getFilterMetadataOrDefault("envoy.filters.http.ext_authz", null); + if (filterMetadataFromAuthZ != null) { + String orgAndAIRLPolicyValue = filterMetadataFromAuthZ.getFieldsMap().get(DYNAMIC_METADATA_KEY_FOR_ORGANIZATION_AND_AIRL_POLICY).getStringValue(); + String aiRLSubsValue = filterMetadataFromAuthZ.getFieldsMap().get(DYNAMIC_METADATA_KEY_FOR_SUBSCRIPTION).getStringValue(); + configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_SUBSCRIPTION_BASED_AI_REQUEST_TOKEN_COUNT, orgAndAIRLPolicyValue, new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_SUBSCRIPTION, aiRLSubsValue, usage.getPrompt_tokens()))); + configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_SUBSCRIPTION_BASED_AI_RESPONSE_TOKEN_COUNT, orgAndAIRLPolicyValue, new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_SUBSCRIPTION, aiRLSubsValue, usage.getPrompt_tokens()))); + configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_SUBSCRIPTION_BASED_AI_TOTAL_TOKEN_COUNT, orgAndAIRLPolicyValue, new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_SUBSCRIPTION, aiRLSubsValue, usage.getPrompt_tokens()))); + } + } + } + ratelimitClient.shouldRatelimit(configs); responseObserver.onCompleted(); } else { System.out.println("Request does not have response body"); @@ -258,8 +279,8 @@ private static Usage extractUsageFromBody(String body, String completionTokenPat System.out.println("Usage extracted: "+ usage); return usage; - } catch (JsonProcessingException e) { - logger.error(String.format("Unexpected error while extracting usage from the body: %s", body), e); + } catch (Exception e) { + System.out.println(String.format("Unexpected error while extracting usage from the body: %s", body) + " \n" + e); return null; } } diff --git a/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/grpc/client/RatelimitClient.java b/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/grpc/client/RatelimitClient.java index 303429606..fb528ffbd 100644 --- a/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/grpc/client/RatelimitClient.java +++ b/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/grpc/client/RatelimitClient.java @@ -52,14 +52,22 @@ public void shouldRatelimit(List configs) { executorService.submit(() -> { System.out.println("Ratelimitclient test"); for (KeyValueHitsAddend config : configs) { + System.out.println("For: " + config.getKey()); - RateLimitDescriptor descriptor = RateLimitDescriptor.newBuilder() - .addEntries(RateLimitDescriptor.Entry.newBuilder().setKey(config.getKey()).setValue(config.getValue()).build()) - .build(); + RateLimitDescriptor.Builder builder = RateLimitDescriptor.newBuilder() + .addEntries(RateLimitDescriptor.Entry.newBuilder().setKey(config.getKey()).setValue(config.getValue()).build()); + KeyValueHitsAddend internalKeyValueHitsAddend = config.keyValueHitsAddend; + int hitsAddend = config.getHitsAddend(); + while (internalKeyValueHitsAddend != null) { + builder.addEntries(RateLimitDescriptor.Entry.newBuilder().setKey(internalKeyValueHitsAddend.getKey()).setValue(internalKeyValueHitsAddend.getValue()).build()); + hitsAddend = internalKeyValueHitsAddend.getHitsAddend(); + internalKeyValueHitsAddend = internalKeyValueHitsAddend.keyValueHitsAddend; + } + RateLimitDescriptor descriptor = builder.build(); RateLimitRequest rateLimitRequest = RateLimitRequest.newBuilder() .addDescriptors(descriptor) .setDomain("Default") - .setHitsAddend(config.getHitsAddend()) + .setHitsAddend(hitsAddend) .build(); RateLimitResponse rateLimitResponse = stub.shouldRateLimit(rateLimitRequest); System.out.println(rateLimitResponse.getOverallCode()); @@ -73,11 +81,19 @@ public static class KeyValueHitsAddend { private String key; private String value; private int hitsAddend; + private KeyValueHitsAddend keyValueHitsAddend; public KeyValueHitsAddend(String key, String value, int hitsAddend) { this.key = key; this.value = value; this.hitsAddend = hitsAddend; + this.keyValueHitsAddend = null; + } + public KeyValueHitsAddend(String key, String value, KeyValueHitsAddend keyValueHitsAddend) { + this.key = key; + this.value = value; + this.hitsAddend = -1; + this.keyValueHitsAddend = keyValueHitsAddend; } public String getKey() { diff --git a/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/security/jwt/Oauth2Authenticator.java b/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/security/jwt/Oauth2Authenticator.java index 9a29fbf45..b12dda1c0 100644 --- a/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/security/jwt/Oauth2Authenticator.java +++ b/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/security/jwt/Oauth2Authenticator.java @@ -268,6 +268,8 @@ public AuthenticationContext authenticate(RequestContext requestContext) throws requestContext.addMetadataToMap("ratelimit:subscription", subscriptionId); requestContext.addMetadataToMap("ratelimit:usage-policy", subscription.getRatelimitTier()); requestContext.addMetadataToMap("ratelimit:organization", subscription.getOrganization()); + System.out.println("Value: "+ String.format("%s-%s", subscription.getOrganization(), subscription.getRatelimitTier())); + requestContext.addMetadataToMap("ratelimit:organization-and-rlpolicy", String.format("%s-%s", subscription.getOrganization(), subscription.getRatelimitTier())); } } }