@@ -417,6 +417,81 @@ def test_distributed_initialization(self):
417417 self .assertEqual (integrator .rank , 0 )
418418 self .assertEqual (integrator .world_size , 1 )
419419
420+ def test_statistics_worldsize_gt1 (self ):
421+ """Mock 分布式 gather 测试 world_size > 1 分支覆盖"""
422+
423+ bounds = torch .tensor ([[0.0 , 1.0 ]], dtype = torch .float64 )
424+ f = lambda x , fx : fx .copy_ (x ) # 不重要,只是占位
425+ integrator = Integrator (bounds = bounds , f = f )
426+ integrator .world_size = 2
427+ integrator .rank = 0
428+
429+ means = torch .ones ((2 , 1 ))
430+ vars = torch .ones ((2 , 1 )) * 0.5
431+
432+ # ---- 构造假的 dist 模块 ----
433+ class DummyDist :
434+ def gather (self , tensor , gather_list = None , dst = 0 ):
435+ # 模拟 rank 0 收到两份数据
436+ if gather_list is not None :
437+ gather_list [0 ].copy_ (tensor )
438+ gather_list [1 ].copy_ (tensor * 2 )
439+
440+ def get_rank (self ):
441+ return integrator .rank
442+
443+ def get_world_size (self ):
444+ return integrator .world_size
445+
446+ def is_initialized (self ):
447+ return True
448+
449+ import MCintegration .integrators as integrators_module
450+ orig_dist = integrators_module .dist
451+ integrators_module .dist = DummyDist ()
452+
453+ try :
454+ result = integrator .statistics (means , vars , neval = 100 )
455+ self .assertIsNotNone (result )
456+ self .assertTrue (hasattr (result , "__len__" ))
457+ finally :
458+ # 恢复原 dist
459+ integrators_module .dist = orig_dist
460+
461+ def test_statistics_worldsize_gt1_rank1 (self ):
462+ """Mock 分布式测试 rank != 0 分支覆盖"""
463+
464+ bounds = torch .tensor ([[0.0 , 1.0 ]], dtype = torch .float64 )
465+ f = lambda x , fx : fx .copy_ (x )
466+ integrator = Integrator (bounds = bounds , f = f )
467+ integrator .world_size = 2
468+ integrator .rank = 1
469+
470+ means = torch .ones ((2 , 1 ))
471+ vars = torch .ones ((2 , 1 )) * 0.5
472+
473+ class DummyDist :
474+ def gather (self , tensor , gather_list = None , dst = 0 ):
475+ pass # rank!=0 的情况
476+
477+ def get_rank (self ):
478+ return integrator .rank
479+
480+ def get_world_size (self ):
481+ return integrator .world_size
482+
483+ def is_initialized (self ):
484+ return True
485+
486+ import MCintegration .integrators as integrators_module
487+ orig_dist = integrators_module .dist
488+ integrators_module .dist = DummyDist ()
489+
490+ try :
491+ result = integrator .statistics (means , vars , neval = 100 )
492+ self .assertIsNone (result )
493+ finally :
494+ integrators_module .dist = orig_dist
420495 # @unittest.skipIf(not torch.distributed.is_available(), "Distributed not available")
421496 # def test_multi_gpu_consistency(self):
422497 # if torch.cuda.device_count() >= 2:
0 commit comments