@@ -29,6 +29,10 @@ import mill.scalalib.bsp.BspModule
2929import mill .scalalib .publish .Artifact
3030import mill .util .Jvm
3131import os .Path
32+ import mill .testrunner .TestResult
33+ import mill .scalalib .api .TransitiveSourceStampResults
34+ import scala .collection .immutable .TreeMap
35+ import scala .util .Try
3236
3337/**
3438 * Core configuration required to compile a single Java compilation target
@@ -103,6 +107,173 @@ trait JavaModule
103107 case _ : ClassNotFoundException => // if we can't find the classes, we certainly are not in a ScalaJSModule
104108 }
105109 }
110+
111+ def testQuick (args : String * ): Command [(String , Seq [TestResult ])] = Task .Command (persistent = true ) {
112+ val quicktestFailedClassesLog = Task .dest / " quickTestFailedClasses.json"
113+ val invalidatedClassesLog = Task .dest / " invalidatedClasses.json"
114+ val failedTestClasses =
115+ if (! os.exists(quicktestFailedClassesLog)) {
116+ Set .empty[String ]
117+ } else {
118+ Try {
119+ upickle.default.read[Seq [String ]](os.read.stream(quicktestFailedClassesLog))
120+ }.getOrElse(Seq .empty[String ]).toSet
121+ }
122+
123+ val transitiveStampsFile = Task .dest / " transitiveStamps.json"
124+ val previousStampsOpt = if (os.exists(transitiveStampsFile)) {
125+ val previousStamps = upickle.default.read[TransitiveSourceStampResults ](
126+ os.read.stream(transitiveStampsFile)
127+ ).currentStamps
128+ os.remove(transitiveStampsFile)
129+ Some (previousStamps)
130+ } else {
131+ None
132+ }
133+
134+ def getAnalysisStore (compileResult : CompilationResult ): Option [xsbti.compile.CompileAnalysis ] = {
135+ val analysisStore = sbt.internal.inc.consistent.ConsistentFileAnalysisStore .binary(
136+ file = compileResult.analysisFile.toIO,
137+ mappers = xsbti.compile.analysis.ReadWriteMappers .getEmptyMappers(),
138+ reproducible = true ,
139+ parallelism = math.min(Runtime .getRuntime.availableProcessors(), 8 )
140+ )
141+ val analysisOptional = analysisStore.get()
142+ if (analysisOptional.isPresent) Some (analysisOptional.get.getAnalysis) else None
143+ }
144+
145+ val combinedAnalysis = (compile() +: upstreamCompileOutput())
146+ .flatMap(getAnalysisStore)
147+ .flatMap {
148+ case analysis : sbt.internal.inc.Analysis => Some (analysis)
149+ case _ => None
150+ }
151+ .foldLeft(sbt.internal.inc.Analysis .empty)(_ ++ _)
152+
153+ val result = TransitiveSourceStampResults (
154+ currentStamps = TreeMap .from(
155+ combinedAnalysis.stamps.sources.view.map { (source, stamp) =>
156+ source.id() -> stamp.writeStamp()
157+ }
158+ ),
159+ previousStamps = previousStampsOpt
160+ )
161+
162+ def getInvalidatedClasspaths (
163+ initialInvalidatedClassNames : Set [String ],
164+ relations : sbt.internal.inc.Relations
165+ ): Set [os.Path ] = {
166+ val seen = collection.mutable.Set .empty[String ]
167+ val seenList = collection.mutable.Buffer .empty[String ]
168+ val queued = collection.mutable.Queue .from(initialInvalidatedClassNames)
169+
170+ while (queued.nonEmpty) {
171+ val current = queued.dequeue()
172+ seenList.append(current)
173+ seen.add(current)
174+
175+ for (next <- relations.usesInternalClass(current)) {
176+ if (! seen.contains(next)) {
177+ seen.add(next)
178+ queued.enqueue(next)
179+ }
180+ }
181+
182+ for (next <- relations.usesExternal(current)) {
183+ if (! seen.contains(next)) {
184+ seen.add(next)
185+ queued.enqueue(next)
186+ }
187+ }
188+ }
189+
190+ seenList
191+ .iterator
192+ .flatMap { invalidatedClassName =>
193+ relations.definesClass(invalidatedClassName)
194+ }
195+ .flatMap { source =>
196+ relations.products(source)
197+ }
198+ .map { product =>
199+ os.Path (product.id)
200+ }
201+ .toSet
202+ }
203+
204+ val relations = combinedAnalysis.relations
205+
206+ val invalidatedAbsoluteClasspaths = getInvalidatedClasspaths(
207+ result.changedSources.flatMap { source =>
208+ relations.classNames(xsbti.VirtualFileRef .of(source))
209+ },
210+ combinedAnalysis.relations
211+ )
212+
213+ // We only care about testing class, so we can:
214+ // - filter out all class path that start with `testClasspath()`
215+ // - strip the prefix and safely turn them into module class path
216+
217+ val testClasspaths = testClasspath()
218+ val invalidatedClassNames = invalidatedAbsoluteClasspaths.flatMap { absoluteClasspath =>
219+ testClasspaths.collectFirst {
220+ case path if absoluteClasspath.startsWith(path.path) =>
221+ absoluteClasspath.relativeTo(path.path).segments.map(_.stripSuffix(" .class" )).mkString(" ." )
222+ }
223+ }
224+ val testingClasses = invalidatedClassNames ++ failedTestClasses
225+ val testClasses = testForkGrouping().map(_.filter(testingClasses.contains)).filter(_.nonEmpty)
226+
227+ // Clean up the directory for test runners
228+ os.walk(Task .dest).foreach { subPath => os.remove.all(subPath) }
229+
230+ val quickTestReportXml = testReportXml()
231+
232+ val testModuleUtil = new TestModuleUtil (
233+ testUseArgsFile(),
234+ forkArgs(),
235+ Seq .empty,
236+ zincWorker().scalalibClasspath(),
237+ resources(),
238+ testFramework(),
239+ runClasspath(),
240+ testClasspaths,
241+ args.toSeq,
242+ testClasses,
243+ zincWorker().testrunnerEntrypointClasspath(),
244+ forkEnv(),
245+ testSandboxWorkingDir(),
246+ forkWorkingDir(),
247+ quickTestReportXml,
248+ zincWorker().javaHome().map(_.path),
249+ testParallelism()
250+ )
251+
252+ val results = testModuleUtil.runTests()
253+
254+ val badTestClasses = (results match {
255+ case Result .Failure (_) =>
256+ // Consider all quick testing classes as failed
257+ testClasses.flatten
258+ case Result .Success ((_, results)) =>
259+ // Get all test classes that failed
260+ results
261+ .filter(testResult => Set (" Error" , " Failure" ).contains(testResult.status))
262+ .map(_.fullyQualifiedName)
263+ }).distinct
264+
265+ os.write.over(transitiveStampsFile, upickle.default.write(result))
266+ os.write.over(quicktestFailedClassesLog, upickle.default.write(badTestClasses))
267+ os.write.over(invalidatedClassesLog, upickle.default.write(invalidatedClassNames))
268+ results match {
269+ case Result .Failure (errMsg) => Result .Failure (errMsg)
270+ case Result .Success ((doneMsg, results)) =>
271+ try TestModule .handleResults(doneMsg, results, Task .ctx(), quickTestReportXml)
272+ catch {
273+ case e : Throwable => Result .Failure (" Test reporting failed: " + e)
274+ }
275+ }
276+ }
106277 }
107278
108279 def defaultCommandName (): String = " run"
0 commit comments