From 7928505d70c73ce4a137bc8d36a280e8c97ebffc Mon Sep 17 00:00:00 2001 From: AndreSlavescu Date: Fri, 20 Mar 2026 15:01:28 -0400 Subject: [PATCH] docs updates with diagrams --- docs/_static/compilation-pipeline.svg | 90 ++++++++++++++ docs/_static/gemm-tiling.svg | 161 ++++++++++++++++++++++++++ docs/_static/morton-swizzle.svg | 150 ++++++++++++++++++++++++ docs/_static/simdgroup-layout.svg | 92 +++++++++++++++ docs/_static/tiling-overview.svg | 105 +++++++++++++++++ docs/_static/unified-memory.svg | 60 ++++++++++ docs/conf.py | 2 +- docs/examples/fused-activations.rst | 19 ++- docs/examples/layernorm.rst | 2 +- docs/examples/matmul.rst | 20 ++-- docs/examples/softmax.rst | 28 +---- docs/examples/vector-add.rst | 12 +- docs/getting-started/first-kernel.rst | 14 ++- docs/getting-started/install.rst | 2 +- docs/guide/autotuning.rst | 16 ++- docs/guide/language.rst | 8 +- docs/guide/memory.rst | 20 +++- docs/guide/tile-ops.rst | 40 +++---- 18 files changed, 753 insertions(+), 88 deletions(-) create mode 100644 docs/_static/compilation-pipeline.svg create mode 100644 docs/_static/gemm-tiling.svg create mode 100644 docs/_static/morton-swizzle.svg create mode 100644 docs/_static/simdgroup-layout.svg create mode 100644 docs/_static/tiling-overview.svg create mode 100644 docs/_static/unified-memory.svg diff --git a/docs/_static/compilation-pipeline.svg b/docs/_static/compilation-pipeline.svg new file mode 100644 index 0000000..0b7d635 --- /dev/null +++ b/docs/_static/compilation-pipeline.svg @@ -0,0 +1,90 @@ + + + + + + + + + + + + + Input Python code + + + + @metile.kernel + def matmul(A, B, C, M, N, K, ...): + acc = metile.zeros((BM, BN)) + for k in metile.tile_range(...): + acc = metile.dot(a, b, acc) + + + + + trace + + + Tile IR + + + Hardware-agnostic operations + + + Dot, TileLoad, TileStore, ForRange, Zeros, ... + + + + lower + + + Metal IR + + + Apple GPU primitives + + + + Simdgroup MMA + M1 / M2 / M3 + + + Tensor Ops + M4+ + + + + optimize + + + Optimization Passes + + + + vectorize, serpentine MMA, double-buffer, split-K, swizzle, fold + + + + emit + + + MSL Source + + + + [[kernel]] void mtile_matmul( + device float* A [[buffer(0)]], ...) + { ... simdgroup_multiply_accumulate ... } + + + + + xcrun metal -O2 + + + + Metal Compute Pipeline (.metallib) + diff --git a/docs/_static/gemm-tiling.svg b/docs/_static/gemm-tiling.svg new file mode 100644 index 0000000..809a7d9 --- /dev/null +++ b/docs/_static/gemm-tiling.svg @@ -0,0 +1,161 @@ + + + + + + + + + + + + GEMM Tiling: C = A x B + + + + + A + M x K + + + + BM rows + + K + M + + + x + + + B + K x N + + + + BN + + N + K + + + = + + + C + + + + + + + + + + + + + + + (0,0) + (0,1) + (0,2) + (0,3) + (1,0) + (1,1) + (1,2) + (1,3) + (2,0) + (2,1) + (2,2) + (2,3) + (3,0) + (3,1) + (3,2) + (3,3) + + + N + M + + + grid = (ceil(M/BM), ceil(N/BN)) + one program instance per tile + + + + + + Inside pid=(0,0): the K-loop + + + + Iteration k = 0 + + + + A tile + BM x BK + + + x + + + + B tile + BK x BN + + + += + + + + acc + BM x BN + (in registers) + + + + metile.dot() + + + + k += BK + + + + Iteration k = BK + + + A tile + BM x BK + + x + + + B tile + BK x BN + + += + + + acc + BM x BN + (accumulated) + + + metile.dot() + + + . . . repeated K / BK times . . . + + + + + + + metile.tile_store(C, ..., acc) + write BM x BN result to global memory + diff --git a/docs/_static/morton-swizzle.svg b/docs/_static/morton-swizzle.svg new file mode 100644 index 0000000..9c86b7f --- /dev/null +++ b/docs/_static/morton-swizzle.svg @@ -0,0 +1,150 @@ + + + + + + + + + + + Tile Scheduling: Morton vs Linear + + + + + + + + + Linear (row-major) + + + + + + 0 + + + 1 + + + 2 + + + 3 + + + + 4 + + + 5 + + + 6 + + + 7 + + + + 8 + + + 9 + + + 10 + + + 11 + + + + 12 + + + 13 + + + 14 + + + 15 + + tiles read in row order, + poor L2 reuse for A and B + + + Morton (Z-order) + + + + + + 0 + + + 1 + + + 4 + + + 5 + + + + 2 + + + 3 + + + 6 + + + 7 + + + + 8 + + + 9 + + + 12 + + + 13 + + + + 10 + + + 11 + + + 14 + + + 15 + + + + + + + 2x2 blocks processed together, + neighbors share A-rows and B-cols in L2 + + + + Why Morton? + Tiles 0,1,2,3 share two A-rows and two B-columns in L2 cache. Less memory traffic. + diff --git a/docs/_static/simdgroup-layout.svg b/docs/_static/simdgroup-layout.svg new file mode 100644 index 0000000..4c8e2f1 --- /dev/null +++ b/docs/_static/simdgroup-layout.svg @@ -0,0 +1,92 @@ + + + + + + + + Simdgroup Layout (WM=4, WN=4) + 128 x 128 output tile, 16 simdgroups, each handles 32 x 32 + + + + + + col 0..31 + col 32..63 + col 64..95 + + + + sg(0,0) + 32 x 32 + + + sg(0,1) + 32 x 32 + + + sg(0,2) + 32 x 32 + + + sg(0,3) + 32 x 32 + + + + sg(1,0) + 32 x 32 + + + sg(1,1) + 32 x 32 + + + sg(1,2) + 32 x 32 + + + sg(1,3) + 32 x 32 + + + + sg(2,0) + 32 x 32 + + + sg(2,1) + 32 x 32 + + + sg(2,2) + 32 x 32 + + + sg(2,3) + 32 x 32 + + + + sg(3,0) + 32 x 32 + + + sg(3,1) + 32 x 32 + + + sg(3,2) + 32 x 32 + + + sg(3,3) + 32 x 32 + + + Each simdgroup = 32 threads computing one subtile independently. + sg_row = sgid / WN, sg_col = sgid % WN + diff --git a/docs/_static/tiling-overview.svg b/docs/_static/tiling-overview.svg new file mode 100644 index 0000000..7a856de --- /dev/null +++ b/docs/_static/tiling-overview.svg @@ -0,0 +1,105 @@ + + + + + + + + + + + How Tiling Works + + + Output matrix C (M x N) + + + + + + + + + pid=(0,0) + BM x BN + pid=(0,1) + pid=(1,0) + pid=(1,1) + + + + + + N + M + + grid = (ceil(M/BM), ceil(N/BN)) + + + + zoom in + + + + Inside pid=(0,0) + + + k = 0 + + + A tile + BM x BK + + x + + + B tile + BK x BN + + += + + + acc + BM x BN + + + + k += BK + + + k = BK + + + A tile + BM x BK + + x + + + B tile + BK x BN + + += + + + acc + accumulated + + + . . . K / BK iterations . . . + + + + + + + tile_store(C, ..., acc) + + + + dot() + + dot() + diff --git a/docs/_static/unified-memory.svg b/docs/_static/unified-memory.svg new file mode 100644 index 0000000..bf6bdde --- /dev/null +++ b/docs/_static/unified-memory.svg @@ -0,0 +1,60 @@ + + + + + + + + + + + + + + + Unified Memory on Apple Silicon + + + + CPU + + + Python / numpy + + + out.numpy() + + + + Apple GPU + + + Metal compute kernel + + + device float* A + + + + Shared Physical Memory + + + metile.Buffer(data=np_array) + + + + + + + + + + + + zero copy + + + Both CPU and GPU read/write the same physical address. No transfers needed. + diff --git a/docs/conf.py b/docs/conf.py index 4b0035a..e74c011 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,5 +1,5 @@ project = "meTile" -copyright = "2025, Andre Slavescu" +copyright = "2026, Andre Slavescu" author = "Andre Slavescu" extensions = [ diff --git a/docs/examples/fused-activations.rst b/docs/examples/fused-activations.rst index 565a159..70c3166 100644 --- a/docs/examples/fused-activations.rst +++ b/docs/examples/fused-activations.rst @@ -8,7 +8,7 @@ to run different computations on different simdgroup subsets within a single ker Simple Activations ------------------ -Element-wise kernels follow the same pattern as vector add — load, compute, store: +Element-wise kernels follow the same pattern as vector add (load, compute, store): .. code-block:: python @@ -39,7 +39,7 @@ Fused GEMM + Activation ------------------------ When an activation follows a ``dot`` operation, the compiler fuses it into the GEMM epilogue. -The activation runs on register-resident data — no global memory round-trip: +The activation runs on register-resident data, no global memory round-trip: .. code-block:: python @@ -54,7 +54,7 @@ The activation runs on register-resident data — no global memory round-trip: a = metile.tile_load(A, pid_m * BLOCK_M, k, K, (BLOCK_M, BLOCK_K)) b = metile.tile_load(B, k, pid_n * BLOCK_N, N, (BLOCK_K, BLOCK_N)) acc = metile.dot(a, b, acc) - # Fused GELU epilogue — runs on accumulator registers + # Fused GELU epilogue, runs on accumulator registers acc = acc / (1.0 + metile.exp(-1.702 * acc)) metile.tile_store(C, pid_m * BLOCK_M, pid_n * BLOCK_N, N, acc, (BLOCK_M, BLOCK_N)) @@ -64,7 +64,7 @@ Simdgroup Roles Apple GPUs organize threads into 32-thread **simdgroups**. A threadgroup can contain multiple simdgroups. With ``simdgroup_role``, you can assign different work to different -simdgroup subsets — useful for computing multiple outputs in a single dispatch: +simdgroup subsets, useful for computing multiple outputs in a single dispatch: .. code-block:: python @@ -85,16 +85,13 @@ simdgroup subsets — useful for computing multiple outputs in a single dispatch metile.store(out_sqrt + offs, metile.sqrt(metile.abs(x)), mask=mask) With ``num_roles=2``, the threadgroup's simdgroups are split in half. Role 0 computes -exponentials while role 1 computes square roots — simultaneously, in the same kernel launch. - -This is useful when you need multiple derived outputs from the same input and want to -avoid the overhead of multiple kernel dispatches. +exponentials while role 1 computes square roots, simultaneously, in the same kernel launch. GEGLU (Gated GELU) ------------------- -A practical use of simdgroup roles — computing the gate and up projections of GEGLU +A practical use of simdgroup roles for computing the gate and up projections of GEGLU in parallel: .. code-block:: python @@ -121,6 +118,6 @@ Concepts Introduced - Element-wise activation patterns - ``metile.exp`` for activation functions -- Fused GEMM epilogues — zero-cost post-GEMM operations -- ``metile.simdgroup_role`` — split work across simdgroup subsets +- Fused GEMM epilogues: zero-cost post-GEMM operations +- ``metile.simdgroup_role``: split work across simdgroup subsets - Multiple outputs from a single kernel diff --git a/docs/examples/layernorm.rst b/docs/examples/layernorm.rst index 282f34e..bbc3fb3 100644 --- a/docs/examples/layernorm.rst +++ b/docs/examples/layernorm.rst @@ -68,5 +68,5 @@ Concepts Introduced - Three-pass algorithm (mean, variance, normalize) - Scalar accumulators across tiled loops - ``metile.sum`` reduction -- ``metile.sqrt`` — element-wise square root +- ``metile.sqrt``: element-wise square root - Loading separate weight/bias arrays (shared across all rows) diff --git a/docs/examples/matmul.rst b/docs/examples/matmul.rst index 20ca67d..3f49e57 100644 --- a/docs/examples/matmul.rst +++ b/docs/examples/matmul.rst @@ -30,7 +30,7 @@ Basic GEMM Launching --------- -The grid is 2D — one program instance per output tile: +The grid is 2D, one program instance per output tile: .. code-block:: python @@ -64,8 +64,8 @@ The compiler maps ``dot`` to the appropriate hardware: Fused GEMM + ReLU ------------------ -Element-wise operations after the GEMM loop are fused into the kernel's epilogue — -they run on register-resident data with zero extra memory traffic: +Element-wise operations after the GEMM loop are fused into the kernel's epilogue. +They run on register-resident data with zero extra memory traffic: .. code-block:: python @@ -80,7 +80,7 @@ they run on register-resident data with zero extra memory traffic: a = metile.tile_load(A, pid_m * BLOCK_M, k, K, (BLOCK_M, BLOCK_K)) b = metile.tile_load(B, k, pid_n * BLOCK_N, N, (BLOCK_K, BLOCK_N)) acc = metile.dot(a, b, acc) - acc = metile.where(acc > 0, acc, 0) # fused ReLU — no global memory round-trip + acc = metile.where(acc > 0, acc, 0) # fused ReLU, no global memory round-trip metile.tile_store(C, pid_m * BLOCK_M, pid_n * BLOCK_N, N, acc, (BLOCK_M, BLOCK_N)) @@ -132,9 +132,9 @@ See :doc:`/guide/autotuning` for the full autotuning guide. Concepts Introduced ------------------- -- ``metile.zeros`` — register-resident accumulator initialization -- ``metile.dot`` — tile-level matrix multiply-accumulate -- ``metile.tile_load`` / ``metile.tile_store`` — 2D strided memory access -- 2D grids — ``kernel[(grid_m, grid_n)]`` -- Fused epilogues — element-wise ops after GEMM are free -- Tile swizzle — cache-friendly scheduling patterns +- ``metile.zeros``: register-resident accumulator initialization +- ``metile.dot``: tile-level matrix multiply-accumulate +- ``metile.tile_load`` / ``metile.tile_store``: 2D strided memory access +- 2D grids: ``kernel[(grid_m, grid_n)]`` +- Fused epilogues: element-wise ops after GEMM are free +- Tile swizzle: cache-friendly scheduling patterns diff --git a/docs/examples/softmax.rst b/docs/examples/softmax.rst index 9c5fc28..e172b1f 100644 --- a/docs/examples/softmax.rst +++ b/docs/examples/softmax.rst @@ -53,30 +53,12 @@ Each program instance handles one row. The grid is 1D with one instance per row: softmax[(rows,)](X, Out, cols, BLOCK=256) -How It Works ------------- - -The kernel makes three passes over each row: - -1. **Find max** — ``metile.maximum`` computes element-wise max across tiles, then - ``metile.max`` reduces the tile to a scalar. This is needed for numerical stability - (subtracting the max prevents overflow in ``exp``). - -2. **Sum exponentials** — accumulates ``exp(x - m)`` across all tiles, then - ``metile.sum`` reduces to a scalar denominator. - -3. **Normalize** — divides each ``exp(x - m)`` by the sum. - -Each pass iterates over the row in chunks of ``BLOCK`` elements using ``tile_range``. -The ``mask`` ensures correctness when ``N`` is not a multiple of ``BLOCK``. - - Concepts Introduced ------------------- -- ``metile.tile_range`` — tiling loop for iterating over a dimension -- ``metile.maximum`` / ``metile.max`` — element-wise max and reduction -- ``metile.sum`` — sum reduction -- ``metile.exp`` — element-wise exponential -- Multi-pass algorithms — reading the same data multiple times in different passes +- ``metile.tile_range``: tiling loop for iterating over a dimension +- ``metile.maximum`` / ``metile.max``: element-wise max and reduction +- ``metile.sum``: sum reduction +- ``metile.exp``: element-wise exponential +- Multi-pass algorithms: reading the same data multiple times in different passes - Scalar accumulators (``m``, ``s``) carried across loop iterations diff --git a/docs/examples/vector-add.rst b/docs/examples/vector-add.rst index 6d0f08b..c28f0e7 100644 --- a/docs/examples/vector-add.rst +++ b/docs/examples/vector-add.rst @@ -37,9 +37,9 @@ The simplest meTile kernel: add two arrays element by element. Concepts Introduced ------------------- -- ``@metile.kernel`` — compile a Python function to Metal -- ``metile.program_id`` — which program instance am I? -- ``metile.arange`` — tile of consecutive indices -- ``metile.load`` / ``metile.store`` — masked memory access -- ``metile.Buffer`` — zero-copy GPU memory -- ``kernel[grid]()`` — launch with a grid of instances +- ``@metile.kernel``: compile a Python function to Metal +- ``metile.program_id``: which program instance am I? +- ``metile.arange``: tile of consecutive indices +- ``metile.load`` / ``metile.store``: masked memory access +- ``metile.Buffer``: zero-copy GPU memory +- ``kernel[grid]()``: launch with a grid of instances diff --git a/docs/getting-started/first-kernel.rst b/docs/getting-started/first-kernel.rst index f703011..3dace45 100644 --- a/docs/getting-started/first-kernel.rst +++ b/docs/getting-started/first-kernel.rst @@ -30,7 +30,7 @@ Let's break this down line by line. Device pointers to GPU memory. These map to ``device float*`` in Metal. ``N`` - A runtime scalar — passed as a ``constant int&`` to the shader. + A runtime scalar, passed as a ``constant int&`` to the shader. ``BLOCK: metile.constexpr`` A **compile-time constant**. The value is baked directly into the shader. Changing it @@ -75,7 +75,7 @@ Launching ``metile.Buffer`` Wraps a Metal buffer in unified memory. CPU and GPU share the same physical memory on - Apple Silicon — there is no copy between host and device. + Apple Silicon, so there is no copy between host and device. ``metile.Buffer.zeros((N,))`` Allocates a zeroed buffer of ``N`` float32 elements. @@ -100,6 +100,10 @@ When you call ``add[grid](...)``, meTile: 5. **Compiles** with ``xcrun metal -O2`` (or JIT if Xcode is unavailable) 6. **Dispatches** the compute pipeline on the GPU +.. image:: /_static/compilation-pipeline.svg + :alt: meTile compilation pipeline: Python to Tile IR to Metal IR to MSL to GPU + :width: 100% + You can inspect any stage with the ``METILE_DEBUG`` environment variable: .. code-block:: bash @@ -112,6 +116,6 @@ You can inspect any stage with the ``METILE_DEBUG`` environment variable: What's Next ----------- -- :doc:`/guide/language` — full language reference for what you can write inside ``@metile.kernel`` -- :doc:`/examples/softmax` — a more complex kernel with reductions and multiple passes -- :doc:`/examples/matmul` — tile-level matrix multiply with ``dot`` and ``tile_load`` +- :doc:`/guide/language` for the full language reference +- :doc:`/examples/softmax` for a more complex kernel with reductions and multiple passes +- :doc:`/examples/matmul` for tile-level matrix multiply with ``dot`` and ``tile_load`` diff --git a/docs/getting-started/install.rst b/docs/getting-started/install.rst index 62f2551..0927668 100644 --- a/docs/getting-started/install.rst +++ b/docs/getting-started/install.rst @@ -5,7 +5,7 @@ Requirements ------------ - macOS 13 (Ventura) or later -- Apple Silicon (M1, M2, M3, M4 — any variant) +- Apple Silicon (M1, or later) - Python 3.10+ Install diff --git a/docs/guide/autotuning.rst b/docs/guide/autotuning.rst index 3adb846..68d0f7a 100644 --- a/docs/guide/autotuning.rst +++ b/docs/guide/autotuning.rst @@ -52,6 +52,20 @@ On the first call with new key values, the autotuner: Subsequent calls with the same key values use the cached winner with zero overhead. +.. code-block:: text + + First call (M=1024, N=1024, K=1024): + +--------------------------------------------------+ + | Config(BM=64, BN=64, BK=32): 1.26ms | + | Config(BM=128, BN=128, BK=64): 0.62ms <-- | winner cached + | Config(BM=128, BN=128, BK=128): 0.91ms | + +--------------------------------------------------+ + + Subsequent calls (same M, N, K): + +--------------------------------------------------+ + | cached -> Config(BM=128, BN=128, BK=64) | no re-tuning + +--------------------------------------------------+ + Config Object ------------- @@ -117,6 +131,6 @@ fast dispatcher that skips all Python overhead on subsequent calls: dispatch = autotuned_matmul[grid].prepare(A, B, C, M, N, K) - # Hot loop — minimal Python overhead per call + # hot path with minimal python overhead for _ in range(1000): dispatch() diff --git a/docs/guide/language.rst b/docs/guide/language.rst index 18cb2f2..fb3a45b 100644 --- a/docs/guide/language.rst +++ b/docs/guide/language.rst @@ -2,7 +2,7 @@ Language Reference ================== meTile provides a Python eDSL (embedded domain-specific language) for writing GPU kernels. Functions -decorated with ``@metile.kernel`` are traced and compiled to Metal shaders — they are not executed +decorated with ``@metile.kernel`` are traced and compiled to Metal shaders. They are not executed as regular Python. This page documents every construct available inside a ``@metile.kernel`` function. @@ -19,9 +19,9 @@ Kernel Definition Parameters are either: -- **Pointers** — numpy arrays or ``metile.Buffer`` objects become ``device float*`` in Metal -- **Scalars** — Python ints/floats become ``constant int&`` or ``constant float&`` -- **Constexprs** — annotated with ``metile.constexpr``, baked into the shader at compile time +- **Pointers**: numpy arrays or ``metile.Buffer`` objects become ``device float*`` in Metal +- **Scalars**: Python ints/floats become ``constant int&`` or ``constant float&`` +- **Constexprs**: annotated with ``metile.constexpr``, baked into the shader at compile time Constexprs are passed as keyword arguments at launch: diff --git a/docs/guide/memory.rst b/docs/guide/memory.rst index 94ba84f..afe95f4 100644 --- a/docs/guide/memory.rst +++ b/docs/guide/memory.rst @@ -1,9 +1,13 @@ Memory Model ============ -Apple Silicon has a **unified memory architecture** — the CPU and GPU share the same physical +Apple Silicon has a **unified memory architecture** where the CPU and GPU share the same physical memory. meTile exposes this directly through ``metile.Buffer``. +.. image:: /_static/unified-memory.svg + :alt: Unified memory: CPU and GPU both access the same physical memory through metile.Buffer + :width: 100% + Buffers ------- @@ -13,7 +17,7 @@ Buffers import numpy as np import metile - # Create from numpy (zero-copy — the GPU reads the same memory) + # Create from numpy (zero-copy, the GPU reads the same memory) x = metile.Buffer(data=np.random.randn(1024).astype(np.float32)) # Allocate zeroed @@ -71,6 +75,16 @@ memory access: x = metile.load(X + offs, mask=mask) # masked-off lanes read 0 metile.store(Out + offs, x, mask=mask) # masked-off lanes are skipped +.. code-block:: text + + N = 10, BLOCK = 4, pid = 2 (last instance) + + offs = [8, 9, 10, 11] + mask = [T, T, F, F] # values 10 and 11 are out of bounds + + load: reads x[8], x[9], returns 0 for indices 10, 11 + store: writes out[8], out[9], skips indices 10, 11 + Masking is essential for correctness. Without it, the last program instance would read/write past the end of the array. @@ -85,5 +99,5 @@ For kernels that need inter-thread communication within a threadgroup, use share buf = metile.shared(size=256, dtype="f32") metile.barrier() # synchronize all threads in the threadgroup -Shared memory is threadgroup-local — it is not visible to other threadgroups. Use +Shared memory is threadgroup-local and not visible to other threadgroups. Use ``metile.barrier()`` to synchronize access within a threadgroup. diff --git a/docs/guide/tile-ops.rst b/docs/guide/tile-ops.rst index fd68120..7fbe675 100644 --- a/docs/guide/tile-ops.rst +++ b/docs/guide/tile-ops.rst @@ -12,13 +12,13 @@ The Two Backends meTile automatically selects the best backend for your hardware when compiling GEMM kernels: **Simdgroup Matrix (M1/M2/M3)** - Uses ``simdgroup_matrix`` — Apple's 8x8 matrix multiply-accumulate + Uses ``simdgroup_matrix``, Apple's 8x8 matrix multiply-accumulate primitive. Each simdgroup (32 threads) collaboratively computes an 8x8 tile. The compiler tiles the output across multiple simdgroups and uses threadgroup (shared) memory to stage data. **Metal 4 Tensor Ops (M4+)** - Uses ``matmul2d`` with ``cooperative_tensor`` — Metal 4's hardware matrix multiply + Uses ``matmul2d`` with ``cooperative_tensor``, Metal 4's hardware matrix multiply descriptors. Each simdgroup independently loads data from device memory into register-resident cooperative tensors and runs the MMA. No threadgroup memory needed. @@ -29,20 +29,12 @@ your hardware and chooses the right path. How Tiling Works ---------------- -A GEMM kernel tiles the computation into blocks: +A GEMM kernel tiles the computation into blocks. Each program instance computes +one output tile, iterating over K to accumulate partial products: -.. code-block:: text - - Output C (M x N) Each tile is BLOCK_M x BLOCK_N - ┌─────────┬─────────┐ - │ (0,0) │ (0,1) │ Each program instance computes one tile. - │ 128x128 │ 128x128 │ The K dimension is tiled with BLOCK_K. - ├─────────┼─────────┤ - │ (1,0) │ (1,1) │ grid = (ceil(M/BLOCK_M), ceil(N/BLOCK_N)) - │ 128x128 │ 128x128 │ - └─────────┴─────────┘ - -Inside each tile, the K-loop accumulates partial results: +.. image:: /_static/tiling-overview.svg + :alt: Output matrix tiled into blocks, with K-loop detail showing tile_load and dot accumulation + :width: 100% .. code-block:: python @@ -81,7 +73,11 @@ The tile sizes are compile-time constants that control how the hardware is used: - 2, 4 ``WM`` and ``WN`` control how many simdgroups tile the output block. With ``WM=4, WN=4``, -16 simdgroups each handle a ``(BLOCK_M/WM) x (BLOCK_N/WN)`` = 32x32 subtile. +16 simdgroups each handle a ``(BLOCK_M/WM) x (BLOCK_N/WN)`` = 32x32 subtile: + +.. image:: /_static/simdgroup-layout.svg + :alt: 4x4 simdgroup grid layout, 16 simdgroups each handling a 32x32 subtile + :width: 100% Fused Epilogues @@ -94,7 +90,7 @@ and fuses them into the kernel. No extra memory traffic: acc = metile.dot(a, b, acc) - # These are fused into the GEMM — no global memory round-trip + # These are fused into the GEMM, no global memory round-trip acc = metile.where(acc > 0, acc, 0) # ReLU acc = acc * scale # scale acc = metile.exp(acc) # unary @@ -109,14 +105,14 @@ Tile Scheduling For 2D grids, the order in which tiles are assigned to threadgroups affects L2 cache locality. meTile supports several scheduling patterns: -**Morton (Z-order)** — default - Tiles are assigned in 2x2 blocks following a Z-curve. Adjacent threadgroups share - A-row and B-column data in L2 cache. +.. image:: /_static/morton-swizzle.svg + :alt: Morton Z-order vs linear tile scheduling, showing how 2x2 blocks share L2 cache + :width: 100% -**Diagonal** +**Diagonal**: Column assignment is rotated by the row index. Distributes memory traffic. -**Linear** +**Linear**: Simple row-major assignment. No locality optimization. The compiler applies Morton scheduling by default. You can override it: