mirror of https://github.com/llvm/torch-mlir
Add canonicalization pattern for maxpool3d with indices op (#3704)
As discussed in https://github.com/llvm/torch-mlir/pull/3652, we should replace maxpool3dwithindices with maxpool3d if indices have no user.pull/3814/head
parent
55ff110dc2
commit
2f9a68cc1e
|
@ -7352,6 +7352,7 @@ def Torch_AtenMaxPool3dWithIndicesOp : Torch_Op<"aten.max_pool3d_with_indices",
|
||||||
printDefaultTorchOp(printer, *this, 6, 2);
|
printDefaultTorchOp(printer, *this, 6, 2);
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AtenMaxPool3dWithIndicesBackwardOp : Torch_Op<"aten.max_pool3d_with_indices_backward", [
|
def Torch_AtenMaxPool3dWithIndicesBackwardOp : Torch_Op<"aten.max_pool3d_with_indices_backward", [
|
||||||
|
|
|
@ -5188,18 +5188,38 @@ OpFoldResult PrimsConvertElementTypeOp::fold(FoldAdaptor adaptor) {
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AtenMaxPool2dWithIndicesOp
|
// AtenMaxPoolWithIndicesOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns(
|
namespace {
|
||||||
RewritePatternSet &patterns, MLIRContext *context) {
|
|
||||||
patterns.add(+[](AtenMaxPool2dWithIndicesOp op, PatternRewriter &rewriter) {
|
template <typename OpTy> struct MaxPoolWithoutIndices {
|
||||||
|
using type = OpTy;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <> struct MaxPoolWithoutIndices<AtenMaxPool2dWithIndicesOp> {
|
||||||
|
using type = AtenMaxPool2dOp;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <> struct MaxPoolWithoutIndices<AtenMaxPool3dWithIndicesOp> {
|
||||||
|
using type = AtenMaxPool3dOp;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
template <typename OpTy>
|
||||||
|
struct SimplifyMaxPoolWithIndices : public mlir::OpRewritePattern<OpTy> {
|
||||||
|
SimplifyMaxPoolWithIndices(mlir::MLIRContext *context)
|
||||||
|
: OpRewritePattern<OpTy>(context, /*benefit=*/1) {}
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(OpTy op, mlir::PatternRewriter &rewriter) const override {
|
||||||
if (!op.getResult1().use_empty()) {
|
if (!op.getResult1().use_empty()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "result1 of MaxPool2dWithIndices should be unused");
|
op, "result1 of MaxPoolWithIndices should be unused");
|
||||||
}
|
}
|
||||||
|
|
||||||
Value result = rewriter.create<Torch::AtenMaxPool2dOp>(
|
Value result = rewriter.create<typename MaxPoolWithoutIndices<OpTy>::type>(
|
||||||
op->getLoc(), op.getResult0().getType(), op.getSelf(),
|
op->getLoc(), op.getResult0().getType(), op.getSelf(),
|
||||||
op.getKernelSize(), op.getStride(), op.getPadding(), op.getDilation(),
|
op.getKernelSize(), op.getStride(), op.getPadding(), op.getDilation(),
|
||||||
op.getCeilMode());
|
op.getCeilMode());
|
||||||
|
@ -5207,7 +5227,17 @@ void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns(
|
||||||
op.getResult0().replaceAllUsesWith(result);
|
op.getResult0().replaceAllUsesWith(result);
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
return success();
|
return success();
|
||||||
});
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns(
|
||||||
|
RewritePatternSet &patterns, MLIRContext *context) {
|
||||||
|
patterns.add<SimplifyMaxPoolWithIndices<AtenMaxPool2dWithIndicesOp>>(context);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AtenMaxPool3dWithIndicesOp::getCanonicalizationPatterns(
|
||||||
|
RewritePatternSet &patterns, MLIRContext *context) {
|
||||||
|
patterns.add<SimplifyMaxPoolWithIndices<AtenMaxPool3dWithIndicesOp>>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -636,7 +636,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
emit("aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
||||||
emit("aten::max_unpool3d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)")
|
emit("aten::max_unpool3d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)")
|
||||||
emit(
|
emit(
|
||||||
"aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
|
"aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)",
|
||||||
|
has_canonicalizer=True,
|
||||||
)
|
)
|
||||||
emit(
|
emit(
|
||||||
"aten::max_pool3d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
|
"aten::max_pool3d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
|
||||||
|
|
|
@ -3136,6 +3136,24 @@ func.func @torch.aten.max_pool2d_with_indices$canonicalize(%arg0: !torch.vtensor
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @torch.aten.max_pool3d_with_indices$canonicalize(
|
||||||
|
// CHECK: %[[ARG:.*]]: !torch.vtensor<[10,64,112,112,112],f32>) -> !torch.vtensor<[10,64,56,56,56],f32> {
|
||||||
|
// CHECK: %[[RET:.*]] = torch.aten.max_pool3d %[[ARG]]
|
||||||
|
// CHECK: return %[[RET]] : !torch.vtensor<[10,64,56,56,56],f32>
|
||||||
|
func.func @torch.aten.max_pool3d_with_indices$canonicalize(%arg0: !torch.vtensor<[10,64,112,112,112],f32>) -> !torch.vtensor<[10,64,56,56,56],f32> {
|
||||||
|
%false = torch.constant.bool false
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%int3 = torch.constant.int 3
|
||||||
|
%29 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%30 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%31 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%result0, %result1 = torch.aten.max_pool3d_with_indices %arg0, %29, %30, %31, %31, %false : !torch.vtensor<[10,64,112,112,112],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[10,64,56,56,56],f32>, !torch.vtensor<[10,64,56,56,56],si64>
|
||||||
|
return %result0 : !torch.vtensor<[10,64,56,56,56],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @torch.aten.clone$no_fold(
|
// CHECK-LABEL: @torch.aten.clone$no_fold(
|
||||||
func.func @torch.aten.clone$no_fold(%arg0: !torch.vtensor<[1,2,50,4],f32>) -> (!torch.tensor) {
|
func.func @torch.aten.clone$no_fold(%arg0: !torch.vtensor<[1,2,50,4],f32>) -> (!torch.tensor) {
|
||||||
// CHECK: %{{.*}} = torch.aten.clone %{{.*}}, %{{.*}} : !torch.vtensor<[1,2,50,4],f32>, !torch.none -> !torch.vtensor
|
// CHECK: %{{.*}} = torch.aten.clone %{{.*}}, %{{.*}} : !torch.vtensor<[1,2,50,4],f32>, !torch.none -> !torch.vtensor
|
||||||
|
|
Loading…
Reference in New Issue