diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index f9d15e1cca..67ccd1d07f 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -1,5 +1,5 @@ name: Bug Report -description: Create a bug report to help us improve CUTLASS +description: Create a bug report to help us improve SYCL*TLA title: "[BUG] " labels: ["? - Needs Triage", "bug"] assignees: [] @@ -10,8 +10,9 @@ body: attributes: label: Which component has the problem? options: - - CuTe DSL - - CUTLASS C++ + - CuTe APIs + - CUTLASS APIs + - Python (APIs or Pypi package) validations: required: true - type: textarea diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml deleted file mode 100644 index 4572ae1b98..0000000000 --- a/.github/ISSUE_TEMPLATE/config.yml +++ /dev/null @@ -1,5 +0,0 @@ -blank_issues_enabled: true -contact_links: - - name: CUTLASS Discord - url: https://discord.gg/nvidiadeveloper - about: Come chat about using and contributing to CUTLASS! diff --git a/.github/ISSUE_TEMPLATE/documentation_request.md b/.github/ISSUE_TEMPLATE/documentation_request.md index c9fa21fac9..9a3f2f0bab 100644 --- a/.github/ISSUE_TEMPLATE/documentation_request.md +++ b/.github/ISSUE_TEMPLATE/documentation_request.md @@ -1,6 +1,6 @@ --- name: Documentation request -about: Report incorrect or needed documentation to improve CUTLASS +about: Report incorrect or needed documentation to improve SYCL*TLA title: "[DOC]" labels: "? - Needs Triage, documentation" assignees: '' diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml index 4f1d616576..ca4c4880f1 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -1,5 +1,5 @@ name: Feature Request -description: Suggest an idea for CUTLASS +description: Suggest an idea for SYCL*TLA title: "[FEA] " labels: ["? - Needs Triage", "feature request"] assignees: [] @@ -10,8 +10,9 @@ body: attributes: label: Which component requires the feature? options: - - CuTe DSL - - CUTLASS C++ + - CuTe APIs + - CUTLASS APIs + - Python (APIs or Pypi package) validations: required: true - type: textarea @@ -21,7 +22,7 @@ body: description: Please fill out all sections below value: | **Is your feature request related to a problem? Please describe.** - A clear and concise description of what the problem is. Ex. I wish I could use CUTLASS to do [...] + A clear and concise description of what the problem is. Ex. I wish I could use SYCL*TLA to do [...] **Describe the solution you'd like** A clear and concise description of what you want to happen. diff --git a/.github/ISSUE_TEMPLATE/submit_question.md b/.github/ISSUE_TEMPLATE/submit_question.md index 5aa2a672d2..efd4322849 100644 --- a/.github/ISSUE_TEMPLATE/submit_question.md +++ b/.github/ISSUE_TEMPLATE/submit_question.md @@ -1,6 +1,6 @@ --- name: Submit question -about: Ask a general question about CUTLASS +about: Ask a general question about SYCL*TLA title: "[QST]" labels: "? - Needs Triage, question" assignees: '' diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000000..ac43cf77b9 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,19 @@ +## Description + + +## Type +- [ ] Bug - [ ] Feature - [ ] Performance - [ ] Refactor + +## Testing +- [ ] Tests pass - [ ] Xe12 - [ ] Xe20 + +## Performance +| Metric | Before | After | +|--------|--------|-------| +| | | | + +## References +Fixes # + +## Checklist +- [ ] Copyright - [ ] Co-pilot Review - [ ] Deprecated APIs not used diff --git a/.github/PULL_REQUEST_TEMPLATE/bug_fix.md b/.github/PULL_REQUEST_TEMPLATE/bug_fix.md new file mode 100644 index 0000000000..ce123ee50e --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/bug_fix.md @@ -0,0 +1,20 @@ +## Bug + + +Severity: + +## Root Cause + + +## Fix + + +## Verification +Before: +After: + +## Testing +- [ ] Regression/Units test +- [ ] Tests pass + +## Details diff --git a/.github/PULL_REQUEST_TEMPLATE/feature.md b/.github/PULL_REQUEST_TEMPLATE/feature.md new file mode 100644 index 0000000000..42e6f25d9f --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/feature.md @@ -0,0 +1,23 @@ +## Feature + + +## Use Case + + +## API +```cpp +// Signature +``` + +## Example +```cpp +// Usage +``` + +## Testing +- [ ] Tests - [ ] Example - [ ] Docs + +## ToDo +- [ ] Implement A +- [ ] Implement B +- [ ] Document diff --git a/.github/PULL_REQUEST_TEMPLATE/performance.md b/.github/PULL_REQUEST_TEMPLATE/performance.md new file mode 100644 index 0000000000..92a18959db --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/performance.md @@ -0,0 +1,19 @@ +## Optimization + + +## Profiling +Tool: +Bottleneck: + +## Results +| Case | Before | After | Gain | +|------|--------|-------|------| +| | | | | + +## Changes + + +## Testing +- [ ] Tests pass - [ ] Xe12 - [ ] Xe20 + +Related: # diff --git a/.github/PULL_REQUEST_TEMPLATE/refactoring.md b/.github/PULL_REQUEST_TEMPLATE/refactoring.md new file mode 100644 index 0000000000..63ead276f9 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/refactoring.md @@ -0,0 +1,23 @@ +## Refactoring + + +## Why + + +## Changes + + +## Preservation +- [ ] Tests unchanged +- [ ] Perf unchanged + +## Quality +| Metric | Before | After | +|--------|--------|-------| +| LOC | | | +| Performance | | | + +## ToDo +- [ ] Implement A +- [ ] Implement B +- [ ] Document diff --git a/.github/actions/install-intel-graphics/action.yml b/.github/actions/install-intel-graphics/action.yml index 93aa9c3cdc..2d3a71b35d 100644 --- a/.github/actions/install-intel-graphics/action.yml +++ b/.github/actions/install-intel-graphics/action.yml @@ -17,20 +17,22 @@ runs: run: | shopt -s expand_aliases which sudo || alias sudo="" - if [[ "${{ inputs.GPU }}" == "BMG" ]]; then + if [[ "${{ inputs.GPU }}" == "BMG" || "${{ inputs.GPU }}" == "PVC" ]]; then sudo add-apt-repository ppa:kobuk-team/intel-graphics sudo apt update else - . /etc/os-release - wget https://repositories.intel.com/gpu/ubuntu/dists/${VERSION_CODENAME}/intel-gpu-ubuntu-${VERSION_CODENAME}.run - chmod +x intel-gpu-ubuntu-${VERSION_CODENAME}.run - sudo ./intel-gpu-ubuntu-${VERSION_CODENAME}.run - sudo apt install -y \ - intel-media-va-driver-non-free libmfx-gen1 libvpl2 \ - libegl-mesa0 libegl1-mesa-dev libgl1-mesa-dev \ - libgles2-mesa-dev libigdgmm12 libxatracker2 mesa-va-drivers \ - mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo \ - libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev libze-dev hwinfo + # LTS PVC drivers + # . /etc/os-release + # wget https://repositories.intel.com/gpu/ubuntu/dists/${VERSION_CODENAME}/intel-gpu-ubuntu-${VERSION_CODENAME}.run + # chmod +x intel-gpu-ubuntu-${VERSION_CODENAME}.run + # sudo ./intel-gpu-ubuntu-${VERSION_CODENAME}.run + # sudo apt install -y \ + # intel-media-va-driver-non-free libmfx-gen1 libvpl2 \ + # libegl-mesa0 libegl1-mesa-dev libgl1-mesa-dev \ + # libgles2-mesa-dev libigdgmm12 libxatracker2 mesa-va-drivers \ + # mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo \ + # libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev libze-dev hwinfo + exit 1 fi sudo apt-get install -y libze-intel-gpu1 libze-dev intel-metrics-discovery \ intel-opencl-icd ocl-icd-opencl-dev clinfo intel-gsc intel-ocloc g++ diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md new file mode 100644 index 0000000000..8ce2a9079d --- /dev/null +++ b/.github/copilot-instructions.md @@ -0,0 +1,167 @@ +# Copilot Coding Agent Onboarding — SYCL*TLA + +Purpose +------- +This file is a short, focused onboarding guide so a coding agent (Copilot coding agent) can make correct, CI-safe changes to the SYCL*TLA repository without long exploratory searches. Keep edits conservative: prefer small, well-tested changes and follow the PR checklist described below. + +Top-level constraints (read first) +--------------------------------- +- **Intel copyright headers**: Many files carry dual NVIDIA/Intel copyright headers. Do not remove or alter copyright headers on modified files. +- **Intel Xe APIs**: The codebase uses new Intel "Xe" APIs (Xe12 for PVC, Xe20 for BMG) and Intel oneAPI toolchain conventions; prefer SYCL-compatible code and avoid adding CUDA-only code paths without explicit gating. +- **CI Requirements**: Changes must build and pass CI workflows in `.github/workflows/*` (notably `intel_test.yml`, `intel_test_gpp_host.yml`, `sycl_python_test.yml`). +- **Test Coverage**: Check for test coverage before making changes. C++ tests are in `test/unit/`, Python tests in `test/python/`. +- **PR Descriptions**: Must include: what changed, why, local build/test steps performed, and expected CI/benchmark impact (see PR templates in `.github/PULL_REQUEST_TEMPLATE/`). + +Quick actions to always run locally before creating a PR +------------------------------------------------------ +1. **ALWAYS source Intel environment first** (required for builds that target Intel compilers; if not available, CMake configure will still catch syntax errors but linking will fail): + +```bash +source /opt/intel/oneapi/setvars.sh +export CXX=icpx +export CC=icx +``` + +2. **ALWAYS create a clean build directory** and configure for SYCL: + +```bash +rm -rf build && mkdir build && cd build +cmake .. -G Ninja \ + -DCUTLASS_ENABLE_SYCL=ON \ + -DDPCPP_SYCL_TARGET=intel_gpu_bmg_g21 \ + -DCUTLASS_SYCL_RUNNING_CI=ON +ninja +``` + +**Critical Notes:** +- `-DDPCPP_SYCL_TARGET` must match your hardware: `intel_gpu_bmg_g21` for BMG (Arc B580), `intel_gpu_pvc` for PVC (Data Center Max). This affects intrinsic availability. +- Build time: ~10-20 minutes for full build on 8-core machine. +- If Intel oneAPI is not installed, CMake configure will still catch syntax errors but linking and target-specific checks will fail. +- **NEVER commit without running a full build locally first**. + +Build / Test / Lint summary +--------------------------- +- **Bootstrap**: No special bootstrap required. Python dependencies in `pyproject.toml` (`networkx`, `numpy`, `pydot`, `scipy`, `treelib`) are needed for Python tests. Install with `pip install -e .` in project root. +- **Build**: Use CMake 3.22+ and Ninja (see commands above). **ALWAYS** run from clean build directory to avoid stale state. +- **C++ Unit Tests**: After build, run `cmake --build . --target test_unit` (runs all unit tests in `test/unit/`). +- **C++ Examples**: `cmake --build . --target test_examples` (builds and validates examples in `examples/`). +- **Python Tests**: + ```bash + cd python + python3 -m pytest -q + ``` + CI runs specific test like `test/python/cutlass/gemm/gemm_bf16_pvc.py`. **ALWAYS** set `export CUTLASS_USE_SYCL=1` before running Python tests. +- **Linting**: No automated linter. Follow existing code style and ensure `-Werror` flag passes (set in CI). + +**Environment Variables Required for Runtime:** +```bash +export ONEAPI_DEVICE_SELECTOR=level_zero:gpu +export IGC_ExtraOCLOptions="-cl-intel-256-GRF-per-thread" +export SYCL_PROGRAM_COMPILE_OPTIONS="-ze-opt-large-register-file -gline-tables-only" +export IGC_VectorAliasBBThreshold=100000000000 +``` +These are set in CI workflows and should be set locally for accurate testing. + +Common failure modes & mitigations +--------------------------------- +- **Missing Intel environment**: builds fail at linking or with unknown compilers. Mitigation: Source `/opt/intel/oneapi/setvars.sh` or unset `CXX`/`CC` to use system compilers for syntax-only checks. +- **Wrong SYCL target**: some intrinsics are target-specific (e.g., 2D block prefetch intrinsics). Match the CI target or use conservative code paths. +- **Layout constraints in Intel Xe epilogues** (ColumnMajor/RowMajor): prefer to reuse existing epilogue code and tests to avoid violating layout constraints. If making changes, run the affected tests locally. +- **Missing libraries in LD_LIBRARY_PATH** for runtime: set `LD_LIBRARY_PATH` to include `build/tools/library` when running `python` tests that load `.so` wrappers. +- **CMake cache issues**: If you see unexpected build behavior, **ALWAYS** delete `build/` completely and reconfigure. Stale CMake cache causes many hard-to-debug issues. +- **Python import errors**: If Python tests fail with import errors, run `pip install -e .` from project root first. + +CI and validation pipelines (what will run) +------------------------------------------- +- See `.github/workflows/` for exact pipelines. Most important: + - `intel_test.yml` — primary CI build for Intel targets + - `intel_test_gpp_host.yml` — GPP host builds + - `sycl_python_test.yml` — Python test workflow + - `nvidia_test.yml` / `cuda_test.yml` — CUDA-targeted tests (keep changes SYCL-first unless explicitly modifying CUDA paths) + +How the agent should validate its changes +----------------------------------------- +1. Run a local CMake configure and build (fast smoke test): + +```bash +rm -rf build && mkdir build && cd build +cmake .. -G Ninja -DCUTLASS_ENABLE_SYCL=ON -DDPCPP_SYCL_TARGET=intel_gpu_bmg_g21 -DCUTLASS_SYCL_RUNNING_CI=ON +ninja -k 0 +``` + +2. Run the Python test subset that touches modified components (or all Python tests if the change is cross-cutting): + +```bash +cd python +python3 -m pytest -q +``` + +3. For C++ kernel changes, run unit tests: `cmake --build . --target test_unit -j 8` + +4. For examples changes, run: `cmake --build . --target test_examples -j 1` +-------------------------------------------------- +- Short summary of change and the files modified. +- Build steps executed locally (CMake + Ninja commands, environment variables set). +- Tests run and their results (include pytest subset names and pass/fail counts). +- If the change affects performance or kernel selection, include expected performance impact and a short benchmark (size and results). +- State whether the Intel oneAPI environment was required to fully validate the change. + +Project layout (quick map) +-------------------------- +- Root files: `CMakeLists.txt`, `README.md`, `CHANGELOG-SYCL.md`, `SYCL.cmake`, `pyproject.toml` +- Major directories: + - `include/` — core headers and kernel templates + - `python/` — Python wrapper, generator, tests + - `examples/` — usage examples, e.g., `11_xe20_cutlass_library` + - `test/` — C++ tests and validation kernels + - `tools/` — build/test utilities + - `media/` — documentation and architecture notes (search `media/docs/cpp/xe_rearchitecture.md`, `media/docs/python/xe_cutlass_library.md`) + +Files the agent should inspect when making changes +-------------------------------------------------- +- `python/cutlass_library/generator.py` — kernel generation and filtering logic +- `python/cutlass_library/arch_constants.py` — architecture detection and constants +- `include/cutlass/gemm/kernel/*` — GEMM kernel implementations +- `.github/workflows/*` — CI steps; ensure changes don't break these workflows + +Search tips +----------- +- Use `grep -R "HACK\|TODO\|WORKAROUND\|FIXME"` to find fragile areas. +- Search for `intel` and `Xe` keywords to find Intel-specific code paths. + +Testing coverage +---------------- +- The repo contains Python tests in `python/` and C++ tests under `test/`. +- Before assuming full coverage, run the test suite locally and include failing tests in your PR notes. + +Special rules for changes +------------------------- +- Keep changes minimal and well-scoped. If modifying kernel selection or architecture constants, include tests or fallbacks. +- Preserve Intel copyright headers. +- Avoid introducing CUDA-only code paths in SYCL code. + +When to run a wider search +-------------------------- +Trust these instructions first. Only perform a broad code search if: +- The instructions are clearly missing information for the requested change, or +- A test or build step fails unexpectedly after following these steps. + +Where to look for help +---------------------- +- `README.md` and `media/docs/*` for architecture details +- `.github/workflows/` for CI expectations +- Open issues and PR templates in `.github/ISSUE_TEMPLATE` and `.github/PULL_REQUEST_TEMPLATE` + +Short checklist before opening a PR +----------------------------------- +- [ ] Build configured and compiled locally (or syntax-checked if environment unavailable) +- [ ] Relevant tests run locally and passed +- [ ] PR description includes steps and validation results +- [ ] No removal of Intel copyright headers + +If anything here proves inaccurate +---------------------------------- +Run the minimal searches you need, then update this file with the corrected steps so future agents benefit. + +--- +This file is intentionally short (<= 2 pages). For deeper onboarding, consult `README.md`, `media/docs/*`, and the workflows in `.github/workflows/`. diff --git a/.github/instructions/copilot.instructions.md b/.github/instructions/copilot.instructions.md new file mode 100644 index 0000000000..2ae263a078 --- /dev/null +++ b/.github/instructions/copilot.instructions.md @@ -0,0 +1,167 @@ +# GitHub Copilot Instructions for SYCL*TLA + +## Project Overview + +SYCL*TLA (SYCL Templates for Linear Algebra) is a high-performance C++ template library for tensor linear algebra operations on Intel GPUs using SYCL. This repository is forked from NVIDIA CUTLASS and extends CUTLASS and CuTe API support to Intel GPUs through SYCL enablement. + +## Key Technologies and Frameworks + +- **SYCL**: Primary parallel programming model for Intel GPU acceleration +- **Intel oneAPI**: Compiler toolchain (icpx/icx) and runtime environment +- **CUTLASS 3.x**: Template-based GEMM library architecture +- **CuTe**: Compile-time tensor abstraction library +- **Intel Xe Architecture**: Target hardware (BMG arch=20, PVC arch=12) +- **CMake/Ninja**: Build system configuration +- **Python**: High-level interface and testing framework + +## Architecture-Specific Guidelines + +### Intel Xe GPU Support +- **Primary Targets**: Intel Data Center GPU Max/Flex Series (PVC) and Intel Arc B580 (BMG) +- **Architecture Constants**: Use BMG (arch=20) for consumer GPUs, PVC (arch=12) for data center +- **Compilation Targets**: + - BMG: `intel_gpu_bmg_g21` + - PVC: `intel_gpu_pvc` +- **DPAS Operations**: Intel's Dot Product Accumulate Systolic instruction support + +### Layout Constraints +- **Intel Xe Requirements**: Specific layout constraints for epilogue operations +- **Memory Access Patterns**: 2D block operations with Intel GPU intrinsics +- **Stride Requirements**: Use CuTe stride descriptors with int64_t types + +## Code Style and Conventions + +### C++ Templates +- Follow CUTLASS 3.x template parameter conventions +- Use CuTe abstractions for tensor operations +- Maintain compile-time optimization patterns +- Template specializations for Intel Xe architecture features + +### SYCL-Specific Patterns +- Use `sycl::queue` for device operations +- Prefer `sycl::buffer` and `sycl::accessor` for memory management +- Follow Intel GPU subgroup programming model +- Use Intel GPU intrinsics for optimized operations + +### File Organization +- **Headers**: `include/cutlass/` for core templates +- **Kernels**: Architecture-specific implementations in `include/cutlass/gemm/kernel/` +- **Examples**: `examples/` with Intel Xe demonstrations +- **Python**: `python/cutlass_library/` for high-level interface +- **Tests**: `test/` with Intel GPU validation + +## Development Workflow + +### Environment Setup +```bash +# Source Intel oneAPI environment +source /opt/intel/oneapi/setvars.sh + +# Configure for Intel GPU compilation +export CXX=icpx +export CC=icx +``` + +### Build Configuration +- Use CMake with `-DCUTLASS_ENABLE_SYCL=ON` +- Specify target architecture: `-DDPCPP_SYCL_TARGET=intel_gpu_bmg_g21` +- Enable CI mode for testing: `-DCUTLASS_SYCL_RUNNING_CI=ON` + +### Testing Patterns +- Validate on both BMG and PVC architectures when possible +- Include mixed-precision test cases (FP16, BF16, FP8) +- Test epilogue fusion operations +- Verify against reference implementations + +## Common Anti-Patterns to Avoid + +### Performance Issues +- **Avoid**: Unnecessary template instantiations +- **Avoid**: Non-coalesced memory access patterns +- **Avoid**: Suboptimal tile sizes for Intel Xe +- **Avoid**: Missing architecture-specific optimizations + +### SYCL-Specific Issues +- **Avoid**: CUDA-specific code paths in SYCL builds +- **Avoid**: Hardcoded NVIDIA architecture assumptions +- **Avoid**: Missing Intel GPU intrinsic availability checks +- **Avoid**: Incorrect compilation target specifications + +### Template Code Issues +- **Avoid**: Template parameter naming conflicts with CUTLASS conventions +- **Avoid**: Missing SFINAE constraints for Intel Xe specializations +- **Avoid**: Breaking CuTe compile-time optimization patterns + +## Key Files and Components + +### Core Templates +- `include/cutlass/gemm/kernel/xe_gemm_universal.hpp`: Intel Xe GEMM kernels +- `include/cutlass/epilogue/threadblock/xe_epilogue.hpp`: Intel Xe epilogue operations +- `include/cutlass/arch/arch.h`: Architecture detection and constants + +### Python Interface +- `python/cutlass_library/generator.py`: Kernel generation logic +- `python/cutlass_library/arch_constants.py`: Architecture configuration +- `python/cutlass_library/cutlass_test_wrapper_source.cpp`: C++ wrapper + +### Build System +- `CMakeLists.txt`: Main build configuration +- `SYCL.cmake`: SYCL-specific build rules +- `.github/workflows/intel_test.yml`: CI configuration + +## Documentation Standards + +### Code Comments +- Document Intel Xe-specific optimizations and constraints +- Explain template parameter purposes and valid ranges +- Include usage examples for complex template patterns +- Reference Intel GPU programming guides where applicable + +### Commit Messages +- Use conventional commit format: `type(scope): description` +- Common types: `feat`, `fix`, `perf`, `refactor`, `test`, `docs` +- Include architecture scope when relevant: `feat(xe20): add BMG GEMM support` + +### Pull Request Guidelines +- Include performance impact analysis for kernel changes +- Test on multiple Intel GPU architectures when possible +- Document any breaking changes to template interfaces +- Provide before/after benchmarks for optimizations + +## Intel GPU-Specific Considerations + +### Memory Hierarchy +- L3 cache optimization for Intel Xe +- Shared local memory usage patterns +- Global memory coalescing requirements + +### Compute Capabilities +- DPAS instruction utilization +- Subgroup size considerations (16 or 32) +- Matrix engine programming patterns + +### Debugging and Profiling +- Use Intel VTune for performance analysis +- Intel Graphics Compiler optimization flags +- SYCL profiling and debugging tools + +## Integration Points + +### Python Ecosystem +- PyTorch XPU backend compatibility +- NumPy array interface support +- dpctl device management integration + +### Intel oneAPI Ecosystem +- oneMKL integration for reference implementations +- Level Zero runtime compatibility +- Intel GPU driver requirements + +## Version and Compatibility + +- **SYCL Version**: Follow Intel SYCL implementation standards +- **Intel oneAPI**: Support latest LTS and current releases +- **Python**: Maintain compatibility with Python 3.8+ +- **CMake**: Minimum version 3.18 for SYCL support + +When reviewing or contributing code, always consider the Intel GPU architecture implications, SYCL best practices, and template library design patterns that maintain the high-performance characteristics of the CUTLASS foundation while enabling optimal execution on Intel hardware. \ No newline at end of file diff --git a/.github/workflows/intel_test.yml b/.github/workflows/intel_test.yml index 09bcd640ab..e32c8128e3 100644 --- a/.github/workflows/intel_test.yml +++ b/.github/workflows/intel_test.yml @@ -43,16 +43,22 @@ jobs: gpu: BMG intel_graphics: ROLLING sycl_target: intel_gpu_bmg_g21 + igc_version_major: 2 + igc_version_minor: 18 runner: bmg108629-01 - compiler: RELEASE gpu: PVC intel_graphics: ROLLING sycl_target: intel_gpu_pvc + igc_version_major: 2 + igc_version_minor: 11 runner: pvc146162-01 - compiler: NIGHTLY gpu: PVC intel_graphics: ROLLING sycl_target: intel_gpu_pvc + igc_version_major: 2 + igc_version_minor: 11 runner: pvc146162-01 name: Run Intel ${{ matrix.compiler }} tests on ${{ matrix.gpu }} with intel-graphics ${{ matrix.intel_graphics }} @@ -103,6 +109,8 @@ jobs: cmake -G Ninja \ -DCUTLASS_ENABLE_SYCL=ON \ -DDPCPP_SYCL_TARGET=${{ matrix.sycl_target }} \ + -DIGC_VERSION_MAJOR=${{ matrix.igc_version_major }} \ + -DIGC_VERSION_MINOR=${{ matrix.igc_version_minor }} \ -DCMAKE_CXX_FLAGS="-Werror" \ -DCUTLASS_SYCL_RUNNING_CI=ON cmake --build . diff --git a/.github/workflows/intel_test_gpp_host.yml b/.github/workflows/intel_test_gpp_host.yml index c75c9ad508..62c253324e 100644 --- a/.github/workflows/intel_test_gpp_host.yml +++ b/.github/workflows/intel_test_gpp_host.yml @@ -28,11 +28,15 @@ jobs: gpu: BMG intel_graphics: ROLLING sycl_target: intel_gpu_bmg_g21 + igc_version_major: 2 + igc_version_minor: 18 runner: bmg108629-01 - compiler: RELEASE gpu: PVC intel_graphics: ROLLING sycl_target: intel_gpu_pvc + igc_version_major: 2 + igc_version_minor: 11 runner: pvc146162-01 @@ -83,9 +87,12 @@ jobs: cmake -G Ninja \ -DCUTLASS_ENABLE_SYCL=ON \ -DDPCPP_SYCL_TARGET=${{ matrix.sycl_target }} \ + -DIGC_VERSION_MAJOR=${{ matrix.igc_version_major }} \ + -DIGC_VERSION_MINOR=${{ matrix.igc_version_minor }} \ -DCUTLASS_SYCL_RUNNING_CI=ON \ -DCMAKE_CXX_FLAGS="-Werror" \ - -DDPCPP_HOST_COMPILER=g++-13 + -DDPCPP_HOST_COMPILER=g++-13 \ + -DDPCPP_HOST_COMPILER_OPTIONS="-Werror" cmake --build . - name: Unit test diff --git a/CHANGELOG-SYCL.md b/CHANGELOG-SYCL.md index df341ceb58..08ad066fc3 100644 --- a/CHANGELOG-SYCL.md +++ b/CHANGELOG-SYCL.md @@ -1,6 +1,46 @@ -# CUTLASS SYCL Changelog +# SYCL*TLA (previously referred to as cutlass-sycl) Changelog -## [CUTLASS SYCL 0.5](https://github.com/intel/cutlass-sycl/releases/tag/v0.5) (2025-09-26) +## [SYCL*TLA 0.6](https://github.com/intel/sycl-tla/releases/tag/v0.6) (2025-11-03) +### Major Architecture Changes +- **Flash Attention Reimplementation ([#d02c58b](https://github.com/intel/sycl-tla/commit/d02c58b4))**: Complete rewrite of Flash Attention using new Xe atoms + - Enhanced performance with optimized memory access patterns + - Better integration with Intel Xe hardware capabilities +- **CUTLASS Library Generation ([#578](https://github.com/intel/sycl-tla/pull/578))**: Full support for CUTLASS library generation and operations + - New Xe architecture support in library generation pipeline + - Automated kernel instantiation and compilation support + +### Enhancements +- **Python Operations Support ([#595](https://github.com/intel/sycl-tla/pull/595))**: Enhanced Python bindings with comprehensive test coverage + - Improved Python API stability and usability + - Enhanced test framework for Python operations +- **CuTe Subgroup Extensions**: New subgroup-scope operations for Intel Xe + - Subgroup broadcast and reduction operations ([#9a6aa27](https://github.com/intel/sycl-tla/commit/9a6aa27c)) + - `make_subgroup_tensor` helpers for improved tensor manipulation ([#21fb89a](https://github.com/intel/sycl-tla/commit/21fb89a8)) +- **Enhanced 2D Copy Operations**: Extended block 2D copy functionality + - New `make_block_2d_copy_{C,D}` variants with subtiling support ([#48d82e8](https://github.com/intel/sycl-tla/commit/48d82e87)) + - Support for size-1 fragments in block 2D copies ([#2212f1b](https://github.com/intel/sycl-tla/commit/2212f1b9)) +- **4-bit VNNI Reorders ([#593](https://github.com/intel/sycl-tla/pull/593))**: New 4-bit unit stride to VNNI reorder operations +- **Batch GEMM with new APIs ([#540](https://github.com/intel/sycl-tla/pull/540))**: Enhanced Batch GEMM with new streamlined APIs +- **Grouped GEMM with new APIs ([#574](https://github.com/intel/sycl-tla/pull/574))**: Enhanced grouped GEMM with new streamlined APIs + +### Test Improvements +- **Python Test Coverage**: Comprehensive test suite improvements for Python operations +- **CI Infrastructure**: Enhanced continuous integration with PVC driver updates ([#575](https://github.com/intel/sycl-tla/pull/575)) +- **Code Reorganization**: Renamed `python/cutlass` to `python/cutlass_cppgen` for clarity ([#587](https://github.com/intel/sycl-tla/pull/587)) + +### Bug Fixes +- **Epilogue Data Type Fixes**: + - Fixed trD compute type in Xe Epilogue ([#580](https://github.com/intel/sycl-tla/pull/580)) + - Resolved epilogue data type mismatches ([#563](https://github.com/intel/sycl-tla/pull/563)) +- **CuTe Copy(new APIs) Improvements**: Multiple fixes for Xe copy operations ([#dec36a9](https://github.com/intel/sycl-tla/commit/dec36a9e)) +- **Split Barrier Refactoring**: Improved split barrier functionality for better reliability ([#521dfcd](https://github.com/intel/sycl-tla/commit/521dfcd4)) + +### Notes and Known Issues +- Python Operations for FP8 and INT8 not generated for CUTLASS library in this release. +- Unit tests and benchmark tests are not yet migrated to newly re architected CuTe APIs. + + +## [SYCL*TLA 0.5](https://github.com/intel/cutlass-sycl/releases/tag/v0.5) (2025-09-26) ### Major Architecture Changes - **Xe Rearchitecture ([#477](https://github.com/intel/cutlass-sycl/pull/477))**: Complete redesign of Xe CuTe atoms with new architecture - New MMA atoms for improved performance diff --git a/CHANGELOG.md b/CHANGELOG.md index 7a897f95bd..4d4e8e7599 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,34 +2,93 @@ # CUTLASS 4.x -## [4.2.0](https://github.com/NVIDIA/cutlass/tree/main) (2025-08-21) +## [4.2.1](https://github.com/NVIDIA/cutlass/releases/tag/v4.2.1) (2025-09-22) ### CuTe DSL -* We will likely be skipping 4.2.dev release and directly target 4.2. -* CuTeDSL version remains at 4.1.0 till then. +* Bug fixings and improvements + - Fixed an issue when running DSL codes with cuda-python 13.0 + - Fixed an issue when running inductor with DSL codes + - Fixed an issue with unexpected logging when running DSL codes in FlashInfer + - Fixed the issue reported in https://github.com/NVIDIA/cutlass/issues/2647 + - Fixed an issue when conditional define of variables outside of dynamic control flow ### CUTLASS C++ -* Add K major scale factor support for Hopper SM90 blockwise kernels. +* Bypass EVT for nosmem blockwise kernels on Blackwell. +* Rename cutlass/python/cutlass directory to cutlass/python/cutlass_cppgen. + +## [4.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.2.0) (2025-09-15) + +### CuTe DSL +* More Python versions are now supported for both x86-64 and aarch64, including + - Python 3.10, 3.11, 3.12, and 3.13 +* Added new example and updated notebook to get started with CuTe DSL + - [Call kernels with dlpack bypassed](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/call_bypass_dlpack.py) + - Updates on [TensorSSA demonstration](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks/tensorssa.ipynb) + + Added a section for introducing the broadcast +* API updates + - Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details +* Bug fixings and improvements + - Fixed ``cute.print_tensor`` for coordinate tensor + - Fixed `cute.print` for tuple of layouts + - Fixed frozen object is not properly updated after fully assigned in dynamic control flow + - Fixed assign tuple/list element in a dynamic control flow may cause compilation failure + - Improved error message when CUDA context is not initialized + - Improved docstring of congruent and weakly_congruent + +### CUTLASS C++ +* Support for Blackwell SM103 kernels for B300 GPUs. + - Collective mainloop codes: [Blockscaled datatypes with support for dense GEMM mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm103_blockscaled_mma_warpspecialized.hpp) + - New [GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/dispatch_policy.hpp) and [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/dispatch_policy.hpp) dispatch policies for collectives, kernel layers, and builders. + - Kernel codes: [Blockscaled datatypes with support for dense GEMM kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_tma_warpspecialized.hpp). +* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM103 architecture: + - [Blockscaled ultra fp4 dense GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/89_sm103_fp4_ultra_gemm/). + - [Blockscaled ultra fp4 dense grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/90_sm103_fp4_ultra_grouped_gemm). +* Set of unit tests that demonstrate the usage of Blackwell SM103 blockscaled GEMM + - Unit test files with prefix name of `sm103_` under [GEMM device unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/). +* Support for Blackwell SM121 kernels for DGX Spark GPUs. + - Share the major codes with Blackwell SM120 kernels. +* Add support for heuristics-based kernel filtering and autotuning using `nvidia-matmul-heuristics` to find the best kernels for a given scenario. + - Details please refer to [heuristics doc](https://github.com/NVIDIA/cutlass/tree/main/media/docs/cpp/heuristics.md). * Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/). - Add fused reduction kernel support for cutlass MLA. + - Add softmax skip correction. + - Support for GQA in FMHA backward kernel. - Fix an issue where `get_unmasked_trip_count` may return a negative value. - Fix an issue where mbarriers are initialized with a zero arrival count. -* Add Blackwell SM120 blockwise gemm kernel example: [example 87](https://github.com/NVIDIA/cutlass/tree/main/87_blackwell_geforce_gemm_blockwise/). -* Support for Blackwell SM100 cpasync kernel. - - Collective mainloop codes: [cpasync mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp). - - Kernel codes: [cpasync kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_cpasync_warpspecialized.hpp). -* Support for Blackwell SM121 kernels for DGX Spark GPUs. - - Share the major codes with Blackwell SM120 kernels. + - Fix a corner case issue where the sequence length of q is not a multiple of tile_q. + - Remove tma padding for forward kernel inputs. +* Add Blackwell SM100 kernels for MoEs (focusing on Low-Latency inference performance): [example 92](https://github.com/NVIDIA/cutlass/tree/main/examples/92_blackwell_moe_gemm/). It uses TMA (for weights) and CPASYNC (for tokens) to load input matrices and allow only one problem dimension to vary across groups/experts, unlike general Grouped GEMMs. Note: further API simplifications and kernel improvements are upcoming. Any feedback on API is welcome. +* Further enhance blockwise and groupwise GEMMs on Hopper and Blackwell + - On Blackwell SM120, a blockwise gemm kernel is added: [example 87](https://github.com/NVIDIA/cutlass/tree/main/examples/87_blackwell_geforce_gemm_blockwise/). + - On Hopper, add K major scale factor support for SM90 blockwise kernels. + - On Hopper, relax the restriction that the k dimension of the problem size has to be the multiple of the k dimension of the tile size. + - On Hopper, grouped version supports the case when k = 0. +* Support for Blackwell SM100 fp4 gemv kernels. + - Kernel codes: [Gemv kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/gemv_blockscaled.h). + - Example codes: [example 91](https://github.com/NVIDIA/cutlass/tree/main/examples/91_fp4_gemv/) * Support for Blackwell SM100 legacy mixed input GEMM kernels. - Collective mainloop codes: [Mixed input mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp). - Kernel codes: [Mixed input kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mixed_input_transform.hpp). - Example codes: [example 86](https://github.com/NVIDIA/cutlass/tree/main/examples/86_blackwell_mixed_dtype_gemm/). -* Support for Blackwell SM100 fp4 gemv kernels. - - Kernel codes: [Gemv kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/gemv_blockscaled.h). - - Example codes: [example 91](https://github.com/NVIDIA/cutlass/tree/main/examples/91_fp4_gemv/) +* Support for Blackwell SM100 cpasync kernel. + - Collective mainloop codes: [cpasync mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp). + - Kernel codes: [cpasync kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_cpasync_warpspecialized.hpp). +* Support Blackwell SM120 mixed input blockscaled grouped GEMM. +* Instantiating more Blackwell kernels in profiler. + - Blackwell SM100 and SM103 kernels support `CUTLASS_LIBRARY_INSTANTIATION_LEVEL` to instantiate all possible combinations. + - To use this feature, `CUTLASS_LIBRARY_KERNELS` must be non-empty. Profiler will combine `CUTLASS_LIBRARY_KERNELS` and `CUTLASS_LIBRARY_INSTANTIATION_LEVEL` to instantiate specific kernels. + - Details please check [Profiler Doc](https://github.com/NVIDIA/cutlass/tree/main/media/docs/cpp/profiler.md). +* Fix some profiler issues: + - Modify default cluster callback values to none 0 to avoid profiler failure when these values are not set in command line. + - Fix some no output and timeout issues. + - Fix Pingpong Blockwise Hopper library generation. * From CUDA 13.0, the Blackwell SM101 for Thor GPUs is renamed to SM110. - For CUDA toolkit version < 13.0, SM101 is still used for Thor GPUs. - For CUDA toolkit version >= 13.0, SM110 is used for Thor GPUs and SM101 is no longer valid. +* Rename legacy Python API package from `cutlass` to `cutlass_cppgen` and add Blackwell EVT support to legacy Python interface. + - Restructuring the C++ Blackwell SM100 Collective Epilogue Builder to work with the Python interface's `EpilogueDescriptors`. + - Added Blackwell SM100 EVT Emitter on the Python side and routed most emission through Hopper SM90 Emitter. + - Added some support for running SM100 kernels via the Python interface. * CuTe changes: - Fix inaccurate GridDim calculation under [CuTe tutorial](https://github.com/NVIDIA/cutlass/tree/main/examples/cute/tutorial/blackwell/). - Add [movmatrix](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-movmatrix) support. @@ -38,18 +97,15 @@ - Shorten `nullspace` implementation. - Isolate and comment on `cosize` hacks. - Important documentation correction: `E<0,1> == 1@0@1`. -* Add support for heuristics-based kernel filtering and autotuning using `nvidia-matmul-heuristics`. - - Details please refer to [heuristics doc](https://github.com/NVIDIA/cutlass/tree/main/media/docs/cpp/heuristics.md). -* Rename legacy Python API package from `cutlass` to `cutlass_cppgen`. -* Fix some profiler issues: - - Modify default cluster callback values to none 0 to avoid profiler failure when these values are not set in command line. - - Fix some no output and timeout issues. +* Fix some kernel issues: + - Fix Hopper SM90 group gemm kernel to only use the commit group and wait group instead of also waiting on mbarriers. + - Fix a tiny bug when K is large for Blackwell SM103 fp4 grouped GEMM kernel. * Add following unit tests: - [fp16 accmulator for sm89 fp8 mma](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/ampere/cooperative_gemm.cu) - [movmatrix test](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/turing/movm.cu) - [fp8 narrow mma n](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32_narrow_mma_n.cu) and [fp16 narrow mma n](test/unit/gemm/device/sm100_tensorop_gemm/f8_f8_void_bf16_narrow_mma_n.cu) * Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs! -* Optimal code generation with CUDA toolkit versions 13.0. +* Optimal code generation with CUDA toolkit versions 13.0U1. ## [4.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.1.0) (2025-07-16) diff --git a/CMakeLists.txt b/CMakeLists.txt index f0110c4d71..c9f974cac2 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -125,24 +125,50 @@ option(CUTLASS_SYCL_BUILTIN_ENABLE "Enable this option to use builtin functions if (CUTLASS_ENABLE_SYCL) set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) - if(DPCPP_SYCL_TARGET STREQUAL "nvptx64-nvidia-cuda") + string(REPLACE "," ";" DPCPP_SYCL_TARGET_LIST "${DPCPP_SYCL_TARGET}") + message(STATUS "DPCPP_SYCL_TARGET as list ${DPCPP_SYCL_TARGET_LIST}") + + if("nvptx64-nvidia-cuda" IN_LIST DPCPP_SYCL_TARGET_LIST) set(SYCL_NVIDIA_TARGET ON) add_compile_definitions(SYCL_NVIDIA_TARGET) - elseif(DPCPP_SYCL_TARGET STREQUAL "intel_gpu_pvc" OR - DPCPP_SYCL_TARGET STREQUAL "spir64" OR - DPCPP_SYCL_TARGET STREQUAL "intel_gpu_bmg_g21") - set(SYCL_INTEL_TARGET ON) - add_compile_definitions(SYCL_INTEL_TARGET) - add_compile_options(-Wall - -Wno-unused-variable - -Wno-unused-local-typedef - -Wno-unused-but-set-variable - -Wno-uninitialized - -Wno-reorder-ctor - -Wno-logical-op-parentheses - -Wno-unused-function - -Wno-unknown-pragmas - ) + list(REMOVE_ITEM DPCPP_SYCL_TARGET_LIST "nvptx64-nvidia-cuda") + endif() + + set(INTEL_SYCL_TARGETS + intel_gpu_pvc + spir64 + intel_gpu_bmg_g21 + bmg + pvc + ) + + foreach(ITEM IN LISTS DPCPP_SYCL_TARGET_LIST) + if(ITEM IN_LIST INTEL_SYCL_TARGETS) + set(SYCL_INTEL_TARGET ON) + else() + message(FATAL_ERROR "In SYCL target list '${DPCPP_SYCL_TARGET}' ${ITEM} is invalid.") + endif() + endforeach() + + if (SYCL_NVIDIA_TARGET AND SYCL_INTEL_TARGET) + message(FATAL_ERROR "Both SYCL_NVIDIA_TARGET and SYCL_INTEL_TARGET are set. Only one is allowed.") + elseif (NOT SYCL_NVIDIA_TARGET AND NOT SYCL_INTEL_TARGET) + message(FATAL_ERROR "Both SYCL_NVIDIA_TARGET and SYCL_INTEL_TARGET are unset.") + endif() + + if(SYCL_INTEL_TARGET) + add_compile_definitions(SYCL_INTEL_TARGET) + add_compile_options( + -Wall + -Wno-unused-variable + -Wno-unused-local-typedef + -Wno-unused-but-set-variable + -Wno-uninitialized + -Wno-reorder-ctor + -Wno-logical-op-parentheses + -Wno-unused-function + -Wno-unknown-pragmas + ) endif() add_compile_definitions(CUTLASS_ENABLE_SYCL) @@ -158,7 +184,6 @@ if (CUTLASS_ENABLE_SYCL) add_compile_definitions(CUTLASS_SYCL_BUILTIN_ENABLE) endif() - include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/onemkl.cmake) endif() find_package(Doxygen QUIET) @@ -201,7 +226,7 @@ set(CUTLASS_ENABLE_HEADERS_ONLY OFF CACHE BOOL "Enable only the header library") if(CUTLASS_ENABLE_HEADERS_ONLY) set(CUTLASS_ENABLE_EXAMPLES_INIT OFF) - set(CUTLASS_ENABLE_TOOLS_INIT ON) + set(CUTLASS_ENABLE_TOOLS_INIT OFF) set(CUTLASS_ENABLE_LIBRARY_INIT OFF) set(CUTLASS_ENABLE_TESTS_INIT OFF) else() @@ -213,6 +238,9 @@ else() else() set(CUTLASS_ENABLE_TESTS_INIT OFF) endif() + + # Not include MKL when headers only + include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/onemkl.cmake) endif() set(CUTLASS_TEST_UNIT_ENABLE_WARNINGS OFF CACHE BOOL "Enable warnings on waived unit tests.") @@ -417,7 +445,7 @@ set(CUTLASS_LIBRARY_OPERATIONS "all" CACHE STRING "Comma-delimited list of opera set(CUTLASS_LIBRARY_KERNELS ${CUTLASS_LIBRARY_KERNELS_INIT} CACHE STRING "Comma-delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If the string 'all' is specified, all kernels are enabled.") set(CUTLASS_LIBRARY_IGNORE_KERNELS "" CACHE STRING "Comma-delimited list of kernels to exclude from build. This option ONLY takes effect if CUTLASS_LIBRARY_KERNELS is set.") set(CUTLASS_LIBRARY_EXCLUDE_KERNELS "" CACHE STRING "Comma-delimited list of kernels to exclude from build. This option always takes effect, whether or not CUTLASS_LIBRARY_KERNELS is set. It also can exclude kernels from the filter file (see KERNEL_FILTER_FILE).") -set(CUTLASS_LIBRARY_INSTANTIATION_LEVEL "" CACHE STRING "Instantiation level for SM90 kernels. Set to `max` and make sure CUTLASS_LIBRARY_KERNELS is non-empty to stamp all possible kernel configurations.") +set(CUTLASS_LIBRARY_INSTANTIATION_LEVEL "" CACHE STRING "Instantiation level for SM90 and SM100 kernels. Set to `max` and make sure CUTLASS_LIBRARY_KERNELS is non-empty to stamp all possible kernel configurations.") if(CUTLASS_LIBRARY_INSTANTIATION_LEVEL OR CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE) message(STATUS "Enable extended SM90 WGMMA instruction shapes for instantiation levels") diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index e6715dd490..0c725d3248 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -2,7 +2,7 @@ [README](./README.md#documentation) > **Contributors** -# CUTLASS SYCL Developers ** +# SYCL*TLA Developers ** Alejandro Acosta
Amit Singh Chandel
diff --git a/README.md b/README.md index 0f2e74ed4d..a4544af34a 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,15 @@ ![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") -# CUTLASS SYCL 0.5 +# SYCL\* Templates for Linear Algebra (SYCL\*TLA) -**This repository fast-follows NVIDIA CUTLASS repository adding SYCL support for Intel GPUs.** +**This repository is forked from the NVIDIA CUTLASS repository and extends CUTLASS and CuTe API support to Intel GPUs through SYCL enablement.** +*This project was previously referred to as CUTLASS-SYCL, you may see references to CUTLASS-SYCL in the code and documentation.* +*For SYCL support instructions, refer to the [SYCL build documentation](./media/docs/cpp/build/building_with_sycl_support.md)* -**For SYCL support instructions, refer to the [SYCL build documentation](./media/docs/cpp/build/building_with_sycl_support.md)** +*SYCL is a trademark of the Khronos Group Inc, Other names and brands may be claimed as the property of others.* +[![OpenSSF Scorecard](https://api.scorecard.dev/projects/github.com/intel/sycl-tla/badge)](https://scorecard.dev/viewer/?uri=github.com/intel/sycl-tla) -[![OpenSSF Scorecard](https://api.scorecard.dev/projects/github.com/intel/cutlass-sycl/badge)](https://scorecard.dev/viewer/?uri=github.com/intel/cutlass-sycl) - -CUTLASS‑SYCL is a modular, header‑only C++ template framework for high‑performance +SYCL\*TLA is a modular, header‑only C++ template framework for high‑performance GEMM, and fused epilogue kernels. It applies hierarchical tiling, composable policy abstractions, and efficient data‑movement primitives to build flexible, reusable building blocks for dense linear algebra. The SYCL implementation brings those @@ -16,13 +17,13 @@ optimizations to Intel GPUs with tuned kernels for modern execution units and me hierarchies. It adds mixed‑precision and epilogue fusion pathways designed to simplify integrating advanced quantization and post‑processing into custom pipelines. -To support a wide variety of applications, CUTLASS-SYCL provides extensive +To support a wide variety of applications, SYCL\*TLA provides extensive support for mixed-precision computations on Intel hardware, providing specialized data-movement and multiply-accumulate abstractions for FP64, FP32, FP16, BF16, 8b floating point types (E5M2 and E4M3 for FP8), narrow integer types (4 and 8b signed and unsigned integers with support for zero-point quantization), and mixed-precision operations with tensor-wise, channel-wise, -and group-wise quantization support. CUTLASS-SYCL demonstrates optimal matrix +and group-wise quantization support. SYCL\*TLA demonstrates optimal matrix multiply operations targeting Intel's programmable, high-throughput execution units implemented in Intel Data Center GPU Max/Flex Series (Intel Xe architecture, codename: Ponte-Vecchio) and Intel Arc B580 GPUs. @@ -33,41 +34,48 @@ See the [functionality docs](./media/docs/cpp/functionality.md) for a more compr list of kernel level features, data types, instructions, and minimum supported by CUTLASS on each GPU architecture. -Base NVIDIA CUTLASS Versions for CUTLASS-SYCL releases: -| CUTLASS SYCL | NVIDIA CUTLASS | +This project fast follows NVIDIA CUTLASS releases to ensure parity of APIs and features. + +Base NVIDIA CUTLASS Versions for SYCL*TLA releases: +| SYCL*TLA | NVIDIA CUTLASS | |-----------------|----------| |0.1| 3.9| |0.2 | 3.9.2 | |0.3 | 3.9.2 | |0.5 | 4.2.0 | +|0.6 | 4.2.0 | -# What's New in CUTLASS SYCL 0.5 +# What's New in SYCL*TLA 0.6 +## [SYCL*TLA 0.6](https://github.com/intel/sycl-tla/releases/tag/v0.6) (2025-11-03) ### Major Architecture Changes -- **Xe Rearchitecture ([#477](https://github.com/intel/cutlass-sycl/pull/477))**: Complete redesign of Xe CuTe atoms with new architecture - - New MMA atoms for improved performance - - Enhanced 2D copy atoms (loads, stores, prefetch with VNNI/transpose support) - - New 2D copy helpers (low-level `make_block_2d_copy` and high-level `make_block_2d_copy_{A,B,C}`) - - Generic and optimized reorder atoms for {int4, uint4, int8, uint8, e2m1, e4m3, e5m2} -> {half, bfloat16} - - Requires IGC version [v2.18.5](https://github.com/intel/intel-graphics-compiler/releases/tag/v2.18.5) or later - -### New Features -- **G++ Host Compiler Support ([#490](https://github.com/intel/cutlass-sycl/pull/490))**: Support for G++ 13 as host compiler -- Migrated `syclcompat` to this repository as `cutlasscompat` for better compatibility - - Fixed compilation issues when using G++ instead of clang++ - - Added new CI workflow for testing G++ host compiler builds - - Enhanced build system to support `-DDPCPP_HOST_COMPILER=g++` option -- **Grouped GEMM for Mixed Dtype ([#457](https://github.com/intel/cutlass-sycl/pull/457))**: Extended grouped GEMM support to mixed precision operations - - Added support for BF16 + S8 mixed dtype grouped GEMM - - Added support for FP16 + U4 mixed dtype grouped GEMM - - New examples: `10_bmg_grouped_gemm_bf16_f16_s8.cpp` and `10_bmg_grouped_gemm_f16_u4.cpp` +- **Flash Attention Reimplementation ([#d02c58b](https://github.com/intel/sycl-tla/commit/d02c58b4))**: Complete rewrite of Flash Attention using new Xe atoms + - Enhanced performance with optimized memory access patterns + - Better integration with Intel Xe hardware capabilities +- **CUTLASS Library Generation ([#578](https://github.com/intel/sycl-tla/pull/578))**: Full support for CUTLASS library generation and operations + - New Xe architecture support in library generation pipeline + - Automated kernel instantiation and compilation support + +### Enhancements +- **Python Operations Support ([#595](https://github.com/intel/sycl-tla/pull/595))**: Enhanced Python bindings with comprehensive test coverage + - Improved Python API stability and usability + - Enhanced test framework for Python operations +- **CuTe Subgroup Extensions**: New subgroup-scope operations for Intel Xe + - Subgroup broadcast and reduction operations ([#9a6aa27](https://github.com/intel/sycl-tla/commit/9a6aa27c)) + - `make_subgroup_tensor` helpers for improved tensor manipulation ([#21fb89a](https://github.com/intel/sycl-tla/commit/21fb89a8)) +- **Enhanced 2D Copy Operations**: Extended block 2D copy functionality + - New `make_block_2d_copy_{C,D}` variants with subtiling support ([#48d82e8](https://github.com/intel/sycl-tla/commit/48d82e87)) + - Support for size-1 fragments in block 2D copies ([#2212f1b](https://github.com/intel/sycl-tla/commit/2212f1b9)) +- **4-bit VNNI Reorders ([#593](https://github.com/intel/sycl-tla/pull/593))**: New 4-bit unit stride to VNNI reorder operations +- **Batch GEMM with new APIs ([#540](https://github.com/intel/sycl-tla/pull/540))**: Enhanced Batch GEMM with new streamlined APIs +- **Grouped GEMM with new APIs ([#574](https://github.com/intel/sycl-tla/pull/574))**: Enhanced grouped GEMM with new streamlined APIs **See the [CHANGELOG](CHANGELOG-SYCL.md) for details of all past releases and updates.** # CuTe -CUTLASS-SYCL supports the newly introducted core library, CuTe, to describe and manipulate tensors of threads and data. -CuTe in CUTLASS-SYCL is a collection of C++ SYCL template abstractions for +SYCL\*TLA supports the newly introduced core library, CuTe, to describe and manipulate tensors of threads and data. +CuTe in SYCL\*TLA is a collection of C++ SYCL template abstractions for defining and operating on hierarchically multidimensional layouts of threads and data. CuTe provides `Layout` and `Tensor` objects that compactly package the type, shape, memory space, and layout of data, while performing the complicated indexing for the user. @@ -81,7 +89,7 @@ The representation of layouts is powerful enough to represent nearly everything we need to implement efficient dense linear algebra. Layouts can also be combined and manipulated via functional composition, on which we build a large set of common operations such as tiling and partitioning. -CUTLASS-SYCL and beyond adopts CuTe throughout the GEMM hierarchy in its templates. +SYCL\*TLA and beyond adopts CuTe throughout the GEMM hierarchy in its templates. This greatly simplifies the design and improves code composability and readability. More documentation specific to CuTe can be found in its [dedicated documentation directory](./media/docs/cpp/cute/00_quickstart.md). @@ -97,7 +105,7 @@ Minimum requirements: ## Hardware Support -CUTLASS-SYCL runs successfully on the following Intel GPUs. +SYCL*TLA runs successfully on the following Intel GPUs. |**GPU**|**Intel GPU Architecture** |---|---| @@ -119,7 +127,7 @@ We are regularly testing following setup in CI. ## Target Architecture -The target architecture information is passed on to CUTLASS-SYCL via the cmake flag +The target architecture information is passed on to SYCL*TLA via the cmake flag `DPCPP_SYCL_TARGET`. ``` @@ -156,13 +164,13 @@ CUTLASS is described in the following documents and the accompanying # Resources -# Building CUTLASS-SYCL +# Building SYCL*TLA -CUTLASS-SYCL is a header-only template library and does not need to be built to be used by other -projects. Client applications should target CUTLASS-SYCL's `include/` directory in their include +SYCL*TLA is a header-only template library and does not need to be built to be used by other +projects. Client applications should target SYCL*TLA's `include/` directory in their include paths. -CUTLASS-SYCL unit tests, examples, and utilities can be built with CMake. +SYCL*TLA unit tests, examples, and utilities can be built with CMake. The minimum version of CMake is given in the [Quickstart guide](./media/docs/cpp/quickstart.md). Make sure you have Intel oneAPI DPC++ compiler installed and the environment is properly set up. @@ -170,7 +178,7 @@ Make sure you have Intel oneAPI DPC++ compiler installed and the environment is $ source /opt/intel/oneapi/setvars.sh ``` -Create a build directory within the CUTLASS-SYCL project, then run CMake. You need to specify +Create a build directory within the SYCL*TLA project, then run CMake. You need to specify the target Intel GPU architecture using the `DPCPP_SYCL_TARGET` flag. For Intel Data Center GPU Max Series (Ponte Vecchio), use `intel_gpu_pvc`. For Intel Arc GPU B580 Graphics, use `intel_gpu_bmg_g21`. @@ -193,9 +201,9 @@ To compile with G++ as host compiler, add the flag `-DDPCPP_HOST_COMPILER=g++-13 $ CC=icx CXX=icpx cmake .. -G Ninja -DCUTLASS_ENABLE_SYCL=ON -DDPCPP_HOST_COMPILER=g++-13 -DDPCPP_SYCL_TARGET="intel_gpu_bmg_g21" # compiles for Intel Arc GPU B580 Graphics with G++ as host compiler ``` -From the `build/` directory, compile and run the CUTLASS-SYCL unit tests by building the target `test_unit` with make. +From the `build/` directory, compile and run the SYCL*TLA unit tests by building the target `test_unit` with make. -The unit tests are organized as several binaries mirroring the top-level namespaces of CUTLASS-SYCL, +The unit tests are organized as several binaries mirroring the top-level namespaces of SYCL*TLA, and they may be executed in parallel via make's `-j` command line argument. ```bash @@ -213,12 +221,12 @@ All tests should pass on supported Intel GPU platforms, though the exact number # Project Structure -CUTLASS-SYCL is arranged as a header-only library along with Utilities, Tools, Examples, and unit tests. +SYCL*TLA is arranged as a header-only library along with Utilities, Tools, Examples, and unit tests. -A detailed explanation of the source code organization may be found in the -[CUTLASS-SYCL documentation](./media/docs/cpp/code_organization.md), but several main components are summarized below. +A detailed explanation of the source code organization may be found in the +[SYCL*TLA documentation](./media/docs/cpp/code_organization.md), but several main components are summarized below. -## CUTLASS-SYCL Template Library +## SYCL*TLA ``` include/ # client applications should target this directory in their build's include paths @@ -263,23 +271,23 @@ include/ # client applications should target this directory ``` -### CUTLASS SDK Examples +### SYCL*TLA Examples -[CUTLASS SDK examples](./examples) apply CUTLASS templates to implement basic computations. +[SYCL*TLA examples](./examples) apply SYCL*TLA templates to implement basic computations. ### Tools ``` tools/ - library/ # CUTLASS-SYCL Instance Library - contains instantiations of all supported CUTLASS-SYCL templates + library/ # SYCL*TLA Instance Library - contains instantiations of all supported SYCL*TLA templates include/ cutlass/ library/ - profiler/ # CUTLASS Profiler - SYCL support not yet available + profiler/ # Profiler - SYCL support not yet available # (command-line utility for executing operations) - util/ # CUTLASS-SYCL Utilities - contains numerous helper classes for + util/ # Utilities - contains numerous helper classes for include/ # managing tensors in Intel GPU device memory, reference cutlass/ # implementations for SYCL GEMM, random initialization util/ # of tensors, and I/O for Intel GPU environments. @@ -294,12 +302,40 @@ Instructions for building and running the Unit tests are described in the [Quick # About -CUTLASS-SYCL is released by INTEL Corporation as Open Source software under the +SYCL*TLA is released by INTEL Corporation as Open Source software under the [3-clause "New" BSD license](LICENSE.txt). # Contributors -The official list of CUTLASS-SYCL developers and contributors is available here: [CONTRIBUTORS](CONTRIBUTORS.md). +The official list of SYCL*TLA developers and contributors is available here: [CONTRIBUTORS](CONTRIBUTORS.md). + +# Contributing + +## Pull Request Templates + +We provide concise PR templates to streamline documentation: + +### Quick Start + +**GitHub CLI:** +```bash +gh pr create --template .github/PULL_REQUEST_TEMPLATE/bug_fix.md +gh pr create --template .github/PULL_REQUEST_TEMPLATE/performance.md +gh pr create --template .github/PULL_REQUEST_TEMPLATE/feature.md +gh pr create --template .github/PULL_REQUEST_TEMPLATE/refactoring.md +``` + +**GitHub Web:** Add `?template=.md` to PR URL (e.g., `?template=bug_fix.md`) + +### Which Template? + +- 🐛 **Bug fixes** → `bug_fix.md` - Root cause + verification +- ⚡ **Performance** → `performance.md` - Profiling data + benchmarks +- ✨ **Features** → `feature.md` - API design + examples +- 🔨 **Refactoring** → `refactoring.md` - Refactored/Redesigned code +- 📝 **Mixed/Other** → Default template + +See [`.github/PULL_REQUEST_TEMPLATE/README.md`](.github/PULL_REQUEST_TEMPLATE/README.md) for details. # Copyright diff --git a/SECURITY.md b/SECURITY.md index 87f1750e7f..56b64f8994 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -2,7 +2,4 @@ ## Reporting a Vulnerability -To report a vulnerability or a security issue please fill the security -advisories form [here](../../security/advisories/new), send an email to -security@codeplay.com or contact us using the [contact form on our web -page](https://codeplay.com/company/contact/?q=Report%20Security%20Issue). +Please report any security vulnerabilities in this project utilizing the guidelines [here](https://www.intel.com/content/www/us/en/security-center/vulnerability-handling-guidelines.html). diff --git a/applications/flash_attention_v2/collective/copy_block_slm.hpp b/applications/flash_attention_v2/collective/copy_block_slm.hpp new file mode 100644 index 0000000000..41b34a2140 --- /dev/null +++ b/applications/flash_attention_v2/collective/copy_block_slm.hpp @@ -0,0 +1,169 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +namespace cute { + +/* Flat copies */ +template +CUTE_HOST_DEVICE +void +copy_block_r2s(Tensor const& src, + Tensor & dst) +{ + static_assert(is_rmem_v && is_smem_v, "Expected rmem->smem copy"); + + auto atom_r2s = Copy_Atom, float>{}; // TODO: larger block messages + + auto atom_shape = make_shape(_1{}, size(src)); + auto src_v = src.compose(make_layout(atom_shape)); + auto dst_v = dst.compose(make_layout(atom_shape, Stride<_0, _16>{})); + + copy(atom_r2s, src_v, dst_v); +} + +template +CUTE_HOST_DEVICE +void +copy_block_s2r(Tensor const& src, + Tensor & dst) +{ + static_assert(is_smem_v && is_rmem_v, "Expected smem->rmem copy"); + + auto atom_s2r = Copy_Atom, float>{}; + + auto atom_shape = make_shape(_1{}, size(dst)); + auto src_v = src.compose(make_layout(atom_shape, Stride<_0, _16>{})); + auto dst_v = dst.compose(make_layout(atom_shape)); + + copy(atom_s2r, src_v, dst_v); +} + +/* Coordinate-aware copies */ +template +CUTE_HOST_DEVICE +void +copy_block_r2s(SubgroupTensor const& src, + Tensor & dst, + DstCoordLayout const& dst_c) +{ + using _SG = intel::_SGSize; + + static_assert(is_rmem_v && is_smem_v, "Expected rmem->smem copy"); + static_assert(sizeof_bits_v == 32, "Only 32-bit data supported"); + + auto atom_r2s = Copy_Atom, float>{}; // TODO: larger block messages + + auto atom_shape = make_shape(_1{}, size(SrcLayout{})); + + auto src_c_wi0 = composition(project_strides(SrcCoordLayout{}), make_layout(atom_shape, Stride<_0, _SG>{})); + auto rlayout = composition(right_inverse(project_strides(dst_c)), src_c_wi0); + + auto src_v = src.compose(make_layout(atom_shape)); + auto dst_v = dst.compose(rlayout); + + copy(atom_r2s, src_v, dst_v); +} + +template +CUTE_HOST_DEVICE +void +copy_block_s2r(Tensor const& src, + SrcCoordLayout const& src_c, + SubgroupTensor & dst) +{ + using _SG = intel::_SGSize; + + static_assert(is_smem_v && is_rmem_v, "Expected smem->rmem copy"); + static_assert(sizeof_bits_v == 32, "Only 32-bit data supported"); + + auto atom_s2r = Copy_Atom, float>{}; + + auto atom_shape = make_shape(_1{}, size(DstLayout{})); + + auto dst_c_wi0 = composition(project_strides(DstCoordLayout{}), make_layout(atom_shape, Stride<_0, _SG>{})); + auto rlayout = composition(right_inverse(project_strides(src_c)), dst_c_wi0); + + auto src_v = src.compose(rlayout); + auto dst_v = dst.compose(make_layout(atom_shape)); + + copy(atom_s2r, src_v, dst_v); +} + +/* Variants accepting rvalue dst */ +template +CUTE_HOST_DEVICE +void +copy_block_r2s(Tensor const& src, + Tensor && dst) +{ + return copy_block_r2s(src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy_block_s2r(Tensor const& src, + Tensor && dst) +{ + return copy_block_s2r(src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy_block_r2s(SubgroupTensor const& src, + Tensor && dst, + DstCoordLayout const& dst_c) +{ + return copy_block_r2s(src, dst, dst_c); +} + +template +CUTE_HOST_DEVICE +void +copy_block_s2r(Tensor const& src, + SrcCoordLayout const& src_c, + SubgroupTensor && dst) +{ + return copy_block_s2r(src, dst); +} + +} /* namespace cute */ \ No newline at end of file diff --git a/applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp b/applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp new file mode 100644 index 0000000000..efa54931d3 --- /dev/null +++ b/applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp @@ -0,0 +1,294 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/detail/layout.hpp" + +#include "cute/algorithm/subgroup_algorithms.hpp" +#include "cute/algorithm/tensor_algorithms.hpp" + +#include "copy_block_slm.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template // Optional TiledCopy for loading O +class FMHAFwdEpilogue { + +public: + // + // Type Aliases + // + using TiledMMAPV = typename CollectiveMainloop::TiledMMAPV; + using TileShapePV = decltype(TiledMMAPV{}.tile_mnk()); + using TileShapeO = TileShapeO_; + using SGPerWG = decltype(product(take<1,4>(shape(typename TiledMMAPV::ThrLayoutVMNK{})))); + + using TensorO = TensorO_; + using TensorO2D = decltype(TensorO_{}(append>(make_coord(_,_),0))); + using ElementO = typename TensorO_::value_type; + + using FragA = typename CollectiveMainloop::FragA; + using FragARow = typename CollectiveMainloop::FragARow; + using ElementA = typename FragA::value_type; + + // Split k-reduced tiles between participating subgroups. + // Assumption: the A tile is contiguous. + using ReduceK = decltype(size<3>(typename TiledMMAPV::ThrLayoutVMNK{})); + + static auto reduce_sg_v_helper() { + constexpr auto v_total_sg = get<1>(SGTileShapeA{}) / intel::_SGSize{}; + constexpr auto v_avail_sg = ReduceK{} / ReduceSGQ{}; + return Int<(v_total_sg > v_avail_sg) ? cute::gcd(v_total_sg, v_avail_sg) : v_total_sg>{}; + } + + using SGTileShapeA = decltype(atuple_coshape(FragA{}.tv_layout())); + using ReduceSGQ = decltype(cute::gcd(get<0>(SGTileShapeA{}), ReduceK{})); + using ReduceSGV = decltype(reduce_sg_v_helper()); + using ReduceSGLayout = decltype(make_identity_layout(Shape{})); + + using SGTileShapeO = decltype(shape_div(take<0,2>(SGTileShapeA{}), shape(ReduceSGLayout{}))); + + using ReduceFragA = decltype(make_subgroup_tensor( + make_layout(select<1,0>(SGTileShapeO{}), + Stride, E<0>>{}) + )); + using ReduceFragARow = decltype(reduce<1>(ReduceFragA{}, sycl::plus{})); + + static auto default_tiled_copy_O_helper() { + if constexpr (ReduceK{} == _1{}) + return make_block_2d_copy_D(TiledMMAPV{}, TensorO2D{}); + else + return make_block_2d_copy_D_subtiled(TiledMMAPV{}, ReduceFragA{}.tv_layout(), ReduceSGLayout{}, TensorO2D{}); + } + + using DefaultTiledCopyO = decltype(default_tiled_copy_O_helper()); + using TiledCopyO = conditional_t, DefaultTiledCopyO, TiledCopyO_>; + + // Stateless design -- no arguments or parameters. + struct Arguments {}; + struct Params {}; + + // Shared memory storage + // Note sum/max tiles are padded to 16 elements, due to limitations in CuTe block load infrastructure. + using AlignedSGTileA_Q = C<((size<0>(SGTileShapeA{}) + intel::sg_size - 1) / intel::sg_size) * intel::sg_size>; + + struct SharedStorageNone {}; + struct SharedStorageReduceK { + cute::array a_data; + cute::array a_sum_data, a_max_data; + }; + + using SharedStorage = conditional_t<(ReduceK{} > _1{}), SharedStorageReduceK, SharedStorageNone>; + +private: + SharedStorage &shared; + +public: + static constexpr + Params to_underlying_arguments(Arguments const &args, void * /* workspace */) { + return {}; + } + + CUTLASS_HOST_DEVICE static bool can_implement(Arguments const&) { + return true; + } + + CUTLASS_HOST_DEVICE + FMHAFwdEpilogue(Params const&, SharedStorage& shared_) : shared(shared_) {} + + template + CUTLASS_DEVICE + void + operator()(TensorO2D const& O, // Global O tensor: (q,v) + FragA & tArA, // O accumulator: (q,v) + FragARow & tA_max, // Softmax row-wise max accumulator + FragARow & tA_sum, // Softmax row-wise sum accumulator + QVCoord blk_qv, // WG tile indices: (q,v) + int thr_id) { // Work-item ID + + using namespace cute; + using ElementA = typename FragA::element_type; + + // Reduce k-blocks of A and A_sum across WG, if needed. + auto [rA, rA_sum, active] = reduce_A(tArA, tA_max, tA_sum, thr_id); + + /* Some subgroups may not have any work to do; if so, quit early. */ + if (!active) return; + + /* Complete softmax, dividing out sums. */ + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < rA_sum.size(); i++) + rA_sum(i) = ElementA(1) / rA_sum(i); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < rA.size(); i++) + rA(i) *= broadcast<0>(rA_sum, rA, i); + + /* Tile output */ + Tensor cO = make_identity_tensor(O.shape()); // (q,v) + Tensor gO = local_tile(cO, TileShapeO{}, blk_qv); // (q,v) + + /* Prepare slices */ + TiledCopyO copy_o{O}; + auto thr_copy_o = copy_o.get_slice(thr_id); + + auto tOrO = thr_copy_o.partition_sg_fragment_S(gO); + auto tOgO = thr_copy_o.partition_D(gO); + + /* Reorder tile and write out */ + reorder(rA, tOrO); + copy(copy_o, tOrO, tOgO); + } + + // Reduce k-blocks of A and A_sum across WG, if needed. + // Note that each k block has its own scale factor based on A_max, + // so A/A_sum contributions need to be rescaled to match. + template + CUTLASS_DEVICE + decltype(auto) + reduce_A(FragA & tArA, // O accumulator: (q,v) + FragARow & tA_max, // Softmax row-wise max accumulator + FragARow & tA_sum, // Softmax row-wise sum accumulator + int thr_id) { // Work-item ID + + using namespace sycl::ext::oneapi::this_work_item; + + if constexpr (ReduceK{} == _1{}) { + return std::make_tuple(tArA, tA_sum, true); + } else { + /* Identify A tile ID and k block for this subgroup. */ + auto thr_vak = group<1,3>(TiledMMAPV{}.get_thr_layout_vmnk()).get_flat_coord(assert_uniform(thr_id)); + auto a_tile = get<1>(thr_vak); + auto k_blk = get<2>(thr_vak); + + /* Set up SLM tensors and partition A tiles among participating subgroups */ + auto shape_A = append(append(SGTileShapeA{}, ReduceK{}), SGPerWG{}/ReduceK{}); + auto shape_A_row = make_shape(get<0>(SGTileShapeO{}), shape(ReduceSGLayout{}), ReduceK{}, SGPerWG{}/ReduceK{}); + + /* Physical layouts, with subtile modes broken out */ + auto sA_layout = group<2,4>(flat_divide(make_ordered_layout(shape_A, Step<_1,_0,_2,_3>{}), SGTileShapeO{})); + auto sA_row_stride = make_stride(_1{}, make_stride(get<0>(shape_A_row), _0{}), + AlignedSGTileA_Q{}, AlignedSGTileA_Q{} * ReduceK{}); + auto sA_row_layout = make_layout(shape_A_row, sA_row_stride); + + /* Coordinate layouts, with subtile modes broken out */ + auto basis2 = make_basis_like(SGTileShapeO{}); + auto sA_coords = make_layout(append(SGTileShapeO{}, shape(ReduceSGLayout{})), + append(basis2, product_each(zip(SGTileShapeO{}, basis2)))); + + auto sA = make_tensor(make_smem_ptr(&shared.a_data), sA_layout); // (q,v,rblk_dst,rblk_src,a_tile) + auto sA_max = make_tensor(make_smem_ptr(&shared.a_max_data), sA_row_layout); // (q,rblk_dst,rblk_src,a_tile) + auto sA_sum = make_tensor(make_smem_ptr(&shared.a_sum_data), sA_row_layout); // (q,rblk_dst,rblk_src,a_tile) + + /* Write my contributions to SLM. */ + copy_block_r2s(tA_max, sA_max(_,_,k_blk,a_tile)); + barrier_arrive(ScopeWorkgroup, SemanticsRelease | SemanticsWGMemory); + copy_block_r2s(tA_sum, sA_sum(_,_,k_blk,a_tile)); + copy_block_r2s(tArA, sA(_,_,_,k_blk,a_tile), sA_coords); + + bool active = (k_blk < size(ReduceSGLayout{})) + || (ReduceK{} == size(ReduceSGLayout{})); // help compiler out + + /* Wait for maxima to be available, signal other data available */ + barrier_wait(ScopeWorkgroup, SemanticsAcquire | SemanticsWGMemory); + barrier_arrive(ScopeWorkgroup, SemanticsRelease | SemanticsWGMemory); + + ReduceFragA rA; + ReduceFragARow rA_sum, rA_max, rA_kmax[ReduceK{}]; + + if (active) { + /* Read A_max back from SLM and reduce. */ + CUTLASS_PRAGMA_UNROLL + for (int kr = 0; kr < ReduceK{}; kr++) { + copy_block_s2r(sA_max(_,k_blk,kr,a_tile), rA_kmax[kr]); + } + + rA_max = rA_kmax[0]; + for (int kr = 1; kr < ReduceK{}; kr++) + cute::transform(rA_max, rA_kmax[kr], rA_max, cute::max_fn{}); + + /* Calculate scale factors for aligning per-block maxima. */ + for (int kr = 0; kr < ReduceK{}; kr++) { + cute::transform(rA_max, rA_kmax[kr], rA_kmax[kr], [](auto gmax, auto kmax) { + return sycl::native::exp2(kmax - gmax); + }); + } + } + + /* Wait for A/A_sum data to be available */ + barrier_wait(ScopeWorkgroup, SemanticsAcquire | SemanticsWGMemory); + + if (active) { + /* Read A/A_sum back from SLM, align scaling to new maxima, and reduce. */ + clear(rA_sum); + + CUTLASS_PRAGMA_UNROLL + for (int kr = 0; kr < ReduceK{}; kr++) { + ReduceFragARow rA_sum_read; + copy_block_s2r(sA_sum(_,k_blk,kr,a_tile), rA_sum_read); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < rA_sum_read.size(); i++) { + rA_sum(i) += rA_sum_read(i) * rA_kmax[kr](i); + } + } + + clear(rA); + + CUTLASS_PRAGMA_UNROLL + for (int kr = 0; kr < ReduceK{}; kr++) { + ReduceFragA rA_read; + copy_block_s2r(sA(_,_,k_blk,kr,a_tile), sA_coords(_,_,0), rA_read); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < rA_read.size(); i++) { + rA(i) += rA_read(i) * broadcast<0>(rA_kmax[kr], rA, i); + } + } + } + return std::make_tuple(rA, rA_sum, active); + } + } +}; + + +} // namespace cutlass::fmha::collective diff --git a/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp b/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp new file mode 100644 index 0000000000..b7c400a63a --- /dev/null +++ b/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp @@ -0,0 +1,428 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/algorithm/subgroup_algorithms.hpp" +#include "cute/atom/mma_atom.hpp" +#include "fmha_fusion.hpp" + +namespace cutlass::fmha { + +template class XeDefault {}; // Default FMHA mainloop, P in registers. + +}; + +namespace cutlass::fmha::collective { + +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template // Optional TiledCopy for loading V +struct FMHAFwdMainloop { + static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct FMHAFwdMainloop, CausalMask_, + TiledMMAQK_, TiledMMAPV_, VTiles_, + TensorQ_, TensorK_, TensorV_, + TiledCopyQ_, TiledCopyK_, TiledCopyV_> { + // + // Type Aliases + // + using TiledMMAQK = TiledMMAQK_; + using TiledMMAPV = TiledMMAPV_; + using TileShapeQK = decltype(TiledMMAQK{}.tile_mnk()); + using TileShapePV = decltype(TiledMMAPV{}.tile_mnk()); + static constexpr int VTiles = VTiles_; + using SubgroupLayoutQK = decltype(TiledMMAQK{}.get_atom_layout_mnk()); + using SGPerWG = decltype(product(take<1,4>(shape(typename TiledMMAQK::ThrLayoutVMNK{})))); + + using TensorQ = TensorQ_; + using TensorK = TensorK_; + using TensorV = TensorV_; + + using TensorQ2D = decltype(TensorQ_{}(append>(make_coord(_,_),0))); + using TensorK2D = decltype(TensorK_{}(append>(make_coord(_,_),0))); + using TensorV2D = decltype(TensorV_{}(append>(make_coord(_,_),0))); + + using TiledCopyQ = conditional_t, decltype(make_block_2d_copy_A(TiledMMAQK{}, TensorQ2D{})), TiledCopyQ_>; + using TiledCopyK = conditional_t, decltype(make_block_2d_copy_B(TiledMMAQK{}, TensorK2D{})), TiledCopyK_>; + using TiledCopyV = conditional_t, decltype(make_block_2d_copy_B(TiledMMAPV{}, TensorV2D{})), TiledCopyV_>; + + // TODO: static_asserts on TiledMMAPV here... + + // + // Accumulator types + // + // FragS: accumulator for Q*K MMA + // FragO: accumulator for P*V MMAs. + // Note: v mode may be split into multiple pieces + // to reduce register pressure. + // Frag*Row types are reductions of the corresponding Frag* types + // over rows. + // + template + using FragC = decltype(TiledMMA{}.get_slice(0).partition_sg_fragment_C( + make_identity_tensor(select<0,1>(TiledMMA{}.tile_mnk())))); + + using FragS = FragC; + using FragSRow = decltype(reduce<1>(FragS{}, sycl::plus{})); + using ElementS = typename TiledMMAQK::ValTypeD; + + using SingleFragA = FragC; // (atom val,q',v') + using FragA = expand_sg_fragment_t; // (atom val,q',v',VV) + using FragARow = decltype(reduce<1>(FragA{}, sycl::plus{})); + using ElementA = typename TiledMMAPV::ValTypeD; + + static constexpr bool CausalMask = CausalMask_; + + // User-facing arguments + struct Arguments { + ElementS const scale; + }; + + // Kernel-facing parameters + using Params = Arguments; + + // SLM data + struct SharedStorage {}; + + Params params; + + // + // Methods + // + + FMHAFwdMainloop(Params const& params_, SharedStorage&) : params(params_) {} + + static constexpr + Params to_underlying_arguments(Arguments const &args, void * /* workspace */) { + constexpr double kLog2e = 1.4426950408889634074; // log_2(e) + ElementS val = args.scale * static_cast(kLog2e); + return Params{val}; + } + + CUTLASS_HOST_DEVICE static + bool can_implement(Arguments const&) { + return true; + } + + template + CUTLASS_DEVICE + void + operator()(TensorQ2D const& Q_2D, // (q,d) + TensorK2D const& K_2D, // (k,d) + TensorV2D const& V_2D, // (d,k) + FragA & tArA, // Output accumulator (q,v) + FragARow & tA_max, // Softmax row-wise max accumulator + FragARow & tA_sum, // Softmax row-wise sum accumulator + QVCoord blk_qv, // WG tile indices: (Q,V) + int blk_k0, // K block range: [K0,K1) + int blk_k1, + int total_blk, // Total # of K blocks + int thr_id, + int seq_len, + int full_tile_offset, + int discard_seq_coord) { + using namespace sycl::ext::oneapi::this_work_item; + + // Short dimension names: + // q = sequence len dimension for Q + // k = sequence len dimension for K + // d = head size dimension for K/Q + // v = head size dimension for V + // VV = MMA tile indices for V + // Capital letters (Q, K, ...) refer to WG block indices. + // Primed letters (q', k', ...) refer to atom block indices. + + auto tile_shape_v = make_shape(get<1>(TileShapePV{}) * C{}, get<2>(TileShapePV{})); + + /* Create proxy coordinate tensors for Q/K/P/V */ + Tensor cQ = make_identity_tensor(Q_2D.shape()); // (q,d) + Tensor cK = make_identity_tensor(K_2D.shape()); // (k,d) + Tensor cV = make_identity_tensor(V_2D.shape()); // (v,k) + Tensor cP = make_identity_tensor(take<0,2>(TileShapeQK{})); // (q,k) + + /* Partition global tensors into workgroup tiles */ + Tensor gQ = local_tile(cQ, TileShapeQK{}, append(blk_qv,_), Step<_1,X,_1>{}); // (q,d,D) + Tensor gK = local_tile(cK, TileShapeQK{}, make_coord(_,_,_), Step{}); // (k,d,K,D) + Tensor gV = local_tile(cV, tile_shape_v, make_coord(get<1>(blk_qv),_)); // (v,k,K) + Tensor gV_split = local_tile(gV, TileShapePV{}, make_coord(_,_,0), Step{}); // (v,k,VV,K) + + /* Create global -> register copies */ + TiledCopyQ copy_q{Q_2D}; + TiledCopyK copy_k{K_2D}; + TiledCopyV copy_v{V_2D}; + + /* Create MMAs */ + TiledMMAQK mma_qk{}; + TiledMMAPV mma_pv{}; + + /* Slice TiledCopy/TiledMMA operations down to to work-item level */ + auto thr_copy_q = copy_q.get_slice(thr_id); + auto thr_copy_k = copy_k.get_slice(thr_id); + auto thr_copy_v = copy_v.get_slice(thr_id); + auto thr_mma_qk = mma_qk.get_slice(thr_id); + auto thr_mma_pv = mma_pv.get_slice(thr_id); + + /* Partition coordinate tensors for copy */ + auto tQgQ = thr_copy_q.partition_S(gQ); // (atom_val,q',d',D) + auto tKgK = thr_copy_k.partition_S(gK); // (atom_val,k',d',K,D) + auto tVgV = thr_copy_v.partition_S(gV_split); // (atom_val,v',k',VV,K) + + /* Create register fragments for MMA and copies */ + auto tQrQ = thr_copy_q.partition_sg_fragment_D(gQ(_,_,0)); + auto tSrQ = thr_mma_qk.partition_sg_fragment_A(gQ(_,_,0)); + + auto tKrK = thr_copy_k.partition_sg_fragment_D(gK(_,_,0,0)); + auto tSrK = thr_mma_qk.partition_sg_fragment_B(gK(_,_,0,0)); + + auto tSrS = thr_mma_qk.partition_sg_fragment_C(cP); + auto tArP = thr_mma_pv.partition_sg_fragment_A(cP); + + auto tVrV = thr_copy_v.partition_sg_fragment_D(gV_split(_,_,0,0)); + auto tArV = thr_mma_pv.partition_sg_fragment_B(gV_split(_,_,0,0)); + + /* Create TiledCopy objects for prefetches */ + auto prefetch_q = make_block_2d_prefetch(copy_q); + auto prefetch_k = make_block_2d_prefetch(copy_k); + auto prefetch_v = make_block_2d_prefetch(tile_shape_v, V_2D); + + /* Partition global tensors for prefetch */ + auto pQgQ = prefetch_q.get_slice(thr_id).partition_S(gQ); + auto pKgK = prefetch_k.get_slice(thr_id).partition_S(gK); + auto pVgV = prefetch_v.get_slice(thr_id).partition_S(gV); + + // ------ + // Kernel + // ------ + + /* Initialization steps for first block: Q/K prefetch, O init */ + /* TODO: limit D prefetch for large head size, and reorder K prefetches */ + if (blk_k0 == 0) { + for (int D = 0; D < size<3>(pQgQ); D++) { + prefetch(prefetch_q, pQgQ(_,_,_,D)); + } + + for (int D = 0; D < size<4>(pKgK); D++) { + CUTLASS_PRAGMA_UNROLL + for (int K = 0; K < Stages; K++) { + prefetch(prefetch_k, pKgK(_,_,_,K,D)); + } + } + + clear(tArA); + fill(tA_max, cutlass::platform::numeric_limits::lowest()); + clear(tA_sum); + } + + /* Check if */ + bool check_remainder_k = (seq_len % get<1>(TileShapeQK{}) != 0); + + /* Main loop, blocked in k. */ + for (int K = blk_k0; K < blk_k1; K++) { + /* Split barrier to keep threads together */ + barrier_arrive(ScopeWorkgroup); + + /* GEMM 1: S = K * Q */ + clear(tSrS); /* TODO: fuse w/ initial gemm call */ + for (int D = 0; D < size<4>(tKgK); D++) { + copy(copy_q, tQgQ(_,_,_,D), tQrQ); + copy(copy_k, tKgK(_,_,_,K,D), tKrK); + + reorder(tQrQ, tSrQ); + reorder(tKrK, tSrK); + + cute::gemm(mma_qk, tSrQ, tSrK, tSrS); + } + + /* V prefetch for GEMM 2 */ + prefetch(prefetch_v, pVgV(_,_,_,K)); + + /* Causal masking */ + if constexpr (CausalMask) { + if (K == blk_k1 - 1) { + // Need to get global col and row indices to mask the elements + Tensor cPgP = make_identity_tensor(make_shape(seq_len, seq_len)); + Tensor gP = local_tile(cPgP, take<0,2>(TileShapeQK{}), make_coord(get<0>(blk_qv), K)); + auto cS_thread = thr_mma_qk.partition_C(gP); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tSrS.size(); ++i) { + int row_idx = get<0>(cS_thread(i)); + int col_idx = get<1>(cS_thread(i)); + if (col_idx - full_tile_offset > row_idx - discard_seq_coord) { + tSrS(i) = ElementS(-INFINITY); + } + } + } + } + /* k masking for remainder tiles */ + if (check_remainder_k && K == total_blk - 1) { + FragSRow k_rem_mask; + int k = get<0>(tKgK(0,0,0,K,0)) + get_sub_group().get_local_id()[0]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < k_rem_mask.size(); i++, k += intel::sg_size) { + k_rem_mask(i) = (k < seq_len) ? ElementS(sycl::nan(0u)) : ElementS(-INFINITY); + } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tSrS.size(); i++) { + tSrS(i) = sycl::fmin(tSrS(i), broadcast<1>(k_rem_mask, tSrS, i)); + } + } + + /* Apply softmax and scaling */ + softmax(K == 0, tSrS, tA_max, tA_sum, tArA); + reorder(tSrS, tArP); + + /* GEMM 2: A += P * V, split in v dimension */ + CUTLASS_PRAGMA_UNROLL + for (int VV = 0; VV < VTiles; VV++) { + copy(copy_v, tVgV(_,_,_,VV,K), tVrV); + reorder(tVrV, tArV); + cute::gemm(mma_pv, tArP, tArV, tArA(_,_,_,VV)); + } + + /* K prefetch */ + for (int D = 0; D < size<4>(pKgK); D++) { + prefetch(prefetch_k, pKgK(_,_,_,K+Stages,D)); + } + + barrier_wait(ScopeWorkgroup); + } + } + + // Single step of blocked softmax. + CUTLASS_DEVICE + void + softmax(bool first_block, // First softmax block? + FragS & tS, // Softmax src/dst block + FragSRow & tS_max, // Softmax row-wise max accumulator + FragSRow & tS_sum, // Softmax row-wise sum accumulator + FragA & tA) { // O accumulator (for rescaling) + + /* Compute row-wise maxima for this block */ + auto tS_bmax = reduce<1>(tS, sycl::maximum{}); + + /* Update (scaled) maxima */ + auto tS_prev_max = tS_max; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tS_max.size(); i++) { + tS_max(i) = sycl::max(tS_max(i), params.scale * tS_bmax(i)); + } + + /* Scale S and subtract maxima, then exponentiate */ + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tS.size(); i++) + tS(i) = sycl::native::exp2(params.scale * tS(i) - broadcast<0>(tS_max, tS, i)); + + /* Rescale existing S sums and O accumulator */ + if (!first_block) { + FragSRow rescale; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tS_max.size(); i++) { + rescale(i) = sycl::native::exp2(tS_prev_max(i) - tS_max(i)); + tS_sum(i) *= rescale(i); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tA.size(); i++) + tA(i) *= broadcast<0>(rescale, tA, i); + } + + /* Update sums */ + auto tS_bsum = reduce<1>(tS, sycl::plus{}); + for (int i = 0; i < tS_sum.size(); i++) + tS_sum(i) += tS_bsum(i); + } +}; + + +template +CUTLASS_HOST_DEVICE +constexpr auto +get_sg_layout_pv(SGLayoutQK const&) +{ + return make_layout( + get<0>(SGLayoutQK{}), + Layout<_1, _0>{}, + get<1>(SGLayoutQK{}) + ); +} + +// Get a P*V TiledMMA given K*Q tile size and SG configuration, for mainloops +// not supporting S data interchange among subgroups (e.g. XeDefault). +template +CUTLASS_HOST_DEVICE +constexpr auto +get_tiled_mma_pv(MMAOp const&, WGTileQK const& wg_tile_qk, SGLayoutQK const& sg_layout_qk, TileV const&) { + using TileQ = decltype(get<0>(wg_tile_qk)); + using TileK = decltype(get<1>(wg_tile_qk)); + + using WGTilePV = Shape; + using SGLayoutPV = decltype(get_sg_layout_pv(sg_layout_qk)); + + static_assert(size(SGLayoutPV{}) == size(SGLayoutQK{}), + "Q*K cannot be parallelized in the head size dimension"); + + return TiledMMAHelper{}; +} + +} // namespace cutlass::fmha::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp b/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp new file mode 100644 index 0000000000..f5905f746a --- /dev/null +++ b/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp @@ -0,0 +1,700 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/kernel_hardware_info.hpp" + +#include "flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp" +#include "flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp" +#include "cute/util/type_traits.hpp" +#include "flash_attention_v2/collective/fmha_fusion.hpp" +#include "flash_attention_v2/kernel/xe_tile_scheduler.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////// +template +struct FMHAProblemShape { + using SeqLenType = cute::conditional_t; + int batch; + int num_heads_q, num_heads_kv; + SeqLenType seq_len_qo, seq_len_kv; + int head_size_qk, head_size_vo; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template +class XeFMHAFwdKernel { + +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + using VariableLength = cutlass::fmha::collective::VariableLength; + static constexpr bool is_var_len = cutlass::fmha::collective::is_variable_length_v; + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + + using TiledMMAQK = typename CollectiveMainloop::TiledMMAQK; + using TiledMMAPV = typename CollectiveMainloop::TiledMMAPV; + using TileShapeQK = typename CollectiveMainloop::TileShapeQK; + using TileShapePV = typename CollectiveMainloop::TileShapePV; + using SubgroupLayoutQK = typename CollectiveMainloop::SubgroupLayoutQK; + using ElementQ = typename CollectiveMainloop::TensorQ::element_type; + using ElementK = typename CollectiveMainloop::TensorK::element_type; + using ElementV = typename CollectiveMainloop::TensorV::element_type; + + using StrideQ = decltype(stride(typename CollectiveMainloop::TensorQ{})); + using StrideK = decltype(stride(typename CollectiveMainloop::TensorK{})); + using StrideV = decltype(stride(typename CollectiveMainloop::TensorV{})); + + using SGPerWG = typename CollectiveMainloop::SGPerWG; + + using FragA = typename CollectiveMainloop::FragA; + using FragARow = typename CollectiveMainloop::FragARow; + + // Tile scheduler derived types + using TileScheduler = TileScheduler_; + using TileSchedulerParams = typename TileScheduler::Params; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + using TileShapeO = typename CollectiveEpilogue::TileShapeO; + using ElementO = typename CollectiveEpilogue::TensorO::element_type; + using StrideO = decltype(stride(typename CollectiveEpilogue::TensorO{})); + + // Kernel level shared memory storage + using MainloopSharedStorage = typename CollectiveMainloop::SharedStorage; + using EpilogueSharedStorage = typename CollectiveEpilogue::SharedStorage; + union SharedStorage { + MainloopSharedStorage mainloop; + EpilogueSharedStorage epilogue; + }; + + static constexpr int SharedStorageSize = is_empty_v ? size_t(0) + : sizeof(SharedStorage); + + // Device side arguments + struct KernelArguments { + ProblemShape shape; + const ElementQ *Q; + StrideQ dQ; + const ElementK *K; + StrideK dK; + const ElementV *V; + StrideV dV; + ElementO *O; + StrideO dO; + }; + using KernelParams = KernelArguments; + + struct Arguments { + KernelArguments kernel{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + }; + + // Kernel entry point API + struct Params { + KernelParams kernel; + MainloopParams mainloop; + EpilogueParams epilogue; + TileSchedulerParams scheduler; + }; + + // + // Methods + // + + static Params to_underlying_arguments(Arguments const &args, void *workspace) { + return {args.kernel, + CollectiveMainloop::to_underlying_arguments(args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.epilogue, workspace), + TileScheduler::to_underlying_arguments(args.kernel.shape, args.hw_info, TileShapeO{})}; + } + + static bool can_implement(Arguments const &args) { + return CollectiveMainloop::can_implement(args.mainloop) + && CollectiveEpilogue::can_implement(args.epilogue); + } + + static int get_workspace_size(Arguments const &args) { return 0; } + + static cutlass::Status initialize_workspace(Arguments const &args, void *workspace = nullptr, + cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const ¶ms) { + return TileScheduler::template get_grid_shape(params.scheduler); + } + + static dim3 get_block_shape() { return dim3(SGPerWG::value * intel::sg_size, 1, 1); } + + CUTLASS_DEVICE + Shape get_sequence_length_shape(ProblemShape const& problem_shape, int const& batch) { + if constexpr (is_var_len) { + return cutlass::fmha::collective::apply_variable_length(Shape{problem_shape.seq_len_qo, problem_shape.seq_len_kv}, batch); + } else { + return Shape{problem_shape.seq_len_qo, problem_shape.seq_len_kv}; + } + } + + CUTLASS_DEVICE + void operator()(Params const ¶ms, char *smem_buf) + { + using namespace sycl::ext::oneapi::this_work_item; + + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + auto &p = params.kernel; + ProblemShape const& s = p.shape; + int head_group_q = s.num_heads_q / s.num_heads_kv; + + int thr_id = int(ThreadIdxX()); + int sub_group_id = thr_id / intel::sg_size; + int q_sg_tile = get<0>(shape_div(TileShapeQK{}, shape(SubgroupLayoutQK{}))); + + auto cS = make_identity_tensor(take<0,2>(TiledMMAQK{}.tile_mnk())); + auto tScS = TiledMMAQK{}.get_slice(thr_id).partition_C(cS); + auto q_offset_wi = get<0>(tScS(0)); + auto q_offset_sg = group_broadcast(sycl::ext::oneapi::this_work_item::get_sub_group(), q_offset_wi, 0); + + TileScheduler tile_scheduler{params.scheduler}; + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto [blk_q, blk_v, head_q, idx_b] = tile_scheduler.get_block_coord(); // (Q,V,h,b) + auto blk_qv = make_coord(blk_q, blk_v); + int head = head_q / head_group_q; + + auto sequence_length_shape = get_sequence_length_shape(s, idx_b); + auto [seq_len_qo, seq_len_kv] = sequence_length_shape; + if (blk_q * get<0>(TileShapeQK{}) >= seq_len_qo) continue; + + auto offset = cute::min(seq_len_qo, seq_len_kv); + auto discard_seq_coord = seq_len_qo - offset; + auto full_tile_offset = seq_len_kv - offset; + int seq_coord = cute::min(seq_len_qo, (blk_q * get<0>(TileShapeQK{}) + q_offset_sg)); + + if (CollectiveMainloop::CausalMask && seq_coord < discard_seq_coord) continue; + const int seq_len = CollectiveMainloop::CausalMask ? full_tile_offset + cute::min(seq_len_kv, seq_coord - discard_seq_coord) + q_sg_tile : seq_len_kv; + const int k_blocks = cute::ceil_div(seq_len, get<1>(TileShapeQK{})); + + int offset_q = 0, offset_k = 0, offset_v = 0, offset_o = 0; + if constexpr (is_var_len) { + int group_heads_q = s.num_heads_q / s.num_heads_kv; + auto qo_cumulative = s.seq_len_qo.cumulative_length; + auto kv_cumulative = s.seq_len_kv.cumulative_length; + offset_q = s.num_heads_q * s.head_size_qk * qo_cumulative[idx_b]; + offset_k = s.num_heads_kv * s.head_size_qk * kv_cumulative[idx_b]; + offset_v = s.num_heads_kv * s.head_size_vo * kv_cumulative[idx_b]; + offset_o = s.num_heads_q * s.head_size_vo * qo_cumulative[idx_b]; + } + + auto batch_dim = is_var_len ? 1 : s.batch; + auto shape_Q = make_shape(seq_len_qo, s.head_size_qk, s.num_heads_q, batch_dim); + auto shape_K = make_shape(seq_len_kv, s.head_size_qk, s.num_heads_kv, batch_dim); + auto shape_V = make_shape(s.head_size_vo, seq_len_kv, s.num_heads_kv, batch_dim); + auto shape_O = make_shape(seq_len_qo, s.head_size_vo, s.num_heads_kv, batch_dim); + + auto dcQ = const_cast(p.Q + offset_q); + auto dcK = const_cast(p.K + offset_k); + auto dcV = const_cast(p.V + offset_v); + auto ptrO = p.O + offset_o; + + auto stride_q = is_var_len ? cutlass::make_cute_packed_stride(StrideQ{}, shape_Q) : p.dQ; + auto stride_k = is_var_len ? cutlass::make_cute_packed_stride(StrideK{}, shape_K) : p.dK; + auto stride_v = is_var_len ? cutlass::make_cute_packed_stride(StrideV{}, shape_V) : p.dV; + auto stride_o = is_var_len ? cutlass::make_cute_packed_stride(StrideO{}, shape_O) : p.dO; + + Tensor Q = make_tensor(make_gmem_ptr(dcQ), make_layout(shape_Q, stride_q)); + Tensor K = make_tensor(make_gmem_ptr(dcK), make_layout(shape_K, stride_k)); + Tensor V = make_tensor(make_gmem_ptr(dcV), make_layout(shape_V, stride_v)); + Tensor O = make_tensor(make_gmem_ptr(ptrO), make_layout(shape_O, stride_o)); + + + // O accumulator types + FragA tArA; + FragARow tA_max, tA_sum; + + // Main loop + int l_coord = is_var_len ? 0 : idx_b; + CollectiveMainloop mainloop(params.mainloop, shared_storage.mainloop); + mainloop(Q(_,_,head_q,l_coord), + K(_,_,head,l_coord), + V(_,_,head,l_coord), + tArA, tA_max, tA_sum, + blk_qv, 0, k_blocks, k_blocks, + thr_id, seq_len, + full_tile_offset, discard_seq_coord); + if constexpr (!is_empty_v && !is_empty_v) { + sycl::group_barrier(get_work_group<3>()); + } + + // Epilogue + CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue}; + epilogue(O(_,_,head_q,l_coord), + tArA, tA_max, tA_sum, + blk_qv, thr_id); + } + } +}; + +template +class XeFMHAFwdDynamicSplitKernel { + +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + + using TiledMMAQK = typename CollectiveMainloop::TiledMMAQK; + using TiledMMAPV = typename CollectiveMainloop::TiledMMAPV; + using TileShapeQK = typename CollectiveMainloop::TileShapeQK; + using TileShapePV = typename CollectiveMainloop::TileShapePV; + + using ElementQ = typename CollectiveMainloop::TensorQ::element_type; + using ElementK = typename CollectiveMainloop::TensorK::element_type; + using ElementV = typename CollectiveMainloop::TensorV::element_type; + + using StrideQ = decltype(stride(typename CollectiveMainloop::TensorQ{})); + using StrideK = decltype(stride(typename CollectiveMainloop::TensorK{})); + using StrideV = decltype(stride(typename CollectiveMainloop::TensorV{})); + + using SGPerWG = typename CollectiveMainloop::SGPerWG; + + using FragA = typename CollectiveMainloop::FragA; + using SingleFragA = typename CollectiveMainloop::SingleFragA; + using FragARow = typename CollectiveMainloop::FragARow; + // element dtype for MmaPV results + using ElementA = typename CollectiveMainloop::ElementA; + + // Tile scheduler derived types + static_assert(is_same_v); + using TileScheduler = TileScheduler_; + using TileSchedulerParams = typename TileScheduler::Params; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + using TileShapeO = typename CollectiveEpilogue::TileShapeO; + using ElementO = typename CollectiveEpilogue::TensorO::element_type; + using StrideO = decltype(stride(typename CollectiveEpilogue::TensorO{})); + + // Kernel level shared memory storage + using MainloopSharedStorage = typename CollectiveMainloop::SharedStorage; + using EpilogueSharedStorage = typename CollectiveEpilogue::SharedStorage; + union SharedStorage { + MainloopSharedStorage mainloop; + EpilogueSharedStorage epilogue; + }; + + static constexpr int SharedStorageSize = is_empty_v ? size_t(0) + : sizeof(SharedStorage); + + // Important: make sure multiple of 16 element for each copy + // this is for storing partial results from different KV partitions + static constexpr int num_elem_per_thread = (size(FragA{}.shape()) + 2 * size(FragARow{}.shape()) + 15) / 16 * 16; + static const int max_num_partitions = 8; + + // Device side arguments + struct KernelArguments { + ProblemShape shape; + const ElementQ *Q; + StrideQ dQ; + const ElementK *K; + StrideK dK; + const ElementV *V; + StrideV dV; + ElementO *O; + StrideO dO; + }; + using KernelParams = KernelArguments; + + struct Arguments { + KernelArguments kernel{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + }; + + // Kernel entry point API + struct Params { + KernelParams kernel; + MainloopParams mainloop; + EpilogueParams epilogue; + TileSchedulerParams scheduler; + // workspace for storing partial results of different KV partitions + ElementA *partial_results_ptr = nullptr; + // for atomic add + int32_t *atomic_reduce_cnt_ptr = nullptr; + }; + + // + // Methods + // + + static Params to_underlying_arguments(Arguments const &args, void *workspace) { + int num_batch_heads = args.kernel.shape.batch * args.kernel.shape.num_heads_q; + int32_t *atomic_reduce_cnt_ptr = reinterpret_cast(workspace); + ElementA *partial_results_ptr = reinterpret_cast(atomic_reduce_cnt_ptr + num_batch_heads); + return {args.kernel, + CollectiveMainloop::to_underlying_arguments(args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.epilogue, workspace), + TileScheduler::to_underlying_arguments(args.kernel.shape, args.hw_info, TileShapeO{}), + partial_results_ptr, atomic_reduce_cnt_ptr + }; + } + + static bool can_implement(Arguments const &args) { + // current kernel only support decode + if (args.kernel.shape.seq_len_qo > 1) { + return false; + } + // current kernel only support num batch heads less than total XeCore count + if (args.kernel.shape.batch * args.kernel.shape.num_heads_q > args.hw_info.sm_count) { + return false; + } + return CollectiveMainloop::can_implement(args.mainloop) + && CollectiveEpilogue::can_implement(args.epilogue); + } + + static int get_workspace_size(Arguments const &args) { + int ws_size = 0; + int num_batch_heads = args.kernel.shape.batch * args.kernel.shape.num_heads_q; + const int wg_size = SGPerWG::value * intel::sg_size; + + // partial attn outputs, exp sum and max logits + ws_size += (max_num_partitions * num_batch_heads) * wg_size * num_elem_per_thread * sizeof(ElementA); + // atomic counter + ws_size += num_batch_heads * sizeof(int32_t); + return ws_size; + } + + static cutlass::Status initialize_workspace(Arguments const &args, void *workspace = nullptr, + cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { + int num_batch_heads = args.kernel.shape.batch * args.kernel.shape.num_heads_q; + compat::fill(reinterpret_cast(workspace), (int32_t)0, num_batch_heads); + auto partial_ws_count = (get_workspace_size(args) - num_batch_heads * sizeof(int32_t)) / sizeof(ElementA); + auto* partial_results_ptr = reinterpret_cast(reinterpret_cast(workspace) + num_batch_heads); + compat::fill(partial_results_ptr, (ElementA)0, partial_ws_count); + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const ¶ms) { + return TileScheduler::template get_grid_shape(params.scheduler); + } + + static dim3 get_block_shape() { return dim3(SGPerWG::value * intel::sg_size, 1, 1); } + + CUTLASS_DEVICE + int get_partition_id(const int cur_wg_id, const int batch_head_id, const int num_blocks_per_wg, const int local_k_blocks) { + int partition_id = 0; + if (batch_head_id == 0) { + return cur_wg_id; + } + int start_wg_id = batch_head_id * local_k_blocks / num_blocks_per_wg; + partition_id = cur_wg_id - start_wg_id; + return partition_id; + } + + CUTLASS_DEVICE + int get_num_partitions(const int batch_head_id, const int num_blocks_per_wg, const int local_k_blocks) { + int num_partitions = 1; + int start_wg_id = batch_head_id * local_k_blocks / num_blocks_per_wg; + int end_wg_id = (batch_head_id + 1) * local_k_blocks / num_blocks_per_wg; + num_partitions = end_wg_id - start_wg_id + 1; + // end_wg_id is the starting wg id of next batch head id + if (((batch_head_id + 1) * local_k_blocks) % num_blocks_per_wg == 0) { + num_partitions -= 1; + } + return num_partitions; + } + + template + CUTLASS_DEVICE + void reduce_split2(const Params ¶ms, FragA &out1, FragARow& max_val1, FragARow& exp_sum_val1, FragA &out2, FragARow& max_val2, FragARow& exp_sum_val2) { + // global max value + FragARow max_prev1 = max_val1; + FragARow max_prev2 = max_val2; + + auto scale = params.mainloop.scale; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < max_val1.size(); i++) { + max_val1(i) = sycl::max(max_val1(i), max_val2(i)); + } + + FragARow rescale1, rescale2; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < max_val1.size(); i++) { + rescale1(i) = sycl::native::exp2(max_prev1(i) - max_val1(i)); + rescale2(i) = sycl::native::exp2(max_prev2(i) - max_val1(i)); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < exp_sum_val1.size(); i++) { + exp_sum_val1(i) = exp_sum_val1(i) * rescale1(i) + exp_sum_val2(i) * rescale2(i); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < out1.size(); i++) + out1(i) = out1(i) * broadcast<0>(rescale1, out1, i) + out2(i) * broadcast<0>(rescale2, out2, i); + } + + CUTLASS_DEVICE + void operator()(Params const ¶ms, char *smem_buf) + { + using namespace sycl::ext::oneapi::this_work_item; + + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + auto &p = params.kernel; + ProblemShape const& s = p.shape; + int head_group_q = s.num_heads_q / s.num_heads_kv; + + int thr_id = int(ThreadIdxX()); + int wg_id = int(BlockIdxZ()); + + int sg_id = thr_id / intel::sg_size; + int tid_in_sg = thr_id % intel::sg_size; + int num_batch_heads = s.batch * s.num_heads_q; + + int local_k_blocks = cute::ceil_div(s.seq_len_kv, get<1>(TileShapeQK{})); + // total number of blocks need to be processed across all wgs + int total_k_blocks = local_k_blocks * num_batch_heads; + // to guarantee all wg process similar number of blocks of KV + int num_blocks_per_wg = cute::ceil_div(total_k_blocks, GridDimZ()); + + TileScheduler tile_scheduler{params.scheduler, get<1>(TileShapeQK{}), local_k_blocks, num_batch_heads}; + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto [blk_q, blk_v, start_batch_head_id] = tile_scheduler.get_block_coord(); // (Q,V, batch_head_idx) + auto blk_qv = make_coord(blk_q, blk_v); + + auto shape_Q = make_shape(s.seq_len_qo, s.head_size_qk, s.num_heads_q, s.batch); + auto shape_K = make_shape(s.seq_len_kv, s.head_size_qk, s.num_heads_kv, s.batch); + auto shape_V = make_shape(s.head_size_vo, s.seq_len_kv, s.num_heads_kv, s.batch); + auto shape_O = make_shape(s.seq_len_qo, s.head_size_vo, s.num_heads_kv, s.batch); + + auto dcQ = const_cast(p.Q); // de-const these for uniformity + auto dcK = const_cast(p.K); + auto dcV = const_cast(p.V); + + Tensor Q = make_tensor(make_gmem_ptr(dcQ), make_layout(shape_Q, p.dQ)); // (q,d,h,b) + Tensor K = make_tensor(make_gmem_ptr(dcK), make_layout(shape_K, p.dK)); // (k,d,h,b) + Tensor V = make_tensor(make_gmem_ptr(dcV), make_layout(shape_V, p.dV)); // (v,k,h,b) + Tensor O = make_tensor(make_gmem_ptr(p.O), make_layout(shape_O, p.dO)); // (q,v,h,b) + + // O accumulator types + FragA tArA; + FragARow tA_max, tA_sum; + + // compute num computed blocks for start batch head id + int num_computed_blocks = (start_batch_head_id == 0) ? (wg_id * num_blocks_per_wg) : (wg_id * num_blocks_per_wg - start_batch_head_id * local_k_blocks); + int start_blk, end_blk, head_q, idx_b, head_kv; + // leader wg is also responsible for reducing partial results, while other + // worker wg only to compute partial results + bool is_leader_wg = wg_id < num_batch_heads; + + if (thr_id == 0 && is_leader_wg) { + // reset atomic counter before computation + *(params.atomic_reduce_cnt_ptr + wg_id) = 0; + } + + // Main loop + CollectiveMainloop mainloop(params.mainloop, shared_storage.mainloop); + + // compute blocks budget remained for each wg + int block_budget_remained = num_blocks_per_wg; + int batch_head_id = start_batch_head_id; + bool is_update_batch_head_id = false; + while (block_budget_remained > 0) { + int num_new_blocks = local_k_blocks - num_computed_blocks; + if (num_new_blocks <= block_budget_remained) { + // finished current batch head id + start_blk = num_computed_blocks; + end_blk = start_blk + num_new_blocks; + + // update states + num_computed_blocks = 0; + block_budget_remained -= num_new_blocks; + is_update_batch_head_id = true; + } else { + // budget cannot afford finishing current batch head id + start_blk = num_computed_blocks; + end_blk = start_blk + block_budget_remained; + + block_budget_remained = 0; + is_update_batch_head_id = false; + } + + head_q = batch_head_id % s.num_heads_q; + idx_b = batch_head_id / s.num_heads_q; + head_kv = head_q / head_group_q; + // mainloop + mainloop(Q(_,_,head_q,idx_b), + K(_,_,head_kv,idx_b), + V(_,_,head_kv,idx_b), + tArA, tA_max, tA_sum, + blk_qv, start_blk, end_blk, local_k_blocks, + thr_id, s.seq_len_kv, /*for causal*/0, 0); + + // partition id of start batch head id in current wg + int partition_id = get_partition_id(wg_id, batch_head_id, num_blocks_per_wg, local_k_blocks); + + // store partial result: tArA, tA_max and tA_sum + int offset = batch_head_id * max_num_partitions * num_elem_per_thread * SGPerWG::value * intel::sg_size + + partition_id * num_elem_per_thread * SGPerWG::value * intel::sg_size + + sg_id * intel::sg_size * num_elem_per_thread + + tid_in_sg * num_elem_per_thread; + Tensor tPartial = make_tensor(params.partial_results_ptr + offset, make_shape(Int{})); + Tensor merged_res = make_tensor(Int{}); + + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < size(FragA{}.shape()); ++i) { + merged_res(i) = tArA(i); + } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(FragARow{}.shape()); ++i) { + merged_res(2 * i + size(FragA{}.shape())) = tA_max(i); + merged_res(2 * i + 1 + size(FragA{}.shape())) = tA_sum(i); + } + copy(merged_res, tPartial); + + // after store, set atomic cnt + if (thr_id == 0) { + atomicAdd(params.atomic_reduce_cnt_ptr + batch_head_id, 1); + } + + // advance to next batch head id + if (is_update_batch_head_id) { + batch_head_id += 1; + if (batch_head_id >= num_batch_heads) { + break; + } + } + } + + if (is_leader_wg) { + int num_partitions = get_num_partitions(wg_id, num_blocks_per_wg, local_k_blocks); + + // check atomic to wait for partial results ready + while(atomicLoad(params.atomic_reduce_cnt_ptr + wg_id) != num_partitions) {} + + clear(tArA); + clear(tA_max); + clear(tA_sum); + + for (int i = 0; i < num_partitions; ++i) { + int offset = wg_id * max_num_partitions * SGPerWG::value * intel::sg_size * num_elem_per_thread + + i * SGPerWG::value * intel::sg_size * num_elem_per_thread + + sg_id * intel::sg_size * num_elem_per_thread + + tid_in_sg * num_elem_per_thread; + Tensor tPartial = make_tensor(params.partial_results_ptr + offset, make_shape(Int{})); + Tensor merged_res = make_tensor(Int{}); + copy(tPartial, merged_res); + + if (i == 0) { + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < size(FragA{}.shape()); ++i) { + tArA(i) = merged_res(i); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(FragARow{}.shape()); ++i) { + tA_max(i) = merged_res(2 * i + size(FragA{}.shape())); + tA_sum(i) = merged_res(2 * i + 1 + size(FragA{}.shape())); + } + + continue; + } + + FragA tArA_2; + FragARow tA_max_2, tA_sum_2; + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < size(FragA{}.shape()); ++i) { + tArA_2(i) = merged_res(i); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(FragARow{}.shape()); ++i) { + tA_max_2(i) = merged_res(2 * i + size(FragA{}.shape())); + tA_sum_2(i) = merged_res(2 * i + 1 + size(FragA{}.shape())); + } + + reduce_split2(params, tArA, tA_max, tA_sum, tArA_2, tA_max_2, tA_sum_2); + } + + // require group barrier if using SLM + if constexpr (!is_empty_v && !is_empty_v) { + sycl::group_barrier(get_work_group<3>()); + } + + head_q = wg_id % s.num_heads_q; + idx_b = wg_id / s.num_heads_q; + head_kv = head_q / head_group_q; + + // Epilogue + CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue}; + epilogue(O(_,_,head_q,idx_b), + tArA, tA_max, tA_sum, + blk_qv, thr_id); + } + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp b/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp new file mode 100644 index 0000000000..24a686993c --- /dev/null +++ b/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp @@ -0,0 +1,163 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.h" + +namespace cutlass::fmha::kernel { + +struct XeFHMAIndividualTileScheduler { + + struct Params { + dim3 grid; + FastDivmod divmod_num_heads; + }; + + bool valid_ = true; + Params params; + + CUTLASS_DEVICE + XeFHMAIndividualTileScheduler(Params const& params) : params(params) {} + + template + static Params to_underlying_arguments( + ProblemShape const& shape, KernelHardwareInfo hw_info, + TileShape const& tile_shape) + { + using namespace cute; + + dim3 grid(size(ceil_div(shape.head_size_vo, get<1>(tile_shape))), // V + size(ceil_div(shape.seq_len_qo, get<0>(tile_shape))), // Q + size(shape.batch * shape.num_heads_q)); // (h,b) -- split later + return Params{grid, {shape.num_heads_q}}; + } + + template + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int idx_b = BlockIdxZ(); + int head; + params.divmod_num_heads(idx_b, head, idx_b); + return make_coord(BlockIdxY(), BlockIdxX(), head, idx_b); + } + + CUTLASS_DEVICE + XeFHMAIndividualTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +struct XeFHMAIndividualPersistentTileScheduler { + + struct Params { + dim3 grid; + FastDivmod divmod_num_heads; + }; + + bool valid_ = true; + Params params; + int kv_tile_size_; + // num of kv blocks for each head + int local_num_kv_blocks_; + int num_batch_heads_; + + CUTLASS_DEVICE + XeFHMAIndividualPersistentTileScheduler(Params const& params, int kv_tile_size, + int local_num_kv_blocks, int num_batch_heads) + : params(params), kv_tile_size_(kv_tile_size), local_num_kv_blocks_(local_num_kv_blocks), num_batch_heads_(num_batch_heads) {} + + template + static Params to_underlying_arguments( + ProblemShape const& shape, KernelHardwareInfo hw_info, + TileShape const& tile_shape) + { + using namespace cute; + + dim3 grid(size(ceil_div(shape.head_size_vo, get<1>(tile_shape))), // V + size(ceil_div(shape.seq_len_qo, get<0>(tile_shape))), // Q + size(shape.batch * shape.num_heads_q)); // (h,b) -- split later + int num_heads = shape.num_heads_q; + grid.z = hw_info.sm_count; + + return Params{grid, {num_heads}}; + } + + template + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int wg_id = BlockIdxZ(); + + // total number of blocks need to be processed across all wgs + int total_num_kv_blocks = local_num_kv_blocks_ * num_batch_heads_; + // guarantee all wg process similar number of blocks of KV (load balance) + int num_blocks_per_wg = cute::ceil_div(total_num_kv_blocks, GridDimZ()); + + // compute start batch head id for current wg + int start_batch_head_id = wg_id * num_blocks_per_wg / local_num_kv_blocks_; + + return make_coord(BlockIdxY(), BlockIdxX(), start_batch_head_id); + } + + CUTLASS_DEVICE + XeFHMAIndividualPersistentTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/benchmarks/gemm/benchmark_runner.hpp b/benchmarks/gemm/benchmark_runner.hpp index 69dc80a741..7fbca3a76b 100644 --- a/benchmarks/gemm/benchmark_runner.hpp +++ b/benchmarks/gemm/benchmark_runner.hpp @@ -172,7 +172,7 @@ struct BenchmarkRunnerGemm { using ElementA = typename Gemm::ElementA; using ElementB = typename Gemm::ElementB; - using ElementAcc = typename Gemm::ElementAccumulator; + using ElementAccumulator = typename Gemm::ElementAccumulator; using CollectiveMainloop = typename Gemm::GemmKernel::CollectiveMainloop; using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; @@ -187,7 +187,6 @@ struct BenchmarkRunnerGemm { using ElementC = typename Gemm::ElementC; using ElementOutput = typename CollectiveEpilogue::ElementOutput; using ElementCompute = typename CollectiveEpilogue::ElementCompute; - using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; @@ -660,7 +659,7 @@ struct BenchmarkRunnerGemm { stride_S, block_zero.get(), stride_Z, 128}; } - arguments.epilogue = {{ElementAcc(options.alpha), ElementAcc(options.beta)}, block_C[0].get(), stride_C, block_D.get(), stride_D}; + arguments.epilogue = {{ElementAccumulator(options.alpha), ElementAccumulator(options.beta)}, block_C[0].get(), stride_C, block_D.get(), stride_D}; arguments.hw_info = hw_info; if constexpr(epi_is_deeltactmul){ @@ -748,7 +747,7 @@ struct BenchmarkRunnerGemm { gemm::GemmUniversalMode::kGemm, problem_size, {block_A[input_num].get(), stride_A, block_B[input_num].get(), stride_B}, - {{ElementAcc(options.alpha), ElementAcc(options.beta)}, block_C[input_num].get(), stride_C, block_D.get(), stride_D}, + {{ElementAccumulator(options.alpha), ElementAccumulator(options.beta)}, block_C[input_num].get(), stride_C, block_D.get(), stride_D}, hw_info }; if constexpr (is_mixed_dtype) { diff --git a/cmake/FindDPCPP.cmake b/cmake/FindDPCPP.cmake index d92070c72c..3013c2302f 100644 --- a/cmake/FindDPCPP.cmake +++ b/cmake/FindDPCPP.cmake @@ -41,41 +41,65 @@ add_library(DPCPP::DPCPP INTERFACE IMPORTED) set(DPCPP_FLAGS "-fsycl;") if(DPCPP_HOST_COMPILER) list(APPEND DPCPP_FLAGS "-fsycl-host-compiler=${DPCPP_HOST_COMPILER}") - list(APPEND DPCPP_FLAGS "-fsycl-host-compiler-options=-Wno-changes-meaning -D$, -D> -I$, -I>") + set(_host_opts "-Wno-changes-meaning $<$>:-fPIC> -D$, -D> -I$, -I>") + if(DEFINED DPCPP_HOST_COMPILER_OPTIONS AND NOT "${DPCPP_HOST_COMPILER_OPTIONS}" STREQUAL "") + set(_host_opts "${DPCPP_HOST_COMPILER_OPTIONS} ${_host_opts}") + string(STRIP "${_host_opts}" _host_opts) + endif() + list(APPEND DPCPP_FLAGS "-fsycl-host-compiler-options=${_host_opts}") endif() set(DPCPP_COMPILE_ONLY_FLAGS "") set(DPCPP_LINK_ONLY_FLAGS "") -if(NOT "${DPCPP_SYCL_TARGET}" STREQUAL "") - list(APPEND DPCPP_FLAGS "-fsycl-targets=${DPCPP_SYCL_TARGET};") -endif() - option(DPCPP_DISABLE_ITT_FOR_CUTLASS "Disables linking of the Instrumentation and Tracing Technology (ITT) device libraries for VTune" ON) if(NOT "${DPCPP_USER_FLAGS}" STREQUAL "") list(APPEND DPCPP_FLAGS "${DPCPP_USER_FLAGS};") endif() +string(REPLACE "," ";" DPCPP_SYCL_TARGET_LIST "${DPCPP_SYCL_TARGET}") + if(NOT "${DPCPP_SYCL_ARCH}" STREQUAL "") - if("${DPCPP_SYCL_TARGET}" STREQUAL "nvptx64-nvidia-cuda") + if(SYCL_NVIDIA_TARGET) + list(APPEND DPCPP_FLAGS "-fsycl-targets=nvptx64-nvidia-cuda;") list(APPEND DPCPP_FLAGS "-Xsycl-target-backend") list(APPEND DPCPP_FLAGS "--cuda-gpu-arch=${DPCPP_SYCL_ARCH}") list(APPEND DPCPP_COMPILE_ONLY_FLAGS; "-mllvm;-enable-global-offset=false;") endif() endif() -if("${DPCPP_SYCL_TARGET}" STREQUAL "intel_gpu_pvc" OR - "${DPCPP_SYCL_TARGET}" STREQUAL "spir64" OR - "${DPCPP_SYCL_TARGET}" STREQUAL "intel_gpu_bmg_g21") - if ((CMAKE_CXX_COMPILER_ID MATCHES "IntelLLVM" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 2025.2) OR CUTLASS_SYCL_BUILTIN_ENABLE) - list(APPEND DPCPP_LINK_ONLY_FLAGS "-Xspirv-translator;-spirv-ext=+SPV_INTEL_split_barrier") - else() - list(APPEND DPCPP_LINK_ONLY_FLAGS "-Xspirv-translator;-spirv-ext=+SPV_INTEL_split_barrier,+SPV_INTEL_2d_block_io,+SPV_INTEL_subgroup_matrix_multiply_accumulate") - endif() +if (SYCL_INTEL_TARGET) if(DPCPP_DISABLE_ITT_FOR_CUTLASS) list(APPEND DPCPP_FLAGS "-fno-sycl-instrument-device-code") endif() + set(SYCL_DEVICES) + + foreach(TGT IN LISTS DPCPP_SYCL_TARGET_LIST) + if(TGT STREQUAL "intel_gpu_bmg_g21" OR TGT STREQUAL "bmg") + list(APPEND SYCL_DEVICES "bmg_g21") + elseif(TGT STREQUAL "intel_gpu_pvc" OR TGT STREQUAL "pvc") + list(APPEND SYCL_DEVICES "pvc") + endif() + endforeach() + + list(REMOVE_DUPLICATES SYCL_DEVICES) + + string(JOIN "," SYCL_DEVICES_STR ${SYCL_DEVICES}) + + list(APPEND DPCPP_LINK_ONLY_FLAGS "-fsycl-targets=spir64") + list(APPEND DPCPP_LINK_ONLY_FLAGS "-Xs;-device ${SYCL_DEVICES_STR}") + + list(APPEND DPCPP_LINK_ONLY_FLAGS "-Xspirv-translator") + + if((CMAKE_CXX_COMPILER_ID MATCHES "IntelLLVM" AND + CMAKE_CXX_COMPILER_VERSION VERSION_LESS 2025.2) OR CUTLASS_SYCL_BUILTIN_ENABLE) + set(SPIRV_EXT "+SPV_INTEL_split_barrier") + else() + set(SPIRV_EXT "+SPV_INTEL_split_barrier,+SPV_INTEL_2d_block_io,+SPV_INTEL_subgroup_matrix_multiply_accumulate") + endif() + list(APPEND DPCPP_LINK_ONLY_FLAGS "-spirv-ext=${SPIRV_EXT}") + endif() if(UNIX) diff --git a/examples/00_bmg_gemm/00_bmg_gemm.cpp b/examples/00_bmg_gemm/00_bmg_gemm.cpp index 7e9291227e..b9c7738872 100644 --- a/examples/00_bmg_gemm/00_bmg_gemm.cpp +++ b/examples/00_bmg_gemm/00_bmg_gemm.cpp @@ -154,13 +154,12 @@ struct ExampleRunner { using ElementA = typename Gemm::ElementA; using ElementB = typename Gemm::ElementB; - using ElementAcc = typename Gemm::ElementAccumulator; + using ElementAccumulator = typename Gemm::ElementAccumulator; using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; using ElementC = typename Gemm::ElementC; using ElementOutput = typename CollectiveEpilogue::ElementOutput; using ElementCompute = typename CollectiveEpilogue::ElementCompute; - using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; @@ -277,9 +276,7 @@ struct ExampleRunner { bool passed = verify(problem_size, options.alpha, options.beta); std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; - if(!passed) return cutlass::Status::kErrorInternal; - - if (options.iterations > 0) { + if (passed && options.iterations > 0) { GPU_Clock timer; timer.start(); for (int i = 0; i < options.iterations; ++i) { @@ -345,9 +342,13 @@ int main(int argc, const char** argv) using LayoutC = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor; - // The 2D block copy operations used for the A and B matrices - using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; - using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + // [New Copy Atom] When left unspecified (void), MainloopXeL1Staged automatically selects + // appropriate 2D block copy operations for matrices A and B. Alternatively, you can + // explicitly specify new copy atom operations such as XE_LOAD_2D, XE_LOAD_2D_VNNI, + // or XE_LOAD_2D_TRANSPOSE. + // Refer https://github.com/intel/sycl-tla/blob/main/media/docs/cpp/xe_rearchitecture.md + using GmemTiledCopyA = void; //XE_LOAD_2D<16, 32, 32>; + using GmemTiledCopyB = void; //XE_LOAD_2D_VNNI<16, 32, 32>; // Workgroup-level tile using TileShape = Shape<_256, _256, _32>; @@ -355,21 +356,21 @@ int main(int argc, const char** argv) // A TiledMMA struct defines a tiling of an MMA atom over M, N and K, combining both additional // hardware (sub-groups for Intel BMG) and iterations by each sub-group. // - // The TiledMMAHelper struct defines a specific TiledMMA for a given MMA atom - // (XE_8x16x16_F32BF16BF16F32_TT), TileShape (<256, 256, 32>) and sub-group layout (8x4x1). The - // TiledMMA constructed using TiledMMAHelper has the property that each sub-group operates on a + // The TiledMMAHelper struct defines a specific TiledMMA for a given MMA atom. This example uses + // the XE_DPAS_TT<8, float, cute::bfloat16_t> atom, which represents an 8x16x16 DPAS operation with + //float32 accumulation and bfloat16 inputs, TileShape (<256, 256, 32>) and sub-group layout (8x4x1). + // The TiledMMA constructed using TiledMMAHelper has the property that each sub-group operates on a // single contiguous chunk of the work-group TileShape. For this configuration, this implies that // each sub-group operates on a contiguous 32x64x32 chunk (4x4x2 iterations). See // 0t_mma_atom.md#TiledMMAs for more info. Sub-groups are arranged row-major (stride 4,1,0) for // performance reasons. - using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32 - typename TiledMMAHelper, Layout, - Layout, Stride<_4, _1, _0>>>::TiledMMA; + using TiledMma = typename TiledMMAHelper>, Layout, Layout, Stride<_4, _1, _0>>>::TiledMMA; // For Intel BMG, PipelineStages defines how many k-blocks ahead to prefetch from A and B. constexpr int PipelineStages = 2; - using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; - using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + // For older version of copy/mma atom, use cutlass::gemm::MainloopIntelXeXMX16 as dispatch policy + using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1Staged; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeGeneric; // This is the 'default' epilogue operation (Linear Combination) which performs everything in: // (D = alpha * (A*B) + beta * C) @@ -380,22 +381,22 @@ int main(int argc, const char** argv) // FusionCallbacks ties the EpilogueOp to an implementation (based on the dispatch // policy/architecture) and defines the epilogue arguments. - using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + // GEMM Epilogue - loads & stores C/D matrices, performs epilogue operations & load/stores any // auxiliary data required using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< EpilogueDispatchPolicy, TileShape, + void, // Epilogue tile (void = automatic) ElementAccumulator, cutlass::gemm::TagToStrideC_t, // Converts CUTLASS 2.x to CUTLASS 3.x representation ElementOutput, cutlass::gemm::TagToStrideC_t, // Converts CUTLASS 2.x to CUTLASS 3.x representation - FusionCallBacks, - XE_2D_U32x8x16_LD_N, // The copy atom used to load matrix C - void, void, - XE_2D_U32x8x16_ST_N, // The copy atom used to store matrix D - void, void>; + FusionCallbacks, + void, // The copy atom used to load matrix C (void = automatic) + void>; // The copy atom used to store matrix D (void = automatic) // GEMM Mainloop - iteration over blocks in K dimension using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< @@ -412,9 +413,9 @@ int main(int argc, const char** argv) // Define the whole kernel (mainloop and epilogue) using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, // Defer global problem shape definition to runtime - CollectiveMainloop, - CollectiveEpilogue + Shape, // Defer global problem shape definition to runtime + CollectiveMainloop, + CollectiveEpilogue >; // The GemmUniversalAdapter wraps the defined GEMM kernel and handles the launch, and e.g. diff --git a/examples/00_bmg_gemm/00_bmg_gemm_padded.cpp b/examples/00_bmg_gemm/00_bmg_gemm_padded.cpp index b231825fe7..4e1905e58d 100644 --- a/examples/00_bmg_gemm/00_bmg_gemm_padded.cpp +++ b/examples/00_bmg_gemm/00_bmg_gemm_padded.cpp @@ -39,7 +39,7 @@ This example makes use of BMGs subgroup cooperative 2d-block copy operations and DPAS instructions. To support more input shapes using these instructions, rows of the input/output matrices are padded - to a multiple of 16 and each matrix in batch is padded to a multiple of 64, as required by these + to a multiple of 16 and each matrix in batch is padded to a multiple of 64, as required by these instructions. The shapes of the A and B matrix are defined at runtime by `options.m`, `.n` and `.k`, and the @@ -161,14 +161,14 @@ struct ExampleRunner { using ElementA = typename Gemm::ElementA; using ElementB = typename Gemm::ElementB; - using ElementAcc = typename Gemm::ElementAccumulator; + using ElementAccumulator = typename Gemm::ElementAccumulator; using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; using ElementC = typename Gemm::ElementC; using ElementD = typename Gemm::ElementD; using ElementOutput = typename CollectiveEpilogue::ElementOutput; using ElementCompute = typename CollectiveEpilogue::ElementCompute; - using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; @@ -200,7 +200,7 @@ struct ExampleRunner { bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { auto [M, N, K, L] = problem_size; - + // Padded values // The inner dimension is padded. Since this example is all RowMajor, // we require the following: @@ -208,7 +208,7 @@ struct ExampleRunner { int N_C = cute::round_up(N, AlignElemC); int N_D = cute::round_up(N, AlignElemD); int K_A = cute::round_up(K, AlignElemA); - + int AlignmentOuter = AlignmentPtr / AlignmentInner; int M_ACD = cute::round_up(M, AlignmentOuter); int K_B = cute::round_up(K, AlignmentOuter); @@ -383,9 +383,13 @@ int main(int argc, const char** argv) using LayoutC = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor; - // The 2D block copy operations used for the A and B matrices - using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; - using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + // [New Copy Atom] When left unspecified (void), MainloopXeL1Staged automatically selects + // appropriate 2D block copy operations for matrices A and B. Alternatively, you can + // explicitly specify new copy atom operations such as XE_LOAD_2D, XE_LOAD_2D_VNNI, + // or XE_LOAD_2D_TRANSPOSE. + // Refer https://github.com/intel/sycl-tla/blob/main/media/docs/cpp/xe_rearchitecture.md + using GmemTiledCopyA = void; //XE_LOAD_2D<16, 32, 32>; + using GmemTiledCopyB = void; //XE_LOAD_2D_VNNI<16, 32, 32>; // Workgroup-level tile using TileShape = Shape<_256, _256, _32>; @@ -393,21 +397,21 @@ int main(int argc, const char** argv) // A TiledMMA struct defines a tiling of an MMA atom over M, N and K, combining both additional // hardware (sub-groups for Intel BMG) and iterations by each sub-group. // - // The TiledMMAHelper struct defines a specific TiledMMA for a given MMA atom - // (XE_8x16x16_F32BF16BF16F32_TT), TileShape (<256, 256, 32>) and sub-group layout (8x4x1). The - // TiledMMA constructed using TiledMMAHelper has the property that each sub-group operates on a + // The TiledMMAHelper struct defines a specific TiledMMA for a given MMA atom. This example uses + // the XE_DPAS_TT<8, float, cute::bfloat16_t> atom, which represents an 8x16x16 DPAS operation with + // float32 accumulation and bfloat16 inputs, TileShape (<256, 256, 32>) and sub-group layout (8x4x1). + // The TiledMMA constructed using TiledMMAHelper has the property that each sub-group operates on a // single contiguous chunk of the work-group TileShape. For this configuration, this implies that // each sub-group operates on a contiguous 32x64x32 chunk (4x4x2 iterations). See // 0t_mma_atom.md#TiledMMAs for more info. Sub-groups are arranged row-major (stride 4,1,0) for // performance reasons. - using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32 - typename TiledMMAHelper, Layout, - Layout, Stride<_4, _1, _0>>>::TiledMMA; + using TiledMma = typename TiledMMAHelper>, Layout, Layout, Stride<_4, _1, _0>>>::TiledMMA; // For Intel BMG, PipelineStages defines how many k-blocks ahead to prefetch from A and B. constexpr int PipelineStages = 2; - using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; - using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + // For older version of copy/mma atom, use cutlass::gemm::MainloopIntelXeXMX16 as dispatch policy + using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1Staged; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeGeneric; // This is the 'default' epilogue operation (Linear Combination) which performs everything in: // (D = alpha * (A*B) + beta * C) @@ -418,22 +422,21 @@ int main(int argc, const char** argv) // FusionCallbacks ties the EpilogueOp to an implementation (based on the dispatch // policy/architecture) and defines the epilogue arguments. - using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; // GEMM Epilogue - loads & stores C/D matrices, performs epilogue operations & load/stores any // auxiliary data required using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< EpilogueDispatchPolicy, TileShape, + void, // Epilogue tile (void = automatic) ElementAccumulator, cutlass::gemm::TagToStrideC_t, // Converts CUTLASS 2.x to CUTLASS 3.x representation ElementOutput, cutlass::gemm::TagToStrideC_t, // Converts CUTLASS 2.x to CUTLASS 3.x representation - FusionCallBacks, - XE_2D_U32x8x16_LD_N, // The copy atom used to load matrix C - void, void, - XE_2D_U32x8x16_ST_N, // The copy atom used to store matrix D - void, void>; + FusionCallbacks, + void, // The copy atom used to load matrix C (void = automatic) + void>; // The copy atom used to store matrix D (void = automatic) // GEMM Mainloop - iteration over blocks in K dimension using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< @@ -464,4 +467,4 @@ int main(int argc, const char** argv) CUTLASS_CHECK(runner.run(options, hw_info)); return 0; -} +} \ No newline at end of file diff --git a/examples/00_bmg_gemm/00_bmg_gemm_with_sycl_queue.cpp b/examples/00_bmg_gemm/00_bmg_gemm_with_sycl_queue.cpp index 67e1193e75..4e42e48e4c 100644 --- a/examples/00_bmg_gemm/00_bmg_gemm_with_sycl_queue.cpp +++ b/examples/00_bmg_gemm/00_bmg_gemm_with_sycl_queue.cpp @@ -136,13 +136,12 @@ struct ExampleRunner { using ElementA = typename Gemm::ElementA; using ElementB = typename Gemm::ElementB; - using ElementAcc = typename Gemm::ElementAccumulator; + using ElementAccumulator = typename Gemm::ElementAccumulator; using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; using ElementC = typename Gemm::ElementC; using ElementOutput = typename CollectiveEpilogue::ElementOutput; using ElementCompute = typename CollectiveEpilogue::ElementCompute; - using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; @@ -348,42 +347,50 @@ int main(int argc, const char** argv) using LayoutC = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor; - using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; - using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + // [New Copy Atom] When left unspecified (void), MainloopXeL1Staged automatically selects + // appropriate 2D block copy operations for matrices A and B. Alternatively, you can + // explicitly specify new copy atom operations such as XE_LOAD_2D, XE_LOAD_2D_VNNI, + // or XE_LOAD_2D_TRANSPOSE. + // Refer https://github.com/intel/sycl-tla/blob/main/media/docs/cpp/xe_rearchitecture.md + using GmemTiledCopyA = void; //XE_LOAD_2D<16, 32, 32>; + using GmemTiledCopyB = void; //XE_LOAD_2D_VNNI<16, 32, 32>; // Workgroup-level tile using TileShape = Shape<_256, _256, _32>; - // The Tile of this layout describes how 8x4x1 sub-groups tile the TileShape of <256, 256, 32>. - // This permutation (which can be thought of as a scatter operation on the default tiling) - // ensures that each sub-group operates on a contiguous 32x64x32 chunk (4x4x2 iterations) - // See 0t_mma_atom.md#TiledMMAs for more info. - // Sub-groups are arranged row-major (stride 4,1,0) for performance reasons. - using TiledMma = - typename TiledMMAHelper, Layout, - Layout, Stride<_4, _1, _0>>>::TiledMMA; + // A TiledMMA struct defines a tiling of an MMA atom over M, N and K, combining both additional + // hardware (sub-groups for Intel BMG) and iterations by each sub-group. + // + // The TiledMMAHelper struct defines a specific TiledMMA for a given MMA atom. This example uses + // the XE_DPAS_TT<8, float, cute::bfloat16_t> atom, which represents an 8x16x16 DPAS operation with + //float32 accumulation and bfloat16 inputs, TileShape (<256, 256, 32>) and sub-group layout (8x4x1). + // The TiledMMA constructed using TiledMMAHelper has the property that each sub-group operates on a + // single contiguous chunk of the work-group TileShape. For this configuration, this implies that + // each sub-group operates on a contiguous 32x64x32 chunk (4x4x2 iterations). See + // 0t_mma_atom.md#TiledMMAs for more info. Sub-groups are arranged row-major (stride 4,1,0) for + // performance reasons. + using TiledMma = typename TiledMMAHelper>, Layout, Layout, Stride<_4, _1, _0>>>::TiledMMA; constexpr int PipelineStages = 2; - using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; - using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1Staged; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeGeneric; using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; - using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< EpilogueDispatchPolicy, TileShape, + void, ElementAccumulator, cutlass::gemm::TagToStrideC_t, ElementOutput, cutlass::gemm::TagToStrideC_t, - FusionCallBacks, - XE_2D_U32x8x16_LD_N, - void, void, - XE_2D_U32x8x16_ST_N, - void, void>; + FusionCallbacks, + void, + void>; // Mainloop using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< @@ -411,4 +418,4 @@ int main(int argc, const char** argv) CUTLASS_CHECK(runner.run(options, hw_info)); return 0; -} +} \ No newline at end of file diff --git a/examples/00_bmg_gemm/legacy/00_bmg_gemm.cpp b/examples/00_bmg_gemm/legacy/00_bmg_gemm.cpp new file mode 100644 index 0000000000..91139cf0d6 --- /dev/null +++ b/examples/00_bmg_gemm/legacy/00_bmg_gemm.cpp @@ -0,0 +1,429 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief CUTLASS Intel BMG Gemm Example. + + This example constructs and executes a simple CUTLASS GEMM kernel on Intel BMG hardware, and + verifies its correctness with a reference implementation + (cutlass::reference::device::GemmComplex). The example also provides a performance measurement + for the GEMM in TFLOPS. + + This example makes use of BMGs subgroup cooperative 2d-block copy operations and DPAS instructions. + + The shapes of the A and B matrix are defined at runtime by `options.m`, `.n` and `.k`, and the + batch size is defined by `options.l`. The tile shape, which defines how much work is executed by + a single work-group, is defined at compile time by: + ``` + using TileShape = Shape<_256, _256, _32>; + ``` + That is, each work-group processes a tile of M=256, N=256, and iterates over `options.k` in + blocks of K=32. + + Performance of GEMM on BMG is heavily dependent on prefetching the A and B matrices. That is, + executing Intel specific prefetch instructions for future iterations to ensure that the required + blocks of A and B are resident in cache before they are needed. + + To build & run this example (from your build dir): + + $ ninja 00_bmg_gemm + $ ./examples/sycl/00_bmg_gemm/00_bmg_gemm + + Call with `--help` for information about available options +*/ + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "sycl_common.hpp" +#include "helper.h" + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + bool error; + + int m, n, k, l, iterations; + float alpha, beta; + + Options(): + help(false), + error(false), + m(5120), n(4096), k(4096), l(1), iterations(20), + alpha(1.f), beta(0.f) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 5120); + cmd.get_cmd_line_argument("n", n, 4096); + cmd.get_cmd_line_argument("k", k, 4096); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations, 100); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "BMG GEMM Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Iterations\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class Gemm +> +struct ExampleRunner { + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; // Reference GEMM result for verification + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + alpha, + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta, + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + // CUTLASS on SYCL uses the compatibility library compat for e.g. default in-order queue + compat::wait(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual( + block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + // Complete the stride by combining static layout info (StrideA) with runtime size info (M,K,L) + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(static_cast(M) * K * L); + block_B.reset(static_cast(K) * N * L); + block_C.reset(static_cast(M) * N * L); + block_D.reset(static_cast(M) * N * L); + block_ref_D.reset(static_cast(M) * N * L); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + } + + cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + if (gemm_op.can_implement(arguments) != cutlass::Status::kSuccess){ + std::cout << "Invalid Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::exit(1); + } + + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); + + // Run the GEMM + CUTLASS_CHECK(gemm_op.run()); + + compat::wait(); + + // Verify that the result is correct + bool passed = verify(problem_size, options.alpha, options.beta); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if(!passed) return cutlass::Status::kErrorInternal; + + if (options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + gemm_op.run(); + } + compat::wait(); + + float cute_time = timer.seconds() / options.iterations; + double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; + std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000); + } + + return cutlass::Status::kSuccess; + } + +}; + +int main(int argc, const char** argv) +{ + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + bool passed; + + // The code section below describes datatype for input, output matrices and computation between + // elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A + using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B + using ElementOutput = float; // <- data type of elements in output matrix D + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + // The 2D block copy operations used for the A and B matrices + using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; + using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + + // Workgroup-level tile + using TileShape = Shape<_256, _256, _32>; + + // A TiledMMA struct defines a tiling of an MMA atom over M, N and K, combining both additional + // hardware (sub-groups for Intel BMG) and iterations by each sub-group. + // + // The TiledMMAHelper struct defines a specific TiledMMA for a given MMA atom + // (XE_8x16x16_F32BF16BF16F32_TT), TileShape (<256, 256, 32>) and sub-group layout (8x4x1). The + // TiledMMA constructed using TiledMMAHelper has the property that each sub-group operates on a + // single contiguous chunk of the work-group TileShape. For this configuration, this implies that + // each sub-group operates on a contiguous 32x64x32 chunk (4x4x2 iterations). See + // 0t_mma_atom.md#TiledMMAs for more info. Sub-groups are arranged row-major (stride 4,1,0) for + // performance reasons. + using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32 + typename TiledMMAHelper, Layout, + Layout, Stride<_4, _1, _0>>>::TiledMMA; + + // For Intel BMG, PipelineStages defines how many k-blocks ahead to prefetch from A and B. + constexpr int PipelineStages = 2; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + + // This is the 'default' epilogue operation (Linear Combination) which performs everything in: + // (D = alpha * (A*B) + beta * C) + // aside from the (A*B), which is handled by the GEMM. See 05_bmg_gemm_with_epilogues for more + // complex epilogue examples. + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; + + // FusionCallbacks ties the EpilogueOp to an implementation (based on the dispatch + // policy/architecture) and defines the epilogue arguments. + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + // GEMM Epilogue - loads & stores C/D matrices, performs epilogue operations & load/stores any + // auxiliary data required + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, // Converts CUTLASS 2.x to CUTLASS 3.x representation + ElementOutput, + cutlass::gemm::TagToStrideC_t, // Converts CUTLASS 2.x to CUTLASS 3.x representation + FusionCallBacks, + XE_2D_U32x8x16_LD_N, // The copy atom used to load matrix C + void, void, + XE_2D_U32x8x16_ST_N, // The copy atom used to store matrix D + void, void>; + + // GEMM Mainloop - iteration over blocks in K dimension + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementInputA, + cutlass::gemm::TagToStrideA_t, // Converts CUTLASS 2.x to CUTLASS 3.x representation + ElementInputB, + cutlass::gemm::TagToStrideB_t, // Converts CUTLASS 2.x to CUTLASS 3.x representation + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + // Define the whole kernel (mainloop and epilogue) + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Defer global problem shape definition to runtime + CollectiveMainloop, + CollectiveEpilogue + >; + + // The GemmUniversalAdapter wraps the defined GEMM kernel and handles the launch, and e.g. + // persistent scratch memory if required. + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + CUTLASS_CHECK(runner.run(options, hw_info)); + + return 0; +} \ No newline at end of file diff --git a/examples/00_bmg_gemm/legacy/00_bmg_gemm_padded.cpp b/examples/00_bmg_gemm/legacy/00_bmg_gemm_padded.cpp new file mode 100644 index 0000000000..32a7285e71 --- /dev/null +++ b/examples/00_bmg_gemm/legacy/00_bmg_gemm_padded.cpp @@ -0,0 +1,467 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief CUTLASS Intel BMG Gemm Example. + + This example constructs and executes a simple CUTLASS GEMM kernel on Intel BMG hardware, and + verifies its correctness with a reference implementation + (cutlass::reference::device::GemmComplex). The example also provides a performance measurement + for the GEMM in TFLOPS. + + This example makes use of BMGs subgroup cooperative 2d-block copy operations and DPAS instructions. + To support more input shapes using these instructions, rows of the input/output matrices are padded + to a multiple of 16 and each matrix in batch is padded to a multiple of 64, as required by these + instructions. + + The shapes of the A and B matrix are defined at runtime by `options.m`, `.n` and `.k`, and the + batch size is defined by `options.l`. The tile shape, which defines how much work is executed by + a single work-group, is defined at compile time by: + ``` + using TileShape = Shape<_256, _256, _32>; + ``` + That is, each work-group processes a tile of M=256, N=256, and iterates over `options.k` in + blocks of K=32. + + Performance of GEMM on BMG is heavily dependent on prefetching the A and B matrices. That is, + executing Intel specific prefetch instructions for future iterations to ensure that the required + blocks of A and B are resident in cache before they are needed. + + To build & run this example (from your build dir): + + $ ninja 00_bmg_gemm + $ ./examples/sycl/00_bmg_gemm/00_bmg_gemm + + Call with `--help` for information about available options +*/ + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "sycl_common.hpp" +#include "helper.h" + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// The alignment requirement in bytes on inner dimmension that will work for both PVC and BMG +constexpr int AlignmentInner = 16; +// The alignment requirement in bytes on outer dimmension that will work for both PVC and BMG +constexpr int AlignmentPtr = 64; + +// Command line options parsing +struct Options { + + bool help; + bool error; + + int m, n, k, l, iterations; + float alpha, beta; + + Options(): + help(false), + error(false), + m(5120), n(4096), k(4096), l(1), iterations(20), + alpha(1.f), beta(0.f) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 5120); + cmd.get_cmd_line_argument("n", n, 4096); + cmd.get_cmd_line_argument("k", k, 4096); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations, 100); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "BMG GEMM Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Iterations\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class Gemm +> +struct ExampleRunner { + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementC = typename Gemm::ElementC; + using ElementD = typename Gemm::ElementD; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + static constexpr int AlignElemA = AlignmentInner / sizeof(ElementA); + static constexpr int AlignElemB = AlignmentInner / sizeof(ElementB); + static constexpr int AlignElemC = AlignmentInner / sizeof(ElementB); + static constexpr int AlignElemD = AlignmentInner / sizeof(ElementD); + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; // Reference GEMM result for verification + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { + auto [M, N, K, L] = problem_size; + + // Padded values + // The inner dimension is padded. Since this example is all RowMajor, + // we require the following: + int N_B = cute::round_up(N, AlignElemB); + int N_C = cute::round_up(N, AlignElemC); + int N_D = cute::round_up(N, AlignElemD); + int K_A = cute::round_up(K, AlignElemA); + + int AlignmentOuter = AlignmentPtr / AlignmentInner; + int M_ACD = cute::round_up(M, AlignmentOuter); + int K_B = cute::round_up(K, AlignmentOuter); + + cutlass::TensorRef ref_A(block_A.get(), LayoutA(K_A)); + cutlass::TensorRef ref_B(block_B.get(), LayoutB(N_B)); + cutlass::TensorRef ref_C(block_C.get(), LayoutC(N_C)); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD(N_D)); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + alpha, + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta, + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M_ACD * K_A, // batch_stride_A + K_B * N_B, // batch_stride_B + M_ACD * N_C, // batch_stride_C + M_ACD * N_D // batch_stride_D + ); + + // CUTLASS on SYCL uses the compatibility library compat for e.g. default in-order queue + compat::wait(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual( + block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + // Padded values + int N_B = cute::round_up(N, AlignElemB); + int N_C = cute::round_up(N, AlignElemC); + int N_D = cute::round_up(N, AlignElemD); + int K_A = cute::round_up(K, AlignElemA); + + int AlignmentOuter = AlignmentPtr / AlignmentInner; + int M_ACD = cute::round_up(M, AlignmentOuter); + int K_B = cute::round_up(K, AlignmentOuter); + + // Complete the stride by combining static layout info (StrideA) with runtime size info (M,K,L) + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M_ACD, K_A, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N_B, K_B, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M_ACD, N_C, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M_ACD, N_D, L)); + + block_A.reset(M_ACD * K_A * L); + block_B.reset(K_B * N_B * L); + block_C.reset(M_ACD * N_C * L); + block_D.reset(M_ACD * N_D * L); + block_ref_D.reset(M_ACD * N_D * L); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + } + + cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + if (gemm_op.can_implement(arguments) != cutlass::Status::kSuccess) { + std::cout << "Warning: Invalid problem size: " + << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l + << ".\nThis size is not directly supported by the selected kernel.\n" + << "However, this example applies padding as needed, so it will still run correctly." + << std::endl; + } + + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); + + // Run the GEMM + CUTLASS_CHECK(gemm_op.run()); + + compat::wait(); + + // Verify that the result is correct + bool passed = verify(problem_size, options.alpha, options.beta); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if(!passed) return cutlass::Status::kErrorInternal; + + if (options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + gemm_op.run(); + } + compat::wait(); + + float cute_time = timer.seconds() / options.iterations; + double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; + std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000); + } + + return cutlass::Status::kSuccess; + } + +}; + +int main(int argc, const char** argv) +{ + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + bool passed; + + // The code section below describes datatype for input, output matrices and computation between + // elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A + using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B + using ElementOutput = float; // <- data type of elements in output matrix D + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + // The 2D block copy operations used for the A and B matrices + using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; + using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + + // Workgroup-level tile + using TileShape = Shape<_256, _256, _32>; + + // A TiledMMA struct defines a tiling of an MMA atom over M, N and K, combining both additional + // hardware (sub-groups for Intel BMG) and iterations by each sub-group. + // + // The TiledMMAHelper struct defines a specific TiledMMA for a given MMA atom + // (XE_8x16x16_F32BF16BF16F32_TT), TileShape (<256, 256, 32>) and sub-group layout (8x4x1). The + // TiledMMA constructed using TiledMMAHelper has the property that each sub-group operates on a + // single contiguous chunk of the work-group TileShape. For this configuration, this implies that + // each sub-group operates on a contiguous 32x64x32 chunk (4x4x2 iterations). See + // 0t_mma_atom.md#TiledMMAs for more info. Sub-groups are arranged row-major (stride 4,1,0) for + // performance reasons. + using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32 + typename TiledMMAHelper, Layout, + Layout, Stride<_4, _1, _0>>>::TiledMMA; + + // For Intel BMG, PipelineStages defines how many k-blocks ahead to prefetch from A and B. + constexpr int PipelineStages = 2; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + + // This is the 'default' epilogue operation (Linear Combination) which performs everything in: + // (D = alpha * (A*B) + beta * C) + // aside from the (A*B), which is handled by the GEMM. See 05_bmg_gemm_with_epilogues for more + // complex epilogue examples. + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; + + // FusionCallbacks ties the EpilogueOp to an implementation (based on the dispatch + // policy/architecture) and defines the epilogue arguments. + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + // GEMM Epilogue - loads & stores C/D matrices, performs epilogue operations & load/stores any + // auxiliary data required + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, // Converts CUTLASS 2.x to CUTLASS 3.x representation + ElementOutput, + cutlass::gemm::TagToStrideC_t, // Converts CUTLASS 2.x to CUTLASS 3.x representation + FusionCallBacks, + XE_2D_U32x8x16_LD_N, // The copy atom used to load matrix C + void, void, + XE_2D_U32x8x16_ST_N, // The copy atom used to store matrix D + void, void>; + + // GEMM Mainloop - iteration over blocks in K dimension + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementInputA, + cutlass::gemm::TagToStrideA_t, // Converts CUTLASS 2.x to CUTLASS 3.x representation + ElementInputB, + cutlass::gemm::TagToStrideB_t, // Converts CUTLASS 2.x to CUTLASS 3.x representation + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + // Define the whole kernel (mainloop and epilogue) + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Defer global problem shape definition to runtime + CollectiveMainloop, + CollectiveEpilogue + >; + + // The GemmUniversalAdapter wraps the defined GEMM kernel and handles the launch, and e.g. + // persistent scratch memory if required. + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + CUTLASS_CHECK(runner.run(options, hw_info)); + + return 0; +} \ No newline at end of file diff --git a/examples/00_bmg_gemm/legacy/00_bmg_gemm_with_sycl_queue.cpp b/examples/00_bmg_gemm/legacy/00_bmg_gemm_with_sycl_queue.cpp new file mode 100644 index 0000000000..3a4e0fa704 --- /dev/null +++ b/examples/00_bmg_gemm/legacy/00_bmg_gemm_with_sycl_queue.cpp @@ -0,0 +1,414 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief CUTLASS Intel BMG Gemm Example with non-default SYCL queue. + This example modifies 00_bmg_gemm to use a non-default queue. The main changes are passing the + queue to gemm_op.initialize and gemm_op.run. Otherwise, changes are made to allocate memory with + the correct queue. + + To build & run this example (from your build dir): + $ ninja 00_bmg_gemm_with_sycl_queue + $ ./examples/sycl/00_bmg_gemm_with_sycl_queue/00_bmg_gemm_with_sycl_queue + Call with `--help` for information about available options +*/ + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "sycl_common.hpp" +#include "helper.h" + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + bool error; + + int m, n, k, l, iterations; + float alpha, beta; + + Options(): + help(false), + error(false), + m(5120), n(4096), k(4096), l(1), iterations(20), + alpha(1.f), beta(0.f) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 5120); + cmd.get_cmd_line_argument("n", n, 4096); + cmd.get_cmd_line_argument("k", k, 4096); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations, 100); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "BMG GEMM Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Iterations\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class Gemm +> +struct ExampleRunner { + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + struct Memory { + ElementA* block_A; + ElementB* block_B; + ElementC* block_C; + ElementOutput* block_D; + ElementOutput* block_ref_D; + sycl::queue q; + + Memory(sycl::queue q, ProblemShapeType problem_shape_MNKL) : q(q) { + auto [M, N, K, L] = problem_shape_MNKL; + block_A = sycl::malloc_device(static_cast(M) * K * L, q); + block_B = sycl::malloc_device(static_cast(N) * K * L, q); + block_C = sycl::malloc_device(static_cast(M) * N * L, q); + block_D = sycl::malloc_device(static_cast(M) * N * L, q); + block_ref_D = sycl::malloc_device(static_cast(M) * N * L, q); + } + + ~Memory() { + sycl::free(block_A, q); + sycl::free(block_B, q); + sycl::free(block_C, q); + sycl::free(block_D, q); + sycl::free(block_ref_D, q); + } + + // delete other constructors so avoiding leaks is easy + Memory(const Memory&) = delete; + Memory(Memory&&) noexcept = delete; + Memory& operator=(const Memory&) = delete; + Memory& operator=(Memory&&) noexcept = delete; + }; + + // + // Methods + // + + bool verify(Memory& mem, const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(mem.block_A, LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(mem.block_B, LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(mem.block_C, LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(mem.block_ref_D, LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + alpha, + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta, + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual( + mem.block_ref_D, mem.block_D, M * N * L); + + return passed; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size, Memory& mem) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + cutlass::initialize_block(mem.block_A, M * K * L, seed + 2023); + cutlass::initialize_block(mem.block_B, N * K * L, seed + 2022); + cutlass::initialize_block(mem.block_C, M * N * L, seed + 2021); + } + + cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + auto q = compat::create_queue(); + Memory mem(q, problem_size); + initialize(problem_size, mem); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {mem.block_A, stride_A, mem.block_B, stride_B}, + {{options.alpha, options.beta}, mem.block_C, stride_C, mem.block_D, stride_D}, + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + if (workspace_size != 0) { + return cutlass::Status::kErrorInternal; + } + + if (gemm_op.can_implement(arguments) != cutlass::Status::kSuccess){ + std::cout << "Invalid Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::exit(1); + } + + CUTLASS_CHECK(gemm_op.initialize(arguments, nullptr, &q)); + + // Run the GEMM + CUTLASS_CHECK(gemm_op.run(&q)); + + q.wait_and_throw(); + + // Verify that the result is correct + bool passed = verify(mem, problem_size, options.alpha, options.beta); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if(!passed) return cutlass::Status::kErrorInternal; + + if (options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + gemm_op.run(&q); + } + + q.wait_and_throw(); + + float cute_time = timer.seconds() / options.iterations; + double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; + std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000); + } + + return cutlass::Status::kSuccess; + } + +}; + +int main(int argc, const char** argv) +{ + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + bool passed; + + // The code section below describes datatype for input, output matrices and computation between + // elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A + using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B + using ElementOutput = float; // <- data type of elements in output matrix D + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; + using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + + // Workgroup-level tile + using TileShape = Shape<_256, _256, _32>; + + // The Tile of this layout describes how 8x4x1 sub-groups tile the TileShape of <256, 256, 32>. + // This permutation (which can be thought of as a scatter operation on the default tiling) + // ensures that each sub-group operates on a contiguous 32x64x32 chunk (4x4x2 iterations) + // See 0t_mma_atom.md#TiledMMAs for more info. + // Sub-groups are arranged row-major (stride 4,1,0) for performance reasons. + using TiledMma = + typename TiledMMAHelper, Layout, + Layout, Stride<_4, _1, _0>>>::TiledMMA; + + constexpr int PipelineStages = 2; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, void, + XE_2D_U32x8x16_ST_N, + void, void>; + + // Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementInputA, + cutlass::gemm::TagToStrideA_t, + ElementInputB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + CUTLASS_CHECK(runner.run(options, hw_info)); + + return 0; +} \ No newline at end of file diff --git a/examples/00_bmg_gemm/legacy/CMakeLists.txt b/examples/00_bmg_gemm/legacy/CMakeLists.txt new file mode 100644 index 0000000000..50fbd31472 --- /dev/null +++ b/examples/00_bmg_gemm/legacy/CMakeLists.txt @@ -0,0 +1,57 @@ +# Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +set(TEST_BATCHES --l=2) +set(TEST_LARGE "--l=513 --m=8 --n=16384 --k=512") # B matrix capacity > uint32_max +set(TEST_SMALL_SHAPE --m=4 --n=8 --k=8 --l=2) + + +cutlass_example_add_executable( + 00_bmg_gemm_legacy + 00_bmg_gemm.cpp + TEST_COMMAND_OPTIONS + TEST_BATCHES + TEST_LARGE + TEST_SMALL_SHAPE +) + +set(TEST_SMALL_SHAPE_PADDABLE --m=1 --n=1 --k=2 --l=2) +cutlass_example_add_executable( + 00_bmg_gemm_padded_legacy + 00_bmg_gemm_padded.cpp + TEST_COMMAND_OPTIONS + TEST_BATCHES + TEST_SMALL_SHAPE_PADDABLE +) + +cutlass_example_add_executable( + 00_bmg_gemm_with_sycl_queue_legacy + 00_bmg_gemm_with_sycl_queue.cpp + TEST_COMMAND_OPTIONS + TEST_BATCHES +) \ No newline at end of file diff --git a/examples/01_bmg_gemm_with_collective_builder/01_bmg_gemm_with_collective_builder.cpp b/examples/01_bmg_gemm_with_collective_builder/01_bmg_gemm_with_collective_builder.cpp index 2c15047d49..1f598965a2 100644 --- a/examples/01_bmg_gemm_with_collective_builder/01_bmg_gemm_with_collective_builder.cpp +++ b/examples/01_bmg_gemm_with_collective_builder/01_bmg_gemm_with_collective_builder.cpp @@ -151,7 +151,6 @@ struct ExampleRunner { using ElementC = typename Gemm::ElementC; using ElementOutput = typename CollectiveEpilogue::ElementOutput; using ElementCompute = typename CollectiveEpilogue::ElementCompute; - using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; @@ -194,7 +193,7 @@ struct ExampleRunner { beta, ref_C, ref_D, - ElementAccumulator(0), + ElementAcc(0), L, // batch_count M * K, // batch_stride_A K * N, // batch_stride_B @@ -338,7 +337,7 @@ int main(int argc, const char** argv) constexpr int AlignmentB = sizeof(ElementInputB); constexpr int AlignmentC = sizeof(ElementAccumulator); constexpr int AlignmentD = sizeof(ElementOutput); - + using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::RowMajor; using LayoutC = cutlass::layout::RowMajor; @@ -346,7 +345,7 @@ int main(int argc, const char** argv) // Workgroup-level tile using TileShape = Shape<_256, _256, _32>; - + using CollectiveMainloop = cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp, ElementInputA, LayoutA, AlignmentA, @@ -358,21 +357,21 @@ int main(int argc, const char** argv) >::CollectiveOp; // Define a Linear Combination, Elementwise Activation (LinCombEltAct) epilogue with ReLU activation - using EpilogueOp = cutlass::epilogue::fusion::LinCombEltAct; using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp, TileShape, Shape<_1, _1, _1>, cutlass::epilogue::collective::EpilogueTileAuto, ElementComputeEpilogue, - ElementAccumulator, + ElementAccumulator, ElementAccumulator, LayoutC, AlignmentC, ElementOutput, LayoutD, AlignmentD, cutlass::epilogue::collective::EpilogueScheduleAuto, EpilogueOp >::CollectiveOp; - + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveMainloop, diff --git a/examples/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp b/examples/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp index ffc01d0825..7b65ae34cd 100644 --- a/examples/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp +++ b/examples/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp @@ -567,22 +567,18 @@ int main(int argc, const char** argv) using LayoutC = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor; - using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; - using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + using GmemTiledCopyA = void; //XE_LOAD_2D<16, 32, 32>; + using GmemTiledCopyB = void; //XE_LOAD_2D_VNNI<16, 32, 32>; // Workgroup-level tile using TileShape = Shape<_256, _256, _32>; - using TiledMma = - TiledMMA, - Layout, Stride<_4, _1, _0>>, - Tile, Stride<_1, _32, _8>>, - Layout, Stride<_1, _64, _16>>, _32>>; + using TiledMma = typename TiledMMAHelper>, Layout, Layout, Stride<_4, _1, _0>>>::TiledMMA; constexpr int PipelineStages = 2; // Dispatch to grouped gemm algorithm - using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16Group; - using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group; + using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1StagedGroup; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeGenericGroup; using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; @@ -592,15 +588,14 @@ int main(int argc, const char** argv) using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< EpilogueDispatchPolicy, TileShape, + void, // Epilogue tile (void = automatic) ElementAccumulator, cutlass::gemm::TagToStrideC_t, ElementOutput, cutlass::gemm::TagToStrideC_t, FusionCallBacks, - XE_2D_U32x8x16_LD_N, - void, void, - XE_2D_U32x8x16_ST_N, - void, void>; + void, // The copy atom used to load matrix C (void = automatic) + void>; // The copy atom used to store matrix D (void = automatic) // Mainloop using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< diff --git a/examples/04_bmg_grouped_gemm/legacy/04_bmg_grouped_gemm.cpp b/examples/04_bmg_grouped_gemm/legacy/04_bmg_grouped_gemm.cpp new file mode 100644 index 0000000000..ffc01d0825 --- /dev/null +++ b/examples/04_bmg_grouped_gemm/legacy/04_bmg_grouped_gemm.cpp @@ -0,0 +1,632 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief CUTLASS Intel BMG Group Gemm + + This example demonstrates fusing multiple GEMM operations into one kernel. + + Note that the scalar arguments to e.g. the standard 00_bmg_gemm example, have been + replaced with vector equivalents, as each individual GEMM has its own inputs and outputs, which + needn't be contiguous in memory. For example, where 00_bmg_gemm receives an `ElementA *` + defining Matrix A, grouped gemm receives a `ElementA **`, i.e. a pointer to pointers, each + pointing to a distinct Matrix A. Likewise, each individual GEMM operation may have its own alpha + and beta factors for linear combination. This example demonstrates two approaches: the user can + provide `options.alpha` and `options.beta`, in which case they will apply to all GEMMs; + otherwise, random values are generated per GEMM. + + Group GEMM scheduling (cutlass::gemm::GroupScheduler) is more complex than standard GEMM, + because each GEMM may have a unique size, only known at runtime. Thus, the scheduler will + distribute an a priori unknown number of tiles to each work-group. See + include/cutlass/gemm/kernel/xe_gemm_array_cooperative.hpp for implementation. + + Note that for simplicity, this example sets every GEMM in the group to the same shape. + + Verification for this example is a conventional GEMM kernel, executed iteratively per group. + + To build & run this example (from your build dir): + + $ ninja 04_bmg_grouped_gemm + $ ./examples/sycl/04_bmg_grouped_gemm/04_bmg_grouped_gemm + + Call with `--help` for information about available options. + + Note: the code may spill registers once compiled which will result in sub-optimal performance. This is because + of an issue inside Intel Graphics Compiler (IGC) related to VectorAliasBBThreshold being debugged internally. + To avoid register spills, build the example by setting the environment variable: + $ export IGC_VectorAliasBBThreshold=10000 +*/ +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_array_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "sycl_common.hpp" +#include "helper.h" + +#include + +using namespace cute; +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group + +using ElementAccumulator = float; // <- data type of accumulator +using ElementComputeEpilogue = float; // <- data type of epilogue operations +using ElementA = bfloat16_t; // <- data type of elements in input matrix A +using ElementB = bfloat16_t; // <- data type of elements in input matrix B +using ElementOutput = float; // <- data type of elements in output matrix D + +/////////////////////////////////////////////////////////////////////////////////////////////////// + + +// Command line options parsing +struct Options { + + bool error = false; + bool help = false; + + float alpha, beta; + int iterations; + int m, n, k, groups; + std::vector problem_sizes_host; + + Options() : error(false), help(false), alpha(FLT_MAX), beta(FLT_MAX), iterations(100), + m(5120), n(4096), k(4096), groups(2) { + problem_sizes_host.reserve(groups); + for(int i = 0; i < groups; i++) { + problem_sizes_host.push_back({m, n, k}); + } + } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 5120); + cmd.get_cmd_line_argument("n", n, 4096); + cmd.get_cmd_line_argument("k", k, 4096); + cmd.get_cmd_line_argument("groups", groups, 2); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations, 100); + + assert(groups > 0); + problem_sizes_host.clear(); + problem_sizes_host.reserve(groups); + for(int i = 0; i < groups; i++) { + problem_sizes_host.push_back({m, n, k}); + } + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "BMG Grouped GEMM\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM for all groups\n" + << " --n= Sets the N extent of the GEMM for all groups\n" + << " --k= Sets the K extent of the GEMM for all groups\n" + << " --groups= Sets the number of individual GEMM problems for Grouped GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "bmg_grouped_gemm" << " --m=5120 --n=4096 --k=4096 --groups=5 --alpha=2.5 --beta=0.5 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s, std::vector problem_sizes_host) const + { + // Number of real-valued multiply-adds + uint64_t fmas = uint64_t(); + + for (auto const & problem : problem_sizes_host) { + fmas += static_cast(get<0>(problem)) * + static_cast(get<1>(problem)) * + static_cast(get<2>(problem)); + } + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * uint64_t(fmas); + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class Gemm +> +struct ExampleRunner { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementAccumulator = ElementOutput; + + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + + // Host-side allocations + std::vector offset_A; + std::vector offset_B; + std::vector offset_C; + std::vector offset_D; + + std::vector stride_A_host; + std::vector stride_B_host; + std::vector stride_C_host; + std::vector stride_D_host; + + std::vector alpha_host; + std::vector beta_host; + + // Device-side allocations + cutlass::DeviceAllocation problem_sizes; + + // This example defines all matrices in a single allocation (e.g. block_A), but this is not a + // requirement. Matrix base pointers are read from device allocation (e.g. ptr_A) + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + + cutlass::DeviceAllocation ptr_A; + cutlass::DeviceAllocation ptr_B; + cutlass::DeviceAllocation ptr_C; + cutlass::DeviceAllocation ptr_D; + cutlass::DeviceAllocation ptr_ref_D; + + cutlass::DeviceAllocation stride_A; + cutlass::DeviceAllocation stride_B; + cutlass::DeviceAllocation stride_C; + cutlass::DeviceAllocation stride_D; + + // Note, this is an array of pointers to alpha and beta scaling values per group + cutlass::DeviceAllocation alpha_device; + cutlass::DeviceAllocation beta_device; + cutlass::DeviceAllocation block_alpha; + cutlass::DeviceAllocation block_beta; + + uint64_t seed = 0; + + // + // Methods + // + + bool verify(const Options &options) { + bool passed = true; + // Verify against individual reference GEMMs + for (int32_t i = 0; i < options.groups; ++i) { + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + cutlass::TensorRef ref_A(block_A.get() + offset_A.at(i), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get() + offset_B.at(i), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get() + offset_C.at(i), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get() + offset_D.at(i), LayoutD::packed({M, N})); + + // + // Compute reference output + // + cutlass::reference::device::GemmComplex( + {M, N, K}, + alpha_host.at(i), + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta_host.at(i), + ref_C, + ref_D, + ElementAccumulator(0), + 1, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + // Wait for kernel to finish + compat::wait(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + passed &= cutlass::reference::device::BlockCompareEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N); + if(!passed) + break; + } + return passed; + } + +/// Allocates device-side data +void allocate(const Options &options) { + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + + // Compute total allocation sizes across group + for (int32_t i = 0; i < options.groups; ++i) { + + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + // Offset into block allocation of each matrix base pointer + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + + int64_t elements_A = M * K; + int64_t elements_B = K * N; + int64_t elements_C = M * N; + int64_t elements_D = M * N; + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + + stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1})); + stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); + stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1})); + stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1})); + + } + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + block_ref_D.reset(total_elements_D); + block_alpha.reset(options.groups); + block_beta.reset(options.groups); +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + + uint64_t seed = 2020; + + problem_sizes.reset(options.groups); + problem_sizes.copy_from_host(options.problem_sizes_host.data()); + + // + // Assign pointers + // + + std::vector ptr_A_host(options.groups); + std::vector ptr_B_host(options.groups); + std::vector ptr_C_host(options.groups); + std::vector ptr_D_host(options.groups); + std::vector ptr_alpha_host(options.groups); + std::vector ptr_beta_host(options.groups); + + // Compute offsets, alpha & beta over group on host + for (int32_t i = 0; i < options.groups; ++i) { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + // Fill host vector of alpha & beta with random values if using per-group values + alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast((rand() % 5) + 1) : options.alpha); + beta_host.push_back((options.beta == FLT_MAX) ? static_cast(rand() % 5) : options.beta); + // Fill host ptr vectors with offset addresses into device alpha/beta blocks + ptr_alpha_host.at(i) = block_alpha.get() + i; + ptr_beta_host.at(i) = block_beta.get() + i; + } + + // Allocate device memory & copy from host + ptr_A.reset(options.groups); + // Per-group alpha and beta + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(options.groups); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_C.reset(options.groups); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(options.groups); + ptr_D.copy_from_host(ptr_D_host.data()); + + stride_A.reset(options.groups); + stride_A.copy_from_host(stride_A_host.data()); + + stride_B.reset(options.groups); + stride_B.copy_from_host(stride_B_host.data()); + + stride_C.reset(options.groups); + stride_C.copy_from_host(stride_C_host.data()); + + stride_D.reset(options.groups); + stride_D.copy_from_host(stride_D_host.data()); + + // Per-group alpha and beta ptrs + alpha_device.reset(options.groups); + alpha_device.copy_from_host(ptr_alpha_host.data()); + beta_device.reset(options.groups); + beta_device.copy_from_host(ptr_beta_host.data()); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + // Per-group alpha and beta values - note these are not directly passed to kernel - the pointers + // (alpha_device/beta_device) are passed instead + block_alpha.copy_from_host(alpha_host.data()); + block_beta.copy_from_host(beta_host.data()); +} + + /// Populates a Gemm::Arguments structure from the given commandline options + typename Gemm::Arguments args_from_options(const Options &options, const cutlass::KernelHardwareInfo& hw_info, bool host_problem_shapes_available = true) + { + typename Gemm::Arguments arguments; + decltype(arguments.epilogue.thread) fusion_args; + + if (options.alpha != FLT_MAX && options.beta != FLT_MAX) { + // If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches. + fusion_args.alpha = options.alpha; + fusion_args.beta = options.beta; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + // Single alpha and beta for all groups + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0}; + } + else { + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups. + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = alpha_device.get(); + fusion_args.beta_ptr_array = beta_device.get(); + // One alpha and beta per each group + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1}; + } + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerXeGroup::RasterOrderOptions; + + // Per-GEMM problem shape info may only exist on the device. + if (host_problem_shapes_available) { + arguments = typename Gemm::Arguments { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), options.problem_sizes_host.data()}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info, + {1, RasterOrderOptions::AlongN} + }; + } + else { + arguments = typename Gemm::Arguments { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), nullptr}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info, + {1, RasterOrderOptions::AlongN} + }; + } + + return arguments; + } + + cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info, bool host_problem_shapes_available = true) { + allocate(options); + initialize(options); + + Gemm gemm_op; + + auto arguments = args_from_options(options, hw_info, host_problem_shapes_available); + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(gemm_op.can_implement(arguments)); + + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); + + // Run the GEMM + CUTLASS_CHECK(gemm_op.run()); + + compat::wait(); + + // Verify that the result is correct + bool passed = verify(options); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if(!passed) return cutlass::Status::kErrorInternal; + + if (options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm_op.run()); + } + compat::wait(); + + float cute_time = timer.seconds() * 1000; + double cute_average_time = double(cute_time) / double(options.iterations); + double gflops = options.gflops(cute_average_time / 1000.0, options.problem_sizes_host); + + std::cout << " Problem Sizes, Alpha, Beta " << std::endl; + for (int32_t i = 0; i < options.groups; ++i) { + std::cout << " " << options.problem_sizes_host.at(i); + std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl; + } + std::cout << " Groups : " << options.groups << std::endl; + std::cout << " Avg runtime : " << cute_average_time << " ms" << std::endl; + std::cout << " GFLOPS : " << gflops << std::endl; + } + + return cutlass::Status::kSuccess; + } + +}; + +int main(int argc, const char** argv) +{ + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; + using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + + // Workgroup-level tile + using TileShape = Shape<_256, _256, _32>; + + using TiledMma = + TiledMMA, + Layout, Stride<_4, _1, _0>>, + Tile, Stride<_1, _32, _8>>, + Layout, Stride<_1, _64, _16>>, _32>>; + + constexpr int PipelineStages = 2; + // Dispatch to grouped gemm algorithm + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16Group; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group; + + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, void, + XE_2D_U32x8x16_ST_N, + void, void>; + +// Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementA, + cutlass::gemm::TagToStrideA_t, + ElementB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::GroupScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + CUTLASS_CHECK(runner.run(options, hw_info)); + + return 0; +} diff --git a/examples/04_bmg_grouped_gemm/legacy/CMakeLists.txt b/examples/04_bmg_grouped_gemm/legacy/CMakeLists.txt new file mode 100644 index 0000000000..42e2051d29 --- /dev/null +++ b/examples/04_bmg_grouped_gemm/legacy/CMakeLists.txt @@ -0,0 +1,42 @@ +# Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +set(TEST_GROUPS_2 --groups=2) +set(TEST_GROUPS_4 --groups=4) + +cutlass_example_add_executable( + 04_bmg_grouped_gemm_legacy + 04_bmg_grouped_gemm.cpp + TEST_COMMAND_OPTIONS + TEST_GROUPS_2 + TEST_GROUPS_4 +) +if(NOT DPCPP_SYCL_TARGET STREQUAL "spir64") + # TODO(codeplay): Remove these once IGC VectorAliasThreshold issue is fixed + target_link_options( 04_bmg_grouped_gemm PRIVATE -Xs "-options \"-igc_opts 'VectorAliasBBThreshold=10000'\"" ) +endif() diff --git a/examples/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_gelu.cpp b/examples/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_gelu.cpp index 0d330b0360..b11e231244 100644 --- a/examples/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_gelu.cpp +++ b/examples/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_gelu.cpp @@ -1,3 +1,4 @@ + /*************************************************************************************************** * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. * Copyright (C) 2025 Intel Corporation, All rights reserved. @@ -149,13 +150,12 @@ struct ExampleRunner { using ElementA = typename Gemm::ElementA; using ElementB = typename Gemm::ElementB; - using ElementAcc = typename Gemm::ElementAccumulator; + using ElementAccumulator = typename Gemm::ElementAccumulator; using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; using ElementC = typename Gemm::ElementC; using ElementOutput = typename CollectiveEpilogue::ElementOutput; using ElementCompute = typename CollectiveEpilogue::ElementCompute; - using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; @@ -343,38 +343,35 @@ int main(int argc, const char** argv) using LayoutC = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor; - using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; - using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + using GmemTiledCopyA = void; + using GmemTiledCopyB = void; // Workgroup-level tile using TileShape = Shape<_256, _256, _32>; - using TiledMma = - typename TiledMMAHelper, Layout, - Layout, Stride<_4, _1, _0>>>::TiledMMA; + using TiledMma = typename TiledMMAHelper>, Layout, Layout, Stride<_4, _1, _0>>>::TiledMMA; constexpr int PipelineStages = 2; - using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; - using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1Staged; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeGeneric; // Linear Combination + element-wise GELU epilogue using EpilogueOp = cutlass::epilogue::fusion::LinCombEltAct; - using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< EpilogueDispatchPolicy, TileShape, + void, ElementAccumulator, cutlass::gemm::TagToStrideC_t, ElementOutput, cutlass::gemm::TagToStrideC_t, - FusionCallBacks, - XE_2D_U32x8x16_LD_N, - void, void, - XE_2D_U32x8x16_ST_N, - void, void>; + FusionCallbacks, + void, + void>; // Mainloop using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< @@ -402,4 +399,4 @@ int main(int argc, const char** argv) CUTLASS_CHECK(runner.run(options, hw_info)); return 0; -} +} \ No newline at end of file diff --git a/examples/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_relu.cpp b/examples/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_relu.cpp index 1a21713b34..0f6a25af9c 100644 --- a/examples/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_relu.cpp +++ b/examples/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_relu.cpp @@ -149,13 +149,12 @@ struct ExampleRunner { using ElementA = typename Gemm::ElementA; using ElementB = typename Gemm::ElementB; - using ElementAcc = typename Gemm::ElementAccumulator; + using ElementAccumulator = typename Gemm::ElementAccumulator; using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; using ElementC = typename Gemm::ElementC; using ElementOutput = typename CollectiveEpilogue::ElementOutput; using ElementCompute = typename CollectiveEpilogue::ElementCompute; - using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; @@ -343,38 +342,35 @@ int main(int argc, const char** argv) using LayoutC = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor; - using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; - using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + using GmemTiledCopyA = void; + using GmemTiledCopyB = void; // Workgroup-level tile using TileShape = Shape<_256, _256, _32>; - using TiledMma = - typename TiledMMAHelper, Layout, - Layout, Stride<_4, _1, _0>>>::TiledMMA; + using TiledMma = typename TiledMMAHelper>, Layout, Layout, Stride<_4, _1, _0>>>::TiledMMA; constexpr int PipelineStages = 2; - using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; - using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1Staged; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeGeneric; // The Linear Combination with ReLU epilogue using EpilogueOp = cutlass::epilogue::fusion::LinCombEltAct; - using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< EpilogueDispatchPolicy, TileShape, + void, ElementAccumulator, cutlass::gemm::TagToStrideC_t, ElementOutput, cutlass::gemm::TagToStrideC_t, - FusionCallBacks, - XE_2D_U32x8x16_LD_N, - void, void, - XE_2D_U32x8x16_ST_N, - void, void>; + FusionCallbacks, + void, + void>; // Mainloop using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< @@ -402,4 +398,4 @@ int main(int argc, const char** argv) CUTLASS_CHECK(runner.run(options, hw_info)); return 0; -} +} \ No newline at end of file diff --git a/examples/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_silu.cpp b/examples/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_silu.cpp index d4f040ad33..f33cec5243 100644 --- a/examples/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_silu.cpp +++ b/examples/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_silu.cpp @@ -148,13 +148,12 @@ struct ExampleRunner { using ElementA = typename Gemm::ElementA; using ElementB = typename Gemm::ElementB; - using ElementAcc = typename Gemm::ElementAccumulator; + using ElementAccumulator = typename Gemm::ElementAccumulator; using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; using ElementC = typename Gemm::ElementC; using ElementOutput = typename CollectiveEpilogue::ElementOutput; using ElementCompute = typename CollectiveEpilogue::ElementCompute; - using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; @@ -342,38 +341,35 @@ int main(int argc, const char** argv) using LayoutC = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor; - using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; - using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + using GmemTiledCopyA = void; + using GmemTiledCopyB = void; // Workgroup-level tile using TileShape = Shape<_256, _256, _32>; - using TiledMma = - typename TiledMMAHelper, Layout, - Layout, Stride<_4, _1, _0>>>::TiledMMA; + using TiledMma = typename TiledMMAHelper>, Layout, Layout, Stride<_4, _1, _0>>>::TiledMMA; constexpr int PipelineStages = 2; - using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; - using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1Staged; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeGeneric; // The Linear Combination with SiLu epilogue using EpilogueOp = cutlass::epilogue::fusion::LinCombEltAct; - using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< EpilogueDispatchPolicy, TileShape, + void, ElementAccumulator, cutlass::gemm::TagToStrideC_t, ElementOutput, cutlass::gemm::TagToStrideC_t, - FusionCallBacks, - XE_2D_U32x8x16_LD_N, - void, void, - XE_2D_U32x8x16_ST_N, - void, void>; + FusionCallbacks, + void, + void>; // Mainloop using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< @@ -401,4 +397,4 @@ int main(int argc, const char** argv) CUTLASS_CHECK(runner.run(options, hw_info)); return 0; -} +} \ No newline at end of file diff --git a/examples/05_bmg_gemm_with_epilogues/legacy/05_bmg_gemm_with_epilogue_gelu.cpp b/examples/05_bmg_gemm_with_epilogues/legacy/05_bmg_gemm_with_epilogue_gelu.cpp new file mode 100644 index 0000000000..0d330b0360 --- /dev/null +++ b/examples/05_bmg_gemm_with_epilogues/legacy/05_bmg_gemm_with_epilogue_gelu.cpp @@ -0,0 +1,405 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief CUTLASS Intel BMG Gemm with GELU Activation Fn epilogue + + This example constructs and executes a standard GEMM fused with a GELU (Gaussian Error Linear + Unit) activation epilogue. Aside from the epilogue operation, it is identical to 00_bmg_gemm. + + CUTLASS 3.x epilogues are implemented using the Epilogue Visitor Tree design pattern, and + typically combine 'Linear Combination' (i.e. `D = alpha * A*B + beta * C`) with an additional + epilogue operation. + + In this case, the GELU Element-wise activation function is applied: + + // D = GELU(alpha * (A*B) + beta * C) + + To build & run this example (from your build dir): + + $ ninja 05_bmg_gemm_with_epilogue_gelu + $ ./examples/sycl/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_gelu + + Call with `--help` for information about available options +*/ + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_gelu.h" +#include "cutlass/tensor_view.h" +#include "cutlass/coord.h" + +#include "sycl_common.hpp" +#include "helper.h" + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + bool error; + + int m, n, k, l, iterations; + float alpha, beta; + + Options(): + help(false), + error(false), + m(5120), n(4096), k(4096), l(1), iterations(100), + alpha(1.f), beta(0.f) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 5120); + cmd.get_cmd_line_argument("n", n, 4096); + cmd.get_cmd_line_argument("k", k, 4096); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations, 100); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "BMG GEMM Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Iterations\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class Gemm +> +struct ExampleRunner { + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + alpha, + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta, + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + compat::wait(); + + using TensorView = cutlass::TensorView; + for(int batch = 0, offset = 0; batch < L; batch++, offset += M * N) { + cutlass::reference::device::TensorGeLu(TensorView(block_ref_D.get() + offset, LayoutD::packed({M, N}), + cutlass::make_Coord(M, N))); + } + + compat::wait(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual( + block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(static_cast(M) * K * L); + block_B.reset(static_cast(K) * N * L); + block_C.reset(static_cast(M) * N * L); + block_D.reset(static_cast(M) * N * L); + block_ref_D.reset(static_cast(M) * N * L); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + } + + cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(gemm_op.can_implement(arguments)); + + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); + + // Run the GEMM + CUTLASS_CHECK(gemm_op.run()); + + compat::wait(); + + // Verify that the result is correct + bool passed = verify(problem_size, options.alpha, options.beta); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if(!passed) return cutlass::Status::kErrorInternal; + + if (options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + gemm_op.run(); + } + compat::wait(); + + float cute_time = timer.seconds() / options.iterations; + double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; + std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000); + } + + return cutlass::Status::kSuccess; + } + +}; + +int main(int argc, const char** argv) +{ + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + bool passed; + + // The code section below describes datatype for input, output matrices and computation between + // elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A + using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B + using ElementOutput = float; // <- data type of elements in output matrix D + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; + using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + + // Workgroup-level tile + using TileShape = Shape<_256, _256, _32>; + + using TiledMma = + typename TiledMMAHelper, Layout, + Layout, Stride<_4, _1, _0>>>::TiledMMA; + + constexpr int PipelineStages = 2; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + + // Linear Combination + element-wise GELU epilogue + using EpilogueOp = cutlass::epilogue::fusion::LinCombEltAct; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, void, + XE_2D_U32x8x16_ST_N, + void, void>; + +// Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementInputA, + cutlass::gemm::TagToStrideA_t, + ElementInputB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + CUTLASS_CHECK(runner.run(options, hw_info)); + + return 0; +} diff --git a/examples/05_bmg_gemm_with_epilogues/legacy/05_bmg_gemm_with_epilogue_relu.cpp b/examples/05_bmg_gemm_with_epilogues/legacy/05_bmg_gemm_with_epilogue_relu.cpp new file mode 100644 index 0000000000..1a21713b34 --- /dev/null +++ b/examples/05_bmg_gemm_with_epilogues/legacy/05_bmg_gemm_with_epilogue_relu.cpp @@ -0,0 +1,405 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief CUTLASS Intel BMG Gemm with ReLU Activation Fn epilogue + + This example constructs and executes a standard GEMM fused with a ReLU (Rectified Linear Unit) + activation epilogue. Aside from the epilogue operation, it is identical to 00_bmg_gemm. + + CUTLASS 3.x epilogues are implemented using the Epilogue Visitor Tree design pattern, and + typically combine 'Linear Combination' (i.e. `D = alpha * A*B + beta * C`) with an additional + epilogue operation. + + In this case, the ReLU Element-wise activation function is applied: + + // D = ReLU(alpha * (A*B) + beta * C) + + To build & run this example (from your build dir): + + $ ninja 05_bmg_gemm_with_epilogue_relu + $ ./examples/sycl/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_relu + + Call with `--help` for information about available options +*/ + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_relu.h" +#include "cutlass/tensor_view.h" +#include "cutlass/coord.h" + +#include "sycl_common.hpp" +#include "helper.h" + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + bool error; + + int m, n, k, l, iterations; + float alpha, beta; + + Options(): + help(false), + error(false), + m(5120), n(4096), k(4096), l(1), iterations(100), + alpha(1.f), beta(0.f) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 5120); + cmd.get_cmd_line_argument("n", n, 4096); + cmd.get_cmd_line_argument("k", k, 4096); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations, 100); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "BMG GEMM Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Iterations\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class Gemm +> +struct ExampleRunner { + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + alpha, + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta, + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + compat::wait(); + + using TensorView = cutlass::TensorView; + for(int batch = 0, offset = 0; batch < L; batch++, offset += M * N) { + cutlass::reference::device::TensorReLu(TensorView(block_ref_D.get() + offset, LayoutD::packed({M, N}), + cutlass::make_Coord(M, N))); + } + + compat::wait(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual( + block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(static_cast(M) * K * L); + block_B.reset(static_cast(K) * N * L); + block_C.reset(static_cast(M) * N * L); + block_D.reset(static_cast(M) * N * L); + block_ref_D.reset(static_cast(M) * N * L); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + } + + cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(gemm_op.can_implement(arguments)); + + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); + + // Run the GEMM + CUTLASS_CHECK(gemm_op.run()); + + compat::wait(); + + // Verify that the result is correct + bool passed = verify(problem_size, options.alpha, options.beta); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if(!passed) return cutlass::Status::kErrorInternal; + + if (options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + gemm_op.run(); + } + compat::wait(); + + float cute_time = timer.seconds() / options.iterations; + double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; + std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000); + } + + return cutlass::Status::kSuccess; + } + +}; + +int main(int argc, const char** argv) +{ + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + bool passed; + + // The code section below describes datatype for input, output matrices and computation between + // elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A + using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B + using ElementOutput = float; // <- data type of elements in output matrix D + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; + using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + + // Workgroup-level tile + using TileShape = Shape<_256, _256, _32>; + + using TiledMma = + typename TiledMMAHelper, Layout, + Layout, Stride<_4, _1, _0>>>::TiledMMA; + + constexpr int PipelineStages = 2; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + + // The Linear Combination with ReLU epilogue + using EpilogueOp = cutlass::epilogue::fusion::LinCombEltAct; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, void, + XE_2D_U32x8x16_ST_N, + void, void>; + +// Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementInputA, + cutlass::gemm::TagToStrideA_t, + ElementInputB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + CUTLASS_CHECK(runner.run(options, hw_info)); + + return 0; +} diff --git a/examples/05_bmg_gemm_with_epilogues/legacy/05_bmg_gemm_with_epilogue_silu.cpp b/examples/05_bmg_gemm_with_epilogues/legacy/05_bmg_gemm_with_epilogue_silu.cpp new file mode 100644 index 0000000000..d4f040ad33 --- /dev/null +++ b/examples/05_bmg_gemm_with_epilogues/legacy/05_bmg_gemm_with_epilogue_silu.cpp @@ -0,0 +1,404 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief CUTLASS Intel BMG Gemm with SiLu Activation Fn epilogue + + This example constructs and executes a standard GEMM fused with a SiLu (Sigmoid Linear Unit) + activation epilogue. Aside from the epilogue operation, it is identical to + 05_bmg_gemm_with_epilogue_relu. + + The SiLu Element-wise activation function is applied as: + + // D = SiLu(alpha * (A*B) + beta * C) + + To build & run this example (from your build dir): + + $ ninja 05_bmg_gemm_with_epilogue_silu + $ ./examples/sycl/05_bmg_gemm_with_epilogues/05_bmg_gemm_with_epilogue_silu + + Call with `--help` for information about available options +*/ + + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_silu.h" +#include "cutlass/tensor_view.h" +#include "cutlass/coord.h" + +#include "sycl_common.hpp" +#include "helper.h" + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + bool error; + + int m, n, k, l, iterations; + float alpha, beta; + + Options(): + help(false), + error(false), + m(5120), n(4096), k(4096), l(1), iterations(100), + alpha(1.f), beta(0.f) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 5120); + cmd.get_cmd_line_argument("n", n, 4096); + cmd.get_cmd_line_argument("k", k, 4096); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations, 100); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "BMG GEMM Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Iterations\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class Gemm +> +struct ExampleRunner { + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + alpha, + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta, + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + compat::wait(); + + using TensorView = cutlass::TensorView; + for(int batch = 0, offset = 0; batch < L; batch++, offset += M * N) { + cutlass::reference::device::TensorSiLu(TensorView(block_ref_D.get() + offset, LayoutD::packed({M, N}), + cutlass::make_Coord(M, N))); + } + + compat::wait(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual( + block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(M * K * L); + block_B.reset(K * N * L); + block_C.reset(M * N * L); + block_D.reset(M * N * L); + block_ref_D.reset(M * N * L); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + } + + cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(gemm_op.can_implement(arguments)); + + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); + + // Run the GEMM + CUTLASS_CHECK(gemm_op.run()); + + compat::wait(); + + // Verify that the result is correct + bool passed = verify(problem_size, options.alpha, options.beta); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if(!passed) return cutlass::Status::kErrorInternal; + + if (options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + gemm_op.run(); + } + compat::wait(); + + float cute_time = timer.seconds() / options.iterations; + double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; + std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000); + } + + return cutlass::Status::kSuccess; + } + +}; + +int main(int argc, const char** argv) +{ + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + bool passed; + + // The code section below describes datatype for input, output matrices and computation between + // elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A + using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B + using ElementOutput = float; // <- data type of elements in output matrix D + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; + using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + + // Workgroup-level tile + using TileShape = Shape<_256, _256, _32>; + + using TiledMma = + typename TiledMMAHelper, Layout, + Layout, Stride<_4, _1, _0>>>::TiledMMA; + + constexpr int PipelineStages = 2; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + + // The Linear Combination with SiLu epilogue + using EpilogueOp = cutlass::epilogue::fusion::LinCombEltAct; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, void, + XE_2D_U32x8x16_ST_N, + void, void>; + + // Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementInputA, + cutlass::gemm::TagToStrideA_t, + ElementInputB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + CUTLASS_CHECK(runner.run(options, hw_info)); + + return 0; +} diff --git a/examples/05_bmg_gemm_with_epilogues/legacy/CMakeLists.txt b/examples/05_bmg_gemm_with_epilogues/legacy/CMakeLists.txt new file mode 100644 index 0000000000..01f30aeeaf --- /dev/null +++ b/examples/05_bmg_gemm_with_epilogues/legacy/CMakeLists.txt @@ -0,0 +1,51 @@ +# Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. +# Copyright (C) 2025 Intel Corporation, All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +set(TEST_BATCHES --l=2) + +cutlass_example_add_executable( + 05_bmg_gemm_with_epilogue_gelu_legacy + 05_bmg_gemm_with_epilogue_gelu.cpp + TEST_COMMAND_OPTIONS + TEST_BATCHES +) + +cutlass_example_add_executable( + 05_bmg_gemm_with_epilogue_relu_legacy + 05_bmg_gemm_with_epilogue_relu.cpp + TEST_COMMAND_OPTIONS + TEST_BATCHES +) + +cutlass_example_add_executable( + 05_bmg_gemm_with_epilogue_silu_legacy + 05_bmg_gemm_with_epilogue_silu.cpp + TEST_COMMAND_OPTIONS + TEST_BATCHES +) \ No newline at end of file diff --git a/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp b/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp new file mode 100644 index 0000000000..9de908336f --- /dev/null +++ b/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp @@ -0,0 +1,179 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Flash Attention V2 Prefill for Intel BMG + + This example constructs and executes a Flash Attention Prefill kernel on Intel BMG. The + definition of the GEMM, options etc for this example are defined in the associated + bmg_flash_attn_runner.hpp header file. + + See https://arxiv.org/pdf/2307.08691 for details of Flash Attention V2 algorithm + + To run this example: + $ ./examples/sycl/06_bmg_flash_attention/06_xe_fmha_fwd --seq_len_qo=512 + --seq_len_kv=512 --head_size_vo=128 --head_size_qk=128 + + To build & run this example (from your build dir): + + $ ninja 06_xe_fmha_fwd + $ ./examples/sycl/06_bmg_flash_attention/06_xe_fmha_fwd + + Call with `--help` for information about available options +*/ + +#include "xe_fmha_fwd_runner.hpp" + +int main(int argc, const char **argv) { + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // Define the work-group tile shape depending on the head-size of the second matmul + +#ifdef PREFILL +#if HEAD_DIM == 16 + /* Tiny config for testing */ + using ShapeQK = Shape<_1, _16, _16>; // (q,k,d) + using ShapePV = Shape<_1, _16, _16>; // (q,v,k) + using ShapeOut = Shape<_1, _16>; // (q,v) + using SubgroupLayoutQK = Layout>; + +#elif HEAD_DIM == 64 + using ShapeQK = Shape<_128, _64, _32>; + using ShapePV = Shape<_128, _32, _64>; + using ShapeOut = Shape<_128, _64>; + using SubgroupLayoutQK = Layout>; + +#elif HEAD_DIM == 96 + using ShapeQK = Shape<_128, _64, _32>; + using ShapePV = Shape<_128, _32, _64>; + using ShapeOut = Shape<_128, _96>; + using SubgroupLayoutQK = Layout>; + +#elif HEAD_DIM == 128 + using ShapeQK = Shape<_128, _64, _32>; + using ShapePV = Shape<_128, _32, _64>; + using ShapeOut = Shape<_128, _128>; + using SubgroupLayoutQK = Layout>; + +#elif HEAD_DIM == 192 + using ShapeQK = Shape<_256, _64, _32>; + using ShapePV = Shape<_256, _32, _64>; + using ShapeOut = Shape<_256, _192>; + using SubgroupLayoutQK = Layout>; + +#endif +#elif defined(DECODE) + +#if PERSISTENT +#define NUM_SG _16 +#define KV_TILE_SIZE _256 +#else +#define NUM_SG _8 +#define KV_TILE_SIZE _512 +#endif + +#if HEAD_DIM == 16 + /* Tiny config for testing */ + using ShapeQK = Shape<_1, _16, _16>; // (q,k,d) + using ShapePV = Shape<_1, _16, _16>; // (q,v,k) + using ShapeOut = Shape<_1, _16>; // (q,v) + using SubgroupLayoutQK = Layout>; + +#elif HEAD_DIM == 64 + using ShapeQK = Shape<_1, KV_TILE_SIZE, _64>; + using ShapePV = Shape<_1, _32, KV_TILE_SIZE>; + using ShapeOut = Shape<_1, _64>; + using SubgroupLayoutQK = Layout>; + +#elif HEAD_DIM == 96 + using ShapeQK = Shape<_1, KV_TILE_SIZE, _64>; + using ShapePV = Shape<_1, _32, KV_TILE_SIZE>; + using ShapeOut = Shape<_1, _96>; + using SubgroupLayoutQK = Layout>; + +#elif HEAD_DIM == 128 + using ShapeQK = Shape<_1, KV_TILE_SIZE, _64>; + using ShapePV = Shape<_1, _32, KV_TILE_SIZE>; + using ShapeOut = Shape<_1, _128>; + using SubgroupLayoutQK = Layout>; + +#elif HEAD_DIM == 192 + using ShapeQK = Shape<_1, KV_TILE_SIZE, _64>; + using ShapePV = Shape<_1, _32, KV_TILE_SIZE>; + using ShapeOut = Shape<_1, _192>; + using SubgroupLayoutQK = Layout>; +#endif +#else +#error Either DECODE or PREFILL should be defined. +#endif + +#ifdef DECODE + constexpr int PipelineStages = 1; +#else + constexpr int PipelineStages = 2; +#endif +#ifdef IS_FLOAT_E5M2 + using ElementQ = cutlass::float_e5m2_t; + using ElementK = cutlass::float_e5m2_t; + using ElementV = cutlass::float_e5m2_t; +#elif defined(IS_FLOAT_E4M3) + using ElementQ = cutlass::float_e4m3_t; + using ElementK = cutlass::float_e4m3_t; + using ElementV = cutlass::float_e4m3_t; +#else + using ElementQ = bfloat16_t; + using ElementK = bfloat16_t; + using ElementV = bfloat16_t; +#endif + +#if PERSISTENT + return FMHAConfig::run(options); +#else + return options.is_causal ? FMHAConfig::run(options) + : FMHAConfig::run(options); +#endif +} diff --git a/examples/06_bmg_flash_attention/CMakeLists.txt b/examples/06_bmg_flash_attention/CMakeLists.txt index 39752da4ed..435a65e6ea 100644 --- a/examples/06_bmg_flash_attention/CMakeLists.txt +++ b/examples/06_bmg_flash_attention/CMakeLists.txt @@ -33,6 +33,39 @@ set(TEST_NO_PAGED "") set(TEST_PAGED "--use_paged_kv") foreach(HEAD_DIM 64 96 128 192) + foreach(INPUT_TYPE bfloat16_t float_e5m2_t float_e4m3_t) + cutlass_example_add_executable( + 06_xe_fmha_fwd_prefill_${INPUT_TYPE}_hdim${HEAD_DIM} + 06_xe_fmha_fwd.cpp + ) + + cutlass_example_add_executable( + 06_xe_fmha_fwd_decode_${INPUT_TYPE}_hdim${HEAD_DIM} + 06_xe_fmha_fwd.cpp + ) + + if (NOT HEAD_DIM STREQUAL 192) + # specific test for persistent kernel + cutlass_example_add_executable( + 06_xe_fmha_fwd_decode_persistent_${INPUT_TYPE}_hdim${HEAD_DIM} + 06_xe_fmha_fwd.cpp + ) + endif() + + if(INPUT_TYPE STREQUAL "bfloat16_t") + set(INPUT_MACRO "IS_BFLOAT16") + elseif(INPUT_TYPE STREQUAL "float_e5m2_t") + set(INPUT_MACRO "IS_FLOAT_E5M2") + elseif(INPUT_TYPE STREQUAL "float_e4m3_t") + set(INPUT_MACRO "IS_FLOAT_E4M3") + endif() + + target_compile_definitions(06_xe_fmha_fwd_prefill_${INPUT_TYPE}_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM} PREFILL SHOW_DIFF=1 INPUT_TYPE=${INPUT_TYPE} ${INPUT_MACRO}) + target_compile_definitions(06_xe_fmha_fwd_decode_${INPUT_TYPE}_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM} DECODE SHOW_DIFF=1 INPUT_TYPE=${INPUT_TYPE} ${INPUT_MACRO}) + if (NOT HEAD_DIM STREQUAL 192) + target_compile_definitions(06_xe_fmha_fwd_decode_persistent_${INPUT_TYPE}_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM} DECODE PERSISTENT SHOW_DIFF=1 INPUT_TYPE=${INPUT_TYPE} ${INPUT_MACRO}) + endif() + endforeach() cutlass_example_add_executable( 06_bmg_prefill_attention_hdim${HEAD_DIM} diff --git a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp new file mode 100644 index 0000000000..b1e9f0284d --- /dev/null +++ b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp @@ -0,0 +1,729 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/util/packed_stride.hpp" +#include "flash_attention_v2/collective/fmha_fusion.hpp" +#include "flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp" +#include "flash_attention_v2/kernel/xe_tile_scheduler.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/sycl_event_manager.hpp" +#include +#include + +#include "helper.h" +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "sycl_common.hpp" + +#include + +using namespace cute; + +// Command line options parsing +struct Options { + + bool help; + bool error; + bool is_causal; + bool varlen = false; + std::string scheduler; + + int batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo, iterations; + float softmax_scale; + + Options() + : help(false), error(false), is_causal(false), varlen(false), batch(32), num_heads_q(16), num_heads_kv(16), seq_len_qo(512), head_size_qk(128), + seq_len_kv(512), head_size_vo(128), iterations(100), softmax_scale(1.f), scheduler("Individual") {} + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + if (cmd.check_cmd_line_flag("is_causal")) { + is_causal = true; + } + + if (cmd.check_cmd_line_flag("varlen")) { + varlen = true; + } + + cmd.get_cmd_line_argument("scheduler", scheduler, std::string("Individual")); + +#ifdef PERSISTENT + cmd.get_cmd_line_argument("batch", batch, 1); + cmd.get_cmd_line_argument("num_heads_q", num_heads_q, 8); + cmd.get_cmd_line_argument("num_heads_kv", num_heads_kv, 1); + cmd.get_cmd_line_argument("seq_len_kv", seq_len_kv, 4096); +#else + cmd.get_cmd_line_argument("batch", batch, 32); + cmd.get_cmd_line_argument("num_heads_q", num_heads_q, 16); + cmd.get_cmd_line_argument("num_heads_kv", num_heads_kv, num_heads_q); + cmd.get_cmd_line_argument("seq_len_kv", seq_len_kv, 512); +#endif +#ifdef DECODE + cmd.get_cmd_line_argument("seq_len_qo", seq_len_qo, 1); +#else + cmd.get_cmd_line_argument("seq_len_qo", seq_len_qo, seq_len_kv); +#endif + cmd.get_cmd_line_argument("head_size_vo", head_size_vo, HEAD_DIM); + cmd.get_cmd_line_argument("head_size_qk", head_size_qk, head_size_vo); + cmd.get_cmd_line_argument("iterations", iterations, 100); + + softmax_scale = 1 / sqrt(static_cast(head_size_qk)); + } + + /// Prints the usage statement. + std::ostream &print_usage(std::ostream &out) const { + + out << "Xe FMHA Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --is_causal Apply Causal Mask to the output of first Matmul\n" + << " --varlen Enable variable sequence length\n" + << " --scheduler=\"Value\" Choose between Individual or Persistent Scheduler\n" + << " --batch= Sets the Batch Size of the Multi-Head Self Attention module\n" + << " --num_heads_q= Sets the Number of Attention Heads for Key-Value pair the Multi-Head Self Attention module\n" + << " --num_heads_kv= Sets the Number of Attention Heads for Query input in the Multi-Head Self Attention module\n" + << " --seq_len_qo= Sets the Sequence length of the Query input in Multi-Head Self Attention module\n" + << " --seq_len_kv= Sets the Sequence length of the Key-Value pair in Multi-Head Self Attention module\n" + << " --head_size_qk= Sets the Attention Head dimension of the 1st Matrix Multiplication in Multi-Head Self Attention module\n" + << " --head_size_vo= Sets the Attention Head dimension of the 2nd Matrix Multiplication in Multi-Head Self Attention module\n" + << " --iterations= Iterations\n\n"; + + return out; + } +}; + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// Helpers + +template class ConvertTensorKernelTag{}; + +template +void convert_tensor(const SrcT* d_src, DstT* d_dst, size_t size) { + using Tag = ConvertTensorKernelTag; + compat::get_default_queue().parallel_for(size, [=](auto indx) { + d_dst[indx] = static_cast(d_src[indx]); + }).wait(); +} + +template inline auto in_memory(cutlass::DeviceAllocation& in) { + using OutT = cute::conditional_t<(sizeof_bits_v <= 8), half_t, InT>; + if constexpr (!is_same_v) { + cutlass::DeviceAllocation out(in.size()); + convert_tensor(in.get(), out.get(), in.size()); + return out; + } else { + return in; + }; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// 3 input matrices: (K)eys, (Q)ueries and (V)alues. +using LayoutQ = cutlass::layout::RowMajor; +using LayoutK = cutlass::layout::ColumnMajor; +using LayoutV = cutlass::layout::RowMajor; +using LayoutO = cutlass::layout::RowMajor; + +template struct ExampleRunner { + + using StrideQ = typename FMHAKernel::StrideQ; + using StrideK = typename FMHAKernel::StrideK; + using StrideV = typename FMHAKernel::StrideV; + using StrideO = typename FMHAKernel::StrideO; + + using ElementQ = typename FMHAKernel::ElementQ; + using ElementK = typename FMHAKernel::ElementK; + using ElementV = typename FMHAKernel::ElementV; + using ElementO = typename FMHAKernel::ElementO; + + using CollectiveMainloop = typename FMHAKernel::CollectiveMainloop; + using ElementS = typename CollectiveMainloop::ElementS; + + using ProblemShapeType = cutlass::fmha::kernel::FMHAProblemShape; + + // + // Data members + // + + /// Initialization + StrideQ stride_Q; + StrideK stride_K; + StrideV stride_V; + StrideO stride_O; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_Q; + cutlass::DeviceAllocation block_K; + cutlass::DeviceAllocation block_V; + cutlass::DeviceAllocation block_O; + cutlass::DeviceAllocation block_ref_O; + + std::vector cumulative_seqlen_q; + std::vector cumulative_seqlen_kv; + cutlass::DeviceAllocation device_cumulative_seqlen_q; + cutlass::DeviceAllocation device_cumulative_seqlen_kv; + + // + // Methods + // + + template + auto initialize_varlen(const ProblemShape& problem_size) { + int num_batches = get<0>(problem_size); + + // generate Q as --b times + // gaussian (--Q, --Q / 2) sampled positive + // track cumulative + std::mt19937 rng(0x202305151552ull); + std::normal_distribution dist_q(get<3>(problem_size), get<3>(problem_size) / 2); + std::normal_distribution dist_kv(get<4>(problem_size), get<4>(problem_size) / 2); + + // Use Cacheline Size to calculate alignment + constexpr int cacheline_bytes = 64; + constexpr int AlignmentQ = cacheline_bytes / sizeof(ElementQ); // Alignment of Q matrix in units of elements + constexpr int AlignmentKV = cacheline_bytes / sizeof(ElementK); // Alignment of Kand V matrix in units of elements + + auto generate_positive_int = [](auto& dist, auto& gen) { + int result = 0; + do { + result = static_cast(dist(gen)); + } while (result <= 0); + return result; + }; + + cumulative_seqlen_q = {0}; + cumulative_seqlen_kv = {0}; + + int total_seqlen_q = 0; + int total_seqlen_kv = 0; + int max_seqlen_q = 0; + int max_seqlen_kv = 0; + + for (int i = 0; i < num_batches; i++) { + int seqlen_q = cutlass::round_up(generate_positive_int(dist_q, rng), AlignmentQ); + int seqlen_kv = cutlass::round_up(generate_positive_int(dist_kv, rng), AlignmentKV); + + total_seqlen_q += seqlen_q; + total_seqlen_kv += seqlen_kv; + + max_seqlen_q = std::max(max_seqlen_q, seqlen_q); + max_seqlen_kv = std::max(max_seqlen_kv, seqlen_kv); + + cumulative_seqlen_q.push_back(cumulative_seqlen_q.back() + seqlen_q); + cumulative_seqlen_kv.push_back(cumulative_seqlen_kv.back() + seqlen_kv); + } + + ProblemShape problem_size_for_init = problem_size; + get<0>(problem_size_for_init) = 1; + get<3>(problem_size_for_init) = total_seqlen_q; + get<4>(problem_size_for_init) = total_seqlen_kv; + + ProblemShapeType problem_size_for_launch; + problem_size_for_launch.batch = get<0>(problem_size); + problem_size_for_launch.num_heads_q = get<1>(problem_size); + problem_size_for_launch.num_heads_kv = get<2>(problem_size); + problem_size_for_launch.seq_len_qo = cutlass::fmha::collective::VariableLength{max_seqlen_q}; + problem_size_for_launch.seq_len_kv = cutlass::fmha::collective::VariableLength{max_seqlen_kv}; + problem_size_for_launch.head_size_qk = get<5>(problem_size); + problem_size_for_launch.head_size_vo = get<6>(problem_size); + + return cute::make_tuple(problem_size_for_init, problem_size_for_launch); + } + + + + bool verify(ProblemShapeType shape, bool is_causal) { + + if constexpr (isVarLen) { + int max_seq_len_q = shape.seq_len_qo; + int max_seq_len_kv = shape.seq_len_kv; + shape.seq_len_qo = cutlass::fmha::collective::VariableLength{max_seq_len_q, cumulative_seqlen_q.data()}; + shape.seq_len_kv = cutlass::fmha::collective::VariableLength{max_seq_len_kv, cumulative_seqlen_kv.data()}; + } + + auto batch = shape.batch; + auto num_heads_q = shape.num_heads_q; + auto num_heads_kv = shape.num_heads_kv; + auto head_size_qk = shape.head_size_qk; + auto head_size_vo = shape.head_size_vo; + int seq_len_qo, seq_len_kv; + + auto block_Q_ = in_memory(block_Q); + auto block_K_ = in_memory(block_K); + auto block_V_ = in_memory(block_V); + + using ElementV_ = std::remove_pointer_t; + + int offset_q = 0; + int offset_k = 0; + int offset_v = 0; + int offset_o = 0; + + // loop over the batch dimension to compute the output + // to avoid the risk of running out of device memory + int q_group_size = num_heads_q/num_heads_kv; + for (int b = 0; b < batch; b++) { + if constexpr (isVarLen) { + auto logical_seq_shape = cutlass::fmha::collective::apply_variable_length(make_shape(shape.seq_len_qo, shape.seq_len_kv), b); + seq_len_qo = get<0>(logical_seq_shape); + seq_len_kv = get<1>(logical_seq_shape); + } else { + seq_len_qo = shape.seq_len_qo; + seq_len_kv = shape.seq_len_kv; + } + + int kv_group_update=1; + for (int h = 0; h < num_heads_q; h++) { + cutlass::DeviceAllocation block_S; + block_S.reset(seq_len_qo * seq_len_kv); + + cutlass::TensorRef ref_Q(block_Q_.get() + offset_q, LayoutQ::packed({seq_len_qo, head_size_qk})); + cutlass::TensorRef ref_K(block_K_.get() + offset_k, LayoutK::packed({head_size_qk, seq_len_kv})); + cutlass::TensorRef ref_V(block_V_.get() + offset_v, LayoutV::packed({seq_len_kv, head_size_vo})); + cutlass::TensorRef ref_S(block_S.get(), LayoutQ::packed({seq_len_qo, seq_len_kv})); + + cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv, head_size_qk}, 1.f, ref_Q, + cutlass::ComplexTransform::kNone, ref_K, cutlass::ComplexTransform::kNone, + 0.f, ref_S, ref_S, ElementS(0), + 1, // batch_count + seq_len_qo * head_size_qk, // batch_stride_Q + seq_len_kv * head_size_qk, // batch_stride_K + seq_len_qo * seq_len_kv, // batch_stride_S + seq_len_qo * seq_len_kv // batch_stride_S + ); + + compat::wait(); + + std::vector host_S(block_S.size()); + compat::memcpy(host_S.data(), block_S.get(), host_S.size()); + + // delete this memory as it is no longer needed + block_S.reset(); + auto offset = cute::min(seq_len_qo, seq_len_kv); + auto discard_seq_coord = seq_len_qo - offset; + auto full_tile_offset = seq_len_kv - offset; + if (is_causal) { + // apply mask to S + for (int row = 0; row < seq_len_qo; row++) { + for (int col = 0; col < seq_len_kv; col++) { + if ((col - full_tile_offset) > (row - discard_seq_coord)) + host_S[col + row * seq_len_kv] = ElementS{-INFINITY}; + } + } + } + + // compute max element per row of S + std::vector max_vec(seq_len_qo, ElementS{-INFINITY}); + for (int row = 0; row < seq_len_qo; row++) { + int idx = row * seq_len_kv; + int max_idx = row; + max_vec[max_idx] = host_S[idx++]; + for (int col = 1; col < seq_len_kv; col++, idx++) { + if (max_vec[max_idx] < host_S[idx]) + max_vec[max_idx] = host_S[idx]; + } + } + + // compute exp of S + for (int row = 0; row < seq_len_qo; row++) { + int idx = row * seq_len_kv; + int max_idx = row; + for (int col = 0; col < seq_len_kv; col++, idx++) { + /* FIXME: use softmax_scale instead of assuming its value here */ + host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / sqrt(static_cast((head_size_qk)))); + } + } + + // compute sum per row of S + std::vector sum_vec(seq_len_qo, ElementS{0}); + for (int row = 0; row < seq_len_qo; row++) { + int idx = row * seq_len_kv; + int sum_idx = row; + for (int col = 0; col < seq_len_kv; col++, idx++) { + sum_vec[sum_idx] += host_S[idx]; + } + + // scale each row with the sum to compute softmax + idx = row * seq_len_kv; + sum_idx = row; + for (int col = 0; col < seq_len_kv; col++, idx++) { + if(is_causal && row < discard_seq_coord) { + host_S[idx] = 0; + } else { + host_S[idx] /= sum_vec[sum_idx]; + } + } + } + + std::vector host_P(host_S.size()); + for (int p = 0; p < host_P.size(); p++) + host_P[p] = static_cast(host_S[p]); + + cutlass::DeviceAllocation block_P; + block_P.reset(host_P.size()); + + compat::memcpy(block_P.get(), host_P.data(), host_P.size()); + + cutlass::TensorRef ref_P(block_P.get(), LayoutQ::packed({seq_len_qo, seq_len_kv})); + + cutlass::DeviceAllocation block_acc; + block_acc.reset(seq_len_qo * head_size_vo); + cutlass::TensorRef ref_acc(block_acc.get(), LayoutO::packed({seq_len_qo, head_size_vo})); + + cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv}, ElementS{1}, ref_P, + cutlass::ComplexTransform::kNone, ref_V, cutlass::ComplexTransform::kNone, + ElementS{0}, ref_acc, ref_acc, ElementS{0}, + 1, // batch_count + seq_len_qo * seq_len_kv, // batch_stride_P + seq_len_kv * head_size_vo, // batch_stride_V + seq_len_qo * head_size_vo, // batch_stride_O + seq_len_qo * head_size_vo // batch_stride_O + ); + + compat::wait(); + // delete this memory as it is no longer needed + block_P.reset(); + + std::vector vec_acc(block_acc.size()); + compat::memcpy(vec_acc.data(), block_acc.get(), vec_acc.size()); + + // delete this memory as it is no longer needed + block_acc.reset(); + std::vector vec_out(vec_acc.size()); + for(int i = 0; i < vec_out.size(); i++) { + vec_out[i] = static_cast(vec_acc[i]); + } + compat::memcpy(block_ref_O.get() + offset_o, vec_out.data(), vec_out.size()); + + offset_q += seq_len_qo * head_size_qk; + if(kv_group_update % q_group_size==0) { + offset_k += seq_len_kv * head_size_qk; + offset_v += seq_len_kv * head_size_vo; + } + kv_group_update++; + offset_o += seq_len_qo * head_size_vo; + } + } + + compat::wait(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_O.get(), block_O.get(), + block_O.size(), ElementO{0.05}, ElementO{0.05}); + + return passed; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + ProblemShapeType initialize(const Options &options) { + auto problem_shape_in = cute::make_tuple(options.batch, options.num_heads_q, options.num_heads_kv, options.seq_len_qo, options.seq_len_kv, options.head_size_qk, options.head_size_vo); + ProblemShapeType shape; + + decltype(problem_shape_in) problem_size; + + if constexpr (isVarLen) { + auto [problem_shape_init, problem_shape_launch] = initialize_varlen(problem_shape_in); + problem_size = problem_shape_init; + shape = problem_shape_launch; + } else { + problem_size = problem_shape_in; + shape.batch = options.batch; + shape.num_heads_q = options.num_heads_q; + shape.num_heads_kv = options.num_heads_kv; + shape.seq_len_qo = options.seq_len_qo; + shape.seq_len_kv = options.seq_len_kv; + shape.head_size_qk = options.head_size_qk; + shape.head_size_vo = options.head_size_vo; + } + + auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo] = problem_size; + stride_Q = cutlass::make_cute_packed_stride(StrideQ{}, cute::make_shape(seq_len_qo, head_size_qk, num_heads_q, batch)); + stride_K = cutlass::make_cute_packed_stride(StrideK{}, cute::make_shape(seq_len_kv, head_size_qk, num_heads_kv, batch)); + stride_V = cutlass::make_cute_packed_stride(StrideV{}, cute::make_shape(head_size_vo, seq_len_kv, num_heads_kv, batch)); + stride_O = cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len_qo, head_size_vo, num_heads_q, batch)); + + block_Q.reset(static_cast(batch) * num_heads_q * seq_len_qo * head_size_qk); + block_K.reset(static_cast(batch) * num_heads_kv * seq_len_kv * head_size_qk); + block_V.reset(static_cast(batch) * num_heads_kv * seq_len_kv * head_size_vo); + block_O.reset(static_cast(batch) * num_heads_q * seq_len_qo * head_size_vo); + block_ref_O.reset(static_cast(batch) * num_heads_q * seq_len_qo * head_size_vo); + + initialize_block(block_Q, seed + 2023); + initialize_block(block_K, seed + 2022); + initialize_block(block_V, seed + 2021); + + if (!cumulative_seqlen_q.empty()) { + device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size()); + device_cumulative_seqlen_q.copy_from_host(cumulative_seqlen_q.data(), cumulative_seqlen_q.size()); + } + + if (!cumulative_seqlen_kv.empty()) { + device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size()); + device_cumulative_seqlen_kv.copy_from_host(cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size()); + } + if constexpr (isVarLen) { + shape.seq_len_qo.cumulative_length = device_cumulative_seqlen_q.get(); + shape.seq_len_kv.cumulative_length = device_cumulative_seqlen_kv.get(); + } + return shape; + } + + // Note that the GemmUniversalAdapter currently doesn't support flash attention, which is why this + // secondary `run` function is required to launch the kernel. + static void run(typename FMHAKernel::Params params) + { + namespace syclex = sycl::ext::oneapi::experimental; + namespace intelex = sycl::ext::intel::experimental; + + dim3 const block = FMHAKernel::get_block_shape(); + dim3 const grid = FMHAKernel::get_grid_shape(params); + + // configure smem size and carveout + int smem_size = FMHAKernel::SharedStorageSize; + + const auto sycl_block = compat::dim3(block.x, block.y, block.z); + const auto sycl_grid = compat::dim3(grid.x, grid.y, grid.z); + + // Launch parameters depend on whether SYCL compiler supports work-group scratch memory extension + compat::experimental::launch_properties launch_props { + syclex::work_group_scratch_size(smem_size), + }; + compat::experimental::kernel_properties kernel_props{ + syclex::sub_group_size, + intelex::grf_size<256> + }; + compat::experimental::launch_policy policy{sycl_grid, sycl_block, launch_props, kernel_props}; + auto event = compat::experimental::launch, FMHAKernel>(policy, params); + + EventManager::getInstance().addEvent(event); + } + + cutlass::Status run(const Options &options, const cutlass::KernelHardwareInfo &hw_info) { + + ProblemShapeType shape = initialize(options); + + typename FMHAKernel::Arguments arguments{ + { + shape, + block_Q.get(), stride_Q, + block_K.get(), stride_K, + block_V.get(), stride_V, + block_O.get(), stride_O + }, + {options.softmax_scale}, + {}, + hw_info + }; + + // Define device-global scratch memory + size_t workspace_size = FMHAKernel::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + if (!FMHAKernel::can_implement(arguments)) { + std::cout << "Invalid Problem Size: " << options.batch << 'x' << options.num_heads_q << 'x' << + options.seq_len_qo << 'x' << options.seq_len_kv << 'x' << options.head_size_qk << 'x' << options.head_size_vo + << (options.is_causal ? "xCausal" : "xNonCausal") << std::endl; + return cutlass::Status::kErrorInvalidProblem; + } + + // Initialize the workspace + CUTLASS_CHECK(FMHAKernel::initialize_workspace(arguments, workspace.get())); + + // Convert host-side arguments to device-side arguments to be passed to the kernel + auto params = FMHAKernel::to_underlying_arguments(arguments, workspace.get()); + + // Run the GEMM + run(params); + + compat::wait(); + + // Verify that the result is correct + bool passed = verify(shape, options.is_causal); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if (!passed) { + return cutlass::Status::kErrorInternal; + } + + if (options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + run(params); + } + compat::wait(); + // when seq_len_qo is not equal to seq_len_kv we use bottom up approach for the masking. + // Following changes will adjust the effective_seq_len_kv when masking applied for such cases + auto offset = cute::min(options.seq_len_qo, options.seq_len_kv); + auto discard_seq_coord = options.seq_len_qo - offset; + auto full_tile_offset = options.seq_len_kv - offset; + // offset + 1 is going to be ceil_div + auto effective_seq_len_kv = options.is_causal ? full_tile_offset + ((offset + 1) / 2.0): options.seq_len_kv; + auto effective_seq_len_qo = options.is_causal ? options.seq_len_qo - discard_seq_coord : options.seq_len_qo; + double cute_time = timer.seconds() / options.iterations; + double flops_qk = 2.0 * options.batch * options.num_heads_q * effective_seq_len_qo * effective_seq_len_kv * options.head_size_qk; + double flops_pv = 2.0 * options.batch * options.num_heads_q * effective_seq_len_qo * options.head_size_vo * effective_seq_len_kv; + double tflops = ((flops_qk + flops_pv) * 1e-12) / cute_time; + double gbps_qk = options.batch * (sizeof(ElementQ) * options.num_heads_q * effective_seq_len_qo * options.head_size_qk + + sizeof(ElementK) * options.num_heads_kv * effective_seq_len_kv * options.head_size_qk); + double gbps_pv = sizeof(ElementV) * options.batch * options.num_heads_kv * effective_seq_len_kv * options.head_size_vo + + sizeof(ElementO) * options.batch * options.num_heads_q * effective_seq_len_qo * options.head_size_vo; + double gbps = ((gbps_qk + gbps_pv) * 1e-9) / (cute_time); + std::cout << "Batch: " << options.batch << "\tNumHeads_q: " << options.num_heads_q << "\tNumHeads_kv: " << options.num_heads_kv << "\tSeq Length QO: " << options.seq_len_qo + << "\tSeq Length KV: " << options.seq_len_kv << "\tHead Size QK: " << options.head_size_qk << "\tHead Size VO: " << options.head_size_vo + << "\tCausal Mask: " << (options.is_causal ? "true" : "false") << "\tVariable Sequence Length: " << (options.varlen ? "true" : "false") + << "\t Scheduler: " << options.scheduler; + printf("\nPerformance: %4.3f GB/s, %4.3f TFlop/s, %6.4f ms\n\n", gbps, tflops, cute_time * 1000); + } + + return cutlass::Status::kSuccess; + } +}; + +template default */ + int PipelineStages, + bool persistent, + typename ElementQ = bfloat16_t, + typename ElementK = bfloat16_t, + typename ElementV = bfloat16_t, + typename ElementO = float, + typename MMAOperation_ = void, /* void -> default */ + typename StrideQ = Stride, + typename StrideK = Stride, + typename StrideV = Stride<_1, int, int, int>, + typename StrideO = Stride, + typename GmemTiledCopyQ = void, /* void -> default block 2D */ + typename GmemTiledCopyK = void, + typename GmemTiledCopyV = void, + typename GmemTiledCopyO = void> +struct FMHAConfig { + + static constexpr int SGTileQ = get<0>(shape_div(TileShapeQK{}, shape(SubgroupLayoutQK{})))(); + using MMAOperation = cute::conditional_t, + typename cute::conditional_t< + cute::is_same_v || cute::is_same_v, + XE_DPAS_TT, + XE_DPAS_TT + >, + MMAOperation_>; + using SubgroupLayoutPV = cute::conditional_t, + decltype(cutlass::fmha::collective::get_sg_layout_pv(SubgroupLayoutQK{})), + SubgroupLayoutPV_>; + + template + static int run(const Options &options) { + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + using ProblemShapeType = cutlass::fmha::kernel::FMHAProblemShape; + + using TiledMMAQK = typename TiledMMAHelper, Layout, SubgroupLayoutQK>::TiledMMA; + using TiledMMAPV = typename TiledMMAHelper, Layout, SubgroupLayoutPV>::TiledMMA; + + static_assert(get<0>(TileShapeOutput{}) == get<0>(TileShapePV{}), + "Output tile and P*V tile have different sizes in Q dimension"); + constexpr int VTiles = get<1>(TileShapeOutput{}) / get<1>(TileShapePV{}); + + auto make_dummy_tensor = [&](auto val, auto stride) { + return make_tensor(make_gmem_ptr(&val), + make_layout(repeat>(1), stride)); + }; + + using TensorQ = decltype(make_dummy_tensor(ElementQ{}, StrideQ{})); + using TensorK = decltype(make_dummy_tensor(ElementK{}, StrideK{})); + using TensorV = decltype(make_dummy_tensor(ElementV{}, StrideV{})); + using TensorO = decltype(make_dummy_tensor(ElementO{}, StrideO{})); + + // Mainloop + using MainloopDispatchPolicy = cutlass::fmha::XeDefault; + using CollectiveMainloop = cutlass::fmha::collective::FMHAFwdMainloop< + MainloopDispatchPolicy, Causal, + TiledMMAQK, TiledMMAPV, VTiles, + TensorQ, TensorK, TensorV, + GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV + >; + + // Epilogue + using CollectiveEpilogue = cutlass::fmha::collective::FMHAFwdEpilogue< + CollectiveMainloop, + TileShapeOutput, + TensorO, + GmemTiledCopyO + >; + + static_assert(!(persistent & Causal), "persistent SDPA kernel not support Causal yet"); + using FMHAKernel = conditional_t, + cutlass::fmha::kernel::XeFMHAFwdDynamicSplitKernel< + ProblemShapeType, CollectiveMainloop, CollectiveEpilogue, Scheduler>, + cutlass::fmha::kernel::XeFMHAFwdKernel< + ProblemShapeType, CollectiveMainloop, CollectiveEpilogue, Scheduler> + >; + + ExampleRunner runner; + + CUTLASS_CHECK(runner.run(options, hw_info)); + return 0; + } + + static int run(const Options &options) { + if (options.varlen) { + return run(options); + } else { + return persistent ? run(options) : + run(options); + } + } +}; diff --git a/examples/09_bmg_grouped_gemm_f8/09_bmg_grouped_gemm_f8.cpp b/examples/09_bmg_grouped_gemm_f8/09_bmg_grouped_gemm_f8.cpp index c5b0233ac7..9c8ef827cb 100644 --- a/examples/09_bmg_grouped_gemm_f8/09_bmg_grouped_gemm_f8.cpp +++ b/examples/09_bmg_grouped_gemm_f8/09_bmg_grouped_gemm_f8.cpp @@ -576,19 +576,24 @@ int launcher(Options& options) using LayoutC = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor; - using GmemTiledCopyA = XE_2D_U8x32x32_LD_V; - using GmemTiledCopyB = XE_2D_U8x32x32_LD_V; + // When left unspecified, the MainloopXeL1Staged dispatch + // automatically selects the appropriate 2D block copy op + using GmemTiledCopyA = void; + using GmemTiledCopyB = void; // Workgroup-level tile - using TileShape = Shape<_256, _256, _32>; + using TileShape = Shape<_256, _256, _16>; using TiledMma = - typename TiledMMAHelper, Layout, - Layout, Stride<_4, _1, _0>>>::TiledMMA; + typename TiledMMAHelper< + MMA_Atom>, // A,B=FP16; accumulator=FP32 + Layout, + Layout, Stride<_4, _1, _0>> + >::TiledMMA; constexpr int PipelineStages = 2; // Dispatch to grouped gemm algorithm - using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16GroupFP8; + using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1StagedGroup; using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group; using EpilogueOp = cutlass::epilogue::fusion::LinearCombination +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "sycl_common.hpp" +#include "helper.h" + +#include + +using namespace cute; +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group + +using ElementAccumulator = float; // <- data type of accumulator +using ElementComputeEpilogue = float; // <- data type of epilogue operations +using ElementOutput = float; // <- data type of elements in output matrix D + +/////////////////////////////////////////////////////////////////////////////////////////////////// + + +// Command line options parsing +struct Options { + + bool error = false; + bool help = false; + + float alpha, beta; + int iterations; + int m, n, k, groups; + std::vector problem_sizes_host; + + Options() : error(false), help(false), alpha(FLT_MAX), beta(FLT_MAX), iterations(100), + m(5120), n(4096), k(4096), groups(2) { + problem_sizes_host.reserve(groups); + for(int i = 0; i < groups; i++) { + problem_sizes_host.push_back({m, n, k}); + } + } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 5120); + cmd.get_cmd_line_argument("n", n, 4096); + cmd.get_cmd_line_argument("k", k, 4096); + cmd.get_cmd_line_argument("groups", groups, 2); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations, 100); + + assert(groups > 0); + problem_sizes_host.clear(); + problem_sizes_host.reserve(groups); + for(int i = 0; i < groups; i++) { + problem_sizes_host.push_back({m, n, k}); + } + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "BMG Grouped GEMM\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM for all groups\n" + << " --n= Sets the N extent of the GEMM for all groups\n" + << " --k= Sets the K extent of the GEMM for all groups\n" + << " --groups= Sets the number of individual GEMM problems for Grouped GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "09_bmg_grouped_gemm_fp8" << " --m=5120 --n=4096 --k=4096 --groups=5 --alpha=2.5 --beta=0.5 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s, std::vector problem_sizes_host) const + { + // Number of real-valued multiply-adds + uint64_t fmas = uint64_t(); + + for (auto const & problem : problem_sizes_host) { + fmas += static_cast(get<0>(problem)) * + static_cast(get<1>(problem)) * + static_cast(get<2>(problem)); + } + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * uint64_t(fmas); + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class Gemm +> +struct ExampleRunner { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementAccumulator = ElementOutput; + + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + + // Host-side allocations + std::vector offset_A; + std::vector offset_B; + std::vector offset_C; + std::vector offset_D; + + std::vector stride_A_host; + std::vector stride_B_host; + std::vector stride_C_host; + std::vector stride_D_host; + + std::vector alpha_host; + std::vector beta_host; + + // Device-side allocations + cutlass::DeviceAllocation problem_sizes; + + // This example defines all matrices in a single allocation (e.g. block_A), but this is not a + // requirement. Matrix base pointers are read from device allocation (e.g. ptr_A) + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + + cutlass::DeviceAllocation ptr_A; + cutlass::DeviceAllocation ptr_B; + cutlass::DeviceAllocation ptr_C; + cutlass::DeviceAllocation ptr_D; + cutlass::DeviceAllocation ptr_ref_D; + + cutlass::DeviceAllocation stride_A; + cutlass::DeviceAllocation stride_B; + cutlass::DeviceAllocation stride_C; + cutlass::DeviceAllocation stride_D; + + // Note, this is an array of pointers to alpha and beta scaling values per group + cutlass::DeviceAllocation alpha_device; + cutlass::DeviceAllocation beta_device; + cutlass::DeviceAllocation block_alpha; + cutlass::DeviceAllocation block_beta; + + uint64_t seed = 0; + + // + // Methods + // + template + bool verify(const Options &options) { + bool passed = true; + // Verify against individual reference GEMMs + for (int32_t i = 0; i < options.groups; ++i) { + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + cutlass::DeviceAllocation block_A_fp16(block_A.size()); + cutlass::DeviceAllocation block_B_fp16(block_B.size()); + + // fp8 -> fp16 + convert_dtype( + block_A.get(), + block_A_fp16.get(), + block_A.size() + ); + convert_dtype( + block_B.get(), + block_B_fp16.get(), + block_B.size() + ); + + cutlass::TensorRef ref_A(block_A_fp16.get() + offset_A.at(i), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B_fp16.get() + offset_B.at(i), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get() + offset_C.at(i), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get() + offset_D.at(i), LayoutD::packed({M, N})); + + // + // Compute reference output + // + cutlass::reference::device::GemmComplex( + {M, N, K}, + alpha_host.at(i), + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta_host.at(i), + ref_C, + ref_D, + ElementAccumulator(0), + 1, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + // Wait for kernel to finish + compat::wait(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + passed &= cutlass::reference::device::BlockCompareEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N); + if(!passed) + break; + } + return passed; + } + +/// Allocates device-side data +void allocate(const Options &options) { + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + + // Compute total allocation sizes across group + for (int32_t i = 0; i < options.groups; ++i) { + + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + // Offset into block allocation of each matrix base pointer + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + + int64_t elements_A = M * K; + int64_t elements_B = K * N; + int64_t elements_C = M * N; + int64_t elements_D = M * N; + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + + stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1})); + stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); + stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1})); + stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1})); + + } + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + block_ref_D.reset(total_elements_D); + block_alpha.reset(options.groups); + block_beta.reset(options.groups); +} + +/// Initialize operands to be used in the GEMM and reference GEMM +template +void initialize(const Options &options) { + + uint64_t seed = 2020; + + problem_sizes.reset(options.groups); + problem_sizes.copy_from_host(options.problem_sizes_host.data()); + + // + // Assign pointers + // + + std::vector ptr_A_host(options.groups); + std::vector ptr_B_host(options.groups); + std::vector ptr_C_host(options.groups); + std::vector ptr_D_host(options.groups); + std::vector ptr_alpha_host(options.groups); + std::vector ptr_beta_host(options.groups); + + // Compute offsets, alpha & beta over group on host + for (int32_t i = 0; i < options.groups; ++i) { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + // Fill host vector of alpha & beta with random values if using per-group values + alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast(rand() % 5 + 1) : options.alpha); + beta_host.push_back((options.beta == FLT_MAX) ? static_cast(rand() % 5) : options.beta); + // Fill host ptr vectors with offset addresses into device alpha/beta blocks + ptr_alpha_host.at(i) = block_alpha.get() + i; + ptr_beta_host.at(i) = block_beta.get() + i; + } + + // Allocate device memory & copy from host + ptr_A.reset(options.groups); + // Per-group alpha and beta + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(options.groups); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_C.reset(options.groups); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(options.groups); + ptr_D.copy_from_host(ptr_D_host.data()); + + stride_A.reset(options.groups); + stride_A.copy_from_host(stride_A_host.data()); + + stride_B.reset(options.groups); + stride_B.copy_from_host(stride_B_host.data()); + + stride_C.reset(options.groups); + stride_C.copy_from_host(stride_C_host.data()); + + stride_D.reset(options.groups); + stride_D.copy_from_host(stride_D_host.data()); + + // Per-group alpha and beta ptrs + alpha_device.reset(options.groups); + alpha_device.copy_from_host(ptr_alpha_host.data()); + beta_device.reset(options.groups); + beta_device.copy_from_host(ptr_beta_host.data()); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + // Per-group alpha and beta values - note these are not directly passed to kernel - the pointers + // (alpha_device/beta_device) are passed instead + block_alpha.copy_from_host(alpha_host.data()); + block_beta.copy_from_host(beta_host.data()); +} + + /// Populates a Gemm::Arguments structure from the given commandline options + typename Gemm::Arguments args_from_options(const Options &options, const cutlass::KernelHardwareInfo& hw_info, bool host_problem_shapes_available = true) + { + typename Gemm::Arguments arguments; + decltype(arguments.epilogue.thread) fusion_args; + + if (options.alpha != FLT_MAX && options.beta != FLT_MAX) { + // If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches. + fusion_args.alpha = options.alpha; + fusion_args.beta = options.beta; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + // Single alpha and beta for all groups + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0}; + } + else { + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups. + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = alpha_device.get(); + fusion_args.beta_ptr_array = beta_device.get(); + // One alpha and beta per each group + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1}; + } + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerXeGroup::RasterOrderOptions; + + // Per-GEMM problem shape info may only exist on the device. + if (host_problem_shapes_available) { + arguments = typename Gemm::Arguments { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), options.problem_sizes_host.data()}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info, + {1, RasterOrderOptions::AlongN} + }; + } + else { + arguments = typename Gemm::Arguments { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), nullptr}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info, + {1, RasterOrderOptions::AlongN} + }; + } + + return arguments; + } + + template + cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info, bool host_problem_shapes_available = true) { + allocate(options); + initialize(options); + + Gemm gemm_op; + + auto arguments = args_from_options(options, hw_info, host_problem_shapes_available); + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(gemm_op.can_implement(arguments)); + + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); + + // Run the GEMM + CUTLASS_CHECK(gemm_op.run()); + + compat::wait(); + + // Verify that the result is correct + bool passed = verify(options); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if(!passed) return cutlass::Status::kErrorInternal; + + if (options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm_op.run()); + } + compat::wait(); + + float cute_time = timer.seconds() * 1000; + double cute_average_time = double(cute_time) / double(options.iterations); + double gflops = options.gflops(cute_average_time / 1000.0, options.problem_sizes_host); + if constexpr (std::is_same_v) { + std::cout << "Datatype: float_e4m3_t"<< std::endl; + } else if constexpr (std::is_same_v) { + std::cout << "Datatype: float_e5m2_t"<< std::endl; + } else { + static_assert(cutlass::detail::dependent_false, "Not a valid fp8 datatype."); + } + std::cout << " Problem Sizes, Alpha, Beta " << std::endl; + for (int32_t i = 0; i < options.groups; ++i) { + std::cout << " " << options.problem_sizes_host.at(i); + std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl; + } + std::cout << " Groups : " << options.groups << std::endl; + std::cout << " Avg runtime : " << cute_average_time << " ms" << std::endl; + std::cout << " GFLOPS : " << gflops << std::endl; + } + + return cutlass::Status::kSuccess; + } + +}; + + +template +int launcher(Options& options) +{ + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using GmemTiledCopyA = XE_2D_U8x32x32_LD_V; + using GmemTiledCopyB = XE_2D_U8x32x32_LD_V; + + // Workgroup-level tile + using TileShape = Shape<_256, _256, _32>; + + using TiledMma = + typename TiledMMAHelper, Layout, + Layout, Stride<_4, _1, _0>>>::TiledMMA; + + constexpr int PipelineStages = 2; + // Dispatch to grouped gemm algorithm + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16GroupFP8; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group; + + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, void, + XE_2D_U32x8x16_ST_N, + void, void>; + +// Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementType, + cutlass::gemm::TagToStrideA_t, + ElementType, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::GroupScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + CUTLASS_CHECK(runner.template run(options, hw_info)); + + return 0; +} + +int main(int argc, const char** argv) +{ + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + launcher(options); + launcher(options); + return 0; +} diff --git a/examples/09_bmg_grouped_gemm_f8/legacy/CMakeLists.txt b/examples/09_bmg_grouped_gemm_f8/legacy/CMakeLists.txt new file mode 100644 index 0000000000..0f064b6650 --- /dev/null +++ b/examples/09_bmg_grouped_gemm_f8/legacy/CMakeLists.txt @@ -0,0 +1,42 @@ +# Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +set(TEST_GROUPS_2 --groups=2) +set(TEST_GROUPS_4 --groups=4) + +cutlass_example_add_executable( + 09_bmg_grouped_gemm_f8_legacy + 09_bmg_grouped_gemm_f8.cpp + TEST_COMMAND_OPTIONS + TEST_GROUPS_2 + TEST_GROUPS_4 +) +if(NOT DPCPP_SYCL_TARGET STREQUAL "spir64") + # TODO(codeplay): Remove these once IGC VectorAliasThreshold issue is fixed + target_link_options( 09_bmg_grouped_gemm_f8_legacy PRIVATE -Xs "-options \"-igc_opts 'VectorAliasBBThreshold=10000'\"" ) +endif() \ No newline at end of file diff --git a/examples/11_xe20_cutlass_library/CMakeLists.txt b/examples/11_xe20_cutlass_library/CMakeLists.txt new file mode 100644 index 0000000000..a4fe624e94 --- /dev/null +++ b/examples/11_xe20_cutlass_library/CMakeLists.txt @@ -0,0 +1,99 @@ +# Copyright (C) 2025 Intel Corporation, All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Example 11: XE20 CUTLASS Library BF16 GEMM +# This example creates a shared library (.so) that exports CUTLASS BF16 GEMM +# functionality for use with Python via ctypes. + +# Create shared library for Python integration +add_library(xe20_cutlass_library_bf16 SHARED + xe_20_cutlass_library_b16.cpp +) + +# Set library properties (this creates shared library for python example to link) +set_target_properties(xe20_cutlass_library_bf16 PROPERTIES + CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON + VERSION 1.0 + SOVERSION 1 + OUTPUT_NAME "xe20_cutlass_library_bf16" + POSITION_INDEPENDENT_CODE ON +) + +# Include directories +target_include_directories(xe20_cutlass_library_bf16 PRIVATE + ${CUTLASS_EXAMPLES_COMMON_SOURCE_DIR} + ${CUTLASS_EXAMPLES_UTILS_DIR} + ${CUTLASS_APPLICATIONS_DIR} +) + +# Link libraries +target_link_libraries(xe20_cutlass_library_bf16 PRIVATE + CUTLASS + cutlass_tools_util_includes +) + +# Add compile definitions +target_compile_definitions(xe20_cutlass_library_bf16 PRIVATE + CUTLASS_ENABLE_SYCL=1 + SYCL_INTEL_TARGET=1 + DPCPP_SYCL_TARGET=intel_gpu_bmg_g21 +) + +# Add Intel-specific SYCL link flags for XE20 optimization +if(CUTLASS_ENABLE_SYCL AND SYCL_INTEL_TARGET) + target_link_options(xe20_cutlass_library_bf16 PRIVATE + -Xspirv-translator + -spirv-ext=+SPV_INTEL_split_barrier,+SPV_INTEL_2d_block_io,+SPV_INTEL_subgroup_matrix_multiply_accumulate + ) + + add_sycl_to_target(TARGET xe20_cutlass_library_bf16) + add_onemkl_to_target(TARGET xe20_cutlass_library_bf16) +endif() + +# Link against CUTLASS XE20 GEMM library if available +if(TARGET cutlass_gemm_xe20_gemm) + target_link_libraries(xe20_cutlass_library_bf16 PRIVATE cutlass_gemm_xe20_gemm) +endif() + +# Install the shared library +install(TARGETS xe20_cutlass_library_bf16 + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} +) + +# Add to examples target +add_dependencies(cutlass_examples xe20_cutlass_library_bf16) + +# Custom target for building just this library +add_custom_target(xe20_cutlass_library + DEPENDS xe20_cutlass_library_bf16 + COMMENT "Building XE20 CUTLASS Library BF16 GEMM Shared Library (.so)" +) + +message(STATUS "Added shared library xe20_cutlass_library_bf16 for Python integration") \ No newline at end of file diff --git a/examples/11_xe20_cutlass_library/xe_20_cutlass_library_b16.cpp b/examples/11_xe20_cutlass_library/xe_20_cutlass_library_b16.cpp new file mode 100644 index 0000000000..812af797d7 --- /dev/null +++ b/examples/11_xe20_cutlass_library/xe_20_cutlass_library_b16.cpp @@ -0,0 +1,225 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ***************************************************************************************************/ + + + +#include +#include +#include +#include +#include + +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/device_memory.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +//#include "cutlass/gemm/device/gemm_sparse.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" + + +// We compile all models with -fvisibility=hidden. Any symbols that need to be +// exposed in the final shared library must be declared with PT_EXPORT to make +// them visible. +#ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++) +#define PT_EXPORT __attribute__((__visibility__("default"))) +#else +#ifdef _WIN32 +#define PT_EXPORT __declspec(dllexport) +#else +#define PT_EXPORT +#endif +#endif + +using namespace cute; +#define CUTLASS_CHECK(status) \ +{ \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \ + cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \ + throw std::runtime_error(msg); \ + } \ +} + +// Used as pass-through functor in EVT just for type casting / rounding +template +struct identity_op { + CUTLASS_HOST_DEVICE + T operator()(T val) const { return val; } +}; + + + +using cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_epilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Xe20, cutlass::arch::OpClassTensorOp, + cute::Shape, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, cutlass::layout::RowMajor, 4, + float, cutlass::layout::RowMajor, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto, + cutlass::epilogue::fusion::LinearCombination< + float, + float, + float, + float + > + >::CollectiveOp; + +using cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_mainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Xe20, cutlass::arch::OpClassTensorOp, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, 8, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, 8, + float, + cute::Shape, + cute::Shape, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + +// Gemm operator cutlass3x_xe11_tensorop_gemm_bf16_128x256_16x0_tn_align2 +using cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_base = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_mainloop, + cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_epilogue, + cutlass::gemm::PersistentScheduler>; + +// Define named type +struct cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8 : +public cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_base { }; + + + using cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_device_type = cutlass::gemm::device::GemmUniversalAdapter; + +// When workspace_size is not a nullptr, populates requested workspace_size and returns. +// Otherwise, computes the Gemm kernel using the given workspace ptr. +extern "C" { +PT_EXPORT int sycl_tla_gemm_xe20_bf16(const uint16_t* X, const uint16_t* W, uint16_t* Y, const int M, const int N, const int K, const int B, const int lda, const int ldb, const int ldc, const int ldd, const int X_offset, const int W_offset, const int Y_offset, const uint8_t swizzle, size_t* workspace_size, uint8_t* workspace, sycl::queue* stream) { + try { + using ElementComputeEpilogue = cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_device_type::ElementAccumulator; + using coord_t = cutlass::gemm::GemmCoord::Index; + static cutlass::KernelHardwareInfo hw_info; + if (hw_info.sm_count == 0) { + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); + CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count); + } + + // Initialize GemmUniversal3xInstance arguments using constructor + cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_device_type::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode + { + static_cast(M), + static_cast(N), + static_cast(K), + static_cast(B) + }, // ProblemShape problem_shape + { + (cutlass::bfloat16_t*)(X + X_offset), // ElementA const* ptr_A + cute::make_tuple(cute::Int<1>{}, int64_t(lda), int64_t(0)), // StrideA dA (column-major: stride_m=1, stride_n=lda, batch=0) + (cutlass::bfloat16_t*)(W + W_offset), // ElementB const* ptr_B + cute::make_tuple(int64_t(ldb), cute::Int<1>{}, int64_t(0)), // StrideB dB (column-major: stride_m=ldb, stride_n=1, batch=0) + }, // MainloopArguments mainloop + + // see https://tinyurl.com/4rk89z48 + { + {ElementComputeEpilogue(1), ElementComputeEpilogue(0)}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT ) + nullptr, // ElementC const* ptr_C + cute::make_tuple(int64_t(0), cute::Int<1>{}, int64_t(0)), // StrideC dC (row-major: stride_m, stride_n=1, batch=0) + (float*)(Y + Y_offset), // ElementD ptr_D (output is float, not bfloat16) + cute::make_tuple(int64_t(ldd), cute::Int<1>{}, int64_t(0)), // StrideD dD (row-major: stride_m=ldd, stride_n=1, batch=0) + }, // EpilogueArguments epilogue, + hw_info + }; + arguments.scheduler.max_swizzle_size = swizzle; + cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_device_type gemm_op; + if (workspace_size) { + *workspace_size = gemm_op.get_workspace_size(arguments); + return 0; + } + // check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers +#ifndef CUTLASS_BACKEND_DISABLE_CHECKS + { + auto status = gemm_op.can_implement(arguments); + CUTLASS_CHECK(status); + } +#endif +#ifdef CUTLASS_DEBUG_TRACE_LEVEL +#if CUTLASS_DEBUG_TRACE_LEVEL == 1 + { + // Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1 + // we don't need a print statement, it's happening inside the function. + gemm_op.maximum_active_blocks(); + } +#endif +#endif + { + auto status = gemm_op.initialize(arguments, workspace, stream); + CUTLASS_CHECK(status); + } + { + auto status = gemm_op(stream); + CUTLASS_CHECK(status); + } + } + catch (std::exception& e) { + std::cerr << "Runtime error: " << e.what() << std::endl; + return -1; + } + catch (...) { + return -1; + } + return 0; +} +} + +// configuration name: cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8 \ No newline at end of file diff --git a/examples/12_bmg_moe_gemm_cute_interface/12_bmg_moe_gemm_cute_interface.cpp b/examples/12_bmg_moe_gemm_cute_interface/12_bmg_moe_gemm_cute_interface.cpp new file mode 100644 index 0000000000..b223593bcd --- /dev/null +++ b/examples/12_bmg_moe_gemm_cute_interface/12_bmg_moe_gemm_cute_interface.cpp @@ -0,0 +1,425 @@ +/*************************************************************************************************** + * Copyright (c) 2025 Intel Corporation. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief CUTLASS Intel BMG MoE API example based on sycl-tla Group GEMM + +*/ + +#include "cutlass/util/GPU_Clock.hpp" + +#include +#include + +#include +#include +#include + +#include + +#include "cutlass/kernel_hardware_info.h" +#include "cutlass/platform/platform.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/initialize_block.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/sycl_event_manager.hpp" + +#include "moe_grouped_gemm.hpp" +#include "moe_tile_scheduler.hpp" + +#pragma clang diagnostic ignored "-Wpass-failed" +#pragma clang diagnostic ignored "-Wdeprecated-declarations" + +using namespace cute; +using namespace MoE; + +using ElementAccumulator = float; // <- data type of accumulator + +struct VerificationHelper { + + bool error = false; + bool help = false; + + float alpha = 1.f; + float beta = 0.f; + int iterations; + int m = 0, n = 0, k = 0, groups; + int *num_rows_per_expert = nullptr; + std::vector + problem_sizes_host; + + VerificationHelper() + : error(false), help(false), alpha(1.f), beta(0.f), iterations(100) {} + + void parse(const int num_experts, const int *num_tokens_per_expert_host, + int moe_n, int moe_k, + const int *num_tokens_per_expert_device = nullptr) { + n = moe_n; + k = moe_k; + groups = num_experts; + iterations = 2; + num_rows_per_expert = const_cast(num_tokens_per_expert_device); + assert(groups > 0); + problem_sizes_host.clear(); + problem_sizes_host.reserve(groups); + for (int i = 0; i < groups; i++) { + problem_sizes_host.push_back({num_tokens_per_expert_host[i], n, k}); + m += num_tokens_per_expert_host[i]; + } + } + + /// Compute performance in GFLOP/s + std::tuple + gflops(double runtime_s, + std::vector + problem_sizes_host) const { + // Number of real-valued multiply-adds + uint64_t fmas = uint64_t(); + uint64_t bytes_loaded = 0; + + for (auto const &problem : problem_sizes_host) { + auto M = static_cast(get<0>(problem)); + auto N = static_cast(get<1>(problem)); + auto K = static_cast(get<2>(problem)); + fmas += M * N * K; + bytes_loaded += + /* sizeof(cutlass::bfloat16_t) */ 2 * (2 * M * N + N * K + M * K); + } + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * uint64_t(fmas); + double gflop = double(flop) / double(1.0e9); + double arithmetic_intensity = double(flop) / double(bytes_loaded); + double peak_mwm_bw = 456.0; + double gflops_attainable = std::min( + 117 * double(1.0e12), + arithmetic_intensity * (peak_mwm_bw * 1024 * 1024 * 1024)); + double projected_time = flop / gflops_attainable; + return std::make_tuple(gflop / runtime_s, + double(bytes_loaded) / 1024 / 1024 / 1024 / + runtime_s, + projected_time * 1000); + } + + template && + is_any_of_v && + is_any_of_v>> + bool verify(const ElementA *activations, const ElementB *weights, + ElementD *outputs) { + cutlass::DeviceAllocation output_ref; + cutlass::DeviceAllocation unused_c_matrix; + output_ref.reset(m * n); + unused_c_matrix.reset(m * n); + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + bool passed = true; + // Verify against individual reference GEMMs + int cumulative_sum = 0; + for (int32_t i = 0; i < groups; ++i) { + auto problem = problem_sizes_host.at(i); + auto M = get<0>(problem); + cutlass::TensorRef ref_A(activations + cumulative_sum * k, + LayoutA::packed({M, k})); + cutlass::TensorRef ref_B(weights + i * n * k, LayoutB::packed({k, n})); + cutlass::TensorRef ref_C(unused_c_matrix.get() + cumulative_sum * n, + LayoutC::packed({M, n})); + cutlass::TensorRef ref_D(output_ref.get() + cumulative_sum * n, + LayoutD::packed({M, n})); + + // + // Compute reference output + // + cutlass::reference::device::GemmComplex( + {M, n, k}, 1.0, ref_A, cutlass::ComplexTransform::kNone, ref_B, + cutlass::ComplexTransform::kNone, 0.0, ref_C, ref_D, + ElementAccumulator(0), + 1, // batch_count + M * k, // batch_stride_A + k * n, // batch_stride_B + M * n, // batch_stride_C + M * n // batch_stride_D + ); + + // Wait for kernel to finish + compat::wait(); + + // Check if output from CUTLASS kernel and reference kernel are equal or + // not + passed &= cutlass::reference::device::BlockCompareEqual( + output_ref.get() + cumulative_sum * n, outputs + cumulative_sum * n, + M * n); + if (!passed) { + break; + } + cumulative_sum += M; + } + return passed; + } +}; +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template auto choose_tiled_mma(TA *A, TB *B) { + using TA_non_CV = cutlass::platform::remove_cv_t; + using TB_non_CV = cutlass::platform::remove_cv_t; + auto op = XE_DPAS_TT<8, float, TA_non_CV, TB_non_CV>{}; + + using WGTile = Shape<_256, _128, _32>; // 256x128 WG tile size + using SGLayout = + Layout, Stride<_2, _1, _0>>; // 8x2 SG tiling, n-major + + using MMA = typename TiledMMAHelper, Layout, + SGLayout>::TiledMMA; + + return MMA{}; +} + +// type tag to define a unique sycl kernel name +template class GemmCuteName; + +template +void MoEGEMMLauncher(const ElementA *activations, const ElementB *weights, + const ElementS *scales, ElementD *outputs, + const int gemm_n, const int gemm_k, + const int *num_rows_per_expert_device, + const int *num_tokens_per_expert_host, + const int num_experts) { + // Change device_id to another value if you are running on a machine with + // multiple GPUs and wish to use a GPU other than that with device ID 0. + // For example, in a framework, you could query device ID. + int sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); + cutlass::KernelHardwareInfo hw_info{0, sm_count}; + auto dummy_problem_shape = cute::Shape{1, gemm_k, gemm_n}; + // The GroupedGEMM API requires creation of a vector of ProblemShape objects + // for each GEMM problem, which is used in the GroupedGEMM tile-scheduler. If + // there are 32 groups, then a vector of 32 `ProblemShape` objects is created. + // Since these would not be known at compile time for a framework, they would + // have to be created at run-time instead. However, for MoEGEMM, I just + // provide one dummy shape, and then the custom code in tile scheduler can + // derive the shape of each GEMM problem. + auto dummy_group_problem_shape = + cutlass::gemm::GroupProblemShape>{ + 1, &dummy_problem_shape, nullptr}; + using TileShape = Shape<_256, _128, _32>; + using ClusterShape = Shape<_1, _1, _1>; + auto scheduler_params = + PersistentTileSchedulerXeMoE::to_underlying_arguments( + dummy_group_problem_shape, TileShape{}, ClusterShape{}, hw_info, + PersistentTileSchedulerXeMoE::Arguments{ + 1, RasterOrderOptions::AlongN}); + auto group_distribution = + PersistentTileSchedulerXeMoE::get_grid_shape( + scheduler_params, dummy_group_problem_shape, TileShape{}, + ClusterShape{}, hw_info, + PersistentTileSchedulerXeMoE::Arguments{ + 1, RasterOrderOptions::AlongN}); + auto mma = choose_tiled_mma(activations, weights); + auto MaxThreadsPerWorkgroup = size(mma); + dim3 local_range{MaxThreadsPerWorkgroup, 1, 1}; + + sycl::range<3> local = {local_range.z, local_range.y, local_range.x}; + sycl::range<3> groups = {group_distribution.z, group_distribution.y, + group_distribution.x}; + sycl::range<3> global = {local[0] * groups[0], local[1] * groups[1], + local[2] * groups[2]}; + + namespace syclex = sycl::ext::oneapi::experimental; + namespace intelex = sycl::ext::intel::experimental; + + syclex::properties kernel_props{syclex::sub_group_size<16>, + intelex::grf_size<256>}; + sycl::queue Q = compat::get_default_queue(); + + GPU_Clock timer; + timer.start(); + auto event = Q.parallel_for< + GemmCuteName>( + sycl::nd_range<3>(global, local), kernel_props, [=](auto) { + // Can also use void for copy atoms. + // In that case, they will be chosen automatically. + MoE::MoEGEMM, + XE_LOAD_2D_VNNI<16, 32, 16, 16>, XE_STORE_2D<16, 8, 32>, + 'R', 'R', 'R'>(activations, weights, scales, outputs, mma, + num_rows_per_expert_device, num_experts, + gemm_n, gemm_k, scheduler_params); + }); + EventManager::getInstance().addEvent(event); + Q.wait_and_throw(); + float cute_time = timer.seconds() * 1000; + double cute_average_time = double(cute_time); + + VerificationHelper helper; + helper.parse(num_experts, num_tokens_per_expert_host, gemm_n, gemm_k); + if (helper.verify(activations, weights, outputs) == false) { + std::cout << "\n\nFailed accuracy verification :(\n\n"; + } + + auto [gflops, mem_bw_util, projected_time] = + helper.gflops(cute_average_time / 1000.0, helper.problem_sizes_host); + + std::cout << " Problem Sizes" << std::endl; + for (int32_t i = 0; i < num_experts; ++i) { + std::cout << " " << num_tokens_per_expert_host[i] << std::endl; + } + std::cout << " N : " << gemm_n << std::endl; + std::cout << " K : " << gemm_k << std::endl; + std::cout << " Groups : " << num_experts << std::endl; + std::cout << " Avg runtime : " << cute_average_time << " ms" << std::endl; + std::cout << " GFLOPS : " << gflops << std::endl; + std::cout << " Memory BW utilization : " << mem_bw_util << " GBPs" + << std::endl; +} + +void launcher(int *M_per_expert, int N, int K, const int &num_experts) { + int n_moe = N; + int k_moe = K; + int num_tokens_incl_duplicated = 0; + for (int i = 0; i < num_experts; i++) { + num_tokens_incl_duplicated += M_per_expert[i]; + } + + float M_occupancy = 0.f; + float actual_num_units = 0.f; + int total_num_M_tiles = 0; + for (int i = 0; i < num_experts; i++) { + total_num_M_tiles += (M_per_expert[i] + 255) / 256; + actual_num_units += M_per_expert[i] / 256.0; + } + M_occupancy = actual_num_units / total_num_M_tiles; + std::cout << "\n\n M-occupancy is " << M_occupancy << std::endl; + cutlass::DeviceAllocation num_rows_per_expert_device; + cutlass::DeviceAllocation activations_data; + cutlass::DeviceAllocation weights_data; + cutlass::DeviceAllocation output_data; + size_t A_size = num_tokens_incl_duplicated * k_moe; + size_t B_size = num_experts * n_moe * k_moe; + size_t D_size = num_tokens_incl_duplicated * n_moe; + num_rows_per_expert_device.reset(num_experts); + num_rows_per_expert_device.copy_from_host(M_per_expert); + activations_data.reset(A_size); + weights_data.reset(B_size); + output_data.reset(D_size); + uint64_t seed = 2023; + initialize_block(activations_data, seed + 2023); + initialize_block(weights_data, seed + 2022); + initialize_block(output_data, seed + 2021); + + MoEGEMMLauncher<'R', 'R'>(activations_data.get(), weights_data.get(), + static_cast(nullptr), output_data.get(), + n_moe, k_moe, num_rows_per_expert_device.get(), + M_per_expert, num_experts); +} + +int main(int argc, const char **argv) { + constexpr int num_experts = 32; + constexpr int num_layers = 24; + + int total_rows_for_each_expert[num_layers][num_experts] = { + {148, 231, 404, 180, 127, 244, 224, 244, 110, 617, 289, + 845, 191, 424, 30, 97, 57, 324, 62, 77, 75, 144, + 250, 287, 629, 370, 161, 101, 215, 113, 224, 35}, + {666, 214, 448, 87, 4, 28, 48, 13, 74, 40, 546, + 397, 487, 350, 26, 95, 517, 487, 295, 58, 637, 97, + 139, 33, 126, 15, 352, 311, 995, 193, 135, 135}, + {1016, 30, 36, 452, 469, 473, 232, 0, 493, 14, 954, + 6, 4, 6, 279, 3, 94, 106, 96, 48, 49, 113, + 142, 169, 75, 99, 25, 220, 249, 289, 4, 1803}, + {350, 229, 703, 154, 8, 64, 80, 339, 2, 56, 5, + 312, 1005, 29, 9, 11, 23, 0, 23, 431, 48, 129, + 496, 476, 8, 1234, 7, 130, 34, 58, 41, 1554}, + {39, 10, 6, 2, 110, 1, 894, 8, 53, 0, 275, + 6, 506, 421, 700, 178, 0, 530, 1623, 15, 231, 74, + 6, 222, 1246, 116, 35, 20, 0, 6, 381, 334}, + {399, 5, 201, 6, 134, 93, 1748, 1, 51, 4, 38, + 336, 53, 88, 328, 724, 15, 388, 706, 52, 19, 55, + 52, 33, 623, 1, 222, 215, 69, 45, 308, 1036}, + {11, 8, 407, 571, 458, 275, 197, 211, 13, 564, 462, + 114, 15, 13, 132, 24, 514, 2, 71, 13, 694, 47, + 16, 203, 610, 40, 0, 1587, 66, 23, 196, 491}, + {0, 230, 116, 136, 315, 643, 6, 183, 37, 26, 960, + 1, 8, 258, 21, 1602, 213, 198, 6, 196, 455, 557, + 47, 282, 493, 18, 101, 11, 616, 45, 268, 0}, + {392, 305, 179, 14, 227, 98, 114, 39, 64, 1456, 465, + 0, 18, 372, 0, 0, 189, 257, 25, 290, 486, 0, + 12, 1534, 468, 4, 555, 35, 146, 0, 161, 143}, + {4, 107, 20, 125, 236, 898, 0, 0, 375, 2, 125, + 0, 0, 1429, 36, 195, 1660, 0, 127, 454, 73, 358, + 47, 79, 32, 20, 1465, 0, 0, 6, 109, 66}, + {19, 0, 0, 0, 2, 1638, 75, 135, 392, 2, 1494, 3, 23, 5, 4, 58, + 0, 0, 71, 1285, 8, 441, 0, 145, 209, 408, 450, 2, 824, 13, 326, 16}, + {4, 2, 14, 0, 30, 206, 41, 131, 0, 429, 16, 895, 35, 21, 44, 128, + 12, 0, 417, 0, 838, 917, 42, 115, 109, 1759, 0, 36, 17, 0, 1790, 0}, + {6, 483, 241, 1327, 17, 11, 480, 9, 880, 58, 4, + 0, 61, 30, 16, 176, 9, 309, 26, 0, 0, 1882, + 4, 281, 475, 783, 197, 0, 19, 15, 6, 243}, + {370, 1222, 0, 6, 108, 929, 2, 7, 157, 348, 149, 106, 2, 5, 25, 33, + 1569, 8, 6, 106, 69, 1298, 0, 2, 529, 520, 0, 421, 0, 25, 26, 0}, + {59, 89, 0, 26, 25, 40, 1873, 141, 527, 371, 262, + 62, 16, 0, 127, 234, 1637, 64, 132, 8, 0, 7, + 161, 1005, 22, 1, 49, 6, 83, 925, 80, 16}, + {269, 617, 30, 4, 90, 26, 0, 16, 154, 212, 21, + 269, 379, 174, 129, 32, 8, 121, 344, 15, 0, 591, + 1494, 6, 737, 50, 112, 856, 483, 25, 454, 330}, + {0, 98, 1488, 22, 73, 0, 0, 343, 77, 4, 0, + 612, 165, 268, 4, 10, 43, 0, 598, 271, 2, 73, + 185, 0, 112, 779, 24, 1626, 0, 0, 0, 1171}, + {0, 0, 0, 189, 266, 1743, 0, 462, 20, 7, 668, 310, 40, 0, 10, 236, + 423, 18, 0, 0, 0, 999, 0, 139, 1754, 8, 619, 3, 23, 0, 102, 9}, + {131, 1753, 0, 113, 24, 94, 2, 12, 108, 0, 0, + 252, 97, 0, 1319, 233, 93, 1254, 195, 152, 14, 413, + 4, 2, 220, 67, 20, 4, 34, 559, 837, 42}, + {55, 76, 0, 8, 0, 3, 1557, 975, 135, 271, 4, 0, 0, 666, 207, 152, + 5, 2, 97, 364, 0, 13, 1423, 771, 159, 31, 223, 0, 431, 7, 409, 4}, + {4, 1026, 1799, 166, 694, 753, 0, 16, 0, 240, 1119, 19, 6, 0, 46, 659, + 10, 0, 112, 808, 181, 0, 28, 22, 90, 0, 176, 0, 37, 5, 10, 22}, + {44, 0, 4, 153, 299, 1357, 6, 23, 0, 12, 4, 419, 73, 24, 16, 24, + 1, 4, 4, 102, 16, 4, 0, 1953, 1850, 0, 908, 4, 0, 13, 708, 23}, + {6, 13, 123, 28, 197, 0, 202, 69, 0, 6, 0, 21, 1434, 1582, 11, 0, 6, + 0, 7, 190, 4, 1700, 6, 434, 1886, 0, 14, 28, 8, 30, 25, 18}, + {5, 27, 1442, 18, 0, 6, 0, 73, 6, 781, 0, 1915, 291, 649, 98, 4, + 33, 77, 6, 22, 73, 9, 8, 587, 1486, 32, 10, 244, 37, 0, 100, 9} + }; + + for (int i = 0; i < num_layers; i++) { + launcher(total_rows_for_each_expert[i], 5760, 2880, num_experts); + launcher(total_rows_for_each_expert[i], 2880, 2880, num_experts); + } + + return 0; +} diff --git a/examples/12_bmg_moe_gemm_cute_interface/CMakeLists.txt b/examples/12_bmg_moe_gemm_cute_interface/CMakeLists.txt new file mode 100644 index 0000000000..000e98a28e --- /dev/null +++ b/examples/12_bmg_moe_gemm_cute_interface/CMakeLists.txt @@ -0,0 +1,37 @@ +# Copyright (c) 2025 Intel Corporation. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +if(NOT "${DPCPP_HOST_COMPILER}" MATCHES "g\\+\\+") + cutlass_example_add_executable( + 12_bmg_moe_gemm_cute_interface + 12_bmg_moe_gemm_cute_interface.cpp + ) + if(NOT DPCPP_SYCL_TARGET STREQUAL "spir64") + target_link_options( 12_bmg_moe_gemm_cute_interface PRIVATE -Xs "-options \"-igc_opts 'VectorAliasBBThreshold=10000'\"" ) + endif() +endif() diff --git a/examples/12_bmg_moe_gemm_cute_interface/moe_gemms.hpp b/examples/12_bmg_moe_gemm_cute_interface/moe_gemms.hpp new file mode 100644 index 0000000000..8b5af99d94 --- /dev/null +++ b/examples/12_bmg_moe_gemm_cute_interface/moe_gemms.hpp @@ -0,0 +1,158 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include +#include +#include + +#include + +#include "cutlass/kernel_hardware_info.h" +#include "cutlass/platform/platform.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/sycl_event_manager.hpp" + +#pragma clang diagnostic ignored "-Wpass-failed" +#pragma clang diagnostic ignored "-Wdeprecated-declarations" + +template struct is_16_bit_fp : std::false_type {}; + +template <> struct is_16_bit_fp : std::true_type {}; +template <> struct is_16_bit_fp : std::true_type {}; + +template +inline constexpr bool is_16_bit_fp_v = + is_16_bit_fp>>::value; + +static_assert(is_16_bit_fp_v); +static_assert(is_16_bit_fp_v); + +namespace MoE { + +using namespace cute; + +template < + class GmemTiledCopyA, class GmemTiledCopyB, class GmemTiledCopyD, + class ATensor, class BTensor, class DTensor, class TiledMMA, + class = std::enable_if_t && + is_16_bit_fp_v>> +CUTE_DEVICE void moe_gemm(ATensor const &A, // (M,K) + BTensor const &B, // (N,K) + DTensor &D, // (M,N) + Coord blk_coord, + TiledMMA const &mma) { + auto item = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + auto local_id = item.get_local_linear_id(); + auto wg_m = get<0>(blk_coord); + auto wg_n = get<1>(blk_coord); + + Tensor cA = make_identity_tensor(A.shape()); // (M,K) + Tensor cB = make_identity_tensor(B.shape()); // (N,K) + Tensor cD = make_identity_tensor(D.shape()); // (M,N) + + auto wg_coord = make_coord(wg_m, wg_n, 0); + auto wg_tile = mma.tile_mnk(); + + Tensor gA = local_tile(cA, select<0, 2>(wg_tile), make_coord(wg_m, _)); + Tensor gB = local_tile(cB, select<1, 2>(wg_tile), make_coord(wg_n, _)); + Tensor gD = local_tile(cD, wg_tile, wg_coord, Step<_1, _1, X>{}); + + auto thr_mma = mma.get_slice(local_id); + + auto tiled_copy_a = get_block_2d_copy_A(mma, A); + auto tiled_copy_b = get_block_2d_copy_B(mma, B); + auto tiled_copy_d = get_block_2d_copy_D(mma, D); + + auto thr_copy_a = tiled_copy_a.get_slice(local_id); + auto thr_copy_b = tiled_copy_b.get_slice(local_id); + auto thr_copy_d = tiled_copy_d.get_slice(local_id); + + auto tCrA = thr_mma.partition_sg_fragment_A(gA(_, _, 0)); + auto tCrB = thr_mma.partition_sg_fragment_B(gB(_, _, 0)); + auto tCrD = thr_mma.partition_sg_fragment_C(gD); + auto tCrD_final = thr_copy_d.partition_sg_fragment_S(gD); + + auto tArA = thr_copy_a.partition_sg_fragment_D(gA(_, _, 0)); + auto tBrB = thr_copy_b.partition_sg_fragment_D(gB(_, _, 0)); + + Tensor tAgA = thr_copy_a.partition_S(gA); + Tensor tBgB = thr_copy_b.partition_S(gB); + auto tCgD = thr_copy_d.partition_D(gD); + + auto prefetch_a = make_block_2d_prefetch(tiled_copy_a); + auto prefetch_b = make_block_2d_prefetch(tiled_copy_b); + + auto thr_prefetch_A = prefetch_a.get_slice(local_id); + auto thr_prefetch_B = prefetch_b.get_slice(local_id); + + auto pAgA = thr_prefetch_A.partition_S(gA); + auto pBgB = thr_prefetch_B.partition_S(gB); + + constexpr int barrier_scope = 2; + int k_start_idx = 0; + int prefetch_k = k_start_idx; + const int prefetch_dist = 3; + int k_tile_count = ceil_div(shape<1>(A), get<2>(wg_tile)); + + CUTE_UNROLL + for (; prefetch_k < prefetch_dist; prefetch_k++) { + prefetch(prefetch_a, pAgA(_, _, _, prefetch_k)); + prefetch(prefetch_b, pBgB(_, _, _, prefetch_k)); + } + + for (int k_tile = k_start_idx; k_tile < k_tile_count; + k_tile++, prefetch_k++) { + barrier_arrive(barrier_scope); + + copy(tiled_copy_a, tAgA(_, _, _, k_tile), tArA); + copy(tiled_copy_b, tBgB(_, _, _, k_tile), tBrB); + + if (prefetch_k < k_tile_count) { + prefetch(prefetch_a, pAgA(_, _, _, prefetch_k)); + prefetch(prefetch_b, pBgB(_, _, _, prefetch_k)); + } + + reorder(tArA, tCrA); + reorder(tBrB, tCrB); + + cute::gemm(mma, tCrA, tCrB, tCrD); + barrier_wait(barrier_scope); + } + reorder(tCrD, tCrD_final); + copy(tiled_copy_d, tCrD_final, tCgD); +} + +} // namespace MoE diff --git a/examples/12_bmg_moe_gemm_cute_interface/moe_grouped_gemm.hpp b/examples/12_bmg_moe_gemm_cute_interface/moe_grouped_gemm.hpp new file mode 100644 index 0000000000..56566b9263 --- /dev/null +++ b/examples/12_bmg_moe_gemm_cute_interface/moe_grouped_gemm.hpp @@ -0,0 +1,144 @@ +/*************************************************************************************************** + * Copyright 2025 Intel corporation. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/platform/platform.h" +#include "moe_gemms.hpp" +#include "moe_tile_scheduler.hpp" +#include + +#pragma clang diagnostic ignored "-Wpass-failed" +#pragma clang diagnostic ignored "-Wdeprecated-declarations" + +namespace MoE { +using namespace cute; + +using ProblemShapeMNKL = Shape; +using ProblemShape = cutlass::gemm::GroupProblemShape>; +using TileScheduler = typename MoE::PersistentTileSchedulerXeMoE; +using RasterOrderOptions = typename TileScheduler::RasterOrderOptions; + +template +CUTE_DEVICE auto make_moe_tensor(T *ptr, int r, int c) { + auto shape = make_shape(r, c); + if constexpr (LayoutKind == 'C') + return make_tensor(make_gmem_ptr(ptr), + make_layout(shape, make_stride(_1{}, r))); + else + return make_tensor(make_gmem_ptr(ptr), + make_layout(shape, make_stride(c, _1{}))); +} + +template +CUTE_DEVICE void +MoEGEMM(const ElementA *Activations, const ElementB *Weights, + const ElementS *Scales, ElementD *Outputs, TiledMMA const &mma, + const int32_t *M_per_group, const int32_t num_experts, const int32_t N, + const int32_t K, + PersistentTileSchedulerSm90GroupParams scheduler_params) { + + TileScheduler scheduler{scheduler_params, const_cast(M_per_group), + N, K, num_experts}; + + auto work_tile_info = scheduler.initial_work_tile_info(Shape<_1, _1, _1>{}); + constexpr char actual_layout_of_B = LayoutKindB ^ ('R' ^ 'C'); + bool did_group_change = true; + int32_t curr_group = 0; + int32_t prev_group = 0; + int32_t cumulative_M = 0; + int32_t M = 0; + + if (work_tile_info.is_valid()) { + // We don't really need this conditional outside the while loop. + // It simply helps initialize tensors. If using nullptr would be + // fine for their initialization, then we can remove this conditional. + curr_group = work_tile_info.L_idx; + M = M_per_group[curr_group]; + } + + auto A_tensor = make_moe_tensor( + const_cast(Activations), M, K); + auto B_tensor = make_moe_tensor( + const_cast(Weights), N, K); + auto D_tensor = make_moe_tensor(Outputs, M, N); + + while (work_tile_info.is_valid()) { + auto m_coord = work_tile_info.M_idx; + auto n_coord = work_tile_info.N_idx; + auto tile_coord = make_coord(m_coord, n_coord, _, 0); + + if (did_group_change) { + curr_group = work_tile_info.L_idx; + M = M_per_group[curr_group]; + // recompute each time because the groups don't necessarily increment by 1 + for (int i = prev_group; i < curr_group; i++) { + cumulative_M += M_per_group[i]; + } + prev_group = curr_group; + + ElementA *ptr_A_curr_batch = + const_cast(Activations) + cumulative_M * K; + ElementB *ptr_B_curr_batch = + const_cast(Weights) + curr_group * K * N; + ElementD *ptr_D_curr_batch = Outputs + cumulative_M * N; + + A_tensor = make_moe_tensor(ptr_A_curr_batch, M, K); + B_tensor = + make_moe_tensor(ptr_B_curr_batch, N, K); + D_tensor = make_moe_tensor(ptr_D_curr_batch, M, N); + did_group_change = false; + } + + // After adding scaledMM mainloops, add something like + // if constexpr (!cute::is_void_v) { + // moe_gemm( + // A_tensor, B_tensor, Scales, D_tensor, tile_coord, mma); + // } else { + moe_gemm( + A_tensor, B_tensor, D_tensor, tile_coord, mma); + + // Get next work tile + work_tile_info = scheduler.fetch_next_work(work_tile_info); + did_group_change = curr_group != work_tile_info.L_idx; + } // end while loop +} + +} // namespace MoE diff --git a/examples/12_bmg_moe_gemm_cute_interface/moe_tile_scheduler.hpp b/examples/12_bmg_moe_gemm_cute_interface/moe_tile_scheduler.hpp new file mode 100644 index 0000000000..c367a61550 --- /dev/null +++ b/examples/12_bmg_moe_gemm_cute_interface/moe_tile_scheduler.hpp @@ -0,0 +1,302 @@ +/*************************************************************************************************** + * Copyright (c) 2025 Intel Corporation. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/gemm_coord.hpp" +#include "cutlass/kernel_hardware_info.hpp" + +namespace MoE { +using namespace cutlass::gemm::kernel::detail; +using namespace cutlass; +using namespace cutlass::gemm; +using namespace cute; +/////////////////////////////////////////////////////////////////////////////// +// Adapted from xe_tile_scheduler_group.hpp +// Persistent Thread Block (TB) scheduler for MoE GEMM +template +class PersistentTileSchedulerXeMoE + : public PersistentTileSchedulerXeGroup { + // + // Data members + // + +private: + uint64_t current_work_linear_idx_ = 0; + uint64_t total_grid_size_ = 0; + int32_t *num_rows_per_expert_ = nullptr; + int32_t K_ = 0; + int32_t N_ = 0; + int32_t num_experts_ = 0; + + // Tracking current group, its starting linear idx and total tiles + struct GroupInfo { + int group_idx = 0; + uint64_t start_linear_idx = 0; + uint64_t total_tiles = 0; + } current_group_info_; + +public: + struct WorkTileInfo { + int32_t M_idx = 0; + int32_t N_idx = 0; + int32_t L_idx = 0; + bool is_valid_tile = false; + + CUTLASS_HOST_DEVICE + bool is_valid() const { return is_valid_tile; } + + CUTLASS_HOST_DEVICE + static WorkTileInfo invalid_work_tile() { return {-1, -1, -1, false}; } + }; + + using ProblemShape = typename GroupProblemShape::UnderlyingProblemShape; + using Params = PersistentTileSchedulerSm90GroupParams; + using RasterOrder = typename Params::RasterOrder; + using RasterOrderOptions = typename Params::RasterOrderOptions; + using BaseClass = PersistentTileSchedulerXeGroup; + + Params scheduler_params; + + // + // Methods + // + + // Given the inputs, computes the total number of output blocks this problem + // will compute over Note that this is only the logical size of our grid, not + // the physical grid we will actually launch. + template + CUTLASS_HOST_DEVICE static dim3 + get_tiled_cta_shape_mnl(const KernelHardwareInfo &hw_info, + ClusterShape cluster_shape) { + uint32_t total_ctas = 0; + uint32_t cta_in_N_dim = + 1; // We linearize the blocks across all the problems here + + total_ctas = hw_info.sm_count; + + return Params::get_tiled_cta_shape_mnl(to_gemm_coord(cluster_shape), + total_ctas, cta_in_N_dim); + } + + template + static Params to_underlying_arguments( + GroupProblemShape problem_shapes, TileShape tile_shape, + ClusterShape cluster_shape, KernelHardwareInfo const &hw_info, + typename BaseClass::Arguments const &arguments, + [[maybe_unused]] void *workspace = nullptr, + [[maybe_unused]] const uint32_t epilogue_subtile = 1, + [[maybe_unused]] uint32_t ktile_start_alignment_count = 1u) { + return BaseClass::to_underlying_arguments(problem_shapes, tile_shape, + cluster_shape, hw_info, arguments, + workspace); + } + + // Given the inputs, computes the physical grid we should launch. + template + CUTLASS_HOST_DEVICE static dim3 + get_grid_shape([[maybe_unused]] Params const ¶ms, + GroupProblemShape problem_shapes, TileShape tile_shape, + ClusterShape cluster_shape, KernelHardwareInfo hw_info, + typename BaseClass::Arguments arguments, + bool truncate_by_problem_size = true) { + + return BaseClass::get_grid_shape(params, problem_shapes, tile_shape, + cluster_shape, hw_info, arguments, + truncate_by_problem_size); + } + + CUTLASS_DEVICE explicit PersistentTileSchedulerXeMoE( + Params const ¶ms_, int32_t *num_rows_per_expert, int32_t N, int32_t K, + int32_t num_experts) + : scheduler_params(params_) { + num_rows_per_expert_ = num_rows_per_expert; + N_ = N; + K_ = K; + num_experts_ = num_experts; + if (scheduler_params.raster_order_ == RasterOrder::AlongN) { + current_work_linear_idx_ = + uint64_t(BlockIdxX()) + uint64_t(BlockIdxY()) * uint64_t(GridDimX()); + } else { + current_work_linear_idx_ = + uint64_t(BlockIdxX()) * uint64_t(GridDimY()) + uint64_t(BlockIdxY()); + } + total_grid_size_ = + uint64_t(GridDimX()) * uint64_t(GridDimY()) * uint64_t(GridDimZ()); + } + + CUTLASS_DEVICE + WorkTileInfo get_current_work() { + return get_current_work_for_linear_idx(current_work_linear_idx_); + } + + CUTLASS_DEVICE + WorkTileInfo get_current_work_for_linear_idx(uint64_t linear_idx) { + return get_work_idx_m_and_n( + linear_idx, current_group_info_, scheduler_params.problem_shapes_, + scheduler_params.cta_shape_, scheduler_params.cluster_shape_, + scheduler_params.divmod_cluster_shape_major_, + scheduler_params.divmod_cluster_shape_minor_, + scheduler_params.divmod_cta_shape_m_, + scheduler_params.divmod_cta_shape_n_, + scheduler_params.log_swizzle_size_, scheduler_params.raster_order_); + } + + CUTLASS_DEVICE + void advance_to_next_work(uint32_t advance_count = 1) { + current_work_linear_idx_ += total_grid_size_ * uint64_t(advance_count); + } + + // get work_idx_m, work_idx_n from linear_idx while applying swizzle + CUTLASS_DEVICE + WorkTileInfo + get_work_idx_m_and_n(uint64_t linear_idx, struct GroupInfo &group_info, + GroupProblemShape &problem_shapes, GemmCoord cta_shape, + cutlass::gemm::GemmCoord cluster_shape, + FastDivmodU64Pow2 const &divmod_cluster_shape_major, + FastDivmodU64Pow2 const &divmod_cluster_shape_minor, + FastDivmodU64 const &divmod_cta_shape_m, + FastDivmodU64 const &divmod_cta_shape_n, + int32_t log_swizzle_size, RasterOrder raster_order) { + + bool valid_tile = true; + uint64_t ctas_along_m, ctas_along_n; + int total_problem_groups = num_experts_; + ctas_along_m = divmod_cta_shape_m.divide( + cute::shape<0>( + ProblemShape(num_rows_per_expert_[group_info.group_idx], N_, K_)) + + divmod_cta_shape_m.divisor - 1); + ctas_along_n = divmod_cta_shape_n.divide( + cute::shape<1>( + ProblemShape(num_rows_per_expert_[group_info.group_idx], N_, K_)) + + divmod_cta_shape_n.divisor - 1); + + auto problem_blocks_m = + round_up(ctas_along_m, (1 << log_swizzle_size) * cluster_shape.m()); + auto problem_blocks_n = + round_up(ctas_along_n, (1 << log_swizzle_size) * cluster_shape.n()); + group_info.total_tiles = problem_blocks_m * problem_blocks_n; + + while (group_info.start_linear_idx + group_info.total_tiles <= linear_idx) { + group_info.group_idx++; + + if (group_info.group_idx >= total_problem_groups) + return WorkTileInfo::invalid_work_tile(); + + group_info.start_linear_idx += group_info.total_tiles; + ctas_along_m = divmod_cta_shape_m.divide( + cute::shape<0>(ProblemShape( + num_rows_per_expert_[group_info.group_idx], N_, K_)) + + divmod_cta_shape_m.divisor - 1); + ctas_along_n = divmod_cta_shape_n.divide( + cute::shape<1>(ProblemShape( + num_rows_per_expert_[group_info.group_idx], N_, K_)) + + divmod_cta_shape_n.divisor - 1); + + problem_blocks_m = + round_up(ctas_along_m, (1 << log_swizzle_size) * cluster_shape.m()); + problem_blocks_n = + round_up(ctas_along_n, (1 << log_swizzle_size) * cluster_shape.n()); + group_info.total_tiles = problem_blocks_m * problem_blocks_n; + } + + uint64_t cluster_id, cluster_major_offset = 0, cluster_minor_offset = 0; + uint64_t blk_per_grid_dim = divmod_cluster_shape_minor.divide( + linear_idx - group_info.start_linear_idx); + divmod_cluster_shape_major(cluster_id, cluster_major_offset, + blk_per_grid_dim); + + // With static schedulers, we launch grid such that all cluster are linear + // (1-D) order, i.e., there can only be one cluster in the minor dimension. + // get_grid_shape() in scheduler params put cluster_shape.m/n() as the minor + // dimension based on raster order AlongN/M resp. Therefore, the offset of a + // CTA (inside a cluster) in the minor dimension can be directly be inferred + // by the blockIdx along the minor dimension. + if (raster_order == RasterOrder::AlongN) { + cluster_minor_offset = BlockIdxX(); + } else { + cluster_minor_offset = BlockIdxY(); + } + + uint64_t cluster_idx_minor, cluster_idx_major; + + uint64_t cluster_idx_minor_div_swizzle, extra, offset; + + offset = cluster_id & ((1 << log_swizzle_size) - 1); + extra = cluster_id >> log_swizzle_size; + + uint64_t curr_group_cluster_blk_major; + if (raster_order == RasterOrder::AlongN) { + curr_group_cluster_blk_major = + divmod_cluster_shape_major.divide(problem_blocks_n); + } else { + curr_group_cluster_blk_major = + divmod_cluster_shape_major.divide(problem_blocks_m); + } + cluster_idx_minor_div_swizzle = extra / curr_group_cluster_blk_major; + cluster_idx_major = extra % curr_group_cluster_blk_major; + + cluster_idx_minor = + cluster_idx_minor_div_swizzle * (1 << log_swizzle_size) + offset; + + auto minor_work_idx = static_cast( + cluster_idx_minor * divmod_cluster_shape_minor.divisor + + cluster_minor_offset); + auto major_work_idx = static_cast( + cluster_idx_major * divmod_cluster_shape_major.divisor + + cluster_major_offset); + + if (raster_order == RasterOrder::AlongN) { + return {minor_work_idx, major_work_idx, group_info.group_idx, valid_tile}; + } else { + return {major_work_idx, minor_work_idx, group_info.group_idx, valid_tile}; + } + } + + // Kernel helper function to get next work tile + CUTLASS_DEVICE + auto fetch_next_work(WorkTileInfo work_tile_info) { + advance_to_next_work(); + return get_current_work(); + } + + // Returns the initial work tile info that will be computed over + template + CUTLASS_DEVICE WorkTileInfo initial_work_tile_info(ClusterShape) { + return get_current_work(); + } +}; + +} // namespace MoE diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index d141f5b7de..99e11c1494 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -101,16 +101,23 @@ if(CUTLASS_ENABLE_SYCL) message(STATUS "Building examples for Intel GPU targets") foreach(EXAMPLE 00_bmg_gemm + 00_bmg_gemm/legacy 01_bmg_gemm_with_collective_builder 02_bmg_gemm_mixed_dtype 03_bmg_gemm_streamk 04_bmg_grouped_gemm + 04_bmg_grouped_gemm/legacy 05_bmg_gemm_with_epilogues + 05_bmg_gemm_with_epilogues/legacy 06_bmg_flash_attention 07_bmg_dual_gemm 08_bmg_gemm_f8 09_bmg_grouped_gemm_f8 + 09_bmg_grouped_gemm_f8/legacy 10_bmg_grouped_gemm_mixed_dtype + 11_xe20_cutlass_library + 12_bmg_moe_gemm_cute_interface + sdpa_bwd ) add_subdirectory(${EXAMPLE}) endforeach() diff --git a/examples/README.md b/examples/README.md index c37ff8c495..9ac067be68 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,9 +1,9 @@ -# CUTLASS SYCL - Programming Examples +# SYCL*TLA - Programming Examples > [!IMPORTANT] > ### ⚠️ **Not for Benchmarking!** ⚠️ > -> These examples are designed **solely for demonstrating CUTLASS-SYCL functionality** and may **NOT be optimized for performance benchmarking**. +> These examples are designed **solely for demonstrating SYCL*TLA functionality** and may **NOT be optimized for performance benchmarking**. > ### Build Requirements @@ -29,7 +29,7 @@ examples/ └── 01_gemm_softmax/ ``` -## CUTLASS-SYCL Examples for Intel GPUs +## SYCL*TLA Examples for Intel GPUs The following examples are optimized for Intel GPU architectures using SYCL: diff --git a/examples/cute/tutorial/xe_gemm.cpp b/examples/cute/tutorial/xe_gemm.cpp index 887c23599d..ea35bb6497 100644 --- a/examples/cute/tutorial/xe_gemm.cpp +++ b/examples/cute/tutorial/xe_gemm.cpp @@ -46,8 +46,12 @@ #include "../../common/sycl_cute_common.hpp" -#pragma clang diagnostic ignored "-Wpass-failed" -#pragma clang diagnostic ignored "-Wdeprecated-declarations" +#if defined(__clang__) + #pragma clang diagnostic ignored "-Wpass-failed" + #pragma clang diagnostic ignored "-Wdeprecated-declarations" +#elif defined(__GNUC__) + #pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#endif using namespace cute; @@ -86,7 +90,7 @@ gemm_device(ATensor const& A, // (M,K) /* Create block 2D TiledCopies */ auto copy_a = make_block_2d_copy_A(mma, A); auto copy_b = make_block_2d_copy_B(mma, B); - auto copy_c = make_block_2d_copy_C(mma, C); + auto copy_c = make_block_2d_copy_D(mma, C); /* Slice TiledCopy/TiledMMA operations to thread (work-item) level */ auto thr_mma = mma.get_slice(local_id); @@ -170,18 +174,6 @@ gemm_device(ATensor const& A, // (M,K) copy(copy_c, tCrC, tCgC); } - - -template -struct is_complete : std::false_type {}; - -template -struct is_complete : std::true_type {}; - -template -static constexpr bool is_complete_v = is_complete::value; - - template auto choose_mma_op() @@ -205,14 +197,22 @@ choose_tiled_mma(ATensor const& A, BTensor const& B, CTensor const&) auto op = choose_mma_op(); constexpr bool byte = (cute::max(sizeof_bits_v, sizeof_bits_v) <= 8); - constexpr bool use_1x_dpas_per_k = is_constant_v<1, decltype(stride<0>(A))> // Use one DPAS in k dimension for A^T case - || (byte && is_constant_v<1, decltype(stride<0>(B))>); // pending compiler improvements (also int8 B^N) + constexpr bool a_t = is_constant_v<1, decltype(stride<0>(A))>; + constexpr bool b_n = is_constant_v<1, decltype(stride<0>(B))>; + + constexpr bool use_1x_dpas_per_k = a_t // Use one DPAS in k dimension for A^T case + || (byte && b_n); // pending compiler improvements (also int8 B^N). + constexpr bool use_4x8_sg = ((sizeof_bits_v < sizeof_bits_v) // Use smaller B loads for expensive reorders. + && !(is_same_v)) + || (b_n && sizeof_bits_v < 8); using _K = conditional_t, C>; using WGTile = Shape<_256, _256, _K>; // 256x256 WG tile size - using SGLayout = Layout, Stride<_4, _1, _0>>; // 8x4 SG tiling, n-major + using SGLayout8x4 = Layout, Stride<_4, _1, _0>>; // 8x4 SG tiling, n-major + using SGLayout4x8 = Layout, Stride<_8, _1, _0>>; // 4x8 SG tiling, n-major + using SGLayout = conditional_t; using MMA = typename TiledMMAHelper, Layout, SGLayout>::TiledMMA; @@ -391,7 +391,7 @@ int main(int argc, char** argv) auto n = parse_size(); auto k = parse_size(); - sycl::queue Q; + sycl::queue Q = compat::get_default_queue(); // Native compute test_case(Q, m, n, k); @@ -414,6 +414,7 @@ int main(int argc, char** argv) test_case(Q, m, n, k); test_case(Q, m, n, k); + test_case(Q, m, n, k); // Upconversion cases test_case(Q, m, n, k); diff --git a/examples/python/cutlass_library/xe20_gemm_bf16.py b/examples/python/cutlass_library/xe20_gemm_bf16.py new file mode 100644 index 0000000000..93205b69ef --- /dev/null +++ b/examples/python/cutlass_library/xe20_gemm_bf16.py @@ -0,0 +1,267 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (C) 2025 Intel Corporation, All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +############################################################################### + +""" +Test the generated CUTLASS GEMM kernel (sycl_tla_gemm_xe20_bf16) +""" + +import ctypes +from ctypes import c_void_p, c_int, c_size_t, c_uint8, c_uint16, POINTER, byref +import numpy as np +import time +from pathlib import Path + + +def test_sycl_tla_gemm_xe20_bf16(): + """Test the compiled sycl_tla_gemm_xe20_bf16 function""" + + # Load the shared library + lib_path = Path(__file__).parent / '../../../build/examples/11_xe20_cutlass_library/libxe20_cutlass_library_bf16.so' + if not lib_path.exists(): + print(f"Error: {lib_path} not found!") + print("Please build the library first: ninja xe20_cutlass_library_bf16") + return + + lib = ctypes.CDLL(str(lib_path)) + + # Define function signature + # int sycl_tla_gemm_xe20_bf16( + # const uint16_t* X, const uint16_t* W, uint16_t* Y, + # const int M, const int N, const int K, const int B, + # const int lda, const int ldb, const int ldc, const int ldd, + # const int X_offset, const int W_offset, const int Y_offset, + # const uint8_t swizzle, + # size_t* workspace_size, uint8_t* workspace, sycl::queue* stream) + lib.sycl_tla_gemm_xe20_bf16.argtypes = [ + c_void_p, # X (input A) + c_void_p, # W (input B) + c_void_p, # Y (output) + c_int, # M + c_int, # N + c_int, # K + c_int, # B (batch) + c_int, # lda + c_int, # ldb + c_int, # ldc + c_int, # ldd + c_int, # X_offset + c_int, # W_offset + c_int, # Y_offset + c_uint8, # swizzle + POINTER(c_size_t), # workspace_size + c_void_p, # workspace + c_void_p, # stream (sycl::queue*) + ] + lib.sycl_tla_gemm_xe20_bf16.restype = c_int + + print("="*80) + print("Testing sycl_tla_gemm_xe20_bf16 (BF16 256x256x32 GEMM)") + print("="*80) + + # Problem dimensions (matching the kernel tile: 256x256x32) + M = 256 + N = 256 + K = 32 + B = 1 # batch size + + print(f"\nProblem size: M={M}, N={N}, K={K}, B={B}") + print(f" A: {M} x {K} (bfloat16, column-major)") + print(f" B: {K} x {N} (bfloat16, column-major)") + print(f" C: {M} x {N} (float, row-major)") + + # Leading dimensions (column-major for inputs, row-major for output) + lda = M # column-major: leading dimension is M + ldb = K # column-major: leading dimension is K + ldc = 0 # not used (ptr_C is nullptr) + ldd = N # row-major: leading dimension is N + + print(f"\nLeading dimensions: lda={lda}, ldb={ldb}, ldd={ldd}") + + # Allocate input/output matrices + # Note: Using uint16 to represent bfloat16 in memory + X = np.random.randint(0, 100, size=(M * K), dtype=np.uint16) + W = np.random.randint(0, 100, size=(K * N), dtype=np.uint16) + Y = np.zeros(M * N, dtype=np.float32) # Output is float32 + + print(f"\nAllocated matrices:") + print(f" X: {X.nbytes} bytes") + print(f" W: {W.nbytes} bytes") + print(f" Y: {Y.nbytes} bytes") + + # Query workspace size + print("\n1. Querying workspace size...") + workspace_size = c_size_t(0) + result = lib.sycl_tla_gemm_xe20_bf16( + c_void_p(), # X (not needed for workspace query) + c_void_p(), # W + c_void_p(), # Y + M, N, K, B, + lda, ldb, ldc, ldd, + 0, 0, 0, # offsets + 1, # swizzle + byref(workspace_size), + c_void_p(), # workspace + c_void_p(), # stream (NULL = use default) + ) + + if result != 0: + print(f" ✗ Workspace query failed with code {result}") + return + + print(f" ✓ Workspace required: {workspace_size.value} bytes") + + # Allocate workspace if needed + workspace = None + workspace_ptr = c_void_p() + if workspace_size.value > 0: + workspace = np.zeros(workspace_size.value, dtype=np.uint8) + workspace_ptr = workspace.ctypes.data_as(c_void_p) + print(f" ✓ Workspace allocated") + + # Run GEMM + print("\n2. Executing GEMM...") + + X_ptr = X.ctypes.data_as(c_void_p) + W_ptr = W.ctypes.data_as(c_void_p) + Y_ptr = Y.ctypes.data_as(c_void_p) + + # Warmup run + result = lib.sycl_tla_gemm_xe20_bf16( + X_ptr, W_ptr, Y_ptr, + M, N, K, B, + lda, ldb, ldc, ldd, + 0, 0, 0, # offsets + 1, # swizzle + None, # workspace_size (None = execute mode, not query) + workspace_ptr, + c_void_p(), # stream (NULL = use default) + ) + + if result != 0: + print(f" ✗ GEMM execution failed with code {result}") + return + + print(f" ✓ Warmup run completed") + + # Benchmark + print("\n3. Benchmarking...") + num_runs = 10 + times = [] + + for i in range(num_runs): + start = time.time() + result = lib.sycl_tla_gemm_xe20_bf16( + X_ptr, W_ptr, Y_ptr, + M, N, K, B, + lda, ldb, ldc, ldd, + 0, 0, 0, + 1, + None, # workspace_size (None = execute mode) + workspace_ptr, + c_void_p(), + ) + elapsed = time.time() - start + + if result != 0: + print(f" ✗ Run {i+1} failed with code {result}") + continue + + times.append(elapsed) + + if not times: + print(" ✗ All runs failed!") + return + + # Calculate statistics + avg_time = np.mean(times) + min_time = np.min(times) + max_time = np.max(times) + std_time = np.std(times) + + # Calculate FLOPS (2*M*N*K for GEMM) + flops = 2 * M * N * K + avg_gflops = flops / avg_time / 1e9 + peak_gflops = flops / min_time / 1e9 + + print(f"\n{'='*80}") + print(f"Performance Results ({num_runs} runs)") + print(f"{'='*80}") + print(f" Average time: {avg_time*1000:.3f} ms") + print(f" Min time: {min_time*1000:.3f} ms") + print(f" Max time: {max_time*1000:.3f} ms") + print(f" Std dev: {std_time*1000:.3f} ms") + print(f"") + print(f" Average GFLOPS: {avg_gflops:.2f}") + print(f" Peak GFLOPS: {peak_gflops:.2f}") + print(f"{'='*80}") + + # Check output (basic sanity check) + non_zero = np.count_nonzero(Y) + print(f"\nOutput sanity check:") + print(f" Non-zero elements: {non_zero}/{Y.size}") + print(f" Output range: [{Y.min():.3f}, {Y.max():.3f}]") + + return avg_gflops + + +def benchmark_multiple_sizes(): + """Benchmark different problem sizes""" + + print("\n" + "="*80) + print("Benchmarking Multiple Problem Sizes") + print("="*80) + + # Test different sizes (all should be compatible with 256x256x32 tile) + sizes = [ + (256, 256, 32), + (512, 512, 32), + (256, 256, 64), + (512, 512, 64), + (1024, 1024, 32), + ] + + # Note: This would require modifying the function to accept variable sizes + # For now, the kernel is hard-coded to 256x256x32 + print("\nNote: Current kernel is optimized for 256x256x32 tile size") + print("Multi-size benchmarking would require different kernel configurations") + + +if __name__ == "__main__": + try: + gflops = test_sycl_tla_gemm_xe20_bf16() + if gflops: + print(f"\n✓ Test completed successfully!") + print(f" Average performance: {gflops:.2f} GFLOPS") + except Exception as e: + print(f"\n✗ Test failed with exception:") + print(f" {e}") + import traceback + traceback.print_exc() diff --git a/examples/sdpa_bwd/CMakeLists.txt b/examples/sdpa_bwd/CMakeLists.txt new file mode 100644 index 0000000000..8db8ae2445 --- /dev/null +++ b/examples/sdpa_bwd/CMakeLists.txt @@ -0,0 +1,10 @@ +set(TEST_GROUPS --groups=128) + +cutlass_example_add_executable( + sdpa_backward + sdpa_backward.cpp + cnpy.cpp +) +target_link_options(sdpa_backward PUBLIC "-lz") +set_target_properties(sdpa_backward PROPERTIES CXX_COMPILER_LAUNCHER "IGC_VectorAliasBBThreshold=10000" ) +set_target_properties(sdpa_backward PROPERTIES CXX_LINKER_LAUNCHER "IGC_VectorAliasBBThreshold=10000" ) diff --git a/examples/sdpa_bwd/cnpy.cpp b/examples/sdpa_bwd/cnpy.cpp new file mode 100644 index 0000000000..2d28578643 --- /dev/null +++ b/examples/sdpa_bwd/cnpy.cpp @@ -0,0 +1,340 @@ +//Copyright (C) 2011 Carl Rogers +//Released under MIT License +//license available in LICENSE file, or at http://www.opensource.org/licenses/mit-license.php + +#include"cnpy.h" +#include +#include +#include +#include +#include +#include +#include +#include + +char cnpy::BigEndianTest() { + int x = 1; + return (((char *)&x)[0]) ? '<' : '>'; +} + +char cnpy::map_type(const std::type_info& t) +{ + if(t == typeid(float) ) return 'f'; + if(t == typeid(double) ) return 'f'; + if(t == typeid(long double) ) return 'f'; + + if(t == typeid(int) ) return 'i'; + if(t == typeid(char) ) return 'i'; + if(t == typeid(short) ) return 'i'; + if(t == typeid(long) ) return 'i'; + if(t == typeid(long long) ) return 'i'; + + if(t == typeid(unsigned char) ) return 'u'; + if(t == typeid(unsigned short) ) return 'u'; + if(t == typeid(unsigned long) ) return 'u'; + if(t == typeid(unsigned long long) ) return 'u'; + if(t == typeid(unsigned int) ) return 'u'; + + if(t == typeid(bool) ) return 'b'; + + if(t == typeid(std::complex) ) return 'c'; + if(t == typeid(std::complex) ) return 'c'; + if(t == typeid(std::complex) ) return 'c'; + + else return '?'; +} + +template<> std::vector& cnpy::operator+=(std::vector& lhs, const std::string rhs) { + lhs.insert(lhs.end(),rhs.begin(),rhs.end()); + return lhs; +} + +template<> std::vector& cnpy::operator+=(std::vector& lhs, const char* rhs) { + //write in little endian + size_t len = strlen(rhs); + lhs.reserve(len); + for(size_t byte = 0; byte < len; byte++) { + lhs.push_back(rhs[byte]); + } + return lhs; +} + +void cnpy::parse_npy_header(unsigned char* buffer,size_t& word_size, std::vector& shape, bool& fortran_order) { + //std::string magic_string(buffer,6); + uint8_t major_version = *reinterpret_cast(buffer+6); + uint8_t minor_version = *reinterpret_cast(buffer+7); + uint16_t header_len = *reinterpret_cast(buffer+8); + std::string header(reinterpret_cast(buffer+9),header_len); + + size_t loc1, loc2; + + //fortran order + loc1 = header.find("fortran_order")+16; + fortran_order = (header.substr(loc1,4) == "True" ? true : false); + + //shape + loc1 = header.find("("); + loc2 = header.find(")"); + + std::regex num_regex("[0-9][0-9]*"); + std::smatch sm; + shape.clear(); + + std::string str_shape = header.substr(loc1+1,loc2-loc1-1); + while(std::regex_search(str_shape, sm, num_regex)) { + shape.push_back(std::stoi(sm[0].str())); + str_shape = sm.suffix().str(); + } + + //endian, word size, data type + //byte order code | stands for not applicable. + //not sure when this applies except for byte array + loc1 = header.find("descr")+9; + bool littleEndian = (header[loc1] == '<' || header[loc1] == '|' ? true : false); + assert(littleEndian); + + //char type = header[loc1+1]; + //assert(type == map_type(T)); + + std::string str_ws = header.substr(loc1+2); + loc2 = str_ws.find("'"); + word_size = atoi(str_ws.substr(0,loc2).c_str()); +} + +void cnpy::parse_npy_header(FILE* fp, size_t& word_size, std::vector& shape, bool& fortran_order) { + char buffer[256]; + size_t res = fread(buffer,sizeof(char),11,fp); + if(res != 11) + throw std::runtime_error("parse_npy_header: failed fread"); + std::string header = fgets(buffer,256,fp); + assert(header[header.size()-1] == '\n'); + + size_t loc1, loc2; + + //fortran order + loc1 = header.find("fortran_order"); + if (loc1 == std::string::npos) + throw std::runtime_error("parse_npy_header: failed to find header keyword: 'fortran_order'"); + loc1 += 16; + fortran_order = (header.substr(loc1,4) == "True" ? true : false); + + //shape + loc1 = header.find("("); + loc2 = header.find(")"); + if (loc1 == std::string::npos || loc2 == std::string::npos) + throw std::runtime_error("parse_npy_header: failed to find header keyword: '(' or ')'"); + + std::regex num_regex("[0-9][0-9]*"); + std::smatch sm; + shape.clear(); + + std::string str_shape = header.substr(loc1+1,loc2-loc1-1); + while(std::regex_search(str_shape, sm, num_regex)) { + shape.push_back(std::stoi(sm[0].str())); + str_shape = sm.suffix().str(); + } + + //endian, word size, data type + //byte order code | stands for not applicable. + //not sure when this applies except for byte array + loc1 = header.find("descr"); + if (loc1 == std::string::npos) + throw std::runtime_error("parse_npy_header: failed to find header keyword: 'descr'"); + loc1 += 9; + bool littleEndian = (header[loc1] == '<' || header[loc1] == '|' ? true : false); + assert(littleEndian); + + //char type = header[loc1+1]; + //assert(type == map_type(T)); + + std::string str_ws = header.substr(loc1+2); + loc2 = str_ws.find("'"); + word_size = atoi(str_ws.substr(0,loc2).c_str()); +} + +void cnpy::parse_zip_footer(FILE* fp, uint16_t& nrecs, size_t& global_header_size, size_t& global_header_offset) +{ + std::vector footer(22); + fseek(fp,-22,SEEK_END); + size_t res = fread(&footer[0],sizeof(char),22,fp); + if(res != 22) + throw std::runtime_error("parse_zip_footer: failed fread"); + + uint16_t disk_no, disk_start, nrecs_on_disk, comment_len; + disk_no = *(uint16_t*) &footer[4]; + disk_start = *(uint16_t*) &footer[6]; + nrecs_on_disk = *(uint16_t*) &footer[8]; + nrecs = *(uint16_t*) &footer[10]; + global_header_size = *(uint32_t*) &footer[12]; + global_header_offset = *(uint32_t*) &footer[16]; + comment_len = *(uint16_t*) &footer[20]; + + assert(disk_no == 0); + assert(disk_start == 0); + assert(nrecs_on_disk == nrecs); + assert(comment_len == 0); +} + +cnpy::NpyArray load_the_npy_file(FILE* fp) { + std::vector shape; + size_t word_size; + bool fortran_order; + cnpy::parse_npy_header(fp,word_size,shape,fortran_order); + + cnpy::NpyArray arr(shape, word_size, fortran_order); + size_t nread = fread(arr.data(),1,arr.num_bytes(),fp); + if(nread != arr.num_bytes()) + throw std::runtime_error("load_the_npy_file: failed fread"); + return arr; +} + +cnpy::NpyArray load_the_npz_array(FILE* fp, uint32_t compr_bytes, uint32_t uncompr_bytes) { + + std::vector buffer_compr(compr_bytes); + std::vector buffer_uncompr(uncompr_bytes); + size_t nread = fread(&buffer_compr[0],1,compr_bytes,fp); + if(nread != compr_bytes) + throw std::runtime_error("load_the_npy_file: failed fread"); + + int err; + z_stream d_stream; + + d_stream.zalloc = Z_NULL; + d_stream.zfree = Z_NULL; + d_stream.opaque = Z_NULL; + d_stream.avail_in = 0; + d_stream.next_in = Z_NULL; + err = inflateInit2(&d_stream, -MAX_WBITS); + + d_stream.avail_in = compr_bytes; + d_stream.next_in = &buffer_compr[0]; + d_stream.avail_out = uncompr_bytes; + d_stream.next_out = &buffer_uncompr[0]; + + err = inflate(&d_stream, Z_FINISH); + err = inflateEnd(&d_stream); + + std::vector shape; + size_t word_size; + bool fortran_order; + cnpy::parse_npy_header(&buffer_uncompr[0],word_size,shape,fortran_order); + + cnpy::NpyArray array(shape, word_size, fortran_order); + + size_t offset = uncompr_bytes - array.num_bytes(); + memcpy(array.data(),&buffer_uncompr[0]+offset,array.num_bytes()); + + return array; +} + +cnpy::npz_t cnpy::npz_load(std::string fname) { + FILE* fp = fopen(fname.c_str(),"rb"); + + if(!fp) { + throw std::runtime_error("npz_load: Error! Unable to open file "+fname+"!"); + } + + cnpy::npz_t arrays; + + while(1) { + std::vector local_header(30); + size_t headerres = fread(&local_header[0],sizeof(char),30,fp); + if(headerres != 30) + throw std::runtime_error("npz_load: failed fread"); + + //if we've reached the global header, stop reading + if(local_header[2] != 0x03 || local_header[3] != 0x04) break; + + //read in the variable name + uint16_t name_len = *(uint16_t*) &local_header[26]; + std::string varname(name_len,' '); + size_t vname_res = fread(&varname[0],sizeof(char),name_len,fp); + if(vname_res != name_len) + throw std::runtime_error("npz_load: failed fread"); + + //erase the lagging .npy + varname.erase(varname.end()-4,varname.end()); + + //read in the extra field + uint16_t extra_field_len = *(uint16_t*) &local_header[28]; + if(extra_field_len > 0) { + std::vector buff(extra_field_len); + size_t efield_res = fread(&buff[0],sizeof(char),extra_field_len,fp); + if(efield_res != extra_field_len) + throw std::runtime_error("npz_load: failed fread"); + } + + uint16_t compr_method = *reinterpret_cast(&local_header[0]+8); + uint32_t compr_bytes = *reinterpret_cast(&local_header[0]+18); + uint32_t uncompr_bytes = *reinterpret_cast(&local_header[0]+22); + + if(compr_method == 0) {arrays[varname] = load_the_npy_file(fp);} + else {arrays[varname] = load_the_npz_array(fp,compr_bytes,uncompr_bytes);} + } + + fclose(fp); + return arrays; +} + +cnpy::NpyArray cnpy::npz_load(std::string fname, std::string varname) { + FILE* fp = fopen(fname.c_str(),"rb"); + + if(!fp) throw std::runtime_error("npz_load: Unable to open file "+fname); + + while(1) { + std::vector local_header(30); + size_t header_res = fread(&local_header[0],sizeof(char),30,fp); + if(header_res != 30) + throw std::runtime_error("npz_load: failed fread"); + + //if we've reached the global header, stop reading + if(local_header[2] != 0x03 || local_header[3] != 0x04) break; + + //read in the variable name + uint16_t name_len = *(uint16_t*) &local_header[26]; + std::string vname(name_len,' '); + size_t vname_res = fread(&vname[0],sizeof(char),name_len,fp); + if(vname_res != name_len) + throw std::runtime_error("npz_load: failed fread"); + vname.erase(vname.end()-4,vname.end()); //erase the lagging .npy + + //read in the extra field + uint16_t extra_field_len = *(uint16_t*) &local_header[28]; + fseek(fp,extra_field_len,SEEK_CUR); //skip past the extra field + + uint16_t compr_method = *reinterpret_cast(&local_header[0]+8); + uint32_t compr_bytes = *reinterpret_cast(&local_header[0]+18); + uint32_t uncompr_bytes = *reinterpret_cast(&local_header[0]+22); + + if(vname == varname) { + NpyArray array = (compr_method == 0) ? load_the_npy_file(fp) : load_the_npz_array(fp,compr_bytes,uncompr_bytes); + fclose(fp); + return array; + } + else { + //skip past the data + uint32_t size = *(uint32_t*) &local_header[22]; + fseek(fp,size,SEEK_CUR); + } + } + + fclose(fp); + + //if we get here, we haven't found the variable in the file + throw std::runtime_error("npz_load: Variable name "+varname+" not found in "+fname); +} + +cnpy::NpyArray cnpy::npy_load(std::string fname) { + + FILE* fp = fopen(fname.c_str(), "rb"); + + if(!fp) throw std::runtime_error("npy_load: Unable to open file "+fname); + + NpyArray arr = load_the_npy_file(fp); + + fclose(fp); + return arr; +} + + + diff --git a/examples/sdpa_bwd/cnpy.h b/examples/sdpa_bwd/cnpy.h new file mode 100644 index 0000000000..0d3bb4c3c2 --- /dev/null +++ b/examples/sdpa_bwd/cnpy.h @@ -0,0 +1,269 @@ +//Copyright (C) 2011 Carl Rogers +//Released under MIT License +//license available in LICENSE file, or at http://www.opensource.org/licenses/mit-license.php + +#ifndef LIBCNPY_H_ +#define LIBCNPY_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace cnpy { + + struct NpyArray { + NpyArray(const std::vector& _shape, size_t _word_size, bool _fortran_order) : + shape(_shape), word_size(_word_size), fortran_order(_fortran_order) + { + num_vals = 1; + for(size_t i = 0;i < shape.size();i++) num_vals *= shape[i]; + data_holder = std::shared_ptr>( + new std::vector(num_vals * word_size)); + } + + NpyArray() : shape(0), word_size(0), fortran_order(0), num_vals(0) { } + + template + T* data() { + return reinterpret_cast(&(*data_holder)[0]); + } + + template + const T* data() const { + return reinterpret_cast(&(*data_holder)[0]); + } + + template + std::vector as_vec() const { + const T* p = data(); + return std::vector(p, p+num_vals); + } + + size_t num_bytes() const { + return data_holder->size(); + } + + std::shared_ptr> data_holder; + std::vector shape; + size_t word_size; + bool fortran_order; + size_t num_vals; + }; + + using npz_t = std::map; + + char BigEndianTest(); + char map_type(const std::type_info& t); + template std::vector create_npy_header(const std::vector& shape); + void parse_npy_header(FILE* fp,size_t& word_size, std::vector& shape, bool& fortran_order); + void parse_npy_header(unsigned char* buffer,size_t& word_size, std::vector& shape, bool& fortran_order); + void parse_zip_footer(FILE* fp, uint16_t& nrecs, size_t& global_header_size, size_t& global_header_offset); + npz_t npz_load(std::string fname); + NpyArray npz_load(std::string fname, std::string varname); + NpyArray npy_load(std::string fname); + + template std::vector& operator+=(std::vector& lhs, const T rhs) { + //write in little endian + for(size_t byte = 0; byte < sizeof(T); byte++) { + char val = *((char*)&rhs+byte); + lhs.push_back(val); + } + return lhs; + } + + template<> std::vector& operator+=(std::vector& lhs, const std::string rhs); + template<> std::vector& operator+=(std::vector& lhs, const char* rhs); + + + template void npy_save(std::string fname, const T* data, const std::vector shape, std::string mode = "w") { + FILE* fp = NULL; + std::vector true_data_shape; //if appending, the shape of existing + new data + + if(mode == "a") fp = fopen(fname.c_str(),"r+b"); + + if(fp) { + //file exists. we need to append to it. read the header, modify the array size + size_t word_size; + bool fortran_order; + parse_npy_header(fp,word_size,true_data_shape,fortran_order); + assert(!fortran_order); + + if(word_size != sizeof(T)) { + std::cout<<"libnpy error: "< header = create_npy_header(true_data_shape); + size_t nels = std::accumulate(shape.begin(),shape.end(),1,std::multiplies()); + + fseek(fp,0,SEEK_SET); + fwrite(&header[0],sizeof(char),header.size(),fp); + fseek(fp,0,SEEK_END); + fwrite(data,sizeof(T),nels,fp); + fclose(fp); + } + + template void npz_save(std::string zipname, std::string fname, const T* data, const std::vector& shape, std::string mode = "w") + { + //first, append a .npy to the fname + fname += ".npy"; + + //now, on with the show + FILE* fp = NULL; + uint16_t nrecs = 0; + size_t global_header_offset = 0; + std::vector global_header; + + if(mode == "a") fp = fopen(zipname.c_str(),"r+b"); + + if(fp) { + //zip file exists. we need to add a new npy file to it. + //first read the footer. this gives us the offset and size of the global header + //then read and store the global header. + //below, we will write the the new data at the start of the global header then append the global header and footer below it + size_t global_header_size; + parse_zip_footer(fp,nrecs,global_header_size,global_header_offset); + fseek(fp,global_header_offset,SEEK_SET); + global_header.resize(global_header_size); + size_t res = fread(&global_header[0],sizeof(char),global_header_size,fp); + if(res != global_header_size){ + throw std::runtime_error("npz_save: header read error while adding to existing zip"); + } + fseek(fp,global_header_offset,SEEK_SET); + } + else { + fp = fopen(zipname.c_str(),"wb"); + } + + std::vector npy_header = create_npy_header(shape); + + size_t nels = std::accumulate(shape.begin(),shape.end(),1,std::multiplies()); + size_t nbytes = nels*sizeof(T) + npy_header.size(); + + //get the CRC of the data to be added + uint32_t crc = crc32(0L,(uint8_t*)&npy_header[0],npy_header.size()); + crc = crc32(crc,(uint8_t*)data,nels*sizeof(T)); + + //build the local header + std::vector local_header; + local_header += "PK"; //first part of sig + local_header += (uint16_t) 0x0403; //second part of sig + local_header += (uint16_t) 20; //min version to extract + local_header += (uint16_t) 0; //general purpose bit flag + local_header += (uint16_t) 0; //compression method + local_header += (uint16_t) 0; //file last mod time + local_header += (uint16_t) 0; //file last mod date + local_header += (uint32_t) crc; //crc + local_header += (uint32_t) nbytes; //compressed size + local_header += (uint32_t) nbytes; //uncompressed size + local_header += (uint16_t) fname.size(); //fname length + local_header += (uint16_t) 0; //extra field length + local_header += fname; + + //build global header + global_header += "PK"; //first part of sig + global_header += (uint16_t) 0x0201; //second part of sig + global_header += (uint16_t) 20; //version made by + global_header.insert(global_header.end(),local_header.begin()+4,local_header.begin()+30); + global_header += (uint16_t) 0; //file comment length + global_header += (uint16_t) 0; //disk number where file starts + global_header += (uint16_t) 0; //internal file attributes + global_header += (uint32_t) 0; //external file attributes + global_header += (uint32_t) global_header_offset; //relative offset of local file header, since it begins where the global header used to begin + global_header += fname; + + //build footer + std::vector footer; + footer += "PK"; //first part of sig + footer += (uint16_t) 0x0605; //second part of sig + footer += (uint16_t) 0; //number of this disk + footer += (uint16_t) 0; //disk where footer starts + footer += (uint16_t) (nrecs+1); //number of records on this disk + footer += (uint16_t) (nrecs+1); //total number of records + footer += (uint32_t) global_header.size(); //nbytes of global headers + footer += (uint32_t) (global_header_offset + nbytes + local_header.size()); //offset of start of global headers, since global header now starts after newly written array + footer += (uint16_t) 0; //zip file comment length + + //write everything + fwrite(&local_header[0],sizeof(char),local_header.size(),fp); + fwrite(&npy_header[0],sizeof(char),npy_header.size(),fp); + fwrite(data,sizeof(T),nels,fp); + fwrite(&global_header[0],sizeof(char),global_header.size(),fp); + fwrite(&footer[0],sizeof(char),footer.size(),fp); + fclose(fp); + } + + template void npy_save(std::string fname, const std::vector data, std::string mode = "w") { + std::vector shape; + shape.push_back(data.size()); + npy_save(fname, &data[0], shape, mode); + } + + template void npz_save(std::string zipname, std::string fname, const std::vector data, std::string mode = "w") { + std::vector shape; + shape.push_back(data.size()); + npz_save(zipname, fname, &data[0], shape, mode); + } + + template std::vector create_npy_header(const std::vector& shape) { + + std::vector dict; + dict += "{'descr': '"; + dict += BigEndianTest(); + dict += map_type(typeid(T)); + dict += std::to_string(sizeof(T)); + dict += "', 'fortran_order': False, 'shape': ("; + dict += std::to_string(shape[0]); + for(size_t i = 1;i < shape.size();i++) { + dict += ", "; + dict += std::to_string(shape[i]); + } + if(shape.size() == 1) dict += ","; + dict += "), }"; + //pad with spaces so that preamble+dict is modulo 16 bytes. preamble is 10 bytes. dict needs to end with \n + int remainder = 16 - (10 + dict.size()) % 16; + dict.insert(dict.end(),remainder,' '); + dict.back() = '\n'; + + std::vector header; + header += (char) 0x93; + header += "NUMPY"; + header += (char) 0x01; //major version of numpy format + header += (char) 0x00; //minor version of numpy format + header += (uint16_t) dict.size(); + header.insert(header.end(),dict.begin(),dict.end()); + + return header; + } + + +} + +#endif diff --git a/examples/sdpa_bwd/params.hpp b/examples/sdpa_bwd/params.hpp new file mode 100644 index 0000000000..2011273c02 --- /dev/null +++ b/examples/sdpa_bwd/params.hpp @@ -0,0 +1,272 @@ +#pragma once +#include +#include +using namespace cute; + +template +struct FAKernel { + /* + Q BATCH,NUM_HEAD_Q,SEQ_LEN_QO,HEAD_SIZE_QK + K BATCH,NUM_HEAD_KV,SEQ_LEN_KV,HEAD_SIZE_QK + V BATCH,NUM_HEAD_KV,SEQ_LEN_KV,HEAD_SIZE_VO + P BATCH,NUM_HEAD_Q,SEQ_LEN_QO,SEQ_LEN_KV + O BATCH,NUM_HEAD_Q,SEQ_LEN_QO,HEAD_SIZE_VO + */ + using DType = T_; + using VType = float; // accumulation + static constexpr int kHeadDim = kHeadDim_; + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kNSGs = kNSGs_; + static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; + static constexpr int AtomLayoutNdKV = AtomLayoutNdKV_; + static constexpr int AtomLayoutMdQ = AtomLayoutMdQ_; + static constexpr bool is_causal = is_causal_; + using MMA_Atom_ARCH = XE_DPAS_TT<8, VType, DType>; + using _K = Int; + using SubgroupLayoutSdP = Layout, Int, _1>>; + using SubgroupLayoutdKV = Layout, Int, _1>>; + using SubgroupLayoutdQ = Layout, Int, _1>>; + using TileShapeSdP = Layout, Int, _K>>; + static_assert(size<0>(TileShapeSdP{}) <= kBlockM && "tile size M must be smaller than or equal to kBlockM"); + static_assert(kBlockM % size<0>(TileShapeSdP{}) == 0 && "kBlockM must be dividable by tile size M"); + static_assert(size<1>(TileShapeSdP{}) <= kBlockN && "tile size N must be smaller than or equal to kBlockN"); + static_assert(kBlockN % size<1>(TileShapeSdP{}) == 0 && "kBlockN must be dividable by tile size N "); + + using TileShapedKV = Layout, Int, _K>>; + static_assert(size<0>(TileShapedKV{}) <= kBlockN && "tile size M must be smaller than or equal to kBlockN"); + static_assert(kBlockN % size<0>(TileShapedKV{}) == 0 && "kBlockN must be dividable by tile size M"); + static_assert(size<1>(TileShapedKV{}) <= kHeadDim && "tile size N must be smaller than or equal to kHeadDim"); + static_assert(kHeadDim % size<1>(TileShapedKV{}) == 0 && "kHeadDim must be dividable by tile size N"); + + using TileShapedQ = Layout, Int, _K>>; + static_assert(size<0>(TileShapedQ{}) <= kBlockM && "tile size M must be smaller than or equal to kBlockM"); + static_assert(kBlockM % size<0>(TileShapedQ{}) == 0 && "kBlockM must dividable by tile size M"); + static_assert(size<1>(TileShapedQ{}) <= kHeadDim && "tile size N must be smaller than or equal to kHeadDim"); + static_assert(kHeadDim % size<1>(TileShapedQ{}) == 0 && "kHeadDim must be dividable by tile size N"); + + using TiledMmaSdP = typename TiledMMAHelper, + TileShapeSdP, + SubgroupLayoutSdP>::TiledMMA; + + using TiledMmadKV = typename TiledMMAHelper, + TileShapedKV, + SubgroupLayoutdKV>::TiledMMA; + + using TiledMmadQ = typename TiledMMAHelper, + TileShapedQ, + SubgroupLayoutdQ>::TiledMMA; + static constexpr auto bP = Int<2>{}; // Pipeline + static constexpr int SubgroupSize = 16; + static constexpr int smem_size = 0; + + FAKernel() {} +}; + +using index_t = uint64_t; + +template +struct Param { + Param(const T *dO, + const T *o, + const T *q, + const T *k, + const T *v, + const float *lse, + float *odo, + float *dqaccum, + T *dq, + T *dk, + T *dv, + T *pb, + const float softmax_scale) + : do_ptr(dO), + o_ptr(o), + q_ptr(q), + k_ptr(k), + v_ptr(v), + lse_ptr(lse), + odo_ptr(odo), + dqaccum_ptr(dqaccum), + dq_ptr(dq), + dk_ptr(dk), + dv_ptr(dv), + pb_ptr(pb), + scale_softmax(softmax_scale), + scale_softmax_log2(softmax_scale * M_LOG2E), + is_bhsd(true) {} + // read only + const T *do_ptr; + const T *o_ptr; + const T *q_ptr; + const T *k_ptr; + const T *v_ptr; + const float *lse_ptr; + const float scale_softmax; + const float scale_softmax_log2; + // write + float *odo_ptr; + float *dqaccum_ptr; + T *dq_ptr; + T *dk_ptr; + T *dv_ptr; + T *pb_ptr; + + // const dimension + int batch; + int num_head_q; + int num_head_kv; + int seq_len_q; + int seq_len_q_pad; + int seq_len_kv; + int seq_len_kv_pad; + int head_dim; + int n_block; + int tail_n; + int m_block; + int tail_m; + int num_qh_per_kvh; + int num_nb_per_blk; + int q_r_stride; + int q_h_stride; + int q_b_stride; + + int k_r_stride; + int k_h_stride; + int k_b_stride; + + int dk_r_stride; + int dk_h_stride; + int dk_b_stride; + + int v_r_stride; + int v_h_stride; + int v_b_stride; + + int dv_r_stride; + int dv_h_stride; + int dv_b_stride; + + int o_r_stride; + int o_h_stride; + int o_b_stride; + + int s_r_stride; + int s_s_stride; + int s_b_stride; + + int dq_r_stride; + int dq_h_stride; + int dq_b_stride; + /* + * input output layout + * true batch, numhead, seqlen, headsize + * false batch, seqlen, numhead, headsize + */ + bool is_bhsd; +}; + +template +struct Boffset { + Boffset(Param ¶m_) : param(param_) {} + index_t q_offset(const index_t b_id, const index_t h_id, const index_t s_id) { + return b_id * param.q_b_stride + h_id * param.q_h_stride + s_id * param.q_r_stride; + } + index_t k_offset(const index_t b_id, const index_t h_id, const index_t s_id) { + return b_id * param.k_b_stride + h_id * param.k_h_stride + s_id * param.k_r_stride; + } + index_t v_offset(const index_t b_id, const index_t h_id, const index_t s_id) { + return b_id * param.v_b_stride + h_id * param.v_h_stride + s_id * param.v_r_stride; + } + index_t dk_offset(const index_t b_id, const index_t h_id, const index_t s_id) { + return b_id * param.dk_b_stride + h_id * param.dk_h_stride + s_id * param.dk_r_stride; + } + index_t dv_offset(const index_t b_id, const index_t h_id, const index_t s_id) { + return b_id * param.dv_b_stride + h_id * param.dv_h_stride + s_id * param.dv_r_stride; + } + index_t lse_offset(const index_t b_id, const index_t h_id, const index_t s_id) { + return b_id * param.seq_len_q * param.num_head_q + h_id * param.seq_len_q + s_id; + } + + index_t o_offset(const index_t b_id, const index_t h_id, const index_t s_id) { + return b_id * param.o_b_stride + h_id * param.o_h_stride + s_id * param.o_r_stride; + } + + index_t dq_offset(const index_t b_id, const index_t h_id, const index_t s_id) { + return b_id * param.dq_b_stride + h_id * param.dq_h_stride + s_id * param.dq_r_stride; + } + Param ¶m; +}; + +// for debug +template +void setup_bhsd_stride(Param ¶m) { + param.q_r_stride = param.head_dim; + param.q_h_stride = param.seq_len_q * param.head_dim; + param.q_b_stride = param.num_head_q * param.seq_len_q * param.head_dim; + + // param.dq_r_stride = param.head_dim; + // param.dq_h_stride = param.seq_len_q * param.head_dim; + // param.dq_b_stride = param.num_head_q * param.seq_len_q * param.head_dim; + + param.k_r_stride = param.head_dim; + param.k_h_stride = param.seq_len_kv * param.head_dim; + param.k_b_stride = param.num_head_kv * param.seq_len_kv * param.head_dim; + + param.dk_r_stride = param.head_dim; + param.dk_h_stride = param.seq_len_kv * param.head_dim; + param.dk_b_stride = param.num_head_q * param.seq_len_kv * param.head_dim; + + param.v_r_stride = param.head_dim; + param.v_h_stride = param.seq_len_kv * param.head_dim; + param.v_b_stride = param.num_head_kv * param.seq_len_kv * param.head_dim; + + param.dv_r_stride = param.head_dim; + param.dv_h_stride = param.seq_len_kv * param.head_dim; + param.dv_b_stride = param.num_head_q * param.seq_len_kv * param.head_dim; + + param.o_r_stride = param.head_dim; + param.o_h_stride = param.seq_len_q * param.head_dim; + param.o_b_stride = param.num_head_q * param.seq_len_q * param.head_dim; + + param.dq_r_stride = param.head_dim; + param.dq_h_stride = param.seq_len_q_pad * param.head_dim; + param.dq_b_stride = param.num_head_q * param.seq_len_q_pad * param.head_dim; +} + +template +void setup_bshd_stride(Param ¶m) { + param.q_r_stride = param.num_head_q * param.head_dim; + param.q_h_stride = param.head_dim; + param.q_b_stride = param.num_head_q * param.seq_len_q * param.head_dim; + + // param.dq_r_stride = param.head_dim; + // param.dq_h_stride = param.seq_len_q * param.head_dim; + // param.dq_b_stride = param.num_head_q * param.seq_len_q * param.head_dim; + + param.k_r_stride = param.num_head_kv * param.head_dim; + param.k_h_stride = param.head_dim; + param.k_b_stride = param.num_head_kv * param.seq_len_kv * param.head_dim; + + param.dk_r_stride = param.num_head_q * param.head_dim; + param.dk_h_stride = param.head_dim; + param.dk_b_stride = param.num_head_q * param.seq_len_kv * param.head_dim; + + param.v_r_stride = param.num_head_kv * param.head_dim; + param.v_h_stride = param.head_dim; + param.v_b_stride = param.num_head_kv * param.seq_len_kv * param.head_dim; + + param.dv_r_stride = param.num_head_q * param.head_dim; + param.dv_h_stride = param.head_dim; + param.dv_b_stride = param.num_head_q * param.seq_len_kv * param.head_dim; + + param.o_r_stride = param.num_head_q * param.head_dim; + param.o_h_stride = param.head_dim; + param.o_b_stride = param.num_head_q * param.seq_len_q * param.head_dim; + + param.dq_r_stride = param.num_head_q * param.head_dim; + param.dq_h_stride = param.head_dim; + param.dq_b_stride = param.num_head_q * param.seq_len_q_pad * param.head_dim; +} diff --git a/examples/sdpa_bwd/sdpa_backward.cpp b/examples/sdpa_bwd/sdpa_backward.cpp new file mode 100644 index 0000000000..d1ff577089 --- /dev/null +++ b/examples/sdpa_bwd/sdpa_backward.cpp @@ -0,0 +1,1385 @@ +#include +#include +#include +#include + +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/print_error.hpp" +#include "cutlass/util/sycl_event_manager.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cnpy.h" +#include "sdpa_util.hpp" +#include "params.hpp" + + +void read_args(int argc, char**argv, int n, int64_t *p) { + if (argc >= n + 1) + sscanf(argv[n], "%ld", p); +} + +void +debug_info() { + print("block idx (%d,%d,%d) dim (%d,%d,%d) thread idx (%d,%d,%d) dim (%d,%d,%d)\n", + BlockIdxX(), BlockIdxY(), BlockIdxZ(), + GridDimX(), GridDimY(), GridDimZ(), + ThreadIdxX(), ThreadIdxY(), ThreadIdxZ(), + BlockDimX(), BlockDimY(), BlockDimZ()); +} + +template +void print_t(Tensor &r) { + print(r); + for (int i = 0; i < size(r); ++i) { + if (i % 8 == 0) + print("\n(%03d): ", i / 8); + print("%10.7f ", (float)r(i)); + } + print("\n"); +} + +template +void print_t(T1 m, T2 g) { + print(m); + for (int i = 0; i < size(g); ++i) { + if (i % 8 == 0) + print("\n(%03d): ", i / 8); + print("%10.7f ", (float)m(g(i))); + } + print("\n"); +} + +template +void print_t_2d(T t) { + static_assert(rank(t) == 2, "Only support 2D Tensor"); + print(t); + for (int i = 0; i < size < 0>(t); ++i) { + print("\n(%03d): ", i); + for (int j = 0; j < size<1>(t); ++j) { + print("%10.7f ", (float)t(i,j)); + } + } + print("\n"); +} + +template +void print_d(T t) { + print(t); + for (int i = 0; i < size(t); ++i) { + if (i % 8 == 0) + print("\n(%03d): ", i / 8); + print("%10u ", t(i)); + } + print("\n"); +} + +template +void print_c(T t) { + print(t); + for (int i = 0; i < size(t); ++i) { + if (i % 8 == 0) + print("\n(%03d): ", i / 8); + print(t(i)); + } + print("\n"); +} + +using ProblemShapeRegular = cute::tuple; // batch, num_head_q,num_head_kv,seq_len_qo,seq_len_kv,head_size_qk,head_size_vo + +template +auto convert_layout_2d_layout(Layout layout) { + auto l = make_layout(make_layout(get<0>(layout), + get<1>(layout)), + get<2>(layout)); + return l; +} + +constexpr int tid = 0; +constexpr int bid = 0; + +const bool +is_cur_thread() { + return cute::thread(tid, bid); +} + +template +CUTLASS_DEVICE void +apply_mask_causal(Tensor &tensor, + Tensor &rC, + int m_offset, int n_offset, int diagonal_offset = 0) { + auto sg = compat::get_nd_item<1>().get_sub_group(); + auto group = compat::get_nd_item<1>().get_group(); + int sg_local_id = sg.get_local_id(); + int sg_group_id = sg.get_group_id(); + Tensor rC_2d = make_tensor( + rC.data(), + convert_layout_2d_layout(rC.layout())); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<1>(tensor); ++n) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<0>(tensor); ++m) { + int x = n_offset + get<1>(rC_2d(m, n)) + sg_local_id; + int y = m_offset + get<0>(rC_2d(m, n)) + diagonal_offset; + if (x > y) { + tensor(m, n) = -INFINITY; + } + } + } + return; +} + +template +auto +create_reg(Trait const &trait, + MTensor const &C, + TiledMMA const &tiled_mma) { + auto sg = compat::get_nd_item<1>().get_sub_group(); + auto first_thread_in_sg_idx = sg.get_group_linear_id() * Trait::SubgroupSize; + auto thr_mma = tiled_mma.get_slice(first_thread_in_sg_idx); + + Tensor cC = make_identity_tensor(C.shape()); // (M,N) + auto tile_mnk = tiled_mma.tile_mnk(); + Tensor gC = local_tile(cC, select<0, 1>(tile_mnk), make_coord(0, 0)); // (BLK_M,BLK_N) + auto copy_c = make_block_2d_copy_D(tiled_mma, C); + auto thr_copy_c = copy_c.get_slice(first_thread_in_sg_idx); + if constexpr(is_same_v) { + auto r32 = thr_mma.partition_sg_fragment_C(make_identity_tensor(select<0,1>(tile_mnk))); // allocate C fragment storage + return r32; + } else { + auto r16 = thr_copy_c.partition_sg_fragment_S(gC); + return r16; + } +} + +template +void +gemm_kernel(Trait &trait, + Tensor const& A, // (M,K) + Tensor const& B, // (N,K) + SubgroupTensor & acc, + TiledMMA const & mma, + const int m_block = 0, + const int n_block = 0) { + // ----- + // Setup + // ----- + + /* Get workgroup and local IDs */ + auto sg = compat::get_nd_item<1>().get_sub_group(); + auto first_thread_in_sg_idx = sg.get_group_linear_id() * Trait::SubgroupSize; + + /* Create proxy coordinate tensors for each global tensor */ + Tensor cA = make_identity_tensor(A.shape()); // (M,K) + Tensor cB = make_identity_tensor(B.shape()); // (N,K) + + auto tile_mnk = mma.tile_mnk(); + + Tensor gA = local_tile(cA, select<0,2>(tile_mnk), make_coord(m_block,_)); // (BLK_M,BLK_K,k) + Tensor gB = local_tile(cB, select<1,2>(tile_mnk), make_coord(n_block,_)); // (BLK_N,BLK_K,k) + + /* Create block 2D TiledCopies */ + auto copy_a = make_block_2d_copy_A(mma, A); + auto copy_b = make_block_2d_copy_B(mma, B); + + /* Slice TiledCopy/TiledMMA operations to thread (work-item) level */ + auto thr_mma = mma.get_slice(first_thread_in_sg_idx); + auto thr_copy_a = copy_a.get_slice(first_thread_in_sg_idx); + auto thr_copy_b = copy_b.get_slice(first_thread_in_sg_idx); + + /* Register fragments for MMA */ + auto tCrA = thr_mma.partition_sg_fragment_A(gA(_,_,0)); + auto tCrB = thr_mma.partition_sg_fragment_B(gB(_,_,0)); + + /* Register fragments for copies */ + auto tArA = thr_copy_a.partition_sg_fragment_D(gA(_,_,0)); + auto tBrB = thr_copy_b.partition_sg_fragment_D(gB(_,_,0)); + + /* Partition global tensor (proxies) for copies */ + Tensor tAgA = thr_copy_a.partition_S(gA); + Tensor tBgB = thr_copy_b.partition_S(gB); + + /* Partition C */ + // Tensor tCrC = partition_fragment_C(mma, select<0,1>(tile_mnk)); + + /* Create prefetch TiledCopy instances */ + auto prefetch_a = make_block_2d_prefetch(copy_a); + auto prefetch_b = make_block_2d_prefetch(copy_b); + + auto thr_prefetch_A = prefetch_a.get_slice(first_thread_in_sg_idx); + auto thr_prefetch_B = prefetch_b.get_slice(first_thread_in_sg_idx); + + /* Partition global tensor (proxies) for prefetch */ + auto pAgA = thr_prefetch_A.partition_S(gA); + auto pBgB = thr_prefetch_B.partition_S(gB); + + /* Prefetch distance, in units of k tiles */ + const int prefetch_dist = 3; + + // ------ + // Kernel + // ------ + + constexpr int barrier_scope = 2; + + int k_tile_count = ceil_div(shape<1>(A), get<2>(tile_mnk)); + int k_tile_prefetch = 0; + /* Clear the accumulators */ + if constexpr(clear_acc) + clear(acc); + + /* Warm up loops with prefetch to L1 */ + CUTE_UNROLL + for (; k_tile_prefetch < prefetch_dist; k_tile_prefetch++) { + prefetch(prefetch_a, pAgA(_,_,_,k_tile_prefetch)); + prefetch(prefetch_b, pBgB(_,_,_,k_tile_prefetch)); + } + + /* Main loop */ + for (int k_tile = 0; k_tile < k_tile_count; k_tile++, k_tile_prefetch++) { + /* Split barrier keeping threads loosely together */ + barrier_arrive(barrier_scope); + + /* Copy A/B from global memory (ideally L1 cache) to registers */ + copy(copy_a, tAgA(_,_,_,k_tile), tArA); + copy(copy_b, tBgB(_,_,_,k_tile), tBrB); + + /* Prefetch A/B tiles to L1 */ + prefetch(prefetch_a, pAgA(_,_,_,k_tile_prefetch)); + prefetch(prefetch_b, pBgB(_,_,_,k_tile_prefetch)); + + /* Shuffle data from copy fragments to MMA fragments */ + reorder(tArA, tCrA); + reorder(tBrB, tCrB); + + /* Accumulate C += A * B */ + gemm(mma, tCrA, tCrB, acc); + + /* Other half of split barrier */ + barrier_wait(barrier_scope); + } +} + +template +void +gemm_SdP(Trait &trait, + Tensor const& A, // (M,K) + Tensor const& B, // (N,K) + SubgroupTensor & rSdP, + TiledMMA const & mma) { + gemm_kernel(trait, A, B, rSdP, mma); +} + +template +void +gemm_dKV(Trait &trait, + Tensor const& A, // (M,K) + Tensor const& B, // (N,K) + SubgroupTensor & rdKV, + TiledMMA const & mma) { + gemm_kernel(trait, A, B, rdKV, mma); +} + +template +void +gemm_dQ(Trait &trait, + Tensor const& A, // (M,K) + Tensor const& B, // (N,K) + Tensor const& C, // (M,N) + TiledMMA const & mma, + const int m_block = 0, + const int n_block = 0) { + auto sg = compat::get_nd_item<1>().get_sub_group(); + auto first_thread_in_sg_idx = sg.get_group_linear_id() * trait.SubgroupSize; + auto tile_mnk = mma.tile_mnk(); + Tensor cC = make_identity_tensor(C.shape()); // (M,N) + Tensor gC = local_tile(cC, select<0, 1>(tile_mnk), make_coord(m_block, n_block)); // (BLK_M,BLK_N) + auto thr_mma = mma.get_slice(first_thread_in_sg_idx); + auto tCrC = thr_mma.partition_sg_fragment_C(make_identity_tensor(select<0,1>(tile_mnk))); // allocate C fragment storage + Tensor tCgC = thr_mma.partition_C(gC); + gemm_kernel(trait, A, B, tCrC, mma, m_block, n_block); + int local_id = sg.get_local_id(); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCgC); ++i) { + auto [m, n] = tCgC(i); + cutlass::atomicAdd(&C(m, n + local_id), tCrC(i)); + } +} + +template +void +mha_copy(Trait & trait, TiledMma &tiled_mma, + SubgroupTensor &r, + Tensor &m, + int m_block = 0, int n_block = 0) { + auto sg = compat::get_nd_item<1>().get_sub_group(); + auto first_thread_in_sg_idx = sg.get_group_linear_id() * trait.SubgroupSize; + auto copy_c = make_block_2d_copy_D(tiled_mma, m); + auto thr_copy_c = copy_c.get_slice(first_thread_in_sg_idx); + auto tile_mnk = tiled_mma.tile_mnk(); + Tensor cC = make_identity_tensor(m.shape()); + Tensor gC = local_tile(cC, select<0, 1>(tile_mnk), make_coord(m_block, n_block)); + Tensor tCgC = thr_copy_c.partition_D(gC); + copy(copy_c, r, tCgC); +} + +template +void +mha_reorder_copy(Trait & trait, TiledMma &tiled_mma, + SubgroupTensor &r, + Tensor &m){ + auto r16 = create_reg(trait, m, tiled_mma); + reorder(r, r16); + mha_copy(trait, tiled_mma, r16, m); +} + +template +CUTLASS_DEVICE void +mha_atomic_add(Tensor& m_tile, + Tensor& g_tile, + Tensor& r_tile, + const int local_id) { + CUTLASS_PRAGMA_UNROLL + for (int ni = 0; ni < size<3>(g_tile); ++ni) { + auto g = g_tile(_, _, _, ni); + auto r = r_tile(_, _, _, ni); + CUTLASS_PRAGMA_UNROLL + for (int ki = 0; ki < size(g); ++ki) { + auto [m, n, l] = g(ki); + cutlass::atomicAdd(&m_tile(m, n + local_id, 0), r(ki)); + } + } +} +template +CUTLASS_DEVICE void +load_1colvec(Tensor0 ®, Tensor1 &mT, Tensor2 &coord_row, + int tail_m = 0) { + if constexpr(Is_even_M) { + CUTLASS_PRAGMA_UNROLL + for (int mi = 0; mi < size(reg); ++mi) { + reg(mi) = mT(get<0>(coord_row(mi))); + } + } else { + for (int mi = 0; mi < size(reg); ++mi) { + int row = get<0>(coord_row(mi)); + if (row < tail_m) { + reg(mi) = mT(row); + } + } + } +} + +template +CUTLASS_DEVICE auto convert_layout_acc_layout(Layout acc_layout) { + static_assert(decltype(size<0>(acc_layout))::value == 8); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_1>{}); // ((2, 2), MMA_M, MMA_N, Tile_M, M, N) + auto l2 = make_layout(make_layout(get<0, 1>(l), get<1>(l)), + make_layout(get<2>(l))); + return l2; +} + +template +CUTLASS_DEVICE void scale_apply_exp2(Tensor &tensor, Tensor &max, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + CUTLASS_PRAGMA_UNROLL + for (int mi = 0; mi < size<0>(tensor); ++mi) { + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * M_LOG2E; + CUTLASS_PRAGMA_UNROLL + for (int ni = 0; ni < size<1>(tensor); ++ni) { + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + } + } +} + +template +CUTLASS_DEVICE void softmax_backward(Tensor0 &P, Tensor1 &dP_sum, Tensor2 &dP, const float scale) { + CUTLASS_PRAGMA_UNROLL + for (int mi = 0; mi < size<0>(dP); ++mi) { + CUTLASS_PRAGMA_UNROLL + for (int mj = 0; mj < size<1>(dP); ++mj) { + dP(mi, mj) = P(mi, mj) * (dP(mi, mj) - dP_sum(mi)) * scale; + } + } +} + +template +CUTLASS_DEVICE auto +softmax_forward(Tensor0 &scores, Tensor1 &tSrS, + Tensor2 &taccScS, Tensor3 &mLSE, + Tensor4 & mdPsum, Tensor5 &tPgP, + TilesaveP &tilesaveP, Param ¶m, + int m_block, int n_block, + bool Is_even_M, int tail_m = 0) { + if constexpr(is_causal) { + Tensor taccScS_rc = logical_divide(taccScS, Shape<_1>{}); + apply_mask_causal(scores, taccScS_rc, + m_block * kBlockM, n_block * kBlockN, + param.seq_len_kv - param.seq_len_q); + } + Tensor taccScS_row = logical_divide(taccScS, Shape<_1>{})(make_coord(0, _), _, 0); + Tensor lse = make_tensor(Shape>{}); + + if (Is_even_M) { + load_1colvec(lse, mLSE, taccScS_row); + } else { + load_1colvec(lse, mLSE, taccScS_row, tail_m); + } + + // P=softmax(S,lse) + scale_apply_exp2(scores, lse, param.scale_softmax_log2); + if (Is_even_M) + load_1colvec(lse, mdPsum, taccScS_row); + else + load_1colvec(lse, mdPsum, taccScS_row, tail_m); + return lse; +} + +template +void +dq_dk_dv_1colblock(Trait &trait, Param ¶m, + const int bidb, const int bidh, const int bidhkv, const int n_block, + const int tail_n = 0) { + using T = typename Trait::DType; + using V = typename Trait::VType; + constexpr int kHeadDim = Trait::kHeadDim; + constexpr int kBlockM = Trait::kBlockM; + constexpr int kBlockN = Trait::kBlockN; + constexpr int kNSGs = Trait::kNSGs; + constexpr int SubgroupSize = Trait::SubgroupSize; + constexpr int AtomLayoutMdQ = Trait::AtomLayoutMdQ; + constexpr bool is_causal = Trait::is_causal; + auto sg = compat::get_nd_item<1>().get_sub_group(); + auto group = compat::get_nd_item<1>().get_group(); + const int local_id = sg.get_local_id(); + auto first_thread_in_sg_idx = sg.get_group_linear_id() * trait.SubgroupSize; + auto bofst = Boffset(param); + + const index_t q_offset = bofst.q_offset(bidb, bidh, 0); + const index_t k_offset = bofst.k_offset(bidb, bidhkv, n_block * kBlockN); + const index_t v_offset = bofst.v_offset(bidb, bidhkv, n_block * kBlockN); + const index_t dk_offset = bofst.dk_offset(bidb, bidh, n_block * kBlockN); + const index_t dv_offset = bofst.dv_offset(bidb, bidh, n_block * kBlockN); + const index_t o_offset = bofst.o_offset(bidb, bidh, 0); + const index_t dq_offset = bofst.dq_offset(bidb, bidh, 0); + const index_t lse_offset = bofst.lse_offset(bidb, bidh, 0); + // buff offset + const index_t pb_offset = (bidb * param.num_head_q * param.seq_len_kv_pad * kBlockM + + bidh * param.seq_len_kv_pad * kBlockM + + n_block * kBlockN * kBlockM) * 2; + const index_t dsb_offset = pb_offset + kBlockN * kBlockM; + + const auto block_n_dim = tail_n == 0 ? Int{} : ((tail_n + 1) & ~1); + auto shapeO = make_shape(kBlockM, Int{}); + auto shapeQtOt = make_shape(Int{}, kBlockM); + auto shapeSP = make_shape(kBlockM, block_n_dim); + auto shapePt = make_shape(block_n_dim, kBlockM); + + using Shape1 = Shape< + std::conditional_t, int>, Int>; + using Shape2 = Shape< + Int , + std::conditional_t, int>>; + auto shapeQ = make_shape(kBlockM, Int{}); + auto shapedQ = Shape, Int>{}; + Shape1 shapeKtV; + Shape2 shapeK; + if constexpr(Is_even_N) { + shapeKtV = make_shape(Int{}, Int{}); + shapeK = make_shape(Int{}, Int{}); + } else { + shapeKtV = make_shape(tail_n, Int{}); + shapeK = make_shape(Int{}, tail_n); + } + + Tensor mQ = make_tensor(make_gmem_ptr(param.q_ptr + q_offset), + make_layout( + shapeQ, + make_stride(param.q_r_stride, _1{}))); + Tensor mKt = make_tensor(make_gmem_ptr(param.k_ptr + k_offset), + make_layout( + shapeKtV, + make_stride(param.k_r_stride, _1{}))); + Tensor mdO = make_tensor(make_gmem_ptr(param.do_ptr + o_offset), + make_layout( + shapeO, + make_stride(param.o_r_stride, _1{}))); + Tensor mV = make_tensor(make_gmem_ptr(param.v_ptr + v_offset), + make_layout( + shapeKtV, + make_stride(param.v_r_stride, _1{}))); + // intermediate buffer + Tensor mP = make_tensor(make_gmem_ptr(param.pb_ptr + pb_offset), + make_layout( + shapeSP, + make_stride(block_n_dim, _1{}))); + Tensor mPt = make_tensor(make_gmem_ptr(param.pb_ptr + pb_offset), + make_layout( + shapePt, + make_stride(_1{}, block_n_dim))); + Tensor mdOt = make_tensor(make_gmem_ptr(param.do_ptr + o_offset), + make_layout( + shapeQtOt, + make_stride(_1{}, param.o_r_stride))); + Tensor mK = make_tensor(make_gmem_ptr(param.k_ptr + k_offset), + make_layout( + shapeK, + make_stride(_1{}, param.k_r_stride))); + Tensor mdPt = make_tensor(make_gmem_ptr(param.pb_ptr + dsb_offset), + make_layout( + shapePt, + make_stride(_1{}, block_n_dim))); + Tensor mQt = make_tensor(make_gmem_ptr(param.q_ptr + q_offset), + make_layout( + shapeQtOt, + make_stride(_1{}, param.q_r_stride))); + + Tensor mLSE = make_tensor(make_gmem_ptr(param.lse_ptr + lse_offset), + make_layout( + Shape>{}, + Stride<_1>{})); + Tensor mdPsum = make_tensor(make_gmem_ptr(param.odo_ptr + lse_offset), + make_layout( + Shape>{}, + Stride<_1>{})); + + Tensor mdV = make_tensor(make_gmem_ptr(param.dv_ptr + dv_offset), + make_layout( + shapeKtV, + make_stride(param.dv_r_stride, _1{}))); + Tensor mdP = make_tensor(make_gmem_ptr(param.pb_ptr + dsb_offset), + make_layout( + shapeSP, + make_stride(block_n_dim, _1{}))); + Tensor mdQaccum = make_tensor(make_gmem_ptr(param.dqaccum_ptr + dq_offset), + make_layout( + shapedQ, + make_stride(param.dq_r_stride, _1{}))); + Tensor mdK = make_tensor(make_gmem_ptr(param.dk_ptr + dk_offset), + make_layout( + shapeKtV, + make_stride(param.dk_r_stride, _1{}))); + + typename Trait::TiledMmaSdP tiled_mma_sdp; + typename Trait::TiledMmadKV tiled_mma_dkv; + typename Trait::TiledMmadQ tiled_mma_dq; + + auto thr_mma_sdp = tiled_mma_sdp.get_slice(first_thread_in_sg_idx); + + // for lse read + Tensor caccS = make_identity_tensor(Shape, Int>{}); // same buffer as accS + Tensor taccScS = thr_mma_sdp.partition_C(caccS); + static_assert(decltype(size<0>(taccScS))::value == 8); + Tensor taccScS_rc = logical_divide(taccScS, Shape<_1>{}); + Tensor taccScS_row = logical_divide(taccScS, Shape<_1>{})(make_coord(0, _), _, 0); + Tensor lse = make_tensor(Shape>{}); + // static_assert(size<0>(tSrS) * size<1>(tSrS) == size<0>(lse) && "row of acc and lse not match"); + // misc + + const int max_m_block = ceil_div(param.seq_len_q, kBlockM); + const int tail_m = param.seq_len_q % kBlockM; + + auto rdV = create_reg(trait, + mdV, + tiled_mma_dkv); + auto rdK = create_reg(trait, + mdK, + tiled_mma_dkv); + clear(rdV); + clear(rdK); + // clear accumulator + for (int m_block = 0; m_block < max_m_block; ++m_block) { + const bool Is_even_M = not ((m_block == max_m_block - 1) and (tail_m != 0)); + if (not Is_even_M) { + mQ = make_tensor(make_gmem_ptr(mQ.data()), + make_layout( + make_shape(tail_m, Int{}), + make_stride(param.q_r_stride, _1{}))); + mdO = make_tensor(make_gmem_ptr(mdO.data()), + make_layout( + make_shape(tail_m, Int{}), + make_stride(param.o_r_stride, _1{}))); + mdOt = make_tensor(make_gmem_ptr(mdOt.data()), + make_layout( + make_shape(Int{}, tail_m), + make_stride(_1{}, param.o_r_stride))); + mdQaccum = make_tensor(make_gmem_ptr(mdQaccum.data()), + make_layout( + shapedQ, + make_stride(param.dq_r_stride, _1{}))); + mQt = make_tensor(make_gmem_ptr(mQt.data()), + make_layout( + make_shape(Int{}, tail_m), + make_stride(_1{}, param.q_r_stride))); + } + { + auto rS = create_reg(trait, + mP, + tiled_mma_sdp); + clear(rS); + // S=QKt + gemm_SdP(trait, mQ, mKt, rS, + tiled_mma_sdp); + Tensor scores = make_tensor(rS.data(), convert_layout_acc_layout(rS.layout())); + if constexpr(is_causal) { + apply_mask_causal(scores, taccScS_rc, m_block * kBlockM, n_block * kBlockN, param.seq_len_kv - param.seq_len_q); + } + + if (Is_even_M) { + load_1colvec(lse, mLSE, taccScS_row); + } else { + load_1colvec(lse, mLSE, taccScS_row, tail_m); + } + + Tensor dP_sum = make_fragment_like(lse); + + if (Is_even_M) + load_1colvec(dP_sum, mdPsum, taccScS_row); + else + load_1colvec(dP_sum, mdPsum, taccScS_row, tail_m); + + // P=softmax(S,lse) + scale_apply_exp2(scores, lse, param.scale_softmax_log2); + mha_reorder_copy(trait, tiled_mma_sdp, rS, mP); + auto rdP = create_reg(trait, + mdP, + tiled_mma_sdp); + clear(rdP); + // dP=dO*Vt + gemm_SdP(trait, mdO, mV, rdP, + tiled_mma_sdp); + Tensor dS = make_tensor(rdP.data(), scores.layout()); + // dS=P(dP-sum_row(P))*scale + softmax_backward(scores, dP_sum, dS, param.scale_softmax); + mha_reorder_copy(trait, tiled_mma_sdp, rdP, mdP); // copy dP to internal buff + } + // dV=Pt*dO + gemm_dKV(trait, mPt, mdOt, rdV, + tiled_mma_dkv); + // dQ=dP*K + gemm_dQ(trait, mdP, mK, mdQaccum, + tiled_mma_dq); + // dK=dPt*Q + gemm_dKV(trait, mdPt, mQt, rdK, + tiled_mma_dkv); + // update ptr/atom copy + mQ.data() = mQ.data() + int(kBlockM * param.q_r_stride); + mdO.data() = mdO.data() + int(kBlockM * param.o_r_stride); + mdOt.data() = mdOt.data() + int(kBlockM * param.o_r_stride); + mdQaccum.data() = mdQaccum.data() + int(kBlockM * param.dq_r_stride); + mQt.data() = mQt.data() + int(kBlockM * param.q_r_stride); + mLSE.data() = mLSE.data() + int(kBlockM); + mdPsum.data() = mdPsum.data() + int(kBlockM); + + } + mha_reorder_copy(trait, tiled_mma_dkv, rdV, mdV); + mha_reorder_copy(trait, tiled_mma_dkv, rdK, mdK); +} + +template +void +compute_o_dot_do(T &trait, Param ¶m, + const int m_block, const int bidb, const int bidh) { + // The thread index. + constexpr int kBlockM = T::kBlockM; + constexpr int kBlockN = T::kBlockN; + constexpr int kHeadDim = T::kHeadDim; + constexpr int kNSGs = T::kNSGs; + constexpr int SubgroupSize = T::SubgroupSize; + using DType = typename T::DType; + using VType = typename T::VType; + + auto sg = compat::get_nd_item<1>().get_sub_group(); + auto group = compat::get_nd_item<1>().get_group(); + auto first_thread_in_sg_idx = sg.get_group_linear_id() * trait.SubgroupSize; + auto bofst = Boffset(param); + + const index_t o_offset = bofst.o_offset(bidb, bidh, m_block * kBlockM); + const index_t dq_offset = bofst.dq_offset(bidb, bidh, m_block * kBlockM); + const index_t dpsum_offset = bofst.lse_offset(bidb, bidh, m_block * kBlockM); + + using ShapeO = Shape< + std::conditional_t , int>, + Int>; + using ShapeP = Shape< + std::conditional_t , int>>; + ShapeO O_shape; + ShapeP dP_shape; + if constexpr(Is_even_M) { + O_shape = make_shape(Int{}, Int{}); + dP_shape = make_shape(Int{}); + } else { + O_shape = make_shape(param.tail_m, Int{}); + dP_shape = make_shape(param.tail_m); + } + auto dQ_shape = make_shape(Int{}, Int{}); + + Tensor mdO = make_tensor(make_gmem_ptr(param.do_ptr + o_offset), + make_layout( + O_shape, + make_stride(param.o_r_stride, _1{}))); + Tensor mO = make_tensor(make_gmem_ptr(param.o_ptr + o_offset), + make_layout( + O_shape, + make_stride(param.o_r_stride, _1{}))); + Tensor mdQaccum = make_tensor(make_gmem_ptr(param.dqaccum_ptr + dq_offset), + make_layout( + make_shape(Int{}, Int{}), + make_stride(param.dq_r_stride, _1{}))); + Tensor mdPsum = make_tensor(make_gmem_ptr(param.odo_ptr + dpsum_offset), + make_layout( + dP_shape, + Stride<_1>{})); + + auto tileload_odo = make_tiled_copy(Copy_Atom, DType>{}, + Layout, Int>, + Stride, _1>>{}, + Layout>{}); + auto tileload_dq = make_tiled_copy(Copy_Atom, VType>{}, + Layout, Int>>{}, + Layout>{}); + auto thr_load_odo = tileload_odo.get_thread_slice(ThreadIdxX()); + auto thr_load_dq = tileload_dq.get_thread_slice(ThreadIdxX()); + + Tensor thr_tile_do_S = thr_load_odo.partition_S(mdO); + Tensor thr_tile_o_S = thr_load_odo.partition_S(mO); + Tensor thr_tile_dq_D = thr_load_dq.partition_D(mdQaccum); + Tensor rdQ = make_fragment_like(thr_tile_dq_D); + Tensor rdO = make_fragment_like(rdQ); + Tensor rO = make_fragment_like(rdQ); + clear(rdQ); + copy(tileload_dq, rdQ, thr_tile_dq_D); + + Tensor cO = make_identity_tensor(dQ_shape); + Tensor tcO = thr_load_odo.partition_S(cO); + Tensor tcO_row = logical_divide(tcO, Shape<_1>{})(make_coord(0, 0), _, 0); + Tensor rdO_2d = make_tensor(rdO.data(), + convert_layout_2d_layout(rdO.layout())); + Tensor rO_2d = make_tensor(rO.data(), + convert_layout_2d_layout(rO.layout())); + if constexpr(Is_even_M) { + copy(tileload_odo, thr_tile_do_S, rdO); + copy(tileload_odo, thr_tile_o_S, rO); + CUTLASS_PRAGMA_UNROLL + for (int mi = 0; mi < size<0>(rdO_2d); ++mi) { + float accum = 0.0f; + CUTLASS_PRAGMA_UNROLL + for (int ni = 0; ni < size<1>(rdO_2d); ++ni) { + accum = accum + (float)rdO_2d(mi, ni) * (float)rO_2d(mi, ni); + } + accum = sycl::reduce_over_group(sg, accum, sycl::plus<>()); + if (sg.get_local_id() == 0) { + mdPsum(get<0>(tcO_row(mi))) = accum; + } + } + } else { + for (int mi = 0; mi < size<0>(rdO_2d); ++mi) { + if (get<0>(tcO_row(mi)) < param.tail_m) { + copy(tileload_odo, thr_tile_do_S(_, mi, _), rdO(_, mi, _)); + copy(tileload_odo, thr_tile_o_S(_, mi, _), rO(_, mi, _)); + } + } + CUTLASS_PRAGMA_UNROLL + for (int mi = 0; mi < size<0>(rdO_2d); ++mi) { + float accum = 0.0f; + CUTLASS_PRAGMA_UNROLL + for (int ni = 0; ni < size<1>(rdO_2d); ++ni) { + accum = accum + (float)rdO_2d(mi, ni) * (float)rO_2d(mi, ni); + } + accum = sycl::reduce_over_group(sg, accum, sycl::plus<>()); + if (sg.get_local_id() == 0 and get<0>(tcO_row(mi)) < param.tail_m) + mdPsum(get<0>(tcO_row(mi))) = accum; + } + } +} + +template +void +mha_backward_seq(T trait, + Param param) { + const int bidb = BlockIdxZ(); + const int bidhq = BlockIdxY(); + const int bidnblk = BlockIdxX(); + const int bidhkv = bidhq / param.num_qh_per_kvh; + for (int n_block = bidnblk; n_block < param.n_block; n_block += GridDimX()) { + if (param.tail_n > 0 and n_block == param.n_block - 1) + dq_dk_dv_1colblock(trait, param, bidb, bidhq, bidhkv, param.n_block - 1, param.tail_n); + else + dq_dk_dv_1colblock(trait, param, bidb, bidhq, bidhkv, n_block); + } +} + +template +void +mha_backward_parallel(T trait, + Param param) { + const int bidb = BlockIdxZ(); + const int bidhq = BlockIdxY(); + const int n_block = BlockIdxX(); + const int bidhkv = bidhq / param.num_qh_per_kvh; + if (param.tail_n > 0 and n_block == param.n_block - 1) + dq_dk_dv_1colblock(trait, param, bidb, bidhq, bidhkv, param.n_block - 1, param.tail_n); + else + dq_dk_dv_1colblock(trait, param, bidb, bidhq, bidhkv, n_block); +} + +template +void +mha_dot_do_o(T trait, + Param param) { + // The block index for the M dimension. + const int m_block = BlockIdxX(); + // The block index for the batch. + const int bidb = BlockIdxZ(); + // The block index for the head. + const int bidh = BlockIdxY();; + if (m_block == param.m_block - 1 and param.tail_m > 0) { + compute_o_dot_do(trait, param, m_block, bidb, bidh); + } else { + compute_o_dot_do(trait, param, m_block, bidb, bidh); + } +} + +template +void +convert_dq(T &trait, Param ¶m, int m_block, int bidb, int bidh) { + constexpr int kBlockM = T::kBlockM; + constexpr int kBlockN = T::kBlockN; + constexpr int kHeadDim = T::kHeadDim; + constexpr int kNSGs = T::kNSGs; + using DType = typename T::DType; + using VType = typename T::VType; + + auto bofst = Boffset(param); + const index_t dq_offset = bofst.dq_offset(bidb, bidh, m_block * kBlockM); + const index_t q_offset = bofst.q_offset(bidb, bidh, m_block * kBlockM); + VType * dQaccum = param.dqaccum_ptr + dq_offset; + DType * dQ = param.dq_ptr + q_offset; + + int tail_m = param.seq_len_q - m_block * kBlockM; + int m = ThreadIdxX(); + if (m < tail_m) { + for (int h = 0; h < kHeadDim; ++h) { + dQ[m * param.q_r_stride + h] = static_cast(dQaccum[m * param.dq_r_stride + h]); + } + } +} + +template +void +convert_dq(T &trait, Param ¶m, int m_block, int bidb, int bidh) { + constexpr int kBlockM = T::kBlockM; + constexpr int kBlockN = T::kBlockN; + constexpr int kHeadDim = T::kHeadDim; + using DType = typename T::DType; + using VType = typename T::VType; + auto sg = compat::get_nd_item<1>().get_sub_group(); + auto first_thread_in_sg_idx = sg.get_group_linear_id() * trait.SubgroupSize; + + auto bofst = Boffset(param); + const index_t dq_offset = bofst.dq_offset(bidb, bidh, m_block * kBlockM); + const index_t q_offset = bofst.q_offset(bidb, bidh, m_block * kBlockM); + using ShapeQ = Shape< + std::conditional_t, int>, + Int>; + ShapeQ shapeQ; + if constexpr (Is_even_M) { + shapeQ = make_shape(Int{}, Int{}); + } else { + shapeQ = make_shape(param.tail_m, Int{}); + } + + Tensor mdQaccum = make_tensor(make_gmem_ptr(param.dqaccum_ptr + dq_offset), + make_layout( + Shape, Int>{}, + make_stride(param.dq_r_stride, _1{}))); + Tensor mdQ = make_tensor(make_gmem_ptr(param.dq_ptr + q_offset), + make_layout( + shapeQ, + make_stride(param.q_r_stride, _1{}))); + + typename T::TiledMmadQ tiled_mma_dq; + auto thr_mma_dq = tiled_mma_dq.get_slice(first_thread_in_sg_idx); + + auto tile_dq = tiled_mma_dq.tile_mnk(); + + auto tileloaddQ = make_block_2d_copy_C(tiled_mma_dq, mdQaccum); + auto tilesavedQ = make_block_2d_copy_D(tiled_mma_dq, mdQ); + + auto thr_load_dQ = tileloaddQ.get_slice(first_thread_in_sg_idx); + auto thr_save_dQ = tilesavedQ.get_slice(first_thread_in_sg_idx); + + Tensor gdQaccum = local_tile(make_identity_tensor(mdQaccum.shape()), + select<0, 1>(tile_dq), make_coord(0,0)); // read dQaccum + Tensor gdQ = local_tile(make_identity_tensor(mdQ.shape()), + select<0, 1>(tile_dq), make_coord(0,0)); // dump dQ + Tensor tdQgdQaccum = thr_load_dQ.partition_S(gdQaccum); // load from dqaccum + auto tdQrdQaccum = thr_load_dQ.partition_sg_fragment_D(gdQaccum); // register for dqaccum + auto tdQrdQ = thr_save_dQ.partition_sg_fragment_S(gdQ); // register for dq + Tensor tdQgdQ = thr_save_dQ.partition_D(gdQ); // save to dq + + copy(tileloaddQ, tdQgdQaccum, tdQrdQaccum); + reorder(tdQrdQaccum, tdQrdQ); + copy(tilesavedQ, tdQrdQ, tdQgdQ); +} + +template +void +mhd_convert_dq(T trait, + Param param) { + // The block index for the M dimension. + const int m_block = BlockIdxX(); + // The block index for the batch. + const int bidb = BlockIdxZ(); + // The block index for the head. + const int bidh = BlockIdxY(); + if (param.tail_m > 0 and m_block == param.m_block - 1) { + convert_dq(trait, param, m_block, bidb, bidh); + } else { + convert_dq(trait, param, m_block, bidb, bidh); + } +} + +/* +template +void +mha_backward_seq(T trait, + Param param) { + const int bidb = BlockIdxZ(); + const int bidh = BlockIdxY(); + CUTLASS_PRAGMA_UNROLL + for (int m_block = 0; m_block < param.m_block - 1; ++m_block) + compute_o_dot_do(trait, param, m_block, bidb, bidh); + if (param.tail_m > 0) { + compute_o_dot_do(trait, param, param.m_block - 1, bidb, bidh); + } else { + compute_o_dot_do(trait, param, param.m_block - 1, bidb, bidh); + } + CUTLASS_PRAGMA_UNROLL + for (int n_block = 0; n_block < param.n_block; ++n_block) + dq_dk_dv_1colblock(trait, param, bidb, bidh, n_block); + if (param.tail_n > 0) + dq_dk_dv_1colblock(trait, param, bidb, bidh, param.n_block, param.tail_n); + CUTLASS_PRAGMA_UNROLL + for (int m_block = 0; m_block < param.m_block - 1; ++m_block) + convert_dq(trait, param, m_block, bidb, bidh); + if (param.tail_m > 0) { + convert_dq(trait, param, param.m_block - 1, bidb, bidh); + } else { + convert_dq(trait, param, param.m_block - 1, bidb, bidh); + } +} +*/ + +template class mhaodoDeviceName; +template class mhabwdDeviceName; +template class mhacvtDeviceName; + +template +void launch_mha_backward_headdim(ProblemShape problem_shape, + const T *do_d, + const T *o_d, + const T *q_d, + const T *k_d, + const T *v_d, + const float *lse_d, + float *odo_d, + float *dqaccum_d, + T *dq_d, + T *dk_d, + T *dv_d, + const int seq_len_q_pad, + const int seq_len_kv_pad) { + auto trait = FAKernel{}; + + const int BATCH = get<0>(problem_shape); + const int NUM_HEAD_Q = get<1>(problem_shape); + const int NUM_HEAD_KV = get<2>(problem_shape); + const int SEQ_LEN_Q = get<3>(problem_shape); + const int SEQ_LEN_KV = get<4>(problem_shape); + const int N_BLOCK = ceil_div(SEQ_LEN_KV, kBlockN); + const int tail_n = SEQ_LEN_KV % kBlockN; + const int M_BLOCK = ceil_div(SEQ_LEN_Q, kBlockM); + const int tail_m = SEQ_LEN_Q % kBlockM; + T * pbuff = compat::malloc(BATCH * NUM_HEAD_Q * seq_len_kv_pad * 2 * kBlockM); + auto param = Param(do_d, o_d, q_d, k_d, v_d, lse_d, odo_d, + dqaccum_d, dq_d, dk_d, dv_d, pbuff, + 1 / sqrt(static_cast(kHeadDim))); + param.batch = BATCH; + param.num_head_q = NUM_HEAD_Q; + param.num_head_kv = NUM_HEAD_KV; + param.num_qh_per_kvh = NUM_HEAD_Q / NUM_HEAD_KV; + param.num_nb_per_blk = std::max(N_BLOCK * NUM_HEAD_Q * BATCH / 1024, 1); // 1024 tuneable here, best for pvc + param.seq_len_q = SEQ_LEN_Q; + param.seq_len_kv = SEQ_LEN_KV; + param.head_dim = kHeadDim; + param.n_block = N_BLOCK; + param.tail_n = tail_n; + param.m_block = M_BLOCK; + param.tail_m = tail_m; + param.seq_len_kv_pad = seq_len_kv_pad; + param.seq_len_q_pad = seq_len_q_pad; + if constexpr(is_bhsd) { + setup_bhsd_stride(param); + } else { + setup_bshd_stride(param); + } + + auto dimGrid0 = compat::dim3(size(M_BLOCK), size(param.num_head_q), size(param.batch)); + auto dimBlock0 = compat::dim3(size(kNSGs * trait.SubgroupSize), size(1), size(1)); + compat::experimental::launch_properties launch_props0{ + // sycl::ext::oneapi::experimental::work_group_scratch_size(0), + }; + compat::experimental::kernel_properties kernel_props0{ + sycl::ext::oneapi::experimental::sub_group_size}; + compat::experimental::launch_policy policy0{dimGrid0, dimBlock0, launch_props0, kernel_props0}; + auto event0 = compat::experimental::launch< + mha_dot_do_o, + mhaodoDeviceName>(policy0, + trait, + param); + EventManager::getInstance().addEvent(event0); + compat::wait_and_throw(); + + auto dimGrid1 = compat::dim3(size(ceil_div(param.n_block, param.num_nb_per_blk)), + size(param.num_head_q), size(param.batch)); + assert((param.num_head_q % param.num_head_kv == 0) && "num_head_q must be dividable by num_head_kv"); + assert((param.num_head_q >= param.num_head_kv) && "num_head_q must be bigger than or equal to num_head_kv"); + auto dimBlock1 = compat::dim3(size(kNSGs * trait.SubgroupSize), size(1), size(1)); + // auto dimBlock = compat::dim3(size(trait.tiled_mma_sdp)); + + compat::experimental::launch_properties launch_props1{ + sycl::ext::oneapi::experimental::work_group_scratch_size(trait.smem_size), + }; + compat::experimental::kernel_properties kernel_props1{ + sycl::ext::oneapi::experimental::sub_group_size}; + compat::experimental::launch_policy policy1{dimGrid1, dimBlock1, launch_props1, kernel_props1}; + auto event1 = compat::experimental::launch< + mha_backward_seq, + mhabwdDeviceName>(policy1, + trait, + param); + EventManager::getInstance().addEvent(event1); + compat::wait_and_throw(); + + auto dimGrid2 = compat::dim3(size(M_BLOCK), size(param.num_head_q), size(param.batch)); + auto dimBlock2 = compat::dim3(size(kNSGs * trait.SubgroupSize), size(1), size(1)); + compat::experimental::launch_properties launch_props2{ + // sycl::ext::oneapi::experimental::work_group_scratch_size(0), + }; + compat::experimental::kernel_properties kernel_props2{ + sycl::ext::oneapi::experimental::sub_group_size}; + compat::experimental::launch_policy policy2{dimGrid2, dimBlock2, launch_props2, kernel_props2}; + auto event2 = compat::experimental::launch< + mhd_convert_dq, + mhacvtDeviceName>(policy2, + trait, + param); + EventManager::getInstance().addEvent(event2); + compat::wait_and_throw(); +} + +template +void launch_mha_backward(ProblemShape problem_shape, + const T *do_d, + const T *o_d, + const T *q_d, + const T *k_d, + const T *v_d, + const float *lse_d, + float *odo_d, + float *dqaccum_d, + T *dq_d, + T *dk_d, + T *dv_d, + const int seq_len_q_pad, + const int seq_len_kv_pad) { + const int headdim = get<5>(problem_shape); + if (headdim == 64) { + constexpr int kBlockM = 64; + constexpr int kBlockN = 32; + constexpr int kHeadDim = 64; + constexpr int kNSGs = 8; + constexpr int AtomLayoutMSdP = 4; + constexpr int AtomLayoutNdKV = 2; + constexpr int AtomLayoutMdQ = 2; + static_assert(kBlockM <= kMPad, "kBlockM must be less than or equal to kMPad"); + static_assert(kBlockN <= kNPad, "kBlockN must be less than or equal to kNPad"); + launch_mha_backward_headdim( + problem_shape, + do_d, o_d, q_d, k_d, v_d, + lse_d, odo_d, + dqaccum_d, dq_d, dk_d, dv_d, + seq_len_q_pad, seq_len_kv_pad); + } else if (headdim == 96) { + constexpr int kBlockM = 64; + constexpr int kBlockN = 64; + constexpr int kHeadDim = 96; + constexpr int kNSGs = 8; + constexpr int AtomLayoutMSdP = 2; + constexpr int AtomLayoutNdKV = 4; + constexpr int AtomLayoutMdQ = 4; + static_assert(kBlockM <= kMPad, "kBlockM must be less than or equal to kMPad"); + static_assert(kBlockN <= kNPad, "kBlockN must be less than or equal to kNPad"); + launch_mha_backward_headdim( + problem_shape, + do_d, o_d, q_d, k_d, v_d, + lse_d, odo_d, + dqaccum_d, dq_d, dk_d, dv_d, + seq_len_q_pad, seq_len_kv_pad); + } else if (headdim == 128) { + constexpr int kBlockM = 64; + constexpr int kBlockN = 64; + constexpr int kHeadDim = 128; + constexpr int kNSGs = 8; + constexpr int AtomLayoutMSdP = 4; + constexpr int AtomLayoutNdKV = 4; + constexpr int AtomLayoutMdQ = 4; + static_assert(kBlockM <= kMPad, "kBlockM must be less than or equal to kMPad"); + static_assert(kBlockN <= kNPad, "kBlockN must be less than or equal to kNPad"); + launch_mha_backward_headdim( + problem_shape, + do_d, o_d, q_d, k_d, v_d, + lse_d, odo_d, + dqaccum_d, dq_d, dk_d, dv_d, + seq_len_q_pad, seq_len_kv_pad); + } else if (headdim == 192) { + constexpr int kBlockM = 64; + constexpr int kBlockN = 32; + constexpr int kHeadDim = 192; + constexpr int kNSGs = 8; + constexpr int AtomLayoutMSdP = 4; + constexpr int AtomLayoutNdKV = 2; + constexpr int AtomLayoutMdQ = 2; + static_assert(kBlockM <= kMPad, "kBlockM must be less than or equal to kMPad"); + static_assert(kBlockN <= kNPad, "kBlockN must be less than or equal to kNPad"); + launch_mha_backward_headdim( + problem_shape, + do_d, o_d, q_d, k_d, v_d, + lse_d, odo_d, + dqaccum_d, dq_d, dk_d, dv_d, + seq_len_q_pad, seq_len_kv_pad); + } else if (headdim == 256) { + constexpr int kBlockM = 64; + constexpr int kBlockN = 32; + constexpr int kHeadDim = 256; + constexpr int kNSGs = 8; + constexpr int AtomLayoutMSdP = 4; + constexpr int AtomLayoutNdKV = 2; + constexpr int AtomLayoutMdQ = 2; + static_assert(kBlockM <= kMPad, "kBlockM must be less than or equal to kMPad"); + static_assert(kBlockN <= kNPad, "kBlockN must be less than or equal to kNPad"); + launch_mha_backward_headdim( + problem_shape, + do_d, o_d, q_d, k_d, v_d, + lse_d, odo_d, + dqaccum_d, dq_d, dk_d, dv_d, + seq_len_q_pad, seq_len_kv_pad); + } else { + assert(false && "only support headdim 64,96,128,192,256"); + } +} + +int main(int argc, char**argv) { + // using T = cute::bfloat16_t; + using T = cute::half_t; + using V = float; + std::string data_file = "mha.npz"; + // read qkv + cnpy::NpyArray q_npy = cnpy::npz_load(data_file, "q"); + cnpy::NpyArray k_npy = cnpy::npz_load(data_file, "k"); + cnpy::NpyArray v_npy = cnpy::npz_load(data_file, "v"); + + // read grad output + cnpy::NpyArray do_npy = cnpy::npz_load(data_file, "grad"); + cnpy::NpyArray o_npy = cnpy::npz_load(data_file, "out"); + + // read lse + cnpy::NpyArray lse_npy = cnpy::npz_load(data_file, "lse"); + // read odo + cnpy::NpyArray odo_npy = cnpy::npz_load(data_file, "odo"); + + // read grad reference + cnpy::NpyArray dq_npy = cnpy::npz_load(data_file, "q_grad"); + cnpy::NpyArray dk_npy = cnpy::npz_load(data_file, "k_grad"); + cnpy::NpyArray dv_npy = cnpy::npz_load(data_file, "v_grad"); + + // read shape + cnpy::NpyArray shape = cnpy::npz_load(data_file, "shape"); + + int64_t BATCH = shape.data()[0]; + int64_t NUM_HEAD_Q = shape.data()[1]; + int64_t NUM_HEAD_KV = shape.data()[2]; + int64_t SEQ_LEN_QO = shape.data()[3]; + int64_t SEQ_LEN_KV = shape.data()[4]; + int64_t HEAD_SIZE_QK = shape.data()[5]; + int64_t HEAD_SIZE_VO = shape.data()[6]; + bool is_causal = shape.data()[7]; + bool is_bhsd = shape.data()[8]; + assert(HEAD_SIZE_QK == HEAD_SIZE_VO && "only support head_size_qk==head_size_vo"); + constexpr int kBlockN = 128; + constexpr int kBlockM = 128; + int64_t SEQ_LEN_QO_PAD = ceil_div(SEQ_LEN_QO, kBlockM) * kBlockM; + int64_t SEQ_LEN_KV_PAD = ceil_div(SEQ_LEN_KV, kBlockN) * kBlockN; + printf("batch %d nh_q %d nh_k %d sq_q %d(%d) sq_k %d(%d) hd_q %d hd_v %d causal %d bhsd %d\n", BATCH, NUM_HEAD_Q, NUM_HEAD_KV, SEQ_LEN_QO, SEQ_LEN_QO_PAD, SEQ_LEN_KV, SEQ_LEN_KV_PAD, HEAD_SIZE_QK, HEAD_SIZE_VO, is_causal, is_bhsd); + // read_args(argc, argv, 1, &BATCH); + // read_args(argc, argv, 2, &NUM_HEAD_Q); + // read_args(argc, argv, 3, &NUM_HEAD_KV); + // read_args(argc, argv, 4, &SEQ_LEN_QO); + // read_args(argc, argv, 5, &SEQ_LEN_KV); + // read_args(argc, argv, 6, &HEAD_SIZE_QK); + // read_args(argc, argv, 7, &HEAD_SIZE_VO); + + // alloc qkv + T *q_d = compat::malloc(q_npy.num_vals); + T *k_d = compat::malloc(k_npy.num_vals); + T *v_d = compat::malloc(v_npy.num_vals); + + // alloc lse, odo + V *lse_d = compat::malloc(lse_npy.num_vals); + V *odo_d = compat::malloc(odo_npy.num_vals); + + // alloc grad output + T *do_d = compat::malloc(do_npy.num_vals); + T *o_d = compat::malloc(o_npy.num_vals); + + // alloc grad test on device + T *dq_d = compat::malloc(dq_npy.num_vals); + V *dqaccum_d = compat::malloc(BATCH * NUM_HEAD_Q * SEQ_LEN_QO_PAD * HEAD_SIZE_QK); + T *dk_d = compat::malloc(dk_npy.num_vals); + T *dv_d = compat::malloc(dv_npy.num_vals); + // copy qkv + compat::memcpy(q_d, q_npy.data(), q_npy.num_vals); + compat::memcpy(k_d, k_npy.data(), k_npy.num_vals); + compat::memcpy(v_d, v_npy.data(), v_npy.num_vals); + + // copy grad output + compat::memcpy(do_d, do_npy.data(), do_npy.num_vals); + compat::memcpy(o_d, o_npy.data(), o_npy.num_vals); + + // copy lse + compat::memcpy(lse_d, lse_npy.data(), lse_npy.num_vals); + + // copy odo + // compat::memcpy(odo_d, odo_npy.data(), odo_npy.num_vals); + + auto problem_shape = ProblemShapeRegular(BATCH, NUM_HEAD_Q, NUM_HEAD_KV, + SEQ_LEN_QO, SEQ_LEN_KV, HEAD_SIZE_QK, HEAD_SIZE_VO); + if (is_bhsd) { + if (is_causal) + launch_mha_backward( + problem_shape, + do_d, o_d, + q_d, k_d, v_d, + lse_d, odo_d, + dqaccum_d, dq_d, dk_d, dv_d, + SEQ_LEN_QO_PAD, SEQ_LEN_KV_PAD); + else + launch_mha_backward( + problem_shape, + do_d, o_d, + q_d, k_d, v_d, + lse_d, odo_d, + dqaccum_d, dq_d, dk_d, dv_d, + SEQ_LEN_QO_PAD, SEQ_LEN_KV_PAD); + } else { + if (is_causal) { + launch_mha_backward( + problem_shape, + do_d, o_d, + q_d, k_d, v_d, + lse_d, odo_d, + dqaccum_d, dq_d, dk_d, dv_d, + SEQ_LEN_QO_PAD, SEQ_LEN_KV_PAD); + } else { + launch_mha_backward( + problem_shape, + do_d, o_d, + q_d, k_d, v_d, + lse_d, odo_d, + dqaccum_d, dq_d, dk_d, dv_d, + SEQ_LEN_QO_PAD, SEQ_LEN_KV_PAD); + } + } + float atol = 5e-3f; + float rtol = 5e-3f; + + std::vector odo_test(odo_npy.num_vals); + compat::memcpy(odo_test.data(), odo_d, odo_test.size()); + compat::wait_and_throw(); + printf("odo val: "); + verify(odo_npy.data(), odo_test.data(), BATCH, NUM_HEAD_Q, SEQ_LEN_QO, atol, rtol); + + compat::wait_and_throw(); + std::vector dv_test(BATCH * NUM_HEAD_Q * SEQ_LEN_KV * HEAD_SIZE_VO); + compat::memcpy(dv_test.data(), dv_d, dv_test.size()); + compat::wait_and_throw(); + printf("dV val: "); + verify(dv_npy.data(), dv_test.data(), BATCH * NUM_HEAD_Q, SEQ_LEN_KV, HEAD_SIZE_VO, atol, rtol); + + std::vector dk_test(BATCH * NUM_HEAD_Q * SEQ_LEN_KV * HEAD_SIZE_QK); + compat::memcpy(dk_test.data(), dk_d, dk_test.size()); + compat::wait_and_throw(); + printf("dK val: "); + verify(dk_npy.data(), dk_test.data(), BATCH * NUM_HEAD_Q, SEQ_LEN_KV, HEAD_SIZE_QK, atol, rtol); + + std::vector dq_test(BATCH * NUM_HEAD_Q * SEQ_LEN_QO * HEAD_SIZE_QK); + compat::memcpy(dq_test.data(), dq_d, dq_test.size()); + compat::wait_and_throw(); + printf("dQ val: "); + verify(dq_npy.data(), dq_test.data(), BATCH * NUM_HEAD_Q, SEQ_LEN_QO, HEAD_SIZE_QK, atol, rtol); + +} diff --git a/examples/sdpa_bwd/sdpa_util.hpp b/examples/sdpa_bwd/sdpa_util.hpp new file mode 100644 index 0000000000..eeb18b0b68 --- /dev/null +++ b/examples/sdpa_bwd/sdpa_util.hpp @@ -0,0 +1,247 @@ +#pragma once +#include +#include +#include +#include +#include +#include + +bool isclose(float a, float b, float atol, float rtol) { + return std::abs(a - b) <= atol + rtol * std::abs(b); +} + +template +float +cosinesimilarity(T *refe, V *test, size_t m) { + float ab = 0.0f; + float a2 = 0.0f; + float b2 = 0.0f; + for (size_t i = 0; i < m; ++i) { + float t_f = (float)test[i]; + float r_f = (float)refe[i]; + ab += t_f * r_f; + a2 += t_f * t_f; + b2 += r_f * r_f; + } + float factor = ab / sqrtf(a2 * b2); + // printf("f=%f\n", factor); + return factor; +} + +template +float +cosinesimilarity(T *refe, V *test, size_t L, size_t M, size_t M_PAD, size_t N, size_t N_PAD) { + float ab = 0.0f; + float a2 = 0.0f; + float b2 = 0.0f; + for (size_t l = 0; l < L; ++l) { + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + size_t i = l * M * N + m * N + n; + size_t j = l * M_PAD * N_PAD + m * N_PAD + n; + float r_f = (float)refe[i]; + float t_f = (float)test[j]; + ab += t_f * r_f; + a2 += t_f * t_f; + b2 += r_f * r_f; + } + } + } + float factor = ab / sqrtf(a2 * b2); + // printf("f=%f\n", factor); + return factor; +} + +template +float +cosinesimilarity(T *refe, V *test, size_t B, size_t H, size_t S, size_t S_PAD, size_t D) { + float ab = 0.0f; + float a2 = 0.0f; + float b2 = 0.0f; + if (is_bhsd) { + for (int b = 0; b < B; ++b) { + for (int h = 0; h < H; ++h) { + for (int s = 0; s < S; ++s) { + for (int d = 0; d < D; ++d) { + int i = b * H * S * D + h * S * D + s * D + d; + int j = b * H * S_PAD * D + h * S_PAD * D + s * D + d; + float r_f = (float)refe[i]; + float t_f = (float)test[j]; + ab += t_f * r_f; + a2 += t_f * t_f; + b2 += r_f * r_f; + } + } + } + } + } else { + for (int b = 0; b < B; ++b) { + for (int s = 0; s < S; ++s) { + for (int h = 0; h < H; ++h) { + for (int d = 0; d < D; ++d) { + int i = b * S * H * D + s * H * D + h * D + d; + int j = b * S_PAD * H * D + s * H * D + h * D + d; + float r_f = (float)refe[i]; + float t_f = (float)test[j]; + ab += t_f * r_f; + a2 += t_f * t_f; + b2 += r_f * r_f; + } + } + } + } + } + float factor = ab / sqrtf(a2 * b2); + // printf("f=%f\n", factor); + return factor; +} + +template +bool allclose(T *refe, V *test, int L, int M, int N, float atol, float rtol) { + size_t err = 0; + size_t count = L * M * N; + bool flag = true; + for (int l = 0; l < L; ++l) { + for (int m = 0; m < M; ++m) { + for (int n = 0; n < N; ++n) { + float expect = (float)refe[l * M * N + m * N + n]; + float value = (float)test[l * M * N + m * N + n]; + // printf("(%d, %d, %d) expect: %f value: %f ratio %f\n", l, m, n, expect, value, value / expect); + if (not isclose(expect, value, atol, rtol)) { + printf("(%d, %d, %d) expect: %f value: %f ratio %f\n", l, m, n, expect, value, value / expect); + err++; + } + if (isnan(value) or isinf(value)) { + printf("\x1B[31m %f detected \x1B[0m at (%d, %d, %d)\n", value, l, m, n); + exit(1); + } + } + } + } + float ratio = static_cast(count - err) / static_cast(count); + // printf("c=%f (%ld)\n", ratio, err); + // printf("CHECK SUM SUCCESS\n"); + return ratio > 0.99f; +} + +template +bool allclose(T *refe, V *test, int L, int M, int M_PAD, int N, int N_PAD, float atol, float rtol) { + size_t err = 0; + size_t count = L * M * N; + bool flag = true; + for (int l = 0; l < L; ++l) { + for (int m = 0; m < M; ++m) { + for (int n = 0; n < N; ++n) { + int i = l * M * N + m * N + n; + int j = l * M_PAD * N_PAD + m * N_PAD + n; + float expect = (float)refe[i]; + float value = (float)test[j]; + // printf("(%d, %d, %d) expect: %f value: %f ratio %f\n", l, m, n, expect, value, value / expect); + if (not isclose(expect, value, atol, rtol)) { + printf("(%d, %d, %d) expect: %f value: %f ratio %f\n", l, m, n, expect, value, value / expect); + err++; + } + if (isnan(value) or isinf(value)) { + printf("\x1B[31m %f detected \x1B[0m at (%d, %d, %d)\n", value, l, m, n); + exit(1); + } + } + } + } + float ratio = static_cast(count - err) / static_cast(count); + // printf("c=%f (%ld)\n", ratio, err); + // printf("CHECK SUM SUCCESS\n"); + return ratio > 0.99f; +} + +template +bool allclose(T *refe, V *test, int B, int H, int S, int S_PAD, int D, float atol, float rtol) { + size_t err = 0; + size_t count = B * S * H * D; + bool flag = true; + if (is_bhsd) { + for (int b = 0; b < B; ++b) { + for (int h = 0; h < H; ++h) { + for (int s = 0; s < S; ++s) { + for (int d = 0; d < D; ++d) { + int i = b * H * S * D + h * S * D + s * D + d; + int j = b * H * S_PAD * D + h * S_PAD * D + s * D + d; + float expect = (float)refe[i]; + float value = (float)test[j]; + if (not isclose(expect, value, atol, rtol)) { + printf("(%d, %d, %d, %d) expect: %f value: %f ratio %f\n", b, h, s, d, expect, value, value / expect); + err++; + } + if (isnan(value) or isinf(value)) { + printf("\x1B[31m %f detected\x1B[0m at (%d, %d, %d, %d)\n", value, b, h, s, d); + exit(1); + } + } + } + } + } + } else { + for (int b = 0; b < B; ++b) { + for (int s = 0; s < S; ++s) { + for (int h = 0; h < H; ++h) { + for (int d = 0; d < D; ++d) { + int i = b * S * H * D + s * H * D + h * D + d; + int j = b * S_PAD * H * D + s * H * D + h * D + d; + float expect = (float)refe[i]; + float value = (float)test[j]; + if (not isclose(expect, value, atol, rtol)) { + printf("(%d, %d, %d, %d) expect: %f value: %f ratio %f\n", b, s, h, d, expect, value, value / expect); + err++; + } + if (isnan(value) or isinf(value)) { + printf("\x1B[31m %f detected\x1B[0m at (%d, %d, %d, %d)\n", value, b, s, h, d); + exit(1); + } + } + } + } + } + } + float ratio = static_cast(count - err) / static_cast(count); + // printf("c=%f (%ld)\n", ratio, err); + // printf("CHECK SUM SUCCESS\n"); + return ratio > 0.99f; +} + +static constexpr char strSUCCESS[] = "\x1B[32mPASS\x1B[0m"; +static constexpr char strFAILURE[] = "\x1B[31mFAIL\x1B[0m"; +template +void verify(T *refe, V *test, int l, int m, int n, float atol, float rtol) { + bool close = allclose(refe, test, l, m, n, atol, rtol); + bool cosine = cosinesimilarity(refe, test, l * m * n) > 0.99f; + printf("%s allclose %s cosinesim %s\n", (close and cosine) ? strSUCCESS : strFAILURE, close ? strSUCCESS : strFAILURE, cosine ? strSUCCESS : strFAILURE); +} + +template +void verify(T *refe, V *test, int l, int m, int m_pad, int n, int n_pad, float atol, float rtol) { + bool close = allclose(refe, test, l, m, m_pad, n, n_pad, atol, rtol); + bool cosine = cosinesimilarity(refe, test, l, m, m_pad, n, n_pad) > 0.99f; + printf("%s allclose %s cosinesim %s\n", (close and cosine) ? strSUCCESS : strFAILURE, close ? strSUCCESS : strFAILURE, cosine ? strSUCCESS : strFAILURE); +} + +template +void verify(T *refe, V *test, int b, int h, int s, int s_pad, int d, float atol, float rtol) { + bool close = allclose(refe, test, b, h, s, s_pad, d, atol, rtol); + bool cosine = cosinesimilarity(refe, test, b, h, s, s_pad, d) > 0.99f; + printf("%s allclose %s cosinesim %s\n", (close and cosine) ? strSUCCESS : strFAILURE, close ? strSUCCESS : strFAILURE, cosine ? strSUCCESS : strFAILURE); +} + +template +void read_file(T *ptr, std::string filename, size_t rsize) { + std::ifstream file(filename, std::ios::in | std::ios::binary | std::ios::ate); + if (file.is_open()) { + size_t fsize = file.tellg(); + assert(fsize == rsize); + size_t len = fsize / sizeof(T); + file.seekg(0, std::ios::beg); + file.read((char *)ptr, len * sizeof(T)); + file.close(); + } else { + std::cout << "fail to open " << filename << std::endl; + } +} diff --git a/examples/sdpa_bwd/test_sdpa_gqa.py b/examples/sdpa_bwd/test_sdpa_gqa.py new file mode 100644 index 0000000000..575b1d4f52 --- /dev/null +++ b/examples/sdpa_bwd/test_sdpa_gqa.py @@ -0,0 +1,356 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.attention import SDPBackend, sdpa_kernel +import numpy as np +def is_close(refe: torch.Tensor, + test: torch.Tensor): + test = test.to(torch.float32) + refe = refe.to(torch.float32) + cosfactor = F.cosine_similarity(test.reshape(-1), refe.reshape(-1), dim=0) > 0.99 + allclose = torch.allclose(test, refe, atol=3e-3, rtol=3e-3) + return cosfactor and allclose + +def num_head_bcast(query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor): + q_num_heads = query.size(-3) + k_num_heads = key.size(-3) + v_num_heads = value.size(-3) + k_dim0 = key.size(0) + k_dim1 = key.size(1) + k_dim2 = key.size(2) + k_dim3 = key.size(3) + v_dim0 = value.size(0) + v_dim1 = value.size(1) + v_dim2 = value.size(2) + v_dim3 = value.size(3) + if (q_num_heads == k_num_heads) and (q_num_heads == v_num_heads): + return key, value + k_repeat = q_num_heads // k_num_heads + v_repeat = q_num_heads // v_num_heads + key = key.repeat_interleave(k_repeat, 1).reshape(k_dim0, k_repeat * k_dim1, k_dim2, k_dim3) + value = value.repeat_interleave(v_repeat, 1).reshape(v_dim0, v_repeat * v_dim1, v_dim2, v_dim3) + return key, value + +def num_head_reduce(expand_grad: torch.Tensor, + x: torch.Tensor): + num_heads_expand = expand_grad.size(-3) + num_heads_orig = x.size(-3) + if (num_heads_expand == num_heads_orig): + return expand_grad + n_repeat = num_heads_expand // num_heads_orig + assert len(x.shape) == 4 + batch, num_head, seq_len, head_size = x.size() + expand_grad = expand_grad.reshape(batch, num_head, n_repeat, seq_len, head_size) + grad = torch.sum(expand_grad, dim=2).reshape(batch, num_head, seq_len, head_size) + return grad + +GRAD_DICT = {} + +def dump_grad(name, value): + global GRAD_DICT + if name not in GRAD_DICT: + GRAD_DICT[name] = value.clone() + else: + print(f'duplicated grad {name}') + return + +def softmax_backward(y: torch.Tensor, + grad_y: torch.Tensor, + scale: float): + orig_dtype = y.dtype + rest_dim = y.shape[:-1] + dim = y.shape[-1] + y = y.to(torch.float32) + grad_y = grad_y.to(torch.float32) + ydy = grad_y * y + sum_row = torch.sum(ydy, dim= -1).reshape(*rest_dim, 1) + grad_x2 = ydy - y * sum_row + grad_x = grad_x2.reshape(*rest_dim, dim) * scale + return grad_x.to(orig_dtype) + +def softmax_backward_odo(p: torch.Tensor, + dp: torch.Tensor, + o: torch.Tensor, + do: torch.Tensor, + scale: float): + orig_dtype = p.dtype + o = o.to(torch.float32) + do = do.to(torch.float32) + p = p.to(torch.float32) + dp = dp.to(torch.float32) + odo = o * do + sum_odo = torch.sum(odo, dim= -1, keepdim=True) + ds = p * (dp - sum_odo) * scale + ds = ds.to(orig_dtype) + return ds, sum_odo + +def dropout_backward(mask: torch.Tensor, + grad_y: torch.Tensor, + dropout_p: float): + return mask * grad_y / (1 - dropout_p) + +def dropout_backward2(grad_y: torch.Tensor, + dropout_p: float): + return dropout_backward(mask, grad_y, dropout_p) + +def dropout_forward(seed: int, + dropout_p: float, + x: torch.Tensor): + torch.manual_seed(seed) + mask = torch.empty_like(x).fill_(dropout_p) + prob = torch.bernoulli(mask).logical_not() + y = x * prob / (1 - dropout_p) + return y + +def softmax_causal_backward(y1: torch.Tensor, + y2: torch.Tensor, + grad_y: torch.Tensor): + # y1 attn2 after dropout mask + # y2 attn after softmax, only half mask + orig_dtype = y1.dtype + rest_dim = y1.shape[:-1] + dim = y1.shape[-1] + # seq_len_q = y.size()[-2] + # seq_len_k = y.size()[-1] + # seq_len_q = grad_y.size()[-2] + # seq_len_k = grad_y.size()[-1] + # mask = torch.ones(seq_len_q, seq_len_k, dtype=torch.bool).tril(diagonal=0) + y1 = y1.to(torch.float32) + grad_y = grad_y.to(torch.float32) + grad_y = grad_y + ydy = grad_y * y1 + sum_row = torch.sum(ydy, dim= -1).reshape(*rest_dim, 1) + grad_x2 = ydy - y2 * sum_row + grad_x = grad_x2.reshape(*rest_dim, dim) + return grad_x.to(orig_dtype) + +class SDPA(nn.Module): + def __init__(self, dropout_p) -> None: + super().__init__() + if dropout_p > 0.0: + self.do_m = nn.Dropout(p=dropout_p) + + def forward(self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float = None): + dtype = q.dtype + self.head_size_q = q.size()[-1] + self.q = q.clone() + self.k = k.clone() + self.v = v.clone() + seq_len_q, seq_len_k = q.size(-2), k.size(-2) + + attn_bias = torch.zeros(seq_len_q, seq_len_k, dtype=dtype) + self.is_causal = is_causal + if is_causal: + temp_mask = torch.ones(seq_len_q, seq_len_k, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(dtype) + + k_expand, v_expand = num_head_bcast(q, k, v) + #k_expand = k + #v_expand = v + # k_expand.register_hook(lambda x: dump_grad('k_expand_grad', k_expand)) + # v_expand.register_hook(lambda x: dump_grad('v_expand_grad', v_expand)) + #self.p = q@k_expand.transpose(-2, -1) + s = torch.matmul(q, k_expand.transpose(-1, -2)) + s.register_hook(lambda x: dump_grad('s_grad', x)) + self.s = s + s = s.to(torch.float32) + self.softmax_scale = 1 / np.sqrt(self.head_size_q) if scale is None else scale + s = s * self.softmax_scale + attn_bias + sum_row, _ = torch.max(s, dim= -1, keepdim=True) + s = s - sum_row + self.lse = torch.logsumexp(s, dim=-1, keepdim=True) + sum_row + p = torch.softmax(s, dim= -1).to(dtype) + self.p = p + p.register_hook(lambda x: dump_grad('p_grad', x)) + attn = torch.matmul(p, v_expand) + attn.register_hook(lambda x: dump_grad('O_grad', x)) + self.o = attn + return attn + + def backward_ref(self, + o_grad: torch.Tensor): + q_grad = torch.empty_like(self.q) + k_grad = torch.empty_like(self.k) + v_grad = torch.empty_like(self.v) + k_expand, v_expand = num_head_bcast(self.q, self.k, self.v) + # forward + s = torch.matmul(self.q, k_expand.transpose(-1, -2)) + s = s.to(torch.float32) + dtype = self.q.dtype + seq_len_q, seq_len_k = q_grad.size(-2), k_grad.size(-2) + attn_bias = torch.zeros(seq_len_q, seq_len_k, dtype=dtype) + if self.is_causal: + temp_mask = torch.ones(seq_len_q, seq_len_k, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(dtype) + s = s * self.softmax_scale + attn_bias + p = torch.exp(s - self.lse).to(dtype) + # backward + v_grad = torch.matmul(p.transpose(-1, -2), o_grad) + + p_grad = torch.matmul(o_grad, v_expand.transpose(-1, -2)) + s_grad, odo = softmax_backward_odo(p, p_grad, self.o, o_grad, self.softmax_scale) + self.odo = odo + # s_grad = softmax_backward(p, p_grad, self.softmax_scale) + k_grad = torch.matmul(s_grad.transpose(-1, -2), self.q) + q_grad = torch.matmul(s_grad, k_expand) + k_grad_red = num_head_reduce(k_grad, self.k) + v_grad_red = num_head_reduce(v_grad, self.v) + return (q_grad, k_grad, k_grad_red, v_grad, v_grad_red, p_grad, s_grad) + +class ptSDPA(nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float = None): + dtype = q.dtype + with sdpa_kernel(backends=[SDPBackend.MATH]): + return F.scaled_dot_product_attention(q, k, v, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=True) + +def set_dict(dump_dict, name, value): + if value.dtype == torch.bfloat16 or value.dtype == torch.float16: + dump_dict[name] = value.detach().clone().view(torch.uint16).numpy() + elif value.dtype == torch.bool: + dump_dict[name] = value.detach().clone().to(torch.uint16).numpy() + else: + dump_dict[name] = value.detach().clone().numpy() + +def test_sdpa(dtype, + seed: int, + batch: int, + num_heads_q: int, + num_heads_kv: int, + seq_len_qo: int, + seq_len_kv: int, + head_size_qk: int, + head_size_vo: int, + dropout_p: float = 0.0, + is_causal: bool = False, + is_bhsd: bool = True): + torch.manual_seed(seed) + q = torch.randn(batch, num_heads_q, seq_len_qo, head_size_qk, requires_grad=True).to(dtype) + k = torch.randn(batch, num_heads_kv, seq_len_kv, head_size_qk, requires_grad=True).to(dtype) + v = torch.randn(batch, num_heads_kv, seq_len_kv, head_size_vo, requires_grad=True).to(dtype) + q2 = q.clone() + k2 = k.clone() + v2 = v.clone() + q.retain_grad() + k.retain_grad() + v.retain_grad() + q2.retain_grad() + k2.retain_grad() + v2.retain_grad() + test_model = SDPA(dropout_p).to(dtype) + refe_model = ptSDPA().to(dtype) + + torch.manual_seed(seed) + attn_out = test_model(q, k, v, dropout_p, is_causal) + torch.manual_seed(seed) + attn_out_pt = refe_model(q2, k2, v2, dropout_p, is_causal) + grad = torch.empty_like(attn_out) + torch.manual_seed(seed) + grad.uniform_(-1, 1) + grad = grad.to(dtype) + attn_out.backward(grad) + attn_out_pt.backward(grad) + q_grad, k_grad, k_grad_red, v_grad, v_grad_red, p_grad, s_grad = test_model.backward_ref(grad) + dump_dict = {} + print(f"seed {seed} bsz {batch} nh_q {num_heads_q} nh_kv {num_heads_kv} sl_qo {seq_len_qo} sl_kv {seq_len_kv} hs_qk {head_size_qk} hs_vo {head_size_vo} dp {dropout_p} is_causal {is_causal} is_bhsd {is_bhsd}") + print('attn_out ', is_close(attn_out, attn_out_pt)) + print('p_grad ', is_close(GRAD_DICT['p_grad'], p_grad)) + # print('s2_grad ', is_close(GRAD_DICT['s2_grad'], s2_grad)) + print('s_grad ', is_close(GRAD_DICT['s_grad'], s_grad)) + print('k_grad ', is_close(k_grad_red, k2.grad)) + print('q_grad ', is_close(q_grad, q2.grad)) + print('v_grad ', is_close(v_grad_red, v2.grad)) + if is_bhsd: + set_dict(dump_dict, 'out', attn_out) + set_dict(dump_dict, 'grad', grad) + set_dict(dump_dict, 'v_grad', v_grad) + set_dict(dump_dict, 'k_grad', k_grad) + set_dict(dump_dict, 'q_grad', q_grad) + set_dict(dump_dict, 'q', q) + set_dict(dump_dict, 'k', k) + set_dict(dump_dict, 'v', v) + else: + set_dict(dump_dict, 'out', attn_out.transpose(1, 2).contiguous()) + set_dict(dump_dict, 'grad', grad.transpose(1, 2).contiguous()) + set_dict(dump_dict, 'v_grad', v_grad.transpose(1, 2).contiguous()) + set_dict(dump_dict, 'k_grad', k_grad.transpose(1, 2).contiguous()) + set_dict(dump_dict, 'q_grad', q_grad.transpose(1, 2).contiguous()) + set_dict(dump_dict, 'q', q.transpose(1, 2).contiguous()) + set_dict(dump_dict, 'k', k.transpose(1, 2).contiguous()) + set_dict(dump_dict, 'v', v.transpose(1, 2).contiguous()) + set_dict(dump_dict, 'lse', test_model.lse) + set_dict(dump_dict, 'odo', test_model.odo) + set_dict(dump_dict, 's', test_model.s) + set_dict(dump_dict, 'p', test_model.p) + set_dict(dump_dict, 'p_grad', p_grad) + set_dict(dump_dict, 's_grad', s_grad) + shape = np.array([batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo, is_causal, is_bhsd], dtype=np.int32) + dump_dict['shape'] = shape + # print('test', v_grad[0,0:4,0,0:16]) + # print('upstream', v2.grad[0,0:4,0,0:16]) + np.savez(f'mha-{batch}-{num_heads_q}-{num_heads_kv}-{seq_len_qo}-{seq_len_kv}-{head_size_qk}-{head_size_vo}-{dropout_p}-{int(is_causal)}-{int(is_bhsd)}.npz', **dump_dict) + +def loop_run(): + global GRAD_DICT + for h in [2, 1]: + # for seq_q in list(range(512, 512+32)): + # for seq_k in list(range(512, 512+32)): + for seq_q in [512, 513, 523, 528, 543]: + for seq_k in [512, 513, 523, 528, 543]: + for dim in [64, 96, 128, 192, 256]: + # print('test_run', 4, 4, h, seq_q, seq_k, dim, dim) + # bhsd + test_sdpa(torch.float16, 123, 4, 4, h, seq_q, seq_k, dim, dim, is_bhsd = True) + GRAD_DICT = {} + # bshd + test_sdpa(torch.float16, 123, 4, 4, h, seq_q, seq_k, dim, dim, is_bhsd = False) + GRAD_DICT = {} + +if __name__ == '__main__': + # test_sdpa(torch.bfloat16, 123, 128, 4, 4, 900, 900, 128, 128) + loop_run() + # test_sdpa(torch.float16, 123, 4, 4, 2, 514, 513, 128, 128, is_bhsd=True) + # GRAD_DICT = {} + # test_sdpa(torch.float16, 123, 4, 4, 4, 513, 784, 128, 128, is_causal=False) + # GRAD_DICT = {} + # test_sdpa(torch.float16, 123, 4, 4, 2, 513, 784, 128, 128, is_causal=False) + # GRAD_DICT = {} + # test_sdpa(torch.float16, 123, 4, 4, 2, 513, 784, 128, 128, is_causal=True) + # GRAD_DICT = {} + # test_sdpa(torch.float16, 123, 4, 4, 1, 513, 784, 128, 128, is_causal=False) + # GRAD_DICT = {} + # test_sdpa(torch.float16, 123, 4, 4, 1, 513, 784, 128, 128, is_causal=True) + # GRAD_DICT = {} + # test_sdpa(torch.bfloat16, 123, 4, 4, 2, 513, 513, 128, 64, is_causal=True) + # GRAD_DICT = {} + # test_sdpa(torch.bfloat16, 123, 4, 4, 2, 513, 513, 128, 64, 0.3, is_causal=False) + # GRAD_DICT = {} + # test_sdpa(torch.bfloat16, 123, 4, 4, 2, 513, 513, 128, 64, is_causal=False) + # GRAD_DICT = {} + # GRAD_DICT = {} + # test_sdpa(torch.bfloat16, 123, 4, 4, 4, 513, 513, 128, 64, is_causal=False) + # test_sdpa(torch.bfloat16, 123, 4, 4, 1, 513, 513, 128, 64, False) + # test_sdpa(torch.bfloat16, 123, 4, 4, 4, 1024, 513, 128, 128) + # test_sdpa(123, 2, 16, 1, 513, 513, 128, 128) diff --git a/examples/sdpa_bwd/test_sdpa_s.py b/examples/sdpa_bwd/test_sdpa_s.py new file mode 100644 index 0000000000..e2be37994e --- /dev/null +++ b/examples/sdpa_bwd/test_sdpa_s.py @@ -0,0 +1,368 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.attention import SDPBackend, sdpa_kernel +import numpy as np +def is_close(refe: torch.Tensor, + test: torch.Tensor): + test = test.to(torch.float32) + refe = refe.to(torch.float32) + cosfactor = F.cosine_similarity(test.reshape(-1), refe.reshape(-1), dim=0) > 0.99 + allclose = torch.allclose(test, refe, atol=3e-3, rtol=3e-3) + return cosfactor and allclose + +def num_head_bcast(query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor): + q_num_heads = query.size(-3) + k_num_heads = key.size(-3) + v_num_heads = value.size(-3) + k_dim0 = key.size(0) + k_dim1 = key.size(1) + k_dim2 = key.size(2) + k_dim3 = key.size(3) + v_dim0 = value.size(0) + v_dim1 = value.size(1) + v_dim2 = value.size(2) + v_dim3 = value.size(3) + if (q_num_heads == k_num_heads) and (q_num_heads == v_num_heads): + return key, value + k_repeat = q_num_heads // k_num_heads + v_repeat = q_num_heads // v_num_heads + key = key.repeat_interleave(k_repeat, 1).reshape(k_dim0, k_repeat * k_dim1, k_dim2, k_dim3) + value = value.repeat_interleave(v_repeat, 1).reshape(v_dim0, v_repeat * v_dim1, v_dim2, v_dim3) + return key, value + +def num_head_reduce(expand_grad: torch.Tensor, + x: torch.Tensor): + num_heads_expand = expand_grad.size(-3) + num_heads_orig = x.size(-3) + if (num_heads_expand == num_heads_orig): + return expand_grad + n_repeat = num_heads_expand // num_heads_orig + assert len(x.shape) == 4 + batch, num_head, seq_len, head_size = x.size() + expand_grad = expand_grad.reshape(batch, num_head, n_repeat, seq_len, head_size) + grad = torch.sum(expand_grad, dim=2).reshape(batch, num_head, seq_len, head_size) + return grad + +GRAD_DICT = {} + +def dump_grad(name, value): + global GRAD_DICT + if name not in GRAD_DICT: + GRAD_DICT[name] = value.clone() + else: + print(f'duplicated grad {name}') + return + +def softmax_backward(y: torch.Tensor, + grad_y: torch.Tensor, + scale: float): + orig_dtype = y.dtype + rest_dim = y.shape[:-1] + dim = y.shape[-1] + y = y.to(torch.float32) + grad_y = grad_y.to(torch.float32) + ydy = grad_y * y + sum_row = torch.sum(ydy, dim= -1).reshape(*rest_dim, 1) + grad_x2 = ydy - y * sum_row + grad_x = grad_x2.reshape(*rest_dim, dim) * scale + return grad_x.to(orig_dtype) + +def softmax_backward_odo(p: torch.Tensor, + dp: torch.Tensor, + o: torch.Tensor, + do: torch.Tensor, + scale: float): + orig_dtype = p.dtype + o = o.to(torch.float32) + do = do.to(torch.float32) + p = p.to(torch.float32) + dp = dp.to(torch.float32) + odo = o * do + sum_odo = torch.sum(odo, dim= -1, keepdim=True) + ds = p * (dp - sum_odo) * scale + ds = ds.to(orig_dtype) + return ds, sum_odo + +def dropout_backward(mask: torch.Tensor, + grad_y: torch.Tensor, + dropout_p: float): + return mask * grad_y / (1 - dropout_p) + +def dropout_backward2(grad_y: torch.Tensor, + dropout_p: float): + return dropout_backward(mask, grad_y, dropout_p) + +def dropout_forward(seed: int, + dropout_p: float, + x: torch.Tensor): + torch.manual_seed(seed) + mask = torch.empty_like(x).fill_(dropout_p) + prob = torch.bernoulli(mask).logical_not() + y = x * prob / (1 - dropout_p) + return y + +def softmax_causal_backward(y1: torch.Tensor, + y2: torch.Tensor, + grad_y: torch.Tensor): + # y1 attn2 after dropout mask + # y2 attn after softmax, only half mask + orig_dtype = y1.dtype + rest_dim = y1.shape[:-1] + dim = y1.shape[-1] + # seq_len_q = y.size()[-2] + # seq_len_k = y.size()[-1] + # seq_len_q = grad_y.size()[-2] + # seq_len_k = grad_y.size()[-1] + # mask = torch.ones(seq_len_q, seq_len_k, dtype=torch.bool).tril(diagonal=0) + y1 = y1.to(torch.float32) + grad_y = grad_y.to(torch.float32) + grad_y = grad_y + ydy = grad_y * y1 + sum_row = torch.sum(ydy, dim= -1).reshape(*rest_dim, 1) + grad_x2 = ydy - y2 * sum_row + grad_x = grad_x2.reshape(*rest_dim, dim) + return grad_x.to(orig_dtype) + +class SDPA(nn.Module): + def __init__(self, dropout_p) -> None: + super().__init__() + if dropout_p > 0.0: + self.do_m = nn.Dropout(p=dropout_p) + + def forward(self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float = None): + dtype = q.dtype + self.head_size_q = q.size()[-1] + self.q = q.clone() + self.k = k.clone() + self.v = v.clone() + seq_len_q, seq_len_k = q.size(-2), k.size(-2) + + attn_bias = torch.zeros(seq_len_q, seq_len_k, dtype=dtype) + self.is_causal = is_causal + if is_causal: + diagonal_offset = seq_len_k - seq_len_q + temp_mask = torch.ones(seq_len_q, seq_len_k, dtype=torch.bool).tril(diagonal=diagonal_offset) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(dtype) + + k_expand, v_expand = num_head_bcast(q, k, v) + #k_expand = k + #v_expand = v + # k_expand.register_hook(lambda x: dump_grad('k_expand_grad', k_expand)) + # v_expand.register_hook(lambda x: dump_grad('v_expand_grad', v_expand)) + #self.p = q@k_expand.transpose(-2, -1) + s = torch.matmul(q, k_expand.transpose(-1, -2)) + s.register_hook(lambda x: dump_grad('s_grad', x)) + self.s = s + s = s.to(torch.float32) + self.softmax_scale = 1 / np.sqrt(self.head_size_q) if scale is None else scale + s = s * self.softmax_scale + attn_bias + if is_causal: + diagonal_offset = seq_len_q - seq_len_k + if (diagonal_offset > 0): + for i in range(diagonal_offset): + s[:, :, i, :] = 0.0 + + sum_row, _ = torch.max(s, dim= -1, keepdim=True) + s = s - sum_row + self.lse = torch.logsumexp(s, dim=-1, keepdim=True) + sum_row + p = torch.softmax(s, dim= -1).to(dtype) + self.p = p + p.register_hook(lambda x: dump_grad('p_grad', x)) + attn = torch.matmul(p, v_expand) + attn.register_hook(lambda x: dump_grad('O_grad', x)) + self.o = attn + return attn + + def backward_ref(self, + o_grad: torch.Tensor): + q_grad = torch.empty_like(self.q) + k_grad = torch.empty_like(self.k) + v_grad = torch.empty_like(self.v) + k_expand, v_expand = num_head_bcast(self.q, self.k, self.v) + # forward + s = torch.matmul(self.q, k_expand.transpose(-1, -2)) + s = s.to(torch.float32) + dtype = self.q.dtype + seq_len_q, seq_len_k = q_grad.size(-2), k_grad.size(-2) + attn_bias = torch.zeros(seq_len_q, seq_len_k, dtype=dtype) + if self.is_causal: + diagonal_offset = seq_len_k - seq_len_q + temp_mask = torch.ones(seq_len_q, seq_len_k, dtype=torch.bool).tril(diagonal=diagonal_offset) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(dtype) + s = s * self.softmax_scale + attn_bias + p = torch.exp(s - self.lse).to(dtype) + # backward + v_grad = torch.matmul(p.transpose(-1, -2), o_grad) + + p_grad = torch.matmul(o_grad, v_expand.transpose(-1, -2)) + s_grad, odo = softmax_backward_odo(p, p_grad, self.o, o_grad, self.softmax_scale) + self.odo = odo + # s_grad = softmax_backward(p, p_grad, self.softmax_scale) + k_grad = torch.matmul(s_grad.transpose(-1, -2), self.q) + q_grad = torch.matmul(s_grad, k_expand) + k_grad = num_head_reduce(k_grad, self.k) + v_grad = num_head_reduce(v_grad, self.v) + return (q_grad, k_grad, v_grad, p_grad, s_grad) + +class ptSDPA(nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float = None): + dtype = q.dtype + if is_causal: + seq_len_q, seq_len_k = q.size(-2), k.size(-2) + diagonal_offset = seq_len_k - seq_len_q + temp_mask = torch.ones(seq_len_q, seq_len_k, dtype=torch.bool).tril(diagonal=diagonal_offset) + with sdpa_kernel(backends=[SDPBackend.MATH]): + return F.scaled_dot_product_attention(q, k, v, + dropout_p=dropout_p, + scale=scale, + attn_mask=temp_mask, + enable_gqa=True) + +def set_dict(dump_dict, name, value): + if value.dtype == torch.bfloat16 or value.dtype == torch.float16: + dump_dict[name] = value.detach().clone().view(torch.uint16).numpy() + elif value.dtype == torch.bool: + dump_dict[name] = value.detach().clone().to(torch.uint16).numpy() + else: + dump_dict[name] = value.detach().clone().numpy() + +def test_sdpa(dtype, + seed: int, + batch: int, + num_heads_q: int, + num_heads_kv: int, + seq_len_qo: int, + seq_len_kv: int, + head_size_qk: int, + head_size_vo: int, + dropout_p: float = 0.0, + is_causal: bool = False, + is_bhsd: bool = True): + torch.manual_seed(seed) + q = torch.randn(batch, num_heads_q, seq_len_qo, head_size_qk, requires_grad=True).to(dtype) + k = torch.randn(batch, num_heads_kv, seq_len_kv, head_size_qk, requires_grad=True).to(dtype) + v = torch.randn(batch, num_heads_kv, seq_len_kv, head_size_vo, requires_grad=True).to(dtype) + q2 = q.clone() + k2 = k.clone() + v2 = v.clone() + q.retain_grad() + k.retain_grad() + v.retain_grad() + q2.retain_grad() + k2.retain_grad() + v2.retain_grad() + test_model = SDPA(dropout_p).to(dtype) + refe_model = ptSDPA().to(dtype) + + torch.manual_seed(seed) + attn_out = test_model(q, k, v, dropout_p, is_causal) + torch.manual_seed(seed) + attn_out_pt = refe_model(q2, k2, v2, dropout_p, is_causal) + grad = torch.empty_like(attn_out) + torch.manual_seed(seed) + grad.uniform_(-1, 1) + grad = grad.to(dtype) + attn_out.backward(grad) + attn_out_pt.backward(grad) + q_grad, k_grad, v_grad, p_grad, s_grad = test_model.backward_ref(grad) + dump_dict = {} + print(f"seed {seed} bsz {batch} nh_q {num_heads_q} nh_kv {num_heads_kv} sl_qo {seq_len_qo} sl_kv {seq_len_kv} hs_qk {head_size_qk} hs_vo {head_size_vo} dp {dropout_p} is_causal {is_causal} is_bhsd {is_bhsd}") + print('attn_out ', is_close(attn_out, attn_out_pt)) + print('p_grad ', is_close(GRAD_DICT['p_grad'], p_grad)) + # print('s2_grad ', is_close(GRAD_DICT['s2_grad'], s2_grad)) + print('s_grad ', is_close(GRAD_DICT['s_grad'], s_grad)) + print('k_grad ', is_close(k_grad, k2.grad)) + print('q_grad ', is_close(q_grad, q2.grad)) + print('v_grad ', is_close(v_grad, v2.grad)) + if is_bhsd: + set_dict(dump_dict, 'out', attn_out) + set_dict(dump_dict, 'grad', grad) + set_dict(dump_dict, 'v_grad', v_grad) + set_dict(dump_dict, 'k_grad', k_grad) + set_dict(dump_dict, 'q_grad', q_grad) + set_dict(dump_dict, 'q', q) + set_dict(dump_dict, 'k', k) + set_dict(dump_dict, 'v', v) + else: + set_dict(dump_dict, 'out', attn_out.transpose(1, 2).contiguous()) + set_dict(dump_dict, 'grad', grad.transpose(1, 2).contiguous()) + set_dict(dump_dict, 'v_grad', v_grad.transpose(1, 2).contiguous()) + set_dict(dump_dict, 'k_grad', k_grad.transpose(1, 2).contiguous()) + set_dict(dump_dict, 'q_grad', q_grad.transpose(1, 2).contiguous()) + set_dict(dump_dict, 'q', q.transpose(1, 2).contiguous()) + set_dict(dump_dict, 'k', k.transpose(1, 2).contiguous()) + set_dict(dump_dict, 'v', v.transpose(1, 2).contiguous()) + set_dict(dump_dict, 'lse', test_model.lse) + set_dict(dump_dict, 'odo', test_model.odo) + set_dict(dump_dict, 's', test_model.s) + set_dict(dump_dict, 'p', test_model.p) + set_dict(dump_dict, 'p_grad', p_grad) + set_dict(dump_dict, 's_grad', s_grad) + shape = np.array([batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo, is_causal, is_bhsd], dtype=np.int32) + dump_dict['shape'] = shape + # print('test', v_grad[0,0:4,0,0:16]) + # print('upstream', v2.grad[0,0:4,0,0:16]) + np.savez(f'mha-{batch}-{num_heads_q}-{num_heads_kv}-{seq_len_qo}-{seq_len_kv}-{head_size_qk}-{head_size_vo}-{dropout_p}-{int(is_causal)}-{int(is_bhsd)}.npz', **dump_dict) + +def loop_run(): + global GRAD_DICT + for h in [4]: + # for seq_q in list(range(512, 512+32)): + # for seq_k in list(range(512, 512+32)): + for seq_q in [512, 513, 523, 527, 535, 543]: + for seq_k in [512, 513, 523, 527, 535, 543]: + for dim in [64, 96, 128, 192]: + # print('test_run', 4, 4, h, seq_q, seq_k, dim, dim) + # bhsd + test_sdpa(torch.float16, 123, 4, 4, h, seq_q, seq_k, dim, dim, is_causal=True, is_bhsd = True) + GRAD_DICT = {} + # bshd + test_sdpa(torch.float16, 123, 4, 4, h, seq_q, seq_k, dim, dim, is_causal=True, is_bhsd = False) + GRAD_DICT = {} + +if __name__ == '__main__': + # test_sdpa(torch.float16, 123, 4, 4, 4, 513, 512, 128, 128, is_causal=True) + loop_run() + # test_sdpa(torch.float16, 123, 4, 4, 4, 513, 784, 128, 128, is_causal=True) + # GRAD_DICT = {} + # test_sdpa(torch.float16, 123, 4, 4, 4, 513, 784, 128, 128, is_causal=False) + # GRAD_DICT = {} + # test_sdpa(torch.float16, 123, 4, 4, 2, 513, 784, 128, 128, is_causal=False) + # GRAD_DICT = {} + # test_sdpa(torch.float16, 123, 4, 4, 2, 513, 784, 128, 128, is_causal=True) + # GRAD_DICT = {} + # test_sdpa(torch.float16, 123, 4, 4, 1, 513, 784, 128, 128, is_causal=False) + # GRAD_DICT = {} + # test_sdpa(torch.float16, 123, 4, 4, 1, 513, 784, 128, 128, is_causal=True) + # GRAD_DICT = {} + # test_sdpa(torch.bfloat16, 123, 4, 4, 2, 513, 513, 128, 64, is_causal=True) + # GRAD_DICT = {} + # test_sdpa(torch.bfloat16, 123, 4, 4, 2, 513, 513, 128, 64, 0.3, is_causal=False) + # GRAD_DICT = {} + # test_sdpa(torch.bfloat16, 123, 4, 4, 2, 513, 513, 128, 64, is_causal=False) + # GRAD_DICT = {} + # GRAD_DICT = {} + # test_sdpa(torch.bfloat16, 123, 4, 4, 4, 513, 513, 128, 64, is_causal=False) + # test_sdpa(torch.bfloat16, 123, 4, 4, 1, 513, 513, 128, 64, False) + # test_sdpa(torch.bfloat16, 123, 4, 4, 4, 1024, 513, 128, 128) + # test_sdpa(123, 2, 16, 1, 513, 513, 128, 128) diff --git a/include/cute/algorithm/reorder.hpp b/include/cute/algorithm/reorder.hpp index ea2d6c6018..4c5ae1a4fd 100644 --- a/include/cute/algorithm/reorder.hpp +++ b/include/cute/algorithm/reorder.hpp @@ -165,21 +165,41 @@ reorder_impl(ReorderAtom const& atom, constexpr int values = size(SLayout{}) / size<0>(SLayout{}); constexpr int vchunk = sizeof_bits_v / sizeof_bits_v; - // Calculate mapping from src val -> dst val on a chunk-by-chunk basis. Unlike a plain copy, there is no intrinsic - // correspondence of src/dst values for subgroup reorders. - auto rlayout = coalesce(composition(right_inverse(dlayout), slayout)); // src index -> dst index - auto vrlayout = composition(composition(Layout>, Stride<_0, _1>>{}, - rlayout), - Layout>, Stride<_0, _SG>>{}); // src val -> dst val - - CUTE_UNROLL - for (int sv = 0; sv < values; sv += vchunk) { - auto pS = recast_ptr(src.data() + sv); - auto pD = recast_ptr(dst.data() + vrlayout(sv)); - - detail::explode(detail::CallReorder{}, - pS, make_int_sequence{}, - pD, make_int_sequence{}); + static constexpr bool has_broadcast = (size(DLayoutWI{}) > size(SLayoutWI{})); + + if (!has_broadcast) { + // Calculate mapping from src val -> dst val on a chunk-by-chunk basis. Unlike a plain copy, there is no intrinsic + // correspondence of src/dst values for subgroup reorders. + auto rlayout = coalesce(composition(right_inverse(dlayout), slayout)); // src index -> dst index + auto vrlayout = composition(composition(Layout>, Stride<_0, _1>>{}, + rlayout), + Layout>, Stride<_0, _SG>>{}); // src val -> dst val + + CUTE_UNROLL + for (int sv = 0; sv < values; sv += vchunk) { + auto pS = recast_ptr(src.data() + sv); + auto pD = recast_ptr(dst.data() + vrlayout(sv)); + + detail::explode(detail::CallReorder{}, + pS, make_int_sequence{}, + pD, make_int_sequence{}); + } + } else { + // If there is broadcast happening, then we need to loop over dst values instead. + auto rlayout = coalesce(composition(right_inverse(slayout), dlayout)); // dst index -> src index + auto vrlayout = composition(composition(Layout>, Stride<_0, _1>>{}, + rlayout), + Layout>, Stride<_0, _SG>>{}); // dst val -> src val + + CUTE_UNROLL + for (int dv = 0; dv < values; dv += vchunk) { + auto pS = recast_ptr(src.data() + vrlayout(dv)); + auto pD = recast_ptr(dst.data() + dv); + + detail::explode(detail::CallReorder{}, + pS, make_int_sequence{}, + pD, make_int_sequence{}); + } } } diff --git a/include/cute/algorithm/subgroup_algorithms.hpp b/include/cute/algorithm/subgroup_algorithms.hpp new file mode 100644 index 0000000000..78f651603a --- /dev/null +++ b/include/cute/algorithm/subgroup_algorithms.hpp @@ -0,0 +1,379 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" +#include "cute/util/sycl_vec.hpp" + +namespace cute { + +// Uniformize a value, in case the compiler cannot prove it is subgroup-uniform. +template +CUTE_HOST_DEVICE +T +assert_uniform(T x) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + return group_broadcast(sg, x, 0); +} + +// Set a value in a single work-item -- x[i] = val. +// WARNING: i _must_ be a compile-time constant. +// No diagnostics/error will be issued by the compiler if it is not. +template +CUTE_HOST_DEVICE void +set_wi_value(T &x, int i, T val) +{ +#if defined(__SYCL_DEVICE_ONLY__) && defined(SYCL_INTEL_TARGET) + asm ( + "mov (M1_NM, 1) %0(0,%2)<1> %1(0,0)<1;1,0>" + : "+rw"(x) + : "rw.u"(val), "P"(i) + ); +#else + int lane = sycl::ext::oneapi::this_work_item::get_sub_group().get_local_id()[0]; + if (lane == i) + x = val; +#endif +} + +// Set an element of a 1D SG-shared fragment x. +// WARNING: i _must_ be a compile-time constant. +// No diagnostics/error will be issued by the compiler if it is not. +template +CUTE_HOST_DEVICE void +set_single_value(FragX& x, int i, typename FragX::element_type val) { + set_wi_value(x(i / intel::sg_size), i % intel::sg_size, val); +} + +// Broadcast the element from a 1D SG-shared fragment x +// corresponding to the Mode'th dimension of the logical coordinates of src(val). +template ::value)> +CUTE_HOST_DEVICE +constexpr auto +broadcast(FragX const& x, SGTensorSrc const& src, int val) +{ + auto coord = src.tv_layout()(0, val); + auto coord_i = get(coord); + + constexpr auto TMode = rank(as_arithmetic_tuple(stride<0>(SGTensorSrc{}.tv_layout()))) - 1; + if constexpr (TMode == Mode) { + return x(coord_i / intel::sg_size); + } else { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + return group_broadcast(sg, x(coord_i / intel::sg_size), coord_i % intel::sg_size); + } +} + +#if defined(__SYCL_DEVICE_ONLY__) && defined(SYCL_INTEL_TARGET) +#define DEFINE_HREDUCE16_FLOAT(op) \ + CUTE_DEVICE \ + float \ + hreduce16_float_ ## op(float x[16]) \ + { \ + float y; \ + asm ( \ + "{\n" \ + ".decl INTERLEAVE_2 v_type=P num_elts=16\n" \ + ".decl INTERLEAVE_4 v_type=P num_elts=16\n" \ + ".decl INTERLEAVE_8 v_type=P num_elts=16\n" \ + ".decl IN0 v_type=G type=UD num_elts=16 alias=<%1,0>\n" \ + ".decl IN1 v_type=G type=UD num_elts=16 alias=<%2,0>\n" \ + ".decl IN2 v_type=G type=UD num_elts=16 alias=<%3,0>\n" \ + ".decl IN3 v_type=G type=UD num_elts=16 alias=<%4,0>\n" \ + ".decl IN4 v_type=G type=UD num_elts=16 alias=<%5,0>\n" \ + ".decl IN5 v_type=G type=UD num_elts=16 alias=<%6,0>\n" \ + ".decl IN6 v_type=G type=UD num_elts=16 alias=<%7,0>\n" \ + ".decl IN7 v_type=G type=UD num_elts=16 alias=<%8,0>\n" \ + ".decl IN8 v_type=G type=UD num_elts=16 alias=<%9,0>\n" \ + ".decl IN9 v_type=G type=UD num_elts=16 alias=<%10,0>\n" \ + ".decl IN10 v_type=G type=UD num_elts=16 alias=<%11,0>\n" \ + ".decl IN11 v_type=G type=UD num_elts=16 alias=<%12,0>\n" \ + ".decl IN12 v_type=G type=UD num_elts=16 alias=<%13,0>\n" \ + ".decl IN13 v_type=G type=UD num_elts=16 alias=<%14,0>\n" \ + ".decl IN14 v_type=G type=UD num_elts=16 alias=<%15,0>\n" \ + ".decl IN15 v_type=G type=UD num_elts=16 alias=<%16,0>\n" \ + ".decl RA0 v_type=G type=UD num_elts=32 align=64\n" \ + ".decl RA2 v_type=G type=UD num_elts=32 align=64\n" \ + ".decl RA4 v_type=G type=UD num_elts=32 align=64\n" \ + ".decl RA6 v_type=G type=UD num_elts=32 align=64\n" \ + ".decl RA8 v_type=G type=UD num_elts=32 align=64\n" \ + ".decl RA10 v_type=G type=UD num_elts=32 align=64\n" \ + ".decl RA12 v_type=G type=UD num_elts=32 align=64\n" \ + ".decl RA14 v_type=G type=UD num_elts=32 align=64\n" \ + ".decl RF0 v_type=G type=F num_elts=16 alias=\n" \ + ".decl RF1 v_type=G type=F num_elts=16 alias=\n" \ + ".decl RF2 v_type=G type=F num_elts=16 alias=\n" \ + ".decl RF3 v_type=G type=F num_elts=16 alias=\n" \ + ".decl RF4 v_type=G type=F num_elts=16 alias=\n" \ + ".decl RF5 v_type=G type=F num_elts=16 alias=\n" \ + ".decl RF6 v_type=G type=F num_elts=16 alias=\n" \ + ".decl RF7 v_type=G type=F num_elts=16 alias=\n" \ + ".decl RF8 v_type=G type=F num_elts=16 alias=\n" \ + ".decl RF9 v_type=G type=F num_elts=16 alias=\n" \ + ".decl RF10 v_type=G type=F num_elts=16 alias=\n" \ + ".decl RF11 v_type=G type=F num_elts=16 alias=\n" \ + ".decl RF12 v_type=G type=F num_elts=16 alias=\n" \ + ".decl RF13 v_type=G type=F num_elts=16 alias=\n" \ + ".decl RF14 v_type=G type=F num_elts=16 alias=\n" \ + ".decl RF15 v_type=G type=F num_elts=16 alias=\n" \ + "setp (M1_NM,16) INTERLEAVE_2 0x5555:uw\n" \ + "setp (M1_NM,16) INTERLEAVE_4 0x3333:uw\n" \ + "setp (M1_NM,16) INTERLEAVE_8 0x0F0F:uw\n" \ + /* Round 1: interleave 2n with 2n+1 */ \ + "(!INTERLEAVE_2) sel (M1_NM,16) RA0(0,0)<1> IN1(0,0)<2;2,0> IN0(0,0)<1;1,0>\n" \ + " (INTERLEAVE_2) sel (M1_NM,16) RA0(1,0)<1> IN0(0,1)<2;2,0> IN1(0,0)<1;1,0>\n" \ + "(!INTERLEAVE_2) sel (M1_NM,16) RA2(0,0)<1> IN3(0,0)<2;2,0> IN2(0,0)<1;1,0>\n" \ + " (INTERLEAVE_2) sel (M1_NM,16) RA2(1,0)<1> IN2(0,1)<2;2,0> IN3(0,0)<1;1,0>\n" \ + "(!INTERLEAVE_2) sel (M1_NM,16) RA4(0,0)<1> IN5(0,0)<2;2,0> IN4(0,0)<1;1,0>\n" \ + " (INTERLEAVE_2) sel (M1_NM,16) RA4(1,0)<1> IN4(0,1)<2;2,0> IN5(0,0)<1;1,0>\n" \ + "(!INTERLEAVE_2) sel (M1_NM,16) RA6(0,0)<1> IN7(0,0)<2;2,0> IN6(0,0)<1;1,0>\n" \ + " (INTERLEAVE_2) sel (M1_NM,16) RA6(1,0)<1> IN6(0,1)<2;2,0> IN7(0,0)<1;1,0>\n" \ + "(!INTERLEAVE_2) sel (M1_NM,16) RA8(0,0)<1> IN9(0,0)<2;2,0> IN8(0,0)<1;1,0>\n" \ + " (INTERLEAVE_2) sel (M1_NM,16) RA8(1,0)<1> IN8(0,1)<2;2,0> IN9(0,0)<1;1,0>\n" \ + "(!INTERLEAVE_2) sel (M1_NM,16) RA10(0,0)<1> IN11(0,0)<2;2,0> IN10(0,0)<1;1,0>\n" \ + " (INTERLEAVE_2) sel (M1_NM,16) RA10(1,0)<1> IN10(0,1)<2;2,0> IN11(0,0)<1;1,0>\n" \ + "(!INTERLEAVE_2) sel (M1_NM,16) RA12(0,0)<1> IN13(0,0)<2;2,0> IN12(0,0)<1;1,0>\n" \ + " (INTERLEAVE_2) sel (M1_NM,16) RA12(1,0)<1> IN12(0,1)<2;2,0> IN13(0,0)<1;1,0>\n" \ + "(!INTERLEAVE_2) sel (M1_NM,16) RA14(0,0)<1> IN15(0,0)<2;2,0> IN14(0,0)<1;1,0>\n" \ + " (INTERLEAVE_2) sel (M1_NM,16) RA14(1,0)<1> IN14(0,1)<2;2,0> IN15(0,0)<1;1,0>\n" \ + /* Reduce */ \ + #op " (M1_NM,16) RF0(0,0)<1> RF0(0,0)<1;1,0> RF1(0,0)<1;1,0>\n" \ + #op " (M1_NM,16) RF3(0,0)<1> RF2(0,0)<1;1,0> RF3(0,0)<1;1,0>\n" \ + #op " (M1_NM,16) RF4(0,0)<1> RF4(0,0)<1;1,0> RF5(0,0)<1;1,0>\n" \ + #op " (M1_NM,16) RF7(0,0)<1> RF6(0,0)<1;1,0> RF7(0,0)<1;1,0>\n" \ + #op " (M1_NM,16) RF8(0,0)<1> RF8(0,0)<1;1,0> RF9(0,0)<1;1,0>\n" \ + #op " (M1_NM,16) RF11(0,0)<1> RF10(0,0)<1;1,0> RF11(0,0)<1;1,0>\n" \ + #op " (M1_NM,16) RF12(0,0)<1> RF12(0,0)<1;1,0> RF13(0,0)<1;1,0>\n" \ + #op " (M1_NM,16) RF15(0,0)<1> RF14(0,0)<1;1,0> RF15(0,0)<1;1,0>\n" \ + /* Round 2: interleave 4n+{0,1} with 4n+{2,3} */ \ + "(!INTERLEAVE_4) sel (M1_NM,16) RA0(1,0)<1> RA2(0,14)<1;1,0> RA0(0,0)<1;1,0>\n" \ + " (INTERLEAVE_4) sel (M1_NM,16) RA0(0,0)<1> RA0(0,2)<1;1,0> RA2(1,0)<1;1,0>\n" \ + "(!INTERLEAVE_4) sel (M1_NM,16) RA4(1,0)<1> RA6(0,14)<1;1,0> RA4(0,0)<1;1,0>\n" \ + " (INTERLEAVE_4) sel (M1_NM,16) RA4(0,0)<1> RA4(0,2)<1;1,0> RA6(1,0)<1;1,0>\n" \ + "(!INTERLEAVE_4) sel (M1_NM,16) RA8(1,0)<1> RA10(0,14)<1;1,0> RA8(0,0)<1;1,0>\n" \ + " (INTERLEAVE_4) sel (M1_NM,16) RA8(0,0)<1> RA8(0,2)<1;1,0> RA10(1,0)<1;1,0>\n" \ + "(!INTERLEAVE_4) sel (M1_NM,16) RA12(1,0)<1> RA14(0,14)<1;1,0> RA12(0,0)<1;1,0>\n" \ + " (INTERLEAVE_4) sel (M1_NM,16) RA12(0,0)<1> RA12(0,2)<1;1,0> RA14(1,0)<1;1,0>\n" \ + /* Reduce */ \ + #op " (M1_NM,16) RF0(0,0)<1> RF0(0,0)<1;1,0> RF1(0,0)<1;1,0>\n" \ + #op " (M1_NM,16) RF5(0,0)<1> RF4(0,0)<1;1,0> RF5(0,0)<1;1,0>\n" \ + #op " (M1_NM,16) RF8(0,0)<1> RF8(0,0)<1;1,0> RF9(0,0)<1;1,0>\n" \ + #op " (M1_NM,16) RF13(0,0)<1> RF12(0,0)<1;1,0> RF13(0,0)<1;1,0>\n" \ + /* Round 3: interleave 8n+{0,1,2,3} with 8n+{4,5,6,7} */ \ + "(!INTERLEAVE_8) sel (M1_NM,16) RA0(1,0)<1> RA4(0,12)<1;1,0> RA0(0,0)<1;1,0>\n" \ + " (INTERLEAVE_8) sel (M1_NM,16) RA0(0,0)<1> RA0(0,4)<1;1,0> RA4(1,0)<1;1,0>\n" \ + "(!INTERLEAVE_8) sel (M1_NM,16) RA8(1,0)<1> RA12(0,12)<1;1,0> RA8(0,0)<1;1,0>\n" \ + " (INTERLEAVE_8) sel (M1_NM,16) RA8(0,0)<1> RA8(0,4)<1;1,0> RA12(1,0)<1;1,0>\n" \ + /* Reduce */ \ + #op " (M1_NM,16) RF0(0,0)<1> RF0(0,0)<1;1,0> RF1(0,0)<1;1,0>\n" \ + #op " (M1_NM,16) RF8(0,0)<1> RF8(0,0)<1;1,0> RF9(0,0)<1;1,0>\n" \ + /* Round 4: final interleave */ \ + "mov (M1_NM, 8) RA0(1,0)<1> RA0(0,8)<1;1,0>\n" \ + "mov (M1_NM, 8) RA8(1,8)<1> RA8(0,0)<1;1,0>\n" \ + /* Reduce */ \ + #op " (M1_NM,8) %0(0,0)<1> RF0(0,0)<1;1,0> RF1(0,0)<1;1,0>\n" \ + #op " (M1_NM,8) %0(0,8)<1> RF8(0,8)<1;1,0> RF9(0,8)<1;1,0>\n" \ + "}\n" \ + : "=rw"(y) \ + : "rw"(x[0]), "rw"(x[1]), "rw"(x[2]), "rw"(x[3]), "rw"(x[4]), "rw"(x[5]), "rw"(x[6]), "rw"(x[7]), \ + "rw"(x[8]), "rw"(x[9]), "rw"(x[10]), "rw"(x[11]), "rw"(x[12]), "rw"(x[13]), "rw"(x[14]), "rw"(x[15]) \ + ); \ + return y; \ + } + +#define DEFINE_HREDUCE8_FLOAT(op) \ + CUTE_DEVICE \ + float \ + hreduce8_float_ ## op(float x[8]) \ + { \ + float y; \ + asm ( \ + "{\n" \ + ".decl INTERLEAVE_2 v_type=P num_elts=16\n" \ + ".decl INTERLEAVE_4 v_type=P num_elts=16\n" \ + ".decl INTERLEAVE_8 v_type=P num_elts=16\n" \ + ".decl IN0 v_type=G type=UD num_elts=16 alias=<%1,0>\n" \ + ".decl IN1 v_type=G type=UD num_elts=16 alias=<%2,0>\n" \ + ".decl IN2 v_type=G type=UD num_elts=16 alias=<%3,0>\n" \ + ".decl IN3 v_type=G type=UD num_elts=16 alias=<%4,0>\n" \ + ".decl IN4 v_type=G type=UD num_elts=16 alias=<%5,0>\n" \ + ".decl IN5 v_type=G type=UD num_elts=16 alias=<%6,0>\n" \ + ".decl IN6 v_type=G type=UD num_elts=16 alias=<%7,0>\n" \ + ".decl IN7 v_type=G type=UD num_elts=16 alias=<%8,0>\n" \ + ".decl RA0 v_type=G type=UD num_elts=32 align=64\n" \ + ".decl RA2 v_type=G type=UD num_elts=32 align=64\n" \ + ".decl RA4 v_type=G type=UD num_elts=32 align=64\n" \ + ".decl RA6 v_type=G type=UD num_elts=32 align=64\n" \ + ".decl RF0 v_type=G type=F num_elts=16 alias=\n" \ + ".decl RF1 v_type=G type=F num_elts=16 alias=\n" \ + ".decl RF2 v_type=G type=F num_elts=16 alias=\n" \ + ".decl RF3 v_type=G type=F num_elts=16 alias=\n" \ + ".decl RF4 v_type=G type=F num_elts=16 alias=\n" \ + ".decl RF5 v_type=G type=F num_elts=16 alias=\n" \ + ".decl RF6 v_type=G type=F num_elts=16 alias=\n" \ + ".decl RF7 v_type=G type=F num_elts=16 alias=\n" \ + "setp (M1_NM,16) INTERLEAVE_2 0x5555:uw\n" \ + "setp (M1_NM,16) INTERLEAVE_4 0x3333:uw\n" \ + "setp (M1_NM,16) INTERLEAVE_8 0x0F0F:uw\n" \ + /* Round 1: interleave 2n with 2n+1 */ \ + "(!INTERLEAVE_2) sel (M1_NM,16) RA0(0,0)<1> IN1(0,0)<2;2,0> IN0(0,0)<1;1,0>\n" \ + " (INTERLEAVE_2) sel (M1_NM,16) RA0(1,0)<1> IN0(0,1)<2;2,0> IN1(0,0)<1;1,0>\n" \ + "(!INTERLEAVE_2) sel (M1_NM,16) RA2(0,0)<1> IN3(0,0)<2;2,0> IN2(0,0)<1;1,0>\n" \ + " (INTERLEAVE_2) sel (M1_NM,16) RA2(1,0)<1> IN2(0,1)<2;2,0> IN3(0,0)<1;1,0>\n" \ + "(!INTERLEAVE_2) sel (M1_NM,16) RA4(0,0)<1> IN5(0,0)<2;2,0> IN4(0,0)<1;1,0>\n" \ + " (INTERLEAVE_2) sel (M1_NM,16) RA4(1,0)<1> IN4(0,1)<2;2,0> IN5(0,0)<1;1,0>\n" \ + "(!INTERLEAVE_2) sel (M1_NM,16) RA6(0,0)<1> IN7(0,0)<2;2,0> IN6(0,0)<1;1,0>\n" \ + " (INTERLEAVE_2) sel (M1_NM,16) RA6(1,0)<1> IN6(0,1)<2;2,0> IN7(0,0)<1;1,0>\n" \ + /* Reduce */ \ + #op " (M1_NM,16) RF0(0,0)<1> RF0(0,0)<1;1,0> RF1(0,0)<1;1,0>\n" \ + #op " (M1_NM,16) RF3(0,0)<1> RF2(0,0)<1;1,0> RF3(0,0)<1;1,0>\n" \ + #op " (M1_NM,16) RF4(0,0)<1> RF4(0,0)<1;1,0> RF5(0,0)<1;1,0>\n" \ + #op " (M1_NM,16) RF7(0,0)<1> RF6(0,0)<1;1,0> RF7(0,0)<1;1,0>\n" \ + /* Round 2: interleave 4n+{0,1} with 4n+{2,3} */ \ + "(!INTERLEAVE_4) sel (M1_NM,16) RA0(1,0)<1> RA2(0,14)<1;1,0> RA0(0,0)<1;1,0>\n" \ + " (INTERLEAVE_4) sel (M1_NM,16) RA0(0,0)<1> RA0(0,2)<1;1,0> RA2(1,0)<1;1,0>\n" \ + "(!INTERLEAVE_4) sel (M1_NM,16) RA4(1,0)<1> RA6(0,14)<1;1,0> RA4(0,0)<1;1,0>\n" \ + " (INTERLEAVE_4) sel (M1_NM,16) RA4(0,0)<1> RA4(0,2)<1;1,0> RA6(1,0)<1;1,0>\n" \ + /* Reduce */ \ + #op " (M1_NM,16) RF0(0,0)<1> RF0(0,0)<1;1,0> RF1(0,0)<1;1,0>\n" \ + #op " (M1_NM,16) RF5(0,0)<1> RF4(0,0)<1;1,0> RF5(0,0)<1;1,0>\n" \ + /* Round 3: interleave 8n+{0,1,2,3} with 8n+{4,5,6,7} */ \ + "(!INTERLEAVE_8) sel (M1_NM,16) RA0(1,0)<1> RA4(0,12)<1;1,0> RA0(0,0)<1;1,0>\n" \ + " (INTERLEAVE_8) sel (M1_NM,16) RA0(0,0)<1> RA0(0,4)<1;1,0> RA4(1,0)<1;1,0>\n" \ + /* Reduce */ \ + #op " (M1_NM,16) RF0(0,0)<1> RF0(0,0)<1;1,0> RF1(0,0)<1;1,0>\n" \ + /* Round 4: reduce top and bottom halves */ \ + #op " (M1_NM,8) %0(0,0)<1> RF0(0,0)<1;1,0> RF0(0,8)<1;1,0>\n" \ + "}\n" \ + : "=rw"(y) \ + : "rw"(x[0]), "rw"(x[1]), "rw"(x[2]), "rw"(x[3]), "rw"(x[4]), "rw"(x[5]), "rw"(x[6]), "rw"(x[7]), \ + "rw"(x[8]), "rw"(x[9]), "rw"(x[10]), "rw"(x[11]), "rw"(x[12]), "rw"(x[13]), "rw"(x[14]), "rw"(x[15]) \ + ); \ + return y; \ + } +#else +#define DEFINE_HREDUCE16_FLOAT(op) \ + CUTE_DEVICE float hreduce16_float_ ## op(float x[16]) { return 0.f; } +#define DEFINE_HREDUCE8_FLOAT(op) \ + CUTE_DEVICE float hreduce8_float_ ## op(float x[8]) { return 0.f; } +#endif + +DEFINE_HREDUCE8_FLOAT(add) +DEFINE_HREDUCE8_FLOAT(max) +DEFINE_HREDUCE16_FLOAT(add) +DEFINE_HREDUCE16_FLOAT(max) + +// Subgroup-cooperative reduction of a SubgroupTensor. +template +CUTE_HOST_DEVICE +auto +reduce(SubgroupTensor const& src, BinaryOp op) +{ + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + using T = typename Engine::value_type; + using TVToV = Layout, Stride<_0,_1>>; + + /* Retrieve logical coordinate -> (T,V) mapping */ + constexpr auto shape = atuple_coshape(SubgroupTVLayout{}); + constexpr auto coord_to_tv = right_inverse(project_strides(SubgroupTVLayout{})).with_shape(shape); + + /* Move reduction coordinate to mode-0 and group the rest in mode-1. Then, remove work-item modes. */ + constexpr auto rcoord_to_tv = make_layout(select(coord_to_tv), remove(coord_to_tv)); + constexpr auto rcoord_to_v = filter(composition(TVToV{}, rcoord_to_tv), Step<_1,_1>{}); + + /* Regroup input tensor */ + Tensor src_r = make_tensor(src.data(), rcoord_to_v); + + /* Create output tensor */ + auto rshape = replace(shape, _1{}); + Tensor out = make_subgroup_tensor(make_tensor(ceil_div(size(rshape), intel::_SGSize{})), + make_identity_layout(rshape)); + + /* Check for reduction type */ + constexpr bool horizontal = (size<0>(rcoord_to_tv) == intel::_SGSize{} * size<0>(rcoord_to_v)); + constexpr bool vertical = (size<1>(rcoord_to_tv) == intel::_SGSize{} * size<1>(rcoord_to_v)); + + /* Check for optimized reductions */ + constexpr bool align16 = is_constant_v<0, decltype(size<1>(rcoord_to_v) % _16{})>; + constexpr bool align8 = is_constant_v<8, decltype(size<1>(rcoord_to_v))>; + + constexpr bool hadd = (horizontal && is_same_v && is_same_v>); + constexpr bool hmax = (horizontal && is_same_v && is_same_v>); + + constexpr bool hadd16 = hadd && align16; + constexpr bool hmax16 = hmax && align16; + + constexpr bool hadd8 = hadd && align8; + constexpr bool hmax8 = hmax && align8; + + [[maybe_unused]] T temp[size<1>(rcoord_to_v)]; /* array of partial reductions */ + + CUTE_UNROLL + for (int j = 0; j < size<1>(rcoord_to_v); j++) { + T acc = src_r(0, j); + CUTE_UNROLL + for (int i = 1; i < size<0>(rcoord_to_v); i++) { + acc = op(acc, src_r(i, j)); + } + + if constexpr (hadd16 || hmax16 || hadd8 || hmax8) + temp[j] = acc; + else if constexpr (horizontal) + set_single_value(out, j, reduce_over_group(sg, acc, op)); + else if constexpr (vertical) + out(j) = acc; + else + static_assert("Unimplemented reduction type"); + } + + if constexpr (hadd16) { + CUTE_UNROLL + for (int j = 0; j < size<1>(rcoord_to_v); j += 16) { + out(j/16) = hreduce16_float_add(&temp[j]); + } + } else if constexpr (hmax16) { + CUTE_UNROLL + for (int j = 0; j < size<1>(rcoord_to_v); j += 16) { + out(j/16) = hreduce16_float_max(&temp[j]); + } + } else if constexpr (hadd8) { + out(0) = hreduce8_float_add(&temp[0]); + } else if constexpr (hmax8) { + out(0) = hreduce8_float_max(&temp[0]); + } + + return out; +} + +} // namespace cute diff --git a/include/cute/arch/copy_xe_legacy.hpp b/include/cute/arch/copy_xe_legacy.hpp index a414885033..fc3da90055 100644 --- a/include/cute/arch/copy_xe_legacy.hpp +++ b/include/cute/arch/copy_xe_legacy.hpp @@ -47,25 +47,4 @@ #include #include -// FIXME: these are not copy-related and should be declared elsewhere. -#ifdef __SYCL_DEVICE_ONLY__ -SYCL_EXTERNAL __attribute__((convergent)) void __spirv_ControlBarrierWaitINTEL(int execution_scope, int memory_scope, int memory_semantics); -SYCL_EXTERNAL __attribute__((convergent)) void __spirv_ControlBarrierArriveINTEL(int execution_scope, int memory_scope, int memory_semantics); -#endif - -namespace cute -{ - -// scope = 3 is for subgroup, scop = 2 is for workgroup -CUTE_HOST_DEVICE void barrier_arrive(int scope, int memory_scope = 0, int memory_semantics = 0) { -#ifdef __SYCL_DEVICE_ONLY__ - __spirv_ControlBarrierArriveINTEL(scope, memory_scope, memory_semantics); -#endif -} -CUTE_HOST_DEVICE void barrier_wait(int scope, int memory_scope = 0, int memory_semantics = 0) { -#ifdef __SYCL_DEVICE_ONLY__ - __spirv_ControlBarrierWaitINTEL(scope, memory_scope, memory_semantics); -#endif -} - -} // end namespace cute +#include \ No newline at end of file diff --git a/include/cute/arch/reorder.hpp b/include/cute/arch/reorder.hpp index a2caa033f0..dd0e32e193 100644 --- a/include/cute/arch/reorder.hpp +++ b/include/cute/arch/reorder.hpp @@ -41,7 +41,7 @@ struct Universal_Reorder_UU { CUTE_HOST_DEVICE static void reorder(SrcType const& src0, DstType& dst0) { - dst0 = src0; + dst0 = DstType(src0); } }; diff --git a/include/cute/arch/reorder_xe.hpp b/include/cute/arch/reorder_xe.hpp index 42e701a4ce..6029889400 100644 --- a/include/cute/arch/reorder_xe.hpp +++ b/include/cute/arch/reorder_xe.hpp @@ -64,7 +64,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort4& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) asm ( /* 2 cycles/output register */ "{\n" ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" @@ -93,7 +93,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort4& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) asm ( /* 2 cycles/output register */ "{\n" ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" @@ -122,7 +122,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort4& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) asm ( /* 2 cycles/output register */ "{\n" ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" @@ -151,7 +151,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort4& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) asm ( /* 2 cycles/output register */ "{\n" ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" @@ -192,7 +192,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort4& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) asm ( /* 4 cycles/output register */ "{\n" ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" @@ -219,7 +219,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort4& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) asm ( /* 4 cycles/output register */ "{\n" ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" @@ -259,7 +259,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort4& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) const uint32_t scale = 0x7E000000; const uint32_t shift = 0xBF000000; asm ( /* 4 cycles/output register */ @@ -288,7 +288,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort4& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) const uint32_t scale = 0x7E000000; const uint32_t shift = 0xBF000000; asm ( /* 4 cycles/output register */ @@ -317,7 +317,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort4& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) asm ( /* 1 cycle/output register */ "{\n" ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" @@ -343,7 +343,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort4& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) asm ( /* 1 cycle/output register */ "{\n" ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" @@ -380,7 +380,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort4& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) asm ( /* 5 cycles/output register */ "{\n" ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" @@ -407,7 +407,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort4& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) asm ( /* 5 cycles/output register */ "{\n" ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" @@ -453,7 +453,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort4& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) asm ( /* 6 cycles/output register */ "{\n" ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" @@ -480,7 +480,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort4& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) asm ( /* 6 cycles/output register */ "{\n" ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" @@ -525,7 +525,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort4& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) asm ( /* 7 cycles/output register */ "{\n" ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" @@ -552,7 +552,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort4& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) asm ( /* 7 cycles/output register */ "{\n" ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" @@ -592,7 +592,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort8& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) const uint32_t shifts = 0x00040000; asm ( /* 3 cycles/output register */ "{\n" @@ -623,7 +623,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort8& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) const uint32_t shifts = 0x00040000; asm ( /* 4 cycles/output register */ "{\n" @@ -658,7 +658,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort8& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) const uint32_t shifts = 0x00040000; asm ( /* 3 cycles/output register */ "{\n" @@ -702,7 +702,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort8& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) const uint32_t shifts = 0x00040000; asm ( /* 3 cycles/output register */ "{\n" @@ -733,7 +733,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort8& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) const uint32_t shifts = 0x00040000; asm ( /* 4 cycles/output register */ "{\n" @@ -768,7 +768,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort8& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) const uint32_t shifts = 0x00040000; asm ( /* 3 cycles/output register */ "{\n" @@ -812,7 +812,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort8& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) const uint32_t shifts = 0x00040000; asm ( /* 4 cycles/output register */ "{\n" @@ -843,7 +843,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort8& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) const uint32_t shifts = 0x00040000; asm ( /* 5 cycles/output register */ "{\n" @@ -878,7 +878,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort8& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) const uint32_t shifts = 0x00040000; asm ( /* 4 cycles/output register */ "{\n" @@ -922,7 +922,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort8& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) const uint32_t shifts = 0x00040000; asm ( /* 4 cycles/output register */ "{\n" @@ -953,7 +953,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort8& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) const uint32_t shifts = 0x00040000; asm ( /* 5 cycles/output register */ "{\n" @@ -988,7 +988,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort8& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) const uint32_t shifts = 0x00040000; asm ( /* 4 cycles/output register */ "{\n" @@ -1035,7 +1035,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort8& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) const uint32_t shifts = 0x0008000C; asm ( /* 4 cycles/output register */ "{\n" @@ -1066,7 +1066,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort8& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) const uint32_t shifts = 0x0008000C; asm ( /* 5 cycles/output register */ "{\n" @@ -1101,7 +1101,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort8& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) const uint32_t shifts = 0x0008000C; asm ( /* 4 cycles/output register */ "{\n" @@ -1148,7 +1148,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort8& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) const uint32_t shifts = 0x0008000C; asm ( /* 5 cycles/output register */ "{\n" @@ -1179,7 +1179,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort8& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) const uint32_t shifts = 0x0008000C; asm ( /* 6 cycles/output register */ "{\n" @@ -1214,7 +1214,7 @@ struct Xe_Reorder CUTE_HOST_DEVICE static void reorder(intel::uchar4 const& src0, intel::ushort8& dst0) { -#if defined(CUTE_ARCH_COPY_XE_ENABLED) +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) const uint32_t shifts = 0x0008000C; asm ( /* 5 cycles/output register */ "{\n" @@ -1236,7 +1236,77 @@ struct Xe_Reorder } }; +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::float4[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::float4& dst0) + { +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) + asm ( /* 3 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=128 alias=<%0,0>\n" + "shl (M1_NM, 32) OUT_UW(0,1)<2> IN_UB(0,0)<1;1,0> 7:uw\n" + "shl (M1_NM, 32) OUT_UW(2,1)<2> IN_UB(0,32)<1;1,0> 7:uw\n" + "add.sat (M1_NM, 32) OUT_UW(0,0)<2> IN_UB(0,0)<1;1,0> -254:w\n" + "add.sat (M1_NM, 32) OUT_UW(2,0)<2> IN_UB(0,32)<1;1,0> -254:w\n" + "max (M1_NM, 32) OUT_UW(0,1)<2> OUT_UW(0,1)<2> 0x40:uw\n" + "max (M1_NM, 32) OUT_UW(2,1)<2> OUT_UW(2,1)<2> 0x40:uw\n" + "}\n" + : "=rw"(dst0) + : "rw"(src0) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::uchar4[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::uchar4& dst0) + { +#if defined(CUTE_ARCH_REORDER_XE_ENABLED) + const uint32_t lshifts = 0x00000004; + const uint32_t rshifts = 0x00040000; + asm ( /* 9 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UB v_type=G type=UB num_elts=64 alias=<%0,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=32 alias=<%0,0>\n" + ".decl LSHIFTS v_type=G type=UW num_elts=2 alias=<%2,0>\n" + ".decl RSHIFTS v_type=G type=UW num_elts=2 alias=<%3,0>\n" + ".decl TMP_UB v_type=G type=UB num_elts=64 align=64\n" + ".decl TMP_UW v_type=G type=UW num_elts=32 alias=\n" + "shr (M1_NM, 16) OUT_UB(0,0)<4> IN_UB(0, 0)<1;2,0> RSHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 16) OUT_UB(0,1)<4> IN_UB(0,16)<1;2,0> RSHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 16) OUT_UB(0,2)<4> IN_UB(0,32)<1;2,0> RSHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 16) OUT_UB(0,3)<4> IN_UB(0,48)<1;2,0> RSHIFTS(0,0)<0;2,1>\n" + "shl (M1_NM, 16) TMP_UB(0,0)<4> IN_UB(0, 8)<1;2,0> LSHIFTS(0,0)<0;2,1>\n" + "shl (M1_NM, 16) TMP_UB(0,1)<4> IN_UB(0,24)<1;2,0> LSHIFTS(0,0)<0;2,1>\n" + "shl (M1_NM, 16) TMP_UB(0,2)<4> IN_UB(0,40)<1;2,0> LSHIFTS(0,0)<0;2,1>\n" + "shl (M1_NM, 16) TMP_UB(0,3)<4> IN_UB(0,56)<1;2,0> LSHIFTS(0,0)<0;2,1>\n" + "bfn.xCA (M1_NM, 32) OUT_UW(0,0)<1> OUT_UW(0,0)<1;1,0> TMP_UW(0,0)<1;1,0> 0xF0F0:uw\n" + "}\n" + : "=rw"(dst0) + : "rw"(src0), "rw.u"(lshifts), "rw.u"(rshifts) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; +template <> struct Xe_Reorder : Xe_Reorder {}; +template <> struct Xe_Reorder : Xe_Reorder {}; } // end namespace cute diff --git a/include/cute/atom/copy_atom.hpp b/include/cute/atom/copy_atom.hpp index 69f492807d..efdbff16dc 100644 --- a/include/cute/atom/copy_atom.hpp +++ b/include/cute/atom/copy_atom.hpp @@ -221,11 +221,11 @@ struct TiledCopy : Copy_Atom // Tile a tensor or a layout from shape // (M,N,...) // to shape - // ((ThrV,ThrX),FrgV,(RestM,RestN,...)) + // (Thr,(FrgV,FrgX),(RestM,RestN,...)) // where - // ThrV: The threads local to a COPY_ATOM Src. - // ThrX: The threads tiled across COPY_ATOMs Src. + // Thr: The logical threads within the tiled copy. // FrgV: The values local to a COPY_ATOM Src. + // FrgX: The values tiled across COPY_ATOMs Src. // RestM: The values tiled in M. // RestN: The values tiled in N. template @@ -242,11 +242,11 @@ struct TiledCopy : Copy_Atom // Tile a tensor or a layout from shape // (M,N,...) // to shape - // ((ThrV,ThrX),FrgV,(RestM,RestN,...)) + // (Thr,(FrgV,FrgX),(RestM,RestN,...)) // where - // ThrV: The threads local to a COPY_ATOM Dst. - // ThrX: The threads tiled across COPY_ATOMs Dst. + // Thr: The logical threads within the tiled copy. // FrgV: The values local to a COPY_ATOM Dst. + // FrgX: The values tiled across COPY_ATOMs Dst. // RestM: The values tiled in M. // RestN: The values tiled in N. template @@ -263,7 +263,7 @@ struct TiledCopy : Copy_Atom // Tile a tensor or a layout from shape // ((TileM,TileN,...), (RestM,RestN,...)) // to shape - // ((ThrV,ThrX),FrgV,(RestM,RestN,...)) + // (Thr,(FrgV,FrgX),(RestM,RestN,...)) template CUTE_HOST_DEVICE constexpr static auto @@ -395,30 +395,31 @@ struct ThrCopy return thr_tensor(thr_idx_, _, repeat>(_)); } - template + template CUTE_HOST_DEVICE auto - atom_partition_S(STensor&& stensor) const { - // Get fragment layout, and group atom thread modes (ThrV) since that is not done by tidfrg_D. - static constexpr auto RThrV = rank<0>(typename TiledCopy::AtomLayoutSrc{}); - auto tf_layout0 = TiledCopy::tidfrg_S(stensor.layout()); - auto tf_layout = replace<0>(tf_layout0, group<0,RThrV>(get<0>(tf_layout0))); - auto thr_tensor = make_tensor(static_cast(stensor).data(), tf_layout); + atom_partition(SDTensor&& sdtensor, ThrFrgLayout const& tf_layout0) const { + // Get fragment layout, and group atom thread modes (ThrV) since that is not done by tidfrg_S/D. + auto tf_layout = logical_divide(tf_layout0, Shape{}); + auto thr_tensor = make_tensor(static_cast(sdtensor).data(), tf_layout); + // Index, selecting full ThrV slice. auto thr = idx2crd(thr_idx_, shape<0>(thr_tensor)); - return thr_tensor(replace<0>(thr, _), _, _); + return thr_tensor(replace<0>(thr,_),_,_); + } + + template + CUTE_HOST_DEVICE + auto + atom_partition_S(STensor&& stensor) const { + return atom_partition(static_cast(stensor), TiledCopy::tidfrg_S(stensor.layout())); } template CUTE_HOST_DEVICE auto atom_partition_D(DTensor&& dtensor) const { - static constexpr auto RThrV = rank<0>(typename TiledCopy::AtomLayoutDst{}); - auto tf_layout0 = TiledCopy::tidfrg_D(dtensor.layout()); - auto tf_layout = replace<0>(tf_layout0, group<0,RThrV>(get<0>(tf_layout0))); - auto thr_tensor = make_tensor(static_cast(dtensor).data(), tf_layout); - auto thr = idx2crd(thr_idx_, shape<0>(thr_tensor)); - return thr_tensor(replace<0>(thr, _), _, _); + return atom_partition(static_cast(dtensor), TiledCopy::tidfrg_D(dtensor.layout())); } template @@ -608,7 +609,8 @@ make_cotiled_copy(Copy_Atom const& copy_atom, auto layout_tv_data = composition(inv_data_layout, atom_tv_layout); // Check validity - CUTE_STATIC_ASSERT_V(coalesce(composition(data_layout, layout<1>(layout_tv_data))) == coalesce(layout<1>(atom_tv_layout)), + // Append 1:0 to data_layout so that OOB coordinates get the stride-0 + CUTE_STATIC_ASSERT_V(coalesce(composition(make_layout(data_layout, Layout<_1,_0>{}), layout<1>(layout_tv_data))) == coalesce(layout<1>(atom_tv_layout)), "The memory pointed to by AtomTVLayout does not exist in the DataLayout."); // // Tiler -- Find the active elements in the DATA tensor and generate a tiler to extract them diff --git a/include/cute/atom/copy_traits_xe_2d.hpp b/include/cute/atom/copy_traits_xe_2d.hpp index 2df8ae0a38..a59d5a5c77 100644 --- a/include/cute/atom/copy_traits_xe_2d.hpp +++ b/include/cute/atom/copy_traits_xe_2d.hpp @@ -56,7 +56,7 @@ namespace cute { // Utility to check if a layout belongs to a coordinate tensor. template -static constexpr bool is_counting_layout_v = is_arithmetic_tuple_like::value; +static constexpr bool is_counting_layout_v = is_arithmetic_tuple_like::value || is_constant_v<1, decltype(size(Layout{}))>; @@ -125,7 +125,7 @@ struct Xe2DTraitsBase assert((height <= 0xFFFFFF) && "CuTe runtime error: block 2D tensor height exceeds 2^24"); assert((pitch <= 0xFFFFFF) && "CuTe runtime error: block 2D tensor pitch exceeds 2^24"); #endif - init_payload(); + device_init(); } template @@ -134,7 +134,7 @@ struct Xe2DTraitsBase : base_ptr(other.base_ptr), width(other.width), height(other.height), pitch(other.pitch), tiled_strides(other.tiled_strides) { - init_payload(); + device_init(); } // Initialize a previously-uninitialized atom. @@ -145,7 +145,7 @@ struct Xe2DTraitsBase } CUTE_DEVICE - void init_payload() { + void device_init() const { #ifdef __SYCL_DEVICE_ONLY__ payload = __builtin_IB_subgroup_createBlock2DAddressPayload( base_ptr, @@ -500,7 +500,8 @@ make_block_2d_copy(const CopyOp& op, using LayoutCopy_TV = typename SGCopy::TiledLayout_TV; // Expand the shape. - auto x_shape = elem_scale(ShapeTiler_MN{}, atom_shape); + auto x_atom_shape = append>(atom_shape, _1{}); + auto x_shape = elem_scale(ShapeTiler_MN{}, x_atom_shape); // Expand the single-SG TV layout to the full shape, then tile. auto x_tv_layout1 = composition(make_layout(ShapeTiler_MN{}, make_layout(x_shape).stride()), LayoutCopy_TV{}); @@ -724,12 +725,11 @@ block_2d_selector(CoordLayout const&, GlobalStride const&) } // Helper for make_block_2d_copy_* routines -template CUTE_HOST_DEVICE auto make_block_2d_copy_X(CopyOp const& op, // Copy operation - TiledMMA const& mma, // TiledMMA instance Stride const& gstride, // Global memory strides XMode const& x_mode, // x, y modes YMode const& y_mode, @@ -756,6 +756,16 @@ make_block_2d_copy_X(CopyOp const& op, // Copy operation return make_block_2d_copy(op, gstride, x_mode, y_mode, atom_shape, sv_layout_t); } +// Single trait with specializations +template struct is_xe_block_2d_atom : std::false_type {}; +template struct is_xe_block_2d_atom> : std::true_type {}; +template struct is_xe_block_2d_atom> : std::true_type {}; +template struct is_xe_block_2d_atom> : std::true_type {}; +template struct is_xe_block_2d_atom> : std::true_type {}; + +template constexpr bool is_xe_block_2d_atom_v = is_xe_block_2d_atom::value; + + // MMA-focused TiledCopy creation functions. template CUTE_HOST_DEVICE @@ -774,6 +784,7 @@ make_block_2d_copy_A(CopyOp const& op, // Copy operation TiledMMA const& mma, // TiledMMA instance Tensor const& gmem) // Global tensor { + static_assert(is_xe_block_2d_atom_v, "Expected a block 2D atom"); using ValType = typename GEngine::value_type; return make_block_2d_copy_A(op, mma, gmem.stride()).with(gmem); } @@ -826,7 +837,7 @@ make_block_2d_copy_A(CopyOp const& op, // Copy operation make_tile(sg_to_vmk, _)); // (SG,V) -> (M,K) // Derive copy tile layout and create TiledCopy - return make_block_2d_copy_X(op, mma, gstride, x_mode, y_mode, tile_mk, svA); + return make_block_2d_copy_X(op, gstride, x_mode, y_mode, tile_mk, svA); } template @@ -846,6 +857,7 @@ make_block_2d_copy_B(CopyOp const& op, // Copy operation TiledMMA const& mma, // TiledMMA instance Tensor const& gmem) // Global tensor { + static_assert(is_xe_block_2d_atom_v, "Expected a block 2D atom"); using ValType = typename GEngine::value_type; return make_block_2d_copy_B(op, mma, gmem.stride()).with(gmem); } @@ -887,7 +899,7 @@ make_block_2d_copy_B(CopyOp const& op, // Copy operation auto thr_vmnk = mma.get_thr_layout_vmnk(); // (ThrV,ThrM,ThrN,ThrK) -> thr auto shape_vmnk = shape(thr_vmnk); // (ThrV,ThrM,ThrN,ThrK) auto drop_m = make_layout(shape_vmnk, - make_stride(_1{}, _0{}, get<0>(shape_vmnk), _0{}, + make_stride(_1{}, _0{}, get<0>(shape_vmnk), get<0>(shape_vmnk) * get<2>(shape_vmnk))); // (ThrV,ThrM,ThrN,ThrK) -> (ThrV,ThrN,ThrK) auto thr_to_vnk = composition(drop_m, right_inverse(thr_vmnk)); // thr -> (ThrV,ThrN,ThrK) @@ -898,7 +910,7 @@ make_block_2d_copy_B(CopyOp const& op, // Copy operation make_tile(sg_to_vnk, _)); // (SG,V) -> (N,K) // Derive copy tile layout and create TiledCopy - return make_block_2d_copy_X(op, mma, gstride, x_mode, y_mode, tile_nk, svB); + return make_block_2d_copy_X(op, gstride, x_mode, y_mode, tile_nk, svB); } template @@ -911,15 +923,26 @@ make_block_2d_copy_C(TiledMMA const& mma, // TiledMMA instance return make_block_2d_copy_C(mma, gmem.stride()).with(gmem); } -template +template CUTE_HOST_DEVICE auto -make_block_2d_copy_C(CopyOp const& op, // Copy operation - TiledMMA const& mma, // TiledMMA instance +make_block_2d_copy_D(TiledMMA const& mma, // TiledMMA instance Tensor const& gmem) // Global tensor { using ValType = typename GEngine::value_type; - return make_block_2d_copy_C(op, mma, gmem.stride()).with(gmem); + return make_block_2d_copy_D(mma, gmem.stride()).with(gmem); +} + +template +CUTE_HOST_DEVICE +auto +make_block_2d_copy_CD(CopyOp const& op, // Copy operation + TiledMMA const& mma, // TiledMMA instance + Tensor const& gmem) // Global tensor +{ + static_assert(is_xe_block_2d_atom_v, "Expected a block 2D atom"); + using ValType = typename GEngine::value_type; + return make_block_2d_copy_CD(op, mma, gmem.stride()).with(gmem); } template @@ -928,32 +951,46 @@ auto make_block_2d_copy_C(TiledMMA const& mma, // TiledMMA instance Stride const& gstride) // Global memory strides { - using MMAType = typename TiledMMA::ValTypeA; + using MMAType = typename TiledMMA::ValTypeC; auto cC = make_identity_tensor(select<0,1>(mma.tile_mnk())); - auto op = block_2d_selector( + auto op = block_2d_selector( mma.get_slice(0).atom_partition_C(cC).layout(), gstride ); - return make_block_2d_copy_C(op, mma, gstride); + return make_block_2d_copy_CD(op, mma, gstride); } -template +template CUTE_HOST_DEVICE auto -make_block_2d_copy_C(CopyOp const& op, // Copy operation - TiledMMA const& mma, // TiledMMA instance +make_block_2d_copy_D(TiledMMA const& mma, // TiledMMA instance Stride const& gstride) // Global memory strides { - return make_block_2d_copy_C(op, mma, gstride, find_x_mode(gstride), find_y_mode(gstride)); + using MMAType = typename TiledMMA::ValTypeD; + auto cD = make_identity_tensor(select<0,1>(mma.tile_mnk())); + auto op = block_2d_selector( + mma.get_slice(0).atom_partition_C(cD).layout(), gstride + ); + return make_block_2d_copy_CD(op, mma, gstride); +} + +template +CUTE_HOST_DEVICE +auto +make_block_2d_copy_CD(CopyOp const& op, // Copy operation + TiledMMA const& mma, // TiledMMA instance + Stride const& gstride) // Global memory strides +{ + return make_block_2d_copy_CD(op, mma, gstride, find_x_mode(gstride), find_y_mode(gstride)); } template CUTE_HOST_DEVICE auto -make_block_2d_copy_C(CopyOp const& op, // Copy operation - TiledMMA const& mma, // TiledMMA instance - Stride const& gstride, // Global memory strides - XMode const& x_mode, // x, y modes - YMode const& y_mode) +make_block_2d_copy_CD(CopyOp const& op, // Copy operation + TiledMMA const& mma, // TiledMMA instance + Stride const& gstride, // Global memory strides + XMode const& x_mode, // x, y modes + YMode const& y_mode) { // Retrieve MMA atom's (subgroup, value) -> (M,N) layout auto tile_mn = select<0,1>(mma.tile_mnk()); @@ -971,7 +1008,162 @@ make_block_2d_copy_C(CopyOp const& op, // Copy operation make_tile(sg_to_vmn, _)); // (SG,V) -> (M,N) // Derive copy tile layout and create TiledCopy - return make_block_2d_copy_X(op, mma, gstride, x_mode, y_mode, tile_mn, svC); + return make_block_2d_copy_X(op, gstride, x_mode, y_mode, tile_mn, svC); +} + +// Variants of make_block_2d_copy_C/D where the C/D tile is further subdivided by the user. +// (e.g. split-k parallelization). + +template )> +CUTE_HOST_DEVICE +auto +make_block_2d_copy_C_subtiled(TiledMMA const& mma, // TiledMMA instance + SubtileTVCoordLayout const& stv_layout, // Subtile TV-layout: (T,V) -> coord + SubtileSGLayout const& ssg_layout, // Subtile subgroup layout: SG_K -> (m_subtile,n_subtile) + Tensor const& gmem) // Global tensor +{ + using ValType = typename GEngine::value_type; + return make_block_2d_copy_C_subtiled(mma, stv_layout, ssg_layout, gmem.stride()).with(gmem); +} + +template )> +CUTE_HOST_DEVICE +auto +make_block_2d_copy_D_subtiled(TiledMMA const& mma, // TiledMMA instance + SubtileTVCoordLayout const& stv_layout, // Subtile TV-layout: (T,V) -> coord + SubtileSGLayout const& ssg_layout, // Subtile subgroup layout: SG_K -> (m_subtile,n_subtile) + Tensor const& gmem) // Global tensor +{ + using ValType = typename GEngine::value_type; + return make_block_2d_copy_D_subtiled(mma, stv_layout, ssg_layout, gmem.stride()).with(gmem); +} + +template )> +CUTE_HOST_DEVICE +auto +make_block_2d_copy_CD_subtiled(CopyOp const& op, // Copy operation + TiledMMA const& mma, // TiledMMA instance + SubtileShape const& sshape, // Subtile shape: (m,n) + SubtileSGLayout const& ssg_layout, // Subtile subgroup layout: SG_K -> (m_subtile,n_subtile) + Tensor const& gmem) // Global tensor +{ + using ValType = typename GEngine::value_type; + return make_block_2d_copy_CD_subtiled(op, sshape, ssg_layout, mma, gmem.stride()).with(gmem); +} + +template )> +CUTE_HOST_DEVICE +auto +make_block_2d_copy_C_subtiled(TiledMMA const& mma, // TiledMMA instance + SubtileTVCoordLayout const& stv_layout, // Subtile TV-layout: (T,V) -> coord + SubtileSGLayout const& ssg_layout, // Subtile subgroup layout: SG_K -> (m_subtile,n_subtile) + Stride const& gstride) // Global memory strides +{ + using MMAType = typename TiledMMA::ValTypeA; + auto op = block_2d_selector(stv_layout, gstride); + return make_block_2d_copy_CD_subtiled(op, mma, atuple_coshape(stv_layout), ssg_layout, gstride); +} + +template )> +CUTE_HOST_DEVICE +auto +make_block_2d_copy_D_subtiled(TiledMMA const& mma, // TiledMMA instance + SubtileTVCoordLayout const& stv_layout, // Subtile TV-layout: (T,V) -> coord + SubtileSGLayout const& ssg_layout, // Subtile subgroup layout: SG_K -> (m_subtile,n_subtile) + Stride const& gstride) // Global memory strides +{ + using MMAType = typename TiledMMA::ValTypeA; + auto op = block_2d_selector(stv_layout, gstride); + return make_block_2d_copy_CD_subtiled(op, mma, atuple_coshape(stv_layout), ssg_layout, gstride); +} + +template )> +CUTE_HOST_DEVICE +auto +make_block_2d_copy_CD_subtiled(CopyOp const& op, // Copy operation + TiledMMA const& mma, // TiledMMA instance + SubtileShape const& sshape, // Subtile shape: (m,n) + SubtileSGLayout const& ssg_layout, // Subtile subgroup layout: SG_K -> (m_subtile,n_subtile) + Stride const& gstride) // Global memory strides +{ + return make_block_2d_copy_CD_subtiled(op, mma, sshape, ssg_layout, gstride, + find_x_mode(gstride), find_y_mode(gstride)); +} + +template )> +CUTE_HOST_DEVICE +auto +make_block_2d_copy_CD_subtiled(CopyOp const& op, // Copy operation + TiledMMA const& mma, // TiledMMA instance + SubtileShape const& sshape, // Subtile shape: (m,n) + SubtileSGLayout const& ssg_layout, // Subtile subgroup layout: SG_K -> (m_subtile,n_subtile) + Stride const& gstride, // Global memory strides + XMode const& x_mode, // x, y modes + YMode const& y_mode) +{ + // Expand subtile layout. + auto xssg_layout = make_layout(shape(ssg_layout), + elem_scale(stride(ssg_layout), sshape)); // SG_K -> (M,N) + + // Retrieve MMA atom's (subgroup, value) -> (M,N) layout. + // Allow cross-MMA tiling. + auto tile_mn = round_up(select<0,1>(mma.tile_mnk()), + atuple_coshape(xssg_layout)); + + auto thr_vmnk = mma.get_thr_layout_vmnk(); // (ThrV,ThrM,ThrN,ThrK) -> thr + auto shape_vmnk = shape(thr_vmnk); // (ThrV,ThrM,ThrN,ThrK) + auto drop_k = replace<3>(make_layout(shape_vmnk), + make_layout(get<3>(shape_vmnk), _0{})); // (ThrV,ThrM,ThrN,ThrK) -> (ThrV,ThrM,ThrN) + + auto thr_to_vmn = composition(drop_k, right_inverse(thr_vmnk)); // thr -> (ThrV,ThrM,ThrN) + auto sg_to_vmn = composition(thr_to_vmn, + make_layout(product(take<1,4>(shape_vmnk)), get<0>(shape_vmnk))); // SG -> (0,ThrM,ThrN) + + auto svC = composition(mma.thrfrg_C(make_layout(tile_mn)), + make_tile(sg_to_vmn, _)); // (SG,V) -> (M,N) + + // Add subtile modes. Limitations: + // - ThrK must be covered by a single mode in svC. + // - SubtileSGLayout must have a subtile for each ThrK, OR ThrK must be the last mode. + decltype(coalesce(get<0>(svC))) sC{}; + constexpr auto mode_thr_k = find_if(stride(sC), [](auto const &x) { return C>{}; }); + static_assert(shape(sC) == shape<3>(thr_vmnk), "ThrK split into multiple modes; unsupported"); + + auto k_to_mn = composition(make_layout(tile_mn), xssg_layout); // ThrK -> (M,N) + + static_assert(size(SubtileSGLayout{}) == shape<3>(thr_vmnk) || mode_thr_k + 1 >= rank(sC), + "Unsupported partially occupied ThrK scenario"); + + // Remove subtile value modes. + auto drop_subtiles = make_layout(zip(sshape, shape_div(tile_mn, sshape)), + zip(stride(make_layout(tile_mn)), Stride<_0,_0>{})); + + auto svC_tiled = make_layout(replace(sC, k_to_mn), + coalesce(composition(drop_subtiles, get<1>(svC)))); + + // Derive copy tile layout and create TiledCopy + return make_block_2d_copy_X(op, gstride, x_mode, y_mode, tile_mn, svC_tiled); } // Prefetch selection and creation. @@ -1024,7 +1216,7 @@ make_block_2d_prefetch(const Shape&, Stride const& stride, const XMo constexpr auto shape_y = get(Shape{}); // Try to retrieve whole cache lines (contiguous dimension = x) - constexpr auto width = cute::min(shape_x, 512 / sizeof_bits_v); + constexpr auto width = cute::gcd(shape_x, 512 / sizeof_bits_v); // Do a preliminary tiling to choose appropriate height. constexpr int n_sg_x = cute::gcd(SGCount, ceil_div(shape_x, width)); @@ -1072,13 +1264,54 @@ make_block_2d_prefetch(PrefetchOp const& op, Int{}); // Tile atom grid across collective op tile. - auto sv_layout = zipped_divide(make_layout(collective_op_tile), atom_shape); + auto sv_layout = zipped_divide(make_layout(atom_shape), collective_op_tile); // Create the TiledCopy object. return make_block_2d_copy(op, stride, x_mode, y_mode, atom_shape, sv_layout); } +// +// Block 2D Copy Utilities - Helper functions for conditional copy operation selection +// +template +auto get_block_2d_copy_A(TiledMMA const& tiled_mma, ATensor const& a_tensor) +{ + if constexpr (!std::is_void_v) { + return make_block_2d_copy_A(CopyOp{}, tiled_mma, a_tensor); + } else { + return make_block_2d_copy_A(tiled_mma, a_tensor); + } +} + +template +auto get_block_2d_copy_B(TiledMMA const& tiled_mma, BTensor const& b_tensor) +{ + if constexpr (!std::is_void_v) { + return make_block_2d_copy_B(CopyOp{}, tiled_mma, b_tensor); + } else { + return make_block_2d_copy_B(tiled_mma, b_tensor); + } +} + +template +auto get_block_2d_copy_C(TiledMMA const& tiled_mma, CTensor const& c_tensor) +{ + if constexpr (!std::is_void_v) { + return make_block_2d_copy_CD(CopyOp{}, tiled_mma, c_tensor); + } else { + return make_block_2d_copy_C(tiled_mma, c_tensor); + } +} +template +auto get_block_2d_copy_D(TiledMMA const& tiled_mma, DTensor const& d_tensor) +{ + if constexpr (!std::is_void_v) { + return make_block_2d_copy_CD(CopyOp{}, tiled_mma, d_tensor); + } else { + return make_block_2d_copy_D(tiled_mma, d_tensor); + } +} // // Display utilities diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index b950a8d64f..4909aaae98 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -236,6 +236,11 @@ struct TiledMMA : MMA_Atom return thr_layout_vmnk_; } + CUTE_HOST_DEVICE constexpr auto + get_atom_layout_mnk() const { + return AtomLayoutMNK{}; + } + // Tile a tensor or a layout from shape // (M,N,...) // to shape @@ -263,7 +268,7 @@ struct TiledMMA : MMA_Atom make_layout(size<1>(AtomShape_MNK{}))); auto c_tensor = zipped_divide(t_tensor, c_tile); // ((AtomM,AtomN),(RestM,RestN)) - // Transform the Atom mode from (M,K) to (Thr,Val) + // Transform the Atom mode from (M,N) to (Thr,Val) auto tv_tensor = c_tensor.compose(AtomLayoutC_TV{},_); // ((ThrV,FrgV),(RestM,RestN)) // Tile the tensor for the C-threads @@ -341,7 +346,7 @@ struct TiledMMA : MMA_Atom make_layout(size<2>(AtomShape_MNK{}))); auto b_tensor = zipped_divide(t_tensor, b_tile); // ((AtomN,AtomK),(RestN,RestK)) - // Transform the Atom mode from (M,K) to (Thr,Val) + // Transform the Atom mode from (N,K) to (Thr,Val) auto tv_tensor = b_tensor.compose(AtomLayoutB_TV{},_); // ((ThrV,FrgV),(RestN,RestK)) // Tile the tensor for the Thread diff --git a/include/cute/atom/reorder_atom_xe.hpp b/include/cute/atom/reorder_atom_xe.hpp index c7bc546fce..7151fc7a15 100644 --- a/include/cute/atom/reorder_atom_xe.hpp +++ b/include/cute/atom/reorder_atom_xe.hpp @@ -115,8 +115,8 @@ constexpr ReorderKind classify_xe_reorder() template -constexpr auto choose_xe_reorder_impl(SLayout const& slayout, // (src thr, src val) -> coord - DLayout const& dlayout) { // (dst thr, dst val) -> coord +auto choose_xe_reorder_impl(SLayout const& slayout, // (src thr, src val) -> coord + DLayout const& dlayout) { // (dst thr, dst val) -> coord // Calculate data transformation, interleaving WI-owned values: // (thr0,val0) ... (thr15,val0), (thr0,val1), ..., (thr15,val1), ... auto rlayout = coalesce(composition(right_inverse(dlayout), slayout)); // src index -> dst index @@ -135,7 +135,7 @@ constexpr auto choose_xe_reorder_impl(SLayout const& slayout, // (src thr, src else if constexpr (is_subbyte_v) return ReorderDispatchRelayoutConvert{}; else if constexpr (!is_same_v, remove_cv_t>) - return ReorderDispatchRelayoutConvert{}; + return ReorderDispatchConvertRelayout{}; else return ReorderDispatchXeGeneric{}; } @@ -176,20 +176,30 @@ void reorder_impl(ReorderDispatchXeGeneric const&, Tensor const& src, // WI fragment Tensor & dst, // WI fragment - SLayout const& slayout, // (src thr, src val) -> coord - DLayout const& dlayout) // (dst thr, dst val) -> coord + SLayout const&, // (src thr, src val) -> coord + DLayout const&) // (dst thr, dst val) -> coord { using SrcType = typename SEngine::element_type; using DstType = typename DEngine::element_type; static_assert(is_same_v, "No type conversions allowed on this path"); - auto rlayout = coalesce(composition(right_inverse(dlayout), slayout)); // src index -> dst index - auto ilayout = coalesce(composition(right_inverse(slayout), dlayout)); // dst index -> src index + static constexpr SLayout slayout{}; + static constexpr DLayout dlayout{}; + static constexpr auto rlayout = coalesce(composition(right_inverse(dlayout), slayout)); // src index -> dst index + static constexpr auto ilayout = coalesce(composition(right_inverse(slayout), dlayout)); // dst index -> src index + + // Check for broadcast cases. This path allows a single src element to be copied + // to multiple dst elements (useful for grouped quantization cases). + // Broadcast in (flattened) mode 0 requires special handling. + static constexpr bool has_broadcast = (size(DLayoutWI{}) > size(SLayoutWI{})); + static constexpr bool mode0_broadcast = has_broadcast && (stride<0>(ilayout) == _0{}); // Decide whether to stride on src or dst, depending on which allows a longer vector length. static constexpr int elems_per_grf = 64 / sizeof(SrcType); - static constexpr int ds_vl = cute::min(32, cute::min(shape<0>(rlayout), elems_per_grf / stride<0>(rlayout))); - static constexpr int ss_vl = cute::min(32, cute::min(shape<0>(ilayout), elems_per_grf / stride<0>(ilayout))); + static constexpr auto dstride = stride<0>(rlayout); + static constexpr int sstride = mode0_broadcast ? 1 : stride<0>(ilayout); + static constexpr int ds_vl = cute::min(32, cute::min(shape<0>(rlayout), elems_per_grf / dstride)); + static constexpr int ss_vl = cute::min(32, cute::min(shape<0>(ilayout), elems_per_grf / sstride)); // Make dst live, to prevent compiler from inserting its own initialization. #ifdef __SYCL_DEVICE_ONLY__ @@ -202,19 +212,28 @@ reorder_impl(ReorderDispatchXeGeneric const&, } #endif - if constexpr (ss_vl >= ds_vl) { + if constexpr (mode0_broadcast) { + // Stride on dst, with mode-0 broadcast. + for_each(make_seq(ilayout)>{}, [&](auto j) { + for_each(make_seq{}, [&](auto i) { + constexpr auto sidx = i * ds_vl; + constexpr auto didx = rlayout(sidx) + j; + reorder_span(src, dst); + }); + }); + } else if constexpr (ss_vl >= ds_vl || has_broadcast) { // Stride on src. For simplicity, take 1 GRF at a time. for_each(make_seq{}, [&](auto i) { constexpr auto didx = i * ss_vl; constexpr auto sidx = ilayout(didx); - reorder_span(decltype(ilayout){}), 1, sidx, didx>(src, dst); + reorder_span(src, dst); }); } else { // Stride on dst. for_each(make_seq{}, [&](auto i) { constexpr auto sidx = i * ds_vl; constexpr auto didx = rlayout(sidx); - reorder_span(decltype(rlayout){}), sidx, didx>(src, dst); + reorder_span(src, dst); }); } } diff --git a/include/cute/layout.hpp b/include/cute/layout.hpp index 76446f0244..0b6da2e0a8 100644 --- a/include/cute/layout.hpp +++ b/include/cute/layout.hpp @@ -326,6 +326,9 @@ struct is_layout : false_type {}; template struct is_layout> : true_type {}; +template +static constexpr bool is_layout_v = is_layout::value; + // // Layout construction // @@ -682,8 +685,10 @@ CUTE_HOST_DEVICE constexpr auto atuple_coshape(Layout const& layout) { + auto _0E0 = ScaledBasis,0>{}; auto flayout = filter(flatten(layout)); - return inner_product_atuple_max(shape(flayout), stride(flayout)); + auto coshape = inner_product_atuple_max(shape(flayout), stride(flayout)) + _0E0 + _0E0; + return cute::transform(coshape, [](auto a) { return cute::max(a, _1{}); }); } // Return the codomain size of a mode @@ -1062,6 +1067,15 @@ group(Layout const& layout) group(layout.stride())); } +template +CUTE_HOST_DEVICE constexpr +auto +remove(Layout const& layout) +{ + return make_layout(remove(layout.shape()), + remove(layout.stride())); +} + // // Composition of two layouts: lhs o rhs // @post compatible(rhs, result) diff --git a/include/cute/tensor_sg.hpp b/include/cute/tensor_sg.hpp index b128bf4e13..4421914d50 100644 --- a/include/cute/tensor_sg.hpp +++ b/include/cute/tensor_sg.hpp @@ -32,6 +32,7 @@ #pragma once #include // cute::Tensor +#include // intel::_SGSize namespace cute { @@ -74,7 +75,7 @@ struct SubgroupTensor : Tensor *this = static_cast(base); } - static constexpr int rank = Layout::rank; + static constexpr int rank = Layout::rank; CUTE_HOST_DEVICE constexpr decltype(auto) @@ -89,22 +90,88 @@ struct SubgroupTensor : Tensor } }; +template +struct is_sg_tensor : false_type {}; +template +struct is_sg_tensor> : true_type {}; + template struct is_tensor> : true_type {}; -template::value)> +// Create a SubgroupTensor from its component parts: +// a regular rmem Tensor and the subgroup-scope TV-layout. +template ::value)> CUTE_HOST_DEVICE -constexpr auto -make_subgroup_tensor(Tensor const& tensor, SubgroupTVLayout const&) +constexpr decltype(auto) +make_subgroup_tensor(Tensor& tensor, SubgroupTVLayout const& tv_layout) { static_assert(is_static_v, "Subgroup TV layout must be static"); static_assert(is_rmem_v, "Expected an rmem tensor"); - return static_cast const&>(tensor); + return make_subgroup_tensor(make_tensor(tensor.data(), tensor.layout()), tv_layout); +} + +template ::value)> +CUTE_HOST_DEVICE +constexpr decltype(auto) +make_subgroup_tensor(Tensor&& tensor, SubgroupTVLayout const&) +{ + static_assert(is_static_v, "Subgroup TV layout must be static"); + static_assert(is_rmem_v, "Expected an rmem tensor"); + return static_cast&&>(tensor); +} + +// Create a new owning SubgroupTensor with the given subgroup-level layout. +// Elements are assigned to threads following the normal Xe interleaved mapping +// (i.e. work-item i gets elements i, i + 16, i + 32, ...) +template +CUTE_HOST_DEVICE +constexpr auto +make_subgroup_tensor(Layout const& sg_layout) +{ + using _SG = intel::_SGSize; + auto ilayout = make_layout(make_shape(_SG{}, size(sg_layout) / _SG{}), + make_stride(_1{}, _16{})); + auto sv_layout = sg_layout.compose(ilayout); + return make_subgroup_tensor(make_fragment_like(sv_layout(0,_)), sv_layout); +} + +// Create a new owning SubgroupTensor with a subgroup-level layout, constructed +// from the argument list with make_layout. +template +CUTE_HOST_DEVICE +constexpr auto +make_subgroup_tensor(Args const&... args) +{ + return make_subgroup_tensor(make_layout(args...)); +} + + +// Replicate a subgroup fragment in a given mode. +template +CUTE_HOST_DEVICE +constexpr auto +expand_sg_fragment_helper(SubgroupTensor const&) +{ + constexpr SubgroupTensor frag; + constexpr int ModeSize = get(atuple_coshape(frag.tv_layout())); + + auto xlayout = append(frag.layout(), + Layout, C>>{}); + auto xv_layout = append(get<1>(frag.tv_layout()), + make_layout(C{}, C{} * E{})); + auto xtv_layout = make_layout(get<0>(frag.tv_layout()), xv_layout); + + return make_subgroup_tensor(make_tensor(xlayout), xtv_layout); } +template +using expand_sg_fragment_t = decltype(expand_sg_fragment_helper(SGTensor{})); // // Display utilities diff --git a/include/cute/util/type_traits.hpp b/include/cute/util/type_traits.hpp index d7add6812f..ad340d85d6 100644 --- a/include/cute/util/type_traits.hpp +++ b/include/cute/util/type_traits.hpp @@ -327,4 +327,22 @@ struct is_any_of { template inline constexpr bool is_any_of_v = is_any_of::value; +// +// replace_void_t +// +template +using replace_void_t = conditional_t, ReplacementTypeIfVoid, T>; + +// +// is_complete -- check for complete types +// +template +struct is_complete : CUTE_STL_NAMESPACE::false_type {}; + +template +struct is_complete : std::true_type {}; + +template +static constexpr bool is_complete_v = is_complete::value; + } // end namespace cute diff --git a/include/cute/util/xe_split_barrier.hpp b/include/cute/util/xe_split_barrier.hpp new file mode 100644 index 0000000000..ad96f8df1f --- /dev/null +++ b/include/cute/util/xe_split_barrier.hpp @@ -0,0 +1,82 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +enum SPIRVScope { + ScopeCrossDevice = 0, + ScopeDevice = 1, + ScopeWorkgroup = 2, + ScopeSubgroup = 3, + ScopeInvocation = 4, +}; + +enum SPIRVMemorySemantics { + SemanticsNone = 0, + SemanticsAcquire = 0x2, + SemanticsRelease = 0x4, + SemanticsAcquireRelease = 0x8, + SemanticsSGMemory = 0x80, + SemanticsWGMemory = 0x100, + SemanticsCrossWGMemory = 0x200, +}; + +#ifdef __SYCL_DEVICE_ONLY__ +SYCL_EXTERNAL __attribute__((convergent)) void __spirv_ControlBarrierWaitINTEL(int execution_scope, int memory_scope, int memory_semantics); +SYCL_EXTERNAL __attribute__((convergent)) void __spirv_ControlBarrierArriveINTEL(int execution_scope, int memory_scope, int memory_semantics); +#endif + +namespace cute +{ + +CUTE_HOST_DEVICE void barrier_arrive(SPIRVScope scope, int memory_semantics = SemanticsNone) { +#ifdef __SYCL_DEVICE_ONLY__ + __spirv_ControlBarrierArriveINTEL(scope, scope, memory_semantics); +#endif +} +CUTE_HOST_DEVICE void barrier_wait(SPIRVScope scope, int memory_semantics = SemanticsNone) { +#ifdef __SYCL_DEVICE_ONLY__ + __spirv_ControlBarrierWaitINTEL(scope, scope, memory_semantics); +#endif +} + +CUTE_HOST_DEVICE void barrier_arrive(int scope, int memory_scope = ScopeCrossDevice, int memory_semantics = SemanticsNone) { +#ifdef __SYCL_DEVICE_ONLY__ + __spirv_ControlBarrierArriveINTEL(scope, memory_scope, memory_semantics); +#endif +} +CUTE_HOST_DEVICE void barrier_wait(int scope, int memory_scope = ScopeCrossDevice, int memory_semantics = SemanticsNone) { +#ifdef __SYCL_DEVICE_ONLY__ + __spirv_ControlBarrierWaitINTEL(scope, memory_scope, memory_semantics); +#endif +} + +} // end namespace cute diff --git a/include/cutlass/arch/arch.h b/include/cutlass/arch/arch.h index 3e4f55c5bd..e344b2922a 100644 --- a/include/cutlass/arch/arch.h +++ b/include/cutlass/arch/arch.h @@ -123,6 +123,17 @@ struct IntelXe { static int const kMinComputeCapability = 0; }; +// Intel Xe architecture aliases for library generation compatibility +// Xe12 = PVC (Ponte Vecchio) +struct Xe12 : IntelXe { + static int const kIntelXeArch = 12; +}; + +// Xe20 = BMG (Battlemage) +struct Xe20 : IntelXe { + static int const kIntelXeArch = 20; +}; + struct Agnostic { static int const kMinComputeCapability = 1; }; diff --git a/include/cutlass/detail/collective/mixed_input_utils.hpp b/include/cutlass/detail/collective/mixed_input_utils.hpp index 61741b3af1..af21b1f6c7 100644 --- a/include/cutlass/detail/collective/mixed_input_utils.hpp +++ b/include/cutlass/detail/collective/mixed_input_utils.hpp @@ -748,7 +748,7 @@ struct MixedInputUtils { auto smem_tiled_copy_S = cute::get<0>(partitioned_transform_extra_info); auto&& scales = cute::get<1>(partitioned_transform_extra_info); using ScaleType = decltype(scales); - auto tSrS = make_tensor(static_cast(scales).data(), scales.layout()); + auto tSrS = make_tensor(scales.data(), scales.layout()); auto tSsS = cute::get<2>(partitioned_transform_extra_info); copy(smem_tiled_copy_S, tSsS(_,_,_,_,load2transform_consumer_index), tSrS); @@ -757,7 +757,7 @@ struct MixedInputUtils { } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { auto&& zeros = cute::get<3>(partitioned_transform_extra_info); using ZeroType = decltype(zeros); - auto tZrZ = make_tensor(static_cast(zeros).data(), zeros.layout()); + auto tZrZ = make_tensor(zeros.data(), zeros.layout()); auto tZsZ = cute::get<4>(partitioned_transform_extra_info); copy(smem_tiled_copy_S, tZsZ(_,_,_,_,load2transform_consumer_index), tZrZ); @@ -1061,9 +1061,8 @@ struct MixedInputUtils { using ScaleArray = cutlass::Array; auto scale_arr = recast(filter_zeros(scales)); - if constexpr (is_same_v){ - Tensor dst_vm = cute::group_modes<1,-1>(cute::zipped_divide(dst, pack)); - Tensor scales_vm = cute::group_modes<1,-1>(cute::zipped_divide(scales, pack)); + if constexpr (is_same_v){ + Tensor scales_vm = cute::group_modes<1,-1>(cute::zipped_divide(scales, pack)); for (int i = 0; i < size<1>(dst_vm); ++i){ auto&& r = cute::recast(dst_vm(_,i))(0); @@ -1239,13 +1238,7 @@ struct MixedInputUtils { Tensor tCsS = cta_mma.partition_A(sS); Tensor tSsS = smem_thr_copy_S.partition_S(tCsS); Tensor tSrS = make_tensor(tSsS(_,_,_,_,0).shape()); -#if 0 - if(cute::thread(128, 0)){ - print("sS: ");print(sS);print("\n"); - print("tSsS: ");print(tSsS);print("\n"); - print("tSrS: ");print(tSrS);print("\n"); - } -#endif + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { return cute::make_tuple(smem_tiled_copy_S, tSrS, tSsS); } @@ -1254,16 +1247,6 @@ struct MixedInputUtils { Tensor tCsZ = cta_mma.partition_A(sZ); Tensor tZsZ = smem_thr_copy_S.partition_S(tCsZ); Tensor tZrZ = make_tensor(tZsZ(_,_,_,_,0).shape()); -#if 0 - if(cute::thread(128, 0)){ - print("sS: ");print(sS);print("\n"); - print("tSsS: ");print(tSsS);print("\n"); - print("tSrS: ");print(tSrS);print("\n"); - print("sZ: ");print(sZ);print("\n"); - print("tZsZ: ");print(tZsZ);print("\n"); - print("tZrZ: ");print(tZrZ);print("\n"); - } -#endif return cute::make_tuple(smem_tiled_copy_S, tSrS, tSsS, tZrZ, tZsZ); } else { diff --git a/include/cutlass/epilogue/collective/builders/sm100_builder.inl b/include/cutlass/epilogue/collective/builders/sm100_builder.inl index 4c3fd96b53..7c14ac3378 100644 --- a/include/cutlass/epilogue/collective/builders/sm100_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm100_builder.inl @@ -582,7 +582,13 @@ sm100_sparse_get_tma_dispatch_policy() { * Selected op also maximizes the TMEM_LOAD shape in order to minimize TMEM_LOADs issued, * subject to the constraint of the provided per-warp tmem subpartition shape **/ -template +template< + class GmemStrideTypeD, + class ElementAccumulator, + class ElementD, + class TmemShape_MN, + bool IsBlockScaleSupported +> constexpr auto sm100_get_tmem_load_op() { using namespace cute; @@ -958,6 +964,172 @@ struct CallbacksBuilder< >; }; +// Attempts to compute a reasonably performant epilogue tile or allows the user to provide one. +template< + class OpClass, + class CtaTileShape_MNK, + class EpilogueTileType, + class TmemWarpShape_MN, + class ElementC_, + class GmemStrideTypeC, + class ElementD, + class GmemStrideTypeD, + bool IsPerColScaleSupported +> +static constexpr auto +sm100_dense_compute_tile_shape_or_override() { + using namespace cute; + static_assert(!cute::is_same_v && !cute::is_same_v); + + constexpr bool DisableSource = cute::is_void_v; + using ElementC = cute::conditional_t; + + if constexpr (is_same_v && + is_same_v && + size<1>(CtaTileShape_MNK{}) == 256) { + constexpr int CtaM = size<0>(CtaTileShape_MNK{}); + constexpr int WarpM = size<0>(TmemWarpShape_MN{}); + constexpr int DpFull = 32; + constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load + // Note: + // Set Epi_Tile_N to 128 support OverlappingAccum for the largest tile. + // This is a general workable epi_tile_N which does not promise best perf. + return make_tile(Int{}, Int<128>{}); + } + else if constexpr (is_same_v) { + constexpr int CtaM = size<0>(CtaTileShape_MNK{}); + constexpr int CtaN = size<1>(CtaTileShape_MNK{}); + constexpr int WarpM = size<0>(TmemWarpShape_MN{}); + constexpr int WarpN = size<1>(TmemWarpShape_MN{}); + constexpr int MaxBits = cute::max(sizeof_bits_v, sizeof_bits_v); + + constexpr int DpFull = 32; // tmem datapaths in 1 subpartition + constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load + constexpr int N_perf = [&]() constexpr { // Known subtile sizes tested for perf + // Epilogues w/o residual load are less sensitive to smem allocation + // Target a fixed amount of compute per epilogue iteration + if (DisableSource) { + if (MaxBits == 4) { + // Make epilogue tile larger to reduce the epilogue iterations. + // 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same. + constexpr int ComputeElts = 8192; + return ComputeElts / M; + } + constexpr int ComputeElts = 4096; + return ComputeElts / M; + } + // Epilogues w/ residual load are more sensitive to smem allocation + // Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize + else { + if (MaxBits == 32) { + return (CtaM > 64 && CtaN <= 128) ? 16 : 32; + } + // Per-column scaling is high register pressure, reduce tile to prevent spills + else if (IsPerColScaleSupported) { + return 32; + } + else if (MaxBits == 16) { + return (CtaN <= 128) ? 32 : 64; + } + else { + return 64; + } + } + }(); + constexpr int N_min_C = (DisableSource || detail::is_m_major()) ? 8 * WarpN + : (sizeof_bits_v == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type + : 128 / sizeof_bits_v * WarpN; + constexpr int N_min_D = (detail::is_m_major()) ? 8 * WarpN + : (sizeof_bits_v == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type + : 128 / sizeof_bits_v * WarpN; + constexpr int N = cute::min(CtaN, cute::max(N_perf, N_min_C, N_min_D)); + static_assert(CtaN >= N_min_C && CtaN >= N_min_D, "CTA tile too small"); + + // stride by tmem warp layout and return a by-mode tiler + auto tile_m = Layout>{}; + auto tile_n = Layout,Int< WarpN>>, + Stride,Int>>{}; + + return make_tile(tile_m, coalesce(tile_n)); + } + else { + static_assert(cute::is_tuple::value && not is_layout::value, + "EpilogueTile must be a cute::Tile or cute::Shape"); + + EpilogueTileType epi_tile; + constexpr int M = size<0>(shape(epi_tile)); + constexpr int N = size<1>(shape(epi_tile)); + static_assert(N % 8 == 0, "Unsupported tile shape"); + + return epi_tile; + } +} + +template< + bool Is2SmMma, + class MmaTileShape_MNK +> +static constexpr auto +sm100_tmem_warps() { + if constexpr (Is2SmMma && size<0>(MmaTileShape_MNK{}) == 128) { + return Shape<_2,_2>{}; + } + else { + return Shape<_4,_1>{}; + } +} + +template< + bool Is2SmMma, + class MmaTileShape_MNK +> +static constexpr auto +sm100_cta_tile_shape() { + if constexpr (Is2SmMma) { // 2x1 threadblock shape + auto [mma_tile_m, mma_tile_n, mma_tile_k] = MmaTileShape_MNK{}; + auto cta_tile_m = reverse(shape_div(reverse(mma_tile_m), _2{})); // first MmaTile_M/2 elements, preserve multimode + return make_shape(cta_tile_m, mma_tile_n, mma_tile_k); + } + else { // 1x1 threadblock shape + return MmaTileShape_MNK{}; + } +} + +template< + class EpilogueScheduleType, + class ElementC_, + class ElementD, + int EpiTiles, + int FragmentSize +> +static constexpr auto +sm100_dense_dispatch_policy() { + // 8b residuals load fast and consume little smem, so the perf cost of waiting on stores to finish outweighs the cost of extra allocation + constexpr bool ReuseSmem = sizeof_bits_v > 8; + // TMA store delay performs worse with residual loads + constexpr bool DelayTmaStore = is_void_v; + + constexpr int StagesD = cute::min(EpiTiles, 2); + constexpr int StagesC = ReuseSmem ? cute::max(cute::min(EpiTiles, 4), StagesD+1) + : cute::min(EpiTiles, 4); + + if constexpr (is_base_of_v || + is_base_of_v) { + return Sm100PtrArrayNoSmemWarpSpecialized{}; + } + else if constexpr (is_base_of_v || is_base_of_v) { + return Sm100NoSmemWarpSpecialized{}; + } + else if constexpr (is_same_v || + is_same_v) { + constexpr bool DelayTmaStore_ = false; // TMA store delay complicates tensormap updates for Ptr-Array GEMMs + return Sm100PtrArrayTmaWarpSpecialized{}; + } + else { + return Sm100TmaWarpSpecialized{}; + } +} + // Helper for building TMA warp-specialized collective epilogues, specialized by // the fusion operation performed and the dispatch policy to use. template < @@ -1017,17 +1189,7 @@ private: } } using CtaTileShape_MNK = decltype(cta_tile_shape()); - - static constexpr auto - tmem_warps() { - if constexpr (Is2SmMma && size<0>(MmaTileShape_MNK{}) == 128) { - return Shape<_2,_2>{}; - } - else { - return Shape<_4,_1>{}; - } - } - using TmemWarpShape_MN = decltype(tmem_warps()); + using TmemWarpShape_MN = decltype(detail::sm100_tmem_warps()); // Attempts to compute a reasonably performant epilogue tile or allows the user to provide one. static constexpr auto @@ -1041,84 +1203,10 @@ private: ElementC_, GmemStrideTypeC, ElementD, GmemStrideTypeD, Schedule, FusionOp>(); } - else if constexpr (is_same_v && - is_same_v && - size<1>(CtaTileShape_MNK{}) == 256) { - constexpr int CtaM = size<0>(CtaTileShape_MNK{}); - constexpr int WarpM = size<0>(TmemWarpShape_MN{}); - constexpr int DpFull = 32; - constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load - // Note: - // Set Epi_Tile_N to 128 support OverlappingAccum for the largest tile. - // This is a general workable epi_tile_N which does not promise best perf. - return make_tile(Int{}, Int<128>{}); - } - else if constexpr (is_same_v) { - constexpr int CtaM = size<0>(CtaTileShape_MNK{}); - constexpr int CtaN = size<1>(CtaTileShape_MNK{}); - constexpr int WarpM = size<0>(TmemWarpShape_MN{}); - constexpr int WarpN = size<1>(TmemWarpShape_MN{}); - constexpr int MaxBits = cute::max(sizeof_bits_v, sizeof_bits_v); - - constexpr int DpFull = 32; // tmem datapaths in 1 subpartition - constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load - constexpr int N_perf = [&]() constexpr { // Known subtile sizes tested for perf - // Epilogues w/o residual load are less sensitive to smem allocation - // Target a fixed amount of compute per epilogue iteration - if (DisableSource) { - if (MaxBits == 4) { - // Make epilogue tile larger to reduce the epilogue iterations. - // 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same. - constexpr int ComputeElts = 8192; - return ComputeElts / M; - } - constexpr int ComputeElts = 4096; - return ComputeElts / M; - } - // Epilogues w/ residual load are more sensitive to smem allocation - // Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize - else { - if (MaxBits == 32) { - return (CtaM > 64 && CtaN <= 128) ? 16 : 32; - } - // Per-column scaling is high register pressure, reduce tile to prevent spills - else if (FusionOp::IsPerColScaleSupported) { - return 32; - } - else if (MaxBits == 16) { - return (CtaN <= 128) ? 32 : 64; - } - else { - return 64; - } - } - }(); - constexpr int N_min_C = (DisableSource || detail::is_m_major()) ? 8 * WarpN - : (sizeof_bits_v == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type - : 128 / sizeof_bits_v * WarpN; - constexpr int N_min_D = (detail::is_m_major()) ? 8 * WarpN - : (sizeof_bits_v == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type - : 128 / sizeof_bits_v * WarpN; - constexpr int N = cute::min(CtaN, cute::max(N_perf, N_min_C, N_min_D)); - static_assert(CtaN >= N_min_C && CtaN >= N_min_D, "CTA tile too small"); - - // stride by tmem warp layout and return a by-mode tiler - auto tile_m = Layout>{}; - auto tile_n = Layout,Int< WarpN>>, - Stride,Int>>{}; - - return make_tile(tile_m, coalesce(tile_n)); - } else { - static_assert(cute::is_tuple::value && not is_layout::value, - "EpilogueTile must be a cute::Tile or cute::Shape"); - - EpilogueTileType epi_tile; - constexpr int M = size<0>(shape(epi_tile)); - constexpr int N = size<1>(shape(epi_tile)); - static_assert(N % 8 == 0, "Unsupported tile shape"); - - return epi_tile; + return sm100_dense_compute_tile_shape_or_override< + OpClass, CtaTileShape_MNK, EpilogueTileType, TmemWarpShape_MN, + ElementC_, GmemStrideTypeC, ElementD, GmemStrideTypeD, FusionOp::IsPerColScaleSupported>(); } } using EpilogueTile_MN = decltype(epilogue_tile()); @@ -1129,30 +1217,18 @@ private: using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTileShape_MN{}, TmemWarpShape_MN{})); using AccLoadOp = decltype(detail::sm100_get_tmem_load_op< - GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, FusionOp>()); + GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, + FusionOp::IsBlockScaleSupported + >()); static constexpr auto dispatch_policy() { - // 8b residuals load fast and consume little smem, so the perf cost of waiting on stores to finish outweighs the cost of extra allocation - constexpr bool ReuseSmem = sizeof_bits_v > 8; - // TMA store delay performs worse with residual loads - constexpr bool DelayTmaStore = is_void_v; - - constexpr int StagesD = cute::min(EpiTiles, 2); - constexpr int StagesC = ReuseSmem ? cute::max(cute::min(EpiTiles, 4), StagesD+1) - : cute::min(EpiTiles, 4); - if constexpr (is_same_v || is_same_v) { return detail::sparse::sm100_sparse_get_tma_dispatch_policy(); } - else if constexpr (is_same_v || - is_same_v) { - constexpr bool DelayTmaStore_ = false; // TMA store delay complicates tensormap updates for Ptr-Array GEMMs - return Sm100PtrArrayTmaWarpSpecialized{}; - } else { - return Sm100TmaWarpSpecialized{}; + return detail::sm100_dense_dispatch_policy(); } } @@ -1228,6 +1304,87 @@ public: >; }; +template< + class OpClass, + class MmaTileShape_MNK, + class EpilogueTileType, + class ElementAccumulator_, + class ElementC, + class ElementD, + class Schedule, + class GmemStrideTypeC, + class GmemStrideTypeD, + bool IsPerColScaleSupported, + bool IsBlockScaleSupported +> +struct Sm100EpilogueDescriptor { + using ElementAccumulator = ElementAccumulator_; + + static constexpr bool Is2SmMma = is_base_of_v || is_base_of_v; + using CtaTileShape_MNK = decltype(sm100_cta_tile_shape()); + using TileShape = CtaTileShape_MNK; + + using TmemWarpShape_MN = decltype(detail::sm100_tmem_warps()); + + using EpilogueTile = decltype( + sm100_dense_compute_tile_shape_or_override() + ); + + using EpilogueTileShape_MN = decltype(product_each(shape(EpilogueTile{}))); + static constexpr int EpiTiles = size(shape_div(take<0,2>(CtaTileShape_MNK{}), EpilogueTileShape_MN{})); + static constexpr int FragmentSize = size(EpilogueTileShape_MN{}) / NumThreadsPerWarpGroup; + + using DispatchPolicy = decltype(sm100_dense_dispatch_policy()); + + constexpr static int StagesC = DispatchPolicy::StagesC; + constexpr static int StagesD = DispatchPolicy::StagesD; + + using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTileShape_MN{}, TmemWarpShape_MN{})); + using AccLoadOp = decltype(detail::sm100_get_tmem_load_op< + GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, + IsBlockScaleSupported + >()); +}; + +// Get Stride, SmemLayout, and CopyOpS2R for AuxLoad node +template< + typename EpilogueDescriptor, + typename StrideOrLayoutTag, + typename ElementAux +> +struct Sm100AuxLoadDescriptor { + constexpr static int Stages = EpilogueDescriptor::StagesC; + using EpilogueTile = typename EpilogueDescriptor::EpilogueTile; + using Element = ElementAux; + using Stride = cutlass::detail::TagToStrideC_t; + + using SmemLayoutAtom = decltype(detail::sm100_get_epilogue_smem_swizzle_layout_atom< + Stride, ElementAux, EpilogueTile>()); + + using CopyOpS2R = decltype(detail::sm100_get_smem_load_op< + Stride, ElementAux, typename EpilogueDescriptor::ElementAccumulator, typename EpilogueDescriptor::AccLoadOp>()); +}; + +// Get Stride, SmemLayout, and CopyOpS2R for AuxStore node +template< + typename EpilogueDescriptor, + typename StrideOrLayoutTag, + typename ElementAux +> +struct Sm100AuxStoreDescriptor { + constexpr static int Stages = EpilogueDescriptor::StagesD; + using EpilogueTile = typename EpilogueDescriptor::EpilogueTile; + using Element = ElementAux; + using Stride = cutlass::detail::TagToStrideC_t; + + using SmemLayoutAtom = decltype(detail::sm100_get_epilogue_smem_swizzle_layout_atom< + Stride, ElementAux, EpilogueTile>()); + + using CopyOpR2S = decltype(detail::sm100_get_smem_store_op< + Stride, ElementAux, typename EpilogueDescriptor::ElementAccumulator, typename EpilogueDescriptor::AccLoadOp>()); +}; + } // namespace detail /////////////////////////////////////////////////////////////////////////////// @@ -1278,8 +1435,12 @@ private: is_same_v || is_same_v || is_same_v; - // Input transform kernels - when dispatching to sm100 nosmem epilogue, go through the default path without EVT support. - static constexpr bool IsInputTransformSchedule = IsInterleavedComplex || IsFastF32Schedule; + static constexpr bool IsBlockwiseSchedule = is_same_v || + is_same_v || + is_same_v || + is_same_v; + // Transform kernels - when dispatching to sm100 nosmem epilogue, go through the default path without EVT support. + static constexpr bool IsTransformSchedule = IsInterleavedComplex || IsFastF32Schedule || IsBlockwiseSchedule; static_assert(Is1SmMma ^ Is2SmMma, "unsupported schedule"); static_assert(not (Is2SmMma && size<0>(ClusterShape_MNK{}) % 2 == 1), "schedule + cluster mismatch"); @@ -1304,17 +1465,7 @@ private: } } using CtaTileShape_MNK = decltype(cta_tile_shape()); - - static constexpr auto - tmem_warps() { - if constexpr (Is2SmMma && size<0>(MmaTileShape_MNK{}) == 128) { - return Shape<_2,_2>{}; - } - else { - return Shape<_4,_1>{}; - } - } - using TmemWarpShape_MN = decltype(tmem_warps()); + using TmemWarpShape_MN = decltype(detail::sm100_tmem_warps()); static constexpr auto epilogue_tile() { @@ -1323,7 +1474,7 @@ private: static_assert(is_tuple_v, "Shape or Tile"); return EpilogueTileType{}; } - else if constexpr (is_same_v || not IsInputTransformSchedule) { + else if constexpr (is_same_v || not IsTransformSchedule) { // Save register usage for sm103 blockscaled kernels and sm100 cpasync kernels // to avoid register spilling. constexpr int EpiM = size<0>(CtaTileShape_MNK{}); @@ -1338,20 +1489,15 @@ private: using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTile{}, TmemWarpShape_MN{})); using AccLoadOp = decltype(detail::sm100_get_tmem_load_op< - GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, FusionOp>()); + GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, + FusionOp::IsBlockScaleSupported + >()); static constexpr int FragmentSize = size(EpilogueTile{}) / NumThreadsPerWarpGroup; - static constexpr auto - dispatch_policy() { - if constexpr (std::is_base_of_v || - std::is_base_of_v) { - return Sm100PtrArrayNoSmemWarpSpecialized{}; - } - else { - return Sm100NoSmemWarpSpecialized{}; - } - } - using DispatchPolicy = decltype(dispatch_policy()); + using EpilogueTileShape_MN = decltype(product_each(shape(EpilogueTile{}))); + static constexpr int EpiTiles = size(shape_div(take<0,2>(CtaTileShape_MNK{}), EpilogueTileShape_MN{})); + + using DispatchPolicy = decltype(detail::sm100_dense_dispatch_policy()); static constexpr auto fusion_callbacks() { @@ -1359,7 +1505,7 @@ private: DisableSource ? thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default; if constexpr (IsDefaultFusionOp::value &&\ not is_same_v && \ - (IsInputTransformSchedule || \ + (IsTransformSchedule || \ is_same_v || \ is_same_v) ) { diff --git a/include/cutlass/epilogue/collective/builders/xe_builder.inl b/include/cutlass/epilogue/collective/builders/xe_builder.inl index 809cede6f7..f9ccabd8c2 100644 --- a/include/cutlass/epilogue/collective/builders/xe_builder.inl +++ b/include/cutlass/epilogue/collective/builders/xe_builder.inl @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -39,92 +40,10 @@ namespace cutlass::epilogue::collective { -namespace detail { - template - struct FusionOpInfo { - static_assert(cutlass::detail::dependent_false, - "Could not find a builder specialization."); - }; - - template < - class ElementD, - class ElementCompute, - class ElementC - > - struct FusionOpInfo> { - constexpr static bool HasBuilder = true; - - template < - class DispatchPolicy, - class TileShape_MNK, - class EpilogueTile, - class> - using FusionCallbacks = cutlass::epilogue::fusion::FusionCallbacks< - DispatchPolicy, - cutlass::epilogue::fusion::LinearCombination, - TileShape_MNK, - EpilogueTile - >; - }; - - template < - template class ActivationFn, - class ElementD, - class ElementCompute, - class ElementC - > - struct FusionOpInfo> { - constexpr static bool HasBuilder = true; - template < - class DispatchPolicy, - class TileShape_MNK, - class EpilogueTile, - class> - - using FusionCallbacks = cutlass::epilogue::fusion::FusionCallbacks< - DispatchPolicy, - cutlass::epilogue::fusion::LinCombEltAct, - TileShape_MNK, - EpilogueTile - >; - }; - - template < - class GmemLayoutTagC, - template class ActivationFn, - class ElementD, - class ElementCompute, - class ElementC - > - struct FusionOpInfo> { - constexpr static bool HasBuilder = true; - - template < - class DispatchPolicy, - class TileShape_MNK, - class EpilogueTile, - class CopyOpG2R> - using FusionCallbacks = cutlass::epilogue::fusion::FusionCallbacks< - DispatchPolicy, - cutlass::epilogue::fusion::LinCombDeEltAct, - TileShape_MNK, - EpilogueTile, - CopyOpG2R - >; - }; -} // namespace detail - - -// Intel epilogue builder +// Xe epilogue builder template < class TileShape_MNK, - class EpilogueTileType, + class EpilogueTile_MN, class ElementAccumulator, class ElementCompute, class ElementC, @@ -136,79 +55,109 @@ template < class EpilogueScheduleType, class FusionOpOrCallbacks > - struct CollectiveBuilder< - arch::IntelXe, - arch::OpClassTensorOp, +struct CollectiveBuilder< + arch::IntelXe, + arch::OpClassTensorOp, + TileShape_MNK, + Shape<_1, _1, _1>, // Cluster Shape + EpilogueTile_MN, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + EpilogueScheduleType, + FusionOpOrCallbacks, + cute::enable_if_t< + cute::is_same_v && + cute::is_any_of_v + > +> { +#ifdef SYCL_NVIDIA_TARGET + static_assert(cutlass::detail::dependent_false, + "Trying to use Intel pipeline on Non Intel hardware"); +#endif + static_assert(is_static::value); + static_assert(cute::is_any_of_v, + "ElementC needs to be one of: float, bfloat, half, int32, or void for the Intel pipeline"); + + using EpilogueSchedule = std::conditional_t, + IntelXeGeneric, + EpilogueScheduleType>; + static constexpr bool IsGroup = cute::is_same_v; + using DispatchPolicy = std::conditional_t; + + using StrideC = std::conditional_t>, + GmemLayoutTagC, + cutlass::detail::TagToStrideC_t>>; + using StrideD = std::conditional_t>, + GmemLayoutTagD, + cutlass::detail::TagToStrideC_t>>; + + static_assert(IsGroup == std::is_pointer_v, "Group GEMM should have a pointer to strides"); + static_assert(IsGroup == std::is_pointer_v, "Group GEMM should have a pointer to strides"); + static_assert(get<1>(std::remove_pointer_t{}) == 1, "Only N-major/row-major layouts for C are supported in the Xe epilogue collective builder"); + static_assert(get<1>(std::remove_pointer_t{}) == 1, "Only N-major/row-major layouts for D are supported in the Xe epilogue collective builder"); + + // Use default copy operations. + using CopyOpG2R = void; + using CopyOpR2G = void; + + using FusionCallbacks = + typename detail::CallbacksBuilder< + DispatchPolicy, + FusionOpOrCallbacks, TileShape_MNK, - Shape<_1, _1, _1>, // Cluster Shape - EpilogueTileType, - ElementAccumulator, - ElementCompute, + EpilogueTile_MN, + ElementAccumulator + >::Callbacks; + + using CollectiveOp = cutlass::epilogue::collective::CollectiveEpilogue< + DispatchPolicy, + TileShape_MNK, // xxx + EpilogueTile_MN, ElementC, - GmemLayoutTagC, - AlignmentC, + StrideC, ElementD, - GmemLayoutTagD, - AlignmentD, - EpilogueScheduleType, - FusionOpOrCallbacks, - cute::enable_if_t< - cute::is_same_v && - cute::is_any_of_v && - detail::FusionOpInfo::HasBuilder - > - >{ - #ifdef SYCL_NVIDIA_TARGET - static_assert(cutlass::detail::dependent_false, - "Trying to use Intel pipeline on Non Intel hardware"); - #endif - static_assert(is_static::value); - static_assert(cute::is_any_of_v, - "ElementC needs to be one of: float, bfloat, half for the Intel pipeline"); - - using EpilogueSchedule = std::conditional_t, - IntelXeXMX16, - EpilogueScheduleType>; - static constexpr bool IsGroup = cute::is_same_v; - using DispatchPolicy = std::conditional_t; - - using StrideC = std::conditional_t>, GmemLayoutTagC, cutlass::detail::TagToStrideC_t>>; - using StrideD = std::conditional_t>, GmemLayoutTagD, cutlass::detail::TagToStrideC_t>>; - - static_assert(IsGroup == std::is_pointer_v, "Group GEMM should have a pointer to strides"); - static_assert(IsGroup == std::is_pointer_v, "Group GEMM should have a pointer to strides"); - static_assert(get<1>(std::remove_pointer_t{}) == 1, "Only N-Major/Row-Major layouts for C are supported in the xe epilogue collective builder"); - static_assert(get<1>(std::remove_pointer_t{}) == 1, "Only N-Major/Row-Major layouts for D are supported in the xe epilogue collective builder"); - - using CopyOpG2R = std::conditional_t, void, std::conditional_t == 32, XE_2D_U32x8x16_LD_N, XE_2D_U16x8x16_LD_N>>; - using CopyOpR2G = std::conditional_t == 32, XE_2D_U32x8x16_ST_N, XE_2D_U16x8x16_ST_N>; - - // Intel Epilogue with Linear Combination does not use shared memory - using SmemLayoutAtomC_ = void; - using CopyOpS2R_ = void; - using SmemLayoutAtomD_ = void; - using CopyOpR2S_ = void; - - //TODO(Codeplay): Should FusionCallbacks use DispatchPolicy IntelXeGroupEpilogue for group gemm? That does not work. - using FusionCallbacks = typename detail::FusionOpInfo::template FusionCallbacks< - IntelXeXMX16, TileShape_MNK, TileShape_MNK, CopyOpG2R>; - - using CollectiveOp = cutlass::epilogue::collective::CollectiveEpilogue< - DispatchPolicy, - TileShape_MNK, - ElementAccumulator, - StrideC, - ElementD, - StrideD, - FusionCallbacks, - CopyOpG2R, - SmemLayoutAtomC_, - CopyOpS2R_, - CopyOpR2G, - SmemLayoutAtomD_, - CopyOpR2S_ - >; - }; + StrideD, + FusionCallbacks, + CopyOpG2R, + CopyOpR2G + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Forward Xe12/Xe20 builders to IntelXe +///////////////////////////////////////////////////////////////////////////////////////////////// + +#define CUTLASS_FORWARD_XE_EPI_BUILDER(Arch) \ +template \ +struct CollectiveBuilder, EpilogueTile_MN, \ + ElementAccumulator, ElementCompute, \ + ElementC, GmemLayoutTagC, AlignmentC, \ + ElementD, GmemLayoutTagD, AlignmentD, \ + EpilogueScheduleType, FusionOpOrCallbacks> \ + : CollectiveBuilder, EpilogueTile_MN, \ + ElementAccumulator, ElementCompute, \ + ElementC, GmemLayoutTagC, AlignmentC, \ + ElementD, GmemLayoutTagD, AlignmentD, \ + EpilogueScheduleType, FusionOpOrCallbacks> {}; + +CUTLASS_FORWARD_XE_EPI_BUILDER(Xe12) +CUTLASS_FORWARD_XE_EPI_BUILDER(Xe20) + +#undef CUTLASS_FORWARD_XE_EPI_BUILDER + } // namespace cutlass::epilogue::collective + diff --git a/include/cutlass/epilogue/collective/collective_epilogue.hpp b/include/cutlass/epilogue/collective/collective_epilogue.hpp index 2e7fbeb579..c58fdb21f5 100644 --- a/include/cutlass/epilogue/collective/collective_epilogue.hpp +++ b/include/cutlass/epilogue/collective/collective_epilogue.hpp @@ -71,7 +71,9 @@ class CollectiveEpilogue { #include "sm100_epilogue_array_tma_warpspecialized.hpp" #if defined (SYCL_INTEL_TARGET) #include "xe_epilogue.hpp" +#include "xe_epilogue_legacy.hpp" #include "xe_array_epilogue.hpp" +#include "xe_array_epilogue_legacy.hpp" #endif // // Conv diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp index c2993000c2..5f3af0dc36 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp @@ -507,8 +507,7 @@ class CollectiveEpilogue< int thread_idx, TensorStorage& shared_tensors, TensorMapC const& load_tensormap, - int subtile_idx=-1, - bool wait_until_load_finishes = false) { + int subtile_idx=-1) { using namespace cute; // Indexing variables @@ -595,12 +594,6 @@ class CollectiveEpilogue< // Post-loop fusion callback entry point pld_callbacks.end(); - if (wait_until_load_finishes && did_load) { - typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_tma_consumer_state = - {last_load_producer_state.index(), !last_load_producer_state.phase(), last_load_producer_state.count()}; - load_pipeline.consumer_wait(epi_load_pipe_tma_consumer_state); - } - return load_pipe_producer_state; } diff --git a/include/cutlass/epilogue/collective/xe_array_epilogue.hpp b/include/cutlass/epilogue/collective/xe_array_epilogue.hpp index 9b879bd14d..6875c8d995 100644 --- a/include/cutlass/epilogue/collective/xe_array_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_array_epilogue.hpp @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -37,6 +38,7 @@ #include #include "cutlass/cutlass.h" #include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/collective_epilogue.hpp" #include "cutlass/epilogue/collective/detail.hpp" #include "cutlass/epilogue/fusion/callbacks.hpp" @@ -55,114 +57,85 @@ namespace collective { ///////////////////////////////////////////////////////////////////////////////////////////////// template < - class CtaTileMNK_, + class WGTileMNK_, + class EpilogueTile_, class ElementC_, class StrideC_, class ElementD_, class StrideD_, class FusionCallbacks_, class CopyOpG2R_, - class SmemLayoutAtomC_, - class CopyOpS2R_, - class CopyOpR2G_, - class SmemLayoutAtomD_, - class CopyOpR2S_ + class CopyOpR2G_ > class CollectiveEpilogue< - IntelXeXMX16Group, - CtaTileMNK_, + IntelXeGenericGroup, + WGTileMNK_, + EpilogueTile_, ElementC_, StrideC_, ElementD_, StrideD_, FusionCallbacks_, CopyOpG2R_, - SmemLayoutAtomC_, - CopyOpS2R_, - CopyOpR2G_, - SmemLayoutAtomD_, - CopyOpR2S_ + CopyOpR2G_ > { public: // // Type Aliases // using DispatchPolicy = IntelXeXMX16Group; - using CtaTileMNK = CtaTileMNK_; - using FusionCallbacks = FusionCallbacks_; + + using WGTileMNK = WGTileMNK_; using ElementC = ElementC_; - using ElementAccumulator = ElementC_; using StrideC = StrideC_; using InternalStrideC = cute::remove_pointer_t; using ElementD = ElementD_; using StrideD = StrideD_; using InternalStrideD = cute::remove_pointer_t; + using FusionCallbacks = FusionCallbacks_; using CopyOpG2R = CopyOpG2R_; - using SmemLayoutAtomC = SmemLayoutAtomC_; - using CopyOpS2R = CopyOpS2R_; using CopyOpR2G = CopyOpR2G_; - using SmemLayoutAtomD = SmemLayoutAtomD_; - using CopyOpR2S = CopyOpR2S_; + + using NonVoidElementC = replace_void_t; using ThreadEpilogueOp = typename fusion::FusionCallbacksTraits::Operation; - using GmemTiledCopyC = CopyOpG2R; - using GmemTiledCopyD = cute::conditional_t && not cute::is_void_v, - CopyOpR2G, XE_2D_U32x8x16_ST_N>; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; using ElementOutput = ElementD; - using ElementCompute = ElementAccumulator; - using ElementSource = typename FusionCallbacks::ElementSource; - using ElementScalar = typename FusionCallbacks::ElementScalar; - static constexpr FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest; - static_assert(cute::is_same_v>, - "Only Linear Combination Epilogue is supported for Grouped GEMM at the moment."); + static constexpr int CopyBitsC = cute::min(sizeof(NonVoidElementC) * 8, 64); + static constexpr int CopyBitsD = cute::min(sizeof(ElementD) * 8, 64); + + // NOTE: GmemTiledCopy* may not be the actual C/D copy operations. They are declared here only so + // that GemmUniversalAdapter can inspect their alignment requirements. + // The real C/D copy operations are deduced inside operator() once we have access to + // the TiledMMA. + using GmemTiledCopyC = replace_void_t>; + using GmemTiledCopyD = replace_void_t>; static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; - static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(cute::rank(WGTileMNK{}) == 3, "WGTileMNK must be rank-3: [M, N, K]"); static_assert(cute::rank(InternalStrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); static_assert(cute::rank(InternalStrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); - static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); - static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); - static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); - static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); - - using CopyThreadShape = Shape<_1, Int>; - using Trait_C = Copy_Traits; - using XE_Copy_C = decltype(make_tiled_copy(Copy_Atom{}, - Layout{}, - make_layout(shape_div(typename Trait_C::BlockShape{}, CopyThreadShape{})))); - using Trait_D = Copy_Traits; - using XE_Copy_D = decltype(make_tiled_copy(Copy_Atom{}, - Layout{}, - make_layout(shape_div(typename Trait_D::BlockShape{}, CopyThreadShape{})))); -private: - constexpr static bool is_source_supported = not cute::is_void_v; - constexpr static bool is_destination_supported = not cute::is_void_v && not cute::is_void_v; + using TensorC = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), + Layout, InternalStrideC>{})); -public: + using TensorD = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), + Layout, InternalStrideD>{})); - using EmptyType = cute::tuple<>; - using SmemCStorage = EmptyType; - using SmemDStorage = EmptyType; + using EpilogueTensors = cute::tuple; - struct TensorStorageImpl: cute::tuple { - using FusionStorage = typename FusionCallbacks::SharedStorage; - FusionStorage thread; - }; +private: + constexpr static bool is_source_supported = not is_void_v; + constexpr static bool is_destination_supported = not is_void_v; +public: struct SharedStorage { - using TensorStorage = TensorStorageImpl; - - TensorStorage tensors; + using FusionSharedStorage = typename FusionCallbacks::SharedStorage; + FusionSharedStorage thread; }; - using TensorStorage = typename SharedStorage::TensorStorage; - - using TensorC = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), InternalStrideC{})); //(m, n) - using TensorD = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), InternalStrideD{})); //(m, n) - using EpilogueTensors = cute::tuple; + using TensorStorage = SharedStorage; // Compatibility with legacy epilogues // Host side epilogue arguments struct Arguments { @@ -172,12 +145,9 @@ class CollectiveEpilogue< ElementD** ptr_D; StrideD dD; }; - // Device side epilogue params struct Params { typename FusionCallbacks::Params thread{}; - XE_Copy_C xe_load_c; - XE_Copy_D xe_store_d; ElementC const** ptr_C; StrideC dC; ElementD** ptr_D; @@ -198,24 +168,8 @@ class CollectiveEpilogue< auto problem_shape_MNL = repeat_like(typename ProblemShape::UnderlyingProblemShape{}, int32_t(1)); auto [M, N, L] = problem_shape_MNL; - XE_Copy_C xe_load_c = {}; - if constexpr (is_source_supported) { - ElementC const* ptr_C_first_batch = reinterpret_cast(args.ptr_C); - TensorC mC_mnl = make_tensor(make_gmem_ptr(ptr_C_first_batch), make_layout(make_shape(M, N, L), InternalStrideC{})); - xe_load_c = {xe_load_c.with(mC_mnl)}; - } - - XE_Copy_D xe_store_d = {}; - if constexpr (is_destination_supported) { - ElementD* ptr_D_first_batch = reinterpret_cast(args.ptr_D); - TensorD mD_mnl = make_tensor(make_gmem_ptr(ptr_D_first_batch), make_layout(make_shape(M, N, L), InternalStrideD{})); - xe_store_d = {xe_store_d.with(mD_mnl)}; - } - return { FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), - xe_load_c, - xe_store_d, args.ptr_C, args.dC, args.ptr_D, @@ -231,7 +185,7 @@ class CollectiveEpilogue< template static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { return Status::kSuccess; } @@ -298,157 +252,172 @@ class CollectiveEpilogue< class TileShapeMNK, class TileCoordMNKL, class Accumulator, - class TiledMma, + class TiledMMA, class LoadStoreTensor > CUTLASS_DEVICE void operator() ( ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_MNK, + TileShapeMNK, /* compatibility with legacy epilogues */ TileCoordMNKL tile_coord_mnkl, - Accumulator accumulators, - TiledMma tiled_mma, + Accumulator accumulators, + TiledMMA, int thread_idx, LoadStoreTensor const& load_store_tensors) { - - (void) tiled_mma; + using namespace cute; - static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); - static_assert(cute::rank(InternalStrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); - static_assert(cute::rank(InternalStrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); - - using MmaAtomShape = typename TiledMma::AtomShape_MNK; - static constexpr auto BLK_M = get<0>(CtaTileMNK{}); - static constexpr auto BLK_N = get<1>(CtaTileMNK{}); - static constexpr auto BLK_K = get<2>(CtaTileMNK{}); - // static_assert(is_same_v, "assertation fail"); - static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); - static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); - static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); - - static_assert( - BLK_M % ATOM_M == 0 && - BLK_N % ATOM_N == 0 && - BLK_K % ATOM_K == 0, - "expected CTATileMNK to be evenly divided by TiledMma::ThrLayoutVMNK"); - static constexpr auto SG_M = BLK_M / ATOM_M; - static constexpr auto SG_N = BLK_N / ATOM_N; - static constexpr auto SG_K = BLK_K / ATOM_K; - using SubgroupTileShape = Shape; - - static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // A frags per sub_group - static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // B frags per sub_group - - static constexpr int FragmentSize = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize; - - // Indexing variables - auto [M, N, K, L] = problem_shape_mnkl; - auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; - auto m_sg = get_sub_group_id() / ATOM_N; - auto n_sg = get_sub_group_id() % ATOM_N; - - auto mn_shape = shape(typename decltype(params.xe_store_d)::Tiler_MN{}); - - auto sg_local_m_coord = get_sub_group_id() / ATOM_N; - auto sg_local_n_coord = get_sub_group_id() % ATOM_N; - - auto sg_m_coord = m_coord * ATOM_M + sg_local_m_coord; - auto sg_n_coord = n_coord * ATOM_N + sg_local_n_coord; - auto sg_coord = make_coord(sg_m_coord, sg_n_coord, k_coord, l_coord); + using MMATile = decltype(take<0,2>(typename TiledMMA::AtomShape_MNK{})); + + static constexpr int EpiRPreferred = 8; + static constexpr int EpiCPreferred = 512 / cute::min(sizeof_bits_v, sizeof_bits_v); // 1 cache line + static constexpr int EpiR = cute::gcd(EpiRPreferred, get<0>(MMATile{})); + static constexpr int EpiC = cute::gcd(EpiCPreferred, get<1>(MMATile{})); + + using DefaultEpilogueTile = Shape, Int>; + using EpilogueTile = conditional_t || is_same_v, + DefaultEpilogueTile, + EpilogueTile_>; + + using DefaultCopyOpG2R = XE_LOAD_2D(EpilogueTile{})), cute::gcd(512 / CopyBitsC, get<1>(EpilogueTile{}))>; + using DefaultCopyOpR2G = XE_STORE_2D(EpilogueTile{})), cute::gcd(512 / CopyBitsD, get<1>(EpilogueTile{}))>; + + using ActualGmemTiledCopyC = replace_void_t; + using ActualGmemTiledCopyD = replace_void_t; bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); - - // Represent the full output tensor - Tensor mD_mnl = cute::get_xe_tensor(make_shape(M,N,L)); - - // Tile the output tensor per WG and select the tile for current WG - Tensor g_wg_D = local_tile(mD_mnl, take<0,2>(CtaTileMNK{}), make_coord(m_coord,n_coord,l_coord)); // (BLK_M,BLK_N) - - // Tile the output tensor per SG and select tile for the current SG - Tensor gD = local_tile(g_wg_D, take<0,2>(SubgroupTileShape{}), make_coord(m_sg,n_sg)); // (SG_M,SG_N) - - auto thread_xe_store_d = params.xe_store_d.get_thread_slice(thread_idx); - Tensor tCgD = thread_xe_store_d.partition_D(gD); - - Tensor trC = make_tensor(Shape>{}); - Tensor trD_compute = make_tensor(Shape>{}); - - // Because Sm90 uses shared memory, they are not tied to using the same accumulator values - // for MMA and Epilogue. But because we are operating directly in the accumulators, we need to be - // sure that we are operating on the same values. - ThrCopy thread_g2r = params.xe_load_c.get_slice(thread_idx); - - // OOB predication for tile quantization "residue" - // Absolute coordinate tensors (dynamic) - Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) - Tensor cD = local_tile(mD_crd, take<0,2>(SubgroupTileShape{}), make_coord(sg_m_coord, sg_n_coord)); - Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) - Tensor tRS_cD_mn = thread_g2r.partition_S(flat_divide(cD_mn, mn_shape)); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) - - Tensor tRS_cD = make_coord_tensor(tRS_cD_mn.layout()); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) - - // Get the fusion callbacks - // Arguments passed here relate to sub-group tiles, rather than CTA (work-group) tiles + + auto MN = take<0,2>(problem_shape_mnkl); + auto cCD = make_identity_tensor(MN); // (m,n) + auto gCD = local_tile(cCD, take<0,2>(WGTileMNK{}), take<0,2>(tile_coord_mnkl)); // (m_in_wg_tile, n_in_wg_tile) + + auto thr_mma = TiledMMA{}.get_slice(thread_idx); + auto tCDgCD = thr_mma.partition_C(gCD); // (mma_v,mma_m,mma_n) -> coord + + // Tile accumulator into epilogue tiles. + auto mma_per_epi = shape_div(EpilogueTile{}, MMATile{}); + auto tiled_acc_layout = group<0,3>(prepend(flat_divide(remove<0>(accumulators.layout()), mma_per_epi), + get<0>(accumulators.layout()))); + auto tiled_acc = make_tensor(accumulators.data(), tiled_acc_layout); // ((mma_v,mma_m,mma_n),epi_m,epi_n) + + // Tile subgroup's TV coord layout into epilogue tiles. + auto sg_v_coord = prepend(flat_divide(remove<0>(tCDgCD.layout()), mma_per_epi), + get<0>(tCDgCD.layout())); // (mma_v,mma_m,mma_n,epi_m,epi_n) -> coord + + // Copy C/D one epilogue tile at a time. Prepare: + // - subgroup-scope TiledCopy objects + // - global coordinate tensors, partitioned into epilogue tiles + // - copy fragments + // - compute fragments (same layout as accumulator) + // Both copy and compute fragments are SubgroupTensors, holding coordinate mappings + // within the epilogue tile. + auto copy_c = make_block_2d_copy(ActualGmemTiledCopyC{}, get<0>(load_store_tensors)(_,_,0)); + auto copy_d = make_block_2d_copy(ActualGmemTiledCopyD{}, get<1>(load_store_tensors)(_,_,0)); + + int wi_idx = thread_idx % intel::sg_size; + auto thr_copy_c = copy_c.get_slice(wi_idx); + auto thr_copy_d = copy_d.get_slice(wi_idx); + + // Partition global coordinate tensors into epilogue tiles, matching + // the work-division from the TiledMMA. + auto gCD_epi_layout = append(append(make_identity_layout(EpilogueTile{}), + get<3>(sg_v_coord)), get<4>(sg_v_coord)); + auto gCD_epi = make_tensor(tCDgCD.data(), gCD_epi_layout); // (m,n,epi_m,epi_n) -> coord + + auto tCgC = thr_copy_c.partition_S(gCD_epi); // (atom_v,atom_m,atom_n,epi_m,epi_n) + auto tDgD = thr_copy_d.partition_D(gCD_epi); // (atom_v,atom_m,atom_n,epi_m,epi_n) + + auto tCrC = thr_copy_c.partition_sg_fragment_D(gCD_epi(_,_,0,0)); // (atom_v,atom_m,atom_n,epi_m,epi_n) + auto tDrD = thr_copy_d.partition_sg_fragment_S(gCD_epi(_,_,0,0)); // (atom_v,atom_m,atom_n,epi_m,epi_n) + + // Create C subgroup fragments for epilogue compute. + using AccTVLayout = decltype(thr_mma.partition_sg_fragment_C(gCD).tv_layout()); + auto cd_compute_tv = make_layout(get<0>(AccTVLayout{}), + sg_v_coord(_,_,_,_0{},_0{})); + + auto tCrC_compute_wi = make_fragment_like(tiled_acc(_,_0{},_0{})); + auto tCrC_compute = make_subgroup_tensor(tCrC_compute_wi, cd_compute_tv); // (mma_v,mma_m,mma_n) + + // Calculate residues. + auto residue_gCD = MN - gCD(_0{}); // (res_m, res_n) + auto residue_tCDgCD = MN - tCDgCD(_0{}); // (res_m, res_n) + + // Pass data to fusions. + // FIXME: Some Xe visitors expect subgroup tiles/coordinates here and should be updated to accept + // workgroup tiles/coordinates, like the NV code. Note that CuTe has no concept of a "subgroup tile." + // Work division within a TiledMMA is flexible, and a subgroup's data need not be contiguous. + // Instead, visitors should retrieve data coordinates within the WG tile via tDgD. constexpr bool RefSrc = true; - auto residue_mn = make_coord(M, N); //TODO(Codeplay): this is not correct - auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ - problem_shape_mnkl, - SubgroupTileShape{}, - sg_coord, - tiled_mma, - mn_shape, - params.xe_store_d, - cD, - residue_mn, - tRS_cD, - residue_mn, - trC, - thread_idx, - }; + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs { + problem_shape_mnkl, + WGTileMNK{}, + tile_coord_mnkl, + TiledMMA{}, + EpilogueTile{}, + copy_d, + gCD, + residue_gCD, + tDgD, + residue_tCDgCD, + tCrC_compute, + thread_idx, + }; auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); - cst_callbacks.begin(); + // Epilogue visitors work on cutlass::Arrays of values for better vectorization. + // For now, choose array size so there is one array per MMA atom C tile -- later + // we might want to make it configurable (FragmentSize in NV code). + using ElementAccumulator = typename Accumulator::element_type; + constexpr int ComputeVectorLen = size<0>(Accumulator{}); + auto tiled_acc_v = recast>(tiled_acc); - auto acc_frag = recast>(accumulators); - auto trD_compute_frag = recast>(trD_compute); + // Create D subgroup fragments for epilogue compute. + using FragmentVisit = decltype(cst_callbacks.visit(tiled_acc_v(0), 0, 0, 0)); + using ElementVisit = typename FragmentVisit::Element; - Tensor trD = make_tensor(Shape>{}); - auto trD_frag = recast>(trD); + auto tDrD_compute_wi = make_fragment_like(tiled_acc(_,_0{},_0{})); + auto tDrD_compute = make_subgroup_tensor(tDrD_compute_wi, cd_compute_tv); // (mma_v,mma_m,mma_n) + auto tDrD_compute_v = recast(tDrD_compute_wi); - constexpr int ValuesLoaded = - FragsM * FragsN * FragmentSize * SubgroupSize * ATOM_M * ATOM_N * ATOM_K; - constexpr int MN = get<0>(CtaTileMNK{}) * get<1>(CtaTileMNK{}); - static_assert(ValuesLoaded == MN, "the total elements loaded by all threads should be the same as MxN" ); + // Outer loops over epilogue tiles. + constexpr auto EpiTilesM = size<2>(gCD_epi); + constexpr auto EpiTilesN = size<3>(gCD_epi); + + cst_callbacks.begin(); - auto synchronize = [&] () {}; CUTLASS_PRAGMA_UNROLL - for (int epi_n = 0; epi_n < FragsN; epi_n++) { + for (int epi_m = 0; epi_m < EpiTilesM; epi_m++) { CUTLASS_PRAGMA_UNROLL - for (int epi_m = 0; epi_m < FragsM; epi_m++) { - - if (is_C_load_needed) { - //cordinates for C and D are the same - copy(params.xe_load_c.with(get<0>(load_store_tensors)), tCgD(_, epi_m, epi_n), trC); + for (int epi_n = 0; epi_n < EpiTilesN; epi_n++) { + cst_callbacks.begin_loop(epi_m, epi_n); + + // Load C + reorder. + if constexpr (is_source_supported) { + if (is_C_load_needed) { + copy(copy_c, tCgC(_,_,_,epi_m,epi_n), tCrC); + reorder(tCrC, tCrC_compute); + } } cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed); - auto acc_frag_mn = acc_frag(_, epi_m, epi_n); - + // Epilogue computation, one ComputeVectorLen-sized array at a time. CUTLASS_PRAGMA_UNROLL - for (int epi_v = 0; epi_v < size<0>(trD_compute_frag); ++epi_v) { - trD_compute_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); + for (int epi_v = 0; epi_v < size<0>(tiled_acc_v); ++epi_v) { + tDrD_compute_v(epi_v) = cst_callbacks.visit(tiled_acc_v(epi_v, epi_m, epi_n), + epi_v, epi_m, epi_n); } - cst_callbacks.reduce(nullptr, synchronize, epi_m, epi_n, (epi_m == FragsM - 1 && epi_n == FragsN - 1), trD_compute_frag); - + + bool last_epi = (epi_m == EpiTilesM - 1) && (epi_n == EpiTilesN - 1); + cst_callbacks.reduce(nullptr, [=]{}, epi_m, epi_n, last_epi, tDrD_compute_v); + + // Reorder D (possibly including data conversion) and store. if constexpr (is_destination_supported) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(trD_compute_frag); ++i) { - trD_frag(i) = cutlass::NumericArrayConverter{}(trD_compute_frag(i)); - } - copy(params.xe_store_d.with(get<1>(load_store_tensors)), trD, tCgD(_, epi_m, epi_n)); + reorder(tDrD_compute, tDrD); + copy(copy_d, tDrD, tDgD(_,_,_,epi_m,epi_n)); } + + cst_callbacks.end_loop(epi_m, epi_n); } } @@ -464,12 +433,12 @@ class CollectiveEpilogue< TensorC mC_mnl; TensorD mD_mnl; if constexpr (is_source_supported) { - ElementC const* ptr_C_curr_batch = reinterpret_cast(params.ptr_C[next_group]); + auto ptr_C_curr_batch = reinterpret_cast(params.ptr_C[next_group]); mC_mnl = make_tensor(make_gmem_ptr(ptr_C_curr_batch), make_layout(make_shape(M, N, L), params.dC[next_group])); } if constexpr (is_destination_supported) { - ElementD* ptr_D_curr_batch = reinterpret_cast(params.ptr_D[next_group]); + auto ptr_D_curr_batch = reinterpret_cast(params.ptr_D[next_group]); mD_mnl = make_tensor(make_gmem_ptr(ptr_D_curr_batch), make_layout(make_shape(M, N, L), params.dD[next_group])); } return cute::make_tuple(mC_mnl, mD_mnl); diff --git a/include/cutlass/epilogue/collective/xe_array_epilogue_legacy.hpp b/include/cutlass/epilogue/collective/xe_array_epilogue_legacy.hpp new file mode 100644 index 0000000000..f8c955316f --- /dev/null +++ b/include/cutlass/epilogue/collective/xe_array_epilogue_legacy.hpp @@ -0,0 +1,500 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/xe_visitor_softmax.hpp" +#include "cutlass/detail/layout.hpp" + +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class CtaTileMNK_, + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class FusionCallbacks_, + class CopyOpG2R_, + class SmemLayoutAtomC_, + class CopyOpS2R_, + class CopyOpR2G_, + class SmemLayoutAtomD_, + class CopyOpR2S_ +> +class CollectiveEpilogue< + IntelXeXMX16Group, + CtaTileMNK_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + FusionCallbacks_, + CopyOpG2R_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpR2G_, + SmemLayoutAtomD_, + CopyOpR2S_ +> { +public: + // + // Type Aliases + // + using DispatchPolicy = IntelXeXMX16Group; + using CtaTileMNK = CtaTileMNK_; + using FusionCallbacks = FusionCallbacks_; + using ElementC = ElementC_; + using StrideC = StrideC_; + using InternalStrideC = cute::remove_pointer_t; + using ElementD = ElementD_; + using StrideD = StrideD_; + using InternalStrideD = cute::remove_pointer_t; + using CopyOpG2R = CopyOpG2R_; + using SmemLayoutAtomC = SmemLayoutAtomC_; + using CopyOpS2R = CopyOpS2R_; + using CopyOpR2G = CopyOpR2G_; + using SmemLayoutAtomD = SmemLayoutAtomD_; + using CopyOpR2S = CopyOpR2S_; + + using ThreadEpilogueOp = typename fusion::FusionCallbacksTraits::Operation; + using GmemTiledCopyC = CopyOpG2R; + using GmemTiledCopyD = cute::conditional_t && not cute::is_void_v, + CopyOpR2G, XE_2D_U32x8x16_ST_N>; + using ElementOutput = ElementD; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementAccumulator = ElementCompute; + using ElementSource = typename FusionCallbacks::ElementSource; + using ElementScalar = typename FusionCallbacks::ElementScalar; + static constexpr FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest; + + static_assert(cute::is_same_v>, + "Only Linear Combination Epilogue is supported for Grouped GEMM at the moment."); + + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + + static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(cute::rank(InternalStrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); + static_assert(cute::rank(InternalStrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); + + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + + using CopyThreadShape = Shape<_1, Int>; + + using Trait_D = Copy_Traits; + using val_layout_store_D = decltype(make_layout(shape_div(typename Trait_D::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_D = decltype(make_tiled_copy(Copy_Atom{}, + Layout{}, + val_layout_store_D{})); +private: + constexpr static bool is_source_supported = not cute::is_void_v; + constexpr static bool is_destination_supported = not cute::is_void_v && not cute::is_void_v; + + using NonVoidElementC = conditional_t; + using Trait_C = Copy_Traits; + using NonVoidTrait_C = conditional_t; + using val_layout_load_C = decltype(make_layout(shape_div(typename NonVoidTrait_C::BlockShape{}, CopyThreadShape{}))); + using NonVoidValLayoutLoad_C = conditional_t; + using XE_Copy_C = decltype(make_tiled_copy(Copy_Atom{}, + Layout{}, + NonVoidValLayoutLoad_C{})); +public: + + using EmptyType = cute::tuple<>; + using SmemCStorage = EmptyType; + using SmemDStorage = EmptyType; + + struct TensorStorageImpl: cute::tuple { + using FusionStorage = typename FusionCallbacks::SharedStorage; + FusionStorage thread; + }; + + struct SharedStorage { + using TensorStorage = TensorStorageImpl; + + TensorStorage tensors; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + + using TensorC = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), InternalStrideC{})); //(m, n) + using TensorD = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), InternalStrideD{})); //(m, n) + using EpilogueTensors = cute::tuple; + + // Host side epilogue arguments + struct Arguments { + typename FusionCallbacks::Arguments thread{}; + ElementC const** ptr_C; + StrideC dC; + ElementD** ptr_D; + StrideD dD; + }; + + // Device side epilogue params + struct Params { + typename FusionCallbacks::Params thread{}; + XE_Copy_C xe_load_c; + XE_Copy_D xe_store_d; + ElementC const** ptr_C; + StrideC dC; + ElementD** ptr_D; + StrideD dD; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNL = repeat_like(typename ProblemShape::UnderlyingProblemShape{}, int32_t(1)); + auto [M, N, L] = problem_shape_MNL; + + XE_Copy_C xe_load_c = {}; + if constexpr (is_source_supported) { + ElementC const* ptr_C_first_batch = reinterpret_cast(args.ptr_C); + TensorC mC_mnl = make_tensor(make_gmem_ptr(ptr_C_first_batch), make_layout(make_shape(M, N, L), InternalStrideC{})); + xe_load_c = {xe_load_c.with(mC_mnl)}; + } + + XE_Copy_D xe_store_d = {}; + if constexpr (is_destination_supported) { + ElementD* ptr_D_first_batch = reinterpret_cast(args.ptr_D); + TensorD mD_mnl = make_tensor(make_gmem_ptr(ptr_D_first_batch), make_layout(make_shape(M, N, L), InternalStrideD{})); + xe_store_d = {xe_store_d.with(mD_mnl)}; + } + + return { + FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), + xe_load_c, + xe_store_d, + args.ptr_C, + args.dC, + args.ptr_D, + args.dD + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + template + static bool + can_implement( + ProblemShape problem_shape, + Arguments const& args) { + constexpr int copy_alignment_bits = 128; + constexpr int batch_alignment_bits = 512; + + bool implementable = true; + bool fusion_implementable = true; + + for (int i = 0; i < problem_shape.groups(); ++i) { + auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + + if constexpr (is_destination_supported) { + constexpr int min_aligned_elements_D = copy_alignment_bits / sizeof_bits::value; + implementable &= cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideD{}); + if (L > 1) { + constexpr int min_batch_aligned_elements_D = batch_alignment_bits / sizeof_bits::value; + implementable &= get<2>(InternalStrideD{}) % min_batch_aligned_elements_D == 0; + } + } + + if constexpr (is_source_supported) { + constexpr int min_aligned_elements_C = copy_alignment_bits / sizeof_bits::value; + implementable &= cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideC{}); + if (L > 1) { + constexpr int min_batch_aligned_elements_C = batch_alignment_bits / sizeof_bits::value; + implementable &= get<2>(InternalStrideC{}) % min_batch_aligned_elements_C == 0; + } + } + + fusion_implementable = fusion_implementable && FusionCallbacks::can_implement(problem_shape_MNKL, args.thread); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for XE 2D copy.\n"); + } + + if (!fusion_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); + } + + return implementable && fusion_implementable; + } + + CUTLASS_HOST_DEVICE + CollectiveEpilogue(Params const& params_, TensorStorage const& shared_storage_) + : params(params_), fusion_callbacks(params_.thread, shared_storage_.thread) {} + + CUTLASS_DEVICE + bool + is_producer_load_needed() const { + return fusion_callbacks.is_producer_load_needed(); + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class Accumulator, + class TiledMma, + class LoadStoreTensor + > + CUTLASS_DEVICE void + operator() ( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + Accumulator accumulators, + TiledMma tiled_mma, + int thread_idx, + LoadStoreTensor const& load_store_tensors) { + + (void) tiled_mma; + using namespace cute; + + static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(cute::rank(InternalStrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); + static_assert(cute::rank(InternalStrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); + + using MmaAtomShape = typename TiledMma::AtomShape_MNK; + static constexpr auto BLK_M = get<0>(CtaTileMNK{}); + static constexpr auto BLK_N = get<1>(CtaTileMNK{}); + static constexpr auto BLK_K = get<2>(CtaTileMNK{}); + // static_assert(is_same_v, "assertation fail"); + static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); + + static_assert( + BLK_M % ATOM_M == 0 && + BLK_N % ATOM_N == 0 && + BLK_K % ATOM_K == 0, + "expected CTATileMNK to be evenly divided by TiledMma::ThrLayoutVMNK"); + static constexpr auto SG_M = BLK_M / ATOM_M; + static constexpr auto SG_N = BLK_N / ATOM_N; + static constexpr auto SG_K = BLK_K / ATOM_K; + using SubgroupTileShape = Shape; + + static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // A frags per sub_group + static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // B frags per sub_group + + static constexpr int FragmentSize = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize; + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + auto m_sg = get_sub_group_id() / ATOM_N; + auto n_sg = get_sub_group_id() % ATOM_N; + + auto mn_shape = shape(typename decltype(params.xe_store_d)::Tiler_MN{}); + + auto sg_local_m_coord = get_sub_group_id() / ATOM_N; + auto sg_local_n_coord = get_sub_group_id() % ATOM_N; + + auto sg_m_coord = m_coord * ATOM_M + sg_local_m_coord; + auto sg_n_coord = n_coord * ATOM_N + sg_local_n_coord; + auto sg_coord = make_coord(sg_m_coord, sg_n_coord, k_coord, l_coord); + + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // Represent the full output tensor + Tensor mD_mnl = cute::get_xe_tensor(make_shape(M,N,L)); + + // Tile the output tensor per WG and select the tile for current WG + Tensor g_wg_D = local_tile(mD_mnl, take<0,2>(CtaTileMNK{}), make_coord(m_coord,n_coord,l_coord)); // (BLK_M,BLK_N) + + // Tile the output tensor per SG and select tile for the current SG + Tensor gD = local_tile(g_wg_D, take<0,2>(SubgroupTileShape{}), make_coord(m_sg,n_sg)); // (SG_M,SG_N) + + auto thread_xe_store_d = params.xe_store_d.get_thread_slice(thread_idx); + Tensor tCgD = thread_xe_store_d.partition_D(gD); + + Tensor trC = make_tensor(Shape>{}); + Tensor trD_compute = make_tensor(Shape>{}); + + // Because Sm90 uses shared memory, they are not tied to using the same accumulator values + // for MMA and Epilogue. But because we are operating directly in the accumulators, we need to be + // sure that we are operating on the same values. + ThrCopy thread_g2r = params.xe_load_c.get_slice(thread_idx); + + // OOB predication for tile quantization "residue" + // Absolute coordinate tensors (dynamic) + Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) + Tensor cD = local_tile(mD_crd, take<0,2>(SubgroupTileShape{}), make_coord(sg_m_coord, sg_n_coord)); + Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) + Tensor tRS_cD_mn = thread_g2r.partition_S(flat_divide(cD_mn, mn_shape)); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) + + Tensor tRS_cD = make_coord_tensor(tRS_cD_mn.layout()); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) + + // Get the fusion callbacks + // Arguments passed here relate to sub-group tiles, rather than CTA (work-group) tiles + constexpr bool RefSrc = true; + auto residue_mn = make_coord(M, N); //TODO(Codeplay): this is not correct + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ + problem_shape_mnkl, + SubgroupTileShape{}, + sg_coord, + tiled_mma, + mn_shape, + params.xe_store_d, + cD, + residue_mn, + tRS_cD, + residue_mn, + trC, + thread_idx, + }; + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + + cst_callbacks.begin(); + + auto acc_frag = recast>(accumulators); + auto trD_compute_frag = recast>(trD_compute); + + Tensor trD = make_tensor(Shape>{}); + auto trD_frag = recast>(trD); + + constexpr int ValuesLoaded = + FragsM * FragsN * FragmentSize * SubgroupSize * ATOM_M * ATOM_N * ATOM_K; + constexpr int MN = get<0>(CtaTileMNK{}) * get<1>(CtaTileMNK{}); + static_assert(ValuesLoaded == MN, "the total elements loaded by all threads should be the same as MxN" ); + + auto synchronize = [&] () {}; + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < FragsN; epi_n++) { + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < FragsM; epi_m++) { + + //Instead of calling is_C_load_needed. We do heirachical check + //so that runtime check not there when ElementC is void + if constexpr (is_source_supported) { + if (is_C_load_needed) { + //cordinates for C and D are the same + copy(params.xe_load_c.with(get<0>(load_store_tensors)), tCgD(_, epi_m, epi_n), trC); + } + } + + cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed); + + auto acc_frag_mn = acc_frag(_, epi_m, epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size<0>(trD_compute_frag); ++epi_v) { + trD_compute_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); + } + cst_callbacks.reduce(nullptr, synchronize, epi_m, epi_n, (epi_m == FragsM - 1 && epi_n == FragsN - 1), trD_compute_frag); + + if constexpr (is_destination_supported) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(trD_compute_frag); ++i) { + trD_frag(i) = cutlass::NumericArrayConverter{}(trD_compute_frag(i)); + } + copy(params.xe_store_d.with(get<1>(load_store_tensors)), trD, tCgD(_, epi_m, epi_n)); + } + } + } + + cst_callbacks.end(); + } + + template + CUTLASS_DEVICE auto update_tensor_shape_stride( + int32_t const& next_group, + ProblemShape_MNKL const& problem_shape_mnkl) { + auto [M, N, K, L] = problem_shape_mnkl; + + TensorC mC_mnl; + TensorD mD_mnl; + if constexpr (is_source_supported) { + ElementC const* ptr_C_curr_batch = reinterpret_cast(params.ptr_C[next_group]); + mC_mnl = make_tensor(make_gmem_ptr(ptr_C_curr_batch), make_layout(make_shape(M, N, L), params.dC[next_group])); + } + + if constexpr (is_destination_supported) { + ElementD* ptr_D_curr_batch = reinterpret_cast(params.ptr_D[next_group]); + mD_mnl = make_tensor(make_gmem_ptr(ptr_D_curr_batch), make_layout(make_shape(M, N, L), params.dD[next_group])); + } + return cute::make_tuple(mC_mnl, mD_mnl); + } + +private: + Params const& params; + FusionCallbacks fusion_callbacks; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/xe_epilogue.hpp b/include/cutlass/epilogue/collective/xe_epilogue.hpp index 463498e8ca..2ed046f1ac 100644 --- a/include/cutlass/epilogue/collective/xe_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_epilogue.hpp @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -28,15 +29,13 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ -/*! \file - \brief Functor performing elementwise operations used by epilogues. -*/ #pragma once #include #include "cutlass/cutlass.h" #include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/collective_epilogue.hpp" #include "cutlass/epilogue/collective/detail.hpp" #include "cutlass/epilogue/fusion/callbacks.hpp" @@ -54,105 +53,81 @@ namespace collective { ///////////////////////////////////////////////////////////////////////////////////////////////// template < - class CtaTileMNK_, + class WGTileMNK_, + class EpilogueTile_, class ElementC_, class StrideC_, class ElementD_, class StrideD_, class FusionCallbacks_, class CopyOpG2R_, - class SmemLayoutAtomC_, - class CopyOpS2R_, - class CopyOpR2G_, - class SmemLayoutAtomD_, - class CopyOpR2S_ + class CopyOpR2G_ > class CollectiveEpilogue< - IntelXeXMX16, - CtaTileMNK_, + IntelXeGeneric, + WGTileMNK_, + EpilogueTile_, ElementC_, StrideC_, ElementD_, StrideD_, FusionCallbacks_, CopyOpG2R_, - SmemLayoutAtomC_, - CopyOpS2R_, - CopyOpR2G_, - SmemLayoutAtomD_, - CopyOpR2S_ + CopyOpR2G_ > { public: // // Type Aliases // using DispatchPolicy = IntelXeXMX16; - using CtaTileMNK = CtaTileMNK_; - using FusionCallbacks = FusionCallbacks_; + + using WGTileMNK = WGTileMNK_; using ElementC = ElementC_; - using ElementAccumulator = ElementC_; using StrideC = StrideC_; using ElementD = ElementD_; using StrideD = StrideD_; + using FusionCallbacks = FusionCallbacks_; using CopyOpG2R = CopyOpG2R_; - using SmemLayoutAtomC = SmemLayoutAtomC_; - using CopyOpS2R = CopyOpS2R_; using CopyOpR2G = CopyOpR2G_; - using SmemLayoutAtomD = SmemLayoutAtomD_; - using CopyOpR2S = CopyOpR2S_; + + using NonVoidElementC = replace_void_t; using ThreadEpilogueOp = typename fusion::FusionCallbacksTraits::Operation; - using GmemTiledCopyC = conditional_t, XE_2D_U32x8x16_LD_N, CopyOpG2R>; - using GmemTiledCopyD = cute::conditional_t && not cute::is_void_v, - CopyOpR2G, XE_2D_U32x8x16_ST_N>; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; using ElementOutput = ElementD; - using ElementCompute = ElementAccumulator; + + static constexpr int CopyBitsC = cute::min(sizeof(NonVoidElementC) * 8, 64); + static constexpr int CopyBitsD = cute::min(sizeof(ElementD) * 8, 64); + + // NOTE: GmemTiledCopy* may not be the actual C/D copy operations. They are declared here only so + // that GemmUniversalAdapter can inspect their alignment requirements. + // The real C/D copy operations are deduced inside operator() once we have access to + // the TiledMMA. + using GmemTiledCopyC = replace_void_t>; + using GmemTiledCopyD = replace_void_t>; static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; - static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(cute::rank(WGTileMNK{}) == 3, "WGTileMNK must be rank-3: [M, N, K]"); static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); - static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); - static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); - static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); - static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); - - using CopyThreadShape = Shape<_1, Int>; - - using Trait_C = Copy_Traits; - using val_layout_load_C = decltype(make_layout(shape_div(typename Trait_C::BlockShape{}, CopyThreadShape{}))); - using XE_Copy_C = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, val_layout_load_C{})); + using TensorC = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), + Layout, StrideC>{})); - using Trait_D = Copy_Traits; - using val_layout_store_D = decltype(make_layout(shape_div(typename Trait_D::BlockShape{}, CopyThreadShape{}))); - using XE_Copy_D = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, val_layout_store_D{})); + using TensorD = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), + Layout, StrideD>{})); private: - constexpr static bool is_source_supported = not cute::is_void_v; - constexpr static bool is_destination_supported = not cute::is_void_v && not cute::is_void_v; - - constexpr static bool is_m_major_C = detail::is_m_major(); - constexpr static bool is_m_major_D = detail::is_m_major(); + constexpr static bool is_source_supported = !is_void_v; + constexpr static bool is_destination_supported = !is_void_v; public: - - using EmptyType = cute::tuple<>; - using SmemCStorage = EmptyType; - using SmemDStorage = EmptyType; - - struct TensorStorageImpl: cute::tuple { - using FusionStorage = typename FusionCallbacks::SharedStorage; - FusionStorage thread; - }; - struct SharedStorage { - using TensorStorage = TensorStorageImpl; - - TensorStorage tensors; + using FusionSharedStorage = typename FusionCallbacks::SharedStorage; + FusionSharedStorage thread; }; - using TensorStorage = typename SharedStorage::TensorStorage; + using TensorStorage = SharedStorage; // Compatibility with legacy epilogues // Host side epilogue arguments struct Arguments { @@ -166,8 +141,8 @@ class CollectiveEpilogue< // Device side epilogue params struct Params { typename FusionCallbacks::Params thread{}; - XE_Copy_C xe_load_c; - XE_Copy_D xe_store_d; + TensorC mC; + TensorD mD; }; // @@ -180,26 +155,20 @@ class CollectiveEpilogue< ProblemShape const& problem_shape, Arguments const& args, [[maybe_unused]] void* workspace) { - // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_MNKL; - XE_Copy_C xe_load_c = {}; - if constexpr (is_source_supported) { - auto mC = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M, N, L), args.dC)); - xe_load_c = {xe_load_c.with(mC)}; - } + // Optionally append 1s until problem shape is rank-4, in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto shape_CD = select<0,1,3>(problem_shape_MNKL); // (M,N,L) - XE_Copy_D xe_store_d = {}; - if constexpr (is_destination_supported) { - auto mD = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M, N, L), args.dD)); - xe_store_d = {xe_store_d.with(mD)}; - } + // Create C/D tensors here; delay TiledCopy creation to the kernel. + auto non_void_ptr_C = reinterpret_cast(args.ptr_C); + auto mC = make_tensor(make_gmem_ptr(non_void_ptr_C), make_layout(shape_CD, args.dC)); + auto mD = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(shape_CD, args.dD)); return { FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), - xe_load_c, - xe_store_d, + mC, + mD, }; } @@ -211,8 +180,8 @@ class CollectiveEpilogue< template static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, - CudaHostAdapter* cuda_adapter = nullptr) { + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { return Status::kSuccess; } @@ -221,6 +190,8 @@ class CollectiveEpilogue< can_implement( ProblemShape const& problem_shapes, Arguments const& args) { + + // TODO: all these checks should be pushed down to individual copy atoms constexpr int copy_alignment_bits = 128; constexpr int batch_alignment_bits = 512; auto problem_shape_MNKL = append<4>(problem_shapes, 1); @@ -250,18 +221,18 @@ class CollectiveEpilogue< fusion_implementable = fusion_implementable && FusionCallbacks::can_implement(problem_shape_MNKL, args.thread); if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for XE 2D copy.\n"); + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem size doesn't meet the minimum alignment requirements for Xe 2D copy.\n"); } if (!fusion_implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem size doesn't meet the minimum requirements for FusionCallbacks.\n"); } return implementable && fusion_implementable; } CUTLASS_HOST_DEVICE - CollectiveEpilogue(Params const& params_, TensorStorage const& shared_storage_) + CollectiveEpilogue(Params const& params_, SharedStorage const& shared_storage_) : params(params_), fusion_callbacks(params_.thread, shared_storage_.thread) {} CUTLASS_DEVICE @@ -272,162 +243,174 @@ class CollectiveEpilogue< template< class ProblemShapeMNKL, - class TileShapeMNK, + class TileShapeMNK, /* compatibility with legacy epilogues */ class TileCoordMNKL, class Accumulator, - class TiledMma + class TiledMMA > CUTLASS_DEVICE void operator() ( ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_MNK, + TileShapeMNK, /* compatibility with legacy epilogues */ TileCoordMNKL tile_coord_mnkl, - Accumulator accumulators, - TiledMma tiled_mma, + Accumulator accumulators, + TiledMMA, int thread_idx) { - (void) tiled_mma; using namespace cute; - static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); - static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); - static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); - - using MmaAtomShape = typename TiledMma::AtomShape_MNK; - static constexpr auto BLK_M = get<0>(CtaTileMNK{}); - static constexpr auto BLK_N = get<1>(CtaTileMNK{}); - static constexpr auto BLK_K = get<2>(CtaTileMNK{}); - // static_assert(is_same_v, "assertation fail"); - static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); - static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); - static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); - - static_assert( - BLK_M % ATOM_M == 0 && - BLK_N % ATOM_N == 0 && - BLK_K % ATOM_K == 0, - "expected CTATileMNK to be evenly divided by TiledMma::ThrLayoutVMNK"); - static constexpr auto SG_M = BLK_M / ATOM_M; - static constexpr auto SG_N = BLK_N / ATOM_N; - static constexpr auto SG_K = BLK_K / ATOM_K; - using SubgroupTileShape = Shape; - - static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // A frags per sub_group - static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // B frags per sub_group - - static constexpr int FragmentSize = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize; - - // Indexing variables - auto [M, N, K, L] = problem_shape_mnkl; - auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; - auto m_sg = get_sub_group_id() / ATOM_N; - auto n_sg = get_sub_group_id() % ATOM_N; - - auto mn_shape = shape(typename decltype(params.xe_store_d)::Tiler_MN{}); - - auto sg_local_m_coord = get_sub_group_id() / ATOM_N; - auto sg_local_n_coord = get_sub_group_id() % ATOM_N; - - auto sg_m_coord = m_coord * ATOM_M + sg_local_m_coord; - auto sg_n_coord = n_coord * ATOM_N + sg_local_n_coord; - auto sg_coord = make_coord(sg_m_coord, sg_n_coord, k_coord, l_coord); - - bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); - - // Represent the full output tensor - Tensor mD_mnl = cute::get_xe_tensor(make_shape(M,N,L)); + using MMATile = decltype(take<0,2>(typename TiledMMA::AtomShape_MNK{})); - // Tile the output tensor per WG and select the tile for current WG - Tensor g_wg_D = local_tile(mD_mnl, take<0,2>(CtaTileMNK{}), make_coord(m_coord,n_coord,l_coord)); // (BLK_M,BLK_N) - - // Tile the output tensor per SG and select tile for the current SG - Tensor gD = local_tile(g_wg_D, take<0,2>(SubgroupTileShape{}), make_coord(m_sg,n_sg)); // (SG_M,SG_N) + static constexpr int EpiRPreferred = 8; + static constexpr int EpiCPreferred = 512 / cute::min(sizeof_bits_v, sizeof_bits_v); // 1 cache line + static constexpr int EpiR = cute::gcd(EpiRPreferred, get<0>(MMATile{})); + static constexpr int EpiC = cute::gcd(EpiCPreferred, get<1>(MMATile{})); - auto thread_xe_load_c = params.xe_load_c.get_thread_slice(thread_idx); - Tensor tCgC = thread_xe_load_c.partition_S(gD); + using DefaultEpilogueTile = Shape, Int>; + using EpilogueTile = conditional_t || is_same_v, + DefaultEpilogueTile, + EpilogueTile_>; - auto thread_xe_store_d = params.xe_store_d.get_thread_slice(thread_idx); - Tensor tCgD = thread_xe_store_d.partition_D(gD); + using DefaultCopyOpG2R = XE_LOAD_2D(EpilogueTile{})), cute::gcd(512 / CopyBitsC, get<1>(EpilogueTile{}))>; + using DefaultCopyOpR2G = XE_STORE_2D(EpilogueTile{})), cute::gcd(512 / CopyBitsD, get<1>(EpilogueTile{}))>; - Tensor trC = make_tensor(Shape>{}); - Tensor trD_compute = make_tensor(Shape>{}); + using ActualGmemTiledCopyC = replace_void_t; + using ActualGmemTiledCopyD = replace_void_t; - // Because Sm90 uses shared memory, they are not tied to using the same accumulator values - // for MMA and Epilogue. But because we are operating directly in the accumulators, we need to be - // sure that we are operating on the same values. - ThrCopy thread_g2r = params.xe_load_c.get_slice(thread_idx); + auto batch_idx = get<3>(tile_coord_mnkl); - // OOB predication for tile quantization "residue" - // Absolute coordinate tensors (dynamic) - Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) - Tensor cD = local_tile(mD_crd, take<0,2>(SubgroupTileShape{}), make_coord(sg_m_coord, sg_n_coord)); - Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) - Tensor tRS_cD_mn = thread_g2r.partition_S(flat_divide(cD_mn, mn_shape)); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) - - Tensor tRS_cD = make_coord_tensor(tRS_cD_mn.layout()); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); - // Get the fusion callbacks - // Arguments passed here relate to sub-group tiles, rather than CTA (work-group) tiles + auto MN = take<0,2>(problem_shape_mnkl); + auto cCD = make_identity_tensor(MN); // (m,n) + auto gCD = local_tile(cCD, take<0,2>(WGTileMNK{}), take<0,2>(tile_coord_mnkl)); // (m_in_wg_tile, n_in_wg_tile) + + auto thr_mma = TiledMMA{}.get_slice(thread_idx); + auto tCDgCD = thr_mma.partition_C(gCD); // (mma_v,mma_m,mma_n) -> coord + + // Tile accumulator into epilogue tiles. + auto mma_per_epi = shape_div(EpilogueTile{}, MMATile{}); + auto tiled_acc_layout = group<0,3>(prepend(flat_divide(remove<0>(accumulators.layout()), mma_per_epi), + get<0>(accumulators.layout()))); + auto tiled_acc = make_tensor(accumulators.data(), tiled_acc_layout); // ((mma_v,mma_m,mma_n),epi_m,epi_n) + + // Tile subgroup's TV coord layout into epilogue tiles. + auto sg_v_coord = prepend(flat_divide(remove<0>(tCDgCD.layout()), mma_per_epi), + get<0>(tCDgCD.layout())); // (mma_v,mma_m,mma_n,epi_m,epi_n) -> coord + + // Copy C/D one epilogue tile at a time. Prepare: + // - subgroup-scope TiledCopy objects + // - global coordinate tensors, partitioned into epilogue tiles + // - copy fragments + // - compute fragments (same layout as accumulator) + // Both copy and compute fragments are SubgroupTensors, holding coordinate mappings + // within the epilogue tile. + auto copy_c = make_block_2d_copy(ActualGmemTiledCopyC{}, params.mC(_,_,batch_idx)); + auto copy_d = make_block_2d_copy(ActualGmemTiledCopyD{}, params.mD(_,_,batch_idx)); + + int wi_idx = thread_idx % intel::sg_size; + auto thr_copy_c = copy_c.get_slice(wi_idx); + auto thr_copy_d = copy_d.get_slice(wi_idx); + + // Partition global coordinate tensors into epilogue tiles, matching + // the work-division from the TiledMMA. + auto gCD_epi_layout = append(append(make_identity_layout(EpilogueTile{}), + get<3>(sg_v_coord)), get<4>(sg_v_coord)); + auto gCD_epi = make_tensor(tCDgCD.data(), gCD_epi_layout); // (m,n,epi_m,epi_n) -> coord + + auto tCgC = thr_copy_c.partition_S(gCD_epi); // (atom_v,atom_m,atom_n,epi_m,epi_n) + auto tDgD = thr_copy_d.partition_D(gCD_epi); // (atom_v,atom_m,atom_n,epi_m,epi_n) + + auto tCrC = thr_copy_c.partition_sg_fragment_D(gCD_epi(_,_,0,0)); // (atom_v,atom_m,atom_n,epi_m,epi_n) + auto tDrD = thr_copy_d.partition_sg_fragment_S(gCD_epi(_,_,0,0)); // (atom_v,atom_m,atom_n,epi_m,epi_n) + + // Create C subgroup fragments for epilogue compute. + using AccTVLayout = decltype(thr_mma.partition_sg_fragment_C(gCD).tv_layout()); + auto cd_compute_tv = make_layout(get<0>(AccTVLayout{}), + sg_v_coord(_,_,_,_0{},_0{})); + + auto tCrC_compute_wi = make_fragment_like(tiled_acc(_,_0{},_0{})); + auto tCrC_compute = make_subgroup_tensor(tCrC_compute_wi, cd_compute_tv); // (mma_v,mma_m,mma_n) + + // Calculate residues. + auto residue_gCD = MN - gCD(_0{}); // (res_m, res_n) + auto residue_tCDgCD = MN - tCDgCD(_0{}); // (res_m, res_n) + + // Pass data to fusions. + // FIXME: Some Xe visitors expect subgroup tiles/coordinates here and should be updated to accept + // workgroup tiles/coordinates, like the NV code. Note that CuTe has no concept of a "subgroup tile." + // Work division within a TiledMMA is flexible, and a subgroup's data need not be contiguous. + // Instead, visitors should retrieve data coordinates within the WG tile via tDgD. constexpr bool RefSrc = true; - auto residue_mn = make_coord(M, N); //TODO(Codeplay): this is not correct - auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ - problem_shape_mnkl, - SubgroupTileShape{}, - sg_coord, - tiled_mma, - mn_shape, - params.xe_store_d, - cD, - residue_mn, - tRS_cD, - residue_mn, - trC, - thread_idx, - }; + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs { + problem_shape_mnkl, + WGTileMNK{}, + tile_coord_mnkl, + TiledMMA{}, + EpilogueTile{}, + copy_d, + gCD, + residue_gCD, + tDgD, + residue_tCDgCD, + tCrC_compute, + thread_idx, + }; auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); - cst_callbacks.begin(); + // Epilogue visitors work on cutlass::Arrays of values for better vectorization. + // For now, choose array size so there is one array per MMA atom C tile -- later + // we might want to make it configurable (FragmentSize in NV code). + using ElementAccumulator = typename Accumulator::element_type; + constexpr int ComputeVectorLen = size<0>(Accumulator{}); + auto tiled_acc_v = recast>(tiled_acc); - auto acc_frag = recast>(accumulators); - auto trD_compute_frag = recast>(trD_compute); + // Create D subgroup fragments for epilogue compute. + using FragmentVisit = decltype(cst_callbacks.visit(tiled_acc_v(0), 0, 0, 0)); + using ElementVisit = typename FragmentVisit::Element; - Tensor trD = make_tensor(Shape>{}); - auto trD_frag = recast>(trD); + auto tDrD_compute_wi = make_fragment_like(tiled_acc(_,_0{},_0{})); + auto tDrD_compute = make_subgroup_tensor(tDrD_compute_wi, cd_compute_tv); // (mma_v,mma_m,mma_n) + auto tDrD_compute_v = recast(tDrD_compute_wi); + + // Outer loops over epilogue tiles. + constexpr auto EpiTilesM = size<2>(gCD_epi); + constexpr auto EpiTilesN = size<3>(gCD_epi); + + cst_callbacks.begin(); - constexpr int ValuesLoaded = - FragsM * FragsN * FragmentSize * SubgroupSize * ATOM_M * ATOM_N * ATOM_K; - constexpr int MN = get<0>(CtaTileMNK{}) * get<1>(CtaTileMNK{}); - static_assert(ValuesLoaded == MN, "the total elements loaded by all threads should be the same as MxN" ); - - auto synchronize = [&] () {}; CUTLASS_PRAGMA_UNROLL - for (int epi_n = 0; epi_n < FragsN; epi_n++) { + for (int epi_m = 0; epi_m < EpiTilesM; epi_m++) { CUTLASS_PRAGMA_UNROLL - for (int epi_m = 0; epi_m < FragsM; epi_m++) { + for (int epi_n = 0; epi_n < EpiTilesN; epi_n++) { cst_callbacks.begin_loop(epi_m, epi_n); - if (is_C_load_needed) { - copy(params.xe_load_c, tCgC(_, epi_m, epi_n), trC); + // Load C + reorder. + if constexpr (is_source_supported) { + if (is_C_load_needed) { + copy(copy_c, tCgC(_,_,_,epi_m,epi_n), tCrC); + reorder(tCrC, tCrC_compute); + } } cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed); - auto acc_frag_mn = acc_frag(_, epi_m, epi_n); - + // Epilogue computation, one ComputeVectorLen-sized array at a time. CUTLASS_PRAGMA_UNROLL - for (int epi_v = 0; epi_v < size<0>(trD_compute_frag); ++epi_v) { - trD_compute_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); + for (int epi_v = 0; epi_v < size<0>(tiled_acc_v); ++epi_v) { + tDrD_compute_v(epi_v) = cst_callbacks.visit(tiled_acc_v(epi_v, epi_m, epi_n), + epi_v, epi_m, epi_n); } - cst_callbacks.reduce(nullptr, synchronize, epi_m, epi_n, (epi_m == FragsM - 1 && epi_n == FragsN - 1), trD_compute_frag); - + + bool last_epi = (epi_m == EpiTilesM - 1) && (epi_n == EpiTilesN - 1); + cst_callbacks.reduce(nullptr, [=]{}, epi_m, epi_n, last_epi, tDrD_compute_v); + + // Reorder D (possibly including data conversion) and store. if constexpr (is_destination_supported) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(trD_compute_frag); ++i) { - trD_frag(i) = cutlass::NumericArrayConverter{}(trD_compute_frag(i)); - } - copy(params.xe_store_d, trD, tCgD(_, epi_m, epi_n)); + reorder(tDrD_compute, tDrD); + copy(copy_d, tDrD, tDgD(_,_,_,epi_m,epi_n)); } - + cst_callbacks.end_loop(epi_m, epi_n); } } diff --git a/include/cutlass/epilogue/collective/xe_epilogue_legacy.hpp b/include/cutlass/epilogue/collective/xe_epilogue_legacy.hpp new file mode 100644 index 0000000000..dfc8edff96 --- /dev/null +++ b/include/cutlass/epilogue/collective/xe_epilogue_legacy.hpp @@ -0,0 +1,459 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/detail/layout.hpp" + +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class CtaTileMNK_, + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class FusionCallbacks_, + class CopyOpG2R_, + class SmemLayoutAtomC_, + class CopyOpS2R_, + class CopyOpR2G_, + class SmemLayoutAtomD_, + class CopyOpR2S_ +> +class CollectiveEpilogue< + IntelXeXMX16, + CtaTileMNK_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + FusionCallbacks_, + CopyOpG2R_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpR2G_, + SmemLayoutAtomD_, + CopyOpR2S_ +> { +public: + // + // Type Aliases + // + using DispatchPolicy = IntelXeXMX16; + using CtaTileMNK = CtaTileMNK_; + using FusionCallbacks = FusionCallbacks_; + using ElementC = ElementC_; + using StrideC = StrideC_; + using ElementD = ElementD_; + using StrideD = StrideD_; + using CopyOpG2R = CopyOpG2R_; + using SmemLayoutAtomC = SmemLayoutAtomC_; + using CopyOpS2R = CopyOpS2R_; + using CopyOpR2G = CopyOpR2G_; + using SmemLayoutAtomD = SmemLayoutAtomD_; + using CopyOpR2S = CopyOpR2S_; + + using ThreadEpilogueOp = typename fusion::FusionCallbacksTraits::Operation; + using GmemTiledCopyC = conditional_t, XE_2D_U32x8x16_LD_N, CopyOpG2R>; + using GmemTiledCopyD = cute::conditional_t && not cute::is_void_v, + CopyOpR2G, XE_2D_U32x8x16_ST_N>; + using ElementOutput = ElementD; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementAccumulator = ElementCompute; + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + + static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); + + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + + using CopyThreadShape = Shape<_1, Int>; + + using Trait_D = Copy_Traits; + using val_layout_store_D = decltype(make_layout(shape_div(typename Trait_D::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_D = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, val_layout_store_D{})); + +private: + constexpr static bool is_source_supported = not cute::is_void_v && not cute::is_void_v; + constexpr static bool is_destination_supported = not cute::is_void_v && not cute::is_void_v; + + using NonVoidElementC = conditional_t; + using Trait_C = Copy_Traits; + using NonVoidTrait_C = conditional_t; + using val_layout_load_C = decltype(make_layout(shape_div(typename NonVoidTrait_C::BlockShape{}, CopyThreadShape{}))); + using NonVoidValLayoutLoad_C = conditional_t; + using XE_Copy_C = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, NonVoidValLayoutLoad_C{})); + + constexpr static bool is_m_major_C = detail::is_m_major(); + constexpr static bool is_m_major_D = detail::is_m_major(); + +public: + + using EmptyType = cute::tuple<>; + using SmemCStorage = EmptyType; + using SmemDStorage = EmptyType; + + struct TensorStorageImpl: cute::tuple { + using FusionStorage = typename FusionCallbacks::SharedStorage; + FusionStorage thread; + }; + + struct SharedStorage { + using TensorStorage = TensorStorageImpl; + + TensorStorage tensors; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + + // Host side epilogue arguments + struct Arguments { + typename FusionCallbacks::Arguments thread{}; + ElementC const* ptr_C; + StrideC dC; + ElementD* ptr_D; + StrideD dD; + }; + + // Device side epilogue params + struct Params { + typename FusionCallbacks::Params thread{}; + XE_Copy_C xe_load_c; + XE_Copy_D xe_store_d; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + XE_Copy_C xe_load_c = {}; + if constexpr (is_source_supported) { + auto mC = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M, N, L), args.dC)); + xe_load_c = {xe_load_c.with(mC)}; + } + + XE_Copy_D xe_store_d = {}; + if constexpr (is_destination_supported) { + auto mD = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M, N, L), args.dD)); + xe_store_d = {xe_store_d.with(mD)}; + } + + return { + FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), + xe_load_c, + xe_store_d, + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + template + CUTLASS_HOST_DEVICE static bool + can_implement( + ProblemShape const& problem_shapes, + Arguments const& args) { + constexpr int copy_alignment_bits = 128; + constexpr int batch_alignment_bits = 512; + auto problem_shape_MNKL = append<4>(problem_shapes, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + bool fusion_implementable = true; + + if constexpr (is_destination_supported) { + constexpr int min_aligned_elements_D = copy_alignment_bits / sizeof_bits::value; + implementable &= cutlass::detail::check_alignment(cute::make_shape(M,N,L), args.dD); + if (L > 1) { + constexpr int min_batch_aligned_elements_D = batch_alignment_bits / sizeof_bits::value; + implementable &= get<2>(args.dD) % min_batch_aligned_elements_D == 0; + } + } + + if constexpr (is_source_supported) { + constexpr int min_aligned_elements_C = copy_alignment_bits / sizeof_bits::value; + implementable &= cutlass::detail::check_alignment(cute::make_shape(M,N,L), args.dC); + if (L > 1) { + constexpr int min_batch_aligned_elements_C = batch_alignment_bits / sizeof_bits::value; + implementable &= get<2>(args.dC) % min_batch_aligned_elements_C == 0; + } + } + + fusion_implementable = fusion_implementable && FusionCallbacks::can_implement(problem_shape_MNKL, args.thread); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for XE 2D copy.\n"); + } + + if (!fusion_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); + } + + return implementable && fusion_implementable; + } + + CUTLASS_HOST_DEVICE + CollectiveEpilogue(Params const& params_, TensorStorage const& shared_storage_) + : params(params_), fusion_callbacks(params_.thread, shared_storage_.thread) {} + + CUTLASS_DEVICE + bool + is_producer_load_needed() const { + return fusion_callbacks.is_producer_load_needed(); + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class Accumulator, + class TiledMma + > + CUTLASS_DEVICE void + operator() ( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + Accumulator accumulators, + TiledMma tiled_mma, + int thread_idx) { + + (void) tiled_mma; + using namespace cute; + + static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); + + using MmaAtomShape = typename TiledMma::AtomShape_MNK; + static constexpr auto BLK_M = get<0>(CtaTileMNK{}); + static constexpr auto BLK_N = get<1>(CtaTileMNK{}); + static constexpr auto BLK_K = get<2>(CtaTileMNK{}); + // static_assert(is_same_v, "assertation fail"); + static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); + + static_assert( + BLK_M % ATOM_M == 0 && + BLK_N % ATOM_N == 0 && + BLK_K % ATOM_K == 0, + "expected CTATileMNK to be evenly divided by TiledMma::ThrLayoutVMNK"); + static constexpr auto SG_M = BLK_M / ATOM_M; + static constexpr auto SG_N = BLK_N / ATOM_N; + static constexpr auto SG_K = BLK_K / ATOM_K; + using SubgroupTileShape = Shape; + + static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // A frags per sub_group + static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // B frags per sub_group + + static constexpr int FragmentSize = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize; + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + auto m_sg = get_sub_group_id() / ATOM_N; + auto n_sg = get_sub_group_id() % ATOM_N; + + auto mn_shape = shape(typename decltype(params.xe_store_d)::Tiler_MN{}); + + auto sg_local_m_coord = get_sub_group_id() / ATOM_N; + auto sg_local_n_coord = get_sub_group_id() % ATOM_N; + + auto sg_m_coord = m_coord * ATOM_M + sg_local_m_coord; + auto sg_n_coord = n_coord * ATOM_N + sg_local_n_coord; + auto sg_coord = make_coord(sg_m_coord, sg_n_coord, k_coord, l_coord); + + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // Represent the full output tensor + Tensor mD_mnl = cute::get_xe_tensor(make_shape(M,N,L)); + + // Tile the output tensor per WG and select the tile for current WG + Tensor g_wg_D = local_tile(mD_mnl, take<0,2>(CtaTileMNK{}), make_coord(m_coord,n_coord,l_coord)); // (BLK_M,BLK_N) + + // Tile the output tensor per SG and select tile for the current SG + Tensor gD = local_tile(g_wg_D, take<0,2>(SubgroupTileShape{}), make_coord(m_sg,n_sg)); // (SG_M,SG_N) + + auto thread_xe_load_c = params.xe_load_c.get_thread_slice(thread_idx); + Tensor tCgC = thread_xe_load_c.partition_S(gD); + + auto thread_xe_store_d = params.xe_store_d.get_thread_slice(thread_idx); + Tensor tCgD = thread_xe_store_d.partition_D(gD); + + Tensor trC = make_tensor(Shape>{}); + Tensor trD_compute = make_tensor(Shape>{}); + + // Because Sm90 uses shared memory, they are not tied to using the same accumulator values + // for MMA and Epilogue. But because we are operating directly in the accumulators, we need to be + // sure that we are operating on the same values. + ThrCopy thread_g2r = params.xe_load_c.get_slice(thread_idx); + + // OOB predication for tile quantization "residue" + // Absolute coordinate tensors (dynamic) + Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) + Tensor cD = local_tile(mD_crd, take<0,2>(SubgroupTileShape{}), make_coord(sg_m_coord, sg_n_coord)); + Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) + Tensor tRS_cD_mn = thread_g2r.partition_S(flat_divide(cD_mn, mn_shape)); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) + + Tensor tRS_cD = make_coord_tensor(tRS_cD_mn.layout()); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) + + // Get the fusion callbacks + // Arguments passed here relate to sub-group tiles, rather than CTA (work-group) tiles + constexpr bool RefSrc = true; + auto residue_mn = make_coord(M, N); //TODO(Codeplay): this is not correct + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ + problem_shape_mnkl, + SubgroupTileShape{}, + sg_coord, + tiled_mma, + mn_shape, + params.xe_store_d, + cD, + residue_mn, + tRS_cD, + residue_mn, + trC, + thread_idx, + }; + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + + cst_callbacks.begin(); + + auto acc_frag = recast>(accumulators); + using FragmentVisit = decltype(cst_callbacks.visit(acc_frag(0), 0, 0, 0)); + constexpr bool IsDirectR2S = cute::is_same_v>; + using RegisterElementD = cute::conditional_t; + auto trD_compute_frag = recast>(trD_compute); + + Tensor trD = make_tensor(Shape>{}); + auto trD_frag = recast>(trD); + + constexpr int ValuesLoaded = + FragsM * FragsN * FragmentSize * SubgroupSize * ATOM_M * ATOM_N * ATOM_K; + constexpr int MN = get<0>(CtaTileMNK{}) * get<1>(CtaTileMNK{}); + static_assert(ValuesLoaded == MN, "the total elements loaded by all threads should be the same as MxN" ); + + auto synchronize = [&] () {}; + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < FragsN; epi_n++) { + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < FragsM; epi_m++) { + cst_callbacks.begin_loop(epi_m, epi_n); + + //Instead of calling is_C_load_needed. We do heirachical check + //so that runtime check not there when ElementC is void + if constexpr (is_source_supported) { + if (is_C_load_needed) { + copy(params.xe_load_c, tCgC(_, epi_m, epi_n), trC); + } + } + + cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed); + + auto acc_frag_mn = acc_frag(_, epi_m, epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size<0>(trD_compute_frag); ++epi_v) { + trD_compute_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); + } + cst_callbacks.reduce(nullptr, synchronize, epi_m, epi_n, (epi_m == FragsM - 1 && epi_n == FragsN - 1), trD_compute_frag); + + if constexpr (is_destination_supported) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(trD_compute_frag); ++i) { + trD_frag(i) = cutlass::NumericArrayConverter{}(trD_compute_frag(i)); + } + copy(params.xe_store_d, trD, tCgD(_, epi_m, epi_n)); + } + + cst_callbacks.end_loop(epi_m, epi_n); + } + } + + cst_callbacks.end(); + } + +private: + Params const& params; + FusionCallbacks fusion_callbacks; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/dispatch_policy.hpp b/include/cutlass/epilogue/dispatch_policy.hpp index fbb4a40b52..3be03749b7 100644 --- a/include/cutlass/epilogue/dispatch_policy.hpp +++ b/include/cutlass/epilogue/dispatch_policy.hpp @@ -63,11 +63,15 @@ struct NoSmemWarpSpecialized1Sm {}; struct NoSmemWarpSpecialized2Sm {}; struct FastF32NoSmemWarpSpecialized1Sm : NoSmemWarpSpecialized1Sm {}; struct FastF32NoSmemWarpSpecialized2Sm : NoSmemWarpSpecialized2Sm {}; +struct BlockwiseNoSmemWarpSpecialized1Sm : NoSmemWarpSpecialized1Sm {}; +struct BlockwiseNoSmemWarpSpecialized2Sm : NoSmemWarpSpecialized2Sm {}; struct PtrArrayNoSmemWarpSpecialized1Sm : NoSmemWarpSpecialized1Sm {}; struct PtrArrayNoSmemWarpSpecialized2Sm : NoSmemWarpSpecialized2Sm {}; struct PtrArrayFastF32NoSmemWarpSpecialized1Sm : PtrArrayNoSmemWarpSpecialized1Sm {}; struct PtrArrayFastF32NoSmemWarpSpecialized2Sm : PtrArrayNoSmemWarpSpecialized2Sm {}; -// Blackwell TMA schedules +struct PtrArrayBlockwiseNoSmemWarpSpecialized1Sm : PtrArrayNoSmemWarpSpecialized1Sm {}; +struct PtrArrayBlockwiseNoSmemWarpSpecialized2Sm : PtrArrayNoSmemWarpSpecialized2Sm {}; +// Blackwell TMA schedules struct TmaWarpSpecialized1Sm {}; struct TmaWarpSpecialized2Sm {}; struct PtrArrayTmaWarpSpecialized1Sm : TmaWarpSpecialized1Sm {}; @@ -294,9 +298,15 @@ struct Sm120PtrArrayTmaWarpSpecialized { }; #if defined (SYCL_INTEL_TARGET) -// Specialization of the GEMM Epilogue for Intel Xe architectures. -// This version is tuned for operations with a subgroup size of 16. -// Suitable for use with Intel Battlemage (Xe2) and PVC (Xe) architectures. +// Standard Xe epilogue. +struct IntelXeGeneric { + static constexpr int SubgroupSize = 16; +}; + +struct IntelXeGenericGroup { + static constexpr int SubgroupSize = 16; +}; +// Legacy epilogues. struct IntelXeXMX16 { static constexpr int SubgroupSize = 16; }; diff --git a/include/cutlass/epilogue/fusion/xe_callbacks.hpp b/include/cutlass/epilogue/fusion/xe_callbacks.hpp index 5173d77000..54c3e2ab73 100644 --- a/include/cutlass/epilogue/fusion/xe_callbacks.hpp +++ b/include/cutlass/epilogue/fusion/xe_callbacks.hpp @@ -56,6 +56,7 @@ namespace cutlass::epilogue::fusion { ///////////////////////////////////////////////////////////////////////////////////////////////// + template < class ElementOutput_, class ElementCompute_, @@ -66,7 +67,7 @@ template < class EpilogueTile_ > struct FusionCallbacks< - epilogue::IntelXeXMX16, + epilogue::IntelXeGeneric, fusion::LinearCombination, CtaTileShapeMNK_, EpilogueTile_ @@ -109,7 +110,6 @@ struct FusionCallbacks< using Impl::Impl; }; - template < template class ActivationFn_, class ElementOutput_, @@ -121,7 +121,7 @@ template < class EpilogueTile_ > struct FusionCallbacks< - epilogue::IntelXeXMX16, + epilogue::IntelXeGeneric, fusion::LinCombEltAct, CtaTileShapeMNK_, EpilogueTile_ @@ -187,6 +187,7 @@ using XeLinCombSplitK = Sm90LinearCombination // beta * C + (alpha * acc) >; +// TODO: update split-k for new epilogues template < // int FragmentSize, class ElementOutput_, @@ -245,6 +246,7 @@ struct FusionCallbacks< using Impl::Impl; }; +// TODO: update softmax for new epilogues // D = softmax(alpha * acc + beta * C) template< // int FragmentSize, @@ -352,7 +354,7 @@ template < class CopyOpG2R > struct FusionCallbacks< - epilogue::IntelXeXMX16, + epilogue::IntelXeGeneric, fusion::LinCombDeEltAct< GmemLayoutTagAux, ActivationFn, ElementOutput_, ElementCompute_, ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle @@ -420,6 +422,45 @@ struct FusionCallbacks< using Impl::Impl; }; +// Temporary: provide default G2R operation for LinCombDeEltAct, for CollectiveBuilder. +// This will be removed once XeAuxLoad moves to new atoms and autodetects copy op. +template +using XeAuxLoadDefaultCopyOpG2R = conditional_t == 32, XE_2D_U32x8x16_LD_N, XE_2D_U16x8x16_LD_N>; + +template < + class GmemLayoutTagAux, template class ActivationFn, + class ElementOutput_, class ElementCompute_, class ElementAux, class ElementSource, class ElementScalar, + int AlignmentAux, FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, class EpilogueTile +> +struct FusionCallbacks< + epilogue::IntelXeGeneric, + fusion::LinCombDeEltAct< + GmemLayoutTagAux, ActivationFn, ElementOutput_, ElementCompute_, + ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : FusionCallbacks< + epilogue::IntelXeGeneric, + fusion::LinCombDeEltAct< + GmemLayoutTagAux, ActivationFn, ElementOutput_, ElementCompute_, + ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile, + XeAuxLoadDefaultCopyOpG2R +> { + using FusionCallbacks< + epilogue::IntelXeGeneric, fusion::LinCombDeEltAct< + GmemLayoutTagAux, ActivationFn, ElementOutput_, ElementCompute_, + ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle + >, + CtaTileShapeMNK, EpilogueTile, XeAuxLoadDefaultCopyOpG2R + >::FusionCallbacks; +}; + + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // D = alpha * acc + beta * C + per-row bias template < @@ -434,7 +475,7 @@ template < class EpilogueTile_ > struct FusionCallbacks< - epilogue::IntelXeXMX16, + epilogue::IntelXeGeneric, fusion::LinCombPerRowBias, CtaTileShapeMNK_, EpilogueTile_ @@ -524,7 +565,7 @@ template < class EpilogueTile_ > struct FusionCallbacks< - epilogue::IntelXeXMX16, + epilogue::IntelXeGeneric, fusion::LinCombPerColBias, CtaTileShapeMNK_, EpilogueTile_ @@ -533,7 +574,7 @@ struct FusionCallbacks< using Impl = XeLinCombPerColBias< _1{}, CtaTileShapeMNK_, - EpilogueTile_, + EpilogueTile_, typename cutlass::detail::get_unpacked_element_type::type, ElementCompute_, ElementBias_, ElementSource_, ElementScalar_, AlignmentBias_, RoundStyle_>; @@ -580,6 +621,7 @@ struct FusionCallbacks< using Impl::Impl; }; +// TODO: move to new epilogue template < int TopK, class ElementOutput_, @@ -602,8 +644,8 @@ struct FusionCallbacks< using ElementCompute = ElementCompute_; using ElementSource = ElementSource_; using ElementScalar = ElementScalar_; - using Impl = Sm90LinCombTopKSoftmaxCol::type, + using Impl = Sm90LinCombTopKSoftmaxCol::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; using Operation = fusion::LinCombTopKSoftmaxCol; @@ -655,7 +697,7 @@ template < class EpilogueTile_ > struct FusionCallbacks< - epilogue::IntelXeXMX16, + epilogue::IntelXeGeneric, fusion::LinCombPerRowBiasEltAct< ActivationFn_, ElementOutput_, ElementCompute_, ElementBias_, ElementSource_, ElementScalar_, AlignmentBias_, RoundStyle_ >, @@ -733,7 +775,7 @@ template < class EpilogueTile_ > struct FusionCallbacks< - epilogue::IntelXeXMX16Group, + epilogue::IntelXeGenericGroup, fusion::LinearCombination, CtaTileShapeMNK_, EpilogueTile_ @@ -780,6 +822,233 @@ struct FusionCallbacks< ///////////////////////////////////////////////////////////////////////////////////////////////// +// Callbacks for legacy epilogues (deprecated). + +template < + class ElementOutput_, + class ElementCompute_, + class ElementSource_, + class ElementScalar_, + FloatRoundStyle RoundStyle_, + class CtaTileShapeMNK_, + class EpilogueTile_ +> +struct FusionCallbacks< + epilogue::IntelXeXMX16, + fusion::LinearCombination, + CtaTileShapeMNK_, + EpilogueTile_ +> : FusionCallbacks< + epilogue::IntelXeGeneric, + fusion::LinearCombination, + CtaTileShapeMNK_, + EpilogueTile_ +> { + using FusionCallbacks< + epilogue::IntelXeGeneric, + fusion::LinearCombination, + CtaTileShapeMNK_, + EpilogueTile_>::FusionCallbacks; +}; + +template < + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementSource_, + class ElementScalar_, + FloatRoundStyle RoundStyle_, + class CtaTileShapeMNK_, + class EpilogueTile_ +> +struct FusionCallbacks< + epilogue::IntelXeXMX16, + fusion::LinCombEltAct, + CtaTileShapeMNK_, + EpilogueTile_ +> : FusionCallbacks< + epilogue::IntelXeGeneric, + fusion::LinCombEltAct, + CtaTileShapeMNK_, + EpilogueTile_ +> { + using FusionCallbacks< + epilogue::IntelXeGeneric, + fusion::LinCombEltAct, + CtaTileShapeMNK_, + EpilogueTile_>::FusionCallbacks; +}; + +template < + class GmemLayoutTagAux, + template class ActivationFn, + class ElementOutput_, + class ElementCompute_, + class ElementAux, + class ElementSource, + class ElementScalar, + int AlignmentAux, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class CopyOpG2R +> +struct FusionCallbacks< + epilogue::IntelXeXMX16, + fusion::LinCombDeEltAct< + GmemLayoutTagAux, ActivationFn, ElementOutput_, ElementCompute_, + ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile, + CopyOpG2R +> : FusionCallbacks< + epilogue::IntelXeGeneric, + fusion::LinCombDeEltAct< + GmemLayoutTagAux, ActivationFn, ElementOutput_, ElementCompute_, + ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile, + CopyOpG2R +> { + using FusionCallbacks< + epilogue::IntelXeGeneric, + fusion::LinCombDeEltAct< + GmemLayoutTagAux, ActivationFn, ElementOutput_, ElementCompute_, + ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile, + CopyOpG2R + >::FusionCallbacks; +}; + +template < + class ElementOutput_, + class ElementCompute_, + class ElementBias_, + class ElementSource_, + class ElementScalar_, + int AlignmentBias_, + FloatRoundStyle RoundStyle_, + class CtaTileShapeMNK_, + class EpilogueTile_ +> +struct FusionCallbacks< + epilogue::IntelXeXMX16, + fusion::LinCombPerRowBias, + CtaTileShapeMNK_, + EpilogueTile_ +> : FusionCallbacks< + epilogue::IntelXeGeneric, + fusion::LinCombPerRowBias, + CtaTileShapeMNK_, + EpilogueTile_ +> { + using FusionCallbacks< + epilogue::IntelXeGeneric, + fusion::LinCombPerRowBias, + CtaTileShapeMNK_, + EpilogueTile_ + >::FusionCallbacks; +}; + +template < + class ElementOutput_, + class ElementCompute_, + class ElementBias_, + class ElementSource_, + class ElementScalar_, + int AlignmentBias_, + FloatRoundStyle RoundStyle_, + class CtaTileShapeMNK_, + class EpilogueTile_ +> +struct FusionCallbacks< + epilogue::IntelXeXMX16, + fusion::LinCombPerColBias, + CtaTileShapeMNK_, + EpilogueTile_ +> : FusionCallbacks< + epilogue::IntelXeGeneric, + fusion::LinCombPerColBias, + CtaTileShapeMNK_, + EpilogueTile_ +> { + using FusionCallbacks< + epilogue::IntelXeGeneric, + fusion::LinCombPerColBias, + CtaTileShapeMNK_, + EpilogueTile_ + >::FusionCallbacks; +}; + +template < + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementBias_, + class ElementSource_, + class ElementScalar_, + int AlignmentBias_, + FloatRoundStyle RoundStyle_, + class CtaTileShapeMNK_, + class EpilogueTile_ +> +struct FusionCallbacks< + epilogue::IntelXeXMX16, + fusion::LinCombPerRowBiasEltAct< + ActivationFn_, ElementOutput_, ElementCompute_, ElementBias_, ElementSource_, ElementScalar_, AlignmentBias_, RoundStyle_ + >, + CtaTileShapeMNK_, + EpilogueTile_ +> : FusionCallbacks< + epilogue::IntelXeGeneric, + fusion::LinCombPerRowBiasEltAct< + ActivationFn_, ElementOutput_, ElementCompute_, ElementBias_, ElementSource_, ElementScalar_, AlignmentBias_, RoundStyle_ + >, + CtaTileShapeMNK_, + EpilogueTile_ +> { + using FusionCallbacks< + epilogue::IntelXeGeneric, + fusion::LinCombPerRowBiasEltAct< + ActivationFn_, ElementOutput_, ElementCompute_, ElementBias_, ElementSource_, ElementScalar_, AlignmentBias_, RoundStyle_ + >, + CtaTileShapeMNK_, + EpilogueTile_ + >::FusionCallbacks; +}; + +template < + class ElementOutput_, + class ElementCompute_, + class ElementSource_, + class ElementScalar_, + FloatRoundStyle RoundStyle_, + class CtaTileShapeMNK_, + class EpilogueTile_ +> +struct FusionCallbacks< + epilogue::IntelXeXMX16Group, + fusion::LinearCombination, + CtaTileShapeMNK_, + EpilogueTile_ +> : FusionCallbacks< + epilogue::IntelXeGenericGroup, + fusion::LinearCombination, + CtaTileShapeMNK_, + EpilogueTile_ +> { + using FusionCallbacks< + epilogue::IntelXeGenericGroup, + fusion::LinearCombination, + CtaTileShapeMNK_, + EpilogueTile_ + >::FusionCallbacks; +}; + } // namespace cutlass::epilogue::fusion ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/fusion/xe_visitor.hpp b/include/cutlass/epilogue/fusion/xe_visitor.hpp index ac78664561..d76dd1ac86 100644 --- a/include/cutlass/epilogue/fusion/xe_visitor.hpp +++ b/include/cutlass/epilogue/fusion/xe_visitor.hpp @@ -189,7 +189,7 @@ struct XeAuxLoad { CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { auto xe_copy_aux = params_ptr->xe_load_aux; - Tensor trAux = make_tensor_like(args.tCrC); + Tensor trAux = make_tensor_like(args.tCrC.tensor()); auto [M, N, K, L] = args.problem_shape_mnkl; auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; @@ -430,7 +430,7 @@ struct XeRowBroadcast { Tensor tCgRow_static = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) mRow_static, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, id_in_sg); Tensor tCrRow = make_tensor_like(tCgRow_static); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - + return ConsumerStoreCallbacks(tCgRow, tCrRow, args.tCcD, args.residue_tCcD, params); } }; diff --git a/include/cutlass/gemm/collective/builders/sm100_blockscaled_mixed_tma_cpasync_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_blockscaled_mixed_tma_cpasync_umma_builder.inl new file mode 100644 index 0000000000..72cc30616e --- /dev/null +++ b/include/cutlass/gemm/collective/builders/sm100_blockscaled_mixed_tma_cpasync_umma_builder.inl @@ -0,0 +1,274 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/gemm/collective/builders/sm100_common.inl" + +#include "cutlass/gemm/collective/collective_builder_decl.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +namespace detail { + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template < + int CapacityBytes, + class ElementA, + class ElementB, + class TileShapeMNK, + class TileShapeSFA, + class TileShapeSFB, + int stages +> +constexpr int +sm100_compute_stage_count_or_override_blockscaled_mixed_tma_cpasync(StageCount stage_count) { + return stages; +} + +template < + int CapacityBytes, + class ElementA, + class ElementB, + class TileShapeMNK, + class TileShapeSFA, + class TileShapeSFB, + int carveout_bytes +> +constexpr int +sm100_compute_stage_count_or_override_blockscaled_mixed_tma_cpasync(StageCountAutoCarveout stage_count) { + // For MXF8F6F4 MMA, ElementA/B will be passed in as uint8_t + // Each stage include (CollectiveMma::SharedStorage) + // 1. smem for A and smem for B (CollectiveMma::SharedStorage::TensorStorage) + // 2. one MainloopPipeline = PipelineTmaUmmaAsync (CollectiveMma::SharedStorage::SharedStorage) + // 3. smem for SFB and smem for SFB (CollectiveMma::SharedStorage::TensorStorage, independent of input size b.c. sizeof(sf) is fixed) + constexpr auto mainloop_pipeline_bytes = + sizeof(typename cutlass::PipelineTmaUmmaAsync<1>::SharedStorage) + + sizeof(typename cutlass::PipelineUmmaConsumerAsync<1>::SharedStorage); + + constexpr auto a_bits = cute::sizeof_bits_v; + constexpr auto b_bits = cute::sizeof_bits_v; + constexpr auto stage_sfa_bytes = size(filter_zeros(TileShapeSFA{})); + constexpr auto stage_sfb_bytes = size(filter_zeros(TileShapeSFB{})); + + constexpr int stage_bytes = + cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + static_cast(mainloop_pipeline_bytes + stage_sfa_bytes + stage_sfb_bytes); + + return (CapacityBytes - carveout_bytes) / stage_bytes; +} + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ElementPairA, + class GmemLayoutATag, + int AlignmentA, + class ElementPairB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class BuilderScheduleTag +> +struct CollectiveBuilder< + arch::Sm100, + arch::OpClassBlockScaledTensorOp, + ElementPairA, + GmemLayoutATag, + AlignmentA, + ElementPairB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + ClusterShape_MNK, // Static cluster shape (_1, _1, _1) + StageCountType, + BuilderScheduleTag, + cute::enable_if_t > +> +{ + using ElementSFA = typename detail::blockscaled::blockscaled_type::sf_type; + using ElementSFB = typename detail::blockscaled::blockscaled_type::sf_type; + using ElementA = typename detail::blockscaled::blockscaled_type::data_type; + using ElementB = typename detail::blockscaled::blockscaled_type::data_type; + using ElementSF = ElementSFA; + + static constexpr cute::UMMA::Major UmmaMajorA = cutlass::gemm::collective::detail::tag_to_umma_major_A(); + static constexpr cute::UMMA::Major UmmaMajorB = cutlass::gemm::collective::detail::tag_to_umma_major_B(); + + static_assert(cute::is_static_v, "TileShape has to be static"); + static_assert(detail::blockscaled::check_input_datatypes(), "Incorrect input types"); + + static constexpr bool is_2sm = false; // detail::blockscaled::is_2sm(); + static constexpr auto Instr = detail::blockscaled::select_instr(); + + using TiledMma = typename cutlass::gemm::collective::detail::TrivialBlockscaledMma::type; + + static constexpr bool UseMxf8f6f4 = Instr == detail::blockscaled::BlockScaledInstr::MXF4F6F8; + + static_assert(UseMxf8f6f4 || (cutlass::gemm::detail::is_k_major_A() && cutlass::gemm::detail::is_k_major_B()), "Only MMA.MXF8F6F4 supports non-K major inputs"); + + // Data type used by MMA instruction + using ElementAMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element()); + using ElementBMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element()); + + static_assert(detail::sm1xx_gemm_check_for_f8f6f4_mix8bit_requirement(), + "TileSize and MNK Major does not met with MMA Mix 8-bit TMA load requirement" ); + + static constexpr uint32_t SFVectorSize = TiledMma::SFVecSize; + + using ElementAMma_SmemAllocType = cute::conditional_t; + using ElementBMma_SmemAllocType = cute::conditional_t; + + // using TiledMma = decltype(detail::sm100_make_trivial_tiled_mma< + // ElementAMma, ElementBMma, ElementAccumulator, + // decltype(cute::product_each(TileShape_MNK{})), ClusterShape_MNK, + // UmmaMajorA, UmmaMajorB, BuilderScheduleTag>()); + + using AtomThrID = typename TiledMma::AtomThrID; + using Sm1xxBlkScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig; + + // ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K) + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + // ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K) + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + + // Assigning 4 warps for mainloop load of B + static constexpr int NumLoadThreadsCpAsync = 128; + + + using SmemShapeA_M = decltype(shape_div(shape<0>(TileShape_MNK{}), shape_div(shape<0>(TileShape_MNK{}), size<0>(TileShape_MNK{}) / size(AtomThrID{})))); + using SmemShapeA_K = decltype(cute::get<2>(TileShape_MNK{})); + + + using GmemTiledCopyA = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_A( + ClusterShape_MNK{}, AtomThrID{})); + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + UmmaMajorA, ElementAMma_SmemAllocType, SmemShapeA_M, SmemShapeA_K>()); + + using AlignmentTypeB = cute::uint_byte_t(cutlass::sizeof_bits::value) * AlignmentB / 8>; + using GmemCopyAtomB = cute::Copy_Atom, ElementB>; + using GmemTiledCopyB = decltype(detail::make_simt_gmem_tiled_copy< + GmemCopyAtomB, NumLoadThreadsCpAsync, AlignmentB, TagToStrideB_t, + decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + using BlockTileB_N = decltype(cute::size<0,0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{})); + using BlockTileB_K = decltype(cute::size<0,1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{})); + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + UmmaMajorB, ElementBMma_SmemAllocType, BlockTileB_N, BlockTileB_K>()); + + using GmemTiledCopySFA = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_A( + ClusterShape_MNK{}, AtomThrID{})); + using GmemTiledCopySFB = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_SFB( + ClusterShape_MNK{}, AtomThrID{})); + + using GmemTiledCopyPairA = decltype(cute::make_tuple(GmemTiledCopyA{}, GmemTiledCopySFA{})); + using GmemTiledCopyPairB = decltype(cute::make_tuple(GmemTiledCopyB{}, GmemTiledCopySFB{})); + + using SmemLayoutAtomSFA = decltype(Sm1xxBlkScaledConfig::deduce_smem_layoutSFA(TiledMma{}, TileShape_MNK{})); + using SmemLayoutAtomSFB = decltype(Sm1xxBlkScaledConfig::deduce_smem_layoutSFB(TiledMma{}, TileShape_MNK{})); + using SmemLayoutAtomsA = decltype(cute::make_tuple(SmemLayoutAtomA{}, SmemLayoutAtomSFA{})); + using SmemLayoutAtomsB = decltype(cute::make_tuple(SmemLayoutAtomB{}, SmemLayoutAtomSFB{})); + + + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideB = cutlass::gemm::TagToStrideB_t; + using InternalStrideA = cute::remove_pointer_t; + using InternalStrideB = cute::remove_pointer_t; + using InternalLayoutSFA = decltype(Sm1xxBlkScaledConfig::deduce_layoutSFA()); + using InternalLayoutSFB = decltype(Sm1xxBlkScaledConfig::deduce_layoutSFB()); + using LayoutSFA = cute::conditional_t, InternalLayoutSFA, InternalLayoutSFA *>; + using LayoutSFB = cute::conditional_t, InternalLayoutSFB, InternalLayoutSFB *>; + using StridePairA = decltype(cute::make_tuple(StrideA{}, LayoutSFA{})); + using StridePairB = decltype(cute::make_tuple(StrideB{}, LayoutSFB{})); + + static constexpr uint32_t AccumulatorPipelineStageCount = 2; + // Calculate scheduler pipeline stages. Having one more stage than the accumulator allows more latency hiding. + static constexpr uint32_t SchedulerPipelineStageCount = AccumulatorPipelineStageCount + 1; + + // AccumulatorPipeline = PipelineUmmaAsync + static constexpr auto AccumulatorPipelineStorage = sizeof(typename cutlass::PipelineUmmaAsync::SharedStorage); + // CLCPipeline = PipelineCLCFetchAsync + static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); + // CLC (scheduler) response + static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize; + // Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage + static constexpr auto KernelSmemCarveout = static_cast( AccumulatorPipelineStorage + + CLCPipelineStorage + + CLCResponseStorage); + // Reduce SMEM capacity available for buffers considering barrier allocations. + static constexpr int Sm100ReducedSmemCapacityBytes = cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + + using SmemTileShape = cute::Shape; + + static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_blockscaled_mixed_tma_cpasync< + Sm100ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, SmemLayoutAtomSFA, SmemLayoutAtomSFB>(StageCountType{}); + static_assert(PipelineStages > 0, "Smem usage is too high. Can't create any SMEM buffers for A, B, SFA, and SFB."); + + using CollectiveOp = cutlass::gemm::collective::CollectiveMma< + cutlass::gemm::MainloopSm100UmmaMixedTmaCpAsyncWarpSpecializedBlockScaled< + PipelineStages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape_MNK>, + TileShape_MNK, + cute::tuple, + StridePairA, + cute::tuple, + StridePairB, + TiledMma, + GmemTiledCopyPairA, + SmemLayoutAtomsA, + void, + cute::identity, + GmemTiledCopyPairB, + SmemLayoutAtomsB, + void, + cute::identity + >; +}; + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl index 3556fad6dd..68600c6779 100644 --- a/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl @@ -120,6 +120,7 @@ struct CollectiveBuilder< BuilderScheduleTag, cute::enable_if_t< // Blockscaled Gemm + (not cute::is_same_v) && (cute::is_base_of_v || cute::is_same_v) && diff --git a/include/cutlass/gemm/collective/builders/sm100_mixed_tma_cpasync_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_mixed_tma_cpasync_umma_builder.inl new file mode 100644 index 0000000000..5fd1201a6e --- /dev/null +++ b/include/cutlass/gemm/collective/builders/sm100_mixed_tma_cpasync_umma_builder.inl @@ -0,0 +1,171 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/gemm/collective/builders/sm100_common.inl" + +#include "cutlass/gemm/collective/collective_builder_decl.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class BuilderScheduleTag +> +struct CollectiveBuilder< + arch::Sm100, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + ClusterShape_MNK, // Static cluster shape (_1, _1, _1) + StageCountType, + BuilderScheduleTag, + cute::enable_if_t > +> +{ + static_assert(cute::is_static_v, "TileShape has to be static"); + + static constexpr cute::UMMA::Major UmmaMajorA = cutlass::gemm::collective::detail::tag_to_umma_major_A(); + static constexpr cute::UMMA::Major UmmaMajorB = cutlass::gemm::collective::detail::tag_to_umma_major_B(); + + // Data type used by MMA instruction + using ElementAMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element()); + using ElementBMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element()); + + using ElementAMma_SmemAllocType = cute::conditional_t < 8, uint8_t, ElementAMma>; + using ElementBMma_SmemAllocType = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using TiledMma = decltype(detail::sm100_make_trivial_tiled_mma< + ElementAMma, ElementBMma, ElementAccumulator, + decltype(cute::product_each(TileShape_MNK{})), ClusterShape_MNK, + UmmaMajorA, UmmaMajorB, BuilderScheduleTag>()); + + using AtomThrID = typename TiledMma::AtomThrID; + + // ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K) + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + // ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K) + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + + // Assigning 4 warps for mainloop load of B + static constexpr int NumLoadThreadsCpAsync = 128; + + + using SmemShapeA_M = decltype(shape_div(shape<0>(TileShape_MNK{}), shape_div(shape<0>(TileShape_MNK{}), size<0>(TileShape_MNK{}) / size(AtomThrID{})))); + using SmemShapeA_K = decltype(cute::get<2>(TileShape_MNK{})); + + + using GmemTiledCopyA = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_A( + ClusterShape_MNK{}, AtomThrID{})); + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + UmmaMajorA, ElementAMma_SmemAllocType, SmemShapeA_M, SmemShapeA_K>()); + + using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * AlignmentB>; + using GmemCopyAtomB = cute::Copy_Atom, ElementB>; + using GmemTiledCopyB = decltype(detail::make_simt_gmem_tiled_copy< + GmemCopyAtomB, NumLoadThreadsCpAsync, AlignmentB, TagToStrideB_t, + decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + using BlockTileB_N = decltype(cute::size<0,0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{})); + using BlockTileB_K = decltype(cute::size<0,1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{})); + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + UmmaMajorB, ElementBMma_SmemAllocType, BlockTileB_N, BlockTileB_K>()); + + static constexpr uint32_t AccumulatorPipelineStageCount = 2; + // Calculate scheduler pipeline stages. Having one more stage than the accumulator allows more latency hiding. + static constexpr uint32_t SchedulerPipelineStageCount = AccumulatorPipelineStageCount + 1; + + // AccumulatorPipeline = PipelineUmmaAsync + static constexpr auto AccumulatorPipelineStorage = sizeof(typename cutlass::PipelineUmmaAsync::SharedStorage); + // CLCPipeline = PipelineCLCFetchAsync + static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); + // CLC (scheduler) response + static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize; + // Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage + static constexpr auto KernelSmemCarveout = static_cast( AccumulatorPipelineStorage + + CLCPipelineStorage + + CLCResponseStorage); + // Reduce SMEM capacity available for buffers considering barrier allocations. + static constexpr int Sm100ReducedSmemCapacityBytes = cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + + using SmemTileShape = cute::Shape; + using MainloopPipelineStorage = typename cutlass::PipelineUmmaConsumerAsync<1>::SharedStorage; + + static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override< + Sm100ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, MainloopPipelineStorage>(StageCountType{}); + + using CollectiveOp = cutlass::gemm::collective::CollectiveMma< + cutlass::gemm::MainloopSm100UmmaMixedTmaCpAsyncWarpSpecialized< + PipelineStages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape_MNK>, + TileShape_MNK, + ElementA, + cutlass::gemm::TagToStrideA_t, + ElementB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + void, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + void, + cute::identity + >; +}; + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl index 5edf637e22..dfd4fece32 100644 --- a/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl @@ -184,6 +184,7 @@ struct CollectiveBuilder< not cute::is_complex_v && not cute::is_complex_v && // Dense Gemm / PtrArrayDenseGemm ( + (not cute::is_same_v) && (not cute::is_same_v) && (cute::is_base_of_v || cute::is_same_v)) && diff --git a/include/cutlass/gemm/collective/builders/sm1xx_common.inl b/include/cutlass/gemm/collective/builders/sm1xx_common.inl index f63842c08b..629f95dcec 100644 --- a/include/cutlass/gemm/collective/builders/sm1xx_common.inl +++ b/include/cutlass/gemm/collective/builders/sm1xx_common.inl @@ -502,6 +502,7 @@ check_input_datatypes() { || (cute::is_same_v) || (cute::is_same_v) || (cute::is_same_v) + || (cute::is_same_v) // SM100 BS ptr_array || (cute::is_same_v) || (cute::is_same_v) @@ -578,6 +579,8 @@ check_input_datatypes() { ((SfVectorSizeA == 32 && cute::is_same_v) || (SfVectorSizeA == 32 && cute::is_same_v) || (SfVectorSizeA == 32 && cute::is_same_v) + || (SfVectorSizeA == 32 && cute::is_same_v) + || (SfVectorSizeA == 32 && cute::is_same_v) || (SfVectorSizeA == 32 && cute::is_base_of_v) || (SfVectorSizeA == 32 && cute::is_base_of_v) || (SfVectorSizeA == 64 && cute::is_base_of_v) diff --git a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl index c75af3acb9..a1ea257e7f 100644 --- a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl @@ -1069,10 +1069,10 @@ struct CollectiveBuilder< StageCountType, KernelScheduleType, cute::enable_if_t< - (cute::is_same_v or - cute::is_same_v or - cute::is_same_v or - cute::is_same_v) and + (cute::is_same_v or + cute::is_same_v or + cute::is_same_v or + cute::is_same_v) and not detail::is_use_rmem_A() > > { @@ -1105,7 +1105,7 @@ struct CollectiveBuilder< cute::is_base_of_v); static constexpr bool IsFP8Input = detail::is_input_fp8(); - static_assert(IsFP8Input, "Warp Specialized gemm with FP8 BlockScaled Accumulator is only compatible with FP8 Blocked Scaled version right now."); + static_assert(IsFP8Input, "Warp Specialized gemm with FP8 Blockwise (Software) Scaling is only compatible with FP8 inputs version right now."); // For fp32 types, map to tf32 MMA value type using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; @@ -1146,8 +1146,8 @@ struct CollectiveBuilder< static constexpr int PipelineStages = detail::compute_stage_count_with_blockwise_scale(StageCountType{}); using DispatchPolicy = cute::conditional_t, - MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8>; + MainloopSm90ArrayTmaGmmaWarpSpecializedBlockwise, + MainloopSm90TmaGmmaWarpSpecializedBlockwiseFP8>; using SmemCopyAtomA = void; using SmemCopyAtomB = void; diff --git a/include/cutlass/gemm/collective/builders/xe_mma_builder.inl b/include/cutlass/gemm/collective/builders/xe_mma_builder.inl index c2ffaa5a5f..4afe206853 100644 --- a/include/cutlass/gemm/collective/builders/xe_mma_builder.inl +++ b/include/cutlass/gemm/collective/builders/xe_mma_builder.inl @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -38,113 +39,31 @@ namespace cutlass::gemm::collective { - // Intel Xe 3 stage pipeline, using prefetch - // Also the auto builder - -template -constexpr auto get_num_atoms(T_m tile_m, T_n tile_n){ - constexpr auto atom_m = get<0>(typename MMAAtom::Shape_MNK{}); - constexpr auto atom_n = get<1>(typename MMAAtom::Shape_MNK{}); - // try to create the biggest number of atoms possible, up to 32, trying to fit the most, up to 8 in m dimension - auto atoms_m_tmp = cute::min(tile_m / atom_m, _8{}); // at most 8 - auto atoms_n = cute::min(tile_n / atom_n, _32{} / atoms_m_tmp); // at most however many are not in m out of 32 - auto atoms_m = cute::min(tile_m / atom_m, _32{} / atoms_n); // at most however many are not in n out of 32 - return make_shape(atoms_m, atoms_n); +// Choose an appropriate XE_DPAS_TT overload for a given +// combination of types. +template +auto +xe_dpas_tt_op_selector() { + if constexpr (is_complete_v>) + return XE_DPAS_TT{}; + else if constexpr (is_same_v) + return XE_DPAS_TT{}; + else /* Use f16 by default as upconversion sequences are typically faster */ + return XE_DPAS_TT{}; } -template -constexpr auto select_copy_atom_16b(T_m tile_m, T_n tile_n){ - #define RETURN_ATOM(WIDTH, HEIGHT, LETTER) \ - return XE_2D_U16x##WIDTH##x##HEIGHT##_LD_##LETTER {}; - - if constexpr(is_t){ - // tile_m and tile_n have swapped role in case of _T - static_assert(tile_n % 16 == 0 && "Invalid tile_m"); - if constexpr(tile_m == 8){ - RETURN_ATOM(16, 8, T) - } else if constexpr(tile_m % 16 == 0){ - RETURN_ATOM(16, 16, T) - } else{ - static_assert(dependent_false && "Invalid tile_n"); - } - } else if constexpr(is_v){ - #define SELECT_HEIGHT_V(WIDTH) \ - if constexpr(tile_n == 16){ \ - RETURN_ATOM(WIDTH, 16, V) \ - } else if constexpr(tile_n % 32 == 0){ \ - RETURN_ATOM(WIDTH, 32, V) \ - } else{ \ - static_assert(dependent_false && "Invalid tile_n"); \ - } - - if constexpr(tile_m == 16){ - SELECT_HEIGHT_V(16) - } else if constexpr(tile_m % 32 == 0){ - SELECT_HEIGHT_V(32) - } else{ - static_assert(dependent_false && "Invalid tile_m"); - } - #undef SELECT_HEIGHT_V - } else{ // _N - #define SELECT_WIDTH_N(HEIGHT) \ - if constexpr(tile_m == 1){ \ - RETURN_ATOM(1, HEIGHT, N) \ - } else if constexpr(tile_m == 2){ \ - RETURN_ATOM(2, HEIGHT, N) \ - } else if constexpr(tile_m == 4){ \ - RETURN_ATOM(4, HEIGHT, N) \ - } else if constexpr(tile_m == 8){ \ - RETURN_ATOM(8, HEIGHT, N) \ - } else if constexpr(tile_m == 16){ \ - RETURN_ATOM(16, HEIGHT, N) \ - } else if constexpr(tile_m % 32 == 0){ \ - RETURN_ATOM(32, HEIGHT, N) \ - } else { \ - static_assert(dependent_false && "Invalid tile_m"); \ - } - - if constexpr(tile_n == 16){ - SELECT_WIDTH_N(16) - } else if constexpr(tile_n % 32 == 0){ - SELECT_WIDTH_N(32) - } else { - static_assert(dependent_false && "Invalid tile_n"); - } - #undef SELECT_WIDTH_N - } - #undef RETURN_ATOM -} - -namespace { -template -struct pick_mma_atom{ - static_assert(dependent_false && "no mma atom for this combination of types"); -}; - -#define PICK_MMA(ElementAB, ElementCD, ATOM) \ -template <> struct pick_mma_atom { \ - using atom = MMA_Atom; \ -}; - -PICK_MMA(bfloat16_t, float, XE_8x16x16_F32BF16BF16F32_TT); -PICK_MMA(bfloat16_t, bfloat16_t, XE_8x16x16_BF16BF16BF16BF16_TT); -PICK_MMA(half_t, float, XE_8x16x16_F32F16F16F32_TT); -PICK_MMA(half_t, half_t, XE_8x16x16_F16F16F16F16_TT); - -#undef PICK_MMA -} template < - class ElementA, - class GmemLayoutATag, - int AlignmentA, - class ElementB, - class GmemLayoutBTag, - int AlignmentB, - class ElementAccumulator, - class TileShape_MNK, - class KernelScheduleType - > + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class KernelScheduleType +> struct CollectiveBuilder< arch::IntelXe, arch::OpClassTensorOp, // Reusing opClassTensorOp for Intel devices @@ -159,81 +78,94 @@ struct CollectiveBuilder< Shape<_1, _1, _1>, // Cluster Shape cutlass::gemm::collective::StageCountAuto, KernelScheduleType, - cute::enable_if_t< - cute::is_any_of_v && - cute::is_any_of_v && - cute::is_any_of_v - > - >{ - - #ifdef SYCL_NVIDIA_TARGET - static_assert(cutlass::detail::dependent_false, - "Trying to use Intel pipeline on Non Intel hardware"); - #endif - static_assert(is_static::value); - static_assert(cute::is_any_of_v, - "Intel multi-stage pipeline requires ElementC to be of type float, bfloat or half"); - - static constexpr bool isAtypeBig = cute::sizeof_bits_v > cute::sizeof_bits_v; - using MMAType = std::conditional_t; - using MMAAtom = typename pick_mma_atom::atom; - - static constexpr auto tile_M = get<0>(TileShape_MNK{}); - static constexpr auto tile_N = get<1>(TileShape_MNK{}); - static constexpr auto tile_K = get<2>(TileShape_MNK{}); - - static constexpr auto n_atoms = get_num_atoms(tile_M, tile_N); - using atoms_M = decltype(get<0>(n_atoms)); - using atoms_N = decltype(get<1>(n_atoms)); - using TiledMma = - typename TiledMMAHelper, - Layout, Stride>>::TiledMMA; - - static constexpr bool IsGroup = cute::is_same_v; - - using KernelSchedule = std::conditional_t, KernelXe, KernelScheduleType>; - static constexpr int PipelineStages = IsGroup ? 2 : 3; - using DispatchPolicy = std::conditional_t, - std::conditional_t, cutlass::gemm::MainloopIntelXeXMX16Group, - cutlass::gemm::MainloopIntelXeXMX16GroupMixedPrecision>>; - - static constexpr bool isAtransposed = cute::is_same_v; - static constexpr bool isBtransposed = cute::is_same_v; - using GmemTiledCopyA = std::conditional_t == 8, std::conditional_t, - decltype(select_copy_atom_16b(tile_M/atoms_M{}, tile_K))>; - using GmemTiledCopyB = std::conditional_t == 4, std::conditional_t, - std::conditional_t == 8, std::conditional_t, - decltype(select_copy_atom_16b(tile_K, tile_N/atoms_N{}))>>; - - // Xe pipeline does not use shared memory - using SmemLayoutAtomA = void; - using SmemLayoutAtomB = void; - using SmemCopyAtomA = void; - using SmemCopyAtomB = void; - - using TransformA = cute::identity; - using TransformB = cute::identity; - - using ElementA_ = std::conditional_t <= 8, cute::tuple, ElementA>; - using ElementB_ = std::conditional_t <= 8, cute::tuple, ElementB>; - - using CollectiveOp = cutlass::gemm::collective::CollectiveMma< - DispatchPolicy, - TileShape_MNK, - ElementA_, - cutlass::gemm::TagToStrideA_t>, - ElementB_, - cutlass::gemm::TagToStrideB_t>, - TiledMma, - GmemTiledCopyA, - SmemLayoutAtomA, - SmemCopyAtomA, - TransformA, - GmemTiledCopyB, - SmemLayoutAtomB, - SmemCopyAtomB, - TransformB - >; - }; -} + cute::enable_if_t> +> { +#ifdef SYCL_NVIDIA_TARGET + static_assert(cutlass::detail::dependent_false, + "Trying to use Xe pipeline on non-Xe hardware"); +#endif + static_assert(is_static::value); + + using DPAS_M = decltype(cute::gcd(_8{}, get<0>(TileShape_MNK{}))); + using MMAAtom = MMA_Atom())>; + + static constexpr auto MMAAtomGrid = shape_div(TileShape_MNK{}, typename MMAAtom::Shape_MNK{}); + + // Choose subgroup configuration. + static constexpr int MaxSG = 32; + static constexpr int SG_M0 = cute::min(get<0>(MMAAtomGrid), 8); + static constexpr int SG_N = cute::min(get<1>(MMAAtomGrid), MaxSG / SG_M0); + static constexpr int SG_M = cute::min(get<0>(MMAAtomGrid), MaxSG / SG_N); + + using TiledMMA = + typename TiledMMAHelper, + Layout, C, _1>, Stride, _1, _0>>>::TiledMMA; + + using KernelSchedule = std::conditional_t, KernelXe, KernelScheduleType>; + + static constexpr bool IsGroup = cute::is_same_v; + static constexpr int PipelineStages = IsGroup ? 2 : 3; + using DispatchPolicy = std::conditional_t, + cutlass::gemm::MainloopXeL1StagedGroup>; + + using GmemTiledCopyA = void; // autoselect + using GmemTiledCopyB = void; // autoselect + + using SmemLayoutAtomA = void; // No shared memory usage + using SmemLayoutAtomB = void; + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using TransformA = cute::identity; + using TransformB = cute::identity; + + using ElementA_ = ElementA; + using ElementB_ = ElementB; + + + using CollectiveOp = cutlass::gemm::collective::CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA_, + cutlass::gemm::TagToStrideA_t>, + ElementB_, + cutlass::gemm::TagToStrideB_t>, + TiledMMA, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + TransformA, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + TransformB + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Forward Xe12/Xe20 builders to IntelXe +///////////////////////////////////////////////////////////////////////////////////////////////// + +#define CUTLASS_FORWARD_XE_MMA_BUILDER(Arch) \ +template \ +struct CollectiveBuilder, \ + cutlass::gemm::collective::StageCountAuto, KernelScheduleType> \ + : CollectiveBuilder, \ + cutlass::gemm::collective::StageCountAuto, KernelScheduleType> {}; + +CUTLASS_FORWARD_XE_MMA_BUILDER(Xe12) +CUTLASS_FORWARD_XE_MMA_BUILDER(Xe20) + +#undef CUTLASS_FORWARD_XE_MMA_BUILDER + +} // namespace cutlass::gemm::collective + diff --git a/include/cutlass/gemm/collective/collective_builder.hpp b/include/cutlass/gemm/collective/collective_builder.hpp index 89192a95f6..2fe4f5cfbf 100644 --- a/include/cutlass/gemm/collective/collective_builder.hpp +++ b/include/cutlass/gemm/collective/collective_builder.hpp @@ -49,6 +49,8 @@ #include "cutlass/gemm/collective/builders/sm100_simt_builder.inl" #include "cutlass/gemm/collective/builders/sm100_mixed_input_umma_builder.inl" #include "cutlass/gemm/collective/builders/sm100_cpasync_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_mixed_tma_cpasync_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_blockscaled_mixed_tma_cpasync_umma_builder.inl" #include "cutlass/gemm/collective/builders/sm103_blockscaled_umma_builder.inl" #include "cutlass/gemm/collective/builders/sm120_mma_builder.inl" #include "cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl" diff --git a/include/cutlass/gemm/collective/collective_mma.hpp b/include/cutlass/gemm/collective/collective_mma.hpp index b54776ece5..4e1656c2eb 100644 --- a/include/cutlass/gemm/collective/collective_mma.hpp +++ b/include/cutlass/gemm/collective/collective_mma.hpp @@ -65,6 +65,8 @@ #include "cutlass/gemm/collective/sm100_mma_array_warpspecialized_blockwise_scaling.hpp" #include "cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp" #include "cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm100_mma_mixed_tma_cpasync_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm100_blockscaled_mma_mixed_tma_cpasync_warpspecialized.hpp" #include "cutlass/gemm/collective/sm103_blockscaled_mma_warpspecialized.hpp" #include "cutlass/gemm/collective/sm103_blockscaled_mma_array_warpspecialized.hpp" #include "cutlass/gemm/collective/sm120_mma_tma.hpp" @@ -78,8 +80,10 @@ #if defined(SYCL_INTEL_TARGET) #include "cutlass/gemm/collective/xe_mma.hpp" +#include "cutlass/gemm/collective/xe_mma_legacy.hpp" #include "cutlass/gemm/collective/xe_array_mma.hpp" -#include "cutlass/gemm/collective/xe_array_mma_fp8.hpp" +#include "cutlass/gemm/collective/xe_array_mma_legacy.hpp" +#include "cutlass/gemm/collective/xe_array_mma_fp8_legacy.hpp" #include "cutlass/gemm/collective/xe_mma_mixed_input.hpp" #include "cutlass/gemm/collective/xe_array_mma_mixed_input.hpp" #include "cutlass/gemm/collective/xe_mma_w8a8.hpp" diff --git a/include/cutlass/gemm/collective/sm100_blockscaled_mma_mixed_tma_cpasync_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_blockscaled_mma_mixed_tma_cpasync_warpspecialized.hpp new file mode 100644 index 0000000000..9e9959aec4 --- /dev/null +++ b/include/cutlass/gemm/collective/sm100_blockscaled_mma_mixed_tma_cpasync_warpspecialized.hpp @@ -0,0 +1,1043 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/arch/memory.h" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +#include "cutlass/gemm/collective/collective_mma_decl.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int Stages, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementPairA_, + class StridePairA_, + class ElementPairB_, + class StridePairB_, + class TiledMma_, + class GmemTiledCopyPairA_, + class SmemLayoutAtomPairA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyPairB_, + class SmemLayoutAtomPairB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100UmmaMixedTmaCpAsyncWarpSpecializedBlockScaled< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShape_, + ElementPairA_, + StridePairA_, + ElementPairB_, + StridePairB_, + TiledMma_, + GmemTiledCopyPairA_, + SmemLayoutAtomPairA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyPairB_, + SmemLayoutAtomPairB_, + SmemCopyAtomB_, + TransformB_> { + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + // Statically asserting to ensure only 1x1x1 cluster shape & 1sm setup is received + static_assert(size(AtomThrShapeMNK{}) == 1, "Lower alignment SM100 GEMM only supports 1SM MMA"); + static_assert(size(ClusterShape{}) == 1, "CPASYNC does not support multicast so the cluster shape is restricted to 1, 1, 1"); + + static_assert(size(typename TiledMma::AtomThrID{}) == 1); + + using DispatchPolicy = MainloopSm100UmmaMixedTmaCpAsyncWarpSpecializedBlockScaled< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + // TileShape refers to MmaTileShape to adapt for runtime cluster + using TileShape = TileShape_; + using TiledMma_SF = TiledMMA, + Layout>, + Tile>; + + static constexpr int SFVecSize = TiledMma::SFVecSize; + static constexpr bool IsOverlappingAccum = DispatchPolicy::IsOverlappingAccum; + static_assert(!IsOverlappingAccum, "TMA+CPASYNC kernel currently only supports non-overlapping accum."); + + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TileShape should be evenly divided by TiledMma"); + + // Define A and B block shapes + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + // using LoadShapeA_MK = decltype(select<0,2>(TileShape{})); + using LoadShapeB_NK = decltype(select<1,2>(TileShape{})); + + // CtaShape_MNK is queried from collective in all kernel layers + using CtaShape_MNK = TileShape; + static_assert(shape<1>(CtaShape_MNK{}) == 192 or shape<1>(CtaShape_MNK{}) == 64 or + shape<1>(CtaShape_MNK{}) == 128 or shape<1>(CtaShape_MNK{}) == 256, + "Cta N should be one of 64/128/192/256"); + + using ClusterTileShape = decltype(make_shape(get<0>(TileShape{})*get<0>(ClusterShape{}),get<1>(TileShape{})*get<1>(ClusterShape{}),get<2>(TileShape{})*get<2>(ClusterShape{}))); + using Sm1xxBlkScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig; + using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; + static constexpr int IsCtaN192 = shape<1>(CtaShape_MNK{}) == 192; + static constexpr int IsCtaN64 = shape<1>(CtaShape_MNK{}) == 64; + static int constexpr CTA_N_SF = cutlass::ceil_div(size<1>(CtaShape_MNK{}), Blk_MN{}) * Blk_MN{}; + // Tile shape used for partitioning Scale Factor B. + // The M-dim does not affect the SFB, so just set it as the original TileShape; + using TileShape_SF = decltype(make_shape(get<0>(CtaShape_MNK{}), + Int{} * shape<2>(typename TiledMma::ThrLayoutVMNK()), + get<2>(TileShape{}))); + + using ElementPairA = ElementPairA_; + using ElementPairB = ElementPairB_; + using StridePairA = StridePairA_; + using StridePairB = StridePairB_; + using SmemLayoutAtomPairA = SmemLayoutAtomPairA_; + using SmemLayoutAtomPairB = SmemLayoutAtomPairB_; + static_assert(cute::is_same_v(ElementPairA{}))>, + remove_cvref_t(ElementPairB{}))>>, "SFA and SFB data types should be the same"); + + + using ElementA = remove_cvref_t(ElementPairA{}))>; + using ElementAMma = typename TiledMma::ValTypeA; + using StrideA = remove_cvref_t(StridePairA{}))>; + using ElementB = remove_cvref_t(ElementPairB{}))>; + using ElementBMma = typename TiledMma::ValTypeB; + using StrideB = remove_cvref_t(StridePairB{}))>; + + static constexpr bool IsRuntimeDataTypeA = cute::is_same_v; + static constexpr bool IsRuntimeDataTypeB = cute::is_same_v; + + static_assert(IsRuntimeDataTypeA == IsRuntimeDataTypeB, + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + // SFA and SFB + using ElementSF = remove_cvref_t(ElementPairA{}))>; + using LayoutSFA = remove_cvref_t(StridePairA{}))>; + using LayoutSFB = remove_cvref_t(StridePairB{}))>; + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyPairA = GmemTiledCopyPairA_; + using GmemTiledCopyPairB = GmemTiledCopyPairB_; + + using GmemTiledCopyA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopySFA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopyB = remove_cvref_t(GmemTiledCopyPairB{}))>; + using GmemTiledCopySFB = remove_cvref_t(GmemTiledCopyPairB{}))>; + + using SmemLayoutAtomA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomSFA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + using SmemLayoutAtomSFB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipelineTMA = cutlass::PipelineTmaUmmaAsync; + using MainloopPipelineTMAState = typename MainloopPipelineTMA::PipelineState; + + using MainloopPipelineCpAsync = cutlass::PipelineUmmaConsumerAsync; + using MainloopPipelineCpAsyncState = typename MainloopPipelineCpAsync::PipelineState; + + // static_assert(size(GmemTiledCopyA{}) == size(GmemTiledCopyB{}), "A and B GmemTiledCopy should share the same thread count"); + static constexpr int NumLoadThreadsCpAsync = size(GmemTiledCopyB{}); + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M,K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtomA must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtomA must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N,K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtomB must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtomB must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + + using MmaSmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + using LoadSmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + append(LoadShapeB_NK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + // SmemLayoutAtomSFA and SmemLayoutAtomSFB are for whole CTA tiles. We add the number of pipeline stages here. + // The number of pipeline stages is the same as the number of pipeline stages from AB Load <-> MainLoop + using SmemLayoutSFA = decltype(make_layout( + append(shape(SmemLayoutAtomSFA{}), Int{}), + append(stride(SmemLayoutAtomSFA{}), size(filter_zeros(SmemLayoutAtomSFA{}))) + )); + using SmemLayoutSFB = decltype(make_layout( + append(shape(SmemLayoutAtomSFB{}), Int{}), + append(stride(SmemLayoutAtomSFB{}), size(filter_zeros(SmemLayoutAtomSFB{}))) + )); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + + static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + + using TmaInternalElementA = cute::conditional_t; + + using SmemAllocTypeA = cute::conditional_t < 8, uint8_t, ElementAMma>; + using SmemAllocTypeB = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using BitTypeElementA = cute::uint_bit_t>; + using BitTypeElementB = cute::uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = typename detail::sm10x_block_scale_runtime_input_t::Type; + using RuntimeDataTypeB = typename detail::sm10x_block_scale_runtime_input_t::Type; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + cute::ArrayEngine> smem_SFA; + cute::ArrayEngine> smem_SFB; + } tensors; + + using PipelineStorageTMA = typename MainloopPipelineTMA::SharedStorage; + using PipelineStorageCpAsync = typename MainloopPipelineCpAsync::SharedStorage; + + struct PipelineStorage : cute::aligned_struct<16, _0> { + alignas(16) PipelineStorageTMA tma; + alignas(16) PipelineStorageCpAsync cpasync; + } pipelines; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr uint32_t SFTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFB{})) * cute::sizeof_bits_v); + static constexpr uint32_t ATmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v); + static constexpr uint32_t TmaTransactionBytes = ATmaTransactionBytes + SFTransactionBytes; + + template + struct TmemStorage { + AccTensor accumulators; + SfaTensor tCtSFA; + SfbTensor tCtSFB; + }; + + // Host side kernel arguments + struct Arguments { + ArrayElementA const* ptr_A{nullptr}; + StrideA dA{}; + ArrayElementB const* ptr_B{nullptr}; + StrideB dB{}; + ElementSF const* ptr_SFA{nullptr}; + LayoutSFA layout_SFA{}; + ElementSF const* ptr_SFB{nullptr}; + LayoutSFB layout_SFB{}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(ClusterShape{}), + make_tile(typename TiledMma::AtomThrID{}))); + using ClusterLayoutSfb_VMNK = decltype(tiled_divide(make_layout(ClusterShape{}), + make_tile(typename TiledMma_SF::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + using TMA_SFA = decltype(make_tma_atom_A_sm100( + GmemTiledCopySFA{}, + make_tensor(static_cast(nullptr), LayoutSFA{}), + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + using TMA_SFB = decltype(make_tma_atom_B_sm100( + GmemTiledCopySFB{}, + make_tensor(static_cast(nullptr), LayoutSFB{}), + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + TileShape_SF{}, + TiledMma_SF{}, + ClusterLayoutSfb_VMNK{}) + ); + + TMA_A tma_load_a; + TMA_SFA tma_load_sfa; + TMA_SFB tma_load_sfb; + + ArrayElementB const* ptr_B{nullptr}; + StrideB dB{}; + + LayoutSFA layout_SFA; + LayoutSFB layout_SFB; + + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params) + : layout_SFA_(params.layout_SFA) + , layout_SFB_(params.layout_SFB) + , runtime_data_type_a_(params.runtime_data_type_a) + , runtime_data_type_b_(params.runtime_data_type_b) { + + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_sfa_ = ¶ms.tma_load_sfa; + observed_tma_load_sfb_ = ¶ms.tma_load_sfb; + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_sfa = make_tensor(args.ptr_SFA, args.layout_SFA); + Tensor tensor_sfb = make_tensor(args.ptr_SFB, args.layout_SFB); + + auto cluster_layout_vmnk = tiled_divide(make_layout(ClusterShape{}), make_tile(typename TiledMma::AtomThrID{})); + auto cluster_layout_sfb_vmnk = tiled_divide(make_layout(ClusterShape{}), make_tile(typename TiledMma_SF::AtomThrID{})); + + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + typename Params::TMA_SFA tma_load_sfa = make_tma_atom_A_sm100( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + typename Params::TMA_SFB tma_load_sfb = make_tma_atom_B_sm100( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + TileShape_SF{}, + TiledMma_SF{}, + cluster_layout_sfb_vmnk); + + return { + tma_load_a, + tma_load_sfa, + tma_load_sfb, + args.ptr_B, + args.dB, + args.layout_SFA, + args.layout_SFB, + args.runtime_data_type_a, + args.runtime_data_type_b + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + // static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits::value; + + bool implementable = true; + + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for CpAsync.\n"); + } + + // Check for SFA SFB layout requirement + const auto layout_sfa_ref = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(problem_shape_MNKL); + const auto layout_sfb_ref = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(problem_shape_MNKL); + implementable = implementable && (layout_sfa_ref == args.layout_SFA); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: layout_SFA mismatch, layout_SFA needs to be K-major\n"); + } + + implementable = implementable && (layout_sfb_ref == args.layout_SFB); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: layout_SFB mismatch, layout_SFB needs to be K-major\n"); + } + + return implementable; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE void + prefetch_tma_descriptors() { + cute::prefetch_tma_descriptor(observed_tma_load_a_->get_tma_descriptor()); + cute::prefetch_tma_descriptor(observed_tma_load_sfa_->get_tma_descriptor()); + cute::prefetch_tma_descriptor(observed_tma_load_sfb_->get_tma_descriptor()); + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE static + auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + template + CUTLASS_DEVICE static + auto + slice_accumulator(TmemStorage tmem_storage, int stage) { + return cute::make_tuple(tmem_storage.accumulators(_,_,_,stage)); + } + + template + CUTLASS_DEVICE static + auto + init_tmem_tensors(EpilogueTile epi_tile) { + TiledMma tiled_mma; + auto acc_shape = partition_accumulator_shape(); + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + Tensor tCtSFA = make_tensor(shape(SmemLayoutAtomSFA{})); + Tensor tCtSFB = make_tensor(shape(SmemLayoutAtomSFB{})); + + TmemStorage tmem_storage; + tmem_storage.accumulators = accumulators; + tmem_storage.tCtSFA = tCtSFA; + tmem_storage.tCtSFB = tCtSFB; + + return tmem_storage; + } + + template + CUTLASS_DEVICE static + void + set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) { + tmem_storage.accumulators.data() = tmem_base_addr; + tmem_storage.tCtSFA.data() = tmem_storage.accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.accumulators); + tmem_storage.tCtSFB.data() = tmem_storage.tCtSFA.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.tCtSFA); + } + + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tensor for input A + /// gB_nkl - The tiled tensor for input B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + template + CUTLASS_DEVICE auto + load_init_tma( + ProblemShape_MNKL const& problem_shape_MNKL, + TensorStorage& shared_tensors) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,L)); + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + + // Represent the full tensor of Scale factors + Tensor mSFA_mkl = observed_tma_load_sfa_->get_tma_tensor(shape(layout_SFA_)); + auto mSFB_nkl = [=](){ + if constexpr (IsCtaN192) { + Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB_)); + auto x = stride<0,1>(mSFB_tmp); + auto y = ceil_div(shape<0,1>(mSFB_tmp), 4); + auto new_shape = make_shape (make_shape( shape<0,0>(mSFB_tmp), + make_shape( make_shape(_2{}, _2{}), y)), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), + make_stride(make_stride( x, x), x*3)), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); + return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); + } + else if constexpr (IsCtaN64) { + Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB_)); + auto new_shape = make_shape(make_shape(shape<0,0>(mSFB_tmp), + make_shape(_2{} , shape<0,1>(mSFB_tmp))), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), + make_stride(_0{}, stride<0,1>(mSFB_tmp))), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); + return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); + } + else { + return observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB_)); + } + }(); + + Tensor gSFA_mkl = local_tile(mSFA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (TILE_M,TILE_K,m,k,l) + Tensor gSFB_nkl = local_tile(mSFB_nkl, TileShape_SF{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (TILE_N,TILE_K,n,k,l) + + + ThrMMA cta_mma = TiledMma{}.get_slice(0); + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + + ThrMMA cta_mma_sfb = TiledMma_SF{}.get_slice(BlockIdxX() % size(typename TiledMma_SF::AtomThrID{})); + Tensor tCgSFA_mkl = cta_mma.partition_A(gSFA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgSFB_nkl = cta_mma_sfb.partition_B(gSFB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(ClusterShape{}); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(0); + Layout cta_layout_sfb_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma_SF::AtomThrID{})); + auto cta_coord_sfb_vmnk = cta_layout_sfb_vmnk.get_flat_coord(0); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + // Project the cta_layout for tma_a along the n-modes + auto [tAgSFA_mkl, tAsSFA] = tma_partition(*observed_tma_load_sfa_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sSFA), group_modes<0,3>(tCgSFA_mkl)); + // Project the cta_layout for tma_b along the m-modes + auto [tBgSFB_nkl, tBsSFB] = tma_partition(*observed_tma_load_sfb_, + get<1>(cta_coord_sfb_vmnk), make_layout(size<1>(cta_layout_sfb_vmnk)), + group_modes<0,3>(sSFB), group_modes<0,3>(tCgSFB_nkl)); + + return cute::make_tuple( + shape<3>(gA_mkl), // for scheduler + tAgA_mkl, tAsA, // for input tensor values + tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB // for input scale factor tensor values + ); + } + + template + CUTLASS_DEVICE auto + load_init_cpasync( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors, + TileScheduler const& scheduler, + typename TileScheduler::WorkTileInfo const& work_tile_info) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // convert to subptr iterator if necessary + auto ptr_B = recast_ptr(params.ptr_B); + // Represent the full tensors + Tensor mB_nkl = make_tensor(make_gmem_ptr(ptr_B), make_shape(N,K,L), params.dB); //(n,k,l) + // Partition for cpasync + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + // Build the coordinate tensors with the same shape as input matrices + Tensor cB_nk = make_identity_tensor(make_shape(N,K)); + // Slice the coordinate tensors in the same way as A/B tensor partitioning + Tensor cgB_nk = local_tile(cB_nk, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k) + + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), LoadSmemLayoutB{}); + + GmemTiledCopyB gmem_to_smem_b_tiled_copy; + + int thread_idx = ThreadIdxX() % NumLoadThreadsCpAsync; + auto thr_copy_b = gmem_to_smem_b_tiled_copy.get_slice(thread_idx); + + return cute::make_tuple( + gB_nkl, cgB_nk, sB, + // problem_shape_MNKL, + gmem_to_smem_b_tiled_copy, thr_copy_b); + } + + /// Set up the data needed by this collective for mma compute. + template + CUTLASS_DEVICE auto + mma_init( + Params const& params, + TmemStorage tmem_storage, + // [[maybe_unused]] cute::tuple, cute::Tensor> const& accumulators_pair, + TensorStorage& shared_tensors) const { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), MmaSmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); + + // + // Scale Factor + // + Tensor tCtSFA = tmem_storage.tCtSFA; + Tensor tCtSFB = tmem_storage.tCtSFB; + // Setup smem descriptors for UTCCP + Tensor tCsSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor tCsSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + // Make SMEM and TMEM tensors compact removing the zero strides to eliminate unnecessary copy instructions. + auto tCsSFA_compact = make_tensor(tCsSFA.data(), filter_zeros(tCsSFA.layout())); + auto tCtSFA_compact = make_tensor(tCtSFA.data(), filter_zeros(tCtSFA.layout())); + auto tCsSFB_compact = make_tensor(tCsSFB.data(), filter_zeros(tCsSFB.layout())); + auto tCtSFB_compact = make_tensor(tCtSFB.data(), filter_zeros(tCtSFB.layout())); + + // Create the SMEM to TMEM copy operations based on the MMA atom used (1CTA vs 2CTA) + using AtomThrID = typename TiledMma::AtomThrID; + using UtccpOp = cute::conditional_t<(decltype(cute::size(AtomThrID{}) == Int<2>{})::value), + SM100_UTCCP_4x32dp128bit_2cta, SM100_UTCCP_4x32dp128bit_1cta>; + auto tiled_copy_s2t_SFA = make_utccp_copy(UtccpOp{}, tCtSFA_compact); + auto tiled_copy_s2t_SFB = make_utccp_copy(UtccpOp{}, tCtSFB_compact); + + auto thr_copy_s2t_SFA = tiled_copy_s2t_SFA.get_slice(0); + auto thr_tCsSFA_compact_s2t_ = thr_copy_s2t_SFA.partition_S(tCsSFA_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFA_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFA_compact_s2t_); + auto thr_tCtSFA_compact_s2t = thr_copy_s2t_SFA.partition_D(tCtSFA_compact); + + auto thr_copy_s2t_SFB = tiled_copy_s2t_SFB.get_slice(0); + auto thr_tCsSFB_compact_s2t_ = thr_copy_s2t_SFB.partition_S(tCsSFB_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFB_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFB_compact_s2t_); + auto thr_tCtSFB_compact_s2t = thr_copy_s2t_SFB.partition_D(tCtSFB_compact); + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + // Update instruction descriptor according to runtime argument. + // Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe. + tiled_mma.idesc_.a_format_ = uint8_t(params.runtime_data_type_a) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(params.runtime_data_type_b) & 0b111; + } + + return cute::make_tuple( + tiled_mma, + tCrA, tCrB, + tCtSFA, tCtSFB, + tiled_copy_s2t_SFA, thr_tCsSFA_compact_s2t, thr_tCtSFA_compact_s2t, + tiled_copy_s2t_SFB, thr_tCsSFB_compact_s2t, thr_tCtSFB_compact_s2t + + // debug + // , sA, sB, tCsSFA, tCsSFB + ); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + // class KTileCount, + // class GTensorPartitionedA, + // class STensorA, + class TileCoordMNKL, + class KTileIterator, + class... TLoadParams // see load_init_tma + > + CUTLASS_DEVICE auto + load_tma( + MainloopPipelineTMA mainloop_pipeline, + MainloopPipelineTMAState mainloop_pipe_producer_state, + cute::tuple const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + // Unpack from load_inputs + // KTileCount k_tiles = get<0>(load_inputs); + // GTensorPartitionedA tAgA_mkl = get<1>(load_inputs); + // STensorA tAsA = get<2>(load_inputs); + + auto [k_tiles, + tAgA_mkl, tAsA, + tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB] = load_inputs; + + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tAgSFA = tAgSFA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgSFB = tBgSFB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + + using BarrierType = typename MainloopPipelineTMA::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(*tma_barrier), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_sfa_->with(*tma_barrier), tAgSFA(_,*k_tile_iter), tAsSFA(_,write_stage)); + copy(observed_tma_load_sfb_->with(*tma_barrier), tBgSFB(_,*k_tile_iter), tBsSFB(_,write_stage)); + } + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + + template < + // class GTensorB, + // class CTensorB, + // class STensorB, + // class ProblemShape_MNKL, + // class TiledCopyB, + // class ThreadCopyB, + class TileCoordMNKL, + class KTileIterator, + class ProblemShape_MNKL, + class... TParams + > + CUTLASS_DEVICE auto + load_cpasync( + Params const& params, + MainloopPipelineCpAsync mainloop_pipeline, + MainloopPipelineCpAsyncState mainloop_pipe_producer_state, + cute::tuple const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count, + ProblemShape_MNKL effective_shape + ) { + + // Unpack from load_inputs + // GTensorB tBgB_nkl = get<0>(load_inputs); + // CTensorB cgB_nk = get<1>(load_inputs); + // STensorB sB = get<2>(load_inputs); + // ProblemShape_MNKL problem_shape_MNKL = get<3>(load_inputs); + // TiledCopyB gmem_to_smem_b_tiled_copy = get<4>(load_inputs); + // ThreadCopyB thr_copy_b = get<5>(load_inputs); + + // auto [M,N,K,L] = problem_shape_MNKL; + + auto [ + tBgB_nkl, cgB_nk, sB, + // problem_shape_MNKL, + gmem_to_smem_b_tiled_copy, thr_copy_b] = load_inputs; + + auto [M,N,K,L] = effective_shape; + + // Slice out the work coord from partitioned tensors + Tensor gB_in = tBgB_nkl(_, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + // Repeat slicing out coordinate tensor exactly the same as input tensor does + Tensor cgB_nk_in = cgB_nk(_, _, get<1>(cta_coord_mnkl), _); + + auto k_residue = K - size<1>(gB_in) * size<2>(gB_in); // K - BLK_K * k is negative + + Tensor gB = gB_in; + Tensor cB = cgB_nk_in; + + auto tBgB = thr_copy_b.partition_S(gB); + auto tBsB = thr_copy_b.partition_D(sB); + + // Allocate predicate tensors for n + Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); + Tensor tBcB_nk = thr_copy_b.partition_S(cgB_nk_in); + Tensor tBcB = thr_copy_b.partition_S(cB); + + // Copy gmem to smem for *k_tile_iter, predicating for k residue + Tensor tBgBk = tBgB(_,_,_,*k_tile_iter); + + // Repeating on predicators with the same operations on tBgB + Tensor tBcBk = tBcB(_,_,_,*k_tile_iter); + + // Set predicates for n bounds + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<0>(tBpB); ++n) { + tBpB(n,0) = elem_less(get<0>(tBcBk(0,n,0)), N); // blk_n coord < N + } + + // we will process the last tile after the mainloop + if (k_residue != 0) { + --k_tile_count; + } + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + + copy_if(gmem_to_smem_b_tiled_copy, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + + mainloop_pipeline.producer_commit(mainloop_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive); + --k_tile_count; + ++k_tile_iter; + ++mainloop_pipe_producer_state; + } + + // last tile with predication on k to account for residue + // For performance consideration, + // this predicated block for K-tail is only activated when there is k-residue + if (k_residue != 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tBsB); ++k) { + if (int(get<1>(tBcBk(0,0,k))) >= 0) { // blk_k coord < K + copy_if(gmem_to_smem_b_tiled_copy, tBpB(_,k), tBgB(_,_,k,*k_tile_iter), tBsB(_,_,k,write_stage)); + } + else { + clear(tBsB(_,_,k,write_stage)); + } + } + ++k_tile_iter; + --k_tile_count; + + // UNLOCK mainloop_pipe_producer_state + mainloop_pipeline.producer_commit(mainloop_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive); + + // Advance mainloop_pipe_producer_state + ++mainloop_pipe_producer_state; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_tail_tma(MainloopPipelineTMA mainloop_pipeline, MainloopPipelineTMAState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + CUTLASS_DEVICE void + load_tail_cpasync(MainloopPipelineCpAsync mainloop_pipeline, MainloopPipelineCpAsyncState mainloop_pipe_producer_state) { + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class AccumulatorPipeline, + class FrgEngine, class FrgLayout, + class CtaTileCoord, + class... TMmaParams + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::tuple> const& accumulators_pair, + cute::tuple const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + auto accumulators = get<0>(accumulators_pair); + auto [tiled_mma, tCrA, tCrB, tCtSFA, tCtSFB, + tiled_copy_s2t_SFA, thr_tCsSFA_s2t, + thr_tCtSFA_s2t, tiled_copy_s2t_SFB, + thr_tCsSFB_s2t, thr_tCtSFB_s2t + + // debug + // , sA, sB, tCsSFA, tCsSFB + ] = mma_inputs; + + auto [mainloop_pipeline_tma, mainloop_pipeline_cpasync, accumulator_pipeline] = pipelines; + auto [mainloop_pipe_tma_consumer_state, mainloop_pipe_cpasync_consumer_state, accumulator_pipe_producer_state] = pipeline_states; + + auto tCtSFB_mma = [tCtSFB = tCtSFB, cta_tile_coord]() { + if constexpr (IsCtaN192) { + // If this is an ODD tile, shift the TMEM start address for N=192 case by two words (ignores first 64 columns of SFB) + auto tCtSFB_tmp = tCtSFB; + if (size<1>(cta_tile_coord) % 2 == 1) { + tCtSFB_tmp.data() = tCtSFB_tmp.data().get() + 2; + } + return tCtSFB_tmp; + } + else if constexpr (IsCtaN64) { + // Move in increments of 64 columns of SFB + auto tCtSFB_tmp = tCtSFB; + tCtSFB_tmp.data() = tCtSFB_tmp.data().get() + (size<1>(cta_tile_coord) % 2) * 2; + return tCtSFB_tmp; + } + else { + return tCtSFB; + } + }(); + + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + + // + // PIPELINED MAIN LOOP + // + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + mainloop_pipeline_tma.consumer_wait(mainloop_pipe_tma_consumer_state); + mainloop_pipeline_cpasync.consumer_wait(mainloop_pipe_cpasync_consumer_state); + + int read_stage_tma = mainloop_pipe_tma_consumer_state.index(); + int read_stage_cpasync = mainloop_pipe_cpasync_consumer_state.index(); + + if (cute::elect_one_sync()) { + copy(tiled_copy_s2t_SFA, thr_tCsSFA_s2t(_,_,_,_,read_stage_tma), thr_tCtSFA_s2t); + copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage_tma), thr_tCtSFB_s2t); + } + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma.with(tiled_mma.accumulate_, + tCtSFA(_,_,k_block), + tCtSFB_mma(_,_,k_block)), + tCrA(_,_,k_block,read_stage_tma), + tCrB(_,_,k_block,read_stage_cpasync), + accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + mainloop_pipeline_tma.consumer_release(mainloop_pipe_tma_consumer_state); + mainloop_pipeline_cpasync.consumer_release(mainloop_pipe_cpasync_consumer_state); + --k_tile_count; + ++mainloop_pipe_tma_consumer_state; + ++mainloop_pipe_cpasync_consumer_state; + } + + return cute::make_tuple(mainloop_pipe_tma_consumer_state, mainloop_pipe_cpasync_consumer_state); + } + +protected: + + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + typename Params::TMA_SFA const* observed_tma_load_sfa_{nullptr}; + typename Params::TMA_SFB const* observed_tma_load_sfb_{nullptr}; + + LayoutSFA layout_SFA_; + LayoutSFB layout_SFB_; + + RuntimeDataTypeA runtime_data_type_a_{}; + RuntimeDataTypeB runtime_data_type_b_{}; + + // ClusterShape cluster_shape_; + // uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm100_mma_mixed_tma_cpasync_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_mma_mixed_tma_cpasync_warpspecialized.hpp new file mode 100644 index 0000000000..cbd6f1bcd7 --- /dev/null +++ b/include/cutlass/gemm/collective/sm100_mma_mixed_tma_cpasync_warpspecialized.hpp @@ -0,0 +1,758 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/arch/memory.h" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +#include "cutlass/gemm/collective/collective_mma_decl.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int Stages, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100UmmaMixedTmaCpAsyncWarpSpecialized< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + // Statically asserting to ensure only 1x1x1 cluster shape & 1sm setup is received + static_assert(size(AtomThrShapeMNK{}) == 1, "Lower alignment SM100 GEMM only supports 1SM MMA"); + static_assert(size(ClusterShape{}) == 1, "CPASYNC does not support multicast so the cluster shape is restricted to 1, 1, 1"); + + static_assert(size(typename TiledMma::AtomThrID{}) == 1); + + using DispatchPolicy = MainloopSm100UmmaMixedTmaCpAsyncWarpSpecialized< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + // TileShape refers to MmaTileShape to adapt for runtime cluster + using TileShape = TileShape_; + + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TileShape should be evenly divided by TiledMma"); + + // Define A and B block shapes + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + // using LoadShapeA_MK = decltype(select<0,2>(TileShape{})); + using LoadShapeB_NK = decltype(select<1,2>(TileShape{})); + + // CtaShape_MNK is queried from collective in all kernel layers + using CtaShape_MNK = TileShape; + + using ElementA = ElementA_; + using ElementAMma = typename TiledMma::ValTypeA; + using StrideA = StrideA_; + using ElementB = ElementB_; + using ElementBMma = typename TiledMma::ValTypeB; + using StrideB = StrideB_; + + static constexpr bool IsRuntimeDataTypeA = cute::is_same_v; + static constexpr bool IsRuntimeDataTypeB = cute::is_same_v; + + static_assert(IsRuntimeDataTypeA == IsRuntimeDataTypeB, + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipelineTMA = cutlass::PipelineTmaUmmaAsync; + using MainloopPipelineTMAState = typename MainloopPipelineTMA::PipelineState; + + using MainloopPipelineCpAsync = cutlass::PipelineUmmaConsumerAsync; + using MainloopPipelineCpAsyncState = typename MainloopPipelineCpAsync::PipelineState; + + // static_assert(size(GmemTiledCopyA{}) == size(GmemTiledCopyB{}), "A and B GmemTiledCopy should share the same thread count"); + static constexpr int NumLoadThreadsCpAsync = size(GmemTiledCopyB{}); + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M,K)"); + static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N,K)"); + static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + + using MmaSmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + using LoadSmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + append(LoadShapeB_NK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + + using TmaInternalElementA = cute::conditional_t, cutlass::tfloat32_t, ElementAMma>; + + using SmemAllocTypeA = cute::conditional_t < 8, uint8_t, ElementAMma>; + using SmemAllocTypeB = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using BitTypeElementA = cute::uint_bit_t>; + using BitTypeElementB = cute::uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = cute::conditional_t; + using RuntimeDataTypeB = cute::conditional_t; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + + using PipelineStorageTMA = typename MainloopPipelineTMA::SharedStorage; + using PipelineStorageCpAsync = typename MainloopPipelineCpAsync::SharedStorage; + + struct PipelineStorage : cute::aligned_struct<16, _0> { + alignas(16) PipelineStorageTMA tma; + alignas(16) PipelineStorageCpAsync cpasync; + } pipelines; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr uint32_t TmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v); + + template + struct TmemStorage { + AccTensor accumulators; + }; + + // Host side kernel arguments + struct Arguments { + ArrayElementA const* ptr_A{nullptr}; + StrideA dA{}; + ArrayElementB const* ptr_B{nullptr}; + StrideB dB{}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(ClusterShape{}), + make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + TMA_A tma_load_a; + + ArrayElementB const* ptr_B{nullptr}; + StrideB dB{}; + + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params) + : runtime_data_type_a_(params.runtime_data_type_a) + , runtime_data_type_b_(params.runtime_data_type_b) { + + observed_tma_load_a_ = ¶ms.tma_load_a; + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + + auto cluster_layout_vmnk = tiled_divide(make_layout(ClusterShape{}), make_tile(typename TiledMma::AtomThrID{})); + + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + return { + tma_load_a, + args.ptr_B, + args.dB, + args.runtime_data_type_a, + args.runtime_data_type_b + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits::value; + + bool implementable = true; + + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for CpAsync.\n"); + } + + return implementable; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE void + prefetch_tma_descriptors() { + cute::prefetch_tma_descriptor(observed_tma_load_a_->get_tma_descriptor()); + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE static + auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + template + CUTLASS_DEVICE static + auto + slice_accumulator(TmemStorage tmem_storage, int stage) { + return cute::make_tuple(tmem_storage.accumulators(_,_,_,stage)); + } + + template + CUTLASS_DEVICE static + auto + init_tmem_tensors(EpilogueTile epi_tile) { + TiledMma tiled_mma; + auto acc_shape = partition_accumulator_shape(); + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + TmemStorage tmem_storage; + tmem_storage.accumulators = accumulators; + return tmem_storage; + } + + template + CUTLASS_DEVICE static + void + set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) { + tmem_storage.accumulators.data() = tmem_base_addr; + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tensor for input A + /// gB_nkl - The tiled tensor for input B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + template + CUTLASS_DEVICE auto + load_init_tma( + ProblemShape_MNKL const& problem_shape_MNKL, + TensorStorage& shared_tensors) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,L)); + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + + ThrMMA cta_mma = TiledMma{}.get_slice(0); + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(ClusterShape{}); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(0); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + return cute::make_tuple( + shape<3>(gA_mkl), // for scheduler + tAgA_mkl, tAsA // for input tensor values + ); + } + + template + CUTLASS_DEVICE auto + load_init_cpasync( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors, + TileScheduler const& scheduler, + typename TileScheduler::WorkTileInfo const& work_tile_info) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Represent the full tensors + Tensor mB_nkl = make_tensor(make_gmem_ptr(params.ptr_B), make_shape(N,K,L), params.dB); //(n,k,l) + // Partition for cpasync + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + // Build the coordinate tensors with the same shape as input matrices + Tensor cB_nk = make_identity_tensor(make_shape(N,K)); + // Slice the coordinate tensors in the same way as A/B tensor partitioning + Tensor cgB_nk = local_tile(cB_nk, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k) + + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), LoadSmemLayoutB{}); + + GmemTiledCopyB gmem_to_smem_b_tiled_copy; + + int thread_idx = ThreadIdxX() % NumLoadThreadsCpAsync; + auto thr_copy_b = gmem_to_smem_b_tiled_copy.get_slice(thread_idx); + + return cute::make_tuple( + gB_nkl, cgB_nk, sB, + gmem_to_smem_b_tiled_copy, thr_copy_b); + } + + /// Set up the data needed by this collective for mma compute. + template + CUTLASS_DEVICE auto + mma_init( + Params const& params, + [[maybe_unused]] TmemStorage tmem_storage, + // [[maybe_unused]] cute::tuple, cute::Tensor> const& accumulators_pair, + TensorStorage& shared_tensors) const { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), MmaSmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + // Update instruction descriptor according to runtime argument. + // Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe. + tiled_mma.idesc_.a_format_ = uint8_t(params.runtime_data_type_a) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(params.runtime_data_type_b) & 0b111; + } + + return cute::make_tuple(tiled_mma, tCrA, tCrB); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class KTileCount, + class GTensorPartitionedA, + class STensorA, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load_tma( + MainloopPipelineTMA mainloop_pipeline, + MainloopPipelineTMAState mainloop_pipe_producer_state, + cute::tuple const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + // Unpack from load_inputs + KTileCount k_tiles = get<0>(load_inputs); + GTensorPartitionedA tAgA_mkl = get<1>(load_inputs); + STensorA tAsA = get<2>(load_inputs); + + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + + using BarrierType = typename MainloopPipelineTMA::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(*tma_barrier), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + } + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + + template < + // class GTensorB, + // class CTensorB, + // class STensorB, + // class ProblemShape_MNKL, + // class TiledCopyB, + // class ThreadCopyB, + class TileCoordMNKL, + class KTileIterator, + class ProblemShape_MNKL, + class... TParams + > + CUTLASS_DEVICE auto + load_cpasync( + Params const& params, + MainloopPipelineCpAsync mainloop_pipeline, + MainloopPipelineCpAsyncState mainloop_pipe_producer_state, + cute::tuple const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count, + ProblemShape_MNKL effective_shape + ) { + + // Unpack from load_inputs + // GTensorB tBgB_nkl = get<0>(load_inputs); + // CTensorB cgB_nk = get<1>(load_inputs); + // STensorB sB = get<2>(load_inputs); + // ProblemShape_MNKL problem_shape_MNKL = get<3>(load_inputs); + // TiledCopyB gmem_to_smem_b_tiled_copy = get<4>(load_inputs); + // ThreadCopyB thr_copy_b = get<5>(load_inputs); + + auto [ + tBgB_nkl, cgB_nk, sB, + // problem_shape_MNKL, + gmem_to_smem_b_tiled_copy, thr_copy_b] = load_inputs; + + auto [M,N,K,L] = effective_shape; + + // Slice out the work coord from partitioned tensors + Tensor gB_in = tBgB_nkl(_, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + // Repeat slicing out coordinate tensor exactly the same as input tensor does + Tensor cgB_nk_in = cgB_nk(_, _, get<1>(cta_coord_mnkl), _); + + auto k_residue = K - size<1>(gB_in) * size<2>(gB_in); // K - BLK_K * k is negative + + Tensor gB = gB_in; + Tensor cB = cgB_nk_in; + + auto tBgB = thr_copy_b.partition_S(gB); + auto tBsB = thr_copy_b.partition_D(sB); + + // Allocate predicate tensors for n + Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); + Tensor tBcB_nk = thr_copy_b.partition_S(cgB_nk_in); + Tensor tBcB = thr_copy_b.partition_S(cB); + + // Copy gmem to smem for *k_tile_iter, predicating for k residue + Tensor tBgBk = tBgB(_,_,_,*k_tile_iter); + + // Repeating on predicators with the same operations on tBgB + Tensor tBcBk = tBcB(_,_,_,*k_tile_iter); + + // Set predicates for n bounds + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<0>(tBpB); ++n) { + tBpB(n,0) = elem_less(get<0>(tBcBk(0,n,0)), N); // blk_n coord < N + } + + // we will process the last tile after the mainloop + if (k_residue != 0) { + --k_tile_count; + } + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + + copy_if(gmem_to_smem_b_tiled_copy, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + + mainloop_pipeline.producer_commit(mainloop_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive); + --k_tile_count; + ++k_tile_iter; + ++mainloop_pipe_producer_state; + } + + // last tile with predication on k to account for residue + // For performance consideration, + // this predicated block for K-tail is only activated when there is k-residue + if (k_residue != 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tBsB); ++k) { + if (int(get<1>(tBcBk(0,0,k))) >= 0) { // blk_k coord < K + copy_if(gmem_to_smem_b_tiled_copy, tBpB(_,k), tBgB(_,_,k,*k_tile_iter), tBsB(_,_,k,write_stage)); + } + else { + clear(tBsB(_,_,k,write_stage)); + } + } + ++k_tile_iter; + --k_tile_count; + + // UNLOCK mainloop_pipe_producer_state + mainloop_pipeline.producer_commit(mainloop_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive); + + // Advance mainloop_pipe_producer_state + ++mainloop_pipe_producer_state; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_tail_tma(MainloopPipelineTMA mainloop_pipeline, MainloopPipelineTMAState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + CUTLASS_DEVICE void + load_tail_cpasync(MainloopPipelineCpAsync mainloop_pipeline, MainloopPipelineCpAsyncState mainloop_pipe_producer_state) { + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class AccumulatorPipeline, + class FrgEngine, class FrgLayout, + class FragmentA, class FragmentB, + class CtaTileCoord + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::tuple> const& accumulators_pair, + cute::tuple const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + auto accumulators = get<0>(accumulators_pair); + auto [tiled_mma, tCrA, tCrB] = mma_inputs; + + auto [mainloop_pipeline_tma, mainloop_pipeline_cpasync, accumulator_pipeline] = pipelines; + auto [mainloop_pipe_tma_consumer_state, mainloop_pipe_cpasync_consumer_state, accumulator_pipe_producer_state] = pipeline_states; + + // + // PIPELINED MAIN LOOP + // + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + mainloop_pipeline_tma.consumer_wait(mainloop_pipe_tma_consumer_state); + mainloop_pipeline_cpasync.consumer_wait(mainloop_pipe_cpasync_consumer_state); + + int read_stage_tma = mainloop_pipe_tma_consumer_state.index(); + int read_stage_cpasync = mainloop_pipe_cpasync_consumer_state.index(); + + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage_tma), tCrB(_,_,k_block,read_stage_cpasync), accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + mainloop_pipeline_tma.consumer_release(mainloop_pipe_tma_consumer_state); + mainloop_pipeline_cpasync.consumer_release(mainloop_pipe_cpasync_consumer_state); + --k_tile_count; + ++mainloop_pipe_tma_consumer_state; + ++mainloop_pipe_cpasync_consumer_state; + } + + return cute::make_tuple(mainloop_pipe_tma_consumer_state, mainloop_pipe_cpasync_consumer_state); + } + +protected: + + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + RuntimeDataTypeA runtime_data_type_a_{}; + RuntimeDataTypeB runtime_data_type_b_{}; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp index eec265a13f..ff57f4ff76 100644 --- a/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp +++ b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp @@ -250,7 +250,7 @@ struct CollectiveMma< using Load2MmaPipeline = cutlass::PipelineTmaUmmaAsync< DispatchPolicy::Load2TransformPipelineStageCount, - ClusterShape, + ClusterShape, AtomThrShapeMNK>; using Load2MmaPipelineState = typename Load2MmaPipeline::PipelineState; @@ -318,7 +318,7 @@ struct CollectiveMma< using SmemLayoutACompute = decltype(UMMA::tile_to_mma_shape( SmemLayoutAtomACompute{}, - append(CtaShapeA_MK{}, Int{}), + append(CtaShapeA_MK{}, Int{}), (cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}))); using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( @@ -387,7 +387,7 @@ struct CollectiveMma< struct TensorStorageUntransformed { alignas(512) cute::ArrayEngine> smem_A; - cute::ArrayEngine> smem_B; + alignas(1024) cute::ArrayEngine> smem_B; cute::ArrayEngine smem_scale; cute::ArrayEngine smem_zero; }; diff --git a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp index bbfd357c88..9202ba4ef2 100644 --- a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp @@ -73,7 +73,7 @@ template < class SmemCopyAtomB_, class TransformB_> struct CollectiveMma< - MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling, + MainloopSm90ArrayTmaGmmaWarpSpecializedBlockwise, TileShape_, ElementA_, StridePairA_, @@ -92,7 +92,7 @@ struct CollectiveMma< // // Type Aliases // - using DispatchPolicy = MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling; + using DispatchPolicy = MainloopSm90ArrayTmaGmmaWarpSpecializedBlockwise; using TileShape = TileShape_; using ElementA = ElementA_; using StrideA = cute::tuple_element_t<0,StridePairA_>; @@ -382,8 +382,6 @@ struct CollectiveMma< auto [M,N,K,L] = problem_shape_MNKL; implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); - // We expect full tiles in K - implementable = implementable && K % size<2>(TileShape{}) == 0; } } @@ -824,16 +822,13 @@ struct CollectiveMma< // Prologue GMMAs tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); - pipeline.consumer_wait(smem_pipe_read, barrier_token); - - // fence_operand(); GmmaFP8Accumulation accumulation(accum, ScalePromotionInterval, size<2>(tCrA)); - warpgroup_fence_operand(accumulation()); - { + if (k_tile_count > 0) { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + pipeline.consumer_wait(smem_pipe_read, barrier_token); int read_stage = smem_pipe_read.index(); // Load per block scale values from shared memory to registers @@ -977,7 +972,7 @@ struct CollectiveMma< ++smem_pipe_release; } - if (k_tile_count) { + if (k_tile_count > 0) { pipeline.consumer_wait(smem_pipe_read, barrier_token); // @@ -1072,9 +1067,11 @@ struct CollectiveMma< /// Perform a Consumer Epilogue to release all buffers CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { - // The pipeline is not released in the first iteration - smem_pipe_release.advance(k_tile_count - 1); - pipeline.consumer_release(smem_pipe_release); + if (k_tile_count > 0) { + // The pipeline is not released in the first iteration + smem_pipe_release.advance(k_tile_count - 1); + pipeline.consumer_release(smem_pipe_release); + } } // diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp index 7be1d5ab57..440f286a58 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp @@ -73,7 +73,7 @@ template < class SmemCopyAtomB_, class TransformB_> struct CollectiveMma< - MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8, + MainloopSm90TmaGmmaWarpSpecializedBlockwiseFP8, TileShape_, ElementA_, StridePairA_, @@ -91,7 +91,7 @@ struct CollectiveMma< // // Type Aliases // - using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8; + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockwiseFP8; using TileShape = TileShape_; using ElementA = ElementA_; using StrideA = cute::tuple_element_t<0,StridePairA_>; @@ -391,12 +391,6 @@ struct CollectiveMma< implementable = false; CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem size doesn't meet the minimum alignment requirements for using TMA to load scale B.\n"); } - - // We expect full tiles in K - if (K % size<2>(TileShape{}) != 0) { - implementable = false; - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem size K is incompatible with tile size.\n"); - } return implementable; } diff --git a/include/cutlass/gemm/collective/xe_array_mma.hpp b/include/cutlass/gemm/collective/xe_array_mma.hpp index 102a431008..574f38f3e5 100644 --- a/include/cutlass/gemm/collective/xe_array_mma.hpp +++ b/include/cutlass/gemm/collective/xe_array_mma.hpp @@ -47,63 +47,30 @@ using namespace cute; template -struct CollectiveMma, TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, +struct CollectiveMma, TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, GmemTiledCopyB_, SmemLayoutAtomB_, - SmemCopyAtomB_, TransformB_> { + SmemCopyAtomB_, TransformB_> + : public CollectiveMma, TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, + GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, GmemTiledCopyB_, SmemLayoutAtomB_, + SmemCopyAtomB_, TransformB_> +{ // // Type Aliases // - using DispatchPolicy = MainloopIntelXeXMX16Group; - using WorkgroupTileShape = TileShape_; + using Base = CollectiveMma, TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, + GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, GmemTiledCopyB_, SmemLayoutAtomB_, + SmemCopyAtomB_, TransformB_>; + using BaseArguments = typename Base::Arguments; + using BaseParams = typename Base::Params; + + using DispatchPolicy = MainloopXeL1StagedGroup; using ElementA = ElementA_; using StrideA = StrideA_; - using InternalStrideA = cute::remove_pointer_t; + using InternalStrideA = typename Base::StrideA; using ElementB = ElementB_; using StrideB = StrideB_; - using InternalStrideB = cute::remove_pointer_t; - using TiledMma = TiledMma_; - using ElementAccumulator = typename TiledMma::ValTypeC; - using GmemTiledCopyA = GmemTiledCopyA_; - using GmemTiledCopyB = GmemTiledCopyB_; - using SmemLayoutAtomA = SmemLayoutAtomA_; - using SmemLayoutAtomB = SmemLayoutAtomB_; - using SmemCopyAtomA = SmemCopyAtomA_; - using SmemCopyAtomB = SmemCopyAtomB_; - using TransformA = TransformA_; - using TransformB = TransformB_; - using ArchTag = typename DispatchPolicy::ArchTag; - - static_assert(platform::is_same::value, "MainloopIntelXeXMX16Array requires that A and B have same type."); - - static_assert(std::is_same_v, "Transformation for A is not currently supported on Intel PVC"); - static_assert(std::is_same_v, "Transformation for B is not currently supported on Intel PVC"); - - static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; - - using MmaAtomShape = typename TiledMma::AtomShape_MNK; - - static constexpr int BLK_M = get<0>(WorkgroupTileShape{}); - static constexpr int BLK_N = get<1>(WorkgroupTileShape{}); - static constexpr int BLK_K = get<2>(WorkgroupTileShape{}); + using InternalStrideB = typename Base::StrideB; - static constexpr int ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); - static constexpr int ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); - static constexpr int ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); - - static constexpr int SG_M = ceil_div(BLK_M, ATOM_M); - static constexpr int SG_N = ceil_div(BLK_N, ATOM_N); - static constexpr int SG_K = ceil_div(BLK_K, ATOM_K); - using SubgroupTileShape = Shape, C, C>; - - static constexpr int Num_SGs = ATOM_N * ATOM_M * ATOM_K; - static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); - - using Copy_A = typename Copy_Traits::template DefaultTiledCopy; - using Copy_B = typename Copy_Traits::template DefaultTiledCopy; - - using TensorMKL = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), InternalStrideA{})); //(m, k) - using TensorNKL = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), InternalStrideB{})); //(n, k) - using MainloopTensors = cute::tuple; // Host side kernel arguments struct Arguments { ElementA const** ptr_A; @@ -112,12 +79,7 @@ struct CollectiveMma, TileShape_, El StrideB dB; }; - struct Params { - ElementA const** ptr_A; - StrideA dA; - ElementB const** ptr_B; - StrideB dB; - }; + using Params = Arguments; // // Methods @@ -129,18 +91,7 @@ struct CollectiveMma, TileShape_, El static constexpr Params to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { (void) workspace; - - auto problem_shape_MNK = repeat_like(typename ProblemShape::UnderlyingProblemShape{}, int32_t(1));; - auto init_M = get<0>(problem_shape_MNK); - auto init_N = get<1>(problem_shape_MNK); - auto init_K = get<2>(problem_shape_MNK); - - return Params{ - args.ptr_A, - args.dA, - args.ptr_B, - args.dB - }; + return args; } template @@ -179,120 +130,12 @@ struct CollectiveMma, TileShape_, El return implementable; } - /// Perform a subgroup-scoped matrix multiply-accumulate - template - CUTLASS_DEVICE void operator()(FrgTensorD &accum, TensorA gA, TensorB gB, FrgTensorC const &src_accum, - KTileIterator k_tile_iter, int const& k_tile_count, - BlkCoord const &blk_coord, int const &K_start, int const& thread_idx, - Params const &mainloop, LoadTensors const& load_tensors) { - static_assert(is_rmem::value, "D tensor must be rmem resident."); - static_assert(is_rmem::value, "C tensor must be rmem resident."); - - (void)thread_idx; - - Copy_A tiled_copy_a{Copy_A{}.with(get<0>(load_tensors))}; - Copy_B tiled_copy_b{Copy_B{}.with(get<1>(load_tensors))}; - - auto thr_copy_A = tiled_copy_a.get_slice(thread_idx); - auto thr_copy_B = tiled_copy_b.get_slice(thread_idx); - - // Instantiate the MMA object and get thread slice - TiledMma tiled_mma; - // TODO(Codeplay): see if we can make this nicer - // To make all work items in a subgroup have the same global tensors pass in the index of work item 0 in each subgroup - auto sg = compat::get_nd_item<1>().get_sub_group(); - auto first_thread_in_sg_idx = sg.get_group_linear_id() * DispatchPolicy::SubgroupSize; - auto thr_mma = tiled_mma.get_slice(first_thread_in_sg_idx); - - // Partition global counting tensors for MMA - Tensor tCgA = thr_mma.partition_A(gA); - Tensor tCgB = thr_mma.partition_B(gB); - - Tensor tCrA = make_tensor(make_fragment_layout(tiled_copy_a, tCgA(_,_,_,0).shape())); - Tensor tCrB = make_tensor(make_fragment_layout(tiled_copy_b, tCgB(_,_,_,0).shape())); - - // Retile registers for copies - Tensor tArA = thr_copy_A.retile_D(tCrA); - Tensor tBrB = thr_copy_B.retile_D(tCrB); - - // Retile global counting tensors for copies - Tensor tAgA = thr_copy_A.retile_S(tCgA); - Tensor tBgB = thr_copy_B.retile_S(tCgB); - - auto tiled_prefetch_a = cute::prefetch_selector,Int>, Num_SGs>(tiled_copy_a); - auto tiled_prefetch_b = cute::prefetch_selector,Int>, Num_SGs>(tiled_copy_b); - auto thr_prefetch_A = tiled_prefetch_a.get_slice(thread_idx); - auto thr_prefetch_B = tiled_prefetch_b.get_slice(thread_idx); - - // Partition global tile for prefetch - auto pAgA = thr_prefetch_A.partition_S(gA); - auto pBgB = thr_prefetch_B.partition_S(gB); - -#if CUTLASS_ENABLE_DEBUG_PRINTS - if (cutlass::thread(LOG_THREAD, LOG_GROUP)) { - print("======================= A: \n"); - print(" gA : "); print(gA); print("\n"); - print("tCgA : "); print(tCgA); print("\n"); - print("tAgA : "); print(tAgA); print("\n"); - - print("===================== B :\n"); - print(" gB : "); print(gB); print("\n"); - print("tCgB : "); print(tCgB); print("\n"); - print("tBgB : "); print(tBgB); print("\n"); - - print("===================== Config: \n"); - print(" threads per workgroup : "); print(MaxThreadsPerBlock); print("\n"); - print(" SubgroupTileShape : "); print(SubgroupTileShape{}); print("\n"); - } -#endif - - // - // Mainloop - // - const auto k_start_idx = crd2idx((*k_tile_iter), make_shape(K_start)); - constexpr int barrier_scope = 2; - int prefetch_k = k_start_idx; - - CUTLASS_PRAGMA_UNROLL - for (; prefetch_k < DispatchPolicy::Stages; prefetch_k++) { - prefetch(tiled_prefetch_a, pAgA(_, _, _, prefetch_k)); - prefetch(tiled_prefetch_b, pBgB(_, _, _, prefetch_k)); - } - - for (int k_tile = k_start_idx; k_tile < k_tile_count + k_start_idx; k_tile++, prefetch_k++) { - barrier_arrive(barrier_scope); - // Copy gmem to rmem for the first k_tile - copy(tiled_copy_a, tAgA(_,_,_,k_tile), tArA); - copy(tiled_copy_b, tBgB(_,_,_,k_tile), tBrB); - - if (prefetch_k < k_tile_count) { - prefetch(tiled_prefetch_a, pAgA(_, _, _, prefetch_k)); - prefetch(tiled_prefetch_b, pBgB(_, _, _, prefetch_k)); - } - - cute::gemm(tiled_mma, tCrA, tCrB, accum); - barrier_wait(barrier_scope); - } + CUTLASS_DEVICE static constexpr BaseArguments + to_base_arguments(Arguments const &args, int idx) { + return BaseArguments{ args.ptr_A[idx], args.dA[idx], + args.ptr_B[idx], args.dB[idx]}; } - template - CUTLASS_DEVICE auto update_tensor_shape_stride( - Params const& mainloop_params, - int32_t const& next_group, - ProblemShape_MNKL const& problem_shape_mnkl) { - const int32_t M = get<0>(problem_shape_mnkl); - const int32_t N = get<1>(problem_shape_mnkl); - const int32_t K = get<2>(problem_shape_mnkl); - - ElementA const* ptr_A_curr_batch = reinterpret_cast(mainloop_params.ptr_A[next_group]); - ElementB const* ptr_B_curr_batch = reinterpret_cast(mainloop_params.ptr_B[next_group]); - - Tensor mA = make_tensor(make_gmem_ptr(ptr_A_curr_batch), make_shape(M, K,(int32_t)1), mainloop_params.dA[next_group]); - Tensor mB = make_tensor(make_gmem_ptr(ptr_B_curr_batch), make_shape(N, K,(int32_t)1), mainloop_params.dB[next_group]); - - return cute::make_tuple(mA, mB); - } }; } // namespace cutlass::gemm::collective diff --git a/include/cutlass/gemm/collective/xe_array_mma_fp8_legacy.hpp b/include/cutlass/gemm/collective/xe_array_mma_fp8_legacy.hpp new file mode 100644 index 0000000000..d71206e868 --- /dev/null +++ b/include/cutlass/gemm/collective/xe_array_mma_fp8_legacy.hpp @@ -0,0 +1,192 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cutlass/fp8_to_fp16.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct CollectiveMma, TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, + GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, GmemTiledCopyB_, SmemLayoutAtomB_, + SmemCopyAtomB_, TransformB_> + : public CollectiveMma< + MainloopIntelW8A8, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> + { + // + // Type Aliases + // + using Base = CollectiveMma< MainloopIntelW8A8, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_>; + using BaseArguments = typename Base::Arguments; + using BaseParams = typename Base::Params; + + using DispatchPolicy = MainloopIntelXeXMX16GroupFP8; + using WorkgroupTileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using InternalStrideA = cute::remove_pointer_t; + using ElementB = ElementB_; + using StrideB = StrideB_; + using InternalStrideB = cute::remove_pointer_t; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + // Host side kernel arguments + struct Arguments { + ElementA const** ptr_A; + StrideA dA; + ElementB const** ptr_B; + StrideB dB; + }; + + using Params = Arguments; + + // + // Methods + // + + CollectiveMma() = default; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + const int32_t mock_L = 1; + auto problem_shape_MNK = repeat_like(typename ProblemShape::UnderlyingProblemShape{}, mock_L);; + auto init_M = get<0>(problem_shape_MNK); + auto init_N = get<1>(problem_shape_MNK); + auto init_K = get<2>(problem_shape_MNK); + + return Params{ + args.ptr_A, + args.dA, + args.ptr_B, + args.dB + }; + } + + template + static bool + can_implement( + ProblemShape problem_shapes, + Arguments const& args) { + constexpr int copy_alignment_bits = 128; + constexpr int batch_alignment_bits = 512; + auto problem_shape_MNKL = append<4>(problem_shapes, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + + constexpr int min_aligned_elements_A = copy_alignment_bits / sizeof_bits::value; + constexpr int min_aligned_elements_B = copy_alignment_bits / sizeof_bits::value; + constexpr int min_batch_aligned_elements_A = batch_alignment_bits / sizeof_bits::value; + constexpr int min_batch_aligned_elements_B = batch_alignment_bits / sizeof_bits::value; + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + + implementable &= cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); + implementable &= cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); + + if (L > 1) { + implementable &= get<2>(InternalStrideA{}) % min_batch_aligned_elements_A == 0; + implementable &= get<2>(InternalStrideB{}) % min_batch_aligned_elements_B == 0; + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for XE 2D copy.\n"); + } + + return implementable; + } + + CUTLASS_DEVICE static constexpr BaseArguments + to_base_arguments(Arguments const &args, int idx) { + return BaseArguments{ args.ptr_A[idx], args.dA[idx], + args.ptr_B[idx], args.dB[idx]}; + } +}; + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/xe_array_mma_legacy.hpp b/include/cutlass/gemm/collective/xe_array_mma_legacy.hpp new file mode 100644 index 0000000000..b9e548cd33 --- /dev/null +++ b/include/cutlass/gemm/collective/xe_array_mma_legacy.hpp @@ -0,0 +1,142 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct CollectiveMma, TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, + GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, GmemTiledCopyB_, SmemLayoutAtomB_, + SmemCopyAtomB_, TransformB_> + : public CollectiveMma, TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, + GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, GmemTiledCopyB_, SmemLayoutAtomB_, + SmemCopyAtomB_, TransformB_> +{ + // + // Type Aliases + // + using Base = CollectiveMma, TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, + GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, GmemTiledCopyB_, SmemLayoutAtomB_, + SmemCopyAtomB_, TransformB_>; + using BaseArguments = typename Base::Arguments; + using BaseParams = typename Base::Params; + + using DispatchPolicy = MainloopIntelXeXMX16Group; + using ElementA = ElementA_; + using StrideA = StrideA_; + using InternalStrideA = typename Base::StrideA; + using ElementB = ElementB_; + using StrideB = StrideB_; + using InternalStrideB = typename Base::StrideB; + + // Host side kernel arguments + struct Arguments { + ElementA const** ptr_A; + StrideA dA; + ElementB const** ptr_B; + StrideB dB; + }; + + using Params = Arguments; + + // + // Methods + // + + CollectiveMma() = default; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + return args; + } + + template + static bool + can_implement( + ProblemShape problem_shapes, + Arguments const& args) { + constexpr int copy_alignment_bits = 128; + constexpr int batch_alignment_bits = 512; + auto problem_shape_MNKL = append<4>(problem_shapes, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + + constexpr int min_aligned_elements_A = copy_alignment_bits / sizeof_bits::value; + constexpr int min_aligned_elements_B = copy_alignment_bits / sizeof_bits::value; + constexpr int min_batch_aligned_elements_A = batch_alignment_bits / sizeof_bits::value; + constexpr int min_batch_aligned_elements_B = batch_alignment_bits / sizeof_bits::value; + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + + implementable &= cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); + implementable &= cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); + + if (L > 1) { + implementable &= get<2>(InternalStrideA{}) % min_batch_aligned_elements_A == 0; + implementable &= get<2>(InternalStrideB{}) % min_batch_aligned_elements_B == 0; + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for XE 2D copy.\n"); + } + + return implementable; + } + + CUTLASS_DEVICE static constexpr BaseArguments + to_base_arguments(Arguments const &args, int idx) { + return BaseArguments{ args.ptr_A[idx], args.dA[idx], + args.ptr_B[idx], args.dB[idx]}; + } +}; + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/xe_array_mma_mixed_input.hpp b/include/cutlass/gemm/collective/xe_array_mma_mixed_input.hpp index 56161e9a37..334d31aca4 100644 --- a/include/cutlass/gemm/collective/xe_array_mma_mixed_input.hpp +++ b/include/cutlass/gemm/collective/xe_array_mma_mixed_input.hpp @@ -78,173 +78,67 @@ struct CollectiveMma< GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, - TransformB_> + TransformB_> : public CollectiveMma, + TileShape_, + ElementAOptionalTuple, + StrideA_, + ElementBOptionalTuple, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> { -private: - enum class ConversionMode { - DirectConvert, - ConvertAndScale, - ConvertAndScaleWithZero - }; - public: // // Type Aliases // - using DispatchPolicy = MainloopIntelXeXMX16GroupMixedPrecision; - using WorkgroupTileShape = TileShape_; - - - static_assert(cute::is_tuple::value ^ cute::is_tuple::value, - "Either A OR B must be a tuple. It must take the from {ElementOperand, [ElementScale]," - "[ElementZero]}. Inputs in [] are optional."); - - using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementAOptionalTuple>; - using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementBOptionalTuple>; + using Base = CollectiveMma, + TileShape_, + ElementAOptionalTuple, + StrideA_, + ElementBOptionalTuple, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_>; + using BaseArguments = typename Base::Arguments; + using BaseParams = typename Base::Params; + + using ElementA = typename Base::ElementA; + using ElementB = typename Base::ElementB; + using ElementScale = typename Base::ElementScale; + using ElementZero = typename Base::ElementZero; + + using StrideScale = typename Base::StrideScale; + using StrideZero = typename Base::StrideZero; - static constexpr bool IsATransformed = cute::is_tuple::value; - - using ElementMMA = cute::conditional_t; - using ElementQuant = cute::conditional_t; - - using ElementScale = cute::conditional_t, detail::deduce_mixed_width_dtype_t<1, ElementBOptionalTuple>>; - using StrideScale = cute::conditional_t, detail::deduce_mixed_width_dtype_t<2, ElementBOptionalTuple>>; + using NonVoidStrideScale = cute::conditional_t, cute::Stride<_1, int64_t, int64_t> *, StrideScale>; + using NonVoidStrideZero = cute::conditional_t, cute::Stride<_1, int64_t, int64_t> *, StrideZero>; - using ElementZero = cute::conditional_t, detail::deduce_mixed_width_dtype_t<3, ElementBOptionalTuple>>; - using StrideZero = cute::conditional_t, detail::deduce_mixed_width_dtype_t<4, ElementBOptionalTuple>>; + using NonVoidElementScale = typename Base::NonVoidElementScale; + using NonVoidElementZero = typename Base::NonVoidElementZero; - // For cases where we can't have a void type, we can use this to allow the code to compile when the scale / zero is void. - using NonVoidElementScale = cute::conditional_t, ElementMMA, ElementScale>; - using NonVoidElementZero = cute::conditional_t, ElementMMA, ElementZero>; + using DispatchPolicy = MainloopIntelXeXMX16GroupMixedPrecision; - using NonVoidStrideScale = cute::conditional_t, cute::Stride<_1, int64_t, int64_t> *, StrideScale>; - using NonVoidStrideZero = cute::conditional_t, cute::Stride<_1, int64_t, int64_t> *, StrideZero>; using InternalNonVoidStrideScale = cute::remove_pointer_t; using InternalNonVoidStrideZero = cute::remove_pointer_t; - static constexpr auto zero_elements_packed_along_k = get<0>(InternalNonVoidStrideZero{}); using StrideA = StrideA_; - using InternalStrideA = cute::remove_pointer_t; + using InternalStrideA = typename Base::StrideA; using StrideB = StrideB_; - using InternalStrideB = cute::remove_pointer_t; - - using TiledMma = TiledMma_; - using ElementAccumulator = typename TiledMma::ValTypeC; - - using GmemTiledCopyA = GmemTiledCopyA_; - using GmemTiledCopyB = GmemTiledCopyB_; - - using SmemLayoutAtomA = SmemLayoutAtomA_; - using SmemLayoutAtomB = SmemLayoutAtomB_; - using SmemCopyAtomA = SmemCopyAtomA_; - using SmemCopyAtomB = SmemCopyAtomB_; - using TransformA = TransformA_; - using TransformB = TransformB_; - using ArchTag = typename DispatchPolicy::ArchTag; - using MmaType = typename TiledMma::ValTypeA; // ValTypeA and ValTypeB are always same and reflects MMA type on intel Xe - using LargerElementType = std::conditional_t<(cute::sizeof_bits_v > cute::sizeof_bits_v), - ElementA, - ElementB>; - - static_assert(!cute::is_same_v, "Mixed precision GEMM requires different types for A and B!"); - static_assert(std::is_same_v, "Transformation for A is not currently supported on Intel PVC"); - static_assert(std::is_same_v, "Transformation for B is not currently supported on Intel PVC"); - -private: - - static constexpr ConversionMode - get_conversion_mode() { - if constexpr (cute::is_void_v) { - return ConversionMode::DirectConvert; - } - else if constexpr (cute::is_void_v) { - return ConversionMode::ConvertAndScale; - } - else { - return ConversionMode::ConvertAndScaleWithZero; - } - } - - static constexpr ConversionMode KernelConversionMode = get_conversion_mode(); - static constexpr bool ModeHasScales = KernelConversionMode == ConversionMode::ConvertAndScale || - KernelConversionMode == ConversionMode::ConvertAndScaleWithZero; - static constexpr bool ModeHasScalesZero = KernelConversionMode == ConversionMode::ConvertAndScaleWithZero; -public: - static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; - - using MmaAtomShape = typename TiledMma::AtomShape_MNK; - - static constexpr int BLK_M = get<0>(WorkgroupTileShape{}); - static constexpr int BLK_N = get<1>(WorkgroupTileShape{}); - static constexpr int BLK_K = get<2>(WorkgroupTileShape{}); - - static constexpr int ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); - static constexpr int ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); - static constexpr int ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); - - static constexpr int SG_M = ceil_div(BLK_M, ATOM_M); - static constexpr int SG_N = ceil_div(BLK_N, ATOM_N); - static constexpr int SG_K = ceil_div(BLK_K, ATOM_K); - using SubgroupTileShape = Shape, C, C>; - - using GmemTiledCopyScale = typename scale_zero_copy_traits::type; - using GmemTiledCopyZero = typename scale_zero_copy_traits::type; - - static constexpr int Num_SGs = ATOM_N * ATOM_M * ATOM_K; - static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); - - using CopyThreadShape = Shape<_1, Int>; - using CopyThreadShapeRev = decltype(cute::reverse(CopyThreadShape{})); - - using traits_load_A = Copy_Traits; - using atom_load_A = Copy_Atom; - using val_layout_load_A = decltype(make_layout(shape_div(typename traits_load_A::BlockShape{}, CopyThreadShape{}))); - using Copy_A = decltype(make_tiled_copy(atom_load_A{}, Layout{}, val_layout_load_A{})); - - using traits_load_B = Copy_Traits; - using atom_load_B = Copy_Atom; - using val_layout_load_B = decltype(make_layout(shape_div(typename traits_load_B::BlockShape{}, CopyThreadShape{}))); - using Copy_B = decltype(make_tiled_copy(atom_load_B{}, Layout{}, val_layout_load_B{})); - - using traits_load_scale = Copy_Traits; - using atom_load_scale = Copy_Atom; - using val_layout_load_scale = decltype(make_layout(shape_div(typename traits_load_scale::BlockShape{}, CopyThreadShapeRev{}))); - using Copy_Scale = decltype(make_tiled_copy(atom_load_scale{}, Layout{}, val_layout_load_scale{})); - - using traits_load_zero = Copy_Traits; - using atom_load_zero = Copy_Atom; - using val_layout_load_zero = decltype(make_layout(shape_div(typename traits_load_zero::BlockShape{}, CopyThreadShapeRev{}))); - using Copy_Zero = decltype(make_tiled_copy(atom_load_zero{}, Layout{}, val_layout_load_zero{})); - - using TensorMKL = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), InternalStrideA{})); //(m, k) - using TensorS = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), InternalNonVoidStrideScale{})); - - // Purpose of this struct is to create a pointer of required type - // for creating the TensorNKL type - template - struct GenPtrType; - - template - struct GenPtrType < 8>> { - static constexpr auto get_pointer() { - // For int4_t type, subbyte_iterator does not accept nullptr, - // so need to create a pointer of type int4_t to get this code - // working. - T* ptr; - return cute::subbyte_iterator(ptr); - } - }; - - template - struct GenPtrType >= 8>> { - static constexpr auto get_pointer() { - return make_gmem_ptr(static_cast(nullptr)); - } - }; - - using TensorNKL = decltype(make_tensor(GenPtrType::get_pointer(), make_shape(0,0,0), InternalStrideB{})); //(n, k) - using TensorZ = decltype(make_tensor(GenPtrType::get_pointer(), make_shape(0,0,0), InternalNonVoidStrideScale{})); - using MainloopTensors = cute::tuple; + using InternalStrideB = typename Base::StrideB; // Host side kernel arguments struct Arguments { @@ -259,17 +153,7 @@ struct CollectiveMma< int group_size = 1; }; - struct Params { - ElementA const** ptr_A; - StrideA dA; - ElementB const** ptr_B; - StrideB dB; - NonVoidElementScale const** ptr_S = nullptr; - NonVoidStrideScale dS{}; - NonVoidElementZero const** ptr_Z = nullptr; - NonVoidStrideZero dZ{}; - int group_size; - }; + using Params = Arguments; // // Methods @@ -321,474 +205,15 @@ struct CollectiveMma< return implementable; } - // Helper functions to select packing for conversion - template - struct select_packing { // Naive packing policy - static constexpr auto value() { - return Int, sizeof_bits_v))>{}; - } - }; - - template - CUTLASS_DEVICE typename std::enable_if_t == 4> - transform_quant( - Tensor const& in, - Tensor& out, - Tensor& tCrS_input, - Tensor tCrZ_input - ) { - // TODO (Codeplay): add assert here because int4 is not currently supported - static_assert(!IsATransformed); - - static_assert(is_rmem::value, "Input tensor for conversion must come from registers"); - static_assert(size_v == cosize_v); - static_assert(size_v == cosize_v); - static_assert(std::is_same_v); - - using SrcType = typename EngineIn::value_type; - using DstType = typename EngineOut::value_type; - using ZeroType = typename EngineZeros::value_type; - using ScaleType = typename EngineScales::value_type; - - static constexpr auto DPAS = decltype(size<0>(in))::value; - static constexpr auto N = decltype(size<1>(in))::value; - static constexpr auto K = decltype(size<2>(in))::value; - - using format_type = ushort; - static constexpr auto src_bits = sizeof_bits_v; - static constexpr auto scalar = sizeof_bits_v / src_bits; - static constexpr auto loop_cnt = decltype(size(out))::value / N; - static_assert((scalar % N) == 0); - - // for tuning performance - static constexpr auto vec_size = scalar; - static constexpr auto splits = loop_cnt / vec_size; - static_assert(vec_size <= scalar); - - // reshape tensors for easy access - auto s_tensor = make_tensor((format_type*)(raw_pointer_cast(in.data())), Shape, Int>{}); - auto d_tensor = make_tensor(out.data(), Shape, Int, Int>{}); - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < N; n++) { - const auto ts = tCrS_input(n); - const auto tz = [&](){ - if constexpr (sizeof_bits_v >= 8) { - return tCrZ_input(n); - } else { - return tCrZ_input(n).get(); - } - }(); - - auto& src = *(cute::array*)(s_tensor(_, n).data()); - - CUTLASS_PRAGMA_UNROLL - for (int s = 0; s < splits; s++) { - auto idx = vec_size * s / scalar; - auto format_data = src[idx]; - - auto& dst = *(cute::array*)(d_tensor(_, s, n).data()); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < vec_size; i++) { - auto data = [&](){ - if constexpr (cutlass::platform::numeric_limits::is_signed) { - return static_cast((format_data >> (src_bits * i)) & 0xf); - } else { - return (format_data >> (src_bits * i)) & 0xf; - } - }(); - - if constexpr (ModeHasScales) { - if constexpr (IsATransformed) { - static_assert(dependent_false && "ATransform not support now"); - } else { - using ret_type = cute::conditional_t >= 8, ZeroType, int8_t>; - ret_type minus(data); - if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { - minus = static_cast(data) - static_cast(tz); - } - dst[i] = (static_cast(minus)) * ts; - } - } else { - dst[i] = static_cast(data); - } - } - } - } + CUTLASS_DEVICE static constexpr BaseArguments + to_base_arguments(Arguments const &args, int idx) { + return BaseArguments{ args.ptr_A[idx], args.dA[idx], + args.ptr_B[idx], args.dB[idx], + args.ptr_S[idx], args.dS[idx], + args.ptr_Z[idx], args.dZ[idx], + args.group_size}; } - /// Utilities to transform A. - template - CUTLASS_DEVICE typename std::enable_if_t >= 8> - transform_quant( - Tensor const& tCrA_load, - Tensor& tCrA_mma, - Tensor& tCrS_input, - Tensor& tCrZ_input - ) { - - static_assert(is_rmem::value, "Input tensor for A conversion must come from registers"); - static_assert(size_v == cosize_v); - static_assert(size_v == cosize_v); - static_assert(std::is_same_v); - static_assert(std::is_same_v); - - using SrcType = typename EngineIn::value_type; - using DstType = typename EngineOut::value_type; - - if constexpr(cute::is_any_of_v - && cute::is_any_of_v) { - convert_FP8_to_FP16(make_tensor(reinterpret_cast(tCrA_load.data()), tCrA_load.layout()), tCrA_mma); - } else { - auto const& src = tCrA_load(_, _, _); - auto const& dst = tCrA_mma(_, _, _); - auto pSrc = const_cast(raw_pointer_cast(src.data())); - auto pDst = const_cast(raw_pointer_cast(dst.data())); - constexpr int num_elements = decltype(size(src))::value; - - // TODO(Codeplay): (perf) consider replacing `pack` with `num_elements` here - See xe_flash_attn_mma.hpp - constexpr int pack = decltype(select_packing::value())::value; - using Converter = cutlass::NumericArrayConverter; - using SrcArray = cutlass::Array; - using DstArray = cutlass::Array; - constexpr int iters = num_elements / pack; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < iters; ++i) { - SrcArray const* pSrcArr = reinterpret_cast(pSrc) + i; - DstArray* pDstArr = reinterpret_cast(pDst) + i; - *pDstArr = Converter::convert(*pSrcArr); - } - } - - if constexpr (ModeHasScales) { - if constexpr(IsATransformed){ - // The current scale load atom (1x32) gives 2 scale values to - // each thread. All threads need access to all other threads - // scale values, and each scale value is reused twice (unrolled) - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 16; ++i) { - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < 2; ++j) { - auto scale = shfl_sync(0xFFFFFFFF, tCrS_input(j), i); - if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero){ - auto zero = shfl_sync(0xFFFFFFFF, tCrZ_input(j), i); - tCrA_mma(_, _, 0)[j * 16 + i] -= zero; - tCrA_mma(_, _, 1)[j * 16 + i] -= zero; - } - tCrA_mma(_, _, 0)[j * 16 + i] *= scale; - tCrA_mma(_, _, 1)[j * 16 + i] *= scale; - } - } - } else { - static constexpr auto N = decltype(size<1>(tCrA_load))::value; - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < N; ++n) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < decltype(size(tCrA_load))::value / N; ++i) { - if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero){ - tCrA_mma(_, n, _)[i] -= tCrZ_input(n); - } - tCrA_mma(_, n, _)[i] *= tCrS_input(n); - } - } - } - } - } - -template -CUTLASS_DEVICE auto create_copies(LoadTensors const& load_tensors) { - Copy_A tiled_copy_a{Copy_A{}.with(get<0>(load_tensors))}; - Copy_B tiled_copy_b{Copy_B{}.with(get<1>(load_tensors))}; - - if constexpr(KernelConversionMode == ConversionMode::DirectConvert){ - return cute::make_tuple(tiled_copy_a, tiled_copy_b, Copy_Scale{}, Copy_Zero{}); - } - - Copy_Scale tiled_copy_scale{Copy_Scale{}.with(get<2>(load_tensors))}; - - if constexpr(KernelConversionMode == ConversionMode::ConvertAndScale){ - return cute::make_tuple(tiled_copy_a, tiled_copy_b, tiled_copy_scale, Copy_Zero{}); - } - - Copy_Zero tiled_copy_zero{Copy_Zero{}.with(get<3>(load_tensors))}; - return cute::make_tuple(tiled_copy_a, tiled_copy_b, tiled_copy_scale, tiled_copy_zero); -} - - /// Perform a subgroup-scoped matrix multiply-accumulate - template - CUTLASS_DEVICE void - operator() ( - FrgTensorD &accum, - TensorA gA, - TensorB gB, - FrgTensorC const &src_accum, - KTileIterator k_tile_iter, int k_tile_count, - BlkCoord const &blk_coord, - int const &K_start, - int thread_idx, - Params const& mainloop, - LoadTensors const& load_tensors) - { - static_assert(is_rmem::value, "D tensor must be rmem resident."); - static_assert(is_rmem::value, "C tensor must be rmem resident."); - - auto [tiled_copy_a, tiled_copy_b, tiled_copy_scale, tiled_copy_zero] = create_copies(load_tensors); - - // Partition the copying of A and B tiles across the threads - auto thr_copy_A = tiled_copy_a.get_slice(thread_idx); - auto thr_copy_B = tiled_copy_b.get_slice(thread_idx); - auto thr_copy_scale = tiled_copy_scale.get_slice(thread_idx); - auto thr_copy_zero = tiled_copy_zero.get_slice(thread_idx); - - // Instantiate the MMA object and get thread slice - TiledMma tiled_mma; - auto sg = compat::get_nd_item<1>().get_sub_group(); - auto first_thread_in_sg_idx = sg.get_group_linear_id() * DispatchPolicy::SubgroupSize; - auto thr_mma = tiled_mma.get_slice(first_thread_in_sg_idx); - - // Partition - Tensor tCgA = thr_mma.partition_A(gA); - Tensor tCgB = thr_mma.partition_B(gB); - - // Create fragments - Tensor mma_A = make_tensor(make_fragment_layout(tiled_copy_a, tCgA(_,_,_,0).shape())); - Tensor mma_B = make_tensor(make_fragment_layout(tiled_copy_b, tCgB(_,_,_,0).shape())); - - // If IsATransformed, we need modes M_atom, and M_iter from fragment_A - // layout else we need mode N_iter from fragment_B layout. - static constexpr auto scale_traits_size = decltype(size(typename GmemTiledCopyScale::BlockShape{}))::value / SubgroupSize; - static constexpr auto scale_traits_num = SG_N / size<1>(typename GmemTiledCopyScale::BlockShape{}); - using FragScaleLayout = std::conditional_t>, - Layout, Int, _1>>>; - Tensor fragment_scale_input = make_tensor(FragScaleLayout{}); - - static constexpr auto zero_traits_size = decltype(size(typename GmemTiledCopyZero::BlockShape{}))::value / SubgroupSize; - static constexpr auto zero_traits_num = SG_N * zero_elements_packed_along_k / size<1>(typename GmemTiledCopyZero::BlockShape{}); - using FragZeroLayout = std::conditional_t>, - Layout, Int, _1>>>; - Tensor fragment_zero_input = make_tensor (FragZeroLayout{}); - - // narrow input fragment - Tensor quant_frag = make_tensor( - std::conditional_t{}); - - static_assert(std::is_same_v); - static_assert(std::is_same_v); - static_assert(std::is_same_v); - - // Retile for copy - auto [frag_copy_A, frag_copy_B] = [&](){ - if constexpr (IsATransformed) { - return std::make_pair(thr_copy_A.retile_D(quant_frag), thr_copy_B.retile_D(mma_B)); - } else { - return std::make_pair(thr_copy_A.retile_D(mma_A), thr_copy_B.retile_D(quant_frag)); - } - }(); - - Tensor copy_tCrS = thr_copy_scale.retile_D(fragment_scale_input); - Tensor copy_tCrZ = thr_copy_zero.retile_D(fragment_zero_input); - - // Retile global tile for copies - Tensor tAgA = thr_copy_A.retile_S(tCgA); - Tensor tBgB = thr_copy_B.retile_S(tCgB); - - auto tiled_prefetch_a = cute::prefetch_selector,Int>, Num_SGs>(tiled_copy_a);; - auto tiled_prefetch_b = cute::prefetch_selector,Int>, Num_SGs>(tiled_copy_b);; - auto thr_prefetch_A = tiled_prefetch_a.get_slice(thread_idx); - auto thr_prefetch_B = tiled_prefetch_b.get_slice(thread_idx); - - // Partition global tile for prefetch - auto pAgA = thr_prefetch_A.partition_S(gA); - auto pBgB = thr_prefetch_B.partition_S(gB); - - // - // Mainloop - // - // TODO(Codeplay): Define these coord tensors using proper cute logic - auto [m_idx, n_idx, k_idx, l_idx] = blk_coord; - const int m_coord = m_idx * BLK_M + (get_sub_group_id() / ATOM_N) * SG_M; - const int n_coord = n_idx * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N; - const int l_coord = 0; - - Tensor copy_iter_s = [&](){ - if constexpr(IsATransformed){ - return make_tensor(make_inttuple_iter(make_coord(m_coord, 0, l_coord)), - make_layout(make_shape(_2{}, _1{}, _1{}, k_tile_count), - make_stride(E<0>{} * _16{}, E<0>{} * _32{}, _0{}, E<1>{} * _1{}))); - }else{ - return make_tensor(make_inttuple_iter(make_coord(n_coord, 0, l_coord)), - make_layout(make_shape(Int{}, Int{}, _1{}, k_tile_count), - make_stride(E<0>{} * _16{}, E<0>{} * size<1>(typename GmemTiledCopyScale::BlockShape{}), _0{}, E<1>{} * _1{}))); - } - }(); - - Tensor copy_iter_z = [&](){ - if constexpr(IsATransformed){ - return make_tensor(make_inttuple_iter(make_coord(m_coord, 0, l_coord)), - make_layout(make_shape(_2{}, _1{}, _1{}, k_tile_count), - make_stride(E<0>{} * _16{}, E<0>{} * _32{}, _0{}, E<1>{} * _1{}))); - }else{ - return make_tensor(make_inttuple_iter(make_coord(n_coord * zero_elements_packed_along_k, 0, l_coord)), - make_layout(make_shape(Int{}, Int{}, _1{}, k_tile_count), - make_stride(E<0>{} * _16{}, E<0>{} * size<1>(typename GmemTiledCopyZero::BlockShape{}), _0{}, E<1>{} * _1{}))); - } - }(); - - #if CUTLASS_ENABLE_DEBUG_PRINTS - #define PRINT(x) print(#x ": "); print(x); print("\n"); - if (cutlass::thread(LOG_THREAD, LOG_GROUP)) { - print("======================= A: \n"); - PRINT(gA); - PRINT(tCgA); - PRINT(tAgA); - PRINT(mma_A); - PRINT(frag_copy_A); - - print("===================== B :\n"); - PRINT(gB); - PRINT(tCgB); - PRINT(tBgB); - PRINT(mma_B); - PRINT(frag_copy_B); - - print("===================== Config: \n"); - PRINT(MaxThreadsPerBlock); - PRINT(SubgroupTileShape{}); - - PRINT(tiled_prefetch_a); - PRINT(tiled_prefetch_b); - PRINT(pAgA); - PRINT(pBgB); - } - #undef PRINT - #endif - - const int k_start_idx = crd2idx((*k_tile_iter), make_shape(K_start)); - constexpr int barrier_scope = 2; - int prefetch_k = k_start_idx; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < DispatchPolicy::Stages; i++, prefetch_k++) { - prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k)); - prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k)); - } - - const int k_reload_factor = mainloop.group_size / BLK_K; - - for (int k_tile = k_start_idx; k_tile < k_tile_count + k_start_idx; k_tile++, prefetch_k++) { - barrier_arrive(barrier_scope); - - // Copy gmem to rmem for the first k_tile - copy(tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A); - copy(tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B); - - if constexpr(ModeHasScales) { - copy(tiled_copy_scale, copy_iter_s(_, _, _, k_tile / k_reload_factor), copy_tCrS); - } - if constexpr(KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { - copy(tiled_copy_zero, copy_iter_z(_, _, _, k_tile / k_reload_factor / zero_elements_packed_along_k), copy_tCrZ); - } - - if(prefetch_k < k_tile_count) { - prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k)); - prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k)); - } - - if constexpr (IsATransformed) { - transform_quant(quant_frag, mma_A, fragment_scale_input, - fragment_zero_input); - } else { - if constexpr (ModeHasScalesZero && sizeof_bits_v < 8) { - transform_quant(quant_frag, mma_B, fragment_scale_input, fragment_zero_input((k_tile / k_reload_factor) % zero_traits_size, _, 0)); - } else { - transform_quant(quant_frag, mma_B, fragment_scale_input, fragment_zero_input); - } - } - - cute::gemm(tiled_mma, mma_A, mma_B, accum); - barrier_wait(barrier_scope); - } - } - - template - CUTLASS_DEVICE MainloopTensors update_tensor_shape_stride( - Params const& mainloop_params, - int32_t const& next_group, - ProblemShape_MNKL const& problem_shape_mnkl) { - const int32_t M = get<0>(problem_shape_mnkl); - const int32_t N = get<1>(problem_shape_mnkl); - const int32_t K = get<2>(problem_shape_mnkl); - - ElementA const* ptr_A_curr_batch = reinterpret_cast(mainloop_params.ptr_A[next_group]); - auto ptr_B_curr_batch = [&]() { - if constexpr (sizeof_bits_v < 8) { - return cute::subbyte_iterator(mainloop_params.ptr_B[next_group]); - } else { - return make_gmem_ptr(static_cast(mainloop_params.ptr_B[next_group])); - } - }(); - - TensorMKL mA_mkl = make_tensor(make_gmem_ptr(ptr_A_curr_batch), make_shape(M, K, static_cast(1)), mainloop_params.dA[next_group]); - TensorNKL mB_nkl = make_tensor(ptr_B_curr_batch, make_shape(N, K,static_cast(1)), mainloop_params.dB[next_group]); - - if constexpr(KernelConversionMode == ConversionMode::DirectConvert){ - return cute::make_tuple(mA_mkl, mB_nkl, TensorS{}, TensorZ{}); - } - - auto scale_k = cute::ceil_div(K, mainloop_params.group_size); - TensorS mScale = make_tensor( - make_gmem_ptr(static_cast(mainloop_params.ptr_S[next_group])), - make_layout(make_shape(IsATransformed ? M : N, scale_k, static_cast(1)), mainloop_params.dS[next_group])); - - if constexpr(KernelConversionMode == ConversionMode::ConvertAndScale){ - return cute::make_tuple(mA_mkl, mB_nkl, mScale, TensorZ{}); - } - - auto ptr_Z = [&]() { - if constexpr (sizeof_bits_v < 8) { - return cute::subbyte_iterator(mainloop_params.ptr_Z[next_group]); - } else { - return make_gmem_ptr(static_cast(mainloop_params.ptr_Z[next_group])); - } - }(); - - TensorZ mZero = make_tensor(ptr_Z, - make_layout(make_shape(zero_elements_packed_along_k * (IsATransformed ? M : N), scale_k / zero_elements_packed_along_k, static_cast(1)), - make_stride(_1{}, static_cast(zero_elements_packed_along_k) * (IsATransformed ? M : N), static_cast(IsATransformed ? M : N) * scale_k))); - - return cute::make_tuple(mA_mkl, mB_nkl, mScale, mZero); - } }; diff --git a/include/cutlass/gemm/collective/xe_mma.hpp b/include/cutlass/gemm/collective/xe_mma.hpp index 7d75825a3e..673e3c110f 100644 --- a/include/cutlass/gemm/collective/xe_mma.hpp +++ b/include/cutlass/gemm/collective/xe_mma.hpp @@ -47,18 +47,18 @@ using namespace cute; template -struct CollectiveMma, TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, +struct CollectiveMma, TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_> { // // Type Aliases // - using DispatchPolicy = MainloopIntelXeXMX16; + using DispatchPolicy = MainloopXeL1Staged; using WorkgroupTileShape = TileShape_; using ElementA = ElementA_; - using StrideA = StrideA_; + using StrideA = cute::remove_pointer_t; using ElementB = ElementB_; - using StrideB = StrideB_; + using StrideB = cute::remove_pointer_t; using TiledMma = TiledMma_; using ElementAccumulator = typename TiledMma::ValTypeC; using GmemTiledCopyA = GmemTiledCopyA_; @@ -71,7 +71,6 @@ struct CollectiveMma, TileShape_, Element using TransformB = TransformB_; using ArchTag = typename DispatchPolicy::ArchTag; - static_assert(platform::is_same::value, "MainloopIntelXeXMX16 requires that A and B have same type."); static_assert(std::is_same_v, "Transformation for A is not currently supported on Intel PVC"); static_assert(std::is_same_v, "Transformation for B is not currently supported on Intel PVC"); @@ -100,8 +99,10 @@ struct CollectiveMma, TileShape_, Element static constexpr auto Num_SGs = ATOM_N * ATOM_M * ATOM_K; static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); - using Copy_A = typename Copy_Traits::template DefaultTiledCopy; - using Copy_B = typename Copy_Traits::template DefaultTiledCopy; + // Helper to get tensor types + template + using TensorType = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), + make_layout(make_shape(int{}, int{}, int{}), Stride{}))); // Host side kernel arguments struct Arguments { @@ -112,8 +113,8 @@ struct CollectiveMma, TileShape_, Element }; struct Params { - Copy_A tiled_copy_a; - Copy_B tiled_copy_b; + TensorType mA_mkl; + TensorType mB_nkl; }; // @@ -129,12 +130,11 @@ struct CollectiveMma, TileShape_, Element auto [M,N,K,L] = problem_shape; - auto mA_mkl = make_tensor(make_gmem_ptr(args.ptr_A), make_layout(make_shape(M, K, L), args.dA)); - auto mB_nkl = make_tensor(make_gmem_ptr(args.ptr_B), make_layout(make_shape(N, K, L), args.dB)); - Copy_A tiled_copy_a{Copy_A{}.with(mA_mkl)}; - Copy_B tiled_copy_b{Copy_B{}.with(mB_nkl)}; - - return Params{tiled_copy_a, tiled_copy_b}; + auto mA_mkl = make_tensor(make_gmem_ptr(args.ptr_A), + make_layout(make_shape(M, K, L), args.dA)); + auto mB_nkl = make_tensor(make_gmem_ptr(args.ptr_B), + make_layout(make_shape(N, K, L), args.dB)); + return Params{mA_mkl, mB_nkl}; } template @@ -177,38 +177,51 @@ struct CollectiveMma, TileShape_, Element static_assert(is_rmem::value, "D tensor must be rmem resident."); static_assert(is_rmem::value, "C tensor must be rmem resident."); - auto thr_copy_A = mainloop.tiled_copy_a.get_slice(thread_idx); - auto thr_copy_B = mainloop.tiled_copy_b.get_slice(thread_idx); + // Current implementation - Batch indexing + // Extract the batch index and slice the tensor before passing to make_block_2d_copy_* + // This is required because the current compiler doesn't properly handle full tensors + // in the make_block_2d_copy_* functions which makes it little slow. + // + // Future implementation - Full tensor support + // Once there is necessary compiler support, switch to: + // + // auto copy_a = get_block_2d_copy_A(TiledMma{}, mainloop.mA_mkl); + // auto copy_b = get_block_2d_copy_B(TiledMma{}, mainloop.mB_nkl); + // + // Required changes in copy_traits_xe_2d.hpp under make_block_2d_copy(): + // 1. tuple_repeat(_1{}) → make_tile(_1{}, _1{}) + // 2. Remove x_atom_shape, use elem_scale(ShapeTiler_MN{}, atom_shape) directly + auto batch_idx = get<3>(blk_coord); + auto copy_a = get_block_2d_copy_A(TiledMma{}, mainloop.mA_mkl(_,_,batch_idx)); + auto copy_b = get_block_2d_copy_B(TiledMma{}, mainloop.mB_nkl(_,_,batch_idx)); + + auto thr_copy_a = copy_a.get_slice(thread_idx); + auto thr_copy_b = copy_b.get_slice(thread_idx); // Instantiate the MMA object and get thread slice TiledMma tiled_mma; - // TODO(Codeplay): see if we can make this nicer - // To make all work items in a subgroup have the same global tensors pass in the index of work item 0 in each subgroup - auto sg = compat::get_nd_item<1>().get_sub_group(); - auto first_thread_in_sg_idx = sg.get_group_linear_id() * DispatchPolicy::SubgroupSize; - auto thr_mma = tiled_mma.get_slice(first_thread_in_sg_idx); - - // Partition global counting tensors for MMA - Tensor tCgA = thr_mma.partition_A(gA); - Tensor tCgB = thr_mma.partition_B(gB); - - Tensor tCrA = make_tensor(make_fragment_layout(mainloop.tiled_copy_a, tCgA(_,_,_,0).shape())); - Tensor tCrB = make_tensor(make_fragment_layout(mainloop.tiled_copy_b, tCgB(_,_,_,0).shape())); - - // Retile registers for copies - Tensor tArA = thr_copy_A.retile_D(tCrA); - Tensor tBrB = thr_copy_B.retile_D(tCrB); - - // Retile global counting tensors for copies - Tensor tAgA = thr_copy_A.retile_S(tCgA); - Tensor tBgB = thr_copy_B.retile_S(tCgB); - - auto tiled_prefetch_a = cute::prefetch_selector,Int>, Num_SGs>(mainloop.tiled_copy_a); - auto tiled_prefetch_b = cute::prefetch_selector,Int>, Num_SGs>(mainloop.tiled_copy_b); - auto thr_prefetch_A = tiled_prefetch_a.get_slice(thread_idx); - auto thr_prefetch_B = tiled_prefetch_b.get_slice(thread_idx); - - // Partition global tile for prefetch + auto thr_mma = tiled_mma.get_slice(thread_idx); + + /* Register fragments for MMA */ + auto tCrA = thr_mma.partition_sg_fragment_A(gA(_,_,0)); + auto tCrB = thr_mma.partition_sg_fragment_B(gB(_,_,0)); + + /* Register fragments for copies */ + auto tArA = thr_copy_a.partition_sg_fragment_D(gA(_,_,0)); + auto tBrB = thr_copy_b.partition_sg_fragment_D(gB(_,_,0)); + + /* Partition global tensor (proxies) for copies */ + Tensor tAgA = thr_copy_a.partition_S(gA); + Tensor tBgB = thr_copy_b.partition_S(gB); + + /* Create prefetch TiledCopy instances */ + auto prefetch_a = make_block_2d_prefetch(copy_a); + auto prefetch_b = make_block_2d_prefetch(copy_b); + + auto thr_prefetch_A = prefetch_a.get_slice(thread_idx); + auto thr_prefetch_B = prefetch_b.get_slice(thread_idx); + + /* Partition global tensor (proxies) for prefetch */ auto pAgA = thr_prefetch_A.partition_S(gA); auto pBgB = thr_prefetch_B.partition_S(gB); @@ -216,20 +229,18 @@ struct CollectiveMma, TileShape_, Element #define PRINT(x) print(#x ": "); print(x); print("\n"); if (cute::thread(LOG_THREAD, LOG_GROUP)) { print("======================= A: \n"); - PRINT(tCgA); PRINT(tAgA); PRINT(tCrA); PRINT(tArA); - PRINT(mainloop.tiled_copy_a); + PRINT(copy_a); print("======================= B: \n"); - PRINT(tCgB); PRINT(tBgB); PRINT(tCrB); PRINT(tBrB); - PRINT(mainloop.tiled_copy_b); + PRINT(copy_b); } #undef PRINT #endif @@ -243,21 +254,25 @@ struct CollectiveMma, TileShape_, Element CUTLASS_PRAGMA_UNROLL for (; prefetch_k < DispatchPolicy::Stages; prefetch_k++) { - prefetch(tiled_prefetch_a, pAgA(_, _, _, prefetch_k)); - prefetch(tiled_prefetch_b, pBgB(_, _, _, prefetch_k)); + prefetch(prefetch_a, pAgA(_, _, _, prefetch_k)); + prefetch(prefetch_b, pBgB(_, _, _, prefetch_k)); } for (int k_tile = k_start_idx; k_tile < k_tile_count + k_start_idx; k_tile++, prefetch_k++) { barrier_arrive(barrier_scope); // Copy gmem to rmem for the first k_tile - copy(mainloop.tiled_copy_a, tAgA(_,_,_,k_tile), tArA); - copy(mainloop.tiled_copy_b, tBgB(_,_,_,k_tile), tBrB); + copy(copy_a, tAgA(_,_,_,k_tile), tArA); + copy(copy_b, tBgB(_,_,_,k_tile), tBrB); if (prefetch_k < k_tile_count) { - prefetch(tiled_prefetch_a, pAgA(_, _, _, prefetch_k)); - prefetch(tiled_prefetch_b, pBgB(_, _, _, prefetch_k)); + prefetch(prefetch_a, pAgA(_, _, _, prefetch_k)); + prefetch(prefetch_b, pBgB(_, _, _, prefetch_k)); } + /* Shuffle data from copy fragments to MMA fragments */ + reorder(tArA, tCrA); + reorder(tBrB, tCrB); + cute::gemm(tiled_mma, tCrA, tCrB, accum); barrier_wait(barrier_scope); } @@ -267,3 +282,4 @@ struct CollectiveMma, TileShape_, Element } // namespace cutlass::gemm::collective ///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cutlass/gemm/collective/xe_array_mma_fp8.hpp b/include/cutlass/gemm/collective/xe_mma_legacy.hpp similarity index 61% rename from include/cutlass/gemm/collective/xe_array_mma_fp8.hpp rename to include/cutlass/gemm/collective/xe_mma_legacy.hpp index 36f8c85587..fac4108f6d 100644 --- a/include/cutlass/gemm/collective/xe_array_mma_fp8.hpp +++ b/include/cutlass/gemm/collective/xe_mma_legacy.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved. + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. * Copyright (C) 2025 Intel Corporation, All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * @@ -37,7 +37,6 @@ #include "cute/algorithm/functional.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cutlass/fp8_to_fp16.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -48,20 +47,18 @@ using namespace cute; template -struct CollectiveMma, TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, +struct CollectiveMma, TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_> { // // Type Aliases // - using DispatchPolicy = MainloopIntelXeXMX16GroupFP8; + using DispatchPolicy = MainloopIntelXeXMX16; using WorkgroupTileShape = TileShape_; using ElementA = ElementA_; - using StrideA = StrideA_; - using InternalStrideA = cute::remove_pointer_t; + using StrideA = cute::remove_pointer_t; using ElementB = ElementB_; - using StrideB = StrideB_; - using InternalStrideB = cute::remove_pointer_t; + using StrideB = cute::remove_pointer_t; using TiledMma = TiledMma_; using ElementAccumulator = typename TiledMma::ValTypeC; using GmemTiledCopyA = GmemTiledCopyA_; @@ -74,8 +71,7 @@ struct CollectiveMma, TileShape_, using TransformB = TransformB_; using ArchTag = typename DispatchPolicy::ArchTag; - static_assert(platform::is_same::value, "MainloopIntelXeXMX16Array requires that A and B have same type."); - static_assert(std::is_same_v || std::is_same_v); + static_assert(platform::is_same::value, "MainloopIntelXeXMX16 requires that A and B have same type."); static_assert(std::is_same_v, "Transformation for A is not currently supported on Intel PVC"); static_assert(std::is_same_v, "Transformation for B is not currently supported on Intel PVC"); @@ -91,33 +87,33 @@ struct CollectiveMma, TileShape_, static constexpr int ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); static constexpr int ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static_assert(BLK_M % TiledMma{}.template tile_size_mnk<0>() == 0, "TiledMma permutation size must match block size."); + static_assert(BLK_N % TiledMma{}.template tile_size_mnk<1>() == 0, "TiledMma permutation size must match block size."); + static_assert(BLK_K % TiledMma{}.template tile_size_mnk<2>() == 0, "TiledMma permutation size must match block size."); + static constexpr int SG_M = ceil_div(BLK_M, ATOM_M); static constexpr int SG_N = ceil_div(BLK_N, ATOM_N); static constexpr int SG_K = ceil_div(BLK_K, ATOM_K); using SubgroupTileShape = Shape, C, C>; - static constexpr int Num_SGs = ATOM_N * ATOM_M * ATOM_K; + // 32 + static constexpr auto Num_SGs = ATOM_N * ATOM_M * ATOM_K; static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); - - using Copy_A = typename Copy_Traits::template DefaultTiledCopy; - using Copy_B = typename Copy_Traits::template DefaultTiledCopy; - using TensorMKL = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), InternalStrideA{})); //(m, k) - using TensorNKL = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), InternalStrideB{})); //(n, k) - using MainloopTensors = cute::tuple; + using Copy_A = typename Copy_Traits::template DefaultTiledCopy; + using Copy_B = typename Copy_Traits::template DefaultTiledCopy; + // Host side kernel arguments struct Arguments { - ElementA const** ptr_A; + ElementA const* ptr_A; StrideA dA; - ElementB const** ptr_B; + ElementB const* ptr_B; StrideB dB; }; struct Params { - ElementA const** ptr_A; - StrideA dA; - ElementB const** ptr_B; - StrideB dB; + Copy_A tiled_copy_a; + Copy_B tiled_copy_b; }; // @@ -131,18 +127,14 @@ struct CollectiveMma, TileShape_, to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { (void) workspace; - const int32_t mock_L = 1; - auto problem_shape_MNK = repeat_like(typename ProblemShape::UnderlyingProblemShape{}, mock_L);; - auto init_M = get<0>(problem_shape_MNK); - auto init_N = get<1>(problem_shape_MNK); - auto init_K = get<2>(problem_shape_MNK); - - return Params{ - args.ptr_A, - args.dA, - args.ptr_B, - args.dB - }; + auto [M,N,K,L] = problem_shape; + + auto mA_mkl = make_tensor(make_gmem_ptr(args.ptr_A), make_layout(make_shape(M, K, L), args.dA)); + auto mB_nkl = make_tensor(make_gmem_ptr(args.ptr_B), make_layout(make_shape(N, K, L), args.dB)); + Copy_A tiled_copy_a{Copy_A{}.with(mA_mkl)}; + Copy_B tiled_copy_b{Copy_B{}.with(mB_nkl)}; + + return Params{tiled_copy_a, tiled_copy_b}; } template @@ -158,20 +150,15 @@ struct CollectiveMma, TileShape_, bool implementable = true; constexpr int min_aligned_elements_A = copy_alignment_bits / sizeof_bits::value; + implementable &= cutlass::detail::check_alignment(cute::make_shape(M,K,L), args.dA); constexpr int min_aligned_elements_B = copy_alignment_bits / sizeof_bits::value; - constexpr int min_batch_aligned_elements_A = batch_alignment_bits / sizeof_bits::value; - constexpr int min_batch_aligned_elements_B = batch_alignment_bits / sizeof_bits::value; - for (int i = 0; i < problem_shapes.groups(); i++) { - auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); - auto [M,N,K,L] = problem_shape_MNKL; - - implementable &= cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); - implementable &= cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); - - if (L > 1) { - implementable &= get<2>(InternalStrideA{}) % min_batch_aligned_elements_A == 0; - implementable &= get<2>(InternalStrideB{}) % min_batch_aligned_elements_B == 0; - } + implementable &= cutlass::detail::check_alignment(cute::make_shape(N,K,L), args.dB); + + if (L > 1) { + constexpr int min_batch_aligned_elements_A = batch_alignment_bits / sizeof_bits::value; + implementable &= get<2>(args.dA) % min_batch_aligned_elements_A == 0; + constexpr int min_batch_aligned_elements_B = batch_alignment_bits / sizeof_bits::value; + implementable &= get<2>(args.dB) % min_batch_aligned_elements_B == 0; } if (!implementable) { @@ -182,22 +169,16 @@ struct CollectiveMma, TileShape_, } /// Perform a subgroup-scoped matrix multiply-accumulate - template + template CUTLASS_DEVICE void operator()(FrgTensorD &accum, TensorA gA, TensorB gB, FrgTensorC const &src_accum, - KTileIterator k_tile_iter, int const& k_tile_count, - BlkCoord const &blk_coord, int const &K_start, int const& thread_idx, - Params const &mainloop, LoadTensors const& load_tensors) { + KTileIterator k_tile_iter, int k_tile_count, BlkCoord const &blk_coord, int const &K_start, int thread_idx, + Params const &mainloop) { + (void)blk_coord; static_assert(is_rmem::value, "D tensor must be rmem resident."); static_assert(is_rmem::value, "C tensor must be rmem resident."); - (void)thread_idx; - - Copy_A tiled_copy_a{Copy_A{}.with(get<0>(load_tensors))}; - Copy_B tiled_copy_b{Copy_B{}.with(get<1>(load_tensors))}; - - auto thr_copy_A = tiled_copy_a.get_slice(thread_idx); - auto thr_copy_B = tiled_copy_b.get_slice(thread_idx); + auto thr_copy_A = mainloop.tiled_copy_a.get_slice(thread_idx); + auto thr_copy_B = mainloop.tiled_copy_b.get_slice(thread_idx); // Instantiate the MMA object and get thread slice TiledMma tiled_mma; @@ -211,11 +192,8 @@ struct CollectiveMma, TileShape_, Tensor tCgA = thr_mma.partition_A(gA); Tensor tCgB = thr_mma.partition_B(gB); - Tensor tCrA = make_tensor(make_fragment_layout(tiled_copy_a, tCgA(_,_,_,0).shape())); - Tensor tCrB = make_tensor(make_fragment_layout(tiled_copy_b, tCgB(_,_,_,0).shape())); - - Tensor tCrA_fp16 = make_fragment_like(tCrA); - Tensor tCrB_fp16 = make_fragment_like(tCrB); + Tensor tCrA = make_tensor(make_fragment_layout(mainloop.tiled_copy_a, tCgA(_,_,_,0).shape())); + Tensor tCrB = make_tensor(make_fragment_layout(mainloop.tiled_copy_b, tCgB(_,_,_,0).shape())); // Retile registers for copies Tensor tArA = thr_copy_A.retile_D(tCrA); @@ -224,32 +202,36 @@ struct CollectiveMma, TileShape_, // Retile global counting tensors for copies Tensor tAgA = thr_copy_A.retile_S(tCgA); Tensor tBgB = thr_copy_B.retile_S(tCgB); - - auto tiled_prefetch_a = cute::prefetch_selector,Int>, Num_SGs>(tiled_copy_a); - auto tiled_prefetch_b = cute::prefetch_selector,Int>, Num_SGs>(tiled_copy_b); + + auto tiled_prefetch_a = cute::prefetch_selector,Int>, Num_SGs>(mainloop.tiled_copy_a); + auto tiled_prefetch_b = cute::prefetch_selector,Int>, Num_SGs>(mainloop.tiled_copy_b); auto thr_prefetch_A = tiled_prefetch_a.get_slice(thread_idx); auto thr_prefetch_B = tiled_prefetch_b.get_slice(thread_idx); - + // Partition global tile for prefetch auto pAgA = thr_prefetch_A.partition_S(gA); auto pBgB = thr_prefetch_B.partition_S(gB); #if CUTLASS_ENABLE_DEBUG_PRINTS - if (cutlass::thread(LOG_THREAD, LOG_GROUP)) { - print("======================= A: \n"); - print(" gA : "); print(gA); print("\n"); - print("tCgA : "); print(tCgA); print("\n"); - print("tAgA : "); print(tAgA); print("\n"); - - print("===================== B :\n"); - print(" gB : "); print(gB); print("\n"); - print("tCgB : "); print(tCgB); print("\n"); - print("tBgB : "); print(tBgB); print("\n"); - - print("===================== Config: \n"); - print(" threads per workgroup : "); print(MaxThreadsPerBlock); print("\n"); - print(" SubgroupTileShape : "); print(SubgroupTileShape{}); print("\n"); +#define PRINT(x) print(#x ": "); print(x); print("\n"); + if (cute::thread(LOG_THREAD, LOG_GROUP)) { + print("======================= A: \n"); + PRINT(tCgA); + PRINT(tAgA); + + PRINT(tCrA); + PRINT(tArA); + PRINT(mainloop.tiled_copy_a); + + print("======================= B: \n"); + PRINT(tCgB); + PRINT(tBgB); + + PRINT(tCrB); + PRINT(tBrB); + PRINT(mainloop.tiled_copy_b); } +#undef PRINT #endif // @@ -268,41 +250,20 @@ struct CollectiveMma, TileShape_, for (int k_tile = k_start_idx; k_tile < k_tile_count + k_start_idx; k_tile++, prefetch_k++) { barrier_arrive(barrier_scope); // Copy gmem to rmem for the first k_tile - copy(tiled_copy_a, tAgA(_,_,_,k_tile), tArA); - copy(tiled_copy_b, tBgB(_,_,_,k_tile), tBrB); - - convert_FP8_to_FP16(tCrA, tCrA_fp16); - convert_FP8_to_FP16(tCrB, tCrB_fp16); + copy(mainloop.tiled_copy_a, tAgA(_,_,_,k_tile), tArA); + copy(mainloop.tiled_copy_b, tBgB(_,_,_,k_tile), tBrB); if (prefetch_k < k_tile_count) { prefetch(tiled_prefetch_a, pAgA(_, _, _, prefetch_k)); prefetch(tiled_prefetch_b, pBgB(_, _, _, prefetch_k)); } - cute::gemm(tiled_mma, tCrA_fp16, tCrB_fp16, accum); + cute::gemm(tiled_mma, tCrA, tCrB, accum); barrier_wait(barrier_scope); } } - - template - CUTLASS_DEVICE auto update_tensor_shape_stride( - Params const& mainloop_params, - int32_t const& next_group, - ProblemShape_MNKL const& problem_shape_mnkl) { - const int32_t M = get<0>(problem_shape_mnkl); - const int32_t N = get<1>(problem_shape_mnkl); - const int32_t K = get<2>(problem_shape_mnkl); - - ElementA const* ptr_A_curr_batch = reinterpret_cast(mainloop_params.ptr_A[next_group]); - ElementB const* ptr_B_curr_batch = reinterpret_cast(mainloop_params.ptr_B[next_group]); - - Tensor mA = make_tensor(make_gmem_ptr(ptr_A_curr_batch), make_shape(M, K,(int32_t)1), mainloop_params.dA[next_group]); - Tensor mB = make_tensor(make_gmem_ptr(ptr_B_curr_batch), make_shape(N, K,(int32_t)1), mainloop_params.dB[next_group]); - - return cute::make_tuple(mA, mB); - } }; } // namespace cutlass::gemm::collective -///////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/include/cutlass/gemm/collective/xe_mma_mixed_input.hpp b/include/cutlass/gemm/collective/xe_mma_mixed_input.hpp index d8b6843b65..a99d8b9793 100644 --- a/include/cutlass/gemm/collective/xe_mma_mixed_input.hpp +++ b/include/cutlass/gemm/collective/xe_mma_mixed_input.hpp @@ -166,8 +166,8 @@ struct CollectiveMma< using NonVoidElementScale = cute::conditional_t, ElementMMA, ElementScale>; using NonVoidElementZero = cute::conditional_t, ElementMMA, ElementZero>; - using NonVoidStrideScale = cute::conditional_t, cute::Stride<_1, int64_t, int64_t>, StrideScale>; - using NonVoidStrideZero = cute::conditional_t, cute::Stride<_1, int64_t, int64_t>, StrideZero>; + using NonVoidStrideScale = cute::remove_pointer_t, cute::Stride<_1, int64_t, int64_t>, StrideScale>>; + using NonVoidStrideZero = cute::remove_pointer_t, cute::Stride<_1, int64_t, int64_t>, StrideZero>>; static constexpr auto zero_elements_packed_along_k = get<0>(NonVoidStrideZero{}); // When stride is Stride<_0, _0, _1>, quantization can be determined as tensor-wise @@ -175,8 +175,8 @@ struct CollectiveMma< static constexpr auto is_groupwise = (quant_mode == QuantMode::GroupWise); static constexpr auto is_tensorwise = (quant_mode == QuantMode::TensorWise); - using StrideA = StrideA_; - using StrideB = StrideB_; + using StrideA = cute::remove_pointer_t; + using StrideB = cute::remove_pointer_t; using ElementAccumulator = typename TiledMma::ValTypeC; diff --git a/include/cutlass/gemm/collective/xe_mma_w8a8.hpp b/include/cutlass/gemm/collective/xe_mma_w8a8.hpp index e9698afdb2..37e7576c90 100644 --- a/include/cutlass/gemm/collective/xe_mma_w8a8.hpp +++ b/include/cutlass/gemm/collective/xe_mma_w8a8.hpp @@ -57,9 +57,9 @@ struct CollectiveMma, TileShape_, ElementA_, using DispatchPolicy = MainloopIntelW8A8; using WorkgroupTileShape = TileShape_; using ElementA = ElementA_; - using StrideA = StrideA_; + using StrideA = cute::remove_pointer_t; using ElementB = ElementB_; - using StrideB = StrideB_; + using StrideB = cute::remove_pointer_t; using TiledMma = TiledMma_; using ElementAccumulator = typename TiledMma::ValTypeC; using GmemTiledCopyA = GmemTiledCopyA_; diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index b742dfd76f..46efe7dfcb 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -127,10 +127,15 @@ struct KernelPtrArrayTmaWarpSpecializedCooperative { }; struct KernelPtrArrayTmaWarpSpecializedPingpong { }; // FP8 related policies (including Blocked Scaled Accumulation) -struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum: KernelTmaWarpSpecializedCooperative { }; -struct KernelTmaWarpSpecializedPingpongFP8BlockScaledAccum: KernelTmaWarpSpecializedPingpong { }; -struct KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum: KernelPtrArrayTmaWarpSpecializedCooperative { }; -struct KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum: KernelPtrArrayTmaWarpSpecializedPingpong { }; +struct KernelTmaWarpSpecializedCooperativeFP8Blockwise: KernelTmaWarpSpecializedCooperative { }; +struct KernelTmaWarpSpecializedPingpongFP8Blockwise: KernelTmaWarpSpecializedPingpong { }; +struct KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise: KernelPtrArrayTmaWarpSpecializedCooperative { }; +struct KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise: KernelPtrArrayTmaWarpSpecializedPingpong { }; + +using KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum = KernelTmaWarpSpecializedCooperativeFP8Blockwise; +using KernelTmaWarpSpecializedPingpongFP8BlockScaledAccum = KernelTmaWarpSpecializedPingpongFP8Blockwise; +using KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum = KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise; +using KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum = KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise; // Policies to opt into mixed type GEMMs struct KernelTmaWarpSpecializedMixedInput : KernelTmaWarpSpecialized { }; @@ -322,17 +327,17 @@ struct MainloopSm90TmaGmmaWarpSpecializedFP8 // n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule -// For FP8 kernels with Block Scaling +// For FP8 kernels with Blockwise (Software) Scaling template< int Stages_, class ClusterShape_ = Shape<_1,_1,_1>, - class KernelSchedule = KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum + class KernelSchedule = KernelTmaWarpSpecializedCooperativeFP8Blockwise > -struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8 +struct MainloopSm90TmaGmmaWarpSpecializedBlockwiseFP8 : MainloopSm90TmaGmmaWarpSpecialized { static_assert( - cute::is_same_v || - cute::is_same_v, + cute::is_same_v || + cute::is_same_v, "KernelSchedule must be one of the warp specialized FP8 block scale policies"); }; @@ -414,15 +419,15 @@ struct MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput { template< int Stages_, class ClusterShape_ = Shape<_1,_1,_1>, - class KernelSchedule = KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum + class KernelSchedule = KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise > -struct MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling +struct MainloopSm90ArrayTmaGmmaWarpSpecializedBlockwise : MainloopSm90ArrayTmaGmmaWarpSpecialized { static_assert( cute::is_any_of_v< KernelSchedule, - KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum, - KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum + KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise, + KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise >, "KernelSchedule must be one of the warp specialized FP8 block scale policies"); }; @@ -443,6 +448,15 @@ struct KernelWarpSpecializedSm100 final { static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; }; +template< + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_ +> +struct KernelMixedTmaCpAsyncWarpSpecializedSm100 final { + static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; +}; + template< int SchedulerPipelineStageCount_, int AccumulatorPipelineStageCount_ @@ -656,7 +670,7 @@ template< class KernelSchedule > struct HasAuxiliaryLoad< - MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling< + MainloopSm90ArrayTmaGmmaWarpSpecializedBlockwise< Stages, ClusterShape, KernelSchedule @@ -669,7 +683,7 @@ template< class KernelSchedule > struct HasAuxiliaryLoad< - MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8< + MainloopSm90TmaGmmaWarpSpecializedBlockwiseFP8< Stages, ClusterShape, KernelSchedule @@ -703,6 +717,7 @@ struct KernelScheduleSm100DenseGemm : KernelScheduleSm100 {}; // Base policy struct KernelTmaWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100DenseGemm {}; // Use for 1SM Dense GEMM Kernels for Collective Mainloop Builder struct KernelTmaWarpSpecialized2SmSm100 final : KernelSchedule2Sm, KernelScheduleSm100DenseGemm {}; // Use for 2SM Dense GEMM Kernels for Collective Mainloop Builder struct KernelWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100DenseGemm {}; // Use for 1SM Dense GEMM Kernels for Collective Mainloop Builder Without TMA +struct KernelMixedTmaCpAsyncWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100DenseGemm {}; /////////////////////////////////////////////////////////////////////////////////////////////////////// // SM100 Ptr-Array Dense GEMM Dispatch Policies @@ -798,6 +813,8 @@ struct KernelTmaWarpSpecialized1SmMxf4Sm100 final : KernelSchedule1 struct KernelTmaWarpSpecialized2SmMxf4Sm100 final : KernelSchedule2Sm, KernelScheduleMxNvf4Sm100 { }; struct KernelTmaWarpSpecialized1SmMxf8f6f4Sm100 final : KernelSchedule1Sm, KernelScheduleMxf8f6f4Sm100 { }; struct KernelTmaWarpSpecialized2SmMxf8f6f4Sm100 final : KernelSchedule2Sm, KernelScheduleMxf8f6f4Sm100 { }; +struct KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100 final : KernelSchedule1Sm, KernelScheduleBlockScaledGemmSm100 {}; + /////////////////////////////////////////////////////////////////////////////////////////////////////// // SM100 BlockScaled Ptr Array Dense GEMM Dispatch Policies /////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -953,6 +970,34 @@ struct MainloopSm100UmmaCpAsyncWarpSpecialized { using Schedule = KernelWarpSpecializedSm100; }; +template< + int Stages_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ClusterShape_ = Shape<_1,_1,_1> +> +struct MainloopSm100UmmaMixedTmaCpAsyncWarpSpecialized { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm100; + using Schedule = KernelMixedTmaCpAsyncWarpSpecializedSm100; + constexpr static bool IsOverlappingAccum = false; +}; + +template< + int Stages_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ClusterShape_ = Shape<_1,_1,_1> +> +struct MainloopSm100UmmaMixedTmaCpAsyncWarpSpecializedBlockScaled { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm100; + using Schedule = KernelMixedTmaCpAsyncWarpSpecializedSm100; + constexpr static bool IsOverlappingAccum = false; +}; + // n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule template< int Stages_, @@ -1218,6 +1263,10 @@ template struct MainloopIntelXeXMX16Group : MainloopIntelXeXMX16 { }; +template +struct MainloopXeL1StagedGroup : MainloopIntelXeXMX16 { +}; + template struct MainloopIntelXeXMX16GroupMixedPrecision : MainloopIntelXeXMX16 { }; @@ -1247,6 +1296,20 @@ struct MainloopDeviceAgnostic { using Schedule = KernelMultistage; }; #endif + +#if defined(CUTLASS_ENABLE_SYCL) +// Note: This dispatch policy is specifically added for CollectiveMma to support +// the integration of new MMA atoms (XE_DPAS_TT) and copy atoms for Intel XE architecture +template +struct MainloopXeL1Staged { + constexpr static int Stages = Stages_; + constexpr static int SubgroupSize = 16; + using ArchTag = arch::IntelXe; + using Schedule = KernelSchedule; + using ClusterShape = Shape<_1,_1,_1>; +}; +#endif + // n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule template< int LoadABPipelineStageCount_, diff --git a/include/cutlass/gemm/group_array_problem_shape.hpp b/include/cutlass/gemm/group_array_problem_shape.hpp index fe5e4c5364..400f7e6b2d 100644 --- a/include/cutlass/gemm/group_array_problem_shape.hpp +++ b/include/cutlass/gemm/group_array_problem_shape.hpp @@ -79,6 +79,16 @@ struct GroupProblemShape { } }; +template +struct MoEProblemShape { + using UnderlyingProblemShape = ProblemShape_; + using MaxProblemShape = MaxProblemShape_; + + UnderlyingProblemShape problem_shape; + MaxProblemShape max_problem_shape; +}; + + template class ArrayProblemShape { public: @@ -120,4 +130,14 @@ class ArrayProblemShape { UnderlyingProblemShape problem_shape_{}; }; + +namespace detail { + +template +struct is_moe_problem_shape : cute::false_type {}; +template +struct is_moe_problem_shape> : cute::true_type {}; + +} + } // namespace cutlass::gemm diff --git a/include/cutlass/gemm/kernel/gemm_universal.hpp b/include/cutlass/gemm/kernel/gemm_universal.hpp index 69137d2114..4b060e0c02 100644 --- a/include/cutlass/gemm/kernel/gemm_universal.hpp +++ b/include/cutlass/gemm/kernel/gemm_universal.hpp @@ -73,6 +73,7 @@ struct IsCutlass3ArrayKernel +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileSchedulerTag_, + cute::enable_if_t< + cutlass::detail::is_kernel_tag_of_v>> +{ +public: + using ProblemShape = ProblemShape_; + + static constexpr bool IsGroupedGemmKernel = cutlass::gemm::detail::is_moe_problem_shape::value; + static constexpr bool IsMoEScheduler = false; // stub for MoE scheduler, which accepts a MoEProblemShape instead of GroupProblemShape + + CUTLASS_HOST_DEVICE + static auto get_problem_shape_gemm(ProblemShape const& shape) { + if constexpr (IsGroupedGemmKernel) { + return shape.max_problem_shape; + } + else { + return shape; + } + } + CUTLASS_HOST_DEVICE + static auto get_problem_shape_scheduler(ProblemShape const& shape) { + if constexpr (IsMoEScheduler) { + return shape; + } + else if constexpr (IsGroupedGemmKernel) { + return shape.problem_shape; + } + else { + return shape; + } + } + + template + CUTLASS_HOST_DEVICE + static auto get_effective_shape(ProblemShape const& shape, WorkTileInfo const& work_tile_info) { + if constexpr (IsGroupedGemmKernel) { + return append<4>(shape.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{}); + } + else { + return append<4>(shape, Int<1>{}); + } + } + + using ProblemShapeGemm = decltype(get_problem_shape_gemm(ProblemShape{})); + using ProblemShapeScheduler = decltype(get_problem_shape_scheduler(ProblemShape{})); + + static_assert(rank(ProblemShapeGemm{}) == 3 or rank(ProblemShapeGemm{}) == 4, + "ProblemShapeGemm{} should be or "); + static constexpr bool IsGdcEnabled = false; + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static_assert(ArchTag::kMinComputeCapability >= 100); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using EpilogueTile = typename CollectiveEpilogue::EpilogueTile; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + static constexpr bool IsComplex = CollectiveEpilogue::NumAccumulatorMtxs == 2; + + // CLC pipeline depth + // determines how many waves (stages-1) a warp can race ahead + static constexpr uint32_t SchedulerPipelineStageCount = DispatchPolicy::Schedule::SchedulerPipelineStageCount; + static constexpr bool IsOverlappingAccum = DispatchPolicy::IsOverlappingAccum; + static_assert(!IsOverlappingAccum, "TMA+CPASYNC kernel currently only supports non-overlapping accum."); + + // TileID scheduler + // Get Blk and Scheduling tile shapes + using CtaShape_MNK = typename CollectiveMainloop::CtaShape_MNK; + using AtomThrShapeMNK = typename CollectiveMainloop::AtomThrShapeMNK; + + static_assert(size(AtomThrShapeMNK{}) == 1, "Lower alignment kernel only supports 1x1x1 cluster shape."); + using TileSchedulerTag = cute::conditional_t; + using TileScheduler = typename detail::TileSchedulerSelector< + TileSchedulerTag, ArchTag, CtaShape_MNK, ClusterShape, SchedulerPipelineStageCount, ProblemShapeScheduler>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + // Warp specialization thread count per threadblock + static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMMAThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEmptyThreads = 0; + static constexpr uint32_t NumMainloopTMALoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMainloopCpAsyncLoadThreads = CollectiveMainloop::NumLoadThreadsCpAsync; // 4 warps + static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueThreads = CollectiveEpilogue::ThreadCount; + static constexpr uint32_t NumEpilogueWarps = NumEpilogueThreads / NumThreadsPerWarp; + + static constexpr uint32_t MaxThreadsPerBlock = NumSchedThreads + + NumMainloopTMALoadThreads + NumMainloopCpAsyncLoadThreads + + NumMMAThreads + + NumEpilogueLoadThreads + NumEpilogueThreads + NumEmptyThreads; + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + static constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_load_pipe_increment(CtaShape_MNK{}); + + static constexpr uint32_t NumFixupBarriers = 1; + static constexpr uint32_t CLCResponseSize = sizeof(typename TileScheduler::CLCResponse); + + static constexpr bool IsSchedDynamicPersistent = TileScheduler::IsDynamicPersistent; + + // Pipelines and pipeline states + static constexpr uint32_t AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; + + // Pipeline and pipeline state types + using MainloopPipelineTMA = typename CollectiveMainloop::MainloopPipelineTMA; + using MainloopPipelineTMAState = typename CollectiveMainloop::MainloopPipelineTMAState; + using MainloopPipelineCpAsync = typename CollectiveMainloop::MainloopPipelineCpAsync; + using MainloopPipelineCpAsyncState = typename CollectiveMainloop::MainloopPipelineCpAsyncState; + + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + using EpiLoadPipelineState = typename CollectiveEpilogue::LoadPipelineState; + + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + using EpiStorePipelineState = typename CollectiveEpilogue::StorePipelineState; + + using AccumulatorPipeline = cutlass::PipelineUmmaAsync; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + + // using CLCPipeline = cutlass::PipelineCLCFetchAsync; + using CLCPipeline = cute::conditional_t, + cutlass::PipelineAsync>; + using CLCPipelineState = typename CLCPipeline::PipelineState; + + using TmemAllocator = cute::TMEM::Allocator1Sm; + + // Kernel level shared memory storage + struct SharedStorage { + struct PipelineStorage : cute::aligned_struct<16, _1> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + using CLCPipelineStorage = typename CLCPipeline::SharedStorage; + using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) CLCPipelineStorage clc; + alignas(16) AccumulatorPipelineStorage accumulator; + alignas(16) arch::ClusterBarrier tmem_dealloc; + } pipelines; + + alignas(16) typename TileScheduler::CLCResponse clc_response[SchedulerPipelineStageCount]; + uint32_t tmem_base_ptr; + + struct TensorStorage : cute::aligned_struct<128, _1> { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); + + // Host facing host arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel device entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + ProblemShapeGemm problem_shape_gemm{}; + ProblemShapeScheduler problem_shape_scheduler{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + }; + + enum class WarpCategory : int32_t { + MMA = 0, + Sched = 1, + MainloopLoadTMA = 2, + EpilogueLoad = 3, + Epilogue = 4, + MainloopLoadCpAsync = 8 + }; + + struct IsParticipant { + uint32_t mma = false; + uint32_t sched = false; + uint32_t main_load_tma = false; + uint32_t epi_load = false; + uint32_t epilogue = false; + uint32_t main_load_cpasync = false; + }; + + // Convert to underlying arguments. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + (void) workspace; + // auto problem_shape = args.problem_shape; + // auto problem_shape_MNKL = append<4>(problem_shape, 1); + + auto problem_shape_gemm = get_problem_shape_gemm(args.problem_shape); + auto problem_shape_scheduler = get_problem_shape_scheduler(args.problem_shape); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count != 0) { + CUTLASS_TRACE_HOST(" WARNING: SM100 tile scheduler does not allow for user specified SM counts.\n" + " To restrict a kernel's resource usage, consider using CUDA driver APIs instead (green contexts)."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + // Epilogue + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* mainloop_workspace = nullptr; + + // Tile scheduler + void* scheduler_workspace = workspace_ptr + workspace_offset; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, problem_shape_scheduler, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + TileSchedulerParams scheduler; + if constexpr (IsGroupedGemmKernel) { + scheduler = TileScheduler::to_underlying_arguments( + problem_shape_scheduler, TileShape{}, AtomThrShapeMNK{}, ClusterShape{}, + args.hw_info, args.scheduler, scheduler_workspace); + } + else { + auto problem_shape = args.problem_shape; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + scheduler = TileScheduler::to_underlying_arguments( + problem_shape, TileShape{}, AtomThrShapeMNK{}, ClusterShape{}, + args.hw_info, args.scheduler, scheduler_workspace + ); + } + + return { + args.mode, + args.problem_shape, + problem_shape_gemm, + problem_shape_scheduler, + CollectiveMainloop::to_underlying_arguments(problem_shape_gemm, args.mainloop, mainloop_workspace), + CollectiveEpilogue::to_underlying_arguments(problem_shape_gemm, args.epilogue, epilogue_workspace), + hw_info, + scheduler + }; + } + + static bool + can_implement(Arguments const& args) { + bool implementable = true; + + if constexpr (IsGroupedGemmKernel) { + implementable &= args.mode == GemmUniversalMode::kGrouped; + implementable &= rank(ProblemShapeGemm{}) == 4; + implementable &= rank(typename ProblemShape::UnderlyingProblemShape::UnderlyingProblemShape{}) == 3; + } + else { + implementable &= (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && rank(ProblemShapeGemm{}) == 4); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + + auto problem_shape_gemm = get_problem_shape_gemm(args.problem_shape); + implementable &= CollectiveMainloop::can_implement(problem_shape_gemm, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(problem_shape_gemm, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + + static constexpr int MaxClusterSize = 16; + implementable &= size(ClusterShape{}) <= MaxClusterSize; + + return implementable; + } + + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_size = 0; + + auto problem_shape_gemm = get_problem_shape_gemm(args.problem_shape); + auto problem_shape_scheduler = get_problem_shape_scheduler(args.problem_shape); + + // Epilogue + workspace_size += CollectiveEpilogue::get_workspace_size(problem_shape_gemm, args.epilogue); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + // Tile scheduler + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, problem_shape_scheduler, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + auto problem_shape_gemm = get_problem_shape_gemm(args.problem_shape); + auto problem_shape_scheduler = get_problem_shape_scheduler(args.problem_shape); + + // Epilogue + status = CollectiveEpilogue::initialize_workspace(problem_shape_gemm, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(problem_shape_gemm, args.epilogue); + status = cutlass::Status::kSuccess; + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + // Tile scheduler + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, problem_shape_scheduler, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs, cuda_adapter); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, problem_shape_scheduler, args.hw_info, NumFixupBarriers); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + static dim3 + get_grid_shape(Params const& params) { + auto cluster_shape = ClusterShape{}; + + dim3 grid_shape; + if constexpr (IsGroupedGemmKernel) { + grid_shape = TileScheduler::get_grid_shape( + params.scheduler, + params.problem_shape_scheduler, + TileShape{}, + AtomThrShapeMNK{}, + cluster_shape, + params.hw_info); + } + else { + auto problem_shape_MNKL = append<4>(params.problem_shape_scheduler, 1); + grid_shape = TileScheduler::get_grid_shape( + params.scheduler, + problem_shape_MNKL, + TileShape{}, + AtomThrShapeMNK{}, + cluster_shape, + params.hw_info); + } + return grid_shape; + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + + using namespace cute; + using X = Underscore; + // Separate out problem shape for convenience + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape_gemm, Int<1>{}); + auto M = get<0>(problem_shape_MNKL); + auto N = get<1>(problem_shape_MNKL); + auto K = get<2>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + + // Account for more than one epilogue warp + int warp_idx = canonical_warp_idx_sync(); + WarpCategory warp_category = warp_idx < static_cast(WarpCategory::Epilogue) ? WarpCategory(warp_idx) + : warp_idx < static_cast(WarpCategory::MainloopLoadCpAsync) ? WarpCategory::Epilogue + : WarpCategory::MainloopLoadCpAsync; + uint32_t lane_predicate = cute::elect_one_sync(); + auto tile_shape = TileShape{}; + auto cluster_shape = ClusterShape{}; + constexpr int cluster_size = size(ClusterShape{}); + int cta_rank_in_cluster = cute::block_rank_in_cluster(); + bool is_first_cta_in_cluster = cta_rank_in_cluster == 0; + int cta_coord_v = cta_rank_in_cluster % size<0>(typename TiledMma::AtomThrID{}); + bool is_mma_leader_cta = cta_coord_v == 0; + int mma_leader_ctas = size(shape_div(cluster_shape, AtomThrShapeMNK{})); + [[maybe_unused]] uint32_t mma_peer_cta_rank = cta_rank_in_cluster; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop(params.mainloop); + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Do we load source tensor C or other aux inputs + bool is_epi_load_needed = collective_epilogue.is_producer_load_needed(); + + // printf("is_epi_load_needed = %d", (int)is_epi_load_needed); + + IsParticipant is_participant = { + (warp_category == WarpCategory::MMA) && is_mma_leader_cta, // mma + (warp_category == WarpCategory::Sched) && is_first_cta_in_cluster, // sched + (warp_category == WarpCategory::MainloopLoadTMA), // main_load_tma + (warp_category == WarpCategory::EpilogueLoad) && is_epi_load_needed, // epi_load + (warp_category == WarpCategory::Epilogue), // epilogue + (warp_category == WarpCategory::MainloopLoadCpAsync) // main_load_cpasync + }; + + // Mainloop Load pipeline (TMA) + typename MainloopPipelineTMA::Params mainloop_pipeline_tma_params; + if (WarpCategory::MainloopLoadTMA == warp_category) { + mainloop_pipeline_tma_params.role = MainloopPipelineTMA::ThreadCategory::Producer; + } + if (WarpCategory::MMA == warp_category) { + mainloop_pipeline_tma_params.role = MainloopPipelineTMA::ThreadCategory::Consumer; + } + + mainloop_pipeline_tma_params.is_leader = lane_predicate && is_mma_leader_cta && is_participant.main_load_tma; + mainloop_pipeline_tma_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + mainloop_pipeline_tma_params.initializing_warp = 0; + MainloopPipelineTMA mainloop_pipeline_tma(shared_storage.pipelines.mainloop.tma, + mainloop_pipeline_tma_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{}); // Delay mask calculation + + // Mainloop Load pipeline (CpAsync) + typename MainloopPipelineCpAsync::Params mainloop_pipeline_cpasync_params; + if (WarpCategory::MainloopLoadCpAsync == warp_category) { + mainloop_pipeline_cpasync_params.role = MainloopPipelineCpAsync::ThreadCategory::Producer; + } + if (WarpCategory::MMA == warp_category) { + mainloop_pipeline_cpasync_params.role = MainloopPipelineCpAsync::ThreadCategory::Consumer; + } + + mainloop_pipeline_cpasync_params.producer_arv_count = NumMainloopCpAsyncLoadThreads; + mainloop_pipeline_cpasync_params.consumer_arv_count = 1; // Only UMMA consumes the A and B buffers + mainloop_pipeline_cpasync_params.dst_blockid = cta_rank_in_cluster; + mainloop_pipeline_cpasync_params.initializing_warp = 0; + MainloopPipelineCpAsync mainloop_pipeline_cpasync(shared_storage.pipelines.mainloop.cpasync, mainloop_pipeline_cpasync_params, cluster_shape); + + // Epilogue Load pipeline + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (WarpCategory::EpilogueLoad == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cta_rank_in_cluster; + epi_load_pipeline_params.producer_arv_count = NumEpilogueLoadThreads; + epi_load_pipeline_params.consumer_arv_count = NumEpilogueThreads; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + epi_load_pipeline_params.initializing_warp = 3; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // CLC pipeline + typename CLCPipeline::Params clc_pipeline_params; + if (WarpCategory::Sched == warp_category) { + clc_pipeline_params.role = IsSchedDynamicPersistent ? CLCPipeline::ThreadCategory::ProducerConsumer : CLCPipeline::ThreadCategory::Producer; + } + else { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::Consumer; + } + clc_pipeline_params.producer_arv_count = 1; + + if constexpr (IsSchedDynamicPersistent) { + clc_pipeline_params.producer_blockid = 0; + clc_pipeline_params.consumer_arv_count = NumSchedThreads + cluster_size * + (NumMainloopTMALoadThreads + NumMainloopCpAsyncLoadThreads + NumEpilogueThreads + NumMMAThreads); + clc_pipeline_params.transaction_bytes = CLCResponseSize; + } + else { + clc_pipeline_params.consumer_arv_count = NumMainloopTMALoadThreads + NumMainloopCpAsyncLoadThreads + NumEpilogueThreads + NumMMAThreads; + } + + clc_pipeline_params.initializing_warp = 1; + // CLCPipeline clc_pipeline(shared_storage.pipelines.clc, clc_pipeline_params, cluster_shape); + // Now declare the pipeline outside the if constexpr + CLCPipeline clc_pipeline = [&]() { + if constexpr (IsSchedDynamicPersistent) { + return CLCPipeline(shared_storage.pipelines.clc, clc_pipeline_params, cluster_shape); + } + else { + return CLCPipeline(shared_storage.pipelines.clc, clc_pipeline_params); + } + }(); + + // Mainloop-Epilogue pipeline + typename AccumulatorPipeline::Params accumulator_pipeline_params; + if (WarpCategory::MMA == warp_category) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; + } + // Only one producer thread arrives on this barrier. + accumulator_pipeline_params.producer_arv_count = 1; + accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * NumEpilogueThreads; + accumulator_pipeline_params.initializing_warp = 2; + AccumulatorPipeline accumulator_pipeline(shared_storage.pipelines.accumulator, accumulator_pipeline_params, cluster_shape); + + // Tmem allocator + TmemAllocator tmem_allocator{}; + + // Sync allocation status between MMA and epilogue warps within CTA + arch::NamedBarrier tmem_allocation_result_barrier(NumMMAThreads + NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + // Sync deallocation status between MMA warps of peer CTAs + arch::ClusterBarrier& tmem_deallocation_result_barrier = shared_storage.pipelines.tmem_dealloc; + [[maybe_unused]] uint32_t dealloc_barrier_phase = 0; + + MainloopPipelineTMAState mainloop_pipe_tma_consumer_state; + MainloopPipelineTMAState mainloop_pipe_tma_producer_state = cutlass::make_producer_start_state(); + MainloopPipelineCpAsyncState mainloop_pipe_cpasync_consumer_state; + MainloopPipelineCpAsyncState mainloop_pipe_cpasync_producer_state = cutlass::make_producer_start_state(); + + EpiLoadPipelineState epi_load_pipe_consumer_state; + EpiLoadPipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + + // epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + EpiStorePipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + CLCPipelineState clc_pipe_consumer_state; + CLCPipelineState clc_pipe_producer_state = cutlass::make_producer_start_state(); + + AccumulatorPipelineState accumulator_pipe_consumer_state; + AccumulatorPipelineState accumulator_pipe_producer_state = cutlass::make_producer_start_state(); + + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer threadblocks in the cluster + pipeline_init_arrive_relaxed(cluster_size); + + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + // TileID scheduler + TileScheduler scheduler(&shared_storage.clc_response[0], params.scheduler, block_id_in_cluster); + typename TileScheduler::WorkTileInfo work_tile_info = scheduler.initial_work_tile_info(cluster_shape); + auto cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + + // + // TMEM "Allocation" + // + // auto acc_shape = collective_mainloop.partition_accumulator_shape(); + // auto bulk_tmem = TiledMma::make_fragment_C(append(acc_shape, + // Int{})); + auto tmem_storage = collective_mainloop.template init_tmem_tensors(EpilogueTile{}); + + // + // END PROLOGUE + // + + // Synchronization call. Blocks until barriers are initialized in shared memory. + pipeline_init_wait(cluster_size); + + // __syncwarp(); + // if (threadIdx.x % 32 == 0) { + // printf("warp %d start\n", warp_idx); + // } + + if (is_participant.main_load_tma) { + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + // bool do_load_order_arrive = is_epi_load_needed; + bool requires_clc_query = true; + + auto load_inputs = collective_mainloop.load_init_tma( + problem_shape_MNKL, shared_storage.tensors.mainloop); + auto k_tiles = cute::get<0>(load_inputs); + + do { + auto effective_shape = get_effective_shape(params.problem_shape, work_tile_info); + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto k_tile_iter = scheduler.get_k_tile_iterator(work_tile_info, effective_shape, CtaShape_MNK{}, k_tiles); + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, effective_shape, CtaShape_MNK{}); + // auto k_tile_prologue = min(MainloopPipeline::Stages, k_tile_count); + + + auto [mainloop_producer_state_next_, unused_] = collective_mainloop.load_tma( + mainloop_pipeline_tma, + mainloop_pipe_tma_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter, k_tile_count // - k_tile_prologue + ); + mainloop_pipe_tma_producer_state = mainloop_producer_state_next_; + + // Sync warp to prevent non-participating threads entering next wave early + syncwarp(); + + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + } while (work_tile_info.is_valid()); + collective_mainloop.load_tail_tma(mainloop_pipeline_tma, mainloop_pipe_tma_producer_state); + + } + + else if (is_participant.main_load_cpasync) { + auto load_inputs = collective_mainloop.load_init_cpasync( + problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop, + scheduler, work_tile_info); + Tensor gA_mkl = get<0>(load_inputs); + + do { + // Get current work tile and fetch next work tile + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + + auto effective_shape = get_effective_shape(params.problem_shape, work_tile_info); + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto k_tile_iter = scheduler.get_k_tile_iterator(work_tile_info, effective_shape, CtaShape_MNK{}, shape<3>(gA_mkl)); + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, effective_shape, CtaShape_MNK{}); + + auto [mainloop_producer_state_next, unused_] = collective_mainloop.load_cpasync( + params.mainloop, + mainloop_pipeline_cpasync, + mainloop_pipe_cpasync_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter, k_tile_count, + effective_shape + ); + mainloop_pipe_cpasync_producer_state = mainloop_producer_state_next; + + // Sync warp to prevent non-participating threads entering next wave early + syncwarp(); + + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + } while (work_tile_info.is_valid()); + + collective_mainloop.load_tail_cpasync(mainloop_pipeline_cpasync, mainloop_pipe_cpasync_producer_state); + + } + + else if (is_participant.sched) { + + if constexpr (IsSchedDynamicPersistent) { + // Whether a new CLC query must be performed. + // See comment below where this variable is updated for a description of + // why this variable is needed. + bool requires_clc_query = true; + + cutlass::arch::wait_on_dependent_grids(); + + do { + if (requires_clc_query) { + // Query next clcID and update producer state + clc_pipe_producer_state = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state); + } + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + // Only perform a new CLC query if we consumed a new CLC query result in + // `fetch_next_work`. An example of a case in which CLC `fetch_next_work` does + // not consume a new CLC query response is when processing stream-K units. + // The current stream-K scheduler uses single WorkTileInfo to track multiple + // (potentially-partial) tiles to be computed via stream-K. In this case, + // `fetch_next_work` simply performs in-place updates on the existing WorkTileInfo, + // rather than consuming a CLC query response. + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + work_tile_info = next_work_tile_info; + } while (work_tile_info.is_valid()); + clc_pipeline.producer_tail(clc_pipe_producer_state); + } + else { + + cutlass::arch::wait_on_dependent_grids(); + + do { + auto [next_work_tile_info, increment_pipe] = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state); + work_tile_info = next_work_tile_info; + if (increment_pipe) { + ++clc_pipe_producer_state; + } + } while (work_tile_info.is_valid()); + clc_pipeline.producer_tail(clc_pipe_producer_state); + } + } + + else if (is_participant.mma) { + // Tmem allocation sequence + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + // bulk_tmem.data() = tmem_base_ptr; + collective_mainloop.set_tmem_offsets(tmem_storage, tmem_base_ptr); + + + // Pass the acc with tuple type since the bgrad kernel change the mma_init API + auto mma_inputs = collective_mainloop.mma_init(params.mainloop, + tmem_storage, + shared_storage.tensors.mainloop); + do { + auto effective_shape = get_effective_shape(params.problem_shape, work_tile_info); + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, effective_shape, CtaShape_MNK{}); + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + // Wait for tmem accumulator buffer to become empty with a flipped phase + // accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + + int acc_stage = accumulator_pipe_producer_state.index(); + // Tensor accumulators = bulk_tmem(_,_,_,acc_stage); + auto [mainloop_pipe_tma_consumer_state_next_, mainloop_pipe_cpasync_consumer_state_next_] = collective_mainloop.mma( + cute::make_tuple(mainloop_pipeline_tma, mainloop_pipeline_cpasync, accumulator_pipeline), + cute::make_tuple(mainloop_pipe_tma_consumer_state, mainloop_pipe_cpasync_consumer_state, accumulator_pipe_producer_state), + // Pass the acc with tuple type since the bgrad kernel change the mma API + // cute::make_tuple(accumulators, accumulators), + collective_mainloop.slice_accumulator(tmem_storage, acc_stage), + mma_inputs, + cta_coord_mnkl, + k_tile_count + ); + mainloop_pipe_tma_consumer_state = mainloop_pipe_tma_consumer_state_next_; + mainloop_pipe_cpasync_consumer_state = mainloop_pipe_cpasync_consumer_state_next_; + + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + + ++accumulator_pipe_producer_state; + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + // Release the right to allocate before deallocations so that the next CTA can rasterize + tmem_allocator.release_allocation_lock(); + + accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); + + // Free entire tmem allocation + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + else if (is_participant.epi_load) { + bool do_tail_load = false; + do { + bool compute_epilogue = TileScheduler::compute_epilogue(work_tile_info, params.scheduler); + + // Get current work tile and fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + if (compute_epilogue) { + + epi_load_pipe_producer_state = collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + shared_storage.tensors.epilogue + ); + + do_tail_load = true; + } + + // Calculate the cta coordinates of the next work tile + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + + // Only perform a tail load if one of the work units processed performed + // an epilogue load. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_load) { + collective_epilogue.load_tail( + epi_load_pipeline, epi_load_pipe_producer_state, + epi_store_pipeline, epi_store_pipe_producer_state); + } + } + + else if (is_participant.epilogue) { + // Wait for tmem allocate here + tmem_allocation_result_barrier.arrive_and_wait(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + collective_mainloop.set_tmem_offsets(tmem_storage, tmem_base_ptr); + // bulk_tmem.data() = tmem_base_ptr; + + bool do_tail_store = false; + do { + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + // Accumulator stage slice + int acc_stage = accumulator_pipe_consumer_state.index(); + // Tensor accumulators = bulk_tmem(_,_,_,acc_stage); + auto accumulator = get<0>(collective_mainloop.slice_accumulator(tmem_storage, acc_stage)); + accumulator_pipe_consumer_state = scheduler.template fixup( + TiledMma{}, + work_tile_info, + accumulator, + accumulator_pipeline, + accumulator_pipe_consumer_state, + typename CollectiveEpilogue::CopyOpT2R{} + ); + + // + // Epilogue and write to gD + // + if (scheduler.compute_epilogue(work_tile_info)) { + auto [load_state_next, store_state_next, acc_state_next] = collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + accumulator_pipeline, + accumulator_pipe_consumer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + accumulator, + shared_storage.tensors.epilogue + ); + epi_load_pipe_consumer_state = load_state_next; + epi_store_pipe_producer_state = store_state_next; + accumulator_pipe_consumer_state = acc_state_next; + do_tail_store = true; + } + + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + + } while (work_tile_info.is_valid()); + + // Only perform a tail store if one of the work units processed performed + // an epilogue. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_store) { + collective_epilogue.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state, + epi_store_pipeline, epi_store_pipe_producer_state, + CtaShape_MNK{}); + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp b/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp index 07d00fb699..8cf885f890 100755 --- a/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp +++ b/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp @@ -240,6 +240,27 @@ class PersistentTileSchedulerSm100Group { void fixup(WorkTileInfo const&, FrgTensorC&, uint32_t, uint32_t, uint32_t = 1) const { } + template < + bool IsComplex, + class TiledMma, + class AccEngine, + class AccLayout, + class AccumulatorPipeline, + class AccumulatorPipelineState, + class CopyOpT2R + > + CUTLASS_DEVICE + AccumulatorPipelineState + fixup( + TiledMma const& , + WorkTileInfo const&, + cute::Tensor&, + AccumulatorPipeline, + AccumulatorPipelineState acc_pipe_consumer_state, + CopyOpT2R) const { + return acc_pipe_consumer_state; + } + template static size_t get_workspace_size(Arguments const& args, ProblemShape problem_shape, KernelHardwareInfo const& hw_info, uint32_t, uint32_t = 1, uint32_t = 1) { diff --git a/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_array_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_array_tma_warpspecialized.hpp index ed5ffa95d0..fcbc01a919 100644 --- a/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_array_tma_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_array_tma_warpspecialized.hpp @@ -991,7 +991,7 @@ class GemmUniversal< mainloop_sf_pipeline, mainloop_sf_pipe_producer_state, load_inputs, - cta_coord_mnkl, + cta_coord_mnk, k_tile_iter_next, k_tile_count - k_tile_prologue, false, /* did_batch_change - prologue loads handle tensormap acquire */ enable_prefetch ? k_tile_count - k_tile_prologue : 0 @@ -1164,7 +1164,7 @@ class GemmUniversal< } bool reverse_epi_n = IsOverlappingAccum && (current_wave % 2 == 0); - epi_load_pipe_producer_state = collective_epilogue.load( + epi_load_pipe_producer_state = collective_epilogue.template load( epi_load_pipeline, epi_load_pipe_producer_state, problem_shape_MNKL, diff --git a/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_tma_warpspecialized.hpp index 29be739382..57b9a3f383 100644 --- a/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_tma_warpspecialized.hpp @@ -977,7 +977,7 @@ class GemmUniversal< } bool reverse_epi_n = IsOverlappingAccum && (current_wave % 2 == 0); - epi_load_pipe_producer_state = collective_epilogue.load( + epi_load_pipe_producer_state = collective_epilogue.template load( epi_load_pipeline, epi_load_pipe_producer_state, problem_shape_MNKL, diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp index 8c564edb93..3e91f6b16c 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp @@ -831,8 +831,6 @@ class GemmUniversal< collective_epilogue.template tensormaps_fence_acquire(epi_load_tensormap); } - bool wait = work_tile_info.is_valid() && curr_batch != next_work_tile_info.L_idx; - epi_load_pipe_producer_state = collective_epilogue.load( epi_load_pipeline, epi_load_pipe_producer_state, @@ -843,8 +841,7 @@ class GemmUniversal< lane_idx, shared_storage.tensors.epilogue, epi_load_tensormap, - work_tile_info.reduction_subtile_idx(), - wait + work_tile_info.reduction_subtile_idx() ); } diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp index 6ac24d34dd..fd7ff603b8 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp @@ -869,8 +869,6 @@ class GemmUniversal< collective_epilogue.template tensormaps_fence_acquire(epi_load_tensormap); } - bool wait = work_tile_info.is_valid() && curr_batch != next_work_tile_info.L_idx; - epi_load_pipe_producer_state = collective_epilogue.load( epi_load_pipeline, epi_load_pipe_producer_state, @@ -881,8 +879,7 @@ class GemmUniversal< lane_idx, shared_storage.tensors.epilogue, epi_load_tensormap, - work_tile_info.reduction_subtile_idx(), - wait + work_tile_info.reduction_subtile_idx() ); } diff --git a/include/cutlass/gemm/kernel/xe_gemm.hpp b/include/cutlass/gemm/kernel/xe_gemm.hpp index 1b1bebcc18..378cc39e89 100644 --- a/include/cutlass/gemm/kernel/xe_gemm.hpp +++ b/include/cutlass/gemm/kernel/xe_gemm.hpp @@ -96,8 +96,6 @@ class GemmUniversal< using StrideD = typename CollectiveEpilogue::StrideD; using EpilogueArguments = typename CollectiveEpilogue::Arguments; using EpilogueParams = typename CollectiveEpilogue::Params; - static_assert(cute::is_same_v, - "Mainloop and epilogue do not agree on accumulator value type."); // MSVC requires the cast to fix a warning-as-error. static constexpr int SharedStorageSize = 0; @@ -109,7 +107,7 @@ class GemmUniversal< // Kernel level shared memory storage struct SharedStorage { - using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; // FIXME: rename to SharedStorage EpilogueTensorStorage epilogue; }; @@ -238,11 +236,11 @@ class GemmUniversal< constexpr auto workgroup_shape = WorkgroupTileShape{}; // (SUB_M,SUB_N,SUB_K) constexpr auto subgroup_shape = SubgroupTileShape{}; - Tensor mA_mkl = cute::get_xe_tensor(make_shape(M,K,L)); //(m,k,l) - Tensor mB_nkl = cute::get_xe_tensor(make_shape(N,K,L)); //(n,k,l) + Tensor cA = make_identity_tensor(make_shape(M,K,L)); // (M,K,L) + Tensor cB = make_identity_tensor(make_shape(N,K,L)); // (N,K,L) - Tensor gA = local_tile(mA_mkl, select<0,2>(blk_shape), make_coord(m_coord,_,l_coord)); - Tensor gB = local_tile(mB_nkl, select<1,2>(blk_shape), make_coord(n_coord,_,l_coord)); + Tensor gA = local_tile(cA, select<0,2>(blk_shape), make_coord(m_coord,_,l_coord)); + Tensor gB = local_tile(cB, select<1,2>(blk_shape), make_coord(n_coord,_,l_coord)); // Allocate the tiled_mma and the accumulators for the (M,N) subgroup_shape TiledMma tiled_mma; diff --git a/include/cutlass/gemm/kernel/xe_gemm_array_cooperative.hpp b/include/cutlass/gemm/kernel/xe_gemm_array_cooperative.hpp index 61bae38de0..9dfc8a6a77 100644 --- a/include/cutlass/gemm/kernel/xe_gemm_array_cooperative.hpp +++ b/include/cutlass/gemm/kernel/xe_gemm_array_cooperative.hpp @@ -107,7 +107,6 @@ class GemmUniversal< using MmaAtomShape = typename CollectiveMainloop::MmaAtomShape; using SubgroupTileShape = typename CollectiveMainloop::SubgroupTileShape; - using MainloopTensors = typename CollectiveMainloop::MainloopTensors; using EpilogueTensors = typename CollectiveEpilogue::EpilogueTensors; // Kernel level shared memory storage @@ -255,7 +254,7 @@ class GemmUniversal< int32_t curr_group = -1; using ProblemShapeMNKL = Shape; ProblemShapeMNKL problem_shape_MNKL; - MainloopTensors AB_tensors; + typename CollectiveMainloop::Base::Params base_params; EpilogueTensors CD_tensors; if (work_tile_info.is_valid()) { @@ -280,7 +279,9 @@ class GemmUniversal< CollectiveMainloop collective_mma; if(did_group_change) { - AB_tensors = collective_mma.update_tensor_shape_stride(params.mainloop, curr_group, problem_shape_MNKL); + base_params = CollectiveMainloop::Base::to_underlying_arguments(problem_shape_MNKL, + CollectiveMainloop::to_base_arguments(params.mainloop, curr_group), + params.workspace); } auto tile_coord = make_coord(m_coord, n_coord, _, 0); @@ -302,8 +303,7 @@ class GemmUniversal< tile_coord, K, thread_idx, - params.mainloop, - AB_tensors + base_params ); TileScheduler::fixup( diff --git a/include/cutlass/gpu_generics.h b/include/cutlass/gpu_generics.h index adc0882e91..69b582e1ec 100644 --- a/include/cutlass/gpu_generics.h +++ b/include/cutlass/gpu_generics.h @@ -365,6 +365,15 @@ CUTLASS_DEVICE T atomicAdd(T *address, T val) { return static_cast(0); } +template +CUTLASS_DEVICE T atomicSub(T *address, T val) { +#if defined(__SYCL_DEVICE_ONLY__) + return compat::atomic_fetch_sub(address, val); +#endif + return static_cast(0); +} + + CUTLASS_DEVICE int atomicCAS(int *address, int compare, int val) { int result = 0; #if defined(__SYCL_DEVICE_ONLY__) @@ -373,6 +382,15 @@ CUTLASS_DEVICE int atomicCAS(int *address, int compare, int val) { return result; } +CUTLASS_DEVICE int atomicLoad(int *address) { + int result = 0; +#if defined(__SYCL_DEVICE_ONLY__) + auto atm = sycl::atomic_ref(address[0]); + result = atm.load(); +#endif + return result; +} + // Error using cudaError_t = unsigned int; constexpr cudaError_t cudaSuccess = 0; diff --git a/include/cutlass/kernel_hardware_info.h b/include/cutlass/kernel_hardware_info.h index 6c2f1f307b..537d403b17 100644 --- a/include/cutlass/kernel_hardware_info.h +++ b/include/cutlass/kernel_hardware_info.h @@ -72,6 +72,17 @@ struct KernelHardwareInfo { #endif return multiprocessor_count; } + // Query maximum number of active clusters that could co-exist on the target device + // based on kernel properties such as cluster dims and threadblock dims + // Return 0 for Intel Xe12 and Xe20 architectures for now + static inline int + query_device_max_active_clusters( + dim3 cluster_dims, + uint32_t threads_per_block, + void const* kernel_ptr) { + return 0; + } + #elif !defined(__CUDACC_RTC__) static inline int diff --git a/include/cutlass/matrix.h b/include/cutlass/matrix.h index 78e15859b5..00222c128d 100644 --- a/include/cutlass/matrix.h +++ b/include/cutlass/matrix.h @@ -429,8 +429,8 @@ struct Matrix { Matrix operator-() const { Matrix m; - m.data[0] = -m.data[0]; - m.data[1] = -m.data[1]; + m.data[0] = -data[0]; + m.data[1] = -data[1]; return m; } @@ -1023,9 +1023,9 @@ struct Matrix { Matrix operator-() const { Matrix m; - m.data[0] = -m.data[0]; - m.data[1] = -m.data[1]; - m.data[2] = -m.data[2]; + m.data[0] = -data[0]; + m.data[1] = -data[1]; + m.data[2] = -data[2]; return m; } @@ -1214,7 +1214,7 @@ struct Matrix { Matrix cross(Matrix const &rhs) const { return Matrix( data[1] * rhs.data[2] - data[2] * rhs.data[1], - data[0] * rhs.data[2] - data[2] * rhs.data[1], + data[2] * rhs.data[0] - data[0] * rhs.data[2], data[0] * rhs.data[1] - data[1] * rhs.data[0] ); } @@ -1689,10 +1689,10 @@ struct Matrix { Matrix operator-() const { Matrix m; - m.data[0] = -m.data[0]; - m.data[1] = -m.data[1]; - m.data[2] = -m.data[2]; - m.data[3] = -m.data[3]; + m.data[0] = -data[0]; + m.data[1] = -data[1]; + m.data[2] = -data[2]; + m.data[3] = -data[3]; return m; } @@ -2306,8 +2306,8 @@ struct Matrix { Matrix operator-() const { Matrix m; - m.data[0] = -m.data[0]; - m.data[1] = -m.data[1]; + m.data[0] = -data[0]; + m.data[1] = -data[1]; return m; } @@ -3040,10 +3040,10 @@ struct Matrix { Matrix operator-() const { Matrix m; - m.data[0] = -m.data[0]; - m.data[1] = -m.data[1]; - m.data[2] = -m.data[2]; - m.data[3] = -m.data[3]; + m.data[0] = -data[0]; + m.data[1] = -data[1]; + m.data[2] = -data[2]; + m.data[3] = -data[3]; return m; } @@ -3912,12 +3912,12 @@ struct Matrix { Matrix operator-() const { Matrix m; - m.data[0] = -m.data[0]; - m.data[1] = -m.data[1]; - m.data[2] = -m.data[2]; - m.data[3] = -m.data[3]; - m.data[4] = -m.data[4]; - m.data[5] = -m.data[5]; + m.data[0] = -data[0]; + m.data[1] = -data[1]; + m.data[2] = -data[2]; + m.data[3] = -data[3]; + m.data[4] = -data[4]; + m.data[5] = -data[5]; return m; } @@ -4884,14 +4884,14 @@ struct Matrix { Matrix operator-() const { Matrix m; - m.data[0] = -m.data[0]; - m.data[1] = -m.data[1]; - m.data[2] = -m.data[2]; - m.data[3] = -m.data[3]; - m.data[4] = -m.data[4]; - m.data[5] = -m.data[5]; - m.data[6] = -m.data[6]; - m.data[7] = -m.data[7]; + m.data[0] = -data[0]; + m.data[1] = -data[1]; + m.data[2] = -data[2]; + m.data[3] = -data[3]; + m.data[4] = -data[4]; + m.data[5] = -data[5]; + m.data[6] = -data[6]; + m.data[7] = -data[7]; return m; } @@ -5590,9 +5590,9 @@ struct Matrix { Matrix operator-() const { Matrix m; - m.data[0] = -m.data[0]; - m.data[1] = -m.data[1]; - m.data[2] = -m.data[2]; + m.data[0] = -data[0]; + m.data[1] = -data[1]; + m.data[2] = -data[2]; return m; } @@ -5768,7 +5768,7 @@ struct Matrix { Matrix cross(Matrix const &rhs) const { return Matrix( data[1] * rhs.data[2] - data[2] * rhs.data[1], - data[0] * rhs.data[2] - data[2] * rhs.data[1], + data[2] * rhs.data[0] - data[0] * rhs.data[2], data[0] * rhs.data[1] - data[1] * rhs.data[0] ); } @@ -6457,12 +6457,12 @@ struct Matrix { Matrix operator-() const { Matrix m; - m.data[0] = -m.data[0]; - m.data[1] = -m.data[1]; - m.data[2] = -m.data[2]; - m.data[3] = -m.data[3]; - m.data[4] = -m.data[4]; - m.data[5] = -m.data[5]; + m.data[0] = -data[0]; + m.data[1] = -data[1]; + m.data[2] = -data[2]; + m.data[3] = -data[3]; + m.data[4] = -data[4]; + m.data[5] = -data[5]; return m; } @@ -7514,15 +7514,15 @@ struct Matrix { Matrix operator-() const { Matrix m; - m.data[0] = -m.data[0]; - m.data[1] = -m.data[1]; - m.data[2] = -m.data[2]; - m.data[3] = -m.data[3]; - m.data[4] = -m.data[4]; - m.data[5] = -m.data[5]; - m.data[6] = -m.data[6]; - m.data[7] = -m.data[7]; - m.data[8] = -m.data[8]; + m.data[0] = -data[0]; + m.data[1] = -data[1]; + m.data[2] = -data[2]; + m.data[3] = -data[3]; + m.data[4] = -data[4]; + m.data[5] = -data[5]; + m.data[6] = -data[6]; + m.data[7] = -data[7]; + m.data[8] = -data[8]; return m; } @@ -8905,18 +8905,18 @@ struct Matrix { Matrix operator-() const { Matrix m; - m.data[0] = -m.data[0]; - m.data[1] = -m.data[1]; - m.data[2] = -m.data[2]; - m.data[3] = -m.data[3]; - m.data[4] = -m.data[4]; - m.data[5] = -m.data[5]; - m.data[6] = -m.data[6]; - m.data[7] = -m.data[7]; - m.data[8] = -m.data[8]; - m.data[9] = -m.data[9]; - m.data[10] = -m.data[10]; - m.data[11] = -m.data[11]; + m.data[0] = -data[0]; + m.data[1] = -data[1]; + m.data[2] = -data[2]; + m.data[3] = -data[3]; + m.data[4] = -data[4]; + m.data[5] = -data[5]; + m.data[6] = -data[6]; + m.data[7] = -data[7]; + m.data[8] = -data[8]; + m.data[9] = -data[9]; + m.data[10] = -data[10]; + m.data[11] = -data[11]; return m; } @@ -9723,10 +9723,10 @@ struct Matrix { Matrix operator-() const { Matrix m; - m.data[0] = -m.data[0]; - m.data[1] = -m.data[1]; - m.data[2] = -m.data[2]; - m.data[3] = -m.data[3]; + m.data[0] = -data[0]; + m.data[1] = -data[1]; + m.data[2] = -data[2]; + m.data[3] = -data[3]; return m; } @@ -10724,14 +10724,14 @@ struct Matrix { Matrix operator-() const { Matrix m; - m.data[0] = -m.data[0]; - m.data[1] = -m.data[1]; - m.data[2] = -m.data[2]; - m.data[3] = -m.data[3]; - m.data[4] = -m.data[4]; - m.data[5] = -m.data[5]; - m.data[6] = -m.data[6]; - m.data[7] = -m.data[7]; + m.data[0] = -data[0]; + m.data[1] = -data[1]; + m.data[2] = -data[2]; + m.data[3] = -data[3]; + m.data[4] = -data[4]; + m.data[5] = -data[5]; + m.data[6] = -data[6]; + m.data[7] = -data[7]; return m; } @@ -11996,18 +11996,18 @@ struct Matrix { Matrix operator-() const { Matrix m; - m.data[0] = -m.data[0]; - m.data[1] = -m.data[1]; - m.data[2] = -m.data[2]; - m.data[3] = -m.data[3]; - m.data[4] = -m.data[4]; - m.data[5] = -m.data[5]; - m.data[6] = -m.data[6]; - m.data[7] = -m.data[7]; - m.data[8] = -m.data[8]; - m.data[9] = -m.data[9]; - m.data[10] = -m.data[10]; - m.data[11] = -m.data[11]; + m.data[0] = -data[0]; + m.data[1] = -data[1]; + m.data[2] = -data[2]; + m.data[3] = -data[3]; + m.data[4] = -data[4]; + m.data[5] = -data[5]; + m.data[6] = -data[6]; + m.data[7] = -data[7]; + m.data[8] = -data[8]; + m.data[9] = -data[9]; + m.data[10] = -data[10]; + m.data[11] = -data[11]; return m; } @@ -13594,22 +13594,22 @@ struct Matrix { Matrix operator-() const { Matrix m; - m.data[0] = -m.data[0]; - m.data[1] = -m.data[1]; - m.data[2] = -m.data[2]; - m.data[3] = -m.data[3]; - m.data[4] = -m.data[4]; - m.data[5] = -m.data[5]; - m.data[6] = -m.data[6]; - m.data[7] = -m.data[7]; - m.data[8] = -m.data[8]; - m.data[9] = -m.data[9]; - m.data[10] = -m.data[10]; - m.data[11] = -m.data[11]; - m.data[12] = -m.data[12]; - m.data[13] = -m.data[13]; - m.data[14] = -m.data[14]; - m.data[15] = -m.data[15]; + m.data[0] = -data[0]; + m.data[1] = -data[1]; + m.data[2] = -data[2]; + m.data[3] = -data[3]; + m.data[4] = -data[4]; + m.data[5] = -data[5]; + m.data[6] = -data[6]; + m.data[7] = -data[7]; + m.data[8] = -data[8]; + m.data[9] = -data[9]; + m.data[10] = -data[10]; + m.data[11] = -data[11]; + m.data[12] = -data[12]; + m.data[13] = -data[13]; + m.data[14] = -data[14]; + m.data[15] = -data[15]; return m; } diff --git a/include/cutlass/platform/platform.h b/include/cutlass/platform/platform.h index 7e3816394e..1547c48e0b 100644 --- a/include/cutlass/platform/platform.h +++ b/include/cutlass/platform/platform.h @@ -612,6 +612,52 @@ struct alignment_of { enum { value = 16 }; }; +#if !defined(CUDA_VECTOR_TYPE_ALIGNMENT_16_32_ENABLED) +#define CUDA_VECTOR_TYPE_ALIGNMENT_16_32_ENABLED (__CUDACC_VER_MAJOR__ >= 13) +#endif + +#if (CUDA_VECTOR_TYPE_ALIGNMENT_16_32_ENABLED) +template <> +struct alignment_of { + enum { value = 16 }; +}; +template <> +struct alignment_of { + enum { value = 16 }; +}; +template <> +struct alignment_of { + enum { value = 16 }; +}; +template <> +struct alignment_of { + enum { value = 16 }; +}; +template <> +struct alignment_of { + enum { value = 16 }; +}; +template <> +struct alignment_of { + enum { value = 32 }; +}; +template <> +struct alignment_of { + enum { value = 32 }; +}; +template <> +struct alignment_of { + enum { value = 32 }; +}; +template <> +struct alignment_of { + enum { value = 32 }; +}; +template <> +struct alignment_of { + enum { value = 32 }; +}; +#else template <> struct alignment_of { enum { value = 16 }; @@ -633,6 +679,7 @@ struct alignment_of { enum { value = 16 }; }; +#endif // Specializations for volatile/const qualified types template @@ -808,6 +855,7 @@ struct numeric_limits { static constexpr int32_t max() noexcept { return 2147483647;} static constexpr bool is_integer = true; static constexpr bool has_infinity = false; + static constexpr bool is_signed = true; }; template <> @@ -818,6 +866,7 @@ struct numeric_limits { static constexpr int16_t max() noexcept { return 32767;} static constexpr bool is_integer = true; static constexpr bool has_infinity = false; + static constexpr bool is_signed = true; }; template <> @@ -828,6 +877,7 @@ struct numeric_limits { static constexpr int8_t max() noexcept { return 127;} static constexpr bool is_integer = true; static constexpr bool has_infinity = false; + static constexpr bool is_signed = true; }; @@ -839,6 +889,7 @@ struct numeric_limits { static constexpr uint32_t max() noexcept { return 4294967295U;} static constexpr bool is_integer = true; static constexpr bool has_infinity = false; + static constexpr bool is_signed = false; }; template <> @@ -849,6 +900,7 @@ struct numeric_limits { static constexpr uint16_t max() noexcept { return 65535U;} static constexpr bool is_integer = true; static constexpr bool has_infinity = false; + static constexpr bool is_signed = false; }; template <> @@ -859,16 +911,20 @@ struct numeric_limits { static constexpr uint8_t max() noexcept { return 255U;} static constexpr bool is_integer = true; static constexpr bool has_infinity = false; + static constexpr bool is_signed = false; }; template <> struct numeric_limits { + CUTLASS_HOST_DEVICE + static constexpr float lowest() noexcept { return bit_cast(0xff7fffff);} CUTLASS_HOST_DEVICE static constexpr float infinity() noexcept { return bit_cast(0x7f800000);} CUTLASS_HOST_DEVICE static constexpr float max() noexcept { return bit_cast(0x7f7fffff);} static constexpr bool is_integer = false; static constexpr bool has_infinity = true; + static constexpr bool is_signed = true; }; /// Returns a value that curries the `std::maximum()` function into the identity diff --git a/include/cutlass/version.h b/include/cutlass/version.h index 71cdcae1df..57a73a5fbb 100644 --- a/include/cutlass/version.h +++ b/include/cutlass/version.h @@ -35,8 +35,8 @@ #include #define CUTLASS_MAJOR 4 -#define CUTLASS_MINOR 1 -#define CUTLASS_PATCH 0 +#define CUTLASS_MINOR 2 +#define CUTLASS_PATCH 1 #ifdef CUTLASS_VERSIONS_GENERATED #include "cutlass/version_extended.h" diff --git a/media/docs/cpp/build/building_with_sycl_support.md b/media/docs/cpp/build/building_with_sycl_support.md index e3cf10b483..a993ab90ab 100644 --- a/media/docs/cpp/build/building_with_sycl_support.md +++ b/media/docs/cpp/build/building_with_sycl_support.md @@ -10,17 +10,17 @@ resources for GPUs. ## Support for Intel GPUs -The CUTLASS-SYCL supports running on Intel GPUs. +The SYCL*TLA supports running on Intel GPUs. Currently, Intel Data Center Max 1550 and 1100 (a.k.a Ponte Vecchio - PVC) along with Intel Arc B580 (a.k.a BattleMage - BMG) are supported. The `examples` directory shows a number of GEMM algorithms and examples of -CUTLASS-SYCL running on PVC and BMG, including flash attention V2. +SYCL*TLA running on PVC and BMG, including flash attention V2. Only Linux platforms are supported. ### Requirements (SYCL for Intel GPU) -To build CUTLASS SYCL support for Intel GPUs, you need the DPC++ compiler; +To build SYCL*TLA support for Intel GPUs, you need the DPC++ compiler; you can use the latest open source [[nightly build](https://github.com/intel/llvm/releases)] or a oneAPI toolkit from 2025.1 onwards. Intel Compute Runtime 25.13 (with Intel Graphics Compiler 2.10.10) is required. At the time of the release it can be installed from [intel-graphics-staging](https://launchpad.net/~kobuk-team/+archive/ubuntu/intel-graphics-staging). Installation from [intel-graphics](https://launchpad.net/~kobuk-team/+archive/ubuntu/intel-graphics) is recommended when it is available there. @@ -87,7 +87,7 @@ purposes and not intended for production. ### Requirements -To build CUTLASS SYCL support you need the latest version of DPC++ compiler. You can either use a recent [nightly build](https://github.com/intel/llvm/releases) +To build SYCL*TLA support you need the latest version of DPC++ compiler. You can either use a recent [nightly build](https://github.com/intel/llvm/releases) or build the compiler from source as described in [oneAPI DPC++ guideline]((https://github.com/intel/llvm/blob/sycl/sycl/doc/GetStartedGuide.md#build-dpc-toolchain-with-support-for-nvidia-cuda)). ### Building with SYCL for NVIDIA support diff --git a/media/docs/cpp/cute/02_layout_algebra.md b/media/docs/cpp/cute/02_layout_algebra.md index 8314495b2a..465d3aef73 100644 --- a/media/docs/cpp/cute/02_layout_algebra.md +++ b/media/docs/cpp/cute/02_layout_algebra.md @@ -151,7 +151,7 @@ For example, * `(3,6,2,8) / 9 => (1,2,2,8)` * `(3,6,2,8) / 72 => (1,1,1,4)` -To compute the strides of the strided layout, the residues of the above operation are used to scale the strides of `A`. For instance, the last example `(3,6,2,8):(w,x,y,z) / 72` with strides `(w,x,y,z)` produces `(3*w,6*x,2*x,2*z)` as the strides of the strided layout. +To compute the strides of the strided layout, the residues of the above operation are used to scale the strides of `A`. For instance, the last example `(3,6,2,8):(w,x,y,z) / 72` with strides `(w,x,y,z)` produces `(72*w,24*x,4*y,2*z)` as the strides of the strided layout. As you may have noticed, we can only divide shapes by certain values and get a sensible result. This is called the **stride divisibility condition** and is statically checked in CuTe when possible. @@ -171,7 +171,7 @@ This operation causes the result to have a shape that is compatible with `B`. Again, this operation must satisfy a **shape divisibility condition** to yield a sensible result and is statically checked in CuTe when possible. -From the above examples, we can construct the composition `(3,6,2,8):(w,x,y,z) o 16:9 = (1,2,2,4):(3*w,3*x,y,z)`. +From the above examples, we can construct the composition `(3,6,2,8):(w,x,y,z) o 16:9 = (1,2,2,4):(9*w,3*x,y,z)`. --- #### Example 1 -- Worked Example of Calculating a Composition diff --git a/media/docs/cpp/cute/03_tensor.md b/media/docs/cpp/cute/03_tensor.md index 413aa1b137..5f83e500af 100644 --- a/media/docs/cpp/cute/03_tensor.md +++ b/media/docs/cpp/cute/03_tensor.md @@ -217,7 +217,7 @@ for (int i = 0; i < A.size(); ++i) ## Tiling a Tensor -Many of the [`Layout` algebra operations](https://github.com/NVIDIA/cutlass/blob/main/media/docs/cute/02_layout_algebra.md) can also be applied to `Tensor`. +Many of the [`Layout` algebra operations](https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/cute/02_layout_algebra.md) can also be applied to `Tensor`. ```cpp composition(Tensor, Tiler) logical_divide(Tensor, Tiler) diff --git a/media/docs/cpp/pipeline.md b/media/docs/cpp/pipeline.md index aa6473043b..2533f69799 100644 --- a/media/docs/cpp/pipeline.md +++ b/media/docs/cpp/pipeline.md @@ -90,7 +90,7 @@ before issuing other instructions associated with a particular pipeline stage (e.g., copy or write). This is a blocking instruction -which blocks further execution of consumer threads +which blocks further execution of producer threads unless the particular stage waiting to be acquired is released by a consumer. diff --git a/media/docs/cpp/profiler.md b/media/docs/cpp/profiler.md index 8331b75fb6..57949a949a 100644 --- a/media/docs/cpp/profiler.md +++ b/media/docs/cpp/profiler.md @@ -79,7 +79,7 @@ Instruction shape levels control the selection of WGMMA shapes used in kernel ge - **Level 2**: Includes shapes that are powers of 2. - **Level 3**: Includes all other shapes. -The detailed defination of the three instantiation levels controlling cluster shape, MMA shape multiplier, and instruction shape can be found in [sm90_shapes.py](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/sm90_shapes.py). +The detailed definition of the three instantiation levels controlling cluster shape, MMA shape multiplier, and instruction shape can be found in [sm90_shapes.py](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/sm90_shapes.py). Schedule pruning levels decide the epilogue schedule and mainloop schedule to stamp out a kernel instance. As defined in `get_valid_schedules` in [sm90_utils.py](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/sm90_utils.py), @@ -122,6 +122,55 @@ For each mixed dtype kernel, the kernel generator will generate combinations of For {4-bits-dtype, 8-bits-dtype} x 16-bits-dtype, the kernel generator will further generate kernels using shuffled layouts for the narrow data type matrix, which may have a better performance compared to its non-shuffle counter parts. +## Instantiating more kernels with Blackwell +Blackwell (SM100) and Blackwell Ultra similarly support +`CUTLASS_LIBRARY_INSTANTIATION_LEVEL`, in order to instantiate all possible combinations. +Due to this, `CUTLASS_LIBRARY_KERNELS` must be non-empty, since generating and filtering these +kernels alone can take hours. +You must also exercise caution, because not all of these configs are tested, and some may fail to +compile or fail to launch at runtime. + +```bash +$ cmake .. \ + -DCUTLASS_NVCC_ARCHS="100f" \ + -DCUTLASS_LIBRARY_KERNELS="cutlass3x_sm100_tensorop_gemm_f16_f16_f32_void_f32_*" \ + -DCUTLASS_LIBRARY_INSTANTIATION_LEVEL="max" \ + -DCUTLASS_UNITY_BUILD_ENABLED=ON +``` + +The CUTLASS profiler uses the same four-digit integer level (global instantiation level) mechanism to manage the generation of kernel configurations for Blackwell as well: + +0. **Instruction Shape** +1. **MMA Shape Multiplier** +2. **Cluster Shape** +3. **Data Type and Schedule Pruning** + +Note for Blackwell kernels an MMA shape multiplier is no longer necessary since Blackwell kernels do not have a different +ping pong or cooperative schedule. The profiler ignores this digit when instantiating. + +Cluster shape levels define the number of CTAs (Cooperative Thread Arrays) included in the kernel generation: + +- **Level 0**: Only dynamic cluster shapes. +- **Level 1**: For 1SM kernels `(1, 1, 1)` and `(2, 1, 1)` for 2SM kernels. +- **Level 2**: For 1SM kernels we also have `(1, 2, 1)` and for 2SM we have `(2, 2, 1)` and `(4, 1, 1)`. +- **Level 3**: For 1SM kernels we have `(1, 4, 1)` and for 2SM we have `(2, 4, 1)` and `(4, 2, 1)`. +- **Level 4**: For 1SM kernels we have `(4, 4, 1)` and for 2SM we have `(4, 4, 1)`. +- **Level 5**: For 1SM kernels we have `(2, 1, 1)`. +- **Level 6**: For 1SM kernels we have `(2, 2, 1)` and `(4, 1, 1)` and for 2SM kernels we have `(8, 1, 1)`. +- **Level 7**: For 1SM kernels we have `(2, 4, 1)` and `(4, 2, 1)` +- **Level 8**: For 1SM kernels we have `(1, 8, 1)` and `(8, 1, 1)` + +Instruction shape levels control the selection of MMA shapes used in kernel generation: + +- **Level 0**: Generates the "default" shape only. +- **Level 1**: Includes additional shapes for FP8, FP6, and FP4 as well as MX and NVFP4. +- **Level 2**: Includes small tile shapes. +- **Level 3**: Includes some non-power of 2 shapes. +- **Level 4**: Includes further small tile shapes and non-power of 2 shapes. +- **Level 5**: Includes all shapes. + +The detailed definition of the three instantiation levels controlling cluster shape and instruction shape can be found in [sm100_shapes.py](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/sm100_shapes.py). + ## CUTLASS Profiler usage The CUTLASS Profiler usage statement may be obtained by executing `cutlass_profiler --help` and appears as follows. @@ -577,6 +626,10 @@ cutlass3x_sm90_tensorop_gemm_f16_f16_f16_void_f16_128x128x64_1x1x1_0_nnn_align8_ * `f16_f16_f16_void_f16`: In this case, C type is set to `void`, indicating that residual matrix support is disabled. +## Further Documentation + +For documentation on profiling blockwise and groupwise (software scaled) GEMMs see the [example 81 README](https://github.com/NVIDIA/cutlass/blob/main/examples/81_blackwell_gemm_blockwise/README.md). + # Convolution The CUTLASS Profiler is capable of executing 2-D and 3-D convolution problems for forwards and backwards diff --git a/media/docs/cpp/xe_rearchitecture.md b/media/docs/cpp/xe_rearchitecture.md index 7e5c49bfb4..24b21611af 100644 --- a/media/docs/cpp/xe_rearchitecture.md +++ b/media/docs/cpp/xe_rearchitecture.md @@ -147,7 +147,7 @@ struct Copy_Traits; Since it can be a tricky to correctly choose block 2D parameters and set up an appropriate tiling, we introduce several helpers for creating TiledCopy objects. -The high-level APIs `make_block_2d_copy_{A,B,C}` automatically create TiledCopy objects for use with an existing `TiledMMA`. They choose the copy operation and trait template parameters heuristically. +The high-level APIs `make_block_2d_copy_{A,B,C,D}` automatically create TiledCopy objects for use with an existing `TiledMMA`. They choose the copy operation and trait template parameters heuristically. Note that `make_block_2d_copy_C` and `make_block_2d_copy_D` only differ in their choice of a load (C) or store (D) operation. ```c++ template @@ -167,6 +167,12 @@ CUTE_DEVICE TiledCopy<...> make_block_2d_copy_C(const TiledMMA<...>&, const Tensor& gmem); // (M,N,...) + +template +CUTE_DEVICE +TiledCopy<...> +make_block_2d_copy_D(const TiledMMA<...>&, + const Tensor& gmem); // (M,N,...) ``` The user may also override the choice of copy operation: @@ -179,7 +185,15 @@ make_block_2d_copy_A(CopyOp const& op, // Copy operation TiledMMA const& mma, // TiledMMA instance Tensor const& gmem); // Global tensor -/* Similarly for B/C */ +/* Similarly for B */ + +/* Single routine for both C/D */ +template +CUTE_HOST_DEVICE +auto +make_block_2d_copy_CD(CopyOp const& op, // Copy operation + TiledMMA const& mma, // TiledMMA instance + Tensor const& gmem); // Global tensor ``` The `make_block_2d_copy_*` family of functions create TiledCopy objects that match the scope of the TiledMMA. That is, the set of threads participating in the TiledMMA will also participate in the TiledCopy. @@ -194,7 +208,7 @@ TiledCopy make_block_2d_copy(const CopyOp& op, const Tensor& gmem); ``` -For advanced usage, there are additional overloads of `make_block_2d_copy` that allow more general work distributions for copies (see `include/cute/atom/copy_traits_xe_2d.hpp`). +For advanced usage, there are additional overloads of `make_block_2d_copy` in which multiple subgroups participate (see `include/cute/atom/copy_traits_xe_2d.hpp`). As the `CUTE_DEVICE` decorators imply, all the APIs above should be called from device code only, as they set up internal state that cannot be transferred from host to device. @@ -250,10 +264,10 @@ In order to perform thread-level operations on subgroup-shared data, it's import ```math \text{thread\ } i\ \ \ \text{owns elements} \ \ \ i, i+16, i+32, \ldots ``` -That is to say, elements are assigned to threads in a round-robin fashion. Conversely, if we declare a vector variable (say `cute::intel::float8`) in SYCL, the compiler will interleave the vectors from each thread in the subgroup to form a length-128 (8 * 16) float array in registers. +That is to say, elements are assigned to threads in a round-robin fashion. Conversely, if we declare a vector variable (say `cute::intel::float8`) in SYCL, the compiler will interleave the vectors from each thread in the subgroup to form a length-128 (8 * 16) float array in registers. Caveat: types smaller than a byte have special rules, to be discussed later. > [!IMPORTANT] -> Note that the thread mapping _depends on the element data size._ If an array of 32-bit data, say, is reinterpreted _on a register level_ as an array of 16-bit data, data ownership will change -- i.e., in SIMT terms, it is a shuffle operation. Contrast this operation with a SIMT bitcast/reinterpret_cast, which does not change data ownership, but _does_ shuffle data in registers. +> Note that the thread mapping _depends on the element data size._ If an array of 32-bit data, say, is reinterpreted _on a register level_ as an array of 16-bit data, data ownership will change — i.e., in SIMT terms, it is a shuffle operation. Contrast this operation with a SIMT bitcast/reinterpret_cast, which does not change data ownership, but _does_ shuffle data in registers. Now that we have the basic thread mapping rule, let's apply it to a simple block 2D load, with height = 8 rows and width = 4 columns. Recalling that the width dimension is contiguous in both memory and registers, we deduce the following mapping: ```math @@ -285,6 +299,9 @@ Now that we have the basic thread mapping rule, let's apply it to a simple block \end{array} \end{array} ``` +The subgroup view shows the data that the entire subgroup owns. The idea here is that the subgroup owns 32 values, enumerated in the order shown. These indices represent the order of elements in registers. +Recall that Intel GPUs have no notion of a "register owned by a thread." Registers belong to subgroups, because it is a SIMD architecture. + (Following CuTe convention, `TxVy` means thread `x`, value `y`.) An individual DPAS atom's A matrix follows the same pattern, with height ranging from 1 to 8, and width equal to 8 (tf32), 16 (f16/bf16), or 32 (s8/u8). The DPAS C matrix is also organized this way, except that its width is always 16. @@ -314,6 +331,58 @@ As a more complicated example, let's consider a 16-bit VNNI load, with height = The DPAS B matrix follows the same pattern. +#### Sub-byte Types + +When the data type is smaller than a byte, we have to take special care with the interleaved data ownership rule. CUTLASS packs these elements into bytes, and these _bytes_ are interleaved between work-items. That is, work-item $i$ owns bytes $i$, $i+16$, $i+32$, etc. + +Here's how that looks for an 16x4 int4 block 2D load: + +```math + \begin{array}{c} + \text{Subgroup view}\\ + \begin{array}{cccc} + 0 & 1 & 2 & 3\\ + 4 & 5 & 6 & 7\\ + 8 & 9 & 10 & 11\\ + 12 & 13 & 14 & 15\\ + 16 & 17 & 18 & 19\\ + 20 & 21 & 22 & 23\\ + 24 & 25 & 26 & 27\\ + 28 & 29 & 30 & 31\\ + 32 & 33 & 34 & 35\\ + 36 & 37 & 38 & 39\\ + 40 & 41 & 42 & 43\\ + 44 & 45 & 46 & 47\\ + 48 & 49 & 50 & 51\\ + 52 & 53 & 54 & 55\\ + 56 & 57 & 58 & 59\\ + 60 & 61 & 62 & 63\\ + \end{array} + \end{array} + \rightarrow + \begin{array}{c} + \text{Thread view}\\ + \begin{array}{cccc} + \text{T0V0} & \text{T0V1} & \text{T1V0} & \text{T1V1}\\ + \text{T2V0} & \text{T2V1} & \text{T3V0} & \text{T3V1}\\ + \text{T4V0} & \text{T4V1} & \text{T5V0} & \text{T5V1}\\ + \text{T6V0} & \text{T6V1} & \text{T7V0} & \text{T7V1}\\ + \text{T8V0} & \text{T8V1} & \text{T9V0} & \text{T9V1}\\ + \text{T10V0} & \text{T10V1} & \text{T11V0} & \text{T11V1}\\ + \text{T12V0} & \text{T12V1} & \text{T13V0} & \text{T13V1}\\ + \text{T12V0} & \text{T14V1} & \text{T15V0} & \text{T15V1}\\ + \text{T0V2} & \text{T0V3} & \text{T1V2} & \text{T1V3}\\ + \text{T2V2} & \text{T2V3} & \text{T3V2} & \text{T3V3}\\ + \text{T4V2} & \text{T4V3} & \text{T5V2} & \text{T5V3}\\ + \text{T6V2} & \text{T6V3} & \text{T7V2} & \text{T7V3}\\ + \text{T8V2} & \text{T8V3} & \text{T9V2} & \text{T9V3}\\ + \text{T10V2} & \text{T10V3} & \text{T11V2} & \text{T11V3}\\ + \text{T12V2} & \text{T12V3} & \text{T13V2} & \text{T13V3}\\ + \text{T12V2} & \text{T14V3} & \text{T15V2} & \text{T15V3} + \end{array} + \end{array} +``` + ### The SubgroupTensor Class @@ -419,7 +488,7 @@ gemm_device(ATensor const& A, // (M,K) /* Create block 2D TiledCopies */ auto copy_a = make_block_2d_copy_A(mma, A); auto copy_b = make_block_2d_copy_B(mma, B); - auto copy_c = make_block_2d_copy_C(mma, C); + auto copy_c = make_block_2d_copy_D(mma, C); /* Slice TiledCopy/TiledMMA operations to thread (work-item) level */ auto thr_mma = mma.get_slice(local_id); @@ -507,4 +576,4 @@ gemm_device(ATensor const& A, // (M,K) ## New Collective MMAs -... coming later! \ No newline at end of file +... coming later! diff --git a/media/docs/python/xe_cutlass_library.md b/media/docs/python/xe_cutlass_library.md new file mode 100644 index 0000000000..b348db46d8 --- /dev/null +++ b/media/docs/python/xe_cutlass_library.md @@ -0,0 +1,192 @@ + + +# Kernel Generation and Manifest + +This is a code/kernel generation system that creates a searchable catalog of CUTLASS kernel operations, bridging build-time generation and runtime selection. + +## Architecture Overview + +**Two-Phase System:** +1. **Build Time (Python)**: `manifest.py` generates C++ initialization code +2. **Runtime (C++)**: Generated code registers operations into a searchable `Manifest` + +``` +Python Generator → C++ Files → Compiled Library → Runtime Catalog +``` + +## Key Components + +### Python Generator (`manifest.py`) + +**Responsibilities:** +- Filter kernels by GPU architecture (SM/Xe), operation type, patterns +- Group operations by kind/architecture/instruction type +- Generate C++ initialization functions and CMake files + +### Generated File Structure +``` +build/tools/library/generated/ +├── initialize_all.cpp +├── gemm/20/tensorop/cutlass3x_xe20_tensorop_gemm_bf16_*.cpp +└── manifest.cmake +``` + +### Architecture Naming +| GPU | Prefix | ID | Example | +|-----|--------|----|---------| +| CUDA | `sm` | 70-90 | `sm80` | +| Intel Xe | `xe` | 12,20 | `xe20` | + +## Runtime API + +### Core Classes + +```cpp +// Manifest: Operation catalog +class Manifest { + Status initialize(); + void append(Operation *op); + OperationVector const& operations() const; +}; + +// Operation: Base kernel interface +class Operation { + virtual Status can_implement(void const *config, void const *args) const = 0; + virtual Status run(void const *args, void *workspace, Stream stream) const = 0; +}; +``` + +### Initialization Hierarchy +```cpp +namespace cutlass::library { + void initialize_all(Manifest &manifest); // All operations + void initialize_all_gemm_operations(Manifest &manifest); // GEMM only + void initialize_all_xe20_gemm_operations(Manifest &manifest); // XE20 GEMM +} +``` + +## Usage Examples + +### Basic Usage +```cpp +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +cutlass::library::Manifest manifest; +cutlass::library::initialize_all(manifest); + +// Find BF16 GEMM +for (auto& op : manifest.operations()) { + if (op->description().name.find("bf16") != std::string::npos) { + // Use operation... + } +} +``` + +### Python Integration +```python +# Use extern "C" wrappers for ctypes integration +from ctypes import CDLL +lib = CDLL("libcutlass_gemm_xe20_gemm.so") +# Call exported C functions that wrap C++ manifest APIs +``` + +**Example Implementation:** See `examples/11_xe20_cutlass_library/` for a complete CMake-based shared library that exports CUTLASS kernels for Python usage via ctypes. + +## Common Patterns + +### Lazy Initialization +```cpp +class LazyManifest { + cutlass::library::Manifest manifest_; + bool initialized_ = false; +public: + cutlass::library::Manifest& get() { + if (!initialized_) { + cutlass::library::initialize_all(manifest_); + initialized_ = true; + } + return manifest_; + } +}; +``` + +### Operation Caching +```cpp +class OperationCache { + std::map cache_; +public: + cutlass::library::Operation* find(const std::string& pattern) { + if (cache_.count(pattern)) return cache_[pattern]; + // Search manifest and cache result... + } +}; +``` + +## Build Integration + +### CMake Configuration +```bash +# Generate for Intel XE20 +cmake .. -DCUTLASS_LIBRARY_GENERATOR_ARCHS="20" +ninja cutlass_library +``` + +### Python Generator +```bash +python3 generator.py --operations=gemm --architectures=20 --build-dir=. +``` + +## Performance Tips + +- **Selective Initialization**: Only initialize needed operation kinds +- **Operation Caching**: Cache frequently used operations +- **Kernel Filtering**: Use build-time filtering to reduce library size +- **Lazy Loading**: Initialize manifest only when needed + +## Debugging + +```bash +# List generated operations +nm -D libcutlass_gemm_xe20_gemm.so | grep initialize + +# Enable Python debug logging +python3 -c "import logging; logging.basicConfig(level=logging.DEBUG)" +``` + +## References + +- **Source**: `python/cutlass_library/manifest.py` +- **Headers**: `tools/library/include/cutlass/library/` +- **Generated**: `build/tools/library/generated/` +- **Examples**: + - `examples/11_xe20_cutlass_library/` - CMake-based shared library for Python integration + - `examples/python/cutlass_library/xe20_gemm_bf16.py` - Python test script using ctypes diff --git a/media/docs/python/xe_library_generation.md b/media/docs/python/xe_library_generation.md new file mode 100644 index 0000000000..2c4786bdc9 --- /dev/null +++ b/media/docs/python/xe_library_generation.md @@ -0,0 +1,218 @@ + + +# Intel SYCL*TLA Library Generation Guide + +**Complete Reference for Intel Xe GPU Architecture Support** + +--- + +## Quick Start + +```bash +# Configure for BMG (Xe2) +cd build +cmake .. -GNinja -DCUTLASS_NVCC_ARCHS="" -DCUTLASS_ENABLE_SYCL=ON -DSYCL_INTEL_TARGET -DCUTLASS_LIBRARY_GENERATOR_ARCHS="20" + +# Build libraries +ninja cutlass_library + +--- + +## Architecture Support + +| GPU | Arch | Compute Cap | File Ext | Arch Tag | +|-----|------|-------------|----------|----------| +| **BMG** (Xe2) | 20 | 12-50 | `.cpp` | `cutlass::arch::Xe20` | +| **PVC** (Xe-HPC) | 12 | 12-50 | `.cpp` | `cutlass::arch::Xe12` | + +**Key Differences from CUDA:** +- Architecture prefix: `xe` (not `sm`) +- File extension: `.cpp` (not `.cu`) +- Compute capability: 12-50 (vs 50-120 for CUDA) + +--- + +## Supported Kernel Types + +### ✅ Homogeneous Types (A == B) + +| Type | A × B → C/D | Math Inst | Tile | Align | Status | +|------|-------------|-----------|------|-------|--------| +| **FP16** | half × half → float | [8,16,16] | 256×256×32 | 8 | ✅ | +| **BF16** | bf16 × bf16 → float | [8,16,16] | 256×256×32 | 8 | ✅ | +| **FP8-E4M3** | e4m3 × e4m3 → float | [8,16,32] | 256×256×64 | 16 | ✅ | +| **FP8-E5M2** | e5m2 × e5m2 → float | [8,16,32] | 256×256×64 | 16 | ✅ | +| **INT8** | int8 × int8 → int32 | [8,16,32] | 256×256×64 | 16 | ✅ | + +**Layout Combinations:** RR, RC, CR, CC (4 variants per type) + +### ❌ Mixed Precision (A ≠ B) + +Mixed precision infrastructure is not supported now: +- FP16 × E4M3/E5M2 → FP32 +- BF16 × E4M3/E5M2 → FP32 +- FP16 × INT4 → FP32 + +--- + +## Generated Libraries + +```bash +$ ls -lh build/tools/library/libcutlass*.so +-rwxrwxr-x 186K libcutlass_gemm_xe20_gemm_bf16.so # BF16 kernels +-rwxrwxr-x 186K libcutlass_gemm_xe20_gemm_e4m3.so # FP8 E4M3 +-rwxrwxr-x 186K libcutlass_gemm_xe20_gemm_e5m2.so # FP8 E5M2 +-rwxrwxr-x 186K libcutlass_gemm_xe20_gemm_f16.so # FP16 kernels +-rwxrwxr-x 186K libcutlass_gemm_xe20_gemm_s8.so # INT8 kernels +-rwxrwxr-x 186K libcutlass_gemm_xe20_gemm.so # Generic +-rwxrwxr-x 19K libcutlass.so # Base library +``` + +### Kernel Naming Convention + +``` +cutlass3x_xe{arch}_{opclass}_{operation}_{dtype}_{tile}_{warp}_{layout}_align{N} +``` + +**Examples:** +```cpp +cutlass3x_xe20_tensorop_gemm_f16_256x256_32x0_nn_align8 // FP16, Row×Row +cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nt_align8 // BF16, Row×Column +cutlass3x_xe20_tensorop_gemm_e4m3_256x256_64x0_tn_align16 // E4M3, Column×Row +``` + +**Layout Codes:** `nn`=Row×Row, `nt`=Row×Column, `tn`=Column×Row, `tt`=Column×Column + +--- + +## Build & Usage + +### CMake Configuration + +```bash +# BMG (Xe2) +cmake .. -GNinja -DCUTLASS_ENABLE_SYCL=ON -DCUTLASS_LIBRARY_GENERATOR_ARCHS="20" + +# PVC (Xe-HPC) +cmake .. -GNinja -DCUTLASS_ENABLE_SYCL=ON -DCUTLASS_LIBRARY_GENERATOR_ARCHS="12" +``` + +### Build Targets + +```bash +ninja cutlass_library # All libraries +ninja cutlass_library_gemm_xe20_gemm_bf16 # BF16 only +ninja cutlass_library_gemm_xe20_gemm_f16 # FP16 only +``` + +### Python Generator (Direct) + +```bash +cd build +python3 ../python/cutlass_library/generator.py --operations=gemm --architectures=20 --build-dir=. +``` + +### Python Integration Example + +For Python integration via ctypes, see: +- **`examples/11_xe20_cutlass_library/`** - Complete CMake-based shared library example +- **`examples/python/cutlass_library/xe20_gemm_bf16.py`** - Python test script using ctypes + +**Build and test:** +```bash +# Build the shared library +ninja xe20_cutlass_library_bf16 + +# Test with Python +cd examples/python/cutlass_library +python3 xe20_gemm_bf16.py +``` + +## Troubleshooting + +### No Operations Generated +**Check:** `GenerateIntelXe()` called for arch in [12, 20] in `generator.py` + +### Library Link Errors +``` +undefined reference to `initialize_all_xe20_gemm_bf16_gemm_operations()` +``` +**Solution:** Build and link the specific library: `-lcutlass_gemm_xe20_gemm_bf16` + +## Summary + +### ✅ What Works +- **5 data type libraries** (FP16, BF16, E4M3, E5M2, INT8) +- **~24 operations, 31 .cpp files** generated +- **Homogeneous type kernels** compile cleanly +- **INT32 accumulator** for INT8 +- **FP8→FP16 conversion** in MMA + +### ❌ Limitations +- **Mixed precision** requires grouped GEMM +- **Regular library** only supports ElementA == ElementB +- **No INT4** in regular GEMM + +### 📊 Quick Reference +| Feature | Value | +|---------|-------| +| Arch Numbers | BMG=20, PVC=12 | +| File Ext | `.cpp` | +| Arch Prefix | `xe` | +| CC Range | 12-50 | +| Total Libraries | 7 | +| Total Kernels | ~24 | +| Supported Types | FP16, BF16, E4M3, E5M2, INT8 | + +## Examples and References + +### Practical Examples +- **`examples/11_xe20_cutlass_library/`** - CMake-based shared library for Python integration + - Exports `sycl_tla_gemm_xe20_bf16()` function via extern "C" + - Builds `libxe20_cutlass_library_bf16.so` with proper CMake integration + - Integrated into main examples build system (`ninja cutlass_examples`) + +- **`examples/python/cutlass_library/xe20_gemm_bf16.py`** - Python ctypes integration + - Complete test script using the shared library + - Demonstrates workspace querying, execution, and benchmarking + - Shows proper error handling and performance measurement + +### Build Integration +```bash +# Build the example library +ninja xe20_cutlass_library_bf16 + +# Run Python test +cd examples/python/cutlass_library +python3 xe20_gemm_bf16.py +``` + +--- diff --git a/pyproject.toml b/pyproject.toml index 58d8d0ec61..433433e892 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,9 +3,9 @@ requires = ["setuptools"] build-backend = "setuptools.build_meta" [project] -name = "cutlass-sycl" -version = "0.5.0" -description = "CUTLASS-SYCL" +name = "sycl-tla" +version = "0.6.0" +description = "sycl templates for linear algebra" readme = "README.md" requires-python = ">=3.8" license = {text = "BSD-3-Clause"} @@ -23,5 +23,5 @@ dependencies = [ ] [project.urls] -"Homepage" = "https://github.com/intel/cutlass-sycl" -"Bug Tracker" = "https://github.com/intel/cutlass-sycl/issues" +"Homepage" = "https://github.com/intel/sycl-tla" +"Bug Tracker" = "https://github.com/intel/sycl-tla/issues" diff --git a/python/README.md b/python/README.md index 3f6365f166..70de228123 100644 --- a/python/README.md +++ b/python/README.md @@ -1,18 +1,18 @@ -![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") +![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete Intel Xe GEMM decomposition") -# Python packages associated with CUTLASS +# Python packages associated with SYCL*TLA (with CUTLASS compatibility) -This directory contains Python packages that are associated with CUTLASS: +This directory contains Python packages that are associated with SYCL*TLA (Intel Xe CUTLASS backend): -* `cutlass_cppgen`: the CUTLASS Python interface, which enables one to compile and run CUTLASS kernels from within Python. Note that this was previously named `cutlass`, but was renamed to disambiguate with the CuTe Python DSL. -* `cutlass_library`: utilities used for enumerating and emitting C++ code for CUTLASS kernels +* `cutlass_cppgen`: the SYCL*TLA Python interface, which enables one to compile and run CUTLASS kernels on Intel GPUs from within Python. Note that this was previously named `cutlass`, but was renamed to disambiguate with the CuTe Python DSL. +* `cutlass_library`: utilities used for enumerating and emitting C++ code for CUTLASS kernels targeting Intel Xe architecture -## CUTLASS Python Interface +## SYCL*TLA Python Interface -The CUTLASS Python interface enables one to compile and run CUTLASS operations from within Python. +The SYCL*TLA Python interface enables one to compile and run CUTLASS operations on Intel GPUs from within Python. ```python -import cutlass +import cutlass_cppgen as cutlass import numpy as np plan = cutlass.op.Gemm(element=np.float16, layout=cutlass.LayoutType.RowMajor) @@ -22,42 +22,38 @@ plan.run(A, B, C, D) ### Overview -The CUTLASS Python interface prioritizes ease of use. +The SYCL*TLA Python interface prioritizes ease of use for Intel GPU development. It has the following features that support this goal. -* It presents high-level interfaces for operators, that require only few parameters. +* It presents high-level interfaces for operators targeting Intel Xe architecture, that require only few parameters. * It selects sensible default configurations for an operator given the parameters that have been specified. -* It enumerates configurations for users that are known to work in a given setting. +* It enumerates configurations for users that are known to work on Intel GPUs in a given setting. * It favors emitting descriptive Python run-time exceptions instead of C++ compile-time errors, where possible. -* It simplifies exporting CUTLASS kernels to framework extensions (e.g., PyTorch CUDA extensions). +* It simplifies exporting CUTLASS kernels for Intel GPUs to framework extensions (e.g., PyTorch XPU extensions). #### Non-goals -The CUTLASS Python interface does not intend to: +The SYCL*TLA Python interface does not intend to: 1. select optimal kernel configurations, 2. act as a fast container for CUTLASS kernels, or -3. act as a Python-to-CUDA-kernel just-in-time (JIT) compilation engine. +3. act as a Python-to-SYCL-kernel just-in-time (JIT) compilation engine. Regarding selection of optimal kernel configurations, the interface favors ease-of-use over maximum configurability. Thus, its default selections for operator parameters may not achieve the highest possible performance in all scenarios. Users wishing to achieve the highest performance possible should either -* select parameters by profiling different combinations of them, or -* use a library such as [cuBLAS](https://developer.nvidia.com/cublas) - that contains heuristics for selecting kernels. - Regarding acting as a fast container for CUTLASS kernels: the interface does not strive to minimize overhead in its Python functions surrounding the running of a kernel. Those wishing to deploy a CUTLASS kernel should either * use the C++ emitted by the Python interface directly, or -* use one of the CUTLASS emitters for automatically creating a framework extension for the kernel (e.g., a PyTorch CUDA extension). +* use one of the CUTLASS emitters for automatically creating a framework extension for the kernel (e.g., a PyTorch XPU extension). -Regarding acting as a Python-to-CUDA-kernel JIT compilation engine: -the interface enables use of CUTLASS in Python code. +Regarding acting as a Python-to-SYCL-kernel JIT compilation engine: +the interface enables use of CUTLASS on Intel GPUs in Python code. It can be used by frameworks for JIT compiling -Python to CUDA kernels, but does not set out to be such a framework. +Python to SYCL kernels, but does not set out to be such a framework. #### Comparison to PyCUTLASS @@ -70,56 +66,97 @@ to operators -- similar to what one must do in specifying template parameters to In contrast, the CUTLASS Python interface aims to provide a higher-level API for declaring, emitting, and compiling kernels that does not require exhaustively defining template parameters. -### Current functionality -The CUTLASS Python interface currently supports the following operations: -* GEMMs -* GEMMs with fused elementwise epilogues (e.g., ReLU) (for pre-SM90 kernels) -* Stream K swizzling (for pre-SM90 kernels) -* Grouped GEMM (for pre-SM90 kernels) +### Current functionality with python interface +The SYCL*TLA Python interface currently supports the following operations on Intel GPUs: +* GEMMs with Intel Xe DPAS (Dot Product Accumulate Systolic) operations +* GEMMs with fused elementwise epilogues (e.g., ReLU, GELU) + ### Getting started -We recommend using the CUTLASS Python interface via an [NGC PyTorch Docker container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch): +We recommend using the SYCL*TLA Python interface with an Intel oneAPI environment and Intel GPU drivers: ```bash -docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.08-py3 -p 8888:8888 +# Install Intel oneAPI toolkit and GPU drivers +# Refer to Intel's documentation for your specific system ``` -The CUTLASS Python interface has been tested with CUDA 11.8, 12.0, and 12.1 on Python 3.8 and 3.9. +#### Prerequisites + +Before installing the SYCL*TLA Python interface, ensure you have the following: + +**System Requirements:** +* Intel GPU (Intel Arc, Intel Data Center GPU Max series, or compatible Intel Xe architecture) +* Intel GPU drivers (latest recommended) +* Intel oneAPI DPC++ compiler +* Python 3.12 (recommended) + +**Required packages:** +* `dpctl`: Intel's Data Parallel Control library for Python +* `torch` with Intel XPU support (if using PyTorch integration) + +The SYCL*TLA Python interface has been tested with Intel oneAPI 2024.2+ and Python 3.12 on Intel Xe GPUs. + +#### Environment setup + +Set up the required environment variables: + +```bash +# Source Intel oneAPI environment +source /opt/intel/oneapi/setvars.sh + +# Set SYCL environment variables +export CUTLASS_USE_SYCL=1 +export ONEAPI_DEVICE_SELECTOR=level_zero:gpu + +# Verify Intel GPU is detected +sycl-ls +``` #### Optional environment variables -Prior to installing the CUTLASS Python interface, one may optionally set the following environment variables: +Prior to installing the SYCL*TLA Python interface, one may optionally set the following environment variables: -* `CUTLASS_PATH`: the path to the cloned CUTLASS repository -* `CUDA_INSTALL_PATH`: the path to the installation of CUDA +* `CUTLASS_PATH`: the path to the cloned SYCL*TLA repository +* `ONEAPI_ROOT`: the path to the Intel oneAPI installation (typically `/opt/intel/oneapi`) If these environment variables are not set, the installation process will infer them to be the following: * `CUTLASS_PATH`: either one directory level above the current directory (i.e., `$(pwd)/..`) if installed locally or in the `source` directory of the location in which `cutlass_library` was installed -* `CUDA_INSTALL_PATH`: the directory holding `/bin/nvcc` for the first version of `nvcc` on `$PATH` (i.e., `which nvcc | awk -F'/bin/nvcc' '{print $1}'`) - -**NOTE:** The version of `cuda-python` installed must match the CUDA version in `CUDA_INSTALL_PATH`. +* `ONEAPI_ROOT`: the default Intel oneAPI installation path #### Installation -Stable releases of the CUTLASS Python interface are available via the `nvidia-cutlass` PyPI package. Any other packages with the name `cutlass` are not affiliated with NVIDIA CUTLASS. +Stable releases of the SYCL*TLA Python interface are available via the `sycl-tla` PyPI package. ```bash -pip install nvidia-cutlass +pip install sycl-tla ``` -The CUTLASS Python interface can also be installed from source by navigating to the root of the CUTLASS directory and performing +You will also need to install the required dependencies: + +```bash +# Install dpctl (Intel Data Parallel Control library) +pip install dpctl + +# Install dpctl for Pytorch 2.9 +pip install dpctl intel-cmplr-lib-rt==2025.2.1 + +# Install Intel PyTorch XPU support (optional, for PyTorch integration) +pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/test/xpu +``` + +The SYCL*TLA Python interface can also be installed from source by navigating to the root of the SYCL*TLA directory and performing ```bash pip install . ``` -If you would like to be able to make changes to the CUTLASS Python interface and have them reflected when using the interface, perform: +If you would like to be able to make changes to the SYCL*TLA Python interface and have them reflected when using the interface, perform: ```bash pip install -e . ``` To test that your installation was successful, you can run: ```python -import cutlass +import cutlass_cppgen as cutlass import numpy as np plan = cutlass.op.Gemm(element=np.float16, layout=cutlass.LayoutType.RowMajor) @@ -127,65 +164,31 @@ A, B, C, D = [np.ones((128, 128), dtype=np.float16) for i in range(4)] plan.run(A, B, C, D) ``` -### Deep learning framework CUDA extensions -The CUTLASS Python interface provides utilities for exporting a CUTLASS kernel to a deep learning framework CUDA extensions. Currently, PyTorch CUDA extensions can be exported, but a similar pattern could be applied for other frameworks as well. An example of this is provided [here](/examples/python/02_pytorch_extension_grouped_gemm.ipynb). - -Currently, the following operations can be exported to a PyTorch CUDA extension: -* GEMM -* Grouped GEMM -* Conv2d - -### Examples - -Jupyter notebook examples of using the CUTLASS Python interface are located in [examples/python](/examples/python). - -To launch these notebooks from this directory, run: -```bash -jupyter-lab ../examples/python -``` - -### Building documentation - -The CUTLASS Python interface uses [Sphinx](https://www.sphinx-doc.org/en/master/) for documentation. - -Building the documentation requires additional packages. The following commands will install them. -```bash -sudo apt-get install pandoc -pip install --upgrade Sphinx furo pandoc myst-parser sphinx-copybutton nbsphinx nbsphinx-link sphinx-inline-tabs -# Needed for jupyter notebooks -pip install ipykernel -# Needed as latest versions are not compatible with Sphinx as of now. -pip install docutils~=0.20.0 -``` +### Deep learning framework XPU extensions +The SYCL*TLA Python interface provides utilities for exporting a CUTLASS kernel to deep learning framework Intel XPU extensions. Currently, PyTorch XPU extensions can be exported, but a similar pattern could be applied for other frameworks as well. -To build documentation, you must first have installed the CUTLASS Python interface via the -[installation instructions](#installation). +Currently, the following operations can be exported to a PyTorch XPU extension: +* GEMM for Intel Xe -Documentation can then be built via the following commands. -```bash -sphinx-apidoc -o docs_src/source/ cutlass/ cutlass/backend* -cd docs_src -make html -mv _build/* ../docs -``` -## CUTLASS library package +## SYCL*TLA library package (with CUTLASS compatibility) -[cutlass_library](/python/cutlass_library) contains utilities for enumerating and emitting CUTLASS C++ kernels. -It is used by the CUTLASS CMake system to construct a library of kernels that can be profiled using the CUTLASS profiler. +[cutlass_library](/python/cutlass_library) contains utilities for enumerating and emitting CUTLASS C++ kernels for Intel Xe architecture. +It is used by the SYCL*TLA CMake system to construct a library of kernels that can be profiled using the CUTLASS profiler on Intel GPUs. To install the `cutlass_library` package, run ```bash python setup_library.py develop --user ``` -Alternatively, `cutlass_library` will automatically be installed if you install the CUTLASS Python interface package. +Alternatively, `cutlass_library` will automatically be installed if you install the SYCL*TLA Python interface package. You can also use the [generator.py](/python/cutlass_library/generator.py) script directly without installing the module. # Copyright Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +Copyright (c) 2024 - 2025 Intel Corporation. All rights reserved. SPDX-License-Identifier: BSD-3-Clause ``` diff --git a/python/cutlass/__init__.py b/python/cutlass_cppgen/__init__.py similarity index 99% rename from python/cutlass/__init__.py rename to python/cutlass_cppgen/__init__.py index e3172aae76..7c9203f344 100644 --- a/python/cutlass/__init__.py +++ b/python/cutlass_cppgen/__init__.py @@ -135,7 +135,7 @@ def get_option_registry(): this._option_registry = OptionRegistry(device_cc()) return this._option_registry -this.__version__ = '4.1.0' +this.__version__ = '4.2.1' from cutlass_cppgen.backend import create_memory_pool from cutlass_cppgen.emit.pytorch import pytorch diff --git a/python/cutlass/backend/__init__.py b/python/cutlass_cppgen/backend/__init__.py similarity index 100% rename from python/cutlass/backend/__init__.py rename to python/cutlass_cppgen/backend/__init__.py diff --git a/python/cutlass/backend/arguments.py b/python/cutlass_cppgen/backend/arguments.py similarity index 100% rename from python/cutlass/backend/arguments.py rename to python/cutlass_cppgen/backend/arguments.py diff --git a/python/cutlass/backend/c_types.py b/python/cutlass_cppgen/backend/c_types.py similarity index 99% rename from python/cutlass/backend/c_types.py rename to python/cutlass_cppgen/backend/c_types.py index c6400296b6..8cec99eb42 100644 --- a/python/cutlass/backend/c_types.py +++ b/python/cutlass_cppgen/backend/c_types.py @@ -251,6 +251,9 @@ class _HardwareInfo(ctypes.Structure): _fields_ = [ ("device_id", ctypes.c_int), ("sm_count", ctypes.c_int), + ("max_active_clusters", ctypes.c_int), + ("cluster_shape", dim3_), + ("cluster_shape_fallback", dim3_), ] class _GemmArguments(ctypes.Structure): diff --git a/python/cutlass/backend/compiler.py b/python/cutlass_cppgen/backend/compiler.py similarity index 99% rename from python/cutlass/backend/compiler.py rename to python/cutlass_cppgen/backend/compiler.py index 33c6ae9698..0b7c7d04c5 100644 --- a/python/cutlass/backend/compiler.py +++ b/python/cutlass_cppgen/backend/compiler.py @@ -104,8 +104,8 @@ def _encode(self): arch_flag = f"-fsycl-targets={self.arch}" else: arch_flag = f"-arch=sm_{self.arch}" - if self.arch == 90 and int(cutlass_cppgen.nvcc_version().split('.')[0]) >= 12: - arch_flag += "a" + if self.arch in [90, 100, 101, 103, 120, 121] and int(cutlass_cppgen.nvcc_version().split('.')[0]) >= 12: + arch_flag += "a" opts.append(arch_flag) diff --git a/python/cutlass/backend/conv2d_operation.py b/python/cutlass_cppgen/backend/conv2d_operation.py similarity index 100% rename from python/cutlass/backend/conv2d_operation.py rename to python/cutlass_cppgen/backend/conv2d_operation.py diff --git a/python/cutlass/backend/epilogue.py b/python/cutlass_cppgen/backend/epilogue.py similarity index 100% rename from python/cutlass/backend/epilogue.py rename to python/cutlass_cppgen/backend/epilogue.py diff --git a/python/cutlass/backend/evt/__init__.py b/python/cutlass_cppgen/backend/evt/__init__.py similarity index 100% rename from python/cutlass/backend/evt/__init__.py rename to python/cutlass_cppgen/backend/evt/__init__.py diff --git a/python/cutlass/backend/evt/backend/__init__.py b/python/cutlass_cppgen/backend/evt/backend/__init__.py similarity index 93% rename from python/cutlass/backend/evt/backend/__init__.py rename to python/cutlass_cppgen/backend/evt/backend/__init__.py index a165454834..945dcf80e3 100644 --- a/python/cutlass/backend/evt/backend/__init__.py +++ b/python/cutlass_cppgen/backend/evt/backend/__init__.py @@ -34,3 +34,5 @@ import cutlass_cppgen.backend.evt.backend.sm80_nodes as sm80_nodes from cutlass_cppgen.backend.evt.backend.sm90_emitter import Sm90Emitter import cutlass_cppgen.backend.evt.backend.sm90_nodes as sm90_nodes +from cutlass_cppgen.backend.evt.backend.sm100_emitter import Sm100Emitter +import cutlass_cppgen.backend.evt.backend.sm100_nodes as sm100_nodes diff --git a/python/cutlass/backend/evt/backend/emitter_base.py b/python/cutlass_cppgen/backend/evt/backend/emitter_base.py similarity index 98% rename from python/cutlass/backend/evt/backend/emitter_base.py rename to python/cutlass_cppgen/backend/evt/backend/emitter_base.py index 39723844e8..72a7d8c04d 100644 --- a/python/cutlass/backend/evt/backend/emitter_base.py +++ b/python/cutlass_cppgen/backend/evt/backend/emitter_base.py @@ -52,6 +52,7 @@ def __init__(self, dag_ir: DAGIR, cc: int, emit_CD=True) -> None: self.dag_ir = dag_ir self.emit_CD = emit_CD self.cc = cc + self.evt_cc = 90 if cc >= 90 else cc if self.cc < 90: self.namespace = "threadblock" else: @@ -103,7 +104,7 @@ def emit_evt(self, node): return "" evt_tmp = f""" -using EVT{node.name_camel} = cutlass::epilogue::{self.namespace}::Sm{self.cc}EVT< +using EVT{node.name_camel} = cutlass::epilogue::{self.namespace}::Sm{self.evt_cc}EVT< {node.name_camel}, """ sorted_children = self.dag_ir.get_all_inputs(node.name) @@ -140,7 +141,7 @@ def emit_dag(self, node): dag_nodes = ",\n".join(dag_node_strs) return f""" -using {node.name_camel} = cutlass::epilogue::{self.namespace}::Sm{self.cc}TopologicalVisitor< +using {node.name_camel} = cutlass::epilogue::{self.namespace}::Sm{self.evt_cc}TopologicalVisitor< {DataTypeTag[node.subgraph.element_compute]}, {edge_tuples}, {dag_nodes} diff --git a/python/cutlass_cppgen/backend/evt/backend/sm100_emitter.py b/python/cutlass_cppgen/backend/evt/backend/sm100_emitter.py new file mode 100644 index 0000000000..db521e5279 --- /dev/null +++ b/python/cutlass_cppgen/backend/evt/backend/sm100_emitter.py @@ -0,0 +1,116 @@ +################################################################################################# +# +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Emitter for Sm100 Epilogue Visitor +""" + +from cutlass_library import DataType, DataTypeTag, EpilogueScheduleTag, OpcodeClassTag +from cutlass_cppgen.backend.library import to_blackwell_threadblock_shape +from cutlass_cppgen.backend import GemmOperationUniversal +from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks +from cutlass_cppgen.backend.evt.ir.node import TupleEmitter + + +class Sm100CollectiveEpilogue: + def __init__(self, tile_description, + kernel_schedule, + epilogue_schedule, + element_accumulator, + element_d, + fusion_callbacks) -> None: + + self.cta_tile_mnk, _ = to_blackwell_threadblock_shape(tile_description, tile_description.cluster_shape, kernel_schedule) + self.element_accumulator = element_accumulator + if fusion_callbacks.dag_ir.has_node("C"): + self.element_c = fusion_callbacks.dag_ir.get_node_meta("C").element + else: + self.element_c = DataType.void + self.element_d = element_d + self.schedule = epilogue_schedule + self.fusion_callbacks = fusion_callbacks + self.opclass = tile_description.math_instruction.opcode_class + + @property + def CtaTileMNK(self) -> str: + """ + The threadblock shape + """ + return f"cute::Shape<_{self.cta_tile_mnk[0]}, _{self.cta_tile_mnk[1]}, _{self.cta_tile_mnk[2]}>" + + @property + def EpilogueTileType(self) -> str: + """ + The epilogue tile type + """ + return "cutlass::epilogue::collective::EpilogueTileAuto" + + @property + def Schedule(self) -> str: + return EpilogueScheduleTag[self.schedule] + + def emit(self): + tuple_emitter = TupleEmitter("int64_t") + stride_D_str = self.fusion_callbacks.dag_ir.get_node_meta("D").underlying_impl.stride_mnl + stride_C_str = stride_D_str + if self.fusion_callbacks.dag_ir.has_node("C"): + stride_C_str = self.fusion_callbacks.dag_ir.get_node_meta("C").underlying_impl.stride_mnl + + callback_decl, callback_name = self.fusion_callbacks.emit() + return callback_name, f""" +using EpilogueDescriptor = cutlass::epilogue::collective::detail::Sm100EpilogueDescriptor< + {OpcodeClassTag[self.opclass]}, + {self.CtaTileMNK}, {self.EpilogueTileType}, + {DataTypeTag[self.element_accumulator]}, {DataTypeTag[self.element_c]}, {DataTypeTag[self.element_d]}, + {self.Schedule}, {stride_C_str}, {stride_D_str}, + false /* IsPerColScaleSupported */, + false /* IsBlockScaleSupported */ +>; +{callback_decl} +""" + + +class Sm100Emitter: + def __init__(self, operation: GemmOperationUniversal, graph) -> None: + fusion_callbacks = FusionCallbacks(graph, cc=100, emit_CD=False) + + self.collective_epilogue = Sm100CollectiveEpilogue( + tile_description=operation.tile_description, + kernel_schedule=operation.tile_description.kernel_schedule, + epilogue_schedule=operation.tile_description.epilogue_schedule, + element_accumulator=operation.tile_description.math_instruction.element_accumulator, + element_d=fusion_callbacks.dag_ir.get_node_meta("D").element, + fusion_callbacks=fusion_callbacks + ) + + def emit(self): + return self.collective_epilogue.emit() diff --git a/python/cutlass_cppgen/backend/evt/backend/sm100_nodes.py b/python/cutlass_cppgen/backend/evt/backend/sm100_nodes.py new file mode 100644 index 0000000000..33e77b4c9f --- /dev/null +++ b/python/cutlass_cppgen/backend/evt/backend/sm100_nodes.py @@ -0,0 +1,134 @@ +################################################################################################# +# +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from pycute import product + +from cutlass_library import DataTypeSize, DataTypeTag + +from cutlass_cppgen.backend.evt.ir import AuxLoadImpl, AuxStoreImpl +import cutlass_cppgen.backend.evt.backend.sm90_nodes as sm90_nodes + +from cutlass_cppgen.backend.library import FloatRoundStyleTag + + +Sm100AccumulatorImpl = sm90_nodes.Sm90AccumulatorImpl +Sm100LoadSrcImpl = sm90_nodes.Sm90LoadSrcImpl +Sm100ScalarBroadcastImpl = sm90_nodes.Sm90ScalarBroadcastImpl +Sm100RowBroadcastImpl = sm90_nodes.Sm90RowBroadcastImpl +Sm100ColumnBroadcastImpl = sm90_nodes.Sm90ColumnBroadcastImpl +Sm100ComputeImpl = sm90_nodes.Sm90ComputeImpl +Sm100StoreDImpl = sm90_nodes.Sm90StoreDImpl +Sm100ColumnReductionImpl = sm90_nodes.Sm90ColumnReductionImpl +Sm100RowReductionImpl = sm90_nodes.Sm90RowReductionImpl +Sm100ScalarReductionImpl = sm90_nodes.Sm90ScalarReductionImpl + + +class Sm100AuxLoadImpl(AuxLoadImpl): + + @property + def descriptor(self) -> str: + """ + Descriptor for Aux Load + """ + return f"{self.name_camel}Descriptor" + + def decl_descriptor(self) -> str: + """ + Declare the descriptor type + """ + return f"\nusing {self.descriptor} = cutlass::epilogue::collective::detail::Sm100AuxLoadDescriptor;\n" + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = self.decl_descriptor() + self._type_decl += f""" +using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxLoad< + {self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]}, + {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, typename {self.descriptor}::CopyOpS2R +>; +""" + return self._type_decl + + def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles): + """ + Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d + """ + return (DataTypeSize[self.element] * stages_c * product(epilogue_tile_mn) // 8, 128) + + +class Sm100AuxStoreImpl(AuxStoreImpl): + + @property + def descriptor(self) -> str: + """ + Descriptor for Aux Load + """ + return f"{self.name_camel}Descriptor" + + def decl_descriptor(self) -> str: + """ + Declare the descriptor type + """ + return f""" +using {self.descriptor} = cutlass::epilogue::collective::detail::Sm100AuxStoreDescriptor< + EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]} +>; +""" + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = self.decl_descriptor() + self._type_decl += f""" +using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxStore< + {self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]}, + {FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, + typename {self.descriptor}::CopyOpR2S +>; +""" + return self._type_decl + + def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles): + """ + Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d + """ + return (DataTypeSize[self.element] * stages_d * product(epilogue_tile_mn) // 8, 128) diff --git a/python/cutlass/backend/evt/backend/sm80_emitter.py b/python/cutlass_cppgen/backend/evt/backend/sm80_emitter.py similarity index 100% rename from python/cutlass/backend/evt/backend/sm80_emitter.py rename to python/cutlass_cppgen/backend/evt/backend/sm80_emitter.py diff --git a/python/cutlass/backend/evt/backend/sm80_nodes.py b/python/cutlass_cppgen/backend/evt/backend/sm80_nodes.py similarity index 100% rename from python/cutlass/backend/evt/backend/sm80_nodes.py rename to python/cutlass_cppgen/backend/evt/backend/sm80_nodes.py diff --git a/python/cutlass/backend/evt/backend/sm90_emitter.py b/python/cutlass_cppgen/backend/evt/backend/sm90_emitter.py similarity index 100% rename from python/cutlass/backend/evt/backend/sm90_emitter.py rename to python/cutlass_cppgen/backend/evt/backend/sm90_emitter.py diff --git a/python/cutlass/backend/evt/backend/sm90_nodes.py b/python/cutlass_cppgen/backend/evt/backend/sm90_nodes.py similarity index 100% rename from python/cutlass/backend/evt/backend/sm90_nodes.py rename to python/cutlass_cppgen/backend/evt/backend/sm90_nodes.py diff --git a/python/cutlass/backend/evt/epilogue.py b/python/cutlass_cppgen/backend/evt/epilogue.py similarity index 98% rename from python/cutlass/backend/evt/epilogue.py rename to python/cutlass_cppgen/backend/evt/epilogue.py index 92f71a1082..da446e76d9 100644 --- a/python/cutlass/backend/evt/epilogue.py +++ b/python/cutlass_cppgen/backend/evt/epilogue.py @@ -70,7 +70,7 @@ def __init__(self, cc: int, visitor, element_compute=DataType.f32) -> None: # Epilogue Thread Type epilogue_thread_type = self.visitor.epilogue_thread_type - if cc == 90: + if cc_map[cc] in [90, 100]: self.arg_c_type = self.visitor.arg_c_type self.arg_d_type = self.visitor.arg_d_type output_names = self.visitor.return_names @@ -114,7 +114,7 @@ def get_tensor_ptr(self, tensor_name, kwargs, is_output=False): Helper function for extracting device pointer """ # Skip the special tensors - if cc == 90: + if cc in [90, 100]: if tensor_name in ["C", "D"]: return 0 if tensor_name not in kwargs.keys(): diff --git a/python/cutlass/backend/evt/frontend/__init__.py b/python/cutlass_cppgen/backend/evt/frontend/__init__.py similarity index 100% rename from python/cutlass/backend/evt/frontend/__init__.py rename to python/cutlass_cppgen/backend/evt/frontend/__init__.py diff --git a/python/cutlass/backend/evt/frontend/frontend_base.py b/python/cutlass_cppgen/backend/evt/frontend/frontend_base.py similarity index 98% rename from python/cutlass/backend/evt/frontend/frontend_base.py rename to python/cutlass_cppgen/backend/evt/frontend/frontend_base.py index c150bf2046..213aafdbe3 100644 --- a/python/cutlass/backend/evt/frontend/frontend_base.py +++ b/python/cutlass_cppgen/backend/evt/frontend/frontend_base.py @@ -56,6 +56,7 @@ PassPreprocessRed, PassShapeTypePropagation, ) +from cutlass_cppgen.backend.evt.passes.util import cc_map from cutlass_cppgen.backend.utils import device_cc from cutlass_cppgen.epilogue.evt_ops import permute, reshape from cutlass_cppgen.utils.datatypes import library_type @@ -119,7 +120,7 @@ def trace(self, *args, **kwargs): self.pass_manager() # Set the epilogue type self.epilogue_thread_type = self.dag_ir.epilogue_thread_type - if self.cc == 90: + if cc_map[self.cc] in [90, 100]: self.arg_c_type = self.dag_ir.arg_c_type self.arg_d_type = self.dag_ir.arg_d_type self.reduction_names = self.dag_ir.reduction_names diff --git a/python/cutlass/backend/evt/frontend/python_ast.py b/python/cutlass_cppgen/backend/evt/frontend/python_ast.py similarity index 100% rename from python/cutlass/backend/evt/frontend/python_ast.py rename to python/cutlass_cppgen/backend/evt/frontend/python_ast.py diff --git a/python/cutlass/backend/evt/ir/__init__.py b/python/cutlass_cppgen/backend/evt/ir/__init__.py similarity index 100% rename from python/cutlass/backend/evt/ir/__init__.py rename to python/cutlass_cppgen/backend/evt/ir/__init__.py diff --git a/python/cutlass/backend/evt/ir/compute_nodes.py b/python/cutlass_cppgen/backend/evt/ir/compute_nodes.py similarity index 100% rename from python/cutlass/backend/evt/ir/compute_nodes.py rename to python/cutlass_cppgen/backend/evt/ir/compute_nodes.py diff --git a/python/cutlass/backend/evt/ir/dag_ir.py b/python/cutlass_cppgen/backend/evt/ir/dag_ir.py similarity index 100% rename from python/cutlass/backend/evt/ir/dag_ir.py rename to python/cutlass_cppgen/backend/evt/ir/dag_ir.py diff --git a/python/cutlass/backend/evt/ir/layout_algorithm.py b/python/cutlass_cppgen/backend/evt/ir/layout_algorithm.py similarity index 100% rename from python/cutlass/backend/evt/ir/layout_algorithm.py rename to python/cutlass_cppgen/backend/evt/ir/layout_algorithm.py diff --git a/python/cutlass/backend/evt/ir/layout_nodes.py b/python/cutlass_cppgen/backend/evt/ir/layout_nodes.py similarity index 100% rename from python/cutlass/backend/evt/ir/layout_nodes.py rename to python/cutlass_cppgen/backend/evt/ir/layout_nodes.py diff --git a/python/cutlass/backend/evt/ir/load_nodes.py b/python/cutlass_cppgen/backend/evt/ir/load_nodes.py similarity index 100% rename from python/cutlass/backend/evt/ir/load_nodes.py rename to python/cutlass_cppgen/backend/evt/ir/load_nodes.py diff --git a/python/cutlass/backend/evt/ir/node.py b/python/cutlass_cppgen/backend/evt/ir/node.py similarity index 93% rename from python/cutlass/backend/evt/ir/node.py rename to python/cutlass_cppgen/backend/evt/ir/node.py index e2b3a34a7b..606591b8e7 100644 --- a/python/cutlass/backend/evt/ir/node.py +++ b/python/cutlass_cppgen/backend/evt/ir/node.py @@ -43,6 +43,28 @@ from cutlass_cppgen.backend.evt.ir.tensor import Tensor +class TupleEmitter: + """ + Emit the cute tuple to C++ code + """ + def __init__(self, stride_dtype): + self.stride_dtype = stride_dtype + + def emit(self, py_tuple): + if isinstance(py_tuple, int): + if py_tuple in [0, 1]: + return f"cute::Int<{py_tuple}>" + else: + return f"{self.stride_dtype}" + elif isinstance(py_tuple, tuple): + decl = "cute::Stride<" + for item in py_tuple: + decl += self.emit(item) + ", " + return decl[:-2] + ">" + else: + raise ValueError(f"TupleEmitter.emit only accepts tuple or int, got {type(py_tuple).__name__}") + + class ImplBase: """ Base class for Node Implementation @@ -52,7 +74,15 @@ def __init__(self, node) -> None: self.name = node.name self.tensor = node.tensor self._type_decl = None - self.stride_dtype = "int64_t" + self.tuple_emitter = TupleEmitter("int64_t") + + @property + def stride_dtype(self): + return self.tuple_emitter.stride_dtype + + @stride_dtype.setter + def stride_dtype(self, stride_dtype): + self.tuple_emitter.stride_dtype = stride_dtype @staticmethod def match(node, problem_size: tuple): @@ -81,30 +111,13 @@ def name_camel(self) -> str: """ return sub(r"(_|-)+", " ", self.name).title().replace(" ", "") - def _emit_cute_tuple(self, py_tuple): - """ - Emit the cute tuple to C++ code - """ - if isinstance(py_tuple, int): - if py_tuple in [0, 1]: - return f"cute::Int<{py_tuple}>" - else: - return f"{self.stride_dtype}" - elif isinstance(py_tuple, tuple): - decl = "cute::Stride<" - for item in py_tuple: - decl += self._emit_cute_tuple(item) + ", " - return decl[:-2] + ">" - else: - raise ValueError(f"_emit_cute_tuple only accepts tuple or int, got {type(py_tuple).__name__}") - @property def stride_mnl(self): """ Typename StrideMNL """ stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2])))) - return self._emit_cute_tuple(stride) + return self.tuple_emitter.emit(stride) def get_non_constant_stride(self, py_tuple): if isinstance(py_tuple, int): diff --git a/python/cutlass/backend/evt/ir/store_nodes.py b/python/cutlass_cppgen/backend/evt/ir/store_nodes.py similarity index 100% rename from python/cutlass/backend/evt/ir/store_nodes.py rename to python/cutlass_cppgen/backend/evt/ir/store_nodes.py diff --git a/python/cutlass/backend/evt/ir/tensor.py b/python/cutlass_cppgen/backend/evt/ir/tensor.py similarity index 100% rename from python/cutlass/backend/evt/ir/tensor.py rename to python/cutlass_cppgen/backend/evt/ir/tensor.py diff --git a/python/cutlass/backend/evt/passes/__init__.py b/python/cutlass_cppgen/backend/evt/passes/__init__.py similarity index 100% rename from python/cutlass/backend/evt/passes/__init__.py rename to python/cutlass_cppgen/backend/evt/passes/__init__.py diff --git a/python/cutlass/backend/evt/passes/graph_drawer.py b/python/cutlass_cppgen/backend/evt/passes/graph_drawer.py similarity index 100% rename from python/cutlass/backend/evt/passes/graph_drawer.py rename to python/cutlass_cppgen/backend/evt/passes/graph_drawer.py diff --git a/python/cutlass/backend/evt/passes/pass_argument_type.py b/python/cutlass_cppgen/backend/evt/passes/pass_argument_type.py similarity index 93% rename from python/cutlass/backend/evt/passes/pass_argument_type.py rename to python/cutlass_cppgen/backend/evt/passes/pass_argument_type.py index c458f79978..b0c3cdbde6 100644 --- a/python/cutlass/backend/evt/passes/pass_argument_type.py +++ b/python/cutlass_cppgen/backend/evt/passes/pass_argument_type.py @@ -40,6 +40,7 @@ from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation +from cutlass_cppgen.backend.evt.passes.util import cc_map class PassGetArgumentType(EVTPassBase): @@ -54,9 +55,9 @@ class PassGetArgumentType(EVTPassBase): def requires(self) -> None: # Check "D" is in the node list - if self.cc == 90 and (not self.dag_ir.has_node("D")): + if cc_map[self.cc] in [90, 100] and (not self.dag_ir.has_node("D")): raise SyntaxError( - "Sm90 EVT requires the epilogue to have a returned tensor D, " + "Sm90+ EVT requires the epilogue to have a returned tensor D, " "but the variable 'D' is not found in the return values.") def call(self): @@ -66,7 +67,7 @@ def call(self): meta = self.dag_ir.get_node_meta(node) if not meta.disabled: self.argument_types[node] = meta.underlying_impl.argument_type - if node == "D" and self.cc == 90: + if node == "D" and cc_map[self.cc] in [90, 100]: continue if isinstance(meta, TopoVisitorNode): self.get_dag_argument_type(node) @@ -111,6 +112,9 @@ def sm90_set_argument_type(self): else: self.dag_ir.arg_c_type = self.dag_ir.arg_d_type + def sm100_set_argument_type(self): + self.sm90_set_argument_type() + def sm80_set_argument_type(self): nodes = self.dag_ir.nodes_topological_order() self.dag_ir.epilogue_thread_type = self.argument_types[nodes[-1]] diff --git a/python/cutlass/backend/evt/passes/pass_dag_2_tree.py b/python/cutlass_cppgen/backend/evt/passes/pass_dag_2_tree.py similarity index 99% rename from python/cutlass/backend/evt/passes/pass_dag_2_tree.py rename to python/cutlass_cppgen/backend/evt/passes/pass_dag_2_tree.py index 5eae2f92e8..469769664a 100644 --- a/python/cutlass/backend/evt/passes/pass_dag_2_tree.py +++ b/python/cutlass_cppgen/backend/evt/passes/pass_dag_2_tree.py @@ -105,7 +105,7 @@ def call(self): output_node = None if (self.dag_ir.cc >= 90): - # For SM90, the lca should be the input node of D + # For SM90+, the lca should be the input node of D if (not self.dag_ir.has_node("D")): raise RuntimeError(f"D is not a node in the DAG IR.") output_node = "D" diff --git a/python/cutlass/backend/evt/passes/pass_fix_element_d.py b/python/cutlass_cppgen/backend/evt/passes/pass_fix_element_d.py similarity index 100% rename from python/cutlass/backend/evt/passes/pass_fix_element_d.py rename to python/cutlass_cppgen/backend/evt/passes/pass_fix_element_d.py diff --git a/python/cutlass/backend/evt/passes/pass_get_impl.py b/python/cutlass_cppgen/backend/evt/passes/pass_get_impl.py similarity index 100% rename from python/cutlass/backend/evt/passes/pass_get_impl.py rename to python/cutlass_cppgen/backend/evt/passes/pass_get_impl.py diff --git a/python/cutlass/backend/evt/passes/pass_layout_elimination.py b/python/cutlass_cppgen/backend/evt/passes/pass_layout_elimination.py similarity index 100% rename from python/cutlass/backend/evt/passes/pass_layout_elimination.py rename to python/cutlass_cppgen/backend/evt/passes/pass_layout_elimination.py diff --git a/python/cutlass/backend/evt/passes/pass_manager.py b/python/cutlass_cppgen/backend/evt/passes/pass_manager.py similarity index 100% rename from python/cutlass/backend/evt/passes/pass_manager.py rename to python/cutlass_cppgen/backend/evt/passes/pass_manager.py diff --git a/python/cutlass/backend/evt/passes/pass_no_op_elimination.py b/python/cutlass_cppgen/backend/evt/passes/pass_no_op_elimination.py similarity index 100% rename from python/cutlass/backend/evt/passes/pass_no_op_elimination.py rename to python/cutlass_cppgen/backend/evt/passes/pass_no_op_elimination.py diff --git a/python/cutlass/backend/evt/passes/pass_preprocess_red.py b/python/cutlass_cppgen/backend/evt/passes/pass_preprocess_red.py similarity index 100% rename from python/cutlass/backend/evt/passes/pass_preprocess_red.py rename to python/cutlass_cppgen/backend/evt/passes/pass_preprocess_red.py diff --git a/python/cutlass/backend/evt/passes/pass_shape_type_propagation.py b/python/cutlass_cppgen/backend/evt/passes/pass_shape_type_propagation.py similarity index 100% rename from python/cutlass/backend/evt/passes/pass_shape_type_propagation.py rename to python/cutlass_cppgen/backend/evt/passes/pass_shape_type_propagation.py diff --git a/python/cutlass/backend/evt/passes/smem_size_calculator.py b/python/cutlass_cppgen/backend/evt/passes/smem_size_calculator.py similarity index 61% rename from python/cutlass/backend/evt/passes/smem_size_calculator.py rename to python/cutlass_cppgen/backend/evt/passes/smem_size_calculator.py index 4896840e5a..8168c59733 100644 --- a/python/cutlass/backend/evt/passes/smem_size_calculator.py +++ b/python/cutlass_cppgen/backend/evt/passes/smem_size_calculator.py @@ -34,12 +34,14 @@ Compute the shared memory size in bytes """ +from math import gcd + import cutlass_library -from pycute import shape_div, product +from pycute import flatten, shape_div, product import cutlass_cppgen from cutlass_cppgen.backend.evt.ir import TopoVisitorNode, DAGIR -from cutlass_cppgen.backend.library import DataTypeSize +from cutlass_cppgen.backend.library import DataType, DataTypeSize class GetSmemSize: @@ -58,12 +60,15 @@ def sm90_epilogue_tile(self, tile_description): # Get the epilogue tile size schedule = tile_description.epilogue_schedule if schedule == cutlass_library.EpilogueScheduleType.TmaWarpSpecialized: - epilogue_tile_mn = (64, 32) + element_d = self.dag_ir.get_node_meta("D").element + nperf = 64 if (DataTypeSize[element_d] == 8 and tile_description.threadblock_shape[1] % 64 == 0) else 32 + epi_tile_m = min(64, tile_description.threadblock_shape[0]) + epi_tile_n = gcd(min(nperf, tile_description.threadblock_shape[1]), tile_description.threadblock_shape[1]) + epilogue_tile_mn = (epi_tile_m, epi_tile_n) elif schedule == cutlass_library.EpilogueScheduleType.TmaWarpSpecializedCooperative: - if tile_description.threadblock_shape[0] >= 128: - epilogue_tile_mn = (128, 32) - else: - epilogue_tile_mn = (64, 32) + epi_tile_m = min(128, tile_description.threadblock_shape[0]) + epi_tile_n = gcd(min(32, tile_description.threadblock_shape[1]), tile_description.threadblock_shape[1]) + epilogue_tile_mn = (epi_tile_m, epi_tile_n) else: raise NotImplementedError(f"Unsupported schedule: {schedule}") @@ -93,11 +98,7 @@ def sm90_epilogue_tile(self, tile_description): self.element_d = element_d self.is_source_supported = element_c is not None - def sm90_epilogue_smem_size(self, tile_description): - """ - Compute the shared memory size of sm90 collective epilogue - """ - self.sm90_epilogue_tile(tile_description) + def sm90_or_sm100_epilogue_smem_size(self, tile_description): # Get the Fusion Storage nodes = self.dag_ir.nodes_topological_order() self.smem_types = {} @@ -139,6 +140,120 @@ def sm90_epilogue_smem_size(self, tile_description): smem_size = self.get_struct_size([tensor_smem_size, pipeline_smem_size]) return smem_size[0] + def sm90_epilogue_smem_size(self, tile_description): + """ + Compute the shared memory size of sm90 collective epilogue + """ + self.sm90_epilogue_tile(tile_description) + return self.sm90_or_sm100_epilogue_smem_size(tile_description) + + # + # Sm100 epilogue specific + # + + def sm100_epilogue_tile(self, tile_description): + cta_tile = (tile_description.blackwell_threadblock_shape[0], tile_description.blackwell_threadblock_shape[1]) + mma_tile = cta_tile + + if tile_description.is_2sm: + cta_tile = (cta_tile[0] // 2, cta_tile[1]) + + if tile_description.is_2sm and mma_tile[0] == 128: + tmem_warps = (2, 2) + else: + tmem_warps = (4, 1) + + if self.dag_ir.has_node("C"): + element_c = self.dag_ir.get_node_meta("C").element + element_c_size = DataTypeSize[element_c] + else: + element_c = None + element_c_size = 0 + + element_d = self.dag_ir.get_node_meta("D").element + + DisableSource = element_c is None or not self.dag_ir.has_node("C") or self.dag_ir.get_node_meta("C").element == DataType.void + + CtaM = cta_tile[0] + CtaN = cta_tile[1] + WarpM = tmem_warps[0] + WarpN = tmem_warps[1] + MaxBits = max(element_c_size, DataTypeSize[element_d]) + DpFull = 32 + M = min(CtaM, DpFull * WarpM) + + if DisableSource: + # Epilogues w/o residual load are less sensitive to smem allocation + # Target a fixed amount of compute per epilogue iteration + if MaxBits == 4: + # Make epilogue tile larger to reduce the epilogue iterations. + # 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same. + ComputeElts = 8192 + Nperf = ComputeElts // M + else: + ComputeElts = 4096 + Nperf = ComputeElts // M + else: + # Epilogues w/ residual load are more sensitive to smem allocation + # Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize + if MaxBits == 32: + Nperf = 16 if CtaM > 64 and CtaN <= 128 else 32 + elif MaxBits == 16: + Nperf = 32 if CtaN <= 128 else 64 + else: + Nperf = 64 + + def is_m_major(layout): + return flatten(layout.stride[0]) == 1 + + if DisableSource or is_m_major(self.dag_ir.get_node_meta("C").tensor.layout): + N_min_C = 8 * WarpN + elif element_c_size == 6: + N_min_C = 128 * WarpN + else: + N_min_C = (128 // element_c_size) * WarpN + + if is_m_major(self.dag_ir.get_node_meta("D").tensor.layout): + N_min_D = 8 * WarpN + elif DataTypeSize[element_d] == 6: + N_min_D = 128 * WarpN + else: + N_min_D = (128 // DataTypeSize[element_d]) * WarpN + + N = min(CtaN, max(Nperf, N_min_C, N_min_D)) + + tile_m = M + tile_n_size = N // WarpN * WarpN + + epilogue_tile_mn = (tile_m, tile_n_size) + epi_tiles = product(shape_div(tuple(tile_description.threadblock_shape)[:2], epilogue_tile_mn)) + + stages_d = min(epi_tiles, 2) + reuse_smem_c = (element_c_size > 8) + + if reuse_smem_c: + stages_c = max(min(epi_tiles, 4), stages_d + 1) + else: + stages_c = min(epi_tiles, 4) + + # Record the epilogue tile + self.cta_tile_mnk = tuple(tile_description.threadblock_shape) + self.epilogue_tile_mn = epilogue_tile_mn + self.epi_tiles = epi_tiles + self.stages_c = stages_c + self.stages_d = stages_d + self.reuse_smem_c = reuse_smem_c + self.element_c = element_c + self.element_d = element_d + self.is_source_supported = not DisableSource + + def sm100_epilogue_smem_size(self, tile_description): + """ + Compute the shared memory size of sm100 collective epilogue + """ + self.sm100_epilogue_tile(tile_description) + return self.sm90_or_sm100_epilogue_smem_size(tile_description) + def __call__(self, tile_description): return getattr(self, f"sm{self.cc}_epilogue_smem_size")(tile_description) diff --git a/python/cutlass/backend/evt/passes/util.py b/python/cutlass_cppgen/backend/evt/passes/util.py similarity index 92% rename from python/cutlass/backend/evt/passes/util.py rename to python/cutlass_cppgen/backend/evt/passes/util.py index ad014bf53b..c3a0a8ebbb 100644 --- a/python/cutlass/backend/evt/passes/util.py +++ b/python/cutlass_cppgen/backend/evt/passes/util.py @@ -36,8 +36,13 @@ # Map from the CC of the kernel to the EVT implementation that the CC targets cc_map = { - 80: 80, - 86: 80, - 89: 80, - 90: 90, + 12: 12, # Intel Xe12 PVC + 20: 20, # Intel Xe20 BMG + 80: 80, + 86: 80, + 89: 80, + 90: 90, + 100: 100, + 101: 100, + 103: 100, } diff --git a/python/cutlass/backend/frontend.py b/python/cutlass_cppgen/backend/frontend.py similarity index 100% rename from python/cutlass/backend/frontend.py rename to python/cutlass_cppgen/backend/frontend.py diff --git a/python/cutlass/backend/gemm_operation.py b/python/cutlass_cppgen/backend/gemm_operation.py similarity index 99% rename from python/cutlass/backend/gemm_operation.py rename to python/cutlass_cppgen/backend/gemm_operation.py index b836df50ac..b6b30188a8 100644 --- a/python/cutlass/backend/gemm_operation.py +++ b/python/cutlass_cppgen/backend/gemm_operation.py @@ -39,6 +39,7 @@ cuda = lazy_import("cuda.cuda") cudart = lazy_import("cuda.cudart") from cutlass_library import SubstituteTemplate +from cutlass_library.arch_constants import is_intel_xe_arch import numpy as np import dpctl @@ -188,7 +189,7 @@ def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalM if operation.C.layout in [LayoutType.RowMajorInterleaved32, LayoutType.ColumnMajorInterleaved32]: raise Exception("Interleaved layout not currently supported") - if hasattr(self.operation.epilogue_functor, "visitor") and operation.arch != 90: + if hasattr(self.operation.epilogue_functor, "visitor") and operation.arch not in [90, 100, 101, 103]: super().__init__(A, B, None, None, **kwargs) else: super().__init__(A, B, C, D, **kwargs) @@ -573,7 +574,9 @@ def get_arguments(self): # Set hardware info hw_info_ = hw_info( - 0, device_sm_count(), + 0, device_sm_count(), 0, + dim3_(0,0,0), + dim3_(0,0,0), ) self.arguments = argument_type( @@ -915,7 +918,7 @@ def get_device_workspace_size(self, arguments): return 0 def initialize(self): - if self.operation.arch == 11: + if is_intel_xe_arch(self.operation.arch): return err, = cuda.cuFuncSetAttribute( @@ -1318,7 +1321,7 @@ def __init__(self, operation_suffix=""): def emit(self, operation): # Support built-in epilogue functors or user-defined functions - if operation.arch == 11: + if is_intel_xe_arch(operation.arch): stage_count_type = "cutlass::gemm::collective::StageCountAuto" elif operation.tile_description.stages is None or operation.tile_description.stages == 0: stage_count_type = "cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>" @@ -1340,7 +1343,7 @@ def emit(self, operation): if operation.tile_description.tile_scheduler is not None: tschedule = operation.tile_description.tile_scheduler - arch = "cutlass::arch::IntelXe" if operation.arch == 11 else f"cutlass::arch::Sm{operation.arch}" + arch = f"cutlass::arch::Xe{operation.arch}" if is_intel_xe_arch(operation.arch) else f"cutlass::arch::Sm{operation.arch}" values = { "operation_name": operation.procedural_name(), "operation_suffix": self.operation_suffix, @@ -1718,10 +1721,15 @@ def epilogue_schedule_name_3x(self): def procedural_name(self): """The full procedural name indicates architecture, extended name, tile size, and layout.""" opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] - if self.api == ApiVersion.v3x and (self.arch >= 90 or self.arch == 11): - kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{k}{e}" + if self.api == ApiVersion.v3x and (self.arch >= 90 or is_intel_xe_arch(self.arch)): + arch_prefix="sm" + if is_intel_xe_arch(self.arch): + arch_prefix="Xe" + + kernel_name_template = "cutlass{p}_{sm_or_xe}{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{k}{e}" return kernel_name_template.format( p=self.prefix, + sm_or_xe=arch_prefix, ar=self.arch, op=opcode_class_name, ex=self.extended_name_3x(), diff --git a/python/cutlass/backend/library.py b/python/cutlass_cppgen/backend/library.py similarity index 94% rename from python/cutlass/backend/library.py rename to python/cutlass_cppgen/backend/library.py index 1b9eddb469..4bdae90b7c 100644 --- a/python/cutlass/backend/library.py +++ b/python/cutlass_cppgen/backend/library.py @@ -41,12 +41,13 @@ DataType, DataTypeSize, EpilogueScheduleType, + KernelScheduleSuffixes, KernelScheduleType, MathOperation, OpcodeClass, TileSchedulerType ) - +from cutlass_library.arch_constants import is_intel_xe_arch # The following block implements enum.auto() for Python 3.5 variants that don't include it such # as the default 3.5.2 on Ubuntu 16.04. @@ -238,6 +239,22 @@ def __init__( self.math_operation = math_operation +def to_blackwell_threadblock_shape(tile_description, cluster_shape, kernel_schedule): + blackwell_threadblock_shape = tile_description.threadblock_shape + is_2sm = False if kernel_schedule is None else ("2sm" in KernelScheduleSuffixes[kernel_schedule]) + if cluster_shape[0] > 0: + blackwell_threadblock_shape = [ + tile_description.threadblock_shape[0] // cluster_shape[0], + tile_description.threadblock_shape[1] // cluster_shape[1], + tile_description.threadblock_shape[2] // cluster_shape[2] + ] + if is_2sm: + blackwell_threadblock_shape[0] *= 2 + else: + blackwell_threadblock_shape = tile_description.math_instruction.instruction_shape + return blackwell_threadblock_shape, is_2sm + + class TileDescription: """ Description of a tile of computation to be performed in the kernel, encompassing threadblock, cluster, and warp shapes, @@ -290,6 +307,8 @@ def __init__( # Number of warps along x, y, z directions self.warp_count = warp_count + self.blackwell_threadblock_shape, self.is_2sm = to_blackwell_threadblock_shape(self, self.cluster_shape, self.kernel_schedule) + def clone_and_update(self, td: dict): attrs = { "cluster_shape": None, @@ -473,10 +492,10 @@ def api_version(arch, opclass, dtype): :return: API version to be used in code emission :rtype: ApiVersion """ - if opclass == OpcodeClass.TensorOp and arch == 11: + if opclass == OpcodeClass.TensorOp and is_intel_xe_arch(arch): return ApiVersion.v3x - if (arch >= 90 and + if (arch in [90, 100, 101, 103] and opclass == OpcodeClass.TensorOp and (dtype != DataType.f64)): return ApiVersion.v3x diff --git a/python/cutlass/backend/memory_manager.py b/python/cutlass_cppgen/backend/memory_manager.py similarity index 100% rename from python/cutlass/backend/memory_manager.py rename to python/cutlass_cppgen/backend/memory_manager.py diff --git a/python/cutlass/backend/operation.py b/python/cutlass_cppgen/backend/operation.py similarity index 98% rename from python/cutlass/backend/operation.py rename to python/cutlass_cppgen/backend/operation.py index 19fdb88c02..7f1325c1b8 100644 --- a/python/cutlass/backend/operation.py +++ b/python/cutlass_cppgen/backend/operation.py @@ -47,7 +47,7 @@ def supports_cluster_launch(): global _supports_cluster_launch if _supports_cluster_launch is None: major, minor = _version_splits[0], _version_splits[1] - _supports_cluster_launch = device_cc() >= 90 and (major > 11 or (major == 11 and minor >= 8)) + _supports_cluster_launch = device_cc() in [90, 100, 101, 103] and (major > 11 or (major == 11 and minor >= 8)) return _supports_cluster_launch diff --git a/python/cutlass/backend/reduction_operation.py b/python/cutlass_cppgen/backend/reduction_operation.py similarity index 100% rename from python/cutlass/backend/reduction_operation.py rename to python/cutlass_cppgen/backend/reduction_operation.py diff --git a/python/cutlass/backend/type_hint.py b/python/cutlass_cppgen/backend/type_hint.py similarity index 100% rename from python/cutlass/backend/type_hint.py rename to python/cutlass_cppgen/backend/type_hint.py diff --git a/python/cutlass/backend/utils/__init__.py b/python/cutlass_cppgen/backend/utils/__init__.py similarity index 100% rename from python/cutlass/backend/utils/__init__.py rename to python/cutlass_cppgen/backend/utils/__init__.py diff --git a/python/cutlass/backend/utils/device.py b/python/cutlass_cppgen/backend/utils/device.py similarity index 98% rename from python/cutlass/backend/utils/device.py rename to python/cutlass_cppgen/backend/utils/device.py index c0c33700ff..f96071a3dd 100644 --- a/python/cutlass/backend/utils/device.py +++ b/python/cutlass_cppgen/backend/utils/device.py @@ -81,8 +81,8 @@ def device_cc(device: int = -1) -> int: device = cutlass_cppgen.device_id() if cutlass_cppgen._use_sycl: - # Using '11' to encode Intel PVC as an integer in the expected format. - return 11 + # Using '12' to encode Intel PVC as an integer in the expected format. + return 12 deviceProp = check_cuda_errors(cudart.cudaGetDeviceProperties(device)) major = str(deviceProp.major) diff --git a/python/cutlass/emit/__init__.py b/python/cutlass_cppgen/emit/__init__.py similarity index 100% rename from python/cutlass/emit/__init__.py rename to python/cutlass_cppgen/emit/__init__.py diff --git a/python/cutlass/emit/common.py b/python/cutlass_cppgen/emit/common.py similarity index 100% rename from python/cutlass/emit/common.py rename to python/cutlass_cppgen/emit/common.py diff --git a/python/cutlass/emit/pytorch.py b/python/cutlass_cppgen/emit/pytorch.py similarity index 99% rename from python/cutlass/emit/pytorch.py rename to python/cutlass_cppgen/emit/pytorch.py index 86374b8b0c..fe96f3ede1 100644 --- a/python/cutlass/emit/pytorch.py +++ b/python/cutlass_cppgen/emit/pytorch.py @@ -689,10 +689,10 @@ def _jit(name: str, cc: int, cpp_file: str, cuda_file: str): from torch.utils.cpp_extension import load extra_cuda_cflags = ["-std=c++17"] - if cc == 90: + if cc in [90, 100, 101, 103]: # PyTorch does not currently add the sm_90a target when compute capability # 9.0 is set within TORCH_CUDA_ARCH_LIST. Thus, we manually add the sm_90a target. - extra_cuda_cflags.append("-gencode=arch=compute_90a,code=sm_90a") + extra_cuda_cflags.append(f"-gencode=arch=compute_{cc}a,code=sm_{cc}a") with _ArchListSetter(cc): jitmodule = load( @@ -768,8 +768,8 @@ def _pytorch_gemm(op, name: str, cc: int, jit: bool = False, sourcedir: str = "" outfile.write(cpp_source) extra_compile_args = "" - if cc == 90: - extra_compile_args = "'--generate-code=arch=compute_90a,code=[sm_90a]'" + if cc in [90, 100, 101, 103]: + extra_compile_args = f"'--generate-code=arch=compute_{cc}a,code=[sm_{cc}a]'" _generate_setup(name, sourcedir, extra_compile_args) if jit: diff --git a/python/cutlass/epilogue/__init__.py b/python/cutlass_cppgen/epilogue/__init__.py similarity index 100% rename from python/cutlass/epilogue/__init__.py rename to python/cutlass_cppgen/epilogue/__init__.py diff --git a/python/cutlass/epilogue/epilogue.py b/python/cutlass_cppgen/epilogue/epilogue.py similarity index 89% rename from python/cutlass/epilogue/epilogue.py rename to python/cutlass_cppgen/epilogue/epilogue.py index d3887c4842..130b829036 100644 --- a/python/cutlass/epilogue/epilogue.py +++ b/python/cutlass_cppgen/epilogue/epilogue.py @@ -118,7 +118,7 @@ def trace(fn, example_tensors, **kwargs): """ Trace `fn(**example_tensors)` and generates epilogue visitor - :param fn: Python callables + :param fn or str: Python callable or string of the epilogue function :param example_tensors: example inputs for fn :type example_tensors: dict @@ -153,6 +153,22 @@ def __init__(self, cc=None, **kwargs): pass setattr(EpilogueFunctor, "__call__", staticmethod(fn)) + epilogue_functor = EpilogueFunctor(**kwargs) + epilogue_functor.trace(example_tensors) + return epilogue_functor + elif isinstance(fn, str): + class EpilogueFunctor(PythonASTFrontend): + def __init__(self, cc=None, **kwargs): + self.source = textwrap.dedent(fn) + if not cc: + cc = device_cc() + super().__init__(cc, **kwargs) + + def parse(self, example_inputs) -> None: + self.example_inputs = example_inputs + self.ast = ast.parse(self.source) + self.visit(self.ast) + epilogue_functor = EpilogueFunctor(**kwargs) epilogue_functor.trace(example_tensors) return epilogue_functor diff --git a/python/cutlass/epilogue/evt_ops.py b/python/cutlass_cppgen/epilogue/evt_ops.py similarity index 100% rename from python/cutlass/epilogue/evt_ops.py rename to python/cutlass_cppgen/epilogue/evt_ops.py diff --git a/python/cutlass/library_defaults.py b/python/cutlass_cppgen/library_defaults.py similarity index 94% rename from python/cutlass/library_defaults.py rename to python/cutlass_cppgen/library_defaults.py index 7ddf96e02f..3691ab5469 100644 --- a/python/cutlass/library_defaults.py +++ b/python/cutlass_cppgen/library_defaults.py @@ -40,14 +40,24 @@ import cutlass_library from cutlass_library.library import ConvKind, IteratorAlgorithm, StrideSupport, GroupMode +from cutlass_library.arch_constants import ( + INTEL_XE_ARCH_MIN, + INTEL_XE_ARCH_MAX, + INTEL_XE12, + INTEL_XE20, + INTEL_XE35, + is_intel_xe_arch +) import cutlass_cppgen from cutlass_cppgen.utils.check import valid_stage_count from cutlass_cppgen.utils.datatypes import td_from_profiler_td, td_from_profiler_op -# The value '11' is used to encode Intel PVC GPU in the expected format. -_generator_ccs = [11, 50, 60, 61, 70, 75, 80, 90] +# Intel Xe architectures and supported NVIDIA architectures +# Intel Xe: 12 (PVC/Xe-HPC), 20 (BMG/Xe2), 30 (future) +# NVIDIA architectures: 50, 60, 61, 70, 75, 80, 90, 100 +_generator_ccs = [INTEL_XE12, INTEL_XE20] #50, 60, 61, 70, 75, 80, 90, 100] class KernelsForDataType: """ @@ -259,9 +269,17 @@ def __init__( self.op_class = None self.allowed_math_operations = allowed_math_operations + if target_cc == 100 and kernel_cc == 90 or target_cc == 90 and kernel_cc == 100: + return + # Identify the method within CUTLASS generator script that generates kernel # descriptions for the target CC - generate_function_name = "GeneratePVC" if kernel_cc == 11 else "GenerateSM" + str(kernel_cc) + # Intel Xe architectures use GenerateIntelXe, NVIDIA uses GenerateSM{cc} + if is_intel_xe_arch(kernel_cc): + generate_function_name = "GenerateIntelXe" + else: + generate_function_name = "GenerateSM" + str(kernel_cc) + if not hasattr(cutlass_library.generator, generate_function_name): cutlass_cppgen.logger.warning(f"No generator found for architecture {kernel_cc}") return @@ -273,13 +291,20 @@ def __init__( "--kernels=all", f"--log-level={logging.getLevelName(cutlass_cppgen.logger.level)}" ] - if self.cc == 11: - args.append("--architectures=11") + # For Intel Xe architectures, specify the architecture number + if is_intel_xe_arch(kernel_cc): + args.append(f"--architectures={kernel_cc}") manifest_args = cutlass_library.generator.define_parser().parse_args(args) manifest = cutlass_library.manifest.Manifest(manifest_args) - generate_function(manifest, cutlass_cppgen._nvcc_version) - + + # For Intel Xe architectures, pass the architecture number to the generator + if is_intel_xe_arch(kernel_cc): + print(f"Calling {generate_function_name} with arch={kernel_cc}") + generate_function(manifest, cutlass_cppgen._nvcc_version, arch=kernel_cc) + else: + generate_function(manifest, cutlass_cppgen._nvcc_version) + if operation_kind not in manifest.operations: # No kernels generated for this architecture, this could be because the CUDA # toolkit is insufficient to support operations in this CC @@ -296,6 +321,7 @@ def __init__( # find available opclasses and data types for name, op_list in manifest.operations[operation_kind][kernel_cc].items(): for op in op_list: + if operation_kind == cutlass_library.OperationKind.Gemm: if op.gemm_kind not in gemm_kinds: continue @@ -320,7 +346,7 @@ def __init__( # TF32 kernels only supported on SM80 and beyond if self.cc < 80: continue - elif self.cc == 90: + elif self.cc == 90 or self.cc == 100: if (op.A.element != cutlass_library.DataType.f32 or op.B.element != cutlass_library.DataType.f32 or op.C.element != cutlass_library.DataType.f32): @@ -554,8 +580,10 @@ class OptionRegistry: def __init__(self, target_cc: int): self.registry = {} - if target_cc > 90: - raise Exception(f"Unsupported compute capability {target_cc}. The CUTLASS Python interface only supports compute capabilities up to 90.") + # Intel Xe architectures: 12-20 (PVC, BMG, etc.) + # NVIDIA architectures: 50-121 + if (target_cc > 100 and (target_cc not in [101, 103, 120, 121])) or (not is_intel_xe_arch(target_cc)): + raise Exception(f"Unsupported compute capability {target_cc}. Supported: NVIDIA SM 50-121, Intel Xe 12-20.") gemm_kinds = [cutlass_library.GemmKind.Universal, cutlass_library.GemmKind.Universal3x] operation_kinds = [cutlass_library.OperationKind.Gemm, cutlass_library.OperationKind.Conv2d] diff --git a/python/cutlass/op/__init__.py b/python/cutlass_cppgen/op/__init__.py similarity index 100% rename from python/cutlass/op/__init__.py rename to python/cutlass_cppgen/op/__init__.py diff --git a/python/cutlass/op/conv.py b/python/cutlass_cppgen/op/conv.py similarity index 99% rename from python/cutlass/op/conv.py rename to python/cutlass_cppgen/op/conv.py index 4f21d85436..711b27da13 100644 --- a/python/cutlass/op/conv.py +++ b/python/cutlass_cppgen/op/conv.py @@ -212,7 +212,7 @@ def __init__( ): super().__init__(cc=cc, kernel_cc=kernel_cc, operation_kind=OperationKind.Conv2d) # Verify the kernel cc - if self.current_cc == 90: + if self.current_cc in [90, 100, 101, 103]: # The Conv2d kernel on Hopper (SM90) is currently unsupported # Revert to use SM80-tagged kernels cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.") diff --git a/python/cutlass/op/gemm.py b/python/cutlass_cppgen/op/gemm.py similarity index 98% rename from python/cutlass/op/gemm.py rename to python/cutlass_cppgen/op/gemm.py index b4fb710028..052c7145b0 100644 --- a/python/cutlass/op/gemm.py +++ b/python/cutlass_cppgen/op/gemm.py @@ -123,6 +123,7 @@ DataType, DataTypeSize, GemmUniversalMode, + KernelScheduleSuffixes, ) import cutlass_cppgen @@ -324,8 +325,8 @@ def swizzling_functor(self, swizzling_functor): if self.op_class == cutlass_cppgen.OpcodeClass.Simt: raise Exception('ThreadblockSwizzleStreamK is currently only supported with opcode class TensorOp') - if self.current_cc == 90: - raise Exception('ThreadblockSwizzleStreamK is currently unsupported on SM90') + if self.current_cc in [90, 100, 101, 103]: + raise Exception('ThreadblockSwizzleStreamK is currently unsupported on SM90+') self._swizzling_functor = swizzling_functor # @@ -395,6 +396,11 @@ def _valid_tile_description(self, td: TileDescription) -> tuple: return (valid, msg) valid, msg = check.valid_schedule(self.current_cc, td.kernel_schedule, td.epilogue_schedule, td.tile_scheduler) + + if self.cc in [100, 101, 103] and td.kernel_schedule is not None and td.is_2sm and td.cluster_shape[0] % 2 != 0: + valid = False + msg = "Cluster shape must be divisible by 2 for 2SM kernels on SM100, SM101, and SM103" + return valid, msg def tile_descriptions(self) -> list: diff --git a/python/cutlass/op/gemm_grouped.py b/python/cutlass_cppgen/op/gemm_grouped.py similarity index 99% rename from python/cutlass/op/gemm_grouped.py rename to python/cutlass_cppgen/op/gemm_grouped.py index 594106f2d1..59f90535c2 100644 --- a/python/cutlass/op/gemm_grouped.py +++ b/python/cutlass_cppgen/op/gemm_grouped.py @@ -133,7 +133,7 @@ def __init__( ) # Grouped GEMM specializations for SM90 are currently unavailable. Revert to using SM80 - if self.current_cc == 90: + if self.current_cc in [90, 100, 101, 103]: self._reset_options(80) self._reset_operations(reset_epilogue=False) diff --git a/python/cutlass/op/op.py b/python/cutlass_cppgen/op/op.py similarity index 95% rename from python/cutlass/op/op.py rename to python/cutlass_cppgen/op/op.py index 88ccd26e07..bebf07a7e5 100644 --- a/python/cutlass/op/op.py +++ b/python/cutlass_cppgen/op/op.py @@ -47,6 +47,7 @@ import cutlass_cppgen from cutlass_cppgen import get_option_registry from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor +from cutlass_cppgen.backend.evt.passes.util import cc_map from cutlass_cppgen.backend.utils.device import device_cc from cutlass_cppgen.epilogue import get_activations, get_activation_epilogue, identity from cutlass_cppgen.library_defaults import KernelsForDataType, _generator_ccs @@ -251,13 +252,13 @@ def math_operation(self, mo: cutlass_cppgen.MathOperation): mo = datatypes.getattr_enum(cutlass_cppgen.MathOperation, mo) if not self.specified_kernel_cc: - if self.current_cc == 90: + if self.current_cc in [90, 100, 101, 103]: # CUTLASS 3.0 kernels do not use different math operations. If one is specified, we # revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels. cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.") self._reset_options(80) self._reset_operations(reset_epilogue=False) - elif self.current_cc == 90: + elif self.current_cc in [90, 100, 101, 103]: raise Exception("CUTLASS 3.0 kernels do not use different math operations. " "To use 2.x kernels with a specific math operation, do not set the `kernel_cc`" "parameter when constructing the plan.") @@ -283,7 +284,7 @@ def _create_epilogue_functor_activation(self, activation): elements_per_access = self.epilogue_functor.epilogue_vector_length if not self.specified_kernel_cc: - if self.current_cc == 90 and activation != identity: + if self.current_cc in [90, 100, 101, 103] and activation != identity: # CUTLASS 3.0 kernels in Python currently only support identity activation. If one requests a non-identity activation, # revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels. cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.") @@ -291,13 +292,13 @@ def _create_epilogue_functor_activation(self, activation): raise Exception("CUTLASS 2.x kernels require element C to be the same as element D") self._reset_options(80) self._reset_operations(reset_epilogue=False) - elif (self.cc == 90 and self.current_cc != 90 and activation == identity and self._math_operation is None): + elif (self.cc in [90, 100, 101, 103] and self.current_cc not in [90, 100, 101, 103] and activation == identity and self._math_operation is None): # SM80 fallback kernels are currently used. Since an identity activation is requested, # we can switch back to using SM90 kernels. - self._reset_options(90) + self._reset_options(self.cc) self._reset_operations(reset_epilogue=False) else: - if self.current_cc == 90 and activation != identity: + if self.current_cc in [90, 100, 101, 103] and activation != identity: raise Exception("Epilogues with elementwise fusion are not currently supported " "in the Python interface for 3.x kernels. To use 2.x kernels " "with fused elementwise epilogues, do not set the `kernel_cc` " @@ -385,12 +386,12 @@ def epilogue_visitor(self, visitor): """ Create the epilogue visitor """ - self.epilogue_functor = EpilogueFunctorVisitor(self.cc, visitor) + self.epilogue_functor = EpilogueFunctorVisitor(cc_map[self.cc], visitor) # The epilogue_functor may consume too much shared memory # Reset the possible operations - if self.cc != 90: - # The shared memory is only a concern for sm90 epilogue + if self.cc not in [90, 100, 101, 103]: + # The shared memory is only a concern for sm90+ epilogue # In sm80, the epilogue and mainloop share the shared memory return @@ -400,7 +401,7 @@ def epilogue_visitor(self, visitor): for operation in self.possible_operations.all_operations: td = datatypes.td_from_profiler_op(operation) # Filter invalid epilogue schedules - if td.epilogue_schedule not in [ + if cc_map[self.cc] == 90 and td.epilogue_schedule not in [ cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized, cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative]: continue diff --git a/python/cutlass/shape.py b/python/cutlass_cppgen/shape.py similarity index 100% rename from python/cutlass/shape.py rename to python/cutlass_cppgen/shape.py diff --git a/python/cutlass/swizzle.py b/python/cutlass_cppgen/swizzle.py similarity index 100% rename from python/cutlass/swizzle.py rename to python/cutlass_cppgen/swizzle.py diff --git a/python/cutlass/utils/__init__.py b/python/cutlass_cppgen/utils/__init__.py similarity index 100% rename from python/cutlass/utils/__init__.py rename to python/cutlass_cppgen/utils/__init__.py diff --git a/python/cutlass/utils/check.py b/python/cutlass_cppgen/utils/check.py similarity index 88% rename from python/cutlass/utils/check.py rename to python/cutlass_cppgen/utils/check.py index aa6cb804a2..89725bc6e0 100644 --- a/python/cutlass/utils/check.py +++ b/python/cutlass_cppgen/utils/check.py @@ -36,8 +36,15 @@ import ctypes -from cutlass_library import DataTypeSize, OperationKind, SharedMemPerCC - +from cutlass_library import DataTypeSize, KernelScheduleSuffixes, OperationKind, SharedMemPerCC +from cutlass_library.arch_constants import ( + INTEL_XE_ARCH_MIN, + INTEL_XE_ARCH_MAX, + INTEL_XE12, + INTEL_XE20, + INTEL_XE35, + is_intel_xe_arch +) import cutlass_cppgen from cutlass_cppgen.backend.library import TileDescription @@ -54,7 +61,9 @@ def calculate_smem_usage_per_stage(td: TileDescription, operation_kind: Operatio :return: number of bytes of shared memory consumed by a single stage :rtype: int """ - m, n, k = td.threadblock_shape + m, n, k = td.blackwell_threadblock_shape + if td.is_2sm: + m //= 2 if operation_kind == OperationKind.Gemm: stage_barrier_bytes = 32 @@ -106,7 +115,7 @@ def valid_stage_count( valid for the provided device and the second element being an error message :rtype: tuple """ - if kernel_cc == 90: + if kernel_cc in [90, 100, 101, 103]: if (td.stages is None or td.stages == 0): # Stage count of None or 0 for SM90 indicates that the CollectiveBuilder automatically # determines the stage count to use. Thus, all settings are valid in these scenarios. @@ -117,16 +126,16 @@ def valid_stage_count( "result in compilation errors if the combination of tile shape, " "stage count, and shared memory requirement of the epilogue exceeds " "the available shared memory per SM.") - - if kernel_cc == 11: + print(f"KernelCC: {kernel_cc}") + if is_intel_xe_arch(kernel_cc): if (td.stages is None or td.stages == 0): - # Support for Intel PVC GPU currently does not allow explicit + # Support for Intel Xe GPUs currently does not allow explicit # specification of the stage count. With None or 0, the # CollectiveBuilder automatically determines the stage count to use. return (True, "") elif verbose: - cutlass.logger.warning( - "Setting an explicit stage count for Intel PVC GPU is currently " + cutlass_cppgen.logger.warning( + "Setting an explicit stage count for Intel Xe GPUs is currently " "not supported.") if td.stages <= 0: @@ -168,10 +177,10 @@ def valid_cluster_shape(cc: int, cluster_shape: list) -> tuple: :rtype: tuple """ - if cc < 90: + if cc < 90 or cc in [120, 121]: if cluster_shape != [1, 1, 1]: return (False, - f"Cluster shape for pre-SM90 architectures must be [1, 1, 1]. Received cluster shape of " + f"Cluster shape for pre-SM90 architectures and SM 120 and 121 must be [1, 1, 1]. Received cluster shape of " f"{cluster_shape} for SM{cc}.") else: return (True, "") @@ -185,15 +194,6 @@ def valid_cluster_shape(cc: int, cluster_shape: list) -> tuple: "CUTLASS kernels currently require the third dimension of cluster shape to be 1. " f"Received cluster shape of {cluster_shape}.") - # The CUDA programming guide currently defines a maximum of 8 thread blocks per cluster - # as being portably supported (https://docs.nvidia.com/cuda/cuda-c-programming-guide/#thread-block-clusters). - # Current CUTLASS kernels only have non-unit cluster dimensions within the first two dimensions, - # so we check that the first two dimensions of the cluster shape do not exceed 8 thread blocks in total. - blocks_in_2d = cluster_shape[0] * cluster_shape[1] - if blocks_in_2d > 8: - return (False, - f"Thread block clusters with more than 8 thread blocks are currently unsupported on SM{cc}. " - f"Received cluster shape {cluster_shape}, which has {blocks_in_2d} thread blocks.") return (True, "") @@ -222,16 +222,16 @@ def valid_schedule( kernel_auto = (kernel_schedule == cutlass_cppgen.KernelScheduleType.ScheduleAuto) epilogue_auto = (epilogue_schedule == cutlass_cppgen.EpilogueScheduleType.ScheduleAuto) tile_scheduler_default = (tile_scheduler == cutlass_cppgen.TileSchedulerType.Default) - if cc < 90 and not (kernel_auto and epilogue_auto and tile_scheduler_default): - return (False, "Non-default schedules are only supported on SM90 and beyond") + if (cc < 90 or cc in [120, 121]) and not (kernel_auto and epilogue_auto and tile_scheduler_default): + return (False, "Non-default schedules are only supported on SM90 and beyond (excluding SM120 and SM121)") - if (kernel_auto and not epilogue_auto) or (not kernel_auto and epilogue_auto): + if cc == 90 and ((kernel_auto and not epilogue_auto) or (not kernel_auto and epilogue_auto)): return (False, "Kernel and epilogue schedules must either both be auto or neither be auto") if not tile_scheduler_default: cooperative_kernels = [cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative, cutlass_cppgen.KernelScheduleType.CpAsyncWarpSpecializedCooperative] - if (tile_scheduler == cutlass_cppgen.TileSchedulerType.StreamK) and (kernel_schedule not in cooperative_kernels): + if cc == 90 and (tile_scheduler == cutlass_cppgen.TileSchedulerType.StreamK) and (kernel_schedule not in cooperative_kernels): return (False, "Stream-K tile scheduler is currently only supported with the cooperative kernel schedule") return (True, "") diff --git a/python/cutlass/utils/datatypes.py b/python/cutlass_cppgen/utils/datatypes.py similarity index 100% rename from python/cutlass/utils/datatypes.py rename to python/cutlass_cppgen/utils/datatypes.py diff --git a/python/cutlass/utils/lazy_import.py b/python/cutlass_cppgen/utils/lazy_import.py similarity index 100% rename from python/cutlass/utils/lazy_import.py rename to python/cutlass_cppgen/utils/lazy_import.py diff --git a/python/cutlass/utils/profiler.py b/python/cutlass_cppgen/utils/profiler.py similarity index 100% rename from python/cutlass/utils/profiler.py rename to python/cutlass_cppgen/utils/profiler.py diff --git a/python/cutlass_library/arch_constants.py b/python/cutlass_library/arch_constants.py new file mode 100644 index 0000000000..7df19f6440 --- /dev/null +++ b/python/cutlass_library/arch_constants.py @@ -0,0 +1,85 @@ +################################################################################################# +# +# Copyright (C) 2025 Intel Corporation, All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Architecture range constants for CUTLASS library generation. +Shared across manifest.py and gemm_operation.py to avoid circular imports. +""" + +################################################################################################### +# Architecture range constants +# Intel Xe architectures use the range [INTEL_XE_ARCH_MIN, INTEL_XE_ARCH_MAX) +# CUDA architectures use values >= CUDA_ARCH_MIN +################################################################################################### +INTEL_XE_ARCH_MIN = 12 # Minimum Intel Xe architecture (PVC = 12, BMG = 20) +INTEL_XE_ARCH_MAX = 50 # Upper bound (exclusive) for Intel Xe range +CUDA_ARCH_MIN = 50 # Minimum CUDA architecture (sm_50, sm_60, etc.) + +################################################################################################### +# Specific Intel Xe architecture constants +################################################################################################### +# Intel Xe12 - PVC (Ponte Vecchio) HPC architecture +INTEL_XE12 = 12 + +# Intel Xe20 - BMG (Battlemage) gaming architecture +INTEL_XE20 = 20 + +# Intel Xe35 - Future architecture placeholder +INTEL_XE35 = 35 + +################################################################################################### +# Architecture validation helpers +################################################################################################### +def is_intel_xe_arch(arch): + """Check if the given architecture is an Intel Xe architecture.""" + return INTEL_XE_ARCH_MIN <= arch < INTEL_XE_ARCH_MAX + +def is_cuda_arch(arch): + """Check if the given architecture is a CUDA architecture.""" + return arch >= CUDA_ARCH_MIN + +def get_arch_name(arch): + """Get a human-readable name for the architecture.""" + if arch == INTEL_XE12: + return "Intel Xe12 (PVC)" + elif arch == INTEL_XE20: + return "Intel Xe20 (BMG)" + elif arch == INTEL_XE35: + return "Intel Xe35 (CRI)" + elif is_intel_xe_arch(arch): + return f"Intel Xe{arch}" + elif is_cuda_arch(arch): + return f"CUDA SM{arch}" + else: + return f"Unknown({arch})" + +################################################################################################### diff --git a/python/cutlass_library/emit_kernel_listing.py b/python/cutlass_library/emit_kernel_listing.py index 3121b7b0fe..fbe52eb587 100755 --- a/python/cutlass_library/emit_kernel_listing.py +++ b/python/cutlass_library/emit_kernel_listing.py @@ -348,11 +348,15 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode 'gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2', ] + block_scaled_tile_k = ['x128_', 'x256_'] + sm103_block_scaled_data_type = [ 'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2', 'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1', ] + sm103_block_scaled_tile_k = ['x768_'] + block_scaled_cluster_size = [ '4x4x1', '2x1x1', '0x0x1' # dynamic cluster @@ -360,11 +364,12 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode block_scaled_layouts = ['tnt'] # regex list must be in kernel procedural name order - block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*" - block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*" + block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*" + block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*" - sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*" - sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*" + sm103_block_scaled_prefetch_policy = ['tmapf'] + sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, sm103_block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*(" + "|".join(sm103_block_scaled_prefetch_policy) + ").*" + sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, sm103_block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*(" + "|".join(sm103_block_scaled_prefetch_policy) + ").*" if arch in ["100a", "100f"]: kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \ diff --git a/python/cutlass_library/gemm_operation.py b/python/cutlass_library/gemm_operation.py index 6dc9a0456b..fa37c5f17c 100644 --- a/python/cutlass_library/gemm_operation.py +++ b/python/cutlass_library/gemm_operation.py @@ -1,6 +1,7 @@ - +################################################################################################# # # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (C) 2025 Intel Corporation, All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -47,8 +48,16 @@ if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: raise ImportError("Disabling attempt to import cutlass_library") from cutlass_library.library import * + from cutlass_library.arch_constants import ( + INTEL_XE_ARCH_MIN, INTEL_XE_ARCH_MAX, CUDA_ARCH_MIN, + INTEL_XE12, INTEL_XE20, INTEL_XE35 + ) except ImportError: from library import * + from arch_constants import ( + INTEL_XE_ARCH_MIN, INTEL_XE_ARCH_MAX, CUDA_ARCH_MIN, + INTEL_XE12, INTEL_XE20, INTEL_XE35 + ) _LOGGER = logging.getLogger(__name__) @@ -87,7 +96,8 @@ def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, self.B = B self.C = C self.D = D - self.is_xe = self.arch == 11 + # Intel Xe architectures: PVC (12), BMG/Xe2 (20), etc. + self.is_xe = self.arch >= INTEL_XE_ARCH_MIN and self.arch < INTEL_XE_ARCH_MAX if is_block_scaled(gemm_kind): self.ScaleFactorA = ScaleFactorA @@ -388,15 +398,48 @@ def _procedural_name(self): l = self.layout_name(), a = str(max(self.A.alignment, self.B.alignment))) else: - threadblock = self.tile_description.procedural_name() - return "cutlass{p}_xe{ar}_{op}_{ex}_{tb}_{l}_align{a}".format( - p = self.prefix, - ar = self.arch, - op = opcode_class_name, - ex = self.extended_name(), - tb = threadblock, - l = self.layout_name(), - a = str(max(self.A.alignment, self.B.alignment))) + # Intel Xe architectures use xe{cc} naming with similar detail level as NVIDIA + # Format: cutlass{p}_xe{ar}_{op}_{ex}{ct}{cs}_{l}_{s}_align{al}{t}{k}{e} + if self.is_3x: + # Use 3x naming convention with full details like NVIDIA SM90+ + tile_shape = self.get_collective_tile_shape() + extended = self.extended_name_3x() + + # Add D type suffix if different from C type to distinguish mixed precision variants + if self.D.element != self.C.element: + extended += f"_d{DataTypeNames[self.D.element]}" + + kernel_name_template = "cutlass{p}_xe{ar}_{op}_{ex}{ct}{cs}_{l}_{s}_align{al}{t}{k}{e}" + return kernel_name_template.format( + p = self.prefix, + ar = self.arch, + op = opcode_class_name, + ex = extended, + ct = '_' + 'x'.join([str(i) for i in tile_shape]) if tile_shape[0] > 0 else "", + cs = '_' + 'x'.join([str(i) for i in self.tile_description.cluster_shape]), + l = self.tile_description.stages, + s = self.layout_name_3x(), + al = str(max(self.A.alignment, self.B.alignment)), + t = TileSchedulerSuffixes[self.tile_scheduler], + k = self.kernel_schedule_name_3x(), + e = self.epilogue_schedule_name_3x()) + else: + # Legacy naming for non-3x Intel Xe operations + threadblock = self.tile_description.procedural_name() + extended = self.extended_name() + + # Add D type suffix if different from C type to distinguish mixed precision variants + if self.D.element != self.C.element: + extended += f"_d{DataTypeNames[self.D.element]}" + + return "cutlass{p}_xe{ar}_{op}_{ex}_{tb}_{l}_align{a}".format( + p = self.prefix, + ar = self.arch, + op = opcode_class_name, + ex = extended, + tb = threadblock, + l = self.layout_name(), + a = str(max(self.A.alignment, self.B.alignment))) # def configuration_name(self): @@ -998,33 +1041,38 @@ def emit(self, operation): epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule] if opcode_class_main == OpcodeClass.BlockScaledTensorOp: - is_no_smem_epilogue = operation.epilogue_schedule in [EpilogueScheduleType.NoSmemWarpSpecialized1Sm, EpilogueScheduleType.NoSmemWarpSpecialized2Sm] grouped = is_grouped(operation.gemm_kind) if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped): epi_tile_mn = "cute::Shape" - if not is_no_smem_epilogue: + if is_tma_epilogue(operation.epilogue_schedule): epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)] if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped): epi_tile_mn = "cute::Shape" - if not is_no_smem_epilogue: + if is_tma_epilogue(operation.epilogue_schedule): epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)] - if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: - epi_tile_mn = "cute::Shape" - if not is_no_smem_epilogue: - epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized1Sm] - if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: - epi_tile_mn = "cute::Shape" - if not is_no_smem_epilogue: - epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized2Sm] - - if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: + # SM103 FP4 Ultra + is_sm103_fp4_ultra_1sm_kernel_schedule = operation.kernel_schedule in [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch, grouped) + ] + is_sm103_fp4_ultra_2sm_kernel_schedule = operation.kernel_schedule in [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch, grouped) + ] + if cta_n == 256 and is_sm103_fp4_ultra_1sm_kernel_schedule: epi_tile_mn = "cute::Shape" - if not is_no_smem_epilogue: - epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized1Sm] - if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: + if is_tma_epilogue(operation.epilogue_schedule): + epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)] + if cta_n == 256 and is_sm103_fp4_ultra_2sm_kernel_schedule: epi_tile_mn = "cute::Shape" - if not is_no_smem_epilogue: - epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized2Sm] + if is_tma_epilogue(operation.epilogue_schedule): + epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)] element_a = f'cute::tuple<{str(element_a)},{str(DataTypeTag[operation.ScaleFactorA])}>' element_b = f'cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>' @@ -1156,9 +1204,11 @@ def emit(self, operation): 'blockwise_prepare_code' : blockwise_prepare_code } - # Overriding values for Intel Xe + # Overriding values for Intel Xe architectures if operation.is_xe: - values['arch'] = "cutlass::arch::IntelXe" + # Use specific compute capability for Intel Xe GPUs + # e.g., cutlass::arch::Xe20 for BMG, cutlass::arch::Xe12 for PVC + values['arch'] = "cutlass::arch::Xe%d" % operation.arch return SubstituteTemplate(self.gemm_template, values) @@ -1473,7 +1523,13 @@ def emit(self, operation): class EmitGemmConfigurationLibrary: def __init__(self, operation_path, configuration_name): self.configuration_name = configuration_name - self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/') + + # Determine file extension based on architecture + # Intel Xe architectures (12=PVC, 20=BMG) use .cpp, CUDA uses .cu + # Check if operation_path contains /12/, /20/, xe12, or xe20 + is_xe_arch = any(marker in operation_path for marker in ['/12/', '\\12\\', 'xe12', '/20/', '\\20\\', 'xe20']) + file_extension = "cpp" if is_xe_arch else "cu" + self.configuration_path = os.path.join(operation_path, "%s.%s" % (configuration_name, file_extension)).replace('\\', '/') self.instance_emitter = { GemmKind.Gemm: EmitGemmInstance, diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index aa73fb8b13..48822094d5 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -1,6 +1,7 @@ ################################################################################################# # # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (C) 2025 Intel Corporation, All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -92,6 +93,7 @@ def _add_package_disablement_flag(argparser): from cutlass_library.manifest import * from cutlass_library.heuristics import * from cutlass_library.emit_kernel_listing import emit_gemm_kernel_testlist + from cutlass_library.arch_constants import INTEL_XE12, INTEL_XE20, INTEL_XE35 except ImportError: from library import * from manifest import * @@ -199,11 +201,12 @@ def CreateGemmUniversal3xOperator( operations = [] # by default, only generate the largest tile and largest alignment - if manifest.kernel_filter == '': + # but generate all tiles when --kernels=all is specified + if manifest.kernel_filter == '' or manifest.kernel_filter == 'all': if len(tile_descriptions) == 0: return operations tile_descriptions = [tile_descriptions[0]] - + combinations = product(layouts, tile_descriptions, data_types, complex_transforms, schedules, tile_schedulers) for layout, tile_description, data_type, complex_transform, schedules, tile_scheduler in combinations: kernel_schedule, epilogue_schedule = schedules @@ -5239,7 +5242,7 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmK if not CudaToolkitVersionSatisfies(cuda_version, 12, 3 if is_grouped(gemm_kind) else 0): return - instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=100, default_level=131, exhaustive_level=9992) + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=131, exhaustive_level=9992) is_aligned = True # layouts for ABC and their alignments. @@ -5305,7 +5308,7 @@ def GenerateSM90_TensorOp_16b_WGMMA_alignx_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): return - instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=100, default_level=101, exhaustive_level=9992) + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=101, exhaustive_level=9992) is_aligned = False # layouts for ABC and their alignments. @@ -5366,7 +5369,7 @@ def GenerateSM90_SparseTensorOp_16b_WGMMA_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 2): return - instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=100, default_level=131, exhaustive_level=9992) + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=131, exhaustive_level=9992) is_aligned = True # layouts for ABC and their alignments. @@ -5431,7 +5434,7 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): return - instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=120, default_level=121, exhaustive_level=9992) + instantiation_level = manifest.get_instantiation_level(pruned_level=120, default_level=121, exhaustive_level=9992) is_aligned = True # layouts for ABC and their alignments @@ -5489,7 +5492,7 @@ def GenerateSM90_TensorOp_tf32_WGMMA_alignx_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): return - instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=100, default_level=101, exhaustive_level=9992) + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=101, exhaustive_level=9992) is_aligned = False # layouts for ABC and their alignments. @@ -5546,7 +5549,7 @@ def GenerateSM90_SparseTensorOp_tf32_WGMMA_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 2): return - instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=120, default_level=121, exhaustive_level=9992) + instantiation_level = manifest.get_instantiation_level(pruned_level=120, default_level=121, exhaustive_level=9992) is_aligned = True # layouts for ABC and their alignments @@ -5601,7 +5604,7 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): return - instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=100, default_level=111, exhaustive_level=9992) + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=111, exhaustive_level=9992) is_aligned = True # layouts for ABC and their alignments @@ -5653,7 +5656,7 @@ def GenerateSM90_TensorOp_int8_WGMMA_alignx_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): return - instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=100, default_level=111, exhaustive_level=9992) + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=111, exhaustive_level=9992) is_aligned = False # layouts for ABC and their alignments @@ -5705,7 +5708,7 @@ def GenerateSM90_SparseTensorOp_int8_WGMMA_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 2): return - instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=100, default_level=111, exhaustive_level=9992) + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=111, exhaustive_level=9992) is_aligned = True # layouts for ABC and their alignments @@ -5760,7 +5763,7 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmK if not CudaToolkitVersionSatisfies(cuda_version, 12, 3 if is_grouped(gemm_kind) else 0): return - instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9992) + instantiation_level = manifest.get_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9992) is_aligned = True # layouts for ABC and their alignments @@ -5826,7 +5829,7 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm_with_blockwise(manifest, cuda_version, if not CudaToolkitVersionSatisfies(cuda_version, 12, 3 if is_grouped(gemm_kind) else 0): return - instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9992) + instantiation_level = manifest.get_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9992) is_aligned = True # layouts for ABC and their alignments @@ -5908,7 +5911,7 @@ def GenerateSM90_TensorOp_fp8_WGMMA_alignx_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): return - instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=0, default_level=101, exhaustive_level=9992) + instantiation_level = manifest.get_instantiation_level(pruned_level=0, default_level=101, exhaustive_level=9992) is_aligned = False # layouts for ABC and their alignments @@ -5965,7 +5968,7 @@ def GenerateSM90_TensorOp_mixed_dtype_WGMMA_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 1): return - instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9999) + instantiation_level = manifest.get_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9999) is_aligned = True # layouts for ABC, their alignments will be fixed later based on the data type @@ -6056,7 +6059,7 @@ def GenerateSM90_SparseTensorOp_fp8_WGMMA_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 2): return - instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9992) + instantiation_level = manifest.get_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9992) is_aligned = True # layouts for ABC and their alignments @@ -6721,6 +6724,31 @@ def GenerateSM90_TensorOp_1684_symm_complex_gaussian(manifest, cuda_version): # Blackwell SM 100 generators +try: + import cutlass_library.sm100_utils + from cutlass_library.sm100_utils import ( + generate_tf32_math_instructions_sm100, + generate_16b_math_instructions_sm100, + generate_f8f6f4_math_instructions_sm100, + generate_mxf8f6f4_math_instructions_sm100, + generate_mxf4nvf4_math_instructions_sm100, + generate_fp8_math_instructions_sm100, + generate_cluster_shapes_sm100, + get_pruning_level_from_global_level + ) +except ImportError: + import sm100_utils + from sm100_utils import ( + generate_tf32_math_instructions_sm100, + generate_16b_math_instructions_sm100, + generate_f8f6f4_math_instructions_sm100, + generate_mxf8f6f4_math_instructions_sm100, + generate_mxf4nvf4_math_instructions_sm100, + generate_fp8_math_instructions_sm100, + generate_cluster_shapes_sm100, + get_pruning_level_from_global_level + ) + ################################################################################################### def get_tma_alignment_elt(data_type : DataType, is_f8f6f4 : bool = True ) -> int: @@ -6743,6 +6771,8 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): return + instantiation_level = manifest.get_instantiation_level(pruned_level=490, default_level=490, exhaustive_level=9999) + # layouts for ABC and their alignments. layouts = [ [[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], @@ -6779,36 +6809,18 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version): min_cc = 100 max_cc = thor_sm - math_instructions_1sm = [ - # tf32 -> f32 - MathInstruction( - [64, 128, 8], - DataType.tf32, DataType.tf32, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 128, 8], - DataType.tf32, DataType.tf32, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 8], - DataType.tf32, DataType.tf32, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - ] + math_instructions_1sm, math_instructions_2sm = generate_tf32_math_instructions_sm100(instantiation_level) - cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1], [4,4,1] - , DynamicClusterShape - ] - - if thor_sm in manifest.compute_capabilities_baseline: - cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1] - , DynamicClusterShape - ] + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level) + + if thor_sm in manifest.compute_capabilities_baseline : + if [4,4,1] in cluster_shapes_1sm : + cluster_shapes_1sm.remove([4,4,1]) + if [4,4,1] in cluster_shapes_2sm : + cluster_shapes_2sm.remove([4,4,1]) tile_schedulers = [ - TileSchedulerType.Default + TileSchedulerType.Default, TileSchedulerType.StreamK ] # 1xSM MMA kernels @@ -6828,38 +6840,6 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version): tile_schedulers=tile_schedulers) # 2xSM MMA kernels - math_instructions_2sm = [ - MathInstruction( - [128, 128, 8], - DataType.tf32, DataType.tf32, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 8], - DataType.tf32, DataType.tf32, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [256, 128, 8], - DataType.tf32, DataType.tf32, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [256, 256, 8], - DataType.tf32, DataType.tf32, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - ] - - cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1] - , DynamicClusterShape - ] - - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1] - , DynamicClusterShape - ] - for math_inst in math_instructions_2sm: tile_descriptions = [] for cluster_shape in cluster_shapes_2sm: @@ -6883,6 +6863,8 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): return + instantiation_level = manifest.get_instantiation_level(pruned_level=490, default_level=490, exhaustive_level=9999) + # layouts for ABC and their alignments. C alignment will be set later based on output type layouts = [ [[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 0]], @@ -6897,76 +6879,22 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK thor_sm = ThorSMRenumbering(cuda_version) + math_instructions_1sm, math_instructions_2sm = generate_16b_math_instructions_sm100(instantiation_level) + min_cc = 100 max_cc = thor_sm grouped = is_grouped(gemm_kind) - math_instructions_1sm = [ - # f16 -> f16 - #MathInstruction( - # [64, 64, 16], - # DataType.f16, DataType.f16, DataType.f16, - # OpcodeClass.TensorOp, - # MathOperation.multiply_add), - MathInstruction( - [64, 128, 16], - DataType.f16, DataType.f16, DataType.f16, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 128, 16], - DataType.f16, DataType.f16, DataType.f16, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 16], - DataType.f16, DataType.f16, DataType.f16, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - # f16 -> f32 - MathInstruction( - [64, 128, 16], - DataType.f16, DataType.f16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 128, 16], - DataType.f16, DataType.f16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 16], - DataType.f16, DataType.f16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - # bf16 -> f32 - MathInstruction( - [64, 128, 16], - DataType.bf16, DataType.bf16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 128, 16], - DataType.bf16, DataType.bf16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 16], - DataType.bf16, DataType.bf16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add)] - - cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1],[4,4,1] - , DynamicClusterShape - ] + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level) if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1] - , DynamicClusterShape - ] + if [4,4,1] in cluster_shapes_1sm : + cluster_shapes_1sm.remove([4,4,1]) + if [4,4,1] in cluster_shapes_2sm : + cluster_shapes_2sm.remove([4,4,1]) tile_schedulers = [ - TileSchedulerType.Default + TileSchedulerType.Default, TileSchedulerType.StreamK ] # 1xSM MMA kernels @@ -7039,90 +6967,6 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) # 2xSM MMA kernels - math_instructions_2sm = [ - # 128x64x16 - #MathInstruction( - # [128, 64, 16], - # DataType.f16, DataType.f16, DataType.f16, - # OpcodeClass.TensorOp, - # MathOperation.multiply_add), - # 128x128x16 - MathInstruction( - [128, 128, 16], - DataType.f16, DataType.f16, DataType.f16, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 128, 16], - DataType.f16, DataType.f16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 128, 16], - DataType.bf16, DataType.bf16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - - # 128x256x16 - MathInstruction( - [128, 256, 16], - DataType.f16, DataType.f16, DataType.f16, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 16], - DataType.f16, DataType.f16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 16], - DataType.bf16, DataType.bf16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - - # 256x128x16 - MathInstruction( - [256, 128, 16], - DataType.f16, DataType.f16, DataType.f16, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [256, 128, 16], - DataType.f16, DataType.f16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [256, 128, 16], - DataType.bf16, DataType.bf16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - - # 256x256x16 - MathInstruction( - [256, 256, 16], - DataType.f16, DataType.f16, DataType.f16, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [256, 256, 16], - DataType.f16, DataType.f16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [256, 256, 16], - DataType.bf16, DataType.bf16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add)] - - cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1] - , DynamicClusterShape - ] - - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1] - , DynamicClusterShape - ] - for math_inst in math_instructions_2sm: tile_descriptions = [] for cluster_shape in cluster_shapes_2sm: @@ -7199,6 +7043,8 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): return + instantiation_level = manifest.get_instantiation_level(pruned_level=591 , default_level=591 , exhaustive_level=9999) + # layouts for ABC and their alignments. layouts = [ [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], @@ -7219,77 +7065,18 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK epi_type = DataType.f32 grouped = is_grouped(gemm_kind) - math_instructions_1sm = [ - # inst 64x128 - MathInstruction( - [64, 128, 32], - DataType.f8, DataType.f8, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [64, 128, 32], - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [64, 128, 32], - DataType.e4m3, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [64, 128, 32], - DataType.e5m2, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - # inst 128x128 - MathInstruction( - [128, 128, 32], - DataType.f8, DataType.f8, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 128, 32], - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 128, 32], - DataType.e4m3, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 128, 32], - DataType.e5m2, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - # inst 128x256 - MathInstruction( - [128, 256, 32], - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 32], - DataType.e4m3, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 32], - DataType.e5m2, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add)] + math_instructions_1sm, math_instructions_2sm = generate_fp8_math_instructions_sm100(instantiation_level, enable_runtime_dtype=not grouped) - cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1], [4,4,1] - , DynamicClusterShape - ] + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level) if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1] - , DynamicClusterShape - ] + if [4,4,1] in cluster_shapes_1sm : + cluster_shapes_1sm.remove([4,4,1]) + if [4,4,1] in cluster_shapes_2sm : + cluster_shapes_2sm.remove([4,4,1]) tile_schedulers = [ - TileSchedulerType.Default, + TileSchedulerType.Default, TileSchedulerType.StreamK ] # 1xSM MMA kernels @@ -7418,86 +7205,6 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) # 2xSM MMA kernels - math_instructions_2sm = [ - # inst 128x128 - MathInstruction( - [128, 128, 32], - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 128, 32], - DataType.f8, DataType.f8, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 128, 32], - DataType.e4m3, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 128, 32], - DataType.e5m2, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - # inst 128x256 - MathInstruction( - [128, 256, 32], - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 32], - DataType.e4m3, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 32], - DataType.e5m2, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - # inst 256x128 - MathInstruction( - [256, 128, 32], - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [256, 128, 32], - DataType.e4m3, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [256, 128, 32], - DataType.e5m2, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - # inst 256x256 - MathInstruction( - [256, 256, 32], - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [256, 256, 32], - DataType.e4m3, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [256, 256, 32], - DataType.e5m2, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add) - ] - - cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1] - , DynamicClusterShape - ] - - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1] - , DynamicClusterShape - ] for math_inst in math_instructions_2sm: tile_descriptions = [] @@ -7633,6 +7340,8 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): return + instantiation_level = manifest.get_instantiation_level(pruned_level=593, default_level=593, exhaustive_level=9999) + grouped = is_grouped(gemm_kind) # layouts for ABC and their alignments. @@ -7651,111 +7360,11 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, max_cc = 100 epi_type = DataType.f32 - math_instructions_1sm = [ - # inst 64x128 - MathInstruction( - [64, 128, 32], - DataType.f8, DataType.f8, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [64, 128, 32], - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [64, 128, 32], - DataType.e4m3, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [64, 128, 32], - DataType.e5m2, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - # inst 128x32 - MathInstruction( - [128, 32, 32], - DataType.f8, DataType.f8, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 32, 32], - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 32, 32], - DataType.e4m3, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 32, 32], - DataType.e5m2, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - # inst 128x64 - MathInstruction( - [128, 64, 32], - DataType.f8, DataType.f8, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 64, 32], - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 64, 32], - DataType.e4m3, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 64, 32], - DataType.e5m2, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - # inst 128x128 - MathInstruction( - [128, 128, 32], - DataType.f8, DataType.f8, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 128, 32], - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 128, 32], - DataType.e4m3, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 128, 32], - DataType.e5m2, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - # inst 128x256 - MathInstruction( - [128, 256, 32], - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 32], - DataType.e4m3, DataType.e5m2, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 32], - DataType.e5m2, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add)] + pruning_level = get_pruning_level_from_global_level(instantiation_level) - cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1], [4,4,1] - , DynamicClusterShape - ] + math_instructions_1sm, math_instructions_2sm = generate_fp8_math_instructions_sm100(instantiation_level, enable_compile_time_dtype=grouped or pruning_level >= 1, enable_runtime_dtype=not grouped) + + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level) tile_schedulers = [ TileSchedulerType.Default, @@ -7861,40 +7470,36 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, kernel_schedule = to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100, grouped) epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped) + epi_schedule_nosmem = to_grouped_schedule(EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm, grouped) CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, - [[kernel_schedule, epi_schedule]], + [[kernel_schedule, epi_schedule], [kernel_schedule, epi_schedule_nosmem]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) -def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version): +def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x): + # SM100 MMA with mixed F4/F6/F8 inputs + without block scale if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): return + instantiation_level = manifest.get_instantiation_level(pruned_level=590, default_level=590, exhaustive_level=9999) + + grouped = is_grouped(gemm_kind) + # layouts for ABC and their alignments. layouts = [ [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], ] - instruction_sizes_1sm = [ - # [64, 128, 32], - [128, 128, 32], - # [64, 256, 32], - [128, 256, 32], - ] - - instruction_sizes_2sm = [ - # [128, 128, 32], - # [128, 256, 32], - [256, 128, 32], - [256, 256, 32], - ] + math_instructions_1sm, math_instructions_2sm = generate_f8f6f4_math_instructions_sm100(instantiation_level, enable_runtime_dtype=not grouped) - ab_types = [ - DataType.f4, DataType.f6, DataType.f8, - DataType.e2m1, DataType.e3m2, DataType.e4m3, - ] + def change_priority_func(shapes_1sm, shapes_2sm): + shapes_1sm[(1,2,1)] = 6 + shapes_1sm[(1,4,1)] = 6 + shapes_2sm[(2,2,1)] = 6 + shapes_2sm[(2,4,1)] = 6 + shapes_2sm[(4,2,1)] = 6 - acc_types = [ DataType.f32 ] + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level, change_priority_func) tile_schedulers = [ TileSchedulerType.Default, TileSchedulerType.StreamK @@ -7907,61 +7512,13 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version): epi_type = DataType.f32 - math_instructions_1sm = [] - is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) - # Usage: - - for instr_size, a_type, b_type, acc_type in product(instruction_sizes_1sm, ab_types, ab_types, acc_types): - is_runtime_datatype_a = is_runtime_datatype(a_type) - is_runtime_datatype_b = is_runtime_datatype(b_type) - - # A/B datatypes should be both static or dynamic - if (is_runtime_datatype_a != is_runtime_datatype_b): - continue - - math_instructions_1sm.append( - MathInstruction( - instr_size, - a_type, b_type, acc_type, - OpcodeClass.TensorOp, - MathOperation.multiply_add) - ) - - math_instructions_2sm = [] - - for instr_size, a_type, b_type, acc_type in product(instruction_sizes_2sm, ab_types, ab_types, acc_types): - is_runtime_datatype_a = is_runtime_datatype(a_type) - is_runtime_datatype_b = is_runtime_datatype(b_type) - - # A/B datatypes should be both static or dynamic - if (is_runtime_datatype_a != is_runtime_datatype_b): - continue - - math_instructions_2sm.append( - MathInstruction( - instr_size, - a_type, b_type, acc_type, - OpcodeClass.TensorOp, - MathOperation.multiply_add) - ) - - cluster_shapes_1sm = [ - # [1,2,1], - [2,1,1], - [1,1,1], - # [1,4,1], - [4,4,1] - , DynamicClusterShape - ] - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_1sm = [ - [2,1,1], - [1,1,1] - , DynamicClusterShape - ] + if [4,4,1] in cluster_shapes_1sm : + cluster_shapes_1sm.remove([4,4,1]) + if [4,4,1] in cluster_shapes_2sm : + cluster_shapes_2sm.remove([4,4,1]) # 1xSM MMA kernels for math_inst in math_instructions_1sm: @@ -8022,22 +7579,6 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version): CreateGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], tile_schedulers=tile_schedulers) - cluster_shapes_2sm = [ - [2,1,1], - # [2,2,1], - # [2,4,1], - # [4,1,1], - # [4,2,1], - [4,4,1] - , DynamicClusterShape - ] - - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_2sm = [ - [2,1,1] - , DynamicClusterShape - ] - for math_inst in math_instructions_2sm: tile_descriptions = [] for cluster_shape in cluster_shapes_2sm: @@ -8101,25 +7642,31 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version): [[KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto]], tile_schedulers=tile_schedulers) def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x): + # SM100 MMA with mixed F4/F6/F8 inputs + block scale if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): return + instantiation_level = manifest.get_instantiation_level(pruned_level=590, default_level=590, exhaustive_level=9999) + grouped = is_grouped(gemm_kind) layouts = [ [[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.ColumnMajor, 0]], [[LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 128], [LayoutType.RowMajor, 0]], ] - instruction_sizes_1sm = [ - [128, 128, 32], [128, 256, 32], # Block scaled kernels only support M=128 for 1SM cases - ] + math_instructions_1sm, math_instructions_2sm = generate_mxf8f6f4_math_instructions_sm100(instantiation_level, enable_runtime_dtype=not grouped) - instruction_sizes_2sm = [ - [256, 128, 32], - [256, 256, 32], - ] + def change_priority_func(shapes_1sm, shapes_2sm): + shapes_1sm[(1,2,1)] = 6 + shapes_1sm[(1,4,1)] = 6 + shapes_2sm[(2,2,1)] = 6 + shapes_2sm[(2,4,1)] = 6 + shapes_2sm[(4,2,1)] = 6 + + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level, change_priority_func) ab_types = [ DataType.f4, DataType.f6, @@ -8147,73 +7694,26 @@ def tile_schedulers(sfdtype): epi_type = DataType.f32 - math_instructions_1sm = [] - is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) - for instr_size, a_type, b_type, acc_type in product(instruction_sizes_1sm, ab_types, ab_types, acc_types): - is_runtime_datatype_a = is_runtime_datatype(a_type) - is_runtime_datatype_b = is_runtime_datatype(b_type) - - # A/B datatypes should be both static or dynamic - if (is_runtime_datatype_a != is_runtime_datatype_b): - continue - - math_instructions_1sm.append( - MathInstruction( - instr_size, - a_type, b_type, acc_type, - OpcodeClass.BlockScaledTensorOp, - MathOperation.multiply_add, - DataType.ue8m0) - ) - - math_instructions_2sm = [] + if thor_sm in manifest.compute_capabilities_baseline : + if [4,4,1] in cluster_shapes_1sm : + cluster_shapes_1sm.remove([4,4,1]) + if [4,4,1] in cluster_shapes_2sm : + cluster_shapes_2sm.remove([4,4,1]) - for instr_size, a_type, b_type, acc_type in product(instruction_sizes_2sm, ab_types, ab_types, acc_types): - is_runtime_datatype_a = is_runtime_datatype(a_type) - is_runtime_datatype_b = is_runtime_datatype(b_type) - - # A/B datatypes should be both static or dynamic - if (is_runtime_datatype_a != is_runtime_datatype_b): - continue - - math_instructions_2sm.append( - MathInstruction( - instr_size, - a_type, b_type, acc_type, - OpcodeClass.BlockScaledTensorOp, - MathOperation.multiply_add, - DataType.ue8m0) - ) - - cluster_shapes_1sm = [ - [1,1,1], - # [1,2,1], - [2,1,1], - # [1,4,1], - [4,4,1] - , DynamicClusterShape - ] - - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_1sm = [ - [1,1,1], - [2,1,1] - , DynamicClusterShape - ] - - # 1xSM MMA kernels - for math_inst in math_instructions_1sm: - tile_descriptions = [] - for cluster_shape in cluster_shapes_1sm: - multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_1sm[0], - math_inst.instruction_shape[1] * multiplier_1sm[1], - math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + assert math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) data_types = [ { @@ -8276,24 +7776,8 @@ def tile_schedulers(sfdtype): [[to_grouped_schedule(KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100, grouped), to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)]] , tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) - cluster_shapes_2sm = [ - [2,1,1], - # [2,2,1], - # [2,4,1], - [4,1,1], - # [4,2,1], - [4,4,1] - , DynamicClusterShape - ] - - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_2sm = [ - [2,1,1], - [4,1,1] - , DynamicClusterShape - ] - for math_inst in math_instructions_2sm: + assert math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp tile_descriptions = [] for cluster_shape in cluster_shapes_2sm: multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) @@ -8403,6 +7887,8 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): return + instantiation_level = manifest.get_instantiation_level(pruned_level=591, default_level=591, exhaustive_level=9999) + grouped = is_grouped(gemm_kind) # layouts for ABC and their alignments. @@ -8411,21 +7897,16 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.ColumnMajor, 0]], ] - instruction_sizes_1sm = [ - [128, 64, 64], - [128, 128, 64], - ] + math_instructions_1sm, math_instructions_2sm = generate_mxf4nvf4_math_instructions_sm100(instantiation_level, enable_runtime_dtype=not grouped) - instruction_sizes_2sm = [ - [256, 64, 64], - [256, 128, 64], - [256, 192, 64], [256, 256, 64] - ] + def change_priority_func(shapes_1sm, shapes_2sm): + shapes_1sm[(1,2,1)] = 6 + shapes_1sm[(1,4,1)] = 6 + shapes_2sm[(2,2,1)] = 6 + shapes_2sm[(2,4,1)] = 6 + shapes_2sm[(4,2,1)] = 6 - ab_types = [ - DataType.f4, - DataType.e2m1, - ] + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level, change_priority_func=change_priority_func) acc_types = [ DataType.f32 ] # Accumulator is always 32 bits for block scaled MMA instructions @@ -8444,80 +7925,17 @@ def tile_schedulers(sfdtype): epi_type = DataType.f32 - math_instructions_1sm = [] - is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) - for instr_size, a_type, b_type, acc_type in product(instruction_sizes_1sm, ab_types, ab_types, acc_types): - is_runtime_datatype_a = is_runtime_datatype(a_type) - is_runtime_datatype_b = is_runtime_datatype(b_type) - - # A/B datatypes should be both static or dynamic - if (is_runtime_datatype_a != is_runtime_datatype_b): - continue - - math_instructions_1sm.append( - MathInstruction( - instr_size, - a_type, b_type, acc_type, - OpcodeClass.BlockScaledTensorOp, - MathOperation.multiply_add, - DataType.ue8m0) # UE8M0 scale factor - ) - math_instructions_1sm.append( - MathInstruction( - instr_size, - a_type, b_type, acc_type, - OpcodeClass.BlockScaledTensorOp, - MathOperation.multiply_add, - DataType.ue4m3) # UE4M3 scale factor - ) - - math_instructions_2sm = [] - - for instr_size, a_type, b_type, acc_type in product(instruction_sizes_2sm, ab_types, ab_types, acc_types): - is_runtime_datatype_a = is_runtime_datatype(a_type) - is_runtime_datatype_b = is_runtime_datatype(b_type) - - # A/B datatypes should be both static or dynamic - if (is_runtime_datatype_a != is_runtime_datatype_b): - continue - - math_instructions_2sm.append( - MathInstruction( - instr_size, - a_type, b_type, acc_type, - OpcodeClass.BlockScaledTensorOp, - MathOperation.multiply_add, - DataType.ue8m0) # UE8M0 scale factor - ) - math_instructions_2sm.append( - MathInstruction( - instr_size, - a_type, b_type, acc_type, - OpcodeClass.BlockScaledTensorOp, - MathOperation.multiply_add, - DataType.ue4m3) # UE4M3 scale factor - ) - - cluster_shapes_1sm = [ - [1,1,1], - # [1,2,1], - [2,1,1], - # [1,4,1], - [4,4,1] - , DynamicClusterShape - ] - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_1sm = [ - [1,1,1], - [2,1,1] - , DynamicClusterShape - ] + if [4,4,1] in cluster_shapes_1sm : + cluster_shapes_1sm.remove([4,4,1]) + if [4,4,1] in cluster_shapes_2sm : + cluster_shapes_2sm.remove([4,4,1]) # 1xSM MMA kernels for math_inst in math_instructions_1sm: + assert math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp tile_descriptions = [] for cluster_shape in cluster_shapes_1sm: multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape @@ -8527,6 +7945,7 @@ def tile_schedulers(sfdtype): math_inst.instruction_shape[1] * multiplier_1sm[1], math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + assert math_inst.instruction_shape[2] * 4 == 256 data_types = [ { @@ -8626,37 +8045,22 @@ def tile_schedulers(sfdtype): # E2M1 x E2M1, vector size 16, UE4M3 isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped) + epi_nosmem_schedule = to_grouped_schedule(EpilogueScheduleType.NoSmemWarpSpecialized1Sm, grouped) nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped) fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100, grouped) - nvfp4_schedule = [nvfp4_kernel_schedule, epi_schedule] - fp4_schedule = [fp4_kernel_schedule, epi_schedule] - CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [nvfp4_schedule] + nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule], [nvfp4_kernel_schedule, epi_nosmem_schedule]] + fp4_schedules = [[fp4_kernel_schedule, epi_schedule], [fp4_kernel_schedule, epi_nosmem_schedule]] + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, nvfp4_schedules , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind ) if isFp4: - CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [fp4_schedule] + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, fp4_schedules , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind ) - cluster_shapes_2sm = [ - [2,1,1], - # [2,2,1], - # [2,4,1], - [4,1,1], - # [4,2,1], - [4,4,1] - , DynamicClusterShape - ] - - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_2sm = [ - [2,1,1], - [4,1,1] - , DynamicClusterShape - ] - for math_inst in math_instructions_2sm: + assert math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp tile_descriptions = [] for cluster_shape in cluster_shapes_2sm: multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) @@ -8765,27 +8169,29 @@ def tile_schedulers(sfdtype): isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 epi_schedule = EpilogueScheduleType.ScheduleAuto if not grouped else EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm + epi_nosmem_schedule = to_grouped_schedule(EpilogueScheduleType.NoSmemWarpSpecialized2Sm, grouped) nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped) fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100, grouped) - nvfp4_schedule = [nvfp4_kernel_schedule, epi_schedule] - fp4_schedule = [fp4_kernel_schedule, epi_schedule] - CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [nvfp4_schedule] + nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule], [nvfp4_kernel_schedule, epi_nosmem_schedule]] + fp4_schedules = [[fp4_kernel_schedule, epi_schedule], [fp4_kernel_schedule, epi_nosmem_schedule]] + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, nvfp4_schedules , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) if isFp4: - CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [fp4_schedule] + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, fp4_schedules , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) - - def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x): # SM100 MMA with F4 + block scale - if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + if not CudaToolkitVersionSatisfies(cuda_version, 13, 0): return + grouped = is_grouped(gemm_kind) + # layouts for ABC and their alignments. layouts = [ [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.ColumnMajor, 0]], ] instruction_sizes_1sm = [ @@ -8794,14 +8200,32 @@ def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_ instruction_sizes_2sm = [ [256, 128, 96], + [256, 192, 96], + [256, 256, 96] ] ab_types = [ + DataType.f4, DataType.e2m1, ] + sf_types = [ + DataType.ue4m3, + DataType.ue8m0 + ] + acc_types = [ DataType.f32 ] # Accumulator is always 32 bits for block scaled MMA instructions + def tile_schedulers(sfdtype): + # Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void, + # the epilogue is the traditional linear combination, for which we already have tests with stream-K. + if grouped: + return [TileSchedulerType.Default] + if sfdtype["type"] == DataType.void: + return [TileSchedulerType.Default] + else: + return [TileSchedulerType.Default, TileSchedulerType.StreamK] + min_cc = 103 max_cc = 103 epi_type = DataType.f32 @@ -8810,7 +8234,7 @@ def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_ is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) - for instr_size, a_type, b_type, acc_type in product(instruction_sizes_1sm, ab_types, ab_types, acc_types): + for instr_size, a_type, b_type, sf_type, acc_type in product(instruction_sizes_1sm, ab_types, ab_types, sf_types, acc_types): is_runtime_datatype_a = is_runtime_datatype(a_type) is_runtime_datatype_b = is_runtime_datatype(b_type) @@ -8824,12 +8248,12 @@ def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_ a_type, b_type, acc_type, OpcodeClass.BlockScaledTensorOp, MathOperation.multiply_add, - DataType.ue8m0) # UE8M0 scale factor + sf_type) ) math_instructions_2sm = [] - for instr_size, a_type, b_type, acc_type in product(instruction_sizes_2sm, ab_types, ab_types, acc_types): + for instr_size, a_type, b_type, sf_type, acc_type in product(instruction_sizes_2sm, ab_types, ab_types, sf_types, acc_types): is_runtime_datatype_a = is_runtime_datatype(a_type) is_runtime_datatype_b = is_runtime_datatype(b_type) @@ -8843,7 +8267,7 @@ def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_ a_type, b_type, acc_type, OpcodeClass.BlockScaledTensorOp, MathOperation.multiply_add, - DataType.ue8m0) # UE8M0 scale factor + sf_type) ) cluster_shapes_1sm = [ @@ -8851,15 +8275,15 @@ def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_ # [1,2,1], [2,1,1], # [1,4,1], - [4,4,1] - , DynamicClusterShape + [4,4,1], + DynamicClusterShape ] # 1xSM MMA kernels for math_inst in math_instructions_1sm: tile_descriptions = [] for cluster_shape in cluster_shapes_1sm: - multiplier_1sm = cluster_shape + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape tile_descriptions.append( TileDescription([ math_inst.instruction_shape[0] * multiplier_1sm[0], @@ -8898,8 +8322,69 @@ def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_ "sf_type" : math_inst.element_scale_factor, "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + } ] + # Set alignment d based on Destination format. for layout in layouts: for data_type in data_types: # Set alignment d based on Destination format. @@ -8908,21 +8393,29 @@ def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_ else: layout[2][1] = min(256 // DataTypeSize[data_type["d_type"]], 256 // DataTypeSize[data_type["c_type"]]) - if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1): + if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.RowMajor): data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.ColumnMajor): + continue # E2M1 x E2M1, vector size 32, E8 isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 - fp4_schedule = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103, EpilogueScheduleType.NoSmemWarpSpecialized1Sm] - fp4_schedule_disable_prefetch = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch, EpilogueScheduleType.NoSmemWarpSpecialized1Sm] - fp4_schedule_enable_prefetch = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch, EpilogueScheduleType.NoSmemWarpSpecialized1Sm] - # For FP4 inputs + epilogue_1sm_schedule = to_grouped_schedule(EpilogueScheduleType.NoSmemWarpSpecialized1Sm, grouped) + + nvfp4_schedule = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103, grouped), epilogue_1sm_schedule] + nvfp4_schedule_disable_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch, grouped), epilogue_1sm_schedule] + nvfp4_schedule_tma_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch, grouped), epilogue_1sm_schedule] + fp4_schedule = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103, grouped), epilogue_1sm_schedule] + fp4_schedule_disable_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch, grouped), epilogue_1sm_schedule] + fp4_schedule_tma_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch, grouped), epilogue_1sm_schedule] + nvfp4_schedules = [nvfp4_schedule, nvfp4_schedule_disable_prefetch, nvfp4_schedule_tma_prefetch] + fp4_schedules = [fp4_schedule, fp4_schedule_disable_prefetch, fp4_schedule_tma_prefetch] + + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, + nvfp4_schedules, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) if isFp4: - CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [fp4_schedule, fp4_schedule_disable_prefetch - ,fp4_schedule_enable_prefetch - ] - , gemm_kind=gemm_kind - ) + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, + fp4_schedules, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) cluster_shapes_2sm = [ [2,1,1], @@ -8930,14 +8423,14 @@ def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_ # [2,4,1], [4,1,1], # [4,2,1], - [4,4,1] - , DynamicClusterShape + [4,4,1], + DynamicClusterShape ] for math_inst in math_instructions_2sm: tile_descriptions = [] for cluster_shape in cluster_shapes_2sm: - multiplier_2sm = (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) tile_descriptions.append( TileDescription([ math_inst.instruction_shape[0] * multiplier_2sm[0], @@ -8954,7 +8447,7 @@ def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_ "acc_type" : math_inst.element_accumulator, "epi_type" : epi_type, "sf_type" : math_inst.element_scale_factor, - "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} }, { "a_type" : math_inst.element_a, @@ -8966,7 +8459,6 @@ def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_ "sf_type" : math_inst.element_scale_factor, "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} }, - { "a_type" : math_inst.element_a, "b_type" : math_inst.element_b, @@ -8977,59 +8469,129 @@ def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_ "sf_type" : math_inst.element_scale_factor, "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} }, - ] - - for layout in layouts: - for data_type in data_types: - # Set alignment d based on Destination format. - if DataTypeSize[data_type["c_type"]] == 0 : - layout[2][1] = 256 // DataTypeSize[data_type["d_type"]] - else: - layout[2][1] = min(256 // DataTypeSize[data_type["d_type"]], 256 // DataTypeSize[data_type["c_type"]]) - - if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1): - data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. - # E2M1 x E2M1, vector size 32, E8 - isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 - - fp4_schedule = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103, EpilogueScheduleType.NoSmemWarpSpecialized2Sm] - fp4_schedule_disable_prefetch = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch, EpilogueScheduleType.NoSmemWarpSpecialized2Sm] - fp4_schedule_enable_prefetch = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch, EpilogueScheduleType.NoSmemWarpSpecialized2Sm] - # For FP4 inputs - if isFp4: - CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [fp4_schedule, fp4_schedule_disable_prefetch - ,fp4_schedule_enable_prefetch - ] - , gemm_kind=gemm_kind - ) - -def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): - return - - # layouts for ABC and their alignments. - layouts = [ - [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], - [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], - [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], - [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], - [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], - [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], - [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], - [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], - ] - - thor_sm = ThorSMRenumbering(cuda_version) - - min_cc = 100 - max_cc = thor_sm - - epi_type = DataType.f32 - - math_instructions_1sm = [ - MathInstruction( - [64, 128, 32], - DataType.s8, DataType.s8, DataType.s32, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + } + ] + + # Set alignment d based on Destination format. + for layout in layouts: + for data_type in data_types: + # Set alignment d based on Destination format. + if DataTypeSize[data_type["c_type"]] == 0 : + layout[2][1] = 256 // DataTypeSize[data_type["d_type"]] + else: + layout[2][1] = min(256 // DataTypeSize[data_type["d_type"]], 256 // DataTypeSize[data_type["c_type"]]) + + if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.RowMajor): + data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.ColumnMajor): + continue + # E2M1 x E2M1, vector size 32, E8 + isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 + + epilogue_2sm_schedule = to_grouped_schedule(EpilogueScheduleType.NoSmemWarpSpecialized2Sm, grouped) + + nvfp4_schedule = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103, grouped), epilogue_2sm_schedule] + nvfp4_schedule_disable_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch, grouped), epilogue_2sm_schedule] + nvfp4_schedule_tma_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch, grouped), epilogue_2sm_schedule] + fp4_schedule = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103, grouped), epilogue_2sm_schedule] + fp4_schedule_disable_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch, grouped), epilogue_2sm_schedule] + fp4_schedule_tma_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch, grouped), epilogue_2sm_schedule] + nvfp4_schedules = [nvfp4_schedule, nvfp4_schedule_disable_prefetch, nvfp4_schedule_tma_prefetch] + fp4_schedules = [fp4_schedule, fp4_schedule_disable_prefetch, fp4_schedule_tma_prefetch] + + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, + nvfp4_schedules, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + if isFp4: + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, + fp4_schedules, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + + +def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], + ] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + epi_type = DataType.f32 + + math_instructions_1sm = [ + MathInstruction( + [64, 128, 32], + DataType.s8, DataType.s8, DataType.s32, OpcodeClass.TensorOp, MathOperation.multiply_add), MathInstruction( @@ -9053,7 +8615,7 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version): ] tile_schedulers = [ - TileSchedulerType.Default, + TileSchedulerType.Default, TileSchedulerType.StreamK ] # 1xSM MMA kernels @@ -9242,7 +8804,7 @@ def GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version): max_cc = thor_sm tile_schedulers = [ - TileSchedulerType.Default, + TileSchedulerType.Default, TileSchedulerType.StreamK ] kernel_data_types = [ @@ -9370,7 +8932,7 @@ def GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version): max_cc = thor_sm tile_schedulers = [ - TileSchedulerType.Default, + TileSchedulerType.Default, TileSchedulerType.StreamK ] kernel_data_types = [ @@ -9498,7 +9060,7 @@ def GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version): max_cc = thor_sm tile_schedulers = [ - TileSchedulerType.Default, + TileSchedulerType.Default, TileSchedulerType.StreamK ] kernel_data_types = [ @@ -9625,7 +9187,7 @@ def GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version): max_cc = thor_sm tile_schedulers = [ - TileSchedulerType.Default, + TileSchedulerType.Default, TileSchedulerType.StreamK ] kernel_data_types = [ @@ -9766,595 +9328,183 @@ def GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version): max_cc = thor_sm tile_schedulers = [ - TileSchedulerType.Default, - ] - - math_instructions_1sm = [ - # Runtime Dtype - MathInstruction( - [128, 128, 64], - DataType.f4, DataType.f4, DataType.f32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 64], - DataType.f4, DataType.f4, DataType.f32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add), - - MathInstruction( - [128, 128, 64], - DataType.f6, DataType.f6, DataType.f32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add), - MathInstruction( - [128, 256, 64], - DataType.f6, DataType.f6, DataType.f32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add), - ] - - math_instructions_2sm = [ - # Runtime DType - MathInstruction( - [256, 128, 64], - DataType.f4, DataType.f4, DataType.f32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add), - MathInstruction( - [256, 256, 64], - DataType.f4, DataType.f4, DataType.f32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add), - - MathInstruction( - [256, 128, 64], - DataType.f6, DataType.f6, DataType.f32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add), - MathInstruction( - [256, 256, 64], - DataType.f6, DataType.f6, DataType.f32, - OpcodeClass.SparseTensorOp, - MathOperation.multiply_add), - ] - - # 1xSM MMA kernels - for math_inst in math_instructions_1sm: - tile_descriptions = [] - for cluster_shape in sm100_cluster_shape_1sm: - if thor_sm in manifest.compute_capabilities_baseline : - if cluster_shape == [4,4,1] : - continue - multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_1sm[0], - math_inst.instruction_shape[1] * multiplier_1sm[1], - math_inst.instruction_shape[2] * 2 * multiplier_1sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - kernel_data_types = [ - # void_c - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.f16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : DataType.f32, - }, - # none void_c - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.f16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : DataType.f32, - }, - ] - - for kernel_data_type in kernel_data_types: - # Update layout alignment - # alignment for d might be different for each kernel_data_type - layouts_filtered = [] - for layout in layouts: - layout_filter = copy.deepcopy(layout) - # * A_K : Logical TileShape_K % 256 == 0 - # * A_M : TileShape_M % 128 == 0 - # * B_N : TileSize_N % 128 == 0 - # * B_K : TileSize_K % 128 == 0 - if ((layout_filter[0][0] == LayoutType.RowMajor and (math_inst.instruction_shape[2] * 2) % 256 == 0) or \ - (layout_filter[0][0] == LayoutType.ColumnMajor and math_inst.instruction_shape[0] % 128 == 0)) and \ - ((layout_filter[1][0] == LayoutType.RowMajor and math_inst.instruction_shape[1] % 128 == 0) or \ - (layout_filter[1][0] == LayoutType.ColumnMajor and (math_inst.instruction_shape[0] * 2) % 128 == 0)): - # alignment for a, 2 for sparsity - layout_filter[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) - # alignment for b - layout_filter[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) - # alignment for d - layout_filter[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) - layouts_filtered.append(layout_filter) - - CreateSparseGemmUniversal3xOperator(manifest, layouts_filtered, tile_descriptions, [kernel_data_type], - [[KernelScheduleType.SparseTmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], - tile_schedulers=tile_schedulers) - - # 2xSM MMA kernels - for math_inst in math_instructions_2sm: - tile_descriptions = [] - for cluster_shape in sm100_cluster_shape_2sm: - if thor_sm in manifest.compute_capabilities_baseline : - if cluster_shape == [4,4,1] : - continue - multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_2sm[0], - math_inst.instruction_shape[1] * multiplier_2sm[1], - math_inst.instruction_shape[2] * 2 * multiplier_2sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - kernel_data_types = [ - # void_c - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.f16, - "d_type" : DataType.f16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : DataType.f32, - }, - # none void_c - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : DataType.void, - "d_type" : DataType.f16, - "acc_type" : math_inst.element_accumulator, - "epi_type" : DataType.f32, - }, - ] - - for kernel_data_type in kernel_data_types: - # Update layout alignment - # alignment for d might be different for each kernel_data_type - layouts_filtered = [] - for layout in layouts: - layout_filter = copy.deepcopy(layout) - # * A_K : Logical TileShape_K % 256 == 0 - # * A_M : TileShape_M % 128 == 0 - # * B_N : TileSize_N % 256 == 0 - # * B_K : TileSize_K % 128 == 0 - if ((layout_filter[0][0] == LayoutType.RowMajor and (math_inst.instruction_shape[2] * 2) % 256 == 0) or \ - (layout_filter[0][0] == LayoutType.ColumnMajor and math_inst.instruction_shape[0] % 128 == 0)) and \ - ((layout_filter[1][0] == LayoutType.RowMajor and math_inst.instruction_shape[1] % 256 == 0) or \ - (layout_filter[1][0] == LayoutType.ColumnMajor and (math_inst.instruction_shape[0] * 2) % 128 == 0)): - # alignment for a, 2 for sparsity - layout_filter[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) - # alignment for b - layout_filter[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) - # alignment for d - layout_filter[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) - layouts_filtered.append(layout_filter) - - CreateSparseGemmUniversal3xOperator(manifest, layouts_filtered, tile_descriptions, [kernel_data_type], - [[KernelScheduleType.SparseTmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], - tile_schedulers=tile_schedulers) - - -# -# Kernels using the stream-K tile scheduler. -# A reduced set of kernels is generated for these schedulers to reduce functional -# and perofrmance testing time. -# - -def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): - return - - # layouts for ABC and their alignments. - layouts = [ - [[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], - [[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]], - [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], - [[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]], - [[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4]], - - ] - - data_types = [ - { - "a_type" : DataType.f32, - "b_type" : DataType.f32, - "c_type" : DataType.f32, - "d_type" : DataType.f32, - "acc_type" : DataType.f32, - "epi_type" : DataType.f32, - } - ] - - thor_sm = ThorSMRenumbering(cuda_version) - - min_cc = 100 - max_cc = thor_sm - - math_instructions_1sm = [ - MathInstruction( - [128, 256, 8], - DataType.tf32, DataType.tf32, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - ] - - cluster_shapes_1sm = [ - [1,2,1], [1,1,1], [1,4,1], [4,4,1] - , DynamicClusterShape - ] - - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_1sm = [ - [1,2,1], [1,1,1], [1,4,1] - , DynamicClusterShape - ] - - tile_schedulers = [ - TileSchedulerType.StreamK, - ] - - # 1xSM MMA kernels - for math_inst in math_instructions_1sm: - tile_descriptions = [] - for cluster_shape in cluster_shapes_1sm: - multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_1sm[0], - math_inst.instruction_shape[1] * multiplier_1sm[1], - math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, - [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], - tile_schedulers=tile_schedulers) - - # 2xSM MMA kernels - math_instructions_2sm = [ - MathInstruction( - [256, 256, 8], - DataType.tf32, DataType.tf32, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - ] - - cluster_shapes_2sm = [ - [2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,4,1] - , DynamicClusterShape - ] - - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_2sm = [ - [2,1,1], [2,2,1], [2,4,1], [4,1,1] - , DynamicClusterShape - ] - - for math_inst in math_instructions_2sm: - tile_descriptions = [] - for cluster_shape in cluster_shapes_2sm: - multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_2sm[0], - math_inst.instruction_shape[1] * multiplier_2sm[1], - math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - if math_inst.instruction_shape[0] == 128: - epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm - else: - epi_schedule = EpilogueScheduleType.ScheduleAuto - - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, - [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) - -def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): - return - - # layouts for ABC and their alignments. C alignment will be set later based on output type - layouts = [ - [[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 0]], - [[LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 0]], - [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 0]], - [[LayoutType.RowMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 0]], - [[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 0]], - ] - - thor_sm = ThorSMRenumbering(cuda_version) - - min_cc = 100 - max_cc = thor_sm - - math_instructions_1sm = [ - MathInstruction( - [128, 256, 16], - DataType.f16, DataType.f16, DataType.f16, - OpcodeClass.TensorOp, - MathOperation.multiply_add)] - - cluster_shapes_1sm = [ - [1,2,1], [1,1,1], [4,4,1] - , DynamicClusterShape - ] - - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_1sm = [ - [1,2,1], [1,1,1] - , DynamicClusterShape - ] - - tile_schedulers = [ - TileSchedulerType.StreamK - ] - - # 1xSM MMA kernels - for math_inst in math_instructions_1sm: - tile_descriptions = [] - for cluster_shape in cluster_shapes_1sm: - multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_1sm[0], - math_inst.instruction_shape[1] * multiplier_1sm[1], - math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - data_types = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : math_inst.element_accumulator, - "d_type" : math_inst.element_accumulator, - "acc_type" : math_inst.element_accumulator, - "epi_type" : math_inst.element_accumulator, - } - ] - # Set alignment d based on Destination format. - for layout in layouts: - layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] - - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, - [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], - tile_schedulers=tile_schedulers) - - # for mixed precision kernels, also generate kernels that write output matrix in the A/B format - # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) - if math_inst.element_a != math_inst.element_accumulator: - data_types_mixed = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : math_inst.element_a, - "d_type" : math_inst.element_a, - "acc_type" : math_inst.element_accumulator, - "epi_type" : math_inst.element_accumulator, - } - ] - # Set alignment d based on Destination format. - for layout in layouts: - layout[2][1] = 128 // DataTypeSize[data_types_mixed[0]["d_type"]] - - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types_mixed, - [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], - tile_schedulers=tile_schedulers) - - # 2xSM MMA kernels - math_instructions_2sm = [ - MathInstruction( - [256, 256, 16], - DataType.f16, DataType.f16, DataType.f16, - OpcodeClass.TensorOp, - MathOperation.multiply_add)] - - cluster_shapes_2sm = [ - [2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,4,1] - , DynamicClusterShape - ] - - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_2sm = [ - [2,1,1], [2,2,1], [2,4,1], [4,1,1] - , DynamicClusterShape - ] - - for math_inst in math_instructions_2sm: - tile_descriptions = [] - for cluster_shape in cluster_shapes_2sm: - multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) - tile_descriptions.append( - TileDescription([ - math_inst.instruction_shape[0] * multiplier_2sm[0], - math_inst.instruction_shape[1] * multiplier_2sm[1], - math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], - 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - - data_types = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : math_inst.element_accumulator, - "d_type" : math_inst.element_accumulator, - "acc_type" : math_inst.element_accumulator, - "epi_type" : math_inst.element_accumulator, - } - ] - # Set alignment d based on Destination format. - for layout in layouts: - layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] - - if math_inst.instruction_shape[0] == 128: - epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm - else: - epi_schedule = EpilogueScheduleType.ScheduleAuto - - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, - [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) - - # for mixed precision kernels, also generate kernels that write output matrix in the A/B format - # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) - if math_inst.element_a != math_inst.element_accumulator: - data_types_mixed = [ - { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : math_inst.element_a, - "d_type" : math_inst.element_a, - "acc_type" : math_inst.element_accumulator, - "epi_type" : math_inst.element_accumulator, - } - ] - # Set alignment d based on Destination format. - for layout in layouts: - layout[2][1] = 128 // DataTypeSize[data_types_mixed[0]["d_type"]] - - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types_mixed, - [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) - - -def GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): - return - - # layouts for ABC and their alignments. - layouts = [ - [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], - [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], - [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], - [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], - [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], - ] - - thor_sm = ThorSMRenumbering(cuda_version) - - min_cc = 100 - max_cc = thor_sm - - epi_type = DataType.f32 - - math_instructions_1sm = [ - MathInstruction( - [128, 256, 32], - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add)] - - cluster_shapes_1sm = [ - [1,2,1], [2,1,1], [1,1,1], [4,4,1] - , DynamicClusterShape + TileSchedulerType.Default, TileSchedulerType.StreamK ] - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_1sm = [ - [1,2,1], [2,1,1], [1,1,1] - , DynamicClusterShape - ] + math_instructions_1sm = [ + # Runtime Dtype + MathInstruction( + [128, 128, 64], + DataType.f4, DataType.f4, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 64], + DataType.f4, DataType.f4, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + + MathInstruction( + [128, 128, 64], + DataType.f6, DataType.f6, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 64], + DataType.f6, DataType.f6, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + ] - tile_schedulers = [ - TileSchedulerType.StreamK, + math_instructions_2sm = [ + # Runtime DType + MathInstruction( + [256, 128, 64], + DataType.f4, DataType.f4, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 256, 64], + DataType.f4, DataType.f4, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + + MathInstruction( + [256, 128, 64], + DataType.f6, DataType.f6, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 256, 64], + DataType.f6, DataType.f6, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), ] # 1xSM MMA kernels for math_inst in math_instructions_1sm: tile_descriptions = [] - for cluster_shape in cluster_shapes_1sm: + for cluster_shape in sm100_cluster_shape_1sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape tile_descriptions.append( TileDescription([ math_inst.instruction_shape[0] * multiplier_1sm[0], math_inst.instruction_shape[1] * multiplier_1sm[1], - math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + math_inst.instruction_shape[2] * 2 * multiplier_1sm[2]], 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - data_types = [ + kernel_data_types = [ + # void_c { "a_type" : math_inst.element_a, "b_type" : math_inst.element_b, "c_type" : DataType.f16, - "d_type" : DataType.e4m3, + "d_type" : DataType.f16, "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }] + "epi_type" : DataType.f32, + }, + # none void_c + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32, + }, + ] - # Set alignment d based on Destination format. - for layout in layouts: - layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + for kernel_data_type in kernel_data_types: + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_filtered = [] + for layout in layouts: + layout_filter = copy.deepcopy(layout) + # * A_K : Logical TileShape_K % 256 == 0 + # * A_M : TileShape_M % 128 == 0 + # * B_N : TileSize_N % 128 == 0 + # * B_K : TileSize_K % 128 == 0 + if ((layout_filter[0][0] == LayoutType.RowMajor and (math_inst.instruction_shape[2] * 2) % 256 == 0) or \ + (layout_filter[0][0] == LayoutType.ColumnMajor and math_inst.instruction_shape[0] % 128 == 0)) and \ + ((layout_filter[1][0] == LayoutType.RowMajor and math_inst.instruction_shape[1] % 128 == 0) or \ + (layout_filter[1][0] == LayoutType.ColumnMajor and (math_inst.instruction_shape[0] * 2) % 128 == 0)): + # alignment for a, 2 for sparsity + layout_filter[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout_filter[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout_filter[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + layouts_filtered.append(layout_filter) - for data_type in data_types: - if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\ - ( data_type["d_type"] == DataType.e5m2 ): - continue - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, - [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + CreateSparseGemmUniversal3xOperator(manifest, layouts_filtered, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], tile_schedulers=tile_schedulers) # 2xSM MMA kernels - math_instructions_2sm = [ - MathInstruction( - [256, 256, 32], - DataType.e4m3, DataType.e4m3, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - ] - - cluster_shapes_2sm = [ - [2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,4,1] - , DynamicClusterShape - ] - - if thor_sm in manifest.compute_capabilities_baseline : - cluster_shapes_2sm = [ - [2,1,1], [2,2,1], [2,4,1], [4,1,1] - , DynamicClusterShape - ] - for math_inst in math_instructions_2sm: tile_descriptions = [] - for cluster_shape in cluster_shapes_2sm: + for cluster_shape in sm100_cluster_shape_2sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) tile_descriptions.append( TileDescription([ math_inst.instruction_shape[0] * multiplier_2sm[0], math_inst.instruction_shape[1] * multiplier_2sm[1], - math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + math_inst.instruction_shape[2] * 2 * multiplier_2sm[2]], 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) - data_types = [ + kernel_data_types = [ + # void_c { "a_type" : math_inst.element_a, "b_type" : math_inst.element_b, "c_type" : DataType.f16, - "d_type" : DataType.e4m3, + "d_type" : DataType.f16, "acc_type" : math_inst.element_accumulator, - "epi_type" : epi_type, - }] - - # Set alignment d based on Destination format. - for layout in layouts: - layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + "epi_type" : DataType.f32, + }, + # none void_c + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32, + }, + ] - for data_type in data_types: - if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\ - ( data_type["d_type"] == DataType.e5m2 ): - continue + for kernel_data_type in kernel_data_types: + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_filtered = [] + for layout in layouts: + layout_filter = copy.deepcopy(layout) + # * A_K : Logical TileShape_K % 256 == 0 + # * A_M : TileShape_M % 128 == 0 + # * B_N : TileSize_N % 256 == 0 + # * B_K : TileSize_K % 128 == 0 + if ((layout_filter[0][0] == LayoutType.RowMajor and (math_inst.instruction_shape[2] * 2) % 256 == 0) or \ + (layout_filter[0][0] == LayoutType.ColumnMajor and math_inst.instruction_shape[0] % 128 == 0)) and \ + ((layout_filter[1][0] == LayoutType.RowMajor and math_inst.instruction_shape[1] % 256 == 0) or \ + (layout_filter[1][0] == LayoutType.ColumnMajor and (math_inst.instruction_shape[0] * 2) % 128 == 0)): + # alignment for a, 2 for sparsity + layout_filter[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout_filter[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout_filter[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + layouts_filtered.append(layout_filter) - if math_inst.instruction_shape[0] == 128: - epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm - else: - epi_schedule = EpilogueScheduleType.ScheduleAuto + CreateSparseGemmUniversal3xOperator(manifest, layouts_filtered, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], + tile_schedulers=tile_schedulers) - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, - [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) # Conv Utility functions def make_dims_and_alignments_triple(dim: int, bit_per_element_A: int, bit_per_element_B: int, bit_per_element_C: int): bit_alignment_required_by_tma = 128 @@ -11240,9 +10390,6 @@ def GenerateSM100(manifest, cuda_version): GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version) GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version) - GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version) - - GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version) if not bool(set(manifest.compute_capabilities_feature_set).intersection(arch_family_cc)): GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version) @@ -11252,8 +10399,6 @@ def GenerateSM100(manifest, cuda_version): GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x) GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x) - GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version) - # StreamK is included in regular generation GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version) @@ -11280,6 +10425,7 @@ def GenerateSM100(manifest, cuda_version): GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x) GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_version) + GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x) # # Conv # @@ -11709,67 +10855,314 @@ def GenerateSM90(manifest, cuda_version): ################################################################################################### -def GeneratePVC_TensorOp_16b_gemm(manifest, cuda_version): +def GeneratePVC(manifest, cuda_version): + """ + Generate CUTLASS kernels for PVC (Ponte Vecchio) architecture. + + PVC is Intel's Xe-HPC GPU architecture with compute capability 12. + + This is a legacy wrapper that calls GenerateIntelXe with arch=INTEL_XE12. + """ + GenerateIntelXe(manifest, cuda_version, arch=INTEL_XE12) + +################################################################################################### +def GenerateXe_TensorOp_16b_DPAS_gemm(manifest, cuda_version, min_cc=20): + """Generate FP16/BF16 GEMM kernels for Intel Xe architecture using DPAS. + + :param min_cc: Architecture number (12 for PVC, 20 for BMG) + """ + layout_list = [ + [[LayoutType.RowMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.RowMajor, 8]], + [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 8]], + [[LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.RowMajor, 8]], + [[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 8]], + ] + + math_instructions = [ + MathInstruction( + [8, 16, 16], + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [8, 16, 16], + DataType.f16, DataType.f16, DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [8, 16, 16], + DataType.bf16, DataType.bf16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [8, 16, 16], + DataType.bf16, DataType.bf16, DataType.bf16, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ] + + max_cc = min_cc + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 256, 32], + 0, [8, 4, 1], math_inst, min_cc, max_cc, [1, 1, 1]), + TileDescription([128, 256, 32], + 0, [4, 8, 1], math_inst, min_cc, max_cc, [1, 1, 1]), + TileDescription([256, 128, 32], + 0, [8, 4, 1], math_inst, min_cc, max_cc, [1, 1, 1]), + TileDescription([128, 128, 32], + 0, [4, 4, 1], math_inst, min_cc, max_cc, [1, 1, 1]), + TileDescription([64, 128, 32], + 0, [2, 4, 1], math_inst, min_cc, max_cc, [1, 1, 1]), + ] + + # Generate kernels for different output (D) types + # Default: accumulator type (FP32 for mixed precision, same as input for native precision) + # For mixed precision (a_type != accumulator): also generate output in input precision + valid_d_types = [math_inst.element_accumulator] + if math_inst.element_a != math_inst.element_accumulator: + valid_d_types.append(math_inst.element_a) + + for d_type in valid_d_types: + # Generate operations both with and without bias (ElementC) + for c_type in [math_inst.element_accumulator, DataType.void]: + data_type = { + "a_type": math_inst.element_a, + "b_type": math_inst.element_b, + "c_type": c_type, + "d_type": d_type, + "acc_type": math_inst.element_accumulator, + "epi_type": math_inst.element_accumulator + } + + schedules = [[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto]] + + CreateGemmUniversal3xOperator(manifest, layout_list, tile_descriptions, data_type, schedules, tile_schedulers=[TileSchedulerType.Persistent]) + +def GenerateXe_TensorOp_fp8_DPAS_gemm(manifest, cuda_version, min_cc=20): + """Generate FP8 (E4M3/E5M2) GEMM kernels for Intel Xe architecture using DPAS. + + Supported combinations for regular GEMM: + - [e4m3, e4m3, fp32]: E4M3 x E4M3 -> FP32 (homogeneous) + - [e5m2, e5m2, fp32]: E5M2 x E5M2 -> FP32 (homogeneous) + + Note: Mixed precision (FP16/BF16 x FP8) requires grouped GEMM infrastructure + and is NOT supported for regular library generation. + + :param min_cc: Architecture number (12 for PVC, 20 for BMG) + """ layout_list = [ - [ - [[LayoutType.RowMajor, 2], [LayoutType.RowMajor, 2], [LayoutType.RowMajor, 4]], - [[LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 2], [LayoutType.RowMajor, 4]], - [[LayoutType.ColumnMajor, 2], [LayoutType.RowMajor, 2], [LayoutType.RowMajor, 4]], - [[LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 2], [LayoutType.RowMajor, 4]], - ], - [ - [[LayoutType.RowMajor, 2], [LayoutType.RowMajor, 2], [LayoutType.RowMajor, 2]], - [[LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 2], [LayoutType.RowMajor, 2]], - [[LayoutType.ColumnMajor, 2], [LayoutType.RowMajor, 2], [LayoutType.RowMajor, 2]], - [[LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 2], [LayoutType.RowMajor, 2]], - ] + [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 8]], + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 8]], + [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 8]], + [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 8]], ] + # FP8 math instructions for Intel Xe + # Only homogeneous types (same A and B type) for regular GEMM math_instructions = [ - MathInstruction( - [8, 16, 16], - DataType.bf16, DataType.bf16, DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add), - MathInstruction( - [8, 16, 16], - DataType.bf16, DataType.bf16, DataType.bf16, - OpcodeClass.TensorOp, - MathOperation.multiply_add) + # Homogeneous FP8 (same type for A and B) - SUPPORTED + MathInstruction( + [8, 16, 32], + DataType.e4m3, DataType.e4m3, DataType.f32, # E4M3 x E4M3 -> FP32 + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [8, 16, 32], + DataType.e5m2, DataType.e5m2, DataType.f32, # E5M2 x E5M2 -> FP32 + OpcodeClass.TensorOp, + MathOperation.multiply_add), + + # DISABLED: Mixed precision FP16/BF16 x FP8 requires grouped GEMM + # These would need MainloopIntelXeXMX16GroupMixedPrecision which is only + # activated when IsGroup=true (KernelXePtrArrayCooperative schedule). + # Regular library GEMMs use MainloopIntelXeXMX16 which requires ElementA == ElementB. + # + # MathInstruction([8, 16, 32], DataType.f16, DataType.e5m2, DataType.f32, ...), + # MathInstruction([8, 16, 32], DataType.f16, DataType.e4m3, DataType.f32, ...), + # MathInstruction([8, 16, 32], DataType.bf16, DataType.e5m2, DataType.f32, ...), + # MathInstruction([8, 16, 32], DataType.bf16, DataType.e4m3, DataType.f32, ...), ] - min_cc = 11 - max_cc = 11 - - for math_inst, layouts in zip(math_instructions, layout_list): - tile_descriptions = [ - TileDescription([256, 256, 32], - 0, [8, 4, 1], math_inst, min_cc, max_cc, [1, 1, 1]), - TileDescription([128, 512, 32], - 0, [4, 8, 1], math_inst, min_cc, max_cc, [1, 1, 1]), - TileDescription([256, 128, 32], - 0, [8, 4, 1], math_inst, min_cc, max_cc, [1, 1, 1]), - TileDescription([128, 256, 16], - 0, [4, 8, 1], math_inst, min_cc, max_cc, [1, 1, 1]), - TileDescription([8, 128, 32], - 0, [1, 4, 1], math_inst, min_cc, max_cc, [1, 1, 1]), - ] + max_cc = min_cc - data_type = { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : math_inst.element_accumulator, - "d_type" : math_inst.element_accumulator, - "acc_type" : math_inst.element_accumulator, - "epi_type" : math_inst.element_accumulator - } + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 256, 64], + 0, [8, 4, 1], math_inst, min_cc, max_cc, [1, 1, 1]), + TileDescription([128, 256, 64], + 0, [4, 8, 1], math_inst, min_cc, max_cc, [1, 1, 1]), + TileDescription([256, 128, 64], + 0, [8, 4, 1], math_inst, min_cc, max_cc, [1, 1, 1]), + TileDescription([128, 128, 64], + 0, [4, 4, 1], math_inst, min_cc, max_cc, [1, 1, 1]), + ] + + # Generate kernels for different output (D) types + # Valid D types for FP8: fp32 (accumulator), bf16, fp16, e4m3, e5m2 + valid_d_types = [DataType.f32, DataType.bf16, DataType.f16, DataType.e4m3, DataType.e5m2] + + for d_type in valid_d_types: + data_type = { + "a_type": math_inst.element_a, + "b_type": math_inst.element_b, + "c_type": math_inst.element_accumulator, + "d_type": d_type, + "acc_type": math_inst.element_accumulator, + "epi_type": math_inst.element_accumulator + } + + schedules = [[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto]] + + CreateGemmUniversal3xOperator(manifest, layout_list, tile_descriptions, data_type, schedules, tile_schedulers=[TileSchedulerType.Persistent]) + +def GenerateXe_TensorOp_int8_DPAS_gemm(manifest, cuda_version, min_cc=20): + """Generate INT8 GEMM kernels for Intel Xe architecture using DPAS. + + Supported: [int8, int8, int32] -> INT32 accumulator (hardware requirement) + + :param min_cc: Architecture number (12 for PVC, 20 for BMG) + """ + layout_list = [ + [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 4]], + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 4]], + [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 4]], + [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 4]], + ] + + # INT8 x INT8 -> INT32 (hardware requirement for Intel Xe) + math_instructions = [ + MathInstruction( + [8, 16, 32], + DataType.s8, DataType.s8, DataType.s32, # Changed from f32 to s32 + OpcodeClass.TensorOp, + MathOperation.multiply_add), + ] + + max_cc = min_cc + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 256, 64], + 0, [8, 4, 1], math_inst, min_cc, max_cc, [1, 1, 1]), + TileDescription([128, 256, 64], + 0, [4, 8, 1], math_inst, min_cc, max_cc, [1, 1, 1]), + TileDescription([256, 128, 64], + 0, [8, 4, 1], math_inst, min_cc, max_cc, [1, 1, 1]), + TileDescription([128, 128, 64], + 0, [4, 4, 1], math_inst, min_cc, max_cc, [1, 1, 1]), + ] + + # Generate kernels for different output (D) types + # Default: accumulator type (INT32) + # Also generate output in input precision (INT8) for quantized workflows + valid_d_types = [math_inst.element_accumulator, math_inst.element_a] + + for d_type in valid_d_types: + data_type = { + "a_type": math_inst.element_a, + "b_type": math_inst.element_b, + "c_type": math_inst.element_accumulator, + "d_type": d_type, + "acc_type": math_inst.element_accumulator, + "epi_type": math_inst.element_accumulator + } - schedules = [[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto]] + schedules = [[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto]] - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, schedules, tile_schedulers=[TileSchedulerType.Persistent]) + CreateGemmUniversal3xOperator(manifest, layout_list, tile_descriptions, data_type, schedules, tile_schedulers=[TileSchedulerType.Persistent]) -def GeneratePVC(manifest, cuda_version): - GeneratePVC_TensorOp_16b_gemm(manifest, cuda_version) + +def GenerateXe_TensorOp_mixed_dtype_DPAS_gemm(manifest, cuda_version, min_cc=20): + """Generate mixed-precision GEMM kernels for Intel Xe architecture using DPAS. + + Supported: [fp16, int4, fp32] -> FP16 x INT4 with FP32 accumulator + + :param min_cc: Architecture number (12 for PVC, 20 for BMG) + """ + layout_list = [ + [[LayoutType.RowMajor, 8], [LayoutType.RowMajor, 32], [LayoutType.RowMajor, 8]], + [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 8]], + [[LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 32], [LayoutType.RowMajor, 8]], + [[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 8]], + ] + + # Mixed precision: FP16 x INT4 -> FP32 (hardware requirement for Intel Xe) + math_instructions = [ + MathInstruction( + [8, 16, 32], + DataType.f16, DataType.s4, DataType.f32, # Changed from [s8, f16, f32] to [f16, s4, f32] + OpcodeClass.TensorOp, + MathOperation.multiply_add), + ] + + max_cc = min_cc + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 256, 64], + 0, [8, 4, 1], math_inst, min_cc, max_cc, [1, 1, 1]), + TileDescription([128, 256, 64], + 0, [4, 8, 1], math_inst, min_cc, max_cc, [1, 1, 1]), + TileDescription([256, 128, 64], + 0, [8, 4, 1], math_inst, min_cc, max_cc, [1, 1, 1]), + ] + + data_type = { + "a_type": math_inst.element_a, + "b_type": math_inst.element_b, + "c_type": math_inst.element_accumulator, + "d_type": math_inst.element_accumulator, + "acc_type": math_inst.element_accumulator, + "epi_type": math_inst.element_accumulator + } + + schedules = [[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto]] + + CreateGemmUniversal3xOperator(manifest, layout_list, tile_descriptions, data_type, schedules, tile_schedulers=[TileSchedulerType.Persistent]) + + +def GenerateBMG(manifest, cuda_version): + """ + Generate CUTLASS kernels for BMG (Battlemage/Xe2) architecture. + + BMG is Intel's Xe2 GPU architecture with compute capability 20. + Supports DPAS operations with FP16, BF16, FP8, and INT8 data types. + + This is a legacy wrapper that calls GenerateIntelXe with arch=INTEL_XE20. + """ + GenerateIntelXe(manifest, cuda_version, arch=INTEL_XE20) + +def GenerateIntelXe(manifest, cuda_version, arch): + """ + Unified generator for Intel Xe GPU architectures. + + Supports both PVC (arch 12) and BMG (arch 20) with the same generation code. + The operations are identical, only the architecture number differs. + + Supported data types: + - FP16/BF16: [fp16/bf16, fp16/bf16, fp32] + - INT8: [int8, int8, int32] + - FP8: [fp8, fp8, fp32] (E4M3 or E5M2, same types only) + - Mixed: [fp16, int4, fp32] + + :param manifest: Manifest object to add operations to + :param cuda_version: CUDA version string (used for compatibility) + :param arch: Architecture number (12 for PVC, 20 for BMG) + """ + if arch not in [INTEL_XE12, INTEL_XE20]: + raise ValueError(f"Unsupported Intel Xe architecture: {arch}. Supported: {INTEL_XE12} (PVC), {INTEL_XE20} (BMG)") + + # All Intel Xe architectures use the same generation functions + # Only the min_cc (architecture number) differs + GenerateXe_TensorOp_16b_DPAS_gemm(manifest, cuda_version, min_cc=arch) + #DISABLED: FP8 GEMMs are not yet ready. Will be enabled once the tests are ready + #GenerateXe_TensorOp_fp8_DPAS_gemm(manifest, cuda_version, min_cc=arch) + #GenerateXe_TensorOp_int8_DPAS_gemm(manifest, cuda_version, min_cc=arch) + # DISABLED: Mixed precision (FP16 x INT4) requires grouped GEMM infrastructure + # Regular library generation uses MainloopIntelXeXMX16 which requires ElementA == ElementB + # GenerateXe_TensorOp_mixed_dtype_DPAS_gemm(manifest, cuda_version, min_cc=arch) ################################################################################################### @@ -11797,7 +11190,7 @@ def define_parser(): parser.add_argument("--build-dir", default=".", required=False, help="CUTLASS top-level build directory") parser.add_argument("--curr-build-dir", default=".", help="CUTLASS current build directory. cmake files will be emitted in this directory") parser.add_argument("--generator-target", default='library', help="Target of CUTLASS Library Generator.") - parser.add_argument("--architectures", default='53;60;61;70;75;80;90', help="Target compute architectures") + parser.add_argument("--architectures", default='53;60;61;70;75;80;90;100', help="Target compute architectures") parser.add_argument("--kernels", default='', help='Comma-delimited list to filter kernels by name. ' + 'Specifying this as \"all\" includes ALL the kernels, ' + 'while not specifying this includes only the default set of kernels.') @@ -11865,6 +11258,21 @@ def define_parser(): GenerateSM100(manifest, args.cuda_version) GenerateSM120(manifest, args.cuda_version) + # Intel Xe GPU architectures - unified handling for PVC and BMG + # Both architectures share the same generation code, just different arch numbers + + # Check for BMG (architecture INTEL_XE20) + bmg_arch_list = [str(INTEL_XE20), "bmg", "xe2", "intel_gpu_bmg_g21"] + bmg_enabled_arch = any(arch.lower() in [x.lower() for x in bmg_arch_list] for arch in archs) + if bmg_enabled_arch: + GenerateIntelXe(manifest, args.cuda_version, arch=INTEL_XE20) + + # Check for PVC (architecture INTEL_XE12) + pvc_arch_list = [str(INTEL_XE12), "pvc", "intel_gpu_pvc"] + pvc_enabled_arch = any(arch.lower() in [x.lower() for x in pvc_arch_list] for arch in archs) + if pvc_enabled_arch: + GenerateIntelXe(manifest, args.cuda_version, arch=INTEL_XE12) + if 'library' in args.generator_target.split(','): manifest.emit(GeneratorTarget.Library) diff --git a/python/cutlass_library/heuristics_provider.py b/python/cutlass_library/heuristics_provider.py index b3f6e5c583..01a4112a34 100644 --- a/python/cutlass_library/heuristics_provider.py +++ b/python/cutlass_library/heuristics_provider.py @@ -41,6 +41,7 @@ import ctypes import functools + try: import builtins if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: diff --git a/python/cutlass_library/library.py b/python/cutlass_library/library.py index 1e0998d5e6..b8163d9877 100644 --- a/python/cutlass_library/library.py +++ b/python/cutlass_library/library.py @@ -65,7 +65,13 @@ class GeneratorTarget(enum.Enum): ################################################################################################### -# +# Architecture constants import with fallback for relative imports +try: + from cutlass_library.arch_constants import INTEL_XE_ARCH_MIN, INTEL_XE_ARCH_MAX, CUDA_ARCH_MIN, INTEL_XE12, INTEL_XE20 +except ImportError: + from arch_constants import INTEL_XE_ARCH_MIN, INTEL_XE_ARCH_MAX, CUDA_ARCH_MIN, INTEL_XE12, INTEL_XE20 + + class DataType(enum.Enum): void = enum_auto() # primarily used to disable C tensor for epilogues b1 = enum_auto() @@ -509,6 +515,8 @@ class KernelScheduleType(enum.Enum): BlockwiseTmaWarpSpecializedCooperative = enum_auto() PtrArrayBlockwiseTmaWarpSpecializedCooperative = enum_auto() + BlockwiseTmaWarpSpecializedPingpong = enum_auto() + PtrArrayBlockwiseTmaWarpSpecializedPingpong = enum_auto() TmaWarpSpecialized1SmSm100 = enum_auto() TmaWarpSpecialized2SmSm100 = enum_auto() @@ -548,20 +556,35 @@ class KernelScheduleType(enum.Enum): Nvf4TmaWarpSpecialized2SmSm100 = enum_auto() # FP4 Ultra - BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103 = enum_auto() - BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103 = enum_auto() - BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103 = enum_auto() - BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103 = enum_auto() - - BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch = enum_auto() - BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch = enum_auto() - BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch = enum_auto() - BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch = enum_auto() - - BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch = enum_auto() - BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch = enum_auto() - BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch = enum_auto() - BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103 = enum_auto() + MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103 = enum_auto() + MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103 = enum_auto() + MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103 = enum_auto() + + MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch = enum_auto() + + MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch = enum_auto() + + PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103 = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103 = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103 = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103 = enum_auto() + + PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch = enum_auto() + + PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch = enum_auto() Mxf8f6f4TmaWarpSpecializedCooperativeSm120 = enum_auto() Mxf8f6f4TmaWarpSpecializedPingpongSm120 = enum_auto() @@ -590,7 +613,8 @@ class KernelScheduleType(enum.Enum): KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum', KernelScheduleType.ImplicitTmaWarpSpecializedSm90: 'cutlass::conv::KernelImplicitTmaWarpSpecializedSm90', - KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum', + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8Blockwise', + KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8Blockwise', KernelScheduleType.TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmSm100', KernelScheduleType.TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmSm100', @@ -621,27 +645,28 @@ class KernelScheduleType(enum.Enum): KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100', # FP4 Ultra - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch', - - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch', + + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch', KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative', KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum', KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong', KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum', - KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum', + KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise', + KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise', KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledSm100", KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledSm100", @@ -652,6 +677,19 @@ class KernelScheduleType(enum.Enum): KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100", KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100", + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch', + KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedMxf8f6f4Sm120', KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongMxf8f6f4Sm120', KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120', @@ -682,7 +720,8 @@ class KernelScheduleType(enum.Enum): KernelScheduleType.ImplicitTmaWarpSpecializedSm90: '_warpspecialized', KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative: '_warpspecialized_cooperative', - + KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong: '_warpspecialized_pingpong', + KernelScheduleType.TmaWarpSpecialized1SmSm100: '_1sm', KernelScheduleType.TmaWarpSpecialized2SmSm100: '_2sm', @@ -710,20 +749,20 @@ class KernelScheduleType(enum.Enum): KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm', KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: '_o_vs16_1sm', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: '_o_vs16_2sm', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: '_o_vs32_1sm', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: '_o_vs32_2sm', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: '_o_vs16_ultra_1sm', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: '_o_vs16_ultra_2sm', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: '_o_vs32_ultra_1sm', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: '_o_vs32_ultra_2sm', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: '_o_vs16_1sm_nopf', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: '_o_vs16_2sm_nopf', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: '_o_vs32_1sm_nopf', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: '_o_vs32_2sm_nopf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_1sm_nopf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_2sm_nopf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_1sm_nopf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_2sm_nopf', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: '_o_vs16_1sm_tmapf', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: '_o_vs16_2sm_tmapf', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: '_o_vs32_1sm_tmapf', - KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: '_o_vs32_2sm_tmapf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_1sm_tmapf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_2sm_tmapf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_1sm_tmapf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_2sm_tmapf', KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative', KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum', @@ -731,6 +770,7 @@ class KernelScheduleType(enum.Enum): KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum', KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative: '_warpspecialized_cooperative', + KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedPingpong: '_warpspecialized_pingpong', KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: '_1sm', KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: '_2sm', @@ -741,6 +781,21 @@ class KernelScheduleType(enum.Enum): KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm', KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: '_o_vs16_ultra_1sm', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: '_o_vs16_ultra_2sm', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: '_o_vs32_ultra_1sm', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: '_o_vs32_ultra_2sm', + + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_1sm_nopf', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_2sm_nopf', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_1sm_nopf', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_2sm_nopf', + + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_1sm_tmapf', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_2sm_tmapf', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_1sm_tmapf', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_2sm_tmapf', + KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120: '_cooperative_q', KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120: '_pingpong_q', KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120: '_cooperative_o_vs16', @@ -763,10 +818,14 @@ class EpilogueScheduleType(enum.Enum): NoSmemWarpSpecialized2Sm = enum_auto() FastF32NoSmemWarpSpecialized1Sm = enum_auto() FastF32NoSmemWarpSpecialized2Sm = enum_auto() + BlockwiseNoSmemWarpSpecialized1Sm = enum_auto() + BlockwiseNoSmemWarpSpecialized2Sm = enum_auto() PtrArrayNoSmemWarpSpecialized1Sm = enum_auto() PtrArrayNoSmemWarpSpecialized2Sm = enum_auto() PtrArrayFastF32NoSmemWarpSpecialized1Sm = enum_auto() PtrArrayFastF32NoSmemWarpSpecialized2Sm = enum_auto() + PtrArrayBlockwiseNoSmemWarpSpecialized1Sm = enum_auto() + PtrArrayBlockwiseNoSmemWarpSpecialized2Sm = enum_auto() TmaWarpSpecialized = enum_auto() TmaWarpSpecializedCooperative = enum_auto() TmaWarpSpecialized1Sm = enum_auto() @@ -786,10 +845,14 @@ class EpilogueScheduleType(enum.Enum): EpilogueScheduleType.NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::NoSmemWarpSpecialized2Sm', EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized1Sm', EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized2Sm', + EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm', + EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::BlockwiseNoSmemWarpSpecialized2Sm', EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized1Sm', EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized2Sm', EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized1Sm', EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized2Sm', + EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayBlockwiseNoSmemWarpSpecialized1Sm', + EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayBlockwiseNoSmemWarpSpecialized2Sm', EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized', EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative', EpilogueScheduleType.TmaWarpSpecialized1Sm: 'cutlass::epilogue::TmaWarpSpecialized1Sm', @@ -810,16 +873,20 @@ class EpilogueScheduleType(enum.Enum): EpilogueScheduleType.NoSmemWarpSpecialized2Sm: '_epi_nosmem', EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32', EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32', + EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: '_epi_nosmem', + EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: '_epi_nosmem', EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: '_epi_nosmem', EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: '_epi_nosmem', EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32', EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32', + EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm: '_epi_nosmem', + EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm: '_epi_nosmem', EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma', EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma', EpilogueScheduleType.TmaWarpSpecialized1Sm: '', EpilogueScheduleType.TmaWarpSpecialized2Sm: '_epi_tma', - EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: '_tma_1sm', - EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: '_tma_2sm', + EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: '', + EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: '_epi_tma', EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_epi_tma', EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_epi_tma', } @@ -856,6 +923,7 @@ def to_grouped_schedule(schedule, grouped): # SM90 KernelScheduleType.TmaWarpSpecializedCooperative : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative, KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative, + KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedPingpong, KernelScheduleType.TmaWarpSpecializedPingpong : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong, KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum, KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum, @@ -875,6 +943,23 @@ def to_grouped_schedule(schedule, grouped): KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm, EpilogueScheduleType.TmaWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm, + EpilogueScheduleType.NoSmemWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm, + EpilogueScheduleType.NoSmemWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm, + EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm, + EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm, + # SM103 + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch, } return group_schedule_map[schedule] @@ -1009,7 +1094,8 @@ class Target(enum.Enum): library = enum_auto() # ArchitectureNames = { - 11: 'pvc', + INTEL_XE12: 'pvc', + INTEL_XE20: 'bmg', 50: 'maxwell', 60: 'pascal', 61: 'pascal', @@ -1022,7 +1108,8 @@ class Target(enum.Enum): # SharedMemPerCC = { - 11: 128, # 128 KiB of SMEM on Intel PVC + INTEL_XE12: 128, # 128 KiB of SMEM on Intel PVC + INTEL_XE20: 128, # 128 KiB of SMEM on Intel BMG 70: 96, # 96KB of SMEM 72: 96, # 96KB of SMEM 75: 64, # 64KB of SMEM @@ -1031,6 +1118,7 @@ class Target(enum.Enum): 87: 163, # 163KB of SMEM - 1KB reserved for the driver 89: 99, # 99KB of SMEM - 1KB reserved for the driver 90: 227, # 227KB of SMEM - 1KB reserved for the driver + 100: 227, # 227KB of SMEM - 1KB reserved for the driver } ################################################################################################### diff --git a/python/cutlass_library/manifest.py b/python/cutlass_library/manifest.py index baaaac28a8..91010f577a 100644 --- a/python/cutlass_library/manifest.py +++ b/python/cutlass_library/manifest.py @@ -1,6 +1,7 @@ ################################################################################################# # # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (C) 2025 Intel Corporation, All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -65,6 +66,16 @@ ################################################################################################### _LOGGER = logging.getLogger(__name__) +################################################################################################### +# Import architecture range constants from shared module +################################################################################################### +try: + from cutlass_library.arch_constants import INTEL_XE_ARCH_MIN, INTEL_XE_ARCH_MAX, CUDA_ARCH_MIN, INTEL_XE12, INTEL_XE20 +except ImportError: + from arch_constants import INTEL_XE_ARCH_MIN, INTEL_XE_ARCH_MAX, CUDA_ARCH_MIN, INTEL_XE12, INTEL_XE20 + +################################################################################################### + class EmitOperationKindAll: """ @@ -136,7 +147,27 @@ def __enter__(self): str(self.operation_path)); os.makedirs(self.operation_path, exist_ok=True) - self.top_level_path = os.path.join(self.operation_path, f"all_{OperationKindNames[self.kind]}_operations.cu") + # Determine file extension based on architecture + # Check if any Intel Xe target is present in the architectures + file_extension = "cu" # Default to CUDA + if self.args and hasattr(self.args, 'architectures'): + archs = self.args.architectures.split(';') if len(self.args.architectures) else [] + for arch in archs: + arch_lower = arch.lower() + # Check for Intel Xe targets + if any(xe_target in arch_lower for xe_target in ['pvc', 'bmg', 'intel_gpu']): + file_extension = "cpp" + break + # Check for numeric Xe architecture in the Intel Xe range + try: + arch_num = int(arch.split('a')[0].split('f')[0]) + if arch_num >= INTEL_XE_ARCH_MIN and arch_num < INTEL_XE_ARCH_MAX: + file_extension = "cpp" + break + except (ValueError, AttributeError): + pass + + self.top_level_path = os.path.join(self.operation_path, f"all_{OperationKindNames[self.kind]}_operations.{file_extension}") _LOGGER.debug(f"*** top_level_path (file to write): {str(self.top_level_path)}") self.top_level_file = open(self.top_level_path, "w") @@ -184,9 +215,10 @@ class EmitOperationKindLibrary: for min_cc=90 and OperationKind=Gemm), in the file all_sm{min_cc}_{operation_kind}_operations.cu (e.g., all_sm90_gemm_operations.cu for min_cc=90 and OperationKind=Gemm). + For Intel Xe targets, uses xe{min_cc} prefix instead of sm{min_cc}. The min_cc variable here indicates the minimum GPU architecture version that the things to be initialized require. - For example, min_cc=90 indicates sm90. + For example, min_cc=90 indicates sm90 for CUDA, min_cc=20 indicates Xe2/BMG for Intel. That file declares several functions in namespace cutlass::library. The functions all have this form, @@ -207,11 +239,23 @@ class EmitOperationKindLibrary: of what happens in each of those subdirectories. """ + @staticmethod + def get_arch_prefix(min_cc): + """Get architecture prefix based on compute capability. + Returns 'sm' for CUDA architectures, 'xe' for Intel Xe architectures. + Intel Xe: 12 (PVC), 20 (BMG) + CUDA: 50+ for CUDA architectures""" + if min_cc >= INTEL_XE_ARCH_MIN and min_cc < INTEL_XE_ARCH_MAX: + return 'xe' + else: + return 'sm' + def __init__(self, generated_path, min_cc, kind, args): self.generated_path = generated_path self.min_cc = min_cc self.kind = kind self.args = args + self.arch_prefix = self.get_arch_prefix(min_cc) self.emitters = { OperationKind.Gemm: EmitGemmConfigurationLibrary, OperationKind.Conv2d: EmitConv2dConfigurationLibrary, @@ -242,12 +286,12 @@ def __init__(self, generated_path, min_cc, kind, args): // // Entry point to construct operations // -void initialize_all_sm${min_cc}_${subclass_name}_${operation_name}_operations(Manifest &manifest) { +void initialize_all_${arch_prefix}${min_cc}_${subclass_name}_${operation_name}_operations(Manifest &manifest) { """ self.configuration_prototype_template = "void initialize_${configuration_name}(Manifest &manifest);\n" self.configuration_template = " initialize_${configuration_name}(manifest);\n" - self.subclass_call_template = " initialize_all_sm${min_cc}_${subclass_name}_${operation_name}_operations(manifest);\n" - self.subclass_prototype_template = "void initialize_all_sm${min_cc}_${subclass_name}_${operation_name}_operations(Manifest &manifest);\n" + self.subclass_call_template = " initialize_all_${arch_prefix}${min_cc}_${subclass_name}_${operation_name}_operations(manifest);\n" + self.subclass_prototype_template = "void initialize_all_${arch_prefix}${min_cc}_${subclass_name}_${operation_name}_operations(Manifest &manifest);\n" self.epilogue_template ="""} /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -268,7 +312,9 @@ def __enter__(self): _LOGGER.debug(f"*** operation_path (directory to make): {str(self.operation_path)}") os.makedirs(self.operation_path) - self.top_level_path = os.path.join(self.operation_path, f"all_sm{self.min_cc}_{OperationKindNames[self.kind]}_operations.cu") + # Use .cpp extension for Intel Xe architectures, .cu for CUDA + file_extension = "cpp" if (self.min_cc >= INTEL_XE_ARCH_MIN and self.min_cc < INTEL_XE_ARCH_MAX) else "cu" + self.top_level_path = os.path.join(self.operation_path, f"all_{self.arch_prefix}{self.min_cc}_{OperationKindNames[self.kind]}_operations.{file_extension}") _LOGGER.debug(f"*** top_level_path (file to write): {str(self.top_level_path)}") self.top_level_file = open(self.top_level_path, "w") @@ -307,9 +353,11 @@ def emit(self, configuration_name, operations): self.subclass_configurations[extended_name] = [] + # Use .cpp extension for Intel Xe architectures, .cu for CUDA + file_extension = "cpp" if (self.min_cc >= INTEL_XE_ARCH_MIN and self.min_cc < INTEL_XE_ARCH_MAX) else "cu" # Open a new top-level file for this sub class subclass_top_level_path = os.path.join( - subclass_path, f"all_sm{self.min_cc}_{extended_name}_{OperationKindNames[self.kind]}_operations.cu") + subclass_path, f"all_{self.arch_prefix}{self.min_cc}_{extended_name}_{OperationKindNames[self.kind]}_operations.{file_extension}") _LOGGER.debug('*** subclass_top_level_path (min_cc, extended_name, ' + 'OperationKind): ' + str(subclass_top_level_path)) @@ -337,6 +385,7 @@ def __exit__(self, exception_type, exception_value, traceback): _LOGGER.debug("*** EmitOperationKindLibrary::__exit__") for subclass_name, subclass_file in sorted(self.subclass_files.items()): subclass_cfg = { + 'arch_prefix': self.arch_prefix, 'min_cc': str(self.min_cc), 'subclass_name': subclass_name, 'operation_name': OperationKindNames[self.kind] @@ -345,6 +394,7 @@ def __exit__(self, exception_type, exception_value, traceback): self.top_level_file.write( SubstituteTemplate(self.entry_template, { + 'arch_prefix': self.arch_prefix, 'min_cc': str(self.min_cc), 'subclass_name': '', 'operation_name': OperationKindNames[self.kind] @@ -353,6 +403,7 @@ def __exit__(self, exception_type, exception_value, traceback): # Finish and close all subclass files for subclass_name, subclass_file in sorted(self.subclass_files.items()): subclass_cfg = { + 'arch_prefix': self.arch_prefix, 'min_cc': str(self.min_cc), 'subclass_name': subclass_name, 'operation_name': OperationKindNames[self.kind] @@ -511,6 +562,7 @@ def __init__(self, args = None): self.compute_capabilities_feature_set = ['50',] self.curr_build_dir = '.' self.filter_by_cc = True + self.is_xe_target = False # Track if building for Intel Xe if self.args: self.kernel_filter = self.args.kernels @@ -518,10 +570,43 @@ def __init__(self, args = None): # A common user error is to use commas instead of semicolons. if ',' in args.architectures: - raise RuntimeError("The list of architectures (CMake option CUTLASS_NVCC_ARCHS) must be semicolon-delimited.\nDon't use commas to separate the architectures; use semicolons.\nYou specified the list as: " + args.architectures) + raise RuntimeError("The list of architectures (CMake option CUTLASS_NVCC_ARCHS or DPCPP_SYCL_TARGET) must be semicolon-delimited.\nDon't use commas to separate the architectures; use semicolons.\nYou specified the list as: " + args.architectures) self.compute_capabilities_feature_set = args.architectures.split(';') if len(args.architectures) else ['50',] - self.compute_capabilities_baseline = sorted(set(int(arch.split('a')[0].split('f')[0]) for arch in self.compute_capabilities_feature_set)) + + # Parse architecture identifiers - support both CUDA SM and Intel Xe targets + baseline_archs = [] + for arch in self.compute_capabilities_feature_set: + # Check if this is an Intel Xe target (pvc, bmg, etc.) + # Support both string names ('pvc', 'bmg') and numeric values + arch_lower = arch.lower() + is_xe_named = any(xe_target in arch_lower for xe_target in ['pvc', 'bmg', 'intel_gpu']) + + # Also check if it's a numeric Xe architecture in the Intel Xe range + try: + arch_num = int(arch.split('a')[0].split('f')[0]) + is_xe_numeric = (arch_num >= INTEL_XE_ARCH_MIN and arch_num < INTEL_XE_ARCH_MAX) + except (ValueError, AttributeError): + arch_num = None + is_xe_numeric = False + + if is_xe_named or is_xe_numeric: + self.is_xe_target = True + # Map Intel Xe architectures to numeric identifiers for compatibility + # PVC (Ponte Vecchio) -> 12 + # BMG (Battlemage/Xe2) -> 20 + if 'pvc' in arch_lower or arch_num == INTEL_XE12: + baseline_archs.append(INTEL_XE12) + elif 'bmg' in arch_lower or 'xe2' in arch_lower or arch_num == INTEL_XE20: + baseline_archs.append(INTEL_XE20) + else: + # Generic Intel GPU target - default to BMG + baseline_archs.append(INTEL_XE20) + else: + # CUDA SM architecture + baseline_archs.append(arch_num if arch_num is not None else int(arch.split('a')[0].split('f')[0])) + + self.compute_capabilities_baseline = sorted(set(baseline_archs)) if args.filter_by_cc in ['false', 'False', '0']: self.filter_by_cc = False @@ -570,7 +655,7 @@ def add_kernel_filter(self, filter_str): self.kernel_filter_list.append(filter_re) - def get_sm90_instantiation_level(self, pruned_level=0, default_level=111, exhaustive_level=9992): + def get_instantiation_level(self, pruned_level=0, default_level=111, exhaustive_level=9992): # Non-negative integer which determines how many kernels are instantiated. # 0 = 0000 generates the fewest kernels, 9999 generates all possible combinations. # increasing first digit reduces schedule / mixed type pruning, @@ -740,18 +825,24 @@ def emit_manifest_cmake(self, manifest_path, top_level_path, source_files): manifest_file.write(target_text + '\n\n') manifest_file.write(" %s\n" % str(top_level_path.replace('\\', '/'))) generated_path = os.path.join(self.curr_build_dir, 'generated') + + # Determine file extension based on whether we're targeting Intel Xe + file_extension = "cpp" if self.is_xe_target else "cu" + for kind in self.operations.keys(): kind_str = OperationKindNames[kind] - all_kind_file = os.path.join(generated_path, kind_str, f"all_{kind_str}_operations.cu").replace('\\', '/') + all_kind_file = os.path.join(generated_path, kind_str, f"all_{kind_str}_operations.{file_extension}").replace('\\', '/') manifest_file.write(f" {all_kind_file}\n") manifest_file.write(')\n\n') for kind in self.operations.keys(): for min_cc in sorted(self.operations[kind].keys()): for subclass in sorted(source_files[kind][min_cc].keys()): + # Use appropriate prefix (sm for CUDA, xe for Intel) + arch_prefix = 'xe' if (min_cc >= INTEL_XE_ARCH_MIN and min_cc < INTEL_XE_ARCH_MAX) else 'sm' target_text = SubstituteTemplate("""cutlass_add_cutlass_library( - SUFFIX ${kind}_sm${min_cc}_${subclass} -""", { 'min_cc': str(min_cc), 'kind': OperationKindNames[kind], 'subclass': subclass }) + SUFFIX ${kind}_${arch_prefix}${min_cc}_${subclass} +""", { 'arch_prefix': arch_prefix, 'min_cc': str(min_cc), 'kind': OperationKindNames[kind], 'subclass': subclass }) manifest_file.write(target_text + '\n\n') for source_file in source_files[kind][min_cc][subclass]: @@ -759,7 +850,8 @@ def emit_manifest_cmake(self, manifest_path, top_level_path, source_files): manifest_file.write(")\n") - if self.disable_full_archs_compilation: + # Only apply CUDA-specific arch compilation settings for CUDA targets + if self.disable_full_archs_compilation and min_cc < INTEL_XE_ARCH_MIN: self.emit_disable_full_archs_compilation(manifest_file, source_files) def emit_disable_full_archs_compilation(manifest_file, source_files): diff --git a/python/cutlass_library/sm100_shapes.py b/python/cutlass_library/sm100_shapes.py new file mode 100644 index 0000000000..32e4376513 --- /dev/null +++ b/python/cutlass_library/sm100_shapes.py @@ -0,0 +1,342 @@ +################################################################################################# +# +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Valid tcgen05 shapes and cluster sizes for SM100, associated with levels. +These shape and level pairs are defined as dicts, where keys are shapes and values are their +associated levels. If the user input level for that category (tcgen05 shape, cluster +size) is smaller than a shape's associated level, it will be excluded, and otherwise, included. +Higher levels are therefore less likely emitted, but lower levels are more emitted more frequently. +Level 0 is always emitted. +""" + +try: + from .library import DynamicClusterShape +except: + from library import DynamicClusterShape + +SM100_CLUSTER_SHAPES_1SM = { + tuple(DynamicClusterShape) : 0, + # size 1 cluster + (1, 1, 1): 1, + # size 2 cluster + (1, 2, 1): 2, + (2, 1, 1): 5, + # size 4 clusters + (2, 2, 1): 6, + (1, 4, 1): 3, + (4, 1, 1): 6, + # size 8 clusters + (2, 4, 1): 7, + (4, 2, 1): 7, + (1, 8, 1): 8, + (8, 1, 1): 8, + # size 16 cluster + (4, 4, 1): 4, +} + +SM100_CLUSTER_SHAPES_2SM = { + tuple(DynamicClusterShape) : 0, + # size 2 cluster + (2, 1, 1): 1, + # size 4 clusters + (2, 2, 1): 2, + (4, 1, 1): 2, + # size 8 clusters + (2, 4, 1): 3, + (4, 2, 1): 3, + (8, 1, 1): 6, + # size 16 cluster + (4, 4, 1): 4, +} + +# MMA shapes + +# 16b Dense + +SM100_MMA_SHAPES_16b_DENSE_1SM = { + (64, 8, 16): 5, + (64, 16, 16): 2, + (64, 24, 16): 5, + (64, 32, 16): 2, + (64, 40, 16): 5, + (64, 48, 16): 5, + (64, 56, 16): 5, + (64, 64, 16): 2, + (64, 72, 16): 5, + (64, 80, 16): 5, + (64, 88, 16): 5, + (64, 96, 16): 5, + (64, 104, 16): 5, + (64, 112, 16): 5, + (64, 120, 16): 5, + (64, 128, 16): 0, + (64, 136, 16): 5, + (64, 144, 16): 5, + (64, 152, 16): 5, + (64, 160, 16): 5, + (64, 168, 16): 5, + (64, 176, 16): 5, + (64, 184, 16): 5, + (64, 192, 16): 3, + (64, 200, 16): 5, + (64, 208, 16): 5, + (64, 216, 16): 5, + (64, 224, 16): 5, + (64, 232, 16): 5, + (64, 240, 16): 5, + (64, 248, 16): 5, + (64, 256, 16): 3, + + (128, 16, 16): 2, + (128, 32, 16): 2, + (128, 48, 16): 5, + (128, 64, 16): 2, + (128, 80, 16): 5, + (128, 96, 16): 5, + (128, 112, 16): 5, + (128, 128, 16): 0, + (128, 144, 16): 5, + (128, 160, 16): 5, + (128, 176, 16): 5, + (128, 192, 16): 3, + (128, 208, 16): 5, + (128, 224, 16): 5, + (128, 240, 16): 5, + (128, 256, 16): 0, + +} + + +SM100_MMA_SHAPES_16b_DENSE_2SM = { + (128, 32, 16): 2, + (128, 64, 16): 2, + (128, 96, 16): 5, + (128, 128, 16): 0, + (128, 160, 16): 5, + (128, 192, 16): 5, + (128, 224, 16): 5, + (128, 256, 16): 0, + + (256, 32, 16): 2, + (256, 64, 16): 2, + (256, 96, 16): 5, + (256, 128, 16): 0, + (256, 160, 16): 5, + (256, 192, 16): 3, + (256, 224, 16): 5, + (256, 256, 16): 0, +} + +# TF32 Dense + +SM100_MMA_SHAPES_TF32_DENSE_1SM = { + (64, 8, 8): 5, + (64, 16, 8): 2, + (64, 24, 8): 5, + (64, 32, 8): 2, + (64, 40, 8): 5, + (64, 48, 8): 5, + (64, 56, 8): 5, + (64, 64, 8): 1, + (64, 72, 8): 5, + (64, 80, 8): 5, + (64, 88, 8): 5, + (64, 96, 8): 5, + (64, 104, 8): 5, + (64, 112, 8): 5, + (64, 120, 8): 5, + (64, 128, 8): 0, + (64, 136, 8): 5, + (64, 144, 8): 5, + (64, 152, 8): 5, + (64, 160, 8): 5, + (64, 168, 8): 5, + (64, 176, 8): 5, + (64, 184, 8): 5, + (64, 192, 8): 3, + (64, 200, 8): 5, + (64, 208, 8): 5, + (64, 216, 8): 5, + (64, 224, 8): 5, + (64, 232, 8): 5, + (64, 240, 8): 5, + (64, 248, 8): 5, + (64, 256, 8): 3, + + (128, 16, 8): 2, + (128, 32, 8): 2, + (128, 48, 8): 5, + (128, 64, 8): 2, + (128, 80, 8): 5, + (128, 96, 8): 5, + (128, 112, 8): 5, + (128, 128, 8): 0, + (128, 144, 8): 5, + (128, 160, 8): 5, + (128, 176, 8): 5, + (128, 192, 8): 3, + (128, 208, 8): 5, + (128, 224, 8): 5, + (128, 240, 8): 5, + (128, 256, 8): 0, + +} + +SM100_MMA_SHAPES_TF32_DENSE_2SM = { + (128, 32, 8): 2, + (128, 64, 8): 1, + (128, 96, 8): 5, + (128, 128, 8): 0, + (128, 160, 8): 5, + (128, 192, 8): 5, + (128, 224, 8): 5, + (128, 256, 8): 0, + + (256, 32, 8): 2, + (256, 64, 8): 1, + (256, 96, 8): 5, + (256, 128, 8): 0, + (256, 160, 8): 5, + (256, 192, 8): 5, + (256, 224, 8): 5, + (256, 256, 8): 0, +} + +# F8F6F4 +SM100_MMA_SHAPES_F8F6F4_DENSE_1SM = { + (64, 8, 32): 4, + (64, 16, 32): 4, + (64, 24, 32): 5, + (64, 32, 32): 3, + (64, 40, 32): 5, + (64, 48, 32): 5, + (64, 56, 32): 5, + (64, 64, 32): 2, + (64, 72, 32): 5, + (64, 80, 32): 5, + (64, 88, 32): 5, + (64, 96, 32): 5, + (64, 104, 32): 5, + (64, 112, 32): 5, + (64, 120, 32): 5, + (64, 128, 32): 0, + (64, 136, 32): 5, + (64, 144, 32): 5, + (64, 152, 32): 5, + (64, 160, 32): 5, + (64, 168, 32): 5, + (64, 176, 32): 5, + (64, 184, 32): 5, + (64, 192, 32): 5, + (64, 200, 32): 5, + (64, 208, 32): 5, + (64, 216, 32): 5, + (64, 224, 32): 5, + (64, 232, 32): 5, + (64, 240, 32): 5, + (64, 248, 32): 5, + (64, 256, 32): 0, + + (128, 16, 32): 4, + (128, 32, 32): 3, + (128, 48, 32): 5, + (128, 64, 32): 2, + (128, 80, 32): 5, + (128, 96, 32): 5, + (128, 112, 32): 5, + (128, 128, 32): 0, + (128, 144, 32): 5, + (128, 160, 32): 5, + (128, 176, 32): 5, + (128, 192, 32): 5, + (128, 208, 32): 5, + (128, 224, 32): 5, + (128, 240, 32): 5, + (128, 256, 32): 0, + +} + +SM100_MMA_SHAPES_F8F6F4_DENSE_2SM = { + (128, 32, 32): 3, + (128, 64, 32): 2, + (128, 96, 32): 5, + (128, 128, 32): 1, + (128, 160, 32): 5, + (128, 192, 32): 5, + (128, 224, 32): 5, + (128, 256, 32): 1, + + (256, 32, 32): 2, + (256, 64, 32): 2, + (256, 96, 32): 5, + (256, 128, 32): 0, + (256, 160, 32): 5, + (256, 192, 32): 5, + (256, 224, 32): 5, + (256, 256, 32): 0, +} + +# MXF8F6F4 +SM100_MMA_SHAPES_MXF8F6F4_DENSE_1SM = { + (128, 64, 32): 1, + (128, 128, 32): 0, + (128, 192, 32): 1, + (128, 256, 32): 0, +} + + +SM100_MMA_SHAPES_MXF8F6F4_DENSE_2SM = { + (256, 64, 32): 1, + (256, 128, 32): 0, + (256, 192, 32): 1, + (256, 256, 32): 0, + + +} + +# MXF4NVF4 +SM100_MMA_SHAPES_MXF4NVF4_DENSE_1SM = { + (128, 64, 64): 1, + (128, 128, 64): 0, + (128, 192, 64): 1, + (128, 256, 64): 0, +} + +SM100_MMA_SHAPES_MXF4NVF4_DENSE_2SM = { + # Multiples of 16 for N + (256, 64, 64): 1, + (256, 128, 64): 0, + (256, 192, 64): 1, + (256, 256, 64): 0, + +} diff --git a/python/cutlass_library/sm100_utils.py b/python/cutlass_library/sm100_utils.py new file mode 100644 index 0000000000..9bf24fe7f5 --- /dev/null +++ b/python/cutlass_library/sm100_utils.py @@ -0,0 +1,661 @@ +################################################################################################# +# +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for enumerating CUTLASS library SM100 kernels +""" + +import argparse +import enum +from itertools import product +import math +import logging +import os.path +import shutil +import sys +import copy +from typing import Any, Optional, Sequence, Tuple, List, Union, Callable + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * + +#### Step 0: define levels + +# One integer level controls multiple "generators" and how many +# combinations they generate. That is the "global" level. +# "Generators" are WGMMA shapes, MMA multipliers, cluster sizes, and +# anything that is eventually involved in the Cartesian product +# which yields our kernel configurations. +# For simplicity, each generator defines their own levels, +# starting from 0. As a rule we assume 10 or fewer levels, making +# their level a digit. +# The "global" level simply stacks these digits and represents them +# as a single integer. +# +# For example, level 500 indicates cluster sizes are at level 5, MMA +# multipliers are at level 0, and WGMMA shapes are at level 0 as well. +# +# Here we define the global level to generator level mappings. + + +def get_tcgen05_level_from_global_level(global_level: int): + return global_level % 10 + +def get_mma_level_from_global_level(global_level: int): + return (global_level // 10) % 10 + + +def get_cluster_level_from_global_level(global_level: int): + return (global_level // 100) % 10 + + +def get_pruning_level_from_global_level(global_level: int): + return (global_level // 1000) % 10 + + +#### Step 1: generate MMA instruction shapes based on levels + +try: + from .sm100_shapes import * +except: + from sm100_shapes import * + +########### + +def generate_tf32_math_instructions_sm100(level: int): + """ + Generate all TensorOp math instructions for TF32 MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + tcgen05_level = get_tcgen05_level_from_global_level(level) + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_TF32_DENSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_TF32_DENSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + for shape in shapes_2sm: + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + return math_instructions_1sm, math_instructions_2sm + +def generate_16b_math_instructions_sm100(level: int): + """ + Generate all TensorOp math instructions for 16b MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + tcgen05_level = get_tcgen05_level_from_global_level(level) + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_16b_DENSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_16b_DENSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.f16, DataType.f16, DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.bf16, DataType.bf16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + + for shape in shapes_2sm: + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.f16, DataType.f16, DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.bf16, DataType.bf16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + return math_instructions_1sm, math_instructions_2sm + + +def generate_fp8_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True): + """ + Generate all TensorOp math instructions for FP8 MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + enable_runtime_dtype: Whether to generate runtime dtype math instructions. + enable_compile_time_dtype: Whether to generate compile time dtype math instructions. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + + tcgen05_level = get_tcgen05_level_from_global_level(level) + pruning_level = get_pruning_level_from_global_level(level) + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + if enable_runtime_dtype: + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.f8, DataType.f8, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + if enable_compile_time_dtype: + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.e5m2, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.e4m3, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + if pruning_level >= 2: + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.e5m2, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + for shape in shapes_2sm: + if enable_runtime_dtype: + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.f8, DataType.f8, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + if enable_compile_time_dtype: + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.e5m2, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.e4m3, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + if pruning_level >= 2: + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.e5m2, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + return math_instructions_1sm, math_instructions_2sm + +def generate_f8f6f4_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True): + """ + Generate all TensorOp math instructions for FP8 FP6 and FP4 MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + enable_runtime_dtype: Whether to generate runtime dtype math instructions. + enable_compile_time_dtype: Whether to generate compile time dtype math instructions. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + + tcgen05_level = get_tcgen05_level_from_global_level(level) + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e4m3, DataType.e5m2, DataType.e3m2, DataType.e2m1 ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + + for shape in shapes_2sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e4m3, DataType.e5m2, DataType.e3m2, DataType.e2m1 ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + return math_instructions_1sm, math_instructions_2sm + +def generate_mxf8f6f4_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True): + """ + Generate all BlockScaledTensorOp math instructions for MXFP8, MXFP6, and MXFP4 MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + enable_runtime_dtype: Whether to generate runtime dtype math instructions. + enable_compile_time_dtype: Whether to generate compile time dtype math instructions. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + + tcgen05_level = get_tcgen05_level_from_global_level(level) + pruning_level = get_pruning_level_from_global_level(level) + + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_MXF8F6F4_DENSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_MXF8F6F4_DENSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + + if pruning_level < 2 and ((a_type == DataType.f8 or b_type == DataType.f8)): + continue + + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e4m3, + DataType.e5m2, + DataType.e3m2, + DataType.e2m3, + DataType.e2m1 ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + + + for shape in shapes_2sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + + if pruning_level < 2 and ((a_type == DataType.f8 or b_type == DataType.f8)): + continue + + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e4m3, + DataType.e5m2, + DataType.e3m2, + DataType.e2m3, + DataType.e2m1 ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + + return math_instructions_1sm, math_instructions_2sm + +def generate_mxf4nvf4_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True): + """ + Generate all BlockScaledTensorOp math instructions for MXFP4 and MXFP4 MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + enable_runtime_dtype: Whether to generate runtime dtype math instructions. + enable_compile_time_dtype: Whether to generate compile time dtype math instructions. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + tcgen05_level = get_tcgen05_level_from_global_level(level) + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_MXF4NVF4_DENSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_MXF4NVF4_DENSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue4m3) + ) + + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e2m1, + ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue4m3) + ) + + + for shape in shapes_2sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue4m3) + ) + + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e2m1, + ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue4m3) + ) + + + return math_instructions_1sm, math_instructions_2sm + + +def generate_cluster_shapes_sm100(level: int, change_priority_func : Union[Callable, None] = None): + """ + Generate all cluster shapes for SM100 at or above the given level. + + Args: + level: The global level to generate cluster shapes for. + + Returns: + A tuple of two lists of cluster shapes. + The first list contains the cluster shapes for 1SM, and the second list contains the cluster shapes for 2SM. + """ + cluster_level = get_cluster_level_from_global_level(level) + + assert cluster_level >= 4 + + if change_priority_func is not None: + SM100_CLUSTER_SHAPES_1SM_CPY = copy.deepcopy(SM100_CLUSTER_SHAPES_1SM) + SM100_CLUSTER_SHAPES_2SM_CPY = copy.deepcopy(SM100_CLUSTER_SHAPES_2SM) + change_priority_func(SM100_CLUSTER_SHAPES_1SM_CPY, SM100_CLUSTER_SHAPES_2SM_CPY) + shapes_1sm = [ + list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_1SM_CPY.items() if cluster_level >= min_level + ] + shapes_2sm = [ + list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_2SM_CPY.items() if cluster_level >= min_level + ] + + return shapes_1sm, shapes_2sm + + else: + + shapes_1sm = [ + list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_1SM.items() if cluster_level >= min_level + ] + shapes_2sm = [ + list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_2SM.items() if cluster_level >= min_level + ] + + return shapes_1sm, shapes_2sm diff --git a/python/cutlass_library/sm90_utils.py b/python/cutlass_library/sm90_utils.py index 3bf3edb23c..fc5fdf14ab 100644 --- a/python/cutlass_library/sm90_utils.py +++ b/python/cutlass_library/sm90_utils.py @@ -637,7 +637,10 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types, if CudaToolkitVersionSatisfies(cuda_version, 12, 1): # Pruning: don't stamp out fp8 ping-pong kernel with non-tma epilogue if not is_fp8 or level >= 1: - schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped), to_grouped_schedule(default_epilogue, grouped)]) + if not is_blockwise(gemm_kind): + schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped), to_grouped_schedule(default_epilogue, grouped)]) + else: + schedules.append([to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong, grouped), to_grouped_schedule(default_epilogue, grouped)]) if can_do_fp8_fast_accum: if not grouped: diff --git a/python/setup_cutlass.py b/python/setup_cutlass.py index 8122b7a6a1..0b22d31b81 100644 --- a/python/setup_cutlass.py +++ b/python/setup_cutlass.py @@ -1,6 +1,7 @@ ################################################################################################# # # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (C) 2025 Intel Corporation, All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -50,9 +51,9 @@ setup( - name='cutlass_cppgen', - version='4.0.0', - description='CUTLASS Pythonic Interface', + name='sycl_tla_cppgen', + version='0.6.0', + description='SYCL*TLA Pythonic Interface based on CUTLASS', package_dir={'': '.'}, packages=[ 'cutlass_cppgen', @@ -65,7 +66,6 @@ setup_requires=['pybind11'], install_requires=[ 'bfloat16', - 'cuda-python>=11.8.0', 'pybind11', 'scikit-build', 'treelib', diff --git a/python/setup_library.py b/python/setup_library.py index 875ba62d55..621decf370 100644 --- a/python/setup_library.py +++ b/python/setup_library.py @@ -1,6 +1,7 @@ ################################################################################################# # # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (C) 2025 Intel Corporation, All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -35,9 +36,9 @@ def perform_setup(): setup( - name='cutlass_library', - version='4.1.0', - description='CUTLASS library generation scripts', + name='cutlass_library_xe', + version='0.6.0', + description='SYCL*TLA library generation scripts', packages=['cutlass_library'] ) diff --git a/python/setup_pycute.py b/python/setup_pycute.py index 7e8a99e034..0bad050fca 100644 --- a/python/setup_pycute.py +++ b/python/setup_pycute.py @@ -36,7 +36,7 @@ def perform_setup(): setup( name='pycute', - version='4.1.0', + version='4.2.1', description='Python implementation of CuTe', packages=['pycute'], ) diff --git a/setup.cfg b/setup.cfg index 310be28b54..a412cd39bf 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] -name = cutlass-sycl -version = 0.5 +name = sycl-tla +version = 0.6 [options] packages = @@ -20,7 +20,7 @@ packages = cutlass_library.source pycute package_dir = - cutlass_cppgen=python/cutlass + cutlass_cppgen=python/cutlass_cppgen cutlass_library=python/cutlass_library cutlass_library.source=. pycute=python/pycute diff --git a/test/python/cutlass/gemm/gemm_batched.py b/test/python/cutlass/gemm/gemm_batched.py index 116597f910..be62f6df7f 100644 --- a/test/python/cutlass/gemm/gemm_batched.py +++ b/test/python/cutlass/gemm/gemm_batched.py @@ -88,7 +88,7 @@ def initialize(rows, cols, batch): return tensor.reshape(rows, cols) -@unittest.skipIf(device_cc() == 11, "Batched GEMM test not supported on PVC") +@unittest.skipIf(device_cc() >= 12 and device_cc() <= 20, "Batched GEMM test not supported on Xe") class GemmF16Batched(unittest.TestCase): def run_batched(self, batch_count: tuple, batch_A: bool, batch_B: bool, batch_C: bool): M = 512 diff --git a/test/python/cutlass/gemm/gemm_bf16_pvc.py b/test/python/cutlass/gemm/gemm_bf16_pvc.py index baeefd8e76..87c602715d 100644 --- a/test/python/cutlass/gemm/gemm_bf16_pvc.py +++ b/test/python/cutlass/gemm/gemm_bf16_pvc.py @@ -1,6 +1,7 @@ ################################################################################################# # # Copyright (c) 2023 - 2025 Codeplay Software Limited. All rights reserved. +# Copyright (c) 2025 Intel Corporation. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -40,16 +41,16 @@ import cutlass_cppgen from cutlass_cppgen.backend.utils.device import device_cc - +from cutlass_library.arch_constants import ( INTEL_XE12, is_intel_xe_arch) from utils import LayoutCombination, add_test_gemm cutlass_cppgen.set_log_level(logging.WARNING) -cc = 11 +cc = INTEL_XE12 # PVC architecture is 12 (Xe-HPC) dtype = cutlass_cppgen.DataType.bf16 -@unittest.skipIf(device_cc() != cc, 'Device compute capability is insufficient for PVC tests.') +@unittest.skipIf(not is_intel_xe_arch(device_cc()), 'Device compute capability is insufficient for PVC tests.') @unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') class GemmBF16PVC(unittest.TestCase): """ @@ -58,7 +59,7 @@ class GemmBF16PVC(unittest.TestCase): pass -add_test_pvc_bf16 = partial(add_test_gemm, cls=GemmBF16PVC, cc=11, +add_test_pvc_bf16 = partial(add_test_gemm, cls=GemmBF16PVC, cc=INTEL_XE12, element=dtype, compilation_modes=["dpcpp"], opclass=cutlass_cppgen.OpcodeClass.TensorOp, diff --git a/test/python/cutlass/gemm/gemm_bf16_xe20.py b/test/python/cutlass/gemm/gemm_bf16_xe20.py new file mode 100644 index 0000000000..a811a65664 --- /dev/null +++ b/test/python/cutlass/gemm/gemm_bf16_xe20.py @@ -0,0 +1,107 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 Codeplay Software Limited. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with BF16 operands on Xe20 (BMG) +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc +from cutlass_library.arch_constants import ( INTEL_XE12, is_intel_xe_arch) + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 20 # BMG architecture is 20 (Xe2) +dtype = cutlass_cppgen.DataType.bf16 + + +@unittest.skipIf(not is_intel_xe_arch(device_cc()), 'Device compute capability is insufficient for Xe20 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmBF16Xe20(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_xe20_bf16 = partial(add_test_gemm, cls=GemmBF16Xe20, cc=INTEL_XE20, + element=dtype, + compilation_modes=["dpcpp"], + opclass=cutlass_cppgen.OpcodeClass.TensorOp, + stages=0, + cluster_shape=[1, 1, 1]) + +add_test_f32_acc = partial(add_test_xe20_bf16, alignments=[16, 16, 4], + element_C=cutlass_cppgen.DataType.f32, + element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32) +add_test_bf16_acc = partial(add_test_xe20_bf16, alignments=[16, 16, 2], + element_C=cutlass_cppgen.DataType.bf16, + element_output=cutlass_cppgen.DataType.bf16, + element_accumulator=cutlass_cppgen.DataType.bf16) + +add_test_f32_acc(layouts=LayoutCombination.TTT, + threadblock_shape=[256, 256, 32], warp_count=[8, 4, 1]) + +add_test_f32_acc(layouts=LayoutCombination.TTT, + threadblock_shape=[128, 512, 32], warp_count=[4, 8, 1]) + +add_test_f32_acc(layouts=LayoutCombination.TTT, + threadblock_shape=[256, 128, 32], warp_count=[8, 4, 1]) + +add_test_f32_acc(layouts=LayoutCombination.TTT, + threadblock_shape=[128, 256, 16], warp_count=[4, 8, 1]) + +add_test_bf16_acc(layouts=LayoutCombination.TTT, + threadblock_shape=[256, 256, 32], warp_count=[8, 4, 1]) + +add_test_bf16_acc(layouts=LayoutCombination.TTT, + threadblock_shape=[128, 512, 32], warp_count=[4, 8, 1]) + +add_test_bf16_acc(layouts=LayoutCombination.TTT, + threadblock_shape=[256, 128, 32], warp_count=[8, 4, 1]) + +add_test_bf16_acc(layouts=LayoutCombination.TTT, + threadblock_shape=[128, 256, 16], warp_count=[4, 8, 1]) + + +# TODO: Test more configurations as soon as they're supported by the +# CollectiveBuilder + +if __name__ == '__main__': + unittest.main() diff --git a/test/python/cutlass/gemm/gemm_testbed.py b/test/python/cutlass/gemm/gemm_testbed.py index f714bfcd60..922c067855 100644 --- a/test/python/cutlass/gemm/gemm_testbed.py +++ b/test/python/cutlass/gemm/gemm_testbed.py @@ -38,6 +38,7 @@ import cutlass_cppgen import torch import os +from cutlass_library.arch_constants import ( INTEL_XE_ARCH_MIN, INTEL_XE_ARCH_MAX, INTEL_XE12, INTEL_XE20, is_intel_xe_arch) if not os.getenv("CUTLASS_USE_SYCL"): import cuda @@ -88,8 +89,9 @@ def __init__( raise Exception(f"Unexpected compiler string {compiler_mode}") op_list = [operation] - if operation.arch < 90 and operation.arch > 11: + if operation.arch < 90 and not is_intel_xe_arch(operation.arch): # Split K via Python is currently only supported for pre-SM90 kernels + # Exclude Intel Xe architectures as reduction is not implemented for Intel Xe self.reduction_operation: ReductionOperation = ReductionOperation( shape=MatrixCoord(4, 32 * operation.C.alignment), C=operation.C, @@ -99,6 +101,9 @@ def __init__( count=operation.C.alignment, ) op_list.append(self.reduction_operation) + else: + # No reduction operation for Intel Xe architectures or SM90+ + self.reduction_operation = None compiler.add_module(op_list, bypass_cache=False) @@ -277,6 +282,8 @@ def transpose(layout): ) if mode == GemmUniversalMode.GemmSplitKParallel: + if self.reduction_operation is None: + raise RuntimeError("GemmSplitKParallel mode is not supported for Intel Xe architectures (reduction operation not implemented)") reduction_arguments = ReductionArguments( self.reduction_operation, problem_size=[problem_size.m, problem_size.n], @@ -290,7 +297,8 @@ def transpose(layout): self.operation.run(arguments) if mode == GemmUniversalMode.GemmSplitKParallel: - self.reduction_operation.run(reduction_arguments) + if self.reduction_operation is not None: + self.reduction_operation.run(reduction_arguments) passed = True diff --git a/test/python/cutlass/gemm/utils.py b/test/python/cutlass/gemm/utils.py index 716c5ae9da..a569b2e7be 100644 --- a/test/python/cutlass/gemm/utils.py +++ b/test/python/cutlass/gemm/utils.py @@ -31,6 +31,14 @@ ################################################################################################# from cutlass_library import SubstituteTemplate +from cutlass_library.arch_constants import ( + INTEL_XE_ARCH_MIN, + INTEL_XE_ARCH_MAX, + INTEL_XE12, + INTEL_XE20, + INTEL_XE35, + is_intel_xe_arch +) import cutlass_cppgen from cutlass_library import ( @@ -117,11 +125,18 @@ def get_name( :return: str """ - name_format = "test_SM${arch}_Device_Gemm_${eA}${lA}_${eB}${lB}_${eC}${lC}_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${cM}x${cN}x${cK}_${stages}_align${aA}-${aB}-${aC}${k}${e}${suffix}" + name_format = "test_${arch}_Device_Gemm_${eA}${lA}_${eB}${lB}_${eC}${lC}_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${cM}x${cN}x${cK}_${stages}_align${aA}-${aB}-${aC}${k}${e}${suffix}" + + # Map Intel Xe architectures to names + if is_intel_xe_arch(arch): + arch_name = f"Xe{str(arch)}" # Generic Xe naming + else: + arch_name = f"SM{str(arch)}" # NVIDIA SM naming + return SubstituteTemplate( name_format, { - "arch": "PVC" if arch == 11 else f"SM{str(arch)}", + "arch": arch_name, "eA": DataTypeNames[element_a], "eB": DataTypeNames[element_b], "eC": DataTypeNames[element_c], @@ -248,8 +263,8 @@ def run(self): td.stages = stages td.cluster_shape = cluster_shape - # For Intel PVC (CC 11), ensure we use auto schedules and default tile scheduler - if cc == 11: + # For Intel Xe architectures, ensure we use auto schedules and default tile scheduler + if is_intel_xe_arch(cc): td.kernel_schedule = cutlass_cppgen.KernelScheduleType.ScheduleAuto td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.ScheduleAuto td.tile_scheduler = cutlass_cppgen.TileSchedulerType.Default diff --git a/test/python/cutlass/interface/gemm_interface.py b/test/python/cutlass/interface/gemm_interface.py index 723c4c07ee..2913d5933f 100644 --- a/test/python/cutlass/interface/gemm_interface.py +++ b/test/python/cutlass/interface/gemm_interface.py @@ -240,8 +240,8 @@ def test_tensorop_availability(self): """ cc = device_cc() - # F64 Tensor Core operations are only avaiable on devices with CC >= 80 - supports_tensorop_f64 = cc >= 80 + # F64 Tensor Core operations are only avaiable on certain devices + supports_tensorop_f64 = cc in [80, 89, 90] plan = cutlass_cppgen.op.Gemm(cc=cc, element=cutlass_cppgen.DataType.f64, layout=cutlass_cppgen.LayoutType.RowMajor) error_msg = f'Incorrectly raised an exception for availability of TensorOp with F64 operands on SM{cc}' @@ -288,7 +288,7 @@ def test_invalid_tile_description(self): with ExpectException(cc < 80, f'Requested more than 2 stages on SM{cc}'): td.stages = 3 plan.construct(td) - else: + elif cc == 90: original_kschedule = td.kernel_schedule original_eschedule = td.epilogue_schedule with ExpectException(False, f'Incorrectly flagged an error for insufficient shared memory'): @@ -296,10 +296,13 @@ def test_invalid_tile_description(self): td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.NoSmemWarpSpecialized td.stages = 3 plan.construct(td) - # Reset schedules td.kernel_schedule = original_kschedule td.epilogue_schedule = original_eschedule + elif cc in [100, 101, 103]: + with ExpectException(False, f'Incorrectly flagged an error for insufficient shared memory'): + td.stages = 3 + plan.construct(td) with ExpectException(True, f'Requested too many stages'): td.stages = 100 @@ -321,12 +324,12 @@ def test_invalid_tile_description(self): td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized plan.construct(td) - with ExpectException(True, f'Requested a non-auto kernel schedule with an auto epilogue schedule'): + with ExpectException(cc == 90, f'Requested a non-auto kernel schedule with an auto epilogue schedule'): td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.ScheduleAuto plan.construct(td) - with ExpectException(True, f'Requested an auto kernel schedule with a non-auto epilogue schedule'): + with ExpectException(cc == 90, f'Requested an auto kernel schedule with a non-auto epilogue schedule'): td.kernel_schedule = cutlass_cppgen.KernelScheduleType.ScheduleAuto td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized plan.construct(td) diff --git a/test/unit/cute/intel_xe/CMakeLists.txt b/test/unit/cute/intel_xe/CMakeLists.txt index 89a58e2015..5f4c74a188 100755 --- a/test/unit/cute/intel_xe/CMakeLists.txt +++ b/test/unit/cute/intel_xe/CMakeLists.txt @@ -26,6 +26,18 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +if(NOT DEFINED IGC_VERSION_MAJOR) + set(IGC_VERSION_MAJOR 2) +endif() +if(NOT DEFINED IGC_VERSION_MINOR) + set(IGC_VERSION_MINOR 18) +endif() + +add_compile_definitions( + IGC_VERSION_MAJOR=${IGC_VERSION_MAJOR} + IGC_VERSION_MINOR=${IGC_VERSION_MINOR} +) + if(SYCL_INTEL_TARGET) cutlass_test_unit_add_executable( cutlass_test_unit_cute_intel_xe @@ -35,6 +47,11 @@ cutlass_test_unit_add_executable( copy_scatter.cpp mma.cpp tiled_mma.cpp + xe_copy_2d_test.cpp + reorder.cpp + xe_copy_prefetch_2d.cpp + xe_vnni_2d.cpp + xe_transpose_2d.cpp ) else() cutlass_test_unit_add_executable( diff --git a/test/unit/cute/intel_xe/mma.cpp b/test/unit/cute/intel_xe/mma.cpp index d30c8ae8d7..23f327e3d4 100755 --- a/test/unit/cute/intel_xe/mma.cpp +++ b/test/unit/cute/intel_xe/mma.cpp @@ -311,3 +311,116 @@ TEST(PVC_CuTe_Xe, MMA_XE_8x16x8_F32TF32TF32F32_TT) { MMA_Test(512, 512, 256); } + +#if (IGC_VERSION_MAJOR > 2) || (IGC_VERSION_MAJOR == 2 && IGC_VERSION_MINOR >= 18) + +TEST(PVC_CuTe_Xe, MMA_DPAS_S8_8x16) { + MMA_Test, 64, 64, 8, 16, 32, int8_t, int8_t, + int32_t>(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_DPAS_S8_4x16) { + MMA_Test, 32, 64, 4, 16, 32, int8_t, int8_t, + int32_t>(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_DPAS_S8_2x16) { + MMA_Test, 16, 64, 2, 16, 32, int8_t, int8_t, + int32_t>(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_DPAS_S8_1x16) { + MMA_Test, 8, 64, 1, 16, 32, int8_t, int8_t, + int32_t>(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_DPAS_U8_8x16) { + MMA_Test, 64, 64, 8, 16, 32, uint8_t, uint8_t, + int32_t>(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_DPAS_U8_4x16) { + MMA_Test, 32, 64, 4, 16, 32, uint8_t, uint8_t, + int32_t>(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_DPAS_U8_2x16) { + MMA_Test, 16, 64, 2, 16, 32, uint8_t, uint8_t, + int32_t>(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_DPAS_U8_1x16) { + MMA_Test, 8, 64, 1, 16, 32, uint8_t, uint8_t, + int32_t>(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_DPAS_BF16_8x16) { + MMA_Test, 256, 256, 32, 64, 32, bfloat16_t, + bfloat16_t, float>(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_DPAS_BF16_4x16) { + MMA_Test, 32, 64, 4, 16, 16, bfloat16_t, + bfloat16_t, float>(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_DPAS_BF16_2x16) { + MMA_Test, 16, 64, 2, 16, 16, bfloat16_t, + bfloat16_t, float>(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_DPAS_BF16_1x16) { + MMA_Test, 8, 64, 1, 16, 16, bfloat16_t, + bfloat16_t, float>(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_DPAS_F16_8x16) { + MMA_Test, 64, 64, 8, 16, 16, half_t, half_t, + float>(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_DPAS_F16_4x16) { + MMA_Test, 32, 64, 4, 16, 16, half_t, half_t, + float>(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_DPAS_F16_2x16) { + MMA_Test, 16, 64, 2, 16, 16, half_t, half_t, + float>(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_DPAS_F16_1x16) { + MMA_Test, 8, 64, 1, 16, 16, half_t, half_t, + float>(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_DPAS_TF32_8x16) { + MMA_Test, 64, 64, 8, 16, 32, tfloat32_t, + tfloat32_t, float>(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_DPAS_TF32_4x16) { + MMA_Test, 64, 64, 8, 16, 16, tfloat32_t, + tfloat32_t, float>(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_DPAS_TF32_2x16) { + MMA_Test, 64, 64, 8, 16, 16, tfloat32_t, + tfloat32_t, float>(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_DPAS_TF32_1x16) { + MMA_Test, 64, 64, 8, 16, 16, tfloat32_t, + tfloat32_t, float>(512, 512, 256); +} + +#else + +// For the fallback case +#include "cutlass_unit_test.h" + +TEST(PVC_CuTe_Xe, MMA_DPAS_TESTS) { + GTEST_SKIP() << "MMA DPAS tests require IGC version 2.18 or higher. skipped"; +} + +#endif diff --git a/test/unit/cute/intel_xe/reorder.cpp b/test/unit/cute/intel_xe/reorder.cpp new file mode 100644 index 0000000000..68bf37c6c1 --- /dev/null +++ b/test/unit/cute/intel_xe/reorder.cpp @@ -0,0 +1,468 @@ +/* Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass/detail/layout.hpp" + +#include +#include +#include +#include +#include + +#include "cutlass_unit_test.h" + +using namespace cute; +using namespace cutlass; +using namespace compat::experimental; + +// ============================================================================ +// Test Helpers +// ============================================================================ + +template class ReorderKernelName; + +// Generic reorder test kernel for SubgroupTensor +template +void reorder_kernel_subgroup_tensor(SrcType* src_global, DstType* dst_global) +{ + const int tid = ThreadIdxX(); + constexpr int total_size = M * N; + constexpr int values_per_thread = total_size / intel::sg_size; + + // Each thread owns a slice of values (round-robin pattern) + SrcType src_local[values_per_thread]; + DstType dst_local[values_per_thread]; + + // Load from global memory (each thread loads its values) + for (int i = 0; i < values_per_thread; ++i) { + src_local[i] = src_global[tid + i * intel::sg_size]; + } + + // Create fragments + auto src_tensor = make_tensor(make_rmem_ptr(src_local), + make_layout(Shape>{})); + auto dst_tensor = make_tensor(make_rmem_ptr(dst_local), + make_layout(Shape>{})); + + // Subgroup TV layout for round-robin ownership + constexpr auto sg_tv_layout = make_layout(Shape, Int>{}, + Stride<_1, Int>{}); + + // Create SubgroupTensors and perform reorder + auto src_sg = make_subgroup_tensor(src_tensor, sg_tv_layout); + auto dst_sg = make_subgroup_tensor(dst_tensor, sg_tv_layout); + reorder(src_sg, dst_sg); + + // Store back to global memory + for (int i = 0; i < values_per_thread; ++i) { + dst_global[tid + i * intel::sg_size] = dst_local[i]; + } +} + +// Helper function to run a reorder test +template +void run_reorder_test(cutlass::host_vector& host_src, + cutlass::host_vector& host_dst) +{ + cutlass::device_vector device_src = host_src; + cutlass::device_vector device_dst(M * N); + + launch, + ReorderKernelName>>( + launch_policy{compat::dim3(1), compat::dim3(intel::sg_size), + kernel_properties{sycl_exp::sub_group_size}}, + device_src.data(), device_dst.data()); + + compat::wait_and_throw(); + host_dst = device_dst; +} + +// Helper function to initialize test data +template +void initialize_test_data(cutlass::host_vector& host_src) { + for (size_t i = 0; i < host_src.size(); ++i) { + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + host_src[i] = static_cast(i); + } else if constexpr (std::is_same_v) { + host_src[i] = half_t(static_cast(i)); + } else if constexpr (std::is_same_v) { + host_src[i] = bfloat16_t(static_cast(i) * 0.25f); + } else if constexpr (std::is_same_v) { + host_src[i] = uint4_t(static_cast(i % 16)); + } else if constexpr (std::is_same_v) { + host_src[i] = int4_t(static_cast((i % 16) - 8)); + } + } +} + +// Helper function to initialize source data for type conversions +template +void initialize_conversion_source(cutlass::host_vector& host_src) { + for (size_t i = 0; i < host_src.size(); ++i) { + if constexpr (std::is_same_v) { + host_src[i] = static_cast(i) * 0.5f; + } else if constexpr (std::is_same_v) { + host_src[i] = half_t(static_cast(i) * 0.5f); + } else if constexpr (std::is_same_v) { + host_src[i] = static_cast(i); + } else if constexpr (std::is_same_v) { + host_src[i] = static_cast(i % 128); + } + } +} + +// Test kernel for tensor-based reorder (without SubgroupTensor) +template +void reorder_kernel_tensor( + SrcType* src_global, DstType* dst_global) +{ + const int tid = ThreadIdxX(); + constexpr int total_size = M * N; + constexpr int values_per_thread = total_size / intel::sg_size; + + // Each thread owns a slice of the data + SrcType src_local[values_per_thread]; + DstType dst_local[values_per_thread]; + + // Load from global memory + for (int i = 0; i < values_per_thread; ++i) { + src_local[i] = src_global[tid + i * intel::sg_size]; + } + + auto src_fragment = make_tensor( + make_rmem_ptr(src_local), + make_layout(Shape>{})); + + auto dst_fragment = make_tensor( + make_rmem_ptr(dst_local), + make_layout(Shape>{})); + + // Subgroup TV layout (same for src and dst) + constexpr auto sg_layout = make_layout( + Shape, Int>{}, + Stride<_1, Int>{}); + + // Perform reorder with explicit TV layouts + reorder(src_fragment, dst_fragment, sg_layout, sg_layout); + + // Store back to global memory + for (int i = 0; i < values_per_thread; ++i) { + dst_global[tid + i * intel::sg_size] = dst_local[i]; + } +} + +// Helper function to run a tensor-based reorder test +template +void run_tensor_reorder_test(cutlass::host_vector& host_src, + cutlass::host_vector& host_dst) +{ + cutlass::device_vector device_src = host_src; + cutlass::device_vector device_dst(M * N); + + launch, + ReorderKernelName>>( + launch_policy{compat::dim3(1), compat::dim3(intel::sg_size), + kernel_properties{sycl_exp::sub_group_size}}, + device_src.data(), device_dst.data()); + + compat::wait_and_throw(); + host_dst = device_dst; +} + +// Generic template-based test helper for reorder operations +template class KernelRunner> +struct ReorderTestBase { + static void run() { + cutlass::host_vector host_src(M * N); + cutlass::host_vector host_dst(M * N); + + // Initialize with sequential values + initialize_test_data(host_src); + + // Run the appropriate kernel via KernelRunner + KernelRunner::template execute(host_src, host_dst); + + // Verify reorder preserves all values + for (size_t i = 0; i < host_src.size(); ++i) { + EXPECT_EQ(host_dst[i], host_src[i]); + } + } +}; + +// Kernel runner for SubgroupTensor-based reorder +template +struct SubgroupTensorRunner { + template + static void execute(cutlass::host_vector& host_src, + cutlass::host_vector& host_dst) { + run_reorder_test(host_src, host_dst); + } +}; + +// Kernel runner for Tensor-based reorder +template +struct TensorRunner { + template + static void execute(cutlass::host_vector& host_src, + cutlass::host_vector& host_dst) { + run_tensor_reorder_test(host_src, host_dst); + } +}; + +// Generic template-based test helper for cross-type reorder operations +template class KernelRunner> +struct ConversionReorderTestBase { + static void run() { + cutlass::host_vector host_src(M * N); + cutlass::host_vector host_dst(M * N); + + // Initialize source data + initialize_conversion_source(host_src); + + // Expected result after conversion + cutlass::host_vector host_expected(M * N); + for (size_t i = 0; i < host_src.size(); ++i) { + host_expected[i] = static_cast(host_src[i]); + } + + // Run the reorder kernel + KernelRunner::template execute(host_src, host_dst); + + // Verify conversion correctness (with tolerance for floating point) + for (size_t i = 0; i < host_expected.size(); ++i) { + if constexpr (std::is_floating_point_v) { + EXPECT_NEAR(static_cast(host_dst[i]), + static_cast(host_expected[i]), 1e-4f); + } else { + EXPECT_EQ(host_dst[i], host_expected[i]); + } + } + } +}; + +// Template-based test helper for cross-type reorders using SubgroupTensor +template +struct ConversionSubgroupTest : ConversionReorderTestBase {}; + +// Template-based test helper for cross-type reorders using Tensor +template +struct ConversionTensorTest : ConversionReorderTestBase {}; + +// Template-based test helper for identity reorders using SubgroupTensor +template +struct IdentityReorderTest : ReorderTestBase {}; + +// ============================================================================ +// SubgroupTensor-based Reorder Tests +// ============================================================================ +// These tests verify identity reorder operations using SubgroupTensor, +// ensuring data integrity through round-robin subgroup ownership patterns. + +// Test: Basic Types with SubgroupTensor (8x16 matrices) +TEST(PVC_CuTe_Xe_Reorder, subgroup_basic_float) { + IdentityReorderTest::run(); +} + +TEST(PVC_CuTe_Xe_Reorder, subgroup_basic_int32) { + IdentityReorderTest::run(); +} + +// Test: Half-precision and BFloat16 types +TEST(PVC_CuTe_Xe_Reorder, subgroup_half_precision) { + IdentityReorderTest::run(); +} + +TEST(PVC_CuTe_Xe_Reorder, subgroup_bfloat16) { + IdentityReorderTest::run(); +} + +// Test: Integer types (8-bit and sub-byte) +TEST(PVC_CuTe_Xe_Reorder, subgroup_int8) { + IdentityReorderTest::run(); +} + +// Test: Sub-byte types (uint4_t, int4_t) identity tests +// NOTE: These tests are disabled due to sub-byte packing issues in device memory. +// Sub-byte types (uint4_t, int4_t) require special byte-packing handling that is not +// properly supported in the current reorder kernel implementation. When two 4-bit values +// are packed into a single byte, the kernel's round-robin memory access pattern causes +// data corruption. For example, int4_t value -5 gets corrupted to 65 due to improper +// byte packing/unpacking during the reorder operation. +TEST(PVC_CuTe_Xe_Reorder, DISABLED_subbyte_uint4_identity) { + IdentityReorderTest::run(); +} + +TEST(PVC_CuTe_Xe_Reorder, DISABLED_subbyte_int4_identity) { + IdentityReorderTest::run(); +} + +// Note: Sub-byte types (uint4_t, int4_t) require special handling due to bit-packing +// and are covered in conversion tests with appropriate value ranges. + +// Test: Varied matrix sizes +TEST(PVC_CuTe_Xe_Reorder, subgroup_small_matrix_4x4) { + IdentityReorderTest::run(); +} + +TEST(PVC_CuTe_Xe_Reorder, subgroup_large_matrix_16x32) { + IdentityReorderTest::run(); +} + +TEST(PVC_CuTe_Xe_Reorder, subgroup_minimal_1x16) { + IdentityReorderTest::run(); +} + +TEST(PVC_CuTe_Xe_Reorder, subgroup_power_of_two_32x16) { + IdentityReorderTest::run(); +} + +// Test: Layout identity verification (same layout in and out) +TEST(PVC_CuTe_Xe_Reorder, subgroup_layout_identity_int32) { + IdentityReorderTest::run(); +} + +// Test: VNNI-compatible patterns +TEST(PVC_CuTe_Xe_Reorder, subgroup_vnni_pattern_half) { + IdentityReorderTest::run(); +} + +// ============================================================================ +// Tensor-based Reorder Tests +// ============================================================================ +// These tests verify identity reorder operations using explicit tensor layouts +// with TV (Thread Value) semantics, ensuring correct data handling without +// SubgroupTensor abstraction. + +// Test: Basic Types with Tensor Layouts (8x16 matrices) +TEST(PVC_CuTe_Xe_Reorder_Tensor, tensor_basic_float) { + ReorderTestBase::run(); +} + +TEST(PVC_CuTe_Xe_Reorder_Tensor, tensor_basic_int32) { + ReorderTestBase::run(); +} + +// Test: Half-precision and BFloat16 types +TEST(PVC_CuTe_Xe_Reorder_Tensor, tensor_half_precision) { + ReorderTestBase::run(); +} + +TEST(PVC_CuTe_Xe_Reorder_Tensor, tensor_bfloat16) { + ReorderTestBase::run(); +} + +// Test: Integer types with larger matrices +TEST(PVC_CuTe_Xe_Reorder_Tensor, tensor_int8) { + ReorderTestBase::run(); +} + +// Note: Sub-byte types (uint4_t, int4_t) require special handling due to bit-packing +// and are covered in conversion tests with appropriate value ranges. + +// ============================================================================ +// Cross-Type Conversion Tests (SubgroupTensor-based) +// ============================================================================ +// These tests verify reorder operations with type conversions using SubgroupTensor, +// ensuring correct data conversion from source to destination types. + +// Test: float to half_t conversion +TEST(PVC_CuTe_Xe_Reorder_Conversion, conversion_float_to_half) { + ConversionSubgroupTest::run(); +} + +// Test: int32_t to int8_t conversion +TEST(PVC_CuTe_Xe_Reorder_Conversion, conversion_int32_to_int8) { + ConversionSubgroupTest::run(); +} + +// Test: float to int32_t conversion +TEST(PVC_CuTe_Xe_Reorder_Conversion, conversion_float_to_int32) { + ConversionSubgroupTest::run(); +} + +// Test: half_t to float conversion +TEST(PVC_CuTe_Xe_Reorder_Conversion, conversion_half_to_float) { + ConversionSubgroupTest::run(); +} + +// ============================================================================ +// Cross-Type Conversion Tests (Tensor-based) +// ============================================================================ +// These tests verify reorder operations with type conversions using explicit tensor layouts, +// ensuring correct data conversion without SubgroupTensor abstraction. + +// Test: float to half_t conversion +TEST(PVC_CuTe_Xe_Reorder_Conversion_Tensor, tensor_conversion_float_to_half) { + ConversionTensorTest::run(); +} + +// Test: int32_t to int8_t conversion +TEST(PVC_CuTe_Xe_Reorder_Conversion_Tensor, tensor_conversion_int32_to_int8) { + ConversionTensorTest::run(); +} + +// Test: float to int32_t conversion +TEST(PVC_CuTe_Xe_Reorder_Conversion_Tensor, tensor_conversion_float_to_int32) { + ConversionTensorTest::run(); +} + +// Test: half_t to float conversion +TEST(PVC_CuTe_Xe_Reorder_Conversion_Tensor, tensor_conversion_half_to_float) { + ConversionTensorTest::run(); +} + +// ============================================================================ +// Sub-byte Type Conversion Tests (Expected Failures) +// ============================================================================ +// Note: Sub-byte types (uint4_t, int4_t) trigger SYCL kernel recursion errors +// in the CuTe reorder algorithm due to the recursive nature of the reorder() +// algorithm which is incompatible with SYCL kernel constraints. + +// SubgroupTensor sub-byte conversions (expected failures) + +// TEST(PVC_CuTe_Xe_Reorder_Conversion, DISABLED_conversion_uint8_to_uint4_subgroup) { +// ConversionSubgroupTest::run(); +// } + +// TEST(PVC_CuTe_Xe_Reorder_Conversion, DISABLED_conversion_int8_to_int4_subgroup) { +// ConversionSubgroupTest::run(); +// } + +// Tensor-based sub-byte conversions (expected failures) +// TEST(PVC_CuTe_Xe_Reorder_Conversion_Tensor, DISABLED_conversion_uint8_to_uint4_tensor) { +// ConversionTensorTest::run(); +// } + +// TEST(PVC_CuTe_Xe_Reorder_Conversion_Tensor, DISABLED_conversion_int8_to_int4_tensor) { +// ConversionTensorTest::run(); +// } diff --git a/test/unit/cute/intel_xe/tiled_mma.cpp b/test/unit/cute/intel_xe/tiled_mma.cpp index 1625d0ad35..c8a0d4f16b 100644 --- a/test/unit/cute/intel_xe/tiled_mma.cpp +++ b/test/unit/cute/intel_xe/tiled_mma.cpp @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -69,3 +70,45 @@ TEST(PVC_CuTe_Xe, tiled_mma_2) { check_tiled_mma, TileShape, SubgroupLayout, ExpectedTiledMMA>(); } + +TEST(PVC_CuTe_Xe, tiled_mma_dpas_3) { + + using TileShape = Shape<_256, _256, _32>; + using SubgroupLayout = Layout, Stride<_4, _1, _0>>; + using ExpectedTiledMMA = TiledMMA< + MMA_Atom>, + Layout, Stride<_4, _1, _0>>, + const Tile, Stride<_1, _32, _8>>, + Layout, Stride<_1, _64, _16>>, + decltype(coalesce(Layout, Stride<_1>>{}))>>; + check_tiled_mma>, TileShape, + SubgroupLayout, ExpectedTiledMMA>(); +} + +TEST(PVC_CuTe_Xe, tiled_mma_dpas_4) { + + using TileShape = Shape<_128, _64, _32>; + using SubgroupLayout = Layout, Stride<_2, _1, _0>>; + using ExpectedTiledMMA = TiledMMA< + MMA_Atom>, + Layout, Stride<_2, _1, _0>>, + const Tile, Stride<_1, _32, _8>>, + Layout, Stride<_1, _32, _16>>, + decltype(coalesce(Layout, Stride<_1>>{}))>>; + check_tiled_mma>, TileShape, + SubgroupLayout, ExpectedTiledMMA>(); +} + +TEST(PVC_CuTe_Xe, tiled_mma_dpas_5) { + + using TileShape = Shape<_128, _64, _32>; + using SubgroupLayout = Layout, Stride<_2, _1, _8>>; + using ExpectedTiledMMA = TiledMMA< + MMA_Atom>, + Layout, Stride<_2, _1, _8>>, + const Tile, Stride<_1, _32, _8>>, + Layout, Stride<_1, _32, _16>>, + decltype(coalesce(Layout, Stride<_1>>{}))>>; + check_tiled_mma>, TileShape, + SubgroupLayout, ExpectedTiledMMA>(); +} diff --git a/test/unit/cute/intel_xe/xe_copy_2d_test.cpp b/test/unit/cute/intel_xe/xe_copy_2d_test.cpp new file mode 100644 index 0000000000..53fa971471 --- /dev/null +++ b/test/unit/cute/intel_xe/xe_copy_2d_test.cpp @@ -0,0 +1,291 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass/detail/layout.hpp" + +#include +#include +#include +#include +#include +#include + +#include "cutlass_unit_test.h" +#include "utils.hpp" + +using namespace cute; +using namespace cutlass; +using namespace compat::experimental; + +#define SUBGROUP_SIZE (16) + +#if (IGC_VERSION_MAJOR > 2) || (IGC_VERSION_MAJOR == 2 && IGC_VERSION_MINOR >= 18) + +// Kernel name for unique identification +template class XECopy2DKernelName; + +// Device kernel for XE_LOAD_2D testing +template +void xe_copy_2d_kernel(SrcTensor src, DstTensor dst) { + using namespace cute; + using Element = typename SrcTensor::value_type; + + // Only execute with the first subgroup to avoid race conditions + if (sycl::ext::oneapi::this_work_item::get_nd_item<1>().get_group(0) == 0) { + // Get thread/subgroup information + auto local_id = int(sycl::ext::oneapi::this_work_item::get_nd_item<1>().get_local_id(0)); + + // Create block 2D copy inside kernel (device-only operation) + using CopyOp = XE_LOAD_2D; + auto tiled_copy = make_block_2d_copy(CopyOp{}, src); + + // Get thread slice of the tiled copy + auto thr_copy = tiled_copy.get_slice(local_id); + + // Create coordinate tensor for a single tile + auto coord_shape = make_shape(Int{}, Int>{}); + Tensor coord_tile = make_identity_tensor(coord_shape); + + // Partition source coordinates and create destination fragment + auto thr_src_coord = thr_copy.partition_S(coord_tile); + auto thr_dst_frag = thr_copy.partition_fragment_D(coord_tile); + + // Perform the copy operation from global memory to registers + copy(tiled_copy, thr_src_coord, thr_dst_frag); + + // For verification, create a 2D store operation to write registers back to destination + using StoreOp = XE_STORE_2D; + auto tiled_store = make_block_2d_copy(StoreOp{}, dst); + auto thr_store = tiled_store.get_slice(local_id); + + // Create destination coordinates for the store operation + auto thr_dst_coord = thr_store.partition_D(coord_tile); + auto thr_src_frag = thr_store.partition_fragment_S(coord_tile); + + // Copy the loaded data from registers to the fragment for storing + copy(thr_dst_frag, thr_src_frag); + + // Perform the store operation from registers to global memory + copy(tiled_store, thr_src_frag, thr_dst_coord); + + // Synchronize to ensure all threads complete their operations + sycl::group_barrier(sycl::ext::oneapi::this_work_item::get_nd_item<1>().get_group()); + } +} + +// Host test function template +template +void test_xe_copy_2d() { + using namespace cute; + + // Matrix dimensions - must be compatible with block 2D constraints + constexpr int M = Height; + constexpr int N = Width * sizeof_bits_v / Bits; + + // Ensure proper alignment (required for block 2D operations) + constexpr int elem_alignment = 16 / sizeof(Element); + constexpr int aligned_N = ((N + elem_alignment - 1) / elem_alignment) * elem_alignment; + + // Allocate and initialize host data + cutlass::host_vector host_src(M * aligned_N); + cutlass::host_vector host_dst(M * aligned_N); + + + // Initialize source with test pattern + for (size_t i = 0; i < host_src.size(); ++i) { + // Use a safe conversion that works for all numeric types + if constexpr (std::is_floating_point_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) { + + // For floating-point types, convert through float + float val = static_cast(i % 256) / 255.0f; // Normalize to [0,1] + host_src[i] = Element(val); + } else { + // For integer types (including uint64_t) and char, direct conversion is safe + host_src[i] = static_cast(i % 256); + } + } + + // Copy to device + cutlass::device_vector device_src = host_src; + cutlass::device_vector device_dst(M * aligned_N); + + // Create tensors with proper layout + Tensor tensor_src = + make_tensor(make_gmem_ptr(device_src.data()), + make_layout(Shape, Int>{}, Stride, _1>{})); + + Tensor tensor_dst = + make_tensor(make_gmem_ptr(device_dst.data()), + make_layout(Shape, Int>{}, Stride, _1>{})); + + // Launch kernel - copy creation happens on device + auto blockDim = compat::dim3(SUBGROUP_SIZE); + auto gridDim = compat::dim3(1); + + launch, + XECopy2DKernelName>( + launch_policy{ + gridDim, blockDim, + kernel_properties{sycl_exp::sub_group_size} + }, + tensor_src, tensor_dst); + + compat::wait_and_throw(); + host_dst = device_dst; + for (int i = 0; i < M * N; ++i) { + // printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(host_dst[i], host_src[i]); + } +} + +TEST(PVC_CuTe_Xe, XE_COPY_2D_uint8) { + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); +} + +TEST(PVC_CuTe_Xe, XE_COPY_2D_int8) { + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); +} + +TEST(PVC_CuTe_Xe, XE_COPY_2D_uint16) { + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); +} + +TEST(PVC_CuTe_Xe, XE_COPY_2D_int16) { + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); +} + +TEST(PVC_CuTe_Xe, XE_COPY_2D_half) { + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); +} + +TEST(PVC_CuTe_Xe, XE_COPY_2D_bfloat16) { + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); +} + +TEST(PVC_CuTe_Xe, XE_COPY_2D_uint32) { + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); +} + +TEST(PVC_CuTe_Xe, XE_COPY_2D_int32) { + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); +} + +TEST(PVC_CuTe_Xe, XE_COPY_2D_float) { + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); +} + +TEST(PVC_CuTe_Xe, XE_COPY_2D_tfloat32) { + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); +} + +TEST(PVC_CuTe_Xe, XE_COPY_2D_char) { + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); + test_xe_copy_2d(); +} + +#else + +// For the fallback case +#include "cutlass_unit_test.h" + +TEST(PVC_CuTe_Xe, XE_COPY_2D_SKIPPED) { + GTEST_SKIP() << "XE_COPY_2D tests require IGC version 2.18 or higher. skipped"; +} + +#endif diff --git a/test/unit/cute/intel_xe/xe_copy_prefetch_2d.cpp b/test/unit/cute/intel_xe/xe_copy_prefetch_2d.cpp new file mode 100644 index 0000000000..675b8d1c43 --- /dev/null +++ b/test/unit/cute/intel_xe/xe_copy_prefetch_2d.cpp @@ -0,0 +1,162 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass/detail/layout.hpp" + +#include +#include +#include +#include +#include +#include + +#include "cutlass_unit_test.h" +#include "utils.hpp" + +using namespace cute; +using namespace cutlass; +using namespace compat::experimental; + +#if (IGC_VERSION_MAJOR > 2) || (IGC_VERSION_MAJOR == 2 && IGC_VERSION_MINOR >= 18) + +// Kernel name for unique identification +template +class XEPrefetch2DKernelName; + +// Device kernel for XE_PREFETCH_2D testing +template +void xe_prefetch_2d_kernel(SrcTensor src) { + using namespace cute; + using namespace sycl::ext::oneapi::this_work_item; + using Element = typename SrcTensor::value_type; + + // Only execute with the first subgroup to avoid race conditions + if (get_nd_item<1>().get_group(0) == 0) { + // Get thread/subgroup information + auto local_id = int(get_nd_item<1>().get_local_id(0)); + + // Create block 2D prefetch inside kernel (device-only operation) + using PrefetchOp = XE_PREFETCH_2D; + auto tiled_prefetch = make_block_2d_copy(PrefetchOp{}, src); + + // Get thread slice of the tiled prefetch + auto thr_prefetch = tiled_prefetch.get_slice(local_id); + + // Create coordinate tensor for a single tile + auto coord_shape = make_shape(Int{}, Int>{}); + Tensor coord_tile = make_identity_tensor(coord_shape); + + // Partition source coordinates for prefetch + auto thr_src_coord = thr_prefetch.partition_S(coord_tile); + + // Create dummy destination fragment (prefetch ignores destination) + auto thr_dst_frag = thr_prefetch.partition_fragment_D(coord_tile); + + // Perform the prefetch operation + copy(tiled_prefetch, thr_src_coord, thr_dst_frag); + + // Synchronize to ensure all threads complete their operations + sycl::group_barrier(get_nd_item<1>().get_group()); + } +} + +// Host test function template for XE_PREFETCH_2D +template +void test_xe_prefetch_2d() { + using namespace cute; + + // Matrix dimensions - must be compatible with block 2D constraints + constexpr int M = Height; + constexpr int N = (Width * sizeof_bits_v) / Bits; + + // Ensure proper alignment (required for block 2D operations) + constexpr int elem_alignment = 16 / sizeof(Element); + constexpr int aligned_N = ((N + elem_alignment - 1) / elem_alignment) * elem_alignment; + + // Allocate and initialize host data + cutlass::host_vector host_src(M * aligned_N); + + // Initialize source with test pattern + for (size_t i = 0; i < host_src.size(); ++i) { + host_src[i] = static_cast(i % 256); + } + + // Copy to device + cutlass::device_vector device_src = host_src; + + // Create tensors with proper layout + Tensor tensor_src = + make_tensor(make_gmem_ptr(device_src.data()), + make_layout(Shape, Int>{}, Stride, _1>{})); + + // Launch kernel - prefetch happens on device + auto blockDim = compat::dim3(intel::sg_size); + auto gridDim = compat::dim3(1); + + launch, + XEPrefetch2DKernelName>( + launch_policy{ + gridDim, blockDim, + kernel_properties{sycl_exp::sub_group_size} + }, + tensor_src); + + compat::wait_and_throw(); + + // Note: XE_PREFETCH_2D just prefetches to cache, no verification needed + EXPECT_TRUE(true) << "XE_PREFETCH_2D operation completed successfully"; +} + +TEST(CuTe_Xe, XE_PREFETCH_2D_uint8) { + test_xe_prefetch_2d(); + test_xe_prefetch_2d(); +} + +TEST(CuTe_Xe, XE_PREFETCH_2D_int16) { + test_xe_prefetch_2d(); + test_xe_prefetch_2d(); +} + +TEST(CuTe_Xe, XE_PREFETCH_2D_float) { + test_xe_prefetch_2d(); + test_xe_prefetch_2d(); +} + +#else + +// For the fallback case +#include "cutlass_unit_test.h" + +TEST(CuTe_Xe, XE_PREFETCH_2D_SKIPPED) { + GTEST_SKIP() << "XE_PREFETCH_2D tests require IGC version 2.18 or higher. skipped"; +} + +#endif diff --git a/test/unit/cute/intel_xe/xe_transpose_2d.cpp b/test/unit/cute/intel_xe/xe_transpose_2d.cpp new file mode 100644 index 0000000000..d2375d2fc8 --- /dev/null +++ b/test/unit/cute/intel_xe/xe_transpose_2d.cpp @@ -0,0 +1,100 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include +#include +#include +#include +#include +#include "cutlass_unit_test.h" + +using namespace cute; + +#if (IGC_VERSION_MAJOR > 2) || (IGC_VERSION_MAJOR == 2 && IGC_VERSION_MINOR >= 18) + +TEST(CuTe_Xe, XE_LOAD_2D_TRANSPOSE_API_Declaration) { + // Template: XE_LOAD_2D_TRANSPOSE + // Constraints: Bits == 32 || Bits == 64, Width <= 8 + // For 64-bit: Height == 8 && Width < 4 + + // Test 32-bit transpose operations + using TransposeOp_32bit_2x4 = XE_LOAD_2D_TRANSPOSE<32, 2, 4>; + using TransposeOp_32bit_4x8 = XE_LOAD_2D_TRANSPOSE<32, 4, 8>; + using TransposeOp_32bit_8x2 = XE_LOAD_2D_TRANSPOSE<32, 8, 2>; + + // Test 64-bit transpose operations (limited constraints) + using TransposeOp_64bit_8x2 = XE_LOAD_2D_TRANSPOSE<64, 8, 2>; + using TransposeOp_64bit_8x3 = XE_LOAD_2D_TRANSPOSE<64, 8, 3>; + + // Test that the operations have the required static members from XE_Copy_Op_2D_Base + static_assert(TransposeOp_32bit_2x4::AtomHeight == 2); + static_assert(TransposeOp_32bit_2x4::AtomWidth == 4); + static_assert(TransposeOp_32bit_2x4::CopyBits == 32); + + static_assert(TransposeOp_32bit_4x8::AtomHeight == 4); + static_assert(TransposeOp_32bit_4x8::AtomWidth == 8); + static_assert(TransposeOp_32bit_4x8::CopyBits == 32); + + static_assert(TransposeOp_64bit_8x2::AtomHeight == 8); + static_assert(TransposeOp_64bit_8x2::AtomWidth == 2); + static_assert(TransposeOp_64bit_8x2::CopyBits == 64); + + EXPECT_TRUE(true) << "XE_LOAD_2D_TRANSPOSE API types declared successfully"; +} + +TEST(CuTe_Xe, XE_LOAD_2D_TRANSPOSE_Constraints) { + // Test that the compile-time constraints are enforced + + // Valid 32-bit operations + using Valid32_1 = XE_LOAD_2D_TRANSPOSE<32, 1, 1>; + using Valid32_2 = XE_LOAD_2D_TRANSPOSE<32, 16, 8>; // Width <= 8 + + // Valid 64-bit operations (Height == 8 && Width < 4) + using Valid64_1 = XE_LOAD_2D_TRANSPOSE<64, 8, 1>; + using Valid64_2 = XE_LOAD_2D_TRANSPOSE<64, 8, 2>; + using Valid64_3 = XE_LOAD_2D_TRANSPOSE<64, 8, 3>; + + static_assert(Valid32_1::CopyBits == 32); + static_assert(Valid32_2::CopyBits == 32); + static_assert(Valid64_1::CopyBits == 64); + static_assert(Valid64_2::CopyBits == 64); + static_assert(Valid64_3::CopyBits == 64); + + EXPECT_TRUE(true) << "XE_LOAD_2D_TRANSPOSE constraint validation successful"; +} + +#else + +TEST(CuTe_Xe, XE_LOAD_2D_TRANSPOSE_SKIPPED) { + GTEST_SKIP() << "XE_LOAD_2D_TRANSPOSE tests require IGC version 2.18 or higher. skipped"; +} + +#endif diff --git a/test/unit/cute/intel_xe/xe_vnni_2d.cpp b/test/unit/cute/intel_xe/xe_vnni_2d.cpp new file mode 100644 index 0000000000..2112e474b0 --- /dev/null +++ b/test/unit/cute/intel_xe/xe_vnni_2d.cpp @@ -0,0 +1,69 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF POSSIBILITY OF SUCH DAMAGE. + **************************************************************************************************/ + +#include +#include +#include +#include +#include +#include "cutlass_unit_test.h" + +using namespace cute; + +#if (IGC_VERSION_MAJOR > 2) || (IGC_VERSION_MAJOR == 2 && IGC_VERSION_MINOR >= 18) + +TEST(CuTe_Xe, XE_LOAD_2D_VNNI_API_Declaration) { + // Template: XE_LOAD_2D_VNNI + + // Test that the VNNI operation types can be declared + using VNNIOp_8bit_2x32 = XE_LOAD_2D_VNNI<8, 2, 32>; + using VNNIOp_8bit_4x32 = XE_LOAD_2D_VNNI<8, 4, 32>; + using VNNIOp_16bit_2x16 = XE_LOAD_2D_VNNI<16, 2, 16>; + using VNNIOp_16bit_4x16 = XE_LOAD_2D_VNNI<16, 4, 16>; + + // Test that the operations have the required static members from XE_Copy_Op_2D_Base + static_assert(VNNIOp_8bit_2x32::AtomHeight == 2); + static_assert(VNNIOp_8bit_2x32::AtomWidth == 32); + static_assert(VNNIOp_8bit_2x32::CopyBits == 8); + + static_assert(VNNIOp_16bit_2x16::AtomHeight == 2); + static_assert(VNNIOp_16bit_2x16::AtomWidth == 16); + static_assert(VNNIOp_16bit_2x16::CopyBits == 16); + + EXPECT_TRUE(true) << "XE_LOAD_2D_VNNI API types declared successfully"; +} + +#else + +TEST(CuTe_Xe, XE_LOAD_2D_VNNI_SKIPPED) { + GTEST_SKIP() << "XE_LOAD_2D_VNNI tests require IGC version 2.18 or higher. skipped"; +} + +#endif diff --git a/test/unit/flash_attention/flash_attention_prefill/flash_prefill_testbed_3x.hpp b/test/unit/flash_attention/flash_attention_prefill/flash_prefill_testbed_3x.hpp index ece31f6f7a..442d17ad37 100644 --- a/test/unit/flash_attention/flash_attention_prefill/flash_prefill_testbed_3x.hpp +++ b/test/unit/flash_attention/flash_attention_prefill/flash_prefill_testbed_3x.hpp @@ -283,6 +283,10 @@ struct TestbedImpl { block_O.reset(batch * num_heads_q * seq_len_qo * head_size_vo); block_ref_O.reset(batch * num_heads_q * seq_len_qo * head_size_vo); + // Zero-initialize output buffer for the kernel result + // block_ref_O is fully written in verify() before being read, so no initialization needed + compat::memset(block_O.get(), 0, block_O.size() * sizeof(ElementOutput)); + initialize_block(block_Q, seed + 2023); initialize_block(block_K, seed + 2022); initialize_block(block_V, seed + 2021); diff --git a/test/unit/flash_attention/flash_attention_prefill_cachedkv/flash_prefill_cachedkv_testbed_3x.hpp b/test/unit/flash_attention/flash_attention_prefill_cachedkv/flash_prefill_cachedkv_testbed_3x.hpp index b758d1b8fd..f879af895e 100644 --- a/test/unit/flash_attention/flash_attention_prefill_cachedkv/flash_prefill_cachedkv_testbed_3x.hpp +++ b/test/unit/flash_attention/flash_attention_prefill_cachedkv/flash_prefill_cachedkv_testbed_3x.hpp @@ -220,6 +220,12 @@ struct TestbedImpl { block_V_cache.reset(batch * num_heads_kv * seq_len_kv_cache * head_size_vo); block_O.reset(batch * num_heads_q * seq_len_qo * head_size_vo); block_ref_O.reset(batch * num_heads_q * seq_len_qo * head_size_vo); + + // Zero-initialize output buffer for the kernel result + // block_ref_O is fully written in verify() before being read, so no initialization needed + if (block_O.size() > 0) { + compat::memset(block_O.get(), 0, block_O.size() * sizeof(ElementOutput)); + } if constexpr (UsePagedKV) { std::vector num_pages_per_seq{0}; diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index ad6f93b3e8..1012c5cc3e 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -34,6 +34,7 @@ if(CUTLASS_ENABLE_SYCL) xe_gemm_bf16_bf16_bf16_tensor_op_bf16.cpp xe_gemm_fp16_fp16_fp16_tensor_op_fp16.cpp xe_gemm_bf16_bf16_bf16_tensor_op_fp32.cpp + xe_gemm_bf16_bf16_fp32_tensor_op_bf16.cpp xe_gemm_bf16_bf16_fp32_tensor_op_fp32.cpp xe_gemm_fp16_fp16_fp16_tensor_op_fp32.cpp xe_gemm_fp16_fp16_fp32_tensor_op_fp32.cpp @@ -42,12 +43,14 @@ if(CUTLASS_ENABLE_SYCL) xe_gemm_f8_f8_fp32_tensor_op_fp32.cpp xe_gemm_fp16_s8_fp32_tensor_op_fp32.cpp gemm_universal_bf16n_bf16t_f32n_tensor_op_f32_xe.cpp + gemm_universal_fp8_fp8_fp32_tensor_op_f32_xe_models.cpp ) cutlass_test_unit_add_executable( cutlass_test_unit_gemm_device_tensorop_cooperative_xe xe_gemm_bf16_bf16_fp32_tensor_op_fp32_cooperative.cpp xe_gemm_fp16_fp16_fp32_tensor_op_fp32_cooperative.cpp + # xe_gemm_fp16_fp16_f32_ptr_array_cooperative.cpp # TODO (Codeplay): fix gemm cooperative tests for s8 and tf32 # xe_gemm_s8_s8_s32_tensor_op_s32_cooperative.cpp # xe_gemm_tf32_tf32_fp32_tensor_op_fp32_cooperative.cpp diff --git a/test/unit/gemm/device/default_gemm_configuration.hpp b/test/unit/gemm/device/default_gemm_configuration.hpp index a851e6110f..9b786027fc 100644 --- a/test/unit/gemm/device/default_gemm_configuration.hpp +++ b/test/unit/gemm/device/default_gemm_configuration.hpp @@ -62,6 +62,18 @@ struct DefaultGemmConfigurationToCutlass3Types { static_assert(sizeof(ElementA) == 0, "No valid DefaultGemmConfigurationToCutlass3Types configuration exists."); }; +// This type is only intended to demonstrate porting 2.x kernels to 3.0 +template< + class OperatorClass, class ArchTag, + class ElementA, class LayoutA, + class ElementB, class LayoutB, + class ElementC, class LayoutC, + class ElementAccumulator, + class ElementOutput> +struct XeDefaultGemmConfigurationToCutlass3Types { + static_assert(sizeof(ElementA) == 0, "No valid XeDefaultGemmConfigurationToCutlass3Types configuration exists."); +}; + /////////////////////////////////////////////////////////////////////////////// namespace detail { @@ -1486,6 +1498,141 @@ struct DefaultGemmConfigurationToCutlass3Types< >::CollectiveOp; }; +/////////////////////////////////////////////////////////////////////////////// + +// Intel XE MMA F32BF16 +// ElementC - > void +// ElementCompute and ElementOutput different in LinearCombination +template +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::IntelXe, + bfloat16_t, LayoutA, + bfloat16_t, LayoutB, + void, LayoutC, + ElementOutput> +{ + using TileShape = Shape<_256, _256, _32>; + + using TiledMma = + typename TiledMMAHelper, + Layout, + Layout, Stride<_4, _1, _0>>>::TiledMMA; + + // A + static constexpr int kAlignmentA = 32; + using DefaultOperandA = detail::DefaultGemm_TensorOpXe_OperandA< + bfloat16_t, LayoutA, kAlignmentA, 32>; + using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; + + // B + static constexpr int kAlignmentB = 32; + using DefaultOperandB = detail::DefaultGemm_TensorOpXe_OperandB< + bfloat16_t, LayoutB, kAlignmentB, 32>; + using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp, + cute::bfloat16_t, LayoutA, 1, + cute::bfloat16_t, LayoutB, 1, + float, + TileShape, Shape<_1, _1, _1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + //using EpilogueOp = epilogue::fusion::LinearCombination; + using EpilogueOp = epilogue::fusion::LinearCombination; + + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks< + epilogue::IntelXeXMX16, + EpilogueOp, + TileShape, + decltype(tile_shape(TiledMma())) + >; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp, + TileShape, Shape<_1, _1, _1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, LayoutC, 1, + cute::bfloat16_t, LayoutC, 1, + cutlass::epilogue::collective::EpilogueScheduleAuto, + EpilogueOp + >::CollectiveOp; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// Intel XE MMA F32BF16 +// D=Ax B + C; => BF16=BF16xBF16+BF16 <=>BF16=FP32+BF16 +// ElementAccumulator and ElementC are different types. +template < + typename LayoutA, + typename LayoutB, + typename LayoutC, + typename ElementAccumulator, + typename ElementOutput> +struct XeDefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::IntelXe, + bfloat16_t, LayoutA, + bfloat16_t, LayoutB, + bfloat16_t, LayoutC, + ElementAccumulator, + ElementOutput> +{ + using TileShape = Shape<_256, _256, _32>; + + using TiledMma = + typename TiledMMAHelper, + Layout, + Layout, Stride<_4, _1, _0>>>::TiledMMA; + + // A + static constexpr int kAlignmentA = 32; + using DefaultOperandA = detail::DefaultGemm_TensorOpXe_OperandA< + bfloat16_t, LayoutA, kAlignmentA, 32>; + using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; + + // B + static constexpr int kAlignmentB = 32; + using DefaultOperandB = detail::DefaultGemm_TensorOpXe_OperandB< + bfloat16_t, LayoutB, kAlignmentB, 32>; + using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp, + cute::bfloat16_t, LayoutA, 1, + cute::bfloat16_t, LayoutB, 1, + ElementAccumulator, + TileShape, Shape<_1, _1, _1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = epilogue::fusion::LinearCombination; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks< + epilogue::IntelXeXMX16, + EpilogueOp, + TileShape, + decltype(tile_shape(TiledMma())) + >; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp, + TileShape, Shape<_1, _1, _1>, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, float, + bfloat16_t, LayoutC, 1, + ElementOutput, LayoutC, 1, + cutlass::epilogue::collective::EpilogueScheduleAuto, + EpilogueOp + >::CollectiveOp; +}; + + /////////////////////////////////////////////////////////////////////////////// namespace detail { diff --git a/test/unit/gemm/device/default_gemm_group_configuration.hpp b/test/unit/gemm/device/default_gemm_group_configuration.hpp index 9dd5ab06f0..4238566ef0 100644 --- a/test/unit/gemm/device/default_gemm_group_configuration.hpp +++ b/test/unit/gemm/device/default_gemm_group_configuration.hpp @@ -90,7 +90,7 @@ struct DefaultGemmGroupConfiguration< using EpilogueOp = epilogue::fusion::LinearCombination; using FusionCallBacks = epilogue::fusion::FusionCallbacks< - epilogue::IntelXeXMX16Group, + epilogue::IntelXeGenericGroup, EpilogueOp, TileShape, decltype(tile_shape(TiledMma())) @@ -103,7 +103,7 @@ struct DefaultGemmGroupConfiguration< float, float, float, LayoutC, 1, ElementOutput, LayoutC, 1, - epilogue::IntelXeXMX16Group, + epilogue::IntelXeGenericGroup, EpilogueOp >::CollectiveOp; }; diff --git a/test/unit/gemm/device/gemm_testbed_3x.hpp b/test/unit/gemm/device/gemm_testbed_3x.hpp index fb2e3dc6bd..b9861f56a8 100644 --- a/test/unit/gemm/device/gemm_testbed_3x.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x.hpp @@ -4187,6 +4187,48 @@ bool TestXe( } // m return passed; } + +template class ActivationFunctor = + cutlass::epilogue::thread::Identity> +bool TestXe( + int m, int n, int k, int l, + double alpha = 1.0, + double beta = cute::is_same_v ? 0.0 : 1.0, + CheckEquality check_relative_equality = CheckEquality::RELATIVE) { + using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + Testbed3x testbed( + check_relative_equality, ScalarLoc::ON_HOST, VectorScale::DISABLED); + + bool passed = true; + ProblemShapeType problem_size{m, n, k, l}; + try { + passed = testbed.run(problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta)); + } + catch (std::exception const& e) { + EXPECT_TRUE(false) << "TestXe: testbed.run threw an exception: " << e.what(); + throw; + } + catch (...) { + EXPECT_TRUE(false) << "TestXe: testbed.run threw an unknown exception for MNKL = " + << m << " " << n << " " << k << " " << l; + throw; + } + + EXPECT_TRUE(passed) << "TestXe: testbed.run failed for MNKL = " + << m << " " << n << " " << k << " " << l + << ", alpha: " << alpha << ", beta: " << beta; + + return passed; +} + #endif template diff --git a/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp b/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp index 9341c58c02..b0107e52e6 100644 --- a/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp @@ -2301,13 +2301,16 @@ bool TestSmall(double alpha = 1.0, double beta = 1.0, using ElementA = typename Gemm::GemmKernel::ElementA; using ElementB = typename Gemm::GemmKernel::ElementB; using TiledMma = typename Gemm::GemmKernel::TiledMma; - int alignment_bits = 128; static constexpr bool IsF8F6F4 = cutlass::gemm::collective::detail::is_sm100_mma_f8f6f4(); - alignment_bits = cutlass::detail::get_input_alignment_bits(); - // For fp4 and fp6 kernels, the min alignment_input is 128 elements, so we don't need to add alignment_input in test problem sizes. - int alignment_input = (alignment_bits / cute::sizeof_bits::value == 128) ? 0 : (alignment_bits / cute::sizeof_bits::value); - + // For fp4 and fp6 kernels, the min alignment_input is 128 elements, so we don't need to add alignment_input in test problem sizes. + int alignment_bits_a = cutlass::detail::get_input_alignment_bits(); + int alignment_input_a = (alignment_bits_a / cute::sizeof_bits::value == 128) ? 0 : (alignment_bits_a / cute::sizeof_bits::value); + + int alignment_bits_b = cutlass::detail::get_input_alignment_bits(); + int alignment_input_b = (alignment_bits_b / cute::sizeof_bits::value == 128) ? 0 : (alignment_bits_b / cute::sizeof_bits::value); + + int alignment_input = (alignment_input_a == 0 || alignment_input_b == 0) ? 0 : std::max(alignment_input_a, alignment_input_b); if constexpr (apply_alignment_offset) { // If BlockScaled, then min alignment is SFVecSize @@ -2424,6 +2427,50 @@ bool TestSmallFusion(double alpha = 1.0, double beta = 0.0, alpha, beta, check_relative_equality, use_device_scalars, vector_scale_mode); } +///////////////////////////////////////////////////////////////////////////////////////////////// + +// TestAll template function overload for grouped GEMM testing with explicit problem sizes +template class ActivationFunctor = cutlass::epilogue::thread::Identity> +bool TestAll(const std::vector& problem_sizes, + double alpha = 1.0, double beta = 0.0) { + using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + if (problem_sizes.empty()) { + std::cerr << "Error: problem_sizes vector cannot be empty.\n"; + return false; + } + + Testbed3x testbed( + CheckEquality::RELATIVE, + ScalarLoc::ON_DEVICE, + VectorScale::DISABLED + ); + + // Convert vector of GemmCoord to the format needed by grouped GEMM testbed + std::vector problem_sizes_host; + for (const auto& coord : problem_sizes) { + problem_sizes_host.push_back({coord.m(), coord.n(), coord.k()}); + } + + cutlass::DeviceAllocation problem_sizes_device; + problem_sizes_device.reset(problem_sizes_host.size()); + problem_sizes_device.copy_from_host(problem_sizes_host.data()); + + + bool passed = testbed.run( + ProblemShapeType{ + static_cast(problem_sizes_host.size()), + problem_sizes_device.get(), + problem_sizes_host.data() + }, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + return passed; +} + } // namespace device } // namespace gemm } // namespace test diff --git a/test/unit/gemm/device/gemm_universal_fp8_fp8_fp32_tensor_op_f32_xe_models.cpp b/test/unit/gemm/device/gemm_universal_fp8_fp8_fp32_tensor_op_f32_xe_models.cpp new file mode 100644 index 0000000000..4953be3fa8 --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_fp8_fp8_fp32_tensor_op_f32_xe_models.cpp @@ -0,0 +1,179 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ +#include +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "default_gemm_configuration.hpp" +#include "gemm_testbed_3x.hpp" + +using namespace cutlass; + +namespace { + +template +struct MainloopIntelW8A8_GemmConfig { + using ElementA = float_e5m2_t; + using ElementB = float_e5m2_t; + using TileShape = Shape<_256, _256, _32>; + constexpr static int PipelineStages = 2; + using Schedule = gemm::KernelXe; + using TiledMma = typename TiledMMAHelper< + MMA_Atom, + Layout, + Layout, Stride<_4, _1, _0>> + >::TiledMMA; + using GmemTiledCopyA = XE_2D_U8x32x32_LD_N; + using GmemTiledCopyB = XE_2D_U8x32x32_LD_V; + + using DispatchPolicy = gemm::MainloopIntelW8A8; + + using CollectiveMainloop = gemm::collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, cutlass::gemm::TagToStrideA_t, + ElementB, cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination< + float, float + >; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks< + cutlass::epilogue::IntelXeXMX16, + EpilogueOp, + TileShape, + decltype(tile_shape(TiledMma())) + >; + + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + cutlass::epilogue::IntelXeXMX16, + TileShape, + float, cutlass::gemm::TagToStrideC_t, + float, cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, void, void, + XE_2D_U32x8x16_ST_N, void, void + >; + + using GemmKernel = gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = gemm::device::GemmUniversalAdapter; +}; + +TEST(MainloopIntelW8A8_Special, LargeModel_LLaMA2_7B) { + using Gemm = typename MainloopIntelW8A8_GemmConfig::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe(4096, 4096, 11008, 1, 1.0, 0.0)); +} + +TEST(MainloopIntelW8A8_Special, LargeModel_Mistral_7B) { + using Gemm = typename MainloopIntelW8A8_GemmConfig::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe(4096, 4096, 14336, 1, 1.0, 0.0)); +} + +TEST(MainloopIntelW8A8_Special, TensorParallel) { + using Gemm = typename MainloopIntelW8A8_GemmConfig::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe(4096, 1024, 4096, 1, 1.0, 0.0)); +} + +TEST(MainloopIntelW8A8_Special, ModelParallel) { + using Gemm = typename MainloopIntelW8A8_GemmConfig::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe(1024, 4096, 4096, 1, 1.0, 0.0)); +} + +TEST(MainloopIntelW8A8_Special, MicroBatch) { + using Gemm = typename MainloopIntelW8A8_GemmConfig::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe(128, 128, 8192, 4, 1.0, 0.0)); +} + +TEST(MainloopIntelW8A8_Special, LargeBatch) { + using Gemm = typename MainloopIntelW8A8_GemmConfig::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe(512, 512, 2048, 32, 1.0, 0.0)); +} + +TEST(MainloopIntelW8A8_Special, SquareSmall) { + using Gemm = typename MainloopIntelW8A8_GemmConfig::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe(64, 64, 64, 1, 1.0, 0.0)); +} + +TEST(MainloopIntelW8A8_Special, SquareMedium) { + using Gemm = typename MainloopIntelW8A8_GemmConfig::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe(512, 512, 512, 1, 1.0, 0.0)); +} + +TEST(MainloopIntelW8A8_Special, SquareLarge) { + using Gemm = typename MainloopIntelW8A8_GemmConfig::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe(2048, 2048, 2048, 1, 1.0, 0.0)); +} + +TEST(MainloopIntelW8A8_Special, TallMatrix) { + using Gemm = typename MainloopIntelW8A8_GemmConfig::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe(4096, 512, 4096, 1, 1.0, 0.0)); +} + +TEST(MainloopIntelW8A8_Special, WideMatrix) { + using Gemm = typename MainloopIntelW8A8_GemmConfig::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe(512, 4096, 4096, 1, 1.0, 0.0)); +} + +TEST(MainloopIntelW8A8_Special, Batch8) { + using Gemm = typename MainloopIntelW8A8_GemmConfig::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe(512, 512, 512, 8, 1.0, 0.0)); +} + +TEST(MainloopIntelW8A8_Special, Batch16Large) { + using Gemm = typename MainloopIntelW8A8_GemmConfig::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe(1024, 1024, 1024, 16, 1.0, 0.0)); +} + +TEST(MainloopIntelW8A8_Special, LargeK) { + using Gemm = typename MainloopIntelW8A8_GemmConfig::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe(64, 64, 8192, 1, 1.0, 0.0)); +} + +TEST(MainloopIntelW8A8_Special, LargeN) { + using Gemm = typename MainloopIntelW8A8_GemmConfig::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe(64, 8192, 64, 1, 1.0, 0.0)); +} + +} // namespace \ No newline at end of file diff --git a/test/unit/gemm/device/xe_gemm_bf16_bf16_bf16_tensor_op_fp32.cpp b/test/unit/gemm/device/xe_gemm_bf16_bf16_bf16_tensor_op_fp32.cpp index f78dbf7734..aa6cb2d8ad 100644 --- a/test/unit/gemm/device/xe_gemm_bf16_bf16_bf16_tensor_op_fp32.cpp +++ b/test/unit/gemm/device/xe_gemm_bf16_bf16_bf16_tensor_op_fp32.cpp @@ -85,5 +85,51 @@ TEST(XE_Device_Gemm_bf16n_bf16n_bf16t_tensor_op_f32, 256x256x32) { EXPECT_TRUE(test::gemm::device::TestXe()); } + +// ElementC ---> void +// ElementOutput != ElementCompute in LinearCombination + +template +struct XE_Device_Gemm_bf16_bf16_bf16_tensor_op_f32_void { + using Config = + gemm::device::DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::IntelXe, + cute::bfloat16_t, LayoutA, + cute::bfloat16_t, LayoutB, + void, layout::RowMajor, + cute::bfloat16_t>; + + using Gemm = gemm::device::GemmUniversalAdapter< + gemm::kernel::GemmUniversal< + cute::Shape, + typename Config::CollectiveMainloop, + typename Config::CollectiveEpilogue>>; +}; + +TEST(XE_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32_void, 256x256x32) { + using Gemm = XE_Device_Gemm_bf16_bf16_bf16_tensor_op_f32_void< + layout::RowMajor, layout::RowMajor>::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe()); +} + +TEST(XE_Device_Gemm_bf16n_bf16t_bf16t_tensor_op_f32_void, 256x256x32) { + using Gemm = XE_Device_Gemm_bf16_bf16_bf16_tensor_op_f32_void< + layout::ColumnMajor, layout::RowMajor>::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe()); +} + +TEST(XE_Device_Gemm_bf16t_bf16n_bf16t_tensor_op_f32_void, 256x256x32) { + using Gemm = XE_Device_Gemm_bf16_bf16_bf16_tensor_op_f32_void< + layout::RowMajor, layout::ColumnMajor>::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe()); +} + +TEST(XE_Device_Gemm_bf16n_bf16n_bf16t_tensor_op_f32_void, 256x256x32) { + using Gemm = XE_Device_Gemm_bf16_bf16_bf16_tensor_op_f32_void< + layout::ColumnMajor, layout::ColumnMajor>::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe()); +} + + } } // namespace cutlass diff --git a/test/unit/gemm/device/xe_gemm_bf16_bf16_fp32_tensor_op_bf16.cpp b/test/unit/gemm/device/xe_gemm_bf16_bf16_fp32_tensor_op_bf16.cpp new file mode 100644 index 0000000000..a50695d289 --- /dev/null +++ b/test/unit/gemm/device/xe_gemm_bf16_bf16_fp32_tensor_op_bf16.cpp @@ -0,0 +1,90 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ***************************************************************************************************/ + +/*! \file + \brief Tests for Xe bf16_bf16_fp32 and C is bf16 +*/ + + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "default_gemm_configuration.hpp" + +#include "gemm_testbed_3x.hpp" + +namespace cutlass { +namespace { +template +struct XE_Device_Gemm_bf16_bf16_f32_tensor_op_bf16 { + using Config = + gemm::device::XeDefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::IntelXe, + cute::bfloat16_t, LayoutA, + cute::bfloat16_t, LayoutB, + cute::bfloat16_t, layout::RowMajor, + float, + cute::bfloat16_t>; + + using Gemm = gemm::device::GemmUniversalAdapter< + gemm::kernel::GemmUniversal< + cute::Shape, + typename Config::CollectiveMainloop, + typename Config::CollectiveEpilogue>>; +}; + +TEST(XE_Device_Gemm_bf16t_bf16t_f32t_tensor_op_bf16, 256x256x32) { + using Gemm = XE_Device_Gemm_bf16_bf16_f32_tensor_op_bf16< + layout::RowMajor, layout::RowMajor>::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe()); +} + +TEST(XE_Device_Gemm_bf16n_bf16t_f32t_tensor_op_bf16, 256x256x32) { + using Gemm = XE_Device_Gemm_bf16_bf16_f32_tensor_op_bf16< + layout::ColumnMajor, layout::RowMajor>::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe()); +} + +TEST(XE_Device_Gemm_bf16t_bf16n_f32t_tensor_op_bf16, 256x256x32) { + using Gemm = XE_Device_Gemm_bf16_bf16_f32_tensor_op_bf16< + layout::RowMajor, layout::ColumnMajor>::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe()); +} + +TEST(XE_Device_Gemm_bf16n_bf16n_f32t_tensor_op_bf16, 256x256x32) { + using Gemm = XE_Device_Gemm_bf16_bf16_f32_tensor_op_bf16< + layout::ColumnMajor, layout::ColumnMajor>::Gemm; + EXPECT_TRUE(test::gemm::device::TestXe()); +} + +} +} // namespace cutlass diff --git a/test/unit/gemm/device/xe_gemm_fp16_fp16_f32_ptr_array_cooperative.cpp b/test/unit/gemm/device/xe_gemm_fp16_fp16_f32_ptr_array_cooperative.cpp new file mode 100644 index 0000000000..685c1fe421 --- /dev/null +++ b/test/unit/gemm/device/xe_gemm_fp16_fp16_f32_ptr_array_cooperative.cpp @@ -0,0 +1,199 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_atom.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/layout/layout.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "default_gemm_configuration.hpp" +#include "gemm_testbed_3x_ptr_array.hpp" +using namespace cute; +namespace cutlass { +namespace { + +// Default GEMM group configuration for Intel XE architecture +template< + class OperatorClass, class ArchTag, + class ElementA, class LayoutA, + class ElementB, class LayoutB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct DefaultGemmGroupConfiguration { + static_assert(sizeof(ElementA) == 0, "No valid DefaultGemmGroupConfiguration configuration exists."); +}; +// Intel XE MMA f16s8f32 +template +struct DefaultGemmGroupConfiguration< + arch::OpClassTensorOp, arch::IntelXe, + ElementA, LayoutA, + ElementB, LayoutB, + float, LayoutC, + ElementOutput> +{ + static_assert(cute::is_any_of_v, "ElementA needs to be of 16 or 8 bit type"); + static_assert(cute::is_any_of_v, "ElementB needs to be of 16, 8 or 4 bit type"); + using TileShape = cute::Shape, cute::C<256>, cute::C<32>>; + + using CollectiveMainloop = typename gemm::collective::CollectiveBuilder< + arch::IntelXe, arch::OpClassTensorOp, + ElementA, LayoutA, 1, + ElementB, LayoutB, 1, + float, + TileShape, cute::Shape, cute::C<1>, cute::C<1>>, + gemm::collective::StageCountAuto, + gemm::KernelXePtrArrayCooperative + >::CollectiveOp; + + using TiledMma = typename CollectiveMainloop::TiledMma; + + using EpilogueOp = epilogue::fusion::LinearCombination; + + using CollectiveEpilogue = typename epilogue::collective::CollectiveBuilder< + arch::IntelXe, arch::OpClassTensorOp, + TileShape, cute::Shape, cute::C<1>, cute::C<1>>, + epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 1, + ElementOutput, LayoutC, 1, + epilogue::IntelXeXMX16Group, + EpilogueOp + >::CollectiveOp; +}; + +} // namespace (unnamed) +} // namespace cutlass + +namespace cutlass { +namespace { + +template +struct XE_Device_Gemm_fp16_fp16_f32_group { + using ProblemShape = gemm::GroupProblemShape>; + using ElementA = cute::half_t; + using ElementB = cute::half_t; + using ElementC = float; + using ElementAccumulator = float; + using LayoutC = layout::RowMajor; + + using Config = DefaultGemmGroupConfiguration< + arch::OpClassTensorOp, arch::IntelXe, + ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ElementAccumulator>; + + using Gemm = gemm::device::GemmUniversalAdapter< + gemm::kernel::GemmUniversal< + ProblemShape, + typename Config::CollectiveMainloop, + typename Config::CollectiveEpilogue, + gemm::GroupScheduler + >>; +}; + + +// Test: Small uniform problem sizes with same dimensions (RowMajor A, RowMajor B) +TEST(XE_Device_Gemm_fp16_PtrArray_Cooperative, small_uniform_rowmajor) { + using Gemm = XE_Device_Gemm_fp16_fp16_f32_group::Gemm; + std::vector problems = {{256, 256, 256}, {256, 256, 256}}; + EXPECT_TRUE(test::gemm::device::TestAll(problems, 1.0, 1.0)); +} + +// Test: Varied problem sizes with different M, N, K dimensions +TEST(XE_Device_Gemm_fp16_PtrArray_Cooperative, varied_sizes_rowmajor) { + using Gemm = XE_Device_Gemm_fp16_fp16_f32_group::Gemm; + std::vector problems = {{128, 256, 512}, {256, 512, 256}, {512, 256, 128}}; + EXPECT_TRUE(test::gemm::device::TestAll(problems, 1.0, 0.0)); +} + +// Test: Large group with mixed layout (ColumnMajor A, RowMajor B) +TEST(XE_Device_Gemm_fp16_PtrArray_Cooperative, large_group_colrow) { + using Gemm = XE_Device_Gemm_fp16_fp16_f32_group::Gemm; + std::vector problems = {{128, 128, 64}, {256, 256, 128}, {512, 512, 256}, {256, 128, 64}}; + EXPECT_TRUE(test::gemm::device::TestAll(problems, 1.0, 1.0)); +} + +// Test: Both inputs in column-major format +TEST(XE_Device_Gemm_fp16_PtrArray_Cooperative, both_colmajor) { + using Gemm = XE_Device_Gemm_fp16_fp16_f32_group::Gemm; + std::vector problems = {{256, 256, 256}, {512, 256, 128}}; + EXPECT_TRUE(test::gemm::device::TestAll(problems, 1.0, 1.0)); +} + +// Test: Multiple identical problems with scaling variations +TEST(XE_Device_Gemm_fp16_PtrArray_Cooperative, scaling_variations_rowmajor) { + using Gemm = XE_Device_Gemm_fp16_fp16_f32_group::Gemm; + std::vector problems = {{256, 256, 256}, {256, 256, 256}, {256, 256, 256}}; + EXPECT_TRUE(test::gemm::device::TestAll(problems, 2.0, 1.0)); +} + +// Test: Five problem batch with various dimensions and mixed layouts +TEST(XE_Device_Gemm_fp16_PtrArray_Cooperative, five_problem_batch_colrow) { + using Gemm = XE_Device_Gemm_fp16_fp16_f32_group::Gemm; + std::vector problems = {{128, 128, 128}, {256, 256, 128}, {256, 128, 256}, {128, 256, 256}, {256, 256, 256}}; + EXPECT_TRUE(test::gemm::device::TestAll(problems, 1.0, 1.0)); +} + +// Test: Single problem (edge case - minimal batch size) +TEST(XE_Device_Gemm_fp16_PtrArray_Cooperative, single_problem_rowmajor) { + using Gemm = XE_Device_Gemm_fp16_fp16_f32_group::Gemm; + std::vector problems = {{256, 256, 256}}; + EXPECT_TRUE(test::gemm::device::TestAll(problems, 1.0, 1.0)); +} + +// Test: Very small matrices (below typical tile size) +TEST(XE_Device_Gemm_fp16_PtrArray_Cooperative, tiny_matrices_rowmajor) { + using Gemm = XE_Device_Gemm_fp16_fp16_f32_group::Gemm; + std::vector problems = {{32, 32, 32}, {64, 64, 32}}; + EXPECT_TRUE(test::gemm::device::TestAll(problems, 1.0, 0.0)); +} + +// Test: Non-square matrices with various aspect ratios +TEST(XE_Device_Gemm_fp16_PtrArray_Cooperative, nonsquare_matrices_rowmajor) { + using Gemm = XE_Device_Gemm_fp16_fp16_f32_group::Gemm; + std::vector problems = {{128, 256, 512}, {512, 128, 256}, {256, 512, 128}}; + EXPECT_TRUE(test::gemm::device::TestAll(problems, 1.0, 1.0)); +} + +} // namespace +} // namespace cutlass diff --git a/tools/library/CMakeLists.txt b/tools/library/CMakeLists.txt index 98e97bc5da..ef6f4c784c 100644 --- a/tools/library/CMakeLists.txt +++ b/tools/library/CMakeLists.txt @@ -83,6 +83,11 @@ target_link_libraries( ################################################################################ +function(cutlass_target_sources target) + # Wrapper function for target_sources to maintain compatibility with generated manifests + target_sources(${target} ${ARGN}) +endfunction() + function(cutlass_add_cutlass_library) # # Generates static and shared libraries with the given SOURCES. The public CMake @@ -120,6 +125,23 @@ function(cutlass_add_cutlass_library) PRIVATE cutlass_library_internal_interface ) + # If oneMKL is provided via the downloaded ExternalProject (onemkl_project), + # ensure the generated object target depends on that project. Otherwise + # compilation of object files can run before the onemkl headers are available + # and cause "oneapi/mkl/rng/device.hpp file not found" errors. + if (CUTLASS_ENABLE_SYCL) + if (NOT CUTLASS_USING_SYSTEM_ONEMKL) + if (TARGET onemkl_project) + add_dependencies(${__NAME}_objs onemkl_project) + endif() + endif() + endif() + + # Add SYCL-specific compile options when building for SYCL + if (CUTLASS_ENABLE_SYCL) + target_compile_options(${__NAME}_objs PRIVATE -fsycl) + endif() + if (CUTLASS_BUILD_MONO_LIBRARY AND __SUFFIX) # If we're only building a single monolithic library then we @@ -150,9 +172,29 @@ function(cutlass_add_cutlass_library) ${__NAME} PUBLIC cutlass_library_includes PRIVATE $ - cuda_driver ) + # Link with appropriate runtime library + if (CUTLASS_ENABLE_SYCL) + # For SYCL builds, explicitly link with libsycl.so + # We use find_library to locate it in the oneAPI installation + find_library(SYCL_LIBRARY NAMES sycl sycl8 PATHS ENV LD_LIBRARY_PATH NO_DEFAULT_PATH) + if(NOT SYCL_LIBRARY) + find_library(SYCL_LIBRARY NAMES sycl sycl8) + endif() + if(SYCL_LIBRARY) + target_link_libraries(${__NAME} PRIVATE ${SYCL_LIBRARY}) + else() + message(WARNING "libsycl.so not found - runtime may fail to load") + endif() + + # Add oneMKL for SYCL builds (needed for sycl_tensor_fill.h runtime) + add_onemkl_to_target(TARGET ${__NAME}) + else() + # For CUDA builds, link with cuda_driver + target_link_libraries(${__NAME} PRIVATE cuda_driver) + endif() + set_target_properties(${__NAME} PROPERTIES DEBUG_POSTFIX "${CUTLASS_LIBRARY_DEBUG_POSTFIX}") cutlass_add_library( @@ -181,9 +223,20 @@ function(cutlass_add_cutlass_library) ${__NAME}_static PUBLIC cutlass_library_includes PRIVATE $ - cuda_driver ) + # Link with appropriate runtime library + if (CUTLASS_ENABLE_SYCL) + # For SYCL builds, explicitly link with libsycl.so + # Note: SYCL_LIBRARY should already be found from shared library linking above + if(SYCL_LIBRARY) + target_link_libraries(${__NAME}_static PRIVATE ${SYCL_LIBRARY}) + endif() + else() + # For CUDA builds, link with cuda_driver + target_link_libraries(${__NAME}_static PRIVATE cuda_driver) + endif() + set_target_properties(${__NAME}_static PROPERTIES DEBUG_POSTFIX "${CUTLASS_LIBRARY_DEBUG_POSTFIX}") install( @@ -272,6 +325,24 @@ if (NOT CUTLASS_ENABLE_SYCL) # For backward compatibility with the old name add_library(cutlass_lib ALIAS cutlass_library) add_library(cutlass_lib_static ALIAS cutlass_library_static) + +else() + # SYCL-enabled library generation + # Create base library targets for SYCL that will be populated by generated kernels + # Note: .cu files will be compiled with SYCL compiler (icpx) for Intel Xe GPUs + + cutlass_add_cutlass_library( + src/handle.cu + src/manifest.cpp + src/operation_table.cu + src/singleton.cu + src/util.cu + ) + + # For backward compatibility with the old name + add_library(cutlass_lib ALIAS cutlass_library) + add_library(cutlass_lib_static ALIAS cutlass_library_static) + endif() ################################################################################ @@ -307,6 +378,13 @@ if(CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE) endif() endif() +# Set architecture parameter based on whether SYCL or CUDA is enabled +if (CUTLASS_ENABLE_SYCL) + set(CUTLASS_LIBRARY_GENERATOR_ARCHS "20" CACHE STRING "Intel Xe architectures (12=PVC, 20=BMG)") +else() + set(CUTLASS_LIBRARY_GENERATOR_ARCHS "${CUTLASS_NVCC_ARCHS_ENABLED}") +endif() + # --log-level is set to DEBUG to enable printing information about which kernels were excluded # from generation in /python/cutlass_library/manifest.py. To avoid having this information appear # in ${CMAKE_CURRENT_BINARY_DIR}/library_instance_generation.log, set this parameter to INFO @@ -318,7 +396,7 @@ execute_process( --build-dir ${PROJECT_BINARY_DIR} --curr-build-dir ${CMAKE_CURRENT_BINARY_DIR} --generator-target library - --architectures "${CUTLASS_NVCC_ARCHS_ENABLED}" + --architectures "${CUTLASS_LIBRARY_GENERATOR_ARCHS}" --kernels "${CUTLASS_LIBRARY_KERNELS}" --instantiation-level "${CUTLASS_LIBRARY_INSTANTIATION_LEVEL}" --ignore-kernels "${CUTLASS_LIBRARY_IGNORE_KERNELS}" @@ -341,14 +419,13 @@ endif() message(STATUS "Completed generation of library instances. See ${CMAKE_CURRENT_BINARY_DIR}/library_instance_generation.log for more information.") -if (NOT CUTLASS_ENABLE_SYCL) - # include auto-instantiated kernels in he CUTLASS Deliverables Library - set(CUTLASS_LIBRARY_MANIFEST_CMAKE_FILE ${CMAKE_CURRENT_BINARY_DIR}/generated/manifest.cmake) - if(EXISTS "${CUTLASS_LIBRARY_MANIFEST_CMAKE_FILE}") - include(${CUTLASS_LIBRARY_MANIFEST_CMAKE_FILE}) - else() - message(STATUS "auto-generated library manifest cmake file (${CUTLASS_LIBRARY_MANIFEST_CMAKE_FILE}) not found.") - endif() +# Include auto-instantiated kernels in the CUTLASS Deliverables Library +# Now enabled for both CUDA and SYCL +set(CUTLASS_LIBRARY_MANIFEST_CMAKE_FILE ${CMAKE_CURRENT_BINARY_DIR}/generated/manifest.cmake) +if(EXISTS "${CUTLASS_LIBRARY_MANIFEST_CMAKE_FILE}") + include(${CUTLASS_LIBRARY_MANIFEST_CMAKE_FILE}) +else() + message(STATUS "auto-generated library manifest cmake file (${CUTLASS_LIBRARY_MANIFEST_CMAKE_FILE}) not found.") endif() ################################################################################ diff --git a/tools/library/include/cutlass/library/arch_mappings.h b/tools/library/include/cutlass/library/arch_mappings.h index df241e3ca6..e6e31f0f9f 100644 --- a/tools/library/include/cutlass/library/arch_mappings.h +++ b/tools/library/include/cutlass/library/arch_mappings.h @@ -1,5 +1,7 @@ /*************************************************************************************************** * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -148,6 +150,39 @@ template struct ArchMap { static int const kMax = 121; }; +// Intel Xe architecture mappings +template struct ArchMap { + static int const kMin = 12; + static int const kMax = 50; +}; + +template <> struct ArchMap { + static int const kMin = 12; + static int const kMax = 50; +}; + +// Xe12 (PVC) alias +template struct ArchMap { + static int const kMin = 12; + static int const kMax = 50; +}; + +template <> struct ArchMap { + static int const kMin = 12; + static int const kMax = 50; +}; + +// Xe20 (BMG) alias +template struct ArchMap { + static int const kMin = 20; + static int const kMax = 50; +}; + +template <> struct ArchMap { + static int const kMin = 20; + static int const kMax = 50; +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace library diff --git a/tools/library/include/cutlass/library/library.h b/tools/library/include/cutlass/library/library.h index 6764d9a6d8..5564325d4f 100644 --- a/tools/library/include/cutlass/library/library.h +++ b/tools/library/include/cutlass/library/library.h @@ -52,7 +52,10 @@ #include #include #include + +#if !defined(CUTLASS_ENABLE_SYCL) #include +#endif #include "cutlass/cutlass.h" #include "cutlass/library/types.h" diff --git a/tools/library/include/cutlass/library/manifest.h b/tools/library/include/cutlass/library/manifest.h index c4fb0ee8ca..9d2cf41be2 100644 --- a/tools/library/include/cutlass/library/manifest.h +++ b/tools/library/include/cutlass/library/manifest.h @@ -80,6 +80,8 @@ class Manifest { public: Manifest (Provider provider = library::Provider::kCUTLASS) : provider_(provider) { } + Provider get_provider() const { return provider_; } + /// Top-level initialization Status initialize(); diff --git a/tools/library/include/cutlass/library/util.h b/tools/library/include/cutlass/library/util.h index f537421751..efd788e22f 100644 --- a/tools/library/include/cutlass/library/util.h +++ b/tools/library/include/cutlass/library/util.h @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -224,21 +225,35 @@ NumericTypeID dynamic_datatype_to_id(RuntimeDatatype type); } \ } while (0) -// RAII CUDA buffer container +// RAII device buffer container (CUDA/SYCL compatible) class CudaBuffer { public: CudaBuffer() : size_(0), d_ptr_(nullptr) {} explicit CudaBuffer(size_t size) : size_(size), d_ptr_(nullptr) { +#if defined(CUTLASS_ENABLE_SYCL) + // SYCL memory allocation using malloc_device + auto q = compat::get_default_queue(); + d_ptr_ = sycl::malloc_device(size_, q); + if (d_ptr_ == nullptr) { + throw std::runtime_error("sycl::malloc_device failed"); + } +#else cudaError_t err = cudaMalloc(&d_ptr_, size_); if (err != cudaSuccess) { throw std::runtime_error("cudaMalloc failed: " + std::string(cudaGetErrorString(err))); } +#endif } ~CudaBuffer() { if (d_ptr_) { +#if defined(CUTLASS_ENABLE_SYCL) + auto q = compat::get_default_queue(); + sycl::free(d_ptr_, q); +#else cudaFree(d_ptr_); +#endif } } @@ -253,7 +268,12 @@ class CudaBuffer { CudaBuffer& operator=(CudaBuffer&& other) noexcept { if (this != &other) { if (d_ptr_) { +#if defined(CUTLASS_ENABLE_SYCL) + auto q = compat::get_default_queue(); + sycl::free(d_ptr_, q); +#else cudaFree(d_ptr_); +#endif } d_ptr_ = other.d_ptr_; size_ = other.size_; diff --git a/tools/library/src/gemm_operation.h b/tools/library/src/gemm_operation.h index 880cb4bf34..1d87f3ecf0 100644 --- a/tools/library/src/gemm_operation.h +++ b/tools/library/src/gemm_operation.h @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -36,13 +37,18 @@ #include "cutlass/cutlass.h" #include "cutlass/gemm/device/gemm.h" + +#if !defined(CUTLASS_ENABLE_SYCL) +// CUDA-only kernel types - not compatible with SYCL #include "cutlass/gemm/device/gemm_sparse.h" #include "cutlass/gemm/device/gemm_complex.h" #include "cutlass/gemm/device/gemm_batched.h" #include "cutlass/gemm/device/gemm_array.h" +#include "cutlass/gemm/kernel/default_gemm_planar_complex_universal.h" +#endif + #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/default_gemm_universal.h" -#include "cutlass/gemm/kernel/default_gemm_planar_complex_universal.h" #include "cutlass/library/library.h" #include "library_internal.h" diff --git a/tools/library/src/gemm_operation_3x.hpp b/tools/library/src/gemm_operation_3x.hpp index 7b27913df9..32c4acb29f 100644 --- a/tools/library/src/gemm_operation_3x.hpp +++ b/tools/library/src/gemm_operation_3x.hpp @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -46,6 +47,7 @@ #include "cutlass/util/device_memory.h" #include "cutlass/util/reference/device/tensor_fill.h" #include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/mixed_dtype_utils.hpp" #include "cute/tensor.hpp" #include @@ -237,8 +239,8 @@ class GemmUniversal3xOperation : public GemmOperation3xBase { return Status::kSuccess; } else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { - fusion_args.alpha = 0; - fusion_args.beta = 0; + fusion_args.alpha = ElementCompute(0); + fusion_args.beta = ElementCompute(0); fusion_args.alpha_ptr = static_cast(arguments.alpha); fusion_args.beta_ptr = static_cast(arguments.beta); diff --git a/tools/library/src/grouped_gemm_operation_3x.hpp b/tools/library/src/grouped_gemm_operation_3x.hpp index 91f618d4fa..d59a43ea3c 100644 --- a/tools/library/src/grouped_gemm_operation_3x.hpp +++ b/tools/library/src/grouped_gemm_operation_3x.hpp @@ -441,13 +441,14 @@ class GroupedGemmUniversal3xOperation : public GroupedGemmOperation3xBase); + args->max_active_clusters = cutlass::KernelHardwareInfo::query_device_max_active_clusters( cluster_dims, threads_per_block, kernel_ptr); if (args->max_active_clusters == 0) { - std::cerr << "Max Active Clusters could not be queried. " + std::cerr << "Max Active Clusters could not be queried. " << "Falling back to heuristics mode (static cluster shape) or preferred cluster mode.\n"; } diff --git a/tools/library/src/library_internal.h b/tools/library/src/library_internal.h index e8bd77397f..a6f343be08 100644 --- a/tools/library/src/library_internal.h +++ b/tools/library/src/library_internal.h @@ -181,7 +181,11 @@ template <> struct NumericTypeMap { static NumericTypeID const kId = NumericTypeID::kTF32; }; - +// Handle cute::tuple-wrapped types (used in some collectives) +template +struct NumericTypeMap> { + static NumericTypeID const kId = NumericTypeMap::kId; +}; template <> struct NumericTypeMap { diff --git a/tools/library/src/manifest.cpp b/tools/library/src/manifest.cpp index b9c04de71d..d622060b83 100644 --- a/tools/library/src/manifest.cpp +++ b/tools/library/src/manifest.cpp @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -43,7 +44,27 @@ namespace library { ////////////////////////////////////////////////////////////////////////////////////////////////////////// +#ifndef CUTLASS_ENABLE_SYCL +// For CUDA builds, reference operations are defined in initialize_reference_operations.cu void initialize_reference_operations(Manifest &manifest); +#else +// For SYCL builds, provide a stub implementation since reference ops are not yet supported +inline void initialize_reference_operations(Manifest &manifest) { + // Reference operations not yet implemented for SYCL + // This is a stub to allow the library to compile +} +#endif + +#ifndef CUTLASS_ENABLE_SYCL +// For CUDA builds, reduction operations are defined in init_reduction_operations.cu +// Declaration is in manifest.h +#else +// For SYCL builds, provide a stub implementation since reduction ops are not yet supported +inline void initialize_all_reduction_op(Manifest &manifest) { + // Reduction operations not yet implemented for SYCL + // This is a stub to allow the library to compile +} +#endif ////////////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/src/sparse_gemm_operation_3x.hpp b/tools/library/src/sparse_gemm_operation_3x.hpp index 34da25b9a6..6cb836b89a 100644 --- a/tools/library/src/sparse_gemm_operation_3x.hpp +++ b/tools/library/src/sparse_gemm_operation_3x.hpp @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -34,6 +35,9 @@ #pragma once +// Sparse GEMM operations are CUDA-only (not supported in SYCL) +#if !defined(CUTLASS_ENABLE_SYCL) + #include "cutlass/cutlass.h" #include "cutlass/detail/collective.hpp" #include "cutlass/array.h" @@ -501,4 +505,6 @@ class SparseGemmUniversal3xOperation : public GemmOperation3xBase { } // namespace cutlass::library +#endif // !defined(CUTLASS_ENABLE_SYCL) + /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/util/CMakeLists.txt b/tools/util/CMakeLists.txt index b69ea02347..3344bef997 100644 --- a/tools/util/CMakeLists.txt +++ b/tools/util/CMakeLists.txt @@ -26,6 +26,12 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. cmake_policy(SET CMP0112 NEW) + +# Include oneMKL support for SYCL builds +if (CUTLASS_ENABLE_SYCL) + include(${CMAKE_SOURCE_DIR}/cmake/onemkl.cmake) +endif() + add_library(cutlass_tools_util_includes INTERFACE) add_library(nvidia::cutlass::tools::util ALIAS cutlass_tools_util_includes) set_target_properties(cutlass_tools_util_includes PROPERTIES EXPORT_NAME tools::util) @@ -43,6 +49,16 @@ target_link_libraries( $<$:cublas> ) +# Add oneMKL for SYCL builds (sycl_tensor_fill.h needs oneapi/mkl/rng/device.hpp) +if (CUTLASS_ENABLE_SYCL) + if (CUTLASS_USING_SYSTEM_ONEMKL) + target_link_libraries(cutlass_tools_util_includes INTERFACE MKL::MKL) + else() + target_include_directories(cutlass_tools_util_includes INTERFACE + $) + endif() +endif() + install( DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ diff --git a/tools/util/include/cutlass/util/packed_stride.hpp b/tools/util/include/cutlass/util/packed_stride.hpp index 811ba152ab..faa3b4aaac 100644 --- a/tools/util/include/cutlass/util/packed_stride.hpp +++ b/tools/util/include/cutlass/util/packed_stride.hpp @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -108,6 +109,54 @@ make_cute_packed_stride(cute::Stride, IntT, int64_t> s, cute::Shape return s_copy; } +// Strides with 2 batch modes. +// All this code should be replaced with a generic implementation. + +template +CUTLASS_HOST_DEVICE +auto +make_cute_packed_stride(cute::Stride,int,int> s, + cute::Shape shape) +{ + using namespace cute; + + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + + int batch_count0 = get<2>(shape); + int batch_count1 = get<3>(shape) * batch_count0; + + get<0>(s_copy) = static_cast(get<1>(shape)); + get<2>(s_copy) = (batch_count0 <= 1) ? 0 : product(take<0,2>(shape)); + get<3>(s_copy) = (batch_count1 <= 1) ? 0 : product(take<0,3>(shape)); + + return s_copy; +} + +template +CUTLASS_HOST_DEVICE +auto +make_cute_packed_stride(cute::Stride,IntT,int,int> s, + cute::Shape shape) +{ + using namespace cute; + + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + + int batch_count0 = get<2>(shape); + int batch_count1 = get<3>(shape) * batch_count0; + + get<1>(s_copy) = static_cast(get<0>(shape)); + get<2>(s_copy) = (batch_count0 <= 1) ? 0 : product(take<0,2>(shape)); + get<3>(s_copy) = (batch_count1 <= 1) ? 0 : product(take<0,3>(shape)); + + return s_copy; +} + + ///////////////////////////////////////////////////////////////////////////////////////////////// // Strides with group mode diff --git a/tools/util/include/cutlass/util/reference/device/tensor_compare.h b/tools/util/include/cutlass/util/reference/device/tensor_compare.h index 5cae58e4f1..4a4c94882e 100644 --- a/tools/util/include/cutlass/util/reference/device/tensor_compare.h +++ b/tools/util/include/cutlass/util/reference/device/tensor_compare.h @@ -102,6 +102,9 @@ __global__ void Element b = cutlass::ReferenceFactory::get(ptr_B, idx); if (!relatively_equal(a, b, epsilon, nonzero_floor)) { +#ifdef SHOW_DIFF + printf("[%zu]: %f vs %f\n", idx, (double) a, (double) b); +#endif *equal = 0; return; }