mirror of https://github.com/llvm/torch-mlir
[TOSA] Add div rounding mode, remainder, fmod, and ge.Tensor ops support (#3717)
- Add legalization for aten.div rounding mode: + trunc: rounds division results towards zero + floor: rounds division results down - Add legalization for aten.remainder.Scalar and aten.fmod ops - Add legalization for aten.ge.Tensor op - Update e2e tests in xfail_sets.py - Update basic.mlir with new legalized ops Signed-off-by: Justin Ngo <justin.ngo@arm.com> Change-Id: Icedd23205254fb893ce6f3de08956772b83b4320 Signed-off-by: Justin Ngo <justin.ngo@arm.com>pull/3718/head
parent
5ce48dfacd
commit
abaff58c6d
|
@ -463,6 +463,119 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
// Function to perform division with trunc rounding mode (rounding result
|
||||
// towards zero) for float type inputs.
|
||||
// This function takes in the division result between lhs and rhs rather
|
||||
// than takes in the original lhs and rhs tensors as parameters.
|
||||
Value truncFloatDivWithDivResult(PatternRewriter &rewriter, Operation *op,
|
||||
TensorType outType, Value divResult) {
|
||||
// To implement trunc mode for float inputs, multiply the floored abs
|
||||
// of the tensor with the elementwise signedness of the tensor.
|
||||
// div_result = lhs / rhs
|
||||
// trunc_val = floor(abs(div_result)) * sign(div_result)
|
||||
auto zero =
|
||||
tosa::getConstTensor<float>(rewriter, op, 0, {}, outType.getElementType())
|
||||
.value();
|
||||
|
||||
auto one =
|
||||
tosa::getConstTensor<float>(rewriter, op, 1, {}, outType.getElementType())
|
||||
.value();
|
||||
|
||||
auto minusOne = tosa::getConstTensor<float>(rewriter, op, -1, {},
|
||||
outType.getElementType())
|
||||
.value();
|
||||
|
||||
auto cond = rewriter.create<tosa::GreaterEqualOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(1)),
|
||||
divResult, zero);
|
||||
|
||||
auto selectOp = rewriter.create<tosa::SelectOp>(op->getLoc(), outType, cond,
|
||||
one, minusOne);
|
||||
|
||||
auto absDivResult =
|
||||
rewriter.create<tosa::AbsOp>(op->getLoc(), outType, divResult);
|
||||
|
||||
auto flooredAbsDivResult =
|
||||
rewriter.create<tosa::FloorOp>(op->getLoc(), outType, absDivResult);
|
||||
|
||||
Value result =
|
||||
tosa::createMulOpAndCast(rewriter, op, outType, flooredAbsDivResult,
|
||||
selectOp, /*shift=*/0)
|
||||
.getResult();
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Function to perform division with trunc rounding mode (rounding result
|
||||
// towards zero) for float type inputs
|
||||
Value truncFloatDiv(PatternRewriter &rewriter, Operation *op,
|
||||
TensorType outType, Value lhs, Value rhs) {
|
||||
rhs = tosa::promoteType(rewriter, rhs, outType);
|
||||
|
||||
auto rhsRcp =
|
||||
rewriter.create<tosa::ReciprocalOp>(op->getLoc(), rhs.getType(), rhs);
|
||||
|
||||
auto divResult = tosa::createMulOpAndCast(rewriter, op, outType, lhs, rhsRcp,
|
||||
/*shift=*/0);
|
||||
|
||||
return truncFloatDivWithDivResult(rewriter, op, outType, divResult);
|
||||
}
|
||||
|
||||
// Function to perform division with floor rounding mode (rounding result
|
||||
// down) for integer type inputs.
|
||||
Value floorIntDiv(PatternRewriter &rewriter, Operation *op, TensorType outType,
|
||||
Value lhs, Value rhs) {
|
||||
// To implement floor mode int input, utilize tosa::IntDivOp (trunc div
|
||||
// result) with the following formula elementwise:
|
||||
// floor_val = trunc_val - ((trunc_val * rhs != lhs)
|
||||
// && (sign(lhs) != sign(rhs)))
|
||||
|
||||
// TOSA IntDiv requires inputs to be i32
|
||||
auto i32Type =
|
||||
RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(32));
|
||||
lhs = tosa::promoteType(rewriter, lhs, i32Type);
|
||||
rhs = tosa::promoteType(rewriter, rhs, i32Type);
|
||||
|
||||
auto intDivOp =
|
||||
rewriter.create<tosa::IntDivOp>(op->getLoc(), i32Type, lhs, rhs);
|
||||
|
||||
auto zero = tosa::getConstTensor<int32_t>(rewriter, op, 0, {}).value();
|
||||
|
||||
auto one = tosa::getConstTensor<int32_t>(rewriter, op, 1, {}).value();
|
||||
|
||||
auto boolType =
|
||||
RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(1));
|
||||
|
||||
auto lhsMulRhs = rewriter.create<tosa::MulOp>(op->getLoc(), i32Type, lhs, rhs,
|
||||
/*shift=*/0);
|
||||
|
||||
auto lhsRhsDifferentSign =
|
||||
rewriter.create<tosa::GreaterOp>(op->getLoc(), boolType, zero, lhsMulRhs);
|
||||
|
||||
auto truncMulRhs = rewriter.create<tosa::MulOp>(op->getLoc(), i32Type,
|
||||
intDivOp, rhs, /*shift=*/0);
|
||||
|
||||
auto truncMulRhsEqualLhs =
|
||||
rewriter.create<tosa::EqualOp>(op->getLoc(), boolType, truncMulRhs, lhs);
|
||||
|
||||
auto truncMulRhsNotEqualLhs = rewriter.create<tosa::LogicalNotOp>(
|
||||
op->getLoc(), boolType, truncMulRhsEqualLhs);
|
||||
|
||||
auto truncMinusOne =
|
||||
rewriter.create<tosa::SubOp>(op->getLoc(), i32Type, intDivOp, one);
|
||||
|
||||
auto cond = rewriter.create<tosa::LogicalAndOp>(
|
||||
op->getLoc(), boolType, lhsRhsDifferentSign, truncMulRhsNotEqualLhs);
|
||||
|
||||
auto selectOp = rewriter.create<tosa::SelectOp>(op->getLoc(), i32Type, cond,
|
||||
truncMinusOne, intDivOp);
|
||||
|
||||
Value result = tosa::promoteType(rewriter, selectOp, outType);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
|
@ -498,25 +611,64 @@ public:
|
|||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()));
|
||||
|
||||
// auto result;
|
||||
// Get rounding mode for aten.div.Tensor_mode
|
||||
std::string roundMode;
|
||||
if constexpr (std::is_same<AtenOpT, AtenDivTensorModeOp>() ||
|
||||
std::is_same<AtenOpT, AtenDivScalarModeOp>()) {
|
||||
if (!matchPattern(op.getRoundingMode(), m_TorchConstantStr(roundMode)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Non-const rounding mode parameter unsupported");
|
||||
}
|
||||
|
||||
Value result;
|
||||
if (isa<mlir::FloatType>(outType.getElementType())) {
|
||||
// The input to the reciprocal is an integer sometimes, and we may need to
|
||||
// promote it to a floating point. Per TOSA specification, the input types
|
||||
// can only be floating point for tosa::ReciprocalOp.
|
||||
Value rhsCasted = tosa::promoteType(rewriter, rhsTensor, outType);
|
||||
auto rcpOp = rewriter.create<tosa::ReciprocalOp>(
|
||||
op->getLoc(), rhsCasted.getType(), rhsCasted);
|
||||
// The input to the reciprocal is an integer sometimes, and we may need
|
||||
// to promote it to a floating point. Per TOSA specification, the input
|
||||
// types can only be floating point for tosa::ReciprocalOp.
|
||||
rhsTensor = tosa::promoteType(rewriter, rhsTensor, outType);
|
||||
auto rhsRcp = rewriter.create<tosa::ReciprocalOp>(
|
||||
op->getLoc(), rhsTensor.getType(), rhsTensor);
|
||||
|
||||
result = tosa::createMulOpAndCast(rewriter, op, outType, lhs,
|
||||
rcpOp.getResult(), /*shift=*/0)
|
||||
.getResult();
|
||||
auto divResult = tosa::createMulOpAndCast(rewriter, op, outType, lhs,
|
||||
rhsRcp, /*shift=*/0);
|
||||
|
||||
// Round result based on rounding mode
|
||||
if (roundMode.compare("floor") == 0) {
|
||||
// "floor": rounds the results of the division down. Equivalent to
|
||||
// floor division in Python (the // operator).
|
||||
auto floorOp =
|
||||
rewriter.create<tosa::FloorOp>(op->getLoc(), outType, divResult);
|
||||
|
||||
result = floorOp.getResult();
|
||||
} else if (roundMode.compare("trunc") == 0) {
|
||||
// "trunc": rounds the results of the division towards zero. Equivalent
|
||||
// to C-style integer division.
|
||||
result = truncFloatDivWithDivResult(rewriter, op, outType, divResult);
|
||||
} else {
|
||||
// The output type can be different than the input types (e.g. dividing an
|
||||
// int tensor results in a floating point tensor).
|
||||
result = tosa::createBinaryOpAndCast<tosa::IntDivOp>(
|
||||
rewriter, op, outType, lhs, rhsTensor)
|
||||
.getResult();
|
||||
// None: No rounding mode
|
||||
result = divResult.getResult();
|
||||
}
|
||||
} else {
|
||||
if (roundMode.compare("floor") == 0) {
|
||||
// "floor": rounds the results of the division down. Equivalent to floor
|
||||
// division in Python (the // operator).
|
||||
result = floorIntDiv(rewriter, op, outType, lhs, rhsTensor);
|
||||
} else {
|
||||
// "trunc": rounds the results of the division towards zero. Equivalent
|
||||
// to C-style integer division.
|
||||
// None: no rounding mode.
|
||||
|
||||
// TOSA IntDiv requires inputs to be i32
|
||||
auto i32Type = RankedTensorType::get(outType.getShape(),
|
||||
rewriter.getIntegerType(32));
|
||||
lhs = tosa::promoteType(rewriter, lhs, i32Type);
|
||||
rhsTensor = tosa::promoteType(rewriter, rhsTensor, i32Type);
|
||||
|
||||
auto intDivOp = rewriter.create<tosa::IntDivOp>(op->getLoc(), i32Type,
|
||||
lhs, rhsTensor);
|
||||
|
||||
result = tosa::promoteType(rewriter, intDivOp, outType);
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, {result});
|
||||
|
@ -4524,20 +4676,24 @@ LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenRemainderScalarOp>::matchAndRewrite(
|
||||
AtenRemainderScalarOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenRemainderFmodOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||
|
||||
if (!selfTy)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only ranked tensor types supported in TOSA Remainder");
|
||||
op, "Only ranked tensor types supported in TOSA Remainder/Fmod");
|
||||
|
||||
auto outType =
|
||||
cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
cast<TensorType>(this->getTypeConverter()->convertType(op.getType()));
|
||||
|
||||
Type outElemTy = outType.getElementType();
|
||||
if (!outElemTy.isIntOrFloat())
|
||||
|
@ -4545,35 +4701,69 @@ LogicalResult ConvertAtenOp<AtenRemainderScalarOp>::matchAndRewrite(
|
|||
op, "Only floating-point or integer datatype legalization supported");
|
||||
|
||||
Value otherTensor;
|
||||
if constexpr (std::is_same<AtenOpT, AtenRemainderScalarOp>()) {
|
||||
Value other = op.getOther();
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op, other, otherTensor,
|
||||
outElemTy, {})))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Currently only scalar constants are supported for "
|
||||
"conversion in TOSA Remainder operation");
|
||||
"conversion in TOSA Remainder/Fmod operation");
|
||||
} else {
|
||||
otherTensor = adaptor.getOther();
|
||||
auto otherTy = cast<RankedTensorType>(otherTensor.getType());
|
||||
|
||||
if (!otherTy)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only ranked tensor types supported in TOSA Remainder/Fmod");
|
||||
}
|
||||
|
||||
constexpr bool isRemainderOp =
|
||||
std::is_same<AtenOpT, AtenRemainderScalarOp>() ||
|
||||
std::is_same<AtenOpT, AtenRemainderTensorOp>() ||
|
||||
std::is_same<AtenOpT, AtenRemainderIntOp>();
|
||||
|
||||
if (selfTy.getElementType() != outElemTy)
|
||||
self = rewriter.create<tosa::CastOp>(op.getLoc(), outType, self);
|
||||
|
||||
auto divTensor = self;
|
||||
Value divTensor;
|
||||
if (isRemainderOp) {
|
||||
// torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b
|
||||
if (isa<mlir::FloatType>(outElemTy)) {
|
||||
auto otherTensorReciprocal = rewriter.create<tosa::ReciprocalOp>(
|
||||
op.getLoc(), otherTensor.getType(), otherTensor);
|
||||
divTensor = rewriter.create<tosa::MulOp>(
|
||||
op.getLoc(), outType, self, otherTensorReciprocal, /*shift=*/0);
|
||||
divTensor = rewriter.create<tosa::FloorOp>(op.getLoc(), outType, divTensor);
|
||||
divTensor =
|
||||
rewriter.create<tosa::FloorOp>(op.getLoc(), outType, divTensor);
|
||||
} else {
|
||||
divTensor = rewriter.create<tosa::IntDivOp>(op.getLoc(), outType, self,
|
||||
otherTensor);
|
||||
divTensor = floorIntDiv(rewriter, op, outType, self, otherTensor);
|
||||
}
|
||||
} else {
|
||||
// torch.fmod(a, b) == a - a.div(b, rounding_mode="trunc") * b
|
||||
if (isa<mlir::FloatType>(outElemTy)) {
|
||||
divTensor = truncFloatDiv(rewriter, op, outType, self, otherTensor);
|
||||
} else {
|
||||
// TOSA IntDiv requires inputs to be i32
|
||||
auto i32Type = RankedTensorType::get(outType.getShape(),
|
||||
rewriter.getIntegerType(32));
|
||||
self = tosa::promoteType(rewriter, self, i32Type);
|
||||
otherTensor = tosa::promoteType(rewriter, otherTensor, i32Type);
|
||||
|
||||
auto intDivTensor = rewriter.create<tosa::IntDivOp>(
|
||||
op->getLoc(), i32Type, self, otherTensor);
|
||||
|
||||
divTensor = tosa::promoteType(rewriter, intDivTensor, outType);
|
||||
}
|
||||
}
|
||||
|
||||
auto mulTensor =
|
||||
rewriter.create<tosa::MulOp>(op.getLoc(), outType, otherTensor, divTensor,
|
||||
auto mulTensor = rewriter.create<tosa::MulOp>(op.getLoc(), outType,
|
||||
otherTensor, divTensor,
|
||||
/*shift=*/0);
|
||||
rewriter.replaceOpWithNewOp<tosa::SubOp>(op, outType, self, mulTensor);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AtenOpT, typename TosaOpT>
|
||||
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
|
||||
|
@ -5649,6 +5839,7 @@ public:
|
|||
patterns.add<ConvertAtenCompareOp<AtenOp, TosaOp>>(typeConverter, context);
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp, tosa::GreaterEqualOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenGeTensorOp, tosa::GreaterEqualOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp)
|
||||
|
@ -5673,8 +5864,19 @@ public:
|
|||
patterns.add<ConvertAtenDivOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp);
|
||||
INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp);
|
||||
INSERT_BINARY_DIV_PATTERN(AtenDivTensorModeOp);
|
||||
INSERT_BINARY_DIV_PATTERN(AtenDivScalarModeOp);
|
||||
#undef INSERT_BINARY_DIV_PATTERN
|
||||
|
||||
#define INSERT_REMAINDER_FMOD_OP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenRemainderFmodOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderScalarOp);
|
||||
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderTensorOp);
|
||||
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodScalarOp);
|
||||
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodTensorOp);
|
||||
#undef INSERT_REMAINDER_FMOD_OP_PATTERN
|
||||
|
||||
#define INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenMultipleDimsReductionOp<AtenOp, ConversionFunc>>( \
|
||||
|
@ -5828,7 +6030,6 @@ public:
|
|||
INSERT_ATENOP_PATTERN(AtenCopyOp);
|
||||
INSERT_ATENOP_PATTERN(AtenToDtypeOp);
|
||||
INSERT_ATENOP_PATTERN(AtenConstantPadNdOp);
|
||||
INSERT_ATENOP_PATTERN(AtenRemainderScalarOp);
|
||||
INSERT_ATENOP_PATTERN(AtenCatOp);
|
||||
INSERT_ATENOP_PATTERN(AtenSqrtOp);
|
||||
INSERT_ATENOP_PATTERN(AtenIscloseOp);
|
||||
|
|
|
@ -1668,6 +1668,40 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
|
|||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
# and very few tests work yet.
|
||||
TOSA_PASS_SET = {
|
||||
"ElementwiseAtenFloorDivideBroadcastModule_basic",
|
||||
"ElementwiseAtenFloorDivideScalarModule_basic",
|
||||
"ElementwiseAtenFloorDivideScalarNegativeModule_basic",
|
||||
"ElementwiseAtenFloorDivideTensorNegativeModule_basic",
|
||||
"ElementwiseAtenFloorDivideTensorPositiveModule_basic",
|
||||
"ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic",
|
||||
"ElementwiseDivScalarRoundingModeFloorModule_basic",
|
||||
"ElementwiseDivScalarRoundingModeFloorStaticModule_basic",
|
||||
"ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic",
|
||||
"ElementwiseDivScalarRoundingModeTruncModule_basic",
|
||||
"ElementwiseDivScalarRoundingModeTruncStaticModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeFloorModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeFloorStaticModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeTruncModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeTruncStaticModule_basic",
|
||||
"ElementwiseGeFloatTensorModule_basic",
|
||||
"ElementwiseGeIntTensorModule_basic",
|
||||
"ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic",
|
||||
"ElementwiseRemainderScalarModule_Bool_basic",
|
||||
"ElementwiseRemainderScalarModule_Int_NegativeDividend_basic",
|
||||
"ElementwiseRemainderScalarModule_Int_NegativeDivisor_basic",
|
||||
"ElementwiseRemainderTensorModule_Float_NegativeDividend_basic",
|
||||
"ElementwiseRemainderTensorModule_Float_NegativeDivisor_basic",
|
||||
"ElementwiseRemainderTensorModule_Float_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_Float_NegativeDividend_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_Float_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_NegativeDividend_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_basic",
|
||||
"TriuBroadcastModule_basic",
|
||||
"TriuModule_basic",
|
||||
"ArgminIntModule_basic",
|
||||
"ArgminIntModule_multiple_mins",
|
||||
"ArgminModule_basic",
|
||||
|
@ -2210,6 +2244,7 @@ MAKE_FX_TOSA_PASS_SET = (
|
|||
TOSA_PASS_SET
|
||||
| {
|
||||
### Tests additionally passing in make_fx_tosa
|
||||
"AdaptiveAvgPool1dStaticLargerOutput_basic",
|
||||
"ArgminIntModule_basic",
|
||||
"ArgminIntModule_multiple_mins",
|
||||
"ArgminModule_basic",
|
||||
|
@ -2318,7 +2353,6 @@ MAKE_FX_TOSA_PASS_SET = (
|
|||
"ViewNoChange1dModule_basic",
|
||||
"ViewNoChange2dModule_basic",
|
||||
"ViewNoChange3dModule_basic",
|
||||
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
||||
}
|
||||
|
||||
if torch_version_for_comparison() < version.parse("2.5.0.dev"):
|
||||
|
@ -3137,7 +3171,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"Rot90MultipleRotationsModule_basic",
|
||||
"Rot90NegativeEvenRotationsModule_basic",
|
||||
"Rot90NegativeOddRotationsModule_basic",
|
||||
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
||||
"AtenIntMM_basic",
|
||||
"AtenKthvalueDynamicDimsModule_basic",
|
||||
"AtenKthvalueFloat64DynamicDimsModule_basic",
|
||||
|
@ -3153,15 +3186,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"EinsumStaticDiagonalDimensionModule_basic",
|
||||
"ElementwiseFloatTensorGtIntTensorModule_basic",
|
||||
"ElementwiseIntTensorLtFloatTensorModule_basic",
|
||||
"ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic",
|
||||
"ElementwiseRemainderScalarModule_Int_NegativeDividend_basic",
|
||||
"ElementwiseRemainderScalarModule_Int_NegativeDivisor_basic",
|
||||
"ElementwiseRemainderTensorModule_Float_NegativeDividend_basic",
|
||||
"ElementwiseRemainderTensorModule_Float_NegativeDivisor_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_Float_NegativeDividend_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_NegativeDividend_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic",
|
||||
"ElementwiseRreluEvalModule_basic",
|
||||
"ElementwiseRreluEvalStaticModule_basic",
|
||||
"ElementwiseRreluTrainModule_basic",
|
||||
|
@ -3194,11 +3218,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"TriuIndicesNegativeOffsetModule_basic",
|
||||
"TypeConversionUint8ToF32Module_basic",
|
||||
"WeightNormInterfaceModule_basic",
|
||||
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
||||
"AdaptiveAvgPool1dGeneralDynamic_basic",
|
||||
"AdaptiveAvgPool1dStaticLargerOutput_basic",
|
||||
"AdaptiveAvgPool2dDynamicNoBatch_basic",
|
||||
"AdaptiveAvgPool2dDynamic_basic",
|
||||
"AdaptiveAvgPool3dDynamicNoBatch_basic",
|
||||
"AdaptiveAvgPool3dDynamic_basic",
|
||||
"AdaptiveMaxPool1dDynamicNoBatch_basic",
|
||||
|
@ -3370,11 +3389,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ElementwiseAtanTensorIntModule_basic",
|
||||
"ElementwiseAtanhIntModule_basic",
|
||||
"ElementwiseAtanhModule_basic",
|
||||
"ElementwiseAtenFloorDivideBroadcastModule_basic",
|
||||
"ElementwiseAtenFloorDivideScalarModule_basic",
|
||||
"ElementwiseAtenFloorDivideScalarNegativeModule_basic",
|
||||
"ElementwiseAtenFloorDivideTensorNegativeModule_basic",
|
||||
"ElementwiseAtenFloorDivideTensorPositiveModule_basic",
|
||||
"ElementwiseAtenLogicalAndOpModule_basic",
|
||||
"ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic",
|
||||
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",
|
||||
|
@ -3402,25 +3416,11 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ElementwiseCoshModule_basic",
|
||||
"ElementwiseDequantizePerChannelModule_basic",
|
||||
"ElementwiseDequantizePerTensorModule_basic",
|
||||
"ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic",
|
||||
"ElementwiseDivScalarRoundingModeFloorModule_basic",
|
||||
"ElementwiseDivScalarRoundingModeFloorStaticModule_basic",
|
||||
"ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic",
|
||||
"ElementwiseDivScalarRoundingModeTruncModule_basic",
|
||||
"ElementwiseDivScalarRoundingModeTruncStaticModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeFloorModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeFloorStaticModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeTruncModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeTruncStaticModule_basic",
|
||||
"ElementwiseErfIntModule_basic",
|
||||
"ElementwiseErfModule_basic",
|
||||
"ElementwiseExpIntModule_basic",
|
||||
"ElementwiseExpm1IntModule_basic",
|
||||
"ElementwiseExpm1Module_basic",
|
||||
"ElementwiseGeFloatTensorModule_basic",
|
||||
"ElementwiseGeIntTensorModule_basic",
|
||||
"ElementwiseGeluApproximateTanhModule_basic",
|
||||
"ElementwiseHardshrinkModule_basic",
|
||||
"ElementwiseHardshrinkStaticModule_basic",
|
||||
|
@ -3448,10 +3448,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ElementwiseQuantizePerTensorModule_basic",
|
||||
"ElementwiseQuantizePerTensorUIntModule_basic",
|
||||
"ElementwiseReciprocalIntModule_basic",
|
||||
"ElementwiseRemainderScalarModule_Bool_basic",
|
||||
"ElementwiseRemainderTensorModule_Float_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_Float_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_basic",
|
||||
"ElementwiseRsqrtIntModule_basic",
|
||||
"ElementwiseSigmoidIntModule_basic",
|
||||
"ElementwiseSinIntModule_basic",
|
||||
|
@ -3850,6 +3846,7 @@ ONNX_TOSA_CRASHING_SET = {
|
|||
}
|
||||
|
||||
ONNX_TOSA_XFAIL_SET = {
|
||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||
"HstackBasicComplexModule_basic",
|
||||
"HstackBasicFloatModule_basic",
|
||||
"HstackBasicIntFloatModule_basic",
|
||||
|
@ -3890,8 +3887,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic",
|
||||
"ElementwiseRemainderScalarModule_Int_Float_NegativeDividend_basic",
|
||||
"ElementwiseRemainderScalarModule_Int_Float_NegativeDivisor_basic",
|
||||
"ElementwiseRemainderScalarModule_Int_NegativeDividend_basic",
|
||||
"ElementwiseRemainderScalarModule_Int_NegativeDivisor_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_Float_NegativeDividend_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_NegativeDividend_basic",
|
||||
|
@ -4223,11 +4218,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ElementwiseFloatTensorGtIntTensorModule_basic",
|
||||
"ElementwiseFmodTensor_Int_Float_basic",
|
||||
"ElementwiseFmodTensor_Int_basic",
|
||||
"ElementwiseGeFloatIntScalarModule_basic",
|
||||
"ElementwiseGeFloatScalarModule_basic",
|
||||
"ElementwiseGeFloatTensorModule_basic",
|
||||
"ElementwiseGeIntScalarModule_basic",
|
||||
"ElementwiseGeIntTensorModule_basic",
|
||||
"ElementwiseGeMixedIntScalarModule_basic",
|
||||
"ElementwiseGtMixed2ScalarModule_basic",
|
||||
"ElementwiseIntTensorLtFloatScalarModule_basic",
|
||||
|
@ -4259,7 +4249,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ElementwiseRelu6Module_basic",
|
||||
"ElementwiseRemainderScalarModule_Bool_basic",
|
||||
"ElementwiseRemainderScalarModule_Int_Float_basic",
|
||||
"ElementwiseRemainderScalarModule_Int_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_Float_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_basic",
|
||||
"ElementwiseRsqrtIntModule_basic",
|
||||
|
@ -4682,7 +4671,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ScalarImplicitIntModule_basic",
|
||||
# REMOVE WHEN ENABLE_GQA IS ADDED
|
||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
|
||||
"ScaledDotProductAttentionSameCausalModule_basic",
|
||||
"ScaledDotProductAttentionSameDynamicModule_basic",
|
||||
"ScatterReduceFloatMaxModule",
|
||||
|
@ -4819,8 +4807,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"TraceSignedIntModule_basic",
|
||||
"TraceUnsignedIntModule_basic",
|
||||
"TraceUnsignedIntModule_empty",
|
||||
"TriuBroadcastModule_basic",
|
||||
"TriuModule_basic",
|
||||
"TupleModule_basic",
|
||||
"TypeAsDifferentModule_basic",
|
||||
"TypeConversionF32ToF64Module_basic",
|
||||
|
|
|
@ -213,10 +213,10 @@ func.func @torch.aten.mul$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch
|
|||
// CHECK-LABEL: func.func @torch.aten.div$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_3]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_2]], %[[VAL_4]] {shift = 0 : i8} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_4]] {shift = 0 : i8} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK: }
|
||||
|
@ -1470,3 +1470,205 @@ func.func @torch.aten.all.dim$basic(%arg0: !torch.vtensor<[3,2,3],i1>) -> !torch
|
|||
%0 = torch.aten.all.dim %arg0, %dim, %keepdims: !torch.vtensor<[3,2,3],i1> , !torch.int, !torch.bool -> !torch.vtensor<[3,2,1],i1>
|
||||
return %0 : !torch.vtensor<[3,2,1],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$float_trunc(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.str "trunc"
|
||||
// CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
|
||||
// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
|
||||
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<-1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
|
||||
// CHECK: %[[VAL_10:.*]] = tosa.greater_equal %[[VAL_6]], %[[VAL_7]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[VAL_11:.*]] = tosa.select %[[VAL_10]], %[[VAL_8]], %[[VAL_9]] : (tensor<?x?xi1>, tensor<f32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_12:.*]] = tosa.abs %[[VAL_6]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_13:.*]] = tosa.floor %[[VAL_12]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_13]], %[[VAL_11]] {shift = 0 : i8} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_15:.*]] = torch_c.from_builtin_tensor %[[VAL_14]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_15]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.div.Tensor_mode$float_trunc(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
|
||||
%str = torch.constant.str "trunc"
|
||||
%0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32>, !torch.str -> !torch.vtensor<[?, ?],f32>
|
||||
return %0 : !torch.vtensor<[?, ?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$int_trunc(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si64>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.str "trunc"
|
||||
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_3]] : (tensor<?x?xi64>) -> tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_2]] : (tensor<?x?xi64>) -> tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_7:.*]] = tosa.int_div %[[VAL_5]], %[[VAL_6]] : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_8:.*]] = tosa.cast %[[VAL_7]] : (tensor<?x?xi32>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64>
|
||||
// CHECK: return %[[VAL_9]] : !torch.vtensor<[?,?],si64>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.div.Tensor_mode$int_trunc(%arg0: !torch.vtensor<[?, ?],si64>, %arg1: !torch.vtensor<[?, ?],si64>) -> !torch.vtensor<[?, ?],si64> {
|
||||
%str = torch.constant.str "trunc"
|
||||
%0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?, ?],si64>, !torch.vtensor<[?, ?],si64>, !torch.str -> !torch.vtensor<[?, ?],si64>
|
||||
return %0 : !torch.vtensor<[?, ?],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$float_floor(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.str "floor"
|
||||
// CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = tosa.floor %[[VAL_6]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.div.Tensor_mode$float_floor(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
|
||||
%str = torch.constant.str "floor"
|
||||
%0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32>, !torch.str -> !torch.vtensor<[?, ?],f32>
|
||||
return %0 : !torch.vtensor<[?, ?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$int_floor(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si64>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.str "floor"
|
||||
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_3]] : (tensor<?x?xi64>) -> tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_2]] : (tensor<?x?xi64>) -> tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_7:.*]] = tosa.int_div %[[VAL_5]], %[[VAL_6]] : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
|
||||
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
|
||||
// CHECK: %[[VAL_10:.*]] = tosa.mul %[[VAL_5]], %[[VAL_6]] {shift = 0 : i8} : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_11:.*]] = tosa.greater %[[VAL_8]], %[[VAL_10]] : (tensor<i32>, tensor<?x?xi32>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[VAL_12:.*]] = tosa.mul %[[VAL_7]], %[[VAL_6]] {shift = 0 : i8} : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_13:.*]] = tosa.equal %[[VAL_12]], %[[VAL_5]] : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[VAL_14:.*]] = tosa.logical_not %[[VAL_13]] : (tensor<?x?xi1>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[VAL_15:.*]] = tosa.sub %[[VAL_7]], %[[VAL_9]] : (tensor<?x?xi32>, tensor<i32>) -> tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_16:.*]] = tosa.logical_and %[[VAL_11]], %[[VAL_14]] : (tensor<?x?xi1>, tensor<?x?xi1>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[VAL_17:.*]] = tosa.select %[[VAL_16]], %[[VAL_15]], %[[VAL_7]] : (tensor<?x?xi1>, tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_18:.*]] = tosa.cast %[[VAL_17]] : (tensor<?x?xi32>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64>
|
||||
// CHECK: return %[[VAL_19]] : !torch.vtensor<[?,?],si64>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.div.Tensor_mode$int_floor(%arg0: !torch.vtensor<[?, ?],si64>, %arg1: !torch.vtensor<[?, ?],si64>) -> !torch.vtensor<[?, ?],si64> {
|
||||
%str = torch.constant.str "floor"
|
||||
%0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?, ?],si64>, !torch.vtensor<[?, ?],si64>, !torch.str -> !torch.vtensor<[?, ?],si64>
|
||||
return %0 : !torch.vtensor<[?, ?],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$float_basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.str ""
|
||||
// CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.div.Tensor_mode$float_basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
|
||||
%str = torch.constant.str ""
|
||||
%0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32>, !torch.str -> !torch.vtensor<[?, ?],f32>
|
||||
return %0 : !torch.vtensor<[?, ?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$int_basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si64>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.str ""
|
||||
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_3]] : (tensor<?x?xi64>) -> tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_2]] : (tensor<?x?xi64>) -> tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_7:.*]] = tosa.int_div %[[VAL_5]], %[[VAL_6]] : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_8:.*]] = tosa.cast %[[VAL_7]] : (tensor<?x?xi32>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64>
|
||||
// CHECK: return %[[VAL_9]] : !torch.vtensor<[?,?],si64>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.div.Tensor_mode$int_basic(%arg0: !torch.vtensor<[?, ?],si64>, %arg1: !torch.vtensor<[?, ?],si64>) -> !torch.vtensor<[?, ?],si64> {
|
||||
%str = torch.constant.str ""
|
||||
%0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?, ?],si64>, !torch.vtensor<[?, ?],si64>, !torch.str -> !torch.vtensor<[?, ?],si64>
|
||||
return %0 : !torch.vtensor<[?, ?],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.ge.Tensor$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = tosa.greater_equal %[[VAL_3]], %[[VAL_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.ge.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||
%0 = torch.aten.ge.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1>
|
||||
return %0 : !torch.vtensor<[?,?],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.remainder.Tensor(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_4]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = tosa.floor %[[VAL_5]] : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_2]], %[[VAL_6]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
// CHECK: %[[VAL_8:.*]] = tosa.sub %[[VAL_3]], %[[VAL_7]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32>
|
||||
// CHECK: return %[[VAL_9]] : !torch.vtensor<[2,4],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.remainder.Tensor(%arg0: !torch.vtensor<[2, 4],f32>, %arg1: !torch.vtensor<[2, 4],f32>) -> !torch.vtensor<[2, 4],f32> {
|
||||
%0 = torch.aten.remainder.Tensor %arg0, %arg1 : !torch.vtensor<[2, 4],f32>, !torch.vtensor<[2, 4],f32> -> !torch.vtensor<[2, 4],f32>
|
||||
return %0 : !torch.vtensor<[2, 4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.fmod.Tensor(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_4]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
|
||||
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
|
||||
// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<-1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
|
||||
// CHECK: %[[VAL_9:.*]] = tosa.greater_equal %[[VAL_5]], %[[VAL_6]] : (tensor<2x4xf32>, tensor<f32>) -> tensor<2x4xi1>
|
||||
// CHECK: %[[VAL_10:.*]] = tosa.select %[[VAL_9]], %[[VAL_7]], %[[VAL_8]] : (tensor<2x4xi1>, tensor<f32>, tensor<f32>) -> tensor<2x4xf32>
|
||||
// CHECK: %[[VAL_11:.*]] = tosa.abs %[[VAL_5]] : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
// CHECK: %[[VAL_12:.*]] = tosa.floor %[[VAL_11]] : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
// CHECK: %[[VAL_13:.*]] = tosa.mul %[[VAL_12]], %[[VAL_10]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_2]], %[[VAL_13]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
// CHECK: %[[VAL_15:.*]] = tosa.sub %[[VAL_3]], %[[VAL_14]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32>
|
||||
// CHECK: return %[[VAL_16]] : !torch.vtensor<[2,4],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.fmod.Tensor(%arg0: !torch.vtensor<[2, 4],f32>, %arg1: !torch.vtensor<[2, 4],f32>) -> !torch.vtensor<[2, 4],f32> {
|
||||
%0 = torch.aten.fmod.Tensor %arg0, %arg1 : !torch.vtensor<[2, 4],f32>, !torch.vtensor<[2, 4],f32> -> !torch.vtensor<[2, 4],f32>
|
||||
return %0 : !torch.vtensor<[2, 4],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue