Added 2 Ops: Floor divide scalar and Floor divide scalar mode (#3156)

- Added linalg lowering for `AtenFloorDivideScalarOp`
  - Needed `AtenDivScalarModeOp` for the decomp.
- Added linalg lowering for `AtenDivScalarModeOp`
- Moved linalg payload logic to `createDivModePayload()` since the logic
was nearly identical for both `AtenDivScalarModeOp` and
`AtenDivTensorModeOp`. Just a template function
 -  Added `AtenDivScalarModeOp` lowering for stablehlo
 

Pytorch's
[`torch.floor_divide()`](https://pytorch.org/docs/stable/generated/torch.floor_divide.html)
in a previous version (for a reason unknown to me) preformed a
truncation instead of "floor". The already implemented op
`AtenFloorDivideTensorOp` was done before this change. However, this
wasn't caught because our testcases only tested positive floor division.
I changed this to floor as well as adding a few test cases.
pull/3167/head
IanWood1 2024-04-15 13:45:10 -07:00 committed by GitHub
parent 83cba8c696
commit 5708ee7ec9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 565 additions and 159 deletions

View File

@ -3397,6 +3397,56 @@ def Torch_AtenDiv_TensorModeOp : Torch_Op<"aten.div_.Tensor_mode", [
}]; }];
} }
def Torch_AtenDivScalarModeOp : Torch_Op<"aten.div.Scalar_mode", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::div.Scalar_mode : (Tensor, Scalar, str?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$other,
AnyTorchOptionalStringType:$rounding_mode
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenDivScalarModeOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenDivScalarModeOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasCanonicalizer = 1;
}
def Torch_AtenDiv_ScalarModeOp : Torch_Op<"aten.div_.Scalar_mode", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::div_.Scalar_mode : (Tensor, Scalar, str?) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self,
AnyTorchScalarType:$other,
AnyTorchOptionalStringType:$rounding_mode
);
let results = (outs
AnyTorchOptionalNonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenDiv_ScalarModeOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenDiv_ScalarModeOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenMulTensorOp : Torch_Op<"aten.mul.Tensor", [ def Torch_AtenMulTensorOp : Torch_Op<"aten.mul.Tensor", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -26,6 +26,7 @@
#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/APSInt.h" #include "llvm/ADT/APSInt.h"
#include <numeric> #include <numeric>
#include <type_traits>
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
@ -213,6 +214,78 @@ createTriangularMatrix(OpBuilder &b, Location loc, ValueRange payloadArgs,
return success(); return success();
} }
template <typename OpT>
Value createDivModePayload(OpBuilder &b, Location loc,
const TypeConverter *converter,
ValueRange payloadArgs, OpT op,
ArrayRef<Value> operands) {
static_assert(std::is_same_v<OpT, AtenDivTensorModeOp> ||
std::is_same_v<OpT, AtenDivScalarModeOp>,
"template type must be a tensor/scalar div mode");
typename OpT::Adaptor adaptor(operands);
Type dtype = cast<RankedTensorType>(converter->convertType(op.getType()))
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(
b, loc,
std::is_same_v<OpT, AtenDivScalarModeOp> ? operands[1] : payloadArgs[1],
dtype);
Value quotient;
if (isa<mlir::FloatType>(dtype)) {
quotient = b.create<arith::DivFOp>(loc, lhs, rhs);
} else if (dtype.isUnsignedInteger()) {
quotient = b.create<arith::DivUIOp>(loc, lhs, rhs);
} else {
assert(dtype.isInteger() &&
"dtype should be an integer (signless or signed)");
quotient = b.create<arith::DivSIOp>(loc, lhs, rhs);
}
if (isa<Torch::NoneType>(op.getRoundingMode().getType()))
return quotient;
std::string roundingMode;
if (!matchPattern(op.getRoundingMode(), m_TorchConstantStr(roundingMode))) {
op.emitError("only support constant str rounding mode");
return nullptr;
}
assert((roundingMode == "trunc" || roundingMode == "floor") &&
"unsupported rounding mode");
if (roundingMode == "trunc") {
// "trunc" - rounds the results of the division towards zero. Equivalent
// to C-style integer division.
if (!isa<mlir::FloatType>(dtype)) {
// nothing to do for integers
return quotient;
}
// float
Value ceil = b.create<math::CeilOp>(loc, quotient);
Value floor = b.create<math::FloorOp>(loc, quotient);
Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
quotient, cstZero);
return b.create<arith::SelectOp>(loc, pred, ceil, floor);
}
if (roundingMode == "floor") {
// "floor" - rounds the results of the division down. Equivalent to
// floor division in Python (the // operator)
if (isa<mlir::FloatType>(dtype))
return b.create<math::FloorOp>(loc, quotient);
if (!dtype.isUnsignedInteger()) {
Type defaultIntToFloatType = b.getF64Type();
lhs = convertScalarToDtype(b, loc, lhs, defaultIntToFloatType);
rhs = convertScalarToDtype(b, loc, rhs, defaultIntToFloatType);
quotient = b.create<arith::DivFOp>(loc, lhs, rhs);
Value floor = b.create<math::FloorOp>(loc, quotient);
Value convert = convertScalarToDtype(b, loc, floor, dtype);
return convert;
}
}
return quotient;
}
static Value createLinalgPayloadCalculationForElementwiseOp( static Value createLinalgPayloadCalculationForElementwiseOp(
OpBuilder &b, Location loc, const TypeConverter *converter, OpBuilder &b, Location loc, const TypeConverter *converter,
ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) { ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) {
@ -769,66 +842,14 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
div.emitError("unimplemented: non-floating point and non-integer dtype"); div.emitError("unimplemented: non-floating point and non-integer dtype");
return nullptr; return nullptr;
} }
if (auto divTensorMode = dyn_cast<AtenDivTensorModeOp>(op)) { if (auto divScalarMode = dyn_cast<AtenDivScalarModeOp>(op)) {
AtenDivTensorModeOp::Adaptor adaptor(operands); return createDivModePayload(b, loc, converter, payloadArgs, divScalarMode,
Type dtype = converter->convertType(divTensorMode.getType()) operands);
.cast<RankedTensorType>() }
.getElementType(); if (auto divTensorMode = dyn_cast<AtenDivTensorModeOp>(op)) {
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); return createDivModePayload(b, loc, converter, payloadArgs, divTensorMode,
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); operands);
Value div;
if (isa<mlir::FloatType>(dtype))
div = b.create<arith::DivFOp>(loc, lhs, rhs);
else {
if (dtype.isUnsignedInteger())
div = b.create<arith::DivUIOp>(loc, lhs, rhs);
else
div = b.create<arith::DivSIOp>(loc, lhs, rhs);
}
if (divTensorMode.getRoundingMode().getType().isa<Torch::NoneType>())
return div;
std::string roundingMode;
if (!matchPattern(divTensorMode.getRoundingMode(),
m_TorchConstantStr(roundingMode))) {
divTensorMode.emitError("only support constant str rounding mode");
return nullptr;
}
if (roundingMode == "trunc") {
// "trunc" - rounds the results of the division towards zero. Equivalent
// to C-style integer division.
if (isa<mlir::FloatType>(dtype)) {
Value ceil = b.create<math::CeilOp>(loc, div);
Value floor = b.create<math::FloorOp>(loc, div);
Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
div, cstZero);
return b.create<arith::SelectOp>(loc, pred, ceil, floor);
} else
return div;
}
if (roundingMode == "floor") {
// "floor" - rounds the results of the division down. Equivalent to
// floor division in Python (the // operator)
if (isa<mlir::FloatType>(dtype))
return b.create<math::FloorOp>(loc, div);
else if (!dtype.isUnsignedInteger()) {
Type defaultIntToFloatType = b.getF64Type();
lhs = convertScalarToDtype(b, loc, lhs, defaultIntToFloatType);
rhs = convertScalarToDtype(b, loc, rhs, defaultIntToFloatType);
div = b.create<arith::DivFOp>(loc, lhs, rhs);
Value floor = b.create<math::FloorOp>(loc, div);
Value convert = convertScalarToDtype(b, loc, floor, dtype);
return convert;
} else {
return div;
}
}
divTensorMode.emitError("invalid rounding mode");
return nullptr;
} }
if (auto pow = dyn_cast<AtenPowScalarOp>(op)) { if (auto pow = dyn_cast<AtenPowScalarOp>(op)) {
Type dtype = pow.getType().cast<ValueTensorType>().getDtype(); Type dtype = pow.getType().cast<ValueTensorType>().getDtype();
if (!isa<mlir::FloatType>(dtype)) { if (!isa<mlir::FloatType>(dtype)) {
@ -1579,12 +1600,13 @@ public:
if (!isa<AtenTanOp, AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenReluOp, if (!isa<AtenTanOp, AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenReluOp,
AtenPreluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp, AtenPreluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp,
AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp, AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp,
AtenSubTensorOp, AtenAtan2Op, AtenLerpTensorOp, AtenSigmoidOp, AtenDivScalarModeOp, AtenSubTensorOp, AtenAtan2Op,
AtenExpOp, AtenExpm1Op, AtenMinimumOp, AtenMaximumOp, AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenExpm1Op,
AtenToDtypeOp, AtenClampOp, AtenClampTensorOp, AtenRsubScalarOp, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
AtenMulScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenClampTensorOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp,
AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenPowScalarOp,
AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op,
AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp,
AtenRemainderScalarOp, AtenRemainderTensorOp, AtenFmodTensorOp, AtenRemainderScalarOp, AtenRemainderTensorOp, AtenFmodTensorOp,
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
@ -2617,25 +2639,25 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenTanOp, AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenAtanhOp, AtenAcoshOp, AtenTanOp, AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenAtanhOp, AtenAcoshOp,
AtenAsinOp, AtenAsinhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAsinOp, AtenAsinhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp,
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp, AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp,
AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenDivScalarModeOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp,
AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenClampTensorOp, AtenMinimumOp, AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenClampTensorOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp,
AtenCeilOp, AtenPreluOp, AtenPowScalarOp, AtenPowTensorScalarOp, AtenFloorOp, AtenCeilOp, AtenPreluOp, AtenPowScalarOp,
AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op,
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp, AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp,
AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp,
AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp,
AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenAcosOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, AtenTrilOp, AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp,
AtenRemainderScalarOp, AtenFmodTensorOp, AtenRemainderTensorOp, AtenTrilOp, AtenRemainderScalarOp, AtenFmodTensorOp,
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, AtenRemainderTensorOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp, AtenFillTensorOp, AtenRealOp, AtenImagOp, AtenDequantizeSelfOp,
AtenQuantizePerTensorOp>(); AtenDequantizeTensorOp, AtenQuantizePerTensorOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context); patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>(); target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenDetachOp>(typeConverter, context); patterns.add<ConvertAtenDetachOp>(typeConverter, context);

View File

@ -27,8 +27,8 @@
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include <iostream>
#include <numeric> #include <numeric>
#include <type_traits>
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
@ -409,9 +409,9 @@ public:
if (!lhsType) if (!lhsType)
return op.emitError("only Tensor types supported in StableHLO"); return op.emitError("only Tensor types supported in StableHLO");
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter() auto outType = cast<TensorType>(
->convertType(op.getType()) OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
.template cast<TensorType>(); op.getType()));
Type outElemTy = outType.getElementType(); Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) { if (!outElemTy.isIntOrFloat()) {
@ -432,18 +432,23 @@ public:
Value result = Value result =
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions); rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
if (!isa<AtenDivTensorModeOp>(op)) { if (!std::is_same<AtenDivTensorModeOp, AtenOpT>() &&
!std::is_same<AtenDivScalarModeOp, AtenOpT>()) {
rewriter.replaceOp(op, result); rewriter.replaceOp(op, result);
return success(); return success();
} }
AtenDivTensorModeOp divTensorModeOp = auto tensorOp = dyn_cast<AtenDivTensorModeOp>(op.getOperation());
llvm::dyn_cast<AtenDivTensorModeOp>(op.getOperation()); auto opRoundingMode =
tensorOp
? tensorOp.getRoundingMode()
: cast<AtenDivScalarModeOp>(op.getOperation()).getRoundingMode();
std::string roundingMode; std::string roundingMode;
if (!matchPattern(divTensorModeOp.getRoundingMode(), if (!matchPattern(opRoundingMode, m_TorchConstantStr(roundingMode))) {
m_TorchConstantStr(roundingMode)))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "only support constant str rounding mode"); op, "only support constant str rounding mode");
}
// if trunc and int, do nothing // if trunc and int, do nothing
if (roundingMode == "trunc" && isa<mlir::FloatType>(outElemTy)) { if (roundingMode == "trunc" && isa<mlir::FloatType>(outElemTy)) {
@ -1845,6 +1850,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorOp, chlo::BroadcastDivOp); INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorOp, chlo::BroadcastDivOp);
INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorModeOp, chlo::BroadcastDivOp); INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorModeOp, chlo::BroadcastDivOp);
INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarOp, chlo::BroadcastDivOp); INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarOp, chlo::BroadcastDivOp);
INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarModeOp, chlo::BroadcastDivOp);
INSERT_BINARY_MULDIV_PATTERN(AtenRemainderScalarOp, chlo::BroadcastRemOp); INSERT_BINARY_MULDIV_PATTERN(AtenRemainderScalarOp, chlo::BroadcastRemOp);
#undef INSERT_BINARY_MULDIV_PATTERN #undef INSERT_BINARY_MULDIV_PATTERN

