From b0d994441912ceb98ac807f69658a66b7cf660c3 Mon Sep 17 00:00:00 2001 From: Tharsanan1 Date: Wed, 11 Sep 2024 15:12:44 +0530 Subject: [PATCH] Use AIProvider to extract token details from response --- .../oasparser/envoyconf/internal_dtos.go | 1 + .../envoyconf/routes_with_clusters.go | 23 ++++--- .../operator/controllers/dp/api_controller.go | 6 +- .../operator/synchronizer/data_store.go | 14 ++++ .../grpc/ExternalProcessorService.java | 67 ++++++++++++------- .../apk/enforcer/security/AuthFilter.java | 6 ++ 6 files changed, 77 insertions(+), 40 deletions(-) diff --git a/adapter/internal/oasparser/envoyconf/internal_dtos.go b/adapter/internal/oasparser/envoyconf/internal_dtos.go index 1ca8d5f21..0cb5f6b2b 100644 --- a/adapter/internal/oasparser/envoyconf/internal_dtos.go +++ b/adapter/internal/oasparser/envoyconf/internal_dtos.go @@ -45,6 +45,7 @@ type routeCreateParams struct { environment string envType string mirrorClusterNames map[string][]string + isAiAPI bool } // RatelimitCriteria criterias of rate limiting diff --git a/adapter/internal/oasparser/envoyconf/routes_with_clusters.go b/adapter/internal/oasparser/envoyconf/routes_with_clusters.go index 4d0928bb7..24abc8ceb 100644 --- a/adapter/internal/oasparser/envoyconf/routes_with_clusters.go +++ b/adapter/internal/oasparser/envoyconf/routes_with_clusters.go @@ -865,7 +865,7 @@ func createRoutes(params *routeCreateParams) (routes []*routev3.Route, err error LuaLocal: luaFilter, wellknown.CORS: corsFilter, } - if !resource.GetEnableBackendBasedAIRatelimit() && !resource.GetEnableSubscriptionBasedAIRatelimit() { + if !params.isAiAPI || (!resource.GetEnableBackendBasedAIRatelimit() && !resource.GetEnableSubscriptionBasedAIRatelimit()) { perFilterConfigExtProc := extProcessorv3.ExtProcPerRoute{ Override: &extProcessorv3.ExtProcPerRoute_Disabled{ Disabled: true, @@ -878,7 +878,7 @@ func createRoutes(params *routeCreateParams) (routes []*routev3.Route, err error } perRouteFilterConfigs[HTTPExternalProcessor] = filterExtProc } - + logger.LoggerOasparser.Debugf("adding route : %s for API : %s", resourcePath, title) rateLimitPolicyLevel := "" @@ -931,7 +931,7 @@ func createRoutes(params *routeCreateParams) (routes []*routev3.Route, err error routeConfig := resource.GetEndpoints().Config metaData := &corev3.Metadata{} logger.LoggerAPI.Infof("Is backend based rl enabled: %+v, Is subs based rl enabled: %+v", resource.GetEnableBackendBasedAIRatelimit(), resource.GetEnableSubscriptionBasedAIRatelimit()) - if resource.GetEnableBackendBasedAIRatelimit() || resource.GetEnableSubscriptionBasedAIRatelimit() { + if params.isAiAPI && (resource.GetEnableBackendBasedAIRatelimit() || resource.GetEnableSubscriptionBasedAIRatelimit()) { metaData = &corev3.Metadata{ FilterMetadata: map[string]*structpb.Struct{ "envoy.filters.http.ext_proc": &structpb.Struct{ @@ -1092,8 +1092,8 @@ func createRoutes(params *routeCreateParams) (routes []*routev3.Route, err error metadataValue := operation.GetMethod() + "_to_" + newMethod match2.DynamicMetadata = generateMetadataMatcherForInternalRoutes(metadataValue) - action1 := generateRouteAction(apiType, routeConfig, rateLimitPolicyCriteria, mirrorClusterNames[operation.GetID()], resource.GetEnableSubscriptionBasedAIRatelimit(), resource.GetEnableBackendBasedAIRatelimit(), resource.GetBackendBasedAIRatelimitDescriptorValue()) - action2 := generateRouteAction(apiType, routeConfig, rateLimitPolicyCriteria, mirrorClusterNames[operation.GetID()], resource.GetEnableSubscriptionBasedAIRatelimit(), resource.GetEnableBackendBasedAIRatelimit(), resource.GetBackendBasedAIRatelimitDescriptorValue()) + action1 := generateRouteAction(apiType, routeConfig, rateLimitPolicyCriteria, mirrorClusterNames[operation.GetID()], resource.GetEnableSubscriptionBasedAIRatelimit() && params.isAiAPI, resource.GetEnableBackendBasedAIRatelimit() && params.isAiAPI, resource.GetBackendBasedAIRatelimitDescriptorValue()) + action2 := generateRouteAction(apiType, routeConfig, rateLimitPolicyCriteria, mirrorClusterNames[operation.GetID()], resource.GetEnableSubscriptionBasedAIRatelimit() && params.isAiAPI, resource.GetEnableBackendBasedAIRatelimit() && params.isAiAPI, resource.GetBackendBasedAIRatelimitDescriptorValue()) // Create route1 for current method. // Do not add policies to route config. Send via enforcer @@ -1116,7 +1116,7 @@ func createRoutes(params *routeCreateParams) (routes []*routev3.Route, err error } else { var action *routev3.Route_Route if requestRedirectAction == nil { - action = generateRouteAction(apiType, routeConfig, rateLimitPolicyCriteria, mirrorClusterNames[operation.GetID()], resource.GetEnableSubscriptionBasedAIRatelimit(), resource.GetEnableBackendBasedAIRatelimit(), resource.GetBackendBasedAIRatelimitDescriptorValue()) + action = generateRouteAction(apiType, routeConfig, rateLimitPolicyCriteria, mirrorClusterNames[operation.GetID()], resource.GetEnableSubscriptionBasedAIRatelimit() && params.isAiAPI, resource.GetEnableBackendBasedAIRatelimit() && params.isAiAPI, resource.GetBackendBasedAIRatelimitDescriptorValue()) } logger.LoggerOasparser.Debug("Creating routes for resource with policies", resourcePath, operation.GetMethod()) // create route for current method. Add policies to route config. Send via enforcer @@ -1145,7 +1145,7 @@ func createRoutes(params *routeCreateParams) (routes []*routev3.Route, err error action := generateRouteAction(apiType, routeConfig, rateLimitPolicyCriteria, nil, resource.GetEnableSubscriptionBasedAIRatelimit(), resource.GetEnableBackendBasedAIRatelimit(), resource.GetBackendBasedAIRatelimitDescriptorValue()) rewritePath := generateRoutePathForReWrite(basePath, resourcePath, pathMatchType) action.Route.RegexRewrite = generateRegexMatchAndSubstitute(rewritePath, resourcePath, pathMatchType) - + route := generateRouteConfig(xWso2Basepath, match, action, nil, metaData, decorator, perRouteFilterConfigs, nil, nil, nil, nil) // general headers to add and remove are included in this methods routes = append(routes, route) @@ -1284,7 +1284,7 @@ func CreateAPIDefinitionRoute(basePath string, vHost string, methods []string, i Decorator: decorator, TypedPerFilterConfig: map[string]*any.Any{ wellknown.HTTPExternalAuthorization: filter, - HTTPExternalProcessor : filterExtProc, + HTTPExternalProcessor: filterExtProc, }, } return &router @@ -1378,7 +1378,7 @@ func CreateAPIDefinitionEndpoint(adapterInternalAPI *model.AdapterInternalAPI, v Decorator: decorator, TypedPerFilterConfig: map[string]*any.Any{ wellknown.HTTPExternalAuthorization: filter, - HTTPExternalProcessor : filterExtProc, + HTTPExternalProcessor: filterExtProc, }, } return router @@ -1443,7 +1443,7 @@ func CreateHealthEndpoint() *routev3.Route { Decorator: decorator, TypedPerFilterConfig: map[string]*any.Any{ wellknown.HTTPExternalAuthorization: filter, - HTTPExternalProcessor : filterExtProc, + HTTPExternalProcessor: filterExtProc, }, } return &router @@ -1493,7 +1493,7 @@ func CreateReadyEndpoint() *routev3.Route { Metadata: nil, Decorator: decorator, TypedPerFilterConfig: map[string]*any.Any{ - HTTPExternalProcessor : filterExtProc, + HTTPExternalProcessor: filterExtProc, }, } return &router @@ -1691,6 +1691,7 @@ func genRouteCreateParams(swagger *model.AdapterInternalAPI, resource *model.Res environment: swagger.GetEnvironment(), envType: swagger.EnvType, mirrorClusterNames: mirrorClusterNames, + isAiAPI: swagger.AIProvider.Enabled, } return params } diff --git a/adapter/internal/operator/controllers/dp/api_controller.go b/adapter/internal/operator/controllers/dp/api_controller.go index 19ba7db98..e14d43fcb 100644 --- a/adapter/internal/operator/controllers/dp/api_controller.go +++ b/adapter/internal/operator/controllers/dp/api_controller.go @@ -869,14 +869,12 @@ func (apiReconciler *APIReconciler) resolveAiSubscriptionRatelimitPolicies(ctx c Name: subscription.Spec.RatelimitRef.Name, Namespace: subscription.GetNamespace(), } - if err := apiReconciler.client.Get(ctx, nn, aiRatelimitPolicy, ); err != nil { - loggers.LoggerAPKOperator.Infof("No associated aiRatelimitPolicy found for Subscription: %s", utils.NamespacedName(&subscription)) - continue - } else { + if err := apiReconciler.client.Get(ctx, nn, aiRatelimitPolicy, ); err == nil { loggers.LoggerAPKOperator.Infof("API state set as AI subscription enabled") apiState.IsAiSubscriptionRatelimitEnabled = true break } + loggers.LoggerAPKOperator.Infof("No associated aiRatelimitPolicy found for Subscription: %s", utils.NamespacedName(&subscription)) } } diff --git a/adapter/internal/operator/synchronizer/data_store.go b/adapter/internal/operator/synchronizer/data_store.go index 23183f949..8d8f4bb50 100644 --- a/adapter/internal/operator/synchronizer/data_store.go +++ b/adapter/internal/operator/synchronizer/data_store.go @@ -81,6 +81,20 @@ func (ods *OperatorDataStore) processAPIState(apiNamespacedName types.Namespaced events = append(events, "Subscription based AI RatelimitPolicy") } + if cachedAPI.AIProvider == nil && apiState.AIProvider != nil { + cachedAPI.AIProvider = apiState.AIProvider + updated = true + events = append(events, "API provider") + } else if cachedAPI.AIProvider != nil && apiState.AIProvider == nil{ + cachedAPI.AIProvider = nil + updated = true + events = append(events, "API provider") + } else if cachedAPI.AIProvider.Generation != apiState.AIProvider.Generation { + cachedAPI.AIProvider = apiState.AIProvider + updated = true + events = append(events, "API provider") + } + if apiState.APIDefinition.Generation > cachedAPI.APIDefinition.Generation { cachedAPI.APIDefinition = apiState.APIDefinition updated = true diff --git a/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/grpc/ExternalProcessorService.java b/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/grpc/ExternalProcessorService.java index 402fa76e0..ade43dfa0 100644 --- a/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/grpc/ExternalProcessorService.java +++ b/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/grpc/ExternalProcessorService.java @@ -56,6 +56,10 @@ public class ExternalProcessorService extends ExternalProcessorGrpc.ExternalProc private static final String DESCRIPTOR_KEY_FOR_AI_SUBSCRIPTION = "subscription"; private static final String DYNAMIC_METADATA_KEY_FOR_ORGANIZATION_AND_AIRL_POLICY = "ratelimit:organization-and-rlpolicy"; private static final String DYNAMIC_METADATA_KEY_FOR_SUBSCRIPTION = "ratelimit:subscription"; + private static final String DYNAMIC_METADATA_KEY_FOR_EXTRACT_TOKEN_FROM = "aitoken:extracttokenfrom"; + private static final String DYNAMIC_METADATA_KEY_FOR_PROMPT_TOKEN_ID = "aitoken:prompttokenid"; + private static final String DYNAMIC_METADATA_KEY_FOR_COMPLETION_TOKEN_ID = "aitoken:completiontokenid"; + private static final String DYNAMIC_METADATA_KEY_FOR_TOTAL_TOKEN_ID = "aitoken:totaltokenid"; RatelimitClient ratelimitClient = new RatelimitClient(); @Override public StreamObserver process( @@ -92,34 +96,47 @@ public void onNext(ProcessingRequest request) { System.out.println("In the response flow metadata descirtor:" + filterMetadata.backendBasedAIRatelimitDescriptorValue); if (request.hasResponseBody()) { String body = request.getResponseBody().getBody().toStringUtf8(); -// System.out.println("Body: " + body); - Usage usage = extractUsageFromBody(body, "usage.completion_tokens", "usage.prompt_tokens", "usage.total_tokens"); - if (usage == null) { - logger.error("Usage details not found.."); - System.out.println("Usage details not found.."); - responseObserver.onCompleted(); - return; - } - System.out.println("body: " +request.getResponseBody().getBody().toStringUtf8()); - List configs = new ArrayList<>(); - if (filterMetadata.enableBackendBasedAIRatelimit) { - configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_REQUEST_TOKEN_COUNT, filterMetadata.backendBasedAIRatelimitDescriptorValue, usage.getPrompt_tokens())); - configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_RESPONSE_TOKEN_COUNT, filterMetadata.backendBasedAIRatelimitDescriptorValue, usage.getCompletion_tokens())); - configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_TOTAL_TOKEN_COUNT, filterMetadata.backendBasedAIRatelimitDescriptorValue, usage.getTotal_tokens())); - } - if (filterMetadata.enableSubscriptionBasedAIRatelimit) { - if (request.hasMetadataContext()) { - Struct filterMetadataFromAuthZ = request.getMetadataContext().getFilterMetadataOrDefault("envoy.filters.http.ext_authz", null); - if (filterMetadataFromAuthZ != null) { - String orgAndAIRLPolicyValue = filterMetadataFromAuthZ.getFieldsMap().get(DYNAMIC_METADATA_KEY_FOR_ORGANIZATION_AND_AIRL_POLICY).getStringValue(); - String aiRLSubsValue = filterMetadataFromAuthZ.getFieldsMap().get(DYNAMIC_METADATA_KEY_FOR_SUBSCRIPTION).getStringValue(); - configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_SUBSCRIPTION_BASED_AI_REQUEST_TOKEN_COUNT, orgAndAIRLPolicyValue, new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_SUBSCRIPTION, aiRLSubsValue, usage.getPrompt_tokens()))); - configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_SUBSCRIPTION_BASED_AI_RESPONSE_TOKEN_COUNT, orgAndAIRLPolicyValue, new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_SUBSCRIPTION, aiRLSubsValue, usage.getCompletion_tokens()))); - configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_SUBSCRIPTION_BASED_AI_TOTAL_TOKEN_COUNT, orgAndAIRLPolicyValue, new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_SUBSCRIPTION, aiRLSubsValue, usage.getTotal_tokens()))); + Struct filterMetadataFromAuthZ = request.getMetadataContext().getFilterMetadataOrDefault("envoy.filters.http.ext_authz", null); + if (filterMetadataFromAuthZ != null) { + String extractTokenFrom = filterMetadataFromAuthZ.getFieldsMap().get(DYNAMIC_METADATA_KEY_FOR_EXTRACT_TOKEN_FROM).getStringValue(); + System.out.println("Extract Token From: " + extractTokenFrom); + + String promptTokenID = filterMetadataFromAuthZ.getFieldsMap().get(DYNAMIC_METADATA_KEY_FOR_PROMPT_TOKEN_ID).getStringValue(); + System.out.println("Prompt Token ID: " + promptTokenID); + + String completionTokenID = filterMetadataFromAuthZ.getFieldsMap().get(DYNAMIC_METADATA_KEY_FOR_COMPLETION_TOKEN_ID).getStringValue(); + System.out.println("Completion Token ID: " + completionTokenID); + + String totalTokenID = filterMetadataFromAuthZ.getFieldsMap().get(DYNAMIC_METADATA_KEY_FOR_TOTAL_TOKEN_ID).getStringValue(); + System.out.println("Total Token ID: " + totalTokenID); + + Usage usage = extractUsageFromBody(body, completionTokenID, promptTokenID, totalTokenID); + if (usage == null) { + logger.error("Usage details not found.."); + System.out.println("Usage details not found.."); + responseObserver.onCompleted(); + return; + } + System.out.println("body: " +request.getResponseBody().getBody().toStringUtf8()); + List configs = new ArrayList<>(); + if (filterMetadata.enableBackendBasedAIRatelimit) { + configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_REQUEST_TOKEN_COUNT, filterMetadata.backendBasedAIRatelimitDescriptorValue, usage.getPrompt_tokens())); + configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_RESPONSE_TOKEN_COUNT, filterMetadata.backendBasedAIRatelimitDescriptorValue, usage.getCompletion_tokens())); + configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_TOTAL_TOKEN_COUNT, filterMetadata.backendBasedAIRatelimitDescriptorValue, usage.getTotal_tokens())); + } + if (filterMetadata.enableSubscriptionBasedAIRatelimit) { + if (request.hasMetadataContext()) { + if (filterMetadataFromAuthZ != null) { + String orgAndAIRLPolicyValue = filterMetadataFromAuthZ.getFieldsMap().get(DYNAMIC_METADATA_KEY_FOR_ORGANIZATION_AND_AIRL_POLICY).getStringValue(); + String aiRLSubsValue = filterMetadataFromAuthZ.getFieldsMap().get(DYNAMIC_METADATA_KEY_FOR_SUBSCRIPTION).getStringValue(); + configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_SUBSCRIPTION_BASED_AI_REQUEST_TOKEN_COUNT, orgAndAIRLPolicyValue, new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_SUBSCRIPTION, aiRLSubsValue, usage.getPrompt_tokens()))); + configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_SUBSCRIPTION_BASED_AI_RESPONSE_TOKEN_COUNT, orgAndAIRLPolicyValue, new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_SUBSCRIPTION, aiRLSubsValue, usage.getCompletion_tokens()))); + configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_SUBSCRIPTION_BASED_AI_TOTAL_TOKEN_COUNT, orgAndAIRLPolicyValue, new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_SUBSCRIPTION, aiRLSubsValue, usage.getTotal_tokens()))); + } } } + ratelimitClient.shouldRatelimit(configs); } - ratelimitClient.shouldRatelimit(configs); responseObserver.onCompleted(); } else { System.out.println("Request does not have response body"); diff --git a/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/security/AuthFilter.java b/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/security/AuthFilter.java index 3982396f8..b6d62e667 100644 --- a/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/security/AuthFilter.java +++ b/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/security/AuthFilter.java @@ -127,6 +127,12 @@ public boolean handleRequest(RequestContext requestContext) { boolean authenticated = false; // Any auth token has been provided for application-level security or not boolean canAuthenticated = false; + if (requestContext.getMatchedAPI() != null && requestContext.getMatchedAPI().getAiProvider() != null) { + requestContext.addMetadataToMap("aitoken:prompttokenid", requestContext.getMatchedAPI().getAiProvider().getPromptTokens().getValue()); + requestContext.addMetadataToMap("aitoken:completiontokenid", requestContext.getMatchedAPI().getAiProvider().getCompletionToken().getValue()); + requestContext.addMetadataToMap("aitoken:totaltokenid", requestContext.getMatchedAPI().getAiProvider().getTotalToken().getValue()); + requestContext.addMetadataToMap("aitoken:extracttokenfrom", requestContext.getMatchedAPI().getAiProvider().getCompletionToken().getIn()); + } for (Authenticator authenticator : authenticators) { if (authenticator.canAuthenticate(requestContext)) { // For transport level securities (mTLS), canAuthenticated will not be applied