Added 2 Ops: Floor divide scalar and Floor divide scalar mode (#3156)

- Added linalg lowering for `AtenFloorDivideScalarOp`
  - Needed `AtenDivScalarModeOp` for the decomp.
- Added linalg lowering for `AtenDivScalarModeOp`
- Moved linalg payload logic to `createDivModePayload()` since the logic
was nearly identical for both `AtenDivScalarModeOp` and
`AtenDivTensorModeOp`. Just a template function
 -  Added `AtenDivScalarModeOp` lowering for stablehlo
 

Pytorch's
[`torch.floor_divide()`](https://pytorch.org/docs/stable/generated/torch.floor_divide.html)
in a previous version (for a reason unknown to me) preformed a
truncation instead of "floor". The already implemented op
`AtenFloorDivideTensorOp` was done before this change. However, this
wasn't caught because our testcases only tested positive floor division.
I changed this to floor as well as adding a few test cases.
pull/3167/head
IanWood1 2024-04-15 13:45:10 -07:00 committed by GitHub
parent 83cba8c696
commit 5708ee7ec9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 565 additions and 159 deletions

View File

@ -3397,6 +3397,56 @@ def Torch_AtenDiv_TensorModeOp : Torch_Op<"aten.div_.Tensor_mode", [
}];
}
def Torch_AtenDivScalarModeOp : Torch_Op<"aten.div.Scalar_mode", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::div.Scalar_mode : (Tensor, Scalar, str?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$other,
AnyTorchOptionalStringType:$rounding_mode
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenDivScalarModeOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenDivScalarModeOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasCanonicalizer = 1;
}
def Torch_AtenDiv_ScalarModeOp : Torch_Op<"aten.div_.Scalar_mode", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::div_.Scalar_mode : (Tensor, Scalar, str?) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self,
AnyTorchScalarType:$other,
AnyTorchOptionalStringType:$rounding_mode
);
let results = (outs
AnyTorchOptionalNonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenDiv_ScalarModeOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenDiv_ScalarModeOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenMulTensorOp : Torch_Op<"aten.mul.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -26,6 +26,7 @@
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/APSInt.h"
#include <numeric>
#include <type_traits>
using namespace mlir;
using namespace mlir::torch;
@ -213,6 +214,78 @@ createTriangularMatrix(OpBuilder &b, Location loc, ValueRange payloadArgs,
return success();
}
template <typename OpT>
Value createDivModePayload(OpBuilder &b, Location loc,
const TypeConverter *converter,
ValueRange payloadArgs, OpT op,
ArrayRef<Value> operands) {
static_assert(std::is_same_v<OpT, AtenDivTensorModeOp> ||
std::is_same_v<OpT, AtenDivScalarModeOp>,
"template type must be a tensor/scalar div mode");
typename OpT::Adaptor adaptor(operands);
Type dtype = cast<RankedTensorType>(converter->convertType(op.getType()))
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(
b, loc,
std::is_same_v<OpT, AtenDivScalarModeOp> ? operands[1] : payloadArgs[1],
dtype);
Value quotient;
if (isa<mlir::FloatType>(dtype)) {
quotient = b.create<arith::DivFOp>(loc, lhs, rhs);
} else if (dtype.isUnsignedInteger()) {
quotient = b.create<arith::DivUIOp>(loc, lhs, rhs);
} else {
assert(dtype.isInteger() &&
"dtype should be an integer (signless or signed)");
quotient = b.create<arith::DivSIOp>(loc, lhs, rhs);
}
if (isa<Torch::NoneType>(op.getRoundingMode().getType()))
return quotient;
std::string roundingMode;
if (!matchPattern(op.getRoundingMode(), m_TorchConstantStr(roundingMode))) {
op.emitError("only support constant str rounding mode");
return nullptr;
}
assert((roundingMode == "trunc" || roundingMode == "floor") &&
"unsupported rounding mode");
if (roundingMode == "trunc") {
// "trunc" - rounds the results of the division towards zero. Equivalent
// to C-style integer division.
if (!isa<mlir::FloatType>(dtype)) {
// nothing to do for integers
return quotient;
}
// float
Value ceil = b.create<math::CeilOp>(loc, quotient);
Value floor = b.create<math::FloorOp>(loc, quotient);
Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
quotient, cstZero);
return b.create<arith::SelectOp>(loc, pred, ceil, floor);
}
if (roundingMode == "floor") {
// "floor" - rounds the results of the division down. Equivalent to
// floor division in Python (the // operator)
if (isa<mlir::FloatType>(dtype))
return b.create<math::FloorOp>(loc, quotient);
if (!dtype.isUnsignedInteger()) {
Type defaultIntToFloatType = b.getF64Type();
lhs = convertScalarToDtype(b, loc, lhs, defaultIntToFloatType);
rhs = convertScalarToDtype(b, loc, rhs, defaultIntToFloatType);
quotient = b.create<arith::DivFOp>(loc, lhs, rhs);
Value floor = b.create<math::FloorOp>(loc, quotient);
Value convert = convertScalarToDtype(b, loc, floor, dtype);
return convert;
}
}
return quotient;
}
static Value createLinalgPayloadCalculationForElementwiseOp(
OpBuilder &b, Location loc, const TypeConverter *converter,
ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) {
@ -769,66 +842,14 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
div.emitError("unimplemented: non-floating point and non-integer dtype");
return nullptr;
}
if (auto divScalarMode = dyn_cast<AtenDivScalarModeOp>(op)) {
return createDivModePayload(b, loc, converter, payloadArgs, divScalarMode,
operands);
}
if (auto divTensorMode = dyn_cast<AtenDivTensorModeOp>(op)) {
AtenDivTensorModeOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(divTensorMode.getType())
.cast<RankedTensorType>()
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
Value div;
if (isa<mlir::FloatType>(dtype))
div = b.create<arith::DivFOp>(loc, lhs, rhs);
else {
if (dtype.isUnsignedInteger())
div = b.create<arith::DivUIOp>(loc, lhs, rhs);
else
div = b.create<arith::DivSIOp>(loc, lhs, rhs);
return createDivModePayload(b, loc, converter, payloadArgs, divTensorMode,
operands);
}
if (divTensorMode.getRoundingMode().getType().isa<Torch::NoneType>())
return div;
std::string roundingMode;
if (!matchPattern(divTensorMode.getRoundingMode(),
m_TorchConstantStr(roundingMode))) {
divTensorMode.emitError("only support constant str rounding mode");
return nullptr;
}
if (roundingMode == "trunc") {
// "trunc" - rounds the results of the division towards zero. Equivalent
// to C-style integer division.
if (isa<mlir::FloatType>(dtype)) {
Value ceil = b.create<math::CeilOp>(loc, div);
Value floor = b.create<math::FloorOp>(loc, div);
Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
div, cstZero);
return b.create<arith::SelectOp>(loc, pred, ceil, floor);
} else
return div;
}
if (roundingMode == "floor") {
// "floor" - rounds the results of the division down. Equivalent to
// floor division in Python (the // operator)
if (isa<mlir::FloatType>(dtype))
return b.create<math::FloorOp>(loc, div);
else if (!dtype.isUnsignedInteger()) {
Type defaultIntToFloatType = b.getF64Type();
lhs = convertScalarToDtype(b, loc, lhs, defaultIntToFloatType);
rhs = convertScalarToDtype(b, loc, rhs, defaultIntToFloatType);
div = b.create<arith::DivFOp>(loc, lhs, rhs);
Value floor = b.create<math::FloorOp>(loc, div);
Value convert = convertScalarToDtype(b, loc, floor, dtype);
return convert;
} else {
return div;
}
}
divTensorMode.emitError("invalid rounding mode");
return nullptr;
}
if (auto pow = dyn_cast<AtenPowScalarOp>(op)) {
Type dtype = pow.getType().cast<ValueTensorType>().getDtype();
if (!isa<mlir::FloatType>(dtype)) {
@ -1579,12 +1600,13 @@ public:
if (!isa<AtenTanOp, AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenReluOp,
AtenPreluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp,
AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp,
AtenSubTensorOp, AtenAtan2Op, AtenLerpTensorOp, AtenSigmoidOp,
AtenExpOp, AtenExpm1Op, AtenMinimumOp, AtenMaximumOp,
AtenToDtypeOp, AtenClampOp, AtenClampTensorOp, AtenRsubScalarOp,
AtenMulScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp,
AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp,
AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp,
AtenDivScalarModeOp, AtenSubTensorOp, AtenAtan2Op,
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenExpm1Op,
AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
AtenClampTensorOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp,
AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenPowScalarOp,
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op,
AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp,
AtenRemainderScalarOp, AtenRemainderTensorOp, AtenFmodTensorOp,
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
@ -2617,25 +2639,25 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenTanOp, AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenAtanhOp, AtenAcoshOp,
AtenAsinOp, AtenAsinhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp,
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp,
AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp,
AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenClampTensorOp,
AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp,
AtenCeilOp, AtenPreluOp, AtenPowScalarOp, AtenPowTensorScalarOp,
AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp,
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp,
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp,
AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp,
AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp,
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp,
AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenAcosOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, AtenTrilOp,
AtenRemainderScalarOp, AtenFmodTensorOp, AtenRemainderTensorOp,
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
AtenQuantizePerTensorOp>();
AtenDivScalarModeOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp,
AtenMinimumOp, AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
AtenClampTensorOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp,
AtenFloorOp, AtenCeilOp, AtenPreluOp, AtenPowScalarOp,
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op,
AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp,
AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp,
AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp,
AtenTrilOp, AtenRemainderScalarOp, AtenFmodTensorOp,
AtenRemainderTensorOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
AtenFillTensorOp, AtenRealOp, AtenImagOp, AtenDequantizeSelfOp,
AtenDequantizeTensorOp, AtenQuantizePerTensorOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenDetachOp>(typeConverter, context);

View File

@ -27,8 +27,8 @@
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include <iostream>
#include <numeric>
#include <type_traits>
using namespace mlir;
using namespace mlir::torch;
@ -409,9 +409,9 @@ public:
if (!lhsType)
return op.emitError("only Tensor types supported in StableHLO");
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
auto outType = cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) {
@ -432,18 +432,23 @@ public:
Value result =
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
if (!isa<AtenDivTensorModeOp>(op)) {
if (!std::is_same<AtenDivTensorModeOp, AtenOpT>() &&
!std::is_same<AtenDivScalarModeOp, AtenOpT>()) {
rewriter.replaceOp(op, result);
return success();
}
AtenDivTensorModeOp divTensorModeOp =
llvm::dyn_cast<AtenDivTensorModeOp>(op.getOperation());
auto tensorOp = dyn_cast<AtenDivTensorModeOp>(op.getOperation());
auto opRoundingMode =
tensorOp
? tensorOp.getRoundingMode()
: cast<AtenDivScalarModeOp>(op.getOperation()).getRoundingMode();
std::string roundingMode;
if (!matchPattern(divTensorModeOp.getRoundingMode(),
m_TorchConstantStr(roundingMode)))
if (!matchPattern(opRoundingMode, m_TorchConstantStr(roundingMode))) {
return rewriter.notifyMatchFailure(
op, "only support constant str rounding mode");
}
// if trunc and int, do nothing
if (roundingMode == "trunc" && isa<mlir::FloatType>(outElemTy)) {
@ -1845,6 +1850,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorOp, chlo::BroadcastDivOp);
INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorModeOp, chlo::BroadcastDivOp);
INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarOp, chlo::BroadcastDivOp);
INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarModeOp, chlo::BroadcastDivOp);
INSERT_BINARY_MULDIV_PATTERN(AtenRemainderScalarOp, chlo::BroadcastRemOp);
#undef INSERT_BINARY_MULDIV_PATTERN

View File

@ -1095,9 +1095,9 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
rhs = rewriter.create<AtenMulIntOp>(loc, rhs, alpha);
}
if (isa<AtenDivTensorModeOp>(op)) {
// None rounding mode
if (isa<AtenDivTensorModeOp, AtenDivScalarModeOp>(op)) {
if (op->getOperand(2).getType().isa<Torch::NoneType>()) {
// None rounding mode
Value quotient = rewriter.create<AtenDivOp>(loc, lhs, rhs);
rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(op, outType,
quotient);
@ -1858,6 +1858,16 @@ void AtenDivTensorModeOp::getCanonicalizationPatterns(
});
}
//===----------------------------------------------------------------------===//
// AtenDivScalarModeOp
//===----------------------------------------------------------------------===//
void AtenDivScalarModeOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add(+[](AtenDivScalarModeOp op, PatternRewriter &rewriter) {
return rewrite0DBinaryTensorOp(op, rewriter);
});
}
//===----------------------------------------------------------------------===//
// AtenNumelOp
//===----------------------------------------------------------------------===//

