Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions scripts/ci/ci_data.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env bash

# kuairand-1k
wget https://tzrec.oss-cn-beijing.aliyuncs.com/data/test/kuairand-1k-train-c4096-s100-154691868ffa7a07a54aa17ef2fdbb96.parquet -O data/test/kuairand-1k-train-c4096-s100.parquet
wget https://tzrec.oss-cn-beijing.aliyuncs.com/data/test/kuairand-1k-eval-c4096-s100-e6426c57d0ee213283cff106c433452f.parquet -O data/test/kuairand-1k-eval-c4096-s100.parquet
wget https://tzrec.oss-cn-beijing.aliyuncs.com/data/test/kuairand-1k-train-c4096-s100-3c725f3b7de8d38ed281d229e56fab37.parquet -O data/test/kuairand-1k-train-c4096-s100.parquet
wget https://tzrec.oss-cn-beijing.aliyuncs.com/data/test/kuairand-1k-eval-c4096-s100-7e841625beda7501876ea8e2ea76523f.parquet -O data/test/kuairand-1k-eval-c4096-s100.parquet
18 changes: 13 additions & 5 deletions tzrec/modules/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
fx_int_item,
fx_mark_keyed_tensor,
fx_mark_seq_len,
fx_mark_seq_tensor,
fx_mark_tensor,
)

Expand All @@ -55,6 +56,7 @@
torch.fx.wrap(fx_int_item)
torch.fx.wrap(fx_mark_keyed_tensor)
torch.fx.wrap(fx_mark_tensor)
torch.fx.wrap(fx_mark_seq_tensor)
torch.fx.wrap(fx_mark_seq_len)


Expand Down Expand Up @@ -918,10 +920,10 @@ def __init__(

self._group_to_shared_query = OrderedDict()
self._group_to_shared_sequence = OrderedDict()
self._group_to_shared_feature = OrderedDict()
self._group_total_dim = dict()
self._group_output_dims = dict()
self._group_to_is_jagged = dict()
self._group_to_sequence_length = OrderedDict()

feat_to_group_to_emb_name = defaultdict(dict)
for feature_group in feature_groups:
Expand Down Expand Up @@ -955,7 +957,7 @@ def __init__(
feature_names = list(feature_group.feature_names)
shared_query = []
shared_sequence = []
shared_feature = []
group_sequence_length = None

for name in feature_names:
shared_name = name
Expand Down Expand Up @@ -1046,16 +1048,16 @@ def __init__(
shared_sequence.append(shared_info)
sequence_dim += output_dim
sequence_dims.append(output_dim)
group_sequence_length = feature.sequence_length
else:
shared_query.append(shared_info)
query_dim += output_dim
query_dims.append(output_dim)
shared_feature.append(shared_info)
output_dims.append(output_dim)

self._group_to_shared_query[group_name] = shared_query
self._group_to_shared_sequence[group_name] = shared_sequence
self._group_to_shared_feature[group_name] = shared_feature
self._group_to_sequence_length[group_name] = group_sequence_length
self._group_total_dim[f"{group_name}.query"] = query_dim
self._group_total_dim[f"{group_name}.sequence"] = sequence_dim
self._group_output_dims[f"{group_name}.query"] = query_dims
Expand Down Expand Up @@ -1320,7 +1322,13 @@ def forward(

if seq_t_list:
seq_cat_t = torch.cat(seq_t_list, dim=-1)
fx_mark_tensor(f"{group_name}__sequence", seq_cat_t, keys=seq_t_keys)
fx_mark_seq_tensor(
f"{group_name}__sequence",
seq_cat_t,
keys=seq_t_keys,
max_seq_len=self._group_to_sequence_length[group_name],
is_jagged_seq=self._group_to_is_jagged[group_name],
)
results[f"{group_name}.sequence"] = seq_cat_t

return results
5 changes: 1 addition & 4 deletions tzrec/ops/jagged_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@
pytorch_split_2D_jagged,
)

torch.fx.wrap("pytorch_concat_2D_jagged")
torch.fx.wrap("pytorch_split_2D_jagged")

if has_triton():
from tzrec.ops.triton.triton_jagged_tensors import (
triton_concat_2D_jagged,
Expand All @@ -38,7 +35,7 @@
else:
triton_concat_2D_jagged = pytorch_concat_2D_jagged
triton_jagged_dense_bmm_broadcast_add = pytorch_jagged_dense_bmm_broadcast_add
pytorch_split_2D_jagged = pytorch_split_2D_jagged
triton_split_2D_jagged = pytorch_split_2D_jagged


def concat_2D_jagged(
Expand Down
3 changes: 0 additions & 3 deletions tzrec/ops/position.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@
pytorch_add_timestamp_positional_embeddings,
)

torch.fx.wrap("pytorch_add_position_embeddings")
torch.fx.wrap("pytorch_add_timestamp_positional_embeddings")

if has_triton():
from tzrec.ops.triton.triton_position import (
triton_add_position_embeddings,
Expand Down
8 changes: 4 additions & 4 deletions tzrec/ops/pytorch/pt_jagged_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,14 @@ def pytorch_concat_2D_jagged(
) -> torch.Tensor:
if offsets_left is None:
B = values_left.shape[0] // max_len_left
offsets_left_non_optional = max_len_left * torch.arange(
offsets_left_non_optional = max_len_left * fx_arange(
B + 1, device=values_left.device
)
else:
offsets_left_non_optional = offsets_left
if offsets_right is None:
B = values_right.shape[0] // max_len_right
offsets_right_non_optional = max_len_right * torch.arange(
offsets_right_non_optional = max_len_right * fx_arange(
B + 1, device=values_left.device
)
else:
Expand Down Expand Up @@ -129,15 +129,15 @@ def pytorch_split_2D_jagged(
if offsets_left is None:
assert max_len_left is not None
assert offsets_right is not None
offsets_left_non_optional = max_len_left * torch.arange(
offsets_left_non_optional = max_len_left * fx_arange(
offsets_right.shape[0], device=values.device
)
else:
offsets_left_non_optional = offsets_left
if offsets_right is None:
assert max_len_right is not None
assert offsets_left is not None
offsets_right_non_optional = max_len_right * torch.arange(
offsets_right_non_optional = max_len_right * fx_arange(
offsets_left.shape[0], device=values.device
)
else:
Expand Down
12 changes: 5 additions & 7 deletions tzrec/ops/pytorch/pt_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def pytorch_add_position_embeddings(
return torch.ops.fbgemm.jagged_dense_elementwise_add_jagged_output(
jagged,
[jagged_offsets],
dense_values,
dense_values.to(jagged.dtype),
)[0]


Expand Down Expand Up @@ -127,12 +127,10 @@ def pytorch_add_timestamp_positional_embeddings(
ts = torch.log(ts)
else:
ts = torch.sqrt(ts)
ts = (ts / time_bucket_divisor).clamp(min=0).int()
ts = torch.clamp(
ts,
min=0,
max=num_time_buckets,
)
ts = (
(ts / time_bucket_divisor / num_time_buckets).clamp(min=0, max=1)
* num_time_buckets
).int()
position_embeddings = torch.index_select(
pos_embeddings, 0, pos_inds.reshape(-1)
).view(B, max_seq_len, -1)
Expand Down
Loading