Skip to content

Conversation

@xsank
Copy link

@xsank xsank commented Nov 13, 2025

SDPA erformance improvement is approximately 50%, flash attention nearly 100%, depends on the data and the batch size.
The greater the difference in audio length, the better the optimization effect. If you use batch size=1, no effect.
@kaituoxu

@Xujianzhong
Copy link

@xsank The test did not show any performance improvement.

is_finished_n = is_finished.sum().item()
active_mask = ~is_finished.squeeze()
#active_indices = self.filter_indexes[M][active_mask]
active_indices = torch.nonzero_static(active_mask, size=M - int(is_finished_n)).squeeze(1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's wired to get a error here. The environment is torch==2.4.0+cu121, A100, flash_attn-2.8.3-cp310-cp310-linux_x86_64.whl.

Traceback (most recent call last):
File "/data/user/lxp/tools/python/speech/batch_fireredaed.py", line 184, in
main(args)
File "/data/user/lxp/tools/python/speech/batch_fireredaed.py", line 144, in main
texts = model.transcribe_aed(
File "/root/miniconda3/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/data/user/lxp/asr/FireRedASR/fireredasr/models/fireredasr.py", line 82, in transcribe_aed
hyps = self.model.transcribe(
File "/data/user/lxp/asr/FireRedASR/fireredasr/models/fireredasr_aed.py", line 33, in transcribe
nbest_hyps = self.decoder.batch_beam_search(
File "/data/user/lxp/asr/FireRedASR/fireredasr/models/module/transformer_decoder.py", line 216, in batch_beam_search
active_indices = torch.nonzero_static(active_mask, size=M - int(is_finished_n)).squeeze(1)
NotImplementedError: Could not run 'aten::nonzero_static' with arguments from the 'CUDA' backend. This could be because the operator doesn't exist for this backend, or was omitted dur
ing the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolution
s. 'aten::nonzero_static' is only available for these backends: [CPU, Meta, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor
, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMTIA, Auto
gradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradMeta, AutogradNestedTensor, Tracer, AutocastCPU, AutocastXPU, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, F
uncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested it on torch 2.7.1+python 3.12

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sonnet 4.5 replaces the ops of torch.nonzero_static with torch.nonzero which solved the problem.

# Update finished state
            is_finished = t_ys.eq(self.eos_id)
            is_finished_n = is_finished.sum().item()
            active_mask = ~is_finished.squeeze()
            active_indices = torch.nonzero(active_mask, as_tuple=False).squeeze(1)

Thanks for the support of flash attention, it's really fast!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sonnet 4.5 replaces the ops of torch.nonzero_static with torch.nonzero which solved the problem.

# Update finished state
            is_finished = t_ys.eq(self.eos_id)
            is_finished_n = is_finished.sum().item()
            active_mask = ~is_finished.squeeze()
            active_indices = torch.nonzero(active_mask, as_tuple=False).squeeze(1)

Thanks for the support of flash attention, it's really fast!

There is a small problem, torch.nonzero have one more transfer from gpu to cpu than the torch.nonzero_static,which would cost about 5% performance loss.

@xsank
Copy link
Author

xsank commented Nov 14, 2025

@xsank The test did not show any performance improvement.

@Xujianzhong which test? let me see see.

@xsank xsank changed the title Optimize beam search & add flash attention support Optimize beam search & add flash attention+xformers support Nov 18, 2025
@kaituoxu
Copy link
Collaborator

Thanks for your PR, we will review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants