Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions model/kronos.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,20 @@ def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None, sample_l
return x


def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context, pred_len, clip=5, T=1.0, top_k=0, top_p=0.99, sample_count=5, verbose=False):
with torch.no_grad():
def _resolve_amp_dtype(amp_dtype):
if amp_dtype is None:
return torch.float32, False
if amp_dtype == "bfloat16":
return torch.bfloat16, True
raise ValueError(
f"Unsupported amp_dtype {amp_dtype!r}; expected 'bfloat16' or None."
)


def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context, pred_len, clip=5, T=1.0, top_k=0, top_p=0.99, sample_count=5, verbose=False, amp_dtype=None):
autocast_dtype, amp_enabled = _resolve_amp_dtype(amp_dtype)

with torch.no_grad(), torch.autocast(device_type=x.device.type, dtype=autocast_dtype, enabled=amp_enabled):
x = torch.clip(x, -clip, clip)

device = x.device
Expand All @@ -396,7 +408,7 @@ def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context
y_stamp = y_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, y_stamp.size(1), y_stamp.size(2)).to(device)

x_token = tokenizer.encode(x, half=True)

initial_seq_len = x.size(1)
batch_size = x_token[0].size(0)
total_seq_len = initial_seq_len + pred_len
Expand Down Expand Up @@ -463,7 +475,7 @@ def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context
]
z = tokenizer.decode(input_tokens, half=True)
z = z.reshape(-1, sample_count, z.size(1), z.size(2))
preds = z.cpu().numpy()
preds = z.float().cpu().numpy()
preds = np.mean(preds, axis=1)

return preds
Expand All @@ -481,7 +493,7 @@ def calc_time_stamps(x_timestamp):

class KronosPredictor:

def __init__(self, model, tokenizer, device=None, max_context=512, clip=5):
def __init__(self, model, tokenizer, device=None, max_context=512, clip=5, amp_dtype=None):
self.tokenizer = tokenizer
self.model = model
self.max_context = max_context
Expand All @@ -490,7 +502,11 @@ def __init__(self, model, tokenizer, device=None, max_context=512, clip=5):
self.vol_col = 'volume'
self.amt_vol = 'amount'
self.time_cols = ['minute', 'hour', 'weekday', 'day', 'month']


# Validate eagerly so the wrong value fails at construction, not on first predict.
_resolve_amp_dtype(amp_dtype)
self.amp_dtype = amp_dtype

# Auto-detect device if not specified
if device is None:
if torch.cuda.is_available():
Expand All @@ -499,7 +515,7 @@ def __init__(self, model, tokenizer, device=None, max_context=512, clip=5):
device = "mps"
else:
device = "cpu"

self.device = device

self.tokenizer = self.tokenizer.to(self.device)
Expand All @@ -512,7 +528,7 @@ def generate(self, x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count,
y_stamp_tensor = torch.from_numpy(np.array(y_stamp).astype(np.float32)).to(self.device)

preds = auto_regressive_inference(self.tokenizer, self.model, x_tensor, x_stamp_tensor, y_stamp_tensor, self.max_context, pred_len,
self.clip, T, top_k, top_p, sample_count, verbose)
self.clip, T, top_k, top_p, sample_count, verbose, amp_dtype=self.amp_dtype)
preds = preds[:, -pred_len:, :]
return preds

Expand Down