Skip to content

feat(inference): add bf16 mixed-precision via KronosPredictor.amp_dtype#289

Open
Burton-David wants to merge 1 commit into
shiyu-coder:masterfrom
Burton-David:feat/bf16-inference
Open

feat(inference): add bf16 mixed-precision via KronosPredictor.amp_dtype#289
Burton-David wants to merge 1 commit into
shiyu-coder:masterfrom
Burton-David:feat/bf16-inference

Conversation

@Burton-David
Copy link
Copy Markdown

Summary

Adds optional bf16 mixed-precision to KronosPredictor, gated by a new amp_dtype argument on the constructor.

predictor = KronosPredictor(model, tokenizer, device="cuda", amp_dtype="bfloat16")
preds = predictor.predict(df, x_ts, y_ts, pred_len=60, sample_count=5)

When amp_dtype="bfloat16", the body of auto_regressive_inference runs inside torch.autocast(device_type=x.device.type, dtype=torch.bfloat16). None (default) keeps the existing FP32 path bit-exact. The dtype is validated in __init__ so a typo fails at construction rather than on the first predict() call.

z.float().cpu().numpy() replaces z.cpu().numpy() at the end of inference because numpy has no bf16 dtype; the explicit cast is a no-op when amp is off.

Numbers

RTX 4090, torch 2.4.1+cu124. Real SPY 1-min bars from yfinance (7 days), 480-bar history window, sample_count=5, max_context=512. Median of 6 timed iters after 2 warmup, calling KronosPredictor.predict end-to-end (tokenize, autoregressive sampling, detokenize).

pred_len fp32 (ms) bf16 (ms) speedup fp32 peak bf16 peak close[-1] rel diff
10 228.2 146.2 1.56x 0.54 GB 0.70 GB 0.01%
30 690.1 540.2 1.28x 0.54 GB 0.70 GB 0.02%
60 1371.0 1033.6 1.33x 0.54 GB 0.70 GB 0.01%
120 2716.0 1728.2 1.57x 0.54 GB 0.70 GB 0.23%

The bf16 path's peak GPU memory is ~160 MB above FP32 because PyTorch's autocast keeps a cached bf16 view of the FP32 weights during the forward; the same effect goes away with torch.set_float32_matmul_precision-style users who already pre-cast. On a 24 GB card this is well under 1% of capacity.

Last-bar predicted close diverges 0.01-0.23% from the FP32 path; the 0.23% at pred_len=120 is accumulated stochastic-sampling divergence over a longer autoregressive horizon.

Compatibility

amp_dtype = None is the default; existing KronosPredictor(...) calls are unchanged. Opt-in via KronosPredictor(..., amp_dtype="bfloat16").

Composes with #288 (training-side bf16); both share the same amp_dtype = "bfloat16" API.

Wraps the body of auto_regressive_inference in torch.autocast, gated by
a new amp_dtype argument on KronosPredictor and the inference function.
"bfloat16" enables bf16 autocast on the active device; None (default)
keeps the existing FP32 path bit-exact. The dtype is validated eagerly
in KronosPredictor.__init__ so a typo fails at construction rather than
on the first predict call.

z.float() before .cpu().numpy() handles the case where bf16 autocast
leaves the decoded tensor in bf16 (numpy has no bf16 dtype).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant