Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package com.comet.opik.api;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.Valid;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;
import lombok.Builder;

import java.util.Set;
import java.util.UUID;

@Builder(toBuilder = true)
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
@Schema(description = "Request to batch update multiple spans")
public record SpanBatchUpdate(
@NotNull @NotEmpty @Size(min = 1, max = 1000) @Schema(description = "List of span IDs to update (max 1000)") Set<UUID> ids,
@NotNull @Valid @Schema(description = "Update to apply to all spans") SpanUpdate update,
@Schema(description = "If true, merge tags with existing tags instead of replacing them. Default: false") Boolean mergeTags) {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package com.comet.opik.api;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.Valid;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;
import lombok.Builder;

import java.util.Set;
import java.util.UUID;

@Builder(toBuilder = true)
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
@Schema(description = "Request to batch update multiple traces")
public record TraceBatchUpdate(
@NotNull @NotEmpty @Size(min = 1, max = 1000) @Schema(description = "List of trace IDs to update (max 1000)") Set<UUID> ids,
@NotNull @Valid @Schema(description = "Update to apply to all traces") TraceUpdate update,
@Schema(description = "If true, merge tags with existing tags instead of replacing them. Default: false") Boolean mergeTags) {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package com.comet.opik.api;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.Valid;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;
import lombok.Builder;

import java.util.Set;
import java.util.UUID;

@Builder(toBuilder = true)
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
@Schema(description = "Request to batch update multiple trace threads")
public record TraceThreadBatchUpdate(
@NotNull @NotEmpty @Size(min = 1, max = 1000) @Schema(description = "List of thread model IDs to update (max 1000)") Set<UUID> ids,
@NotNull @Valid @Schema(description = "Update to apply to all threads") TraceThreadUpdate update,
@Schema(description = "If true, merge tags with existing tags instead of replacing them. Default: false") Boolean mergeTags) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import com.comet.opik.api.ProjectStats;
import com.comet.opik.api.Span;
import com.comet.opik.api.SpanBatch;
import com.comet.opik.api.SpanBatchUpdate;
import com.comet.opik.api.SpanSearchStreamRequest;
import com.comet.opik.api.SpanUpdate;
import com.comet.opik.api.filter.FiltersFactory;
Expand Down Expand Up @@ -232,6 +233,28 @@ public Response createSpans(
return Response.noContent().build();
}

@PATCH
@Path("/batch")
@Operation(operationId = "batchUpdateSpans", summary = "Batch update spans", description = "Update multiple spans", responses = {
@ApiResponse(responseCode = "204", description = "No Content"),
@ApiResponse(responseCode = "400", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class)))})
@RateLimited
public Response batchUpdate(
@RequestBody(content = @Content(schema = @Schema(implementation = SpanBatchUpdate.class))) @Valid @NotNull SpanBatchUpdate batchUpdate) {

String workspaceId = requestContext.get().getWorkspaceId();

log.info("Batch updating '{}' spans on workspaceId '{}'", batchUpdate.ids().size(), workspaceId);

spanService.batchUpdate(batchUpdate)
.contextWrite(ctx -> setRequestContext(ctx, requestContext))
.block();

log.info("Batch updated '{}' spans on workspaceId '{}'", batchUpdate.ids().size(), workspaceId);

return Response.noContent().build();
}

@PATCH
@Path("{id}")
@Operation(operationId = "updateSpan", summary = "Update span by id", description = "Update span by id", responses = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
import com.comet.opik.api.Trace;
import com.comet.opik.api.Trace.TracePage;
import com.comet.opik.api.TraceBatch;
import com.comet.opik.api.TraceBatchUpdate;
import com.comet.opik.api.TraceSearchStreamRequest;
import com.comet.opik.api.TraceThread;
import com.comet.opik.api.TraceThreadBatchIdentifier;
import com.comet.opik.api.TraceThreadBatchUpdate;
import com.comet.opik.api.TraceThreadIdentifier;
import com.comet.opik.api.TraceThreadSearchStreamRequest;
import com.comet.opik.api.TraceThreadUpdate;
Expand Down Expand Up @@ -305,6 +307,28 @@ public Response createTraces(
return Response.noContent().build();
}

@PATCH
@Path("/batch")
@Operation(operationId = "batchUpdateTraces", summary = "Batch update traces", description = "Update multiple traces", responses = {
@ApiResponse(responseCode = "204", description = "No Content"),
@ApiResponse(responseCode = "400", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class)))})
@RateLimited
public Response batchUpdate(
@RequestBody(content = @Content(schema = @Schema(implementation = TraceBatchUpdate.class))) @Valid @NotNull TraceBatchUpdate batchUpdate) {

String workspaceId = requestContext.get().getWorkspaceId();

log.info("Batch updating '{}' traces on workspaceId '{}'", batchUpdate.ids().size(), workspaceId);

service.batchUpdate(batchUpdate)
.contextWrite(ctx -> setRequestContext(ctx, requestContext))
.block();

log.info("Batch updated '{}' traces on workspaceId '{}'", batchUpdate.ids().size(), workspaceId);

return Response.noContent().build();
}

