[MLIR][TORCH] Add E2E support for aten.ceil.float op

This commit adds lowering of `aten.ceil.float` op.
This commit also fixes formatting for the file scalar.py.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/811/head
Vivek Khandelwal 2022-04-27 17:27:14 +05:30
parent 86eb493a44
commit 78f5747568
7 changed files with 136 additions and 9 deletions

View File

@ -6883,6 +6883,30 @@ def Torch_AtenEqDeviceOp : Torch_Op<"aten.eq.device", [
}]; }];
} }
def Torch_AtenCeilFloatOp : Torch_Op<"aten.ceil.float", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::ceil.float : (float) -> (int)`";
let arguments = (ins
Torch_FloatType:$a
);
let results = (outs
Torch_IntType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenCeilFloatOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenCeilFloatOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasFolder = 1;
}
def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [ def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -13,6 +13,7 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h" #include "mlir/Dialect/Traits.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
@ -77,6 +78,25 @@ public:
}; };
} // namespace } // namespace
namespace {
template <typename AtenOp, typename UnaryOp>
class ConvertAtenUnaryOp : public OpConversionPattern<AtenOp> {
public:
using OpConversionPattern<AtenOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenOp op,
typename OpConversionPattern<AtenOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type resultType =
this->getTypeConverter()->convertType(op->getResult(0).getType());
Value result = rewriter.create<UnaryOp>(op.getLoc(), adaptor.a());
rewriter.replaceOp(
op, convertScalarToDtype(rewriter, op.getLoc(), result, resultType));
return success();
}
};
} // namespace
namespace { namespace {
// Lowers aten integer comparison ops. // Lowers aten integer comparison ops.
template <typename AtenOp, arith::CmpIPredicate Pred> template <typename AtenOp, arith::CmpIPredicate Pred>
@ -182,6 +202,7 @@ public:
registry.insert<arith::ArithmeticDialect>(); registry.insert<arith::ArithmeticDialect>();
registry.insert<tensor::TensorDialect>(); registry.insert<tensor::TensorDialect>();
registry.insert<cf::ControlFlowDialect>(); registry.insert<cf::ControlFlowDialect>();
registry.insert<math::MathDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry); TorchConversion::getBackendTypeConversionDependentDialects(registry);
} }
@ -190,7 +211,7 @@ public:
ConversionTarget target(*context); ConversionTarget target(*context);
target.addLegalDialect<Torch::TorchDialect, func::FuncDialect, target.addLegalDialect<Torch::TorchDialect, func::FuncDialect,
arith::ArithmeticDialect, tensor::TensorDialect, arith::ArithmeticDialect, tensor::TensorDialect,
cf::ControlFlowDialect>(); cf::ControlFlowDialect, math::MathDialect>();
TypeConverter typeConverter; TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; }); typeConverter.addConversion([](Type type) { return type; });
@ -246,6 +267,9 @@ public:
target.addIllegalOp<AtenDivFloatOp>(); target.addIllegalOp<AtenDivFloatOp>();
patterns.add<ConvertAtenBinaryOp<AtenDivFloatOp, arith::DivFOp>>( patterns.add<ConvertAtenBinaryOp<AtenDivFloatOp, arith::DivFOp>>(
typeConverter, context); typeConverter, context);
target.addIllegalOp<AtenCeilFloatOp>();
patterns.add<ConvertAtenUnaryOp<AtenCeilFloatOp, math::CeilOp>>(
typeConverter, context);
if (failed(applyPartialConversion(getOperation(), target, if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) std::move(patterns))))

View File

@ -1545,6 +1545,16 @@ OpFoldResult AtenDivFloatOp::fold(ArrayRef<Attribute> operands) {
return nullptr; return nullptr;
} }
// AtenCeilFloatOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenCeilFloatOp::fold(ArrayRef<Attribute> operands) {
double c;
if (matchPattern(getOperand(), m_TorchConstantFloat(&c)))
return getI64IntegerAttr(getContext(), std::ceil(c));
return nullptr;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -509,6 +509,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::_set_item.t : (t[], int, t) -> (t[])") emit("aten::_set_item.t : (t[], int, t) -> (t[])")
emit("aten::div : (Scalar, Scalar) -> (float)") emit("aten::div : (Scalar, Scalar) -> (float)")
emit("aten::eq.device : (Device, Device) -> (bool)") emit("aten::eq.device : (Device, Device) -> (bool)")
emit("aten::ceil.float : (float) -> (int)", has_folder=True)
# backprop ops # backprop ops
emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)") emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)")

View File

@ -11,7 +11,9 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ============================================================================== # ==============================================================================
class AddIntModule(torch.nn.Module): class AddIntModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -22,16 +24,19 @@ class AddIntModule(torch.nn.Module):
([], torch.int64, True), ([], torch.int64, True),
]) ])
def forward(self, lhs, rhs): def forward(self, lhs, rhs):
return int(lhs)+int(rhs) return int(lhs) + int(rhs)
@register_test_case(module_factory=lambda: AddIntModule()) @register_test_case(module_factory=lambda: AddIntModule())
def AddIntModule_basic(module, tu: TestUtils): def AddIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-100, 100,()), torch.randint(-100, 100,())) module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ()))
# ============================================================================== # ==============================================================================
class SubIntModule(torch.nn.Module): class SubIntModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -42,16 +47,19 @@ class SubIntModule(torch.nn.Module):
([], torch.int64, True), ([], torch.int64, True),
]) ])
def forward(self, lhs, rhs): def forward(self, lhs, rhs):
return int(lhs)-int(rhs) return int(lhs) - int(rhs)
@register_test_case(module_factory=lambda: SubIntModule()) @register_test_case(module_factory=lambda: SubIntModule())
def SubIntModule_basic(module, tu: TestUtils): def SubIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-100, 100,()), torch.randint(-100, 100,())) module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ()))
# ============================================================================== # ==============================================================================
class SubFloatModule(torch.nn.Module): class SubFloatModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -62,16 +70,19 @@ class SubFloatModule(torch.nn.Module):
([], torch.float64, True), ([], torch.float64, True),
]) ])
def forward(self, lhs, rhs): def forward(self, lhs, rhs):
return float(lhs)-float(rhs) return float(lhs) - float(rhs)
@register_test_case(module_factory=lambda: SubFloatModule()) @register_test_case(module_factory=lambda: SubFloatModule())
def SubFloatModule_basic(module, tu: TestUtils): def SubFloatModule_basic(module, tu: TestUtils):
module.forward(torch.rand(()).double(), torch.rand(()).double()) module.forward(torch.rand(()).double(), torch.rand(()).double())
# ============================================================================== # ==============================================================================
class MulIntModule(torch.nn.Module): class MulIntModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -82,16 +93,19 @@ class MulIntModule(torch.nn.Module):
([], torch.int64, True), ([], torch.int64, True),
]) ])
def forward(self, lhs, rhs): def forward(self, lhs, rhs):
return int(lhs)*int(rhs) return int(lhs) * int(rhs)
@register_test_case(module_factory=lambda: MulIntModule()) @register_test_case(module_factory=lambda: MulIntModule())
def MulIntModule_basic(module, tu: TestUtils): def MulIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-100, 100,()), torch.randint(-100, 100,())) module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ()))
# ============================================================================== # ==============================================================================
class DivFloatModule(torch.nn.Module): class DivFloatModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -102,9 +116,33 @@ class DivFloatModule(torch.nn.Module):
([], torch.float64, True), ([], torch.float64, True),
]) ])
def forward(self, lhs, rhs): def forward(self, lhs, rhs):
return float(lhs)/float(rhs) return float(lhs) / float(rhs)
@register_test_case(module_factory=lambda: DivFloatModule()) @register_test_case(module_factory=lambda: DivFloatModule())
def DivFloatModule_basic(module, tu: TestUtils): def DivFloatModule_basic(module, tu: TestUtils):
module.forward(torch.rand(()).double(), torch.rand(()).double()) module.forward(torch.rand(()).double(), torch.rand(()).double())
# ==============================================================================
class CeilFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([], torch.float64, True),
([], torch.float64, True),
])
def forward(self, lhs, rhs):
sub = float(lhs) - float(rhs)
return torch.ops.aten.ceil(float(sub))
@register_test_case(module_factory=lambda: CeilFloatModule())
def CeilFloatModule_basic(module, tu: TestUtils):
module.forward(torch.rand(()).double(), torch.rand(()).double())

View File

@ -207,3 +207,15 @@ func @torch.aten.ne.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.
%0 = torch.aten.ne.float_int %arg0, %arg1 : !torch.float, !torch.int -> !torch.bool %0 = torch.aten.ne.float_int %arg0, %arg1 : !torch.float, !torch.int -> !torch.bool
return %0 : !torch.bool return %0 : !torch.bool
} }
// CHECK-LABEL: func @torch.aten.ceil.float(
// CHECK-SAME: %[[ARG:.*]]: !torch.float) -> !torch.int {
// CHECK: %[[ARG_F64:.*]] = torch_c.to_f64 %[[ARG]]
// CHECK: %[[CEIL:.*]] = math.ceil %[[ARG_F64]] : f64
// CHECK: %[[CEIL_I64:.*]] = arith.fptosi %[[CEIL]] : f64 to i64
// CHECK: %[[OUT:.*]] = torch_c.from_i64 %[[CEIL_I64]]
// CHECK: return %[[OUT]] : !torch.int
func @torch.aten.ceil.float(%arg0: !torch.float) -> !torch.int {
%0 = torch.aten.ceil.float %arg0 : !torch.float -> !torch.int
return %0 : !torch.int
}

View File

@ -1191,3 +1191,21 @@ func @torch.aten.ge.float$different_value() -> !torch.bool {
%2 = torch.aten.ge.float %float4, %float4_0: !torch.float, !torch.float -> !torch.bool %2 = torch.aten.ge.float %float4, %float4_0: !torch.float, !torch.float -> !torch.bool
return %2 : !torch.bool return %2 : !torch.bool
} }
// CHECK-LABEL: func @torch.aten.ceil.float$fold_cst() -> !torch.int {
// CHECK: %[[CST2:.*]] = torch.constant.int 2
// CHECK: return %[[CST2]] : !torch.int
func @torch.aten.ceil.float$fold_cst() -> !torch.int {
%float = torch.constant.float 1.5
%1 = torch.aten.ceil.float %float : !torch.float -> !torch.int
return %1 : !torch.int
}
// CHECK-LABEL: func @torch.aten.ceil.float$no_fold(
// CHECK-SAME: %[[ARG:.*]]: !torch.float) -> !torch.int {
// CHECK: %[[RESULT:.*]] = torch.aten.ceil.float %[[ARG]] : !torch.float -> !torch.int
// CHECK: return %[[RESULT]] : !torch.int
func @torch.aten.ceil.float$no_fold(%arg0 : !torch.float) -> !torch.int {
%1 = torch.aten.ceil.float %arg0 : !torch.float -> !torch.int
return %1 : !torch.int
}