View File

@ -8496,6 +8496,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.div.Scalar_mode\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.optional<str>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.floor_divide\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -11166,6 +11170,83 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
" %5 = torch.prim.ListConstruct %0#1, %4 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %5) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %6 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.div.Scalar_mode\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.optional<str>) -> !torch.int {\n"
" %str = torch.constant.str \"trunc\"\n"
" %int6 = torch.constant.int 6\n"
" %true = torch.constant.bool true\n"
" %false = torch.constant.bool false\n"
" %str_0 = torch.constant.str \"floor\"\n"
" %none = torch.constant.none\n"
" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional<str>, !torch.none -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.bool) {\n"
" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional<str> -> !torch.str\n"
" %4 = torch.aten.eq.str %3, %str_0 : !torch.str, !torch.str -> !torch.bool\n"
" torch.prim.If.yield %4 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
" %3 = func.call @\"__torch_mlir_dtype_fn.aten.floor_divide.Scalar\"(%arg0, %arg1) : (!torch.tuple<int, int>, !torch.number) -> !torch.int\n"
" torch.prim.If.yield %3 : !torch.int\n"
" } else {\n"
" %3:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %4 = torch.prim.ListConstruct %none, %3#0 : (!torch.none, !torch.int) -> !torch.list<optional<int>>\n"
" %5 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
" %6 = torch.prim.ListConstruct %5, %3#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%4, %6) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" %8 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%7) : (!torch.int) -> !torch.bool\n"
" %9 = torch.prim.If %8 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %12 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%7) : (!torch.int) -> !torch.bool\n"
" %13 = torch.prim.If %12 -> (!torch.bool) {\n"
" %14 = torch.aten.ne.int %7, %int6 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %14 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If.yield %13 : !torch.bool\n"
" }\n"
" %10 = torch.prim.If %9 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %12 = torch.aten.__isnot__ %arg2, %none : !torch.optional<str>, !torch.none -> !torch.bool\n"
" %13 = torch.prim.If %12 -> (!torch.bool) {\n"
" %14 = torch.prim.unchecked_cast %arg2 : !torch.optional<str> -> !torch.str\n"
" %15 = torch.aten.eq.str %14, %str : !torch.str, !torch.str -> !torch.bool\n"
" torch.prim.If.yield %15 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If.yield %13 : !torch.bool\n"
" }\n"
" %11 = torch.prim.If %10 -> (!torch.int) {\n"
" torch.prim.If.yield %7 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %int6 : !torch.int\n"
" }\n"
" torch.prim.If.yield %11 : !torch.int\n"
" }\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.matmul\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
@ -11688,24 +11769,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
" %5 = torch.prim.ListConstruct %0#1, %4 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %5) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %6 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.pow.Scalar\"(%arg0: !torch.number, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"

