Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions python/triton/tools/ragged_tma.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def create_ragged_descriptor(T, block_shape, ragged_dim=0):
of potentially unequal size.
The load_ragged and store_ragged device functions can be used to read
and write from subarrays T[batch_offset : batch_offset + batch_size]
and write from subarrays T[slice_off : slice_off + slice_size]
with hardware bounds-checking preventing any sort of leakage outside
the subarray.
"""
Expand Down Expand Up @@ -46,22 +46,22 @@ def create_ragged_descriptor(T, block_shape, ragged_dim=0):


@triton.jit
def to_ragged_indices(batch_offset, batch_size, row):
def to_ragged_indices(slice_off, slice_size, row):
"""
Helper function for load_ragged and store_ragged.
"""

billion = 0x40000000 # == 2**30
x = billion - batch_size + row
y = batch_offset + batch_size
x = billion - slice_size + row
y = slice_off + slice_size

return billion, y, x


@triton.jit
def load_ragged(TMA, batch_offset, batch_size, coords, ragged_dim: tl.constexpr = 0):
def load_ragged(TMA, slice_off, slice_size, coords, ragged_dim: tl.constexpr = 0):
"""
Read from a subarray T[batch_offset : batch_offset + batch_size] with
Read from a subarray T[slice_off : slice_off + slice_size] with
hardware bounds-checking, where reading outside the subarray gives zeros.
Coords should be an appropriately-sized list of integers, just like in
Expand All @@ -70,39 +70,39 @@ def load_ragged(TMA, batch_offset, batch_size, coords, ragged_dim: tl.constexpr

tl.static_assert(len(TMA.shape) == len(coords) + 2, "TMA must be a read-write ragged descriptor")

c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim])
c0, c1, c2 = to_ragged_indices(slice_off, slice_size, coords[ragged_dim])
data = TMA.load([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:])
data = tl.reshape(data, data.shape[2:])
return data


@triton.jit
def store_ragged(TMA, batch_offset, batch_size, coords, data, ragged_dim: tl.constexpr = 0):
def store_ragged(TMA, slice_off, slice_size, coords, data, ragged_dim: tl.constexpr = 0):
"""
Write to a subarray T[batch_offset : batch_offset + batch_size] with
Write to a subarray T[slice_off : slice_off + slice_size] with
hardware bounds-checking, where writes outside the subarray are masked
correctly.
Coords should be an appropriately-sized list of integers, just like in
TMA.store().
"""

c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim])
c0, c1, c2 = to_ragged_indices(slice_off, slice_size, coords[ragged_dim])
data = tl.reshape(data, [1, 1] + data.shape)
TMA.store([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:], data)


@triton.jit
def atomic_add_ragged(TMA, batch_offset, batch_size, coords, data, ragged_dim: tl.constexpr = 0):
def atomic_add_ragged(TMA, slice_off, slice_size, coords, data, ragged_dim: tl.constexpr = 0):
"""
Atomic add into a subarray T[batch_offset : batch_offset + batch_size] with
Atomic add into a subarray T[slice_off : slice_off + slice_size] with
hardware bounds-checking, where adds outside the subarray are masked
correctly.
Coords should be an appropriately-sized list of integers, just like in
TMA.atomic_add().
"""

c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim])
c0, c1, c2 = to_ragged_indices(slice_off, slice_size, coords[ragged_dim])
data = tl.reshape(data, [1, 1] + data.shape)
TMA.atomic_add([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:], data)
11 changes: 5 additions & 6 deletions python/triton_kernels/bench/bench_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import triton_kernels
import triton_kernels.roofline as roofline
import triton_kernels.swiglu
from triton_kernels.matmul_ogs import matmul_ogs, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation
from triton_kernels.matmul import matmul, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation
from triton_kernels.target_info import get_cdna_version
import distributed as triton_dist
from triton_kernels.tensor_details import layout
Expand Down Expand Up @@ -71,7 +71,7 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d

