We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent ed3af4f commit 07c3f5aCopy full SHA for 07c3f5a
correlation/correlation.py
@@ -381,7 +381,8 @@ def backward(self, gradOutput):
381
# end
382
383
def FunctionCorrelation(tenOne, tenTwo):
384
- return _FunctionCorrelation.apply(tenOne, tenTwo)
+ with cupy.cuda.Device(tenOne.get_device()):
385
+ return _FunctionCorrelation.apply(tenOne, tenTwo)
386
387
388
class ModuleCorrelation(torch.nn.Module):
0 commit comments