View File

@ -5800,7 +5800,7 @@ class DecomposeAtenFloorDivideOp : public OpRewritePattern<AtenFloorDivideOp> {
// PyTorch aten.floorDivide is a misnomer because it actually rounds
// the quotient towards zero instead of taking its floor.
Value cstStrFloor =
rewriter.create<Torch::ConstantStrOp>(op.getLoc(), "trunc");
rewriter.create<Torch::ConstantStrOp>(op.getLoc(), "floor");
rewriter.replaceOpWithNewOp<AtenDivTensorModeOp>(
op, op.getType(), op.getSelf(), op.getOther(),
/*roundingMode=*/cstStrFloor);
@ -5809,6 +5809,22 @@ class DecomposeAtenFloorDivideOp : public OpRewritePattern<AtenFloorDivideOp> {
};
} // namespace
namespace {
class DecomposeAtenFloorDivideScalarOp
: public OpRewritePattern<AtenFloorDivideScalarOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenFloorDivideScalarOp op,
PatternRewriter &rewriter) const override {
Value cstStrFloor =
rewriter.create<Torch::ConstantStrOp>(op.getLoc(), "floor");
rewriter.replaceOpWithNewOp<AtenDivScalarModeOp>(
op, op.getType(), op.getSelf(), op.getOther(),
/*roundingMode=*/cstStrFloor);
return success();
}
};
} // namespace
namespace {
// Decompose `aten.numpyT` op into `aten.permute` op.
class DecomposeAtenNumpyTOp : public OpRewritePattern<AtenNumpyTOp> {
@ -7560,6 +7576,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenCosineSimilarityOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBaddbmmOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFloorDivideOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFloorDivideScalarOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNumpyTOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectScatterOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarDimOp>(patterns);

View File

@ -483,6 +483,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenClampMaxOp>();
target.addIllegalOp<AtenBaddbmmOp>();
target.addIllegalOp<AtenFloorDivideOp>();
target.addIllegalOp<AtenFloorDivideScalarOp>();
target.addIllegalOp<AtenNumpyTOp>();
target.addIllegalOp<AtenSelectScatterOp>();
target.addIllegalOp<AtenVarDimOp>();

View File

@ -244,12 +244,20 @@ TORCHDYNAMO_XFAIL_SET = {
"ElementwiseSubScalarIntModule_basic",
# ERROR: Exception: Unsupported: missing default value for argument 0 in schema for aten.div.Tensor_mode
"ElementwiseDivRoundingModeFloorModule_basic",
"ElementwiseDivRoundingModeTruncModule_basic",
"ElementwiseDivRoundingModeFloorStaticModule_basic",
"ElementwiseDivRoundingModeTruncStaticModule_basic",
"ElementwiseDivRoundingModeFloorIntStaticModule_basic",
"ElementwiseDivRoundingModeTruncIntStaticModule_basic",
"ElementwiseAtenFloorDivideScalarNegativeModule_basic",
"ElementwiseAtenFloorDivideScalarModule_basic",
"ElementwiseDivTensorRoundingModeFloorModule_basic",
"ElementwiseDivTensorRoundingModeTruncModule_basic",
"ElementwiseDivTensorRoundingModeFloorStaticModule_basic",
"ElementwiseDivTensorRoundingModeTruncStaticModule_basic",
"ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic",
"ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic",
"ElementwiseDivScalarRoundingModeFloorModule_basic",
"ElementwiseDivScalarRoundingModeTruncModule_basic",
"ElementwiseDivScalarRoundingModeFloorStaticModule_basic",
"ElementwiseDivScalarRoundingModeTruncStaticModule_basic",
"ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic",
"ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic",
"AdaptiveAvgPool1dStaticLargerOutput_basic",
"AdaptiveAvgPool1dGeneralDynamic_basic",
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
@ -799,10 +807,14 @@ STABLEHLO_PASS_SET = {
"ElementwiseCloneContiguousModule_basic",
"ElementwiseCloneModule_basic",
"ElementwiseCosModule_basic",
"ElementwiseDivRoundingModeFloorStaticModule_basic",
"ElementwiseDivRoundingModeTruncStaticModule_basic",
"ElementwiseDivRoundingModeFloorIntStaticModule_basic",
"ElementwiseDivRoundingModeTruncIntStaticModule_basic",
"ElementwiseDivTensorRoundingModeFloorStaticModule_basic",
"ElementwiseDivTensorRoundingModeTruncStaticModule_basic",
"ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic",
"ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic",
"ElementwiseDivScalarRoundingModeFloorStaticModule_basic",
"ElementwiseDivScalarRoundingModeTruncStaticModule_basic",
"ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic",
"ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic",
"ElementwiseErfModule_basic",
"ElementwiseExpModule_basic",
"ElementwiseFloorIntModule_basic",
@ -1354,6 +1366,8 @@ TOSA_PASS_SET = {
"ElementwiseDivScalarModule_basic",
"ElementwiseDivTensorIntegerModule_basic",
"ElementwiseDivTensorUnsignedIntegerModule_basic",
"ElementwiseDivScalarIntegerModule_basic",
"ElementwiseDivScalarUnsignedIntegerModule_basic",
"ElementwiseEluModule_basic",
"ElementwiseEluNonDefaultModule_basic",
"ElementwiseEqBoolScalarModule_basic",
@ -2450,10 +2464,11 @@ ONNX_XFAIL_SET = {
"ElementwiseAsinIntModule_basic",
"ElementwiseAtanTensorIntModule_basic",
"ElementwiseCosIntModule_basic",
"ElementwiseDivRoundingModeFloorIntStaticModule_basic",
"ElementwiseDivRoundingModeTruncIntStaticModule_basic",
"ElementwiseDivRoundingModeTruncModule_basic",
"ElementwiseDivRoundingModeTruncStaticModule_basic",
"ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic",
"ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic",
"ElementwiseDivTensorRoundingModeTruncModule_basic",
"ElementwiseDivTensorRoundingModeTruncStaticModule_basic",
"ElementwiseErfIntModule_basic",
"ElementwiseErfIntModule_basic",
"ElementwiseExpIntModule_basic",
"ElementwiseLogIntModule_basic",
@ -2484,7 +2499,12 @@ ONNX_XFAIL_SET = {
"TensorsStackPromoteDTypeModule_basic",
# Failure - "RuntimeError: linalg.cross: inputs dimension 1 must have length 3. Got 1 and 1"
"AtenLinalgCrossDynamic_basic"
"AtenLinalgCrossDynamic_basic",
# Failure - value not close to golden value (op is incorrectly truncating)
"ElementwiseAtenFloorDivideTensorNegativeModule_basic",
"ElementwiseAtenFloorDivideScalarNegativeModule_basic",
}
ONNX_CRASHING_SET = {

View File

@ -1209,6 +1209,9 @@ def atendivTensor〡shape(self: List[int], other: List[int]) -> List[int]:
def atendivTensor_mode〡shape(self: List[int], other: List[int], rounding_mode: Optional[str]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)
def atendivScalar_mode〡shape(self: List[int], other: float, rounding_mode: Optional[str]) -> List[int]:
return upstream_shape_functions.unary(self)
def atenfloor_divide〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)
@ -3152,6 +3155,45 @@ def atendivTensor_mode〡dtype(self_rank_dtype: Tuple[int, int], other_ran
else:
return torch.float32
@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1.0))
def atenfloor_divideScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
self_rank, self_dtype = self_rank_dtype
assert not is_complex_dtype(self_dtype)
ranks: List[Optional[int]] = [self_rank, None]
dtypes = [self_dtype, get_dtype_of_scalar(other)]
return promote_dtypes(ranks, dtypes)
@check_dtype_function(
# _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0, rounding_mode=None, error_types={torch.complex64, torch.complex128}) +
# _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0, rounding_mode=None, error_types={torch.complex64, torch.complex128}) +
# _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0, rounding_mode="floor", error_types={torch.complex64, torch.complex128}) +
# _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0, rounding_mode="floor", error_types={torch.complex64, torch.complex128}) +
# _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0, rounding_mode="trunc", error_types={torch.complex64, torch.complex128}) +
# _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0, rounding_mode="trunc", error_types={torch.complex64, torch.complex128}))
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0, rounding_mode=None) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0, rounding_mode=None) +
_check_tensors_with_the_same_dtype(error_types={torch.complex64, torch.complex128}, num_of_tensors=1, other=0.0, rounding_mode="floor") +
_check_tensors_with_the_same_dtype(error_types={torch.complex64, torch.complex128}, num_of_tensors=1, other=0, rounding_mode="floor") +
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0, rounding_mode="trunc") +
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0, rounding_mode="trunc"))
def atendivScalar_mode〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex], rounding_mode: Optional[str]) -> int:
if rounding_mode is not None and rounding_mode == "floor":
return atenfloor_divideScalar〡dtype(self_rank_dtype, other)
self_rank, self_dtype = self_rank_dtype
ranks: List[Optional[int]] = [None, self_rank]
dtypes = [get_dtype_of_scalar(other), self_dtype]
promoted_dtype = promote_dtypes(ranks, dtypes)
if is_complex_dtype(promoted_dtype) or \
(is_float_dtype(promoted_dtype) and promoted_dtype != torch.float32) or \
(rounding_mode is not None and rounding_mode == "trunc"):
return promoted_dtype
else:
return torch.float32
@check_dtype_function(
_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 4), (2, 4, 3)]) +
# Different width
@ -3631,15 +3673,6 @@ def atenfmodTensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dt
dtypes = [self_dtype, other_dtype]
return promote_dtypes(ranks, dtypes)
@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1.0))
def atenfloor_divideScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
self_rank, self_dtype = self_rank_dtype
assert not is_complex_dtype(self_dtype)
ranks: List[Optional[int]] = [self_rank, None]
dtypes = [self_dtype, get_dtype_of_scalar(other)]
return promote_dtypes(ranks, dtypes)
def atenpowScalar〡dtype(self: Union[int, float, complex], exponent_rank_dtype: Tuple[int, int]) -> int:
exponent_rank, exponent_dtype = exponent_rank_dtype