input_x = torch.randn((batch // DP, dim1), device=dev)
expt_assignment = triton_dist.create_expt_assignment(EP, n_expts_tot, torch.device(dev))
triton_dist.initialize_matmul_ogs(batch, dim1, dim2, n_expts_act, n_expts_tot, input_x.dtype)
triton_dist.initialize_matmul(batch, dim1, dim2, n_expts_act, n_expts_tot, input_x.dtype)

# run layer
fpath = Path(tempfile.mktemp())
Expand All @@ -80,17 +80,16 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d
xg = input_x.to(wg.dtype if n_expts_tot > 1 else input_x.dtype)
for i in range(100):
if n_expts_tot > 1: # sparse
logits = matmul_ogs(xg, wg, bg, precision_config=pcg)
logits = matmul(xg, wg, bg, precision_config=pcg)
x, rdata, gather_indx, scatter_indx, metadata = triton_dist.routing(input_x, logits, n_expts_act, EP=EP,
TP=TP, expt_assignment=expt_assignment,
mode="ep_sharding")
else: # dense
x = triton_dist.all_gather(input_x, dim=0)
rdata, gather_indx, scatter_indx, metadata = None, None, None, None
if x.nelement() > 0:
x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1, fused_activation=act)
x = matmul_ogs(x, w2, b2 if rank % TP == 0 else None, rdata, scatter_indx=scatter_indx,
precision_config=pc2)
x = matmul(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1, fused_activation=act)
x = matmul(x, w2, b2 if rank % TP == 0 else None, rdata, scatter_indx=scatter_indx, precision_config=pc2)
x = triton_dist.reduce_scatter(x, n_expts_act, metadata=metadata, expt_assignment=expt_assignment)
proton.finalize()
return roofline.parse_profile(fpath.with_suffix(".hatchet"), useful_op_regex=".*matmul.*")
Expand Down
51 changes: 22 additions & 29 deletions python/triton_kernels/bench/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@
import triton_kernels
import triton_kernels.swiglu
from triton_kernels.reduce import reduce
from triton_kernels.matmul_ogs import RoutingData, GatherIndx, ScatterIndx
from triton_kernels.topk import topk
from triton_kernels.matmul_ogs import matmul_ogs, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation
from triton_kernels.matmul import matmul, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation
from triton_kernels.target_info import get_cdna_version, is_hip, is_cuda, cuda_capability_geq
from triton_kernels.tensor_details import layout
from triton_kernels.tensor import make_ragged_tensor_metadata, remap_ragged_tensor_metadata
from triton_kernels.tensor import RaggedTensorMetadata, make_ragged_tensor_metadata, remap_ragged_tensor_metadata
from triton_kernels.distributed import make_expt_dict_uniform, make_expt_assignment, convert_dp_to_ep, convert_ep_to_dp, ExptAssignment, symm_mem_pool

from bench_utils import quantize_weight
Expand All @@ -40,7 +39,7 @@ def create_expt_assignment(EP: int, n_expts_tot: int, device: torch.device) -> O
return make_expt_assignment(EP, n_expts_tot, expt_dict, device)


def initialize_matmul_ogs(
def initialize_matmul(
batch: int,
dim1: int,
dim2: int,
Expand All @@ -52,7 +51,7 @@ def initialize_matmul_ogs(
return
world_size = dist.get_world_size()
device = torch.cuda.current_device()
symm_mem_pool.initialize_matmul_ogs(
symm_mem_pool.initialize_matmul(
n_tokens_global=batch,
d_input=dim1,
d_model=dim2,
Expand Down Expand Up @@ -146,8 +145,7 @@ def routing(
TP: int = 1,
expt_assignment: Optional[ExptAssignment] = None,
mode: Optional[str] = None,
) -> Tuple[torch.Tensor, RoutingData, GatherIndx, ScatterIndx, Optional[ReduceScatterMetadata]]:
n_expts_tot = logits.shape[-1]
) -> Tuple[torch.Tensor, RaggedTensorMetadata, torch.Tensor, torch.Tensor, Optional[ReduceScatterMetadata]]:
if _is_distributed_launch() and mode:
if mode == "ep_sharding":
if not expt_assignment:
Expand All @@ -170,29 +168,24 @@ def routing(
logits_global_metadata = make_ragged_tensor_metadata(expt_sizes, dispatch_indx.shape[0])
x = convert_dp_to_ep(x, expt_assignment, active_indx, dispatch_indx)
logits_local_metadata = remap_ragged_tensor_metadata(logits_global_metadata, expt_map)
gate_scal = logits_global.vals.flatten()[combine_indx]
rdata = RoutingData(gate_scal, expt_sizes, n_expts_tot // EP, n_expts_act, logits_local_metadata)
reduce_scatter_metadata = ReduceScatterMetadata(
mode=mode,
active_indx=active_indx,
dispatch_indx=dispatch_indx,
combine_indx=combine_indx,
)
return x, rdata, None, None, reduce_scatter_metadata
return x, logits_local_metadata, None, None, reduce_scatter_metadata
else:
raise NotImplementedError(f"Distributed routing mode {mode} is not implemented yet.")
else:
# If mode is not specified or we have a single process, we do single-GPU routing.
logits = topk(logits, n_expts_act, y_indx=y_indx, apply_softmax=not sm_first)
dispatch_indx = logits.mask_metadata.row_sorted_indx
combine_indx = logits.mask_metadata.col_sorted_indx
ragged_batch_metadata = make_ragged_tensor_metadata(logits.mask_metadata.col_sum, dispatch_indx.shape[0])
gate_scal = logits.vals.flatten()[combine_indx]
routing_data = RoutingData(gate_scal, ragged_batch_metadata.slice_sizes, n_expts_tot, n_expts_act,
ragged_batch_metadata)
gather_indx = GatherIndx(combine_indx, dispatch_indx)
scatter_indx = ScatterIndx(dispatch_indx, combine_indx)
return x, routing_data, gather_indx, scatter_indx, None
ragged_metadata = make_ragged_tensor_metadata(logits.mask_metadata.col_sum, dispatch_indx.shape[0])
gather_indx = combine_indx // n_expts_act
scatter_indx = combine_indx
return x, ragged_metadata, gather_indx, scatter_indx, None


def gather_ep(rank, world_size, param, TP, EP):
Expand Down Expand Up @@ -276,14 +269,14 @@ def distributed_run(rank, world_size, batch, dim1, dim2, n_expts_tot, n_expts_ac
w1_full = w2_full = w1_flex_full = w2_flex_full = w1_scale_full = w2_scale_full = None

# precision configs
pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=wg_flex), weight_scale=wg_scale)
pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=wg_flex), b_mx_scale=wg_scale)
act = FusedActivation(FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit"), reduction_n=2),
(1.0, 1.0))
pc1 = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex), weight_scale=w1_scale)
pc2 = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex), weight_scale=w2_scale)
pc1 = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex), b_mx_scale=w1_scale)
pc2 = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex), b_mx_scale=w2_scale)
if rank == 0:
pc1_full = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex_full), weight_scale=w1_scale_full)
pc2_full = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex_full), weight_scale=w2_scale_full)
pc1_full = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex_full), b_mx_scale=w1_scale_full)
pc2_full = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex_full), b_mx_scale=w2_scale_full)
else:
pc1_full = pc2_full = None

