From 11cd7cd9e7705fd69f40fabdad2e0e5b5b738914 Mon Sep 17 00:00:00 2001 From: Ze Zhang Date: Thu, 2 May 2024 00:03:41 -0700 Subject: [PATCH] Folder and Canonicalizer for PrimsConvertElementTypeOp and AtenMaxPool2dWithIndicesOp (#3272) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit While playing with TorchDynamo on ResNet18. I notice following issues: - `prims.convert_element_type` can’t be canonicalized even if the input and the output share the same type - `aten.max_pool2d_with_indices` is always used instead of `aten.max_pool2d`, even if the second returned output (indices) has no user This PR fixes above issues by adding a folder to the PrimsConvertElementTypeOp and a canonicalizer to the AtenMaxPool2dWithIndicesOp Lit test: `cmake --build build --target check-torch-mlir-all` --------- Co-authored-by: Ze Zhang --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 2 + lib/Dialect/Torch/IR/TorchOps.cpp | 39 ++++++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 5 --- .../build_tools/torch_ods_gen.py | 5 ++- test/Dialect/Torch/canonicalize.mlir | 41 +++++++++++++++++++ 5 files changed, 85 insertions(+), 7 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index cb08ffd53..95d92af99 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6720,6 +6720,7 @@ def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices", printDefaultTorchOp(printer, *this, 6, 2); } }]; + let hasCanonicalizer = 1; } def Torch_AtenMaxPool2dWithIndicesBackwardOp : Torch_Op<"aten.max_pool2d_with_indices_backward", [ @@ -15982,6 +15983,7 @@ def Torch_PrimsConvertElementTypeOp : Torch_Op<"prims.convert_element_type", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasFolder = 1; } def Torch_PrimsVarOp : Torch_Op<"prims.var", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 29911961d..1d0ff41f7 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -4715,6 +4715,45 @@ LogicalResult AtenPermuteOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// PrimsConvertElementTypeOp +//===----------------------------------------------------------------------===// + +OpFoldResult PrimsConvertElementTypeOp::fold(FoldAdaptor adaptor) { + auto inputType = cast(getA().getType()); + auto outputType = cast(getResult().getType()); + if (inputType != outputType) + return nullptr; + if (!inputType.hasDtype() || !outputType.hasDtype()) + return nullptr; + if (inputType.getDtype() != outputType.getDtype()) + return nullptr; + return getA(); +} + +//===----------------------------------------------------------------------===// +// AtenMaxPool2dWithIndicesOp +//===----------------------------------------------------------------------===// + +void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(+[](AtenMaxPool2dWithIndicesOp op, PatternRewriter &rewriter) { + if (!op.getResult1().use_empty()) { + return rewriter.notifyMatchFailure( + op, "result1 of MaxPool2dWithIndices should be unused"); + } + + Value result = rewriter.create( + op->getLoc(), op.getResult0().getType(), op.getSelf(), + op.getKernelSize(), op.getStride(), op.getPadding(), op.getDilation(), + op.getCeilMode()); + + op.getResult0().replaceAllUsesWith(result); + rewriter.eraseOp(op); + return success(); + }); +} + //===----------------------------------------------------------------------===// // AtenLinalgCrossOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 33f1ed702..d8529cb38 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1924,11 +1924,6 @@ MAKE_FX_TOSA_PASS_SET = ( # Dynamic shape, has extra unsupported broadcast ops "Matmul_3d", "MatmulStaticBroadcast_basic", - # failed to legalize operation 'torch.aten.max_pool2d_with_indices - "MaxPool2dEmptyStrideStaticModule_basic", - "MaxPool2dStaticCeilModeTrueModule_basic", - "MaxPool2dStaticModule_basic", - "ResNet18StaticModule_basic", # Unimplemented operator 'aten._index_put_impl_.hacked_twin' "IndexPutImpl1DFloatNonAccumulateModule_basic", "IndexPutImpl1DIntNonAccumulateModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index e0329c8df..d4d547456 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -594,7 +594,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") emit( - "aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)" + "aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)", + has_canonicalizer=True, ) emit( "aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)" @@ -1104,7 +1105,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): # `prims::` namespace. # ========================================================================== - emit("prims::convert_element_type : (Tensor, int) -> (Tensor)") + emit("prims::convert_element_type : (Tensor, int) -> (Tensor)", has_folder=True) emit("prims::var : (Tensor, int[]?, float, int?) -> (Tensor)") emit("prims::sqrt : (Tensor) -> (Tensor)") emit("prims::collapse : (Tensor, int, int) -> (Tensor)") diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index a317e4011..e7605f661 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -2974,3 +2974,44 @@ func.func @aten_log$fold_splat_f32() -> !torch.vtensor<[4], f32> { %result = torch.aten.log %cst : !torch.vtensor<[4], f32> -> !torch.vtensor<[4], f32> return %result : !torch.vtensor<[4], f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.prims.convert_element_type$fold( +// CHECK: %[[ARG:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],f32> { +// CHECK: return %[[ARG]] : !torch.vtensor<[64],f32> +func.func @torch.prims.convert_element_type$fold(%arg0: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],f32> { + %int6 = torch.constant.int 6 + %0 = torch.prims.convert_element_type %arg0, %int6 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + return %0 : !torch.vtensor<[64],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.prims.convert_element_type$no_fold( +// CHECK: %[[ARG:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],si32> { +// CHECK: %[[RET:.*]] = torch.prims.convert_element_type %[[ARG]], %{{.*}} : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],si32> +// CHECK: return %[[RET]] : !torch.vtensor<[64],si32> +func.func @torch.prims.convert_element_type$no_fold(%arg0: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],si32> { + %int6 = torch.constant.int 6 + %0 = torch.prims.convert_element_type %arg0, %int6 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],si32> + return %0 : !torch.vtensor<[64],si32> +} + +// ----- + +// CHECK-LABEL: @torch.aten.max_pool2d_with_indices$canonicalize( +// CHECK: %[[ARG:.*]]: !torch.vtensor<[10,64,112,112],f32>) -> !torch.vtensor<[10,64,56,56],f32> { +// CHECK: %[[RET:.*]] = torch.aten.max_pool2d %[[ARG]] +// CHECK: return %[[RET]] : !torch.vtensor<[10,64,56,56],f32> +func.func @torch.aten.max_pool2d_with_indices$canonicalize(%arg0: !torch.vtensor<[10,64,112,112],f32>) -> !torch.vtensor<[10,64,56,56],f32> { + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %29 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %30 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %31 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %result0, %result1 = torch.aten.max_pool2d_with_indices %arg0, %29, %30, %31, %31, %false : !torch.vtensor<[10,64,112,112],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[10,64,56,56],f32>, !torch.vtensor<[10,64,56,56],si64> + return %result0 : !torch.vtensor<[10,64,56,56],f32> +}