mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] E2E support for [ge|ceil].float, [ge|ne|gt].float_int op
This commit adds lowering of `aten.ge.float`, `aten.ge.float_int`, `aten.ne.float_int`, `aten.gt.float_int` and `aten.ceil.float` op. This commit also fixes formatting for the file scalar.py and scalar_comparison.py. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/821/head snapshot-20220505.433
parent
e682b1d0f3
commit
96fabc0036
|
@ -6678,6 +6678,31 @@ def Torch_AtenGtFloatOp : Torch_Op<"aten.gt.float", [
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenGeFloatOp : Torch_Op<"aten.ge.float", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::ge.float : (float, float) -> (bool)`";
|
||||||
|
let arguments = (ins
|
||||||
|
Torch_FloatType:$a,
|
||||||
|
Torch_FloatType:$b
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
Torch_BoolType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenGeFloatOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||||
|
}
|
||||||
|
void AtenGeFloatOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
let hasFolder = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenLtFloatOp : Torch_Op<"aten.lt.float", [
|
def Torch_AtenLtFloatOp : Torch_Op<"aten.lt.float", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
@ -6727,6 +6752,78 @@ def Torch_AtenLtFloatIntOp : Torch_Op<"aten.lt.float_int", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenGeFloatIntOp : Torch_Op<"aten.ge.float_int", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::ge.float_int : (float, int) -> (bool)`";
|
||||||
|
let arguments = (ins
|
||||||
|
Torch_FloatType:$a,
|
||||||
|
Torch_IntType:$b
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
Torch_BoolType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenGeFloatIntOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||||
|
}
|
||||||
|
void AtenGeFloatIntOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_AtenNeFloatIntOp : Torch_Op<"aten.ne.float_int", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::ne.float_int : (float, int) -> (bool)`";
|
||||||
|
let arguments = (ins
|
||||||
|
Torch_FloatType:$a,
|
||||||
|
Torch_IntType:$b
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
Torch_BoolType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenNeFloatIntOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||||
|
}
|
||||||
|
void AtenNeFloatIntOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_AtenGtFloatIntOp : Torch_Op<"aten.gt.float_int", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::gt.float_int : (float, int) -> (bool)`";
|
||||||
|
let arguments = (ins
|
||||||
|
Torch_FloatType:$a,
|
||||||
|
Torch_IntType:$b
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
Torch_BoolType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenGtFloatIntOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||||
|
}
|
||||||
|
void AtenGtFloatIntOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_Aten__And__BoolOp : Torch_Op<"aten.__and__.bool", [
|
def Torch_Aten__And__BoolOp : Torch_Op<"aten.__and__.bool", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
@ -6970,6 +7067,30 @@ def Torch_AtenEqDeviceOp : Torch_Op<"aten.eq.device", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenCeilFloatOp : Torch_Op<"aten.ceil.float", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::ceil.float : (float) -> (int)`";
|
||||||
|
let arguments = (ins
|
||||||
|
Torch_FloatType:$a
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
Torch_IntType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenCeilFloatOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||||
|
}
|
||||||
|
void AtenCeilFloatOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 1, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
let hasFolder = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [
|
def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -13,9 +13,11 @@
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Dialect/Math/IR/Math.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/Traits.h"
|
#include "mlir/Dialect/Traits.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
||||||
|
@ -76,6 +78,25 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <typename AtenOp, typename UnaryOp>
|
||||||
|
class ConvertAtenUnaryOp : public OpConversionPattern<AtenOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<AtenOp>::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenOp op,
|
||||||
|
typename OpConversionPattern<AtenOp>::OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
Type resultType =
|
||||||
|
this->getTypeConverter()->convertType(op->getResult(0).getType());
|
||||||
|
Value result = rewriter.create<UnaryOp>(op.getLoc(), adaptor.a());
|
||||||
|
rewriter.replaceOp(
|
||||||
|
op, convertScalarToDtype(rewriter, op.getLoc(), result, resultType));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Lowers aten integer comparison ops.
|
// Lowers aten integer comparison ops.
|
||||||
template <typename AtenOp, arith::CmpIPredicate Pred>
|
template <typename AtenOp, arith::CmpIPredicate Pred>
|
||||||
|
@ -93,6 +114,24 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Lowers aten float and float_int comparison ops.
|
||||||
|
template <typename AtenOp, arith::CmpFPredicate Pred>
|
||||||
|
class ConvertAtenFloatComparisonOp : public OpConversionPattern<AtenOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<AtenOp>::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenOp op,
|
||||||
|
typename OpConversionPattern<AtenOp>::OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
Value lhs = adaptor.a(), rhs = adaptor.b();
|
||||||
|
rhs = convertScalarToDtype(rewriter, op.getLoc(), rhs, lhs.getType());
|
||||||
|
rewriter.replaceOpWithNewOp<arith::CmpFOp>(op, Pred, lhs, rhs);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Tensors with integer types need to be converted to signless integer
|
// Tensors with integer types need to be converted to signless integer
|
||||||
// element type. All tensors with element types other than integer can reuse
|
// element type. All tensors with element types other than integer can reuse
|
||||||
// existing elements attribute.
|
// existing elements attribute.
|
||||||
|
@ -163,6 +202,7 @@ public:
|
||||||
registry.insert<arith::ArithmeticDialect>();
|
registry.insert<arith::ArithmeticDialect>();
|
||||||
registry.insert<tensor::TensorDialect>();
|
registry.insert<tensor::TensorDialect>();
|
||||||
registry.insert<cf::ControlFlowDialect>();
|
registry.insert<cf::ControlFlowDialect>();
|
||||||
|
registry.insert<math::MathDialect>();
|
||||||
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -171,7 +211,7 @@ public:
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
target.addLegalDialect<Torch::TorchDialect, func::FuncDialect,
|
target.addLegalDialect<Torch::TorchDialect, func::FuncDialect,
|
||||||
arith::ArithmeticDialect, tensor::TensorDialect,
|
arith::ArithmeticDialect, tensor::TensorDialect,
|
||||||
cf::ControlFlowDialect>();
|
cf::ControlFlowDialect, math::MathDialect>();
|
||||||
|
|
||||||
TypeConverter typeConverter;
|
TypeConverter typeConverter;
|
||||||
typeConverter.addConversion([](Type type) { return type; });
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
|
@ -192,6 +232,20 @@ public:
|
||||||
patterns.add<
|
patterns.add<
|
||||||
ConvertAtenIntComparisonOp<AtenGtIntOp, arith::CmpIPredicate::sgt>>(
|
ConvertAtenIntComparisonOp<AtenGtIntOp, arith::CmpIPredicate::sgt>>(
|
||||||
typeConverter, context);
|
typeConverter, context);
|
||||||
|
target.addIllegalOp<AtenGeFloatOp, AtenGeFloatIntOp, AtenNeFloatIntOp,
|
||||||
|
AtenGtFloatIntOp>();
|
||||||
|
patterns.add<
|
||||||
|
ConvertAtenFloatComparisonOp<AtenGeFloatOp, arith::CmpFPredicate::UGE>>(
|
||||||
|
typeConverter, context);
|
||||||
|
patterns.add<ConvertAtenFloatComparisonOp<AtenGeFloatIntOp,
|
||||||
|
arith::CmpFPredicate::UGE>>(
|
||||||
|
typeConverter, context);
|
||||||
|
patterns.add<ConvertAtenFloatComparisonOp<AtenNeFloatIntOp,
|
||||||
|
arith::CmpFPredicate::UNE>>(
|
||||||
|
typeConverter, context);
|
||||||
|
patterns.add<ConvertAtenFloatComparisonOp<AtenGtFloatIntOp,
|
||||||
|
arith::CmpFPredicate::UGT>>(
|
||||||
|
typeConverter, context);
|
||||||
target.addIllegalOp<ValueTensorLiteralOp>();
|
target.addIllegalOp<ValueTensorLiteralOp>();
|
||||||
patterns.add<ConvertTorchTensorLiteralOp>(typeConverter, context);
|
patterns.add<ConvertTorchTensorLiteralOp>(typeConverter, context);
|
||||||
|
|
||||||
|
@ -217,6 +271,9 @@ public:
|
||||||
target.addIllegalOp<AtenDivFloatOp>();
|
target.addIllegalOp<AtenDivFloatOp>();
|
||||||
patterns.add<ConvertAtenBinaryOp<AtenDivFloatOp, arith::DivFOp>>(
|
patterns.add<ConvertAtenBinaryOp<AtenDivFloatOp, arith::DivFOp>>(
|
||||||
typeConverter, context);
|
typeConverter, context);
|
||||||
|
target.addIllegalOp<AtenCeilFloatOp>();
|
||||||
|
patterns.add<ConvertAtenUnaryOp<AtenCeilFloatOp, math::CeilOp>>(
|
||||||
|
typeConverter, context);
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
std::move(patterns))))
|
std::move(patterns))))
|
||||||
|
|
|
@ -857,6 +857,15 @@ OpFoldResult AtenGtFloatOp::fold(ArrayRef<Attribute> operands) {
|
||||||
[](double a, double b) { return a > b; });
|
[](double a, double b) { return a > b; });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AtenGeFloatOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult AtenGeFloatOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
return floatComparatorFoldHelper(*this,
|
||||||
|
[](double a, double b) { return a >= b; });
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AtenEqFloatOp
|
// AtenEqFloatOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1604,6 +1613,16 @@ OpFoldResult AtenDivFloatOp::fold(ArrayRef<Attribute> operands) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AtenCeilFloatOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult AtenCeilFloatOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
double c;
|
||||||
|
if (matchPattern(getOperand(), m_TorchConstantFloat(&c)))
|
||||||
|
return getI64IntegerAttr(getContext(), std::ceil(c));
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -501,8 +501,12 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::neg.float : (float) -> (float)")
|
emit("aten::neg.float : (float) -> (float)")
|
||||||
emit("aten::eq.float : (float, float) -> (bool)", has_folder=True)
|
emit("aten::eq.float : (float, float) -> (bool)", has_folder=True)
|
||||||
emit("aten::gt.float : (float, float) -> (bool)", has_folder=True)
|
emit("aten::gt.float : (float, float) -> (bool)", has_folder=True)
|
||||||
|
emit("aten::ge.float : (float, float) -> (bool)", has_folder=True)
|
||||||
emit("aten::lt.float : (float, float) -> (bool)", has_folder=True)
|
emit("aten::lt.float : (float, float) -> (bool)", has_folder=True)
|
||||||
emit("aten::lt.float_int : (float, int) -> (bool)")
|
emit("aten::lt.float_int : (float, int) -> (bool)")
|
||||||
|
emit("aten::ge.float_int : (float, int) -> (bool)")
|
||||||
|
emit("aten::ne.float_int : (float, int) -> (bool)")
|
||||||
|
emit("aten::gt.float_int : (float, int) -> (bool)")
|
||||||
emit("aten::__and__.bool : (bool, bool) -> (bool)")
|
emit("aten::__and__.bool : (bool, bool) -> (bool)")
|
||||||
emit("aten::ne.bool : (bool, bool) -> (bool)", has_folder=True)
|
emit("aten::ne.bool : (bool, bool) -> (bool)", has_folder=True)
|
||||||
emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True)
|
emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True)
|
||||||
|
@ -515,6 +519,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::_set_item.t : (t[], int, t) -> (t[])")
|
emit("aten::_set_item.t : (t[], int, t) -> (t[])")
|
||||||
emit("aten::div : (Scalar, Scalar) -> (float)")
|
emit("aten::div : (Scalar, Scalar) -> (float)")
|
||||||
emit("aten::eq.device : (Device, Device) -> (bool)")
|
emit("aten::eq.device : (Device, Device) -> (bool)")
|
||||||
|
emit("aten::ceil.float : (float) -> (int)", has_folder=True)
|
||||||
|
|
||||||
# backprop ops
|
# backprop ops
|
||||||
emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)")
|
emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)")
|
||||||
|
|
|
@ -11,7 +11,9 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class AddIntModule(torch.nn.Module):
|
class AddIntModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -22,16 +24,19 @@ class AddIntModule(torch.nn.Module):
|
||||||
([], torch.int64, True),
|
([], torch.int64, True),
|
||||||
])
|
])
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
return int(lhs)+int(rhs)
|
return int(lhs) + int(rhs)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: AddIntModule())
|
@register_test_case(module_factory=lambda: AddIntModule())
|
||||||
def AddIntModule_basic(module, tu: TestUtils):
|
def AddIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randint(-100, 100,()), torch.randint(-100, 100,()))
|
module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ()))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class SubIntModule(torch.nn.Module):
|
class SubIntModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -42,16 +47,19 @@ class SubIntModule(torch.nn.Module):
|
||||||
([], torch.int64, True),
|
([], torch.int64, True),
|
||||||
])
|
])
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
return int(lhs)-int(rhs)
|
return int(lhs) - int(rhs)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: SubIntModule())
|
@register_test_case(module_factory=lambda: SubIntModule())
|
||||||
def SubIntModule_basic(module, tu: TestUtils):
|
def SubIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randint(-100, 100,()), torch.randint(-100, 100,()))
|
module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ()))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class SubFloatModule(torch.nn.Module):
|
class SubFloatModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -62,16 +70,19 @@ class SubFloatModule(torch.nn.Module):
|
||||||
([], torch.float64, True),
|
([], torch.float64, True),
|
||||||
])
|
])
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
return float(lhs)-float(rhs)
|
return float(lhs) - float(rhs)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: SubFloatModule())
|
@register_test_case(module_factory=lambda: SubFloatModule())
|
||||||
def SubFloatModule_basic(module, tu: TestUtils):
|
def SubFloatModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.rand(()).double(), torch.rand(()).double())
|
module.forward(torch.rand(()).double(), torch.rand(()).double())
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class MulIntModule(torch.nn.Module):
|
class MulIntModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -82,16 +93,19 @@ class MulIntModule(torch.nn.Module):
|
||||||
([], torch.int64, True),
|
([], torch.int64, True),
|
||||||
])
|
])
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
return int(lhs)*int(rhs)
|
return int(lhs) * int(rhs)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: MulIntModule())
|
@register_test_case(module_factory=lambda: MulIntModule())
|
||||||
def MulIntModule_basic(module, tu: TestUtils):
|
def MulIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randint(-100, 100,()), torch.randint(-100, 100,()))
|
module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ()))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class DivFloatModule(torch.nn.Module):
|
class DivFloatModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -102,9 +116,33 @@ class DivFloatModule(torch.nn.Module):
|
||||||
([], torch.float64, True),
|
([], torch.float64, True),
|
||||||
])
|
])
|
||||||
def forward(self, lhs, rhs):
|
def forward(self, lhs, rhs):
|
||||||
return float(lhs)/float(rhs)
|
return float(lhs) / float(rhs)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: DivFloatModule())
|
@register_test_case(module_factory=lambda: DivFloatModule())
|
||||||
def DivFloatModule_basic(module, tu: TestUtils):
|
def DivFloatModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.rand(()).double(), torch.rand(()).double())
|
module.forward(torch.rand(()).double(), torch.rand(()).double())
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class CeilFloatModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([], torch.float64, True),
|
||||||
|
([], torch.float64, True),
|
||||||
|
])
|
||||||
|
def forward(self, lhs, rhs):
|
||||||
|
sub = float(lhs) - float(rhs)
|
||||||
|
return torch.ops.aten.ceil(float(sub))
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: CeilFloatModule())
|
||||||
|
def CeilFloatModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.rand(()).double(), torch.rand(()).double())
|
||||||
|
|
|
@ -11,7 +11,9 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class NeIntModule(torch.nn.Module):
|
class NeIntModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -29,9 +31,12 @@ class NeIntModule(torch.nn.Module):
|
||||||
def NeIntModule_basic(module, tu: TestUtils):
|
def NeIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ()))
|
module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ()))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class EqIntModule(torch.nn.Module):
|
class EqIntModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -49,9 +54,12 @@ class EqIntModule(torch.nn.Module):
|
||||||
def EqIntModule_basic(module, tu: TestUtils):
|
def EqIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ()))
|
module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ()))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class GtIntModule(torch.nn.Module):
|
class GtIntModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -69,3 +77,94 @@ class GtIntModule(torch.nn.Module):
|
||||||
def GtIntModule_basic(module, tu: TestUtils):
|
def GtIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ()))
|
module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ()))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class GeFloatModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([], torch.float64, True),
|
||||||
|
([], torch.float64, True),
|
||||||
|
])
|
||||||
|
def forward(self, lhs, rhs):
|
||||||
|
return float(lhs) >= float(rhs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: GeFloatModule())
|
||||||
|
def GeFloatModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randn(()).double(), torch.randn(()).double())
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class GeFloatIntModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([], torch.float64, True),
|
||||||
|
([], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, lhs, rhs):
|
||||||
|
return float(lhs) >= int(rhs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: GeFloatIntModule())
|
||||||
|
def GeFloatIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randn(()).double(), torch.randint(-100, 100, ()))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class NeFloatIntModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([], torch.float64, True),
|
||||||
|
([], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, lhs, rhs):
|
||||||
|
return float(lhs) != int(rhs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NeFloatIntModule())
|
||||||
|
def NeFloatIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randn(()).double(), torch.randint(-100, 100, ()))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class GtFloatIntModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([], torch.float64, True),
|
||||||
|
([], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, lhs, rhs):
|
||||||
|
return float(lhs) > int(rhs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: GtFloatIntModule())
|
||||||
|
def GtFloatIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randn(()).double(), torch.randint(-100, 100, ()))
|
||||||
|
|
|
@ -166,3 +166,70 @@ func @torch.aten.div.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.f
|
||||||
%0 = torch.aten.div.float %arg0, %arg1 : !torch.float, !torch.float -> !torch.float
|
%0 = torch.aten.div.float %arg0, %arg1 : !torch.float, !torch.float -> !torch.float
|
||||||
return %0 : !torch.float
|
return %0 : !torch.float
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.ge.float(
|
||||||
|
// CHECK-SAME: %[[LHS:.*]]: !torch.float,
|
||||||
|
// CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.bool {
|
||||||
|
// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]]
|
||||||
|
// CHECK: %[[RHS_F64:.*]] = torch_c.to_f64 %[[RHS]]
|
||||||
|
// CHECK: %[[CMP:.*]] = arith.cmpf uge, %[[LHS_F64]], %[[RHS_F64]] : f64
|
||||||
|
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]]
|
||||||
|
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
|
||||||
|
func @torch.aten.ge.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.bool {
|
||||||
|
%0 = torch.aten.ge.float %arg0, %arg1 : !torch.float, !torch.float -> !torch.bool
|
||||||
|
return %0 : !torch.bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.ge.float_int(
|
||||||
|
// CHECK-SAME: %[[LHS:.*]]: !torch.float,
|
||||||
|
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
|
||||||
|
// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]]
|
||||||
|
// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]]
|
||||||
|
// CHECK: %[[RHS_F64:.*]] = arith.sitofp %[[RHS_I64]] : i64 to f64
|
||||||
|
// CHECK: %[[CMP:.*]] = arith.cmpf uge, %[[LHS_F64]], %[[RHS_F64]] : f64
|
||||||
|
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]]
|
||||||
|
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
|
||||||
|
func @torch.aten.ge.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.bool {
|
||||||
|
%0 = torch.aten.ge.float_int %arg0, %arg1 : !torch.float, !torch.int -> !torch.bool
|
||||||
|
return %0 : !torch.bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.ne.float_int(
|
||||||
|
// CHECK-SAME: %[[LHS:.*]]: !torch.float,
|
||||||
|
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
|
||||||
|
// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]]
|
||||||
|
// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]]
|
||||||
|
// CHECK: %[[RHS_F64:.*]] = arith.sitofp %[[RHS_I64]] : i64 to f64
|
||||||
|
// CHECK: %[[CMP:.*]] = arith.cmpf une, %[[LHS_F64]], %[[RHS_F64]] : f64
|
||||||
|
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]]
|
||||||
|
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
|
||||||
|
func @torch.aten.ne.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.bool {
|
||||||
|
%0 = torch.aten.ne.float_int %arg0, %arg1 : !torch.float, !torch.int -> !torch.bool
|
||||||
|
return %0 : !torch.bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.ceil.float(
|
||||||
|
// CHECK-SAME: %[[ARG:.*]]: !torch.float) -> !torch.int {
|
||||||
|
// CHECK: %[[ARG_F64:.*]] = torch_c.to_f64 %[[ARG]]
|
||||||
|
// CHECK: %[[CEIL:.*]] = math.ceil %[[ARG_F64]] : f64
|
||||||
|
// CHECK: %[[CEIL_I64:.*]] = arith.fptosi %[[CEIL]] : f64 to i64
|
||||||
|
// CHECK: %[[OUT:.*]] = torch_c.from_i64 %[[CEIL_I64]]
|
||||||
|
// CHECK: return %[[OUT]] : !torch.int
|
||||||
|
func @torch.aten.ceil.float(%arg0: !torch.float) -> !torch.int {
|
||||||
|
%0 = torch.aten.ceil.float %arg0 : !torch.float -> !torch.int
|
||||||
|
return %0 : !torch.int
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.gt.float_int(
|
||||||
|
// CHECK-SAME: %[[LHS:.*]]: !torch.float,
|
||||||
|
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
|
||||||
|
// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]]
|
||||||
|
// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]]
|
||||||
|
// CHECK: %[[RHS_F64:.*]] = arith.sitofp %[[RHS_I64]] : i64 to f64
|
||||||
|
// CHECK: %[[CMP:.*]] = arith.cmpf ugt, %[[LHS_F64]], %[[RHS_F64]] : f64
|
||||||
|
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]]
|
||||||
|
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
|
||||||
|
func @torch.aten.gt.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.bool {
|
||||||
|
%0 = torch.aten.gt.float_int %arg0, %arg1 : !torch.float, !torch.int -> !torch.bool
|
||||||
|
return %0 : !torch.bool
|
||||||
|
}
|
||||||
|
|
|
@ -1186,3 +1186,50 @@ func @torch.aten.to.dtype_layout$same_dtype(%arg0: !torch.tensor<[?,?],f32>) ->
|
||||||
%0 = torch.aten.to.dtype_layout %arg0, %int6, %none, %none, %none, %false, %false, %none : !torch.tensor<[?,?],f32>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.bool, !torch.none -> !torch.tensor<[?,?],f32>
|
%0 = torch.aten.to.dtype_layout %arg0, %int6, %none, %none, %none, %false, %false, %none : !torch.tensor<[?,?],f32>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.bool, !torch.none -> !torch.tensor<[?,?],f32>
|
||||||
return %0 : !torch.tensor<[?,?],f32>
|
return %0 : !torch.tensor<[?,?],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.ge.float$same_operand(
|
||||||
|
// CHECK-SAME: %{{.*}}: !torch.float) -> !torch.bool {
|
||||||
|
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||||
|
// CHECK: return %[[TRUE]] : !torch.bool
|
||||||
|
func @torch.aten.ge.float$same_operand(%arg0: !torch.float) -> !torch.bool {
|
||||||
|
%2 = torch.aten.ge.float %arg0, %arg0: !torch.float, !torch.float -> !torch.bool
|
||||||
|
return %2 : !torch.bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.ge.float$same_value() -> !torch.bool {
|
||||||
|
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||||
|
// CHECK: return %[[TRUE]] : !torch.bool
|
||||||
|
func @torch.aten.ge.float$same_value() -> !torch.bool {
|
||||||
|
%float4 = torch.constant.float 4.0
|
||||||
|
%float4_0 = torch.constant.float 4.0
|
||||||
|
%2 = torch.aten.ge.float %float4, %float4_0: !torch.float, !torch.float -> !torch.bool
|
||||||
|
return %2 : !torch.bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.ge.float$different_value() -> !torch.bool {
|
||||||
|
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: return %[[FALSE]] : !torch.bool
|
||||||
|
func @torch.aten.ge.float$different_value() -> !torch.bool {
|
||||||
|
%float4 = torch.constant.float 4.0
|
||||||
|
%float4_0 = torch.constant.float 5.0
|
||||||
|
%2 = torch.aten.ge.float %float4, %float4_0: !torch.float, !torch.float -> !torch.bool
|
||||||
|
return %2 : !torch.bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.ceil.float$fold_cst() -> !torch.int {
|
||||||
|
// CHECK: %[[CST2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: return %[[CST2]] : !torch.int
|
||||||
|
func @torch.aten.ceil.float$fold_cst() -> !torch.int {
|
||||||
|
%float = torch.constant.float 1.5
|
||||||
|
%1 = torch.aten.ceil.float %float : !torch.float -> !torch.int
|
||||||
|
return %1 : !torch.int
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.ceil.float$no_fold(
|
||||||
|
// CHECK-SAME: %[[ARG:.*]]: !torch.float) -> !torch.int {
|
||||||
|
// CHECK: %[[RESULT:.*]] = torch.aten.ceil.float %[[ARG]] : !torch.float -> !torch.int
|
||||||
|
// CHECK: return %[[RESULT]] : !torch.int
|
||||||
|
func @torch.aten.ceil.float$no_fold(%arg0 : !torch.float) -> !torch.int {
|
||||||
|
%1 = torch.aten.ceil.float %arg0 : !torch.float -> !torch.int
|
||||||
|
return %1 : !torch.int
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue