diff --git a/adapter/internal/controlplane/eventPublisher.go b/adapter/internal/controlplane/eventPublisher.go index 3cad85379..9aa28a3a6 100644 --- a/adapter/internal/controlplane/eventPublisher.go +++ b/adapter/internal/controlplane/eventPublisher.go @@ -82,6 +82,7 @@ type API struct { IsDefaultVersion bool `json:"isDefaultVersion"` Definition string `json:"definition"` APIType string `json:"apiType"` + APISubType string `json:"apiSubType"` BasePath string `json:"basePath"` Organization string `json:"organization"` SystemAPI bool `json:"systemAPI"` @@ -98,18 +99,25 @@ type API struct { AuthHeader string `json:"authHeader"` APIKeyHeader string `json:"apiKeyHeader"` Operations []Operation `json:"operations"` + AIConfiguration AIConfiguration `json:"aiConfiguration"` APIHash string `json:"-"` - SandAIRL *AIRL `json:"sandAIRL"` - ProdAIRL *AIRL `json:"prodAIRL"` + SandAIRL *AIRL `json:"sandAIRL"` + ProdAIRL *AIRL `json:"prodAIRL"` } // AIRL holds AI ratelimit related data type AIRL struct { - PromptTokenCount *uint32 `json:"promptTokenCount"` - CompletionTokenCount *uint32 `json:"CompletionTokenCount"` - TotalTokenCount *uint32 `json:"totalTokenCount"` - TimeUnit string `json:"timeUnit"` - RequestCount *uint32 `json:"requestCount"` + PromptTokenCount *uint32 `json:"promptTokenCount"` + CompletionTokenCount *uint32 `json:"CompletionTokenCount"` + TotalTokenCount *uint32 `json:"totalTokenCount"` + TimeUnit string `json:"timeUnit"` + RequestCount *uint32 `json:"requestCount"` +} + +// AIConfiguration holds the AI configuration +type AIConfiguration struct { + LLMProviderName string `json:"llmProviderName"` + LLMProviderAPIVersion string `json:"llmProviderAPIVersion"` } // Headers contains the request and response header modifier information diff --git a/adapter/internal/operator/controllers/dp/api_controller.go b/adapter/internal/operator/controllers/dp/api_controller.go index b4646de29..f647a8a0e 100644 --- a/adapter/internal/operator/controllers/dp/api_controller.go +++ b/adapter/internal/operator/controllers/dp/api_controller.go @@ -430,6 +430,7 @@ func (apiReconciler *APIReconciler) resolveAPIRefs(ctx context.Context, api dpv1 prodRouteRefs, namespace) } } + apiState.ProdAIRL = prodAirl var sandAirl *dpv1alpha3.AIRateLimitPolicy if len(sandRouteRefs) > 0 && apiState.APIDefinition.Spec.APIType == "REST" { apiState.SandHTTPRoute = &synchronizer.HTTPRouteState{} @@ -447,7 +448,7 @@ func (apiReconciler *APIReconciler) resolveAPIRefs(ctx context.Context, api dpv1 sandRouteRefs, namespace) } } - + apiState.SandAIRL = sandAirl // handle gql apis if len(prodRouteRefs) > 0 && apiState.APIDefinition.Spec.APIType == "GraphQL" { if apiState.ProdGQLRoute, err = apiReconciler.resolveGQLRouteRefs(ctx, prodRouteRefs, namespace, @@ -2879,8 +2880,8 @@ func (apiReconciler *APIReconciler) convertAPIStateToAPICp(ctx context.Context, sandVhost := geSandVhost(&apiState) securityScheme, authHeader, apiKeyHeader := prepareSecuritySchemeForCP(&apiState) operations := prepareOperations(&apiState) - var sandAIRLToAgent controlplane.AIRL - var prodAIRLToAgent controlplane.AIRL + var sandAIRLToAgent *controlplane.AIRL + var prodAIRLToAgent *controlplane.AIRL if prodAIRL != nil { var promptTC, completionTC, totalTC, requestC *uint32 var timeUnit string @@ -2894,7 +2895,7 @@ func (apiReconciler *APIReconciler) convertAPIStateToAPICp(ctx context.Context, timeUnit = prodAIRL.Spec.Override.RequestCount.Unit requestC = &prodAIRL.Spec.Override.RequestCount.RequestsPerUnit } - prodAIRLToAgent = controlplane.AIRL{ + prodAIRLToAgent = &controlplane.AIRL{ PromptTokenCount: promptTC, CompletionTokenCount: completionTC, TotalTokenCount: totalTC, @@ -2915,7 +2916,7 @@ func (apiReconciler *APIReconciler) convertAPIStateToAPICp(ctx context.Context, timeUnit = sandAIRL.Spec.Override.RequestCount.Unit requestC = &sandAIRL.Spec.Override.RequestCount.RequestsPerUnit } - sandAIRLToAgent = controlplane.AIRL{ + sandAIRLToAgent = &controlplane.AIRL{ PromptTokenCount: promptTC, CompletionTokenCount: completionTC, TotalTokenCount: totalTC, @@ -2923,11 +2924,24 @@ func (apiReconciler *APIReconciler) convertAPIStateToAPICp(ctx context.Context, RequestCount: requestC, } } + subType := "DEFAULT" + aiConfiguration := controlplane.AIConfiguration{} + if apiState.AIProvider != nil { + loggers.LoggerAPKOperator.Debugf("AIProvider is found") + subType = "AIAPI" + aiConfiguration = controlplane.AIConfiguration{ + LLMProviderName: apiState.AIProvider.Spec.ProviderName, + LLMProviderAPIVersion: apiState.AIProvider.Spec.ProviderAPIVersion, + } + } + loggers.LoggerAPKOperator.Debugf("Resolved aiConfiguration: %+v", aiConfiguration) + api := controlplane.API{ APIName: spec.APIName, APIVersion: spec.APIVersion, IsDefaultVersion: spec.IsDefaultVersion, APIType: spec.APIType, + APISubType: subType, BasePath: spec.BasePath, Organization: spec.Organization, Environment: spec.Environment, @@ -2947,8 +2961,9 @@ func (apiReconciler *APIReconciler) convertAPIStateToAPICp(ctx context.Context, Operations: operations, APIHash: apiHash, APIKeyHeader: apiKeyHeader, - SandAIRL: &sandAIRLToAgent, - ProdAIRL: &prodAIRLToAgent, + SandAIRL: sandAIRLToAgent, + ProdAIRL: prodAIRLToAgent, + AIConfiguration: aiConfiguration, } apiCPEvent.API = api apiCPEvent.CRName = apiState.APIDefinition.ObjectMeta.Name @@ -3029,6 +3044,9 @@ func (apiReconciler *APIReconciler) getAPIHash(apiState *synchronizer.APIState) loggers.LoggerAPK.Infof("Error occured while extracting values using reflection. Error: %+v", r) } }() + if obj == nil { + return "nil" + } var sb strings.Builder objValue := reflect.ValueOf(obj) if objValue.Kind() == reflect.Ptr { @@ -3055,6 +3073,8 @@ func (apiReconciler *APIReconciler) getAPIHash(apiState *synchronizer.APIState) uniqueIDs := make([]string, 0) uniqueIDs = append(uniqueIDs, getUniqueID(apiState.APIDefinition, "ObjectMeta.Name", "ObjectMeta.Namespace", "ObjectMeta.Generation")) + uniqueIDs = append(uniqueIDs, getUniqueID(apiState.SandAIRL, "ObjectMeta.Name", "ObjectMeta.Namespace", "ObjectMeta.Generation")) + uniqueIDs = append(uniqueIDs, getUniqueID(apiState.ProdAIRL, "ObjectMeta.Name", "ObjectMeta.Namespace", "ObjectMeta.Generation")) for _, auth := range apiState.Authentications { uniqueIDs = append(uniqueIDs, getUniqueID(auth, "ObjectMeta.Name", "ObjectMeta.Namespace", "ObjectMeta.Generation")) } diff --git a/adapter/internal/operator/synchronizer/api_state.go b/adapter/internal/operator/synchronizer/api_state.go index c0a25cbd1..b0a03e2ad 100644 --- a/adapter/internal/operator/synchronizer/api_state.go +++ b/adapter/internal/operator/synchronizer/api_state.go @@ -48,6 +48,8 @@ type APIState struct { APIDefinitionFile []byte SubscriptionValidation bool MutualSSL *v1alpha2.MutualSSL + ProdAIRL *v1alpha3.AIRateLimitPolicy + SandAIRL *v1alpha3.AIRateLimitPolicy } // HTTPRouteState holds the state of the deployed httpRoutes. This state is compared with diff --git a/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/grpc/ExtAuthService.java b/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/grpc/ExtAuthService.java index 060a441b0..e582beafe 100644 --- a/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/grpc/ExtAuthService.java +++ b/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/grpc/ExtAuthService.java @@ -191,8 +191,11 @@ private CheckResponse buildResponse(CheckRequest request, ResponseObject respons Struct.Builder structBuilder = Struct.newBuilder(); if (responseObject.getMetaDataMap() != null) { - responseObject.getMetaDataMap().forEach((key, value) -> - addMetadata(structBuilder, key, value)); + responseObject.getMetaDataMap().forEach((key, value) -> { + if (value != null) { + addMetadata(structBuilder, key, value); + } + }); } //Adds original request path header without params as a metadata for access logging. diff --git a/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/grpc/ExternalProcessorService.java b/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/grpc/ExternalProcessorService.java index 0a2dbaf34..f6b4c87e4 100644 --- a/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/grpc/ExternalProcessorService.java +++ b/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/grpc/ExternalProcessorService.java @@ -44,6 +44,7 @@ import java.io.InputStreamReader; import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -309,6 +310,9 @@ private static Usage extractUsageFromBody(String body, String completionTokenPat JsonNode rootNode = mapper.readTree(body); // Extract prompt token count String[] keysForPromtTokens = promptTokenPath.split("\\."); + if (keysForPromtTokens.length > 0 && "$".equals(keysForPromtTokens[0])) { + keysForPromtTokens = Arrays.copyOfRange(keysForPromtTokens, 1, keysForPromtTokens.length); + } JsonNode currentNodeForPromtToken = null; if (rootNode.has(keysForPromtTokens[0])) { currentNodeForPromtToken = rootNode.get(keysForPromtTokens[0]); diff --git a/helm-charts/values.yaml b/helm-charts/values.yaml index 27930447a..12381fef0 100644 --- a/helm-charts/values.yaml +++ b/helm-charts/values.yaml @@ -120,8 +120,8 @@ wso2: failureThreshold: 5 strategy: RollingUpdate replicas: 1 - imagePullPolicy: IfNotPresent - image: apk-adapter:1.2.0-SNAPSHOT + imagePullPolicy: Always + image: wso2/apk-adapter:latest security: sslHostname: "adapter" logging: @@ -154,8 +154,8 @@ wso2: failureThreshold: 5 strategy: RollingUpdate replicas: 1 - imagePullPolicy: IfNotPresent - image: apk-common-controller:1.2.0-SNAPSHOT + imagePullPolicy: Always + image: wso2/apk-common-controller:latest security: sslHostname: "commoncontroller" # configs: @@ -242,8 +242,8 @@ wso2: periodSeconds: 20 failureThreshold: 5 strategy: RollingUpdate - imagePullPolicy: IfNotPresent - image: apk-enforcer:1.2.0-SNAPSHOT + imagePullPolicy: Always + image: wso2/apk-enforcer:latest security: sslHostname: "enforcer" # logging: