From 4f3cd236dd443ba714b24ea10d7151f0e0fb7625 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=AD=A6=E5=AE=B6=E4=BC=9F?= <73166454+Vremold@users.noreply.github.com> Date: Tue, 20 Sep 2022 12:40:19 +0800 Subject: [PATCH] Strength the shape inference for aten.arange-like op (#1367) Strength the shape inference for aten.arange-like op by 1. registering aten.sub and aten.ceil.Scalar op and design folders for them. 2. register a new constant-like op: Torch::ConstantNumberOp and design canonicalizer for it. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 49 ++++++++ .../torch-mlir/Dialect/Torch/IR/TorchOps.td | 22 ++++ lib/Dialect/Torch/IR/TorchDialect.cpp | 8 ++ lib/Dialect/Torch/IR/TorchOps.cpp | 108 ++++++++++++++++-- lib/Dialect/Torch/Transforms/ShapeLibrary.cpp | 10 +- .../jit_ir/build_tools/torch_ods_gen.py | 2 + 6 files changed, 183 insertions(+), 16 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 87d52ff94..4eb65b80c 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -8914,6 +8914,55 @@ def Torch_AtenAddOp : Torch_Op<"aten.add", [ }]; } +def Torch_AtenSubOp : Torch_Op<"aten.sub", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::sub : (Scalar, Scalar) -> (Scalar)`"; + let arguments = (ins + AnyTorchScalarType:$a, + AnyTorchScalarType:$b + ); + let results = (outs + AnyTorchScalarType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSubOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenSubOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; +} + +def Torch_AtenCeilScalarOp : Torch_Op<"aten.ceil.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::ceil.Scalar : (Scalar) -> (Scalar)`"; + let arguments = (ins + AnyTorchScalarType:$a + ); + let results = (outs + AnyTorchScalarType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCeilScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenCeilScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; + let hasFolder = 1; +} + def Torch_AtenSqrtIntOp : Torch_Op<"aten.sqrt.int", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index d46431e6c..df73855a3 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -753,6 +753,28 @@ def Torch_ConstantFloatOp : Torch_Op<"constant.float", [ let hasFolder = 1; } +def Torch_ConstantNumberOp : Torch_Op<"constant.number", + [ConstantLike, NoSideEffect]> { + let summary = "Materialize a constant `number` value."; + let description = [{ + This op is used as a workaround to the fact that the constant + materialization in MLIR must materialize a constant with a single op. + To materialize ops with a static `!torch.number` type, we must use this op, + even though we statically know if it is an integer or a float. + + Note: This op unconditionally canonicalizes to + `torch.constant.{float,int}` + `torch.derefine` + }]; + let arguments = (ins + AnyAttrOf<[F64Attr, I64Attr]>:$value + ); + let results = (outs + Torch_NumberType:$result + ); + let hasFolder = 1; + let hasCanonicalizer = 1; +} + def Torch_ConstantBoolOp : Torch_Op<"constant.bool", [ ConstantLike, NoSideEffect, diff --git a/lib/Dialect/Torch/IR/TorchDialect.cpp b/lib/Dialect/Torch/IR/TorchDialect.cpp index 835be0316..a29c2e16a 100644 --- a/lib/Dialect/Torch/IR/TorchDialect.cpp +++ b/lib/Dialect/Torch/IR/TorchDialect.cpp @@ -149,6 +149,14 @@ Operation *TorchDialect::materializeConstant(OpBuilder &builder, if (auto floatType = type.dyn_cast()) return builder.create(loc, value.cast()); + if (auto numberType = type.dyn_cast()) { + if (auto floatValue = value.dyn_cast()) { + return builder.create(loc, floatValue); + } else if (auto intValue = value.dyn_cast()) { + return builder.create(loc, intValue); + } + } + if (type.isa()) { return builder.create(loc, value.cast()); diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 2a54c55ec..0cae61f20 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1591,6 +1591,34 @@ void Torch::ConstantFloatOp::getAsmResultNames( setNameFn(getResult(), StringRef(buf.data(), buf.size())); } +//===----------------------------------------------------------------------===// +// ConstantNumberOp +//===----------------------------------------------------------------------===// + +OpFoldResult Torch::ConstantNumberOp::fold(ArrayRef operands) { + return valueAttr(); +} + +void Torch::ConstantNumberOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(+[](Torch::ConstantNumberOp op, PatternRewriter &rewriter) { + Location loc = op->getLoc(); + + Value constValue; + Attribute value = op.valueAttr(); + if (auto floatValue = value.dyn_cast()) { + constValue = rewriter.create(loc, floatValue); + } else if (auto intValue = value.dyn_cast()) { + constValue = rewriter.create(loc, intValue); + } else { + return failure(); + } + rewriter.replaceOpWithNewOp(op, op.getType(), + constValue); + return success(); + }); +} + //===----------------------------------------------------------------------===// // ConstantBoolOp //===----------------------------------------------------------------------===// @@ -1886,15 +1914,39 @@ OpFoldResult Aten__Contains__IntListOp::fold(ArrayRef operands) { } using BinaryIntOperatorFn = std::function; -template -static OpFoldResult atenBinaryIntOperatorFoldHelper(OpTy op, - BinaryIntOperatorFn f) { - int64_t lhs, rhs; - if (!matchPattern(op.getOperand(0), m_TorchConstantInt(&lhs)) || - !matchPattern(op.getOperand(1), m_TorchConstantInt(&rhs))) +static OpFoldResult +atenBinaryIntOperatorFoldHelper(ArrayRef operands, + BinaryIntOperatorFn f) { + auto intLhs = operands[0].dyn_cast_or_null(); + auto intRhs = operands[1].dyn_cast_or_null(); + if (!intLhs || !intRhs) { return nullptr; + } + return IntegerAttr::get( + intLhs.getType(), + f(intLhs.getValue().getSExtValue(), intRhs.getValue().getSExtValue())); +} - return getI64IntegerAttr(op.getContext(), f(lhs, rhs)); +using BinaryFloatOperatorFn = std::function; +static OpFoldResult +atenBinaryFloatOperatorFoldHelper(ArrayRef operands, + BinaryFloatOperatorFn f) { + double lhs, rhs; + auto parseDoubleAttribute = [](Attribute attr, double &value) -> bool { + if (auto intLhs = attr.dyn_cast_or_null()) { + value = static_cast(intLhs.getValue().getSExtValue()); + } else if (auto floatLhs = attr.dyn_cast_or_null()) { + value = floatLhs.getValue().convertToDouble(); + } else { + return false; + } + return true; + }; + if (!parseDoubleAttribute(operands[0], lhs) || + !parseDoubleAttribute(operands[1], rhs)) { + return nullptr; + } + return getF64FloatAttr(operands[0].getContext(), f(lhs, rhs)); } //===----------------------------------------------------------------------===// @@ -1903,7 +1955,7 @@ static OpFoldResult atenBinaryIntOperatorFoldHelper(OpTy op, OpFoldResult AtenFloordivIntOp::fold(ArrayRef operands) { return atenBinaryIntOperatorFoldHelper( - *this, [](int64_t a, int64_t b) { return std::floor(a / (double)b); }); + operands, [](int64_t a, int64_t b) { return std::floor(a / (double)b); }); } //===----------------------------------------------------------------------===// @@ -1912,7 +1964,7 @@ OpFoldResult AtenFloordivIntOp::fold(ArrayRef operands) { OpFoldResult AtenRemainderIntOp::fold(ArrayRef operands) { return atenBinaryIntOperatorFoldHelper( - *this, [](int64_t a, int64_t b) { return a % b; }); + operands, [](int64_t a, int64_t b) { return a % b; }); } //===----------------------------------------------------------------------===// @@ -1921,7 +1973,7 @@ OpFoldResult AtenRemainderIntOp::fold(ArrayRef operands) { OpFoldResult AtenAddIntOp::fold(ArrayRef operands) { return atenBinaryIntOperatorFoldHelper( - *this, [](int64_t a, int64_t b) { return a + b; }); + operands, [](int64_t a, int64_t b) { return a + b; }); } //===----------------------------------------------------------------------===// @@ -1930,7 +1982,7 @@ OpFoldResult AtenAddIntOp::fold(ArrayRef operands) { OpFoldResult AtenSubIntOp::fold(ArrayRef operands) { return atenBinaryIntOperatorFoldHelper( - *this, [](int64_t a, int64_t b) { return a - b; }); + operands, [](int64_t a, int64_t b) { return a - b; }); } //===----------------------------------------------------------------------===// @@ -1948,6 +2000,40 @@ OpFoldResult AtenMulIntOp::fold(ArrayRef operands) { return nullptr; } +//===----------------------------------------------------------------------===// +// AtenSubOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenSubOp::fold(ArrayRef operands) { + if (!operands[0] || !operands[1]) { + return nullptr; + } + + if (operands[0].isa() && operands[1].isa()) { + return atenBinaryIntOperatorFoldHelper( + operands, [](int64_t a, int64_t b) -> int64_t { return a - b; }); + } + return atenBinaryFloatOperatorFoldHelper( + operands, [](double a, double b) -> double { return a - b; }); +} + +//===----------------------------------------------------------------------===// +// AtenCeilScalarOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenCeilScalarOp::fold(ArrayRef operands) { + if (!operands[0]) { + return nullptr; + } + auto floatValue = operands[0].dyn_cast_or_null(); + if (!floatValue) { + return nullptr; + } + return getI64IntegerAttr( + getContext(), + static_cast(std::ceil(floatValue.getValue().convertToDouble()))); +} + //===----------------------------------------------------------------------===// // AtenNegIntOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index a5e0872a8..034228ebf 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -969,7 +969,7 @@ StringRef mlir::torch::Torch::getShapeLibrary() { " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = torch.operator \"aten.ceil.Scalar\"(%arg0) : (!torch.union) -> !torch.number\n" +" %1 = torch.aten.ceil.Scalar %arg0 : !torch.union -> !torch.number\n" " %2 = torch.aten.Int.Scalar %1 : !torch.number -> !torch.int\n" " %3 = torch.prim.ListConstruct %2 : (!torch.int) -> !torch.list\n" " return %3 : !torch.list\n" @@ -992,8 +992,8 @@ StringRef mlir::torch::Torch::getShapeLibrary() { " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = torch.operator \"aten.sub\"(%arg1, %arg0) : (!torch.union, !torch.union) -> !torch.number\n" -" %3 = torch.operator \"aten.ceil.Scalar\"(%2) : (!torch.number) -> !torch.number\n" +" %2 = torch.aten.sub %arg1, %arg0 : !torch.union, !torch.union -> !torch.number\n" +" %3 = torch.aten.ceil.Scalar %2 : !torch.number -> !torch.number\n" " %4 = torch.aten.Int.Scalar %3 : !torch.number -> !torch.int\n" " %5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list\n" " return %5 : !torch.list\n" @@ -1029,7 +1029,7 @@ StringRef mlir::torch::Torch::getShapeLibrary() { " }\n" " torch.prim.If.yield\n" " }\n" -" %2 = torch.operator \"aten.sub\"(%arg1, %arg0) : (!torch.union, !torch.union) -> !torch.number\n" +" %2 = torch.aten.sub %arg1, %arg0 : !torch.union, !torch.union -> !torch.number\n" " %3 = torch.aten.div %2, %arg2 : !torch.number, !torch.union -> !torch.float\n" " %4 = torch.aten.ceil.float %3 : !torch.float -> !torch.int\n" " %5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list\n" @@ -7821,4 +7821,4 @@ StringRef mlir::torch::Torch::getShapeLibrary() { #ifndef _MSC_VER #pragma clang diagnostic pop #endif -} +} \ No newline at end of file diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index bdd07110e..01dbfeb4b 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -592,6 +592,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::_set_item.t : (t[], int, t) -> (t[])") emit("aten::div : (Scalar, Scalar) -> (float)") emit("aten::add : (Scalar, Scalar) -> (Scalar)") + emit("aten::sub : (Scalar, Scalar) -> (Scalar)", has_folder=True) + emit("aten::ceil.Scalar : (Scalar) -> (Scalar)", has_folder=True) emit("aten::sqrt.int : (int) -> (float)", has_folder=True) emit("aten::Bool.float : (float) -> (bool)", has_folder=True) emit("aten::Bool.int : (int) -> (bool)", has_folder=True)