mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add E2E support for aten.sqrt.int op
This commit adds lowering of `aten.sqrt.int` op. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/872/head
parent
f791b2ecae
commit
bc9b2156e3
|
@ -7142,6 +7142,30 @@ def Torch_AtenAddOp : Torch_Op<"aten.add", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenSqrtIntOp : Torch_Op<"aten.sqrt.int", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::sqrt.int : (int) -> (float)`";
|
||||
let arguments = (ins
|
||||
Torch_IntType:$a
|
||||
);
|
||||
let results = (outs
|
||||
Torch_FloatType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenSqrtIntOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenSqrtIntOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenEqDeviceOp : Torch_Op<"aten.eq.device", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -80,18 +80,22 @@ public:
|
|||
|
||||
namespace {
|
||||
template <typename AtenOp, typename UnaryOp>
|
||||
class ConvertAtenUnaryOp : public OpConversionPattern<AtenOp> {
|
||||
class ConvertAtenUnaryOpToFloatMathOp : public OpConversionPattern<AtenOp> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOp>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOp op,
|
||||
typename OpConversionPattern<AtenOp>::OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value input = adaptor.a();
|
||||
Type resultType =
|
||||
this->getTypeConverter()->convertType(op->getResult(0).getType());
|
||||
Value result = rewriter.create<UnaryOp>(op.getLoc(), adaptor.a());
|
||||
rewriter.replaceOp(
|
||||
op, convertScalarToDtype(rewriter, op.getLoc(), result, resultType));
|
||||
if (!input.getType().isa<mlir::FloatType>())
|
||||
input = convertScalarToDtype(rewriter, loc, input, rewriter.getF64Type());
|
||||
Value result = rewriter.create<UnaryOp>(loc, input);
|
||||
rewriter.replaceOp(op,
|
||||
convertScalarToDtype(rewriter, loc, result, resultType));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -272,7 +276,11 @@ public:
|
|||
patterns.add<ConvertAtenBinaryOp<AtenDivFloatOp, arith::DivFOp>>(
|
||||
typeConverter, context);
|
||||
target.addIllegalOp<AtenCeilFloatOp>();
|
||||
patterns.add<ConvertAtenUnaryOp<AtenCeilFloatOp, math::CeilOp>>(
|
||||
patterns
|
||||
.add<ConvertAtenUnaryOpToFloatMathOp<AtenCeilFloatOp, math::CeilOp>>(
|
||||
typeConverter, context);
|
||||
target.addIllegalOp<AtenSqrtIntOp>();
|
||||
patterns.add<ConvertAtenUnaryOpToFloatMathOp<AtenSqrtIntOp, math::SqrtOp>>(
|
||||
typeConverter, context);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
|
|
|
@ -1581,6 +1581,17 @@ OpFoldResult AtenNegIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenSqrtIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenSqrtIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
int64_t c;
|
||||
if (matchPattern(getOperand(), m_TorchConstantInt(&c)))
|
||||
return getF64FloatAttr(getContext(), std::sqrt(c));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PrimDtypeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -521,6 +521,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::_set_item.t : (t[], int, t) -> (t[])")
|
||||
emit("aten::div : (Scalar, Scalar) -> (float)")
|
||||
emit("aten::add : (Scalar, Scalar) -> (Scalar)")
|
||||
emit("aten::sqrt.int : (int) -> (float)", has_folder=True)
|
||||
|
||||
emit("aten::eq.device : (Device, Device) -> (bool)")
|
||||
emit("aten::ceil.float : (float) -> (int)", has_folder=True)
|
||||
|
|
|
@ -151,3 +151,43 @@ class CeilFloatModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: CeilFloatModule())
|
||||
def CeilFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(()).double(), torch.rand(()).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SqrtIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return float(torch.ops.aten.sqrt(int(a)))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SqrtIntModule())
|
||||
def SqrtIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, ()))
|
||||
|
||||
|
||||
class SqrtIntConstantModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return float(torch.ops.aten.sqrt(5))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SqrtIntConstantModule())
|
||||
def SqrtIntConstantModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
|
|
@ -233,3 +233,15 @@ func.func @torch.aten.gt.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !t
|
|||
%0 = torch.aten.gt.float_int %arg0, %arg1 : !torch.float, !torch.int -> !torch.bool
|
||||
return %0 : !torch.bool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.sqrt.int(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.float {
|
||||
// CHECK: %[[ARG_I64:.*]] = torch_c.to_i64 %[[ARG]]
|
||||
// CHECK: %[[ARG_F64:.*]] = arith.sitofp %[[ARG_I64]] : i64 to f64
|
||||
// CHECK: %[[SQRT:.*]] = math.sqrt %[[ARG_F64]] : f64
|
||||
// CHECK: %[[SQRT_TORCH_FLOAT:.*]] = torch_c.from_f64 %[[SQRT]]
|
||||
// CHECK: return %[[SQRT_TORCH_FLOAT]] : !torch.float
|
||||
func.func @torch.aten.sqrt.int(%arg0: !torch.int) -> !torch.float {
|
||||
%0 = torch.aten.sqrt.int %arg0 : !torch.int -> !torch.float
|
||||
return %0 : !torch.float
|
||||
}
|
||||
|
|
|
@ -1231,3 +1231,21 @@ func.func @torch.aten.ceil.float$no_fold(%arg0 : !torch.float) -> !torch.int {
|
|||
%1 = torch.aten.ceil.float %arg0 : !torch.float -> !torch.int
|
||||
return %1 : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.sqrt.int$fold_cst() -> !torch.float {
|
||||
// CHECK: %[[CST:.*]] = torch.constant.float 2.2360679774997898
|
||||
// CHECK: return %[[CST]] : !torch.float
|
||||
func.func @torch.aten.sqrt.int$fold_cst() -> !torch.float {
|
||||
%int = torch.constant.int 5
|
||||
%0 = torch.aten.sqrt.int %int : !torch.int -> !torch.float
|
||||
return %0 : !torch.float
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.sqrt.int$no_fold(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.float {
|
||||
// CHECK: %[[RESULT:.*]] = torch.aten.sqrt.int %[[ARG]] : !torch.int -> !torch.float
|
||||
// CHECK: return %[[RESULT]] : !torch.float
|
||||
func.func @torch.aten.sqrt.int$no_fold(%arg0 : !torch.int) -> !torch.float {
|
||||
%0 = torch.aten.sqrt.int %arg0 : !torch.int -> !torch.float
|
||||
return %0 : !torch.float
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue