Skip to content

Commit

Permalink
Merge pull request #2541 from Tharsanan1/ai-analytics
Browse files Browse the repository at this point in the history
Add AI analytics integration tests
  • Loading branch information
Tharsanan1 authored Oct 7, 2024
2 parents f3a7349 + bc720d9 commit cbcad45
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,18 @@
*/
public class AITokenUsage {
@JsonProperty("totalTokens")
private Double totalTokens;
private Integer totalTokens;

@JsonProperty("promptTokens")
private Double promptTokens;
private Integer promptTokens;

@JsonProperty("completionTokens")
private Double completionTokens;
private Integer completionTokens;

@JsonProperty("hour")
private Integer hour;

public Double getTotalTokens() {
public Integer getTotalTokens() {

return totalTokens;
}
Expand All @@ -50,27 +50,27 @@ public void setHour(Integer hour) {
this.hour = hour;
}

public void setTotalTokens(Double totalTokens) {
public void setTotalTokens(Integer totalTokens) {

this.totalTokens = totalTokens;
}

public Double getPromptTokens() {
public Integer getPromptTokens() {

return promptTokens;
}

public void setPromptTokens(Double promptTokens) {
public void setPromptTokens(Integer promptTokens) {

this.promptTokens = promptTokens;
}

public Double getCompletionTokens() {
public Integer getCompletionTokens() {

return completionTokens;
}

public void setCompletionTokens(Double completionTokens) {
public void setCompletionTokens(Integer completionTokens) {

this.completionTokens = completionTokens;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import com.google.protobuf.Value;
import io.envoyproxy.envoy.data.accesslog.v3.AccessLogCommon;
import io.envoyproxy.envoy.data.accesslog.v3.HTTPAccessLogEntry;
import io.envoyproxy.envoy.service.ext_proc.v3.ProcessingResponse;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.wso2.apk.enforcer.commons.analytics.collectors.AnalyticsCustomDataProvider;
Expand Down Expand Up @@ -260,9 +259,9 @@ public Map<String, Object> getProperties() {
Map<String,Object> map = new HashMap();
Map<String, Value> fieldsMap = getFieldsMapFromLogEntry();
String gwURL = getValueAsString(fieldsMap, MetadataConstants.GATEWAY_URL);
Double totalTokenCount = getValueAsDouble(fieldsMap, MetadataConstants.TOTAL_TOKEN_COUNT);
Double completionTokenCount = getValueAsDouble(fieldsMap, MetadataConstants.COMPLETION_TOKEN_COUNT);
Double promptTokenCount = getValueAsDouble(fieldsMap, MetadataConstants.PROMPT_TOKEN_COUNT);
Integer totalTokenCount = getValueAsInteger(fieldsMap, MetadataConstants.TOTAL_TOKEN_COUNT);
Integer completionTokenCount = getValueAsInteger(fieldsMap, MetadataConstants.COMPLETION_TOKEN_COUNT);
Integer promptTokenCount = getValueAsInteger(fieldsMap, MetadataConstants.PROMPT_TOKEN_COUNT);
String model = getValueAsString(fieldsMap, MetadataConstants.MODEL);
String providerName = getValueAsString(fieldsMap, MetadataConstants.AI_PROVIDER_NAME);
String providerAPIVersion = getValueAsString(fieldsMap, MetadataConstants.AI_PROVIDER_API_VERSION);
Expand Down Expand Up @@ -308,12 +307,13 @@ private String getValueAsString(Map<String, Value> fieldsMap, String key) {
return fieldsMap.get(key).getStringValue();
}

private Double getValueAsDouble(Map<String, Value> fieldsMap, String key) {
private Integer getValueAsInteger(Map<String, Value> fieldsMap, String key) {

if (fieldsMap == null || !fieldsMap.containsKey(key)) {
return null;
}
return fieldsMap.get(key).getNumberValue();
Double d = fieldsMap.get(key).getNumberValue();
return d.intValue();
}

private Map<String, Value> getFieldsMapFromLogEntry() {
Expand All @@ -325,10 +325,18 @@ private Map<String, Value> getFieldsMapFromLogEntry() {
.containsKey(MetadataConstants.EXT_AUTH_METADATA_CONTEXT_KEY)) {
return new HashMap<>(0);
}
Map<String, Value> metadataFromExtProc = logEntry.getCommonProperties().getMetadata().getFilterMetadataMap()
.get(MetadataConstants.EXT_PROC_METADATA_CONTEXT_KEY).getFieldsMap();
Map<String, Value> metadataFromExtAuthz = logEntry.getCommonProperties().getMetadata().getFilterMetadataMap()
.get(MetadataConstants.EXT_AUTH_METADATA_CONTEXT_KEY).getFieldsMap();
Map<String, Value> metadataFromExtProc = new HashMap<>();
if (logEntry.getCommonProperties().getMetadata().getFilterMetadataMap()
.get(MetadataConstants.EXT_PROC_METADATA_CONTEXT_KEY) != null) {
metadataFromExtProc = logEntry.getCommonProperties().getMetadata().getFilterMetadataMap()
.get(MetadataConstants.EXT_PROC_METADATA_CONTEXT_KEY).getFieldsMap();
}
Map<String, Value> metadataFromExtAuthz = new HashMap<>();
if (logEntry.getCommonProperties().getMetadata().getFilterMetadataMap()
.get(MetadataConstants.EXT_AUTH_METADATA_CONTEXT_KEY) != null) {
metadataFromExtAuthz = logEntry.getCommonProperties().getMetadata().getFilterMetadataMap()
.get(MetadataConstants.EXT_AUTH_METADATA_CONTEXT_KEY).getFieldsMap();
}
Map<String, Value> mergedMetadata = new HashMap<>(metadataFromExtProc);
mergedMetadata.putAll(metadataFromExtAuthz);
return mergedMetadata;
Expand Down
1 change: 1 addition & 0 deletions test/cucumber-tests/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ dependencies {
implementation 'io.grpc:grpc-stub:1.57.0'
implementation 'io.grpc:grpc-stub:1.57.0'
implementation 'com.google.protobuf:protobuf-java:4.28.2'
implementation group: 'io.kubernetes', name: 'client-java', version: '21.0.1'
}

test {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@
import io.cucumber.java.Before;
import io.cucumber.java.en.Given;
import io.cucumber.java.en.Then;
import io.kubernetes.client.openapi.ApiClient;
import io.kubernetes.client.openapi.ApiException;
import io.kubernetes.client.openapi.Configuration;
import io.kubernetes.client.openapi.apis.CoreV1Api;
import io.kubernetes.client.openapi.models.V1Pod;
import io.kubernetes.client.util.Config;
import org.apache.commons.io.IOUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand Down Expand Up @@ -71,6 +77,8 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import io.grpc.Status;
import io.grpc.StatusRuntimeException;

Expand Down Expand Up @@ -305,6 +313,31 @@ public void eventualSuccess(int statusCode, DataTable dataTable) throws IOExcept
}
}

@Then("I see following strings in the enforcer logs")
public void checkEnforcerLogs(DataTable dataTable) throws IOException, InterruptedException, ApiException {
List<String> stringsToCheck = dataTable.asList(String.class);
ApiClient client = Config.defaultClient();
Configuration.setDefaultApiClient(client);
CoreV1Api api = new CoreV1Api();
String namespace = "apk-integration-test";
String podName = "your-pod-name";
String labelSelector = "app.kubernetes.io/app=gateway";

List<V1Pod> podList = api.listNamespacedPod(namespace).labelSelector(labelSelector).execute().getItems();
if (!podList.isEmpty()) {
podName = Objects.requireNonNull(podList.get(0).getMetadata()).getName();
}
try {
String logs = api.readNamespacedPodLog(podName, namespace).container("enforcer").sinceSeconds(60).execute();
Assert.assertNotNull(logs, String.format("Could not find any logs in the last 60 seconds. PodName: %s, namespace: %s", podName, namespace));
for(String word : stringsToCheck) {
Assert.assertTrue(logs.contains(word), "Expected word '" + word + "' not found in logs");
}
} catch (ApiException e) {
System.out.println(e);
}
}

@Then("I wait for next minute")
public void waitForNextMinute() throws InterruptedException {
LocalDateTime now = LocalDateTime.now();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,21 @@ Feature: API backend based AI ratelimit Feature
Then the response status code should be 200
And the response headers should contain
| x-ratelimit-remaining | 4999 |
Then I see following strings in the enforcer logs
|aiMetadata|
|gpt-35-turbo|
|AzureAI|
|2024-06-01|
|aiTokenUsage|
|1000|
|300|
|500|
|hour|
|vendor_name|
|vendor_version|
|totalTokens|
|promptTokens|
|completionTokens|
And I wait for 3 seconds
And I send "GET" request to "https://default.gw.wso2.com:9095/llm-api/v1.0.0/3.14/employee?send=body" with body ""
Then the response status code should be 200
Expand Down

0 comments on commit cbcad45

Please sign in to comment.