View File

@ -1095,9 +1095,9 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
rhs = rewriter.create<AtenMulIntOp>(loc, rhs, alpha); rhs = rewriter.create<AtenMulIntOp>(loc, rhs, alpha);
} }
if (isa<AtenDivTensorModeOp>(op)) { if (isa<AtenDivTensorModeOp, AtenDivScalarModeOp>(op)) {
// None rounding mode
if (op->getOperand(2).getType().isa<Torch::NoneType>()) { if (op->getOperand(2).getType().isa<Torch::NoneType>()) {
// None rounding mode
Value quotient = rewriter.create<AtenDivOp>(loc, lhs, rhs); Value quotient = rewriter.create<AtenDivOp>(loc, lhs, rhs);
rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(op, outType, rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(op, outType,
quotient); quotient);
@ -1858,6 +1858,16 @@ void AtenDivTensorModeOp::getCanonicalizationPatterns(
}); });
} }
//===----------------------------------------------------------------------===//
// AtenDivScalarModeOp
//===----------------------------------------------------------------------===//
void AtenDivScalarModeOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add(+[](AtenDivScalarModeOp op, PatternRewriter &rewriter) {
return rewrite0DBinaryTensorOp(op, rewriter);
});
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenNumelOp // AtenNumelOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -8496,6 +8496,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.div.Scalar_mode\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.optional<str>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.floor_divide\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.floor_divide\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
@ -11166,6 +11170,83 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n" " }\n"
" return %2 : !torch.int\n" " return %2 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
" %5 = torch.prim.ListConstruct %0#1, %4 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %5) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %6 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.div.Scalar_mode\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.optional<str>) -> !torch.int {\n"
" %str = torch.constant.str \"trunc\"\n"
" %int6 = torch.constant.int 6\n"
" %true = torch.constant.bool true\n"
" %false = torch.constant.bool false\n"
" %str_0 = torch.constant.str \"floor\"\n"
" %none = torch.constant.none\n"
" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional<str>, !torch.none -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.bool) {\n"
" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional<str> -> !torch.str\n"
" %4 = torch.aten.eq.str %3, %str_0 : !torch.str, !torch.str -> !torch.bool\n"
" torch.prim.If.yield %4 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
" %3 = func.call @\"__torch_mlir_dtype_fn.aten.floor_divide.Scalar\"(%arg0, %arg1) : (!torch.tuple<int, int>, !torch.number) -> !torch.int\n"
" torch.prim.If.yield %3 : !torch.int\n"
" } else {\n"
" %3:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %4 = torch.prim.ListConstruct %none, %3#0 : (!torch.none, !torch.int) -> !torch.list<optional<int>>\n"
" %5 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
" %6 = torch.prim.ListConstruct %5, %3#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%4, %6) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" %8 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%7) : (!torch.int) -> !torch.bool\n"
" %9 = torch.prim.If %8 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %12 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%7) : (!torch.int) -> !torch.bool\n"
" %13 = torch.prim.If %12 -> (!torch.bool) {\n"
" %14 = torch.aten.ne.int %7, %int6 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %14 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If.yield %13 : !torch.bool\n"
" }\n"
" %10 = torch.prim.If %9 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %12 = torch.aten.__isnot__ %arg2, %none : !torch.optional<str>, !torch.none -> !torch.bool\n"
" %13 = torch.prim.If %12 -> (!torch.bool) {\n"
" %14 = torch.prim.unchecked_cast %arg2 : !torch.optional<str> -> !torch.str\n"
" %15 = torch.aten.eq.str %14, %str : !torch.str, !torch.str -> !torch.bool\n"
" torch.prim.If.yield %15 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If.yield %13 : !torch.bool\n"
" }\n"
" %11 = torch.prim.If %10 -> (!torch.int) {\n"
" torch.prim.If.yield %7 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %int6 : !torch.int\n"
" }\n"
" torch.prim.If.yield %11 : !torch.int\n"
" }\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.matmul\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.matmul\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
@ -11688,24 +11769,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n" " %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n" " return %4 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
" %5 = torch.prim.ListConstruct %0#1, %4 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %5) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %6 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.pow.Scalar\"(%arg0: !torch.number, %arg1: !torch.tuple<int, int>) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.pow.Scalar\"(%arg0: !torch.number, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %none = torch.constant.none\n" " %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"