@PATCH
@Path("{id}")
@Operation(operationId = "updateTrace", summary = "Update trace by id", description = "Update trace by id", responses = {
Expand Down Expand Up @@ -787,6 +811,28 @@ public Response closeTraceThread(
return Response.noContent().build();
}

@PATCH
@Path("/threads/batch")
@Operation(operationId = "batchUpdateThreads", summary = "Batch update threads", description = "Update multiple threads", responses = {
@ApiResponse(responseCode = "204", description = "No Content"),
@ApiResponse(responseCode = "400", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class)))})
@RateLimited
public Response batchUpdateThreads(
@RequestBody(content = @Content(schema = @Schema(implementation = TraceThreadBatchUpdate.class))) @Valid @NotNull TraceThreadBatchUpdate batchUpdate) {

String workspaceId = requestContext.get().getWorkspaceId();

log.info("Batch updating '{}' threads on workspaceId '{}'", batchUpdate.ids().size(), workspaceId);

traceThreadService.batchUpdate(batchUpdate)
.contextWrite(ctx -> setRequestContext(ctx, requestContext))
.block();

log.info("Batch updated '{}' threads on workspaceId '{}'", batchUpdate.ids().size(), workspaceId);

return Response.noContent().build();
}

@PATCH
@Path("/threads/{threadModelId}")
@Operation(operationId = "updateThread", summary = "Update thread", description = "Update thread", responses = {
Expand Down
162 changes: 162 additions & 0 deletions apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java
Original file line number Diff line number Diff line change
Expand Up @@ -2253,6 +2253,168 @@ private void bindCost(Span span, Statement statement, String index) {
}
}

private static final String BULK_UPDATE = """
INSERT INTO spans (
id,
project_id,
workspace_id,
trace_id,
parent_span_id,
name,
type,
start_time,
end_time,
input,
output,
metadata,
model,
provider,
total_estimated_cost,
total_estimated_cost_version,
tags,
usage,
error_info,
created_at,
created_by,
last_updated_by,
truncation_threshold
)
SELECT
s.id,
s.project_id,
s.workspace_id,
s.trace_id,
s.parent_span_id,
<if(name)> :name <else> s.name <endif> as name,
<if(type)> :type <else> s.type <endif> as type,
s.start_time,
<if(end_time)> parseDateTime64BestEffort(:end_time, 9) <else> s.end_time <endif> as end_time,
<if(input)> :input <else> s.input <endif> as input,
<if(output)> :output <else> s.output <endif> as output,
<if(metadata)> :metadata <else> s.metadata <endif> as metadata,
<if(model)> :model <else> s.model <endif> as model,
<if(provider)> :provider <else> s.provider <endif> as provider,
<if(total_estimated_cost)> toDecimal128(:total_estimated_cost, 12) <else> s.total_estimated_cost <endif> as total_estimated_cost,
<if(total_estimated_cost_version)> :total_estimated_cost_version <else> s.total_estimated_cost_version <endif> as total_estimated_cost_version,
<if(tags)><if(merge_tags)>arrayConcat(s.tags, :tags)<else>:tags<endif><else>s.tags<endif> as tags,
<if(usage)> CAST((:usageKeys, :usageValues), 'Map(String, Int64)') <else> s.usage <endif> as usage,
<if(error_info)> :error_info <else> s.error_info <endif> as error_info,
s.created_at,
s.created_by,
:user_name as last_updated_by,
:truncation_threshold
FROM spans s
WHERE s.id IN :ids AND s.workspace_id = :workspace_id
ORDER BY (s.workspace_id, s.project_id, s.trace_id, s.parent_span_id, s.id) DESC, s.last_updated_at DESC
LIMIT 1 BY s.id;
""";

@WithSpan
public Mono<Void> bulkUpdate(@NonNull Set<UUID> ids, @NonNull SpanUpdate update, boolean mergeTags) {
Preconditions.checkArgument(!ids.isEmpty(), "ids must not be empty");
log.info("Bulk updating '{}' spans", ids.size());

var template = newBulkUpdateTemplate(update, BULK_UPDATE, mergeTags);
var query = template.render();

return Mono.from(connectionFactory.create())
.flatMapMany(connection -> {
var statement = connection.createStatement(query)
.bind("ids", ids);

bindBulkUpdateParams(update, statement);
TruncationUtils.bindTruncationThreshold(statement, "truncation_threshold", configuration);

Segment segment = startSegment("spans", "Clickhouse", "bulk_update");

return makeFluxContextAware(bindUserNameAndWorkspaceContextToStream(statement))
.doFinally(signalType -> endSegment(segment));
})
.then()
.doOnSuccess(__ -> log.info("Completed bulk update for '{}' spans", ids.size()));
}

private ST newBulkUpdateTemplate(SpanUpdate spanUpdate, String sql, boolean mergeTags) {
var template = TemplateUtils.newST(sql);

if (StringUtils.isNotBlank(spanUpdate.name())) {
template.add("name", spanUpdate.name());
}
Optional.ofNullable(spanUpdate.type())
.ifPresent(type -> template.add("type", type.toString()));
Optional.ofNullable(spanUpdate.input())
.ifPresent(input -> template.add("input", input.toString()));
Optional.ofNullable(spanUpdate.output())
.ifPresent(output -> template.add("output", output.toString()));
Optional.ofNullable(spanUpdate.tags())
.ifPresent(tags -> {
template.add("tags", tags.toString());
template.add("merge_tags", mergeTags);
});
Optional.ofNullable(spanUpdate.metadata())
.ifPresent(metadata -> template.add("metadata", metadata.toString()));
if (StringUtils.isNotBlank(spanUpdate.model())) {
template.add("model", spanUpdate.model());
}
if (StringUtils.isNotBlank(spanUpdate.provider())) {
template.add("provider", spanUpdate.provider());
}
Optional.ofNullable(spanUpdate.endTime())
.ifPresent(endTime -> template.add("end_time", endTime.toString()));
Optional.ofNullable(spanUpdate.usage())
.ifPresent(usage -> template.add("usage", usage.toString()));
Optional.ofNullable(spanUpdate.errorInfo())
.ifPresent(errorInfo -> template.add("error_info", JsonUtils.readTree(errorInfo).toString()));

if (spanUpdate.totalEstimatedCost() != null) {
template.add("total_estimated_cost", "total_estimated_cost");
template.add("total_estimated_cost_version", "total_estimated_cost_version");
}
return template;
}

private void bindBulkUpdateParams(SpanUpdate spanUpdate, Statement statement) {
if (StringUtils.isNotBlank(spanUpdate.name())) {
statement.bind("name", spanUpdate.name());
}
Optional.ofNullable(spanUpdate.type())
.ifPresent(type -> statement.bind("type", type.toString()));
Optional.ofNullable(spanUpdate.input())
.ifPresent(input -> statement.bind("input", input.toString()));
Optional.ofNullable(spanUpdate.output())
.ifPresent(output -> statement.bind("output", output.toString()));
Optional.ofNullable(spanUpdate.tags())
.ifPresent(tags -> statement.bind("tags", tags.toArray(String[]::new)));
Optional.ofNullable(spanUpdate.usage())
.ifPresent(usage -> {
var usageKeys = new ArrayList<String>();
var usageValues = new ArrayList<Integer>();
for (var entry : usage.entrySet()) {
usageKeys.add(entry.getKey());
usageValues.add(entry.getValue());
}
statement.bind("usageKeys", usageKeys.toArray(String[]::new));
statement.bind("usageValues", usageValues.toArray(Integer[]::new));
});
Optional.ofNullable(spanUpdate.endTime())
.ifPresent(endTime -> statement.bind("end_time", endTime.toString()));
Optional.ofNullable(spanUpdate.metadata())
.ifPresent(metadata -> statement.bind("metadata", metadata.toString()));
if (StringUtils.isNotBlank(spanUpdate.model())) {
statement.bind("model", spanUpdate.model());
}
if (StringUtils.isNotBlank(spanUpdate.provider())) {
statement.bind("provider", spanUpdate.provider());
}
Optional.ofNullable(spanUpdate.errorInfo())
.ifPresent(errorInfo -> statement.bind("error_info", JsonUtils.readTree(errorInfo).toString()));

if (spanUpdate.totalEstimatedCost() != null) {
statement.bind("total_estimated_cost", spanUpdate.totalEstimatedCost().toString());
statement.bind("total_estimated_cost_version", "");
}
}

private JsonNode getMetadataWithProvider(Row row, Set<SpanField> exclude, String provider) {
// Parse base metadata from database
JsonNode baseMetadata = Optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import com.comet.opik.api.ProjectStats;
import com.comet.opik.api.Span;
import com.comet.opik.api.SpanBatch;
import com.comet.opik.api.SpanBatchUpdate;
import com.comet.opik.api.SpanUpdate;
import com.comet.opik.api.SpansCountResponse;
import com.comet.opik.api.attachment.AttachmentInfo;
Expand Down Expand Up @@ -204,6 +205,15 @@ public Mono<Void> update(@NonNull UUID id, @NonNull SpanUpdate spanUpdate) {
.then()))));
}

@WithSpan
public Mono<Void> batchUpdate(@NonNull SpanBatchUpdate batchUpdate) {
log.info("Batch updating '{}' spans", batchUpdate.ids().size());

boolean mergeTags = Boolean.TRUE.equals(batchUpdate.mergeTags());
return spanDAO.bulkUpdate(batchUpdate.ids(), batchUpdate.update(), mergeTags)
.doOnSuccess(__ -> log.info("Completed batch update for '{}' spans", batchUpdate.ids().size()));
}

private Mono<Long> insertUpdate(Project project, SpanUpdate spanUpdate, UUID id) {
return IdGenerator
.validateVersionAsync(id, SPAN_KEY)
Expand Down
Loading
Loading