Strength the shape inference for aten.arange-like op (#1367)

Strength the shape inference for aten.arange-like op by
1. registering aten.sub and aten.ceil.Scalar op and design folders for them.
2. register a new constant-like op: Torch::ConstantNumberOp and design canonicalizer for it.
pull/1385/head snapshot-20220920.602
武家伟 2022-09-20 12:40:19 +08:00 committed by GitHub
parent bb47b36eac
commit 4f3cd236dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 183 additions and 16 deletions

View File

@ -8914,6 +8914,55 @@ def Torch_AtenAddOp : Torch_Op<"aten.add", [
}]; }];
} }
def Torch_AtenSubOp : Torch_Op<"aten.sub", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::sub : (Scalar, Scalar) -> (Scalar)`";
let arguments = (ins
AnyTorchScalarType:$a,
AnyTorchScalarType:$b
);
let results = (outs
AnyTorchScalarType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenSubOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenSubOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}
def Torch_AtenCeilScalarOp : Torch_Op<"aten.ceil.Scalar", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::ceil.Scalar : (Scalar) -> (Scalar)`";
let arguments = (ins
AnyTorchScalarType:$a
);
let results = (outs
AnyTorchScalarType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenCeilScalarOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenCeilScalarOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasFolder = 1;
}
def Torch_AtenSqrtIntOp : Torch_Op<"aten.sqrt.int", [ def Torch_AtenSqrtIntOp : Torch_Op<"aten.sqrt.int", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -753,6 +753,28 @@ def Torch_ConstantFloatOp : Torch_Op<"constant.float", [
let hasFolder = 1; let hasFolder = 1;
} }
def Torch_ConstantNumberOp : Torch_Op<"constant.number",
[ConstantLike, NoSideEffect]> {
let summary = "Materialize a constant `number` value.";
let description = [{
This op is used as a workaround to the fact that the constant
materialization in MLIR must materialize a constant with a single op.
To materialize ops with a static `!torch.number` type, we must use this op,
even though we statically know if it is an integer or a float.
Note: This op unconditionally canonicalizes to
`torch.constant.{float,int}` + `torch.derefine`
}];
let arguments = (ins
AnyAttrOf<[F64Attr, I64Attr]>:$value
);
let results = (outs
Torch_NumberType:$result
);
let hasFolder = 1;
let hasCanonicalizer = 1;
}
def Torch_ConstantBoolOp : Torch_Op<"constant.bool", [ def Torch_ConstantBoolOp : Torch_Op<"constant.bool", [
ConstantLike, ConstantLike,
NoSideEffect, NoSideEffect,

View File

@ -149,6 +149,14 @@ Operation *TorchDialect::materializeConstant(OpBuilder &builder,
if (auto floatType = type.dyn_cast<Torch::FloatType>()) if (auto floatType = type.dyn_cast<Torch::FloatType>())
return builder.create<Torch::ConstantFloatOp>(loc, value.cast<FloatAttr>()); return builder.create<Torch::ConstantFloatOp>(loc, value.cast<FloatAttr>());
if (auto numberType = type.dyn_cast<Torch::NumberType>()) {
if (auto floatValue = value.dyn_cast<mlir::FloatAttr>()) {
return builder.create<Torch::ConstantNumberOp>(loc, floatValue);
} else if (auto intValue = value.dyn_cast<mlir::IntegerAttr>()) {
return builder.create<Torch::ConstantNumberOp>(loc, intValue);
}
}
if (type.isa<Torch::BoolType>()) { if (type.isa<Torch::BoolType>()) {
return builder.create<Torch::ConstantBoolOp>(loc, return builder.create<Torch::ConstantBoolOp>(loc,
value.cast<IntegerAttr>()); value.cast<IntegerAttr>());

View File

@ -1591,6 +1591,34 @@ void Torch::ConstantFloatOp::getAsmResultNames(
setNameFn(getResult(), StringRef(buf.data(), buf.size())); setNameFn(getResult(), StringRef(buf.data(), buf.size()));
} }
//===----------------------------------------------------------------------===//
// ConstantNumberOp
//===----------------------------------------------------------------------===//
OpFoldResult Torch::ConstantNumberOp::fold(ArrayRef<Attribute> operands) {
return valueAttr();
}
void Torch::ConstantNumberOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add(+[](Torch::ConstantNumberOp op, PatternRewriter &rewriter) {
Location loc = op->getLoc();
Value constValue;
Attribute value = op.valueAttr();
if (auto floatValue = value.dyn_cast<mlir::FloatAttr>()) {
constValue = rewriter.create<Torch::ConstantFloatOp>(loc, floatValue);
} else if (auto intValue = value.dyn_cast<mlir::IntegerAttr>()) {
constValue = rewriter.create<Torch::ConstantIntOp>(loc, intValue);
} else {
return failure();
}
rewriter.replaceOpWithNewOp<Torch::DerefineOp>(op, op.getType(),
constValue);
return success();
});
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ConstantBoolOp // ConstantBoolOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1886,15 +1914,39 @@ OpFoldResult Aten__Contains__IntListOp::fold(ArrayRef<Attribute> operands) {
} }
using BinaryIntOperatorFn = std::function<int64_t(int64_t, int64_t)>; using BinaryIntOperatorFn = std::function<int64_t(int64_t, int64_t)>;
template <typename OpTy> static OpFoldResult
static OpFoldResult atenBinaryIntOperatorFoldHelper(OpTy op, atenBinaryIntOperatorFoldHelper(ArrayRef<Attribute> operands,
BinaryIntOperatorFn f) { BinaryIntOperatorFn f) {
int64_t lhs, rhs; auto intLhs = operands[0].dyn_cast_or_null<IntegerAttr>();
if (!matchPattern(op.getOperand(0), m_TorchConstantInt(&lhs)) || auto intRhs = operands[1].dyn_cast_or_null<IntegerAttr>();
!matchPattern(op.getOperand(1), m_TorchConstantInt(&rhs))) if (!intLhs || !intRhs) {
return nullptr; return nullptr;
}
return IntegerAttr::get(
intLhs.getType(),
f(intLhs.getValue().getSExtValue(), intRhs.getValue().getSExtValue()));
}
return getI64IntegerAttr(op.getContext(), f(lhs, rhs)); using BinaryFloatOperatorFn = std::function<double(double, double)>;
static OpFoldResult
atenBinaryFloatOperatorFoldHelper(ArrayRef<Attribute> operands,
BinaryFloatOperatorFn f) {
double lhs, rhs;
auto parseDoubleAttribute = [](Attribute attr, double &value) -> bool {
if (auto intLhs = attr.dyn_cast_or_null<IntegerAttr>()) {
value = static_cast<double>(intLhs.getValue().getSExtValue());
} else if (auto floatLhs = attr.dyn_cast_or_null<FloatAttr>()) {
value = floatLhs.getValue().convertToDouble();
} else {
return false;
}
return true;
};
if (!parseDoubleAttribute(operands[0], lhs) ||
!parseDoubleAttribute(operands[1], rhs)) {
return nullptr;
}
return getF64FloatAttr(operands[0].getContext(), f(lhs, rhs));
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1903,7 +1955,7 @@ static OpFoldResult atenBinaryIntOperatorFoldHelper(OpTy op,
OpFoldResult AtenFloordivIntOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenFloordivIntOp::fold(ArrayRef<Attribute> operands) {
return atenBinaryIntOperatorFoldHelper( return atenBinaryIntOperatorFoldHelper(
*this, [](int64_t a, int64_t b) { return std::floor(a / (double)b); }); operands, [](int64_t a, int64_t b) { return std::floor(a / (double)b); });
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1912,7 +1964,7 @@ OpFoldResult AtenFloordivIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenRemainderIntOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenRemainderIntOp::fold(ArrayRef<Attribute> operands) {
return atenBinaryIntOperatorFoldHelper( return atenBinaryIntOperatorFoldHelper(
*this, [](int64_t a, int64_t b) { return a % b; }); operands, [](int64_t a, int64_t b) { return a % b; });
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1921,7 +1973,7 @@ OpFoldResult AtenRemainderIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenAddIntOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenAddIntOp::fold(ArrayRef<Attribute> operands) {
return atenBinaryIntOperatorFoldHelper( return atenBinaryIntOperatorFoldHelper(
*this, [](int64_t a, int64_t b) { return a + b; }); operands, [](int64_t a, int64_t b) { return a + b; });
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1930,7 +1982,7 @@ OpFoldResult AtenAddIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenSubIntOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenSubIntOp::fold(ArrayRef<Attribute> operands) {
return atenBinaryIntOperatorFoldHelper( return atenBinaryIntOperatorFoldHelper(
*this, [](int64_t a, int64_t b) { return a - b; }); operands, [](int64_t a, int64_t b) { return a - b; });
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1948,6 +2000,40 @@ OpFoldResult AtenMulIntOp::fold(ArrayRef<Attribute> operands) {
return nullptr; return nullptr;
} }
//===----------------------------------------------------------------------===//
// AtenSubOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenSubOp::fold(ArrayRef<Attribute> operands) {
if (!operands[0] || !operands[1]) {
return nullptr;
}
if (operands[0].isa<IntegerAttr>() && operands[1].isa<IntegerAttr>()) {
return atenBinaryIntOperatorFoldHelper(
operands, [](int64_t a, int64_t b) -> int64_t { return a - b; });
}
return atenBinaryFloatOperatorFoldHelper(
operands, [](double a, double b) -> double { return a - b; });
}
//===----------------------------------------------------------------------===//
// AtenCeilScalarOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenCeilScalarOp::fold(ArrayRef<Attribute> operands) {
if (!operands[0]) {
return nullptr;
}
auto floatValue = operands[0].dyn_cast_or_null<FloatAttr>();
if (!floatValue) {
return nullptr;
}
return getI64IntegerAttr(
getContext(),
static_cast<int64_t>(std::ceil(floatValue.getValue().convertToDouble())));
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenNegIntOp // AtenNegIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -969,7 +969,7 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n" " torch.prim.If.yield\n"
" }\n" " }\n"
" %1 = torch.operator \"aten.ceil.Scalar\"(%arg0) : (!torch.union<float, int>) -> !torch.number\n" " %1 = torch.aten.ceil.Scalar %arg0 : !torch.union<float, int> -> !torch.number\n"
" %2 = torch.aten.Int.Scalar %1 : !torch.number -> !torch.int\n" " %2 = torch.aten.Int.Scalar %1 : !torch.number -> !torch.int\n"
" %3 = torch.prim.ListConstruct %2 : (!torch.int) -> !torch.list<int>\n" " %3 = torch.prim.ListConstruct %2 : (!torch.int) -> !torch.list<int>\n"
" return %3 : !torch.list<int>\n" " return %3 : !torch.list<int>\n"
@ -992,8 +992,8 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n" " torch.prim.If.yield\n"
" }\n" " }\n"
" %2 = torch.operator \"aten.sub\"(%arg1, %arg0) : (!torch.union<float, int>, !torch.union<float, int>) -> !torch.number\n" " %2 = torch.aten.sub %arg1, %arg0 : !torch.union<float, int>, !torch.union<float, int> -> !torch.number\n"
" %3 = torch.operator \"aten.ceil.Scalar\"(%2) : (!torch.number) -> !torch.number\n" " %3 = torch.aten.ceil.Scalar %2 : !torch.number -> !torch.number\n"
" %4 = torch.aten.Int.Scalar %3 : !torch.number -> !torch.int\n" " %4 = torch.aten.Int.Scalar %3 : !torch.number -> !torch.int\n"
" %5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list<int>\n" " %5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list<int>\n"
" return %5 : !torch.list<int>\n" " return %5 : !torch.list<int>\n"
@ -1029,7 +1029,7 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
" }\n" " }\n"
" torch.prim.If.yield\n" " torch.prim.If.yield\n"
" }\n" " }\n"
" %2 = torch.operator \"aten.sub\"(%arg1, %arg0) : (!torch.union<float, int>, !torch.union<float, int>) -> !torch.number\n" " %2 = torch.aten.sub %arg1, %arg0 : !torch.union<float, int>, !torch.union<float, int> -> !torch.number\n"
" %3 = torch.aten.div %2, %arg2 : !torch.number, !torch.union<float, int> -> !torch.float\n" " %3 = torch.aten.div %2, %arg2 : !torch.number, !torch.union<float, int> -> !torch.float\n"
" %4 = torch.aten.ceil.float %3 : !torch.float -> !torch.int\n" " %4 = torch.aten.ceil.float %3 : !torch.float -> !torch.int\n"
" %5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list<int>\n" " %5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list<int>\n"

View File

@ -592,6 +592,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::_set_item.t : (t[], int, t) -> (t[])") emit("aten::_set_item.t : (t[], int, t) -> (t[])")
emit("aten::div : (Scalar, Scalar) -> (float)") emit("aten::div : (Scalar, Scalar) -> (float)")
emit("aten::add : (Scalar, Scalar) -> (Scalar)") emit("aten::add : (Scalar, Scalar) -> (Scalar)")
emit("aten::sub : (Scalar, Scalar) -> (Scalar)", has_folder=True)
emit("aten::ceil.Scalar : (Scalar) -> (Scalar)", has_folder=True)
emit("aten::sqrt.int : (int) -> (float)", has_folder=True) emit("aten::sqrt.int : (int) -> (float)", has_folder=True)
emit("aten::Bool.float : (float) -> (bool)", has_folder=True) emit("aten::Bool.float : (float) -> (bool)", has_folder=True)
emit("aten::Bool.int : (int) -> (bool)", has_folder=True) emit("aten::Bool.int : (int) -> (bool)", has_folder=True)