Skip to content

Commit d1601d6

Browse files
committed
fix
1 parent 231bad6 commit d1601d6

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed

MCintegration/integrators_test.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,81 @@ def test_distributed_initialization(self):
417417
self.assertEqual(integrator.rank, 0)
418418
self.assertEqual(integrator.world_size, 1)
419419

420+
def test_statistics_worldsize_gt1(self):
421+
"""Mock 分布式 gather 测试 world_size > 1 分支覆盖"""
422+
423+
bounds = torch.tensor([[0.0, 1.0]], dtype=torch.float64)
424+
f = lambda x, fx: fx.copy_(x) # 不重要,只是占位
425+
integrator = Integrator(bounds=bounds, f=f)
426+
integrator.world_size = 2
427+
integrator.rank = 0
428+
429+
means = torch.ones((2, 1))
430+
vars = torch.ones((2, 1)) * 0.5
431+
432+
# ---- 构造假的 dist 模块 ----
433+
class DummyDist:
434+
def gather(self, tensor, gather_list=None, dst=0):
435+
# 模拟 rank 0 收到两份数据
436+
if gather_list is not None:
437+
gather_list[0].copy_(tensor)
438+
gather_list[1].copy_(tensor * 2)
439+
440+
def get_rank(self):
441+
return integrator.rank
442+
443+
def get_world_size(self):
444+
return integrator.world_size
445+
446+
def is_initialized(self):
447+
return True
448+
449+
import MCintegration.integrators as integrators_module
450+
orig_dist = integrators_module.dist
451+
integrators_module.dist = DummyDist()
452+
453+
try:
454+
result = integrator.statistics(means, vars, neval=100)
455+
self.assertIsNotNone(result)
456+
self.assertTrue(hasattr(result, "__len__"))
457+
finally:
458+
# 恢复原 dist
459+
integrators_module.dist = orig_dist
460+
461+
def test_statistics_worldsize_gt1_rank1(self):
462+
"""Mock 分布式测试 rank != 0 分支覆盖"""
463+
464+
bounds = torch.tensor([[0.0, 1.0]], dtype=torch.float64)
465+
f = lambda x, fx: fx.copy_(x)
466+
integrator = Integrator(bounds=bounds, f=f)
467+
integrator.world_size = 2
468+
integrator.rank = 1
469+
470+
means = torch.ones((2, 1))
471+
vars = torch.ones((2, 1)) * 0.5
472+
473+
class DummyDist:
474+
def gather(self, tensor, gather_list=None, dst=0):
475+
pass # rank!=0 的情况
476+
477+
def get_rank(self):
478+
return integrator.rank
479+
480+
def get_world_size(self):
481+
return integrator.world_size
482+
483+
def is_initialized(self):
484+
return True
485+
486+
import MCintegration.integrators as integrators_module
487+
orig_dist = integrators_module.dist
488+
integrators_module.dist = DummyDist()
489+
490+
try:
491+
result = integrator.statistics(means, vars, neval=100)
492+
self.assertIsNone(result)
493+
finally:
494+
integrators_module.dist = orig_dist
420495
# @unittest.skipIf(not torch.distributed.is_available(), "Distributed not available")
421496
# def test_multi_gpu_consistency(self):
422497
# if torch.cuda.device_count() >= 2:

0 commit comments

Comments
 (0)