4141
4242def backwarp (tenInput , tenFlow ):
4343 if str (tenFlow .shape ) not in backwarp_tenGrid :
44- tenHor = torch .linspace (- 1.0 + ( 1.0 / tenFlow . shape [ 3 ]) , 1.0 - ( 1.0 / tenFlow . shape [ 3 ]) , tenFlow .shape [3 ]).view (1 , 1 , 1 , - 1 ).repeat (1 , 1 , tenFlow .shape [2 ], 1 )
45- tenVer = torch .linspace (- 1.0 + ( 1.0 / tenFlow . shape [ 2 ]) , 1.0 - ( 1.0 / tenFlow . shape [ 2 ]) , tenFlow .shape [2 ]).view (1 , 1 , - 1 , 1 ).repeat (1 , 1 , 1 , tenFlow .shape [3 ])
44+ tenHor = torch .linspace (- 1.0 , 1.0 , tenFlow .shape [3 ]).view (1 , 1 , 1 , - 1 ).repeat (1 , 1 , tenFlow .shape [2 ], 1 )
45+ tenVer = torch .linspace (- 1.0 , 1.0 , tenFlow .shape [2 ]).view (1 , 1 , - 1 , 1 ).repeat (1 , 1 , 1 , tenFlow .shape [3 ])
4646
4747 backwarp_tenGrid [str (tenFlow .shape )] = torch .cat ([ tenHor , tenVer ], 1 ).cuda ()
4848 # end
@@ -51,10 +51,10 @@ def backwarp(tenInput, tenFlow):
5151 backwarp_tenPartial [str (tenFlow .shape )] = tenFlow .new_ones ([ tenFlow .shape [0 ], 1 , tenFlow .shape [2 ], tenFlow .shape [3 ] ])
5252 # end
5353
54- tenFlow = torch .cat ([ tenFlow [:, 0 :1 , :, :] / (( tenInput .shape [3 ] - 1.0 ) / 2.0 ), tenFlow [:, 1 :2 , :, :] / (( tenInput .shape [2 ] - 1.0 ) / 2.0 ) ], 1 )
54+ tenFlow = torch .cat ([ tenFlow [:, 0 :1 , :, :] * ( 2.0 / (tenInput .shape [3 ] - 1.0 )), tenFlow [:, 1 :2 , :, :] * ( 2.0 / (tenInput .shape [2 ] - 1.0 )) ], 1 )
5555 tenInput = torch .cat ([ tenInput , backwarp_tenPartial [str (tenFlow .shape )] ], 1 )
5656
57- tenOutput = torch .nn .functional .grid_sample (input = tenInput , grid = (backwarp_tenGrid [str (tenFlow .shape )] + tenFlow ).permute (0 , 2 , 3 , 1 ), mode = 'bilinear' , padding_mode = 'zeros' , align_corners = False )
57+ tenOutput = torch .nn .functional .grid_sample (input = tenInput , grid = (backwarp_tenGrid [str (tenFlow .shape )] + tenFlow ).permute (0 , 2 , 3 , 1 ), mode = 'bilinear' , padding_mode = 'zeros' , align_corners = True )
5858
5959 tenMask = tenOutput [:, - 1 :, :, :]; tenMask [tenMask > 0.999 ] = 1.0 ; tenMask [tenMask < 1.0 ] = 0.0
6060
0 commit comments