Skip to content

feat(hstu): enable HSTU + DynamicEmb E2E training on Blackwell (sm_100)#399

Draft
JacoCheung wants to merge 18 commits into
NVIDIA:mainfrom
JacoCheung:junzhang/hstu_blackwell_enable
Draft

feat(hstu): enable HSTU + DynamicEmb E2E training on Blackwell (sm_100)#399
JacoCheung wants to merge 18 commits into
NVIDIA:mainfrom
JacoCheung:junzhang/hstu_blackwell_enable

Conversation

@JacoCheung
Copy link
Copy Markdown
Collaborator

@JacoCheung JacoCheung commented May 19, 2026

Summary

Enables HSTU + DynamicEmb E2E training on NVIDIA Blackwell (sm_100). Three commits:

  1. HSTU attention kernel on sm_100 — bump FBGEMM submodule so Path A hstu_varlen_fwd_80/90 resolves on Blackwell; bump cutlass-dsl to the 4.4.1 line (earliest one whose AST preprocessor accepts hstu_blackwell relative imports); route sm10 attention through FusedHSTUAttention.
  2. E2E training pipeline on sm_100 — add 10.0 to TORCH_CUDA_ARCH_LIST for FBGEMM Path A / dynamicemb / commons / kvcache_manager; nvcomp tarball arch select for arm64 vs x86; slurm_job.sub exports RANK/WORLD_SIZE so torchrec's local-size detection doesn't fall back on multi-tray; arm64 skips inference-only FlexKV / NVE.
  3. Base image 26.02 → 26.04 — nvcr pytorch 26.04 ships cutlass-dsl 4.4.1 directly, so the in-Dockerfile uninstall/reinstall drops out. FBGEMM submodule pinned to immutable tag. TorchRec install bundled into the same Dockerfile stage as FBGEMM (FBGEMM C++ ops version-coupled); FBGEMM stage uses python3 not python so tritonserver base (no python alias) also builds.

Verified

test_pipeline -k "contextual_feature_names1" passes on:

Arch Result
arm64 (GB200) 8 passed
x86 (B200) 8 passed

Covers DynamicEmb on/off × prefetch/native sparse pipeline × max_num_candidates 0/10, bf16, checkpoint save/load. The -k filter excludes the context-mask cases the sm_100 kernel doesn't yet support.

CI

