feat(inference): add bf16 mixed-precision via KronosPredictor.amp_dtype#289
Open
Burton-David wants to merge 1 commit into
Open
feat(inference): add bf16 mixed-precision via KronosPredictor.amp_dtype#289Burton-David wants to merge 1 commit into
Burton-David wants to merge 1 commit into
Conversation
Wraps the body of auto_regressive_inference in torch.autocast, gated by a new amp_dtype argument on KronosPredictor and the inference function. "bfloat16" enables bf16 autocast on the active device; None (default) keeps the existing FP32 path bit-exact. The dtype is validated eagerly in KronosPredictor.__init__ so a typo fails at construction rather than on the first predict call. z.float() before .cpu().numpy() handles the case where bf16 autocast leaves the decoded tensor in bf16 (numpy has no bf16 dtype).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds optional bf16 mixed-precision to
KronosPredictor, gated by a newamp_dtypeargument on the constructor.When
amp_dtype="bfloat16", the body ofauto_regressive_inferenceruns insidetorch.autocast(device_type=x.device.type, dtype=torch.bfloat16).None(default) keeps the existing FP32 path bit-exact. The dtype is validated in__init__so a typo fails at construction rather than on the firstpredict()call.z.float().cpu().numpy()replacesz.cpu().numpy()at the end of inference because numpy has no bf16 dtype; the explicit cast is a no-op when amp is off.Numbers
RTX 4090, torch 2.4.1+cu124. Real SPY 1-min bars from yfinance (7 days), 480-bar history window,
sample_count=5,max_context=512. Median of 6 timed iters after 2 warmup, callingKronosPredictor.predictend-to-end (tokenize, autoregressive sampling, detokenize).The bf16 path's peak GPU memory is ~160 MB above FP32 because PyTorch's autocast keeps a cached bf16 view of the FP32 weights during the forward; the same effect goes away with
torch.set_float32_matmul_precision-style users who already pre-cast. On a 24 GB card this is well under 1% of capacity.Last-bar predicted close diverges 0.01-0.23% from the FP32 path; the 0.23% at
pred_len=120is accumulated stochastic-sampling divergence over a longer autoregressive horizon.Compatibility
amp_dtype = Noneis the default; existingKronosPredictor(...)calls are unchanged. Opt-in viaKronosPredictor(..., amp_dtype="bfloat16").Composes with #288 (training-side bf16); both share the same
amp_dtype = "bfloat16"API.