Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions lib/Conversion/TorchToSCF/TorchToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ class ConvertTorchPrimLoopWhileLikeOp : public OpConversionPattern<PrimLoopOp> {
targetType = Torch::IntType::get(op->getContext());
torchArg = typeConverter->materializeSourceConversion(
rewriter, scfWhileOp.getLoc(), targetType, {to});
} else if (auto tty = dyn_cast<RankedTensorType>(targetType)) {
targetType = op.getIterArgsInit()[barg.index()].getType();
torchArg = typeConverter->materializeSourceConversion(
rewriter, scfWhileOp.getLoc(), targetType, {to});
}
if (!torchArg)
return rewriter.notifyMatchFailure(op,
Expand All @@ -173,14 +177,6 @@ class ConvertTorchPrimLoopWhileLikeOp : public OpConversionPattern<PrimLoopOp> {
"unsupported type of the operand");
loopConditionIterArgs.push_back(shouldContinue);
for (auto torchArg : primLoopConditionOp.getIterArgs()) {
Type torchType = torchArg.getType();

// If the argument is a torch tensor, directly add it in the list of
// iter args.
if (isa<Torch::BaseTensorType>(torchType)) {
loopConditionIterArgs.push_back(torchArg);
continue;
}
Value arg = typeConverter->materializeTargetConversion(
rewriter, scfWhileOp->getLoc(),
typeConverter->convertType(torchArg.getType()), {torchArg});
Expand Down
3 changes: 3 additions & 0 deletions lib/Dialect/Torch/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ void mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline(

void mlir::torch::Torch::createTorchDynamoExportToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
// Inline func.call operations created by higher-order ops like while_loop
// to conform to the linalg-on-tensors backend contract.
pm.addPass(createInlinerPass());
pm.addNestedPass<func::FuncOp>(
createReduceOpVariantsPass(options.extraLibrary));
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
Expand Down
12 changes: 12 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,8 @@
"IsFloatingPointInt_False",
"TorchPrimLoopForLikeModule_basic",
"TorchPrimLoopWhileLikeModule_basic",
# torch._dynamo.exc.BackendCompilerFailed: Unsupported op: get_attr
"TorchPrimLoopWhileLikeHOPModule_basic",
"ScalarConstantTupleModule_basic",
# END tests failing due to: empty graph in dynamo
# ERROR due to: backend never runs because of empty frame
Expand Down Expand Up @@ -481,6 +483,7 @@
"TensorToBoolZeroRank_basic",
"TensorToBool_basic",
"ThresholdBackward2dMixedModule_basic",
"TorchPrimLoopWhileLikeHOPModule_basic", # Compilation error: failed to legalize operation 'func.call'
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
"UpSampleNearest2dDynamicFactor_basic",
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
Expand Down Expand Up @@ -993,6 +996,8 @@
"ElementwiseClampMinModule_bfloat16",
"ElementwiseClampModule_bfloat16",
"ElementwiseReluModule_bfloat16",
# Runtime error: failed to legalize operation 'torch.constant.int'
"TorchPrimLoopWhileLikeHOPModule_basic",
}

FX_IMPORTER_STABLEHLO_CRASHING_SET = {
Expand Down Expand Up @@ -2575,6 +2580,7 @@

LTC_XFAIL_SET = {
"TorchPrimLoopForLikeTensorArgModule_basic" "CollapseAllDimensionsModule_basic",
"TorchPrimLoopWhileLikeHOPModule_basic",
"CollapseRank1DynamicModule_basic",
"CollapseStaticModule_basic",
"CollapsePartialDynamicModule_basic",
Expand Down Expand Up @@ -3261,6 +3267,8 @@
"ToCopyWithDTypeModule_basic",
"TorchPrimLoopForLikeModule_basic",
"TorchPrimLoopWhileLikeModule_basic",
# RuntimeError: Detected that you are using FX to torch.jit.trace a dynamo-optimized function
"TorchPrimLoopWhileLikeHOPModule_basic",
"TraceModule_basic",
"TraceModule_empty",
"TraceModule_nonsquare",
Expand Down Expand Up @@ -3957,6 +3965,8 @@
"ThresholdBackward2dMixedModule_basic",
"TorchPrimLoopForLikeModule_basic",
"TorchPrimLoopWhileLikeModule_basic",
# Runtime error: failed to legalize operation 'torch.aten.Bool.Tensor'
"TorchPrimLoopWhileLikeHOPModule_basic",
"TraceModule_empty",
"TraceUnsignedIntModule_empty",
"TransposedConv1dNegativePadding_basic",
Expand Down Expand Up @@ -5036,6 +5046,8 @@
"ToDtypeFloatFromIntModule_basic",
"TorchPrimLoopForLikeModule_basic",
"TorchPrimLoopWhileLikeModule_basic",
# RuntimeError: Detected that you are using FX to torch.jit.trace a dynamo-optimized function
"TorchPrimLoopWhileLikeHOPModule_basic",
"TraceModule_basic",
"TraceModule_empty",
"TraceModule_nonsquare",
Expand Down
34 changes: 34 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch_mlir_e2e_test.framework import TestUtils
from torch_mlir_e2e_test.registry import register_test_case
from torch_mlir_e2e_test.annotations import annotate_args, export
from torch._higher_order_ops.while_loop import while_loop

# ==============================================================================

Expand Down Expand Up @@ -78,3 +79,36 @@ def TorchPrimLoopForLikeTensorArgModule_basic(module, tu: TestUtils):
x_test = torch.zeros([7, 9]).float()

module.forward(x_test)


# ==============================================================================


class TorchPrimLoopWhileLikeHOPModule(torch.nn.Module):
def __init__(self):
super().__init__()

def body_fn(self, i, x):
return i + 1, x + 1

def cond_fn(self, i, x):
return i < 3

@export
@annotate_args(
[
None,
([7, 9], torch.float32, True),
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
i0 = torch.tensor(0)
out_i, out_x = while_loop(self.cond_fn, self.body_fn, (i0, x))
return out_i, out_x


@register_test_case(module_factory=lambda: TorchPrimLoopWhileLikeHOPModule())
def TorchPrimLoopWhileLikeHOPModule_basic(module, tu: TestUtils):
x_test = torch.zeros([7, 9]).float()

module.forward(x_test)
Loading
Loading