-
Notifications
You must be signed in to change notification settings - Fork 153
Expand file tree
/
Copy pathtemplate.py
More file actions
31 lines (22 loc) · 903 Bytes
/
template.py
File metadata and controls
31 lines (22 loc) · 903 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from task import input_t, output_t
import torch
import helion
import helion.language as hl
# Per-shape configs: map input shape tuples to optimized helion.Config objects.
# Autotune locally for each shape, then paste the best config here.
# Include all test and benchmark shapes from task.yml.
SHAPE_CONFIGS: dict[tuple, helion.Config] = {
# (shape_dim_1, shape_dim_2, ...): helion.Config(...), # TODO: replace with your config
}
def _make_kernel(config: helion.Config):
@helion.kernel(static_shapes=True, config=config)
def kernel(...) -> ...:
# Your Helion kernel implementation
...
return kernel
_KERNELS = {shape: _make_kernel(cfg) for shape, cfg in SHAPE_CONFIGS.items()}
def custom_kernel(data: input_t) -> output_t:
# Extract shape key from input tensors to select the right kernel
# shape_key = (...)
# kernel = _KERNELS[shape_key]
pass