Skip to content

Commit bb236aa

Browse files
committed
try
1 parent d1601d6 commit bb236aa

File tree

1 file changed

+16
-92
lines changed

1 file changed

+16
-92
lines changed

MCintegration/integrators_test.py

Lines changed: 16 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -416,102 +416,26 @@ def test_distributed_initialization(self):
416416
integrator = Integrator(bounds=bounds, f=f)
417417
self.assertEqual(integrator.rank, 0)
418418
self.assertEqual(integrator.world_size, 1)
419+
@unittest.skipIf(not torch.distributed.is_available(), "Distributed not available")
420+
def test_multi_gpu_consistency(self):
421+
if torch.cuda.device_count() >= 2:
422+
bounds = torch.tensor([[0.0, 1.0]], dtype=torch.float64)
423+
f = lambda x, fx: torch.ones_like(x)
419424

420-
def test_statistics_worldsize_gt1(self):
421-
"""Mock 分布式 gather 测试 world_size > 1 分支覆盖"""
422-
423-
bounds = torch.tensor([[0.0, 1.0]], dtype=torch.float64)
424-
f = lambda x, fx: fx.copy_(x) # 不重要,只是占位
425-
integrator = Integrator(bounds=bounds, f=f)
426-
integrator.world_size = 2
427-
integrator.rank = 0
428-
429-
means = torch.ones((2, 1))
430-
vars = torch.ones((2, 1)) * 0.5
431-
432-
# ---- 构造假的 dist 模块 ----
433-
class DummyDist:
434-
def gather(self, tensor, gather_list=None, dst=0):
435-
# 模拟 rank 0 收到两份数据
436-
if gather_list is not None:
437-
gather_list[0].copy_(tensor)
438-
gather_list[1].copy_(tensor * 2)
439-
440-
def get_rank(self):
441-
return integrator.rank
442-
443-
def get_world_size(self):
444-
return integrator.world_size
445-
446-
def is_initialized(self):
447-
return True
448-
449-
import MCintegration.integrators as integrators_module
450-
orig_dist = integrators_module.dist
451-
integrators_module.dist = DummyDist()
425+
# Create two integrators on different devices
426+
integrator1 = Integrator(bounds=bounds, f=f, device="cuda:0")
427+
integrator2 = Integrator(bounds=bounds, f=f, device="cuda:1")
452428

453-
try:
454-
result = integrator.statistics(means, vars, neval=100)
455-
self.assertIsNotNone(result)
456-
self.assertTrue(hasattr(result, "__len__"))
457-
finally:
458-
# 恢复原 dist
459-
integrators_module.dist = orig_dist
429+
# Results should be consistent across devices
430+
result1 = integrator1(neval=10000)
431+
result2 = integrator2(neval=10000)
460432

461-
def test_statistics_worldsize_gt1_rank1(self):
462-
"""Mock 分布式测试 rank != 0 分支覆盖"""
433+
if hasattr(result1, "mean"):
434+
value1, value2 = result1.mean, result2.mean
435+
else:
436+
value1, value2 = result1, result2
463437

464-
bounds = torch.tensor([[0.0, 1.0]], dtype=torch.float64)
465-
f = lambda x, fx: fx.copy_(x)
466-
integrator = Integrator(bounds=bounds, f=f)
467-
integrator.world_size = 2
468-
integrator.rank = 1
469-
470-
means = torch.ones((2, 1))
471-
vars = torch.ones((2, 1)) * 0.5
472-
473-
class DummyDist:
474-
def gather(self, tensor, gather_list=None, dst=0):
475-
pass # rank!=0 的情况
476-
477-
def get_rank(self):
478-
return integrator.rank
479-
480-
def get_world_size(self):
481-
return integrator.world_size
482-
483-
def is_initialized(self):
484-
return True
485-
486-
import MCintegration.integrators as integrators_module
487-
orig_dist = integrators_module.dist
488-
integrators_module.dist = DummyDist()
489-
490-
try:
491-
result = integrator.statistics(means, vars, neval=100)
492-
self.assertIsNone(result)
493-
finally:
494-
integrators_module.dist = orig_dist
495-
# @unittest.skipIf(not torch.distributed.is_available(), "Distributed not available")
496-
# def test_multi_gpu_consistency(self):
497-
# if torch.cuda.device_count() >= 2:
498-
# bounds = torch.tensor([[0.0, 1.0]], dtype=torch.float64)
499-
# f = lambda x, fx: torch.ones_like(x)
500-
501-
# # Create two integrators on different devices
502-
# integrator1 = Integrator(bounds=bounds, f=f, device="cuda:0")
503-
# integrator2 = Integrator(bounds=bounds, f=f, device="cuda:1")
504-
505-
# # Results should be consistent across devices
506-
# result1 = integrator1(neval=10000)
507-
# result2 = integrator2(neval=10000)
508-
509-
# if hasattr(result1, "mean"):
510-
# value1, value2 = result1.mean, result2.mean
511-
# else:
512-
# value1, value2 = result1, result2
513-
514-
# self.assertAlmostEqual(float(value1), float(value2), places=1)
438+
self.assertAlmostEqual(float(value1), float(value2), places=1)
515439

516440

517441
if __name__ == "__main__":

0 commit comments

Comments
 (0)