@@ -416,102 +416,26 @@ def test_distributed_initialization(self):
416416 integrator = Integrator (bounds = bounds , f = f )
417417 self .assertEqual (integrator .rank , 0 )
418418 self .assertEqual (integrator .world_size , 1 )
419+ @unittest .skipIf (not torch .distributed .is_available (), "Distributed not available" )
420+ def test_multi_gpu_consistency (self ):
421+ if torch .cuda .device_count () >= 2 :
422+ bounds = torch .tensor ([[0.0 , 1.0 ]], dtype = torch .float64 )
423+ f = lambda x , fx : torch .ones_like (x )
419424
420- def test_statistics_worldsize_gt1 (self ):
421- """Mock 分布式 gather 测试 world_size > 1 分支覆盖"""
422-
423- bounds = torch .tensor ([[0.0 , 1.0 ]], dtype = torch .float64 )
424- f = lambda x , fx : fx .copy_ (x ) # 不重要,只是占位
425- integrator = Integrator (bounds = bounds , f = f )
426- integrator .world_size = 2
427- integrator .rank = 0
428-
429- means = torch .ones ((2 , 1 ))
430- vars = torch .ones ((2 , 1 )) * 0.5
431-
432- # ---- 构造假的 dist 模块 ----
433- class DummyDist :
434- def gather (self , tensor , gather_list = None , dst = 0 ):
435- # 模拟 rank 0 收到两份数据
436- if gather_list is not None :
437- gather_list [0 ].copy_ (tensor )
438- gather_list [1 ].copy_ (tensor * 2 )
439-
440- def get_rank (self ):
441- return integrator .rank
442-
443- def get_world_size (self ):
444- return integrator .world_size
445-
446- def is_initialized (self ):
447- return True
448-
449- import MCintegration .integrators as integrators_module
450- orig_dist = integrators_module .dist
451- integrators_module .dist = DummyDist ()
425+ # Create two integrators on different devices
426+ integrator1 = Integrator (bounds = bounds , f = f , device = "cuda:0" )
427+ integrator2 = Integrator (bounds = bounds , f = f , device = "cuda:1" )
452428
453- try :
454- result = integrator .statistics (means , vars , neval = 100 )
455- self .assertIsNotNone (result )
456- self .assertTrue (hasattr (result , "__len__" ))
457- finally :
458- # 恢复原 dist
459- integrators_module .dist = orig_dist
429+ # Results should be consistent across devices
430+ result1 = integrator1 (neval = 10000 )
431+ result2 = integrator2 (neval = 10000 )
460432
461- def test_statistics_worldsize_gt1_rank1 (self ):
462- """Mock 分布式测试 rank != 0 分支覆盖"""
433+ if hasattr (result1 , "mean" ):
434+ value1 , value2 = result1 .mean , result2 .mean
435+ else :
436+ value1 , value2 = result1 , result2
463437
464- bounds = torch .tensor ([[0.0 , 1.0 ]], dtype = torch .float64 )
465- f = lambda x , fx : fx .copy_ (x )
466- integrator = Integrator (bounds = bounds , f = f )
467- integrator .world_size = 2
468- integrator .rank = 1
469-
470- means = torch .ones ((2 , 1 ))
471- vars = torch .ones ((2 , 1 )) * 0.5
472-
473- class DummyDist :
474- def gather (self , tensor , gather_list = None , dst = 0 ):
475- pass # rank!=0 的情况
476-
477- def get_rank (self ):
478- return integrator .rank
479-
480- def get_world_size (self ):
481- return integrator .world_size
482-
483- def is_initialized (self ):
484- return True
485-
486- import MCintegration .integrators as integrators_module
487- orig_dist = integrators_module .dist
488- integrators_module .dist = DummyDist ()
489-
490- try :
491- result = integrator .statistics (means , vars , neval = 100 )
492- self .assertIsNone (result )
493- finally :
494- integrators_module .dist = orig_dist
495- # @unittest.skipIf(not torch.distributed.is_available(), "Distributed not available")
496- # def test_multi_gpu_consistency(self):
497- # if torch.cuda.device_count() >= 2:
498- # bounds = torch.tensor([[0.0, 1.0]], dtype=torch.float64)
499- # f = lambda x, fx: torch.ones_like(x)
500-
501- # # Create two integrators on different devices
502- # integrator1 = Integrator(bounds=bounds, f=f, device="cuda:0")
503- # integrator2 = Integrator(bounds=bounds, f=f, device="cuda:1")
504-
505- # # Results should be consistent across devices
506- # result1 = integrator1(neval=10000)
507- # result2 = integrator2(neval=10000)
508-
509- # if hasattr(result1, "mean"):
510- # value1, value2 = result1.mean, result2.mean
511- # else:
512- # value1, value2 = result1, result2
513-
514- # self.assertAlmostEqual(float(value1), float(value2), places=1)
438+ self .assertAlmostEqual (float (value1 ), float (value2 ), places = 1 )
515439
516440
517441if __name__ == "__main__" :
0 commit comments