Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OPIK-645 Return all feedback score names for project ids endpoint #947

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import com.comet.opik.api.PageColumns;
import com.comet.opik.api.filter.ExperimentsComparisonFilter;
import com.comet.opik.api.filter.FiltersFactory;
import com.comet.opik.api.resources.v1.priv.validate.ExperimentParamsValidator;
import com.comet.opik.api.resources.v1.priv.validate.IdParamsValidator;
import com.comet.opik.api.sorting.SortingFactoryDatasets;
import com.comet.opik.api.sorting.SortingField;
import com.comet.opik.domain.DatasetItemService;
Expand Down Expand Up @@ -375,7 +375,7 @@ public Response findDatasetItemsWithExperimentItems(
@QueryParam("filters") String filters,
@QueryParam("truncate") boolean truncate) {

var experimentIds = ExperimentParamsValidator.getExperimentIds(experimentIdsQueryParam);
var experimentIds = IdParamsValidator.getIds(experimentIdsQueryParam);

var queryFilters = filtersFactory.newFilters(filters, ExperimentsComparisonFilter.LIST_TYPE_REFERENCE);

Expand Down Expand Up @@ -413,7 +413,7 @@ public Response getDatasetItemsOutputColumns(

var experimentIds = Optional.ofNullable(experimentIdsQueryParam)
.filter(Predicate.not(String::isEmpty))
.map(ExperimentParamsValidator::getExperimentIds)
.map(IdParamsValidator::getIds)
.orElse(null);

String workspaceId = requestContext.get().getWorkspaceId();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import com.comet.opik.api.FeedbackDefinition;
import com.comet.opik.api.FeedbackScoreNames;
import com.comet.opik.api.Identifier;
import com.comet.opik.api.resources.v1.priv.validate.ExperimentParamsValidator;
import com.comet.opik.api.resources.v1.priv.validate.IdParamsValidator;
import com.comet.opik.domain.ExperimentItemService;
import com.comet.opik.domain.ExperimentService;
import com.comet.opik.domain.FeedbackScoreService;
Expand Down Expand Up @@ -53,7 +53,7 @@
import lombok.extern.slf4j.Slf4j;
import org.glassfish.jersey.server.ChunkedOutput;

import java.util.List;
import java.util.Collections;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
Expand Down Expand Up @@ -281,9 +281,8 @@ public Response deleteExperimentItems(
public Response findFeedbackScoreNames(@QueryParam("experiment_ids") String experimentIdsQueryParam) {

var experimentIds = Optional.ofNullable(experimentIdsQueryParam)
.map(ExperimentParamsValidator::getExperimentIds)
.map(List::copyOf)
.orElse(null);
.map(IdParamsValidator::getIds)
.orElse(Collections.emptySet());

String workspaceId = requestContext.get().getWorkspaceId();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.codahale.metrics.annotation.Timed;
import com.comet.opik.api.BatchDelete;
import com.comet.opik.api.FeedbackScoreNames;
import com.comet.opik.api.Page;
import com.comet.opik.api.Project;
import com.comet.opik.api.ProjectCriteria;
Expand All @@ -10,8 +11,10 @@
import com.comet.opik.api.error.ErrorMessage;
import com.comet.opik.api.metrics.ProjectMetricRequest;
import com.comet.opik.api.metrics.ProjectMetricResponse;
import com.comet.opik.api.resources.v1.priv.validate.IdParamsValidator;
import com.comet.opik.api.sorting.SortingFactoryProjects;
import com.comet.opik.api.sorting.SortingField;
import com.comet.opik.domain.FeedbackScoreService;
import com.comet.opik.domain.ProjectMetricsService;
import com.comet.opik.domain.ProjectService;
import com.comet.opik.infrastructure.auth.RequestContext;
Expand Down Expand Up @@ -48,7 +51,9 @@
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;

import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.UUID;

import static com.comet.opik.domain.ProjectMetricsService.ERR_START_BEFORE_END;
Expand All @@ -67,6 +72,7 @@ public class ProjectsResource {
private final @NonNull Provider<RequestContext> requestContext;
private final @NonNull SortingFactoryProjects sortingFactory;
private final @NonNull ProjectMetricsService metricsService;
private final @NonNull FeedbackScoreService feedbackScoreService;

@GET
@Operation(operationId = "findProjects", summary = "Find projects", description = "Find projects", responses = {
Expand Down Expand Up @@ -232,6 +238,31 @@ public Response getProjectMetrics(
return Response.ok().entity(response).build();
}

@GET
@Path("/feedback-scores/names")
@Operation(operationId = "findFeedbackScoreNamesByProjectIds", summary = "Find Feedback Score names By Project Ids", description = "Find Feedback Score names By Project Ids", responses = {
@ApiResponse(responseCode = "200", description = "Feedback Scores resource", content = @Content(schema = @Schema(implementation = FeedbackScoreNames.class)))
})
public Response findFeedbackScoreNames(@QueryParam("project_ids") String projectIdsQueryParam) {

var projectIds = Optional.ofNullable(projectIdsQueryParam)
.map(IdParamsValidator::getIds)
.orElse(Collections.emptySet());

String workspaceId = requestContext.get().getWorkspaceId();

log.info("Find feedback score names by project_ids '{}', on workspaceId '{}'",
projectIds, workspaceId);
FeedbackScoreNames feedbackScoreNames = feedbackScoreService
.getProjectsFeedbackScoreNames(projectIds)
.contextWrite(ctx -> setRequestContext(ctx, requestContext))
.block();
log.info("Found feedback score names '{}' by project_ids '{}', on workspaceId '{}'",
feedbackScoreNames.scores().size(), projectIds, workspaceId);

return Response.ok(feedbackScoreNames).build();
}

private void validate(ProjectMetricRequest request) {
if (!request.intervalStart().isBefore(request.intervalEnd())) {
throw new BadRequestException(ERR_START_BEFORE_END);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@

@UtilityClass
@Slf4j
public class ExperimentParamsValidator {
public class IdParamsValidator {

private static final TypeReference<List<UUID>> LIST_UUID_TYPE_REFERENCE = new TypeReference<>() {
};

public static Set<UUID> getExperimentIds(String experimentIds) {
var message = "Invalid query param experiment ids '%s'".formatted(experimentIds);
public static Set<UUID> getIds(String idsQueryParam) {
var message = "Invalid query param ids '%s'".formatted(idsQueryParam);
try {
return JsonUtils.readValue(experimentIds, LIST_UUID_TYPE_REFERENCE)
return JsonUtils.readValue(idsQueryParam, LIST_UUID_TYPE_REFERENCE)
.stream()
.collect(Collectors.toUnmodifiableSet());
} catch (RuntimeException exception) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ Mono<Long> scoreEntity(EntityType entityType, UUID entityId, FeedbackScore score

Mono<List<String>> getSpanFeedbackScoreNames(@NonNull UUID projectId, SpanType type);

Mono<List<String>> getExperimentsFeedbackScoreNames(List<UUID> experimentIds);
Mono<List<String>> getExperimentsFeedbackScoreNames(Set<UUID> experimentIds);

Mono<List<String>> getProjectsFeedbackScoreNames(Set<UUID> projectIds);
}

@Singleton
Expand Down Expand Up @@ -203,6 +205,23 @@ INNER JOIN (
;
""";

private static final String SELECT_PROJECTS_FEEDBACK_SCORE_NAMES = """
SELECT
distinct name
FROM (
SELECT
name
FROM feedback_scores
WHERE workspace_id = :workspace_id
<if(project_ids)>
AND project_id IN :project_ids
<endif>
ORDER BY entity_id DESC, last_updated_at DESC
LIMIT 1 BY entity_id, name
) AS names
;
""";

private final static String SELECT_SPAN_FEEDBACK_SCORE_NAMES = """
SELECT
distinct name
Expand Down Expand Up @@ -415,7 +434,7 @@ public Mono<List<String>> getTraceFeedbackScoreNames(@NonNull UUID projectId) {

@Override
@WithSpan
public Mono<List<String>> getExperimentsFeedbackScoreNames(List<UUID> experimentIds) {
public Mono<List<String>> getExperimentsFeedbackScoreNames(Set<UUID> experimentIds) {
return asyncTemplate.nonTransaction(connection -> {

ST template = new ST(SELECT_TRACE_FEEDBACK_SCORE_NAMES);
Expand All @@ -433,6 +452,29 @@ public Mono<List<String>> getExperimentsFeedbackScoreNames(List<UUID> experiment
});
}

@Override
@WithSpan
public Mono<List<String>> getProjectsFeedbackScoreNames(Set<UUID> projectIds) {
return asyncTemplate.nonTransaction(connection -> {

ST template = new ST(SELECT_PROJECTS_FEEDBACK_SCORE_NAMES);

if (CollectionUtils.isNotEmpty(projectIds)) {
template.add("project_ids", projectIds);
}

var statement = connection.createStatement(template.render());

if (CollectionUtils.isNotEmpty(projectIds)) {
statement.bind("project_ids", projectIds);
}

return makeMonoContextAware(bindWorkspaceIdToMono(statement))
.flatMapMany(result -> result.map((row, rowMetadata) -> row.get("name", String.class)))
.collect(Collectors.toList());
BorisTkachenko marked this conversation as resolved.
Show resolved Hide resolved
});
}

@Override
@WithSpan
public Mono<List<String>> getSpanFeedbackScoreNames(@NonNull UUID projectId, SpanType type) {
Expand Down Expand Up @@ -463,17 +505,17 @@ private Mono<List<String>> getNames(Statement statement) {
.collect(Collectors.toList());
}

private void bindStatementParam(UUID projectId, List<UUID> experimentIds, Statement statement) {
private void bindStatementParam(UUID projectId, Set<UUID> experimentIds, Statement statement) {
if (projectId != null) {
statement.bind("project_id", projectId);
}

if (CollectionUtils.isNotEmpty(experimentIds)) {
statement.bind("experiment_ids", experimentIds.toArray(UUID[]::new));
statement.bind("experiment_ids", experimentIds);
}
}

private void bindTemplateParam(UUID projectId, boolean withExperimentsOnly, List<UUID> experimentIds, ST template) {
private void bindTemplateParam(UUID projectId, boolean withExperimentsOnly, Set<UUID> experimentIds, ST template) {
if (projectId != null) {
template.add("project_id", projectId);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ public interface FeedbackScoreService {

Mono<FeedbackScoreNames> getSpanFeedbackScoreNames(UUID projectId, SpanType type);

Mono<FeedbackScoreNames> getExperimentsFeedbackScoreNames(List<UUID> experimentIds);
Mono<FeedbackScoreNames> getExperimentsFeedbackScoreNames(Set<UUID> experimentIds);

Mono<FeedbackScoreNames> getProjectsFeedbackScoreNames(Set<UUID> projectIds);
}

@Slf4j
Expand Down Expand Up @@ -250,12 +252,19 @@ public Mono<FeedbackScoreNames> getSpanFeedbackScoreNames(@NonNull UUID projectI
}

@Override
public Mono<FeedbackScoreNames> getExperimentsFeedbackScoreNames(List<UUID> experimentIds) {
public Mono<FeedbackScoreNames> getExperimentsFeedbackScoreNames(Set<UUID> experimentIds) {
return dao.getExperimentsFeedbackScoreNames(experimentIds)
.map(names -> names.stream().map(FeedbackScoreNames.ScoreName::new).toList())
.map(FeedbackScoreNames::new);
}

@Override
public Mono<FeedbackScoreNames> getProjectsFeedbackScoreNames(Set<UUID> projectIds) {
return dao.getProjectsFeedbackScoreNames(projectIds)
.map(names -> names.stream().map(FeedbackScoreNames.ScoreName::new).toList())
.map(FeedbackScoreNames::new);
}

private Mono<Long> failWithNotFound(String errorMessage) {
log.info(errorMessage);
return Mono.error(new NotFoundException(Response.status(404)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package com.comet.opik.api.resources.utils;

import com.comet.opik.api.FeedbackScoreNames;
import lombok.experimental.UtilityClass;

import java.util.List;

import static org.assertj.core.api.Assertions.assertThat;

@UtilityClass
public class AssertionUtils {

public static void assertFeedbackScoreNames(FeedbackScoreNames actual, List<String> expectedNames) {
assertThat(actual.scores()).hasSize(expectedNames.size());
assertThat(actual
.scores()
.stream()
.map(FeedbackScoreNames.ScoreName::name)
.toList()).containsExactlyInAnyOrderElementsOf(expectedNames);
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package com.comet.opik.api.resources.utils.resources;

import com.comet.opik.api.FeedbackScoreNames;
import com.comet.opik.api.Project;
import com.comet.opik.api.resources.utils.TestUtils;
import com.comet.opik.infrastructure.auth.RequestContext;
import jakarta.ws.rs.client.Entity;
import jakarta.ws.rs.client.WebTarget;
import jakarta.ws.rs.core.HttpHeaders;
import lombok.RequiredArgsConstructor;
import org.apache.hc.core5.http.HttpStatus;
Expand Down Expand Up @@ -74,4 +76,22 @@ public Project getByName(String projectName, String apiKey, String workspaceName
}
}

public FeedbackScoreNames findFeedbackScoreNames(String projectIdsQueryParam, String apiKey, String workspaceName) {
WebTarget webTarget = client.target(RESOURCE_PATH.formatted(baseURI))
.path("feedback-scores")
.path("names")
.queryParam("project_ids", projectIdsQueryParam);

try (var actualResponse = webTarget
.request()
.header(HttpHeaders.AUTHORIZATION, apiKey)
.header(RequestContext.WORKSPACE_HEADER, workspaceName)
.get()) {

// then
assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(org.apache.http.HttpStatus.SC_OK);

return actualResponse.readEntity(FeedbackScoreNames.class);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,24 @@
import com.comet.opik.api.FeedbackScore;
import com.comet.opik.api.FeedbackScoreBatch;
import com.comet.opik.api.FeedbackScoreBatchItem;
import com.comet.opik.api.Project;
import com.comet.opik.api.Trace;
import com.comet.opik.api.TraceBatch;
import com.comet.opik.api.TraceUpdate;
import com.comet.opik.api.resources.utils.TestUtils;
import com.comet.opik.podam.PodamFactoryUtils;
import jakarta.ws.rs.HttpMethod;
import jakarta.ws.rs.client.Entity;
import jakarta.ws.rs.core.HttpHeaders;
import jakarta.ws.rs.core.MediaType;
import lombok.RequiredArgsConstructor;
import org.apache.http.HttpStatus;
import ru.vyarus.dropwizard.guice.test.ClientSupport;
import uk.co.jemos.podam.api.PodamFactory;

import java.util.List;
import java.util.UUID;
import java.util.stream.IntStream;

import static com.comet.opik.infrastructure.auth.RequestContext.WORKSPACE_HEADER;
import static org.assertj.core.api.Assertions.assertThat;
Expand All @@ -29,6 +33,7 @@ public class TraceResourceClient {

private final ClientSupport client;
private final String baseURI;
private final PodamFactory podamFactory = PodamFactoryUtils.newPodamFactory();

public UUID createTrace(Trace trace, String apiKey, String workspaceName) {
try (var response = client.target(RESOURCE_PATH.formatted(baseURI))
Expand Down Expand Up @@ -140,4 +145,29 @@ public void updateTrace(UUID id, TraceUpdate traceUpdate, String apiKey, String
assertThat(actualResponse.hasEntity()).isFalse();
}
}

public List<List<FeedbackScoreBatchItem>> createMultiValueScores(List<String> multipleValuesFeedbackScores,
Project project, String apiKey, String workspaceName) {
return IntStream.range(0, multipleValuesFeedbackScores.size())
.mapToObj(i -> {

Trace trace = podamFactory.manufacturePojo(Trace.class).toBuilder()
.name(project.name())
.build();

createTrace(trace, apiKey, workspaceName);

List<FeedbackScoreBatchItem> scores = multipleValuesFeedbackScores.stream()
.map(name -> podamFactory.manufacturePojo(FeedbackScoreBatchItem.class).toBuilder()
.name(name)
.projectName(project.name())
.id(trace.id())
.build())
.toList();

feedbackScores(scores, apiKey, workspaceName);

return scores;
}).toList();
}
}
Loading
Loading