Skip to content

swin_transformer.py compute_mask doesn't have expected torch dtype #92

@wanghao14

Description

@wanghao14

Describe the bug

Build video swin transformer in torch.bfloat16 and run with following error:

.........
File "/data4/Projects/video_captioning/my_project/experiments/modeling/swin_transformer.py", line 284, in forward
    x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
RuntimeError: expected scalar type BFloat16 but found Float 

As I haved change something of this code in my project, the corresponding line in swin_transformer.py of this repository is in https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/mmaction/models/backbones/swin_transformer.py#L166
and related function with bug is https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/mmaction/models/backbones/swin_transformer.py#L317.

Bug fix

One way to fix is shown below:

-    def compute_mask(D, H, W, window_size, shift_size, device):
+    def compute_mask(D, H, W, window_size, shift_size, device, dtype): 
-        img_mask = torch.zeros((1, D, H, W, 1), device=device)  # 1 Dp Hp Wp 1
+        img_mask = torch.zeros((1, D, H, W, 1), device=device, dtype=dtype)  # 1 Dp Hp Wp 1

And change the line in https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/mmaction/models/backbones/swin_transformer.py#L405 to
attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device, x.dtype)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions