Skip to content

Commit 5865424

Browse files
committed
no message
1 parent 965d122 commit 5865424

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

correlation/correlation.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -246,11 +246,8 @@ def cupy_kernel(strFunction, objVariables):
246246

247247
strTensor = objMatch.group(4)
248248
intSizes = objVariables[strTensor].size()
249-
if torch.is_tensor(intSizes[intArg]):
250-
strKernel = strKernel.replace(objectMatch.group(), str(intSizes[intArg].item()))
251-
else:
252-
strKernel = strKernel.replace(objectMatch.group(), str(intSizes[intArg]))
253-
# end
249+
250+
strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item()))
254251

255252
while True:
256253
objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel)
@@ -264,7 +261,7 @@ def cupy_kernel(strFunction, objVariables):
264261

265262
strTensor = strArgs[0]
266263
intStrides = objVariables[strTensor].stride()
267-
strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ]
264+
strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')' for intArg in range(intArgs) ]
268265

269266
strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']')
270267
# end

0 commit comments

Comments
 (0)