diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 21977d3f5..5931ca2df 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -767,6 +767,7 @@ STABLEHLO_PASS_SET = { "ToDtypeLayoutNoneModule_basic", "ToDtypeLayoutStridedModule_basic", "TypeAsSameModule_basic", + "TypeAsDifferentModule_basic", "TypeConversionF32ToF64Module_basic", "TypeConversionF64ToF32Module_basic", "TypeConversionI1ToF32Module_basic", diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index f3175678b..860b674bb 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -8001,7 +8001,6 @@ def Torch_AtenTypeAsOp : Torch_Op<"aten.type_as", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; - let hasFolder = 1; } def Torch_AtenViewOp : Torch_Op<"aten.view", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 6c48e5b26..e615945f6 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -712,20 +712,6 @@ OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) { return nullptr; } -//===----------------------------------------------------------------------===// -// AtenTypeAsOp -//===----------------------------------------------------------------------===// - -OpFoldResult AtenTypeAsOp::fold(FoldAdaptor adaptor) { - Type inType = getSelf().getType(); - Type newType = getOther().getType(); - - if (inType == newType) - return getSelf(); - - return nullptr; -} - //===----------------------------------------------------------------------===// // AtenToDtypeOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 51a107f43..13996d6db 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -4590,6 +4590,29 @@ public: }; } // namespace +namespace { +// Unconditionally decompose `torch.type_as` into `prim.dtype` + +// `torch.to.dtype`. +class DecomposeAtenTypeAsOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenTypeAsOp op, + PatternRewriter &rewriter) const override { + auto input = op.getSelf(); + auto other = op.getOther(); + Location loc = op.getLoc(); + + Value targetDtype = rewriter.create(loc, other); + Value nonBlocking = rewriter.create(loc, false); + Value copy = rewriter.create(loc, false); + Value memoryFormat = rewriter.create(loc); + rewriter.replaceOpWithNewOp( + op, op.getType(), input, targetDtype, nonBlocking, copy, memoryFormat); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -4759,6 +4782,7 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); GreedyRewriteConfig config; config.useTopDownTraversal = true; diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 7fa9c26f8..fc9b845cf 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -483,6 +483,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); for (auto &opName : backendLegalOpsSet) { target.addLegalOp( OperationName(kTorchOpPrefix + opName.first().str(), context)); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index da314cbb5..94938d2a4 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -521,7 +521,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)", has_canonicalizer=True) emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)") emit("aten::to.device : (Tensor, Device, int, bool, bool, int?) -> (Tensor)") - emit("aten::type_as : (Tensor, Tensor) -> (Tensor)", has_folder=True) + emit("aten::type_as : (Tensor, Tensor) -> (Tensor)") emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True) emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)") emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/type_conversion.py b/python/torch_mlir_e2e_test/test_suite/type_conversion.py index 6e15da5a4..5a4a19c50 100644 --- a/python/torch_mlir_e2e_test/test_suite/type_conversion.py +++ b/python/torch_mlir_e2e_test/test_suite/type_conversion.py @@ -235,6 +235,27 @@ class TypeAsSameModule(torch.nn.Module): def TypeAsSameModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5), tu.rand(3, 5)) +class TypeAsDifferentModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int, True), + ([-1, -1], torch.int64, True), + ]) + def forward(self, x, y): + return torch.ops.aten.type_as(x, y) + + +@register_test_case(module_factory=lambda: TypeAsDifferentModule()) +def TypeAsDifferentModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(3, 5, low=0, high=10, dtype=torch.int), + tu.randint(3, 5, low=0, high=10, dtype=torch.int64) + ) # ============================================================================== diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index d89092f8d..315c5a4fe 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1412,14 +1412,6 @@ func.func @torch.aten.squeeze.dim$zero_rank(%arg0: !torch.tensor<[],f32>) -> !to return %0 : !torch.tensor<[],f32> } -// CHECK-LABEL: func.func @torch.aten.type_as$same( -// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> { -// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[?,?],f32> -func.func @torch.aten.type_as$same(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> { - %0 = torch.aten.type_as %arg0, %arg0 : !torch.tensor<[?,?],f32>, !torch.tensor<[?,?],f32> -> !torch.tensor<[?,?],f32> - return %0 : !torch.tensor<[?,?],f32> -} - // CHECK-LABEL: func.func @torch.aten.to.dtype$same_dtype( // CHECK-SAME: %[[ARG:.*]]: !torch.tensor<*,f32>) -> !torch.tensor<*,f32> { // CHECK-NEXT: return %[[ARG]] : !torch.tensor<*,f32> diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index abaa2860c..17767f9f4 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -79,3 +79,27 @@ func.func @torch.aten.adaptive_avg_pool2d$unit_output_size(%arg0: !torch.vtensor %0 = torch.aten.adaptive_avg_pool2d %arg0, %output_size : !torch.vtensor<[?,?,?,?],f32>, !torch.list -> !torch.vtensor<[?,?,?,?],f32> return %0 : !torch.vtensor<[?,?,?,?],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.type_as$basic( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor { +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[DTYPE:.*]] = torch.prim.dtype %[[ARG_1]] : !torch.tensor -> !torch.int +// CHECK: %[[VAR:.*]] = torch.aten.to.dtype %[[ARG_0]], %[[DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.tensor, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor +// CHECK: return %[[VAR]] : !torch.tensor +func.func @torch.aten.type_as$basic(%arg0: !torch.tensor, %arg1: !torch.tensor) -> !torch.tensor { + %0 = torch.aten.type_as %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tensor + return %0 : !torch.tensor +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.type_as$fold( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor<[?],f16>, %[[ARG_1:.*]]: !torch.tensor<[?,?],f16>) -> !torch.tensor<[?],f16> { +// CHECK: return %[[ARG_0]] : !torch.tensor<[?],f16> +func.func @torch.aten.type_as$fold(%arg0: !torch.tensor<[?], f16>, %arg1: !torch.tensor<[?,?],f16>) -> !torch.tensor<[?],f16> { + %0 = torch.aten.type_as %arg0, %arg1 : !torch.tensor<[?], f16>, !torch.tensor<[?,?],f16> -> !torch.tensor<[?], f16> + return %0 : !torch.tensor<[?], f16> +}