Skip to content

Commit 965d122

Browse files
authored
Merge pull request #57 from Etienne66/master
Fix issue #47
2 parents 9712c86 + 4bfb33f commit 965d122

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

correlation/correlation.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,9 +246,11 @@ def cupy_kernel(strFunction, objVariables):
246246

247247
strTensor = objMatch.group(4)
248248
intSizes = objVariables[strTensor].size()
249-
250-
strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg]))
251-
# end
249+
if torch.is_tensor(intSizes[intArg]):
250+
strKernel = strKernel.replace(objectMatch.group(), str(intSizes[intArg].item()))
251+
else:
252+
strKernel = strKernel.replace(objectMatch.group(), str(intSizes[intArg]))
253+
# end
252254

253255
while True:
254256
objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel)
@@ -394,4 +396,4 @@ def __init__(self):
394396
def forward(self, tenOne, tenTwo):
395397
return _FunctionCorrelation.apply(tenOne, tenTwo)
396398
# end
397-
# end
399+
# end

0 commit comments

Comments
 (0)