diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/SpanBatchUpdate.java b/apps/opik-backend/src/main/java/com/comet/opik/api/SpanBatchUpdate.java new file mode 100644 index 00000000000..0671b2da6e7 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/SpanBatchUpdate.java @@ -0,0 +1,24 @@ +package com.comet.opik.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.validation.Valid; +import jakarta.validation.constraints.NotEmpty; +import jakarta.validation.constraints.NotNull; +import jakarta.validation.constraints.Size; +import lombok.Builder; + +import java.util.Set; +import java.util.UUID; + +@Builder(toBuilder = true) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +@Schema(description = "Request to batch update multiple spans") +public record SpanBatchUpdate( + @NotNull @NotEmpty @Size(min = 1, max = 1000) @Schema(description = "List of span IDs to update (max 1000)") Set ids, + @NotNull @Valid @Schema(description = "Update to apply to all spans") SpanUpdate update, + @Schema(description = "If true, merge tags with existing tags instead of replacing them. Default: false") Boolean mergeTags) { +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/TraceBatchUpdate.java b/apps/opik-backend/src/main/java/com/comet/opik/api/TraceBatchUpdate.java new file mode 100644 index 00000000000..ebaf4493854 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/TraceBatchUpdate.java @@ -0,0 +1,24 @@ +package com.comet.opik.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.validation.Valid; +import jakarta.validation.constraints.NotEmpty; +import jakarta.validation.constraints.NotNull; +import jakarta.validation.constraints.Size; +import lombok.Builder; + +import java.util.Set; +import java.util.UUID; + +@Builder(toBuilder = true) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +@Schema(description = "Request to batch update multiple traces") +public record TraceBatchUpdate( + @NotNull @NotEmpty @Size(min = 1, max = 1000) @Schema(description = "List of trace IDs to update (max 1000)") Set ids, + @NotNull @Valid @Schema(description = "Update to apply to all traces") TraceUpdate update, + @Schema(description = "If true, merge tags with existing tags instead of replacing them. Default: false") Boolean mergeTags) { +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/TraceThreadBatchUpdate.java b/apps/opik-backend/src/main/java/com/comet/opik/api/TraceThreadBatchUpdate.java new file mode 100644 index 00000000000..914e4bf624b --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/TraceThreadBatchUpdate.java @@ -0,0 +1,24 @@ +package com.comet.opik.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.validation.Valid; +import jakarta.validation.constraints.NotEmpty; +import jakarta.validation.constraints.NotNull; +import jakarta.validation.constraints.Size; +import lombok.Builder; + +import java.util.Set; +import java.util.UUID; + +@Builder(toBuilder = true) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +@Schema(description = "Request to batch update multiple trace threads") +public record TraceThreadBatchUpdate( + @NotNull @NotEmpty @Size(min = 1, max = 1000) @Schema(description = "List of thread model IDs to update (max 1000)") Set ids, + @NotNull @Valid @Schema(description = "Update to apply to all threads") TraceThreadUpdate update, + @Schema(description = "If true, merge tags with existing tags instead of replacing them. Default: false") Boolean mergeTags) { +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/SpansResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/SpansResource.java index 299d8905311..d543d452191 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/SpansResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/SpansResource.java @@ -12,6 +12,7 @@ import com.comet.opik.api.ProjectStats; import com.comet.opik.api.Span; import com.comet.opik.api.SpanBatch; +import com.comet.opik.api.SpanBatchUpdate; import com.comet.opik.api.SpanSearchStreamRequest; import com.comet.opik.api.SpanUpdate; import com.comet.opik.api.filter.FiltersFactory; @@ -232,6 +233,28 @@ public Response createSpans( return Response.noContent().build(); } + @PATCH + @Path("/batch") + @Operation(operationId = "batchUpdateSpans", summary = "Batch update spans", description = "Update multiple spans", responses = { + @ApiResponse(responseCode = "204", description = "No Content"), + @ApiResponse(responseCode = "400", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class)))}) + @RateLimited + public Response batchUpdate( + @RequestBody(content = @Content(schema = @Schema(implementation = SpanBatchUpdate.class))) @Valid @NotNull SpanBatchUpdate batchUpdate) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Batch updating '{}' spans on workspaceId '{}'", batchUpdate.ids().size(), workspaceId); + + spanService.batchUpdate(batchUpdate) + .contextWrite(ctx -> setRequestContext(ctx, requestContext)) + .block(); + + log.info("Batch updated '{}' spans on workspaceId '{}'", batchUpdate.ids().size(), workspaceId); + + return Response.noContent().build(); + } + @PATCH @Path("{id}") @Operation(operationId = "updateSpan", summary = "Update span by id", description = "Update span by id", responses = { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TracesResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TracesResource.java index adb771b5e37..1580638c75e 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TracesResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TracesResource.java @@ -16,9 +16,11 @@ import com.comet.opik.api.Trace; import com.comet.opik.api.Trace.TracePage; import com.comet.opik.api.TraceBatch; +import com.comet.opik.api.TraceBatchUpdate; import com.comet.opik.api.TraceSearchStreamRequest; import com.comet.opik.api.TraceThread; import com.comet.opik.api.TraceThreadBatchIdentifier; +import com.comet.opik.api.TraceThreadBatchUpdate; import com.comet.opik.api.TraceThreadIdentifier; import com.comet.opik.api.TraceThreadSearchStreamRequest; import com.comet.opik.api.TraceThreadUpdate; @@ -305,6 +307,28 @@ public Response createTraces( return Response.noContent().build(); } + @PATCH + @Path("/batch") + @Operation(operationId = "batchUpdateTraces", summary = "Batch update traces", description = "Update multiple traces", responses = { + @ApiResponse(responseCode = "204", description = "No Content"), + @ApiResponse(responseCode = "400", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class)))}) + @RateLimited + public Response batchUpdate( + @RequestBody(content = @Content(schema = @Schema(implementation = TraceBatchUpdate.class))) @Valid @NotNull TraceBatchUpdate batchUpdate) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Batch updating '{}' traces on workspaceId '{}'", batchUpdate.ids().size(), workspaceId); + + service.batchUpdate(batchUpdate) + .contextWrite(ctx -> setRequestContext(ctx, requestContext)) + .block(); + + log.info("Batch updated '{}' traces on workspaceId '{}'", batchUpdate.ids().size(), workspaceId); + + return Response.noContent().build(); + } + @PATCH @Path("{id}") @Operation(operationId = "updateTrace", summary = "Update trace by id", description = "Update trace by id", responses = { @@ -787,6 +811,28 @@ public Response closeTraceThread( return Response.noContent().build(); } + @PATCH + @Path("/threads/batch") + @Operation(operationId = "batchUpdateThreads", summary = "Batch update threads", description = "Update multiple threads", responses = { + @ApiResponse(responseCode = "204", description = "No Content"), + @ApiResponse(responseCode = "400", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class)))}) + @RateLimited + public Response batchUpdateThreads( + @RequestBody(content = @Content(schema = @Schema(implementation = TraceThreadBatchUpdate.class))) @Valid @NotNull TraceThreadBatchUpdate batchUpdate) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Batch updating '{}' threads on workspaceId '{}'", batchUpdate.ids().size(), workspaceId); + + traceThreadService.batchUpdate(batchUpdate) + .contextWrite(ctx -> setRequestContext(ctx, requestContext)) + .block(); + + log.info("Batch updated '{}' threads on workspaceId '{}'", batchUpdate.ids().size(), workspaceId); + + return Response.noContent().build(); + } + @PATCH @Path("/threads/{threadModelId}") @Operation(operationId = "updateThread", summary = "Update thread", description = "Update thread", responses = { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java index 0c6e54d3ff2..e2c6bcc6aed 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java @@ -2253,6 +2253,168 @@ private void bindCost(Span span, Statement statement, String index) { } } + private static final String BULK_UPDATE = """ + INSERT INTO spans ( + id, + project_id, + workspace_id, + trace_id, + parent_span_id, + name, + type, + start_time, + end_time, + input, + output, + metadata, + model, + provider, + total_estimated_cost, + total_estimated_cost_version, + tags, + usage, + error_info, + created_at, + created_by, + last_updated_by, + truncation_threshold + ) + SELECT + s.id, + s.project_id, + s.workspace_id, + s.trace_id, + s.parent_span_id, + :name s.name as name, + :type s.type as type, + s.start_time, + parseDateTime64BestEffort(:end_time, 9) s.end_time as end_time, + :input s.input as input, + :output s.output as output, + :metadata s.metadata as metadata, + :model s.model as model, + :provider s.provider as provider, + toDecimal128(:total_estimated_cost, 12) s.total_estimated_cost as total_estimated_cost, + :total_estimated_cost_version s.total_estimated_cost_version as total_estimated_cost_version, + arrayConcat(s.tags, :tags):tagss.tags as tags, + CAST((:usageKeys, :usageValues), 'Map(String, Int64)') s.usage as usage, + :error_info s.error_info as error_info, + s.created_at, + s.created_by, + :user_name as last_updated_by, + :truncation_threshold + FROM spans s + WHERE s.id IN :ids AND s.workspace_id = :workspace_id + ORDER BY (s.workspace_id, s.project_id, s.trace_id, s.parent_span_id, s.id) DESC, s.last_updated_at DESC + LIMIT 1 BY s.id; + """; + + @WithSpan + public Mono bulkUpdate(@NonNull Set ids, @NonNull SpanUpdate update, boolean mergeTags) { + Preconditions.checkArgument(!ids.isEmpty(), "ids must not be empty"); + log.info("Bulk updating '{}' spans", ids.size()); + + var template = newBulkUpdateTemplate(update, BULK_UPDATE, mergeTags); + var query = template.render(); + + return Mono.from(connectionFactory.create()) + .flatMapMany(connection -> { + var statement = connection.createStatement(query) + .bind("ids", ids); + + bindBulkUpdateParams(update, statement); + TruncationUtils.bindTruncationThreshold(statement, "truncation_threshold", configuration); + + Segment segment = startSegment("spans", "Clickhouse", "bulk_update"); + + return makeFluxContextAware(bindUserNameAndWorkspaceContextToStream(statement)) + .doFinally(signalType -> endSegment(segment)); + }) + .then() + .doOnSuccess(__ -> log.info("Completed bulk update for '{}' spans", ids.size())); + } + + private ST newBulkUpdateTemplate(SpanUpdate spanUpdate, String sql, boolean mergeTags) { + var template = TemplateUtils.newST(sql); + + if (StringUtils.isNotBlank(spanUpdate.name())) { + template.add("name", spanUpdate.name()); + } + Optional.ofNullable(spanUpdate.type()) + .ifPresent(type -> template.add("type", type.toString())); + Optional.ofNullable(spanUpdate.input()) + .ifPresent(input -> template.add("input", input.toString())); + Optional.ofNullable(spanUpdate.output()) + .ifPresent(output -> template.add("output", output.toString())); + Optional.ofNullable(spanUpdate.tags()) + .ifPresent(tags -> { + template.add("tags", tags.toString()); + template.add("merge_tags", mergeTags); + }); + Optional.ofNullable(spanUpdate.metadata()) + .ifPresent(metadata -> template.add("metadata", metadata.toString())); + if (StringUtils.isNotBlank(spanUpdate.model())) { + template.add("model", spanUpdate.model()); + } + if (StringUtils.isNotBlank(spanUpdate.provider())) { + template.add("provider", spanUpdate.provider()); + } + Optional.ofNullable(spanUpdate.endTime()) + .ifPresent(endTime -> template.add("end_time", endTime.toString())); + Optional.ofNullable(spanUpdate.usage()) + .ifPresent(usage -> template.add("usage", usage.toString())); + Optional.ofNullable(spanUpdate.errorInfo()) + .ifPresent(errorInfo -> template.add("error_info", JsonUtils.readTree(errorInfo).toString())); + + if (spanUpdate.totalEstimatedCost() != null) { + template.add("total_estimated_cost", "total_estimated_cost"); + template.add("total_estimated_cost_version", "total_estimated_cost_version"); + } + return template; + } + + private void bindBulkUpdateParams(SpanUpdate spanUpdate, Statement statement) { + if (StringUtils.isNotBlank(spanUpdate.name())) { + statement.bind("name", spanUpdate.name()); + } + Optional.ofNullable(spanUpdate.type()) + .ifPresent(type -> statement.bind("type", type.toString())); + Optional.ofNullable(spanUpdate.input()) + .ifPresent(input -> statement.bind("input", input.toString())); + Optional.ofNullable(spanUpdate.output()) + .ifPresent(output -> statement.bind("output", output.toString())); + Optional.ofNullable(spanUpdate.tags()) + .ifPresent(tags -> statement.bind("tags", tags.toArray(String[]::new))); + Optional.ofNullable(spanUpdate.usage()) + .ifPresent(usage -> { + var usageKeys = new ArrayList(); + var usageValues = new ArrayList(); + for (var entry : usage.entrySet()) { + usageKeys.add(entry.getKey()); + usageValues.add(entry.getValue()); + } + statement.bind("usageKeys", usageKeys.toArray(String[]::new)); + statement.bind("usageValues", usageValues.toArray(Integer[]::new)); + }); + Optional.ofNullable(spanUpdate.endTime()) + .ifPresent(endTime -> statement.bind("end_time", endTime.toString())); + Optional.ofNullable(spanUpdate.metadata()) + .ifPresent(metadata -> statement.bind("metadata", metadata.toString())); + if (StringUtils.isNotBlank(spanUpdate.model())) { + statement.bind("model", spanUpdate.model()); + } + if (StringUtils.isNotBlank(spanUpdate.provider())) { + statement.bind("provider", spanUpdate.provider()); + } + Optional.ofNullable(spanUpdate.errorInfo()) + .ifPresent(errorInfo -> statement.bind("error_info", JsonUtils.readTree(errorInfo).toString())); + + if (spanUpdate.totalEstimatedCost() != null) { + statement.bind("total_estimated_cost", spanUpdate.totalEstimatedCost().toString()); + statement.bind("total_estimated_cost_version", ""); + } + } + private JsonNode getMetadataWithProvider(Row row, Set exclude, String provider) { // Parse base metadata from database JsonNode baseMetadata = Optional diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanService.java index cc202d06234..63add66f87a 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanService.java @@ -6,6 +6,7 @@ import com.comet.opik.api.ProjectStats; import com.comet.opik.api.Span; import com.comet.opik.api.SpanBatch; +import com.comet.opik.api.SpanBatchUpdate; import com.comet.opik.api.SpanUpdate; import com.comet.opik.api.SpansCountResponse; import com.comet.opik.api.attachment.AttachmentInfo; @@ -204,6 +205,15 @@ public Mono update(@NonNull UUID id, @NonNull SpanUpdate spanUpdate) { .then())))); } + @WithSpan + public Mono batchUpdate(@NonNull SpanBatchUpdate batchUpdate) { + log.info("Batch updating '{}' spans", batchUpdate.ids().size()); + + boolean mergeTags = Boolean.TRUE.equals(batchUpdate.mergeTags()); + return spanDAO.bulkUpdate(batchUpdate.ids(), batchUpdate.update(), mergeTags) + .doOnSuccess(__ -> log.info("Completed batch update for '{}' spans", batchUpdate.ids().size())); + } + private Mono insertUpdate(Project project, SpanUpdate spanUpdate, UUID id) { return IdGenerator .validateVersionAsync(id, SPAN_KEY) diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java index 823c696de75..b4797dd6c74 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java @@ -31,6 +31,7 @@ import com.google.inject.ImplementedBy; import io.opentelemetry.instrumentation.annotations.WithSpan; import io.r2dbc.spi.Connection; +import io.r2dbc.spi.ConnectionFactory; import io.r2dbc.spi.Result; import io.r2dbc.spi.Row; import io.r2dbc.spi.RowMetadata; @@ -65,6 +66,7 @@ import static com.comet.opik.api.TraceCountResponse.WorkspaceTraceCount; import static com.comet.opik.api.TraceThread.TraceThreadPage; import static com.comet.opik.domain.AsyncContextUtils.bindUserNameAndWorkspaceContext; +import static com.comet.opik.domain.AsyncContextUtils.bindUserNameAndWorkspaceContextToStream; import static com.comet.opik.domain.AsyncContextUtils.bindWorkspaceIdToFlux; import static com.comet.opik.domain.AsyncContextUtils.bindWorkspaceIdToMono; import static com.comet.opik.infrastructure.instrumentation.InstrumentAsyncUtils.Segment; @@ -127,6 +129,8 @@ interface TraceDAO { Flux threadsSearch(int limit, TraceSearchCriteria criteria); Mono> getMinimalThreadInfoByIds(UUID projectId, Set threadId); + + Mono bulkUpdate(@NonNull Set ids, @NonNull TraceUpdate update, boolean mergeTags); } @Slf4j @@ -2412,6 +2416,7 @@ ORDER BY (workspace_id, project_id, id) DESC, last_updated_at DESC private final @NonNull TraceSortingFactory sortingFactory; private final @NonNull TraceThreadSortingFactory traceThreadSortingFactory; private final @NonNull OpikConfiguration configuration; + private final @NonNull ConnectionFactory connectionFactory; @Override @WithSpan @@ -3525,6 +3530,124 @@ private Flux findTraceStream(int limit, @NonNull TraceSearchCr }); } + private static final String BULK_UPDATE = """ + INSERT INTO traces ( + id, + project_id, + workspace_id, + name, + start_time, + end_time, + input, + output, + metadata, + tags, + error_info, + created_at, + created_by, + last_updated_by, + thread_id, + visibility_mode, + truncation_threshold + ) + SELECT + t.id, + t.project_id, + t.workspace_id, + :name t.name as name, + t.start_time, + parseDateTime64BestEffort(:end_time, 9) t.end_time as end_time, + :input t.input as input, + :output t.output as output, + :metadata t.metadata as metadata, + arrayConcat(t.tags, :tags):tagst.tags as tags, + :error_info t.error_info as error_info, + t.created_at, + t.created_by, + :user_name as last_updated_by, + :thread_id t.thread_id as thread_id, + t.visibility_mode, + :truncation_threshold as truncation_threshold + FROM traces t + WHERE t.id IN :ids AND t.workspace_id = :workspace_id + ORDER BY t.last_updated_at DESC + LIMIT 1 BY t.id;"""; + + @Override + @WithSpan + public Mono bulkUpdate(@NonNull Set ids, @NonNull TraceUpdate update, boolean mergeTags) { + Preconditions.checkArgument(!ids.isEmpty(), "ids must not be empty"); + log.info("Bulk updating '{}' traces", ids.size()); + + var template = newBulkUpdateTemplate(update, BULK_UPDATE, mergeTags); + var query = template.render(); + + return Mono.from(connectionFactory.create()) + .flatMapMany(connection -> { + var statement = connection.createStatement(query) + .bind("ids", ids); + + bindBulkUpdateParams(update, statement); + TruncationUtils.bindTruncationThreshold(statement, "truncation_threshold", configuration); + + Segment segment = startSegment("traces", "Clickhouse", "bulk_update"); + + return makeFluxContextAware(bindUserNameAndWorkspaceContextToStream(statement)) + .doFinally(signalType -> endSegment(segment)); + }) + .then() + .doOnSuccess(__ -> log.info("Completed bulk update for '{}' traces", ids.size())); + } + + private ST newBulkUpdateTemplate(TraceUpdate traceUpdate, String sql, boolean mergeTags) { + var template = TemplateUtils.newST(sql); + + if (StringUtils.isNotBlank(traceUpdate.name())) { + template.add("name", traceUpdate.name()); + } + Optional.ofNullable(traceUpdate.input()) + .ifPresent(input -> template.add("input", input.toString())); + Optional.ofNullable(traceUpdate.output()) + .ifPresent(output -> template.add("output", output.toString())); + Optional.ofNullable(traceUpdate.tags()) + .ifPresent(tags -> { + template.add("tags", tags.toString()); + template.add("merge_tags", mergeTags); + }); + Optional.ofNullable(traceUpdate.metadata()) + .ifPresent(metadata -> template.add("metadata", metadata.toString())); + Optional.ofNullable(traceUpdate.endTime()) + .ifPresent(endTime -> template.add("end_time", endTime.toString())); + Optional.ofNullable(traceUpdate.errorInfo()) + .ifPresent(errorInfo -> template.add("error_info", JsonUtils.readTree(errorInfo).toString())); + if (StringUtils.isNotBlank(traceUpdate.threadId())) { + template.add("thread_id", traceUpdate.threadId()); + } + + return template; + } + + private void bindBulkUpdateParams(TraceUpdate traceUpdate, Statement statement) { + if (StringUtils.isNotBlank(traceUpdate.name())) { + statement.bind("name", traceUpdate.name()); + } + Optional.ofNullable(traceUpdate.input()) + .ifPresent(input -> statement.bind("input", input.toString())); + Optional.ofNullable(traceUpdate.output()) + .ifPresent(output -> statement.bind("output", output.toString())); + Optional.ofNullable(traceUpdate.tags()) + .ifPresent(tags -> statement.bind("tags", tags.toArray(String[]::new))); + Optional.ofNullable(traceUpdate.endTime()) + .ifPresent(endTime -> statement.bind("end_time", endTime.toString())); + Optional.ofNullable(traceUpdate.metadata()) + .ifPresent(metadata -> statement.bind("metadata", metadata.toString())); + Optional.ofNullable(traceUpdate.errorInfo()) + .ifPresent(errorInfo -> statement.bind("error_info", JsonUtils.readTree(errorInfo).toString())); + if (StringUtils.isNotBlank(traceUpdate.threadId())) { + statement.bind("thread_id", traceUpdate.threadId()); + } + } + private JsonNode getMetadataWithProviders(Row row, Set exclude, List providers) { // Parse base metadata from database JsonNode baseMetadata = Optional diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceService.java index 902e88c4c97..ca12b69094f 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceService.java @@ -7,6 +7,7 @@ import com.comet.opik.api.ProjectStats; import com.comet.opik.api.Trace; import com.comet.opik.api.TraceBatch; +import com.comet.opik.api.TraceBatchUpdate; import com.comet.opik.api.TraceCountResponse; import com.comet.opik.api.TraceDetails; import com.comet.opik.api.TraceThread; @@ -79,6 +80,8 @@ public interface TraceService { Mono update(TraceUpdate trace, UUID id); + Mono batchUpdate(TraceBatchUpdate batchUpdate); + Mono get(UUID id); Mono get(UUID id, boolean stripAttachments); @@ -382,6 +385,16 @@ public Mono update(@NonNull TraceUpdate traceUpdate, @NonNull UUID id) { .then()); } + @Override + @WithSpan + public Mono batchUpdate(@NonNull TraceBatchUpdate batchUpdate) { + log.info("Batch updating '{}' traces", batchUpdate.ids().size()); + + boolean mergeTags = Boolean.TRUE.equals(batchUpdate.mergeTags()); + return dao.bulkUpdate(batchUpdate.ids(), batchUpdate.update(), mergeTags) + .doOnSuccess(__ -> log.info("Completed batch update for '{}' traces", batchUpdate.ids().size())); + } + private Mono insertUpdate(Project project, TraceUpdate traceUpdate, UUID id) { return IdGenerator .validateVersionAsync(id, TRACE_KEY) diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/threads/TraceThreadDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/threads/TraceThreadDAO.java index b7d3b864a6b..a586e506e04 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/threads/TraceThreadDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/threads/TraceThreadDAO.java @@ -4,11 +4,14 @@ import com.comet.opik.api.TraceThreadStatus; import com.comet.opik.api.TraceThreadUpdate; import com.comet.opik.api.events.ProjectWithPendingClosureTraceThreads; +import com.comet.opik.infrastructure.OpikConfiguration; import com.comet.opik.infrastructure.db.TransactionTemplateAsync; import com.comet.opik.infrastructure.instrumentation.InstrumentAsyncUtils; import com.comet.opik.utils.template.TemplateUtils; +import com.google.common.base.Preconditions; import com.google.inject.ImplementedBy; import io.r2dbc.spi.Connection; +import io.r2dbc.spi.ConnectionFactory; import io.r2dbc.spi.Result; import io.r2dbc.spi.Statement; import jakarta.inject.Inject; @@ -33,6 +36,7 @@ import static com.comet.opik.domain.AsyncContextUtils.bindUserNameAndWorkspaceContext; import static com.comet.opik.domain.AsyncContextUtils.bindWorkspaceIdToFlux; import static com.comet.opik.domain.AsyncContextUtils.bindWorkspaceIdToMono; +import static com.comet.opik.infrastructure.instrumentation.InstrumentAsyncUtils.Segment; import static com.comet.opik.infrastructure.instrumentation.InstrumentAsyncUtils.endSegment; import static com.comet.opik.infrastructure.instrumentation.InstrumentAsyncUtils.startSegment; import static com.comet.opik.utils.AsyncUtils.makeFluxContextAware; @@ -64,6 +68,11 @@ Flux findProjectsWithPendingClosureThread Mono setScoredAt(UUID projectId, List threadIds, Instant scoredAt); Flux> streamPendingClosureThreads(UUID projectId, Instant lastUpdatedAt); + + record ThreadIdWithTagsAndMetadata(UUID id, Set tags, UUID projectId) { + } + + Mono bulkUpdate(@NonNull List ids, @NonNull TraceThreadUpdate update, boolean mergeTags); } @Singleton @@ -258,6 +267,8 @@ AND last_updated_at < parseDateTime64BestEffort(:last_updated_at, 6) """; private final @NonNull TransactionTemplateAsync asyncTemplate; + private final @NonNull ConnectionFactory connectionFactory; + private final @NonNull OpikConfiguration configuration; @Override public Mono save(@NonNull List traceThreads) { @@ -583,4 +594,76 @@ private void bindStatementParam(TraceThreadCriteria criteria, Statement statemen statement.bind("status", criteria.status().getValue()); } } + + private static final String BULK_UPDATE = """ + INSERT INTO trace_threads ( + workspace_id, + project_id, + thread_id, + id, + status, + created_by, + last_updated_by, + created_at, + last_updated_at, + tags, + sampling_per_rule, + scored_at + ) + SELECT + tt.workspace_id, + tt.project_id, + tt.thread_id, + tt.id, + tt.status, + tt.created_by, + tt.last_updated_by, + tt.created_at, + now64(6) as last_updated_at, + arrayConcat(tt.tags, :tags):tagstt.tags as tags, + tt.sampling_per_rule, + tt.scored_at + FROM trace_threads tt final + WHERE tt.id IN :ids AND tt.workspace_id = :workspace_id;"""; + + @Override + public Mono bulkUpdate(@NonNull List ids, @NonNull TraceThreadUpdate update, boolean mergeTags) { + Preconditions.checkArgument(!ids.isEmpty(), "ids must not be empty"); + log.info("Bulk updating '{}' thread models", ids.size()); + + var template = newBulkUpdateTemplate(update, BULK_UPDATE, mergeTags); + var query = template.render(); + + return Mono.from(connectionFactory.create()) + .flatMapMany(connection -> { + var statement = connection.createStatement(query) + .bind("ids", ids); + + bindBulkUpdateParams(update, statement); + + Segment segment = startSegment("trace_threads", "Clickhouse", "bulk_update"); + + return makeFluxContextAware(bindWorkspaceIdToFlux(statement)) + .doFinally(signalType -> endSegment(segment)); + }) + .then() + .doOnSuccess(__ -> log.info("Completed bulk update for '{}' thread models", ids.size())); + } + + private ST newBulkUpdateTemplate(TraceThreadUpdate update, String sql, boolean mergeTags) { + var template = TemplateUtils.newST(sql); + + Optional.ofNullable(update.tags()) + .ifPresent(tags -> { + template.add("tags", tags.toString()); + template.add("merge_tags", mergeTags); + }); + + return template; + } + + private void bindBulkUpdateParams(TraceThreadUpdate update, Statement statement) { + Optional.ofNullable(update.tags()) + .ifPresent(tags -> statement.bind("tags", tags.toArray(String[]::new))); + } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/threads/TraceThreadIdService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/threads/TraceThreadIdService.java index 56e6bd044c7..c3e9ded3f17 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/threads/TraceThreadIdService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/threads/TraceThreadIdService.java @@ -37,6 +37,8 @@ Mono getOrCreateTraceThreadId(String workspaceId, UUID proje Mono> getTraceThreadIdsByThreadModelIds(List threadModelIds); + Mono> getTraceThreadIdModelsByThreadModelIds(List threadModelIds); + } @Singleton @@ -106,6 +108,22 @@ public Mono> getTraceThreadIdsByThreadModelIds(@NonNull List> getTraceThreadIdModelsByThreadModelIds(@NonNull List threadModelIds) { + Preconditions.checkArgument(!threadModelIds.isEmpty(), + "Thread model IDs cannot be null or empty"); + + return Mono.fromCallable(() -> { + var threadModels = transactionTemplate.inTransaction(TransactionTemplateAsync.READ_ONLY, + handle -> handle.attach(TraceThreadIdDAO.class).findByThreadModelIds(threadModelIds)); + + log.info("Fetched '{}' thread ID models for '{}' thread model IDs", threadModels.size(), + threadModelIds.size()); + + return threadModels; + }).subscribeOn(Schedulers.boundedElastic()); + } + private Mono getTraceThreadId(String threadId, UUID projectId) { return Mono.fromCallable(() -> transactionTemplate.inTransaction(TransactionTemplateAsync.READ_ONLY, handle -> handle.attach(TraceThreadIdDAO.class).findByProjectIdAndThreadId(projectId, threadId))); diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/threads/TraceThreadService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/threads/TraceThreadService.java index a7e46b3d283..335486d6efb 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/threads/TraceThreadService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/threads/TraceThreadService.java @@ -2,6 +2,7 @@ import com.comet.opik.api.ThreadTimestamps; import com.comet.opik.api.TraceThread; +import com.comet.opik.api.TraceThreadBatchUpdate; import com.comet.opik.api.TraceThreadSampling; import com.comet.opik.api.TraceThreadStatus; import com.comet.opik.api.TraceThreadUpdate; @@ -31,6 +32,7 @@ import java.time.Duration; import java.time.Instant; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Set; @@ -72,6 +74,8 @@ Mono processProjectWithTraceThreadsPendingClosure(UUID projectId, Instant Mono update(UUID threadModelId, TraceThreadUpdate threadUpdate); + Mono batchUpdate(TraceThreadBatchUpdate batchUpdate); + Mono setScoredAt(UUID projectId, List threadIds, Instant scoredAt); Mono> getThreadIdsByThreadModelIds(List threadModelIds); @@ -167,6 +171,16 @@ public Mono update(@NonNull UUID threadModelId, @NonNull TraceThreadUpdate traceThreadIdModel.projectId(), threadUpdate)); } + @Override + public Mono batchUpdate(@NonNull TraceThreadBatchUpdate batchUpdate) { + log.info("Batch updating '{}' threads", batchUpdate.ids().size()); + + boolean mergeTags = Boolean.TRUE.equals(batchUpdate.mergeTags()); + List threadModelIds = new ArrayList<>(batchUpdate.ids()); + return traceThreadDAO.bulkUpdate(threadModelIds, batchUpdate.update(), mergeTags) + .doOnSuccess(__ -> log.info("Completed batch update for '{}' threads", batchUpdate.ids().size())); + } + @Override public Mono setScoredAt(@NonNull UUID projectId, @NonNull List threadIds, @NonNull Instant scoredAt) { if (threadIds.isEmpty()) { diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/resources/SpanResourceClient.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/resources/SpanResourceClient.java index 82754f58020..b3277549d64 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/resources/SpanResourceClient.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/resources/SpanResourceClient.java @@ -7,6 +7,7 @@ import com.comet.opik.api.ProjectStats; import com.comet.opik.api.Span; import com.comet.opik.api.SpanBatch; +import com.comet.opik.api.SpanBatchUpdate; import com.comet.opik.api.SpanSearchStreamRequest; import com.comet.opik.api.SpanUpdate; import com.comet.opik.api.filter.SpanFilter; @@ -438,4 +439,20 @@ public ProjectStats getSpansStats(String projectName, return null; } + public void batchUpdateSpans(SpanBatchUpdate batchUpdate, String apiKey, String workspaceName) { + try (var actualResponse = callBatchUpdateSpans(batchUpdate, apiKey, workspaceName)) { + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(HttpStatus.SC_NO_CONTENT); + assertThat(actualResponse.hasEntity()).isFalse(); + } + } + + public Response callBatchUpdateSpans(SpanBatchUpdate batchUpdate, String apiKey, String workspaceName) { + return client.target(RESOURCE_PATH.formatted(baseURI)) + .path("batch") + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .method(HttpMethod.PATCH, Entity.json(batchUpdate)); + } + } diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/resources/TraceResourceClient.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/resources/TraceResourceClient.java index e9a5ca2038c..c642fbd3da8 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/resources/TraceResourceClient.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/resources/TraceResourceClient.java @@ -10,9 +10,11 @@ import com.comet.opik.api.ProjectStats; import com.comet.opik.api.Trace; import com.comet.opik.api.TraceBatch; +import com.comet.opik.api.TraceBatchUpdate; import com.comet.opik.api.TraceSearchStreamRequest; import com.comet.opik.api.TraceThread; import com.comet.opik.api.TraceThreadBatchIdentifier; +import com.comet.opik.api.TraceThreadBatchUpdate; import com.comet.opik.api.TraceThreadIdentifier; import com.comet.opik.api.TraceThreadSearchStreamRequest; import com.comet.opik.api.TraceThreadUpdate; @@ -675,6 +677,39 @@ public Response callBatchCreateTracesWithCookie(List traces, String sessi .post(Entity.json(new TraceBatch(traces))); } + public void batchUpdateTraces(TraceBatchUpdate batchUpdate, String apiKey, String workspaceName) { + try (var actualResponse = callBatchUpdateTraces(batchUpdate, apiKey, workspaceName)) { + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(HttpStatus.SC_NO_CONTENT); + assertThat(actualResponse.hasEntity()).isFalse(); + } + } + + public Response callBatchUpdateTraces(TraceBatchUpdate batchUpdate, String apiKey, String workspaceName) { + return client.target(RESOURCE_PATH.formatted(baseURI)) + .path("batch") + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .method(HttpMethod.PATCH, Entity.json(batchUpdate)); + } + + public void batchUpdateThreads(TraceThreadBatchUpdate batchUpdate, String apiKey, String workspaceName) { + try (var actualResponse = callBatchUpdateThreads(batchUpdate, apiKey, workspaceName)) { + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(HttpStatus.SC_NO_CONTENT); + assertThat(actualResponse.hasEntity()).isFalse(); + } + } + + public Response callBatchUpdateThreads(TraceThreadBatchUpdate batchUpdate, String apiKey, String workspaceName) { + return client.target(RESOURCE_PATH.formatted(baseURI)) + .path("threads") + .path("batch") + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .method(HttpMethod.PATCH, Entity.json(batchUpdate)); + } + public Response callPostWithCookie(Object body, String sessionToken, String workspaceName) { return client.target(RESOURCE_PATH.formatted(baseURI)) .request() diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/SpansBatchUpdateResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/SpansBatchUpdateResourceTest.java new file mode 100644 index 00000000000..dc3b1f91b94 --- /dev/null +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/SpansBatchUpdateResourceTest.java @@ -0,0 +1,773 @@ +package com.comet.opik.api.resources.v1.priv; + +import com.comet.opik.api.ErrorInfo; +import com.comet.opik.api.Span; +import com.comet.opik.api.SpanBatchUpdate; +import com.comet.opik.api.SpanUpdate; +import com.comet.opik.api.Trace; +import com.comet.opik.api.error.ErrorMessage; +import com.comet.opik.api.resources.utils.AuthTestUtils; +import com.comet.opik.api.resources.utils.ClickHouseContainerUtils; +import com.comet.opik.api.resources.utils.ClientSupportUtils; +import com.comet.opik.api.resources.utils.MigrationUtils; +import com.comet.opik.api.resources.utils.MinIOContainerUtils; +import com.comet.opik.api.resources.utils.MySQLContainerUtils; +import com.comet.opik.api.resources.utils.RedisContainerUtils; +import com.comet.opik.api.resources.utils.TestDropwizardAppExtensionUtils; +import com.comet.opik.api.resources.utils.TestUtils; +import com.comet.opik.api.resources.utils.WireMockUtils; +import com.comet.opik.api.resources.utils.resources.ProjectResourceClient; +import com.comet.opik.api.resources.utils.resources.SpanResourceClient; +import com.comet.opik.api.resources.utils.resources.TraceResourceClient; +import com.comet.opik.domain.SpanType; +import com.comet.opik.extensions.DropwizardAppExtensionProvider; +import com.comet.opik.extensions.RegisterApp; +import com.comet.opik.podam.PodamFactoryUtils; +import com.comet.opik.utils.JsonUtils; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.uuid.Generators; +import com.fasterxml.uuid.impl.TimeBasedEpochGenerator; +import com.redis.testcontainers.RedisContainer; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.testcontainers.clickhouse.ClickHouseContainer; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.lifecycle.Startables; +import org.testcontainers.mysql.MySQLContainer; +import ru.vyarus.dropwizard.guice.test.ClientSupport; +import ru.vyarus.dropwizard.guice.test.jupiter.ext.TestDropwizardAppExtension; +import uk.co.jemos.podam.api.PodamFactory; + +import java.math.BigDecimal; +import java.sql.SQLException; +import java.time.Instant; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Stream; + +import static com.comet.opik.api.resources.utils.ClickHouseContainerUtils.DATABASE_NAME; +import static com.comet.opik.domain.ProjectService.DEFAULT_PROJECT; +import static org.assertj.core.api.Assertions.assertThat; + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +@ExtendWith(DropwizardAppExtensionProvider.class) +@DisplayName("Spans Batch Update Resource Test") +class SpansBatchUpdateResourceTest { + + private static final String API_KEY = UUID.randomUUID().toString(); + private static final String USER = UUID.randomUUID().toString(); + private static final String WORKSPACE_ID = UUID.randomUUID().toString(); + private static final String TEST_WORKSPACE = UUID.randomUUID().toString(); + + private final RedisContainer redisContainer = RedisContainerUtils.newRedisContainer(); + private final MySQLContainer mySqlContainer = MySQLContainerUtils.newMySQLContainer(); + private final GenericContainer zookeeperContainer = ClickHouseContainerUtils.newZookeeperContainer(); + private final ClickHouseContainer clickHouseContainer = ClickHouseContainerUtils + .newClickHouseContainer(zookeeperContainer); + private final GenericContainer minIOContainer = MinIOContainerUtils.newMinIOContainer(); + private final WireMockUtils.WireMockRuntime wireMock; + + @RegisterApp + private final TestDropwizardAppExtension app; + + { + Startables.deepStart(redisContainer, mySqlContainer, clickHouseContainer, zookeeperContainer, minIOContainer) + .join(); + String minioUrl = "http://%s:%d".formatted(minIOContainer.getHost(), minIOContainer.getMappedPort(9000)); + + wireMock = WireMockUtils.startWireMock(); + + var databaseAnalyticsFactory = ClickHouseContainerUtils.newDatabaseAnalyticsFactory( + clickHouseContainer, DATABASE_NAME); + + MigrationUtils.runMysqlDbMigration(mySqlContainer); + MigrationUtils.runClickhouseDbMigration(clickHouseContainer); + MinIOContainerUtils.setupBucketAndCredentials(minioUrl); + + app = TestDropwizardAppExtensionUtils.newTestDropwizardAppExtension( + TestDropwizardAppExtensionUtils.AppContextConfig.builder() + .jdbcUrl(mySqlContainer.getJdbcUrl()) + .databaseAnalyticsFactory(databaseAnalyticsFactory) + .redisUrl(redisContainer.getRedisURI()) + .runtimeInfo(wireMock.runtimeInfo()) + .isMinIO(true) + .minioUrl(minioUrl) + .build()); + } + + private final PodamFactory podamFactory = PodamFactoryUtils.newPodamFactory(); + private final TimeBasedEpochGenerator generator = Generators.timeBasedEpochGenerator(); + + private String baseURI; + private ClientSupport client; + private ProjectResourceClient projectResourceClient; + private TraceResourceClient traceResourceClient; + private SpanResourceClient spanResourceClient; + + @BeforeAll + void setUpAll(ClientSupport client) throws SQLException { + this.baseURI = TestUtils.getBaseUrl(client); + this.client = client; + + ClientSupportUtils.config(client); + + mockTargetWorkspace(API_KEY, TEST_WORKSPACE, WORKSPACE_ID); + + this.projectResourceClient = new ProjectResourceClient(this.client, baseURI, podamFactory); + this.traceResourceClient = new TraceResourceClient(this.client, baseURI); + this.spanResourceClient = new SpanResourceClient(this.client, baseURI); + } + + private void mockTargetWorkspace(String apiKey, String workspaceName, String workspaceId) { + AuthTestUtils.mockTargetWorkspace(wireMock.server(), apiKey, workspaceName, workspaceId, USER); + } + + @Nested + @DisplayName("Batch Update Tags:") + @TestInstance(TestInstance.Lifecycle.PER_CLASS) + class BatchUpdateTags { + + private UUID traceId; + + @BeforeEach + void setUp() { + var trace = podamFactory.manufacturePojo(Trace.class).toBuilder() + .projectName(DEFAULT_PROJECT) + .feedbackScores(null) + .build(); + traceId = traceResourceClient.createTrace(trace, API_KEY, TEST_WORKSPACE); + } + + Stream mergeTagsTestCases() { + return Stream.of( + Arguments.of(true, "merge"), + Arguments.of(false, "replace")); + } + + @ParameterizedTest(name = "Success: batch update tags with {1} mode") + @MethodSource("mergeTagsTestCases") + @DisplayName("Success: batch update tags for multiple spans") + void batchUpdate__success(boolean mergeTags, String mode) { + // Create spans with existing tags + var span1 = podamFactory.manufacturePojo(Span.class).toBuilder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .parentSpanId(null) + .tags(mergeTags ? Set.of("existing-tag-1", "existing-tag-2") : Set.of("old-tag-1", "old-tag-2")) + .build(); + var span2 = podamFactory.manufacturePojo(Span.class).toBuilder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .parentSpanId(null) + .tags(mergeTags ? Set.of("existing-tag-3") : Set.of("old-tag-3")) + .build(); + var span3 = mergeTags + ? podamFactory.manufacturePojo(Span.class).toBuilder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .parentSpanId(null) + .tags(null) + .build() + : null; + + var id1 = spanResourceClient.createSpan(span1, API_KEY, TEST_WORKSPACE); + var id2 = spanResourceClient.createSpan(span2, API_KEY, TEST_WORKSPACE); + var id3 = mergeTags ? spanResourceClient.createSpan(span3, API_KEY, TEST_WORKSPACE) : null; + + // Batch update with new tags + var newTags = mergeTags ? Set.of("new-tag-1", "new-tag-2") : Set.of("new-tag"); + var ids = mergeTags ? Set.of(id1, id2, id3) : Set.of(id1, id2); + var batchUpdate = SpanBatchUpdate.builder() + .ids(ids) + .update(SpanUpdate.builder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .tags(newTags) + .build()) + .mergeTags(mergeTags) + .build(); + + spanResourceClient.batchUpdateSpans(batchUpdate, API_KEY, TEST_WORKSPACE); + + // Verify spans were updated + var updatedSpan1 = spanResourceClient.getById(id1, TEST_WORKSPACE, API_KEY); + if (mergeTags) { + assertThat(updatedSpan1.tags()).containsExactlyInAnyOrder( + "existing-tag-1", "existing-tag-2", "new-tag-1", "new-tag-2"); + } else { + assertThat(updatedSpan1.tags()).containsExactly("new-tag"); + } + + var updatedSpan2 = spanResourceClient.getById(id2, TEST_WORKSPACE, API_KEY); + if (mergeTags) { + assertThat(updatedSpan2.tags()).containsExactlyInAnyOrder("existing-tag-3", "new-tag-1", "new-tag-2"); + } else { + assertThat(updatedSpan2.tags()).containsExactly("new-tag"); + } + + if (mergeTags) { + var updatedSpan3 = spanResourceClient.getById(id3, TEST_WORKSPACE, API_KEY); + assertThat(updatedSpan3.tags()).containsExactlyInAnyOrder("new-tag-1", "new-tag-2"); + } + } + + @Test + @DisplayName("when batch update with empty IDs, then return 400") + void batchUpdate__whenEmptyIds__thenReturn400() { + var batchUpdate = SpanBatchUpdate.builder() + .ids(Set.of()) + .update(SpanUpdate.builder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .tags(Set.of("tag")) + .build()) + .mergeTags(true) + .build(); + + try (var actualResponse = spanResourceClient.callBatchUpdateSpans(batchUpdate, API_KEY, TEST_WORKSPACE)) { + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(422); + assertThat(actualResponse.hasEntity()).isTrue(); + var error = actualResponse.readEntity(ErrorMessage.class); + assertThat(error.errors()).anySatisfy(msg -> assertThat(msg).contains("ids")); + } + } + + @Test + @DisplayName("when batch update with too many IDs, then return 400") + void batchUpdate__whenTooManyIds__thenReturn400() { + // Create 1001 IDs (exceeds max of 1000) + var ids = new HashSet(); + for (int i = 0; i < 1001; i++) { + ids.add(generator.generate()); + } + + var batchUpdate = SpanBatchUpdate.builder() + .ids(ids) + .update(SpanUpdate.builder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .tags(Set.of("tag")) + .build()) + .mergeTags(true) + .build(); + + try (var actualResponse = spanResourceClient.callBatchUpdateSpans(batchUpdate, API_KEY, TEST_WORKSPACE)) { + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(422); + assertThat(actualResponse.hasEntity()).isTrue(); + var error = actualResponse.readEntity(ErrorMessage.class); + assertThat(error.errors()).anySatisfy(msg -> assertThat(msg).contains("ids")); + } + } + + @Test + @DisplayName("when batch update with null update, then return 400") + void batchUpdate__whenNullUpdate__thenReturn400() { + var span = podamFactory.manufacturePojo(Span.class).toBuilder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .parentSpanId(null) + .build(); + var id = spanResourceClient.createSpan(span, API_KEY, TEST_WORKSPACE); + + var batchUpdate = SpanBatchUpdate.builder() + .ids(Set.of(id)) + .update(null) + .mergeTags(true) + .build(); + + try (var actualResponse = spanResourceClient.callBatchUpdateSpans(batchUpdate, API_KEY, TEST_WORKSPACE)) { + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(422); + assertThat(actualResponse.hasEntity()).isTrue(); + } + } + } + + @Nested + @DisplayName("Batch Update All Fields:") + @TestInstance(TestInstance.Lifecycle.PER_CLASS) + class BatchUpdateAllFields { + + private UUID traceId; + + @BeforeEach + void setUp() { + var trace = podamFactory.manufacturePojo(Trace.class).toBuilder() + .projectName(DEFAULT_PROJECT) + .feedbackScores(null) + .build(); + traceId = traceResourceClient.createTrace(trace, API_KEY, TEST_WORKSPACE); + } + + @Test + @DisplayName("Success: batch update name field") + void batchUpdate__updateName__success() { + // Create spans + var span1 = podamFactory.manufacturePojo(Span.class).toBuilder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .parentSpanId(null) + .name("original-name-1") + .build(); + var span2 = podamFactory.manufacturePojo(Span.class).toBuilder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .parentSpanId(null) + .name("original-name-2") + .build(); + + var id1 = spanResourceClient.createSpan(span1, API_KEY, TEST_WORKSPACE); + var id2 = spanResourceClient.createSpan(span2, API_KEY, TEST_WORKSPACE); + + // Batch update with new name + var batchUpdate = SpanBatchUpdate.builder() + .ids(Set.of(id1, id2)) + .update(SpanUpdate.builder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .name("updated-name") + .build()) + .mergeTags(false) + .build(); + + spanResourceClient.batchUpdateSpans(batchUpdate, API_KEY, TEST_WORKSPACE); + + // Verify spans were updated + var updatedSpan1 = spanResourceClient.getById(id1, TEST_WORKSPACE, API_KEY); + assertThat(updatedSpan1.name()).isEqualTo("updated-name"); + + var updatedSpan2 = spanResourceClient.getById(id2, TEST_WORKSPACE, API_KEY); + assertThat(updatedSpan2.name()).isEqualTo("updated-name"); + } + + @Test + @DisplayName("Success: batch update type field") + void batchUpdate__updateType__success() { + // Create spans + var span1 = podamFactory.manufacturePojo(Span.class).toBuilder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .parentSpanId(null) + .type(SpanType.general) + .build(); + var span2 = podamFactory.manufacturePojo(Span.class).toBuilder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .parentSpanId(null) + .type(SpanType.general) + .build(); + + var id1 = spanResourceClient.createSpan(span1, API_KEY, TEST_WORKSPACE); + var id2 = spanResourceClient.createSpan(span2, API_KEY, TEST_WORKSPACE); + + // Batch update with new type + var batchUpdate = SpanBatchUpdate.builder() + .ids(Set.of(id1, id2)) + .update(SpanUpdate.builder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .type(SpanType.llm) + .build()) + .mergeTags(false) + .build(); + + spanResourceClient.batchUpdateSpans(batchUpdate, API_KEY, TEST_WORKSPACE); + + // Verify spans were updated + var updatedSpan1 = spanResourceClient.getById(id1, TEST_WORKSPACE, API_KEY); + assertThat(updatedSpan1.type()).isEqualTo(SpanType.llm); + + var updatedSpan2 = spanResourceClient.getById(id2, TEST_WORKSPACE, API_KEY); + assertThat(updatedSpan2.type()).isEqualTo(SpanType.llm); + } + + @Test + @DisplayName("Success: batch update input and output fields") + void batchUpdate__updateInputOutput__success() { + // Create spans + var span1 = podamFactory.manufacturePojo(Span.class).toBuilder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .parentSpanId(null) + .build(); + var span2 = podamFactory.manufacturePojo(Span.class).toBuilder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .parentSpanId(null) + .build(); + + var id1 = spanResourceClient.createSpan(span1, API_KEY, TEST_WORKSPACE); + var id2 = spanResourceClient.createSpan(span2, API_KEY, TEST_WORKSPACE); + + // Batch update with new input/output + JsonNode newInput = JsonUtils.readTree(Map.of("prompt", "updated prompt")); + JsonNode newOutput = JsonUtils.readTree(Map.of("response", "updated response")); + + var batchUpdate = SpanBatchUpdate.builder() + .ids(Set.of(id1, id2)) + .update(SpanUpdate.builder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .input(newInput) + .output(newOutput) + .build()) + .mergeTags(false) + .build(); + + spanResourceClient.batchUpdateSpans(batchUpdate, API_KEY, TEST_WORKSPACE); + + // Verify spans were updated + var updatedSpan1 = spanResourceClient.getById(id1, TEST_WORKSPACE, API_KEY); + assertThat(updatedSpan1.input().get("prompt").asText()).isEqualTo("updated prompt"); + assertThat(updatedSpan1.output().get("response").asText()).isEqualTo("updated response"); + + var updatedSpan2 = spanResourceClient.getById(id2, TEST_WORKSPACE, API_KEY); + assertThat(updatedSpan2.input().get("prompt").asText()).isEqualTo("updated prompt"); + assertThat(updatedSpan2.output().get("response").asText()).isEqualTo("updated response"); + } + + @Test + @DisplayName("Success: batch update metadata field") + void batchUpdate__updateMetadata__success() { + // Create spans + var span1 = podamFactory.manufacturePojo(Span.class).toBuilder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .parentSpanId(null) + .build(); + var span2 = podamFactory.manufacturePojo(Span.class).toBuilder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .parentSpanId(null) + .build(); + + var id1 = spanResourceClient.createSpan(span1, API_KEY, TEST_WORKSPACE); + var id2 = spanResourceClient.createSpan(span2, API_KEY, TEST_WORKSPACE); + + // Batch update with new metadata + JsonNode newMetadata = JsonUtils.readTree(Map.of("key1", "value1", "key2", "value2")); + + var batchUpdate = SpanBatchUpdate.builder() + .ids(Set.of(id1, id2)) + .update(SpanUpdate.builder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .metadata(newMetadata) + .build()) + .mergeTags(false) + .build(); + + spanResourceClient.batchUpdateSpans(batchUpdate, API_KEY, TEST_WORKSPACE); + + // Verify spans were updated + var updatedSpan1 = spanResourceClient.getById(id1, TEST_WORKSPACE, API_KEY); + assertThat(updatedSpan1.metadata().get("key1").asText()).isEqualTo("value1"); + assertThat(updatedSpan1.metadata().get("key2").asText()).isEqualTo("value2"); + + var updatedSpan2 = spanResourceClient.getById(id2, TEST_WORKSPACE, API_KEY); + assertThat(updatedSpan2.metadata().get("key1").asText()).isEqualTo("value1"); + assertThat(updatedSpan2.metadata().get("key2").asText()).isEqualTo("value2"); + } + + @Test + @DisplayName("Success: batch update model and provider fields") + void batchUpdate__updateModelProvider__success() { + // Create spans + var span1 = podamFactory.manufacturePojo(Span.class).toBuilder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .parentSpanId(null) + .model("gpt-3.5-turbo") + .provider("openai") + .build(); + var span2 = podamFactory.manufacturePojo(Span.class).toBuilder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .parentSpanId(null) + .model("gpt-3.5-turbo") + .provider("openai") + .build(); + + var id1 = spanResourceClient.createSpan(span1, API_KEY, TEST_WORKSPACE); + var id2 = spanResourceClient.createSpan(span2, API_KEY, TEST_WORKSPACE); + + // Batch update with new model and provider + var batchUpdate = SpanBatchUpdate.builder() + .ids(Set.of(id1, id2)) + .update(SpanUpdate.builder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .model("gpt-4") + .provider("openai") + .build()) + .mergeTags(false) + .build(); + + spanResourceClient.batchUpdateSpans(batchUpdate, API_KEY, TEST_WORKSPACE); + + // Verify spans were updated + var updatedSpan1 = spanResourceClient.getById(id1, TEST_WORKSPACE, API_KEY); + assertThat(updatedSpan1.model()).isEqualTo("gpt-4"); + assertThat(updatedSpan1.provider()).isEqualTo("openai"); + + var updatedSpan2 = spanResourceClient.getById(id2, TEST_WORKSPACE, API_KEY); + assertThat(updatedSpan2.model()).isEqualTo("gpt-4"); + assertThat(updatedSpan2.provider()).isEqualTo("openai"); + } + + @Test + @DisplayName("Success: batch update usage field") + void batchUpdate__updateUsage__success() { + // Create spans + var span1 = podamFactory.manufacturePojo(Span.class).toBuilder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .parentSpanId(null) + .usage(Map.of("prompt_tokens", 100, "completion_tokens", 50)) + .build(); + var span2 = podamFactory.manufacturePojo(Span.class).toBuilder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .parentSpanId(null) + .usage(Map.of("prompt_tokens", 200, "completion_tokens", 100)) + .build(); + + var id1 = spanResourceClient.createSpan(span1, API_KEY, TEST_WORKSPACE); + var id2 = spanResourceClient.createSpan(span2, API_KEY, TEST_WORKSPACE); + + // Batch update with new usage + var newUsage = Map.of("prompt_tokens", 500, "completion_tokens", 250); + + var batchUpdate = SpanBatchUpdate.builder() + .ids(Set.of(id1, id2)) + .update(SpanUpdate.builder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .usage(newUsage) + .build()) + .mergeTags(false) + .build(); + + spanResourceClient.batchUpdateSpans(batchUpdate, API_KEY, TEST_WORKSPACE); + + // Verify spans were updated + var updatedSpan1 = spanResourceClient.getById(id1, TEST_WORKSPACE, API_KEY); + assertThat(updatedSpan1.usage()).containsEntry("prompt_tokens", 500); + assertThat(updatedSpan1.usage()).containsEntry("completion_tokens", 250); + + var updatedSpan2 = spanResourceClient.getById(id2, TEST_WORKSPACE, API_KEY); + assertThat(updatedSpan2.usage()).containsEntry("prompt_tokens", 500); + assertThat(updatedSpan2.usage()).containsEntry("completion_tokens", 250); + } + + @Test + @DisplayName("Success: batch update end_time field") + void batchUpdate__updateEndTime__success() { + // Create spans with start time + Instant startTime = Instant.now().minusSeconds(3600); + var span1 = podamFactory.manufacturePojo(Span.class).toBuilder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .parentSpanId(null) + .startTime(startTime) + .endTime(null) + .build(); + var span2 = podamFactory.manufacturePojo(Span.class).toBuilder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .parentSpanId(null) + .startTime(startTime) + .endTime(null) + .build(); + + var id1 = spanResourceClient.createSpan(span1, API_KEY, TEST_WORKSPACE); + var id2 = spanResourceClient.createSpan(span2, API_KEY, TEST_WORKSPACE); + + // Batch update with end time + Instant endTime = Instant.now(); + var batchUpdate = SpanBatchUpdate.builder() + .ids(Set.of(id1, id2)) + .update(SpanUpdate.builder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .endTime(endTime) + .build()) + .mergeTags(false) + .build(); + + spanResourceClient.batchUpdateSpans(batchUpdate, API_KEY, TEST_WORKSPACE); + + // Verify spans were updated + var updatedSpan1 = spanResourceClient.getById(id1, TEST_WORKSPACE, API_KEY); + assertThat(updatedSpan1.endTime()).isNotNull(); + assertThat(updatedSpan1.endTime().toEpochMilli()).isEqualTo(endTime.toEpochMilli()); + + var updatedSpan2 = spanResourceClient.getById(id2, TEST_WORKSPACE, API_KEY); + assertThat(updatedSpan2.endTime()).isNotNull(); + assertThat(updatedSpan2.endTime().toEpochMilli()).isEqualTo(endTime.toEpochMilli()); + } + + @Test + @DisplayName("Success: batch update totalEstimatedCost field") + void batchUpdate__updateTotalEstimatedCost__success() { + // Create spans + var span1 = podamFactory.manufacturePojo(Span.class).toBuilder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .parentSpanId(null) + .build(); + var span2 = podamFactory.manufacturePojo(Span.class).toBuilder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .parentSpanId(null) + .build(); + + var id1 = spanResourceClient.createSpan(span1, API_KEY, TEST_WORKSPACE); + var id2 = spanResourceClient.createSpan(span2, API_KEY, TEST_WORKSPACE); + + // Batch update with new cost + BigDecimal newCost = new BigDecimal("0.005"); + + var batchUpdate = SpanBatchUpdate.builder() + .ids(Set.of(id1, id2)) + .update(SpanUpdate.builder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .totalEstimatedCost(newCost) + .build()) + .mergeTags(false) + .build(); + + spanResourceClient.batchUpdateSpans(batchUpdate, API_KEY, TEST_WORKSPACE); + + // Verify spans were updated + var updatedSpan1 = spanResourceClient.getById(id1, TEST_WORKSPACE, API_KEY); + assertThat(updatedSpan1.totalEstimatedCost()).isEqualByComparingTo(newCost); + + var updatedSpan2 = spanResourceClient.getById(id2, TEST_WORKSPACE, API_KEY); + assertThat(updatedSpan2.totalEstimatedCost()).isEqualByComparingTo(newCost); + } + + @Test + @DisplayName("Success: batch update errorInfo field") + void batchUpdate__updateErrorInfo__success() { + // Create spans + var span1 = podamFactory.manufacturePojo(Span.class).toBuilder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .parentSpanId(null) + .errorInfo(null) + .build(); + var span2 = podamFactory.manufacturePojo(Span.class).toBuilder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .parentSpanId(null) + .errorInfo(null) + .build(); + + var id1 = spanResourceClient.createSpan(span1, API_KEY, TEST_WORKSPACE); + var id2 = spanResourceClient.createSpan(span2, API_KEY, TEST_WORKSPACE); + + // Batch update with error info + var errorInfo = ErrorInfo.builder() + .exceptionType("ValidationError") + .message("Invalid input") + .traceback("Stack trace here") + .build(); + + var batchUpdate = SpanBatchUpdate.builder() + .ids(Set.of(id1, id2)) + .update(SpanUpdate.builder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .errorInfo(errorInfo) + .build()) + .mergeTags(false) + .build(); + + spanResourceClient.batchUpdateSpans(batchUpdate, API_KEY, TEST_WORKSPACE); + + // Verify spans were updated + var updatedSpan1 = spanResourceClient.getById(id1, TEST_WORKSPACE, API_KEY); + assertThat(updatedSpan1.errorInfo()).isNotNull(); + assertThat(updatedSpan1.errorInfo().exceptionType()).isEqualTo("ValidationError"); + assertThat(updatedSpan1.errorInfo().message()).isEqualTo("Invalid input"); + + var updatedSpan2 = spanResourceClient.getById(id2, TEST_WORKSPACE, API_KEY); + assertThat(updatedSpan2.errorInfo()).isNotNull(); + assertThat(updatedSpan2.errorInfo().exceptionType()).isEqualTo("ValidationError"); + assertThat(updatedSpan2.errorInfo().message()).isEqualTo("Invalid input"); + } + + @Test + @DisplayName("Success: batch update multiple fields simultaneously") + void batchUpdate__updateMultipleFields__success() { + // Create spans + var span1 = podamFactory.manufacturePojo(Span.class).toBuilder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .parentSpanId(null) + .name("old-name") + .type(SpanType.general) + .tags(Set.of("old-tag")) + .build(); + var span2 = podamFactory.manufacturePojo(Span.class).toBuilder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .parentSpanId(null) + .name("old-name") + .type(SpanType.general) + .tags(Set.of("old-tag")) + .build(); + + var id1 = spanResourceClient.createSpan(span1, API_KEY, TEST_WORKSPACE); + var id2 = spanResourceClient.createSpan(span2, API_KEY, TEST_WORKSPACE); + + // Batch update with multiple fields + JsonNode newMetadata = JsonUtils.readTree(Map.of("environment", "production")); + + var batchUpdate = SpanBatchUpdate.builder() + .ids(Set.of(id1, id2)) + .update(SpanUpdate.builder() + .projectName(DEFAULT_PROJECT) + .traceId(traceId) + .name("updated-name") + .type(SpanType.llm) + .tags(Set.of("new-tag")) + .metadata(newMetadata) + .model("gpt-4") + .provider("openai") + .build()) + .mergeTags(false) + .build(); + + spanResourceClient.batchUpdateSpans(batchUpdate, API_KEY, TEST_WORKSPACE); + + // Verify all fields were updated + var updatedSpan1 = spanResourceClient.getById(id1, TEST_WORKSPACE, API_KEY); + assertThat(updatedSpan1.name()).isEqualTo("updated-name"); + assertThat(updatedSpan1.type()).isEqualTo(SpanType.llm); + assertThat(updatedSpan1.tags()).containsExactly("new-tag"); + assertThat(updatedSpan1.metadata().get("environment").asText()).isEqualTo("production"); + assertThat(updatedSpan1.model()).isEqualTo("gpt-4"); + assertThat(updatedSpan1.provider()).isEqualTo("openai"); + + var updatedSpan2 = spanResourceClient.getById(id2, TEST_WORKSPACE, API_KEY); + assertThat(updatedSpan2.name()).isEqualTo("updated-name"); + assertThat(updatedSpan2.type()).isEqualTo(SpanType.llm); + assertThat(updatedSpan2.tags()).containsExactly("new-tag"); + assertThat(updatedSpan2.metadata().get("environment").asText()).isEqualTo("production"); + assertThat(updatedSpan2.model()).isEqualTo("gpt-4"); + assertThat(updatedSpan2.provider()).isEqualTo("openai"); + } + } +} diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java index b8dddb44b6c..c35687c1709 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java @@ -16,9 +16,11 @@ import com.comet.opik.api.ScoreSource; import com.comet.opik.api.Span; import com.comet.opik.api.Trace; +import com.comet.opik.api.TraceBatchUpdate; import com.comet.opik.api.TraceSearchStreamRequest; import com.comet.opik.api.TraceThread; import com.comet.opik.api.TraceThread.TraceThreadPage; +import com.comet.opik.api.TraceThreadBatchUpdate; import com.comet.opik.api.TraceThreadIdentifier; import com.comet.opik.api.TraceThreadStatus; import com.comet.opik.api.TraceThreadUpdate; @@ -112,6 +114,7 @@ import java.util.Base64; import java.util.Comparator; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; @@ -3603,6 +3606,182 @@ void update__whenUpdatingTraceWithDifferentAttachments__thenOldAttachmentsAreDel } } + @Nested + @DisplayName("Batch Update Traces Tags:") + @TestInstance(TestInstance.Lifecycle.PER_CLASS) + class BatchUpdateTraces { + + Stream mergeTagsTestCases() { + return Stream.of( + Arguments.of(true, "merge"), + Arguments.of(false, "replace")); + } + + @ParameterizedTest(name = "Success: batch update tags with {1} mode") + @MethodSource("mergeTagsTestCases") + @DisplayName("Success: batch update tags for multiple traces") + void batchUpdate__success(boolean mergeTags, String mode) { + // Create traces with existing tags + var trace1 = factory.manufacturePojo(Trace.class).toBuilder() + .projectName(DEFAULT_PROJECT) + .tags(mergeTags ? Set.of("existing-tag-1", "existing-tag-2") : Set.of("old-tag-1", "old-tag-2")) + .feedbackScores(null) + .build(); + var trace2 = factory.manufacturePojo(Trace.class).toBuilder() + .projectName(DEFAULT_PROJECT) + .tags(mergeTags ? Set.of("existing-tag-3") : Set.of("old-tag-3")) + .feedbackScores(null) + .build(); + var trace3 = mergeTags + ? factory.manufacturePojo(Trace.class).toBuilder() + .projectName(DEFAULT_PROJECT) + .tags(null) + .feedbackScores(null) + .build() + : null; + + var id1 = traceResourceClient.createTrace(trace1, API_KEY, TEST_WORKSPACE); + var id2 = traceResourceClient.createTrace(trace2, API_KEY, TEST_WORKSPACE); + var id3 = mergeTags ? traceResourceClient.createTrace(trace3, API_KEY, TEST_WORKSPACE) : null; + + // Batch update with new tags + var newTags = mergeTags ? Set.of("new-tag-1", "new-tag-2") : Set.of("new-tag"); + var ids = mergeTags ? Set.of(id1, id2, id3) : Set.of(id1, id2); + var batchUpdate = TraceBatchUpdate.builder() + .ids(ids) + .update(TraceUpdate.builder() + .projectName(DEFAULT_PROJECT) + .tags(newTags) + .build()) + .mergeTags(mergeTags) + .build(); + + traceResourceClient.batchUpdateTraces(batchUpdate, API_KEY, TEST_WORKSPACE); + + // Verify traces were updated + var updatedTrace1 = traceResourceClient.getById(id1, TEST_WORKSPACE, API_KEY); + if (mergeTags) { + assertThat(updatedTrace1.tags()).containsExactlyInAnyOrder( + "existing-tag-1", "existing-tag-2", "new-tag-1", "new-tag-2"); + } else { + assertThat(updatedTrace1.tags()).containsExactly("new-tag"); + } + + var updatedTrace2 = traceResourceClient.getById(id2, TEST_WORKSPACE, API_KEY); + if (mergeTags) { + assertThat(updatedTrace2.tags()).containsExactlyInAnyOrder("existing-tag-3", "new-tag-1", "new-tag-2"); + } else { + assertThat(updatedTrace2.tags()).containsExactly("new-tag"); + } + + if (mergeTags) { + var updatedTrace3 = traceResourceClient.getById(id3, TEST_WORKSPACE, API_KEY); + assertThat(updatedTrace3.tags()).containsExactlyInAnyOrder("new-tag-1", "new-tag-2"); + } + } + + @Test + @DisplayName("when batch update with empty IDs, then return 400") + void batchUpdate__whenEmptyIds__thenReturn400() { + var batchUpdate = TraceBatchUpdate.builder() + .ids(Set.of()) + .update(TraceUpdate.builder() + .projectName(DEFAULT_PROJECT) + .tags(Set.of("tag")) + .build()) + .mergeTags(true) + .build(); + + try (var actualResponse = traceResourceClient.callBatchUpdateTraces(batchUpdate, API_KEY, TEST_WORKSPACE)) { + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(422); + assertThat(actualResponse.hasEntity()).isTrue(); + var error = actualResponse.readEntity(ErrorMessage.class); + assertThat(error.errors()).anySatisfy(msg -> assertThat(msg).contains("ids")); + } + } + + @Test + @DisplayName("when batch update with too many IDs, then return 400") + void batchUpdate__whenTooManyIds__thenReturn400() { + // Create 1001 IDs (exceeds max of 1000) + var ids = new HashSet(); + for (int i = 0; i < 1001; i++) { + ids.add(generator.generate()); + } + + var batchUpdate = TraceBatchUpdate.builder() + .ids(ids) + .update(TraceUpdate.builder() + .projectName(DEFAULT_PROJECT) + .tags(Set.of("tag")) + .build()) + .mergeTags(true) + .build(); + + try (var actualResponse = traceResourceClient.callBatchUpdateTraces(batchUpdate, API_KEY, TEST_WORKSPACE)) { + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(422); + assertThat(actualResponse.hasEntity()).isTrue(); + var error = actualResponse.readEntity(ErrorMessage.class); + assertThat(error.errors()).anySatisfy(msg -> assertThat(msg).contains("ids")); + } + } + + @Test + @DisplayName("when batch update with null update, then return 400") + void batchUpdate__whenNullUpdate__thenReturn400() { + var trace = factory.manufacturePojo(Trace.class).toBuilder() + .projectName(DEFAULT_PROJECT) + .feedbackScores(null) + .build(); + var id = traceResourceClient.createTrace(trace, API_KEY, TEST_WORKSPACE); + + var batchUpdate = TraceBatchUpdate.builder() + .ids(Set.of(id)) + .update(null) + .mergeTags(true) + .build(); + + try (var actualResponse = traceResourceClient.callBatchUpdateTraces(batchUpdate, API_KEY, TEST_WORKSPACE)) { + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(422); + assertThat(actualResponse.hasEntity()).isTrue(); + } + } + + @Test + @DisplayName("when batch update with max size (1000), then success") + void batchUpdate__whenMaxSize__thenSuccess() { + // Create 1000 traces + var ids = new HashSet(); + for (int i = 0; i < 1000; i++) { + var trace = factory.manufacturePojo(Trace.class).toBuilder() + .projectName(DEFAULT_PROJECT) + .tags(Set.of("old-tag")) + .feedbackScores(null) + .build(); + var id = traceResourceClient.createTrace(trace, API_KEY, TEST_WORKSPACE); + ids.add(id); + } + + var batchUpdate = TraceBatchUpdate.builder() + .ids(ids) + .update(TraceUpdate.builder() + .projectName(DEFAULT_PROJECT) + .tags(Set.of("new-tag")) + .build()) + .mergeTags(true) + .build(); + + traceResourceClient.batchUpdateTraces(batchUpdate, API_KEY, TEST_WORKSPACE); + + // Verify a sample of traces + var sampleIds = ids.stream().limit(10).toList(); + for (var id : sampleIds) { + var trace = traceResourceClient.getById(id, TEST_WORKSPACE, API_KEY); + assertThat(trace.tags()).containsExactlyInAnyOrder("old-tag", "new-tag"); + } + } + } + @Nested @DisplayName("Comment:") @TestInstance(TestInstance.Lifecycle.PER_CLASS) @@ -5796,4 +5975,171 @@ void whenThreadIsClosedWithMixedScores_andReopened_thenOnlyManualScoresAreDelete } } + @Nested + @DisplayName("Batch Update Threads Tags:") + @TestInstance(TestInstance.Lifecycle.PER_CLASS) + class BatchUpdateThreads { + + Stream mergeTagsTestCases() { + return Stream.of( + Arguments.of(true, "merge", 3), + Arguments.of(false, "replace", 2)); + } + + @ParameterizedTest(name = "Success: batch update tags with {1} mode") + @MethodSource("mergeTagsTestCases") + @DisplayName("Success: batch update tags for multiple threads") + void batchUpdate__success(boolean mergeTags, String mode, int threadCount) { + // Create thread IDs + var threadId1 = UUID.randomUUID().toString(); + var threadId2 = UUID.randomUUID().toString(); + var threadId3 = mergeTags ? UUID.randomUUID().toString() : null; + + var projectName = "project-" + RandomStringUtils.secure().nextAlphanumeric(32); + projectResourceClient.createProject(projectName, API_KEY, TEST_WORKSPACE); + var projectId = projectResourceClient.getByName(projectName, API_KEY, TEST_WORKSPACE).id(); + + // Create traces to create threads + create(createTrace().toBuilder().projectName(projectName).threadId(threadId1).build(), API_KEY, + TEST_WORKSPACE); + create(createTrace().toBuilder().projectName(projectName).threadId(threadId2).build(), API_KEY, + TEST_WORKSPACE); + if (mergeTags) { + create(createTrace().toBuilder().projectName(projectName).threadId(threadId3).build(), API_KEY, + TEST_WORKSPACE); + } + + // Wait for threads to be created + Awaitility.await() + .atMost(5, TimeUnit.SECONDS) + .untilAsserted(() -> { + var threads = traceResourceClient.getTraceThreads(projectId, null, API_KEY, TEST_WORKSPACE, + null, null, null); + assertThat(threads.content()).hasSize(threadCount); + }); + + var threads = traceResourceClient.getTraceThreads(projectId, null, API_KEY, TEST_WORKSPACE, + null, null, null); + var threadModelId1 = threads.content().stream().filter(t -> t.id().equals(threadId1)).findFirst() + .get().threadModelId(); + var threadModelId2 = threads.content().stream().filter(t -> t.id().equals(threadId2)).findFirst() + .get().threadModelId(); + var threadModelId3 = mergeTags + ? threads.content().stream().filter(t -> t.id().equals(threadId3)).findFirst() + .get().threadModelId() + : null; + + // Update threads with existing tags + if (mergeTags) { + traceResourceClient.updateThread(TraceThreadUpdate.builder().tags(Set.of("existing-tag-1")).build(), + threadModelId1, API_KEY, TEST_WORKSPACE, 204); + traceResourceClient.updateThread(TraceThreadUpdate.builder().tags(Set.of("existing-tag-2")).build(), + threadModelId2, API_KEY, TEST_WORKSPACE, 204); + } else { + traceResourceClient.updateThread( + TraceThreadUpdate.builder().tags(Set.of("old-tag-1", "old-tag-2")).build(), + threadModelId1, API_KEY, TEST_WORKSPACE, 204); + traceResourceClient.updateThread(TraceThreadUpdate.builder().tags(Set.of("old-tag-3")).build(), + threadModelId2, API_KEY, TEST_WORKSPACE, 204); + } + + // Batch update with new tags + var newTags = mergeTags ? Set.of("new-tag-1", "new-tag-2") : Set.of("new-tag"); + var ids = mergeTags + ? Set.of(threadModelId1, threadModelId2, threadModelId3) + : Set.of(threadModelId1, threadModelId2); + var batchUpdate = TraceThreadBatchUpdate.builder() + .ids(ids) + .update(TraceThreadUpdate.builder() + .tags(newTags) + .build()) + .mergeTags(mergeTags) + .build(); + + traceResourceClient.batchUpdateThreads(batchUpdate, API_KEY, TEST_WORKSPACE); + + // Verify threads were updated + var thread1 = traceResourceClient.getTraceThread(threadId1, projectId, API_KEY, TEST_WORKSPACE); + if (mergeTags) { + assertThat(thread1.tags()).containsExactlyInAnyOrder("existing-tag-1", "new-tag-1", "new-tag-2"); + } else { + assertThat(thread1.tags()).containsExactly("new-tag"); + } + + var thread2 = traceResourceClient.getTraceThread(threadId2, projectId, API_KEY, TEST_WORKSPACE); + if (mergeTags) { + assertThat(thread2.tags()).containsExactlyInAnyOrder("existing-tag-2", "new-tag-1", "new-tag-2"); + } else { + assertThat(thread2.tags()).containsExactly("new-tag"); + } + + if (mergeTags) { + var thread3 = traceResourceClient.getTraceThread(threadId3, projectId, API_KEY, TEST_WORKSPACE); + assertThat(thread3.tags()).containsExactlyInAnyOrder("new-tag-1", "new-tag-2"); + } + } + + @Test + @DisplayName("when batch update with empty IDs, then return 400") + void batchUpdate__whenEmptyIds__thenReturn400() { + var batchUpdate = TraceThreadBatchUpdate.builder() + .ids(Set.of()) + .update(TraceThreadUpdate.builder() + .tags(Set.of("tag")) + .build()) + .mergeTags(true) + .build(); + + try (var actualResponse = traceResourceClient.callBatchUpdateThreads(batchUpdate, API_KEY, + TEST_WORKSPACE)) { + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(422); + assertThat(actualResponse.hasEntity()).isTrue(); + var error = actualResponse.readEntity(ErrorMessage.class); + assertThat(error.errors()).anySatisfy(msg -> assertThat(msg).contains("ids")); + } + } + + @Test + @DisplayName("when batch update with too many IDs, then return 400") + void batchUpdate__whenTooManyIds__thenReturn400() { + // Create 1001 IDs (exceeds max of 1000) + var ids = new HashSet(); + for (int i = 0; i < 1001; i++) { + ids.add(generator.generate()); + } + + var batchUpdate = TraceThreadBatchUpdate.builder() + .ids(ids) + .update(TraceThreadUpdate.builder() + .tags(Set.of("tag")) + .build()) + .mergeTags(true) + .build(); + + try (var actualResponse = traceResourceClient.callBatchUpdateThreads(batchUpdate, API_KEY, + TEST_WORKSPACE)) { + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(422); + assertThat(actualResponse.hasEntity()).isTrue(); + var error = actualResponse.readEntity(ErrorMessage.class); + assertThat(error.errors()).anySatisfy(msg -> assertThat(msg).contains("ids")); + } + } + + @Test + @DisplayName("when batch update with null update, then return 400") + void batchUpdate__whenNullUpdate__thenReturn400() { + var batchUpdate = TraceThreadBatchUpdate.builder() + .ids(Set.of(generator.generate())) + .update(null) + .mergeTags(true) + .build(); + + try (var actualResponse = traceResourceClient.callBatchUpdateThreads(batchUpdate, API_KEY, + TEST_WORKSPACE)) { + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(422); + assertThat(actualResponse.hasEntity()).isTrue(); + } + } + } + } \ No newline at end of file diff --git a/apps/opik-frontend/src/api/traces/useSpanBatchUpdateMutation.ts b/apps/opik-frontend/src/api/traces/useSpanBatchUpdateMutation.ts new file mode 100644 index 00000000000..68582d3da82 --- /dev/null +++ b/apps/opik-frontend/src/api/traces/useSpanBatchUpdateMutation.ts @@ -0,0 +1,55 @@ +import { useMutation, useQueryClient } from "@tanstack/react-query"; +import { AxiosError } from "axios"; +import get from "lodash/get"; + +import api, { SPANS_KEY, SPANS_REST_ENDPOINT } from "@/api/api"; +import { Span } from "@/types/traces"; +import { useToast } from "@/components/ui/use-toast"; + +type UseSpanBatchUpdateMutationParams = { + projectId: string; + spanIds: string[]; + span: Partial; + mergeTags?: boolean; +}; + +const useSpanBatchUpdateMutation = () => { + const queryClient = useQueryClient(); + const { toast } = useToast(); + + return useMutation({ + mutationFn: async ({ + spanIds, + span, + mergeTags, + }: UseSpanBatchUpdateMutationParams) => { + const { data } = await api.patch(SPANS_REST_ENDPOINT + "batch", { + ids: spanIds, + update: span, + merge_tags: mergeTags, + }); + + return data; + }, + onError: (error: AxiosError) => { + const message = get( + error, + ["response", "data", "message"], + error.message, + ); + + toast({ + title: "Error", + description: message, + variant: "destructive", + }); + }, + onSettled: (data, error, variables) => { + queryClient.invalidateQueries({ + queryKey: [SPANS_KEY, { projectId: variables.projectId }], + }); + }, + }); +}; + +export default useSpanBatchUpdateMutation; diff --git a/apps/opik-frontend/src/api/traces/useThreadBatchUpdateMutation.ts b/apps/opik-frontend/src/api/traces/useThreadBatchUpdateMutation.ts new file mode 100644 index 00000000000..27eab99eb5e --- /dev/null +++ b/apps/opik-frontend/src/api/traces/useThreadBatchUpdateMutation.ts @@ -0,0 +1,56 @@ +import { useMutation, useQueryClient } from "@tanstack/react-query"; +import { AxiosError } from "axios"; +import get from "lodash/get"; + +import api, { THREADS_KEY, TRACES_REST_ENDPOINT } from "@/api/api"; +import { useToast } from "@/components/ui/use-toast"; + +type UseThreadBatchUpdateMutationParams = { + projectId: string; + threadIds: string[]; + thread: { + tags?: string[]; + }; + mergeTags?: boolean; +}; + +const useThreadBatchUpdateMutation = () => { + const queryClient = useQueryClient(); + const { toast } = useToast(); + + return useMutation({ + mutationFn: async ({ + threadIds, + thread, + mergeTags, + }: UseThreadBatchUpdateMutationParams) => { + const { data } = await api.patch(TRACES_REST_ENDPOINT + "threads/batch", { + ids: threadIds, + update: thread, + merge_tags: mergeTags, + }); + + return data; + }, + onError: (error: AxiosError) => { + const message = get( + error, + ["response", "data", "message"], + error.message, + ); + + toast({ + title: "Error", + description: message, + variant: "destructive", + }); + }, + onSettled: (data, error, variables) => { + queryClient.invalidateQueries({ + queryKey: [THREADS_KEY, { projectId: variables.projectId }], + }); + }, + }); +}; + +export default useThreadBatchUpdateMutation; diff --git a/apps/opik-frontend/src/api/traces/useTraceBatchUpdateMutation.ts b/apps/opik-frontend/src/api/traces/useTraceBatchUpdateMutation.ts new file mode 100644 index 00000000000..babb1d1a7fa --- /dev/null +++ b/apps/opik-frontend/src/api/traces/useTraceBatchUpdateMutation.ts @@ -0,0 +1,55 @@ +import { useMutation, useQueryClient } from "@tanstack/react-query"; +import { AxiosError } from "axios"; +import get from "lodash/get"; + +import api, { TRACES_KEY, TRACES_REST_ENDPOINT } from "@/api/api"; +import { Trace } from "@/types/traces"; +import { useToast } from "@/components/ui/use-toast"; + +type UseTraceBatchUpdateMutationParams = { + projectId: string; + traceIds: string[]; + trace: Partial; + mergeTags?: boolean; +}; + +const useTraceBatchUpdateMutation = () => { + const queryClient = useQueryClient(); + const { toast } = useToast(); + + return useMutation({ + mutationFn: async ({ + traceIds, + trace, + mergeTags, + }: UseTraceBatchUpdateMutationParams) => { + const { data } = await api.patch(TRACES_REST_ENDPOINT + "batch", { + ids: traceIds, + update: trace, + merge_tags: mergeTags, + }); + + return data; + }, + onError: (error: AxiosError) => { + const message = get( + error, + ["response", "data", "message"], + error.message, + ); + + toast({ + title: "Error", + description: message, + variant: "destructive", + }); + }, + onSettled: (data, error, variables) => { + queryClient.invalidateQueries({ + queryKey: [TRACES_KEY, { projectId: variables.projectId }], + }); + }, + }); +}; + +export default useTraceBatchUpdateMutation; diff --git a/apps/opik-frontend/src/components/pages-shared/traces/AddTagDialog/AddTagDialog.tsx b/apps/opik-frontend/src/components/pages-shared/traces/AddTagDialog/AddTagDialog.tsx index e2b77e50d80..5f06f2d990f 100644 --- a/apps/opik-frontend/src/components/pages-shared/traces/AddTagDialog/AddTagDialog.tsx +++ b/apps/opik-frontend/src/components/pages-shared/traces/AddTagDialog/AddTagDialog.tsx @@ -1,5 +1,5 @@ import React, { useState } from "react"; -import { Trace, Span } from "@/types/traces"; +import { Trace, Span, Thread } from "@/types/traces"; import { TRACE_DATA_TYPE } from "@/hooks/useTracesOrSpansList"; import { Dialog, @@ -11,16 +11,23 @@ import { import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; import { useToast } from "@/components/ui/use-toast"; -import useTraceUpdateMutation from "@/api/traces/useTraceUpdateMutation"; -import useSpanUpdateMutation from "@/api/traces/useSpanUpdateMutation"; +import useTraceBatchUpdateMutation from "@/api/traces/useTraceBatchUpdateMutation"; +import useSpanBatchUpdateMutation from "@/api/traces/useSpanBatchUpdateMutation"; +import useThreadBatchUpdateMutation from "@/api/traces/useThreadBatchUpdateMutation"; import useAppStore from "@/store/AppStore"; +export enum TAG_ENTITY_TYPE { + traces = "traces", + spans = "spans", + threads = "threads", +} + type AddTagDialogProps = { - rows: Array; + rows: Array; open: boolean | number; setOpen: (open: boolean | number) => void; projectId: string; - type: TRACE_DATA_TYPE; + type: TRACE_DATA_TYPE | TAG_ENTITY_TYPE.threads; onSuccess?: () => void; }; @@ -35,9 +42,9 @@ const AddTagDialog: React.FunctionComponent = ({ const { toast } = useToast(); const workspaceName = useAppStore((state) => state.activeWorkspaceName); const [newTag, setNewTag] = useState(""); - const traceUpdateMutation = useTraceUpdateMutation(); - const spanUpdateMutation = useSpanUpdateMutation(); - const MAX_ENTITIES = 10; + const traceBatchUpdateMutation = useTraceBatchUpdateMutation(); + const spanBatchUpdateMutation = useSpanBatchUpdateMutation(); + const threadBatchUpdateMutation = useThreadBatchUpdateMutation(); const handleClose = () => { setOpen(false); @@ -47,54 +54,55 @@ const AddTagDialog: React.FunctionComponent = ({ const handleAddTag = () => { if (!newTag) return; - const promises: Promise[] = []; - - rows.forEach((row) => { - const currentTags = row.tags || []; - - if (currentTags.includes(newTag)) return; - - const newTags = [...currentTags, newTag]; - - if (type === TRACE_DATA_TYPE.traces) { - promises.push( - traceUpdateMutation.mutateAsync({ - projectId, - traceId: row.id, - trace: { - workspace_name: workspaceName, - project_id: projectId, - tags: newTags, - }, - }), - ); - } else { - const span = row as Span; - const parentId = span.parent_span_id; + let mutationPromise; + let entityName; - promises.push( - spanUpdateMutation.mutateAsync({ - projectId, - spanId: span.id, - span: { - workspace_name: workspaceName, - project_id: projectId, - ...(parentId && { parent_span_id: parentId }), - trace_id: span.trace_id, - tags: newTags, - }, - }), - ); - } - }); + if (type === TRACE_DATA_TYPE.traces) { + const ids = rows.map((row) => row.id); + mutationPromise = traceBatchUpdateMutation.mutateAsync({ + projectId, + traceIds: ids, + trace: { + workspace_name: workspaceName, + project_id: projectId, + tags: [newTag], + }, + mergeTags: true, + }); + entityName = "traces"; + } else if (type === TRACE_DATA_TYPE.spans) { + const ids = rows.map((row) => row.id); + mutationPromise = spanBatchUpdateMutation.mutateAsync({ + projectId, + spanIds: ids, + span: { + workspace_name: workspaceName, + project_id: projectId, + trace_id: "00000000-0000-0000-0000-000000000000", // Placeholder - not used by backend batch update, just to bypass validation + tags: [newTag], + }, + mergeTags: true, + }); + entityName = "spans"; + } else { + // threads - use thread_model_id instead of id + const threadModelIds = rows.map((row) => (row as Thread).thread_model_id); + mutationPromise = threadBatchUpdateMutation.mutateAsync({ + projectId, + threadIds: threadModelIds, + thread: { + tags: [newTag], + }, + mergeTags: true, + }); + entityName = "threads"; + } - Promise.all(promises) + mutationPromise .then(() => { toast({ title: "Success", - description: `Tag "${newTag}" added to ${rows.length} selected ${ - type === TRACE_DATA_TYPE.traces ? "traces" : "spans" - }`, + description: `Tag "${newTag}" added to ${rows.length} selected ${entityName}`, }); if (onSuccess) { @@ -104,7 +112,7 @@ const AddTagDialog: React.FunctionComponent = ({ handleClose(); }) .catch(() => { - // Error handling is already done by the mutation hooks,this just ensures we don't close the dialog on error + // Error handling is already done by the mutation hooks }); }; @@ -114,15 +122,13 @@ const AddTagDialog: React.FunctionComponent = ({ Add tag to {rows.length}{" "} - {type === TRACE_DATA_TYPE.traces ? "traces" : "spans"} + {type === TRACE_DATA_TYPE.traces + ? "traces" + : type === TRACE_DATA_TYPE.spans + ? "spans" + : "threads"} - {rows.length > MAX_ENTITIES && ( -
- You can only add tags to up to {MAX_ENTITIES} entities at a time. - Please select fewer entities. -
- )}
= ({ value={newTag} onChange={(event) => setNewTag(event.target.value)} className="col-span-3" - disabled={rows.length > MAX_ENTITIES} />
@@ -138,10 +143,7 @@ const AddTagDialog: React.FunctionComponent = ({ - diff --git a/apps/opik-frontend/src/components/pages/TracesPage/ThreadsTab/ThreadsActionsPanel.tsx b/apps/opik-frontend/src/components/pages/TracesPage/ThreadsTab/ThreadsActionsPanel.tsx index 3c072dcb662..4676e46cff3 100644 --- a/apps/opik-frontend/src/components/pages/TracesPage/ThreadsTab/ThreadsActionsPanel.tsx +++ b/apps/opik-frontend/src/components/pages/TracesPage/ThreadsTab/ThreadsActionsPanel.tsx @@ -1,5 +1,5 @@ import React, { useCallback, useRef, useState } from "react"; -import { Trash, Brain } from "lucide-react"; +import { Trash, Brain, Tag } from "lucide-react"; import get from "lodash/get"; import first from "lodash/first"; import slugify from "slugify"; @@ -13,6 +13,9 @@ import ExportToButton from "@/components/shared/ExportToButton/ExportToButton"; import AddToDropdown from "@/components/pages-shared/traces/AddToDropdown/AddToDropdown"; import { COLUMN_FEEDBACK_SCORES_ID } from "@/types/shared"; import RunEvaluationDialog from "@/components/pages-shared/automations/RunEvaluationDialog/RunEvaluationDialog"; +import AddTagDialog, { + TAG_ENTITY_TYPE, +} from "@/components/pages-shared/traces/AddTagDialog/AddTagDialog"; type ThreadsActionsPanelProps = { getDataForExport: () => Promise; @@ -90,9 +93,17 @@ const ThreadsActionsPanel: React.FunctionComponent< confirmText="Delete threads" confirmButtonVariant="destructive" /> + row.thread_model_id)} @@ -104,7 +115,7 @@ const ThreadsActionsPanel: React.FunctionComponent< disabled={disabled} dataType="threads" /> - + + + +