Companion CI MR (multi-arch base_fbgemm + devel manifests, plus B200 smoke job, separate tritonserver FBGEMM cache): !130
(internal: https://gitlab-master.nvidia.com/Devtech-Compute/distributed-recommender/-/merge_requests/130 )

Test plan

  • CI passes on existing arch (A100 / H100 / L20).
  • B200 smoke job in MR !130 passes on landing.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 19, 2026

Greptile Summary

This PR enables HSTU + DynamicEmb end-to-end training on NVIDIA Blackwell (sm_100) across three commits: routing sm10 through FusedHSTUAttention via a bumped FBGEMM submodule, adding 10.0 to all CUDA arch lists and fixing multi-node SLURM rank detection, and upgrading the base image to nvcr pytorch 26.04 with a restructured two-stage Dockerfile.

  • sm_100 kernel routing (fused_hstu_op.py, hstu_attention.py): The forward uses hstu.hstu_attn_varlen_func (FBGEMM Path A, now resolving on sm_100 after the submodule bump); the backward dispatches to hstu_varlen_bwd_100 for head_dim ∈ {64, 128} and falls back to a pytorch_hstu_mha re-computation for other head dims. Prior review threads flag the missing head_dim guard in the forward path and the broken pass/return backward guard in test_fbgemm_hstu_smoke.py.
  • Dockerfile refactor (3-stage: base_fbgemmdevelbuild): Allows pre-built FBGEMM to be injected via BASE_FBGEMM_IMAGE, adds arm64-aware nvcomp tarball selection, and removes the nvidia-cutlass-dsl reinstall block now that 26.04 ships 4.4.1 directly.
  • cal_hstu_flops refactor (perf.py): Replaces the gather-to-rank-0 pattern with a per-rank local computation followed by all_reduce, so every rank now receives the total FLOPS — used correctly in training.py via print_rank_0.

Confidence Score: 3/5

The sm_100 forward/backward routing in fused_hstu_op.py has multiple open correctness issues identified across prior review rounds that are not yet resolved in this diff.

Several defects flagged in earlier rounds remain unaddressed: the backward guard in test_fbgemm_hstu_smoke.py uses pass instead of return so the Blackwell backward runs unconditionally; the sm8/sm9 head_dim and contextual parametrization was collapsed, silently removing non-Blackwell coverage; the forward on sm10 lacks a head_dim guard while the backward does have one, so a 32-wide CUTLASS forward computes gradients via pytorch_hstu_mha (a different computation); and slurm_job.sub allocates 4 GPUs for 8 tasks. The core training path (fused_hstu_op, dynamicemb, commons) is otherwise structurally sound.

examples/hstu/test/hstu_attn/test_fbgemm_hstu_smoke.py (backward guard, parametrization), examples/hstu/ops/fused_hstu_op.py (head_dim guard asymmetry), examples/hstu/training/benchmark/scripts/slurm_job.sub (GPU/task count)

Important Files Changed

Filename Overview
examples/hstu/ops/fused_hstu_op.py Adds sm10 forward and backward paths; forward uses hstu.hstu_attn_varlen_func (FBGEMM Path A), backward uses hstu_varlen_bwd_100 for head_dim in (64,128) and falls back to pytorch_hstu_mha otherwise — this asymmetry and the missing head_dim guard in the forward are flagged in prior review threads.
docker/Dockerfile Splits into three stages (base_fbgemm, devel, build); adds sm_100 arch lists; handles arm64 vs x86 nvcomp and NVE/AVX512 correctly; removes build-time import validation tests, reducing silent-failure visibility.
examples/hstu/test/hstu_attn/test_fbgemm_hstu_smoke.py Adds Blackwell smoke tests but collapses parametrization to head_dim=[64] and max_num_contextuals=[0] for all architectures (sm8/sm9 coverage lost), and the backward guard on sm10 uses pass instead of return, running backward unconditionally — both flagged in prior threads.
examples/hstu/training/benchmark/scripts/slurm_job.sub Exports RANK/WORLD_SIZE/LOCAL_RANK/LOCAL_WORLD_SIZE for torchrec; adds --gpus-per-node=4 while --ntasks-per-node=8 is still set two lines earlier — GPU/task mismatch flagged in prior thread.
examples/commons/utils/perf.py Adds GB200/B200 specs; refactors cal_hstu_flops from gather-to-rank-0 to local-compute + all_reduce; GB200 is correctly ordered before B200 in all lookup dicts to prevent substring mismatches; behavior change (all ranks now return total FLOPS) is handled correctly in training.py.
examples/hstu/modules/hstu_attention.py Routes sm10 through FusedHSTUAttention; missing num_contextuals guard on Blackwell in FusedHSTUAttention.forward flagged as outside-diff comment.
examples/hstu/training/trainer/training.py Adds _warm_up_data_parallel_collective() before the training loop to pre-initialize NCCL for the DP group; implementation is correct and guards on dist.is_initialized().
examples/tests/commons/test_perf.py New test file validating GB200 and B200 peak TFLOPS lookups, including a test that explicitly verifies GB200 is matched before B200 (substring collision guard).

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[FusedHSTULayerFunction.forward/backward] --> B{sm_major_version?}
    B -->|8| C[torch.ops.fbgemm.hstu_varlen_fwd_80 / hstu_varlen_bwd_80]
    B -->|9| D[torch.ops.fbgemm.hstu_varlen_fwd_90 / hstu_varlen_bwd_90]
    B -->|10| E{head_dim in 64,128?}
    B -->|other| F[ValueError: Unsupported SM]
    E -->|Forward: always| G[hstu.hstu_attn_varlen_func FBGEMM Path A]
    E -->|Backward: yes| H[sm100_ops.hstu_varlen_bwd_100 Blackwell DSL backward]
    E -->|Backward: no| I[pytorch_hstu_mha re-compute grad]
    G --> J[_blackwell_num_contexts_or_none raises if contextuals != 0]
Loading

Reviews (16): Last reviewed commit: "fix(training): reduce HSTU flops over DP..." | Re-trigger Greptile

Comment on lines +200 to +218
if major == 10:
# hstu_blackwell backward (hstu_varlen_bwd_100) currently raises
# "stride_order is not consistent with the layout" inside cutlass DSL
# for the contiguous (L, H, D) tensors this test produces. Until the
# upstream sm100 backward kernel handles this stride pattern, skip
# the backward comparison on Blackwell.
pass
# return

dout = torch.rand_like(new_out)
new_out.backward(dout)
ref_out.backward(dout)
ref_out_fp32.backward(dout.float())
torch.cuda.synchronize()

assert_hstu_close(nq.grad, ref_q.grad, ref_q_fp32.grad, fwd=False)
assert_hstu_close(nk.grad, ref_k.grad, ref_k_fp32.grad, fwd=False)
assert_hstu_close(nv.grad, ref_v.grad, ref_v_fp32.grad, fwd=False)
print(f"[BWD] sm{arch_sm} head_dim={head_dim} ctx={max_num_contextuals} PASS")
# ref_out.backward(dout)
# ref_out_fp32.backward(dout.float())
# torch.cuda.synchronize()

# assert_hstu_close(nq.grad, ref_q.grad, ref_q_fp32.grad, fwd=False)
# assert_hstu_close(nk.grad, ref_k.grad, ref_k_fp32.grad, fwd=False)
# assert_hstu_close(nv.grad, ref_v.grad, ref_v_fp32.grad, fwd=False)
# print(f"[BWD] sm{arch_sm} head_dim={head_dim} ctx={max_num_contextuals} PASS")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 The backward guard for Blackwell is broken: pass is a no-op, so the commented-out # return is never executed and new_out.backward(dout) runs unconditionally on sm10. The comment directly above explains this triggers "stride_order is not consistent with the layout" in the cutlass DSL backward kernel. The fix is to replace pass with return so execution halts before the backward call. Additionally, all backward-vs-reference assertions are commented out for every architecture (sm8/sm9 too), silently removing gradient correctness coverage from the test suite.

Suggested change
if major == 10:
# hstu_blackwell backward (hstu_varlen_bwd_100) currently raises
# "stride_order is not consistent with the layout" inside cutlass DSL
# for the contiguous (L, H, D) tensors this test produces. Until the
# upstream sm100 backward kernel handles this stride pattern, skip
# the backward comparison on Blackwell.
pass
# return
dout = torch.rand_like(new_out)
new_out.backward(dout)
ref_out.backward(dout)
ref_out_fp32.backward(dout.float())
torch.cuda.synchronize()
assert_hstu_close(nq.grad, ref_q.grad, ref_q_fp32.grad, fwd=False)
assert_hstu_close(nk.grad, ref_k.grad, ref_k_fp32.grad, fwd=False)
assert_hstu_close(nv.grad, ref_v.grad, ref_v_fp32.grad, fwd=False)
print(f"[BWD] sm{arch_sm} head_dim={head_dim} ctx={max_num_contextuals} PASS")
# ref_out.backward(dout)
# ref_out_fp32.backward(dout.float())
# torch.cuda.synchronize()
# assert_hstu_close(nq.grad, ref_q.grad, ref_q_fp32.grad, fwd=False)
# assert_hstu_close(nk.grad, ref_k.grad, ref_k_fp32.grad, fwd=False)
# assert_hstu_close(nv.grad, ref_v.grad, ref_v_fp32.grad, fwd=False)
# print(f"[BWD] sm{arch_sm} head_dim={head_dim} ctx={max_num_contextuals} PASS")
if major == 10:
# hstu_blackwell backward (hstu_varlen_bwd_100) currently raises
# "stride_order is not consistent with the layout" inside cutlass DSL
# for the contiguous (L, H, D) tensors this test produces. Until the
# upstream sm100 backward kernel handles this stride pattern, skip
# the backward comparison on Blackwell.
return
dout = torch.rand_like(new_out)
new_out.backward(dout)
ref_out.backward(dout)
ref_out_fp32.backward(dout.float())
torch.cuda.synchronize()
assert_hstu_close(nq.grad, ref_q.grad, ref_q_fp32.grad, fwd=False)
assert_hstu_close(nk.grad, ref_k.grad, ref_k_fp32.grad, fwd=False)
assert_hstu_close(nv.grad, ref_v.grad, ref_v_fp32.grad, fwd=False)
print(f"[BWD] sm{arch_sm} head_dim={head_dim} ctx={max_num_contextuals} PASS")

Comment on lines 21 to 23
#SBATCH --exclusive
#SBATCH --gpus-per-node=4
#SBATCH --output=hstu-e2e-benchmark-%j.out
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 --gpus-per-node=4 conflicts with --ntasks-per-node=8 set two lines above. SLURM will allocate only 4 GPUs per node for 8 tasks, so two processes will share each GPU. Distributed training jobs typically require one GPU per task; this mismatch would cause each rank to see the same CUDA_VISIBLE_DEVICES slot as a sibling rank and likely produce incorrect DynamicEmb placement or OOM. If this is intentional for a specific topology (e.g. NVL72 with a non-default GPU-per-node count), please add a comment explaining the expected hardware layout.

Suggested change
#SBATCH --exclusive
#SBATCH --gpus-per-node=4
#SBATCH --output=hstu-e2e-benchmark-%j.out
#SBATCH --exclusive
#SBATCH --gpus-per-node=8
#SBATCH --output=hstu-e2e-benchmark-%j.out

@JacoCheung JacoCheung force-pushed the junzhang/hstu_blackwell_enable branch 5 times, most recently from f3ab011 to 7612cfe Compare May 20, 2026 03:09
Comment on lines +43 to +46
@pytest.mark.parametrize("head_dim", [64])
@pytest.mark.parametrize("num_heads", [4])
@pytest.mark.parametrize("max_num_targets", [10])
@pytest.mark.parametrize("max_num_contextuals", [0, 4])
@pytest.mark.parametrize("max_num_contextuals", [0])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 sm8/sm9 test coverage silently lost

The parametrization was collapsed from head_dim=[32, 64, 128] / max_num_contextuals=[0, 4] to head_dim=[64] / max_num_contextuals=[0] to avoid triggering Blackwell kernel assertions. However, this collapses the parameter space for ALL architectures: on sm8 and sm9, head_dim=32, head_dim=128, and max_num_contextuals=4 were previously tested and are now gone from CI. Future regressions in Ampere/Hopper forward or gradient correctness for those input shapes will pass silently.

The sm10-specific skips (if major == 10: pytest.skip(...)) are the right mechanism to exclude unsupported cases on Blackwell while keeping the full matrix for older GPUs. Restoring the original [32, 64, 128] / [0, 4] parametrization and relying on those guards would preserve the existing coverage.

Comment on lines +367 to +373
elif sm_major_version == 10:
assert q.dtype in (
torch.bfloat16,
torch.float16,
), f"Blackwell fwd expects bfloat16 or float16, got {q.dtype}"
num_contexts = _blackwell_num_contexts_or_none(num_contexts)
jagged_attn_output = hstu.hstu_attn_varlen_func(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Missing head_dim guard in the sm10 forward path creates an asymmetry with the backward. The backward explicitly checks q.shape[-1] in (64, 128) and falls back to pytorch_hstu_mha for other head_dims. If the CUTLASS forward kernel accepts an unsupported head_dim (e.g., 32) without raising — even if it silently misfires — the backward will compute gradients of pytorch_hstu_mha, not of the actual CUTLASS forward computation, violating the autograd.Function contract and producing wrong gradients. Adding the same guard in the forward ensures a clear ValueError rather than a cryptic kernel assertion and keeps forward/backward consistent.

Suggested change
elif sm_major_version == 10:
assert q.dtype in (
torch.bfloat16,
torch.float16,
), f"Blackwell fwd expects bfloat16 or float16, got {q.dtype}"
num_contexts = _blackwell_num_contexts_or_none(num_contexts)
jagged_attn_output = hstu.hstu_attn_varlen_func(
elif sm_major_version == 10:
assert q.dtype in (
torch.bfloat16,
torch.float16,
), f"Blackwell fwd expects bfloat16 or float16, got {q.dtype}"
if q.shape[-1] not in (64, 128):
raise ValueError(
f"Blackwell fwd only supports head_dim in (64, 128), got {q.shape[-1]}"
)
num_contexts = _blackwell_num_contexts_or_none(num_contexts)
jagged_attn_output = hstu.hstu_attn_varlen_func(

@JacoCheung
Copy link
Copy Markdown
Collaborator Author

pytorch-26.04 has triton bug. Need newer base image(26.05), which is not publicly available.

@JacoCheung JacoCheung marked this pull request as draft May 26, 2026 01:26
@JacoCheung
Copy link
Copy Markdown
Collaborator Author

Another constraint is blackwell does not support contextual features. Waiting for #395 .

@JacoCheung JacoCheung force-pushed the junzhang/hstu_blackwell_enable branch from 0d66e6d to ca5ef37 Compare May 27, 2026 09:52
Path B fbgemm_gpu_hstu sm_100 forward+backward smoke validated on GB200:
1 passed in 26.83s
  pytest -xvs examples/hstu/test/hstu_attn/test_fbgemm_hstu_smoke.py
  → test_fbgemm_hstu_fwd_bwd[0-10-4-64-200-32]

Three minimum changes:

- third_party/FBGEMM: 65bad42a → 5f13f139
  Branch jiayus-nvidia/FBGEMM:junzhang/dev_with_export_and_patch, off upstream
  dev with two cherry-picks from torch_exportable_cuda13 (setup.py fix +
  register_fake for torch.export) and one new commit dropping the
  load_library < (10,0) gate so register_fake can resolve
  fbgemm::hstu_varlen_fwd_{80,90} on Blackwell. hstu_blackwell/* kept on
  relative imports so the same source installs under both Path A
  (fbgemm_gpu.experimental.hstu.*) and Path B (top-level hstu.*).

- docker/Dockerfile: bump nvidia-cutlass-dsl 4.3.0 → 4.4.1 (uninstall + fresh
  install, since 4.x layout changes break a straight upgrade). 4.4.1 is the
  earliest version whose AST preprocessor handles `from . import X` —
  verified by reading ast_preprocessor.py source from both wheels.

- docker/Dockerfile: HSTU_ARCH_LIST 8.0 9.0 → 8.0 9.0 10.0 (Path B build).

- examples/hstu/modules/hstu_attention.py: route sm10 through
  FusedHSTUAttention (cutlass backend) instead of the print-and-fallback
  TorchHSTUAttention path.

- test_fbgemm_hstu_smoke.py: skip sm10-unsupported cases (head_dim not in
  {64,128}, num_contextuals>0) and run BWD section.
@JacoCheung JacoCheung force-pushed the junzhang/hstu_blackwell_enable branch from ca5ef37 to b1d927a Compare June 1, 2026 09:27
JacoCheung and others added 14 commits June 1, 2026 02:37
Extends the kernel-only enablement to the full E2E training stack
(DynamicEmb + commons jagged ops + kvcache_manager + multi-node SLURM
launcher). HSTU attention kernel itself is enabled in the previous commit.

- docker/Dockerfile Layer 3: add 10.0 to fbgemm_gpu Path A's
  TORCH_CUDA_ARCH_LIST (DynamicEmb's torch_binding stage links against
  Path A's libs, so sm_100 has to be in there).

- docker/Dockerfile build stage: thread TORCH_CUDA_ARCH_LIST with 10.0
  into dynamicemb / examples/commons / corelib/recsys_kvcache_manager
  setup.py invocations so their CUDA C++ ops compile for sm_100.

- docker/Dockerfile build stage: nvcomp tarball URL now picks linux-sbsa
  on aarch64 (GB200), linux-x86_64 on x86_64. The x86_64-only tarball
  used to fail to link with libnvcomp_static.a on GB200.

- corelib/dynamicemb/CMakeLists.txt: add 100 to CMAKE_CUDA_ARCHITECTURES
  default.

- corelib/dynamicemb/setup.py: add -gencode arch=compute_100,code=sm_100
  to nvcc flags.

- examples/hstu/training/benchmark/scripts/slurm_job.sub: export
  RANK / LOCAL_RANK / WORLD_SIZE / LOCAL_WORLD_SIZE from SLURM env. On
  multi-tray runs (e.g. NVL72 18×4) torchrec.get_local_size() falls back
  to torch.cuda.device_count() if LOCAL_WORLD_SIZE is unset, which
  mismatches the actual tasks-per-node and produces DynamicEmb sharding
  placement errors.
…nstall

- docker/Dockerfile:
    * Bump nvcr.io/nvidia/pytorch base from 26.02-py3 to 26.04-py3.
      26.04 already ships nvidia-cutlass-dsl 4.4.1, the earliest version
      whose AST preprocessor handles 'from . import X' in hstu_blackwell.
    * Drop the uninstall + reinstall of nvidia-cutlass-dsl{,-libs-base,
      -libs-cu13} that PR NVIDIA#379 (Beam search) added to handle the 4.3.x
      shipped by 26.02. On 26.04 it's a same-version round-trip that costs
      ~1-2 min of build time with no correctness benefit.
    * Keep the cute_arch import check as a light verification.
    * Keep the quack-kernels / apache-tvm-ffi / torch-c-dlpack-ext install
      (added by PR NVIDIA#379, needed by beam search regardless of base version).

- .gitmodules: add 'branch = recsys-examples-v26.05' for third_party/FBGEMM
  so the submodule tracks the immutable tag (jiayus-nvidia/FBGEMM:
  recsys-examples-v26.05 → commit 5f13f139) instead of the moving
  junzhang/dev_with_export_and_patch branch tip. Submodule SHA is
  unchanged.
Hoist the slow Path A fbgemm_gpu build (~55 min) out of devel into a
dedicated base_fbgemm stage at the top of the Dockerfile. The new layout:

    BASE_IMAGE  ->  base_fbgemm  ->  devel  ->  build

base_fbgemm contains a single RUN that installs scikit-build + the
explicit cutlass-DSL build deps, clones FBGEMM v1.5.0, and runs the
fbgemm_gpu setup.py against TORCH_CUDA_ARCH_LIST='7.5 8.0 9.0 10.0'.
Everything else (system setup, Megatron-LM, pip deps, TorchRec,
flash-attention cute, FlexKV/NVE, HSTU Path B) now lives in devel,
which FROMs ${BASE_FBGEMM_IMAGE}.

Override BASE_FBGEMM_IMAGE via --build-arg to consume a pre-built
fbgemm image and skip the ~55 min compile (CI side: build_fbgemm
stage gated on BUILD_FBGEMM=1 produces it).

Dockerfile comments also trimmed — removed running session/PR-history
references and cutlass-dsl version-bump notes.
TorchRec links FBGEMM C++ ops at install time, so the two are version-
coupled. Moving TorchRec into base_fbgemm keeps them in lockstep and
removes one layer from the devel stage.
nvcr.io/nvidia/tritonserver:25.11-py3 ships only python3, no python
alias. Stage 1 base_fbgemm runs before stage 2's 'ln /bin/python3
/bin/python' (which is in the tritonserver-only branch of Layer 1),
so 'python setup.py install' fails with command not found.

Switching to 'python3 setup.py' works on both bases (pytorch base
has both python and python3; tritonserver only python3).
Pin was added when base was an older pytorch image; 26.04 now ships a
newer triton (3.7+). Downgrading to 3.6.0 breaks
torch._higher_order_ops.triton_kernel_wrap which references
'create_tma_experimental_metadata' (new in triton 3.7+), causing
ImportError during unit_test_1gpu_a100 (job 322376439).
tritonserver:25.11-py3 lacks python alias, cmake/patchelf, and torch
itself. FBGEMM Path A (stage 1) needs all three before it can run
'python3 setup.py install' (setup.py does 'import torch' at line 21).

Move the TRITONSERVER_BUILD bootstrap (python symlink, cmake/patchelf
apt install, pandas/rich/cloudpickle/psutil/cython pip, torch pip via
cu130 index) from stage 2 Layer 1 to stage 1 'Layer 0-pre'. Pytorch
base skips this branch.

stage 2 Layer 1 now only handles apt liburing/xxhash/ssl + arch
symlinks (unchanged for both bases). ARG TRITONSERVER_BUILD kept in
stage 2 as a guard for future re-use; current Layer 1 no longer
references it.

Stage 3 still gates the libcuda symlink / dynamicemb cmake on
TRITONSERVER_BUILD, unchanged.
FBGEMM v1.5.0 fbgemm_gpu/setup.py line 24 imports 'tabulate'. pytorch
base image happens to ship it; tritonserver base does not, so the
stage-1 FBGEMM Path A build fails with ModuleNotFoundError on tritonserver.

Add tabulate to the Layer 0a pip install alongside setuptools-git-
versioning and scikit-build.
nvcr/pytorch:26.04-py3 has a circular import in torch._higher_order_ops.
triton_kernel_wrap: when wrap_triton() is first called it lazy-loads
triton_kernel_wrap, which at line 24 imports torch._inductor.dependencies;
that triggers full _inductor / _dynamo init which transitively executes
torch._dynamo.variables.functions line 2881 'from torch._higher_order_ops.
triton_kernel_wrap import (create_tma_experimental_metadata, ...)' while
triton_kernel_wrap is still partially initialised -> ImportError.

Pre-loading torch._inductor.dependencies at the top of every triton-using
module forces the cycle to resolve in a sequence that avoids the partial-
init lookup (entry is via _inductor instead of triton_kernel_wrap).

Verified on s4124-0071 against devel_pt26.04_v1.5.0:
  - test_addmm.py             12 passed (was 12F before)
  - test_hstu_layer.py       192 passed
  - test_hstu_op.py         3536 passed (16 skipped)
  - test_hstu_preprocess.py    8 passed
  - test_ln_mul_dropout.py    34 passed
  - test_ln_silu.py            8 passed
  - test_triton_silu.py       27 passed
tabulate is a FBGEMM v1.5.0 setup.py import that nvcr/pytorch:26.04-py3
already ships but nvcr/tritonserver:25.11-py3 does not. Move from the
unconditional Layer 0a pip install into the TRITONSERVER_BUILD=1
bootstrap block alongside the other tritonserver-specific deps.

Pytorch base no longer pays the (tiny) pip dependency-check cost; both
bases land at the same effective state going into FBGEMM Path A.
…E env

Previously arm64 skipped the entire FlexKV + NVE block. Restore the
x86 build chain on both archs; only difference is the PYNVE build env:

  arm64: AVX512 doesn't exist, leave PYNVE_DISABLE_AVX512 unset
  x86:   PYNVE_DISABLE_AVX512=1 (existing behavior)

Also bump FlexKV/NVE TORCH_CUDA_ARCH_LIST to include 10.0 (Blackwell),
matching stage 1 / stage 3.

Drop the post-install 'from flexkv.kvmanager import KVManager' check:
that import dlopens C++ extensions linked to libcuda, which is not
guaranteed to be visible inside the buildkit container (it relies on
runtime nvidia-container-toolkit driver injection that may not apply
on arm64 builders). setup.py returning 0 already verifies the build.
nvcr.io/nvidia/tritonserver:25.11-py3 doesn't ship ninja. Without it,
torch BuildExtension silently falls back to distutils' serial _compile
(MAX_JOBS unused), so HSTU's ~150 .cu files build one-at-a-time. On
a4u8g-0015 the running build_tritonserver_devel_x86 takes ~3h with a
single nvcc; ninja restores >100 parallel nvcc — observed in container
test: nvcc=108, cc1plus=135 within seconds of launch.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@JacoCheung JacoCheung force-pushed the junzhang/hstu_blackwell_enable branch from a7f45b5 to e972f7f Compare June 1, 2026 09:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant