-
Notifications
You must be signed in to change notification settings - Fork 148
Optimize beam search & add flash attention+xformers support #105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@xsank The test did not show any performance improvement. |
| is_finished_n = is_finished.sum().item() | ||
| active_mask = ~is_finished.squeeze() | ||
| #active_indices = self.filter_indexes[M][active_mask] | ||
| active_indices = torch.nonzero_static(active_mask, size=M - int(is_finished_n)).squeeze(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's wired to get a error here. The environment is torch==2.4.0+cu121, A100, flash_attn-2.8.3-cp310-cp310-linux_x86_64.whl.
Traceback (most recent call last):
File "/data/user/lxp/tools/python/speech/batch_fireredaed.py", line 184, in
main(args)
File "/data/user/lxp/tools/python/speech/batch_fireredaed.py", line 144, in main
texts = model.transcribe_aed(
File "/root/miniconda3/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/data/user/lxp/asr/FireRedASR/fireredasr/models/fireredasr.py", line 82, in transcribe_aed
hyps = self.model.transcribe(
File "/data/user/lxp/asr/FireRedASR/fireredasr/models/fireredasr_aed.py", line 33, in transcribe
nbest_hyps = self.decoder.batch_beam_search(
File "/data/user/lxp/asr/FireRedASR/fireredasr/models/module/transformer_decoder.py", line 216, in batch_beam_search
active_indices = torch.nonzero_static(active_mask, size=M - int(is_finished_n)).squeeze(1)
NotImplementedError: Could not run 'aten::nonzero_static' with arguments from the 'CUDA' backend. This could be because the operator doesn't exist for this backend, or was omitted dur
ing the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolution
s. 'aten::nonzero_static' is only available for these backends: [CPU, Meta, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor
, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMTIA, Auto
gradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradMeta, AutogradNestedTensor, Tracer, AutocastCPU, AutocastXPU, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, F
uncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested it on torch 2.7.1+python 3.12
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sonnet 4.5 replaces the ops of torch.nonzero_static with torch.nonzero which solved the problem.
# Update finished state
is_finished = t_ys.eq(self.eos_id)
is_finished_n = is_finished.sum().item()
active_mask = ~is_finished.squeeze()
active_indices = torch.nonzero(active_mask, as_tuple=False).squeeze(1)
Thanks for the support of flash attention, it's really fast!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sonnet 4.5 replaces the ops of torch.nonzero_static with torch.nonzero which solved the problem.
# Update finished state is_finished = t_ys.eq(self.eos_id) is_finished_n = is_finished.sum().item() active_mask = ~is_finished.squeeze() active_indices = torch.nonzero(active_mask, as_tuple=False).squeeze(1)Thanks for the support of flash attention, it's really fast!
There is a small problem, torch.nonzero have one more transfer from gpu to cpu than the torch.nonzero_static,which would cost about 5% performance loss.
@Xujianzhong which test? let me see see. |
|
Thanks for your PR, we will review. |
SDPA erformance improvement is approximately 50%, flash attention nearly 100%, depends on the data and the batch size.
The greater the difference in audio length, the better the optimization effect. If you use batch size=1, no effect.
@kaituoxu