Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aiter-kernels/build.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[general]
name = "aiter-kernels"
version = 1
version = 2
license = "MIT"
backends = ["rocm"]

Expand Down
10 changes: 7 additions & 3 deletions aiter-kernels/torch-ext/aiter_kernels/rope/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _to_bhsd(t: torch.Tensor) -> torch.Tensor:
return t.permute(1, 2, 0, 3)


def apply_rotary_transformers(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
def apply_rotary_transformers(q, k, cos, sin, unsqueeze_dim=1):
"""Apply NEOX-style RoPE to ``q`` and ``k``.

Drop-in replacement for ``kernels-community/rotary``'s
Expand All @@ -40,8 +40,8 @@ def apply_rotary_transformers(q, k, cos, sin, position_ids=None, unsqueeze_dim=1
Args:
q, k: ``(batch, heads, seq, head_dim)``.
cos, sin: ``(batch, seq, head_dim // 2)`` — pre-``unsqueeze`` form.
position_ids, unsqueeze_dim: accepted for API parity; the kernel reads
positions from the already-computed ``cos`` / ``sin``.
unsqueeze_dim: accepted for API parity; the kernel reads positions from
the already-computed ``cos`` / ``sin``.

Returns:
``(q_embed, k_embed)`` in the same shape as ``q``, ``k``.
Expand All @@ -63,6 +63,10 @@ def apply_rotary_transformers(q, k, cos, sin, position_ids=None, unsqueeze_dim=1
return _to_bhsd(q_out_sbhd), _to_bhsd(k_out_sbhd)


# Add torch compile support for functions
apply_rotary_transformers.can_torch_compile = True


__all__ = [
"RotateStyle",
"apply_rotary_transformers",
Expand Down
2 changes: 1 addition & 1 deletion aiter-rope/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Original code: https://github.com/ROCm/aiter (MIT, © Advanced Micro Devices, In

## Functions

### `apply_rotary_transformers(q, k, cos, sin, position_ids=None, unsqueeze_dim=1)`
### `apply_rotary_transformers(q, k, cos, sin, unsqueeze_dim=1)`

Apply NEOX-style RoPE to query and key tensors.

Expand Down
2 changes: 1 addition & 1 deletion aiter-rope/build.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[general]
name = "aiter-rope"
version = 1
version = 2
license = "MIT"
backends = ["rocm"]

Expand Down
10 changes: 7 additions & 3 deletions aiter-rope/torch-ext/aiter_rope/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _to_bhsd(t: torch.Tensor) -> torch.Tensor:
return t.permute(1, 2, 0, 3)


def apply_rotary_transformers(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
def apply_rotary_transformers(q, k, cos, sin, unsqueeze_dim=1):
"""Apply NEOX-style RoPE to ``q`` and ``k``.

Signature mirrors ``kernels-community/rotary``'s ``apply_rotary_transformers``
Expand All @@ -42,8 +42,8 @@ def apply_rotary_transformers(q, k, cos, sin, position_ids=None, unsqueeze_dim=1
Args:
q, k: ``(batch, heads, seq, head_dim)``.
cos, sin: ``(batch, seq, head_dim // 2)`` — pre-``unsqueeze`` form.
position_ids, unsqueeze_dim: accepted for API parity; the kernel reads
positions from the already-computed ``cos`` / ``sin``.
unsqueeze_dim: accepted for API parity; the kernel reads positions from
the already-computed ``cos`` / ``sin``.

Returns:
``(q_embed, k_embed)`` in the same shape as ``q``, ``k``.
Expand All @@ -68,6 +68,10 @@ def apply_rotary_transformers(q, k, cos, sin, position_ids=None, unsqueeze_dim=1
return _to_bhsd(q_out_sbhd), _to_bhsd(k_out_sbhd)


# Add torch compile support for functions
apply_rotary_transformers.can_torch_compile = True


__all__ = [
"__kernel_metadata__",
"RotateStyle",
Expand Down