View File

@ -342,6 +342,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
# Elementwise tensor compute ops that don't have the standard mutating
# variants.
emit_with_mutating_variants("aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::div.Scalar_mode : (Tensor, Scalar, str?) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True, has_folder=True)
emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True)
emit_with_mutating_variants("aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True)

View File

@ -2793,7 +2793,129 @@ def ElementwiseDivTensorUnsignedIntegerModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseDivRoundingModeTruncModule(torch.nn.Module):
class ElementwiseDivScalarRoundingModeTruncModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
])
def forward(self, a):
return torch.div(a, 0.5, rounding_mode="trunc")
@register_test_case(
module_factory=lambda: ElementwiseDivScalarRoundingModeTruncModule())
def ElementwiseDivScalarRoundingModeTruncModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4))
class ElementwiseDivScalarRoundingModeFloorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.div(a, 0.5, rounding_mode="floor")
@register_test_case(
module_factory=lambda: ElementwiseDivScalarRoundingModeFloorModule())
def ElementwiseDivScalarRoundingModeFloorModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
class ElementwiseDivScalarRoundingModeTruncStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([4], torch.float32, True),
])
def forward(self, a):
return torch.div(a, 0.5, rounding_mode="trunc")
@register_test_case(
module_factory=lambda: ElementwiseDivScalarRoundingModeTruncStaticModule())
def ElementwiseDivScalarRoundingModeTruncStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4))
class ElementwiseDivScalarRoundingModeFloorStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4], torch.float32, True),
])
def forward(self, a):
return torch.div(a, 0.5, rounding_mode="floor")
@register_test_case(
module_factory=lambda: ElementwiseDivScalarRoundingModeFloorStaticModule())
def ElementwiseDivScalarRoundingModeFloorStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
class ElementwiseDivScalarRoundingModeTruncIntStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4], torch.int32, True),
])
def forward(self, a):
return torch.div(a, 3, rounding_mode="trunc")
@register_test_case(
module_factory=lambda: ElementwiseDivScalarRoundingModeTruncIntStaticModule())
def ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32))
class ElementwiseDivScalarRoundingModeFloorIntStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4], torch.int32, True),
])
def forward(self, a):
return torch.div(a, 3, rounding_mode="floor")
@register_test_case(
module_factory=lambda: ElementwiseDivScalarRoundingModeFloorIntStaticModule())
def ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32))
# ==============================================================================
class ElementwiseDivTensorRoundingModeTruncModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -2809,12 +2931,12 @@ class ElementwiseDivRoundingModeTruncModule(torch.nn.Module):
@register_test_case(
module_factory=lambda: ElementwiseDivRoundingModeTruncModule())
def ElementwiseDivRoundingModeTruncModule_basic(module, tu: TestUtils):
module_factory=lambda: ElementwiseDivTensorRoundingModeTruncModule())
def ElementwiseDivTensorRoundingModeTruncModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4), tu.rand(4).type(torch.float64))
class ElementwiseDivRoundingModeFloorModule(torch.nn.Module):
class ElementwiseDivTensorRoundingModeFloorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -2830,11 +2952,11 @@ class ElementwiseDivRoundingModeFloorModule(torch.nn.Module):
@register_test_case(
module_factory=lambda: ElementwiseDivRoundingModeFloorModule())
def ElementwiseDivRoundingModeFloorModule_basic(module, tu: TestUtils):
module_factory=lambda: ElementwiseDivTensorRoundingModeFloorModule())
def ElementwiseDivTensorRoundingModeFloorModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4), tu.rand(3, 4).type(torch.float64))
class ElementwiseDivRoundingModeTruncStaticModule(torch.nn.Module):
class ElementwiseDivTensorRoundingModeTruncStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -2850,12 +2972,12 @@ class ElementwiseDivRoundingModeTruncStaticModule(torch.nn.Module):
@register_test_case(
module_factory=lambda: ElementwiseDivRoundingModeTruncStaticModule())
def ElementwiseDivRoundingModeTruncStaticModule_basic(module, tu: TestUtils):
module_factory=lambda: ElementwiseDivTensorRoundingModeTruncStaticModule())
def ElementwiseDivTensorRoundingModeTruncStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4), tu.rand(4).type(torch.float64))
class ElementwiseDivRoundingModeFloorStaticModule(torch.nn.Module):
class ElementwiseDivTensorRoundingModeFloorStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -2871,11 +2993,11 @@ class ElementwiseDivRoundingModeFloorStaticModule(torch.nn.Module):
@register_test_case(
module_factory=lambda: ElementwiseDivRoundingModeFloorStaticModule())
def ElementwiseDivRoundingModeFloorStaticModule_basic(module, tu: TestUtils):
module_factory=lambda: ElementwiseDivTensorRoundingModeFloorStaticModule())
def ElementwiseDivTensorRoundingModeFloorStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4), tu.rand(3, 4).type(torch.float64))
class ElementwiseDivRoundingModeTruncIntStaticModule(torch.nn.Module):
class ElementwiseDivTensorRoundingModeTruncIntStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -2891,12 +3013,12 @@ class ElementwiseDivRoundingModeTruncIntStaticModule(torch.nn.Module):
@register_test_case(
module_factory=lambda: ElementwiseDivRoundingModeTruncIntStaticModule())
def ElementwiseDivRoundingModeTruncIntStaticModule_basic(module, tu: TestUtils):
module_factory=lambda: ElementwiseDivTensorRoundingModeTruncIntStaticModule())
def ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32), tu.randint(3, 4, low=1, high=10).type(torch.int64))
class ElementwiseDivRoundingModeFloorIntStaticModule(torch.nn.Module):
class ElementwiseDivTensorRoundingModeFloorIntStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -2912,8 +3034,8 @@ class ElementwiseDivRoundingModeFloorIntStaticModule(torch.nn.Module):
@register_test_case(
module_factory=lambda: ElementwiseDivRoundingModeFloorIntStaticModule())
def ElementwiseDivRoundingModeFloorIntStaticModule_basic(module, tu: TestUtils):
module_factory=lambda: ElementwiseDivTensorRoundingModeFloorIntStaticModule())
def ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32), tu.randint(3, 4, low=1, high=10).type(torch.int64))
@ -4194,7 +4316,48 @@ def ElementwiseAtenLogicalNotOpPromoteModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseAtenFloorDivideModule(torch.nn.Module):
class ElementwiseAtenFloorDivideScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.floor_divide(x, 0.14)
@register_test_case(module_factory=lambda: ElementwiseAtenFloorDivideScalarModule())
def ElementwiseAtenFloorDivideScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 3))
class ElementwiseAtenFloorDivideScalarNegativeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.floor_divide(x, 0.14)
@register_test_case(module_factory=lambda: ElementwiseAtenFloorDivideScalarNegativeModule())
def ElementwiseAtenFloorDivideScalarNegativeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 3, low=-10.0, high=10.0))
# ==============================================================================
class ElementwiseAtenFloorDivideTensorNegativeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -4209,8 +4372,28 @@ class ElementwiseAtenFloorDivideModule(torch.nn.Module):
return torch.ops.aten.floor_divide(x, y)
@register_test_case(module_factory=lambda: ElementwiseAtenFloorDivideModule())
def ElementwiseAtenFloorDivideModule_basic(module, tu: TestUtils):
@register_test_case(module_factory=lambda: ElementwiseAtenFloorDivideTensorNegativeModule())
def ElementwiseAtenFloorDivideTensorNegativeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 3, low= -1, high=0), tu.rand(4, 3, low= 0, high=1))
class ElementwiseAtenFloorDivideTensorPositiveModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, x, y):
return torch.ops.aten.floor_divide(x, y)
@register_test_case(module_factory=lambda: ElementwiseAtenFloorDivideTensorPositiveModule())
def ElementwiseAtenFloorDivideTensorPositiveModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 3), tu.rand(4, 3))