diff --git a/gateway/enforcer/org.wso2.apk.enforcer.commons/src/main/java/org/wso2/apk/enforcer/commons/analytics/publishers/dto/AITokenUsage.java b/gateway/enforcer/org.wso2.apk.enforcer.commons/src/main/java/org/wso2/apk/enforcer/commons/analytics/publishers/dto/AITokenUsage.java index b381fd271..5607fd8bc 100644 --- a/gateway/enforcer/org.wso2.apk.enforcer.commons/src/main/java/org/wso2/apk/enforcer/commons/analytics/publishers/dto/AITokenUsage.java +++ b/gateway/enforcer/org.wso2.apk.enforcer.commons/src/main/java/org/wso2/apk/enforcer/commons/analytics/publishers/dto/AITokenUsage.java @@ -24,18 +24,18 @@ */ public class AITokenUsage { @JsonProperty("totalTokens") - private Double totalTokens; + private Integer totalTokens; @JsonProperty("promptTokens") - private Double promptTokens; + private Integer promptTokens; @JsonProperty("completionTokens") - private Double completionTokens; + private Integer completionTokens; @JsonProperty("hour") private Integer hour; - public Double getTotalTokens() { + public Integer getTotalTokens() { return totalTokens; } @@ -50,27 +50,27 @@ public void setHour(Integer hour) { this.hour = hour; } - public void setTotalTokens(Double totalTokens) { + public void setTotalTokens(Integer totalTokens) { this.totalTokens = totalTokens; } - public Double getPromptTokens() { + public Integer getPromptTokens() { return promptTokens; } - public void setPromptTokens(Double promptTokens) { + public void setPromptTokens(Integer promptTokens) { this.promptTokens = promptTokens; } - public Double getCompletionTokens() { + public Integer getCompletionTokens() { return completionTokens; } - public void setCompletionTokens(Double completionTokens) { + public void setCompletionTokens(Integer completionTokens) { this.completionTokens = completionTokens; } diff --git a/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/analytics/ChoreoAnalyticsProvider.java b/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/analytics/ChoreoAnalyticsProvider.java index 40784b448..bf92c8927 100644 --- a/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/analytics/ChoreoAnalyticsProvider.java +++ b/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/analytics/ChoreoAnalyticsProvider.java @@ -23,7 +23,6 @@ import com.google.protobuf.Value; import io.envoyproxy.envoy.data.accesslog.v3.AccessLogCommon; import io.envoyproxy.envoy.data.accesslog.v3.HTTPAccessLogEntry; -import io.envoyproxy.envoy.service.ext_proc.v3.ProcessingResponse; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.wso2.apk.enforcer.commons.analytics.collectors.AnalyticsCustomDataProvider; @@ -260,9 +259,9 @@ public Map getProperties() { Map map = new HashMap(); Map fieldsMap = getFieldsMapFromLogEntry(); String gwURL = getValueAsString(fieldsMap, MetadataConstants.GATEWAY_URL); - Double totalTokenCount = getValueAsDouble(fieldsMap, MetadataConstants.TOTAL_TOKEN_COUNT); - Double completionTokenCount = getValueAsDouble(fieldsMap, MetadataConstants.COMPLETION_TOKEN_COUNT); - Double promptTokenCount = getValueAsDouble(fieldsMap, MetadataConstants.PROMPT_TOKEN_COUNT); + Integer totalTokenCount = getValueAsInteger(fieldsMap, MetadataConstants.TOTAL_TOKEN_COUNT); + Integer completionTokenCount = getValueAsInteger(fieldsMap, MetadataConstants.COMPLETION_TOKEN_COUNT); + Integer promptTokenCount = getValueAsInteger(fieldsMap, MetadataConstants.PROMPT_TOKEN_COUNT); String model = getValueAsString(fieldsMap, MetadataConstants.MODEL); String providerName = getValueAsString(fieldsMap, MetadataConstants.AI_PROVIDER_NAME); String providerAPIVersion = getValueAsString(fieldsMap, MetadataConstants.AI_PROVIDER_API_VERSION); @@ -308,12 +307,13 @@ private String getValueAsString(Map fieldsMap, String key) { return fieldsMap.get(key).getStringValue(); } - private Double getValueAsDouble(Map fieldsMap, String key) { + private Integer getValueAsInteger(Map fieldsMap, String key) { if (fieldsMap == null || !fieldsMap.containsKey(key)) { return null; } - return fieldsMap.get(key).getNumberValue(); + Double d = fieldsMap.get(key).getNumberValue(); + return d.intValue(); } private Map getFieldsMapFromLogEntry() { @@ -325,10 +325,18 @@ private Map getFieldsMapFromLogEntry() { .containsKey(MetadataConstants.EXT_AUTH_METADATA_CONTEXT_KEY)) { return new HashMap<>(0); } - Map metadataFromExtProc = logEntry.getCommonProperties().getMetadata().getFilterMetadataMap() - .get(MetadataConstants.EXT_PROC_METADATA_CONTEXT_KEY).getFieldsMap(); - Map metadataFromExtAuthz = logEntry.getCommonProperties().getMetadata().getFilterMetadataMap() - .get(MetadataConstants.EXT_AUTH_METADATA_CONTEXT_KEY).getFieldsMap(); + Map metadataFromExtProc = new HashMap<>(); + if (logEntry.getCommonProperties().getMetadata().getFilterMetadataMap() + .get(MetadataConstants.EXT_PROC_METADATA_CONTEXT_KEY) != null) { + metadataFromExtProc = logEntry.getCommonProperties().getMetadata().getFilterMetadataMap() + .get(MetadataConstants.EXT_PROC_METADATA_CONTEXT_KEY).getFieldsMap(); + } + Map metadataFromExtAuthz = new HashMap<>(); + if (logEntry.getCommonProperties().getMetadata().getFilterMetadataMap() + .get(MetadataConstants.EXT_AUTH_METADATA_CONTEXT_KEY) != null) { + metadataFromExtAuthz = logEntry.getCommonProperties().getMetadata().getFilterMetadataMap() + .get(MetadataConstants.EXT_AUTH_METADATA_CONTEXT_KEY).getFieldsMap(); + } Map mergedMetadata = new HashMap<>(metadataFromExtProc); mergedMetadata.putAll(metadataFromExtAuthz); return mergedMetadata; diff --git a/test/cucumber-tests/build.gradle b/test/cucumber-tests/build.gradle index d7e3821d3..1c35a066b 100644 --- a/test/cucumber-tests/build.gradle +++ b/test/cucumber-tests/build.gradle @@ -51,6 +51,7 @@ dependencies { implementation 'io.grpc:grpc-stub:1.57.0' implementation 'io.grpc:grpc-stub:1.57.0' implementation 'com.google.protobuf:protobuf-java:4.28.2' + implementation group: 'io.kubernetes', name: 'client-java', version: '21.0.1' } test { diff --git a/test/cucumber-tests/src/test/java/org/wso2/apk/integration/api/BaseSteps.java b/test/cucumber-tests/src/test/java/org/wso2/apk/integration/api/BaseSteps.java index 9d0d09795..e7245f198 100644 --- a/test/cucumber-tests/src/test/java/org/wso2/apk/integration/api/BaseSteps.java +++ b/test/cucumber-tests/src/test/java/org/wso2/apk/integration/api/BaseSteps.java @@ -39,6 +39,12 @@ import io.cucumber.java.Before; import io.cucumber.java.en.Given; import io.cucumber.java.en.Then; +import io.kubernetes.client.openapi.ApiClient; +import io.kubernetes.client.openapi.ApiException; +import io.kubernetes.client.openapi.Configuration; +import io.kubernetes.client.openapi.apis.CoreV1Api; +import io.kubernetes.client.openapi.models.V1Pod; +import io.kubernetes.client.util.Config; import org.apache.commons.io.IOUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -71,6 +77,8 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; + import io.grpc.Status; import io.grpc.StatusRuntimeException; @@ -305,6 +313,31 @@ public void eventualSuccess(int statusCode, DataTable dataTable) throws IOExcept } } + @Then("I see following strings in the enforcer logs") + public void checkEnforcerLogs(DataTable dataTable) throws IOException, InterruptedException, ApiException { + List stringsToCheck = dataTable.asList(String.class); + ApiClient client = Config.defaultClient(); + Configuration.setDefaultApiClient(client); + CoreV1Api api = new CoreV1Api(); + String namespace = "apk-integration-test"; + String podName = "your-pod-name"; + String labelSelector = "app.kubernetes.io/app=gateway"; + + List podList = api.listNamespacedPod(namespace).labelSelector(labelSelector).execute().getItems(); + if (!podList.isEmpty()) { + podName = Objects.requireNonNull(podList.get(0).getMetadata()).getName(); + } + try { + String logs = api.readNamespacedPodLog(podName, namespace).container("enforcer").sinceSeconds(60).execute(); + Assert.assertNotNull(logs, String.format("Could not find any logs in the last 60 seconds. PodName: %s, namespace: %s", podName, namespace)); + for(String word : stringsToCheck) { + Assert.assertTrue(logs.contains(word), "Expected word '" + word + "' not found in logs"); + } + } catch (ApiException e) { + System.out.println(e); + } + } + @Then("I wait for next minute") public void waitForNextMinute() throws InterruptedException { LocalDateTime now = LocalDateTime.now(); diff --git a/test/cucumber-tests/src/test/resources/tests/api/APIBackendBasedAIRatelimit.feature b/test/cucumber-tests/src/test/resources/tests/api/APIBackendBasedAIRatelimit.feature index 13fb3c024..5c18d0e09 100644 --- a/test/cucumber-tests/src/test/resources/tests/api/APIBackendBasedAIRatelimit.feature +++ b/test/cucumber-tests/src/test/resources/tests/api/APIBackendBasedAIRatelimit.feature @@ -10,6 +10,21 @@ Feature: API backend based AI ratelimit Feature Then the response status code should be 200 And the response headers should contain | x-ratelimit-remaining | 4999 | + Then I see following strings in the enforcer logs + |aiMetadata| + |gpt-35-turbo| + |AzureAI| + |2024-06-01| + |aiTokenUsage| + |1000| + |300| + |500| + |hour| + |vendor_name| + |vendor_version| + |totalTokens| + |promptTokens| + |completionTokens| And I wait for 3 seconds And I send "GET" request to "https://default.gw.wso2.com:9095/llm-api/v1.0.0/3.14/employee?send=body" with body "" Then the response status code should be 200