Skip to content

Commit d7ebf79

Browse files
authored
[FRONTEND] Clean up imports in code_generator.py (#5440)
## The What In code_generator we previously had a mix of absolute and relative imports (in fact `_utils` was imported from absolutely and relatively). Let's unify to change all the imports to use relative imports since the majority were already. Along the way I regrouped the import statements to be more clear which imports are from our package vs. system packages. I also imported semantic directly because it was previously being used implcitly (i.e. `semantic` isn't exported from `language/__init__.py` so just relying on `import ..language` doesn't explicitly import `semantic`). ## The Why I'm starting to investigate adding typehints throughout the repo. I started with `code_generator` because they'll inherently play with the frontend of the compiler. I noticed the imports were a little confusing so I decided to clean them up.
1 parent 053921b commit d7ebf79

File tree

1 file changed

+17
-16
lines changed

1 file changed

+17
-16
lines changed

python/triton/compiler/code_generator.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@
44
import warnings
55
import os
66
import textwrap
7+
from types import ModuleType
78
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
9+
810
from .. import language
911
from .._C.libtriton import ir
10-
from ..language import constexpr, tensor, str_to_ty
12+
from ..language import constexpr, semantic, str_to_ty, tensor
1113
from ..language.core import _unwrap_if_constexpr, nv_tma_desc_type, _value
1214
from ..runtime.jit import _normalize_ty, get_jit_fn_file_line
1315
# ideally we wouldn't need any runtime component
1416
from ..runtime import JITFunction
17+
from .._utils import list_list_flatten, list_list_unflatten, find_paths_if, get_iterable_path, set_iterable_path
18+
1519
from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
16-
from types import ModuleType
17-
from triton._utils import list_list_flatten, list_list_unflatten
18-
from .._utils import find_paths_if, get_iterable_path, set_iterable_path
1920

2021

2122
def mangle_ty(ty):
@@ -412,12 +413,12 @@ def visit_Return(self, node):
412413
self.builder.ret([])
413414
ret_ty = language.void
414415
elif isinstance(ret_value, language.tuple):
415-
ret_values = [language.semantic.to_tensor(v, self.builder) for v in ret_value.values]
416+
ret_values = [semantic.to_tensor(v, self.builder) for v in ret_value.values]
416417
ret_types = [v.type for v in ret_values]
417418
self.builder.ret([v.handle for v in ret_values])
418419
ret_ty = language.tuple_type(ret_types)
419420
else:
420-
ret = language.semantic.to_tensor(ret_value, self.builder)
421+
ret = semantic.to_tensor(ret_value, self.builder)
421422
self.builder.ret([ret.handle])
422423
ret_ty = ret.type
423424
if self.ret_type is None:
@@ -536,7 +537,7 @@ def _sanitize_value(value):
536537
if value is not None and \
537538
not _is_triton_value(value) and \
538539
not isinstance(value, native_nontensor_types):
539-
value = language.semantic.to_tensor(value, self.builder)
540+
value = semantic.to_tensor(value, self.builder)
540541
return value
541542

542543
values = _sanitize_value(self.visit(node.value))
@@ -762,14 +763,14 @@ def visit_IfExp(self, node):
762763

763764
then_block = self.builder.create_block()
764765
self.builder.set_insertion_point_to_start(then_block)
765-
then_val = language.semantic.to_tensor(self.visit(node.body), self.builder)
766+
then_val = semantic.to_tensor(self.visit(node.body), self.builder)
766767
then_block = self.builder.get_insertion_block()
767768

768769
else_block = self.builder.create_block()
769770
self.builder.set_insertion_point_to_start(else_block)
770771
# do not need to reset lscope since
771772
# ternary expressions cannot define new variables
772-
else_val = language.semantic.to_tensor(self.visit(node.orelse), self.builder)
773+
else_val = semantic.to_tensor(self.visit(node.orelse), self.builder)
773774
else_block = self.builder.get_insertion_block()
774775

775776
self._set_insertion_point_and_loc(ip, last_loc)
@@ -998,14 +999,14 @@ def visit_For(self, node):
998999
step = constexpr(-step.value)
9991000
negative_step = True
10001001
lb, ub = ub, lb
1001-
lb = language.semantic.to_tensor(lb, self.builder)
1002-
ub = language.semantic.to_tensor(ub, self.builder)
1003-
step = language.semantic.to_tensor(step, self.builder)
1002+
lb = semantic.to_tensor(lb, self.builder)
1003+
ub = semantic.to_tensor(ub, self.builder)
1004+
step = semantic.to_tensor(step, self.builder)
10041005
# induction variable type
10051006
if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int():
10061007
raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})")
1007-
iv_type = language.semantic.integer_promote_impl(lb.dtype, ub.dtype)
1008-
iv_type = language.semantic.integer_promote_impl(iv_type, step.dtype)
1008+
iv_type = semantic.integer_promote_impl(lb.dtype, ub.dtype)
1009+
iv_type = semantic.integer_promote_impl(iv_type, step.dtype)
10091010
iv_ir_type = iv_type.to_ir(self.builder)
10101011
iv_is_signed = iv_type.int_signedness == language.core.dtype.SIGNEDNESS.SIGNED
10111012
# lb/ub/step might be constexpr, we need to cast them to tensor
@@ -1076,7 +1077,7 @@ def visit_For(self, node):
10761077
if name in liveins:
10771078
local = self.local_defs[name]
10781079
if isinstance(local, constexpr):
1079-
local = language.semantic.to_tensor(local, self.builder)
1080+
local = semantic.to_tensor(local, self.builder)
10801081
yields.append(local)
10811082

10821083
# create YieldOp
@@ -1231,7 +1232,7 @@ def visit_BoolOp(self, node: ast.BoolOp):
12311232
def visit_Attribute(self, node):
12321233
lhs = self.visit(node.value)
12331234
if _is_triton_tensor(lhs) and node.attr == "T":
1234-
return language.semantic.permute(lhs, (1, 0), builder=self.builder)
1235+
return semantic.permute(lhs, (1, 0), builder=self.builder)
12351236
return getattr(lhs, node.attr)
12361237

12371238
def visit_Expr(self, node):

0 commit comments

Comments
 (0)