-
Notifications
You must be signed in to change notification settings - Fork 68
Xe2 scaledMMs with MX format weights #633
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
a55d296 to
e53e636
Compare
e53e636 to
b0e9006
Compare
336b849 to
15ead41
Compare
|
The CI failures are unrelated, and have been affecting all recent PRs. Thanks! |
|
thanks for the PR, how do you calculate the throughtput? what's the peak of 4bit computation in B580? |
|
Hi @pengzhao-intel,
Only compute corresponding to MMA is being considered for calculating the throughput, so the actual throughput measured with a profiler would be higher, since we don't currently account for conversion of 4-bit weights, scaling, and FP32 -> BF16 conversion (which is just a mov instruction, so it doesn't matter much, but the former two have a substantial overhead). Throughput is being computed as (2 * M * N * K)/ latency.
Please go through the description of the PR - BMG doesn't support 4-bit MMA natively. 4-bit weights are converted to FP16 or BF16, depending upon the activation. For BF16/FP16, the peak throughput of B580 is ~117 TFLOPs/s.
|
|
@sanchitintel -- on the use of shl for bf16->f32 conversion: this comes from |
The hardware supports fusing BF16/FP16 to FP32 conversion for one of the inputs to mul. Based on some reference code provided by Peter Caday.
|
Thanks for your inputs, @petercad! I revised the description and also added the scaling code in assembly that you provided. I also added some clarifying details on observed performance. |
Summary
cute/tutorial/xe_gemm.hpp.Caution
When the number of output workgroup tiles aren't a multiple of the number of Xe cores, i.e. when the computation has a tail, the scaledMM perf would look worse than the case in which the number of output workgroup tiles are a multiple of the number of Xe cores, because the hardware remains underutilized when the output tiles corresponding to the tail are processed. In practice, if these scaledMM mainloops would be ported to a Group GEMM, the issue of tail-latency would dissipate, as the tail of individual GEMM problems won't matter, and only the tail corresponding to all the output tiles would matter.
Important
Please don't be dissuaded to see N=2880/5760, K=2880 performance below if you intend to use this scaledMM mainloop in a Grouped GEMM, as tail-latency of individual GEMM problems wouldn't matter at all, as described above.
Details
Known bottlenecks for BF16 -
Issues 2 was identified by Peter & is being tracked for igc. However, he provided corresponding asm code as a workaround that I've since added.
Performance on BMG B580
B Quantization Group size 32 (as in OCP defined MX formats)
Further tuning is possible on a case-by-case basis.
e.g. bf16 x fp8_e4m3 can be tuned further for ColumnMajor B, if required.
In DL models, weights are usually in ColumnMajor B format.
Benchmarking instructions
Build instructions
Please do not use
-DDPCPP_HOST_COMPILER=g++-13(for now, I'll later revise the code to make it compatible with g++. It's related to the sycl kernel launch).cc @pengzhao-intel @EikanWang @CaoZhongZ