From d3efabe564d80b2c6ae10766089a7afb01a57fdb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A1=8C=E7=AD=A0?= Date: Tue, 21 Oct 2025 10:39:34 +0800 Subject: [PATCH 1/2] accelerate inference with sdpa --- .../models/module/transformer_decoder.py | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/fireredasr/models/module/transformer_decoder.py b/fireredasr/models/module/transformer_decoder.py index 2088b08..bcdd6fa 100644 --- a/fireredasr/models/module/transformer_decoder.py +++ b/fireredasr/models/module/transformer_decoder.py @@ -40,6 +40,8 @@ def batch_beam_search(self, encoder_outputs, src_masks, softmax_smoothing=1.0, length_penalty=0.0, eos_penalty=1.0): B = beam_size N, Ti, H = encoder_outputs.size() + flash_attn_enabled = (N == 1) + device = encoder_outputs.device maxlen = decode_max_len if decode_max_len > 0 else Ti assert eos_penalty > 0.0 and eos_penalty <= 1.0 @@ -68,7 +70,9 @@ def batch_beam_search(self, encoder_outputs, src_masks, dec_output = dec_layer.forward( dec_output, encoder_outputs, tgt_mask, src_mask, - cache=caches[i]) + cache=caches[i], + flash_attn_enabled=flash_attn_enabled + ) caches[i] = dec_output i += 1 @@ -183,7 +187,7 @@ def __init__(self, d_model, n_head, dropout): self.mlp = PositionwiseFeedForward(d_model, d_model*4, dropout) def forward(self, dec_input, enc_output, self_attn_mask, cross_attn_mask, - cache=None): + cache=None, flash_attn_enabled=flash_attn_enabled): x = dec_input residual = x x = self.self_attn_norm(x) @@ -193,12 +197,12 @@ def forward(self, dec_input, enc_output, self_attn_mask, cross_attn_mask, self_attn_mask = self_attn_mask[:, -1:, :] else: xq = x - x = self.self_attn(xq, x, x, mask=self_attn_mask) + x = self.self_attn(xq, x, x, mask=self_attn_mask, flash_attn_enabled=flash_attn_enabled) x = residual + x residual = x x = self.cross_attn_norm(x) - x = self.cross_attn(x, enc_output, enc_output, mask=cross_attn_mask) + x = self.cross_attn(x, enc_output, enc_output, mask=cross_attn_mask, flash_attn_enabled=flash_attn_enabled) x = residual + x residual = x @@ -227,7 +231,7 @@ def __init__(self, d_model, n_head, dropout=0.1): self.fc = nn.Linear(n_head * self.d_k, d_model) self.dropout = nn.Dropout(dropout) - def forward(self, q, k, v, mask=None): + def forward(self, q, k, v, mask=None, flash_attn_enabled=False): bs = q.size(0) q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k) @@ -237,10 +241,14 @@ def forward(self, q, k, v, mask=None): k = k.transpose(1, 2) v = v.transpose(1, 2) - if mask is not None: - mask = mask.unsqueeze(1) - - output = self.attention(q, k, v, mask=mask) + if flash_attn_enabled: + output = F.scaled_dot_product_attention( + q, k, v, scale=1/(self.d_k ** 0.5) + ) + else: + if mask is not None: + mask = mask.unsqueeze(1) + output = self.attention(q, k, v, mask=mask) output = output.transpose(1, 2).contiguous().view(bs, -1, self.d_model) output = self.fc(output) From f840f4d3b2b6f5dc894524a70a427963fe2bd27b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A1=8C=E7=AD=A0?= Date: Tue, 21 Oct 2025 10:55:56 +0800 Subject: [PATCH 2/2] accelerate inference with sdpa --- fireredasr/models/module/transformer_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fireredasr/models/module/transformer_decoder.py b/fireredasr/models/module/transformer_decoder.py index bcdd6fa..00f55c9 100644 --- a/fireredasr/models/module/transformer_decoder.py +++ b/fireredasr/models/module/transformer_decoder.py @@ -187,7 +187,7 @@ def __init__(self, d_model, n_head, dropout): self.mlp = PositionwiseFeedForward(d_model, d_model*4, dropout) def forward(self, dec_input, enc_output, self_attn_mask, cross_attn_mask, - cache=None, flash_attn_enabled=flash_attn_enabled): + cache=None, flash_attn_enabled=False): x = dec_input residual = x x = self.self_attn_norm(x)