Skip to content

Commit cd82a95

Browse files
authored
Fix ESM attention for TFLite compatibility (#2466)
* Fix ESM attention for TFLite compatibility * Update esm_attention.py
1 parent 80c615c commit cd82a95

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

keras_hub/src/models/esm/esm_attention.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ def _compute_cos_sin_embedding(self, x, position=1):
1414
inv_freq = self.scaling_factor / (
1515
self.max_wavelength ** (ops.arange(0, dim, 2, dtype=x.dtype) / dim)
1616
)
17-
t = ops.arange(x.shape[position], dtype=x.dtype)
17+
# Use ops.shape for dynamic shape compatibility with TFLite
18+
t = ops.arange(ops.shape(x)[position], dtype=x.dtype)
1819
freqs = ops.outer(t, inv_freq)
1920
emb = ops.concatenate((freqs, freqs), axis=-1)
2021

@@ -32,11 +33,17 @@ def call(self, q, k, position=1):
3233

3334
def rotate_half(self, x):
3435
x1, x2 = ops.split(x, 2, -1)
35-
return ops.concatenate((-x2, x1), axis=-1)
36+
# Avoid `ops.concatenate` to prevent XLA compilation issues on JAX
37+
# backend. Use stack + reshape approach from base RotaryEmbedding.
38+
half_rot_x = ops.stack((-x2, x1), axis=-2)
39+
half_rot_x = ops.reshape(half_rot_x, ops.shape(x))
40+
return half_rot_x
3641

3742
def apply_rotary_pos_emb(self, x, cos, sin):
38-
cos = cos[:, : x.shape[1], :, :]
39-
sin = sin[:, : x.shape[1], :, :]
43+
# Use ops.shape for dynamic shape compatibility with TFLite
44+
seq_len = ops.shape(x)[1]
45+
cos = cos[:, :seq_len, :, :]
46+
sin = sin[:, :seq_len, :, :]
4047

4148
return (x * cos) + (self.rotate_half(x) * sin)
4249

0 commit comments

Comments
 (0)