|
4 | 4 | import warnings |
5 | 5 | import os |
6 | 6 | import textwrap |
| 7 | +from types import ModuleType |
7 | 8 | from typing import Any, Callable, Dict, Optional, Tuple, Type, Union |
| 9 | + |
8 | 10 | from .. import language |
9 | 11 | from .._C.libtriton import ir |
10 | | -from ..language import constexpr, tensor, str_to_ty |
| 12 | +from ..language import constexpr, semantic, str_to_ty, tensor |
11 | 13 | from ..language.core import _unwrap_if_constexpr, nv_tma_desc_type, _value |
12 | 14 | from ..runtime.jit import _normalize_ty, get_jit_fn_file_line |
13 | 15 | # ideally we wouldn't need any runtime component |
14 | 16 | from ..runtime import JITFunction |
| 17 | +from .._utils import list_list_flatten, list_list_unflatten, find_paths_if, get_iterable_path, set_iterable_path |
| 18 | + |
15 | 19 | from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) |
16 | | -from types import ModuleType |
17 | | -from triton._utils import list_list_flatten, list_list_unflatten |
18 | | -from .._utils import find_paths_if, get_iterable_path, set_iterable_path |
19 | 20 |
|
20 | 21 |
|
21 | 22 | def mangle_ty(ty): |
@@ -412,12 +413,12 @@ def visit_Return(self, node): |
412 | 413 | self.builder.ret([]) |
413 | 414 | ret_ty = language.void |
414 | 415 | elif isinstance(ret_value, language.tuple): |
415 | | - ret_values = [language.semantic.to_tensor(v, self.builder) for v in ret_value.values] |
| 416 | + ret_values = [semantic.to_tensor(v, self.builder) for v in ret_value.values] |
416 | 417 | ret_types = [v.type for v in ret_values] |
417 | 418 | self.builder.ret([v.handle for v in ret_values]) |
418 | 419 | ret_ty = language.tuple_type(ret_types) |
419 | 420 | else: |
420 | | - ret = language.semantic.to_tensor(ret_value, self.builder) |
| 421 | + ret = semantic.to_tensor(ret_value, self.builder) |
421 | 422 | self.builder.ret([ret.handle]) |
422 | 423 | ret_ty = ret.type |
423 | 424 | if self.ret_type is None: |
@@ -536,7 +537,7 @@ def _sanitize_value(value): |
536 | 537 | if value is not None and \ |
537 | 538 | not _is_triton_value(value) and \ |
538 | 539 | not isinstance(value, native_nontensor_types): |
539 | | - value = language.semantic.to_tensor(value, self.builder) |
| 540 | + value = semantic.to_tensor(value, self.builder) |
540 | 541 | return value |
541 | 542 |
|
542 | 543 | values = _sanitize_value(self.visit(node.value)) |
@@ -762,14 +763,14 @@ def visit_IfExp(self, node): |
762 | 763 |
|
763 | 764 | then_block = self.builder.create_block() |
764 | 765 | self.builder.set_insertion_point_to_start(then_block) |
765 | | - then_val = language.semantic.to_tensor(self.visit(node.body), self.builder) |
| 766 | + then_val = semantic.to_tensor(self.visit(node.body), self.builder) |
766 | 767 | then_block = self.builder.get_insertion_block() |
767 | 768 |
|
768 | 769 | else_block = self.builder.create_block() |
769 | 770 | self.builder.set_insertion_point_to_start(else_block) |
770 | 771 | # do not need to reset lscope since |
771 | 772 | # ternary expressions cannot define new variables |
772 | | - else_val = language.semantic.to_tensor(self.visit(node.orelse), self.builder) |
| 773 | + else_val = semantic.to_tensor(self.visit(node.orelse), self.builder) |
773 | 774 | else_block = self.builder.get_insertion_block() |
774 | 775 |
|
775 | 776 | self._set_insertion_point_and_loc(ip, last_loc) |
@@ -998,14 +999,14 @@ def visit_For(self, node): |
998 | 999 | step = constexpr(-step.value) |
999 | 1000 | negative_step = True |
1000 | 1001 | lb, ub = ub, lb |
1001 | | - lb = language.semantic.to_tensor(lb, self.builder) |
1002 | | - ub = language.semantic.to_tensor(ub, self.builder) |
1003 | | - step = language.semantic.to_tensor(step, self.builder) |
| 1002 | + lb = semantic.to_tensor(lb, self.builder) |
| 1003 | + ub = semantic.to_tensor(ub, self.builder) |
| 1004 | + step = semantic.to_tensor(step, self.builder) |
1004 | 1005 | # induction variable type |
1005 | 1006 | if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int(): |
1006 | 1007 | raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})") |
1007 | | - iv_type = language.semantic.integer_promote_impl(lb.dtype, ub.dtype) |
1008 | | - iv_type = language.semantic.integer_promote_impl(iv_type, step.dtype) |
| 1008 | + iv_type = semantic.integer_promote_impl(lb.dtype, ub.dtype) |
| 1009 | + iv_type = semantic.integer_promote_impl(iv_type, step.dtype) |
1009 | 1010 | iv_ir_type = iv_type.to_ir(self.builder) |
1010 | 1011 | iv_is_signed = iv_type.int_signedness == language.core.dtype.SIGNEDNESS.SIGNED |
1011 | 1012 | # lb/ub/step might be constexpr, we need to cast them to tensor |
@@ -1076,7 +1077,7 @@ def visit_For(self, node): |
1076 | 1077 | if name in liveins: |
1077 | 1078 | local = self.local_defs[name] |
1078 | 1079 | if isinstance(local, constexpr): |
1079 | | - local = language.semantic.to_tensor(local, self.builder) |
| 1080 | + local = semantic.to_tensor(local, self.builder) |
1080 | 1081 | yields.append(local) |
1081 | 1082 |
|
1082 | 1083 | # create YieldOp |
@@ -1231,7 +1232,7 @@ def visit_BoolOp(self, node: ast.BoolOp): |
1231 | 1232 | def visit_Attribute(self, node): |
1232 | 1233 | lhs = self.visit(node.value) |
1233 | 1234 | if _is_triton_tensor(lhs) and node.attr == "T": |
1234 | | - return language.semantic.permute(lhs, (1, 0), builder=self.builder) |
| 1235 | + return semantic.permute(lhs, (1, 0), builder=self.builder) |
1235 | 1236 | return getattr(lhs, node.attr) |
1236 | 1237 |
|
1237 | 1238 | def visit_Expr(self, node): |
|
0 commit comments