Skip to content

Commit 6646ff2

Browse files
authored
Fix mel length computation in Qwen2-Audio
1 parent 069684e commit 6646ff2

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

src/transformers/models/qwen2_audio/modeling_qwen2_audio.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -347,9 +347,15 @@ def forward(
347347
):
348348
r"""
349349
Args:
350-
attention_mask (`torch.Tensor`)`, *optional*):
351-
Qwen2Audio does not support masking of the `input_features`, this argument is preserved for compatibility,
352-
but it is not used. By default the silence in the input log mel spectrogram are ignored.
350+
input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
351+
Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
352+
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
353+
`numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or
354+
the soundfile library (`pip install soundfile`). To prepare the array into
355+
`input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
356+
and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
357+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`), *optional*):
358+
attention mask used in the encoder stack (after the convolutional layers).
353359
output_attentions (`bool`, *optional*):
354360
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
355361
returned tensors for more detail.
@@ -765,7 +771,7 @@ def forward(
765771
feature_attention_mask.sum(-1)
766772
)
767773
batch_size, _, max_mel_seq_len = input_features.shape
768-
max_seq_len = (max_mel_seq_len - 2) // 2 + 1
774+
max_seq_len = (max_mel_seq_len - 1) // 2 + 1
769775
# Create a sequence tensor of shape (batch_size, max_seq_len)
770776
seq_range = (
771777
torch.arange(0, max_seq_len, dtype=audio_feat_lengths.dtype, device=audio_feat_lengths.device)

0 commit comments

Comments
 (0)