diff --git a/model/kronos.py b/model/kronos.py index ce4494ee..24b14dc9 100644 --- a/model/kronos.py +++ b/model/kronos.py @@ -386,8 +386,20 @@ def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None, sample_l return x -def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context, pred_len, clip=5, T=1.0, top_k=0, top_p=0.99, sample_count=5, verbose=False): - with torch.no_grad(): +def _resolve_amp_dtype(amp_dtype): + if amp_dtype is None: + return torch.float32, False + if amp_dtype == "bfloat16": + return torch.bfloat16, True + raise ValueError( + f"Unsupported amp_dtype {amp_dtype!r}; expected 'bfloat16' or None." + ) + + +def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context, pred_len, clip=5, T=1.0, top_k=0, top_p=0.99, sample_count=5, verbose=False, amp_dtype=None): + autocast_dtype, amp_enabled = _resolve_amp_dtype(amp_dtype) + + with torch.no_grad(), torch.autocast(device_type=x.device.type, dtype=autocast_dtype, enabled=amp_enabled): x = torch.clip(x, -clip, clip) device = x.device @@ -396,7 +408,7 @@ def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context y_stamp = y_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, y_stamp.size(1), y_stamp.size(2)).to(device) x_token = tokenizer.encode(x, half=True) - + initial_seq_len = x.size(1) batch_size = x_token[0].size(0) total_seq_len = initial_seq_len + pred_len @@ -463,7 +475,7 @@ def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context ] z = tokenizer.decode(input_tokens, half=True) z = z.reshape(-1, sample_count, z.size(1), z.size(2)) - preds = z.cpu().numpy() + preds = z.float().cpu().numpy() preds = np.mean(preds, axis=1) return preds @@ -481,7 +493,7 @@ def calc_time_stamps(x_timestamp): class KronosPredictor: - def __init__(self, model, tokenizer, device=None, max_context=512, clip=5): + def __init__(self, model, tokenizer, device=None, max_context=512, clip=5, amp_dtype=None): self.tokenizer = tokenizer self.model = model self.max_context = max_context @@ -490,7 +502,11 @@ def __init__(self, model, tokenizer, device=None, max_context=512, clip=5): self.vol_col = 'volume' self.amt_vol = 'amount' self.time_cols = ['minute', 'hour', 'weekday', 'day', 'month'] - + + # Validate eagerly so the wrong value fails at construction, not on first predict. + _resolve_amp_dtype(amp_dtype) + self.amp_dtype = amp_dtype + # Auto-detect device if not specified if device is None: if torch.cuda.is_available(): @@ -499,7 +515,7 @@ def __init__(self, model, tokenizer, device=None, max_context=512, clip=5): device = "mps" else: device = "cpu" - + self.device = device self.tokenizer = self.tokenizer.to(self.device) @@ -512,7 +528,7 @@ def generate(self, x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, y_stamp_tensor = torch.from_numpy(np.array(y_stamp).astype(np.float32)).to(self.device) preds = auto_regressive_inference(self.tokenizer, self.model, x_tensor, x_stamp_tensor, y_stamp_tensor, self.max_context, pred_len, - self.clip, T, top_k, top_p, sample_count, verbose) + self.clip, T, top_k, top_p, sample_count, verbose, amp_dtype=self.amp_dtype) preds = preds[:, -pred_len:, :] return preds