mirror of https://github.com/llvm/torch-mlir
Folder and Canonicalizer for PrimsConvertElementTypeOp and AtenMaxPool2dWithIndicesOp (#3272)
While playing with TorchDynamo on ResNet18. I notice following issues: - `prims.convert_element_type` can’t be canonicalized even if the input and the output share the same type - `aten.max_pool2d_with_indices` is always used instead of `aten.max_pool2d`, even if the second returned output (indices) has no user This PR fixes above issues by adding a folder to the PrimsConvertElementTypeOp and a canonicalizer to the AtenMaxPool2dWithIndicesOp Lit test: `cmake --build build --target check-torch-mlir-all` --------- Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>pull/3275/head
parent
8c48135a42
commit
11cd7cd9e7
|
@ -6720,6 +6720,7 @@ def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices",
|
|||
printDefaultTorchOp(printer, *this, 6, 2);
|
||||
}
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenMaxPool2dWithIndicesBackwardOp : Torch_Op<"aten.max_pool2d_with_indices_backward", [
|
||||
|
@ -15982,6 +15983,7 @@ def Torch_PrimsConvertElementTypeOp : Torch_Op<"prims.convert_element_type", [
|
|||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_PrimsVarOp : Torch_Op<"prims.var", [
|
||||
|
|
|
@ -4715,6 +4715,45 @@ LogicalResult AtenPermuteOp::verify() {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PrimsConvertElementTypeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult PrimsConvertElementTypeOp::fold(FoldAdaptor adaptor) {
|
||||
auto inputType = cast<BaseTensorType>(getA().getType());
|
||||
auto outputType = cast<BaseTensorType>(getResult().getType());
|
||||
if (inputType != outputType)
|
||||
return nullptr;
|
||||
if (!inputType.hasDtype() || !outputType.hasDtype())
|
||||
return nullptr;
|
||||
if (inputType.getDtype() != outputType.getDtype())
|
||||
return nullptr;
|
||||
return getA();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenMaxPool2dWithIndicesOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
patterns.add(+[](AtenMaxPool2dWithIndicesOp op, PatternRewriter &rewriter) {
|
||||
if (!op.getResult1().use_empty()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "result1 of MaxPool2dWithIndices should be unused");
|
||||
}
|
||||
|
||||
Value result = rewriter.create<Torch::AtenMaxPool2dOp>(
|
||||
op->getLoc(), op.getResult0().getType(), op.getSelf(),
|
||||
op.getKernelSize(), op.getStride(), op.getPadding(), op.getDilation(),
|
||||
op.getCeilMode());
|
||||
|
||||
op.getResult0().replaceAllUsesWith(result);
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenLinalgCrossOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1924,11 +1924,6 @@ MAKE_FX_TOSA_PASS_SET = (
|
|||
# Dynamic shape, has extra unsupported broadcast ops
|
||||
"Matmul_3d",
|
||||
"MatmulStaticBroadcast_basic",
|
||||
# failed to legalize operation 'torch.aten.max_pool2d_with_indices
|
||||
"MaxPool2dEmptyStrideStaticModule_basic",
|
||||
"MaxPool2dStaticCeilModeTrueModule_basic",
|
||||
"MaxPool2dStaticModule_basic",
|
||||
"ResNet18StaticModule_basic",
|
||||
# Unimplemented operator 'aten._index_put_impl_.hacked_twin'
|
||||
"IndexPutImpl1DFloatNonAccumulateModule_basic",
|
||||
"IndexPutImpl1DIntNonAccumulateModule_basic",
|
||||
|
|
|
@ -594,7 +594,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
||||
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
||||
emit(
|
||||
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
|
||||
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)",
|
||||
has_canonicalizer=True,
|
||||
)
|
||||
emit(
|
||||
"aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
|
||||
|
@ -1104,7 +1105,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
# `prims::` namespace.
|
||||
# ==========================================================================
|
||||
|
||||
emit("prims::convert_element_type : (Tensor, int) -> (Tensor)")
|
||||
emit("prims::convert_element_type : (Tensor, int) -> (Tensor)", has_folder=True)
|
||||
emit("prims::var : (Tensor, int[]?, float, int?) -> (Tensor)")
|
||||
emit("prims::sqrt : (Tensor) -> (Tensor)")
|
||||
emit("prims::collapse : (Tensor, int, int) -> (Tensor)")
|
||||
|
|
|
@ -2974,3 +2974,44 @@ func.func @aten_log$fold_splat_f32() -> !torch.vtensor<[4], f32> {
|
|||
%result = torch.aten.log %cst : !torch.vtensor<[4], f32> -> !torch.vtensor<[4], f32>
|
||||
return %result : !torch.vtensor<[4], f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.prims.convert_element_type$fold(
|
||||
// CHECK: %[[ARG:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],f32> {
|
||||
// CHECK: return %[[ARG]] : !torch.vtensor<[64],f32>
|
||||
func.func @torch.prims.convert_element_type$fold(%arg0: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],f32> {
|
||||
%int6 = torch.constant.int 6
|
||||
%0 = torch.prims.convert_element_type %arg0, %int6 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32>
|
||||
return %0 : !torch.vtensor<[64],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.prims.convert_element_type$no_fold(
|
||||
// CHECK: %[[ARG:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],si32> {
|
||||
// CHECK: %[[RET:.*]] = torch.prims.convert_element_type %[[ARG]], %{{.*}} : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],si32>
|
||||
// CHECK: return %[[RET]] : !torch.vtensor<[64],si32>
|
||||
func.func @torch.prims.convert_element_type$no_fold(%arg0: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],si32> {
|
||||
%int6 = torch.constant.int 6
|
||||
%0 = torch.prims.convert_element_type %arg0, %int6 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],si32>
|
||||
return %0 : !torch.vtensor<[64],si32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @torch.aten.max_pool2d_with_indices$canonicalize(
|
||||
// CHECK: %[[ARG:.*]]: !torch.vtensor<[10,64,112,112],f32>) -> !torch.vtensor<[10,64,56,56],f32> {
|
||||
// CHECK: %[[RET:.*]] = torch.aten.max_pool2d %[[ARG]]
|
||||
// CHECK: return %[[RET]] : !torch.vtensor<[10,64,56,56],f32>
|
||||
func.func @torch.aten.max_pool2d_with_indices$canonicalize(%arg0: !torch.vtensor<[10,64,112,112],f32>) -> !torch.vtensor<[10,64,56,56],f32> {
|
||||
%false = torch.constant.bool false
|
||||
%int1 = torch.constant.int 1
|
||||
%int2 = torch.constant.int 2
|
||||
%int3 = torch.constant.int 3
|
||||
%29 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%30 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%31 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%result0, %result1 = torch.aten.max_pool2d_with_indices %arg0, %29, %30, %31, %31, %false : !torch.vtensor<[10,64,112,112],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[10,64,56,56],f32>, !torch.vtensor<[10,64,56,56],si64>
|
||||
return %result0 : !torch.vtensor<[10,64,56,56],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue