Skip to content

Commit

Permalink
[OPIK-486] Add NOT_EQUAL operator for filtering (#816)
Browse files Browse the repository at this point in the history
  • Loading branch information
BorisTkachenko authored Dec 5, 2024
1 parent b3fcc74 commit 98c360c
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ public enum Operator {
STARTS_WITH("starts_with"),
ENDS_WITH("ends_with"),
EQUAL("="),
NOT_EQUAL("!="),
GREATER_THAN(">"),
GREATER_THAN_EQUAL(">="),
LESS_THAN("<"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ public class FilterQueryBuilder {
"has(groupArray(tuple(lower(name), %1$s)), tuple(lower(:filterKey%2$d), toDecimal64(:filter%2$d, 9))) = 1",
FieldType.DICTIONARY,
"lower(JSON_VALUE(%1$s, :filterKey%2$d)) = lower(:filter%2$d)")),
Operator.NOT_EQUAL, new EnumMap<>(Map.of(
FieldType.STRING, "lower(%1$s) != lower(:filter%2$d)",
FieldType.DATE_TIME, "%1$s != parseDateTime64BestEffort(:filter%2$d, 9)",
FieldType.NUMBER, "%1$s != :filter%2$d",
FieldType.FEEDBACK_SCORES_NUMBER,
"has(groupArray(tuple(lower(name), %1$s)), tuple(lower(:filterKey%2$d), toDecimal64(:filter%2$d, 9))) = 0",
FieldType.DICTIONARY,
"lower(JSON_VALUE(%1$s, :filterKey%2$d)) != lower(:filter%2$d)")),
Operator.GREATER_THAN, new EnumMap<>(Map.of(
FieldType.DATE_TIME, "%1$s > parseDateTime64BestEffort(:filter%2$d, 9)",
FieldType.NUMBER, "%1$s > :filter%2$d",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
import java.util.Set;
import java.util.UUID;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
Expand Down Expand Up @@ -1109,15 +1110,56 @@ void getByProjectName__whenFilterByCorrespondingField__thenReturnSpansFiltered(S
apiKey);
}

@ParameterizedTest
@MethodSource("equalAndNotEqualFilters")
void getByProjectName__whenFilterTotalEstimatedCostEqual_NotEqual__thenReturnSpansFiltered(Operator operator,
Function<List<Span>, List<Span>> getUnexpectedSpans,
Function<List<Span>, List<Span>> getExpectedSpans) {
String workspaceName = UUID.randomUUID().toString();
String workspaceId = UUID.randomUUID().toString();
String apiKey = UUID.randomUUID().toString();

mockTargetWorkspace(apiKey, workspaceName, workspaceId);

var projectName = generator.generate().toString();
var spans = PodamFactoryUtils.manufacturePojoList(podamFactory, Span.class)
.stream()
.map(span -> span.toBuilder()
.projectId(null)
.projectName(projectName)
.feedbackScores(null)
.build())
.collect(Collectors.toCollection(ArrayList::new));
spans.set(0, spans.getFirst().toBuilder()
.model("gpt-3.5-turbo-1106")
.usage(Map.of("completion_tokens", Math.abs(podamFactory.manufacturePojo(Integer.class)),
"prompt_tokens", Math.abs(podamFactory.manufacturePojo(Integer.class))))
.build());

spans.forEach(expectedSpan -> SpansResourceTest.this.createAndAssert(expectedSpan, apiKey, workspaceName));
var expectedSpans = getExpectedSpans.apply(spans);
var unexpectedSpans = getUnexpectedSpans.apply(spans);

var filters = List.of(SpanFilter.builder()
.field(SpanField.TOTAL_ESTIMATED_COST)
.operator(operator)
.value("0")
.build());
getAndAssertPage(workspaceName, projectName, filters, spans, expectedSpans.reversed(), unexpectedSpans, apiKey);
}

static Stream<Arguments> getByProjectName__whenFilterByCorrespondingField__thenReturnSpansFiltered() {
return Stream.of(
Arguments.of(SpanField.TOTAL_ESTIMATED_COST, Operator.GREATER_THAN, "0"),
Arguments.of(SpanField.MODEL, Operator.EQUAL, "gpt-3.5-turbo-1106"),
Arguments.of(SpanField.PROVIDER, Operator.EQUAL, null));
}

@Test
void getByProjectName__whenFilterNameEqual__thenReturnSpansFiltered() {
@ParameterizedTest
@MethodSource("equalAndNotEqualFilters")
void getByProjectName__whenFilterNameEqual_NotEqual__thenReturnSpansFiltered(Operator operator,
Function<List<Span>, List<Span>> getExpectedSpans,
Function<List<Span>, List<Span>> getUnexpectedSpans) {
String workspaceName = UUID.randomUUID().toString();
String workspaceId = UUID.randomUUID().toString();
String apiKey = UUID.randomUUID().toString();
Expand All @@ -1134,19 +1176,25 @@ void getByProjectName__whenFilterNameEqual__thenReturnSpansFiltered() {
.build())
.collect(Collectors.toCollection(ArrayList::new));
spans.forEach(expectedSpan -> SpansResourceTest.this.createAndAssert(expectedSpan, apiKey, workspaceName));
var expectedSpans = List.of(spans.getFirst());
var unexpectedSpans = List.of(podamFactory.manufacturePojo(Span.class).toBuilder()
.projectId(null)
.build());
unexpectedSpans.forEach(
expectedSpan -> SpansResourceTest.this.createAndAssert(expectedSpan, apiKey, workspaceName));
var expectedSpans = getExpectedSpans.apply(spans);
var unexpectedSpans = getUnexpectedSpans.apply(spans);

var filters = List.of(SpanFilter.builder()
.field(SpanField.NAME)
.operator(Operator.EQUAL)
.operator(operator)
.value(spans.getFirst().name().toUpperCase())
.build());
getAndAssertPage(workspaceName, projectName, filters, spans, expectedSpans, unexpectedSpans, apiKey);
getAndAssertPage(workspaceName, projectName, filters, spans, expectedSpans.reversed(), unexpectedSpans, apiKey);
}

private Stream<Arguments> equalAndNotEqualFilters() {
return Stream.of(
Arguments.of(Operator.EQUAL,
(Function<List<Span>, List<Span>>) spans -> List.of(spans.getFirst()),
(Function<List<Span>, List<Span>>) spans -> spans.subList(1, spans.size())),
Arguments.of(Operator.NOT_EQUAL,
(Function<List<Span>, List<Span>>) spans -> spans.subList(1, spans.size()),
(Function<List<Span>, List<Span>>) spans -> List.of(spans.getFirst())));
}

@Test
Expand Down Expand Up @@ -1286,8 +1334,11 @@ void getByProjectName__whenFilterNameNotContains__thenReturnSpansFiltered() {
getAndAssertPage(workspaceName, projectName, filters, spans, expectedSpans, unexpectedSpans, apiKey);
}

@Test
void getByProjectName__whenFilterStartTimeEqual__thenReturnSpansFiltered() {
@ParameterizedTest
@MethodSource("equalAndNotEqualFilters")
void getByProjectName__whenFilterStartTimeEqual_NotEqual__thenReturnSpansFiltered(Operator operator,
Function<List<Span>, List<Span>> getExpectedSpans,
Function<List<Span>, List<Span>> getUnexpectedSpans) {
String workspaceName = UUID.randomUUID().toString();
String workspaceId = UUID.randomUUID().toString();
String apiKey = UUID.randomUUID().toString();
Expand All @@ -1304,19 +1355,15 @@ void getByProjectName__whenFilterStartTimeEqual__thenReturnSpansFiltered() {
.build())
.collect(Collectors.toCollection(ArrayList::new));
spans.forEach(expectedSpan -> SpansResourceTest.this.createAndAssert(expectedSpan, apiKey, workspaceName));
var expectedSpans = List.of(spans.getFirst());
var unexpectedSpans = List.of(podamFactory.manufacturePojo(Span.class).toBuilder()
.projectId(null)
.build());
unexpectedSpans.forEach(
expectedSpan -> SpansResourceTest.this.createAndAssert(expectedSpan, apiKey, workspaceName));
var expectedSpans = getExpectedSpans.apply(spans);
var unexpectedSpans = getUnexpectedSpans.apply(spans);

var filters = List.of(SpanFilter.builder()
.field(SpanField.START_TIME)
.operator(Operator.EQUAL)
.operator(operator)
.value(spans.getFirst().startTime().toString())
.build());
getAndAssertPage(workspaceName, projectName, filters, spans, expectedSpans, unexpectedSpans, apiKey);
getAndAssertPage(workspaceName, projectName, filters, spans, expectedSpans.reversed(), unexpectedSpans, apiKey);
}

@Test
Expand Down Expand Up @@ -1566,8 +1613,11 @@ void getByProjectName__whenFilterOutputEqual__thenReturnSpansFiltered() {
getAndAssertPage(workspaceName, projectName, filters, spans, expectedSpans, unexpectedSpans, apiKey);
}

@Test
void getByProjectName__whenFilterMetadataEqualString__thenReturnSpansFiltered() {
@ParameterizedTest
@MethodSource("equalAndNotEqualFilters")
void getByProjectName__whenFilterMetadataEqualString__thenReturnSpansFiltered(Operator operator,
Function<List<Span>, List<Span>> getExpectedSpans,
Function<List<Span>, List<Span>> getUnexpectedSpans) {
String workspaceName = UUID.randomUUID().toString();
String workspaceId = UUID.randomUUID().toString();
String apiKey = UUID.randomUUID().toString();
Expand All @@ -1590,20 +1640,16 @@ void getByProjectName__whenFilterMetadataEqualString__thenReturnSpansFiltered()
"Chat-GPT 4.0\"}]}"))
.build());
spans.forEach(expectedSpan -> SpansResourceTest.this.createAndAssert(expectedSpan, apiKey, workspaceName));
var expectedSpans = List.of(spans.getFirst());
var unexpectedSpans = List.of(podamFactory.manufacturePojo(Span.class).toBuilder()
.projectId(null)
.build());
unexpectedSpans.forEach(
expectedSpan -> SpansResourceTest.this.createAndAssert(expectedSpan, apiKey, workspaceName));
var expectedSpans = getExpectedSpans.apply(spans);
var unexpectedSpans = getUnexpectedSpans.apply(spans);

var filters = List.of(SpanFilter.builder()
.field(SpanField.METADATA)
.operator(Operator.EQUAL)
.operator(operator)
.key("$.model[0].version")
.value("OPENAI, CHAT-GPT 4.0")
.build());
getAndAssertPage(workspaceName, projectName, filters, spans, expectedSpans, unexpectedSpans, apiKey);
getAndAssertPage(workspaceName, projectName, filters, spans, expectedSpans.reversed(), unexpectedSpans, apiKey);
}

@Test
Expand Down Expand Up @@ -2439,8 +2485,11 @@ void getByProjectName__whenFilterUsageLessThanEqual__thenReturnSpansFiltered(Str
getAndAssertPage(workspaceName, projectName, filters, spans, expectedSpans, unexpectedSpans, apiKey);
}

@Test
void getByProjectName__whenFilterFeedbackScoresEqual__thenReturnSpansFiltered() {
@ParameterizedTest
@MethodSource
void getByProjectName__whenFilterFeedbackScoresEqual_NotEqual__thenReturnSpansFiltered(Operator operator,
Function<List<Span>, List<Span>> getExpectedSpans,
Function<List<Span>, List<Span>> getUnexpectedSpans) {

String workspaceName = UUID.randomUUID().toString();
String workspaceId = UUID.randomUUID().toString();
Expand Down Expand Up @@ -2474,31 +2523,33 @@ void getByProjectName__whenFilterFeedbackScoresEqual__thenReturnSpansFiltered()
.forEach(
feedbackScore -> createAndAssert(span.id(), feedbackScore, workspaceName, apiKey)));

var expectedSpans = List.of(spans.getFirst());
var unexpectedSpans = List.of(podamFactory.manufacturePojo(Span.class).toBuilder()
.projectId(null)
.build());
unexpectedSpans.forEach(
expectedSpan -> SpansResourceTest.this.createAndAssert(expectedSpan, apiKey, workspaceName));
unexpectedSpans.forEach(
span -> span.feedbackScores()
.forEach(
feedbackScore -> createAndAssert(span.id(), feedbackScore, workspaceName, apiKey)));
var expectedSpans = getExpectedSpans.apply(spans);
var unexpectedSpans = getUnexpectedSpans.apply(spans);

var filters = List.of(
SpanFilter.builder()
.field(SpanField.FEEDBACK_SCORES)
.operator(Operator.EQUAL)
.operator(operator)
.key(spans.getFirst().feedbackScores().get(1).name().toUpperCase())
.value(spans.getFirst().feedbackScores().get(1).value().toString())
.build(),
SpanFilter.builder()
.field(SpanField.FEEDBACK_SCORES)
.operator(Operator.EQUAL)
.operator(operator)
.key(spans.getFirst().feedbackScores().get(2).name().toUpperCase())
.value(spans.getFirst().feedbackScores().get(2).value().toString())
.build());
getAndAssertPage(workspaceName, projectName, filters, spans, expectedSpans, unexpectedSpans, apiKey);
getAndAssertPage(workspaceName, projectName, filters, spans, expectedSpans.reversed(), unexpectedSpans, apiKey);
}

private Stream<Arguments> getByProjectName__whenFilterFeedbackScoresEqual_NotEqual__thenReturnSpansFiltered() {
return Stream.of(
Arguments.of(Operator.EQUAL,
(Function<List<Span>, List<Span>>) spans -> List.of(spans.getFirst()),
(Function<List<Span>, List<Span>>) spans -> spans.subList(1, spans.size())),
Arguments.of(Operator.NOT_EQUAL,
(Function<List<Span>, List<Span>>) spans -> spans.subList(2, spans.size()),
(Function<List<Span>, List<Span>>) spans -> spans.subList(0, 2)));
}

@Test
Expand Down

0 comments on commit 98c360c

Please sign in to comment.