File tree Expand file tree Collapse file tree 1 file changed +6
-4
lines changed Expand file tree Collapse file tree 1 file changed +6
-4
lines changed Original file line number Diff line number Diff line change @@ -246,9 +246,11 @@ def cupy_kernel(strFunction, objVariables):
246246
247247 strTensor = objMatch .group (4 )
248248 intSizes = objVariables [strTensor ].size ()
249-
250- strKernel = strKernel .replace (objMatch .group (), str (intSizes [intArg ]))
251- # end
249+ if torch .is_tensor (intSizes [intArg ]):
250+ strKernel = strKernel .replace (objectMatch .group (), str (intSizes [intArg ].item ()))
251+ else :
252+ strKernel = strKernel .replace (objectMatch .group (), str (intSizes [intArg ]))
253+ # end
252254
253255 while True :
254256 objMatch = re .search ('(VALUE_)([0-4])(\()([^\)]+)(\))' , strKernel )
@@ -394,4 +396,4 @@ def __init__(self):
394396 def forward (self , tenOne , tenTwo ):
395397 return _FunctionCorrelation .apply (tenOne , tenTwo )
396398 # end
397- # end
399+ # end
You can’t perform that action at this time.
0 commit comments