@@ -14,7 +14,8 @@ def _compute_cos_sin_embedding(self, x, position=1):
1414 inv_freq = self .scaling_factor / (
1515 self .max_wavelength ** (ops .arange (0 , dim , 2 , dtype = x .dtype ) / dim )
1616 )
17- t = ops .arange (x .shape [position ], dtype = x .dtype )
17+ # Use ops.shape for dynamic shape compatibility with TFLite
18+ t = ops .arange (ops .shape (x )[position ], dtype = x .dtype )
1819 freqs = ops .outer (t , inv_freq )
1920 emb = ops .concatenate ((freqs , freqs ), axis = - 1 )
2021
@@ -32,11 +33,17 @@ def call(self, q, k, position=1):
3233
3334 def rotate_half (self , x ):
3435 x1 , x2 = ops .split (x , 2 , - 1 )
35- return ops .concatenate ((- x2 , x1 ), axis = - 1 )
36+ # Avoid `ops.concatenate` to prevent XLA compilation issues on JAX
37+ # backend. Use stack + reshape approach from base RotaryEmbedding.
38+ half_rot_x = ops .stack ((- x2 , x1 ), axis = - 2 )
39+ half_rot_x = ops .reshape (half_rot_x , ops .shape (x ))
40+ return half_rot_x
3641
3742 def apply_rotary_pos_emb (self , x , cos , sin ):
38- cos = cos [:, : x .shape [1 ], :, :]
39- sin = sin [:, : x .shape [1 ], :, :]
43+ # Use ops.shape for dynamic shape compatibility with TFLite
44+ seq_len = ops .shape (x )[1 ]
45+ cos = cos [:, :seq_len , :, :]
46+ sin = sin [:, :seq_len , :, :]
4047
4148 return (x * cos ) + (self .rotate_half (x ) * sin )
4249
0 commit comments