View File

@ -5800,7 +5800,7 @@ class DecomposeAtenFloorDivideOp : public OpRewritePattern<AtenFloorDivideOp> {
// PyTorch aten.floorDivide is a misnomer because it actually rounds // PyTorch aten.floorDivide is a misnomer because it actually rounds
// the quotient towards zero instead of taking its floor. // the quotient towards zero instead of taking its floor.
Value cstStrFloor = Value cstStrFloor =
rewriter.create<Torch::ConstantStrOp>(op.getLoc(), "trunc"); rewriter.create<Torch::ConstantStrOp>(op.getLoc(), "floor");
rewriter.replaceOpWithNewOp<AtenDivTensorModeOp>( rewriter.replaceOpWithNewOp<AtenDivTensorModeOp>(
op, op.getType(), op.getSelf(), op.getOther(), op, op.getType(), op.getSelf(), op.getOther(),
/*roundingMode=*/cstStrFloor); /*roundingMode=*/cstStrFloor);
@ -5809,6 +5809,22 @@ class DecomposeAtenFloorDivideOp : public OpRewritePattern<AtenFloorDivideOp> {
}; };
} // namespace } // namespace
namespace {
class DecomposeAtenFloorDivideScalarOp
: public OpRewritePattern<AtenFloorDivideScalarOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenFloorDivideScalarOp op,
PatternRewriter &rewriter) const override {
Value cstStrFloor =
rewriter.create<Torch::ConstantStrOp>(op.getLoc(), "floor");
rewriter.replaceOpWithNewOp<AtenDivScalarModeOp>(
op, op.getType(), op.getSelf(), op.getOther(),
/*roundingMode=*/cstStrFloor);
return success();
}
};
} // namespace
namespace { namespace {
// Decompose `aten.numpyT` op into `aten.permute` op. // Decompose `aten.numpyT` op into `aten.permute` op.
class DecomposeAtenNumpyTOp : public OpRewritePattern<AtenNumpyTOp> { class DecomposeAtenNumpyTOp : public OpRewritePattern<AtenNumpyTOp> {
@ -7560,6 +7576,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenCosineSimilarityOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenCosineSimilarityOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBaddbmmOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenBaddbmmOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFloorDivideOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenFloorDivideOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFloorDivideScalarOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNumpyTOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenNumpyTOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectScatterOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenSelectScatterOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarDimOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenVarDimOp>(patterns);

View File

@ -483,6 +483,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenClampMaxOp>(); target.addIllegalOp<AtenClampMaxOp>();
target.addIllegalOp<AtenBaddbmmOp>(); target.addIllegalOp<AtenBaddbmmOp>();
target.addIllegalOp<AtenFloorDivideOp>(); target.addIllegalOp<AtenFloorDivideOp>();
target.addIllegalOp<AtenFloorDivideScalarOp>();
target.addIllegalOp<AtenNumpyTOp>(); target.addIllegalOp<AtenNumpyTOp>();
target.addIllegalOp<AtenSelectScatterOp>(); target.addIllegalOp<AtenSelectScatterOp>();
target.addIllegalOp<AtenVarDimOp>(); target.addIllegalOp<AtenVarDimOp>();

View File

@ -244,12 +244,20 @@ TORCHDYNAMO_XFAIL_SET = {
"ElementwiseSubScalarIntModule_basic", "ElementwiseSubScalarIntModule_basic",
# ERROR: Exception: Unsupported: missing default value for argument 0 in schema for aten.div.Tensor_mode # ERROR: Exception: Unsupported: missing default value for argument 0 in schema for aten.div.Tensor_mode
"ElementwiseDivRoundingModeFloorModule_basic", "ElementwiseAtenFloorDivideScalarNegativeModule_basic",
"ElementwiseDivRoundingModeTruncModule_basic", "ElementwiseAtenFloorDivideScalarModule_basic",
"ElementwiseDivRoundingModeFloorStaticModule_basic", "ElementwiseDivTensorRoundingModeFloorModule_basic",
"ElementwiseDivRoundingModeTruncStaticModule_basic", "ElementwiseDivTensorRoundingModeTruncModule_basic",
"ElementwiseDivRoundingModeFloorIntStaticModule_basic", "ElementwiseDivTensorRoundingModeFloorStaticModule_basic",
"ElementwiseDivRoundingModeTruncIntStaticModule_basic", "ElementwiseDivTensorRoundingModeTruncStaticModule_basic",
"ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic",
"ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic",
"ElementwiseDivScalarRoundingModeFloorModule_basic",
"ElementwiseDivScalarRoundingModeTruncModule_basic",
"ElementwiseDivScalarRoundingModeFloorStaticModule_basic",
"ElementwiseDivScalarRoundingModeTruncStaticModule_basic",
"ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic",
"ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic",
"AdaptiveAvgPool1dStaticLargerOutput_basic", "AdaptiveAvgPool1dStaticLargerOutput_basic",
"AdaptiveAvgPool1dGeneralDynamic_basic", "AdaptiveAvgPool1dGeneralDynamic_basic",
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
@ -799,10 +807,14 @@ STABLEHLO_PASS_SET = {
"ElementwiseCloneContiguousModule_basic", "ElementwiseCloneContiguousModule_basic",
"ElementwiseCloneModule_basic", "ElementwiseCloneModule_basic",
"ElementwiseCosModule_basic", "ElementwiseCosModule_basic",
"ElementwiseDivRoundingModeFloorStaticModule_basic", "ElementwiseDivTensorRoundingModeFloorStaticModule_basic",
"ElementwiseDivRoundingModeTruncStaticModule_basic", "ElementwiseDivTensorRoundingModeTruncStaticModule_basic",
"ElementwiseDivRoundingModeFloorIntStaticModule_basic", "ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic",
"ElementwiseDivRoundingModeTruncIntStaticModule_basic", "ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic",
"ElementwiseDivScalarRoundingModeFloorStaticModule_basic",
"ElementwiseDivScalarRoundingModeTruncStaticModule_basic",
"ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic",
"ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic",
"ElementwiseErfModule_basic", "ElementwiseErfModule_basic",
"ElementwiseExpModule_basic", "ElementwiseExpModule_basic",
"ElementwiseFloorIntModule_basic", "ElementwiseFloorIntModule_basic",
@ -1354,6 +1366,8 @@ TOSA_PASS_SET = {
"ElementwiseDivScalarModule_basic", "ElementwiseDivScalarModule_basic",
"ElementwiseDivTensorIntegerModule_basic", "ElementwiseDivTensorIntegerModule_basic",
"ElementwiseDivTensorUnsignedIntegerModule_basic", "ElementwiseDivTensorUnsignedIntegerModule_basic",
"ElementwiseDivScalarIntegerModule_basic",
"ElementwiseDivScalarUnsignedIntegerModule_basic",
"ElementwiseEluModule_basic", "ElementwiseEluModule_basic",
"ElementwiseEluNonDefaultModule_basic", "ElementwiseEluNonDefaultModule_basic",
"ElementwiseEqBoolScalarModule_basic", "ElementwiseEqBoolScalarModule_basic",
@ -2450,10 +2464,11 @@ ONNX_XFAIL_SET = {
"ElementwiseAsinIntModule_basic", "ElementwiseAsinIntModule_basic",
"ElementwiseAtanTensorIntModule_basic", "ElementwiseAtanTensorIntModule_basic",
"ElementwiseCosIntModule_basic", "ElementwiseCosIntModule_basic",
"ElementwiseDivRoundingModeFloorIntStaticModule_basic", "ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic",
"ElementwiseDivRoundingModeTruncIntStaticModule_basic", "ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic",
"ElementwiseDivRoundingModeTruncModule_basic", "ElementwiseDivTensorRoundingModeTruncModule_basic",
"ElementwiseDivRoundingModeTruncStaticModule_basic", "ElementwiseDivTensorRoundingModeTruncStaticModule_basic",
"ElementwiseErfIntModule_basic",
"ElementwiseErfIntModule_basic", "ElementwiseErfIntModule_basic",
"ElementwiseExpIntModule_basic", "ElementwiseExpIntModule_basic",
"ElementwiseLogIntModule_basic", "ElementwiseLogIntModule_basic",
@ -2484,7 +2499,12 @@ ONNX_XFAIL_SET = {
"TensorsStackPromoteDTypeModule_basic", "TensorsStackPromoteDTypeModule_basic",
# Failure - "RuntimeError: linalg.cross: inputs dimension 1 must have length 3. Got 1 and 1" # Failure - "RuntimeError: linalg.cross: inputs dimension 1 must have length 3. Got 1 and 1"
"AtenLinalgCrossDynamic_basic" "AtenLinalgCrossDynamic_basic",
# Failure - value not close to golden value (op is incorrectly truncating)
"ElementwiseAtenFloorDivideTensorNegativeModule_basic",
"ElementwiseAtenFloorDivideScalarNegativeModule_basic",
} }
ONNX_CRASHING_SET = { ONNX_CRASHING_SET = {

View File

@ -1209,6 +1209,9 @@ def atendivTensor〡shape(self: List[int], other: List[int]) -> List[int]:
def atendivTensor_mode〡shape(self: List[int], other: List[int], rounding_mode: Optional[str]) -> List[int]: def atendivTensor_mode〡shape(self: List[int], other: List[int], rounding_mode: Optional[str]) -> List[int]:
return upstream_shape_functions.broadcast(self, other) return upstream_shape_functions.broadcast(self, other)
def atendivScalar_mode〡shape(self: List[int], other: float, rounding_mode: Optional[str]) -> List[int]:
return upstream_shape_functions.unary(self)
def atenfloor_divide〡shape(self: List[int], other: List[int]) -> List[int]: def atenfloor_divide〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other) return upstream_shape_functions.broadcast(self, other)
@ -3152,6 +3155,45 @@ def atendivTensor_mode〡dtype(self_rank_dtype: Tuple[int, int], other_ran
else: else:
return torch.float32 return torch.float32
@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1.0))
def atenfloor_divideScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
self_rank, self_dtype = self_rank_dtype
assert not is_complex_dtype(self_dtype)
ranks: List[Optional[int]] = [self_rank, None]
dtypes = [self_dtype, get_dtype_of_scalar(other)]
return promote_dtypes(ranks, dtypes)
@check_dtype_function(
# _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0, rounding_mode=None, error_types={torch.complex64, torch.complex128}) +
# _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0, rounding_mode=None, error_types={torch.complex64, torch.complex128}) +
# _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0, rounding_mode="floor", error_types={torch.complex64, torch.complex128}) +
# _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0, rounding_mode="floor", error_types={torch.complex64, torch.complex128}) +
# _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0, rounding_mode="trunc", error_types={torch.complex64, torch.complex128}) +
# _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0, rounding_mode="trunc", error_types={torch.complex64, torch.complex128}))
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0, rounding_mode=None) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0, rounding_mode=None) +
_check_tensors_with_the_same_dtype(error_types={torch.complex64, torch.complex128}, num_of_tensors=1, other=0.0, rounding_mode="floor") +
_check_tensors_with_the_same_dtype(error_types={torch.complex64, torch.complex128}, num_of_tensors=1, other=0, rounding_mode="floor") +
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0, rounding_mode="trunc") +
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0, rounding_mode="trunc"))
def atendivScalar_mode〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex], rounding_mode: Optional[str]) -> int:
if rounding_mode is not None and rounding_mode == "floor":
return atenfloor_divideScalar〡dtype(self_rank_dtype, other)
self_rank, self_dtype = self_rank_dtype
ranks: List[Optional[int]] = [None, self_rank]
dtypes = [get_dtype_of_scalar(other), self_dtype]
promoted_dtype = promote_dtypes(ranks, dtypes)
if is_complex_dtype(promoted_dtype) or \
(is_float_dtype(promoted_dtype) and promoted_dtype != torch.float32) or \
(rounding_mode is not None and rounding_mode == "trunc"):
return promoted_dtype
else:
return torch.float32
@check_dtype_function( @check_dtype_function(
_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 4), (2, 4, 3)]) + _check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 4), (2, 4, 3)]) +
# Different width # Different width
@ -3631,15 +3673,6 @@ def atenfmodTensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dt
dtypes = [self_dtype, other_dtype] dtypes = [self_dtype, other_dtype]
return promote_dtypes(ranks, dtypes) return promote_dtypes(ranks, dtypes)
@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1.0))
def atenfloor_divideScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
self_rank, self_dtype = self_rank_dtype
assert not is_complex_dtype(self_dtype)
ranks: List[Optional[int]] = [self_rank, None]
dtypes = [self_dtype, get_dtype_of_scalar(other)]
return promote_dtypes(ranks, dtypes)
def atenpowScalar〡dtype(self: Union[int, float, complex], exponent_rank_dtype: Tuple[int, int]) -> int: def atenpowScalar〡dtype(self: Union[int, float, complex], exponent_rank_dtype: Tuple[int, int]) -> int:
exponent_rank, exponent_dtype = exponent_rank_dtype exponent_rank, exponent_dtype = exponent_rank_dtype

