mirror of https://github.com/llvm/torch-mlir
Added 2 Ops: Floor divide scalar and Floor divide scalar mode (#3156)
- Added linalg lowering for `AtenFloorDivideScalarOp` - Needed `AtenDivScalarModeOp` for the decomp. - Added linalg lowering for `AtenDivScalarModeOp` - Moved linalg payload logic to `createDivModePayload()` since the logic was nearly identical for both `AtenDivScalarModeOp` and `AtenDivTensorModeOp`. Just a template function - Added `AtenDivScalarModeOp` lowering for stablehlo Pytorch's [`torch.floor_divide()`](https://pytorch.org/docs/stable/generated/torch.floor_divide.html) in a previous version (for a reason unknown to me) preformed a truncation instead of "floor". The already implemented op `AtenFloorDivideTensorOp` was done before this change. However, this wasn't caught because our testcases only tested positive floor division. I changed this to floor as well as adding a few test cases.pull/3167/head
parent
83cba8c696
commit
5708ee7ec9
|
@ -3397,6 +3397,56 @@ def Torch_AtenDiv_TensorModeOp : Torch_Op<"aten.div_.Tensor_mode", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenDivScalarModeOp : Torch_Op<"aten.div.Scalar_mode", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::div.Scalar_mode : (Tensor, Scalar, str?) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchScalarType:$other,
|
||||||
|
AnyTorchOptionalStringType:$rounding_mode
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchOptionalTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenDivScalarModeOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||||
|
}
|
||||||
|
void AtenDivScalarModeOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 3, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_AtenDiv_ScalarModeOp : Torch_Op<"aten.div_.Scalar_mode", [
|
||||||
|
IsTrailingUnderscoreInplaceVariant,
|
||||||
|
AllowsTypeRefinement
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::div_.Scalar_mode : (Tensor, Scalar, str?) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
Torch_NonValueTensorType:$self,
|
||||||
|
AnyTorchScalarType:$other,
|
||||||
|
AnyTorchOptionalStringType:$rounding_mode
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchOptionalNonValueTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenDiv_ScalarModeOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||||
|
}
|
||||||
|
void AtenDiv_ScalarModeOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 3, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenMulTensorOp : Torch_Op<"aten.mul.Tensor", [
|
def Torch_AtenMulTensorOp : Torch_Op<"aten.mul.Tensor", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -26,6 +26,7 @@
|
||||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||||
#include "llvm/ADT/APSInt.h"
|
#include "llvm/ADT/APSInt.h"
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
|
@ -213,6 +214,78 @@ createTriangularMatrix(OpBuilder &b, Location loc, ValueRange payloadArgs,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename OpT>
|
||||||
|
Value createDivModePayload(OpBuilder &b, Location loc,
|
||||||
|
const TypeConverter *converter,
|
||||||
|
ValueRange payloadArgs, OpT op,
|
||||||
|
ArrayRef<Value> operands) {
|
||||||
|
static_assert(std::is_same_v<OpT, AtenDivTensorModeOp> ||
|
||||||
|
std::is_same_v<OpT, AtenDivScalarModeOp>,
|
||||||
|
"template type must be a tensor/scalar div mode");
|
||||||
|
typename OpT::Adaptor adaptor(operands);
|
||||||
|
Type dtype = cast<RankedTensorType>(converter->convertType(op.getType()))
|
||||||
|
.getElementType();
|
||||||
|
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||||
|
Value rhs = convertScalarToDtype(
|
||||||
|
b, loc,
|
||||||
|
std::is_same_v<OpT, AtenDivScalarModeOp> ? operands[1] : payloadArgs[1],
|
||||||
|
dtype);
|
||||||
|
|
||||||
|
Value quotient;
|
||||||
|
if (isa<mlir::FloatType>(dtype)) {
|
||||||
|
quotient = b.create<arith::DivFOp>(loc, lhs, rhs);
|
||||||
|
} else if (dtype.isUnsignedInteger()) {
|
||||||
|
quotient = b.create<arith::DivUIOp>(loc, lhs, rhs);
|
||||||
|
} else {
|
||||||
|
assert(dtype.isInteger() &&
|
||||||
|
"dtype should be an integer (signless or signed)");
|
||||||
|
quotient = b.create<arith::DivSIOp>(loc, lhs, rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isa<Torch::NoneType>(op.getRoundingMode().getType()))
|
||||||
|
return quotient;
|
||||||
|
|
||||||
|
std::string roundingMode;
|
||||||
|
if (!matchPattern(op.getRoundingMode(), m_TorchConstantStr(roundingMode))) {
|
||||||
|
op.emitError("only support constant str rounding mode");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
assert((roundingMode == "trunc" || roundingMode == "floor") &&
|
||||||
|
"unsupported rounding mode");
|
||||||
|
if (roundingMode == "trunc") {
|
||||||
|
// "trunc" - rounds the results of the division towards zero. Equivalent
|
||||||
|
// to C-style integer division.
|
||||||
|
if (!isa<mlir::FloatType>(dtype)) {
|
||||||
|
// nothing to do for integers
|
||||||
|
return quotient;
|
||||||
|
}
|
||||||
|
|
||||||
|
// float
|
||||||
|
Value ceil = b.create<math::CeilOp>(loc, quotient);
|
||||||
|
Value floor = b.create<math::FloorOp>(loc, quotient);
|
||||||
|
Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
|
||||||
|
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
|
||||||
|
quotient, cstZero);
|
||||||
|
return b.create<arith::SelectOp>(loc, pred, ceil, floor);
|
||||||
|
}
|
||||||
|
if (roundingMode == "floor") {
|
||||||
|
// "floor" - rounds the results of the division down. Equivalent to
|
||||||
|
// floor division in Python (the // operator)
|
||||||
|
if (isa<mlir::FloatType>(dtype))
|
||||||
|
return b.create<math::FloorOp>(loc, quotient);
|
||||||
|
if (!dtype.isUnsignedInteger()) {
|
||||||
|
Type defaultIntToFloatType = b.getF64Type();
|
||||||
|
lhs = convertScalarToDtype(b, loc, lhs, defaultIntToFloatType);
|
||||||
|
rhs = convertScalarToDtype(b, loc, rhs, defaultIntToFloatType);
|
||||||
|
quotient = b.create<arith::DivFOp>(loc, lhs, rhs);
|
||||||
|
Value floor = b.create<math::FloorOp>(loc, quotient);
|
||||||
|
Value convert = convertScalarToDtype(b, loc, floor, dtype);
|
||||||
|
return convert;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return quotient;
|
||||||
|
}
|
||||||
|
|
||||||
static Value createLinalgPayloadCalculationForElementwiseOp(
|
static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
OpBuilder &b, Location loc, const TypeConverter *converter,
|
OpBuilder &b, Location loc, const TypeConverter *converter,
|
||||||
ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) {
|
ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) {
|
||||||
|
@ -769,66 +842,14 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
div.emitError("unimplemented: non-floating point and non-integer dtype");
|
div.emitError("unimplemented: non-floating point and non-integer dtype");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
if (auto divTensorMode = dyn_cast<AtenDivTensorModeOp>(op)) {
|
if (auto divScalarMode = dyn_cast<AtenDivScalarModeOp>(op)) {
|
||||||
AtenDivTensorModeOp::Adaptor adaptor(operands);
|
return createDivModePayload(b, loc, converter, payloadArgs, divScalarMode,
|
||||||
Type dtype = converter->convertType(divTensorMode.getType())
|
operands);
|
||||||
.cast<RankedTensorType>()
|
}
|
||||||
.getElementType();
|
if (auto divTensorMode = dyn_cast<AtenDivTensorModeOp>(op)) {
|
||||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
return createDivModePayload(b, loc, converter, payloadArgs, divTensorMode,
|
||||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
operands);
|
||||||
Value div;
|
|
||||||
if (isa<mlir::FloatType>(dtype))
|
|
||||||
div = b.create<arith::DivFOp>(loc, lhs, rhs);
|
|
||||||
else {
|
|
||||||
if (dtype.isUnsignedInteger())
|
|
||||||
div = b.create<arith::DivUIOp>(loc, lhs, rhs);
|
|
||||||
else
|
|
||||||
div = b.create<arith::DivSIOp>(loc, lhs, rhs);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (divTensorMode.getRoundingMode().getType().isa<Torch::NoneType>())
|
|
||||||
return div;
|
|
||||||
|
|
||||||
std::string roundingMode;
|
|
||||||
if (!matchPattern(divTensorMode.getRoundingMode(),
|
|
||||||
m_TorchConstantStr(roundingMode))) {
|
|
||||||
divTensorMode.emitError("only support constant str rounding mode");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
if (roundingMode == "trunc") {
|
|
||||||
// "trunc" - rounds the results of the division towards zero. Equivalent
|
|
||||||
// to C-style integer division.
|
|
||||||
if (isa<mlir::FloatType>(dtype)) {
|
|
||||||
Value ceil = b.create<math::CeilOp>(loc, div);
|
|
||||||
Value floor = b.create<math::FloorOp>(loc, div);
|
|
||||||
Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
|
|
||||||
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
|
|
||||||
div, cstZero);
|
|
||||||
return b.create<arith::SelectOp>(loc, pred, ceil, floor);
|
|
||||||
} else
|
|
||||||
return div;
|
|
||||||
}
|
|
||||||
if (roundingMode == "floor") {
|
|
||||||
// "floor" - rounds the results of the division down. Equivalent to
|
|
||||||
// floor division in Python (the // operator)
|
|
||||||
if (isa<mlir::FloatType>(dtype))
|
|
||||||
return b.create<math::FloorOp>(loc, div);
|
|
||||||
else if (!dtype.isUnsignedInteger()) {
|
|
||||||
Type defaultIntToFloatType = b.getF64Type();
|
|
||||||
lhs = convertScalarToDtype(b, loc, lhs, defaultIntToFloatType);
|
|
||||||
rhs = convertScalarToDtype(b, loc, rhs, defaultIntToFloatType);
|
|
||||||
div = b.create<arith::DivFOp>(loc, lhs, rhs);
|
|
||||||
Value floor = b.create<math::FloorOp>(loc, div);
|
|
||||||
Value convert = convertScalarToDtype(b, loc, floor, dtype);
|
|
||||||
return convert;
|
|
||||||
} else {
|
|
||||||
return div;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
divTensorMode.emitError("invalid rounding mode");
|
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto pow = dyn_cast<AtenPowScalarOp>(op)) {
|
if (auto pow = dyn_cast<AtenPowScalarOp>(op)) {
|
||||||
Type dtype = pow.getType().cast<ValueTensorType>().getDtype();
|
Type dtype = pow.getType().cast<ValueTensorType>().getDtype();
|
||||||
if (!isa<mlir::FloatType>(dtype)) {
|
if (!isa<mlir::FloatType>(dtype)) {
|
||||||
|
@ -1579,12 +1600,13 @@ public:
|
||||||
if (!isa<AtenTanOp, AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenReluOp,
|
if (!isa<AtenTanOp, AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenReluOp,
|
||||||
AtenPreluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp,
|
AtenPreluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp,
|
||||||
AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp,
|
AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp,
|
||||||
AtenSubTensorOp, AtenAtan2Op, AtenLerpTensorOp, AtenSigmoidOp,
|
AtenDivScalarModeOp, AtenSubTensorOp, AtenAtan2Op,
|
||||||
AtenExpOp, AtenExpm1Op, AtenMinimumOp, AtenMaximumOp,
|
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenExpm1Op,
|
||||||
AtenToDtypeOp, AtenClampOp, AtenClampTensorOp, AtenRsubScalarOp,
|
AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
|
||||||
AtenMulScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp,
|
AtenClampTensorOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp,
|
||||||
AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp,
|
AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenPowScalarOp,
|
||||||
AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp,
|
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op,
|
||||||
|
AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp,
|
||||||
AtenRemainderScalarOp, AtenRemainderTensorOp, AtenFmodTensorOp,
|
AtenRemainderScalarOp, AtenRemainderTensorOp, AtenFmodTensorOp,
|
||||||
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
|
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
|
||||||
AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
|
AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
|
||||||
|
@ -2617,25 +2639,25 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
||||||
AtenTanOp, AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenAtanhOp, AtenAcoshOp,
|
AtenTanOp, AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenAtanhOp, AtenAcoshOp,
|
||||||
AtenAsinOp, AtenAsinhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp,
|
AtenAsinOp, AtenAsinhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp,
|
||||||
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp,
|
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp,
|
||||||
AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp,
|
AtenDivScalarModeOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp,
|
||||||
AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenClampTensorOp,
|
AtenMinimumOp, AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
|
||||||
AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp,
|
AtenClampTensorOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp,
|
||||||
AtenCeilOp, AtenPreluOp, AtenPowScalarOp, AtenPowTensorScalarOp,
|
AtenFloorOp, AtenCeilOp, AtenPreluOp, AtenPowScalarOp,
|
||||||
AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp,
|
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op,
|
||||||
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
|
AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
|
||||||
AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
|
AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
|
||||||
AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp,
|
AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
|
||||||
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
|
AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp,
|
||||||
AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp,
|
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
|
||||||
AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp,
|
AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp,
|
||||||
AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp,
|
AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
|
||||||
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp,
|
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp,
|
||||||
AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenAcosOp,
|
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp,
|
||||||
AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, AtenTrilOp,
|
AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp,
|
||||||
AtenRemainderScalarOp, AtenFmodTensorOp, AtenRemainderTensorOp,
|
AtenTrilOp, AtenRemainderScalarOp, AtenFmodTensorOp,
|
||||||
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
|
AtenRemainderTensorOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
|
||||||
AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
|
AtenFillTensorOp, AtenRealOp, AtenImagOp, AtenDequantizeSelfOp,
|
||||||
AtenQuantizePerTensorOp>();
|
AtenDequantizeTensorOp, AtenQuantizePerTensorOp>();
|
||||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenNllLossForwardOp>();
|
target.addIllegalOp<AtenNllLossForwardOp>();
|
||||||
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
||||||
|
|
|
@ -27,8 +27,8 @@
|
||||||
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||||
#include <iostream>
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
|
@ -409,9 +409,9 @@ public:
|
||||||
if (!lhsType)
|
if (!lhsType)
|
||||||
return op.emitError("only Tensor types supported in StableHLO");
|
return op.emitError("only Tensor types supported in StableHLO");
|
||||||
|
|
||||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
auto outType = cast<TensorType>(
|
||||||
->convertType(op.getType())
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
.template cast<TensorType>();
|
op.getType()));
|
||||||
|
|
||||||
Type outElemTy = outType.getElementType();
|
Type outElemTy = outType.getElementType();
|
||||||
if (!outElemTy.isIntOrFloat()) {
|
if (!outElemTy.isIntOrFloat()) {
|
||||||
|
@ -432,18 +432,23 @@ public:
|
||||||
Value result =
|
Value result =
|
||||||
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
|
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
|
||||||
|
|
||||||
if (!isa<AtenDivTensorModeOp>(op)) {
|
if (!std::is_same<AtenDivTensorModeOp, AtenOpT>() &&
|
||||||
|
!std::is_same<AtenDivScalarModeOp, AtenOpT>()) {
|
||||||
rewriter.replaceOp(op, result);
|
rewriter.replaceOp(op, result);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
AtenDivTensorModeOp divTensorModeOp =
|
auto tensorOp = dyn_cast<AtenDivTensorModeOp>(op.getOperation());
|
||||||
llvm::dyn_cast<AtenDivTensorModeOp>(op.getOperation());
|
auto opRoundingMode =
|
||||||
|
tensorOp
|
||||||
|
? tensorOp.getRoundingMode()
|
||||||
|
: cast<AtenDivScalarModeOp>(op.getOperation()).getRoundingMode();
|
||||||
|
|
||||||
std::string roundingMode;
|
std::string roundingMode;
|
||||||
if (!matchPattern(divTensorModeOp.getRoundingMode(),
|
if (!matchPattern(opRoundingMode, m_TorchConstantStr(roundingMode))) {
|
||||||
m_TorchConstantStr(roundingMode)))
|
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only support constant str rounding mode");
|
op, "only support constant str rounding mode");
|
||||||
|
}
|
||||||
|
|
||||||
// if trunc and int, do nothing
|
// if trunc and int, do nothing
|
||||||
if (roundingMode == "trunc" && isa<mlir::FloatType>(outElemTy)) {
|
if (roundingMode == "trunc" && isa<mlir::FloatType>(outElemTy)) {
|
||||||
|
@ -1845,6 +1850,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
||||||
INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorOp, chlo::BroadcastDivOp);
|
INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorOp, chlo::BroadcastDivOp);
|
||||||
INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorModeOp, chlo::BroadcastDivOp);
|
INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorModeOp, chlo::BroadcastDivOp);
|
||||||
INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarOp, chlo::BroadcastDivOp);
|
INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarOp, chlo::BroadcastDivOp);
|
||||||
|
INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarModeOp, chlo::BroadcastDivOp);
|
||||||
INSERT_BINARY_MULDIV_PATTERN(AtenRemainderScalarOp, chlo::BroadcastRemOp);
|
INSERT_BINARY_MULDIV_PATTERN(AtenRemainderScalarOp, chlo::BroadcastRemOp);
|
||||||
#undef INSERT_BINARY_MULDIV_PATTERN
|
#undef INSERT_BINARY_MULDIV_PATTERN
|
||||||
|
|
||||||
|
|
|
@ -1095,9 +1095,9 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
|
||||||
rhs = rewriter.create<AtenMulIntOp>(loc, rhs, alpha);
|
rhs = rewriter.create<AtenMulIntOp>(loc, rhs, alpha);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isa<AtenDivTensorModeOp>(op)) {
|
if (isa<AtenDivTensorModeOp, AtenDivScalarModeOp>(op)) {
|
||||||
// None rounding mode
|
|
||||||
if (op->getOperand(2).getType().isa<Torch::NoneType>()) {
|
if (op->getOperand(2).getType().isa<Torch::NoneType>()) {
|
||||||
|
// None rounding mode
|
||||||
Value quotient = rewriter.create<AtenDivOp>(loc, lhs, rhs);
|
Value quotient = rewriter.create<AtenDivOp>(loc, lhs, rhs);
|
||||||
rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(op, outType,
|
rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(op, outType,
|
||||||
quotient);
|
quotient);
|
||||||
|
@ -1858,6 +1858,16 @@ void AtenDivTensorModeOp::getCanonicalizationPatterns(
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AtenDivScalarModeOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
void AtenDivScalarModeOp::getCanonicalizationPatterns(
|
||||||
|
RewritePatternSet &patterns, MLIRContext *context) {
|
||||||
|
patterns.add(+[](AtenDivScalarModeOp op, PatternRewriter &rewriter) {
|
||||||
|
return rewrite0DBinaryTensorOp(op, rewriter);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AtenNumelOp
|
// AtenNumelOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -8496,6 +8496,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.div.Scalar_mode\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.optional<str>) -> !torch.list<int> {\n"
|
||||||
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
|
" return %0 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.floor_divide\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.floor_divide\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
|
@ -11166,6 +11170,83 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" }\n"
|
" }\n"
|
||||||
" return %2 : !torch.int\n"
|
" return %2 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
|
||||||
|
" %none = torch.constant.none\n"
|
||||||
|
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||||
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||||
|
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
|
||||||
|
" torch.prim.If %2 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" %3 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
|
||||||
|
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
|
||||||
|
" %5 = torch.prim.ListConstruct %0#1, %4 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||||
|
" %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %5) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||||
|
" return %6 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.div.Scalar_mode\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.optional<str>) -> !torch.int {\n"
|
||||||
|
" %str = torch.constant.str \"trunc\"\n"
|
||||||
|
" %int6 = torch.constant.int 6\n"
|
||||||
|
" %true = torch.constant.bool true\n"
|
||||||
|
" %false = torch.constant.bool false\n"
|
||||||
|
" %str_0 = torch.constant.str \"floor\"\n"
|
||||||
|
" %none = torch.constant.none\n"
|
||||||
|
" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional<str>, !torch.none -> !torch.bool\n"
|
||||||
|
" %1 = torch.prim.If %0 -> (!torch.bool) {\n"
|
||||||
|
" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional<str> -> !torch.str\n"
|
||||||
|
" %4 = torch.aten.eq.str %3, %str_0 : !torch.str, !torch.str -> !torch.bool\n"
|
||||||
|
" torch.prim.If.yield %4 : !torch.bool\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.If.yield %false : !torch.bool\n"
|
||||||
|
" }\n"
|
||||||
|
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
|
||||||
|
" %3 = func.call @\"__torch_mlir_dtype_fn.aten.floor_divide.Scalar\"(%arg0, %arg1) : (!torch.tuple<int, int>, !torch.number) -> !torch.int\n"
|
||||||
|
" torch.prim.If.yield %3 : !torch.int\n"
|
||||||
|
" } else {\n"
|
||||||
|
" %3:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" %4 = torch.prim.ListConstruct %none, %3#0 : (!torch.none, !torch.int) -> !torch.list<optional<int>>\n"
|
||||||
|
" %5 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
|
||||||
|
" %6 = torch.prim.ListConstruct %5, %3#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||||
|
" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%4, %6) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||||
|
" %8 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%7) : (!torch.int) -> !torch.bool\n"
|
||||||
|
" %9 = torch.prim.If %8 -> (!torch.bool) {\n"
|
||||||
|
" torch.prim.If.yield %true : !torch.bool\n"
|
||||||
|
" } else {\n"
|
||||||
|
" %12 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%7) : (!torch.int) -> !torch.bool\n"
|
||||||
|
" %13 = torch.prim.If %12 -> (!torch.bool) {\n"
|
||||||
|
" %14 = torch.aten.ne.int %7, %int6 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" torch.prim.If.yield %14 : !torch.bool\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.If.yield %false : !torch.bool\n"
|
||||||
|
" }\n"
|
||||||
|
" torch.prim.If.yield %13 : !torch.bool\n"
|
||||||
|
" }\n"
|
||||||
|
" %10 = torch.prim.If %9 -> (!torch.bool) {\n"
|
||||||
|
" torch.prim.If.yield %true : !torch.bool\n"
|
||||||
|
" } else {\n"
|
||||||
|
" %12 = torch.aten.__isnot__ %arg2, %none : !torch.optional<str>, !torch.none -> !torch.bool\n"
|
||||||
|
" %13 = torch.prim.If %12 -> (!torch.bool) {\n"
|
||||||
|
" %14 = torch.prim.unchecked_cast %arg2 : !torch.optional<str> -> !torch.str\n"
|
||||||
|
" %15 = torch.aten.eq.str %14, %str : !torch.str, !torch.str -> !torch.bool\n"
|
||||||
|
" torch.prim.If.yield %15 : !torch.bool\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.If.yield %false : !torch.bool\n"
|
||||||
|
" }\n"
|
||||||
|
" torch.prim.If.yield %13 : !torch.bool\n"
|
||||||
|
" }\n"
|
||||||
|
" %11 = torch.prim.If %10 -> (!torch.int) {\n"
|
||||||
|
" torch.prim.If.yield %7 : !torch.int\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.If.yield %int6 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" torch.prim.If.yield %11 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" return %2 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.matmul\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten.matmul\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||||
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
@ -11688,24 +11769,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||||
" return %4 : !torch.int\n"
|
" return %4 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
|
|
||||||
" %none = torch.constant.none\n"
|
|
||||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
|
||||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
|
||||||
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
|
||||||
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
|
|
||||||
" torch.prim.If %2 -> () {\n"
|
|
||||||
" torch.prim.If.yield\n"
|
|
||||||
" } else {\n"
|
|
||||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
|
||||||
" torch.prim.If.yield\n"
|
|
||||||
" }\n"
|
|
||||||
" %3 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
|
|
||||||
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
|
|
||||||
" %5 = torch.prim.ListConstruct %0#1, %4 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
|
||||||
" %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %5) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
|
||||||
" return %6 : !torch.int\n"
|
|
||||||
" }\n"
|
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.pow.Scalar\"(%arg0: !torch.number, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten.pow.Scalar\"(%arg0: !torch.number, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||||
" %none = torch.constant.none\n"
|
" %none = torch.constant.none\n"
|
||||||
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
|
|
@ -5800,7 +5800,7 @@ class DecomposeAtenFloorDivideOp : public OpRewritePattern<AtenFloorDivideOp> {
|
||||||
// PyTorch aten.floorDivide is a misnomer because it actually rounds
|
// PyTorch aten.floorDivide is a misnomer because it actually rounds
|
||||||
// the quotient towards zero instead of taking its floor.
|
// the quotient towards zero instead of taking its floor.
|
||||||
Value cstStrFloor =
|
Value cstStrFloor =
|
||||||
rewriter.create<Torch::ConstantStrOp>(op.getLoc(), "trunc");
|
rewriter.create<Torch::ConstantStrOp>(op.getLoc(), "floor");
|
||||||
rewriter.replaceOpWithNewOp<AtenDivTensorModeOp>(
|
rewriter.replaceOpWithNewOp<AtenDivTensorModeOp>(
|
||||||
op, op.getType(), op.getSelf(), op.getOther(),
|
op, op.getType(), op.getSelf(), op.getOther(),
|
||||||
/*roundingMode=*/cstStrFloor);
|
/*roundingMode=*/cstStrFloor);
|
||||||
|
@ -5809,6 +5809,22 @@ class DecomposeAtenFloorDivideOp : public OpRewritePattern<AtenFloorDivideOp> {
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class DecomposeAtenFloorDivideScalarOp
|
||||||
|
: public OpRewritePattern<AtenFloorDivideScalarOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenFloorDivideScalarOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Value cstStrFloor =
|
||||||
|
rewriter.create<Torch::ConstantStrOp>(op.getLoc(), "floor");
|
||||||
|
rewriter.replaceOpWithNewOp<AtenDivScalarModeOp>(
|
||||||
|
op, op.getType(), op.getSelf(), op.getOther(),
|
||||||
|
/*roundingMode=*/cstStrFloor);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Decompose `aten.numpyT` op into `aten.permute` op.
|
// Decompose `aten.numpyT` op into `aten.permute` op.
|
||||||
class DecomposeAtenNumpyTOp : public OpRewritePattern<AtenNumpyTOp> {
|
class DecomposeAtenNumpyTOp : public OpRewritePattern<AtenNumpyTOp> {
|
||||||
|
@ -7560,6 +7576,7 @@ public:
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenCosineSimilarityOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenCosineSimilarityOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenBaddbmmOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenBaddbmmOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenFloorDivideOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenFloorDivideOp>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<DecomposeAtenFloorDivideScalarOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNumpyTOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenNumpyTOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectScatterOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectScatterOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarDimOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenVarDimOp>(patterns);
|
||||||
|
|
|
@ -483,6 +483,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
||||||
target.addIllegalOp<AtenClampMaxOp>();
|
target.addIllegalOp<AtenClampMaxOp>();
|
||||||
target.addIllegalOp<AtenBaddbmmOp>();
|
target.addIllegalOp<AtenBaddbmmOp>();
|
||||||
target.addIllegalOp<AtenFloorDivideOp>();
|
target.addIllegalOp<AtenFloorDivideOp>();
|
||||||
|
target.addIllegalOp<AtenFloorDivideScalarOp>();
|
||||||
target.addIllegalOp<AtenNumpyTOp>();
|
target.addIllegalOp<AtenNumpyTOp>();
|
||||||
target.addIllegalOp<AtenSelectScatterOp>();
|
target.addIllegalOp<AtenSelectScatterOp>();
|
||||||
target.addIllegalOp<AtenVarDimOp>();
|
target.addIllegalOp<AtenVarDimOp>();
|
||||||
|
|
|
@ -244,12 +244,20 @@ TORCHDYNAMO_XFAIL_SET = {
|
||||||
"ElementwiseSubScalarIntModule_basic",
|
"ElementwiseSubScalarIntModule_basic",
|
||||||
|
|
||||||
# ERROR: Exception: Unsupported: missing default value for argument 0 in schema for aten.div.Tensor_mode
|
# ERROR: Exception: Unsupported: missing default value for argument 0 in schema for aten.div.Tensor_mode
|
||||||
"ElementwiseDivRoundingModeFloorModule_basic",
|
"ElementwiseAtenFloorDivideScalarNegativeModule_basic",
|
||||||
"ElementwiseDivRoundingModeTruncModule_basic",
|
"ElementwiseAtenFloorDivideScalarModule_basic",
|
||||||
"ElementwiseDivRoundingModeFloorStaticModule_basic",
|
"ElementwiseDivTensorRoundingModeFloorModule_basic",
|
||||||
"ElementwiseDivRoundingModeTruncStaticModule_basic",
|
"ElementwiseDivTensorRoundingModeTruncModule_basic",
|
||||||
"ElementwiseDivRoundingModeFloorIntStaticModule_basic",
|
"ElementwiseDivTensorRoundingModeFloorStaticModule_basic",
|
||||||
"ElementwiseDivRoundingModeTruncIntStaticModule_basic",
|
"ElementwiseDivTensorRoundingModeTruncStaticModule_basic",
|
||||||
|
"ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic",
|
||||||
|
"ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic",
|
||||||
|
"ElementwiseDivScalarRoundingModeFloorModule_basic",
|
||||||
|
"ElementwiseDivScalarRoundingModeTruncModule_basic",
|
||||||
|
"ElementwiseDivScalarRoundingModeFloorStaticModule_basic",
|
||||||
|
"ElementwiseDivScalarRoundingModeTruncStaticModule_basic",
|
||||||
|
"ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic",
|
||||||
|
"ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic",
|
||||||
"AdaptiveAvgPool1dStaticLargerOutput_basic",
|
"AdaptiveAvgPool1dStaticLargerOutput_basic",
|
||||||
"AdaptiveAvgPool1dGeneralDynamic_basic",
|
"AdaptiveAvgPool1dGeneralDynamic_basic",
|
||||||
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
||||||
|
@ -799,10 +807,14 @@ STABLEHLO_PASS_SET = {
|
||||||
"ElementwiseCloneContiguousModule_basic",
|
"ElementwiseCloneContiguousModule_basic",
|
||||||
"ElementwiseCloneModule_basic",
|
"ElementwiseCloneModule_basic",
|
||||||
"ElementwiseCosModule_basic",
|
"ElementwiseCosModule_basic",
|
||||||
"ElementwiseDivRoundingModeFloorStaticModule_basic",
|
"ElementwiseDivTensorRoundingModeFloorStaticModule_basic",
|
||||||
"ElementwiseDivRoundingModeTruncStaticModule_basic",
|
"ElementwiseDivTensorRoundingModeTruncStaticModule_basic",
|
||||||
"ElementwiseDivRoundingModeFloorIntStaticModule_basic",
|
"ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic",
|
||||||
"ElementwiseDivRoundingModeTruncIntStaticModule_basic",
|
"ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic",
|
||||||
|
"ElementwiseDivScalarRoundingModeFloorStaticModule_basic",
|
||||||
|
"ElementwiseDivScalarRoundingModeTruncStaticModule_basic",
|
||||||
|
"ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic",
|
||||||
|
"ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic",
|
||||||
"ElementwiseErfModule_basic",
|
"ElementwiseErfModule_basic",
|
||||||
"ElementwiseExpModule_basic",
|
"ElementwiseExpModule_basic",
|
||||||
"ElementwiseFloorIntModule_basic",
|
"ElementwiseFloorIntModule_basic",
|
||||||
|
@ -1354,6 +1366,8 @@ TOSA_PASS_SET = {
|
||||||
"ElementwiseDivScalarModule_basic",
|
"ElementwiseDivScalarModule_basic",
|
||||||
"ElementwiseDivTensorIntegerModule_basic",
|
"ElementwiseDivTensorIntegerModule_basic",
|
||||||
"ElementwiseDivTensorUnsignedIntegerModule_basic",
|
"ElementwiseDivTensorUnsignedIntegerModule_basic",
|
||||||
|
"ElementwiseDivScalarIntegerModule_basic",
|
||||||
|
"ElementwiseDivScalarUnsignedIntegerModule_basic",
|
||||||
"ElementwiseEluModule_basic",
|
"ElementwiseEluModule_basic",
|
||||||
"ElementwiseEluNonDefaultModule_basic",
|
"ElementwiseEluNonDefaultModule_basic",
|
||||||
"ElementwiseEqBoolScalarModule_basic",
|
"ElementwiseEqBoolScalarModule_basic",
|
||||||
|
@ -2450,10 +2464,11 @@ ONNX_XFAIL_SET = {
|
||||||
"ElementwiseAsinIntModule_basic",
|
"ElementwiseAsinIntModule_basic",
|
||||||
"ElementwiseAtanTensorIntModule_basic",
|
"ElementwiseAtanTensorIntModule_basic",
|
||||||
"ElementwiseCosIntModule_basic",
|
"ElementwiseCosIntModule_basic",
|
||||||
"ElementwiseDivRoundingModeFloorIntStaticModule_basic",
|
"ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic",
|
||||||
"ElementwiseDivRoundingModeTruncIntStaticModule_basic",
|
"ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic",
|
||||||
"ElementwiseDivRoundingModeTruncModule_basic",
|
"ElementwiseDivTensorRoundingModeTruncModule_basic",
|
||||||
"ElementwiseDivRoundingModeTruncStaticModule_basic",
|
"ElementwiseDivTensorRoundingModeTruncStaticModule_basic",
|
||||||
|
"ElementwiseErfIntModule_basic",
|
||||||
"ElementwiseErfIntModule_basic",
|
"ElementwiseErfIntModule_basic",
|
||||||
"ElementwiseExpIntModule_basic",
|
"ElementwiseExpIntModule_basic",
|
||||||
"ElementwiseLogIntModule_basic",
|
"ElementwiseLogIntModule_basic",
|
||||||
|
@ -2484,7 +2499,12 @@ ONNX_XFAIL_SET = {
|
||||||
"TensorsStackPromoteDTypeModule_basic",
|
"TensorsStackPromoteDTypeModule_basic",
|
||||||
|
|
||||||
# Failure - "RuntimeError: linalg.cross: inputs dimension 1 must have length 3. Got 1 and 1"
|
# Failure - "RuntimeError: linalg.cross: inputs dimension 1 must have length 3. Got 1 and 1"
|
||||||
"AtenLinalgCrossDynamic_basic"
|
"AtenLinalgCrossDynamic_basic",
|
||||||
|
|
||||||
|
# Failure - value not close to golden value (op is incorrectly truncating)
|
||||||
|
"ElementwiseAtenFloorDivideTensorNegativeModule_basic",
|
||||||
|
"ElementwiseAtenFloorDivideScalarNegativeModule_basic",
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ONNX_CRASHING_SET = {
|
ONNX_CRASHING_SET = {
|
||||||
|
|
|
@ -1209,6 +1209,9 @@ def aten〇div〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]:
|
||||||
def aten〇div〇Tensor_mode〡shape(self: List[int], other: List[int], rounding_mode: Optional[str]) -> List[int]:
|
def aten〇div〇Tensor_mode〡shape(self: List[int], other: List[int], rounding_mode: Optional[str]) -> List[int]:
|
||||||
return upstream_shape_functions.broadcast(self, other)
|
return upstream_shape_functions.broadcast(self, other)
|
||||||
|
|
||||||
|
def aten〇div〇Scalar_mode〡shape(self: List[int], other: float, rounding_mode: Optional[str]) -> List[int]:
|
||||||
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
def aten〇floor_divide〡shape(self: List[int], other: List[int]) -> List[int]:
|
def aten〇floor_divide〡shape(self: List[int], other: List[int]) -> List[int]:
|
||||||
return upstream_shape_functions.broadcast(self, other)
|
return upstream_shape_functions.broadcast(self, other)
|
||||||
|
|
||||||
|
@ -3152,6 +3155,45 @@ def aten〇div〇Tensor_mode〡dtype(self_rank_dtype: Tuple[int, int], other_ran
|
||||||
else:
|
else:
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
|
@check_dtype_function(
|
||||||
|
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1) +
|
||||||
|
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1.0))
|
||||||
|
def aten〇floor_divide〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
|
||||||
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
assert not is_complex_dtype(self_dtype)
|
||||||
|
ranks: List[Optional[int]] = [self_rank, None]
|
||||||
|
dtypes = [self_dtype, get_dtype_of_scalar(other)]
|
||||||
|
return promote_dtypes(ranks, dtypes)
|
||||||
|
|
||||||
|
@check_dtype_function(
|
||||||
|
|
||||||
|
# _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0, rounding_mode=None, error_types={torch.complex64, torch.complex128}) +
|
||||||
|
# _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0, rounding_mode=None, error_types={torch.complex64, torch.complex128}) +
|
||||||
|
# _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0, rounding_mode="floor", error_types={torch.complex64, torch.complex128}) +
|
||||||
|
# _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0, rounding_mode="floor", error_types={torch.complex64, torch.complex128}) +
|
||||||
|
# _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0, rounding_mode="trunc", error_types={torch.complex64, torch.complex128}) +
|
||||||
|
# _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0, rounding_mode="trunc", error_types={torch.complex64, torch.complex128}))
|
||||||
|
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0, rounding_mode=None) +
|
||||||
|
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0, rounding_mode=None) +
|
||||||
|
_check_tensors_with_the_same_dtype(error_types={torch.complex64, torch.complex128}, num_of_tensors=1, other=0.0, rounding_mode="floor") +
|
||||||
|
_check_tensors_with_the_same_dtype(error_types={torch.complex64, torch.complex128}, num_of_tensors=1, other=0, rounding_mode="floor") +
|
||||||
|
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0, rounding_mode="trunc") +
|
||||||
|
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0, rounding_mode="trunc"))
|
||||||
|
def aten〇div〇Scalar_mode〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex], rounding_mode: Optional[str]) -> int:
|
||||||
|
if rounding_mode is not None and rounding_mode == "floor":
|
||||||
|
return aten〇floor_divide〇Scalar〡dtype(self_rank_dtype, other)
|
||||||
|
|
||||||
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
ranks: List[Optional[int]] = [None, self_rank]
|
||||||
|
dtypes = [get_dtype_of_scalar(other), self_dtype]
|
||||||
|
promoted_dtype = promote_dtypes(ranks, dtypes)
|
||||||
|
if is_complex_dtype(promoted_dtype) or \
|
||||||
|
(is_float_dtype(promoted_dtype) and promoted_dtype != torch.float32) or \
|
||||||
|
(rounding_mode is not None and rounding_mode == "trunc"):
|
||||||
|
return promoted_dtype
|
||||||
|
else:
|
||||||
|
return torch.float32
|
||||||
|
|
||||||
@check_dtype_function(
|
@check_dtype_function(
|
||||||
_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 4), (2, 4, 3)]) +
|
_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 4), (2, 4, 3)]) +
|
||||||
# Different width
|
# Different width
|
||||||
|
@ -3631,15 +3673,6 @@ def aten〇fmod〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dt
|
||||||
dtypes = [self_dtype, other_dtype]
|
dtypes = [self_dtype, other_dtype]
|
||||||
return promote_dtypes(ranks, dtypes)
|
return promote_dtypes(ranks, dtypes)
|
||||||
|
|
||||||
@check_dtype_function(
|
|
||||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1) +
|
|
||||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1.0))
|
|
||||||
def aten〇floor_divide〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
|
|
||||||
self_rank, self_dtype = self_rank_dtype
|
|
||||||
assert not is_complex_dtype(self_dtype)
|
|
||||||
ranks: List[Optional[int]] = [self_rank, None]
|
|
||||||
dtypes = [self_dtype, get_dtype_of_scalar(other)]
|
|
||||||
return promote_dtypes(ranks, dtypes)
|
|
||||||
|
|
||||||
def aten〇pow〇Scalar〡dtype(self: Union[int, float, complex], exponent_rank_dtype: Tuple[int, int]) -> int:
|
def aten〇pow〇Scalar〡dtype(self: Union[int, float, complex], exponent_rank_dtype: Tuple[int, int]) -> int:
|
||||||
exponent_rank, exponent_dtype = exponent_rank_dtype
|
exponent_rank, exponent_dtype = exponent_rank_dtype
|
||||||
|
|
|
@ -342,6 +342,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
# Elementwise tensor compute ops that don't have the standard mutating
|
# Elementwise tensor compute ops that don't have the standard mutating
|
||||||
# variants.
|
# variants.
|
||||||
emit_with_mutating_variants("aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", has_canonicalizer=True)
|
emit_with_mutating_variants("aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", has_canonicalizer=True)
|
||||||
|
emit_with_mutating_variants("aten::div.Scalar_mode : (Tensor, Scalar, str?) -> (Tensor)", has_canonicalizer=True)
|
||||||
emit_with_mutating_variants("aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True, has_folder=True)
|
emit_with_mutating_variants("aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True, has_folder=True)
|
||||||
emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True)
|
emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True)
|
||||||
emit_with_mutating_variants("aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True)
|
emit_with_mutating_variants("aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True)
|
||||||
|
|
|
@ -2793,7 +2793,129 @@ def ElementwiseDivTensorUnsignedIntegerModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseDivRoundingModeTruncModule(torch.nn.Module):
|
|
||||||
|
class ElementwiseDivScalarRoundingModeTruncModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.div(a, 0.5, rounding_mode="trunc")
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ElementwiseDivScalarRoundingModeTruncModule())
|
||||||
|
def ElementwiseDivScalarRoundingModeTruncModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(4))
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseDivScalarRoundingModeFloorModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.div(a, 0.5, rounding_mode="floor")
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ElementwiseDivScalarRoundingModeFloorModule())
|
||||||
|
def ElementwiseDivScalarRoundingModeFloorModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4))
|
||||||
|
|
||||||
|
class ElementwiseDivScalarRoundingModeTruncStaticModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([4], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.div(a, 0.5, rounding_mode="trunc")
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ElementwiseDivScalarRoundingModeTruncStaticModule())
|
||||||
|
def ElementwiseDivScalarRoundingModeTruncStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(4))
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseDivScalarRoundingModeFloorStaticModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([3, 4], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.div(a, 0.5, rounding_mode="floor")
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ElementwiseDivScalarRoundingModeFloorStaticModule())
|
||||||
|
def ElementwiseDivScalarRoundingModeFloorStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4))
|
||||||
|
|
||||||
|
class ElementwiseDivScalarRoundingModeTruncIntStaticModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([3, 4], torch.int32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.div(a, 3, rounding_mode="trunc")
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ElementwiseDivScalarRoundingModeTruncIntStaticModule())
|
||||||
|
def ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32))
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseDivScalarRoundingModeFloorIntStaticModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([3, 4], torch.int32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.div(a, 3, rounding_mode="floor")
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ElementwiseDivScalarRoundingModeFloorIntStaticModule())
|
||||||
|
def ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseDivTensorRoundingModeTruncModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -2809,12 +2931,12 @@ class ElementwiseDivRoundingModeTruncModule(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(
|
@register_test_case(
|
||||||
module_factory=lambda: ElementwiseDivRoundingModeTruncModule())
|
module_factory=lambda: ElementwiseDivTensorRoundingModeTruncModule())
|
||||||
def ElementwiseDivRoundingModeTruncModule_basic(module, tu: TestUtils):
|
def ElementwiseDivTensorRoundingModeTruncModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4), tu.rand(4).type(torch.float64))
|
module.forward(tu.rand(4), tu.rand(4).type(torch.float64))
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseDivRoundingModeFloorModule(torch.nn.Module):
|
class ElementwiseDivTensorRoundingModeFloorModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -2830,11 +2952,11 @@ class ElementwiseDivRoundingModeFloorModule(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(
|
@register_test_case(
|
||||||
module_factory=lambda: ElementwiseDivRoundingModeFloorModule())
|
module_factory=lambda: ElementwiseDivTensorRoundingModeFloorModule())
|
||||||
def ElementwiseDivRoundingModeFloorModule_basic(module, tu: TestUtils):
|
def ElementwiseDivTensorRoundingModeFloorModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4), tu.rand(3, 4).type(torch.float64))
|
module.forward(tu.rand(3, 4), tu.rand(3, 4).type(torch.float64))
|
||||||
|
|
||||||
class ElementwiseDivRoundingModeTruncStaticModule(torch.nn.Module):
|
class ElementwiseDivTensorRoundingModeTruncStaticModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -2850,12 +2972,12 @@ class ElementwiseDivRoundingModeTruncStaticModule(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(
|
@register_test_case(
|
||||||
module_factory=lambda: ElementwiseDivRoundingModeTruncStaticModule())
|
module_factory=lambda: ElementwiseDivTensorRoundingModeTruncStaticModule())
|
||||||
def ElementwiseDivRoundingModeTruncStaticModule_basic(module, tu: TestUtils):
|
def ElementwiseDivTensorRoundingModeTruncStaticModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4), tu.rand(4).type(torch.float64))
|
module.forward(tu.rand(4), tu.rand(4).type(torch.float64))
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseDivRoundingModeFloorStaticModule(torch.nn.Module):
|
class ElementwiseDivTensorRoundingModeFloorStaticModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -2871,11 +2993,11 @@ class ElementwiseDivRoundingModeFloorStaticModule(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(
|
@register_test_case(
|
||||||
module_factory=lambda: ElementwiseDivRoundingModeFloorStaticModule())
|
module_factory=lambda: ElementwiseDivTensorRoundingModeFloorStaticModule())
|
||||||
def ElementwiseDivRoundingModeFloorStaticModule_basic(module, tu: TestUtils):
|
def ElementwiseDivTensorRoundingModeFloorStaticModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4), tu.rand(3, 4).type(torch.float64))
|
module.forward(tu.rand(3, 4), tu.rand(3, 4).type(torch.float64))
|
||||||
|
|
||||||
class ElementwiseDivRoundingModeTruncIntStaticModule(torch.nn.Module):
|
class ElementwiseDivTensorRoundingModeTruncIntStaticModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -2891,12 +3013,12 @@ class ElementwiseDivRoundingModeTruncIntStaticModule(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(
|
@register_test_case(
|
||||||
module_factory=lambda: ElementwiseDivRoundingModeTruncIntStaticModule())
|
module_factory=lambda: ElementwiseDivTensorRoundingModeTruncIntStaticModule())
|
||||||
def ElementwiseDivRoundingModeTruncIntStaticModule_basic(module, tu: TestUtils):
|
def ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32), tu.randint(3, 4, low=1, high=10).type(torch.int64))
|
module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32), tu.randint(3, 4, low=1, high=10).type(torch.int64))
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseDivRoundingModeFloorIntStaticModule(torch.nn.Module):
|
class ElementwiseDivTensorRoundingModeFloorIntStaticModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -2912,8 +3034,8 @@ class ElementwiseDivRoundingModeFloorIntStaticModule(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(
|
@register_test_case(
|
||||||
module_factory=lambda: ElementwiseDivRoundingModeFloorIntStaticModule())
|
module_factory=lambda: ElementwiseDivTensorRoundingModeFloorIntStaticModule())
|
||||||
def ElementwiseDivRoundingModeFloorIntStaticModule_basic(module, tu: TestUtils):
|
def ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32), tu.randint(3, 4, low=1, high=10).type(torch.int64))
|
module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32), tu.randint(3, 4, low=1, high=10).type(torch.int64))
|
||||||
|
|
||||||
|
|
||||||
|
@ -4194,7 +4316,48 @@ def ElementwiseAtenLogicalNotOpPromoteModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseAtenFloorDivideModule(torch.nn.Module):
|
class ElementwiseAtenFloorDivideScalarModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.floor_divide(x, 0.14)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseAtenFloorDivideScalarModule())
|
||||||
|
def ElementwiseAtenFloorDivideScalarModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(4, 3))
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseAtenFloorDivideScalarNegativeModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.floor_divide(x, 0.14)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseAtenFloorDivideScalarNegativeModule())
|
||||||
|
def ElementwiseAtenFloorDivideScalarNegativeModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(4, 3, low=-10.0, high=10.0))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseAtenFloorDivideTensorNegativeModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -4209,8 +4372,28 @@ class ElementwiseAtenFloorDivideModule(torch.nn.Module):
|
||||||
return torch.ops.aten.floor_divide(x, y)
|
return torch.ops.aten.floor_divide(x, y)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ElementwiseAtenFloorDivideModule())
|
@register_test_case(module_factory=lambda: ElementwiseAtenFloorDivideTensorNegativeModule())
|
||||||
def ElementwiseAtenFloorDivideModule_basic(module, tu: TestUtils):
|
def ElementwiseAtenFloorDivideTensorNegativeModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(4, 3, low= -1, high=0), tu.rand(4, 3, low= 0, high=1))
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseAtenFloorDivideTensorPositiveModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, y):
|
||||||
|
return torch.ops.aten.floor_divide(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseAtenFloorDivideTensorPositiveModule())
|
||||||
|
def ElementwiseAtenFloorDivideTensorPositiveModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 3), tu.rand(4, 3))
|
module.forward(tu.rand(4, 3), tu.rand(4, 3))
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue