diff --git a/sam2/modeling/sam/transformer.py b/sam2/modeling/sam/transformer.py index f9fe9a3fb..22cf2ee95 100644 --- a/sam2/modeling/sam/transformer.py +++ b/sam2/modeling/sam/transformer.py @@ -239,6 +239,10 @@ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: v = self._separate_heads(v, self.num_heads) dropout_p = self.dropout_p if self.training else 0.0 + if q.dtype != k.dtype: + k = k.to(dtype=q.dtype) + if q.dtype != v.dtype: + v = v.to(dtype=q.dtype) # Attention out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) @@ -302,6 +306,10 @@ def forward( ) dropout_p = self.dropout_p if self.training else 0.0 + if q.dtype != k.dtype: + k = k.to(dtype=q.dtype) + if q.dtype != v.dtype: + v = v.to(dtype=q.dtype) # Attention out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)