Skip to content

Commit

Permalink
Merge pull request #2498 from Tharsanan1/apk-agent-airl
Browse files Browse the repository at this point in the history
AIRatelimitpolicy change should trigger an update
  • Loading branch information
Tharsanan1 authored Sep 27, 2024
2 parents 44efbda + cebc399 commit b74b689
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 11 deletions.
20 changes: 13 additions & 7 deletions adapter/internal/operator/controllers/dp/api_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ func (apiReconciler *APIReconciler) resolveAPIRefs(ctx context.Context, api dpv1
prodRouteRefs, namespace)
}
}
apiState.ProdAIRL = prodAirl
var sandAirl *dpv1alpha3.AIRateLimitPolicy
if len(sandRouteRefs) > 0 && apiState.APIDefinition.Spec.APIType == "REST" {
apiState.SandHTTPRoute = &synchronizer.HTTPRouteState{}
Expand All @@ -447,7 +448,7 @@ func (apiReconciler *APIReconciler) resolveAPIRefs(ctx context.Context, api dpv1
sandRouteRefs, namespace)
}
}

apiState.SandAIRL = sandAirl
// handle gql apis
if len(prodRouteRefs) > 0 && apiState.APIDefinition.Spec.APIType == "GraphQL" {
if apiState.ProdGQLRoute, err = apiReconciler.resolveGQLRouteRefs(ctx, prodRouteRefs, namespace,
Expand Down Expand Up @@ -2879,8 +2880,8 @@ func (apiReconciler *APIReconciler) convertAPIStateToAPICp(ctx context.Context,
sandVhost := geSandVhost(&apiState)
securityScheme, authHeader, apiKeyHeader := prepareSecuritySchemeForCP(&apiState)
operations := prepareOperations(&apiState)
var sandAIRLToAgent controlplane.AIRL
var prodAIRLToAgent controlplane.AIRL
var sandAIRLToAgent *controlplane.AIRL
var prodAIRLToAgent *controlplane.AIRL
if prodAIRL != nil {
var promptTC, completionTC, totalTC, requestC *uint32
var timeUnit string
Expand All @@ -2894,7 +2895,7 @@ func (apiReconciler *APIReconciler) convertAPIStateToAPICp(ctx context.Context,
timeUnit = prodAIRL.Spec.Override.RequestCount.Unit
requestC = &prodAIRL.Spec.Override.RequestCount.RequestsPerUnit
}
prodAIRLToAgent = controlplane.AIRL{
prodAIRLToAgent = &controlplane.AIRL{
PromptTokenCount: promptTC,
CompletionTokenCount: completionTC,
TotalTokenCount: totalTC,
Expand All @@ -2915,7 +2916,7 @@ func (apiReconciler *APIReconciler) convertAPIStateToAPICp(ctx context.Context,
timeUnit = sandAIRL.Spec.Override.RequestCount.Unit
requestC = &sandAIRL.Spec.Override.RequestCount.RequestsPerUnit
}
sandAIRLToAgent = controlplane.AIRL{
sandAIRLToAgent = &controlplane.AIRL{
PromptTokenCount: promptTC,
CompletionTokenCount: completionTC,
TotalTokenCount: totalTC,
Expand Down Expand Up @@ -2960,8 +2961,8 @@ func (apiReconciler *APIReconciler) convertAPIStateToAPICp(ctx context.Context,
Operations: operations,
APIHash: apiHash,
APIKeyHeader: apiKeyHeader,
SandAIRL: &sandAIRLToAgent,
ProdAIRL: &prodAIRLToAgent,
SandAIRL: sandAIRLToAgent,
ProdAIRL: prodAIRLToAgent,
AIConfiguration: aiConfiguration,
}
apiCPEvent.API = api
Expand Down Expand Up @@ -3043,6 +3044,9 @@ func (apiReconciler *APIReconciler) getAPIHash(apiState *synchronizer.APIState)
loggers.LoggerAPK.Infof("Error occured while extracting values using reflection. Error: %+v", r)
}
}()
if obj == nil {
return "nil"
}
var sb strings.Builder
objValue := reflect.ValueOf(obj)
if objValue.Kind() == reflect.Ptr {
Expand All @@ -3069,6 +3073,8 @@ func (apiReconciler *APIReconciler) getAPIHash(apiState *synchronizer.APIState)

uniqueIDs := make([]string, 0)
uniqueIDs = append(uniqueIDs, getUniqueID(apiState.APIDefinition, "ObjectMeta.Name", "ObjectMeta.Namespace", "ObjectMeta.Generation"))
uniqueIDs = append(uniqueIDs, getUniqueID(apiState.SandAIRL, "ObjectMeta.Name", "ObjectMeta.Namespace", "ObjectMeta.Generation"))
uniqueIDs = append(uniqueIDs, getUniqueID(apiState.ProdAIRL, "ObjectMeta.Name", "ObjectMeta.Namespace", "ObjectMeta.Generation"))
for _, auth := range apiState.Authentications {
uniqueIDs = append(uniqueIDs, getUniqueID(auth, "ObjectMeta.Name", "ObjectMeta.Namespace", "ObjectMeta.Generation"))
}
Expand Down
10 changes: 6 additions & 4 deletions adapter/internal/operator/synchronizer/api_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,18 @@ type APIState struct {
APIDefinitionFile []byte
SubscriptionValidation bool
MutualSSL *v1alpha2.MutualSSL
ProdAIRL *v1alpha3.AIRateLimitPolicy
SandAIRL *v1alpha3.AIRateLimitPolicy
}

// HTTPRouteState holds the state of the deployed httpRoutes. This state is compared with
// the state of the Kubernetes controller cache to detect updates.
// +k8s:deepcopy-gen=true
type HTTPRouteState struct {
HTTPRouteCombined *gwapiv1.HTTPRoute
HTTPRoutePartitions map[string]*gwapiv1.HTTPRoute
BackendMapping map[string]*v1alpha2.ResolvedBackend
Scopes map[string]v1alpha1.Scope
HTTPRouteCombined *gwapiv1.HTTPRoute
HTTPRoutePartitions map[string]*gwapiv1.HTTPRoute
BackendMapping map[string]*v1alpha2.ResolvedBackend
Scopes map[string]v1alpha1.Scope
RuleIdxToAiRatelimitPolicyMapping map[int]*v1alpha3.AIRateLimitPolicy
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
Expand Down Expand Up @@ -309,6 +310,9 @@ private static Usage extractUsageFromBody(String body, String completionTokenPat
JsonNode rootNode = mapper.readTree(body);
// Extract prompt token count
String[] keysForPromtTokens = promptTokenPath.split("\\.");
if (keysForPromtTokens.length > 0 && "$".equals(keysForPromtTokens[0])) {
keysForPromtTokens = Arrays.copyOfRange(keysForPromtTokens, 1, keysForPromtTokens.length);
}
JsonNode currentNodeForPromtToken = null;
if (rootNode.has(keysForPromtTokens[0])) {
currentNodeForPromtToken = rootNode.get(keysForPromtTokens[0]);
Expand Down

0 comments on commit b74b689

Please sign in to comment.