From 8e92a039b73b16549419c008a7494c33ede08889 Mon Sep 17 00:00:00 2001 From: Stephen Zhang Date: Fri, 17 Apr 2026 07:23:37 +1000 Subject: [PATCH] Fix assignment of transport --- torchcfm/optimal_transport.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchcfm/optimal_transport.py b/torchcfm/optimal_transport.py index fec7db3..d380cd9 100644 --- a/torchcfm/optimal_transport.py +++ b/torchcfm/optimal_transport.py @@ -177,8 +177,8 @@ def sample_plan_with_scipy(self, x0, x1): if self.normalize_cost: M = M / M.max() _, j = scipy.optimize.linear_sum_assignment(M.cpu().numpy()) - pi_x0 = x0[j] - pi_x1 = x1 + pi_x0 = x0 + pi_x1 = x1[j] return pi_x0, pi_x1 def sample_plan_with_labels(self, x0, x1, y0=None, y1=None, replace=True):