diff --git a/torchcfm/optimal_transport.py b/torchcfm/optimal_transport.py index fec7db3..d380cd9 100644 --- a/torchcfm/optimal_transport.py +++ b/torchcfm/optimal_transport.py @@ -177,8 +177,8 @@ def sample_plan_with_scipy(self, x0, x1): if self.normalize_cost: M = M / M.max() _, j = scipy.optimize.linear_sum_assignment(M.cpu().numpy()) - pi_x0 = x0[j] - pi_x1 = x1 + pi_x0 = x0 + pi_x1 = x1[j] return pi_x0, pi_x1 def sample_plan_with_labels(self, x0, x1, y0=None, y1=None, replace=True):