|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | + |
| 4 | +from typing import Optional |
| 5 | + |
| 6 | +import torch |
| 7 | + |
| 8 | +from .common import apply_rotary_emb_dispatch |
| 9 | +from .mrope import MRotaryEmbedding |
| 10 | + |
| 11 | + |
| 12 | +class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding): |
| 13 | + """3D rotary positional embedding. 3D is t:time h:height w:width""" |
| 14 | + |
| 15 | + def forward( |
| 16 | + self, |
| 17 | + positions: torch.Tensor, |
| 18 | + query: torch.Tensor, |
| 19 | + key: Optional[torch.Tensor] = None, |
| 20 | + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: |
| 21 | + assert positions.ndim == 1 or positions.ndim == 2 |
| 22 | + assert key is not None |
| 23 | + |
| 24 | + num_tokens = positions.shape[-1] |
| 25 | + cos_sin = self.cos_sin_cache[positions] |
| 26 | + cos, sin = cos_sin.chunk(2, dim=-1) |
| 27 | + if positions.ndim == 2: |
| 28 | + assert self.mrope_section |
| 29 | + |
| 30 | + section_h = self.mrope_section[0] # 22 |
| 31 | + section_w = self.mrope_section[1] # 22 |
| 32 | + section_t = self.mrope_section[2] # 20 |
| 33 | + assert section_h == section_w |
| 34 | + # Split according to [h w h w h w h w... t t t...] |
| 35 | + section_cos_t = cos[..., -section_t:] |
| 36 | + section_cos_h = cos[..., :section_h + section_w:2] |
| 37 | + section_cos_w = cos[..., 1:section_h + section_w:2] |
| 38 | + |
| 39 | + cos_t, cos_h, cos_w = section_cos_t[0], section_cos_h[ |
| 40 | + 1], section_cos_w[2] |
| 41 | + cos_hw = torch.stack([cos_h, cos_w], |
| 42 | + dim=-1).reshape(cos_h.shape[:-1] + |
| 43 | + (cos_h.shape[-1] * 2, )) |
| 44 | + cos = torch.cat([cos_hw, cos_t], dim=-1) |
| 45 | + |
| 46 | + section_sin_t = sin[..., -section_t:] |
| 47 | + section_sin_h = sin[..., :section_h + section_w:2] |
| 48 | + section_sin_w = sin[..., 1:section_h + section_w:2] |
| 49 | + |
| 50 | + sin_t, sin_h, sin_w = section_sin_t[0], section_sin_h[ |
| 51 | + 1], section_sin_w[2] |
| 52 | + sin_hw = torch.stack([sin_h, sin_w], |
| 53 | + dim=-1).reshape(sin_h.shape[:-1] + |
| 54 | + (sin_h.shape[-1] * 2, )) |
| 55 | + sin = torch.cat([sin_hw, sin_t], dim=-1) |
| 56 | + |
| 57 | + query_shape = query.shape |
| 58 | + query = query.view(num_tokens, -1, self.head_size) |
| 59 | + query_rot = query[..., :self.rotary_dim] |
| 60 | + query_pass = query[..., self.rotary_dim:] |
| 61 | + query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, |
| 62 | + self.is_neox_style) |
| 63 | + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) |
| 64 | + |
| 65 | + key_shape = key.shape |
| 66 | + key = key.view(num_tokens, -1, self.head_size) |
| 67 | + key_rot = key[..., :self.rotary_dim] |
| 68 | + key_pass = key[..., self.rotary_dim:] |
| 69 | + key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, |
| 70 | + self.is_neox_style) |
| 71 | + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) |
| 72 | + return query, key |
0 commit comments