Skip to content

Commit c7bfa8f

Browse files
committed
Update interp lib
Signed-off-by: zjgarvey <[email protected]>
1 parent 3143abd commit c7bfa8f

File tree

2 files changed

+6
-41
lines changed

2 files changed

+6
-41
lines changed

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 6 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15960,9 +15960,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1596015960
" return %0 : !torch.int\n"
1596115961
" }\n"
1596215962
" func.func @\"__torch_mlir_dtype_fn.aten.linalg_vector_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.optional<list<int>>, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.int {\n"
15963-
" %int6 = torch.constant.int 6\n"
15964-
" %int15 = torch.constant.int 15\n"
15965-
" %int5 = torch.constant.int 5\n"
1596615963
" %true = torch.constant.bool true\n"
1596715964
" %none = torch.constant.none\n"
1596815965
" %str = torch.constant.str \"AssertionError: \"\n"
@@ -16011,22 +16008,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1601116008
" }\n"
1601216009
" torch.prim.If.yield %9 : !torch.int\n"
1601316010
" } else {\n"
16014-
" %5 = torch.prim.ListConstruct %int5, %int15 : (!torch.int, !torch.int) -> !torch.list<int>\n"
16015-
" %6 = torch.aten.__contains__.int_list %5, %0#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
16016-
" %7 = torch.prim.If %6 -> (!torch.int) {\n"
16017-
" torch.prim.If.yield %int6 : !torch.int\n"
16018-
" } else {\n"
16019-
" %8 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
16020-
" torch.prim.If.yield %8 : !torch.int\n"
16021-
" }\n"
16022-
" torch.prim.If.yield %7 : !torch.int\n"
16011+
" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
16012+
" torch.prim.If.yield %5 : !torch.int\n"
1602316013
" }\n"
1602416014
" return %4 : !torch.int\n"
1602516015
" }\n"
1602616016
" func.func @\"__torch_mlir_dtype_fn.aten.linalg_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<number>, %arg2: !torch.optional<list<int>>, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.int {\n"
16027-
" %int6 = torch.constant.int 6\n"
16028-
" %int15 = torch.constant.int 15\n"
16029-
" %int5 = torch.constant.int 5\n"
1603016017
" %true = torch.constant.bool true\n"
1603116018
" %none = torch.constant.none\n"
1603216019
" %str = torch.constant.str \"AssertionError: \"\n"
@@ -16075,15 +16062,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1607516062
" }\n"
1607616063
" torch.prim.If.yield %9 : !torch.int\n"
1607716064
" } else {\n"
16078-
" %5 = torch.prim.ListConstruct %int5, %int15 : (!torch.int, !torch.int) -> !torch.list<int>\n"
16079-
" %6 = torch.aten.__contains__.int_list %5, %0#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
16080-
" %7 = torch.prim.If %6 -> (!torch.int) {\n"
16081-
" torch.prim.If.yield %int6 : !torch.int\n"
16082-
" } else {\n"
16083-
" %8 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
16084-
" torch.prim.If.yield %8 : !torch.int\n"
16085-
" }\n"
16086-
" torch.prim.If.yield %7 : !torch.int\n"
16065+
" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
16066+
" torch.prim.If.yield %5 : !torch.int\n"
1608716067
" }\n"
1608816068
" return %4 : !torch.int\n"
1608916069
" }\n"
@@ -16107,8 +16087,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1610716087
" }\n"
1610816088
" func.func @\"__torch_mlir_dtype_fn.aten.norm.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
1610916089
" %true = torch.constant.bool true\n"
16110-
" %int6 = torch.constant.int 6\n"
16111-
" %int15 = torch.constant.int 15\n"
1611216090
" %int5 = torch.constant.int 5\n"
1611316091
" %int8 = torch.constant.int 8\n"
1611416092
" %none = torch.constant.none\n"
@@ -16126,15 +16104,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1612616104
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
1612716105
" torch.prim.If.yield %int5 : !torch.int\n"
1612816106
" } else {\n"
16129-
" %5 = torch.prim.ListConstruct %int5, %int15 : (!torch.int, !torch.int) -> !torch.list<int>\n"
16130-
" %6 = torch.aten.__contains__.int_list %5, %0#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
16131-
" %7 = torch.prim.If %6 -> (!torch.int) {\n"
16132-
" torch.prim.If.yield %int6 : !torch.int\n"
16133-
" } else {\n"
16134-
" %8 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
16135-
" torch.prim.If.yield %8 : !torch.int\n"
16136-
" }\n"
16137-
" torch.prim.If.yield %7 : !torch.int\n"
16107+
" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
16108+
" torch.prim.If.yield %5 : !torch.int\n"
1613816109
" }\n"
1613916110
" return %4 : !torch.int\n"
1614016111
" }\n"

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5544,8 +5544,6 @@ def aten〇linalg_vector_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Uni
55445544
return aten〇std〡dtype((self_rank, dtype))
55455545
assert not is_complex_dtype(dtype)
55465546
return dtype
5547-
if self_dtype in [torch.float16, torch.bfloat16]:
5548-
return torch.float32
55495547
return aten〇std〡dtype(self_rank_dtype)
55505548

55515549
@check_dtype_function(
@@ -5569,8 +5567,6 @@ def aten〇linalg_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Optional[U
55695567
return aten〇std〡dtype((self_rank, dtype))
55705568
assert not is_complex_dtype(dtype)
55715569
return dtype
5572-
if self_dtype in [torch.float16, torch.bfloat16]:
5573-
return torch.float32
55745570
return aten〇std〡dtype(self_rank_dtype)
55755571

55765572
def aten〇binary_cross_entropy_with_logits〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]] = None, pos_weight_rank_dtype: Optional[Tuple[int, int]] = None, reduction: int = 1) -> int:
@@ -5604,8 +5600,6 @@ def aten〇norm〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], p: Union[int,
56045600
# Should possibly be added to aten〇std〡dtype.
56055601
if self_dtype == torch.complex32:
56065602
return torch.half
5607-
if self_dtype in [torch.float16, torch.bfloat16]:
5608-
return torch.float32
56095603
return aten〇std〡dtype(self_rank_dtype)
56105604

56115605
@check_dtype_function([Invocation(0.0),

0 commit comments

Comments
 (0)