[MLIR][TORCH] Add E2E support for aten.sqrt.int op

This commit adds lowering of `aten.sqrt.int` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/872/head
Vivek Khandelwal 2022-05-19 20:24:16 +05:30
parent f791b2ecae
commit bc9b2156e3
7 changed files with 119 additions and 5 deletions

View File

@ -7142,6 +7142,30 @@ def Torch_AtenAddOp : Torch_Op<"aten.add", [
}];
}
def Torch_AtenSqrtIntOp : Torch_Op<"aten.sqrt.int", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::sqrt.int : (int) -> (float)`";
let arguments = (ins
Torch_IntType:$a
);
let results = (outs
Torch_FloatType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenSqrtIntOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenSqrtIntOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasFolder = 1;
}
def Torch_AtenEqDeviceOp : Torch_Op<"aten.eq.device", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -80,18 +80,22 @@ public:
namespace {
template <typename AtenOp, typename UnaryOp>
class ConvertAtenUnaryOp : public OpConversionPattern<AtenOp> {
class ConvertAtenUnaryOpToFloatMathOp : public OpConversionPattern<AtenOp> {
public:
using OpConversionPattern<AtenOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenOp op,
typename OpConversionPattern<AtenOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = adaptor.a();
Type resultType =
this->getTypeConverter()->convertType(op->getResult(0).getType());
Value result = rewriter.create<UnaryOp>(op.getLoc(), adaptor.a());
rewriter.replaceOp(
op, convertScalarToDtype(rewriter, op.getLoc(), result, resultType));
if (!input.getType().isa<mlir::FloatType>())
input = convertScalarToDtype(rewriter, loc, input, rewriter.getF64Type());
Value result = rewriter.create<UnaryOp>(loc, input);
rewriter.replaceOp(op,
convertScalarToDtype(rewriter, loc, result, resultType));
return success();
}
};
@ -272,7 +276,11 @@ public:
patterns.add<ConvertAtenBinaryOp<AtenDivFloatOp, arith::DivFOp>>(
typeConverter, context);
target.addIllegalOp<AtenCeilFloatOp>();
patterns.add<ConvertAtenUnaryOp<AtenCeilFloatOp, math::CeilOp>>(
patterns
.add<ConvertAtenUnaryOpToFloatMathOp<AtenCeilFloatOp, math::CeilOp>>(
typeConverter, context);
target.addIllegalOp<AtenSqrtIntOp>();
patterns.add<ConvertAtenUnaryOpToFloatMathOp<AtenSqrtIntOp, math::SqrtOp>>(
typeConverter, context);
if (failed(applyPartialConversion(getOperation(), target,

View File

@ -1581,6 +1581,17 @@ OpFoldResult AtenNegIntOp::fold(ArrayRef<Attribute> operands) {
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenSqrtIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenSqrtIntOp::fold(ArrayRef<Attribute> operands) {
int64_t c;
if (matchPattern(getOperand(), m_TorchConstantInt(&c)))
return getF64FloatAttr(getContext(), std::sqrt(c));
return nullptr;
}
//===----------------------------------------------------------------------===//
// PrimDtypeOp
//===----------------------------------------------------------------------===//

View File

@ -521,6 +521,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::_set_item.t : (t[], int, t) -> (t[])")
emit("aten::div : (Scalar, Scalar) -> (float)")
emit("aten::add : (Scalar, Scalar) -> (Scalar)")
emit("aten::sqrt.int : (int) -> (float)", has_folder=True)
emit("aten::eq.device : (Device, Device) -> (bool)")
emit("aten::ceil.float : (float) -> (int)", has_folder=True)

View File

@ -151,3 +151,43 @@ class CeilFloatModule(torch.nn.Module):
@register_test_case(module_factory=lambda: CeilFloatModule())
def CeilFloatModule_basic(module, tu: TestUtils):
module.forward(torch.rand(()).double(), torch.rand(()).double())
# ==============================================================================
class SqrtIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([], torch.int64, True),
])
def forward(self, a):
return float(torch.ops.aten.sqrt(int(a)))
@register_test_case(module_factory=lambda: SqrtIntModule())
def SqrtIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(10, ()))
class SqrtIntConstantModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return float(torch.ops.aten.sqrt(5))
@register_test_case(module_factory=lambda: SqrtIntConstantModule())
def SqrtIntConstantModule_basic(module, tu: TestUtils):
module.forward()

View File

@ -233,3 +233,15 @@ func.func @torch.aten.gt.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !t
%0 = torch.aten.gt.float_int %arg0, %arg1 : !torch.float, !torch.int -> !torch.bool
return %0 : !torch.bool
}
// CHECK-LABEL: func.func @torch.aten.sqrt.int(
// CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.float {
// CHECK: %[[ARG_I64:.*]] = torch_c.to_i64 %[[ARG]]
// CHECK: %[[ARG_F64:.*]] = arith.sitofp %[[ARG_I64]] : i64 to f64
// CHECK: %[[SQRT:.*]] = math.sqrt %[[ARG_F64]] : f64
// CHECK: %[[SQRT_TORCH_FLOAT:.*]] = torch_c.from_f64 %[[SQRT]]
// CHECK: return %[[SQRT_TORCH_FLOAT]] : !torch.float
func.func @torch.aten.sqrt.int(%arg0: !torch.int) -> !torch.float {
%0 = torch.aten.sqrt.int %arg0 : !torch.int -> !torch.float
return %0 : !torch.float
}

View File

@ -1231,3 +1231,21 @@ func.func @torch.aten.ceil.float$no_fold(%arg0 : !torch.float) -> !torch.int {
%1 = torch.aten.ceil.float %arg0 : !torch.float -> !torch.int
return %1 : !torch.int
}
// CHECK-LABEL: func.func @torch.aten.sqrt.int$fold_cst() -> !torch.float {
// CHECK: %[[CST:.*]] = torch.constant.float 2.2360679774997898
// CHECK: return %[[CST]] : !torch.float
func.func @torch.aten.sqrt.int$fold_cst() -> !torch.float {
%int = torch.constant.int 5
%0 = torch.aten.sqrt.int %int : !torch.int -> !torch.float
return %0 : !torch.float
}
// CHECK-LABEL: func.func @torch.aten.sqrt.int$no_fold(
// CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.float {
// CHECK: %[[RESULT:.*]] = torch.aten.sqrt.int %[[ARG]] : !torch.int -> !torch.float
// CHECK: return %[[RESULT]] : !torch.float
func.func @torch.aten.sqrt.int$no_fold(%arg0 : !torch.int) -> !torch.float {
%0 = torch.aten.sqrt.int %arg0 : !torch.int -> !torch.float
return %0 : !torch.float
}