mirror of https://github.com/llvm/torch-mlir
Strength the shape inference for aten.arange-like op (#1367)
Strength the shape inference for aten.arange-like op by 1. registering aten.sub and aten.ceil.Scalar op and design folders for them. 2. register a new constant-like op: Torch::ConstantNumberOp and design canonicalizer for it.pull/1385/head snapshot-20220920.602
parent
bb47b36eac
commit
4f3cd236dd
|
@ -8914,6 +8914,55 @@ def Torch_AtenAddOp : Torch_Op<"aten.add", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenSubOp : Torch_Op<"aten.sub", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::sub : (Scalar, Scalar) -> (Scalar)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchScalarType:$a,
|
||||||
|
AnyTorchScalarType:$b
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchScalarType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenSubOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||||
|
}
|
||||||
|
void AtenSubOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
let hasFolder = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_AtenCeilScalarOp : Torch_Op<"aten.ceil.Scalar", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::ceil.Scalar : (Scalar) -> (Scalar)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchScalarType:$a
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchScalarType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenCeilScalarOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||||
|
}
|
||||||
|
void AtenCeilScalarOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 1, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
let hasFolder = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenSqrtIntOp : Torch_Op<"aten.sqrt.int", [
|
def Torch_AtenSqrtIntOp : Torch_Op<"aten.sqrt.int", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -753,6 +753,28 @@ def Torch_ConstantFloatOp : Torch_Op<"constant.float", [
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_ConstantNumberOp : Torch_Op<"constant.number",
|
||||||
|
[ConstantLike, NoSideEffect]> {
|
||||||
|
let summary = "Materialize a constant `number` value.";
|
||||||
|
let description = [{
|
||||||
|
This op is used as a workaround to the fact that the constant
|
||||||
|
materialization in MLIR must materialize a constant with a single op.
|
||||||
|
To materialize ops with a static `!torch.number` type, we must use this op,
|
||||||
|
even though we statically know if it is an integer or a float.
|
||||||
|
|
||||||
|
Note: This op unconditionally canonicalizes to
|
||||||
|
`torch.constant.{float,int}` + `torch.derefine`
|
||||||
|
}];
|
||||||
|
let arguments = (ins
|
||||||
|
AnyAttrOf<[F64Attr, I64Attr]>:$value
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
Torch_NumberType:$result
|
||||||
|
);
|
||||||
|
let hasFolder = 1;
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_ConstantBoolOp : Torch_Op<"constant.bool", [
|
def Torch_ConstantBoolOp : Torch_Op<"constant.bool", [
|
||||||
ConstantLike,
|
ConstantLike,
|
||||||
NoSideEffect,
|
NoSideEffect,
|
||||||
|
|
|
@ -149,6 +149,14 @@ Operation *TorchDialect::materializeConstant(OpBuilder &builder,
|
||||||
if (auto floatType = type.dyn_cast<Torch::FloatType>())
|
if (auto floatType = type.dyn_cast<Torch::FloatType>())
|
||||||
return builder.create<Torch::ConstantFloatOp>(loc, value.cast<FloatAttr>());
|
return builder.create<Torch::ConstantFloatOp>(loc, value.cast<FloatAttr>());
|
||||||
|
|
||||||
|
if (auto numberType = type.dyn_cast<Torch::NumberType>()) {
|
||||||
|
if (auto floatValue = value.dyn_cast<mlir::FloatAttr>()) {
|
||||||
|
return builder.create<Torch::ConstantNumberOp>(loc, floatValue);
|
||||||
|
} else if (auto intValue = value.dyn_cast<mlir::IntegerAttr>()) {
|
||||||
|
return builder.create<Torch::ConstantNumberOp>(loc, intValue);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (type.isa<Torch::BoolType>()) {
|
if (type.isa<Torch::BoolType>()) {
|
||||||
return builder.create<Torch::ConstantBoolOp>(loc,
|
return builder.create<Torch::ConstantBoolOp>(loc,
|
||||||
value.cast<IntegerAttr>());
|
value.cast<IntegerAttr>());
|
||||||
|
|
|
@ -1591,6 +1591,34 @@ void Torch::ConstantFloatOp::getAsmResultNames(
|
||||||
setNameFn(getResult(), StringRef(buf.data(), buf.size()));
|
setNameFn(getResult(), StringRef(buf.data(), buf.size()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ConstantNumberOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult Torch::ConstantNumberOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
return valueAttr();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Torch::ConstantNumberOp::getCanonicalizationPatterns(
|
||||||
|
RewritePatternSet &patterns, MLIRContext *context) {
|
||||||
|
patterns.add(+[](Torch::ConstantNumberOp op, PatternRewriter &rewriter) {
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
|
||||||
|
Value constValue;
|
||||||
|
Attribute value = op.valueAttr();
|
||||||
|
if (auto floatValue = value.dyn_cast<mlir::FloatAttr>()) {
|
||||||
|
constValue = rewriter.create<Torch::ConstantFloatOp>(loc, floatValue);
|
||||||
|
} else if (auto intValue = value.dyn_cast<mlir::IntegerAttr>()) {
|
||||||
|
constValue = rewriter.create<Torch::ConstantIntOp>(loc, intValue);
|
||||||
|
} else {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::DerefineOp>(op, op.getType(),
|
||||||
|
constValue);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ConstantBoolOp
|
// ConstantBoolOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1886,15 +1914,39 @@ OpFoldResult Aten__Contains__IntListOp::fold(ArrayRef<Attribute> operands) {
|
||||||
}
|
}
|
||||||
|
|
||||||
using BinaryIntOperatorFn = std::function<int64_t(int64_t, int64_t)>;
|
using BinaryIntOperatorFn = std::function<int64_t(int64_t, int64_t)>;
|
||||||
template <typename OpTy>
|
static OpFoldResult
|
||||||
static OpFoldResult atenBinaryIntOperatorFoldHelper(OpTy op,
|
atenBinaryIntOperatorFoldHelper(ArrayRef<Attribute> operands,
|
||||||
BinaryIntOperatorFn f) {
|
BinaryIntOperatorFn f) {
|
||||||
int64_t lhs, rhs;
|
auto intLhs = operands[0].dyn_cast_or_null<IntegerAttr>();
|
||||||
if (!matchPattern(op.getOperand(0), m_TorchConstantInt(&lhs)) ||
|
auto intRhs = operands[1].dyn_cast_or_null<IntegerAttr>();
|
||||||
!matchPattern(op.getOperand(1), m_TorchConstantInt(&rhs)))
|
if (!intLhs || !intRhs) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
}
|
||||||
|
return IntegerAttr::get(
|
||||||
|
intLhs.getType(),
|
||||||
|
f(intLhs.getValue().getSExtValue(), intRhs.getValue().getSExtValue()));
|
||||||
|
}
|
||||||
|
|
||||||
return getI64IntegerAttr(op.getContext(), f(lhs, rhs));
|
using BinaryFloatOperatorFn = std::function<double(double, double)>;
|
||||||
|
static OpFoldResult
|
||||||
|
atenBinaryFloatOperatorFoldHelper(ArrayRef<Attribute> operands,
|
||||||
|
BinaryFloatOperatorFn f) {
|
||||||
|
double lhs, rhs;
|
||||||
|
auto parseDoubleAttribute = [](Attribute attr, double &value) -> bool {
|
||||||
|
if (auto intLhs = attr.dyn_cast_or_null<IntegerAttr>()) {
|
||||||
|
value = static_cast<double>(intLhs.getValue().getSExtValue());
|
||||||
|
} else if (auto floatLhs = attr.dyn_cast_or_null<FloatAttr>()) {
|
||||||
|
value = floatLhs.getValue().convertToDouble();
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
if (!parseDoubleAttribute(operands[0], lhs) ||
|
||||||
|
!parseDoubleAttribute(operands[1], rhs)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return getF64FloatAttr(operands[0].getContext(), f(lhs, rhs));
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1903,7 +1955,7 @@ static OpFoldResult atenBinaryIntOperatorFoldHelper(OpTy op,
|
||||||
|
|
||||||
OpFoldResult AtenFloordivIntOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult AtenFloordivIntOp::fold(ArrayRef<Attribute> operands) {
|
||||||
return atenBinaryIntOperatorFoldHelper(
|
return atenBinaryIntOperatorFoldHelper(
|
||||||
*this, [](int64_t a, int64_t b) { return std::floor(a / (double)b); });
|
operands, [](int64_t a, int64_t b) { return std::floor(a / (double)b); });
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1912,7 +1964,7 @@ OpFoldResult AtenFloordivIntOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
|
||||||
OpFoldResult AtenRemainderIntOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult AtenRemainderIntOp::fold(ArrayRef<Attribute> operands) {
|
||||||
return atenBinaryIntOperatorFoldHelper(
|
return atenBinaryIntOperatorFoldHelper(
|
||||||
*this, [](int64_t a, int64_t b) { return a % b; });
|
operands, [](int64_t a, int64_t b) { return a % b; });
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1921,7 +1973,7 @@ OpFoldResult AtenRemainderIntOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
|
||||||
OpFoldResult AtenAddIntOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult AtenAddIntOp::fold(ArrayRef<Attribute> operands) {
|
||||||
return atenBinaryIntOperatorFoldHelper(
|
return atenBinaryIntOperatorFoldHelper(
|
||||||
*this, [](int64_t a, int64_t b) { return a + b; });
|
operands, [](int64_t a, int64_t b) { return a + b; });
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1930,7 +1982,7 @@ OpFoldResult AtenAddIntOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
|
||||||
OpFoldResult AtenSubIntOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult AtenSubIntOp::fold(ArrayRef<Attribute> operands) {
|
||||||
return atenBinaryIntOperatorFoldHelper(
|
return atenBinaryIntOperatorFoldHelper(
|
||||||
*this, [](int64_t a, int64_t b) { return a - b; });
|
operands, [](int64_t a, int64_t b) { return a - b; });
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1948,6 +2000,40 @@ OpFoldResult AtenMulIntOp::fold(ArrayRef<Attribute> operands) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AtenSubOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult AtenSubOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
if (!operands[0] || !operands[1]) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (operands[0].isa<IntegerAttr>() && operands[1].isa<IntegerAttr>()) {
|
||||||
|
return atenBinaryIntOperatorFoldHelper(
|
||||||
|
operands, [](int64_t a, int64_t b) -> int64_t { return a - b; });
|
||||||
|
}
|
||||||
|
return atenBinaryFloatOperatorFoldHelper(
|
||||||
|
operands, [](double a, double b) -> double { return a - b; });
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AtenCeilScalarOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult AtenCeilScalarOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
if (!operands[0]) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto floatValue = operands[0].dyn_cast_or_null<FloatAttr>();
|
||||||
|
if (!floatValue) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return getI64IntegerAttr(
|
||||||
|
getContext(),
|
||||||
|
static_cast<int64_t>(std::ceil(floatValue.getValue().convertToDouble())));
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AtenNegIntOp
|
// AtenNegIntOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -969,7 +969,7 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
|
||||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||||
" torch.prim.If.yield\n"
|
" torch.prim.If.yield\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
" %1 = torch.operator \"aten.ceil.Scalar\"(%arg0) : (!torch.union<float, int>) -> !torch.number\n"
|
" %1 = torch.aten.ceil.Scalar %arg0 : !torch.union<float, int> -> !torch.number\n"
|
||||||
" %2 = torch.aten.Int.Scalar %1 : !torch.number -> !torch.int\n"
|
" %2 = torch.aten.Int.Scalar %1 : !torch.number -> !torch.int\n"
|
||||||
" %3 = torch.prim.ListConstruct %2 : (!torch.int) -> !torch.list<int>\n"
|
" %3 = torch.prim.ListConstruct %2 : (!torch.int) -> !torch.list<int>\n"
|
||||||
" return %3 : !torch.list<int>\n"
|
" return %3 : !torch.list<int>\n"
|
||||||
|
@ -992,8 +992,8 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
|
||||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||||
" torch.prim.If.yield\n"
|
" torch.prim.If.yield\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
" %2 = torch.operator \"aten.sub\"(%arg1, %arg0) : (!torch.union<float, int>, !torch.union<float, int>) -> !torch.number\n"
|
" %2 = torch.aten.sub %arg1, %arg0 : !torch.union<float, int>, !torch.union<float, int> -> !torch.number\n"
|
||||||
" %3 = torch.operator \"aten.ceil.Scalar\"(%2) : (!torch.number) -> !torch.number\n"
|
" %3 = torch.aten.ceil.Scalar %2 : !torch.number -> !torch.number\n"
|
||||||
" %4 = torch.aten.Int.Scalar %3 : !torch.number -> !torch.int\n"
|
" %4 = torch.aten.Int.Scalar %3 : !torch.number -> !torch.int\n"
|
||||||
" %5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list<int>\n"
|
" %5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list<int>\n"
|
||||||
" return %5 : !torch.list<int>\n"
|
" return %5 : !torch.list<int>\n"
|
||||||
|
@ -1029,7 +1029,7 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
|
||||||
" }\n"
|
" }\n"
|
||||||
" torch.prim.If.yield\n"
|
" torch.prim.If.yield\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
" %2 = torch.operator \"aten.sub\"(%arg1, %arg0) : (!torch.union<float, int>, !torch.union<float, int>) -> !torch.number\n"
|
" %2 = torch.aten.sub %arg1, %arg0 : !torch.union<float, int>, !torch.union<float, int> -> !torch.number\n"
|
||||||
" %3 = torch.aten.div %2, %arg2 : !torch.number, !torch.union<float, int> -> !torch.float\n"
|
" %3 = torch.aten.div %2, %arg2 : !torch.number, !torch.union<float, int> -> !torch.float\n"
|
||||||
" %4 = torch.aten.ceil.float %3 : !torch.float -> !torch.int\n"
|
" %4 = torch.aten.ceil.float %3 : !torch.float -> !torch.int\n"
|
||||||
" %5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list<int>\n"
|
" %5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list<int>\n"
|
||||||
|
|
|
@ -592,6 +592,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::_set_item.t : (t[], int, t) -> (t[])")
|
emit("aten::_set_item.t : (t[], int, t) -> (t[])")
|
||||||
emit("aten::div : (Scalar, Scalar) -> (float)")
|
emit("aten::div : (Scalar, Scalar) -> (float)")
|
||||||
emit("aten::add : (Scalar, Scalar) -> (Scalar)")
|
emit("aten::add : (Scalar, Scalar) -> (Scalar)")
|
||||||
|
emit("aten::sub : (Scalar, Scalar) -> (Scalar)", has_folder=True)
|
||||||
|
emit("aten::ceil.Scalar : (Scalar) -> (Scalar)", has_folder=True)
|
||||||
emit("aten::sqrt.int : (int) -> (float)", has_folder=True)
|
emit("aten::sqrt.int : (int) -> (float)", has_folder=True)
|
||||||
emit("aten::Bool.float : (float) -> (bool)", has_folder=True)
|
emit("aten::Bool.float : (float) -> (bool)", has_folder=True)
|
||||||
emit("aten::Bool.int : (int) -> (bool)", has_folder=True)
|
emit("aten::Bool.int : (int) -> (bool)", has_folder=True)
|
||||||
|
|
Loading…
Reference in New Issue