mirror of https://github.com/llvm/torch-mlir
Strength the shape inference for aten.arange-like op (#1367)
Strength the shape inference for aten.arange-like op by 1. registering aten.sub and aten.ceil.Scalar op and design folders for them. 2. register a new constant-like op: Torch::ConstantNumberOp and design canonicalizer for it.pull/1385/head snapshot-20220920.602
parent
bb47b36eac
commit
4f3cd236dd
|
@ -8914,6 +8914,55 @@ def Torch_AtenAddOp : Torch_Op<"aten.add", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenSubOp : Torch_Op<"aten.sub", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::sub : (Scalar, Scalar) -> (Scalar)`";
|
||||
let arguments = (ins
|
||||
AnyTorchScalarType:$a,
|
||||
AnyTorchScalarType:$b
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchScalarType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenSubOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenSubOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenCeilScalarOp : Torch_Op<"aten.ceil.Scalar", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::ceil.Scalar : (Scalar) -> (Scalar)`";
|
||||
let arguments = (ins
|
||||
AnyTorchScalarType:$a
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchScalarType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenCeilScalarOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenCeilScalarOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenSqrtIntOp : Torch_Op<"aten.sqrt.int", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -753,6 +753,28 @@ def Torch_ConstantFloatOp : Torch_Op<"constant.float", [
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_ConstantNumberOp : Torch_Op<"constant.number",
|
||||
[ConstantLike, NoSideEffect]> {
|
||||
let summary = "Materialize a constant `number` value.";
|
||||
let description = [{
|
||||
This op is used as a workaround to the fact that the constant
|
||||
materialization in MLIR must materialize a constant with a single op.
|
||||
To materialize ops with a static `!torch.number` type, we must use this op,
|
||||
even though we statically know if it is an integer or a float.
|
||||
|
||||
Note: This op unconditionally canonicalizes to
|
||||
`torch.constant.{float,int}` + `torch.derefine`
|
||||
}];
|
||||
let arguments = (ins
|
||||
AnyAttrOf<[F64Attr, I64Attr]>:$value
|
||||
);
|
||||
let results = (outs
|
||||
Torch_NumberType:$result
|
||||
);
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_ConstantBoolOp : Torch_Op<"constant.bool", [
|
||||
ConstantLike,
|
||||
NoSideEffect,
|
||||
|
|
|
@ -149,6 +149,14 @@ Operation *TorchDialect::materializeConstant(OpBuilder &builder,
|
|||
if (auto floatType = type.dyn_cast<Torch::FloatType>())
|
||||
return builder.create<Torch::ConstantFloatOp>(loc, value.cast<FloatAttr>());
|
||||
|
||||
if (auto numberType = type.dyn_cast<Torch::NumberType>()) {
|
||||
if (auto floatValue = value.dyn_cast<mlir::FloatAttr>()) {
|
||||
return builder.create<Torch::ConstantNumberOp>(loc, floatValue);
|
||||
} else if (auto intValue = value.dyn_cast<mlir::IntegerAttr>()) {
|
||||
return builder.create<Torch::ConstantNumberOp>(loc, intValue);
|
||||
}
|
||||
}
|
||||
|
||||
if (type.isa<Torch::BoolType>()) {
|
||||
return builder.create<Torch::ConstantBoolOp>(loc,
|
||||
value.cast<IntegerAttr>());
|
||||
|
|
|
@ -1591,6 +1591,34 @@ void Torch::ConstantFloatOp::getAsmResultNames(
|
|||
setNameFn(getResult(), StringRef(buf.data(), buf.size()));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConstantNumberOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult Torch::ConstantNumberOp::fold(ArrayRef<Attribute> operands) {
|
||||
return valueAttr();
|
||||
}
|
||||
|
||||
void Torch::ConstantNumberOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
patterns.add(+[](Torch::ConstantNumberOp op, PatternRewriter &rewriter) {
|
||||
Location loc = op->getLoc();
|
||||
|
||||
Value constValue;
|
||||
Attribute value = op.valueAttr();
|
||||
if (auto floatValue = value.dyn_cast<mlir::FloatAttr>()) {
|
||||
constValue = rewriter.create<Torch::ConstantFloatOp>(loc, floatValue);
|
||||
} else if (auto intValue = value.dyn_cast<mlir::IntegerAttr>()) {
|
||||
constValue = rewriter.create<Torch::ConstantIntOp>(loc, intValue);
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<Torch::DerefineOp>(op, op.getType(),
|
||||
constValue);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConstantBoolOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1886,15 +1914,39 @@ OpFoldResult Aten__Contains__IntListOp::fold(ArrayRef<Attribute> operands) {
|
|||
}
|
||||
|
||||
using BinaryIntOperatorFn = std::function<int64_t(int64_t, int64_t)>;
|
||||
template <typename OpTy>
|
||||
static OpFoldResult atenBinaryIntOperatorFoldHelper(OpTy op,
|
||||
BinaryIntOperatorFn f) {
|
||||
int64_t lhs, rhs;
|
||||
if (!matchPattern(op.getOperand(0), m_TorchConstantInt(&lhs)) ||
|
||||
!matchPattern(op.getOperand(1), m_TorchConstantInt(&rhs)))
|
||||
static OpFoldResult
|
||||
atenBinaryIntOperatorFoldHelper(ArrayRef<Attribute> operands,
|
||||
BinaryIntOperatorFn f) {
|
||||
auto intLhs = operands[0].dyn_cast_or_null<IntegerAttr>();
|
||||
auto intRhs = operands[1].dyn_cast_or_null<IntegerAttr>();
|
||||
if (!intLhs || !intRhs) {
|
||||
return nullptr;
|
||||
}
|
||||
return IntegerAttr::get(
|
||||
intLhs.getType(),
|
||||
f(intLhs.getValue().getSExtValue(), intRhs.getValue().getSExtValue()));
|
||||
}
|
||||
|
||||
return getI64IntegerAttr(op.getContext(), f(lhs, rhs));
|
||||
using BinaryFloatOperatorFn = std::function<double(double, double)>;
|
||||
static OpFoldResult
|
||||
atenBinaryFloatOperatorFoldHelper(ArrayRef<Attribute> operands,
|
||||
BinaryFloatOperatorFn f) {
|
||||
double lhs, rhs;
|
||||
auto parseDoubleAttribute = [](Attribute attr, double &value) -> bool {
|
||||
if (auto intLhs = attr.dyn_cast_or_null<IntegerAttr>()) {
|
||||
value = static_cast<double>(intLhs.getValue().getSExtValue());
|
||||
} else if (auto floatLhs = attr.dyn_cast_or_null<FloatAttr>()) {
|
||||
value = floatLhs.getValue().convertToDouble();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
if (!parseDoubleAttribute(operands[0], lhs) ||
|
||||
!parseDoubleAttribute(operands[1], rhs)) {
|
||||
return nullptr;
|
||||
}
|
||||
return getF64FloatAttr(operands[0].getContext(), f(lhs, rhs));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1903,7 +1955,7 @@ static OpFoldResult atenBinaryIntOperatorFoldHelper(OpTy op,
|
|||
|
||||
OpFoldResult AtenFloordivIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
return atenBinaryIntOperatorFoldHelper(
|
||||
*this, [](int64_t a, int64_t b) { return std::floor(a / (double)b); });
|
||||
operands, [](int64_t a, int64_t b) { return std::floor(a / (double)b); });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1912,7 +1964,7 @@ OpFoldResult AtenFloordivIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
|
||||
OpFoldResult AtenRemainderIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
return atenBinaryIntOperatorFoldHelper(
|
||||
*this, [](int64_t a, int64_t b) { return a % b; });
|
||||
operands, [](int64_t a, int64_t b) { return a % b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1921,7 +1973,7 @@ OpFoldResult AtenRemainderIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
|
||||
OpFoldResult AtenAddIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
return atenBinaryIntOperatorFoldHelper(
|
||||
*this, [](int64_t a, int64_t b) { return a + b; });
|
||||
operands, [](int64_t a, int64_t b) { return a + b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1930,7 +1982,7 @@ OpFoldResult AtenAddIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
|
||||
OpFoldResult AtenSubIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
return atenBinaryIntOperatorFoldHelper(
|
||||
*this, [](int64_t a, int64_t b) { return a - b; });
|
||||
operands, [](int64_t a, int64_t b) { return a - b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1948,6 +2000,40 @@ OpFoldResult AtenMulIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenSubOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenSubOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (!operands[0] || !operands[1]) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (operands[0].isa<IntegerAttr>() && operands[1].isa<IntegerAttr>()) {
|
||||
return atenBinaryIntOperatorFoldHelper(
|
||||
operands, [](int64_t a, int64_t b) -> int64_t { return a - b; });
|
||||
}
|
||||
return atenBinaryFloatOperatorFoldHelper(
|
||||
operands, [](double a, double b) -> double { return a - b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenCeilScalarOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenCeilScalarOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (!operands[0]) {
|
||||
return nullptr;
|
||||
}
|
||||
auto floatValue = operands[0].dyn_cast_or_null<FloatAttr>();
|
||||
if (!floatValue) {
|
||||
return nullptr;
|
||||
}
|
||||
return getI64IntegerAttr(
|
||||
getContext(),
|
||||
static_cast<int64_t>(std::ceil(floatValue.getValue().convertToDouble())));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenNegIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -969,7 +969,7 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
|
|||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %1 = torch.operator \"aten.ceil.Scalar\"(%arg0) : (!torch.union<float, int>) -> !torch.number\n"
|
||||
" %1 = torch.aten.ceil.Scalar %arg0 : !torch.union<float, int> -> !torch.number\n"
|
||||
" %2 = torch.aten.Int.Scalar %1 : !torch.number -> !torch.int\n"
|
||||
" %3 = torch.prim.ListConstruct %2 : (!torch.int) -> !torch.list<int>\n"
|
||||
" return %3 : !torch.list<int>\n"
|
||||
|
@ -992,8 +992,8 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
|
|||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %2 = torch.operator \"aten.sub\"(%arg1, %arg0) : (!torch.union<float, int>, !torch.union<float, int>) -> !torch.number\n"
|
||||
" %3 = torch.operator \"aten.ceil.Scalar\"(%2) : (!torch.number) -> !torch.number\n"
|
||||
" %2 = torch.aten.sub %arg1, %arg0 : !torch.union<float, int>, !torch.union<float, int> -> !torch.number\n"
|
||||
" %3 = torch.aten.ceil.Scalar %2 : !torch.number -> !torch.number\n"
|
||||
" %4 = torch.aten.Int.Scalar %3 : !torch.number -> !torch.int\n"
|
||||
" %5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list<int>\n"
|
||||
" return %5 : !torch.list<int>\n"
|
||||
|
@ -1029,7 +1029,7 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
|
|||
" }\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %2 = torch.operator \"aten.sub\"(%arg1, %arg0) : (!torch.union<float, int>, !torch.union<float, int>) -> !torch.number\n"
|
||||
" %2 = torch.aten.sub %arg1, %arg0 : !torch.union<float, int>, !torch.union<float, int> -> !torch.number\n"
|
||||
" %3 = torch.aten.div %2, %arg2 : !torch.number, !torch.union<float, int> -> !torch.float\n"
|
||||
" %4 = torch.aten.ceil.float %3 : !torch.float -> !torch.int\n"
|
||||
" %5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list<int>\n"
|
||||
|
@ -7821,4 +7821,4 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
|
|||
#ifndef _MSC_VER
|
||||
#pragma clang diagnostic pop
|
||||
#endif
|
||||
}
|
||||
}
|
|
@ -592,6 +592,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::_set_item.t : (t[], int, t) -> (t[])")
|
||||
emit("aten::div : (Scalar, Scalar) -> (float)")
|
||||
emit("aten::add : (Scalar, Scalar) -> (Scalar)")
|
||||
emit("aten::sub : (Scalar, Scalar) -> (Scalar)", has_folder=True)
|
||||
emit("aten::ceil.Scalar : (Scalar) -> (Scalar)", has_folder=True)
|
||||
emit("aten::sqrt.int : (int) -> (float)", has_folder=True)
|
||||
emit("aten::Bool.float : (float) -> (bool)", has_folder=True)
|
||||
emit("aten::Bool.int : (int) -> (bool)", has_folder=True)
|
||||
|
|
Loading…
Reference in New Issue