diff --git a/pkg/graveler/retention/active_commits_test.go b/pkg/graveler/retention/active_commits_test.go index 46ffd05a5de..ace2b2f3d69 100644 --- a/pkg/graveler/retention/active_commits_test.go +++ b/pkg/graveler/retention/active_commits_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "hash/fnv" "sort" "testing" "time" @@ -165,12 +166,44 @@ func TestCommitsMap(t *testing.T) { } } +type testRepoCommits struct { + commits map[string]testCommit + headsRetentionDays map[string]int32 + expectedActiveIDs []string +} + +func fingerprint[T ~string](s T) T { + h := fnv.New64a() + h.Write([]byte(s)) // Write cannot fail for a hash func. + fingerprint := h.Sum64() + return T(fmt.Sprintf("%016x-%s", fingerprint, s)) +} + +func scrambleIDs(tst testRepoCommits) testRepoCommits { + commits := make(map[string]testCommit, len(tst.commits)) + for id, commit := range tst.commits { + for idx, parent := range commit.parents { + commit.parents[idx] = fingerprint(parent) + } + commits[fingerprint(id)] = commit + } + headsRetentionDays := make(map[string]int32, len(tst.headsRetentionDays)) + for id, days := range tst.headsRetentionDays { + headsRetentionDays[fingerprint(id)] = days + } + expectedActiveIDs := make([]string, len(tst.expectedActiveIDs)) + for idx, id := range tst.expectedActiveIDs { + expectedActiveIDs[idx] = fingerprint(id) + } + return testRepoCommits{ + commits: commits, + headsRetentionDays: headsRetentionDays, + expectedActiveIDs: expectedActiveIDs, + } +} + func TestActiveCommits(t *testing.T) { - tests := map[string]struct { - commits map[string]testCommit - headsRetentionDays map[string]int32 - expectedActiveIDs []string - }{ + tests := map[string]testRepoCommits{ "two_branches": { commits: map[string]testCommit{ "a": newTestCommit(15), @@ -330,7 +363,7 @@ func TestActiveCommits(t *testing.T) { expectedActiveIDs: []string{"h1", "h2", "h3", "e1", "e2", "e3"}, }, } - for name, tst := range tests { + for name, tstBase := range tests { t.Run(name, func(t *testing.T) { now := time.Now() ctrl := gomock.NewController(t) @@ -339,6 +372,11 @@ func TestActiveCommits(t *testing.T) { repositoryRecord := &graveler.RepositoryRecord{ RepositoryID: "test", } + + // Shuffle startingPoints (below) by replacing all CommitIDs with their + // fingerprints. Otherwise they tend to be alphabetically ordered, + // which can match their time ordering better than will actually happen. + tst := scrambleIDs(tstBase) garbageCollectionRules := &graveler.GarbageCollectionRules{DefaultRetentionDays: 5, BranchRetentionDays: make(map[string]int32)} var branches []*graveler.BranchRecord for head, retentionDays := range tst.headsRetentionDays { @@ -369,12 +407,16 @@ func TestActiveCommits(t *testing.T) { refManagerMock.EXPECT().ListCommits(ctx, repositoryRecord).Return(testutil.NewFakeCommitIterator(commitsRecords), nil).MaxTimes(1) - gcCommits, err := GetGarbageCollectionCommits(ctx, NewGCStartingPointIterator( + startingPoints := NewGCStartingPointIterator( testutil.NewFakeCommitIterator(findMainAncestryLeaves(now, tst.headsRetentionDays, tst.commits)), - testutil.NewFakeBranchIterator(branches)), &repositoryCommitGetter{ - refManager: refManagerMock, - repository: repositoryRecord, - }, garbageCollectionRules) + testutil.NewFakeBranchIterator(branches)) + defer startingPoints.Close() + + gcCommits, err := GetGarbageCollectionCommits(ctx, startingPoints, + &repositoryCommitGetter{ + refManager: refManagerMock, + repository: repositoryRecord, + }, garbageCollectionRules) if err != nil { t.Fatalf("failed to find expired commits: %v", err) } diff --git a/pkg/graveler/retention/starting_point_iterator_test.go b/pkg/graveler/retention/starting_point_iterator_test.go index daa71c2e415..32339dfcc04 100644 --- a/pkg/graveler/retention/starting_point_iterator_test.go +++ b/pkg/graveler/retention/starting_point_iterator_test.go @@ -71,6 +71,7 @@ func TestStartingPointIterator(t *testing.T) { branchIterator := testutil.NewFakeBranchIterator(branchRecords) commitIterator := testutil.NewFakeCommitIterator(commitRecords) it := NewGCStartingPointIterator(commitIterator, branchIterator) + defer it.Close() i := 0 for it.Next() { val := it.Value() @@ -89,7 +90,6 @@ func TestStartingPointIterator(t *testing.T) { if it.Err() != nil { t.Fatalf("unexpected error: %v", it.Err()) } - it.Close() if i != len(expected) { t.Fatalf("got unexpected number of results. expected=%d, got=%d", len(expected), i) }