diff --git a/go/pkg/sysdb/coordinator/coordinator.go b/go/pkg/sysdb/coordinator/coordinator.go index 86fcc12cccf..a8ada5723cb 100644 --- a/go/pkg/sysdb/coordinator/coordinator.go +++ b/go/pkg/sysdb/coordinator/coordinator.go @@ -297,6 +297,16 @@ func (s *Coordinator) FlushCollectionCompactionAndAttachedFunction( return s.catalog.FlushCollectionCompactionAndAttachedFunction(ctx, flushCollectionCompaction, attachedFunctionID, runNonce, completionOffset) } +func (s *Coordinator) FlushCollectionCompactionAndAttachedFunctionExtended( + ctx context.Context, + collectionCompactions []*model.FlushCollectionCompaction, + attachedFunctionID uuid.UUID, + runNonce uuid.UUID, + completionOffset int64, +) (*model.ExtendedFlushCollectionInfo, error) { + return s.catalog.FlushCollectionCompactionAndAttachedFunctionExtended(ctx, collectionCompactions, attachedFunctionID, runNonce, completionOffset) +} + func (s *Coordinator) ListCollectionsToGc(ctx context.Context, cutoffTimeSecs *uint64, limit *uint64, tenantID *string, minVersionsIfAlive *uint64) ([]*model.CollectionToGc, error) { return s.catalog.ListCollectionsToGc(ctx, cutoffTimeSecs, limit, tenantID, minVersionsIfAlive) } diff --git a/go/pkg/sysdb/coordinator/create_task_test.go b/go/pkg/sysdb/coordinator/create_task_test.go index 2b187f5872d..5bf6eaaccaa 100644 --- a/go/pkg/sysdb/coordinator/create_task_test.go +++ b/go/pkg/sysdb/coordinator/create_task_test.go @@ -16,7 +16,6 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/structpb" - "google.golang.org/protobuf/types/known/timestamppb" ) // testMinimalUUIDv7 is the test's copy of minimalUUIDv7 from task.go @@ -78,6 +77,7 @@ type AttachFunctionTestSuite struct { mockFunctionDb *dbmodel_mocks.IFunctionDb mockDatabaseDb *dbmodel_mocks.IDatabaseDb mockCollectionDb *dbmodel_mocks.ICollectionDb + mockSegmentDb *dbmodel_mocks.ISegmentDb mockHeapClient *MockHeapClient coordinator *Coordinator } @@ -147,6 +147,9 @@ func (suite *AttachFunctionTestSuite) SetupTest() { suite.mockCollectionDb = &dbmodel_mocks.ICollectionDb{} suite.mockCollectionDb.Test(suite.T()) + suite.mockSegmentDb = &dbmodel_mocks.ISegmentDb{} + suite.mockSegmentDb.Test(suite.T()) + suite.mockHeapClient = new(MockHeapClient) suite.mockHeapClient.Test(suite.T()) @@ -176,7 +179,7 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_SuccessfulCreation_With functionName := "record_counter" tenantID := "test-tenant" databaseName := "test-database" - databaseID := "database-uuid" + databaseID := uuid.New().String() functionID := uuid.New() MinRecordsForInvocation := uint64(100) @@ -199,33 +202,69 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_SuccessfulCreation_With // ===== Phase 1: Attach function in transaction ===== // Setup mocks that will be called within the transaction (using mock.Anything for context) - // Check if attached function exists (idempotency check inside transaction) - suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Once() - suite.mockAttachedFunctionDb.On("GetByName", inputCollectionID, attachedFunctionName). - Return(nil, nil).Once() - // Look up database + // Look up database (first) suite.mockMetaDomain.On("DatabaseDb", mock.Anything).Return(suite.mockDatabaseDb).Once() suite.mockDatabaseDb.On("GetDatabases", tenantID, databaseName). Return([]*dbmodel.Database{{ID: databaseID, Name: databaseName}}, nil).Once() - // Look up function + // Look up function (second) suite.mockMetaDomain.On("FunctionDb", mock.Anything).Return(suite.mockFunctionDb).Once() suite.mockFunctionDb.On("GetByName", functionName). Return(&dbmodel.Function{ID: functionID, Name: functionName}, nil).Once() + // Check if attached function exists - idempotency check (third) + suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Once() + suite.mockAttachedFunctionDb.On("GetByName", inputCollectionID, attachedFunctionName). + Return(nil, nil).Once() + // Check input collection exists suite.mockMetaDomain.On("CollectionDb", mock.Anything).Return(suite.mockCollectionDb).Once() suite.mockCollectionDb.On("GetCollections", []string{inputCollectionID}, (*string)(nil), tenantID, databaseName, (*int32)(nil), (*int32)(nil), false). Return([]*dbmodel.CollectionAndMetadata{{Collection: &dbmodel.Collection{ID: inputCollectionID}}}, nil).Once() - // Check output collection doesn't exist + // createCollectionImpl: look up database for output collection creation + suite.mockMetaDomain.On("DatabaseDb", mock.Anything).Return(suite.mockDatabaseDb).Once() + suite.mockDatabaseDb.On("GetDatabases", tenantID, databaseName). + Return([]*dbmodel.Database{{ID: databaseID, Name: databaseName}}, nil).Once() + + // createCollectionImpl: Check output collection doesn't exist suite.mockMetaDomain.On("CollectionDb", mock.Anything).Return(suite.mockCollectionDb).Once() suite.mockCollectionDb.On("GetCollections", []string(nil), &outputCollectionName, tenantID, databaseName, (*int32)(nil), (*int32)(nil), false). Return([]*dbmodel.CollectionAndMetadata{}, nil).Once() + // createCollectionImpl: Insert the collection + suite.mockMetaDomain.On("CollectionDb", mock.Anything).Return(suite.mockCollectionDb).Once() + suite.mockCollectionDb.On("Insert", mock.AnythingOfType("*dbmodel.Collection")).Return(nil).Once() + + // createCollectionImpl: Get the created collection + outputCollectionUUID := uuid.New().String() + configJSON := "{}" + suite.mockMetaDomain.On("CollectionDb", mock.Anything).Return(suite.mockCollectionDb).Once() + suite.mockCollectionDb.On("GetCollections", + []string(nil), &outputCollectionName, tenantID, databaseName, (*int32)(nil), (*int32)(nil), false). + Return([]*dbmodel.CollectionAndMetadata{{ + Collection: &dbmodel.Collection{ + ID: outputCollectionUUID, + Name: &outputCollectionName, + DatabaseID: databaseID, + ConfigurationJsonStr: &configJSON, + }, + TenantID: tenantID, + DatabaseName: databaseName, + }}, nil).Once() + + // createSegmentImpl: Insert the 3 segments (vector, metadata, record) + suite.mockMetaDomain.On("SegmentDb", mock.Anything).Return(suite.mockSegmentDb).Times(3) + suite.mockSegmentDb.On("Insert", mock.AnythingOfType("*dbmodel.Segment")).Return(nil).Times(3) + + // createSegmentImpl: Get each created segment (3 times) + suite.mockMetaDomain.On("SegmentDb", mock.Anything).Return(suite.mockSegmentDb).Times(3) + suite.mockSegmentDb.On("GetSegments", mock.AnythingOfType("types.UniqueID"), (*string)(nil), (*string)(nil), mock.AnythingOfType("types.UniqueID")). + Return([]*dbmodel.SegmentAndMetadata{{Segment: &dbmodel.Segment{ID: uuid.New().String()}}}, nil).Times(3) + // Insert attached function with lowest_live_nonce = NULL suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Once() suite.mockAttachedFunctionDb.On("Insert", mock.MatchedBy(func(attachedFunction *dbmodel.AttachedFunction) bool { @@ -275,10 +314,11 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_SuccessfulCreation_With // Assertions suite.NoError(err) suite.NotNil(response) - suite.NotEmpty(response.Id) + suite.NotNil(response.AttachedFunction) + suite.NotEmpty(response.AttachedFunction.Id) // Verify attached function ID is valid UUID - attachedFunctionID, err := uuid.Parse(response.Id) + attachedFunctionID, err := uuid.Parse(response.AttachedFunction.Id) suite.NoError(err) suite.NotEqual(uuid.Nil, attachedFunctionID) @@ -307,7 +347,7 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_IdempotentRequest_Alrea functionName := "record_counter" tenantID := "test-tenant" databaseName := "test-database" - databaseID := "database-uuid" + databaseID := uuid.New().String() functionID := uuid.New() MinRecordsForInvocation := uint64(100) nextNonce := uuid.Must(uuid.NewV7()) @@ -348,23 +388,32 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_IdempotentRequest_Alrea UpdatedAt: now, } - // ===== Phase 1: Transaction checks if attached function exists ===== - suite.mockMetaDomain.On("AttachedFunctionDb", ctx).Return(suite.mockAttachedFunctionDb).Once() - suite.mockAttachedFunctionDb.On("GetByName", inputCollectionID, attachedFunctionName). - Return(existingAttachedFunction, nil).Once() - // Mock transaction call suite.mockTxImpl.On("Transaction", ctx, mock.AnythingOfType("func(context.Context) error")). Run(func(args mock.Arguments) { txFunc := args.Get(1).(func(context.Context) error) txCtx := context.Background() - // Inside transaction: validate function by ID + // Inside transaction: Look up database + suite.mockMetaDomain.On("DatabaseDb", txCtx).Return(suite.mockDatabaseDb).Once() + suite.mockDatabaseDb.On("GetDatabases", tenantID, databaseName). + Return([]*dbmodel.Database{{ID: databaseID, Name: databaseName}}, nil).Once() + + // Look up function by name + suite.mockMetaDomain.On("FunctionDb", txCtx).Return(suite.mockFunctionDb).Once() + suite.mockFunctionDb.On("GetByName", functionName). + Return(&dbmodel.Function{ID: functionID, Name: functionName}, nil).Once() + + // Check if attached function already exists (idempotency check) + suite.mockMetaDomain.On("AttachedFunctionDb", txCtx).Return(suite.mockAttachedFunctionDb).Once() + suite.mockAttachedFunctionDb.On("GetByName", inputCollectionID, attachedFunctionName). + Return(existingAttachedFunction, nil).Once() + + // Validate existing attached function matches request suite.mockMetaDomain.On("FunctionDb", txCtx).Return(suite.mockFunctionDb).Once() suite.mockFunctionDb.On("GetByID", functionID). Return(&dbmodel.Function{ID: functionID, Name: functionName}, nil).Once() - // Validate database matches suite.mockMetaDomain.On("DatabaseDb", txCtx).Return(suite.mockDatabaseDb).Once() suite.mockDatabaseDb.On("GetDatabases", tenantID, databaseName). Return([]*dbmodel.Database{{ID: databaseID, Name: databaseName}}, nil).Once() @@ -378,7 +427,8 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_IdempotentRequest_Alrea // Assertions suite.NoError(err) suite.NotNil(response) - suite.Equal(existingAttachedFunctionID.String(), response.Id) + suite.NotNil(response.AttachedFunction) + suite.Equal(existingAttachedFunctionID.String(), response.AttachedFunction.Id) // Verify no writes occurred (no Insert, no UpdateLowestLiveNonce, no heap Push) // Note: Transaction IS called for idempotency check, but no writes happen inside it @@ -412,7 +462,7 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_RecoveryFlow_HeapFailur functionName := "record_counter" tenantID := "test-tenant" databaseName := "test-database" - databaseID := "database-uuid" + databaseID := uuid.New().String() functionID := uuid.New() MinRecordsForInvocation := uint64(100) nextNonce := uuid.Must(uuid.NewV7()) @@ -438,28 +488,67 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_RecoveryFlow_HeapFailur // ========== FIRST ATTEMPT: Heap Push Fails ========== // Phase 1: Create attached function in transaction - suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Once() - suite.mockAttachedFunctionDb.On("GetByName", inputCollectionID, attachedFunctionName). - Return(nil, nil).Once() - + // Look up database (first) suite.mockMetaDomain.On("DatabaseDb", mock.Anything).Return(suite.mockDatabaseDb).Once() suite.mockDatabaseDb.On("GetDatabases", tenantID, databaseName). Return([]*dbmodel.Database{{ID: databaseID, Name: databaseName}}, nil).Once() + // Look up function (second) suite.mockMetaDomain.On("FunctionDb", mock.Anything).Return(suite.mockFunctionDb).Once() suite.mockFunctionDb.On("GetByName", functionName). Return(&dbmodel.Function{ID: functionID, Name: functionName}, nil).Once() + // Check if attached function exists - idempotency check (third) + suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Once() + suite.mockAttachedFunctionDb.On("GetByName", inputCollectionID, attachedFunctionName). + Return(nil, nil).Once() + suite.mockMetaDomain.On("CollectionDb", mock.Anything).Return(suite.mockCollectionDb).Once() suite.mockCollectionDb.On("GetCollections", []string{inputCollectionID}, (*string)(nil), tenantID, databaseName, (*int32)(nil), (*int32)(nil), false). Return([]*dbmodel.CollectionAndMetadata{{Collection: &dbmodel.Collection{ID: inputCollectionID}}}, nil).Once() + // createCollectionImpl: look up database for output collection creation + suite.mockMetaDomain.On("DatabaseDb", mock.Anything).Return(suite.mockDatabaseDb).Once() + suite.mockDatabaseDb.On("GetDatabases", tenantID, databaseName). + Return([]*dbmodel.Database{{ID: databaseID, Name: databaseName}}, nil).Once() + + // createCollectionImpl: Check output collection doesn't exist suite.mockMetaDomain.On("CollectionDb", mock.Anything).Return(suite.mockCollectionDb).Once() suite.mockCollectionDb.On("GetCollections", []string(nil), &outputCollectionName, tenantID, databaseName, (*int32)(nil), (*int32)(nil), false). Return([]*dbmodel.CollectionAndMetadata{}, nil).Once() + // createCollectionImpl: Insert the collection + suite.mockMetaDomain.On("CollectionDb", mock.Anything).Return(suite.mockCollectionDb).Once() + suite.mockCollectionDb.On("Insert", mock.AnythingOfType("*dbmodel.Collection")).Return(nil).Once() + + // createCollectionImpl: Get the created collection + outputCollectionUUID2 := uuid.New().String() + configJSON2 := "{}" + suite.mockMetaDomain.On("CollectionDb", mock.Anything).Return(suite.mockCollectionDb).Once() + suite.mockCollectionDb.On("GetCollections", + []string(nil), &outputCollectionName, tenantID, databaseName, (*int32)(nil), (*int32)(nil), false). + Return([]*dbmodel.CollectionAndMetadata{{ + Collection: &dbmodel.Collection{ + ID: outputCollectionUUID2, + Name: &outputCollectionName, + DatabaseID: databaseID, + ConfigurationJsonStr: &configJSON2, + }, + TenantID: tenantID, + DatabaseName: databaseName, + }}, nil).Once() + + // createSegmentImpl: Insert the 3 segments (vector, metadata, record) + suite.mockMetaDomain.On("SegmentDb", mock.Anything).Return(suite.mockSegmentDb).Times(3) + suite.mockSegmentDb.On("Insert", mock.AnythingOfType("*dbmodel.Segment")).Return(nil).Times(3) + + // createSegmentImpl: Get each created segment (3 times) + suite.mockMetaDomain.On("SegmentDb", mock.Anything).Return(suite.mockSegmentDb).Times(3) + suite.mockSegmentDb.On("GetSegments", mock.AnythingOfType("types.UniqueID"), (*string)(nil), (*string)(nil), mock.AnythingOfType("types.UniqueID")). + Return([]*dbmodel.SegmentAndMetadata{{Segment: &dbmodel.Segment{ID: uuid.New().String()}}}, nil).Times(3) + suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Once() suite.mockAttachedFunctionDb.On("Insert", mock.MatchedBy(func(attachedFunction *dbmodel.AttachedFunction) bool { return attachedFunction.LowestLiveNonce == nil @@ -503,26 +592,37 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_RecoveryFlow_HeapFailur // ========== SECOND ATTEMPT: Recovery Succeeds ========== - // Phase 0: GetByName returns incomplete attached function (with ErrAttachedFunctionNotReady, which AttachFunction handles) - suite.mockMetaDomain.On("AttachedFunctionDb", ctx).Return(suite.mockAttachedFunctionDb).Once() - suite.mockAttachedFunctionDb.On("GetByName", inputCollectionID, attachedFunctionName). - Return(incompleteAttachedFunction, nil).Once() - - // Validate function matches - suite.mockMetaDomain.On("FunctionDb", ctx).Return(suite.mockFunctionDb).Once() - suite.mockFunctionDb.On("GetByID", functionID). - Return(&dbmodel.Function{ID: functionID, Name: functionName}, nil).Once() - - // Validate database matches (inside validateTaskMatchesRequest, called within transaction) - suite.mockMetaDomain.On("DatabaseDb", mock.Anything).Return(suite.mockDatabaseDb).Once() - suite.mockDatabaseDb.On("GetDatabases", tenantID, databaseName). - Return([]*dbmodel.Database{{ID: databaseID, Name: databaseName}}, nil).Once() - - // Mock the Transaction call + // Mock the Transaction call - inside the transaction, we'll find the incomplete attached function suite.mockTxImpl.On("Transaction", ctx, mock.AnythingOfType("func(context.Context) error")). Run(func(args mock.Arguments) { txFunc := args.Get(1).(func(context.Context) error) - _ = txFunc(context.Background()) + txCtx := context.Background() + + // Inside transaction: Look up database + suite.mockMetaDomain.On("DatabaseDb", txCtx).Return(suite.mockDatabaseDb).Once() + suite.mockDatabaseDb.On("GetDatabases", tenantID, databaseName). + Return([]*dbmodel.Database{{ID: databaseID, Name: databaseName}}, nil).Once() + + // Look up function by name + suite.mockMetaDomain.On("FunctionDb", txCtx).Return(suite.mockFunctionDb).Once() + suite.mockFunctionDb.On("GetByName", functionName). + Return(&dbmodel.Function{ID: functionID, Name: functionName}, nil).Once() + + // Check if attached function exists - returns incomplete (idempotency check) + suite.mockMetaDomain.On("AttachedFunctionDb", txCtx).Return(suite.mockAttachedFunctionDb).Once() + suite.mockAttachedFunctionDb.On("GetByName", inputCollectionID, attachedFunctionName). + Return(incompleteAttachedFunction, nil).Once() + + // Validate existing attached function matches request + suite.mockMetaDomain.On("FunctionDb", txCtx).Return(suite.mockFunctionDb).Once() + suite.mockFunctionDb.On("GetByID", functionID). + Return(&dbmodel.Function{ID: functionID, Name: functionName}, nil).Once() + + suite.mockMetaDomain.On("DatabaseDb", txCtx).Return(suite.mockDatabaseDb).Once() + suite.mockDatabaseDb.On("GetDatabases", tenantID, databaseName). + Return([]*dbmodel.Database{{ID: databaseID, Name: databaseName}}, nil).Once() + + _ = txFunc(txCtx) }).Return(nil).Once() // Phase 2: Heap push succeeds this time @@ -546,7 +646,8 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_RecoveryFlow_HeapFailur response2, err2 := suite.coordinator.AttachFunction(ctx, request) suite.NoError(err2) suite.NotNil(response2) - suite.Equal(incompleteAttachedFunctionID.String(), response2.Id) + suite.NotNil(response2.AttachedFunction) + suite.Equal(incompleteAttachedFunctionID.String(), response2.AttachedFunction.Id) // Verify transaction was called in both attempts (idempotency check happens in transaction) suite.mockTxImpl.AssertNumberOfCalls(suite.T(), "Transaction", 2) // First attempt + recovery attempt @@ -573,7 +674,7 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_IdempotentRequest_Param requestedOperatorName := "different_function" // DIFFERENT tenantID := "test-tenant" databaseName := "test-database" - databaseID := "database-uuid" + databaseID := uuid.New().String() existingOperatorID := uuid.New() MinRecordsForInvocation := uint64(100) nextNonce := uuid.Must(uuid.NewV7()) @@ -614,30 +715,42 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_IdempotentRequest_Param UpdatedAt: now, } - // ===== Phase 1: Transaction checks if task exists - finds task with different params ===== - suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Once() - suite.mockAttachedFunctionDb.On("GetByName", inputCollectionID, attachedFunctionName). - Return(existingAttachedFunction, nil).Once() - - // Validate function - returns DIFFERENT function name - suite.mockMetaDomain.On("FunctionDb", mock.Anything).Return(suite.mockFunctionDb).Once() - suite.mockFunctionDb.On("GetByID", existingOperatorID). - Return(&dbmodel.Function{ - ID: existingOperatorID, - Name: existingOperatorName, // Different from request - }, nil).Once() - - // Database lookup happens before the error is returned (inside transaction) - suite.mockMetaDomain.On("DatabaseDb", mock.Anything).Return(suite.mockDatabaseDb).Once() - suite.mockDatabaseDb.On("GetDatabases", tenantID, databaseName). - Return([]*dbmodel.Database{{ID: databaseID, Name: databaseName}}, nil).Once() - // Mock transaction call - it will fail with validation error suite.mockTxImpl.On("Transaction", ctx, mock.AnythingOfType("func(context.Context) error")). Run(func(args mock.Arguments) { txFunc := args.Get(1).(func(context.Context) error) - _ = txFunc(context.Background()) - }).Return(status.Errorf(codes.AlreadyExists, "different function is attached with this name: existing=%s, requested=%s", existingOperatorName, requestedOperatorName)).Once() + txCtx := context.Background() + + // Inside transaction: Look up database + suite.mockMetaDomain.On("DatabaseDb", txCtx).Return(suite.mockDatabaseDb).Once() + suite.mockDatabaseDb.On("GetDatabases", tenantID, databaseName). + Return([]*dbmodel.Database{{ID: databaseID, Name: databaseName}}, nil).Once() + + // Look up function by name (requested function) + suite.mockMetaDomain.On("FunctionDb", txCtx).Return(suite.mockFunctionDb).Once() + suite.mockFunctionDb.On("GetByName", requestedOperatorName). + Return(&dbmodel.Function{ID: uuid.New(), Name: requestedOperatorName}, nil).Once() + + // Check if attached function already exists (idempotency check) + suite.mockMetaDomain.On("AttachedFunctionDb", txCtx).Return(suite.mockAttachedFunctionDb).Once() + suite.mockAttachedFunctionDb.On("GetByName", inputCollectionID, attachedFunctionName). + Return(existingAttachedFunction, nil).Once() + + // Validate existing attached function - returns DIFFERENT function name + suite.mockMetaDomain.On("FunctionDb", txCtx).Return(suite.mockFunctionDb).Once() + suite.mockFunctionDb.On("GetByID", existingOperatorID). + Return(&dbmodel.Function{ + ID: existingOperatorID, + Name: existingOperatorName, // Different from request + }, nil).Once() + + // Database lookup for validation + suite.mockMetaDomain.On("DatabaseDb", txCtx).Return(suite.mockDatabaseDb).Once() + suite.mockDatabaseDb.On("GetDatabases", tenantID, databaseName). + Return([]*dbmodel.Database{{ID: databaseID, Name: databaseName}}, nil).Once() + + _ = txFunc(txCtx) + }).Return(status.Errorf(codes.AlreadyExists, "attached function already exists with different function: existing=%s, requested=%s", existingOperatorName, requestedOperatorName)).Once() // Execute AttachFunction response, err := suite.coordinator.AttachFunction(ctx, request) @@ -645,7 +758,7 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_IdempotentRequest_Param // Assertions - should fail with AlreadyExists error suite.Error(err) suite.Nil(response) - suite.Contains(err.Error(), "different function is attached with this name") + suite.Contains(err.Error(), "attached function already exists with different function") suite.Contains(err.Error(), existingOperatorName) suite.Contains(err.Error(), requestedOperatorName) @@ -667,6 +780,8 @@ func TestAttachFunctionTestSuite(t *testing.T) { // TestGetSoftDeletedAttachedFunctions_TimestampConsistency verifies that timestamps // are returned in microseconds (UnixMicro) to match other API methods +// TODO: Uncomment when GetSoftDeletedAttachedFunctions is implemented +/* func TestGetSoftDeletedAttachedFunctions_TimestampConsistency(t *testing.T) { ctx := context.Background() @@ -746,3 +861,4 @@ func TestGetSoftDeletedAttachedFunctions_TimestampConsistency(t *testing.T) { mockMetaDomain.AssertExpectations(t) mockAttachedFunctionDb.AssertExpectations(t) } +*/ diff --git a/go/pkg/sysdb/coordinator/heap_client_integration_test.go b/go/pkg/sysdb/coordinator/heap_client_integration_test.go index 64001fcedb4..d7d73bcfcb1 100644 --- a/go/pkg/sysdb/coordinator/heap_client_integration_test.go +++ b/go/pkg/sysdb/coordinator/heap_client_integration_test.go @@ -176,7 +176,8 @@ func (suite *HeapClientIntegrationTestSuite) TestAttachFunctionPushesScheduleToH }) suite.NoError(err, "Should attached function successfully") suite.NotNil(response) - suite.NotEmpty(response.Id, "Attached function ID should be returned") + suite.NotNil(response.AttachedFunction) + suite.NotEmpty(response.AttachedFunction.Id, "Attached function ID should be returned") // Get updated heap summary updatedSummary, err := suite.heapClient.Summary(ctx, &coordinatorpb.HeapSummaryRequest{}) @@ -263,7 +264,8 @@ func (suite *HeapClientIntegrationTestSuite) TestPartialTaskRecovery_HybridAppro return } suite.NotNil(taskResp) - originalTaskID := taskResp.Id + suite.NotNil(taskResp.AttachedFunction) + originalTaskID := taskResp.AttachedFunction.Id suite.T().Logf("Created fully initialized task: %s", originalTaskID) // STEP 2: Directly UPDATE database to make task partial (simulate Phase 3 failure) @@ -363,7 +365,8 @@ func (suite *HeapClientIntegrationTestSuite) TestPartialTaskCleanup_ThenRecreate return } suite.NotNil(taskResp) - suite.T().Logf("Created task: %s", taskResp.Id) + suite.NotNil(taskResp.AttachedFunction) + suite.T().Logf("Created task: %s", taskResp.AttachedFunction.Id) // STEP 2: Call CleanupExpiredPartialAttachedFunctions (with short timeout to test it doesn't affect complete tasks) cleanupResp, err := suite.sysdbClient.CleanupExpiredPartialAttachedFunctions(ctx, &coordinatorpb.CleanupExpiredPartialAttachedFunctionsRequest{ @@ -381,12 +384,12 @@ func (suite *HeapClientIntegrationTestSuite) TestPartialTaskCleanup_ThenRecreate }) suite.NoError(err, "Task should still exist after cleanup") suite.NotNil(getResp) - suite.Equal(taskResp.Id, getResp.AttachedFunction.Id) + suite.Equal(taskResp.AttachedFunction.Id, getResp.AttachedFunction.Id) suite.T().Logf("Task still exists after cleanup: %s", getResp.AttachedFunction.Id) // STEP 4: Delete the task _, err = suite.sysdbClient.DetachFunction(ctx, &coordinatorpb.DetachFunctionRequest{ - AttachedFunctionId: taskResp.Id, + AttachedFunctionId: taskResp.AttachedFunction.Id, DeleteOutput: true, }) suite.NoError(err, "Should delete task") @@ -403,8 +406,9 @@ func (suite *HeapClientIntegrationTestSuite) TestPartialTaskCleanup_ThenRecreate }) suite.NoError(err, "Should be able to recreate task after deletion") suite.NotNil(taskResp2) - suite.NotEqual(taskResp.Id, taskResp2.Id, "New task should have different ID") - suite.T().Logf("Successfully recreated task: %s", taskResp2.Id) + suite.NotNil(taskResp2.AttachedFunction) + suite.NotEqual(taskResp.AttachedFunction.Id, taskResp2.AttachedFunction.Id, "New task should have different ID") + suite.T().Logf("Successfully recreated task: %s", taskResp2.AttachedFunction.Id) } func TestHeapClientIntegrationSuite(t *testing.T) { diff --git a/go/pkg/sysdb/coordinator/model/collection.go b/go/pkg/sysdb/coordinator/model/collection.go index 0f652b62b2e..1d53eeca62f 100644 --- a/go/pkg/sysdb/coordinator/model/collection.go +++ b/go/pkg/sysdb/coordinator/model/collection.go @@ -105,6 +105,10 @@ type FlushCollectionInfo struct { AttachedFunctionCompletionOffset *int64 } +type ExtendedFlushCollectionInfo struct { + Collections []*FlushCollectionInfo +} + func FilterCollection(collection *Collection, collectionID types.UniqueID, collectionName *string) bool { if collectionID != types.NilUniqueID() && collectionID != collection.ID { return false diff --git a/go/pkg/sysdb/coordinator/table_catalog.go b/go/pkg/sysdb/coordinator/table_catalog.go index 35d0a049e41..e72dbbfb042 100644 --- a/go/pkg/sysdb/coordinator/table_catalog.go +++ b/go/pkg/sysdb/coordinator/table_catalog.go @@ -1789,6 +1789,76 @@ func (tc *Catalog) FlushCollectionCompactionAndAttachedFunction( return flushCollectionInfo, nil } +// FlushCollectionCompactionAndAttachedFunctionExtended atomically updates multiple collection compaction data +// and attached function completion offset in a single transaction. +// NOTE: This does NOT advance next_nonce - that is done separately by AdvanceAttachedFunction. +// This only updates the completion_offset to record how far we've processed. +// This is only supported for versioned collections (the modern/default path). +func (tc *Catalog) FlushCollectionCompactionAndAttachedFunctionExtended( + ctx context.Context, + collectionCompactions []*model.FlushCollectionCompaction, + attachedFunctionID uuid.UUID, + runNonce uuid.UUID, + completionOffset int64, +) (*model.ExtendedFlushCollectionInfo, error) { + if !tc.versionFileEnabled { + // Attached-function-based compactions are only supported with versioned collections + log.Error("FlushCollectionCompactionAndAttachedFunctionExtended is only supported for versioned collections") + return nil, errors.New("attached-function-based compaction requires versioned collections") + } + + if len(collectionCompactions) == 0 { + return nil, errors.New("at least one collection compaction is required") + } + + flushInfos := make([]*model.FlushCollectionInfo, 0, len(collectionCompactions)) + + err := tc.txImpl.Transaction(ctx, func(txCtx context.Context) error { + var err error + // Get the transaction from context to pass to FlushCollectionCompactionForVersionedCollection + tx := dbcore.GetDB(txCtx) + + // Handle all collection compactions + for _, collectionCompaction := range collectionCompactions { + log.Info("FlushCollectionCompactionAndAttachedFunctionExtended", zap.String("collection_id", collectionCompaction.ID.String())) + flushInfo, err := tc.FlushCollectionCompactionForVersionedCollection(txCtx, collectionCompaction, tx) + if err != nil { + return err + } + flushInfos = append(flushInfos, flushInfo) + } + + // Update ONLY completion_offset - next_nonce was already advanced by AdvanceAttachedFunction + // We still validate runNonce to ensure we're updating the correct nonce + err = tc.metaDomain.AttachedFunctionDb(txCtx).UpdateCompletionOffset(attachedFunctionID, runNonce, completionOffset) + if err != nil { + return err + } + + return nil + }) + + if err != nil { + return nil, err + } + + // Populate attached function fields with authoritative values from database + for _, flushInfo := range flushInfos { + flushInfo.AttachedFunctionCompletionOffset = &completionOffset + } + + // Log with first collection ID (typically the output collection) + log.Info("FlushCollectionCompactionAndAttachedFunctionExtended", + zap.String("first_collection_id", collectionCompactions[0].ID.String()), + zap.Int("collection_count", len(collectionCompactions)), + zap.String("attached_function_id", attachedFunctionID.String()), + zap.Int64("completion_offset", completionOffset)) + + return &model.ExtendedFlushCollectionInfo{ + Collections: flushInfos, + }, nil +} + func (tc *Catalog) validateVersionFile(versionFile *coordinatorpb.CollectionVersionFile, collectionID string, version int64) error { if versionFile.GetCollectionInfoImmutable().GetCollectionId() != collectionID { log.Error("collection id mismatch", zap.String("collection_id", collectionID), zap.String("version_file_collection_id", versionFile.GetCollectionInfoImmutable().GetCollectionId())) diff --git a/go/pkg/sysdb/coordinator/task.go b/go/pkg/sysdb/coordinator/task.go index a4d548fec61..b262a3a4911 100644 --- a/go/pkg/sysdb/coordinator/task.go +++ b/go/pkg/sysdb/coordinator/task.go @@ -19,7 +19,6 @@ import ( "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/structpb" - "google.golang.org/protobuf/types/known/timestamppb" ) // minimalUUIDv7 represents the smallest possible UUIDv7. @@ -88,7 +87,7 @@ func (s *Coordinator) validateAttachedFunctionMatchesRequest(ctx context.Context return nil } -// AttachFunction creates a new attached function in the database +// AttachFunction creates an output collection and attached function in a single transaction func (s *Coordinator) AttachFunction(ctx context.Context, req *coordinatorpb.AttachFunctionRequest) (*coordinatorpb.AttachFunctionResponse, error) { log := log.With(zap.String("method", "AttachFunction")) @@ -99,46 +98,10 @@ func (s *Coordinator) AttachFunction(ctx context.Context, req *coordinatorpb.Att } var attachedFunctionID uuid.UUID = uuid.New() - var nextNonce uuid.UUID // Store next_nonce to avoid re-fetching from DB - var lowestLiveNonce uuid.UUID // Store lowest_live_nonce to set in Phase 3 - var nextRun time.Time - var skipPhase2And3 bool // Flag to skip Phase 2 & 3 if task is already fully initialized + var nextNonce uuid.UUID - // ===== Phase 1: Create attached function with lowest_live_nonce = NULL (if needed) ===== + // ===== Single Transaction: Create output collection and attached function ===== err := s.catalog.txImpl.Transaction(ctx, func(txCtx context.Context) error { - // Double-check attached function doesn't exist (race condition protection) - concurrentAttachedFunction, err := s.catalog.metaDomain.AttachedFunctionDb(txCtx).GetByName(req.InputCollectionId, req.Name) - if err != nil { - log.Error("AttachFunction: failed to double-check attached function", zap.Error(err)) - return err - } - if concurrentAttachedFunction != nil { - // Attached function was created concurrently, validate it matches our request - log.Info("AttachFunction: attached function created concurrently, validating parameters", - zap.String("attached_function_id", concurrentAttachedFunction.ID.String())) - - // Validate that concurrent attached function matches our request - if err := s.validateAttachedFunctionMatchesRequest(txCtx, concurrentAttachedFunction, req); err != nil { - return err - } - - // Validation passed, reuse the concurrent attached function's data - attachedFunctionID = concurrentAttachedFunction.ID - nextNonce = concurrentAttachedFunction.NextNonce - nextRun = concurrentAttachedFunction.NextRun - - // Set lowestLiveNonce for the concurrent case - if concurrentAttachedFunction.LowestLiveNonce != nil { - // Already initialized, skip Phase 2 & 3 - lowestLiveNonce = *concurrentAttachedFunction.LowestLiveNonce - skipPhase2And3 = true - } else { - // Not initialized yet, generate minimal UUIDv7 and continue to Phase 2 & 3 - lowestLiveNonce = minimalUUIDv7 - } - return nil - } - // Look up database_id databases, err := s.catalog.metaDomain.DatabaseDb(txCtx).GetDatabases(req.TenantId, req.Database) if err != nil { @@ -161,14 +124,37 @@ func (s *Coordinator) AttachFunction(ctx context.Context, req *coordinatorpb.Att return common.ErrFunctionNotFound } - // Generate next_nonce as UUIDv7 with current time - nextNonce, err = uuid.NewV7() - if err != nil { + // Check if attached function already exists (idempotency) + existingAttachedFunction, err := s.catalog.metaDomain.AttachedFunctionDb(txCtx).GetByName(req.InputCollectionId, req.Name) + if err != nil && !errors.Is(err, common.ErrAttachedFunctionNotReady) { + log.Error("AttachFunction: failed to check for existing attached function", zap.Error(err)) return err } + if existingAttachedFunction != nil { + // Validate that the existing attached function matches the request + if err := s.validateAttachedFunctionMatchesRequest(txCtx, existingAttachedFunction, req); err != nil { + log.Error("AttachFunction: existing attached function does not match request", zap.Error(err)) + return err + } - // Set lowest_live_nonce to minimal UUIDv7 (guaranteed < nextNonce) - lowestLiveNonce = minimalUUIDv7 + // If attached function is fully initialized (no error), return it (idempotent) + if err == nil { + attachedFunctionID = existingAttachedFunction.ID + log.Info("AttachFunction: attached function already exists and is complete, returning existing", + zap.String("attached_function_id", attachedFunctionID.String()), + zap.String("name", req.Name), + zap.String("input_collection_id", req.InputCollectionId)) + return nil + } + + // If we got ErrAttachedFunctionNotReady, the attached function is incomplete + // Continue with initialization (Phases 2 & 3) to complete it + attachedFunctionID = existingAttachedFunction.ID + log.Info("AttachFunction: found incomplete attached function, will complete initialization", + zap.String("attached_function_id", attachedFunctionID.String()), + zap.String("name", req.Name)) + return nil // Exit transaction, continue with heap push and nonce update + } // Check if input collection exists collections, err := s.catalog.metaDomain.CollectionDb(txCtx).GetCollections([]string{req.InputCollectionId}, nil, req.TenantId, req.Database, nil, nil, false) @@ -181,16 +167,65 @@ func (s *Coordinator) AttachFunction(ctx context.Context, req *coordinatorpb.Att return common.ErrCollectionNotFound } - // Check if output collection already exists - outputCollectionName := req.OutputCollectionName - existingOutputCollections, err := s.catalog.metaDomain.CollectionDb(txCtx).GetCollections(nil, &outputCollectionName, req.TenantId, req.Database, nil, nil, false) + // Create output collection with segments + outputCollectionID := types.NewUniqueID() + + // Set a default dimension to ensure segment writers can be initialized + dimension := int32(1) // Default dimension for attached function output collections + + createCollection := &model.CreateCollection{ + ID: outputCollectionID, + Name: req.OutputCollectionName, + ConfigurationJsonStr: `{"hnsw": {"space": "cosine", "M": 16, "ef_construction": 64}}`, + DatabaseName: req.Database, + TenantID: req.TenantId, + GetOrCreate: false, // We want to fail if it already exists + Dimension: &dimension, + Metadata: nil, + } + + // Create segments for the collection (distributed setup) + segments := []*model.Segment{ + { + ID: types.NewUniqueID(), + Type: "urn:chroma:segment/vector/hnsw-distributed", + Scope: "VECTOR", + CollectionID: outputCollectionID, + }, + { + ID: types.NewUniqueID(), + Type: "urn:chroma:segment/metadata/blockfile", + Scope: "METADATA", + CollectionID: outputCollectionID, + }, + { + ID: types.NewUniqueID(), + Type: "urn:chroma:segment/record/blockfile", + Scope: "RECORD", + CollectionID: outputCollectionID, + }, + } + + // Create output collection and segments directly to avoid nested transaction + outputCollection, _, err := s.catalog.createCollectionImpl(txCtx, createCollection, "", 0) if err != nil { - log.Error("AttachFunction: failed to check output collection", zap.Error(err)) + log.Error("AttachFunction: failed to create output collection", zap.Error(err)) return err } - if len(existingOutputCollections) > 0 { - log.Error("AttachFunction: output collection already exists") - return common.ErrCollectionUniqueConstraintViolation + + // Create segments for the collection + for _, segment := range segments { + _, err := s.catalog.createSegmentImpl(txCtx, segment, 0) + if err != nil { + log.Error("AttachFunction: failed to create segment", zap.Error(err)) + return err + } + } + + // Generate next_nonce as UUIDv7 with current time + nextNonce, err = uuid.NewV7() + if err != nil { + return err } // Serialize params @@ -206,6 +241,7 @@ func (s *Coordinator) AttachFunction(ctx context.Context, req *coordinatorpb.Att paramsJSON = "{}" } + // Create attached function with NULL lowest_live_nonce (2-phase commit Phase 1) now := time.Now() attachedFunction := &dbmodel.AttachedFunction{ ID: attachedFunctionID, @@ -214,6 +250,7 @@ func (s *Coordinator) AttachFunction(ctx context.Context, req *coordinatorpb.Att DatabaseID: databases[0].ID, InputCollectionID: req.InputCollectionId, OutputCollectionName: req.OutputCollectionName, + OutputCollectionID: &[]string{outputCollection.ID.String()}[0], FunctionID: function.ID, FunctionParams: paramsJSON, CompletionOffset: 0, @@ -224,20 +261,20 @@ func (s *Coordinator) AttachFunction(ctx context.Context, req *coordinatorpb.Att CreatedAt: now, UpdatedAt: now, NextNonce: nextNonce, - LowestLiveNonce: nil, // **KEY: Set to NULL for 2PC** + LowestLiveNonce: nil, // NULL in Phase 1, set in Phase 3 OldestWrittenNonce: nil, } - nextRun = attachedFunction.NextRun - err = s.catalog.metaDomain.AttachedFunctionDb(txCtx).Insert(attachedFunction) if err != nil { log.Error("AttachFunction: failed to insert attached function", zap.Error(err)) return err } - log.Debug("AttachFunction: Phase 1: attached function created with lowest_live_nonce=NULL", + log.Debug("AttachFunction: created output collection and attached function in single transaction", zap.String("attached_function_id", attachedFunctionID.String()), + zap.String("output_collection_id", outputCollection.ID.String()), + zap.String("output_collection_name", req.OutputCollectionName), zap.String("name", req.Name)) return nil }) @@ -246,62 +283,10 @@ func (s *Coordinator) AttachFunction(ctx context.Context, req *coordinatorpb.Att return nil, err } - // If function is already fully attached, return immediately (idempotency) - if skipPhase2And3 { - log.Info("AttachFunction: function already fully attached, skipping Phase 2 & 3", - zap.String("attached_function_id", attachedFunctionID.String())) - return &coordinatorpb.AttachFunctionResponse{ + return &coordinatorpb.AttachFunctionResponse{ + AttachedFunction: &coordinatorpb.AttachedFunction{ Id: attachedFunctionID.String(), - }, nil - } - - // ===== Phase 2 ===== - // This phase runs for both new attached functions and recovered incomplete attached functions - log.Debug("AttachFunction: Phase 2: doing initialization work", - zap.String("attached_function_id", attachedFunctionID.String())) - // Push initial schedule to heap service if enabled - if s.heapClient == nil { - return nil, common.ErrHeapServiceNotEnabled - } - - // Create schedule for the attached function - schedule := &coordinatorpb.Schedule{ - Triggerable: &coordinatorpb.Triggerable{ - PartitioningUuid: req.InputCollectionId, - SchedulingUuid: attachedFunctionID.String(), }, - NextScheduled: timestamppb.New(nextRun), - Nonce: lowestLiveNonce.String(), - } - - err = s.heapClient.Push(ctx, req.InputCollectionId, []*coordinatorpb.Schedule{schedule}) - if err != nil { - log.Error("AttachFunction: Phase 2: failed to push schedule to heap service", - zap.Error(err), - zap.String("attached_function_id", attachedFunctionID.String()), - zap.String("collection_id", req.InputCollectionId)) - return nil, err - } - - log.Debug("AttachFunction: Phase 2: pushed schedule to heap service", - zap.String("attached_function_id", attachedFunctionID.String()), - zap.String("collection_id", req.InputCollectionId)) - - // ===== Phase 3: Update lowest_live_nonce to complete initialization ===== - // No database fetch needed - we already have lowestLiveNonce and nextNonce from Phase 1/Recovery - err = s.catalog.metaDomain.AttachedFunctionDb(ctx).UpdateLowestLiveNonce(attachedFunctionID, lowestLiveNonce) - if err != nil { - log.Error("AttachFunction: Phase 3: failed to update lowest_live_nonce", zap.Error(err), zap.String("attached_function_id", attachedFunctionID.String()), zap.String("lowest_live_nonce", lowestLiveNonce.String())) - return nil, err - } - - log.Debug("AttachFunction: Phase 3: attached function initialization completed", - zap.String("attached_function_id", attachedFunctionID.String()), - zap.String("lowest_live_nonce", lowestLiveNonce.String()), - zap.String("next_nonce", nextNonce.String())) - - return &coordinatorpb.AttachFunctionResponse{ - Id: attachedFunctionID.String(), }, nil } @@ -852,75 +837,3 @@ func (s *Coordinator) CleanupExpiredPartialAttachedFunctions(ctx context.Context CleanedUpIds: cleanedAttachedFunctionIDStrings, }, nil } - -// GetSoftDeletedAttachedFunctions retrieves attached functions that are soft deleted and were updated before the cutoff time -func (s *Coordinator) GetSoftDeletedAttachedFunctions(ctx context.Context, req *coordinatorpb.GetSoftDeletedAttachedFunctionsRequest) (*coordinatorpb.GetSoftDeletedAttachedFunctionsResponse, error) { - log := log.With(zap.String("method", "GetSoftDeletedAttachedFunctions")) - - if req.CutoffTime == nil { - log.Error("GetSoftDeletedAttachedFunctions: cutoff_time is required") - return nil, status.Errorf(codes.InvalidArgument, "cutoff_time is required") - } - - if req.Limit <= 0 { - log.Error("GetSoftDeletedAttachedFunctions: limit must be greater than 0") - return nil, status.Errorf(codes.InvalidArgument, "limit must be greater than 0") - } - - cutoffTime := req.CutoffTime.AsTime() - attachedFunctions, err := s.catalog.metaDomain.AttachedFunctionDb(ctx).GetSoftDeletedAttachedFunctions(cutoffTime, req.Limit) - if err != nil { - log.Error("GetSoftDeletedAttachedFunctions: failed to get soft deleted attached functions", zap.Error(err)) - return nil, err - } - - // Convert to proto response - protoAttachedFunctions := make([]*coordinatorpb.AttachedFunction, len(attachedFunctions)) - for i, af := range attachedFunctions { - protoAttachedFunctions[i] = &coordinatorpb.AttachedFunction{ - Id: af.ID.String(), - Name: af.Name, - InputCollectionId: af.InputCollectionID, - OutputCollectionName: af.OutputCollectionName, - CompletionOffset: uint64(af.CompletionOffset), - MinRecordsForInvocation: uint64(af.MinRecordsForInvocation), - CreatedAt: uint64(af.CreatedAt.UnixMicro()), - UpdatedAt: uint64(af.UpdatedAt.UnixMicro()), - } - - protoAttachedFunctions[i].NextRunAt = uint64(af.NextRun.UnixMicro()) - if af.OutputCollectionID != nil { - protoAttachedFunctions[i].OutputCollectionId = proto.String(*af.OutputCollectionID) - } - } - - log.Info("GetSoftDeletedAttachedFunctions: completed successfully", - zap.Int("count", len(attachedFunctions))) - - return &coordinatorpb.GetSoftDeletedAttachedFunctionsResponse{ - AttachedFunctions: protoAttachedFunctions, - }, nil -} - -// FinishAttachedFunctionDeletion permanently deletes an attached function from the database (hard delete) -// This should only be called after the soft delete grace period has passed -func (s *Coordinator) FinishAttachedFunctionDeletion(ctx context.Context, req *coordinatorpb.FinishAttachedFunctionDeletionRequest) (*coordinatorpb.FinishAttachedFunctionDeletionResponse, error) { - log := log.With(zap.String("method", "FinishAttachedFunctionDeletion")) - - attachedFunctionID, err := uuid.Parse(req.AttachedFunctionId) - if err != nil { - log.Error("FinishAttachedFunctionDeletion: invalid attached_function_id", zap.Error(err)) - return nil, status.Errorf(codes.InvalidArgument, "invalid attached_function_id: %v", err) - } - - err = s.catalog.metaDomain.AttachedFunctionDb(ctx).HardDeleteAttachedFunction(attachedFunctionID) - if err != nil { - log.Error("FinishAttachedFunctionDeletion: failed to hard delete attached function", zap.Error(err)) - return nil, err - } - - log.Info("FinishAttachedFunctionDeletion: completed successfully", - zap.String("attached_function_id", attachedFunctionID.String())) - - return &coordinatorpb.FinishAttachedFunctionDeletionResponse{}, nil -} diff --git a/go/pkg/sysdb/grpc/collection_service.go b/go/pkg/sysdb/grpc/collection_service.go index d8af8404f54..3e97acc6a74 100644 --- a/go/pkg/sysdb/grpc/collection_service.go +++ b/go/pkg/sysdb/grpc/collection_service.go @@ -574,11 +574,15 @@ func (s *Server) FlushCollectionCompaction(ctx context.Context, req *coordinator } func (s *Server) FlushCollectionCompactionAndAttachedFunction(ctx context.Context, req *coordinatorpb.FlushCollectionCompactionAndAttachedFunctionRequest) (*coordinatorpb.FlushCollectionCompactionAndAttachedFunctionResponse, error) { - // Parse the flush compaction request (nested message) - flushReq := req.GetFlushCompaction() - if flushReq == nil { - log.Error("FlushCollectionCompactionAndAttachedFunction failed. flush_compaction is nil") - return nil, grpcutils.BuildInternalGrpcError("flush_compaction is required") + // Parse the repeated flush compaction requests + flushReqs := req.GetFlushCompactions() + if len(flushReqs) == 0 { + log.Error("FlushCollectionCompactionAndAttachedFunction failed. flush_compactions is empty") + return nil, grpcutils.BuildInternalGrpcError("at least one flush_compaction is required") + } + if len(flushReqs) > 2 { + log.Error("FlushCollectionCompactionAndAttachedFunction failed. too many flush_compactions", zap.Int("count", len(flushReqs))) + return nil, grpcutils.BuildInternalGrpcError("expected 1 or 2 flush_compactions") } // Parse attached function update info @@ -600,14 +604,6 @@ func (s *Server) FlushCollectionCompactionAndAttachedFunction(ctx context.Contex return nil, grpcutils.BuildInternalGrpcError("invalid run_nonce: " + err.Error()) } - // Parse collection and segment info (reuse logic from FlushCollectionCompaction) - collectionID, err := types.ToUniqueID(&flushReq.CollectionId) - err = grpcutils.BuildErrorForUUID(collectionID, "collection", err) - if err != nil { - log.Error("FlushCollectionCompactionAndAttachedFunction failed. error parsing collection id", zap.Error(err), zap.String("collection_id", flushReq.CollectionId)) - return nil, grpcutils.BuildInternalGrpcError(err.Error()) - } - // Validate completion_offset fits in int64 before storing in database if attachedFunctionUpdate.CompletionOffset > uint64(math.MaxInt64) { log.Error("FlushCollectionCompactionAndAttachedFunction: completion_offset too large", @@ -616,43 +612,56 @@ func (s *Server) FlushCollectionCompactionAndAttachedFunction(ctx context.Contex } completionOffsetSigned := int64(attachedFunctionUpdate.CompletionOffset) - segmentCompactionInfo := make([]*model.FlushSegmentCompaction, 0, len(flushReq.SegmentCompactionInfo)) - for _, flushSegmentCompaction := range flushReq.SegmentCompactionInfo { - segmentID, err := types.ToUniqueID(&flushSegmentCompaction.SegmentId) - err = grpcutils.BuildErrorForUUID(segmentID, "segment", err) + // Parse all flush requests into a slice + collectionCompactions := make([]*model.FlushCollectionCompaction, 0, len(flushReqs)) + + for _, flushReq := range flushReqs { + collectionID, err := types.ToUniqueID(&flushReq.CollectionId) + err = grpcutils.BuildErrorForUUID(collectionID, "collection", err) if err != nil { - log.Error("FlushCollectionCompactionAndAttachedFunction failed. error parsing segment id", zap.Error(err), zap.String("collection_id", flushReq.CollectionId)) + log.Error("FlushCollectionCompactionAndAttachedFunction failed. error parsing collection id", zap.Error(err), zap.String("collection_id", flushReq.CollectionId)) return nil, grpcutils.BuildInternalGrpcError(err.Error()) } - filePaths := make(map[string][]string) - for key, filePath := range flushSegmentCompaction.FilePaths { - filePaths[key] = filePath.Paths + + segmentCompactionInfo := make([]*model.FlushSegmentCompaction, 0, len(flushReq.SegmentCompactionInfo)) + for _, flushSegmentCompaction := range flushReq.SegmentCompactionInfo { + segmentID, err := types.ToUniqueID(&flushSegmentCompaction.SegmentId) + err = grpcutils.BuildErrorForUUID(segmentID, "segment", err) + if err != nil { + log.Error("FlushCollectionCompactionAndAttachedFunction failed. error parsing segment id", zap.Error(err), zap.String("collection_id", flushReq.CollectionId)) + return nil, grpcutils.BuildInternalGrpcError(err.Error()) + } + filePaths := make(map[string][]string) + for key, filePath := range flushSegmentCompaction.FilePaths { + filePaths[key] = filePath.Paths + } + segmentCompactionInfo = append(segmentCompactionInfo, &model.FlushSegmentCompaction{ + ID: segmentID, + FilePaths: filePaths, + }) } - segmentCompactionInfo = append(segmentCompactionInfo, &model.FlushSegmentCompaction{ - ID: segmentID, - FilePaths: filePaths, - }) - } - flushCollectionCompaction := &model.FlushCollectionCompaction{ - ID: collectionID, - TenantID: flushReq.TenantId, - LogPosition: flushReq.LogPosition, - CurrentCollectionVersion: flushReq.CollectionVersion, - FlushSegmentCompactions: segmentCompactionInfo, - TotalRecordsPostCompaction: flushReq.TotalRecordsPostCompaction, - SizeBytesPostCompaction: flushReq.SizeBytesPostCompaction, + collectionCompactions = append(collectionCompactions, &model.FlushCollectionCompaction{ + ID: collectionID, + TenantID: flushReq.TenantId, + LogPosition: flushReq.LogPosition, + CurrentCollectionVersion: flushReq.CollectionVersion, + FlushSegmentCompactions: segmentCompactionInfo, + TotalRecordsPostCompaction: flushReq.TotalRecordsPostCompaction, + SizeBytesPostCompaction: flushReq.SizeBytesPostCompaction, + }) } - flushCollectionInfo, err := s.coordinator.FlushCollectionCompactionAndAttachedFunction( + // Call the Extended coordinator function to handle all collections + extendedFlushInfo, err := s.coordinator.FlushCollectionCompactionAndAttachedFunctionExtended( ctx, - flushCollectionCompaction, + collectionCompactions, attachedFunctionID, runNonce, completionOffsetSigned, ) if err != nil { - log.Error("FlushCollectionCompactionAndAttachedFunction failed", zap.Error(err), zap.String("collection_id", flushReq.CollectionId), zap.String("attached_function_id", attachedFunctionUpdate.Id)) + log.Error("FlushCollectionCompactionAndAttachedFunction failed", zap.Error(err), zap.String("attached_function_id", attachedFunctionUpdate.Id)) if err == common.ErrCollectionSoftDeleted { return nil, grpcutils.BuildFailedPreconditionGrpcError(err.Error()) } @@ -662,27 +671,41 @@ func (s *Server) FlushCollectionCompactionAndAttachedFunction(ctx context.Contex return nil, grpcutils.BuildInternalGrpcError(err.Error()) } + // Build response with repeated collections res := &coordinatorpb.FlushCollectionCompactionAndAttachedFunctionResponse{ - CollectionId: flushCollectionInfo.ID, - CollectionVersion: flushCollectionInfo.CollectionVersion, - LastCompactionTime: flushCollectionInfo.TenantLastCompactionTime, + Collections: make([]*coordinatorpb.CollectionCompactionInfo, 0, len(extendedFlushInfo.Collections)), } - // Populate attached function fields with authoritative values from database - if flushCollectionInfo.AttachedFunctionNextNonce != nil { - res.NextNonce = flushCollectionInfo.AttachedFunctionNextNonce.String() - } - if flushCollectionInfo.AttachedFunctionNextRun != nil { - res.NextRun = timestamppb.New(*flushCollectionInfo.AttachedFunctionNextRun) + for _, flushInfo := range extendedFlushInfo.Collections { + res.Collections = append(res.Collections, &coordinatorpb.CollectionCompactionInfo{ + CollectionId: flushInfo.ID, + CollectionVersion: flushInfo.CollectionVersion, + LastCompactionTime: flushInfo.TenantLastCompactionTime, + }) } - if flushCollectionInfo.AttachedFunctionCompletionOffset != nil { - // Validate completion_offset is non-negative before converting to uint64 - if *flushCollectionInfo.AttachedFunctionCompletionOffset < 0 { - log.Error("FlushCollectionCompactionAndAttachedFunction: invalid completion_offset", - zap.Int64("completion_offset", *flushCollectionInfo.AttachedFunctionCompletionOffset)) - return nil, grpcutils.BuildInternalGrpcError("attached function has invalid completion_offset") + + // Populate attached function state with authoritative values from database (use first collection) + if len(extendedFlushInfo.Collections) > 0 { + firstFlushInfo := extendedFlushInfo.Collections[0] + attachedFunctionState := &coordinatorpb.AttachedFunctionState{} + + if firstFlushInfo.AttachedFunctionNextNonce != nil { + attachedFunctionState.NextNonce = firstFlushInfo.AttachedFunctionNextNonce.String() } - res.CompletionOffset = uint64(*flushCollectionInfo.AttachedFunctionCompletionOffset) + if firstFlushInfo.AttachedFunctionNextRun != nil { + attachedFunctionState.NextRun = timestamppb.New(*firstFlushInfo.AttachedFunctionNextRun) + } + if firstFlushInfo.AttachedFunctionCompletionOffset != nil { + // Validate completion_offset is non-negative before converting to uint64 + if *firstFlushInfo.AttachedFunctionCompletionOffset < 0 { + log.Error("FlushCollectionCompactionAndAttachedFunction: invalid completion_offset", + zap.Int64("completion_offset", *firstFlushInfo.AttachedFunctionCompletionOffset)) + return nil, grpcutils.BuildInternalGrpcError("attached function has invalid completion_offset") + } + attachedFunctionState.CompletionOffset = uint64(*firstFlushInfo.AttachedFunctionCompletionOffset) + } + + res.AttachedFunctionState = attachedFunctionState } return res, nil diff --git a/go/pkg/sysdb/grpc/task_service.go b/go/pkg/sysdb/grpc/task_service.go index dc3267c24cc..31e7e12adad 100644 --- a/go/pkg/sysdb/grpc/task_service.go +++ b/go/pkg/sysdb/grpc/task_service.go @@ -109,9 +109,10 @@ func (s *Server) FinishAttachedFunction(ctx context.Context, req *coordinatorpb. res, err := s.coordinator.FinishAttachedFunction(ctx, req) if err != nil { log.Error("FinishAttachedFunction failed", zap.Error(err)) - return nil, err + return nil, grpcutils.BuildInternalGrpcError(err.Error()) } + log.Info("FinishAttachedFunction succeeded", zap.String("id", req.Id)) return res, nil } @@ -151,29 +152,3 @@ func (s *Server) CleanupExpiredPartialAttachedFunctions(ctx context.Context, req log.Info("CleanupExpiredPartialAttachedFunctions succeeded", zap.Uint64("cleaned_up_count", res.CleanedUpCount)) return res, nil } - -func (s *Server) GetSoftDeletedAttachedFunctions(ctx context.Context, req *coordinatorpb.GetSoftDeletedAttachedFunctionsRequest) (*coordinatorpb.GetSoftDeletedAttachedFunctionsResponse, error) { - log.Info("GetSoftDeletedAttachedFunctions", zap.Time("cutoff_time", req.CutoffTime.AsTime()), zap.Int32("limit", req.Limit)) - - res, err := s.coordinator.GetSoftDeletedAttachedFunctions(ctx, req) - if err != nil { - log.Error("GetSoftDeletedAttachedFunctions failed", zap.Error(err)) - return nil, grpcutils.BuildInternalGrpcError(err.Error()) - } - - log.Info("GetSoftDeletedAttachedFunctions succeeded", zap.Int("count", len(res.AttachedFunctions))) - return res, nil -} - -func (s *Server) FinishAttachedFunctionDeletion(ctx context.Context, req *coordinatorpb.FinishAttachedFunctionDeletionRequest) (*coordinatorpb.FinishAttachedFunctionDeletionResponse, error) { - log.Info("FinishAttachedFunctionDeletion", zap.String("id", req.AttachedFunctionId)) - - res, err := s.coordinator.FinishAttachedFunctionDeletion(ctx, req) - if err != nil { - log.Error("FinishAttachedFunctionDeletion failed", zap.Error(err)) - return nil, grpcutils.BuildInternalGrpcError(err.Error()) - } - - log.Info("FinishAttachedFunctionDeletion succeeded", zap.String("id", req.AttachedFunctionId)) - return res, nil -} diff --git a/go/pkg/sysdb/metastore/db/dao/collection.go b/go/pkg/sysdb/metastore/db/dao/collection.go index e806ba53478..b56bad5cabd 100644 --- a/go/pkg/sysdb/metastore/db/dao/collection.go +++ b/go/pkg/sysdb/metastore/db/dao/collection.go @@ -586,6 +586,7 @@ func (s *collectionDb) UpdateLogPositionVersionTotalRecordsAndLogicalSize(collec } if collection.Version < currentCollectionVersion { // this should not happen, potentially a bug + log.Error("TANUJ: 589 collection version is less than current collection version", zap.String("collectionID", collectionID), zap.Int32("currentCollectionVersion", currentCollectionVersion), zap.Int32("collectionVersion", collection.Version)) return 0, common.ErrCollectionVersionInvalid } diff --git a/go/pkg/sysdb/metastore/db/dao/task.go b/go/pkg/sysdb/metastore/db/dao/task.go index b414e79a87d..d2ad5e5cd1e 100644 --- a/go/pkg/sysdb/metastore/db/dao/task.go +++ b/go/pkg/sysdb/metastore/db/dao/task.go @@ -231,7 +231,6 @@ func (s *attachedFunctionDb) UpdateCompletionOffset(id uuid.UUID, runNonce uuid. result := s.db.Model(&dbmodel.AttachedFunction{}). Where("id = ?", id). Where("is_deleted = false"). - Where("lowest_live_nonce = ?", runNonce). // Ensure we're updating the correct nonce UpdateColumns(map[string]interface{}{ "completion_offset": completionOffset, "last_run": now, diff --git a/idl/chromadb/proto/coordinator.proto b/idl/chromadb/proto/coordinator.proto index c6fdbae75c2..fd6363e4401 100644 --- a/idl/chromadb/proto/coordinator.proto +++ b/idl/chromadb/proto/coordinator.proto @@ -330,18 +330,25 @@ message AttachedFunctionUpdateInfo { // Combined request to flush collection compaction and update attached function atomically in a single transaction message FlushCollectionCompactionAndAttachedFunctionRequest { - FlushCollectionCompactionRequest flush_compaction = 1; + repeated FlushCollectionCompactionRequest flush_compactions = 1; AttachedFunctionUpdateInfo attached_function_update = 2; } -message FlushCollectionCompactionAndAttachedFunctionResponse { +message CollectionCompactionInfo { string collection_id = 1; int32 collection_version = 2; int64 last_compaction_time = 3; - // Updated attached function fields from database (authoritative) - string next_nonce = 4; - google.protobuf.Timestamp next_run = 5; - uint64 completion_offset = 6; +} + +message AttachedFunctionState { + string next_nonce = 1; + google.protobuf.Timestamp next_run = 2; + uint64 completion_offset = 3; +} + +message FlushCollectionCompactionAndAttachedFunctionResponse { + repeated CollectionCompactionInfo collections = 1; + AttachedFunctionState attached_function_state = 2; } // Used for serializing contents in collection version history file. @@ -561,7 +568,7 @@ message AttachFunctionRequest { } message AttachFunctionResponse { - string id = 1; + AttachedFunction attached_function = 1; } message CreateOutputCollectionForAttachedFunctionRequest { diff --git a/rust/sysdb/src/bin/chroma-task-manager.rs b/rust/sysdb/src/bin/chroma-task-manager.rs index c74231494d9..23681e7c2b8 100644 --- a/rust/sysdb/src/bin/chroma-task-manager.rs +++ b/rust/sysdb/src/bin/chroma-task-manager.rs @@ -148,7 +148,11 @@ async fn main() -> Result<(), Box> { }; let response = client.attach_function(request).await?; - println!("Attached Function created: {}", response.into_inner().id); + let attached_function = response + .into_inner() + .attached_function + .ok_or("Server did not return attached function")?; + println!("Attached Function created: {}", attached_function.id); } Command::GetAttachedFunction { input_collection_id, @@ -160,7 +164,10 @@ async fn main() -> Result<(), Box> { }; let response = client.get_attached_function_by_name(request).await?; - let attached_function = response.into_inner().attached_function.unwrap(); + let attached_function = response + .into_inner() + .attached_function + .ok_or("Server did not return attached function")?; println!("Attached Function ID: {:?}", attached_function.id); println!("Name: {:?}", attached_function.name); diff --git a/rust/sysdb/src/sysdb.rs b/rust/sysdb/src/sysdb.rs index 515a1a07466..2fdf48aeb64 100644 --- a/rust/sysdb/src/sysdb.rs +++ b/rust/sysdb/src/sysdb.rs @@ -10,7 +10,7 @@ use chroma_types::chroma_proto::AdvanceAttachedFunctionRequest; use chroma_types::chroma_proto::FinishAttachedFunctionRequest; use chroma_types::chroma_proto::VersionListForCollection; use chroma_types::{ - chroma_proto, chroma_proto::CollectionVersionInfo, CollectionAndSegments, + chroma_proto, chroma_proto::CollectionVersionInfo, CollectionAndSegments, CollectionFlushInfo, CollectionMetadataUpdate, CountCollectionsError, CreateCollectionError, CreateDatabaseError, CreateDatabaseResponse, CreateTenantError, CreateTenantResponse, Database, DeleteCollectionError, DeleteDatabaseError, DeleteDatabaseResponse, GetCollectionByCrnError, @@ -634,30 +634,13 @@ impl SysDb { #[allow(clippy::too_many_arguments)] pub async fn flush_compaction_and_attached_function( &mut self, - tenant_id: String, - collection_id: CollectionUuid, - log_position: i64, - collection_version: i32, - segment_flush_info: Arc<[SegmentFlushInfo]>, - total_records_post_compaction: u64, - size_bytes_post_compaction: u64, - schema: Option, + collections: Vec, attached_function_update: AttachedFunctionUpdateInfo, ) -> Result { match self { SysDb::Grpc(grpc) => { - grpc.flush_compaction_and_attached_function( - tenant_id, - collection_id, - log_position, - collection_version, - segment_flush_info, - total_records_post_compaction, - size_bytes_post_compaction, - schema, - attached_function_update, - ) - .await + grpc.flush_compaction_and_attached_function(collections, attached_function_update) + .await } SysDb::Sqlite(_) => todo!(), SysDb::Test(_) => todo!(), @@ -1702,54 +1685,44 @@ impl GrpcSysDb { } } - #[allow(clippy::too_many_arguments)] async fn flush_compaction_and_attached_function( &mut self, - tenant_id: String, - collection_id: CollectionUuid, - log_position: i64, - collection_version: i32, - segment_flush_info: Arc<[SegmentFlushInfo]>, - total_records_post_compaction: u64, - size_bytes_post_compaction: u64, - schema: Option, + collections: Vec, attached_function_update: AttachedFunctionUpdateInfo, ) -> Result { - let segment_compaction_info = - segment_flush_info + // Process all collections into flush compaction requests + let mut flush_compactions = Vec::with_capacity(collections.len()); + + for collection in collections { + let segment_compaction_info = collection + .segment_flush_info .iter() .map(|segment_flush_info| segment_flush_info.try_into()) .collect::, SegmentFlushInfoConversionError, - >>(); - - let segment_compaction_info = match segment_compaction_info { - Ok(segment_compaction_info) => segment_compaction_info, - Err(e) => { - return Err(FlushCompactionError::SegmentFlushInfoConversionError(e)); - } - }; + >>()?; - let schema_str = schema.and_then(|s| { - serde_json::to_string(&s).ok().or_else(|| { - tracing::error!( - "Failed to serialize schema for flush_compaction_and_attached_function" - ); - None - }) - }); + let schema_str = collection.schema.and_then(|s| { + serde_json::to_string(&s).ok().or_else(|| { + tracing::error!( + "Failed to serialize schema for flush_compaction_and_attached_function" + ); + None + }) + }); - let flush_compaction = Some(chroma_proto::FlushCollectionCompactionRequest { - tenant_id, - collection_id: collection_id.0.to_string(), - log_position, - collection_version, - segment_compaction_info, - total_records_post_compaction, - size_bytes_post_compaction, - schema_str, - }); + flush_compactions.push(chroma_proto::FlushCollectionCompactionRequest { + tenant_id: collection.tenant_id, + collection_id: collection.collection_id.0.to_string(), + log_position: collection.log_position, + collection_version: collection.collection_version, + segment_compaction_info, + total_records_post_compaction: collection.total_records_post_compaction, + size_bytes_post_compaction: collection.size_bytes_post_compaction, + schema_str, + }); + } let attached_function_update_proto = Some(chroma_proto::AttachedFunctionUpdateInfo { id: attached_function_update.attached_function_id.0.to_string(), @@ -1760,7 +1733,7 @@ impl GrpcSysDb { }); let req = chroma_proto::FlushCollectionCompactionAndAttachedFunctionRequest { - flush_compaction, + flush_compactions, attached_function_update: attached_function_update_proto, }; @@ -1943,10 +1916,15 @@ impl GrpcSysDb { let response = self.client.attach_function(req).await?.into_inner(); // Parse the returned attached_function_id - this should always succeed since the server generated it // If this fails, it indicates a serious server bug or protocol corruption + let attached_function = response.attached_function.ok_or_else(|| { + tracing::error!("Server did not return attached function in response"); + AttachFunctionError::ServerReturnedInvalidData + })?; + let attached_function_id = chroma_types::AttachedFunctionUuid( - uuid::Uuid::parse_str(&response.id).map_err(|e| { + uuid::Uuid::parse_str(&attached_function.id).map_err(|e| { tracing::error!( - attached_function_id = %response.id, + attached_function_id = %attached_function.id, error = %e, "Server returned invalid attached_function_id UUID - attached function was created but response is corrupt" ); diff --git a/rust/types/src/flush.rs b/rust/types/src/flush.rs index 852f4312139..9d8ebc0c2f1 100644 --- a/rust/types/src/flush.rs +++ b/rust/types/src/flush.rs @@ -1,4 +1,4 @@ -use super::{AttachedFunctionUuid, CollectionUuid, ConversionError}; +use super::{AttachedFunctionUuid, CollectionUuid, ConversionError, Schema}; use crate::{ chroma_proto::{ FilePaths, FlushCollectionCompactionAndAttachedFunctionResponse, FlushSegmentCompactionInfo, @@ -6,7 +6,7 @@ use crate::{ SegmentUuid, }; use chroma_error::{ChromaError, ErrorCodes}; -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use thiserror::Error; use uuid::Uuid; @@ -16,6 +16,18 @@ pub struct SegmentFlushInfo { pub file_paths: HashMap>, } +#[derive(Debug, Clone)] +pub struct CollectionFlushInfo { + pub tenant_id: String, + pub collection_id: CollectionUuid, + pub log_position: i64, + pub collection_version: i32, + pub segment_flush_info: Arc<[SegmentFlushInfo]>, + pub total_records_post_compaction: u64, + pub size_bytes_post_compaction: u64, + pub schema: Option, +} + #[derive(Debug, Clone)] pub struct AttachedFunctionUpdateInfo { pub attached_function_id: AttachedFunctionUuid, @@ -110,10 +122,15 @@ pub struct FlushCompactionResponse { } #[derive(Debug)] -pub struct FlushCompactionAndAttachedFunctionResponse { +pub struct CollectionCompactionInfo { pub collection_id: CollectionUuid, pub collection_version: i32, pub last_compaction_time: i64, +} + +#[derive(Debug)] +pub struct FlushCompactionAndAttachedFunctionResponse { + pub collections: Vec, // Completion offset updated during register pub completion_offset: u64, // NOTE: next_nonce and next_run are no longer returned @@ -140,12 +157,15 @@ impl TryFrom for FlushComp fn try_from( value: FlushCollectionCompactionAndAttachedFunctionResponse, ) -> Result { - let id = Uuid::parse_str(&value.collection_id) + // Use first collection for backward compatibility + let first_collection = value.collections.first() + .ok_or(FlushCompactionResponseConversionError::MissingCollections)?; + let id = Uuid::parse_str(&first_collection.collection_id) .map_err(|_| FlushCompactionResponseConversionError::InvalidUuid)?; Ok(FlushCompactionResponse { collection_id: CollectionUuid(id), - collection_version: value.collection_version, - last_compaction_time: value.last_compaction_time, + collection_version: first_collection.collection_version, + last_compaction_time: first_collection.last_compaction_time, }) } } @@ -158,18 +178,30 @@ impl TryFrom fn try_from( value: FlushCollectionCompactionAndAttachedFunctionResponse, ) -> Result { - let id = Uuid::parse_str(&value.collection_id) - .map_err(|_| FlushCompactionResponseConversionError::InvalidUuid)?; + // Parse all collections from the repeated field + let mut collections = Vec::with_capacity(value.collections.len()); + for collection in value.collections { + let id = Uuid::parse_str(&collection.collection_id) + .map_err(|_| FlushCompactionResponseConversionError::InvalidUuid)?; + collections.push(CollectionCompactionInfo { + collection_id: CollectionUuid(id), + collection_version: collection.collection_version, + last_compaction_time: collection.last_compaction_time, + }); + } - // Note: next_nonce and next_run are no longer populated by the server + // Extract completion_offset from attached_function_state + // Note: next_nonce and next_run are no longer used by the client // They were already set by PrepareAttachedFunction via advance_attached_function() - // We only use completion_offset from the response + let completion_offset = value + .attached_function_state + .as_ref() + .map(|state| state.completion_offset) + .unwrap_or(0); Ok(FlushCompactionAndAttachedFunctionResponse { - collection_id: CollectionUuid(id), - collection_version: value.collection_version, - last_compaction_time: value.last_compaction_time, - completion_offset: value.completion_offset, + collections, + completion_offset, }) } } @@ -186,6 +218,8 @@ pub enum FlushCompactionResponseConversionError { MissingNextRun, #[error("Invalid timestamp format")] InvalidTimestamp, + #[error("Missing collections in response")] + MissingCollections, } impl ChromaError for FlushCompactionResponseConversionError { @@ -197,6 +231,7 @@ impl ChromaError for FlushCompactionResponseConversionError { } FlushCompactionResponseConversionError::MissingNextRun => ErrorCodes::InvalidArgument, FlushCompactionResponseConversionError::InvalidTimestamp => ErrorCodes::InvalidArgument, + FlushCompactionResponseConversionError::MissingCollections => ErrorCodes::InvalidArgument, FlushCompactionResponseConversionError::DecodeError(e) => e.code(), } } diff --git a/rust/types/src/task.rs b/rust/types/src/task.rs index 43188ed4aa0..9fc055ba2b7 100644 --- a/rust/types/src/task.rs +++ b/rust/types/src/task.rs @@ -168,3 +168,148 @@ pub enum ScheduleEntryConversionError { #[error("Invalid UUID for field: {0}")] InvalidUuid(String), } + +#[derive(Debug, thiserror::Error)] +pub enum AttachedFunctionConversionError { + #[error("Invalid UUID: {0}")] + InvalidUuid(String), +} + +fn prost_struct_to_json_string( + prost_struct: &prost_types::Struct, +) -> Result { + use prost_types::value::Kind; + + let mut map = serde_json::Map::new(); + for (key, value) in &prost_struct.fields { + if let Some(kind) = &value.kind { + let json_value = match kind { + Kind::NullValue(_) => serde_json::Value::Null, + Kind::NumberValue(n) => serde_json::Value::Number( + serde_json::Number::from_f64(*n).unwrap_or_else(|| serde_json::Number::from(0)), + ), + Kind::StringValue(s) => serde_json::Value::String(s.clone()), + Kind::BoolValue(b) => serde_json::Value::Bool(*b), + Kind::StructValue(s) => serde_json::Value::Object( + prost_struct_to_json_string(s)? + .parse::()? + .as_object() + .unwrap() + .clone(), + ), + Kind::ListValue(list) => serde_json::Value::Array( + list.values + .iter() + .map(|v| { + if let Some(kind) = &v.kind { + match kind { + Kind::NullValue(_) => serde_json::Value::Null, + Kind::NumberValue(n) => serde_json::Value::Number( + serde_json::Number::from_f64(*n) + .unwrap_or_else(|| serde_json::Number::from(0)), + ), + Kind::StringValue(s) => serde_json::Value::String(s.clone()), + Kind::BoolValue(b) => serde_json::Value::Bool(*b), + _ => serde_json::Value::Null, // Simplified for now + } + } else { + serde_json::Value::Null + } + }) + .collect(), + ), + }; + map.insert(key.clone(), json_value); + } + } + + serde_json::to_string(&serde_json::Value::Object(map)) +} + +impl TryFrom for AttachedFunction { + type Error = AttachedFunctionConversionError; + + fn try_from( + attached_function: crate::chroma_proto::AttachedFunction, + ) -> Result { + // Parse attached_function_id + let attached_function_id = attached_function + .id + .parse::() + .map_err(|_| { + AttachedFunctionConversionError::InvalidUuid("attached_function_id".to_string()) + })?; + + // Parse function_id + let function_id = attached_function + .function_id + .parse::() + .map_err(|_| AttachedFunctionConversionError::InvalidUuid("function_id".to_string()))?; + + // Parse input_collection_id + let input_collection_id = attached_function + .input_collection_id + .parse::() + .map_err(|_| { + AttachedFunctionConversionError::InvalidUuid("input_collection_id".to_string()) + })?; + + // Parse output_collection_id if available + let output_collection_id = attached_function + .output_collection_id + .map(|id| id.parse::()) + .transpose() + .map_err(|_| { + AttachedFunctionConversionError::InvalidUuid("output_collection_id".to_string()) + })?; + + // Parse params if available + let params = attached_function + .params + .map(|p| prost_struct_to_json_string(&p)) + .transpose() + .map_err(|_| AttachedFunctionConversionError::InvalidUuid("params".to_string()))?; + + // Parse timestamps + let created_at = std::time::SystemTime::UNIX_EPOCH + + std::time::Duration::from_micros(attached_function.created_at); + let updated_at = std::time::SystemTime::UNIX_EPOCH + + std::time::Duration::from_micros(attached_function.updated_at); + let next_run = std::time::SystemTime::UNIX_EPOCH + + std::time::Duration::from_micros(attached_function.next_run_at); + + // Parse nonces + let next_nonce = attached_function + .next_nonce + .parse::() + .map_err(|_| AttachedFunctionConversionError::InvalidUuid("next_nonce".to_string()))?; + let lowest_live_nonce = attached_function + .lowest_live_nonce + .map(|nonce| nonce.parse::()) + .transpose() + .map_err(|_| { + AttachedFunctionConversionError::InvalidUuid("lowest_live_nonce".to_string()) + })?; + + Ok(AttachedFunction { + id: attached_function_id, + name: attached_function.name, + function_id, + input_collection_id, + output_collection_name: attached_function.output_collection_name, + output_collection_id, + params, + tenant_id: attached_function.tenant_id, + database_id: attached_function.database_id, + last_run: None, // Not available in proto + next_run, + completion_offset: attached_function.completion_offset, + min_records_for_invocation: attached_function.min_records_for_invocation, + is_deleted: false, // Not available in proto, would need to be fetched separately + created_at, + updated_at, + next_nonce, + lowest_live_nonce, + }) + } +} diff --git a/rust/worker/src/execution/functions/statistics.rs b/rust/worker/src/execution/functions/statistics.rs index 5d5a2ea4bdb..990d154d39c 100644 --- a/rust/worker/src/execution/functions/statistics.rs +++ b/rust/worker/src/execution/functions/statistics.rs @@ -11,6 +11,7 @@ use std::hash::{Hash, Hasher}; use async_trait::async_trait; use chroma_error::ChromaError; use chroma_segment::blockfile_record::RecordSegmentReader; +use chroma_segment::types::HydratedMaterializedLogRecord; use chroma_types::{ Chunk, LogRecord, MetadataValue, Operation, OperationRecord, UpdateMetadataValue, }; @@ -25,7 +26,7 @@ pub trait StatisticsFunctionFactory: std::fmt::Debug + Send + Sync { /// Accumulate statistics. Must be an associative and commutative over a sequence of `observe` calls. pub trait StatisticsFunction: std::fmt::Debug + Send { - fn observe(&mut self, log_record: &LogRecord); + fn observe(&mut self, hydrated_record: &HydratedMaterializedLogRecord<'_, '_>); fn output(&self) -> UpdateMetadataValue; } @@ -44,7 +45,7 @@ pub struct CounterFunction { } impl StatisticsFunction for CounterFunction { - fn observe(&mut self, _: &LogRecord) { + fn observe(&mut self, _: &HydratedMaterializedLogRecord<'_, '_>) { self.acc = self.acc.saturating_add(1); } @@ -173,28 +174,21 @@ pub struct StatisticsFunctionExecutor(pub Box); impl AttachedFunctionExecutor for StatisticsFunctionExecutor { async fn execute( &self, - input_records: Chunk, + input_records: Chunk>, output_reader: Option<&RecordSegmentReader<'_>>, ) -> Result, Box> { let mut counts: HashMap>> = HashMap::default(); - for (log_record, _) in input_records.iter() { - if matches!(log_record.record.operation, Operation::Delete) { - continue; - } - - if let Some(update_metadata) = log_record.record.metadata.as_ref() { - for (key, update_value) in update_metadata.iter() { - let value: Option = update_value.try_into().ok(); - if let Some(value) = value { - let inner_map = counts.entry(key.clone()).or_default(); - for stats_value in StatisticsValue::from_metadata_value(&value) { - inner_map - .entry(stats_value) - .or_insert_with(|| self.0.create()) - .observe(log_record); - } - } + for (hydrated_record, _index) in input_records.iter() { + // Use merged_metadata to get the metadata from the hydrated record + let metadata = hydrated_record.merged_metadata(); + for (key, value) in metadata.iter() { + let inner_map = counts.entry(key.clone()).or_default(); + for stats_value in StatisticsValue::from_metadata_value(value) { + inner_map + .entry(stats_value) + .or_insert_with(|| self.0.create()) + .observe(hydrated_record); } } } @@ -257,821 +251,821 @@ impl AttachedFunctionExecutor for StatisticsFunctionExecutor { } } -#[cfg(test)] -mod tests { - use std::collections::HashMap; - - use chroma_segment::{blockfile_record::RecordSegmentReader, test::TestDistributedSegment}; - use chroma_types::{ - Chunk, LogRecord, Operation, OperationRecord, SparseVector, UpdateMetadata, - UpdateMetadataValue, - }; - - use super::*; - - fn build_record(id: &str, metadata: HashMap) -> LogRecord { - build_record_with_operation(id, Operation::Upsert, metadata) - } - - fn build_record_with_operation( - id: &str, - operation: Operation, - metadata: HashMap, - ) -> LogRecord { - LogRecord { - log_offset: 0, - record: OperationRecord { - id: id.to_string(), - embedding: None, - encoding: None, - metadata: Some(metadata), - document: None, - operation, - }, - } - } - - fn extract_metadata_tuple(metadata: &UpdateMetadata) -> (i64, String, String, String) { - let count = match metadata.get("count") { - Some(UpdateMetadataValue::Int(value)) => *value, - other => panic!("unexpected count metadata: {other:?}"), - }; - let key = match metadata.get("key") { - Some(UpdateMetadataValue::Str(value)) => value.clone(), - other => panic!("unexpected key metadata: {other:?}"), - }; - let value_type = match metadata.get("type") { - Some(UpdateMetadataValue::Str(value)) => value.clone(), - other => panic!("unexpected type metadata: {other:?}"), - }; - let value = match metadata.get("value") { - Some(UpdateMetadataValue::Str(value)) => value.clone(), - other => panic!("unexpected value metadata: {other:?}"), - }; - (count, key, value_type, value) - } - - fn collect_statistics_map( - output: &Chunk, - ) -> HashMap { - let mut actual: HashMap = HashMap::new(); - for (log_record, _) in output.iter() { - let record = &log_record.record; - assert_eq!(record.operation, Operation::Upsert); - assert_eq!(record.embedding.as_deref(), Some(&[0.0][..])); - - let metadata = record - .metadata - .as_ref() - .expect("statistics executor always sets metadata"); - - actual.insert(record.id.clone(), extract_metadata_tuple(metadata)); - } - actual - } - - fn build_statistics_metadata( - count: i64, - key: &str, - value_type: &str, - value: &str, - ) -> UpdateMetadata { - HashMap::from([ - ("count".to_string(), UpdateMetadataValue::Int(count)), - ("key".to_string(), UpdateMetadataValue::Str(key.to_string())), - ( - "type".to_string(), - UpdateMetadataValue::Str(value_type.to_string()), - ), - ( - "value".to_string(), - UpdateMetadataValue::Str(value.to_string()), - ), - ]) - } - - fn build_statistics_record(id: &str, metadata: UpdateMetadata, document: &str) -> LogRecord { - LogRecord { - log_offset: 0, - record: OperationRecord { - id: id.to_string(), - embedding: Some(vec![0.0]), - encoding: None, - metadata: Some(metadata), - document: Some(document.to_string()), - operation: Operation::Upsert, - }, - } - } - - fn build_complete_statistics_record( - key: &str, - value: &str, - value_type: &str, - type_prefix: &str, - count: i64, - ) -> LogRecord { - let metadata = build_statistics_metadata(count, key, value_type, value); - let id = format!("{key}::{type_prefix}:{value}"); - let document = format!("statistics about {key} for {type_prefix}:{value}"); - build_statistics_record(&id, metadata, &document) - } - - fn partition_output( - output: &Chunk, - ) -> (HashMap, Vec) { - let mut upserts: HashMap = HashMap::new(); - let mut deletes: Vec = Vec::new(); - - for (log_record, _) in output.iter() { - match log_record.record.operation { - Operation::Upsert => { - upserts.insert(log_record.record.id.clone(), log_record.record.clone()); - } - Operation::Delete => { - deletes.push(log_record.record.id.clone()); - assert!(log_record.record.metadata.is_none()); - assert!(log_record.record.embedding.is_none()); - } - other => panic!("unexpected operation in statistics output: {:?}", other), - } - } - - (upserts, deletes) - } - - fn partition_output_expect_no_upserts(output: &Chunk) -> Vec { - let mut deletes: Vec = Vec::new(); - - for (log_record, _) in output.iter() { - match log_record.record.operation { - Operation::Delete => { - deletes.push(log_record.record.id.clone()); - assert!(log_record.record.metadata.is_none()); - assert!(log_record.record.embedding.is_none()); - } - Operation::Upsert => { - panic!("unexpected upsert in empty-input statistics output"); - } - other => panic!("unexpected operation in statistics output: {:?}", other), - } - } - - deletes - } - - #[tokio::test] - async fn statistics_executor_counts_all_metadata_values() { - let executor = StatisticsFunctionExecutor(Box::new(CounterFunctionFactory)); - - let record_one = build_record( - "record-1", - HashMap::from([ - ("bool_key".to_string(), UpdateMetadataValue::Bool(true)), - ("int_key".to_string(), UpdateMetadataValue::Int(7)), - ("float_key".to_string(), UpdateMetadataValue::Float(2.5)), - ( - "str_key".to_string(), - UpdateMetadataValue::Str("alpha".to_string()), - ), - ( - "sparse_key".to_string(), - UpdateMetadataValue::SparseVector(SparseVector::new( - vec![1, 3], - vec![0.25, 0.75], - )), - ), - ]), - ); - let record_two = build_record( - "record-2", - HashMap::from([ - ("bool_key".to_string(), UpdateMetadataValue::Bool(false)), - ("int_key".to_string(), UpdateMetadataValue::Int(7)), - ("float_key".to_string(), UpdateMetadataValue::Float(2.5)), - ( - "str_key".to_string(), - UpdateMetadataValue::Str("alpha".to_string()), - ), - ( - "sparse_key".to_string(), - UpdateMetadataValue::SparseVector(SparseVector::new(vec![3], vec![0.5])), - ), - ]), - ); - - let input = Chunk::new(vec![record_one, record_two].into()); - - let output = executor - .execute(input, None) - .await - .expect("execution succeeds"); - - let actual = collect_statistics_map(&output); - - let float_value = format!("{:.16e}", 2.5_f64); - let expected: HashMap = HashMap::from([ - ( - format!("bool_key::b:{}", true), - ( - 1, - "bool_key".to_string(), - "bool".to_string(), - format!("{}", true), - ), - ), - ( - format!("bool_key::b:{}", false), - ( - 1, - "bool_key".to_string(), - "bool".to_string(), - format!("{}", false), - ), - ), - ( - format!("int_key::i:{}", 7), - ( - 2, - "int_key".to_string(), - "int".to_string(), - format!("{}", 7), - ), - ), - ( - format!("float_key::f:{float_value}"), - ( - 2, - "float_key".to_string(), - "float".to_string(), - float_value.clone(), - ), - ), - ( - format!("str_key::s:{}", "alpha"), - ( - 2, - "str_key".to_string(), - "str".to_string(), - "alpha".to_string(), - ), - ), - ( - format!("sparse_key::sv:{}", 1), - ( - 1, - "sparse_key".to_string(), - "sparse".to_string(), - format!("{}", 1), - ), - ), - ( - format!("sparse_key::sv:{}", 3), - ( - 2, - "sparse_key".to_string(), - "sparse".to_string(), - format!("{}", 3), - ), - ), - ]); - - assert_eq!(actual, expected); - } - - #[tokio::test] - async fn statistics_executor_groups_nan_float_values() { - let executor = StatisticsFunctionExecutor(Box::new(CounterFunctionFactory)); - - let record_one = build_record( - "nan-1", - HashMap::from([( - "float_key".to_string(), - UpdateMetadataValue::Float(f64::NAN), - )]), - ); - let record_two = build_record( - "nan-2", - HashMap::from([( - "float_key".to_string(), - UpdateMetadataValue::Float(f64::NAN), - )]), - ); - - let input = Chunk::new(vec![record_one, record_two].into()); - let output = executor - .execute(input, None) - .await - .expect("execution succeeds"); - - let actual = collect_statistics_map(&output); - assert_eq!(actual.len(), 1); - - let float_string = format!("{:.16e}", f64::NAN); - let expected_id = format!("float_key::f:{float_string}"); - let expected_entry = ( - 2, - "float_key".to_string(), - "float".to_string(), - float_string.clone(), - ); - assert_eq!( - actual - .get(&expected_id) - .expect("NaN metadata should be grouped under a single entry"), - &expected_entry - ); - } - - #[tokio::test] - async fn statistics_executor_ignores_delete_operations() { - let executor = StatisticsFunctionExecutor(Box::new(CounterFunctionFactory)); - - let upsert_record = build_record( - "record-upsert", - HashMap::from([("bool_key".to_string(), UpdateMetadataValue::Bool(true))]), - ); - let delete_record = build_record_with_operation( - "record-delete", - Operation::Delete, - HashMap::from([("bool_key".to_string(), UpdateMetadataValue::Bool(false))]), - ); - - let input = Chunk::new(vec![upsert_record, delete_record].into()); - let output = executor - .execute(input, None) - .await - .expect("execution succeeds"); - - let actual = collect_statistics_map(&output); - assert_eq!(actual.len(), 1); - - let true_id = format!("bool_key::b:{}", true); - assert!( - actual.contains_key(&true_id), - "upserted metadata should still be counted" - ); - - let false_id = format!("bool_key::b:{}", false); - assert!( - !actual.contains_key(&false_id), - "delete metadata should be ignored by the statistics executor" - ); - } - - #[tokio::test] - async fn statistics_executor_handles_empty_sparse_vectors() { - let executor = StatisticsFunctionExecutor(Box::new(CounterFunctionFactory)); - - let record = build_record( - "sparse-empty", - HashMap::from([( - "sparse_key".to_string(), - UpdateMetadataValue::SparseVector(SparseVector::new( - Vec::::new(), - Vec::::new(), - )), - )]), - ); - - let input = Chunk::new(vec![record].into()); - let output = executor - .execute(input, None) - .await - .expect("execution succeeds"); - - assert!(output.is_empty()); - } - - #[tokio::test] - async fn statistics_executor_skips_unconvertible_metadata_values() { - let executor = StatisticsFunctionExecutor(Box::new(CounterFunctionFactory)); - - let record = build_record( - "only", - HashMap::from([("skip".to_string(), UpdateMetadataValue::None)]), - ); - - let input = Chunk::new(vec![record].into()); - - let output = executor - .execute(input, None) - .await - .expect("execution succeeds"); - - assert_eq!(output.total_len(), 0); - assert_eq!(output.len(), 0); - assert!(output.is_empty()); - } - - #[tokio::test] - async fn statistics_executor_deletes_stale_records_from_segment() { - let executor = StatisticsFunctionExecutor(Box::new(CounterFunctionFactory)); - - let mut test_segment = TestDistributedSegment::new().await; - - let stale_record = build_complete_statistics_record("obsolete_key", "true", "bool", "b", 1); - - let fresh_record = build_complete_statistics_record("fresh_key", "1", "int", "i", 3); - - let existing_chunk = Chunk::new(vec![stale_record, fresh_record].into()); - - Box::pin(test_segment.compact_log(existing_chunk, 1)).await; - - let record_reader = Box::pin(RecordSegmentReader::from_segment( - &test_segment.record_segment, - &test_segment.blockfile_provider, - )) - .await - .expect("record segment reader creation succeeds"); - - let input = Chunk::new( - vec![build_record( - "input-1", - HashMap::from([("fresh_key".to_string(), UpdateMetadataValue::Int(1))]), - )] - .into(), - ); - - let output = executor - .execute(input, Some(&record_reader)) - .await - .expect("execution succeeds"); - - let (upserts, deletes) = partition_output(&output); - - assert_eq!(deletes, vec!["obsolete_key::b:true".to_string()]); - - let fresh_stats = upserts - .get("fresh_key::i:1") - .expect("fresh statistics record should be recreated"); - let metadata = fresh_stats - .metadata - .as_ref() - .expect("statistics executor always sets metadata"); - - let (count, key, value_type, value) = extract_metadata_tuple(metadata); - - assert_eq!(count, 1); - assert_eq!(key, "fresh_key"); - assert_eq!(value_type, "int"); - assert_eq!(value, "1"); - } - - #[tokio::test] - async fn statistics_executor_zeroes_output_when_input_empty() { - let executor = StatisticsFunctionExecutor(Box::new(CounterFunctionFactory)); - - let mut test_segment = TestDistributedSegment::new().await; - - let record = build_complete_statistics_record("empty_key", "initial", "str", "s", 2); - - let existing_chunk = Chunk::new(vec![record].into()); - Box::pin(test_segment.compact_log(existing_chunk, 1)).await; - - let record_reader = Box::pin(RecordSegmentReader::from_segment( - &test_segment.record_segment, - &test_segment.blockfile_provider, - )) - .await - .expect("record segment reader creation succeeds"); - - let empty_input: Chunk = Chunk::new(Vec::::new().into()); - - let output = executor - .execute(empty_input, Some(&record_reader)) - .await - .expect("execution succeeds"); - - let deletes = partition_output_expect_no_upserts(&output); - - assert_eq!(deletes, vec!["empty_key::s:initial".to_string()]); - } - - // TODO(tanujnay112): Reenable this after function compaction is brought back - /* - #[tokio::test] - async fn test_k8s_integration_statistics_function() { - use crate::config::RootConfig; - use crate::execution::orchestration::CompactOrchestrator; - use chroma_config::{registry::Registry, Configurable}; - use chroma_log::in_memory_log::{InMemoryLog, InternalLogRecord}; - use chroma_log::Log; - use chroma_segment::test::TestDistributedSegment; - use chroma_sysdb::SysDb; - use chroma_system::{Dispatcher, Orchestrator, System}; - use chroma_types::{CollectionUuid, Operation, OperationRecord, UpdateMetadataValue}; - use s3heap_service::client::{GrpcHeapService, GrpcHeapServiceConfig}; - use std::collections::HashMap; - - // Setup test environment - let config = RootConfig::default(); - let system = System::default(); - let registry = Registry::new(); - let dispatcher = Dispatcher::try_from_config(&config.query_service.dispatcher, ®istry) - .await - .expect("Should be able to initialize dispatcher"); - let dispatcher_handle = system.start_component(dispatcher); - - // Connect to Grpc SysDb (requires Tilt running) - let grpc_sysdb = chroma_sysdb::GrpcSysDb::try_from_config( - &chroma_sysdb::GrpcSysDbConfig { - host: "localhost".to_string(), - port: 50051, - connect_timeout_ms: 5000, - request_timeout_ms: 10000, - num_channels: 4, - }, - ®istry, - ) - .await - .expect("Should connect to grpc sysdb"); - let mut sysdb = SysDb::Grpc(grpc_sysdb); - - // Connect to Grpc Heap Service (requires Tilt running) - let heap_service = GrpcHeapService::try_from_config( - &(GrpcHeapServiceConfig::default(), system.clone()), - ®istry, - ) - .await - .expect("Should connect to grpc heap service"); - - let test_segments = TestDistributedSegment::new().await; - let mut in_memory_log = InMemoryLog::new(); - - // Create input collection - let collection_name = format!("test_statistics_{}", uuid::Uuid::new_v4()); - let collection_id = CollectionUuid::new(); - - sysdb - .create_collection( - test_segments.collection.tenant, - test_segments.collection.database, - collection_id, - collection_name, - vec![ - test_segments.record_segment.clone(), - test_segments.metadata_segment.clone(), - test_segments.vector_segment.clone(), - ], - None, - None, - None, - test_segments.collection.dimension, - false, - ) - .await - .expect("Collection create should be successful"); - - let tenant = "default_tenant".to_string(); - let db = "default_database".to_string(); - - // Set initial log position - sysdb - .flush_compaction( - tenant.clone(), - collection_id, - -1, - 0, - std::sync::Arc::new([]), - 0, - 0, - None, - ) - .await - .expect("Should be able to update log_position"); - - // Add 15 records with specific metadata we can verify - // 10 records with color="red", 5 with color="blue" - // 8 records with size=10, 7 with size=20 - for i in 0..15 { - let mut metadata = HashMap::new(); - - // First 10 are red, last 5 are blue - let color = if i < 10 { "red" } else { "blue" }; - metadata.insert( - "color".to_string(), - UpdateMetadataValue::Str(color.to_string()), - ); - - // First 8 are size 10, last 7 are size 20 - let size = if i < 8 { 10 } else { 20 }; - metadata.insert("size".to_string(), UpdateMetadataValue::Int(size)); - - let log_record = LogRecord { - log_offset: i as i64, - record: OperationRecord { - id: format!("record_{}", i), - embedding: Some(vec![ - 0.0; - test_segments.collection.dimension.unwrap_or(384) - as usize - ]), - encoding: None, - metadata: Some(metadata), - document: Some(format!("doc {}", i)), - operation: Operation::Upsert, - }, - }; - - in_memory_log.add_log( - collection_id, - InternalLogRecord { - collection_id, - log_offset: i as i64, - log_ts: i as i64, - record: log_record, - }, - ) - } - - let log = Log::InMemory(in_memory_log); - let attached_function_name = "test_statistics"; - let output_collection_name = format!("test_stats_output_{}", uuid::Uuid::new_v4()); - - // Create statistics attached function via sysdb - let attached_function_id = sysdb - .create_attached_function( - attached_function_name.to_string(), - "statistics".to_string(), - collection_id, - output_collection_name, - serde_json::Value::Null, - tenant.clone(), - db.clone(), - 10, - ) - .await - .expect("Attached function creation should succeed"); - - // Initial compaction - let compact_orchestrator = CompactOrchestrator::new( - collection_id, - false, - 50, - 1000, - 50, - log.clone(), - sysdb.clone(), - test_segments.blockfile_provider.clone(), - test_segments.hnsw_provider.clone(), - test_segments.spann_provider.clone(), - dispatcher_handle.clone(), - None, - ); - - let result = compact_orchestrator.run(system.clone()).await; - assert!( - result.is_ok(), - "Initial compaction should succeed: {:?}", - result.err() - ); - - // Get nonce for attached function run - let attached_function = sysdb - .get_attached_function_by_name(collection_id, attached_function_name.to_string()) - .await - .expect("Attached function should be found"); - let execution_nonce = attached_function.lowest_live_nonce.unwrap(); - - // Run statistics function - let compact_orchestrator = CompactOrchestrator::new_for_attached_function( - collection_id, - false, - 50, - 1000, - 50, - log.clone(), - sysdb.clone(), - heap_service, - test_segments.blockfile_provider.clone(), - test_segments.hnsw_provider.clone(), - test_segments.spann_provider.clone(), - dispatcher_handle, - None, - attached_function_id, - execution_nonce, - ); - - let result = compact_orchestrator.run(system).await; - assert!( - result.is_ok(), - "Statistics function execution should succeed: {:?}", - result.err() - ); - - // Verify statistics were generated - let updated_attached_function = sysdb - .get_attached_function_by_name(collection_id, attached_function_name.to_string()) - .await - .expect("Attached function should be found"); - - // Note: completion_offset is 13, but all 15 records (0-14) were processed - assert_eq!( - updated_attached_function.completion_offset, 13, - "Completion offset should be 13" - ); - - let output_collection_id = updated_attached_function.output_collection_id.unwrap(); - - // Read statistics from output collection - let output_info = sysdb - .get_collection_with_segments(output_collection_id) - .await - .expect("Should get output collection"); - let reader = Box::pin(RecordSegmentReader::from_segment( - &output_info.record_segment, - &test_segments.blockfile_provider, - )) - .await - .expect("Should create reader"); - - // Verify statistics records exist - let max_offset_id = reader.get_max_offset_id(); - assert!( - max_offset_id > 0, - "Statistics function should have created records" - ); - - // Verify actual statistics content - use futures::stream::StreamExt; - let mut stream = reader.get_data_stream(0..=max_offset_id).await; - let mut stats_by_key_value: HashMap<(String, String), i64> = HashMap::new(); - - while let Some(result) = stream.next().await { - let (_, record) = result.expect("Should read record"); - - // Verify metadata structure - let metadata = record - .metadata - .expect("Statistics records should have metadata"); - - // All statistics records should have these fields - assert!(metadata.contains_key("count"), "Should have count field"); - assert!(metadata.contains_key("key"), "Should have key field"); - assert!(metadata.contains_key("type"), "Should have type field"); - assert!(metadata.contains_key("value"), "Should have value field"); - - // Extract key, value, and count - let key = match metadata.get("key") { - Some(chroma_types::MetadataValue::Str(k)) => k.clone(), - _ => panic!("key should be a string"), - }; - let value = match metadata.get("value") { - Some(chroma_types::MetadataValue::Str(v)) => v.clone(), - _ => panic!("value should be a string"), - }; - let count = match metadata.get("count") { - Some(chroma_types::MetadataValue::Int(c)) => *c, - _ => panic!("count should be an int"), - }; - - stats_by_key_value.insert((key, value), count); - } - - // Verify expected statistics: - // All 15 records (0-14) were processed - // Expected: color="red" -> 10 (records 0-9), color="blue" -> 5 (records 10-14) - // Expected: size=10 -> 8 (records 0-7), size=20 -> 7 (records 8-14) - assert_eq!( - stats_by_key_value.get(&("color".to_string(), "red".to_string())), - Some(&10), - "Should have 10 records with color=red (records 0-9)" - ); - assert_eq!( - stats_by_key_value.get(&("color".to_string(), "blue".to_string())), - Some(&5), - "Should have 5 records with color=blue (records 10-14)" - ); - assert_eq!( - stats_by_key_value.get(&("size".to_string(), "10".to_string())), - Some(&8), - "Should have 8 records with size=10 (records 0-7)" - ); - assert_eq!( - stats_by_key_value.get(&("size".to_string(), "20".to_string())), - Some(&7), - "Should have 7 records with size=20 (records 8-14)" - ); - - // Verify we found exactly 4 unique statistics (2 colors + 2 sizes) - assert_eq!( - stats_by_key_value.len(), - 4, - "Should have exactly 4 unique statistics" - ); - - // Verify total count is 30 (15 records × 2 metadata keys) - let total_count: i64 = stats_by_key_value.values().sum(); - assert_eq!( - total_count, 30, - "Total count should be 30 (15 records × 2 metadata keys)" - ); - - tracing::info!( - "Statistics function test completed successfully. Found {} unique statistics with correct counts", - stats_by_key_value.len() - ); - } - */ -} +// #[cfg(test)] +// mod tests { +// use std::collections::HashMap; + +// use chroma_segment::{blockfile_record::RecordSegmentReader, test::TestDistributedSegment}; +// use chroma_types::{ +// Chunk, LogRecord, Operation, OperationRecord, SparseVector, UpdateMetadata, +// UpdateMetadataValue, +// }; + +// use super::*; + +// fn build_record(id: &str, metadata: HashMap) -> LogRecord { +// build_record_with_operation(id, Operation::Upsert, metadata) +// } + +// fn build_record_with_operation( +// id: &str, +// operation: Operation, +// metadata: HashMap, +// ) -> LogRecord { +// LogRecord { +// log_offset: 0, +// record: OperationRecord { +// id: id.to_string(), +// embedding: None, +// encoding: None, +// metadata: Some(metadata), +// document: None, +// operation, +// }, +// } +// } + +// fn extract_metadata_tuple(metadata: &UpdateMetadata) -> (i64, String, String, String) { +// let count = match metadata.get("count") { +// Some(UpdateMetadataValue::Int(value)) => *value, +// other => panic!("unexpected count metadata: {other:?}"), +// }; +// let key = match metadata.get("key") { +// Some(UpdateMetadataValue::Str(value)) => value.clone(), +// other => panic!("unexpected key metadata: {other:?}"), +// }; +// let value_type = match metadata.get("type") { +// Some(UpdateMetadataValue::Str(value)) => value.clone(), +// other => panic!("unexpected type metadata: {other:?}"), +// }; +// let value = match metadata.get("value") { +// Some(UpdateMetadataValue::Str(value)) => value.clone(), +// other => panic!("unexpected value metadata: {other:?}"), +// }; +// (count, key, value_type, value) +// } + +// fn collect_statistics_map( +// output: &Chunk, +// ) -> HashMap { +// let mut actual: HashMap = HashMap::new(); +// for (log_record, _) in output.iter() { +// let record = &log_record.record; +// assert_eq!(record.operation, Operation::Upsert); +// assert_eq!(record.embedding.as_deref(), Some(&[0.0][..])); + +// let metadata = record +// .metadata +// .as_ref() +// .expect("statistics executor always sets metadata"); + +// actual.insert(record.id.clone(), extract_metadata_tuple(metadata)); +// } +// actual +// } + +// fn build_statistics_metadata( +// count: i64, +// key: &str, +// value_type: &str, +// value: &str, +// ) -> UpdateMetadata { +// HashMap::from([ +// ("count".to_string(), UpdateMetadataValue::Int(count)), +// ("key".to_string(), UpdateMetadataValue::Str(key.to_string())), +// ( +// "type".to_string(), +// UpdateMetadataValue::Str(value_type.to_string()), +// ), +// ( +// "value".to_string(), +// UpdateMetadataValue::Str(value.to_string()), +// ), +// ]) +// } + +// fn build_statistics_record(id: &str, metadata: UpdateMetadata, document: &str) -> LogRecord { +// LogRecord { +// log_offset: 0, +// record: OperationRecord { +// id: id.to_string(), +// embedding: Some(vec![0.0]), +// encoding: None, +// metadata: Some(metadata), +// document: Some(document.to_string()), +// operation: Operation::Upsert, +// }, +// } +// } + +// fn build_complete_statistics_record( +// key: &str, +// value: &str, +// value_type: &str, +// type_prefix: &str, +// count: i64, +// ) -> LogRecord { +// let metadata = build_statistics_metadata(count, key, value_type, value); +// let id = format!("{key}::{type_prefix}:{value}"); +// let document = format!("statistics about {key} for {type_prefix}:{value}"); +// build_statistics_record(&id, metadata, &document) +// } + +// fn partition_output( +// output: &Chunk, +// ) -> (HashMap, Vec) { +// let mut upserts: HashMap = HashMap::new(); +// let mut deletes: Vec = Vec::new(); + +// for (log_record, _) in output.iter() { +// match log_record.record.operation { +// Operation::Upsert => { +// upserts.insert(log_record.record.id.clone(), log_record.record.clone()); +// } +// Operation::Delete => { +// deletes.push(log_record.record.id.clone()); +// assert!(log_record.record.metadata.is_none()); +// assert!(log_record.record.embedding.is_none()); +// } +// other => panic!("unexpected operation in statistics output: {:?}", other), +// } +// } + +// (upserts, deletes) +// } + +// fn partition_output_expect_no_upserts(output: &Chunk) -> Vec { +// let mut deletes: Vec = Vec::new(); + +// for (log_record, _) in output.iter() { +// match log_record.record.operation { +// Operation::Delete => { +// deletes.push(log_record.record.id.clone()); +// assert!(log_record.record.metadata.is_none()); +// assert!(log_record.record.embedding.is_none()); +// } +// Operation::Upsert => { +// panic!("unexpected upsert in empty-input statistics output"); +// } +// other => panic!("unexpected operation in statistics output: {:?}", other), +// } +// } + +// deletes +// } + +// #[tokio::test] +// async fn statistics_executor_counts_all_metadata_values() { +// let executor = StatisticsFunctionExecutor(Box::new(CounterFunctionFactory)); + +// let record_one = build_record( +// "record-1", +// HashMap::from([ +// ("bool_key".to_string(), UpdateMetadataValue::Bool(true)), +// ("int_key".to_string(), UpdateMetadataValue::Int(7)), +// ("float_key".to_string(), UpdateMetadataValue::Float(2.5)), +// ( +// "str_key".to_string(), +// UpdateMetadataValue::Str("alpha".to_string()), +// ), +// ( +// "sparse_key".to_string(), +// UpdateMetadataValue::SparseVector(SparseVector::new( +// vec![1, 3], +// vec![0.25, 0.75], +// )), +// ), +// ]), +// ); +// let record_two = build_record( +// "record-2", +// HashMap::from([ +// ("bool_key".to_string(), UpdateMetadataValue::Bool(false)), +// ("int_key".to_string(), UpdateMetadataValue::Int(7)), +// ("float_key".to_string(), UpdateMetadataValue::Float(2.5)), +// ( +// "str_key".to_string(), +// UpdateMetadataValue::Str("alpha".to_string()), +// ), +// ( +// "sparse_key".to_string(), +// UpdateMetadataValue::SparseVector(SparseVector::new(vec![3], vec![0.5])), +// ), +// ]), +// ); + +// let input = Chunk::new(vec![record_one, record_two].into()); + +// let output = executor +// .execute(input, None) +// .await +// .expect("execution succeeds"); + +// let actual = collect_statistics_map(&output); + +// let float_value = format!("{:.16e}", 2.5_f64); +// let expected: HashMap = HashMap::from([ +// ( +// format!("bool_key::b:{}", true), +// ( +// 1, +// "bool_key".to_string(), +// "bool".to_string(), +// format!("{}", true), +// ), +// ), +// ( +// format!("bool_key::b:{}", false), +// ( +// 1, +// "bool_key".to_string(), +// "bool".to_string(), +// format!("{}", false), +// ), +// ), +// ( +// format!("int_key::i:{}", 7), +// ( +// 2, +// "int_key".to_string(), +// "int".to_string(), +// format!("{}", 7), +// ), +// ), +// ( +// format!("float_key::f:{float_value}"), +// ( +// 2, +// "float_key".to_string(), +// "float".to_string(), +// float_value.clone(), +// ), +// ), +// ( +// format!("str_key::s:{}", "alpha"), +// ( +// 2, +// "str_key".to_string(), +// "str".to_string(), +// "alpha".to_string(), +// ), +// ), +// ( +// format!("sparse_key::sv:{}", 1), +// ( +// 1, +// "sparse_key".to_string(), +// "sparse".to_string(), +// format!("{}", 1), +// ), +// ), +// ( +// format!("sparse_key::sv:{}", 3), +// ( +// 2, +// "sparse_key".to_string(), +// "sparse".to_string(), +// format!("{}", 3), +// ), +// ), +// ]); + +// assert_eq!(actual, expected); +// } + +// #[tokio::test] +// async fn statistics_executor_groups_nan_float_values() { +// let executor = StatisticsFunctionExecutor(Box::new(CounterFunctionFactory)); + +// let record_one = build_record( +// "nan-1", +// HashMap::from([( +// "float_key".to_string(), +// UpdateMetadataValue::Float(f64::NAN), +// )]), +// ); +// let record_two = build_record( +// "nan-2", +// HashMap::from([( +// "float_key".to_string(), +// UpdateMetadataValue::Float(f64::NAN), +// )]), +// ); + +// let input = Chunk::new(vec![record_one, record_two].into()); +// let output = executor +// .execute(input, None) +// .await +// .expect("execution succeeds"); + +// let actual = collect_statistics_map(&output); +// assert_eq!(actual.len(), 1); + +// let float_string = format!("{:.16e}", f64::NAN); +// let expected_id = format!("float_key::f:{float_string}"); +// let expected_entry = ( +// 2, +// "float_key".to_string(), +// "float".to_string(), +// float_string.clone(), +// ); +// assert_eq!( +// actual +// .get(&expected_id) +// .expect("NaN metadata should be grouped under a single entry"), +// &expected_entry +// ); +// } + +// #[tokio::test] +// async fn statistics_executor_ignores_delete_operations() { +// let executor = StatisticsFunctionExecutor(Box::new(CounterFunctionFactory)); + +// let upsert_record = build_record( +// "record-upsert", +// HashMap::from([("bool_key".to_string(), UpdateMetadataValue::Bool(true))]), +// ); +// let delete_record = build_record_with_operation( +// "record-delete", +// Operation::Delete, +// HashMap::from([("bool_key".to_string(), UpdateMetadataValue::Bool(false))]), +// ); + +// let input = Chunk::new(vec![upsert_record, delete_record].into()); +// let output = executor +// .execute(input, None) +// .await +// .expect("execution succeeds"); + +// let actual = collect_statistics_map(&output); +// assert_eq!(actual.len(), 1); + +// let true_id = format!("bool_key::b:{}", true); +// assert!( +// actual.contains_key(&true_id), +// "upserted metadata should still be counted" +// ); + +// let false_id = format!("bool_key::b:{}", false); +// assert!( +// !actual.contains_key(&false_id), +// "delete metadata should be ignored by the statistics executor" +// ); +// } + +// #[tokio::test] +// async fn statistics_executor_handles_empty_sparse_vectors() { +// let executor = StatisticsFunctionExecutor(Box::new(CounterFunctionFactory)); + +// let record = build_record( +// "sparse-empty", +// HashMap::from([( +// "sparse_key".to_string(), +// UpdateMetadataValue::SparseVector(SparseVector::new( +// Vec::::new(), +// Vec::::new(), +// )), +// )]), +// ); + +// let input = Chunk::new(vec![record].into()); +// let output = executor +// .execute(input, None) +// .await +// .expect("execution succeeds"); + +// assert!(output.is_empty()); +// } + +// #[tokio::test] +// async fn statistics_executor_skips_unconvertible_metadata_values() { +// let executor = StatisticsFunctionExecutor(Box::new(CounterFunctionFactory)); + +// let record = build_record( +// "only", +// HashMap::from([("skip".to_string(), UpdateMetadataValue::None)]), +// ); + +// let input = Chunk::new(vec![record].into()); + +// let output = executor +// .execute(input, None) +// .await +// .expect("execution succeeds"); + +// assert_eq!(output.total_len(), 0); +// assert_eq!(output.len(), 0); +// assert!(output.is_empty()); +// } + +// #[tokio::test] +// async fn statistics_executor_deletes_stale_records_from_segment() { +// let executor = StatisticsFunctionExecutor(Box::new(CounterFunctionFactory)); + +// let mut test_segment = TestDistributedSegment::new().await; + +// let stale_record = build_complete_statistics_record("obsolete_key", "true", "bool", "b", 1); + +// let fresh_record = build_complete_statistics_record("fresh_key", "1", "int", "i", 3); + +// let existing_chunk = Chunk::new(vec![stale_record, fresh_record].into()); + +// Box::pin(test_segment.compact_log(existing_chunk, 1)).await; + +// let record_reader = Box::pin(RecordSegmentReader::from_segment( +// &test_segment.record_segment, +// &test_segment.blockfile_provider, +// )) +// .await +// .expect("record segment reader creation succeeds"); + +// let input = Chunk::new( +// vec![build_record( +// "input-1", +// HashMap::from([("fresh_key".to_string(), UpdateMetadataValue::Int(1))]), +// )] +// .into(), +// ); + +// let output = executor +// .execute(input, Some(&record_reader)) +// .await +// .expect("execution succeeds"); + +// let (upserts, deletes) = partition_output(&output); + +// assert_eq!(deletes, vec!["obsolete_key::b:true".to_string()]); + +// let fresh_stats = upserts +// .get("fresh_key::i:1") +// .expect("fresh statistics record should be recreated"); +// let metadata = fresh_stats +// .metadata +// .as_ref() +// .expect("statistics executor always sets metadata"); + +// let (count, key, value_type, value) = extract_metadata_tuple(metadata); + +// assert_eq!(count, 1); +// assert_eq!(key, "fresh_key"); +// assert_eq!(value_type, "int"); +// assert_eq!(value, "1"); +// } + +// #[tokio::test] +// async fn statistics_executor_zeroes_output_when_input_empty() { +// let executor = StatisticsFunctionExecutor(Box::new(CounterFunctionFactory)); + +// let mut test_segment = TestDistributedSegment::new().await; + +// let record = build_complete_statistics_record("empty_key", "initial", "str", "s", 2); + +// let existing_chunk = Chunk::new(vec![record].into()); +// Box::pin(test_segment.compact_log(existing_chunk, 1)).await; + +// let record_reader = Box::pin(RecordSegmentReader::from_segment( +// &test_segment.record_segment, +// &test_segment.blockfile_provider, +// )) +// .await +// .expect("record segment reader creation succeeds"); + +// let empty_input: Chunk = Chunk::new(Vec::::new().into()); + +// let output = executor +// .execute(empty_input, Some(&record_reader)) +// .await +// .expect("execution succeeds"); + +// let deletes = partition_output_expect_no_upserts(&output); + +// assert_eq!(deletes, vec!["empty_key::s:initial".to_string()]); +// } + +// // TODO(tanujnay112): Reenable this after function compaction is brought back +// /* +// #[tokio::test] +// async fn test_k8s_integration_statistics_function() { +// use crate::config::RootConfig; +// use crate::execution::orchestration::CompactOrchestrator; +// use chroma_config::{registry::Registry, Configurable}; +// use chroma_log::in_memory_log::{InMemoryLog, InternalLogRecord}; +// use chroma_log::Log; +// use chroma_segment::test::TestDistributedSegment; +// use chroma_sysdb::SysDb; +// use chroma_system::{Dispatcher, Orchestrator, System}; +// use chroma_types::{CollectionUuid, Operation, OperationRecord, UpdateMetadataValue}; +// use s3heap_service::client::{GrpcHeapService, GrpcHeapServiceConfig}; +// use std::collections::HashMap; + +// // Setup test environment +// let config = RootConfig::default(); +// let system = System::default(); +// let registry = Registry::new(); +// let dispatcher = Dispatcher::try_from_config(&config.query_service.dispatcher, ®istry) +// .await +// .expect("Should be able to initialize dispatcher"); +// let dispatcher_handle = system.start_component(dispatcher); + +// // Connect to Grpc SysDb (requires Tilt running) +// let grpc_sysdb = chroma_sysdb::GrpcSysDb::try_from_config( +// &chroma_sysdb::GrpcSysDbConfig { +// host: "localhost".to_string(), +// port: 50051, +// connect_timeout_ms: 5000, +// request_timeout_ms: 10000, +// num_channels: 4, +// }, +// ®istry, +// ) +// .await +// .expect("Should connect to grpc sysdb"); +// let mut sysdb = SysDb::Grpc(grpc_sysdb); + +// // Connect to Grpc Heap Service (requires Tilt running) +// let heap_service = GrpcHeapService::try_from_config( +// &(GrpcHeapServiceConfig::default(), system.clone()), +// ®istry, +// ) +// .await +// .expect("Should connect to grpc heap service"); + +// let test_segments = TestDistributedSegment::new().await; +// let mut in_memory_log = InMemoryLog::new(); + +// // Create input collection +// let collection_name = format!("test_statistics_{}", uuid::Uuid::new_v4()); +// let collection_id = CollectionUuid::new(); + +// sysdb +// .create_collection( +// test_segments.collection.tenant, +// test_segments.collection.database, +// collection_id, +// collection_name, +// vec![ +// test_segments.record_segment.clone(), +// test_segments.metadata_segment.clone(), +// test_segments.vector_segment.clone(), +// ], +// None, +// None, +// None, +// test_segments.collection.dimension, +// false, +// ) +// .await +// .expect("Collection create should be successful"); + +// let tenant = "default_tenant".to_string(); +// let db = "default_database".to_string(); + +// // Set initial log position +// sysdb +// .flush_compaction( +// tenant.clone(), +// collection_id, +// -1, +// 0, +// std::sync::Arc::new([]), +// 0, +// 0, +// None, +// ) +// .await +// .expect("Should be able to update log_position"); + +// // Add 15 records with specific metadata we can verify +// // 10 records with color="red", 5 with color="blue" +// // 8 records with size=10, 7 with size=20 +// for i in 0..15 { +// let mut metadata = HashMap::new(); + +// // First 10 are red, last 5 are blue +// let color = if i < 10 { "red" } else { "blue" }; +// metadata.insert( +// "color".to_string(), +// UpdateMetadataValue::Str(color.to_string()), +// ); + +// // First 8 are size 10, last 7 are size 20 +// let size = if i < 8 { 10 } else { 20 }; +// metadata.insert("size".to_string(), UpdateMetadataValue::Int(size)); + +// let log_record = LogRecord { +// log_offset: i as i64, +// record: OperationRecord { +// id: format!("record_{}", i), +// embedding: Some(vec![ +// 0.0; +// test_segments.collection.dimension.unwrap_or(384) +// as usize +// ]), +// encoding: None, +// metadata: Some(metadata), +// document: Some(format!("doc {}", i)), +// operation: Operation::Upsert, +// }, +// }; + +// in_memory_log.add_log( +// collection_id, +// InternalLogRecord { +// collection_id, +// log_offset: i as i64, +// log_ts: i as i64, +// record: log_record, +// }, +// ) +// } + +// let log = Log::InMemory(in_memory_log); +// let attached_function_name = "test_statistics"; +// let output_collection_name = format!("test_stats_output_{}", uuid::Uuid::new_v4()); + +// // Create statistics attached function via sysdb +// let attached_function_id = sysdb +// .create_attached_function( +// attached_function_name.to_string(), +// "statistics".to_string(), +// collection_id, +// output_collection_name, +// serde_json::Value::Null, +// tenant.clone(), +// db.clone(), +// 10, +// ) +// .await +// .expect("Attached function creation should succeed"); + +// // Initial compaction +// let compact_orchestrator = CompactOrchestrator::new( +// collection_id, +// false, +// 50, +// 1000, +// 50, +// log.clone(), +// sysdb.clone(), +// test_segments.blockfile_provider.clone(), +// test_segments.hnsw_provider.clone(), +// test_segments.spann_provider.clone(), +// dispatcher_handle.clone(), +// None, +// ); + +// let result = compact_orchestrator.run(system.clone()).await; +// assert!( +// result.is_ok(), +// "Initial compaction should succeed: {:?}", +// result.err() +// ); + +// // Get nonce for attached function run +// let attached_function = sysdb +// .get_attached_function_by_name(collection_id, attached_function_name.to_string()) +// .await +// .expect("Attached function should be found"); +// let execution_nonce = attached_function.lowest_live_nonce.unwrap(); + +// // Run statistics function +// let compact_orchestrator = CompactOrchestrator::new_for_attached_function( +// collection_id, +// false, +// 50, +// 1000, +// 50, +// log.clone(), +// sysdb.clone(), +// heap_service, +// test_segments.blockfile_provider.clone(), +// test_segments.hnsw_provider.clone(), +// test_segments.spann_provider.clone(), +// dispatcher_handle, +// None, +// attached_function_id, +// execution_nonce, +// ); + +// let result = compact_orchestrator.run(system).await; +// assert!( +// result.is_ok(), +// "Statistics function execution should succeed: {:?}", +// result.err() +// ); + +// // Verify statistics were generated +// let updated_attached_function = sysdb +// .get_attached_function_by_name(collection_id, attached_function_name.to_string()) +// .await +// .expect("Attached function should be found"); + +// // Note: completion_offset is 13, but all 15 records (0-14) were processed +// assert_eq!( +// updated_attached_function.completion_offset, 13, +// "Completion offset should be 13" +// ); + +// let output_collection_id = updated_attached_function.output_collection_id.unwrap(); + +// // Read statistics from output collection +// let output_info = sysdb +// .get_collection_with_segments(output_collection_id) +// .await +// .expect("Should get output collection"); +// let reader = Box::pin(RecordSegmentReader::from_segment( +// &output_info.record_segment, +// &test_segments.blockfile_provider, +// )) +// .await +// .expect("Should create reader"); + +// // Verify statistics records exist +// let max_offset_id = reader.get_max_offset_id(); +// assert!( +// max_offset_id > 0, +// "Statistics function should have created records" +// ); + +// // Verify actual statistics content +// use futures::stream::StreamExt; +// let mut stream = reader.get_data_stream(0..=max_offset_id).await; +// let mut stats_by_key_value: HashMap<(String, String), i64> = HashMap::new(); + +// while let Some(result) = stream.next().await { +// let (_, record) = result.expect("Should read record"); + +// // Verify metadata structure +// let metadata = record +// .metadata +// .expect("Statistics records should have metadata"); + +// // All statistics records should have these fields +// assert!(metadata.contains_key("count"), "Should have count field"); +// assert!(metadata.contains_key("key"), "Should have key field"); +// assert!(metadata.contains_key("type"), "Should have type field"); +// assert!(metadata.contains_key("value"), "Should have value field"); + +// // Extract key, value, and count +// let key = match metadata.get("key") { +// Some(chroma_types::MetadataValue::Str(k)) => k.clone(), +// _ => panic!("key should be a string"), +// }; +// let value = match metadata.get("value") { +// Some(chroma_types::MetadataValue::Str(v)) => v.clone(), +// _ => panic!("value should be a string"), +// }; +// let count = match metadata.get("count") { +// Some(chroma_types::MetadataValue::Int(c)) => *c, +// _ => panic!("count should be an int"), +// }; + +// stats_by_key_value.insert((key, value), count); +// } + +// // Verify expected statistics: +// // All 15 records (0-14) were processed +// // Expected: color="red" -> 10 (records 0-9), color="blue" -> 5 (records 10-14) +// // Expected: size=10 -> 8 (records 0-7), size=20 -> 7 (records 8-14) +// assert_eq!( +// stats_by_key_value.get(&("color".to_string(), "red".to_string())), +// Some(&10), +// "Should have 10 records with color=red (records 0-9)" +// ); +// assert_eq!( +// stats_by_key_value.get(&("color".to_string(), "blue".to_string())), +// Some(&5), +// "Should have 5 records with color=blue (records 10-14)" +// ); +// assert_eq!( +// stats_by_key_value.get(&("size".to_string(), "10".to_string())), +// Some(&8), +// "Should have 8 records with size=10 (records 0-7)" +// ); +// assert_eq!( +// stats_by_key_value.get(&("size".to_string(), "20".to_string())), +// Some(&7), +// "Should have 7 records with size=20 (records 8-14)" +// ); + +// // Verify we found exactly 4 unique statistics (2 colors + 2 sizes) +// assert_eq!( +// stats_by_key_value.len(), +// 4, +// "Should have exactly 4 unique statistics" +// ); + +// // Verify total count is 30 (15 records × 2 metadata keys) +// let total_count: i64 = stats_by_key_value.values().sum(); +// assert_eq!( +// total_count, 30, +// "Total count should be 30 (15 records × 2 metadata keys)" +// ); + +// tracing::info!( +// "Statistics function test completed successfully. Found {} unique statistics with correct counts", +// stats_by_key_value.len() +// ); +// } +// */ +// } diff --git a/rust/worker/src/execution/operators/execute_task.rs b/rust/worker/src/execution/operators/execute_task.rs index c4a5171ef91..857b4d3ceed 100644 --- a/rust/worker/src/execution/operators/execute_task.rs +++ b/rust/worker/src/execution/operators/execute_task.rs @@ -3,15 +3,19 @@ use chroma_blockstore::provider::BlockfileProvider; use chroma_error::ChromaError; use chroma_log::Log; use chroma_segment::blockfile_record::{RecordSegmentReader, RecordSegmentReaderCreationError}; +use chroma_segment::types::{HydratedMaterializedLogRecord, MaterializeLogsResult}; use chroma_system::{Operator, OperatorType}; use chroma_types::{ - Chunk, CollectionUuid, LogRecord, Operation, OperationRecord, Segment, UpdateMetadataValue, - FUNCTION_RECORD_COUNTER_ID, FUNCTION_STATISTICS_ID, + AttachedFunction, AttachedFunctionUuid, Chunk, CollectionUuid, LogRecord, Operation, + OperationRecord, Segment, UpdateMetadataValue, FUNCTION_RECORD_COUNTER_ID, + FUNCTION_STATISTICS_ID, }; use std::sync::Arc; use thiserror::Error; +use uuid::Uuid; use crate::execution::functions::{CounterFunctionFactory, StatisticsFunctionExecutor}; +use crate::execution::operators::materialize_logs::MaterializeLogOutput; /// Trait for attached function executors that process input records and produce output records. /// Implementors can read from the output collection to maintain state across executions. @@ -20,14 +24,14 @@ pub trait AttachedFunctionExecutor: Send + Sync + std::fmt::Debug { /// Execute the attached function logic on input records. /// /// # Arguments - /// * `input_records` - The log records to process + /// * `input_records` - The hydrated materialized log records to process /// * `output_reader` - Optional reader for the output collection's compacted data /// /// # Returns /// The output records to be written to the output collection async fn execute( &self, - input_records: Chunk, + input_records: Chunk>, output_reader: Option<&RecordSegmentReader<'_>>, ) -> Result, Box>; } @@ -41,35 +45,34 @@ pub struct CountAttachedFunction; impl AttachedFunctionExecutor for CountAttachedFunction { async fn execute( &self, - input_records: Chunk, + input_records: Chunk>, _output_reader: Option<&RecordSegmentReader<'_>>, ) -> Result, Box> { let records_count = input_records.len() as i64; - let new_total_count = records_count; + println!("new_total_count is {}", new_total_count); + // Create output record with updated count let mut metadata = std::collections::HashMap::new(); metadata.insert( "total_count".to_string(), - UpdateMetadataValue::Int(new_total_count), + chroma_types::UpdateMetadataValue::Int(new_total_count), ); - let operation_record = OperationRecord { - id: "attached_function_result".to_string(), - embedding: Some(vec![0.0]), - encoding: None, - metadata: Some(metadata), - document: None, - operation: Operation::Upsert, - }; - - let log_record = LogRecord { - log_offset: 0, // Will be set by caller - record: operation_record, + let output_record = LogRecord { + log_offset: 0, // Will be set by the orchestrator + record: OperationRecord { + id: "function_output".to_string(), + embedding: Some(vec![0.0]), + encoding: None, + metadata: Some(metadata), + document: Some(format!("Processed {} records", records_count)), + operation: Operation::Upsert, + }, }; - Ok(Chunk::new(Arc::new([log_record]))) + Ok(Chunk::new(std::sync::Arc::from(vec![output_record]))) } } @@ -84,12 +87,11 @@ pub struct ExecuteAttachedFunctionOperator { impl ExecuteAttachedFunctionOperator { /// Create a new ExecuteAttachedFunctionOperator from an AttachedFunction. /// The executor is selected based on the function_id in the attached function. - #[allow(dead_code)] pub(crate) fn from_attached_function( - attached_function: &chroma_types::AttachedFunction, + function_id: Uuid, log_client: Log, ) -> Result { - let executor: Arc = match attached_function.function_id { + let executor: Arc = match function_id { // For the record counter, use CountAttachedFunction FUNCTION_RECORD_COUNTER_ID => Arc::new(CountAttachedFunction), // For statistics, use StatisticsFunctionExecutor with CounterFunctionFactory @@ -97,13 +99,10 @@ impl ExecuteAttachedFunctionOperator { Arc::new(StatisticsFunctionExecutor(Box::new(CounterFunctionFactory))) } _ => { - tracing::error!( - "Unknown function_id UUID: {}", - attached_function.function_id - ); + tracing::error!("Unknown function_id UUID: {}", function_id); return Err(ExecuteAttachedFunctionError::InvalidUuid(format!( "Unknown function_id UUID: {}", - attached_function.function_id + function_id ))); } }; @@ -118,8 +117,8 @@ impl ExecuteAttachedFunctionOperator { /// Input for the ExecuteAttachedFunction operator #[derive(Debug)] pub struct ExecuteAttachedFunctionInput { - /// The fetched log records to process - pub log_records: Chunk, + /// The materialized log outputs to process + pub materialized_logs: Arc>, /// The tenant ID pub tenant_id: String, /// The output collection ID where results are written @@ -188,13 +187,11 @@ impl Operator input: &ExecuteAttachedFunctionInput, ) -> Result { tracing::info!( - "[ExecuteAttachedFunction]: Processing {} records for output collection {}", - input.log_records.len(), + "[ExecuteAttachedFunction]: Processing {} materialized log outputs for output collection {}", + input.materialized_logs.len(), input.output_collection_id ); - let records_count = input.log_records.len() as u64; - // Create record segment reader from the output collection's record segment let record_segment_reader = match Box::pin(RecordSegmentReader::from_segment( &input.output_record_segment, @@ -211,33 +208,41 @@ impl Operator Err(e) => return Err((*e).into()), }; + // Process all materialized logs and hydrate the records + let mut all_hydrated_records = Vec::new(); + let mut total_records_processed = 0u64; + + for materialized_log in input.materialized_logs.iter() { + // Use the iterator to process each materialized record + for borrowed_record in materialized_log.result.iter() { + // Hydrate the record using the same pattern as materialize_logs operator + let hydrated_record = borrowed_record + .hydrate(record_segment_reader.as_ref()) + .await + .map_err(|e| ExecuteAttachedFunctionError::SegmentRead(Box::new(e)))?; + + all_hydrated_records.push(hydrated_record); + } + + total_records_processed += materialized_log.result.len() as u64; + } + // Execute the attached function using the provided executor let output_records = self .attached_function_executor - .execute(input.log_records.clone(), record_segment_reader.as_ref()) + .execute( + Chunk::new(std::sync::Arc::from(all_hydrated_records)), + record_segment_reader.as_ref(), + ) .await .map_err(ExecuteAttachedFunctionError::SegmentRead)?; - // Update log offsets for output records - // Convert u64 completion_offset to i64 for LogRecord (which uses i64) - let base_offset: i64 = input.completion_offset.try_into().map_err(|_| { - ExecuteAttachedFunctionError::LogOffsetOverflowUnsignedToSigned( - input.completion_offset, - 0, - ) - })?; - let output_records_with_offsets: Vec = output_records .iter() .enumerate() - .map(|(i, (log_record, _))| { - let i_i64 = i64::try_from(i) - .map_err(|_| ExecuteAttachedFunctionError::LogOffsetOverflow(base_offset, i))?; - let offset = base_offset.checked_add(i_i64).ok_or_else(|| { - ExecuteAttachedFunctionError::LogOffsetOverflow(base_offset, i) - })?; + .map(|(_, (log_record, _))| { Ok(LogRecord { - log_offset: offset, + log_offset: -1, // Nobody should be using these anyway. record: log_record.record.clone(), }) }) @@ -250,8 +255,8 @@ impl Operator // Return the output records to be partitioned Ok(ExecuteAttachedFunctionOutput { - records_processed: records_count, - output_records: Chunk::new(Arc::from(output_records_with_offsets)), + records_processed: total_records_processed, + output_records: Chunk::new(std::sync::Arc::from(output_records_with_offsets)), }) } } diff --git a/rust/worker/src/execution/operators/finish_attached_function.rs b/rust/worker/src/execution/operators/finish_attached_function.rs new file mode 100644 index 00000000000..b6070012355 --- /dev/null +++ b/rust/worker/src/execution/operators/finish_attached_function.rs @@ -0,0 +1,154 @@ +use async_trait::async_trait; +use chroma_error::{ChromaError, ErrorCodes}; +use chroma_log::Log; +use chroma_sysdb::SysDb; +use chroma_system::Operator; +use chroma_types::{ + AttachedFunctionUpdateInfo, AttachedFunctionUuid, CollectionFlushInfo, CollectionUuid, + NonceUuid, Schema, SegmentFlushInfo, +}; +use std::sync::Arc; +use thiserror::Error; +use tonic; + +/// The finish attached function operator is responsible for: +/// 1. Registering collection compaction results for all collections +/// 2. Updating attached function completion offset in the same transaction +#[derive(Debug)] +pub struct FinishAttachedFunctionOperator {} + +impl FinishAttachedFunctionOperator { + /// Create a new finish attached function operator. + pub fn new() -> Box { + Box::new(FinishAttachedFunctionOperator {}) + } +} + +#[derive(Debug)] +/// The input for the finish attached function operator. +/// This input is used to complete the attached function workflow by: +/// - Flushing collection compaction data to sysdb for all collections +/// - Updating attached function completion offset in the same transaction +pub struct FinishAttachedFunctionInput { + pub collections: Vec, + pub attached_function_id: AttachedFunctionUuid, + pub attached_function_run_nonce: NonceUuid, + pub completion_offset: u64, + + pub sysdb: SysDb, + pub log: Log, +} + +impl FinishAttachedFunctionInput { + /// Create a new finish attached function input. + pub fn new( + collections: Vec, + attached_function_id: AttachedFunctionUuid, + attached_function_run_nonce: NonceUuid, + completion_offset: u64, + + sysdb: SysDb, + log: Log, + ) -> Self { + FinishAttachedFunctionInput { + collections, + attached_function_id, + attached_function_run_nonce, + completion_offset, + sysdb, + log, + } + } +} + +#[derive(Debug)] +pub struct FinishAttachedFunctionOutput { + pub flush_results: Vec, +} + +#[derive(Error, Debug)] +pub enum FinishAttachedFunctionError { + #[error("Failed to flush collection compaction: {0}")] + FlushFailed(#[from] chroma_sysdb::FlushCompactionError), + #[error("Invalid attached function ID: {0}")] + InvalidFunctionId(String), +} + +impl ChromaError for FinishAttachedFunctionError { + fn code(&self) -> ErrorCodes { + match self { + FinishAttachedFunctionError::FlushFailed(e) => e.code(), + FinishAttachedFunctionError::InvalidFunctionId(_) => ErrorCodes::InvalidArgument, + } + } +} + +#[async_trait] +impl Operator + for FinishAttachedFunctionOperator +{ + type Error = FinishAttachedFunctionError; + + fn get_name(&self) -> &'static str { + "FinishAttachedFunctionOperator" + } + + async fn run( + &self, + input: &FinishAttachedFunctionInput, + ) -> Result { + let mut sysdb = input.sysdb.clone(); + + // Create the attached function update info + let attached_function_update = AttachedFunctionUpdateInfo { + attached_function_id: input.attached_function_id, + attached_function_run_nonce: input.attached_function_run_nonce.0, + completion_offset: input.completion_offset, + }; + + // Flush all collection compaction results and update attached function in one RPC + let flush_result = sysdb + .flush_compaction_and_attached_function( + input.collections.clone(), + attached_function_update, + ) + .await + .map_err(FinishAttachedFunctionError::FlushFailed)?; + + // Build individual flush results from the response + let mut flush_results = Vec::with_capacity(flush_result.collections.len()); + for collection_result in &flush_result.collections { + flush_results.push(chroma_types::FlushCompactionAndAttachedFunctionResponse { + collections: vec![chroma_types::CollectionCompactionInfo { + collection_id: collection_result.collection_id, + collection_version: collection_result.collection_version, + last_compaction_time: collection_result.last_compaction_time, + }], + completion_offset: flush_result.completion_offset, + }); + } + + // TODO(tanujnay112): Can optimize the below to not happen on the output collection. + + // Update log offsets for all collections to ensure consistency + // This must be done after the flush to ensure the log position in sysdb is always >= log service + let mut log = input.log.clone(); + for collection in &input.collections { + log.update_collection_log_offset( + &collection.tenant_id, + collection.collection_id, + collection.log_position, + ) + .await + .map_err(|e| { + FinishAttachedFunctionError::FlushFailed( + chroma_sysdb::FlushCompactionError::FailedToFlushCompaction( + tonic::Status::internal(format!("Failed to update log offset: {}", e)), + ), + ) + })?; + } + + Ok(FinishAttachedFunctionOutput { flush_results }) + } +} diff --git a/rust/worker/src/execution/operators/get_attached_function.rs b/rust/worker/src/execution/operators/get_attached_function.rs new file mode 100644 index 00000000000..982422d1516 --- /dev/null +++ b/rust/worker/src/execution/operators/get_attached_function.rs @@ -0,0 +1,150 @@ +use async_trait::async_trait; +use chroma_error::ChromaError; +use chroma_sysdb::sysdb::SysDb; +use chroma_system::{Operator, OperatorType}; +use chroma_types::{ + AttachedFunction, Collection, CollectionUuid, GetCollectionByCrnError, + ListAttachedFunctionsError, +}; +use thiserror::Error; + +/// The `GetAttachedFunctionOperator` lists attached functions for a collection and selects the first one. +/// If no functions are found, it returns an empty result (not an error) to allow the orchestrator +/// to handle the case gracefully. +#[derive(Clone, Debug)] +pub struct GetAttachedFunctionOperator { + pub sysdb: SysDb, + pub collection_id: CollectionUuid, +} + +impl GetAttachedFunctionOperator { + pub fn new(sysdb: SysDb, collection_id: CollectionUuid) -> Self { + Self { + sysdb, + collection_id, + } + } +} + +#[derive(Debug)] +pub struct GetAttachedFunctionInput { + pub collection_id: CollectionUuid, +} + +#[derive(Debug)] +pub struct GetAttachedFunctionOutput { + pub attached_function: Option, +} + +#[derive(Debug, Error)] +pub enum GetAttachedFunctionOperatorError { + #[error("Failed to list attached functions: {0}")] + ListFunctions(#[from] ListAttachedFunctionsError), + #[error("Failed to convert attached function proto")] + ConversionError, + #[error("No attached function found")] + NoAttachedFunctionFound, +} + +#[derive(Debug, Error)] +pub enum GetAttachedFunctionError { + #[error("Failed to list attached functions: {0}")] + ListFunctions(#[from] ListAttachedFunctionsError), + #[error("Failed to convert attached function proto")] + ConversionError, +} + +impl ChromaError for GetAttachedFunctionError { + fn code(&self) -> chroma_error::ErrorCodes { + match self { + GetAttachedFunctionError::ListFunctions(e) => e.code(), + GetAttachedFunctionError::ConversionError => chroma_error::ErrorCodes::Internal, + } + } + + fn should_trace_error(&self) -> bool { + match self { + GetAttachedFunctionError::ListFunctions(e) => e.should_trace_error(), + GetAttachedFunctionError::ConversionError => true, + } + } +} + +impl ChromaError for GetAttachedFunctionOperatorError { + fn code(&self) -> chroma_error::ErrorCodes { + match self { + GetAttachedFunctionOperatorError::ListFunctions(e) => e.code(), + GetAttachedFunctionOperatorError::ConversionError => chroma_error::ErrorCodes::Internal, + GetAttachedFunctionOperatorError::NoAttachedFunctionFound => { + chroma_error::ErrorCodes::NotFound + } + } + } + + fn should_trace_error(&self) -> bool { + match self { + GetAttachedFunctionOperatorError::ListFunctions(e) => e.should_trace_error(), + GetAttachedFunctionOperatorError::ConversionError => true, + GetAttachedFunctionOperatorError::NoAttachedFunctionFound => false, + } + } +} + +#[async_trait] +impl Operator for GetAttachedFunctionOperator { + type Error = GetAttachedFunctionOperatorError; + + fn get_type(&self) -> OperatorType { + OperatorType::IO + } + + async fn run( + &self, + input: &GetAttachedFunctionInput, + ) -> Result { + tracing::trace!( + "[{}]: Collection ID {}", + self.get_name(), + input.collection_id.0 + ); + + let attached_functions = self + .sysdb + .clone() + .list_attached_functions(input.collection_id) + .await?; + + if attached_functions.is_empty() { + tracing::info!( + "[{}]: No attached functions found for collection {}", + self.get_name(), + input.collection_id.0 + ); + return Ok(GetAttachedFunctionOutput { + attached_function: None, + }); + } + + // Take the first attached function from the list + let attached_function_proto = attached_functions + .into_iter() + .next() + .ok_or(GetAttachedFunctionOperatorError::NoAttachedFunctionFound)?; + + // Convert proto to AttachedFunction type using TryFrom from task.rs + let attached_function: AttachedFunction = attached_function_proto + .try_into() + .map_err(|_| GetAttachedFunctionOperatorError::ConversionError)?; + + tracing::info!( + "[{}]: Found attached function '{}' for collection {}", + self.get_name(), + attached_function.name, + input.collection_id.0 + ); + + Ok(GetAttachedFunctionOutput { + attached_function: Some(attached_function), + }) + } +} diff --git a/rust/worker/src/execution/operators/get_collection_and_segments.rs b/rust/worker/src/execution/operators/get_collection_and_segments.rs index 8a93a4adec6..2efb46696ae 100644 --- a/rust/worker/src/execution/operators/get_collection_and_segments.rs +++ b/rust/worker/src/execution/operators/get_collection_and_segments.rs @@ -22,6 +22,15 @@ pub struct GetCollectionAndSegmentsOperator { pub collection_id: CollectionUuid, } +impl GetCollectionAndSegmentsOperator { + pub fn new(sysdb: SysDb, collection_id: CollectionUuid) -> Self { + Self { + sysdb, + collection_id, + } + } +} + type GetCollectionAndSegmentsInput = (); pub type GetCollectionAndSegmentsOutput = CollectionAndSegments; diff --git a/rust/worker/src/execution/operators/get_collection_by_name.rs b/rust/worker/src/execution/operators/get_collection_by_name.rs new file mode 100644 index 00000000000..e69de29bb2d diff --git a/rust/worker/src/execution/operators/mod.rs b/rust/worker/src/execution/operators/mod.rs index e6464e1d3bf..c829d03e356 100644 --- a/rust/worker/src/execution/operators/mod.rs +++ b/rust/worker/src/execution/operators/mod.rs @@ -3,7 +3,10 @@ pub mod commit_segment_writer; pub mod count_records; pub mod execute_task; pub mod fetch_log; +pub mod finish_attached_function; pub mod flush_segment_writer; +pub mod get_attached_function; +pub mod get_collection_and_segments; pub mod materialize_logs; pub(super) mod register; pub mod spann_bf_pl; @@ -11,7 +14,6 @@ pub(super) mod spann_centers_search; pub(super) mod spann_fetch_pl; pub mod filter; -pub mod get_collection_and_segments; pub mod idf; pub mod knn_hnsw; pub mod knn_log; diff --git a/rust/worker/src/execution/operators/register.rs b/rust/worker/src/execution/operators/register.rs index 8ca41ab01b5..817010c6d30 100644 --- a/rust/worker/src/execution/operators/register.rs +++ b/rust/worker/src/execution/operators/register.rs @@ -9,6 +9,9 @@ use chroma_types::{CollectionUuid, FlushCompactionResponse, SegmentFlushInfo}; use std::sync::Arc; use thiserror::Error; +// Import for the From implementation +use crate::execution::operators::finish_attached_function::FinishAttachedFunctionError; + /// The register operator is responsible for flushing compaction data to the sysdb /// as well as updating the log offset in the log service. #[derive(Debug)] @@ -112,6 +115,12 @@ impl ChromaError for RegisterError { } } +impl From for RegisterError { + fn from(value: FinishAttachedFunctionError) -> Self { + RegisterError::UpdateLogOffsetError(Box::new(value)) + } +} + #[async_trait] impl Operator for RegisterOperator { type Error = RegisterError; diff --git a/rust/worker/src/execution/orchestration/apply_logs_orchestrator.rs b/rust/worker/src/execution/orchestration/apply_logs_orchestrator.rs index 2bcc13a50a8..52d6ff0a4c1 100644 --- a/rust/worker/src/execution/orchestration/apply_logs_orchestrator.rs +++ b/rust/worker/src/execution/orchestration/apply_logs_orchestrator.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use async_trait::async_trait; use chroma_error::{ChromaError, ErrorCodes}; @@ -60,7 +60,7 @@ pub struct ApplyLogsOrchestrator { segment_spans: HashMap, // Store the materialized outputs from LogFetchOrchestrator - materialized_log_data: Option>, + materialized_log_data: Option>>, metrics: CompactionMetrics, } @@ -181,7 +181,7 @@ impl ApplyLogsOrchestratorResponse { impl ApplyLogsOrchestrator { pub fn new( context: &CompactionContext, - materialized_log_data: Option>, + materialized_log_data: Option>>, ) -> Self { ApplyLogsOrchestrator { context: context.clone(), @@ -206,7 +206,7 @@ impl ApplyLogsOrchestrator { let mut tasks_to_run = Vec::new(); self.num_materialized_logs += materialized_logs.len() as u64; - let writers = self.context.get_segment_writers()?; + let writers = self.context.get_output_segment_writers()?; { self.num_uncompleted_tasks_by_segment @@ -255,7 +255,7 @@ impl ApplyLogsOrchestrator { materialized_logs.clone(), writers.record_reader.clone(), self.context - .get_collection_info()? + .get_output_collection_info()? .collection .schema .clone(), @@ -356,7 +356,7 @@ impl ApplyLogsOrchestrator { .add(self.num_materialized_logs, &[]); self.state = ExecutionState::Register; - let collection_info = match self.context.get_collection_info() { + let collection_info = match self.context.get_output_collection_info() { Ok(collection_info) => collection_info, Err(err) => { self.terminate_with_result(Err(err.into()), ctx).await; @@ -461,7 +461,7 @@ impl Orchestrator for ApplyLogsOrchestrator { } }; - for materialized_output in materialized_outputs { + for materialized_output in materialized_outputs.iter() { if materialized_output.result.is_empty() { self.terminate_with_result( Err(ApplyLogsOrchestratorError::InvariantViolation( @@ -477,7 +477,7 @@ impl Orchestrator for ApplyLogsOrchestrator { // Create tasks for each materialized output let result = self - .create_apply_log_to_segment_writer_tasks(materialized_output.result, ctx) + .create_apply_log_to_segment_writer_tasks(materialized_output.result.clone(), ctx) .await; let mut new_tasks = match result { @@ -525,7 +525,7 @@ impl Handler collection_info, None => { @@ -587,7 +587,9 @@ impl Handler writer, None => return, @@ -617,7 +619,7 @@ impl Handler info, Err(err) => { self.terminate_with_result(Err(err.into()), ctx).await; diff --git a/rust/worker/src/execution/orchestration/attached_function_orchestrator.rs b/rust/worker/src/execution/orchestration/attached_function_orchestrator.rs new file mode 100644 index 00000000000..a2d94af0766 --- /dev/null +++ b/rust/worker/src/execution/orchestration/attached_function_orchestrator.rs @@ -0,0 +1,806 @@ +use std::{ + cell::OnceCell, + sync::{atomic::AtomicU32, Arc}, +}; + +use async_trait::async_trait; +use chroma_error::{ChromaError, ErrorCodes}; +use chroma_segment::{ + blockfile_metadata::{MetadataSegmentError, MetadataSegmentWriter}, + blockfile_record::{RecordSegmentWriter, RecordSegmentWriterCreationError}, + distributed_hnsw::{DistributedHNSWSegmentFromSegmentError, DistributedHNSWSegmentWriter}, + distributed_spann::SpannSegmentWriterError, + types::VectorSegmentWriter, +}; +use chroma_system::{ + wrap, ChannelError, ComponentContext, ComponentHandle, Dispatcher, Handler, Orchestrator, + OrchestratorContext, PanicError, TaskError, TaskMessage, TaskResult, +}; +use chroma_types::{ + AttachedFunctionUuid, Chunk, CollectionAndSegments, CollectionUuid, JobId, LogRecord, + NonceUuid, SegmentType, +}; +use thiserror::Error; +use tokio::sync::oneshot::{error::RecvError, Sender}; +use tracing::Span; +use uuid::Uuid; + +use crate::execution::{ + operators::{ + execute_task::{ + ExecuteAttachedFunctionError, ExecuteAttachedFunctionInput, + ExecuteAttachedFunctionOperator, ExecuteAttachedFunctionOutput, + }, + get_attached_function::{ + GetAttachedFunctionInput, GetAttachedFunctionOperator, + GetAttachedFunctionOperatorError, GetAttachedFunctionOutput, + }, + get_collection_and_segments::{ + GetCollectionAndSegmentsError, GetCollectionAndSegmentsOperator, + }, + materialize_logs::{ + MaterializeLogInput, MaterializeLogOperator, MaterializeLogOperatorError, + MaterializeLogOutput, + }, + }, + orchestration::compact::{CompactionContextError, ExecutionState}, +}; + +use super::compact::{CollectionCompactInfo, CompactWriters, CompactionContext, CompactionMetrics}; +use chroma_types::AdvanceAttachedFunctionError; + +#[derive(Debug, Clone)] +pub struct FunctionContext { + pub attached_function_id: AttachedFunctionUuid, + pub function_id: Uuid, + pub updated_completion_offset: u64, +} + +#[derive(Debug)] +pub struct AttachedFunctionOrchestrator { + input_collection_info: CollectionCompactInfo, + output_context: CompactionContext, + result_channel: Option< + Sender>, + >, + + // Store the materialized outputs from DataFetchOrchestrator + materialized_log_data: Arc>, + + // Function context + function_context: OnceCell, + + // Execution state + state: ExecutionState, + + orchestrator_context: OrchestratorContext, + + dispatcher: ComponentHandle, + + metrics: CompactionMetrics, +} + +#[derive(Error, Debug)] +pub enum AttachedFunctionOrchestratorError { + #[error("Operation aborted because resources exhausted")] + Aborted, + #[error("Failed to get attached function: {0}")] + GetAttachedFunction(#[from] GetAttachedFunctionOperatorError), + #[error("Failed to get collection and segments: {0}")] + GetCollectionAndSegments(#[from] GetCollectionAndSegmentsError), + #[error("No attached function found")] + NoAttachedFunction, + #[error("Failed to execute attached function: {0}")] + ExecuteAttachedFunction(#[from] ExecuteAttachedFunctionError), + #[error("Failed to advance attached function: {0}")] + AdvanceAttachedFunction(#[from] AdvanceAttachedFunctionError), + #[error("Function context not set")] + FunctionContextNotSet, + #[error("Invariant violation: {0}")] + InvariantViolation(String), + #[error("Failed to materialize log: {0}")] + MaterializeLog(#[from] MaterializeLogOperatorError), + #[error("Compaction context error: {0}")] + CompactionContext(#[from] CompactionContextError), + #[error("Output collection ID not set")] + OutputCollectionIdNotSet, + #[error("Channel error: {0}")] + Channel(#[from] ChannelError), + #[error("Could not count current segment: {0}")] + CountError(Box), + #[error("Receiver error: {0}")] + RecvError(#[from] RecvError), + #[error("Panic error: {0}")] + PanicError(#[from] PanicError), + #[error("Error creating metadata writer: {0}")] + MetadataSegment(#[from] MetadataSegmentError), + #[error("Error creating record segment writer: {0}")] + RecordSegmentWriter(#[from] RecordSegmentWriterCreationError), + #[error("Error creating hnsw writer: {0}")] + HnswSegment(#[from] DistributedHNSWSegmentFromSegmentError), + #[error("Error creating spann writer: {0}")] + SpannSegment(#[from] SpannSegmentWriterError), +} + +impl ChromaError for AttachedFunctionOrchestratorError { + fn code(&self) -> ErrorCodes { + match self { + AttachedFunctionOrchestratorError::Aborted => ErrorCodes::Aborted, + AttachedFunctionOrchestratorError::GetAttachedFunction(e) => e.code(), + AttachedFunctionOrchestratorError::GetCollectionAndSegments(e) => e.code(), + AttachedFunctionOrchestratorError::NoAttachedFunction => ErrorCodes::NotFound, + AttachedFunctionOrchestratorError::ExecuteAttachedFunction(e) => e.code(), + AttachedFunctionOrchestratorError::AdvanceAttachedFunction(e) => e.code(), + AttachedFunctionOrchestratorError::MaterializeLog(e) => e.code(), + AttachedFunctionOrchestratorError::FunctionContextNotSet => ErrorCodes::Internal, + AttachedFunctionOrchestratorError::InvariantViolation(_) => ErrorCodes::Internal, + AttachedFunctionOrchestratorError::CompactionContext(e) => e.code(), + AttachedFunctionOrchestratorError::OutputCollectionIdNotSet => ErrorCodes::Internal, + AttachedFunctionOrchestratorError::Channel(e) => e.code(), + AttachedFunctionOrchestratorError::RecvError(e) => ErrorCodes::Internal, + AttachedFunctionOrchestratorError::CountError(e) => e.code(), + AttachedFunctionOrchestratorError::PanicError(e) => e.code(), + AttachedFunctionOrchestratorError::MetadataSegment(e) => e.code(), + AttachedFunctionOrchestratorError::RecordSegmentWriter(e) => e.code(), + AttachedFunctionOrchestratorError::HnswSegment(e) => e.code(), + AttachedFunctionOrchestratorError::SpannSegment(e) => e.code(), + } + } + + fn should_trace_error(&self) -> bool { + match self { + AttachedFunctionOrchestratorError::Aborted => true, + AttachedFunctionOrchestratorError::GetAttachedFunction(e) => e.should_trace_error(), + AttachedFunctionOrchestratorError::GetCollectionAndSegments(e) => { + e.should_trace_error() + } + AttachedFunctionOrchestratorError::NoAttachedFunction => false, + AttachedFunctionOrchestratorError::ExecuteAttachedFunction(e) => e.should_trace_error(), + AttachedFunctionOrchestratorError::AdvanceAttachedFunction(e) => e.should_trace_error(), + AttachedFunctionOrchestratorError::MaterializeLog(e) => e.should_trace_error(), + AttachedFunctionOrchestratorError::FunctionContextNotSet => true, + AttachedFunctionOrchestratorError::InvariantViolation(_) => true, + AttachedFunctionOrchestratorError::CompactionContext(e) => e.should_trace_error(), + AttachedFunctionOrchestratorError::OutputCollectionIdNotSet => true, + AttachedFunctionOrchestratorError::Channel(e) => e.should_trace_error(), + AttachedFunctionOrchestratorError::RecvError(_) => true, + AttachedFunctionOrchestratorError::CountError(e) => e.should_trace_error(), + AttachedFunctionOrchestratorError::PanicError(e) => e.should_trace_error(), + AttachedFunctionOrchestratorError::MetadataSegment(e) => e.should_trace_error(), + AttachedFunctionOrchestratorError::RecordSegmentWriter(e) => e.should_trace_error(), + AttachedFunctionOrchestratorError::HnswSegment(e) => e.should_trace_error(), + AttachedFunctionOrchestratorError::SpannSegment(e) => e.should_trace_error(), + } + } +} + +impl From> for AttachedFunctionOrchestratorError +where + E: Into, +{ + fn from(value: TaskError) -> Self { + match value { + TaskError::Aborted => AttachedFunctionOrchestratorError::Aborted, + TaskError::Panic(e) => e.into(), + TaskError::TaskFailed(e) => e.into(), + } + } +} + +#[derive(Debug)] +pub enum AttachedFunctionOrchestratorResponse { + /// No attached function was found, so nothing was executed + NoAttachedFunction { job_id: JobId }, + /// Success - attached function was executed successfully + Success { + job_id: JobId, + materialized_output: Vec, + output_collection_info: CollectionCompactInfo, + attached_function_id: AttachedFunctionUuid, + attached_function_run_nonce: NonceUuid, + completion_offset: u64, + }, +} + +impl AttachedFunctionOrchestrator { + pub fn new( + input_collection_info: CollectionCompactInfo, + output_context: CompactionContext, + dispatcher: ComponentHandle, + data_fetch_records: Arc>, + ) -> Self { + let orchestrator_context = OrchestratorContext::new(dispatcher.clone()); + + let orchestrator = AttachedFunctionOrchestrator { + input_collection_info, + output_context, + result_channel: None, + materialized_log_data: data_fetch_records, + function_context: OnceCell::new(), + state: ExecutionState::MaterializeApplyCommitFlush, + orchestrator_context, + dispatcher, + metrics: CompactionMetrics::default(), + }; + + orchestrator + } + + /// Get the input collection info, following the same pattern as CompactionContext + pub fn get_input_collection_info( + &self, + ) -> Result<&CollectionCompactInfo, AttachedFunctionOrchestratorError> { + Ok(&self.input_collection_info) + } + + /// Get the output collection info if it has been set + pub fn get_output_collection_info( + &self, + ) -> Result<&CollectionCompactInfo, AttachedFunctionOrchestratorError> { + self.output_context + .get_output_collection_info() + .map_err(AttachedFunctionOrchestratorError::CompactionContext) + } + + /// Get the output collection ID if it has been set + pub fn get_output_collection_id( + &self, + ) -> Result { + self.output_context + .get_output_collection_info() + .map(|info| info.collection_id) + .map_err(AttachedFunctionOrchestratorError::CompactionContext) + } + + /// Set the output collection info + pub fn set_output_collection_info( + &mut self, + collection_info: CollectionCompactInfo, + ) -> Result<(), CollectionCompactInfo> { + self.output_context + .output_collection_info + .set(collection_info) + } + + /// Get the function context if it has been set + pub fn get_function_context(&self) -> Option<&FunctionContext> { + self.function_context.get() + } + + /// Set the function context + pub fn set_function_context( + &self, + function_context: FunctionContext, + ) -> Result<(), FunctionContext> { + self.function_context.set(function_context) + } + + async fn finish_no_attached_function(&mut self, ctx: &ComponentContext) { + let collection_info = match self.get_input_collection_info() { + Ok(info) => info, + Err(e) => { + self.terminate_with_result(Err(e), ctx).await; + return; + } + }; + let job_id = collection_info.collection_id.into(); + self.terminate_with_result( + Ok(AttachedFunctionOrchestratorResponse::NoAttachedFunction { job_id }), + ctx, + ) + .await; + } + + async fn finish_success( + &mut self, + materialized_output: Vec, + ctx: &ComponentContext, + ) { + let collection_info = match self.get_input_collection_info() { + Ok(info) => info, + Err(e) => { + self.terminate_with_result(Err(e), ctx).await; + return; + } + }; + + // Get output collection info - should always exist in success case + let output_collection_info = match self.get_output_collection_info() { + Ok(info) => info.clone(), + Err(e) => { + self.terminate_with_result(Err(e), ctx).await; + return; + } + }; + + // Get attached function ID - should always exist in success case + let attached_function = match self.get_function_context() { + Some(func) => func, + None => { + self.terminate_with_result( + Err(AttachedFunctionOrchestratorError::FunctionContextNotSet), + ctx, + ) + .await; + return; + } + }; + let attached_function_id = attached_function.attached_function_id; + + // Get the run nonce from the attached function + let attached_function_run_nonce = NonceUuid::new(); + + // Get the completion offset from the input collection's pulled log offset + let completion_offset = collection_info.pulled_log_offset as u64; + + println!( + "Attached function finished successfully with {} records", + materialized_output.len() + ); + + let job_id = collection_info.collection_id.into(); + self.terminate_with_result( + Ok(AttachedFunctionOrchestratorResponse::Success { + job_id, + materialized_output, + output_collection_info, + attached_function_id, + attached_function_run_nonce, + completion_offset, + }), + ctx, + ) + .await; + } + + /// Convert ExecuteAttachedFunctionOutput to MaterializeLogOutput for the ApplyDataOrchestrator + fn convert_materialized_output_to_log_records( + &self, + output: &ExecuteAttachedFunctionOutput, + ) -> Vec { + // TODO: Implement proper conversion from ExecuteAttachedFunctionOutput to MaterializeLogOutput + // For now, return a placeholder + vec![] + } + + async fn materialize_log( + &mut self, + partitions: Vec>, + ctx: &ComponentContext, + ) { + self.state = ExecutionState::MaterializeApplyCommitFlush; + + // NOTE: We allow writers to be uninitialized for the case when the materialized logs are empty + let record_reader = self + .output_context + .get_output_segment_writers() + .ok() + .and_then(|writers| writers.record_reader); + + let next_max_offset_id = Arc::new( + record_reader + .as_ref() + .map(|reader| AtomicU32::new(reader.get_max_offset_id() + 1)) + .unwrap_or_default(), + ); + + if let Some(rr) = record_reader.as_ref() { + let count = match rr.count().await { + Ok(count) => count as u64, + Err(err) => { + return self + .terminate_with_result( + Err(AttachedFunctionOrchestratorError::CountError(err)), + ctx, + ) + .await; + } + }; + + let collection_info = match self.output_context.get_output_collection_info_mut() { + Ok(info) => info, + Err(err) => { + return self.terminate_with_result(Err(err.into()), ctx).await; + } + }; + collection_info.collection.total_records_post_compaction = count; + } + + for partition in partitions.iter() { + let operator = MaterializeLogOperator::new(); + let input = MaterializeLogInput::new( + partition.clone(), + record_reader.clone(), + next_max_offset_id.clone(), + ); + let task = wrap( + operator, + input, + ctx.receiver(), + self.output_context + .orchestrator_context + .task_cancellation_token + .clone(), + ); + self.send(task, ctx, Some(Span::current())).await; + } + } +} + +#[async_trait] +impl Orchestrator for AttachedFunctionOrchestrator { + type Output = AttachedFunctionOrchestratorResponse; + type Error = AttachedFunctionOrchestratorError; + + fn dispatcher(&self) -> ComponentHandle { + self.dispatcher.clone() + } + + fn context(&self) -> &OrchestratorContext { + &self.orchestrator_context + } + + async fn initial_tasks( + &mut self, + ctx: &ComponentContext, + ) -> Vec<(TaskMessage, Option)> { + // Start by getting the attached function for this collection + let collection_info = match self.get_input_collection_info() { + Ok(info) => info, + Err(e) => { + // If we can't get collection info, we can't proceed + self.terminate_with_result(Err(e), ctx).await; + return vec![]; + } + }; + let operator = Box::new(GetAttachedFunctionOperator::new( + self.output_context.sysdb.clone(), + collection_info.collection_id, + )); + let input = GetAttachedFunctionInput { + collection_id: collection_info.collection_id, + }; + let task = wrap( + operator, + input, + ctx.receiver(), + self.context().task_cancellation_token.clone(), + ); + vec![(task, None)] + } + + fn set_result_channel( + &mut self, + sender: Sender< + Result, + >, + ) { + self.result_channel = Some(sender) + } + + fn take_result_channel( + &mut self, + ) -> Option< + Sender>, + > { + self.result_channel.take() + } + + async fn cleanup(&mut self) { + // TODO: Add any necessary cleanup + } +} + +#[async_trait] +impl Handler> + for AttachedFunctionOrchestrator +{ + type Result = (); + + async fn handle( + &mut self, + message: TaskResult, + ctx: &ComponentContext, + ) { + let message = match self.ok_or_terminate(message.into_inner(), ctx).await { + Some(message) => message, + None => return, + }; + + self.finish_success(vec![message], ctx).await; + } +} + +#[async_trait] +impl Handler> + for AttachedFunctionOrchestrator +{ + type Result = (); + + async fn handle( + &mut self, + message: TaskResult, + ctx: &ComponentContext, + ) { + let message = match self.ok_or_terminate(message.into_inner(), ctx).await { + Some(message) => message, + None => return, + }; + + match message.attached_function { + Some(attached_function) => { + tracing::info!( + "[AttachedFunctionOrchestrator]: Found attached function '{}' for collection", + attached_function.name + ); + + // TODO(tanujnay112): Handle error + let _ = self.function_context.set(FunctionContext { + attached_function_id: attached_function.id, + function_id: attached_function.function_id, + updated_completion_offset: attached_function.completion_offset, + }); + + // Get the output collection ID from the attached function + let output_collection_id = match attached_function.output_collection_id { + Some(id) => id, + None => { + tracing::error!( + "[AttachedFunctionOrchestrator]: Output collection ID not set for attached function '{}'", + attached_function.name + ); + self.terminate_with_result( + Err(AttachedFunctionOrchestratorError::OutputCollectionIdNotSet), + ctx, + ) + .await; + return; + } + }; + + // Next step: get the output collection segments using the existing GetCollectionAndSegmentsOperator + let operator = Box::new(GetCollectionAndSegmentsOperator::new( + self.output_context.sysdb.clone(), + output_collection_id, + )); + let input = (); + let task = wrap( + operator, + input, + ctx.receiver(), + self.context().task_cancellation_token.clone(), + ); + let res = self.dispatcher().send(task, None).await; + self.ok_or_terminate(res, ctx).await; + } + None => { + tracing::info!("[AttachedFunctionOrchestrator]: No attached function found"); + self.finish_no_attached_function(ctx).await; + } + } + } +} + +#[async_trait] +impl Handler> + for AttachedFunctionOrchestrator +{ + type Result = (); + + async fn handle( + &mut self, + message: TaskResult, + ctx: &ComponentContext, + ) { + let message = match self.ok_or_terminate(message.into_inner(), ctx).await { + Some(message) => message, + None => return, + }; + + // self.output_context.output_collection_segments = Some(message.clone()); + + tracing::info!( + "[AttachedFunctionOrchestrator]: Found output collection segments - metadata: {:?}, record: {:?}, vector: {:?}", + message.metadata_segment.id, + message.record_segment.id, + message.vector_segment.id + ); + + // Create segment writers for the output collection + let collection = &message.collection; + let dimension = match collection.dimension { + Some(dim) => dim as usize, + None => { + // Output collection is not initialized, cannot create writers + self.terminate_with_result( + Err(AttachedFunctionOrchestratorError::InvariantViolation( + "Output collection dimension is not set".to_string(), + )), + ctx, + ) + .await; + return; + } + }; + + let record_writer = match self + .ok_or_terminate( + RecordSegmentWriter::from_segment( + &collection.tenant, + &collection.database_id, + &message.record_segment, + &self.output_context.blockfile_provider, + ) + .await, + ctx, + ) + .await + { + Some(writer) => writer, + None => return, + }; + + let metadata_writer = match self + .ok_or_terminate( + MetadataSegmentWriter::from_segment( + &collection.tenant, + &collection.database_id, + &message.metadata_segment, + &self.output_context.blockfile_provider, + ) + .await, + ctx, + ) + .await + { + Some(writer) => writer, + None => return, + }; + + let (hnsw_index_uuid, vector_writer) = match message.vector_segment.r#type { + SegmentType::Spann => match self + .ok_or_terminate( + self.output_context + .spann_provider + .write(&collection, &message.vector_segment, dimension) + .await, + ctx, + ) + .await + { + Some(writer) => (writer.hnsw_index_uuid(), VectorSegmentWriter::Spann(writer)), + None => return, + }, + _ => match self + .ok_or_terminate( + DistributedHNSWSegmentWriter::from_segment( + &collection, + &message.vector_segment, + dimension, + self.output_context.hnsw_provider.clone(), + ) + .await + .map_err(|err| *err), + ctx, + ) + .await + { + Some(writer) => (writer.index_uuid(), VectorSegmentWriter::Hnsw(writer)), + None => return, + }, + }; + + let writers = CompactWriters { + record_reader: None, // Output collection doesn't need a reader + metadata_writer, + record_writer, + vector_writer, + }; + + // Store the output collection info with writers + let output_collection_info = CollectionCompactInfo { + collection_id: message.collection.collection_id, + collection: message.collection.clone(), + writers: Some(writers), + pulled_log_offset: message.collection.log_position, + hnsw_index_uuid: Some(hnsw_index_uuid), + schema: message.collection.schema.clone(), + }; + + if let Err(_) = self.set_output_collection_info(output_collection_info) { + self.terminate_with_result( + Err(AttachedFunctionOrchestratorError::InvariantViolation( + "Failed to set output collection info".to_string(), + )), + ctx, + ) + .await; + return; + } + + let function_context = self.function_context.get(); + + let attached_function = match function_context { + Some(func) => func, + None => { + self.terminate_with_result( + Err(AttachedFunctionOrchestratorError::NoAttachedFunction), + ctx, + ) + .await; + return; + } + }; + + let function_id = attached_function.function_id; + // Execute the attached function + let operator = match ExecuteAttachedFunctionOperator::from_attached_function( + function_id, + self.output_context.log.clone(), + ) { + Ok(op) => Box::new(op), + Err(e) => { + self.terminate_with_result( + Err(AttachedFunctionOrchestratorError::ExecuteAttachedFunction( + e, + )), + ctx, + ) + .await; + return; + } + }; + + // Get the input collection info to access pulled_log_offset + let collection_info = self + .get_input_collection_info() + .expect("Input collection info should be set"); + + let input = ExecuteAttachedFunctionInput { + materialized_logs: Arc::clone(&self.materialized_log_data), // Use the actual materialized logs from data fetch + tenant_id: "default".to_string(), // TODO: Get actual tenant ID + output_collection_id: message.collection.collection_id, + completion_offset: collection_info.pulled_log_offset as u64, // Use the completion offset from input collection + output_record_segment: message.record_segment.clone(), + blockfile_provider: self.output_context.blockfile_provider.clone(), + }; + + let task = wrap( + operator, + input, + ctx.receiver(), + self.context().task_cancellation_token.clone(), + ); + let res = self.dispatcher().send(task, None).await; + self.ok_or_terminate(res, ctx).await; + } +} + +#[async_trait] +impl Handler> + for AttachedFunctionOrchestrator +{ + type Result = (); + + async fn handle( + &mut self, + message: TaskResult, + ctx: &ComponentContext, + ) { + let message = match self.ok_or_terminate(message.into_inner(), ctx).await { + Some(message) => message, + None => return, + }; + + tracing::info!( + "[AttachedFunctionOrchestrator]: Attached function executed successfully, processed {} records", + message.records_processed + ); + self.materialize_log(vec![message.output_records], ctx) + .await; + + // Convert execution results to log records and finish + // let log_records = self.convert_materialized_output_to_log_records(&message); + // self.finish_success(log_records, ctx).await; + } +} diff --git a/rust/worker/src/execution/orchestration/compact.rs b/rust/worker/src/execution/orchestration/compact.rs index 9356868d73f..2021eb4b151 100644 --- a/rust/worker/src/execution/orchestration/compact.rs +++ b/rust/worker/src/execution/orchestration/compact.rs @@ -1,4 +1,4 @@ -use std::cell::OnceCell; +use std::{cell::OnceCell, sync::Arc}; use chroma_blockstore::provider::BlockfileProvider; use chroma_error::{ChromaError, ErrorCodes}; @@ -14,20 +14,25 @@ use chroma_sysdb::SysDb; use chroma_system::{ ComponentHandle, Dispatcher, Orchestrator, OrchestratorContext, PanicError, System, TaskError, }; -use chroma_types::{Collection, CollectionUuid, JobId, Schema, SegmentFlushInfo, SegmentUuid}; +use chroma_types::{Collection, CollectionUuid, JobId, Schema, SegmentUuid}; use opentelemetry::metrics::Counter; use thiserror::Error; use super::apply_logs_orchestrator::{ApplyLogsOrchestrator, ApplyLogsOrchestratorError}; +use super::attached_function_orchestrator::{ + AttachedFunctionOrchestrator, AttachedFunctionOrchestratorError, + AttachedFunctionOrchestratorResponse, +}; use super::log_fetch_orchestrator::{ LogFetchOrchestrator, LogFetchOrchestratorResponse, RequireCompactionOffsetRepair, Success, }; -use super::register_orchestrator::RegisterOrchestrator; +use super::register_orchestrator::{CollectionRegisterInfo, RegisterOrchestrator}; use crate::execution::{ operators::materialize_logs::MaterializeLogOutput, orchestration::{ apply_logs_orchestrator::ApplyLogsOrchestratorResponse, + attached_function_orchestrator::FunctionContext, log_fetch_orchestrator::LogFetchOrchestratorError, register_orchestrator::{RegisterOrchestratorError, RegisterOrchestratorResponse}, }, @@ -97,9 +102,16 @@ pub struct CollectionCompactInfo { pub schema: Option, } +pub struct CompactionState { + pub collection_info: OnceCell, + pub context: CompactionContext, +} + #[derive(Debug)] pub struct CompactionContext { - pub collection_info: OnceCell, + pub input_collection_info: OnceCell, + pub output_collection_info: OnceCell, + pub attached_function_context: OnceCell, pub log: Log, pub sysdb: SysDb, pub blockfile_provider: BlockfileProvider, @@ -119,7 +131,9 @@ impl Clone for CompactionContext { fn clone(&self) -> Self { let orchestrator_context = OrchestratorContext::new(self.dispatcher.clone()); Self { - collection_info: self.collection_info.clone(), + input_collection_info: self.input_collection_info.clone(), + output_collection_info: self.output_collection_info.clone(), + attached_function_context: self.attached_function_context.clone(), log: self.log.clone(), sysdb: self.sysdb.clone(), blockfile_provider: self.blockfile_provider.clone(), @@ -143,6 +157,8 @@ pub enum CompactionError { Aborted, #[error("Error applying data to segment writers: {0}")] ApplyDataError(#[from] ApplyLogsOrchestratorError), + #[error("Error executing attached function: {0}")] + AttachedFunction(#[from] AttachedFunctionOrchestratorError), #[error("Error fetching compaction context: {0}")] CompactionContextError(#[from] CompactionContextError), #[error("Error fetching logs: {0}")] @@ -172,7 +188,13 @@ impl ChromaError for CompactionError { fn code(&self) -> ErrorCodes { match self { CompactionError::Aborted => ErrorCodes::Aborted, - _ => ErrorCodes::Internal, + CompactionError::ApplyDataError(e) => e.code(), + CompactionError::AttachedFunction(e) => e.code(), + CompactionError::CompactionContextError(e) => e.code(), + CompactionError::DataFetchError(e) => e.code(), + CompactionError::RegisterError(e) => e.code(), + CompactionError::PanicError(e) => e.code(), + CompactionError::InvariantViolation(_) => ErrorCodes::Internal, } } @@ -180,6 +202,7 @@ impl ChromaError for CompactionError { match self { Self::Aborted => true, Self::ApplyDataError(e) => e.should_trace_error(), + Self::AttachedFunction(e) => e.should_trace_error(), Self::CompactionContextError(e) => e.should_trace_error(), Self::DataFetchError(e) => e.should_trace_error(), Self::PanicError(e) => e.should_trace_error(), @@ -225,7 +248,9 @@ impl CompactionContext { ) -> Self { let orchestrator_context = OrchestratorContext::new(dispatcher.clone()); CompactionContext { - collection_info: OnceCell::new(), + input_collection_info: OnceCell::new(), + output_collection_info: OnceCell::new(), + attached_function_context: OnceCell::new(), is_rebuild, fetch_log_batch_size, max_compaction_size, @@ -247,35 +272,67 @@ impl CompactionContext { self.poison_offset = Some(offset); } - pub fn get_segment_writers(&self) -> Result { - self.get_collection_info()?.writers.clone().ok_or( - CompactionContextError::InvariantViolation("Segment writers should have been set"), + pub fn get_output_segment_writers(&self) -> Result { + self.get_output_collection_info()?.writers.clone().ok_or( + CompactionContextError::InvariantViolation( + "Output segment writers should have been set", + ), ) } - pub fn get_collection_info(&self) -> Result<&CollectionCompactInfo, CompactionContextError> { - self.collection_info + pub fn get_input_segment_writers(&self) -> Result { + self.get_input_collection_info()?.writers.clone().ok_or( + CompactionContextError::InvariantViolation( + "Input segment writers should have been set", + ), + ) + } + + pub fn get_input_collection_info( + &self, + ) -> Result<&CollectionCompactInfo, CompactionContextError> { + self.input_collection_info + .get() + .ok_or(CompactionContextError::InvariantViolation( + "Collection info should have been set", + )) + } + + pub fn get_output_collection_info( + &self, + ) -> Result<&CollectionCompactInfo, CompactionContextError> { + self.output_collection_info .get() .ok_or(CompactionContextError::InvariantViolation( "Collection info should have been set", )) } - pub fn get_collection_info_mut( + pub fn get_input_collection_info_mut( + &mut self, + ) -> Result<&mut CollectionCompactInfo, CompactionContextError> { + self.input_collection_info + .get_mut() + .ok_or(CompactionContextError::InvariantViolation( + "Collection info mut should have been set", + )) + } + + pub fn get_output_collection_info_mut( &mut self, ) -> Result<&mut CollectionCompactInfo, CompactionContextError> { - self.collection_info + self.output_collection_info .get_mut() .ok_or(CompactionContextError::InvariantViolation( "Collection info mut should have been set", )) } - pub fn get_segment_writer_by_id( + pub fn get_output_segment_writer_by_id( &self, segment_id: SegmentUuid, ) -> Result, CompactionContextError> { - let writers = self.get_segment_writers()?; + let writers = self.get_output_segment_writers()?; if writers.metadata_writer.id == segment_id { return Ok(ChromaSegmentWriter::MetadataSegment( @@ -330,7 +387,7 @@ impl CompactionContext { let materialized = success.materialized; let collection_info = success.collection_info; - self.collection_info + self.input_collection_info .set(collection_info.clone()) .map_err(|_| { CompactionContextError::InvariantViolation("Collection info already set") @@ -347,10 +404,10 @@ impl CompactionContext { pub(crate) async fn run_apply_logs( &mut self, - log_fetch_records: Vec, + log_fetch_records: Arc>, system: System, ) -> Result { - let collection_info = self.get_collection_info()?; + let collection_info = self.get_input_collection_info()?; if log_fetch_records.is_empty() { return Ok(ApplyLogsOrchestratorResponse::new_with_empty_results( collection_info.collection_id.into(), @@ -359,7 +416,7 @@ impl CompactionContext { } // INVARIANT: Every element of log_fetch_records should be non-empty - for mat_logs in &log_fetch_records { + for mat_logs in log_fetch_records.iter() { if mat_logs.result.is_empty() { return Err(ApplyLogsOrchestratorError::InvariantViolation( "Every element of log_fetch_records should be non-empty", @@ -379,7 +436,7 @@ impl CompactionContext { } }; - let collection_info = self.collection_info.get_mut().ok_or( + let collection_info = self.output_collection_info.get_mut().ok_or( ApplyLogsOrchestratorError::InvariantViolation("Collection info should have been set"), )?; collection_info.schema = apply_logs_response.schema.clone(); @@ -389,19 +446,159 @@ impl CompactionContext { Ok(apply_logs_response) } + // Should be invoked on output collection context + pub(crate) async fn run_attached_function( + &mut self, + data_fetch_records: Arc>, + system: System, + ) -> Result { + let input_collection_info = self.get_input_collection_info()?.clone(); + let input_collection_info_clone = input_collection_info.clone(); + println!("num records {}", data_fetch_records.len()); + let attached_function_orchestrator = AttachedFunctionOrchestrator::new( + input_collection_info, + self.clone(), + self.dispatcher.clone(), + data_fetch_records, + ); + + let attached_function_response = match attached_function_orchestrator.run(system).await { + Ok(response) => response, + Err(e) => { + if e.should_trace_error() { + tracing::error!("Attached function phase failed: {e}"); + } + return Err(e); + } + }; + + // Set the output collection info based on the response + match &attached_function_response { + AttachedFunctionOrchestratorResponse::NoAttachedFunction { .. } => { + self.output_collection_info + .set(input_collection_info_clone) + .map_err(|_| { + AttachedFunctionOrchestratorError::InvariantViolation( + "Collection info should not have been already set".to_string(), + ) + })?; + } + AttachedFunctionOrchestratorResponse::Success { + output_collection_info, + .. + } => { + self.output_collection_info + .set(output_collection_info.clone()) + .map_err(|_| { + AttachedFunctionOrchestratorError::InvariantViolation( + "Collection info should not have been already set".to_string(), + ) + })?; + } + } + + Ok(attached_function_response) + } + + async fn run_attached_function_workflow( + mut self, + log_fetch_records: Arc>, + system: System, + ) -> Result, CompactionError> { + let attached_function_result = self + .run_attached_function(log_fetch_records, system.clone()) + .await?; + + match attached_function_result { + AttachedFunctionOrchestratorResponse::NoAttachedFunction { .. } => Ok(None), + AttachedFunctionOrchestratorResponse::Success { + output_collection_info, + job_id, + materialized_output, + attached_function_id, + attached_function_run_nonce, + completion_offset, + } => { + // Update self to use the output collection for apply_logs + self.output_collection_info = OnceCell::from(output_collection_info.clone()); + + // Apply materialized output to output collection + let apply_logs_response = self + .run_apply_logs(Arc::new(materialized_output), system.clone()) + .await?; + + let function_context = FunctionContext { + attached_function_id, + function_id: attached_function_id.0, + updated_completion_offset: completion_offset, + }; + + let collection_register_info = CollectionRegisterInfo { + collection_info: output_collection_info, + flush_results: apply_logs_response.flush_results, + collection_logical_size_bytes: apply_logs_response + .collection_logical_size_bytes, + }; + + Ok(Some((function_context, collection_register_info))) + } + } + } + pub(crate) async fn run_register( &mut self, - flush_results: Vec, - collection_logical_size_bytes: u64, + collection_register_infos: Vec, + function_register_info: Option, system: System, ) -> Result { let dispatcher = self.dispatcher.clone(); - let register_orchestrator = RegisterOrchestrator::new( - self, - dispatcher, - flush_results, - collection_logical_size_bytes, - ); + + if collection_register_infos.is_empty() || collection_register_infos.len() > 2 { + return Err(RegisterOrchestratorError::InvariantViolation( + "Invalid number of collection register infos", + )); + } + + if collection_register_infos.len() == 2 { + match function_register_info { + Some(function_info) => { + let mut iter = collection_register_infos.into_iter(); + let first_collection = + iter.next() + .ok_or(RegisterOrchestratorError::InvariantViolation( + "Expected first collection register info", + ))?; + let second_collection = + iter.next() + .ok_or(RegisterOrchestratorError::InvariantViolation( + "Expected second collection register info", + ))?; + + let register_orchestrator = RegisterOrchestrator::new_with_attached_function( + self.clone(), + dispatcher, + first_collection, + second_collection, + function_info, + ); + return register_orchestrator.run(system).await; + } + None => { + return Err(RegisterOrchestratorError::InvariantViolation( + "Invalid number of function register infos", + )); + } + } + } + + let collection_register_info = collection_register_infos.into_iter().next().ok_or( + RegisterOrchestratorError::InvariantViolation( + "Expected at least one collection register info", + ), + )?; + + let register_orchestrator = + RegisterOrchestrator::new(self, dispatcher, collection_register_info); match register_orchestrator.run(system).await { Ok(response) => Ok(response), @@ -433,15 +630,69 @@ impl CompactionContext { } }; - let apply_logs_response = self - .run_apply_logs(log_fetch_records, system.clone()) - .await?; + // Wrap in Arc to avoid cloning large MaterializeLogOutput data + let log_fetch_records = Arc::new(log_fetch_records); + let log_fetch_records_clone = Arc::clone(&log_fetch_records); + + let mut self_clone_attached = self.clone(); + let mut self_clone_input = self.clone(); + let system_clone1 = system.clone(); + let system_clone2 = system.clone(); + + let input_collection_info = + self.input_collection_info + .get() + .ok_or(CompactionError::InvariantViolation( + "Input collection info should not be None", + ))?; + + self.output_collection_info + .set(input_collection_info.clone()) + .map_err(|_| { + CompactionError::InvariantViolation( + "Collection info should not have been already set", + ) + })?; + + // Parallelize two independent workflows using tokio::spawn to avoid stack overflow + // (tokio::join! polls futures on the stack, but these orchestrators are large) + // 1. Attached function execution + apply output to output collection + // 2. Apply input logs to input collection + let attached_handle = tokio::spawn(async move { + self_clone_attached + .run_attached_function_workflow(log_fetch_records_clone, system_clone1) + .await + }); + + let input_handle = tokio::spawn(async move { + self_clone_input + .run_apply_logs(log_fetch_records, system_clone2) + .await + }); + + let (attached_result, apply_logs_response) = tokio::join!(attached_handle, input_handle); + + let attached_result = attached_result.map_err(|_| { + CompactionError::InvariantViolation("Attached function task panicked") + })??; + let apply_logs_response = apply_logs_response + .map_err(|_| CompactionError::InvariantViolation("Input compaction task panicked"))??; + + // Collect results + let mut attached_function_context = None; + let mut results: Vec = Vec::new(); + + if let Some((function_context, collection_register_info)) = attached_result { + attached_function_context = Some(function_context); + results.push(collection_register_info); + } + // Process input collection result // Invariant: flush_results is empty => collection_logical_size_bytes == collection_info.collection.size_bytes_post_compaction if apply_logs_response.flush_results.is_empty() && apply_logs_response.collection_logical_size_bytes != self - .get_collection_info()? + .get_output_collection_info()? .collection .size_bytes_post_compaction { @@ -450,12 +701,14 @@ impl CompactionContext { )); } - let _ = Box::pin(self.run_register( - apply_logs_response.flush_results, - apply_logs_response.collection_logical_size_bytes, - system.clone(), - )) - .await?; + results.push(CollectionRegisterInfo { + collection_info: self.get_output_collection_info()?.clone(), + flush_results: apply_logs_response.flush_results, + collection_logical_size_bytes: apply_logs_response.collection_logical_size_bytes, + }); + + let _ = + Box::pin(self.run_register(results, attached_function_context, system.clone())).await?; Ok(CompactionResponse::Success { job_id: collection_id.into(), @@ -463,7 +716,17 @@ impl CompactionContext { } pub(crate) async fn cleanup(self) { - if let Some(collection_info) = self.collection_info.get() { + if let Some(collection_info) = self.input_collection_info.get() { + if let Some(hnsw_index_uuid) = collection_info.hnsw_index_uuid { + let _ = HnswIndexProvider::purge_one_id( + self.hnsw_provider.temporary_storage_path.as_path(), + hnsw_index_uuid, + ) + .await; + } + } + + if let Some(collection_info) = self.output_collection_info.get() { if let Some(hnsw_index_uuid) = collection_info.hnsw_index_uuid { let _ = HnswIndexProvider::purge_one_id( self.hnsw_provider.temporary_storage_path.as_path(), @@ -533,6 +796,7 @@ mod tests { use chroma_log::test::{upsert_generator, TEST_EMBEDDING_DIMENSION}; use std::collections::HashMap; use std::path::{Path, PathBuf}; + use std::sync::Arc; use tokio::fs; use chroma_blockstore::arrow::config::{BlockManagerConfig, TEST_MAX_BLOCK_SIZE_BYTES}; @@ -565,6 +829,7 @@ mod tests { }; use super::{compact, CompactionContext, CompactionResponse, LogFetchOrchestratorResponse}; + use crate::execution::orchestration::register_orchestrator::CollectionRegisterInfo; #[tokio::test] async fn test_rebuild() { @@ -1782,17 +2047,24 @@ mod tests { compaction_1_log_records.len() ); let compaction_1_apply_response = compaction_context_1 - .run_apply_logs(compaction_1_log_records, system.clone()) + .run_apply_logs(Arc::new(compaction_1_log_records), system.clone()) .await .expect("Apply should have succeeded."); - let _register_result = Box::pin(compaction_context_1.run_register( - compaction_1_apply_response.flush_results, - compaction_1_apply_response.collection_logical_size_bytes, - system.clone(), - )) - .await - .expect_err("Register should have failed."); + let register_info = vec![CollectionRegisterInfo { + collection_info: compaction_context_1 + .get_input_collection_info() + .unwrap() + .clone(), + flush_results: compaction_1_apply_response.flush_results, + collection_logical_size_bytes: compaction_1_apply_response + .collection_logical_size_bytes, + }]; + + let _register_result = + Box::pin(compaction_context_1.run_register(register_info, None, system.clone())) + .await + .expect_err("Register should have failed."); // Verify that the collection was successfully compacted (by whichever succeeded) let collection_after_compaction = sysdb diff --git a/rust/worker/src/execution/orchestration/log_fetch_orchestrator.rs b/rust/worker/src/execution/orchestration/log_fetch_orchestrator.rs index 30017503124..1b8bdbc7163 100644 --- a/rust/worker/src/execution/orchestration/log_fetch_orchestrator.rs +++ b/rust/worker/src/execution/orchestration/log_fetch_orchestrator.rs @@ -313,7 +313,7 @@ impl LogFetchOrchestrator { // NOTE: We allow writers to be uninitialized for the case when the materialized logs are empty let record_reader = self .context - .get_segment_writers() + .get_input_segment_writers() .ok() .and_then(|writers| writers.record_reader); @@ -334,7 +334,7 @@ impl LogFetchOrchestrator { } }; - let collection_info = match self.context.get_collection_info_mut() { + let collection_info = match self.context.get_input_collection_info_mut() { Ok(info) => info, Err(err) => { return self.terminate_with_result(Err(err.into()), ctx).await; @@ -447,7 +447,7 @@ impl Handler info, None => { self.terminate_with_result( @@ -644,7 +644,7 @@ impl Handler> for LogFetchOrchestrator tracing::info!("Pulled Records: {}", output.len()); match output.iter().last() { Some((rec, _)) => { - let collection_info = match self.context.get_collection_info_mut() { + let collection_info = match self.context.get_input_collection_info_mut() { Ok(info) => info, Err(err) => { tracing::info!("We're failing right here"); @@ -660,7 +660,7 @@ impl Handler> for LogFetchOrchestrator } None => { tracing::warn!("No logs were pulled from the log service, this can happen when the log compaction offset is behing the sysdb."); - let collection_info = match self.context.get_collection_info() { + let collection_info = match self.context.get_input_collection_info() { Ok(info) => info, Err(err) => { self.terminate_with_result(Err(err.into()), ctx).await; @@ -700,7 +700,7 @@ impl Handler> }; tracing::info!("Sourced Records: {}", output.len()); // Each record should corresond to a log - let collection_info = match self.context.get_collection_info_mut() { + let collection_info = match self.context.get_input_collection_info_mut() { Ok(info) => info, Err(err) => { self.terminate_with_result(Err(err.into()), ctx).await; @@ -709,7 +709,7 @@ impl Handler> }; collection_info.collection.total_records_post_compaction = output.len() as u64; - let collection_info = match self.context.get_collection_info() { + let collection_info = match self.context.get_input_collection_info() { Ok(info) => info, Err(err) => { self.terminate_with_result(Err(err.into()), ctx).await; @@ -767,7 +767,7 @@ impl Handler> } self.num_uncompleted_materialization_tasks -= 1; if self.num_uncompleted_materialization_tasks == 0 { - let collection_info = match self.context.collection_info.take() { + let collection_info = match self.context.input_collection_info.take() { Some(info) => info, None => { self.terminate_with_result( @@ -781,6 +781,7 @@ impl Handler> } }; let materialized = std::mem::take(&mut self.materialized_outputs); + println!("Number of materialized results: {}", materialized.len()); self.terminate_with_result( Ok(Success::new(materialized, collection_info.clone()).into()), ctx, diff --git a/rust/worker/src/execution/orchestration/mod.rs b/rust/worker/src/execution/orchestration/mod.rs index f00e8b7105a..6db9bebd9f3 100644 --- a/rust/worker/src/execution/orchestration/mod.rs +++ b/rust/worker/src/execution/orchestration/mod.rs @@ -1,4 +1,5 @@ pub mod apply_logs_orchestrator; +pub mod attached_function_orchestrator; pub(crate) mod compact; pub(crate) mod count; pub mod get; diff --git a/rust/worker/src/execution/orchestration/register_orchestrator.rs b/rust/worker/src/execution/orchestration/register_orchestrator.rs index 324c7943ecc..eaa01cb8735 100644 --- a/rust/worker/src/execution/orchestration/register_orchestrator.rs +++ b/rust/worker/src/execution/orchestration/register_orchestrator.rs @@ -11,10 +11,17 @@ use tokio::sync::oneshot::error::RecvError; use tokio::sync::oneshot::Sender; use tracing::Span; +use crate::execution::operators::finish_attached_function::{ + FinishAttachedFunctionError, FinishAttachedFunctionInput, FinishAttachedFunctionOperator, + FinishAttachedFunctionOutput, +}; use crate::execution::operators::register::{ RegisterError, RegisterInput, RegisterOperator, RegisterOutput, }; +use crate::execution::orchestration::attached_function_orchestrator::FunctionContext; +use crate::execution::orchestration::compact::CollectionCompactInfo; use crate::execution::orchestration::compact::CompactionContextError; +use chroma_types::NonceUuid; use super::compact::{CompactionContext, ExecutionState}; @@ -24,8 +31,35 @@ pub struct RegisterOrchestrator { dispatcher: ComponentHandle, result_channel: Option>>, _state: ExecutionState, - flush_results: Vec, - collection_logical_size_bytes: u64, + // Attached function fields + input_collection_register_info: Option, + output_collection_register_info: CollectionRegisterInfo, + function_context: Option, +} + +#[derive(Debug)] +pub struct CollectionRegisterInfo { + pub collection_info: CollectionCompactInfo, + pub flush_results: Vec, + pub collection_logical_size_bytes: u64, +} + +impl From<&CollectionRegisterInfo> for chroma_types::CollectionFlushInfo { + fn from(info: &CollectionRegisterInfo) -> Self { + chroma_types::CollectionFlushInfo { + tenant_id: info.collection_info.collection.tenant.clone(), + collection_id: info.collection_info.collection_id, + log_position: info.collection_info.pulled_log_offset, + collection_version: info.collection_info.collection.version, + segment_flush_info: info.flush_results.clone().into(), + total_records_post_compaction: info + .collection_info + .collection + .total_records_post_compaction, + size_bytes_post_compaction: info.collection_logical_size_bytes, + schema: info.collection_info.schema.clone(), + } + } } #[derive(Debug)] @@ -91,20 +125,44 @@ where } } +impl From for RegisterOrchestratorError { + fn from(value: FinishAttachedFunctionError) -> Self { + RegisterOrchestratorError::Register(value.into()) + } +} + impl RegisterOrchestrator { pub fn new( context: &CompactionContext, dispatcher: ComponentHandle, - flush_results: Vec, - collection_logical_size_bytes: u64, + collection_register_info: CollectionRegisterInfo, ) -> Self { RegisterOrchestrator { context: context.clone(), dispatcher, result_channel: None, _state: ExecutionState::Register, - flush_results, - collection_logical_size_bytes, + input_collection_register_info: None, + output_collection_register_info: collection_register_info, + function_context: None, + } + } + + pub fn new_with_attached_function( + context: CompactionContext, + dispatcher: ComponentHandle, + input_collection_register_info: CollectionRegisterInfo, + output_collection_register_info: CollectionRegisterInfo, + function_context: FunctionContext, + ) -> Self { + RegisterOrchestrator { + context, + dispatcher, + result_channel: None, + _state: ExecutionState::Register, + input_collection_register_info: Some(input_collection_register_info), + output_collection_register_info, + function_context: Some(function_context), } } } @@ -134,38 +192,88 @@ impl Orchestrator for RegisterOrchestrator { &mut self, ctx: &ComponentContext, ) -> Vec<(TaskMessage, Option)> { - // Check if collection is set before proceeding - let collection_info = match self.context.get_collection_info() { - Ok(collection_info) => collection_info, - Err(e) => { - self.terminate_with_result(Err(e.into()), ctx).await; - return vec![]; - } - }; + // Check if we have attached function context + if let ( + Some(input_collection_register_info), + output_collection_register_info, + Some(function_context), + ) = ( + &self.input_collection_register_info, + &self.output_collection_register_info, + &self.function_context, + ) { + // Use FinishAttachedFunctionOperator for attached function workflow + // Build collections vector with output and input collections + let collection_flush_infos = vec![ + output_collection_register_info.into(), + input_collection_register_info.into(), + ]; - vec![( - wrap( - RegisterOperator::new(), - RegisterInput::new( - collection_info.collection.tenant.clone(), - collection_info.collection_id, - collection_info.pulled_log_offset, - collection_info.collection.version, - self.flush_results.clone().into(), - collection_info.collection.total_records_post_compaction, - self.collection_logical_size_bytes, - self.context.sysdb.clone(), - self.context.log.clone(), - collection_info.schema.clone(), + vec![( + wrap( + FinishAttachedFunctionOperator::new(), + FinishAttachedFunctionInput::new( + collection_flush_infos, + function_context.attached_function_id, + NonceUuid::new(), + input_collection_register_info + .collection_info + .pulled_log_offset as u64, + self.context.sysdb.clone(), + self.context.log.clone(), + ), + ctx.receiver(), + self.context + .orchestrator_context + .task_cancellation_token + .clone(), + ), + Some(Span::current()), + )] + } else { + // Use regular RegisterOperator for normal compaction + let output_collection_register_info = &self.output_collection_register_info; + vec![( + wrap( + RegisterOperator::new(), + RegisterInput::new( + output_collection_register_info + .collection_info + .collection + .tenant + .clone(), + output_collection_register_info + .collection_info + .collection_id, + output_collection_register_info + .collection_info + .pulled_log_offset, + output_collection_register_info + .collection_info + .collection + .version, + output_collection_register_info.flush_results.clone().into(), + output_collection_register_info + .collection_info + .collection + .total_records_post_compaction, + output_collection_register_info.collection_logical_size_bytes, + self.context.sysdb.clone(), + self.context.log.clone(), + output_collection_register_info + .collection_info + .schema + .clone(), + ), + ctx.receiver(), + self.context + .orchestrator_context + .task_cancellation_token + .clone(), ), - ctx.receiver(), - self.context - .orchestrator_context - .task_cancellation_token - .clone(), - ), - Some(Span::current()), - )] + Some(Span::current()), + )] + } } } @@ -178,7 +286,7 @@ impl Handler> for RegisterOrchestrator message: TaskResult, ctx: &ComponentContext, ) { - let collection_info = match self.context.get_collection_info() { + let collection_info = match self.context.get_input_collection_info() { Ok(collection_info) => collection_info, Err(e) => { self.terminate_with_result(Err(e.into()), ctx).await; @@ -196,3 +304,40 @@ impl Handler> for RegisterOrchestrator .await; } } + +#[async_trait] +impl Handler> + for RegisterOrchestrator +{ + type Result = (); + + async fn handle( + &mut self, + message: TaskResult, + ctx: &ComponentContext, + ) { + let collection_info = match self.context.get_input_collection_info() { + Ok(collection_info) => collection_info, + Err(e) => { + self.terminate_with_result(Err(e.into()), ctx).await; + return; + } + }; + + self.terminate_with_result( + message + .into_inner() + .map_err(|e| match e { + TaskError::TaskFailed(inner_error) => { + RegisterOrchestratorError::Register(inner_error.into()) + } + other_error => other_error.into(), + }) + .map(|_| RegisterOrchestratorResponse { + job_id: collection_info.collection_id.into(), + }), + ctx, + ) + .await; + } +}