Skip to content

Commit

Permalink
Add sub based ai ratelimit feature related changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Tharsanan1 committed Sep 11, 2024
1 parent 49708f9 commit 1ff618b
Show file tree
Hide file tree
Showing 10 changed files with 71 additions and 15 deletions.
5 changes: 4 additions & 1 deletion adapter/internal/oasparser/envoyconf/routes_configs.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ func generateRouteAction(apiType string, routeConfig *model.EndpointConfig, rate
if isBackendBasedAIRatelimitEnabled {
action.Route.RateLimits = append(action.Route.RateLimits, generateBackendBasedAIRatelimit(descriptorValueForBackendBasedAIRatelimit)...)
}
if isSubscriptionBasedAIRatelimitEnabled {
action.Route.RateLimits = append(action.Route.RateLimits, generateSubscriptionBasedAIRatelimit()...)
}

// Add request mirroring configurations
if mirrorClusterNames != nil && len(mirrorClusterNames) > 0 {
Expand Down Expand Up @@ -260,7 +263,7 @@ func generateBackendBasedAIRatelimit(descValue string) []*routev3.RateLimit {
}


func generateSubscriptionBasedAIRatelimit(descValue string) []*routev3.RateLimit {
func generateSubscriptionBasedAIRatelimit() []*routev3.RateLimit {
rateLimitForRequestTokenCount := routev3.RateLimit{
Actions: []*routev3.RateLimit_Action{
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,7 @@ func createRoutes(params *routeCreateParams) (routes []*routev3.Route, err error
}
routeConfig := resource.GetEndpoints().Config
metaData := &corev3.Metadata{}
logger.LoggerAPI.Infof("Is backend based rl enabled: %+v, Is subs based rl enabled: %+v", resource.GetEnableBackendBasedAIRatelimit(), resource.GetEnableSubscriptionBasedAIRatelimit())
if resource.GetEnableBackendBasedAIRatelimit() || resource.GetEnableSubscriptionBasedAIRatelimit() {
metaData = &corev3.Metadata{
FilterMetadata: map[string]*structpb.Struct{
Expand Down
14 changes: 10 additions & 4 deletions adapter/internal/operator/controllers/dp/api_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -865,12 +865,18 @@ func (apiReconciler *APIReconciler) resolveAiSubscriptionRatelimitPolicies(ctx c
}
for _, subscription := range subscriptionList.Items {
aiRatelimitPolicy := &dpv1alpha3.AIRateLimitPolicy{}
if err := apiReconciler.client.Get(ctx, utils.NamespacedName(&subscription), aiRatelimitPolicy, ); err != nil {
nn:= types.NamespacedName{
Name: subscription.Spec.RatelimitRef.Name,
Namespace: subscription.GetNamespace(),
}
if err := apiReconciler.client.Get(ctx, nn, aiRatelimitPolicy, ); err != nil {
loggers.LoggerAPKOperator.Infof("No associated aiRatelimitPolicy found for Subscription: %s", utils.NamespacedName(&subscription))
return
continue
} else {
loggers.LoggerAPKOperator.Infof("API state set as AI subscription enabled")
apiState.IsAiSubscriptionRatelimitEnabled = true
break
}
apiState.IsAiSubscriptionRatelimitEnabled = true
return
}
}

Expand Down
2 changes: 1 addition & 1 deletion adapter/internal/operator/operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ import (
dpv1alpha2 "github.com/wso2/apk/common-go-libs/apis/dp/v1alpha2"
dpv1alpha3 "github.com/wso2/apk/common-go-libs/apis/dp/v1alpha3"
cpv1alpha2 "github.com/wso2/apk/common-go-libs/apis/cp/v1alpha2"
cpv1alpha3 "github.com/wso2/apk/common-go-libs/apis/cp/v1alpha2"
cpv1alpha3 "github.com/wso2/apk/common-go-libs/apis/cp/v1alpha3"
//+kubebuilder:scaffold:imports
)

Expand Down
5 changes: 5 additions & 0 deletions adapter/internal/operator/synchronizer/data_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ func (ods *OperatorDataStore) processAPIState(apiNamespacedName types.Namespaced
var updated bool
events := []string{}
cachedAPI := ods.apiStore[apiNamespacedName]
if cachedAPI.IsAiSubscriptionRatelimitEnabled != apiState.IsAiSubscriptionRatelimitEnabled {
cachedAPI.IsAiSubscriptionRatelimitEnabled = apiState.IsAiSubscriptionRatelimitEnabled
updated = true
events = append(events, "Subscription based AI RatelimitPolicy")
}

if apiState.APIDefinition.Generation > cachedAPI.APIDefinition.Generation {
cachedAPI.APIDefinition = apiState.APIDefinition
Expand Down
3 changes: 2 additions & 1 deletion common-controller/internal/xds/ratelimiter_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,8 @@ func (r *rateLimitPolicyCache) ProcessAIRatelimitPolicySpecsAndUpdateCache(aiRat
}

func prepareSubscriptionBasedAIRatelimitIdentifier(org string, namespacedName types.NamespacedName) string {
return fmt.Sprintf("%s-%s-%s", org, string(namespacedName.Namespace), string(namespacedName.Name))
// return fmt.Sprintf("%s-%s-%s", org, string(namespacedName.Namespace), string(namespacedName.Name))
return fmt.Sprintf("%s-%s", org, string(namespacedName.Name))
}

func prepareAIRatelimitIdentifier(org string, namespacedName types.NamespacedName, spec *dpv1alpha3.AIRateLimitPolicySpec) string {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ private String constructQueryParamString(boolean removeAllQueryParams, String re
* @param value
*/
private void addMetadata(Struct.Builder structBuilder, String key, String value) {
System.out.println("Key: " + key + " value: " + value);
structBuilder.putFields(key, Value.newBuilder().setStringValue(value).build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@

package org.wso2.apk.enforcer.grpc;

import com.fasterxml.jackson.core.JsonParseException;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.protobuf.Struct;
import com.google.protobuf.Value;
import io.envoyproxy.envoy.config.core.v3.Metadata;
import io.envoyproxy.envoy.service.ext_proc.v3.BodyMutation;
import io.envoyproxy.envoy.service.ext_proc.v3.BodyResponse;
import io.envoyproxy.envoy.service.ext_proc.v3.CommonResponse;
Expand All @@ -47,6 +50,12 @@ public class ExternalProcessorService extends ExternalProcessorGrpc.ExternalProc
private static final String DESCRIPTOR_KEY_FOR_AI_REQUEST_TOKEN_COUNT = "airequesttokencount";
private static final String DESCRIPTOR_KEY_FOR_AI_RESPONSE_TOKEN_COUNT = "airesponsetokencount";
private static final String DESCRIPTOR_KEY_FOR_AI_TOTAL_TOKEN_COUNT = "aitotaltokencount";
private static final String DESCRIPTOR_KEY_FOR_SUBSCRIPTION_BASED_AI_REQUEST_TOKEN_COUNT = "airequesttokencountsubs";
private static final String DESCRIPTOR_KEY_FOR_SUBSCRIPTION_BASED_AI_RESPONSE_TOKEN_COUNT = "airesponsetokencountsubs";
private static final String DESCRIPTOR_KEY_FOR_SUBSCRIPTION_BASED_AI_TOTAL_TOKEN_COUNT = "aitotaltokencountsubs";
private static final String DESCRIPTOR_KEY_FOR_AI_SUBSCRIPTION = "subscription";
private static final String DYNAMIC_METADATA_KEY_FOR_ORGANIZATION_AND_AIRL_POLICY = "ratelimit:organization-and-rlpolicy";
private static final String DYNAMIC_METADATA_KEY_FOR_SUBSCRIPTION = "ratelimit:subscription";
RatelimitClient ratelimitClient = new RatelimitClient();
@Override
public StreamObserver<ProcessingRequest> process(
Expand Down Expand Up @@ -92,13 +101,25 @@ public void onNext(ProcessingRequest request) {
return;
}
System.out.println("body: " +request.getResponseBody().getBody().toStringUtf8());
List<RatelimitClient.KeyValueHitsAddend> configs = new ArrayList<>();
if (filterMetadata.enableBackendBasedAIRatelimit) {
List<RatelimitClient.KeyValueHitsAddend> configs = new ArrayList<>();
configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_REQUEST_TOKEN_COUNT, filterMetadata.backendBasedAIRatelimitDescriptorValue, usage.getPrompt_tokens()));
configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_RESPONSE_TOKEN_COUNT, filterMetadata.backendBasedAIRatelimitDescriptorValue, usage.getCompletion_tokens()));
configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_TOTAL_TOKEN_COUNT, filterMetadata.backendBasedAIRatelimitDescriptorValue, usage.getTotal_tokens()));
ratelimitClient.shouldRatelimit(configs);
}
if (filterMetadata.enableSubscriptionBasedAIRatelimit) {
if (request.hasMetadataContext()) {
Struct filterMetadataFromAuthZ = request.getMetadataContext().getFilterMetadataOrDefault("envoy.filters.http.ext_authz", null);
if (filterMetadataFromAuthZ != null) {
String orgAndAIRLPolicyValue = filterMetadataFromAuthZ.getFieldsMap().get(DYNAMIC_METADATA_KEY_FOR_ORGANIZATION_AND_AIRL_POLICY).getStringValue();
String aiRLSubsValue = filterMetadataFromAuthZ.getFieldsMap().get(DYNAMIC_METADATA_KEY_FOR_SUBSCRIPTION).getStringValue();
configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_SUBSCRIPTION_BASED_AI_REQUEST_TOKEN_COUNT, orgAndAIRLPolicyValue, new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_SUBSCRIPTION, aiRLSubsValue, usage.getPrompt_tokens())));
configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_SUBSCRIPTION_BASED_AI_RESPONSE_TOKEN_COUNT, orgAndAIRLPolicyValue, new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_SUBSCRIPTION, aiRLSubsValue, usage.getPrompt_tokens())));
configs.add(new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_SUBSCRIPTION_BASED_AI_TOTAL_TOKEN_COUNT, orgAndAIRLPolicyValue, new RatelimitClient.KeyValueHitsAddend(DESCRIPTOR_KEY_FOR_AI_SUBSCRIPTION, aiRLSubsValue, usage.getPrompt_tokens())));
}
}
}
ratelimitClient.shouldRatelimit(configs);
responseObserver.onCompleted();
} else {
System.out.println("Request does not have response body");
Expand Down Expand Up @@ -258,8 +279,8 @@ private static Usage extractUsageFromBody(String body, String completionTokenPat
System.out.println("Usage extracted: "+ usage);
return usage;

} catch (JsonProcessingException e) {
logger.error(String.format("Unexpected error while extracting usage from the body: %s", body), e);
} catch (Exception e) {
System.out.println(String.format("Unexpected error while extracting usage from the body: %s", body) + " \n" + e);
return null;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,22 @@ public void shouldRatelimit(List<KeyValueHitsAddend> configs) {
executorService.submit(() -> {
System.out.println("Ratelimitclient test");
for (KeyValueHitsAddend config : configs) {

System.out.println("For: " + config.getKey());
RateLimitDescriptor descriptor = RateLimitDescriptor.newBuilder()
.addEntries(RateLimitDescriptor.Entry.newBuilder().setKey(config.getKey()).setValue(config.getValue()).build())
.build();
RateLimitDescriptor.Builder builder = RateLimitDescriptor.newBuilder()
.addEntries(RateLimitDescriptor.Entry.newBuilder().setKey(config.getKey()).setValue(config.getValue()).build());
KeyValueHitsAddend internalKeyValueHitsAddend = config.keyValueHitsAddend;
int hitsAddend = config.getHitsAddend();
while (internalKeyValueHitsAddend != null) {
builder.addEntries(RateLimitDescriptor.Entry.newBuilder().setKey(internalKeyValueHitsAddend.getKey()).setValue(internalKeyValueHitsAddend.getValue()).build());
hitsAddend = internalKeyValueHitsAddend.getHitsAddend();
internalKeyValueHitsAddend = internalKeyValueHitsAddend.keyValueHitsAddend;
}
RateLimitDescriptor descriptor = builder.build();
RateLimitRequest rateLimitRequest = RateLimitRequest.newBuilder()
.addDescriptors(descriptor)
.setDomain("Default")
.setHitsAddend(config.getHitsAddend())
.setHitsAddend(hitsAddend)
.build();
RateLimitResponse rateLimitResponse = stub.shouldRateLimit(rateLimitRequest);
System.out.println(rateLimitResponse.getOverallCode());
Expand All @@ -73,11 +81,19 @@ public static class KeyValueHitsAddend {
private String key;
private String value;
private int hitsAddend;
private KeyValueHitsAddend keyValueHitsAddend;

public KeyValueHitsAddend(String key, String value, int hitsAddend) {
this.key = key;
this.value = value;
this.hitsAddend = hitsAddend;
this.keyValueHitsAddend = null;
}
public KeyValueHitsAddend(String key, String value, KeyValueHitsAddend keyValueHitsAddend) {
this.key = key;
this.value = value;
this.hitsAddend = -1;
this.keyValueHitsAddend = keyValueHitsAddend;
}

public String getKey() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,8 @@ public AuthenticationContext authenticate(RequestContext requestContext) throws
requestContext.addMetadataToMap("ratelimit:subscription", subscriptionId);
requestContext.addMetadataToMap("ratelimit:usage-policy", subscription.getRatelimitTier());
requestContext.addMetadataToMap("ratelimit:organization", subscription.getOrganization());
System.out.println("Value: "+ String.format("%s-%s", subscription.getOrganization(), subscription.getRatelimitTier()));
requestContext.addMetadataToMap("ratelimit:organization-and-rlpolicy", String.format("%s-%s", subscription.getOrganization(), subscription.getRatelimitTier()));
}
}
}
Expand Down

0 comments on commit 1ff618b

Please sign in to comment.