diff --git a/docs/_static/compilation-pipeline.svg b/docs/_static/compilation-pipeline.svg
new file mode 100644
index 0000000..0b7d635
--- /dev/null
+++ b/docs/_static/compilation-pipeline.svg
@@ -0,0 +1,90 @@
+
diff --git a/docs/_static/gemm-tiling.svg b/docs/_static/gemm-tiling.svg
new file mode 100644
index 0000000..809a7d9
--- /dev/null
+++ b/docs/_static/gemm-tiling.svg
@@ -0,0 +1,161 @@
+
diff --git a/docs/_static/morton-swizzle.svg b/docs/_static/morton-swizzle.svg
new file mode 100644
index 0000000..9c86b7f
--- /dev/null
+++ b/docs/_static/morton-swizzle.svg
@@ -0,0 +1,150 @@
+
diff --git a/docs/_static/simdgroup-layout.svg b/docs/_static/simdgroup-layout.svg
new file mode 100644
index 0000000..4c8e2f1
--- /dev/null
+++ b/docs/_static/simdgroup-layout.svg
@@ -0,0 +1,92 @@
+
diff --git a/docs/_static/tiling-overview.svg b/docs/_static/tiling-overview.svg
new file mode 100644
index 0000000..7a856de
--- /dev/null
+++ b/docs/_static/tiling-overview.svg
@@ -0,0 +1,105 @@
+
diff --git a/docs/_static/unified-memory.svg b/docs/_static/unified-memory.svg
new file mode 100644
index 0000000..bf6bdde
--- /dev/null
+++ b/docs/_static/unified-memory.svg
@@ -0,0 +1,60 @@
+
diff --git a/docs/conf.py b/docs/conf.py
index 4b0035a..e74c011 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -1,5 +1,5 @@
project = "meTile"
-copyright = "2025, Andre Slavescu"
+copyright = "2026, Andre Slavescu"
author = "Andre Slavescu"
extensions = [
diff --git a/docs/examples/fused-activations.rst b/docs/examples/fused-activations.rst
index 565a159..70c3166 100644
--- a/docs/examples/fused-activations.rst
+++ b/docs/examples/fused-activations.rst
@@ -8,7 +8,7 @@ to run different computations on different simdgroup subsets within a single ker
Simple Activations
------------------
-Element-wise kernels follow the same pattern as vector add — load, compute, store:
+Element-wise kernels follow the same pattern as vector add (load, compute, store):
.. code-block:: python
@@ -39,7 +39,7 @@ Fused GEMM + Activation
------------------------
When an activation follows a ``dot`` operation, the compiler fuses it into the GEMM epilogue.
-The activation runs on register-resident data — no global memory round-trip:
+The activation runs on register-resident data, no global memory round-trip:
.. code-block:: python
@@ -54,7 +54,7 @@ The activation runs on register-resident data — no global memory round-trip:
a = metile.tile_load(A, pid_m * BLOCK_M, k, K, (BLOCK_M, BLOCK_K))
b = metile.tile_load(B, k, pid_n * BLOCK_N, N, (BLOCK_K, BLOCK_N))
acc = metile.dot(a, b, acc)
- # Fused GELU epilogue — runs on accumulator registers
+ # Fused GELU epilogue, runs on accumulator registers
acc = acc / (1.0 + metile.exp(-1.702 * acc))
metile.tile_store(C, pid_m * BLOCK_M, pid_n * BLOCK_N, N, acc, (BLOCK_M, BLOCK_N))
@@ -64,7 +64,7 @@ Simdgroup Roles
Apple GPUs organize threads into 32-thread **simdgroups**. A threadgroup can contain
multiple simdgroups. With ``simdgroup_role``, you can assign different work to different
-simdgroup subsets — useful for computing multiple outputs in a single dispatch:
+simdgroup subsets, useful for computing multiple outputs in a single dispatch:
.. code-block:: python
@@ -85,16 +85,13 @@ simdgroup subsets — useful for computing multiple outputs in a single dispatch
metile.store(out_sqrt + offs, metile.sqrt(metile.abs(x)), mask=mask)
With ``num_roles=2``, the threadgroup's simdgroups are split in half. Role 0 computes
-exponentials while role 1 computes square roots — simultaneously, in the same kernel launch.
-
-This is useful when you need multiple derived outputs from the same input and want to
-avoid the overhead of multiple kernel dispatches.
+exponentials while role 1 computes square roots, simultaneously, in the same kernel launch.
GEGLU (Gated GELU)
-------------------
-A practical use of simdgroup roles — computing the gate and up projections of GEGLU
+A practical use of simdgroup roles for computing the gate and up projections of GEGLU
in parallel:
.. code-block:: python
@@ -121,6 +118,6 @@ Concepts Introduced
- Element-wise activation patterns
- ``metile.exp`` for activation functions
-- Fused GEMM epilogues — zero-cost post-GEMM operations
-- ``metile.simdgroup_role`` — split work across simdgroup subsets
+- Fused GEMM epilogues: zero-cost post-GEMM operations
+- ``metile.simdgroup_role``: split work across simdgroup subsets
- Multiple outputs from a single kernel
diff --git a/docs/examples/layernorm.rst b/docs/examples/layernorm.rst
index 282f34e..bbc3fb3 100644
--- a/docs/examples/layernorm.rst
+++ b/docs/examples/layernorm.rst
@@ -68,5 +68,5 @@ Concepts Introduced
- Three-pass algorithm (mean, variance, normalize)
- Scalar accumulators across tiled loops
- ``metile.sum`` reduction
-- ``metile.sqrt`` — element-wise square root
+- ``metile.sqrt``: element-wise square root
- Loading separate weight/bias arrays (shared across all rows)
diff --git a/docs/examples/matmul.rst b/docs/examples/matmul.rst
index 20ca67d..3f49e57 100644
--- a/docs/examples/matmul.rst
+++ b/docs/examples/matmul.rst
@@ -30,7 +30,7 @@ Basic GEMM
Launching
---------
-The grid is 2D — one program instance per output tile:
+The grid is 2D, one program instance per output tile:
.. code-block:: python
@@ -64,8 +64,8 @@ The compiler maps ``dot`` to the appropriate hardware:
Fused GEMM + ReLU
------------------
-Element-wise operations after the GEMM loop are fused into the kernel's epilogue —
-they run on register-resident data with zero extra memory traffic:
+Element-wise operations after the GEMM loop are fused into the kernel's epilogue.
+They run on register-resident data with zero extra memory traffic:
.. code-block:: python
@@ -80,7 +80,7 @@ they run on register-resident data with zero extra memory traffic:
a = metile.tile_load(A, pid_m * BLOCK_M, k, K, (BLOCK_M, BLOCK_K))
b = metile.tile_load(B, k, pid_n * BLOCK_N, N, (BLOCK_K, BLOCK_N))
acc = metile.dot(a, b, acc)
- acc = metile.where(acc > 0, acc, 0) # fused ReLU — no global memory round-trip
+ acc = metile.where(acc > 0, acc, 0) # fused ReLU, no global memory round-trip
metile.tile_store(C, pid_m * BLOCK_M, pid_n * BLOCK_N, N, acc, (BLOCK_M, BLOCK_N))
@@ -132,9 +132,9 @@ See :doc:`/guide/autotuning` for the full autotuning guide.
Concepts Introduced
-------------------
-- ``metile.zeros`` — register-resident accumulator initialization
-- ``metile.dot`` — tile-level matrix multiply-accumulate
-- ``metile.tile_load`` / ``metile.tile_store`` — 2D strided memory access
-- 2D grids — ``kernel[(grid_m, grid_n)]``
-- Fused epilogues — element-wise ops after GEMM are free
-- Tile swizzle — cache-friendly scheduling patterns
+- ``metile.zeros``: register-resident accumulator initialization
+- ``metile.dot``: tile-level matrix multiply-accumulate
+- ``metile.tile_load`` / ``metile.tile_store``: 2D strided memory access
+- 2D grids: ``kernel[(grid_m, grid_n)]``
+- Fused epilogues: element-wise ops after GEMM are free
+- Tile swizzle: cache-friendly scheduling patterns
diff --git a/docs/examples/softmax.rst b/docs/examples/softmax.rst
index 9c5fc28..e172b1f 100644
--- a/docs/examples/softmax.rst
+++ b/docs/examples/softmax.rst
@@ -53,30 +53,12 @@ Each program instance handles one row. The grid is 1D with one instance per row:
softmax[(rows,)](X, Out, cols, BLOCK=256)
-How It Works
-------------
-
-The kernel makes three passes over each row:
-
-1. **Find max** — ``metile.maximum`` computes element-wise max across tiles, then
- ``metile.max`` reduces the tile to a scalar. This is needed for numerical stability
- (subtracting the max prevents overflow in ``exp``).
-
-2. **Sum exponentials** — accumulates ``exp(x - m)`` across all tiles, then
- ``metile.sum`` reduces to a scalar denominator.
-
-3. **Normalize** — divides each ``exp(x - m)`` by the sum.
-
-Each pass iterates over the row in chunks of ``BLOCK`` elements using ``tile_range``.
-The ``mask`` ensures correctness when ``N`` is not a multiple of ``BLOCK``.
-
-
Concepts Introduced
-------------------
-- ``metile.tile_range`` — tiling loop for iterating over a dimension
-- ``metile.maximum`` / ``metile.max`` — element-wise max and reduction
-- ``metile.sum`` — sum reduction
-- ``metile.exp`` — element-wise exponential
-- Multi-pass algorithms — reading the same data multiple times in different passes
+- ``metile.tile_range``: tiling loop for iterating over a dimension
+- ``metile.maximum`` / ``metile.max``: element-wise max and reduction
+- ``metile.sum``: sum reduction
+- ``metile.exp``: element-wise exponential
+- Multi-pass algorithms: reading the same data multiple times in different passes
- Scalar accumulators (``m``, ``s``) carried across loop iterations
diff --git a/docs/examples/vector-add.rst b/docs/examples/vector-add.rst
index 6d0f08b..c28f0e7 100644
--- a/docs/examples/vector-add.rst
+++ b/docs/examples/vector-add.rst
@@ -37,9 +37,9 @@ The simplest meTile kernel: add two arrays element by element.
Concepts Introduced
-------------------
-- ``@metile.kernel`` — compile a Python function to Metal
-- ``metile.program_id`` — which program instance am I?
-- ``metile.arange`` — tile of consecutive indices
-- ``metile.load`` / ``metile.store`` — masked memory access
-- ``metile.Buffer`` — zero-copy GPU memory
-- ``kernel[grid]()`` — launch with a grid of instances
+- ``@metile.kernel``: compile a Python function to Metal
+- ``metile.program_id``: which program instance am I?
+- ``metile.arange``: tile of consecutive indices
+- ``metile.load`` / ``metile.store``: masked memory access
+- ``metile.Buffer``: zero-copy GPU memory
+- ``kernel[grid]()``: launch with a grid of instances
diff --git a/docs/getting-started/first-kernel.rst b/docs/getting-started/first-kernel.rst
index f703011..3dace45 100644
--- a/docs/getting-started/first-kernel.rst
+++ b/docs/getting-started/first-kernel.rst
@@ -30,7 +30,7 @@ Let's break this down line by line.
Device pointers to GPU memory. These map to ``device float*`` in Metal.
``N``
- A runtime scalar — passed as a ``constant int&`` to the shader.
+ A runtime scalar, passed as a ``constant int&`` to the shader.
``BLOCK: metile.constexpr``
A **compile-time constant**. The value is baked directly into the shader. Changing it
@@ -75,7 +75,7 @@ Launching
``metile.Buffer``
Wraps a Metal buffer in unified memory. CPU and GPU share the same physical memory on
- Apple Silicon — there is no copy between host and device.
+ Apple Silicon, so there is no copy between host and device.
``metile.Buffer.zeros((N,))``
Allocates a zeroed buffer of ``N`` float32 elements.
@@ -100,6 +100,10 @@ When you call ``add[grid](...)``, meTile:
5. **Compiles** with ``xcrun metal -O2`` (or JIT if Xcode is unavailable)
6. **Dispatches** the compute pipeline on the GPU
+.. image:: /_static/compilation-pipeline.svg
+ :alt: meTile compilation pipeline: Python to Tile IR to Metal IR to MSL to GPU
+ :width: 100%
+
You can inspect any stage with the ``METILE_DEBUG`` environment variable:
.. code-block:: bash
@@ -112,6 +116,6 @@ You can inspect any stage with the ``METILE_DEBUG`` environment variable:
What's Next
-----------
-- :doc:`/guide/language` — full language reference for what you can write inside ``@metile.kernel``
-- :doc:`/examples/softmax` — a more complex kernel with reductions and multiple passes
-- :doc:`/examples/matmul` — tile-level matrix multiply with ``dot`` and ``tile_load``
+- :doc:`/guide/language` for the full language reference
+- :doc:`/examples/softmax` for a more complex kernel with reductions and multiple passes
+- :doc:`/examples/matmul` for tile-level matrix multiply with ``dot`` and ``tile_load``
diff --git a/docs/getting-started/install.rst b/docs/getting-started/install.rst
index 62f2551..0927668 100644
--- a/docs/getting-started/install.rst
+++ b/docs/getting-started/install.rst
@@ -5,7 +5,7 @@ Requirements
------------
- macOS 13 (Ventura) or later
-- Apple Silicon (M1, M2, M3, M4 — any variant)
+- Apple Silicon (M1, or later)
- Python 3.10+
Install
diff --git a/docs/guide/autotuning.rst b/docs/guide/autotuning.rst
index 3adb846..68d0f7a 100644
--- a/docs/guide/autotuning.rst
+++ b/docs/guide/autotuning.rst
@@ -52,6 +52,20 @@ On the first call with new key values, the autotuner:
Subsequent calls with the same key values use the cached winner with zero overhead.
+.. code-block:: text
+
+ First call (M=1024, N=1024, K=1024):
+ +--------------------------------------------------+
+ | Config(BM=64, BN=64, BK=32): 1.26ms |
+ | Config(BM=128, BN=128, BK=64): 0.62ms <-- | winner cached
+ | Config(BM=128, BN=128, BK=128): 0.91ms |
+ +--------------------------------------------------+
+
+ Subsequent calls (same M, N, K):
+ +--------------------------------------------------+
+ | cached -> Config(BM=128, BN=128, BK=64) | no re-tuning
+ +--------------------------------------------------+
+
Config Object
-------------
@@ -117,6 +131,6 @@ fast dispatcher that skips all Python overhead on subsequent calls:
dispatch = autotuned_matmul[grid].prepare(A, B, C, M, N, K)
- # Hot loop — minimal Python overhead per call
+ # hot path with minimal python overhead
for _ in range(1000):
dispatch()
diff --git a/docs/guide/language.rst b/docs/guide/language.rst
index 18cb2f2..fb3a45b 100644
--- a/docs/guide/language.rst
+++ b/docs/guide/language.rst
@@ -2,7 +2,7 @@ Language Reference
==================
meTile provides a Python eDSL (embedded domain-specific language) for writing GPU kernels. Functions
-decorated with ``@metile.kernel`` are traced and compiled to Metal shaders — they are not executed
+decorated with ``@metile.kernel`` are traced and compiled to Metal shaders. They are not executed
as regular Python.
This page documents every construct available inside a ``@metile.kernel`` function.
@@ -19,9 +19,9 @@ Kernel Definition
Parameters are either:
-- **Pointers** — numpy arrays or ``metile.Buffer`` objects become ``device float*`` in Metal
-- **Scalars** — Python ints/floats become ``constant int&`` or ``constant float&``
-- **Constexprs** — annotated with ``metile.constexpr``, baked into the shader at compile time
+- **Pointers**: numpy arrays or ``metile.Buffer`` objects become ``device float*`` in Metal
+- **Scalars**: Python ints/floats become ``constant int&`` or ``constant float&``
+- **Constexprs**: annotated with ``metile.constexpr``, baked into the shader at compile time
Constexprs are passed as keyword arguments at launch:
diff --git a/docs/guide/memory.rst b/docs/guide/memory.rst
index 94ba84f..afe95f4 100644
--- a/docs/guide/memory.rst
+++ b/docs/guide/memory.rst
@@ -1,9 +1,13 @@
Memory Model
============
-Apple Silicon has a **unified memory architecture** — the CPU and GPU share the same physical
+Apple Silicon has a **unified memory architecture** where the CPU and GPU share the same physical
memory. meTile exposes this directly through ``metile.Buffer``.
+.. image:: /_static/unified-memory.svg
+ :alt: Unified memory: CPU and GPU both access the same physical memory through metile.Buffer
+ :width: 100%
+
Buffers
-------
@@ -13,7 +17,7 @@ Buffers
import numpy as np
import metile
- # Create from numpy (zero-copy — the GPU reads the same memory)
+ # Create from numpy (zero-copy, the GPU reads the same memory)
x = metile.Buffer(data=np.random.randn(1024).astype(np.float32))
# Allocate zeroed
@@ -71,6 +75,16 @@ memory access:
x = metile.load(X + offs, mask=mask) # masked-off lanes read 0
metile.store(Out + offs, x, mask=mask) # masked-off lanes are skipped
+.. code-block:: text
+
+ N = 10, BLOCK = 4, pid = 2 (last instance)
+
+ offs = [8, 9, 10, 11]
+ mask = [T, T, F, F] # values 10 and 11 are out of bounds
+
+ load: reads x[8], x[9], returns 0 for indices 10, 11
+ store: writes out[8], out[9], skips indices 10, 11
+
Masking is essential for correctness. Without it, the last program instance would read/write
past the end of the array.
@@ -85,5 +99,5 @@ For kernels that need inter-thread communication within a threadgroup, use share
buf = metile.shared(size=256, dtype="f32")
metile.barrier() # synchronize all threads in the threadgroup
-Shared memory is threadgroup-local — it is not visible to other threadgroups. Use
+Shared memory is threadgroup-local and not visible to other threadgroups. Use
``metile.barrier()`` to synchronize access within a threadgroup.
diff --git a/docs/guide/tile-ops.rst b/docs/guide/tile-ops.rst
index fd68120..7fbe675 100644
--- a/docs/guide/tile-ops.rst
+++ b/docs/guide/tile-ops.rst
@@ -12,13 +12,13 @@ The Two Backends
meTile automatically selects the best backend for your hardware when compiling GEMM kernels:
**Simdgroup Matrix (M1/M2/M3)**
- Uses ``simdgroup_matrix`` — Apple's 8x8 matrix multiply-accumulate
+ Uses ``simdgroup_matrix``, Apple's 8x8 matrix multiply-accumulate
primitive. Each simdgroup (32 threads) collaboratively computes an 8x8 tile.
The compiler tiles the output across multiple simdgroups and uses threadgroup
(shared) memory to stage data.
**Metal 4 Tensor Ops (M4+)**
- Uses ``matmul2d`` with ``cooperative_tensor`` — Metal 4's hardware matrix multiply
+ Uses ``matmul2d`` with ``cooperative_tensor``, Metal 4's hardware matrix multiply
descriptors. Each simdgroup independently loads data from device memory into
register-resident cooperative tensors and runs the MMA. No threadgroup memory needed.
@@ -29,20 +29,12 @@ your hardware and chooses the right path.
How Tiling Works
----------------
-A GEMM kernel tiles the computation into blocks:
+A GEMM kernel tiles the computation into blocks. Each program instance computes
+one output tile, iterating over K to accumulate partial products:
-.. code-block:: text
-
- Output C (M x N) Each tile is BLOCK_M x BLOCK_N
- ┌─────────┬─────────┐
- │ (0,0) │ (0,1) │ Each program instance computes one tile.
- │ 128x128 │ 128x128 │ The K dimension is tiled with BLOCK_K.
- ├─────────┼─────────┤
- │ (1,0) │ (1,1) │ grid = (ceil(M/BLOCK_M), ceil(N/BLOCK_N))
- │ 128x128 │ 128x128 │
- └─────────┴─────────┘
-
-Inside each tile, the K-loop accumulates partial results:
+.. image:: /_static/tiling-overview.svg
+ :alt: Output matrix tiled into blocks, with K-loop detail showing tile_load and dot accumulation
+ :width: 100%
.. code-block:: python
@@ -81,7 +73,11 @@ The tile sizes are compile-time constants that control how the hardware is used:
- 2, 4
``WM`` and ``WN`` control how many simdgroups tile the output block. With ``WM=4, WN=4``,
-16 simdgroups each handle a ``(BLOCK_M/WM) x (BLOCK_N/WN)`` = 32x32 subtile.
+16 simdgroups each handle a ``(BLOCK_M/WM) x (BLOCK_N/WN)`` = 32x32 subtile:
+
+.. image:: /_static/simdgroup-layout.svg
+ :alt: 4x4 simdgroup grid layout, 16 simdgroups each handling a 32x32 subtile
+ :width: 100%
Fused Epilogues
@@ -94,7 +90,7 @@ and fuses them into the kernel. No extra memory traffic:
acc = metile.dot(a, b, acc)
- # These are fused into the GEMM — no global memory round-trip
+ # These are fused into the GEMM, no global memory round-trip
acc = metile.where(acc > 0, acc, 0) # ReLU
acc = acc * scale # scale
acc = metile.exp(acc) # unary
@@ -109,14 +105,14 @@ Tile Scheduling
For 2D grids, the order in which tiles are assigned to threadgroups affects L2 cache locality.
meTile supports several scheduling patterns:
-**Morton (Z-order)** — default
- Tiles are assigned in 2x2 blocks following a Z-curve. Adjacent threadgroups share
- A-row and B-column data in L2 cache.
+.. image:: /_static/morton-swizzle.svg
+ :alt: Morton Z-order vs linear tile scheduling, showing how 2x2 blocks share L2 cache
+ :width: 100%
-**Diagonal**
+**Diagonal**:
Column assignment is rotated by the row index. Distributes memory traffic.
-**Linear**
+**Linear**:
Simple row-major assignment. No locality optimization.
The compiler applies Morton scheduling by default. You can override it: