Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions fireredasr/models/module/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def batch_beam_search(self, encoder_outputs, src_masks,
softmax_smoothing=1.0, length_penalty=0.0, eos_penalty=1.0):
B = beam_size
N, Ti, H = encoder_outputs.size()
flash_attn_enabled = (N == 1)

device = encoder_outputs.device
maxlen = decode_max_len if decode_max_len > 0 else Ti
assert eos_penalty > 0.0 and eos_penalty <= 1.0
Expand Down Expand Up @@ -68,7 +70,9 @@ def batch_beam_search(self, encoder_outputs, src_masks,
dec_output = dec_layer.forward(
dec_output, encoder_outputs,
tgt_mask, src_mask,
cache=caches[i])
cache=caches[i],
flash_attn_enabled=flash_attn_enabled
)
caches[i] = dec_output
i += 1

Expand Down Expand Up @@ -183,7 +187,7 @@ def __init__(self, d_model, n_head, dropout):
self.mlp = PositionwiseFeedForward(d_model, d_model*4, dropout)

def forward(self, dec_input, enc_output, self_attn_mask, cross_attn_mask,
cache=None):
cache=None, flash_attn_enabled=False):
x = dec_input
residual = x
x = self.self_attn_norm(x)
Expand All @@ -193,12 +197,12 @@ def forward(self, dec_input, enc_output, self_attn_mask, cross_attn_mask,
self_attn_mask = self_attn_mask[:, -1:, :]
else:
xq = x
x = self.self_attn(xq, x, x, mask=self_attn_mask)
x = self.self_attn(xq, x, x, mask=self_attn_mask, flash_attn_enabled=flash_attn_enabled)
x = residual + x

residual = x
x = self.cross_attn_norm(x)
x = self.cross_attn(x, enc_output, enc_output, mask=cross_attn_mask)
x = self.cross_attn(x, enc_output, enc_output, mask=cross_attn_mask, flash_attn_enabled=flash_attn_enabled)
x = residual + x

residual = x
Expand Down Expand Up @@ -227,7 +231,7 @@ def __init__(self, d_model, n_head, dropout=0.1):
self.fc = nn.Linear(n_head * self.d_k, d_model)
self.dropout = nn.Dropout(dropout)

def forward(self, q, k, v, mask=None):
def forward(self, q, k, v, mask=None, flash_attn_enabled=False):
bs = q.size(0)

q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k)
Expand All @@ -237,10 +241,14 @@ def forward(self, q, k, v, mask=None):
k = k.transpose(1, 2)
v = v.transpose(1, 2)

if mask is not None:
mask = mask.unsqueeze(1)

output = self.attention(q, k, v, mask=mask)
if flash_attn_enabled:
output = F.scaled_dot_product_attention(
q, k, v, scale=1/(self.d_k ** 0.5)
)
else:
if mask is not None:
mask = mask.unsqueeze(1)
output = self.attention(q, k, v, mask=mask)

output = output.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
output = self.fc(output)
Expand Down