Skip to content

MLX support #8

@kyr0

Description

@kyr0

As commented on LinkedIn, I'm working on an MLX port of this project - mainly to speed up the local models that I'm running.

My prior work regarding the MLX/ANE platform spans:

  • Custom Bonsai inference server Qwen3-8B with just 1.1 GB memory footprint + specific KV quantization including support for custom speculative decoding with working TurboQuant integration, 40-50 t/s on Macbook Air M4 24GB when running in background alongside many other processes
  • contributions to mlx-audio fixing Qwen3-TTS implementation issues
  • several ANE FluidAudio ASR/TTS custom inference server implementations (to be released)
  • my typical coding/working style (C++): https://github.com/kyr0/libsharedmemory

I'm opening this issue for:

a) reporting/tracking progress before opening a PR (keeping the noise level low)
b) asking the community for feature requests
c) checking if anyone would be interested to collaborate
d) asking the maintainer(s) if your have any concrete idea on how cross-backend support should look like

Broader Motivation/Context:

I'm engaging on this in my spare time. I'm currently on vacation -- my mid-term goal is to build a pretty good local voice agent that can be used as a generic, multilingual, intelligent realtime interface audio (TTS + ASR) / video on Mac using a multi-backend inference approach (clever management of resources/scheduling to limit memory bandwidth + smart timing). Because of this, I need to speed up MLX models that are suitable for the task -- so my interest is a bit narrow-focused, but as I'm targeting MLX kernels here, this should still yield a generalizable outcome; I'm, in parallel, also focusing on improving model quantization (separate project, research paper on optimal Gemma 4 quantization is pending rn..). To get the maximum out of any new model until the goal is reached, I'm also checking every corner where I could improve the inference pipeline when it comes to cache quantization and throughput maximization (see Bonsai Server).

Implementation plan:

  • I froze my MLX fork including 1 bit inference support for PrinsmML/Bonsai-8B, commit ref
  • I'm adding an mlx backend to AutoKernel instead of trying to force the current Triton/CUDA flow onto MLX
  • Making the optimization unit a target manifest, not a single editable kernel file
  • I created calibration files with prompt/response pairs and use only three frozen replay/calibration packs:
    • calibration_gemma4_e4b.json
    • calibration_gemma4_e2b.json
    • calibration_bonsai_8b.json
  • I'm running real inference for the matching target model against its fixed pack
  • Treating the stored response, tokens, time_s, and tok_per_sec fields as the frozen baseline for replay validation and throughput comparison (imo AutoKernel should not optimize for theoretical results, but real world results)
  • As almost all models that users run in practice, are quantized, I started with explicit pipeline support for:
    • quantized.cpp
    • quantized.h / quantized.metal
    • quantized_nax.h / quantized_nax.metal
    • quantized_utils.h
    • scaled_dot_product_attention.cpp
    • sdpa_vector.h
    • scaled_dot_product_attention.metal
  • I'm currently building bench_mlx.py on top of public MLX ops only:
    • mlx.core.quantized_matmul
    • mlx.core.fast.scaled_dot_product_attention
    • mlx.core.fast.rms_norm
    • mlx.core.fast.rope
  • I'll add verify_mlx.py for fixed-pack replay; it must run prompt-by-prompt inference and record:
    • decoded continuation
    • generated token count
    • wall time
    • tok/s
    • memory / TTFT if available
  • I'll then add rebuild automation so every experiment can:
    • patch
    • rebuild
    • microbench
    • keep or revert
  • Then record the actual dispatch path for every run
  • Optimization imo. needs to go in this order:
    1. quantized decode dispatch (qmv_fast / qmv / qmv_quad)
    2. quantized prefill dispatch (qmm_t / qmm_n / NAX-gated paths)
    3. qdot / qdot_safe
    4. QMM loader + GEMM loop pipeline
    5. sdpa_vector / 2-pass SDPA
    6. RMSNorm
    7. single-token RoPE
    8. qvm_split_k
  • I'm treating standalone softmax, LayerNorm, GEMV-masked, and Hadamard as secondary until traces prove otherwise
  • AutoKernel-MLX "Agentic DoD" will require every kept change to pass:
    • compile
    • op-level correctness
    • determinism
    • fixed-pack replay on the relevant model
    • no unacceptable response drift
    • throughput improvement or trace-backed hot-path win
  • It will capture Metal traces only on kept/promoted candidates, using the same fixed prompt packs
  • Future research (V2): After a single-kernel path is solid, it will move to the higher-ROI work: reducing Metal dispatch overhead and fusing decode-path ops where it actually pays off

WDYT?

Edits: Grammar, mostly, and some wordings

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions