diff --git a/CHANGELOG.md b/CHANGELOG.md index 415561d..f1c8869 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,12 @@ All notable changes to this project are documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## Unreleased + +### Fixed + +- Fixed a data-race when running sub-tests in separate goroutines. + ## [1.3.1] - 2026-05-29 ### Added diff --git a/internal/syncutil/syncutil.go b/internal/syncutil/syncutil.go new file mode 100644 index 0000000..f378dbf --- /dev/null +++ b/internal/syncutil/syncutil.go @@ -0,0 +1,29 @@ +package syncutil + +import "sync" + +// Guarded wraps given type in a [sync.RWMutex]. +type Guarded[T any] struct { + value T + mu sync.RWMutex +} + +func (g *Guarded[T]) Load() T { + g.mu.RLock() + defer g.mu.RUnlock() + + return g.value +} + +func (g *Guarded[T]) Store(value T) { + g.Modify(func(v *T) { + *v = value + }) +} + +func (g *Guarded[T]) Modify(f func(value *T)) { + g.mu.Lock() + defer g.mu.Unlock() + + f(&g.value) +} diff --git a/runner.go b/runner.go index 0b5cc7c..eb598bf 100644 --- a/runner.go +++ b/runner.go @@ -141,7 +141,9 @@ func RunSubSuite[Suite suite[Sub], Parent, Sub CommonT]( r := newRunner[Suite](t) - return r.runSuite(t.unwrap().testingT, suite, &t.unwrap().reflection.Suite, options...) + parent := t.unwrap().reflection.Load().Suite + + return r.runSuite(t.unwrap().testingT, suite, &parent, options...) } // Run runs f as a subtest of t called name. It runs f in a separate goroutine @@ -174,14 +176,19 @@ func Run[T CommonT]( &parentT, func(t *testoT) { t.testNamer = parentT.unwrap().testNamer - t.reflection.Suite = parentT.unwrap().reflection.Suite - t.reflection.Test = testoreflect.RegularTestInfo{ - Name: parentT.unwrap().testNamer.Name(parentT.unwrap().Name(), name), - RawBaseName: name, - Level: t.level(), - IsSubtest: true, - FuncPC: reflect.ValueOf(f).Pointer(), - } + + parentSuite := parentT.unwrap().reflection.Load().Suite + + t.reflection.Modify(func(r *testoreflect.Reflection) { + r.Suite = parentSuite + r.Test = testoreflect.RegularTestInfo{ + Name: parentT.unwrap().testNamer.Name(parentT.unwrap().Name(), name), + RawBaseName: name, + Level: t.level(), + IsSubtest: true, + FuncPC: reflect.ValueOf(f).Pointer(), + } + }) }, options..., ) @@ -190,10 +197,12 @@ func Run[T CommonT]( if r := recover(); r != nil { trace := string(debug.Stack()) - t.unwrap().reflection.Panic = &testoreflect.PanicInfo{ - Value: r, - Trace: trace, - } + t.unwrap().reflection.Modify(func(r *testoreflect.Reflection) { + r.Panic = &testoreflect.PanicInfo{ + Value: r, + Trace: trace, + } + }) t.Fatalf("testo: test %q panicked: %v\n\n%s", t.Name(), r, trace) } @@ -269,11 +278,14 @@ func (r *runner[Suite, T]) runSuite( nil, func(t *testoT) { t.testNamer = r.testNamer - t.reflection.Suite = suiteInfo - t.reflection.Test = testoreflect.RegularTestInfo{ - Name: r.caller, - RawBaseName: r.suiteName, - } + + t.reflection.Modify(func(ref *testoreflect.Reflection) { + ref.Suite = suiteInfo + ref.Test = testoreflect.RegularTestInfo{ + Name: r.caller, + RawBaseName: r.suiteName, + } + }) }, options..., ) @@ -311,13 +323,15 @@ func (r *runner[Suite, T]) runSuiteTests(t T, s Suite, tests suiteTests[Suite, T s.BeforeAll(t) + suiteReflection := t.unwrap().reflection.Load().Suite + suiteInfo := testoreflect.SuiteInfo{ - Parent: t.unwrap().reflection.Suite.Parent, - Name: t.unwrap().reflection.Suite.Name, - Caller: t.unwrap().reflection.Suite.Caller, - TestingT: t.unwrap().reflection.Suite.TestingT, + Parent: suiteReflection.Parent, + Name: suiteReflection.Name, + Caller: suiteReflection.Caller, + TestingT: suiteReflection.TestingT, Value: s, - Hooks: t.unwrap().reflection.Suite.Hooks, + Hooks: suiteReflection.Hooks, } allTests := r.applyPlan( @@ -338,8 +352,11 @@ func (r *runner[Suite, T]) runSuiteTests(t T, s Suite, tests suiteTests[Suite, T &t, func(t *testoT) { t.testNamer = r.testNamer - t.reflection.Suite = suiteInfo - t.reflection.Test = test.Info + + t.reflection.Modify(func(r *testoreflect.Reflection) { + r.Suite = suiteInfo + r.Test = test.Info + }) if test.Configure != nil { test.Configure(t) @@ -369,10 +386,12 @@ func (r *runner[Suite, T]) runSuiteTest( if r := recover(); r != nil { trace := string(debug.Stack()) - t.unwrap().reflection.Panic = &testoreflect.PanicInfo{ - Value: r, - Trace: trace, - } + t.unwrap().reflection.Modify(func(ref *testoreflect.Reflection) { + ref.Panic = &testoreflect.PanicInfo{ + Value: r, + Trace: trace, + } + }) t.Fatalf("testo: test %q panicked: %v\n\n%s", t.Name(), r, trace) } diff --git a/runner_test.go b/runner_test.go index df0e9e3..79652da 100644 --- a/runner_test.go +++ b/runner_test.go @@ -2,7 +2,9 @@ package testo import ( "slices" + "sync" "testing" + "time" "github.com/ozontech/testo/testoplugin" ) @@ -202,3 +204,59 @@ func (s TestSuite) TestFoo(t *TestT) { } func (s *TestSuite) TestBar(t *TestT) {} + +type PluginGoroutine struct { + *T + + wg sync.WaitGroup +} + +func (p *PluginGoroutine) Go(f func()) { + p.Helper() + + p.wg.Add(1) + + go func() { + defer p.wg.Done() + + p.Helper() + + f() + }() +} + +func (p *PluginGoroutine) Plugin(testoplugin.Plugin, ...testoplugin.Option) testoplugin.Spec { + return testoplugin.Spec{ + Hooks: testoplugin.Hooks{ + AfterAll: p.after(), + AfterEach: p.after(), + AfterEachSub: p.after(), + }, + } +} + +func (p *PluginGoroutine) after() testoplugin.Hook { + return testoplugin.Hook{ + Priority: testoplugin.TryFirst, + Func: p.wg.Wait, + } +} + +func TestDataRace(t *testing.T) { + t.Parallel() + + type DataRaceT struct { + *T + *PluginGoroutine + } + + RunTest(t, func(t DataRaceT) { + t.Go(func() { + Run(t, "inner", func(t DataRaceT) { + time.Sleep(time.Second) + }) + }) + + t.Log("test") + }) +} diff --git a/suite.go b/suite.go index 904a259..1357c3e 100644 --- a/suite.go +++ b/suite.go @@ -1,6 +1,9 @@ package testo -import "github.com/ozontech/testo/testoplugin" +import ( + "github.com/ozontech/testo/testoplugin" + "github.com/ozontech/testo/testoreflect" +) // singleton is a special (virtual) suite with a single test. // @@ -71,22 +74,30 @@ type Suite[T CommonT] struct { // BeforeAll hook. func (Suite[T]) BeforeAll(t T) { - t.unwrap().reflection.Suite.Hooks.MissedBeforeAll = true + t.unwrap().reflection.Modify(func(r *testoreflect.Reflection) { + r.Suite.Hooks.MissedBeforeAll = true + }) } // BeforeEach hook. func (Suite[T]) BeforeEach(t T) { - t.unwrap().reflection.Suite.Hooks.MissedBeforeEach = true + t.unwrap().reflection.Modify(func(r *testoreflect.Reflection) { + r.Suite.Hooks.MissedBeforeEach = true + }) } // AfterEach hook. func (Suite[T]) AfterEach(t T) { - t.unwrap().reflection.Suite.Hooks.MissedAfterEach = true + t.unwrap().reflection.Modify(func(r *testoreflect.Reflection) { + r.Suite.Hooks.MissedAfterEach = true + }) } // AfterAll hook. func (Suite[T]) AfterAll(t T) { - t.unwrap().reflection.Suite.Hooks.MissedAfterAll = true + t.unwrap().reflection.Modify(func(r *testoreflect.Reflection) { + r.Suite.Hooks.MissedAfterAll = true + }) } func (Suite[T]) private() {} diff --git a/t.go b/t.go index c83d0a6..3a2f62d 100644 --- a/t.go +++ b/t.go @@ -15,6 +15,7 @@ import ( "time" "github.com/ozontech/testo/internal/reflectutil" + "github.com/ozontech/testo/internal/syncutil" "github.com/ozontech/testo/internal/testnamer" "github.com/ozontech/testo/testoplugin" "github.com/ozontech/testo/testoreflect" @@ -74,7 +75,7 @@ type ( levelOptions []testoplugin.Option // reflection holds information for [Reflect]. - reflection testoreflect.Reflection + reflection syncutil.Guarded[testoreflect.Reflection] failureSource atomicInt[testoreflect.TestFailureSource] failureKind atomicInt[testoreflect.TestFailureKind] @@ -124,7 +125,7 @@ func (t *T) parallel() { t.Helper() if t.propagateParallel { - t.reflection.Suite.TestingT.Parallel() + t.reflection.Load().Suite.TestingT.Parallel() return } @@ -409,7 +410,7 @@ func (t *T) Cleanup(f func()) { func (t *T) Name() string { t.Helper() - return t.reflection.Test.GetName() + return t.reflection.Load().Test.GetName() } // unwrap the underlying T. @@ -560,7 +561,7 @@ func Reflect(t CommonT) testoreflect.Reflection { internal := t.unwrap() - info := internal.reflection + info := internal.reflection.Load() info.FailureSource = internal.failureSource.Load() info.FailureKind = internal.failureKind.Load()