Skip to content

Commit

Permalink
Fix comments and refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Borys Tkachenko authored and Borys Tkachenko committed Dec 30, 2024
1 parent 5a4222b commit 888ed03
Show file tree
Hide file tree
Showing 11 changed files with 142 additions and 112 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
import lombok.extern.slf4j.Slf4j;
import org.glassfish.jersey.server.ChunkedOutput;

import java.util.List;
import java.util.Collections;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
Expand Down Expand Up @@ -282,8 +282,7 @@ public Response findFeedbackScoreNames(@QueryParam("experiment_ids") String expe

var experimentIds = Optional.ofNullable(experimentIdsQueryParam)
.map(IdParamsValidator::getIds)
.map(List::copyOf)
.orElse(null);
.orElse(Collections.emptySet());

String workspaceId = requestContext.get().getWorkspaceId();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;

import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
Expand Down Expand Up @@ -246,8 +247,7 @@ public Response findFeedbackScoreNames(@QueryParam("project_ids") String project

var projectIds = Optional.ofNullable(projectIdsQueryParam)
.map(IdParamsValidator::getIds)
.map(List::copyOf)
.orElse(null);
.orElse(Collections.emptySet());

String workspaceId = requestContext.get().getWorkspaceId();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ Mono<Long> scoreEntity(EntityType entityType, UUID entityId, FeedbackScore score

Mono<List<String>> getSpanFeedbackScoreNames(@NonNull UUID projectId, SpanType type);

Mono<List<String>> getExperimentsFeedbackScoreNames(List<UUID> experimentIds);
Mono<List<String>> getExperimentsFeedbackScoreNames(Set<UUID> experimentIds);

Mono<List<String>> getProjectsFeedbackScoreNames(List<UUID> projectIds);
Mono<List<String>> getProjectsFeedbackScoreNames(Set<UUID> projectIds);
}

@Singleton
Expand Down Expand Up @@ -434,7 +434,7 @@ public Mono<List<String>> getTraceFeedbackScoreNames(@NonNull UUID projectId) {

@Override
@WithSpan
public Mono<List<String>> getExperimentsFeedbackScoreNames(List<UUID> experimentIds) {
public Mono<List<String>> getExperimentsFeedbackScoreNames(Set<UUID> experimentIds) {
return asyncTemplate.nonTransaction(connection -> {

ST template = new ST(SELECT_TRACE_FEEDBACK_SCORE_NAMES);
Expand All @@ -454,7 +454,7 @@ public Mono<List<String>> getExperimentsFeedbackScoreNames(List<UUID> experiment

@Override
@WithSpan
public Mono<List<String>> getProjectsFeedbackScoreNames(List<UUID> projectIds) {
public Mono<List<String>> getProjectsFeedbackScoreNames(Set<UUID> projectIds) {
return asyncTemplate.nonTransaction(connection -> {

ST template = new ST(SELECT_PROJECTS_FEEDBACK_SCORE_NAMES);
Expand All @@ -466,16 +466,11 @@ public Mono<List<String>> getProjectsFeedbackScoreNames(List<UUID> projectIds) {
var statement = connection.createStatement(template.render());

if (CollectionUtils.isNotEmpty(projectIds)) {
template.add("project_ids", projectIds);
}

if (CollectionUtils.isNotEmpty(projectIds)) {
statement.bind("project_ids", projectIds.toArray(UUID[]::new));
statement.bind("project_ids", projectIds);
}

return makeMonoContextAware(bindWorkspaceIdToMono(statement))
.flatMapMany(result -> result.map((row, rowMetadata) -> row.get("name", String.class)))
.distinct()
.collect(Collectors.toList());
});
}
Expand Down Expand Up @@ -510,17 +505,17 @@ private Mono<List<String>> getNames(Statement statement) {
.collect(Collectors.toList());
}

private void bindStatementParam(UUID projectId, List<UUID> experimentIds, Statement statement) {
private void bindStatementParam(UUID projectId, Set<UUID> experimentIds, Statement statement) {
if (projectId != null) {
statement.bind("project_id", projectId);
}

if (CollectionUtils.isNotEmpty(experimentIds)) {
statement.bind("experiment_ids", experimentIds.toArray(UUID[]::new));
statement.bind("experiment_ids", experimentIds);
}
}

private void bindTemplateParam(UUID projectId, boolean withExperimentsOnly, List<UUID> experimentIds, ST template) {
private void bindTemplateParam(UUID projectId, boolean withExperimentsOnly, Set<UUID> experimentIds, ST template) {
if (projectId != null) {
template.add("project_id", projectId);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ public interface FeedbackScoreService {

Mono<FeedbackScoreNames> getSpanFeedbackScoreNames(UUID projectId, SpanType type);

Mono<FeedbackScoreNames> getExperimentsFeedbackScoreNames(List<UUID> experimentIds);
Mono<FeedbackScoreNames> getExperimentsFeedbackScoreNames(Set<UUID> experimentIds);

Mono<FeedbackScoreNames> getProjectsFeedbackScoreNames(List<UUID> projectIds);
Mono<FeedbackScoreNames> getProjectsFeedbackScoreNames(Set<UUID> projectIds);
}

@Slf4j
Expand Down Expand Up @@ -252,14 +252,14 @@ public Mono<FeedbackScoreNames> getSpanFeedbackScoreNames(@NonNull UUID projectI
}

@Override
public Mono<FeedbackScoreNames> getExperimentsFeedbackScoreNames(List<UUID> experimentIds) {
public Mono<FeedbackScoreNames> getExperimentsFeedbackScoreNames(Set<UUID> experimentIds) {
return dao.getExperimentsFeedbackScoreNames(experimentIds)
.map(names -> names.stream().map(FeedbackScoreNames.ScoreName::new).toList())
.map(FeedbackScoreNames::new);
}

@Override
public Mono<FeedbackScoreNames> getProjectsFeedbackScoreNames(List<UUID> projectIds) {
public Mono<FeedbackScoreNames> getProjectsFeedbackScoreNames(Set<UUID> projectIds) {
return dao.getProjectsFeedbackScoreNames(projectIds)
.map(names -> names.stream().map(FeedbackScoreNames.ScoreName::new).toList())
.map(FeedbackScoreNames::new);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package com.comet.opik.api.resources.utils;

import com.comet.opik.api.FeedbackScoreNames;
import lombok.experimental.UtilityClass;

import java.util.List;

import static org.assertj.core.api.Assertions.assertThat;

@UtilityClass
public class AssertionUtils {

public static void assertFeedbackScoreNames(FeedbackScoreNames actual, List<String> expectedNames) {
assertThat(actual.scores()).hasSize(expectedNames.size());
assertThat(actual
.scores()
.stream()
.map(FeedbackScoreNames.ScoreName::name)
.toList()).containsExactlyInAnyOrderElementsOf(expectedNames);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,8 @@ public Project getByName(String projectName, String apiKey, String workspaceName
public FeedbackScoreNames findFeedbackScoreNames(String projectIdsQueryParam, String apiKey, String workspaceName) {
WebTarget webTarget = client.target(RESOURCE_PATH.formatted(baseURI))
.path("feedback-scores")
.path("names");

if (projectIdsQueryParam != null) {
webTarget = webTarget.queryParam("project_ids", projectIdsQueryParam);
}
.path("names")
.queryParam("project_ids", projectIdsQueryParam);

try (var actualResponse = webTarget
.request()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,24 @@
import com.comet.opik.api.FeedbackScore;
import com.comet.opik.api.FeedbackScoreBatch;
import com.comet.opik.api.FeedbackScoreBatchItem;
import com.comet.opik.api.Project;
import com.comet.opik.api.Trace;
import com.comet.opik.api.TraceBatch;
import com.comet.opik.api.TraceUpdate;
import com.comet.opik.api.resources.utils.TestUtils;
import com.comet.opik.podam.PodamFactoryUtils;
import jakarta.ws.rs.HttpMethod;
import jakarta.ws.rs.client.Entity;
import jakarta.ws.rs.core.HttpHeaders;
import jakarta.ws.rs.core.MediaType;
import lombok.RequiredArgsConstructor;
import org.apache.http.HttpStatus;
import ru.vyarus.dropwizard.guice.test.ClientSupport;
import uk.co.jemos.podam.api.PodamFactory;

import java.util.List;
import java.util.UUID;
import java.util.stream.IntStream;

import static com.comet.opik.infrastructure.auth.RequestContext.WORKSPACE_HEADER;
import static org.assertj.core.api.Assertions.assertThat;
Expand All @@ -29,6 +33,7 @@ public class TraceResourceClient {

private final ClientSupport client;
private final String baseURI;
private final PodamFactory podamFactory = PodamFactoryUtils.newPodamFactory();

public UUID createTrace(Trace trace, String apiKey, String workspaceName) {
try (var response = client.target(RESOURCE_PATH.formatted(baseURI))
Expand Down Expand Up @@ -140,4 +145,29 @@ public void updateTrace(UUID id, TraceUpdate traceUpdate, String apiKey, String
assertThat(actualResponse.hasEntity()).isFalse();
}
}

public List<List<FeedbackScoreBatchItem>> createMultiValueScores(List<String> multipleValuesFeedbackScores,
Project project, String apiKey, String workspaceName) {
return IntStream.range(0, multipleValuesFeedbackScores.size())
.mapToObj(i -> {

Trace trace = podamFactory.manufacturePojo(Trace.class).toBuilder()
.name(project.name())
.build();

createTrace(trace, apiKey, workspaceName);

List<FeedbackScoreBatchItem> scores = multipleValuesFeedbackScores.stream()
.map(name -> podamFactory.manufacturePojo(FeedbackScoreBatchItem.class).toBuilder()
.name(name)
.projectName(project.name())
.id(trace.id())
.build())
.toList();

feedbackScores(scores, apiKey, workspaceName);

return scores;
}).toList();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static com.comet.opik.api.resources.utils.AssertionUtils.assertFeedbackScoreNames;
import static com.comet.opik.api.resources.utils.ClickHouseContainerUtils.DATABASE_NAME;
import static com.comet.opik.api.resources.utils.MigrationUtils.CLICKHOUSE_CHANGELOG_FILE;
import static com.comet.opik.api.resources.utils.TestDropwizardAppExtensionUtils.AppContextConfig;
Expand Down Expand Up @@ -2642,10 +2643,12 @@ void getFeedbackScoreNames__whenGetFeedbackScoreNames__thenReturnFeedbackScoreNa
// Create multiple values feedback scores
List<String> multipleValuesFeedbackScores = names.subList(0, names.size() - 1);

List<List<FeedbackScoreBatchItem>> multipleValuesFeedbackScoreList = createMultiValueScores(
multipleValuesFeedbackScores, project, apiKey, workspaceName);
List<List<FeedbackScoreBatchItem>> multipleValuesFeedbackScoreList = traceResourceClient
.createMultiValueScores(
multipleValuesFeedbackScores, project, apiKey, workspaceName);

List<List<FeedbackScoreBatchItem>> singleValueScores = createMultiValueScores(List.of(names.getLast()),
List<List<FeedbackScoreBatchItem>> singleValueScores = traceResourceClient.createMultiValueScores(
List.of(names.getLast()),
project, apiKey, workspaceName);

UUID experimentId = createExperimentsItems(apiKey, workspaceName, multipleValuesFeedbackScoreList,
Expand All @@ -2654,27 +2657,26 @@ void getFeedbackScoreNames__whenGetFeedbackScoreNames__thenReturnFeedbackScoreNa
// Create unexpected feedback scores
var unexpectedProject = podamFactory.manufacturePojo(Project.class);

List<List<FeedbackScoreBatchItem>> unexpectedScores = createMultiValueScores(otherNames, unexpectedProject,
List<List<FeedbackScoreBatchItem>> unexpectedScores = traceResourceClient.createMultiValueScores(otherNames,
unexpectedProject,
apiKey, workspaceName);

createExperimentsItems(apiKey, workspaceName, unexpectedScores, List.of());

fetchAndAssertResponse(userExperimentId, experimentId, projectId, names, otherNames, apiKey, workspaceName);
fetchAndAssertResponse(userExperimentId, experimentId, names, otherNames, apiKey, workspaceName);
}
}

private void fetchAndAssertResponse(boolean userExperimentId, UUID experimentId, UUID projectId, List<String> names,
private void fetchAndAssertResponse(boolean userExperimentId, UUID experimentId, List<String> names,
List<String> otherNames, String apiKey, String workspaceName) {

WebTarget webTarget = client.target(URL_TEMPLATE.formatted(baseURI))
.path("feedback-scores")
.path("names");

String projectIdsQueryParam = null;
if (userExperimentId) {
var ids = JsonUtils.writeValueAsString(List.of(experimentId));
webTarget = webTarget.queryParam("experiment_ids", ids);
projectIdsQueryParam = JsonUtils.writeValueAsString(List.of(projectId));
}

List<String> expectedNames = userExperimentId
Expand All @@ -2692,43 +2694,6 @@ private void fetchAndAssertResponse(boolean userExperimentId, UUID experimentId,
var actualEntity = actualResponse.readEntity(FeedbackScoreNames.class);
assertFeedbackScoreNames(actualEntity, expectedNames);
}

var feedbackScoreNamesByProjectId = projectResourceClient.findFeedbackScoreNames(projectIdsQueryParam, apiKey, workspaceName);
assertFeedbackScoreNames(feedbackScoreNamesByProjectId, expectedNames);
}

private void assertFeedbackScoreNames(FeedbackScoreNames actual, List<String> expectedNames) {
assertThat(actual.scores()).hasSize(expectedNames.size());
assertThat(actual
.scores()
.stream()
.map(FeedbackScoreNames.ScoreName::name)
.toList()).containsExactlyInAnyOrderElementsOf(expectedNames);
}

private List<List<FeedbackScoreBatchItem>> createMultiValueScores(List<String> multipleValuesFeedbackScores,
Project project, String apiKey, String workspaceName) {
return IntStream.range(0, multipleValuesFeedbackScores.size())
.mapToObj(i -> {

Trace trace = podamFactory.manufacturePojo(Trace.class).toBuilder()
.name(project.name())
.build();

traceResourceClient.createTrace(trace, apiKey, workspaceName);

List<FeedbackScoreBatchItem> scores = multipleValuesFeedbackScores.stream()
.map(name -> podamFactory.manufacturePojo(FeedbackScoreBatchItem.class).toBuilder()
.name(name)
.projectName(project.name())
.id(trace.id())
.build())
.toList();

traceResourceClient.feedbackScores(scores, apiKey, workspaceName);

return scores;
}).toList();
}

private UUID createExperimentsItems(String apiKey, String workspaceName,
Expand Down
Loading

0 comments on commit 888ed03

Please sign in to comment.