View File

@ -342,6 +342,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
# Elementwise tensor compute ops that don't have the standard mutating # Elementwise tensor compute ops that don't have the standard mutating
# variants. # variants.
emit_with_mutating_variants("aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::div.Scalar_mode : (Tensor, Scalar, str?) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True, has_folder=True) emit_with_mutating_variants("aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True, has_folder=True)
emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True) emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True)
emit_with_mutating_variants("aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True) emit_with_mutating_variants("aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True)

View File

@ -2793,7 +2793,129 @@ def ElementwiseDivTensorUnsignedIntegerModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class ElementwiseDivRoundingModeTruncModule(torch.nn.Module):
class ElementwiseDivScalarRoundingModeTruncModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
])
def forward(self, a):
return torch.div(a, 0.5, rounding_mode="trunc")
@register_test_case(
module_factory=lambda: ElementwiseDivScalarRoundingModeTruncModule())
def ElementwiseDivScalarRoundingModeTruncModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4))
class ElementwiseDivScalarRoundingModeFloorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.div(a, 0.5, rounding_mode="floor")
@register_test_case(
module_factory=lambda: ElementwiseDivScalarRoundingModeFloorModule())
def ElementwiseDivScalarRoundingModeFloorModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
class ElementwiseDivScalarRoundingModeTruncStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([4], torch.float32, True),
])
def forward(self, a):
return torch.div(a, 0.5, rounding_mode="trunc")
@register_test_case(
module_factory=lambda: ElementwiseDivScalarRoundingModeTruncStaticModule())
def ElementwiseDivScalarRoundingModeTruncStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4))
class ElementwiseDivScalarRoundingModeFloorStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4], torch.float32, True),
])
def forward(self, a):
return torch.div(a, 0.5, rounding_mode="floor")
@register_test_case(
module_factory=lambda: ElementwiseDivScalarRoundingModeFloorStaticModule())
def ElementwiseDivScalarRoundingModeFloorStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
class ElementwiseDivScalarRoundingModeTruncIntStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4], torch.int32, True),
])
def forward(self, a):
return torch.div(a, 3, rounding_mode="trunc")
@register_test_case(
module_factory=lambda: ElementwiseDivScalarRoundingModeTruncIntStaticModule())
def ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32))
class ElementwiseDivScalarRoundingModeFloorIntStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4], torch.int32, True),
])
def forward(self, a):
return torch.div(a, 3, rounding_mode="floor")
@register_test_case(
module_factory=lambda: ElementwiseDivScalarRoundingModeFloorIntStaticModule())
def ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32))
# ==============================================================================
class ElementwiseDivTensorRoundingModeTruncModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -2809,12 +2931,12 @@ class ElementwiseDivRoundingModeTruncModule(torch.nn.Module):
@register_test_case( @register_test_case(
module_factory=lambda: ElementwiseDivRoundingModeTruncModule()) module_factory=lambda: ElementwiseDivTensorRoundingModeTruncModule())
def ElementwiseDivRoundingModeTruncModule_basic(module, tu: TestUtils): def ElementwiseDivTensorRoundingModeTruncModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4), tu.rand(4).type(torch.float64)) module.forward(tu.rand(4), tu.rand(4).type(torch.float64))
class ElementwiseDivRoundingModeFloorModule(torch.nn.Module): class ElementwiseDivTensorRoundingModeFloorModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -2830,11 +2952,11 @@ class ElementwiseDivRoundingModeFloorModule(torch.nn.Module):
@register_test_case( @register_test_case(
module_factory=lambda: ElementwiseDivRoundingModeFloorModule()) module_factory=lambda: ElementwiseDivTensorRoundingModeFloorModule())
def ElementwiseDivRoundingModeFloorModule_basic(module, tu: TestUtils): def ElementwiseDivTensorRoundingModeFloorModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4), tu.rand(3, 4).type(torch.float64)) module.forward(tu.rand(3, 4), tu.rand(3, 4).type(torch.float64))
class ElementwiseDivRoundingModeTruncStaticModule(torch.nn.Module): class ElementwiseDivTensorRoundingModeTruncStaticModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -2850,12 +2972,12 @@ class ElementwiseDivRoundingModeTruncStaticModule(torch.nn.Module):
@register_test_case( @register_test_case(
module_factory=lambda: ElementwiseDivRoundingModeTruncStaticModule()) module_factory=lambda: ElementwiseDivTensorRoundingModeTruncStaticModule())
def ElementwiseDivRoundingModeTruncStaticModule_basic(module, tu: TestUtils): def ElementwiseDivTensorRoundingModeTruncStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4), tu.rand(4).type(torch.float64)) module.forward(tu.rand(4), tu.rand(4).type(torch.float64))
class ElementwiseDivRoundingModeFloorStaticModule(torch.nn.Module): class ElementwiseDivTensorRoundingModeFloorStaticModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -2871,11 +2993,11 @@ class ElementwiseDivRoundingModeFloorStaticModule(torch.nn.Module):
@register_test_case( @register_test_case(
module_factory=lambda: ElementwiseDivRoundingModeFloorStaticModule()) module_factory=lambda: ElementwiseDivTensorRoundingModeFloorStaticModule())
def ElementwiseDivRoundingModeFloorStaticModule_basic(module, tu: TestUtils): def ElementwiseDivTensorRoundingModeFloorStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4), tu.rand(3, 4).type(torch.float64)) module.forward(tu.rand(3, 4), tu.rand(3, 4).type(torch.float64))
class ElementwiseDivRoundingModeTruncIntStaticModule(torch.nn.Module): class ElementwiseDivTensorRoundingModeTruncIntStaticModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -2891,12 +3013,12 @@ class ElementwiseDivRoundingModeTruncIntStaticModule(torch.nn.Module):
@register_test_case( @register_test_case(
module_factory=lambda: ElementwiseDivRoundingModeTruncIntStaticModule()) module_factory=lambda: ElementwiseDivTensorRoundingModeTruncIntStaticModule())
def ElementwiseDivRoundingModeTruncIntStaticModule_basic(module, tu: TestUtils): def ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32), tu.randint(3, 4, low=1, high=10).type(torch.int64)) module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32), tu.randint(3, 4, low=1, high=10).type(torch.int64))
class ElementwiseDivRoundingModeFloorIntStaticModule(torch.nn.Module): class ElementwiseDivTensorRoundingModeFloorIntStaticModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -2912,8 +3034,8 @@ class ElementwiseDivRoundingModeFloorIntStaticModule(torch.nn.Module):
@register_test_case( @register_test_case(
module_factory=lambda: ElementwiseDivRoundingModeFloorIntStaticModule()) module_factory=lambda: ElementwiseDivTensorRoundingModeFloorIntStaticModule())
def ElementwiseDivRoundingModeFloorIntStaticModule_basic(module, tu: TestUtils): def ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32), tu.randint(3, 4, low=1, high=10).type(torch.int64)) module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32), tu.randint(3, 4, low=1, high=10).type(torch.int64))
@ -4194,7 +4316,48 @@ def ElementwiseAtenLogicalNotOpPromoteModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class ElementwiseAtenFloorDivideModule(torch.nn.Module): class ElementwiseAtenFloorDivideScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.floor_divide(x, 0.14)
@register_test_case(module_factory=lambda: ElementwiseAtenFloorDivideScalarModule())
def ElementwiseAtenFloorDivideScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 3))
class ElementwiseAtenFloorDivideScalarNegativeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.floor_divide(x, 0.14)
@register_test_case(module_factory=lambda: ElementwiseAtenFloorDivideScalarNegativeModule())
def ElementwiseAtenFloorDivideScalarNegativeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 3, low=-10.0, high=10.0))
# ==============================================================================
class ElementwiseAtenFloorDivideTensorNegativeModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -4209,8 +4372,28 @@ class ElementwiseAtenFloorDivideModule(torch.nn.Module):
return torch.ops.aten.floor_divide(x, y) return torch.ops.aten.floor_divide(x, y)
@register_test_case(module_factory=lambda: ElementwiseAtenFloorDivideModule()) @register_test_case(module_factory=lambda: ElementwiseAtenFloorDivideTensorNegativeModule())
def ElementwiseAtenFloorDivideModule_basic(module, tu: TestUtils): def ElementwiseAtenFloorDivideTensorNegativeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 3, low= -1, high=0), tu.rand(4, 3, low= 0, high=1))
class ElementwiseAtenFloorDivideTensorPositiveModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, x, y):
return torch.ops.aten.floor_divide(x, y)
@register_test_case(module_factory=lambda: ElementwiseAtenFloorDivideTensorPositiveModule())
def ElementwiseAtenFloorDivideTensorPositiveModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 3), tu.rand(4, 3)) module.forward(tu.rand(4, 3), tu.rand(4, 3))