Expand All @@ -296,7 +289,7 @@ def distributed_run(rank, world_size, batch, dim1, dim2, n_expts_tot, n_expts_ac
xd = torch.randn((batch // world_size, dim1), device=dev).to(dtype_map[x_dtype])
x0 = all_gather(xd, dim=0)
expt_assignment = create_expt_assignment(EP, n_expts_tot, torch.device(dev))
symm_mem_pool.initialize_matmul_ogs(
symm_mem_pool.initialize_matmul(
n_tokens_global=batch,
d_input=dim1,
d_model=dim2,
Expand All @@ -312,25 +305,25 @@ def distributed_run(rank, world_size, batch, dim1, dim2, n_expts_tot, n_expts_ac
def single(x):
xg = x.to(wg.dtype if n_expts_tot > 1 else x.dtype)
if n_expts_tot > 1:
logits = matmul_ogs(xg, wg, bg, precision_config=pcg)
logits = matmul(xg, wg, bg, precision_config=pcg)
x, rdata, gi, si, _ = routing(x, logits, n_expts_act)
else:
rdata = gi = si = None
x = matmul_ogs(x, w1_full, b1_full, rdata, gather_indx=gi, precision_config=pc1_full, fused_activation=act)
return matmul_ogs(x, w2_full, b2_full, rdata, scatter_indx=si, precision_config=pc2_full)
x = matmul(x, w1_full, b1_full, rdata, gather_indx=gi, precision_config=pc1_full, fused_activation=act)
return matmul(x, w2_full, b2_full, rdata, scatter_indx=si, precision_config=pc2_full)

# distributed pass
def distributed(x):
xg = x.to(wg.dtype if n_expts_tot > 1 else x.dtype)
if n_expts_tot > 1: # sparse
logits = matmul_ogs(xg, wg, bg, precision_config=pcg)
logits = matmul(xg, wg, bg, precision_config=pcg)
x, rdata, gi, si, metadata = routing(x, logits, n_expts_act, EP=EP, TP=TP, expt_assignment=expt_assignment,
mode="ep_sharding")
else: # dense
x = all_gather(x, dim=0)
rdata = gi = si = metadata = None
x = matmul_ogs(x, w1, b1, rdata, gather_indx=gi, precision_config=pc1, fused_activation=act)
x = matmul_ogs(x, w2, b2 if rank % TP == 0 else None, rdata, scatter_indx=si, precision_config=pc2)
x = matmul(x, w1, b1, rdata, gather_indx=gi, precision_config=pc1, fused_activation=act)
x = matmul(x, w2, b2 if rank % TP == 0 else None, rdata, scatter_indx=si, precision_config=pc2)
x = reduce_scatter(x, n_expts_act, metadata=metadata, expt_assignment=expt_assignment)
# gather the result from all GPUs, just for verification
return all_gather(x, dim=0)
Expand Down
23 changes: 11 additions & 12 deletions python/triton_kernels/tests/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from triton_kernels.distributed import convert_dp_to_ep, convert_ep_to_dp, make_expt_dict_uniform, make_expt_dict_random, make_expt_assignment, symm_mem_pool
from triton_kernels.reduce import reduce
from triton_kernels.topk import topk
from triton_kernels.matmul_ogs import matmul_ogs, RoutingData, GatherIndx, ScatterIndx
from triton_kernels.matmul import matmul
from triton_kernels.target_info import is_hip
from triton_kernels.tensor import make_ragged_tensor_metadata, remap_ragged_tensor_metadata
import pytest
Expand Down Expand Up @@ -122,17 +122,18 @@ def routing(logits, n_expts_act, all_gather=False, y_indx=None):
dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx
combine_indx = sparse_logits.mask_metadata.col_sorted_indx
ragged_batch_metadata = make_ragged_tensor_metadata(sparse_logits.mask_metadata.col_sum, dispatch_indx.shape[0])
gate_scal = sparse_logits.vals.flatten()[combine_indx]
routing_data = RoutingData(gate_scal, ragged_batch_metadata.slice_sizes, logits.shape[-1], n_expts_act,
ragged_batch_metadata)
gather_idx = GatherIndx(combine_indx, dispatch_indx)
scatter_idx = ScatterIndx(dispatch_indx, combine_indx)
return routing_data, gather_idx, scatter_idx, sparse_logits.indx
gather_idx = torch.div(combine_indx, n_expts_act, rounding_mode="trunc")
scatter_idx = combine_indx
return ragged_batch_metadata, gather_idx, scatter_idx, sparse_logits.indx


def mixture_of_expt_nosharded(x_global, l_global, w_global, b_global, n_expts_act, y_indx=None):
rdata, combine_indx, dispatch_indx, _ = routing(l_global, n_expts_act, y_indx=y_indx)
y_global = matmul_ogs(x_global, w_global, b_global, rdata, gather_indx=combine_indx, scatter_indx=dispatch_indx)
y_global = matmul(x_global, w_global, b_global, rdata, gather_indx=combine_indx, scatter_indx=dispatch_indx)
y_mask = (dispatch_indx != -1).view(y_global.shape[-2] // n_expts_act, n_expts_act, 1)
y_global = y_global.view(y_global.shape[-2] // n_expts_act, n_expts_act, -1)
y_mask = y_mask.expand_as(y_global)
y_global, _ = reduce(y_global, dim=1, mask=y_mask)
return y_global


Expand All @@ -153,9 +154,7 @@ def mixture_of_expt_epsharded(x_dp_local, l_dp_local, w_ep_local, b_ep_local, ex
y_ep_local = convert_dp_to_ep(x_dp_local, expt_assignment, active_indx, dispatch_indx)
y_ep_local_metadata = remap_ragged_tensor_metadata(x_global_metadata, expt_map)
# matrix multiply
# TODO: clean-up API. `RoutingData` should not exist; we should be passing `y_ep_local_metadata`.
rdata_ep_local = RoutingData(None, expt_sizes, w_ep_local.shape[0], n_expts_act, y_ep_local_metadata)
y_ep_local = matmul_ogs(y_ep_local, w_ep_local, b_ep_local, rdata_ep_local)
y_ep_local = matmul(y_ep_local, w_ep_local, b_ep_local, a_ragged_metadata=y_ep_local_metadata)
# convert x from expert-sorted, ep-local to token-sorted, dp-local
y_dp_local = convert_ep_to_dp(y_ep_local, expt_assignment, active_indx, combine_indx)
# weighted average of the output token from experts
Expand Down Expand Up @@ -208,7 +207,7 @@ def run_mixture():
y_indx=y_indx_global,
)

symm_mem_pool.initialize_matmul_ogs(
symm_mem_pool.initialize_matmul(
n_tokens_global=n_tokens_global,
d_input=d_model,
d_model=d_model,
Expand Down
Loading
Loading