mirror of https://github.com/llvm/torch-mlir
Added 2 Ops: Floor divide scalar and Floor divide scalar mode (#3156)
- Added linalg lowering for `AtenFloorDivideScalarOp` - Needed `AtenDivScalarModeOp` for the decomp. - Added linalg lowering for `AtenDivScalarModeOp` - Moved linalg payload logic to `createDivModePayload()` since the logic was nearly identical for both `AtenDivScalarModeOp` and `AtenDivTensorModeOp`. Just a template function - Added `AtenDivScalarModeOp` lowering for stablehlo Pytorch's [`torch.floor_divide()`](https://pytorch.org/docs/stable/generated/torch.floor_divide.html) in a previous version (for a reason unknown to me) preformed a truncation instead of "floor". The already implemented op `AtenFloorDivideTensorOp` was done before this change. However, this wasn't caught because our testcases only tested positive floor division. I changed this to floor as well as adding a few test cases.pull/3167/head
parent
83cba8c696
commit
5708ee7ec9
|
@ -3397,6 +3397,56 @@ def Torch_AtenDiv_TensorModeOp : Torch_Op<"aten.div_.Tensor_mode", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenDivScalarModeOp : Torch_Op<"aten.div.Scalar_mode", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::div.Scalar_mode : (Tensor, Scalar, str?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchScalarType:$other,
|
||||
AnyTorchOptionalStringType:$rounding_mode
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenDivScalarModeOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenDivScalarModeOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenDiv_ScalarModeOp : Torch_Op<"aten.div_.Scalar_mode", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::div_.Scalar_mode : (Tensor, Scalar, str?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
Torch_NonValueTensorType:$self,
|
||||
AnyTorchScalarType:$other,
|
||||
AnyTorchOptionalStringType:$rounding_mode
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalNonValueTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenDiv_ScalarModeOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenDiv_ScalarModeOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenMulTensorOp : Torch_Op<"aten.mul.Tensor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "llvm/ADT/APSInt.h"
|
||||
#include <numeric>
|
||||
#include <type_traits>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
|
@ -213,6 +214,78 @@ createTriangularMatrix(OpBuilder &b, Location loc, ValueRange payloadArgs,
|
|||
return success();
|
||||
}
|
||||
|
||||
template <typename OpT>
|
||||
Value createDivModePayload(OpBuilder &b, Location loc,
|
||||
const TypeConverter *converter,
|
||||
ValueRange payloadArgs, OpT op,
|
||||
ArrayRef<Value> operands) {
|
||||
static_assert(std::is_same_v<OpT, AtenDivTensorModeOp> ||
|
||||
std::is_same_v<OpT, AtenDivScalarModeOp>,
|
||||
"template type must be a tensor/scalar div mode");
|
||||
typename OpT::Adaptor adaptor(operands);
|
||||
Type dtype = cast<RankedTensorType>(converter->convertType(op.getType()))
|
||||
.getElementType();
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
Value rhs = convertScalarToDtype(
|
||||
b, loc,
|
||||
std::is_same_v<OpT, AtenDivScalarModeOp> ? operands[1] : payloadArgs[1],
|
||||
dtype);
|
||||
|
||||
Value quotient;
|
||||
if (isa<mlir::FloatType>(dtype)) {
|
||||
quotient = b.create<arith::DivFOp>(loc, lhs, rhs);
|
||||
} else if (dtype.isUnsignedInteger()) {
|
||||
quotient = b.create<arith::DivUIOp>(loc, lhs, rhs);
|
||||
} else {
|
||||
assert(dtype.isInteger() &&
|
||||
"dtype should be an integer (signless or signed)");
|
||||
quotient = b.create<arith::DivSIOp>(loc, lhs, rhs);
|
||||
}
|
||||
|
||||
if (isa<Torch::NoneType>(op.getRoundingMode().getType()))
|
||||
return quotient;
|
||||
|
||||
std::string roundingMode;
|
||||
if (!matchPattern(op.getRoundingMode(), m_TorchConstantStr(roundingMode))) {
|
||||
op.emitError("only support constant str rounding mode");
|
||||
return nullptr;
|
||||
}
|
||||
assert((roundingMode == "trunc" || roundingMode == "floor") &&
|
||||
"unsupported rounding mode");
|
||||
if (roundingMode == "trunc") {
|
||||
// "trunc" - rounds the results of the division towards zero. Equivalent
|
||||
// to C-style integer division.
|
||||
if (!isa<mlir::FloatType>(dtype)) {
|
||||
// nothing to do for integers
|
||||
return quotient;
|
||||
}
|
||||
|
||||
// float
|
||||
Value ceil = b.create<math::CeilOp>(loc, quotient);
|
||||
Value floor = b.create<math::FloorOp>(loc, quotient);
|
||||
Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
|
||||
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
|
||||
quotient, cstZero);
|
||||
return b.create<arith::SelectOp>(loc, pred, ceil, floor);
|
||||
}
|
||||
if (roundingMode == "floor") {
|
||||
// "floor" - rounds the results of the division down. Equivalent to
|
||||
// floor division in Python (the // operator)
|
||||
if (isa<mlir::FloatType>(dtype))
|
||||
return b.create<math::FloorOp>(loc, quotient);
|
||||
if (!dtype.isUnsignedInteger()) {
|
||||
Type defaultIntToFloatType = b.getF64Type();
|
||||
lhs = convertScalarToDtype(b, loc, lhs, defaultIntToFloatType);
|
||||
rhs = convertScalarToDtype(b, loc, rhs, defaultIntToFloatType);
|
||||
quotient = b.create<arith::DivFOp>(loc, lhs, rhs);
|
||||
Value floor = b.create<math::FloorOp>(loc, quotient);
|
||||
Value convert = convertScalarToDtype(b, loc, floor, dtype);
|
||||
return convert;
|
||||
}
|
||||
}
|
||||
return quotient;
|
||||
}
|
||||
|
||||
static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||
OpBuilder &b, Location loc, const TypeConverter *converter,
|
||||
ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) {
|
||||
|
@ -769,66 +842,14 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
div.emitError("unimplemented: non-floating point and non-integer dtype");
|
||||
return nullptr;
|
||||
}
|
||||
if (auto divTensorMode = dyn_cast<AtenDivTensorModeOp>(op)) {
|
||||
AtenDivTensorModeOp::Adaptor adaptor(operands);
|
||||
Type dtype = converter->convertType(divTensorMode.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||
Value div;
|
||||
if (isa<mlir::FloatType>(dtype))
|
||||
div = b.create<arith::DivFOp>(loc, lhs, rhs);
|
||||
else {
|
||||
if (dtype.isUnsignedInteger())
|
||||
div = b.create<arith::DivUIOp>(loc, lhs, rhs);
|
||||
else
|
||||
div = b.create<arith::DivSIOp>(loc, lhs, rhs);
|
||||
}
|
||||
|
||||
if (divTensorMode.getRoundingMode().getType().isa<Torch::NoneType>())
|
||||
return div;
|
||||
|
||||
std::string roundingMode;
|
||||
if (!matchPattern(divTensorMode.getRoundingMode(),
|
||||
m_TorchConstantStr(roundingMode))) {
|
||||
divTensorMode.emitError("only support constant str rounding mode");
|
||||
return nullptr;
|
||||
}
|
||||
if (roundingMode == "trunc") {
|
||||
// "trunc" - rounds the results of the division towards zero. Equivalent
|
||||
// to C-style integer division.
|
||||
if (isa<mlir::FloatType>(dtype)) {
|
||||
Value ceil = b.create<math::CeilOp>(loc, div);
|
||||
Value floor = b.create<math::FloorOp>(loc, div);
|
||||
Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
|
||||
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
|
||||
div, cstZero);
|
||||
return b.create<arith::SelectOp>(loc, pred, ceil, floor);
|
||||
} else
|
||||
return div;
|
||||
}
|
||||
if (roundingMode == "floor") {
|
||||
// "floor" - rounds the results of the division down. Equivalent to
|
||||
// floor division in Python (the // operator)
|
||||
if (isa<mlir::FloatType>(dtype))
|
||||
return b.create<math::FloorOp>(loc, div);
|
||||
else if (!dtype.isUnsignedInteger()) {
|
||||
Type defaultIntToFloatType = b.getF64Type();
|
||||
lhs = convertScalarToDtype(b, loc, lhs, defaultIntToFloatType);
|
||||
rhs = convertScalarToDtype(b, loc, rhs, defaultIntToFloatType);
|
||||
div = b.create<arith::DivFOp>(loc, lhs, rhs);
|
||||
Value floor = b.create<math::FloorOp>(loc, div);
|
||||
Value convert = convertScalarToDtype(b, loc, floor, dtype);
|
||||
return convert;
|
||||
} else {
|
||||
return div;
|
||||
}
|
||||
}
|
||||
divTensorMode.emitError("invalid rounding mode");
|
||||
return nullptr;
|
||||
if (auto divScalarMode = dyn_cast<AtenDivScalarModeOp>(op)) {
|
||||
return createDivModePayload(b, loc, converter, payloadArgs, divScalarMode,
|
||||
operands);
|
||||
}
|
||||
if (auto divTensorMode = dyn_cast<AtenDivTensorModeOp>(op)) {
|
||||
return createDivModePayload(b, loc, converter, payloadArgs, divTensorMode,
|
||||
operands);
|
||||
}
|
||||
|
||||
if (auto pow = dyn_cast<AtenPowScalarOp>(op)) {
|
||||
Type dtype = pow.getType().cast<ValueTensorType>().getDtype();
|
||||
if (!isa<mlir::FloatType>(dtype)) {
|
||||
|
@ -1579,12 +1600,13 @@ public:
|
|||
if (!isa<AtenTanOp, AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenReluOp,
|
||||
AtenPreluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp,
|
||||
AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp,
|
||||
AtenSubTensorOp, AtenAtan2Op, AtenLerpTensorOp, AtenSigmoidOp,
|
||||
AtenExpOp, AtenExpm1Op, AtenMinimumOp, AtenMaximumOp,
|
||||
AtenToDtypeOp, AtenClampOp, AtenClampTensorOp, AtenRsubScalarOp,
|
||||
AtenMulScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp,
|
||||
AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp,
|
||||
AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp,
|
||||
AtenDivScalarModeOp, AtenSubTensorOp, AtenAtan2Op,
|
||||
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenExpm1Op,
|
||||
AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
|
||||
AtenClampTensorOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp,
|
||||
AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenPowScalarOp,
|
||||
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op,
|
||||
AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp,
|
||||
AtenRemainderScalarOp, AtenRemainderTensorOp, AtenFmodTensorOp,
|
||||
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
|
||||
AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
|
||||
|
@ -2617,25 +2639,25 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
|||
AtenTanOp, AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenAtanhOp, AtenAcoshOp,
|
||||
AtenAsinOp, AtenAsinhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp,
|
||||
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp,
|
||||
AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp,
|
||||
AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenClampTensorOp,
|
||||
AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp,
|
||||
AtenCeilOp, AtenPreluOp, AtenPowScalarOp, AtenPowTensorScalarOp,
|
||||
AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp,
|
||||
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
|
||||
AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
|
||||
AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp,
|
||||
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
|
||||
AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp,
|
||||
AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp,
|
||||
AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp,
|
||||
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp,
|
||||
AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenAcosOp,
|
||||
AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, AtenTrilOp,
|
||||
AtenRemainderScalarOp, AtenFmodTensorOp, AtenRemainderTensorOp,
|
||||
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
|
||||
AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
|
||||
AtenQuantizePerTensorOp>();
|
||||
AtenDivScalarModeOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp,
|
||||
AtenMinimumOp, AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
|
||||
AtenClampTensorOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp,
|
||||
AtenFloorOp, AtenCeilOp, AtenPreluOp, AtenPowScalarOp,
|
||||
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op,
|
||||
AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
|
||||
AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
|
||||
AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
|
||||
AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp,
|
||||
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
|
||||
AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp,
|
||||
AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
|
||||
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp,
|
||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp,
|
||||
AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp,
|
||||
AtenTrilOp, AtenRemainderScalarOp, AtenFmodTensorOp,
|
||||
AtenRemainderTensorOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
|
||||
AtenFillTensorOp, AtenRealOp, AtenImagOp, AtenDequantizeSelfOp,
|
||||
AtenDequantizeTensorOp, AtenQuantizePerTensorOp>();
|
||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenNllLossForwardOp>();
|
||||
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
||||
|
|
|
@ -27,8 +27,8 @@
|
|||
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <type_traits>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
|
@ -409,9 +409,9 @@ public:
|
|||
if (!lhsType)
|
||||
return op.emitError("only Tensor types supported in StableHLO");
|
||||
|
||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
auto outType = cast<TensorType>(
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()));
|
||||
|
||||
Type outElemTy = outType.getElementType();
|
||||
if (!outElemTy.isIntOrFloat()) {
|
||||
|
@ -432,18 +432,23 @@ public:
|
|||
Value result =
|
||||
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
|
||||
|
||||
if (!isa<AtenDivTensorModeOp>(op)) {
|
||||
if (!std::is_same<AtenDivTensorModeOp, AtenOpT>() &&
|
||||
!std::is_same<AtenDivScalarModeOp, AtenOpT>()) {
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
AtenDivTensorModeOp divTensorModeOp =
|
||||
llvm::dyn_cast<AtenDivTensorModeOp>(op.getOperation());
|
||||
auto tensorOp = dyn_cast<AtenDivTensorModeOp>(op.getOperation());
|
||||
auto opRoundingMode =
|
||||
tensorOp
|
||||
? tensorOp.getRoundingMode()
|
||||
: cast<AtenDivScalarModeOp>(op.getOperation()).getRoundingMode();
|
||||
|
||||
std::string roundingMode;
|
||||
if (!matchPattern(divTensorModeOp.getRoundingMode(),
|
||||
m_TorchConstantStr(roundingMode)))
|
||||
if (!matchPattern(opRoundingMode, m_TorchConstantStr(roundingMode))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only support constant str rounding mode");
|
||||
}
|
||||
|
||||
// if trunc and int, do nothing
|
||||
if (roundingMode == "trunc" && isa<mlir::FloatType>(outElemTy)) {
|
||||
|
@ -1845,6 +1850,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
|||
INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorOp, chlo::BroadcastDivOp);
|
||||
INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorModeOp, chlo::BroadcastDivOp);
|
||||
INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarOp, chlo::BroadcastDivOp);
|
||||
INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarModeOp, chlo::BroadcastDivOp);
|
||||
INSERT_BINARY_MULDIV_PATTERN(AtenRemainderScalarOp, chlo::BroadcastRemOp);
|
||||
#undef INSERT_BINARY_MULDIV_PATTERN
|
||||
|
||||
|
|
|
@ -1095,9 +1095,9 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
|
|||
rhs = rewriter.create<AtenMulIntOp>(loc, rhs, alpha);
|
||||
}
|
||||
|
||||
if (isa<AtenDivTensorModeOp>(op)) {
|
||||
// None rounding mode
|
||||
if (isa<AtenDivTensorModeOp, AtenDivScalarModeOp>(op)) {
|
||||
if (op->getOperand(2).getType().isa<Torch::NoneType>()) {
|
||||
// None rounding mode
|
||||
Value quotient = rewriter.create<AtenDivOp>(loc, lhs, rhs);
|
||||
rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(op, outType,
|
||||
quotient);
|
||||
|
@ -1858,6 +1858,16 @@ void AtenDivTensorModeOp::getCanonicalizationPatterns(
|
|||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenDivScalarModeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
void AtenDivScalarModeOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
patterns.add(+[](AtenDivScalarModeOp op, PatternRewriter &rewriter) {
|
||||
return rewrite0DBinaryTensorOp(op, rewriter);
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenNumelOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -8496,6 +8496,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.div.Scalar_mode\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.optional<str>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.floor_divide\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -11166,6 +11170,83 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %2 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
|
||||
" torch.prim.If %2 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %3 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
|
||||
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
|
||||
" %5 = torch.prim.ListConstruct %0#1, %4 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %5) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %6 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.div.Scalar_mode\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.optional<str>) -> !torch.int {\n"
|
||||
" %str = torch.constant.str \"trunc\"\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %false = torch.constant.bool false\n"
|
||||
" %str_0 = torch.constant.str \"floor\"\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional<str>, !torch.none -> !torch.bool\n"
|
||||
" %1 = torch.prim.If %0 -> (!torch.bool) {\n"
|
||||
" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional<str> -> !torch.str\n"
|
||||
" %4 = torch.aten.eq.str %3, %str_0 : !torch.str, !torch.str -> !torch.bool\n"
|
||||
" torch.prim.If.yield %4 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
|
||||
" %3 = func.call @\"__torch_mlir_dtype_fn.aten.floor_divide.Scalar\"(%arg0, %arg1) : (!torch.tuple<int, int>, !torch.number) -> !torch.int\n"
|
||||
" torch.prim.If.yield %3 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" %3:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %4 = torch.prim.ListConstruct %none, %3#0 : (!torch.none, !torch.int) -> !torch.list<optional<int>>\n"
|
||||
" %5 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
|
||||
" %6 = torch.prim.ListConstruct %5, %3#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%4, %6) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" %8 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%7) : (!torch.int) -> !torch.bool\n"
|
||||
" %9 = torch.prim.If %8 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %12 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%7) : (!torch.int) -> !torch.bool\n"
|
||||
" %13 = torch.prim.If %12 -> (!torch.bool) {\n"
|
||||
" %14 = torch.aten.ne.int %7, %int6 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %14 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If.yield %13 : !torch.bool\n"
|
||||
" }\n"
|
||||
" %10 = torch.prim.If %9 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %12 = torch.aten.__isnot__ %arg2, %none : !torch.optional<str>, !torch.none -> !torch.bool\n"
|
||||
" %13 = torch.prim.If %12 -> (!torch.bool) {\n"
|
||||
" %14 = torch.prim.unchecked_cast %arg2 : !torch.optional<str> -> !torch.str\n"
|
||||
" %15 = torch.aten.eq.str %14, %str : !torch.str, !torch.str -> !torch.bool\n"
|
||||
" torch.prim.If.yield %15 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If.yield %13 : !torch.bool\n"
|
||||
" }\n"
|
||||
" %11 = torch.prim.If %10 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %7 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %int6 : !torch.int\n"
|
||||
" }\n"
|
||||
" torch.prim.If.yield %11 : !torch.int\n"
|
||||
" }\n"
|
||||
" return %2 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.matmul\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
|
@ -11688,24 +11769,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
|
||||
" torch.prim.If %2 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %3 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
|
||||
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
|
||||
" %5 = torch.prim.ListConstruct %0#1, %4 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %5) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %6 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.pow.Scalar\"(%arg0: !torch.number, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
|
|
|
@ -5800,7 +5800,7 @@ class DecomposeAtenFloorDivideOp : public OpRewritePattern<AtenFloorDivideOp> {
|
|||
// PyTorch aten.floorDivide is a misnomer because it actually rounds
|
||||
// the quotient towards zero instead of taking its floor.
|
||||
Value cstStrFloor =
|
||||
rewriter.create<Torch::ConstantStrOp>(op.getLoc(), "trunc");
|
||||
rewriter.create<Torch::ConstantStrOp>(op.getLoc(), "floor");
|
||||
rewriter.replaceOpWithNewOp<AtenDivTensorModeOp>(
|
||||
op, op.getType(), op.getSelf(), op.getOther(),
|
||||
/*roundingMode=*/cstStrFloor);
|
||||
|
@ -5809,6 +5809,22 @@ class DecomposeAtenFloorDivideOp : public OpRewritePattern<AtenFloorDivideOp> {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenFloorDivideScalarOp
|
||||
: public OpRewritePattern<AtenFloorDivideScalarOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenFloorDivideScalarOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value cstStrFloor =
|
||||
rewriter.create<Torch::ConstantStrOp>(op.getLoc(), "floor");
|
||||
rewriter.replaceOpWithNewOp<AtenDivScalarModeOp>(
|
||||
op, op.getType(), op.getSelf(), op.getOther(),
|
||||
/*roundingMode=*/cstStrFloor);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.numpyT` op into `aten.permute` op.
|
||||
class DecomposeAtenNumpyTOp : public OpRewritePattern<AtenNumpyTOp> {
|
||||
|
@ -7560,6 +7576,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenCosineSimilarityOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenBaddbmmOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenFloorDivideOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenFloorDivideScalarOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNumpyTOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectScatterOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarDimOp>(patterns);
|
||||
|
|
|
@ -483,6 +483,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenClampMaxOp>();
|
||||
target.addIllegalOp<AtenBaddbmmOp>();
|
||||
target.addIllegalOp<AtenFloorDivideOp>();
|
||||
target.addIllegalOp<AtenFloorDivideScalarOp>();
|
||||
target.addIllegalOp<AtenNumpyTOp>();
|
||||
target.addIllegalOp<AtenSelectScatterOp>();
|
||||
target.addIllegalOp<AtenVarDimOp>();
|
||||
|
|
|
@ -244,12 +244,20 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
"ElementwiseSubScalarIntModule_basic",
|
||||
|
||||
# ERROR: Exception: Unsupported: missing default value for argument 0 in schema for aten.div.Tensor_mode
|
||||
"ElementwiseDivRoundingModeFloorModule_basic",
|
||||
"ElementwiseDivRoundingModeTruncModule_basic",
|
||||
"ElementwiseDivRoundingModeFloorStaticModule_basic",
|
||||
"ElementwiseDivRoundingModeTruncStaticModule_basic",
|
||||
"ElementwiseDivRoundingModeFloorIntStaticModule_basic",
|
||||
"ElementwiseDivRoundingModeTruncIntStaticModule_basic",
|
||||
"ElementwiseAtenFloorDivideScalarNegativeModule_basic",
|
||||
"ElementwiseAtenFloorDivideScalarModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeFloorModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeTruncModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeFloorStaticModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeTruncStaticModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic",
|
||||
"ElementwiseDivScalarRoundingModeFloorModule_basic",
|
||||
"ElementwiseDivScalarRoundingModeTruncModule_basic",
|
||||
"ElementwiseDivScalarRoundingModeFloorStaticModule_basic",
|
||||
"ElementwiseDivScalarRoundingModeTruncStaticModule_basic",
|
||||
"ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic",
|
||||
"ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic",
|
||||
"AdaptiveAvgPool1dStaticLargerOutput_basic",
|
||||
"AdaptiveAvgPool1dGeneralDynamic_basic",
|
||||
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
||||
|
@ -799,10 +807,14 @@ STABLEHLO_PASS_SET = {
|
|||
"ElementwiseCloneContiguousModule_basic",
|
||||
"ElementwiseCloneModule_basic",
|
||||
"ElementwiseCosModule_basic",
|
||||
"ElementwiseDivRoundingModeFloorStaticModule_basic",
|
||||
"ElementwiseDivRoundingModeTruncStaticModule_basic",
|
||||
"ElementwiseDivRoundingModeFloorIntStaticModule_basic",
|
||||
"ElementwiseDivRoundingModeTruncIntStaticModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeFloorStaticModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeTruncStaticModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic",
|
||||
"ElementwiseDivScalarRoundingModeFloorStaticModule_basic",
|
||||
"ElementwiseDivScalarRoundingModeTruncStaticModule_basic",
|
||||
"ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic",
|
||||
"ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic",
|
||||
"ElementwiseErfModule_basic",
|
||||
"ElementwiseExpModule_basic",
|
||||
"ElementwiseFloorIntModule_basic",
|
||||
|
@ -1354,6 +1366,8 @@ TOSA_PASS_SET = {
|
|||
"ElementwiseDivScalarModule_basic",
|
||||
"ElementwiseDivTensorIntegerModule_basic",
|
||||
"ElementwiseDivTensorUnsignedIntegerModule_basic",
|
||||
"ElementwiseDivScalarIntegerModule_basic",
|
||||
"ElementwiseDivScalarUnsignedIntegerModule_basic",
|
||||
"ElementwiseEluModule_basic",
|
||||
"ElementwiseEluNonDefaultModule_basic",
|
||||
"ElementwiseEqBoolScalarModule_basic",
|
||||
|
@ -2450,10 +2464,11 @@ ONNX_XFAIL_SET = {
|
|||
"ElementwiseAsinIntModule_basic",
|
||||
"ElementwiseAtanTensorIntModule_basic",
|
||||
"ElementwiseCosIntModule_basic",
|
||||
"ElementwiseDivRoundingModeFloorIntStaticModule_basic",
|
||||
"ElementwiseDivRoundingModeTruncIntStaticModule_basic",
|
||||
"ElementwiseDivRoundingModeTruncModule_basic",
|
||||
"ElementwiseDivRoundingModeTruncStaticModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeTruncModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeTruncStaticModule_basic",
|
||||
"ElementwiseErfIntModule_basic",
|
||||
"ElementwiseErfIntModule_basic",
|
||||
"ElementwiseExpIntModule_basic",
|
||||
"ElementwiseLogIntModule_basic",
|
||||
|
@ -2484,7 +2499,12 @@ ONNX_XFAIL_SET = {
|
|||
"TensorsStackPromoteDTypeModule_basic",
|
||||
|
||||
# Failure - "RuntimeError: linalg.cross: inputs dimension 1 must have length 3. Got 1 and 1"
|
||||
"AtenLinalgCrossDynamic_basic"
|
||||
"AtenLinalgCrossDynamic_basic",
|
||||
|
||||
# Failure - value not close to golden value (op is incorrectly truncating)
|
||||
"ElementwiseAtenFloorDivideTensorNegativeModule_basic",
|
||||
"ElementwiseAtenFloorDivideScalarNegativeModule_basic",
|
||||
|
||||
}
|
||||
|
||||
ONNX_CRASHING_SET = {
|
||||
|
|
|
@ -1209,6 +1209,9 @@ def aten〇div〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]:
|
|||
def aten〇div〇Tensor_mode〡shape(self: List[int], other: List[int], rounding_mode: Optional[str]) -> List[int]:
|
||||
return upstream_shape_functions.broadcast(self, other)
|
||||
|
||||
def aten〇div〇Scalar_mode〡shape(self: List[int], other: float, rounding_mode: Optional[str]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇floor_divide〡shape(self: List[int], other: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.broadcast(self, other)
|
||||
|
||||
|
@ -3152,6 +3155,45 @@ def aten〇div〇Tensor_mode〡dtype(self_rank_dtype: Tuple[int, int], other_ran
|
|||
else:
|
||||
return torch.float32
|
||||
|
||||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1.0))
|
||||
def aten〇floor_divide〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
assert not is_complex_dtype(self_dtype)
|
||||
ranks: List[Optional[int]] = [self_rank, None]
|
||||
dtypes = [self_dtype, get_dtype_of_scalar(other)]
|
||||
return promote_dtypes(ranks, dtypes)
|
||||
|
||||
@check_dtype_function(
|
||||
|
||||
# _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0, rounding_mode=None, error_types={torch.complex64, torch.complex128}) +
|
||||
# _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0, rounding_mode=None, error_types={torch.complex64, torch.complex128}) +
|
||||
# _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0, rounding_mode="floor", error_types={torch.complex64, torch.complex128}) +
|
||||
# _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0, rounding_mode="floor", error_types={torch.complex64, torch.complex128}) +
|
||||
# _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0, rounding_mode="trunc", error_types={torch.complex64, torch.complex128}) +
|
||||
# _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0, rounding_mode="trunc", error_types={torch.complex64, torch.complex128}))
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0, rounding_mode=None) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0, rounding_mode=None) +
|
||||
_check_tensors_with_the_same_dtype(error_types={torch.complex64, torch.complex128}, num_of_tensors=1, other=0.0, rounding_mode="floor") +
|
||||
_check_tensors_with_the_same_dtype(error_types={torch.complex64, torch.complex128}, num_of_tensors=1, other=0, rounding_mode="floor") +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0, rounding_mode="trunc") +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0, rounding_mode="trunc"))
|
||||
def aten〇div〇Scalar_mode〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex], rounding_mode: Optional[str]) -> int:
|
||||
if rounding_mode is not None and rounding_mode == "floor":
|
||||
return aten〇floor_divide〇Scalar〡dtype(self_rank_dtype, other)
|
||||
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
ranks: List[Optional[int]] = [None, self_rank]
|
||||
dtypes = [get_dtype_of_scalar(other), self_dtype]
|
||||
promoted_dtype = promote_dtypes(ranks, dtypes)
|
||||
if is_complex_dtype(promoted_dtype) or \
|
||||
(is_float_dtype(promoted_dtype) and promoted_dtype != torch.float32) or \
|
||||
(rounding_mode is not None and rounding_mode == "trunc"):
|
||||
return promoted_dtype
|
||||
else:
|
||||
return torch.float32
|
||||
|
||||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 4), (2, 4, 3)]) +
|
||||
# Different width
|
||||
|
@ -3631,15 +3673,6 @@ def aten〇fmod〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dt
|
|||
dtypes = [self_dtype, other_dtype]
|
||||
return promote_dtypes(ranks, dtypes)
|
||||
|
||||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1.0))
|
||||
def aten〇floor_divide〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
assert not is_complex_dtype(self_dtype)
|
||||
ranks: List[Optional[int]] = [self_rank, None]
|
||||
dtypes = [self_dtype, get_dtype_of_scalar(other)]
|
||||
return promote_dtypes(ranks, dtypes)
|
||||
|
||||
def aten〇pow〇Scalar〡dtype(self: Union[int, float, complex], exponent_rank_dtype: Tuple[int, int]) -> int:
|
||||
exponent_rank, exponent_dtype = exponent_rank_dtype
|
||||
|
|
|
@ -342,6 +342,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
# Elementwise tensor compute ops that don't have the standard mutating
|
||||
# variants.
|
||||
emit_with_mutating_variants("aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", has_canonicalizer=True)
|
||||
emit_with_mutating_variants("aten::div.Scalar_mode : (Tensor, Scalar, str?) -> (Tensor)", has_canonicalizer=True)
|
||||
emit_with_mutating_variants("aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True, has_folder=True)
|
||||
emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True)
|
||||
emit_with_mutating_variants("aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True)
|
||||
|
|
|
@ -2793,7 +2793,129 @@ def ElementwiseDivTensorUnsignedIntegerModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseDivRoundingModeTruncModule(torch.nn.Module):
|
||||
|
||||
class ElementwiseDivScalarRoundingModeTruncModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.div(a, 0.5, rounding_mode="trunc")
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseDivScalarRoundingModeTruncModule())
|
||||
def ElementwiseDivScalarRoundingModeTruncModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4))
|
||||
|
||||
|
||||
class ElementwiseDivScalarRoundingModeFloorModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.div(a, 0.5, rounding_mode="floor")
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseDivScalarRoundingModeFloorModule())
|
||||
def ElementwiseDivScalarRoundingModeFloorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
class ElementwiseDivScalarRoundingModeTruncStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([4], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.div(a, 0.5, rounding_mode="trunc")
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseDivScalarRoundingModeTruncStaticModule())
|
||||
def ElementwiseDivScalarRoundingModeTruncStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4))
|
||||
|
||||
|
||||
class ElementwiseDivScalarRoundingModeFloorStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 4], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.div(a, 0.5, rounding_mode="floor")
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseDivScalarRoundingModeFloorStaticModule())
|
||||
def ElementwiseDivScalarRoundingModeFloorStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
class ElementwiseDivScalarRoundingModeTruncIntStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 4], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.div(a, 3, rounding_mode="trunc")
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseDivScalarRoundingModeTruncIntStaticModule())
|
||||
def ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32))
|
||||
|
||||
|
||||
class ElementwiseDivScalarRoundingModeFloorIntStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 4], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.div(a, 3, rounding_mode="floor")
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseDivScalarRoundingModeFloorIntStaticModule())
|
||||
def ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseDivTensorRoundingModeTruncModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -2809,12 +2931,12 @@ class ElementwiseDivRoundingModeTruncModule(torch.nn.Module):
|
|||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseDivRoundingModeTruncModule())
|
||||
def ElementwiseDivRoundingModeTruncModule_basic(module, tu: TestUtils):
|
||||
module_factory=lambda: ElementwiseDivTensorRoundingModeTruncModule())
|
||||
def ElementwiseDivTensorRoundingModeTruncModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4), tu.rand(4).type(torch.float64))
|
||||
|
||||
|
||||
class ElementwiseDivRoundingModeFloorModule(torch.nn.Module):
|
||||
class ElementwiseDivTensorRoundingModeFloorModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -2830,11 +2952,11 @@ class ElementwiseDivRoundingModeFloorModule(torch.nn.Module):
|
|||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseDivRoundingModeFloorModule())
|
||||
def ElementwiseDivRoundingModeFloorModule_basic(module, tu: TestUtils):
|
||||
module_factory=lambda: ElementwiseDivTensorRoundingModeFloorModule())
|
||||
def ElementwiseDivTensorRoundingModeFloorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4), tu.rand(3, 4).type(torch.float64))
|
||||
|
||||
class ElementwiseDivRoundingModeTruncStaticModule(torch.nn.Module):
|
||||
class ElementwiseDivTensorRoundingModeTruncStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -2850,12 +2972,12 @@ class ElementwiseDivRoundingModeTruncStaticModule(torch.nn.Module):
|
|||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseDivRoundingModeTruncStaticModule())
|
||||
def ElementwiseDivRoundingModeTruncStaticModule_basic(module, tu: TestUtils):
|
||||
module_factory=lambda: ElementwiseDivTensorRoundingModeTruncStaticModule())
|
||||
def ElementwiseDivTensorRoundingModeTruncStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4), tu.rand(4).type(torch.float64))
|
||||
|
||||
|
||||
class ElementwiseDivRoundingModeFloorStaticModule(torch.nn.Module):
|
||||
class ElementwiseDivTensorRoundingModeFloorStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -2871,11 +2993,11 @@ class ElementwiseDivRoundingModeFloorStaticModule(torch.nn.Module):
|
|||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseDivRoundingModeFloorStaticModule())
|
||||
def ElementwiseDivRoundingModeFloorStaticModule_basic(module, tu: TestUtils):
|
||||
module_factory=lambda: ElementwiseDivTensorRoundingModeFloorStaticModule())
|
||||
def ElementwiseDivTensorRoundingModeFloorStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4), tu.rand(3, 4).type(torch.float64))
|
||||
|
||||
class ElementwiseDivRoundingModeTruncIntStaticModule(torch.nn.Module):
|
||||
class ElementwiseDivTensorRoundingModeTruncIntStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -2891,12 +3013,12 @@ class ElementwiseDivRoundingModeTruncIntStaticModule(torch.nn.Module):
|
|||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseDivRoundingModeTruncIntStaticModule())
|
||||
def ElementwiseDivRoundingModeTruncIntStaticModule_basic(module, tu: TestUtils):
|
||||
module_factory=lambda: ElementwiseDivTensorRoundingModeTruncIntStaticModule())
|
||||
def ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32), tu.randint(3, 4, low=1, high=10).type(torch.int64))
|
||||
|
||||
|
||||
class ElementwiseDivRoundingModeFloorIntStaticModule(torch.nn.Module):
|
||||
class ElementwiseDivTensorRoundingModeFloorIntStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -2912,8 +3034,8 @@ class ElementwiseDivRoundingModeFloorIntStaticModule(torch.nn.Module):
|
|||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseDivRoundingModeFloorIntStaticModule())
|
||||
def ElementwiseDivRoundingModeFloorIntStaticModule_basic(module, tu: TestUtils):
|
||||
module_factory=lambda: ElementwiseDivTensorRoundingModeFloorIntStaticModule())
|
||||
def ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32), tu.randint(3, 4, low=1, high=10).type(torch.int64))
|
||||
|
||||
|
||||
|
@ -4194,7 +4316,48 @@ def ElementwiseAtenLogicalNotOpPromoteModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAtenFloorDivideModule(torch.nn.Module):
|
||||
class ElementwiseAtenFloorDivideScalarModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.floor_divide(x, 0.14)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtenFloorDivideScalarModule())
|
||||
def ElementwiseAtenFloorDivideScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 3))
|
||||
|
||||
|
||||
class ElementwiseAtenFloorDivideScalarNegativeModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.floor_divide(x, 0.14)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtenFloorDivideScalarNegativeModule())
|
||||
def ElementwiseAtenFloorDivideScalarNegativeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 3, low=-10.0, high=10.0))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAtenFloorDivideTensorNegativeModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -4209,8 +4372,28 @@ class ElementwiseAtenFloorDivideModule(torch.nn.Module):
|
|||
return torch.ops.aten.floor_divide(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtenFloorDivideModule())
|
||||
def ElementwiseAtenFloorDivideModule_basic(module, tu: TestUtils):
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtenFloorDivideTensorNegativeModule())
|
||||
def ElementwiseAtenFloorDivideTensorNegativeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 3, low= -1, high=0), tu.rand(4, 3, low= 0, high=1))
|
||||
|
||||
|
||||
class ElementwiseAtenFloorDivideTensorPositiveModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.floor_divide(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtenFloorDivideTensorPositiveModule())
|
||||
def ElementwiseAtenFloorDivideTensorPositiveModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 3), tu.rand(4, 3))
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue