[TOSA] Add div rounding mode, remainder, fmod, and ge.Tensor ops support (#3717)

- Add legalization for aten.div rounding mode:
  + trunc: rounds division results towards zero
  + floor: rounds division results down
- Add legalization for aten.remainder.Scalar and aten.fmod ops
- Add legalization for aten.ge.Tensor op
- Update e2e tests in xfail_sets.py
- Update basic.mlir with new legalized ops

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
Change-Id: Icedd23205254fb893ce6f3de08956772b83b4320

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
pull/3718/head
Justin Ngo 2024-09-20 13:34:09 -07:00 committed by GitHub
parent 5ce48dfacd
commit abaff58c6d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 501 additions and 112 deletions

View File

@ -463,6 +463,119 @@ public:
} }
}; };
// Function to perform division with trunc rounding mode (rounding result
// towards zero) for float type inputs.
// This function takes in the division result between lhs and rhs rather
// than takes in the original lhs and rhs tensors as parameters.
Value truncFloatDivWithDivResult(PatternRewriter &rewriter, Operation *op,
TensorType outType, Value divResult) {
// To implement trunc mode for float inputs, multiply the floored abs
// of the tensor with the elementwise signedness of the tensor.
// div_result = lhs / rhs
// trunc_val = floor(abs(div_result)) * sign(div_result)
auto zero =
tosa::getConstTensor<float>(rewriter, op, 0, {}, outType.getElementType())
.value();
auto one =
tosa::getConstTensor<float>(rewriter, op, 1, {}, outType.getElementType())
.value();
auto minusOne = tosa::getConstTensor<float>(rewriter, op, -1, {},
outType.getElementType())
.value();
auto cond = rewriter.create<tosa::GreaterEqualOp>(
op->getLoc(),
RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(1)),
divResult, zero);
auto selectOp = rewriter.create<tosa::SelectOp>(op->getLoc(), outType, cond,
one, minusOne);
auto absDivResult =
rewriter.create<tosa::AbsOp>(op->getLoc(), outType, divResult);
auto flooredAbsDivResult =
rewriter.create<tosa::FloorOp>(op->getLoc(), outType, absDivResult);
Value result =
tosa::createMulOpAndCast(rewriter, op, outType, flooredAbsDivResult,
selectOp, /*shift=*/0)
.getResult();
return result;
}
// Function to perform division with trunc rounding mode (rounding result
// towards zero) for float type inputs
Value truncFloatDiv(PatternRewriter &rewriter, Operation *op,
TensorType outType, Value lhs, Value rhs) {
rhs = tosa::promoteType(rewriter, rhs, outType);
auto rhsRcp =
rewriter.create<tosa::ReciprocalOp>(op->getLoc(), rhs.getType(), rhs);
auto divResult = tosa::createMulOpAndCast(rewriter, op, outType, lhs, rhsRcp,
/*shift=*/0);
return truncFloatDivWithDivResult(rewriter, op, outType, divResult);
}
// Function to perform division with floor rounding mode (rounding result
// down) for integer type inputs.
Value floorIntDiv(PatternRewriter &rewriter, Operation *op, TensorType outType,
Value lhs, Value rhs) {
// To implement floor mode int input, utilize tosa::IntDivOp (trunc div
// result) with the following formula elementwise:
// floor_val = trunc_val - ((trunc_val * rhs != lhs)
// && (sign(lhs) != sign(rhs)))
// TOSA IntDiv requires inputs to be i32
auto i32Type =
RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(32));
lhs = tosa::promoteType(rewriter, lhs, i32Type);
rhs = tosa::promoteType(rewriter, rhs, i32Type);
auto intDivOp =
rewriter.create<tosa::IntDivOp>(op->getLoc(), i32Type, lhs, rhs);
auto zero = tosa::getConstTensor<int32_t>(rewriter, op, 0, {}).value();
auto one = tosa::getConstTensor<int32_t>(rewriter, op, 1, {}).value();
auto boolType =
RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(1));
auto lhsMulRhs = rewriter.create<tosa::MulOp>(op->getLoc(), i32Type, lhs, rhs,
/*shift=*/0);
auto lhsRhsDifferentSign =
rewriter.create<tosa::GreaterOp>(op->getLoc(), boolType, zero, lhsMulRhs);
auto truncMulRhs = rewriter.create<tosa::MulOp>(op->getLoc(), i32Type,
intDivOp, rhs, /*shift=*/0);
auto truncMulRhsEqualLhs =
rewriter.create<tosa::EqualOp>(op->getLoc(), boolType, truncMulRhs, lhs);
auto truncMulRhsNotEqualLhs = rewriter.create<tosa::LogicalNotOp>(
op->getLoc(), boolType, truncMulRhsEqualLhs);
auto truncMinusOne =
rewriter.create<tosa::SubOp>(op->getLoc(), i32Type, intDivOp, one);
auto cond = rewriter.create<tosa::LogicalAndOp>(
op->getLoc(), boolType, lhsRhsDifferentSign, truncMulRhsNotEqualLhs);
auto selectOp = rewriter.create<tosa::SelectOp>(op->getLoc(), i32Type, cond,
truncMinusOne, intDivOp);
Value result = tosa::promoteType(rewriter, selectOp, outType);
return result;
}
template <typename AtenOpT> template <typename AtenOpT>
class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> { class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
public: public:
@ -498,25 +611,64 @@ public:
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType( OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType())); op.getType()));
// auto result; // Get rounding mode for aten.div.Tensor_mode
std::string roundMode;
if constexpr (std::is_same<AtenOpT, AtenDivTensorModeOp>() ||
std::is_same<AtenOpT, AtenDivScalarModeOp>()) {
if (!matchPattern(op.getRoundingMode(), m_TorchConstantStr(roundMode)))
return rewriter.notifyMatchFailure(
op, "Non-const rounding mode parameter unsupported");
}
Value result; Value result;
if (isa<mlir::FloatType>(outType.getElementType())) { if (isa<mlir::FloatType>(outType.getElementType())) {
// The input to the reciprocal is an integer sometimes, and we may need to // The input to the reciprocal is an integer sometimes, and we may need
// promote it to a floating point. Per TOSA specification, the input types // to promote it to a floating point. Per TOSA specification, the input
// can only be floating point for tosa::ReciprocalOp. // types can only be floating point for tosa::ReciprocalOp.
Value rhsCasted = tosa::promoteType(rewriter, rhsTensor, outType); rhsTensor = tosa::promoteType(rewriter, rhsTensor, outType);
auto rcpOp = rewriter.create<tosa::ReciprocalOp>( auto rhsRcp = rewriter.create<tosa::ReciprocalOp>(
op->getLoc(), rhsCasted.getType(), rhsCasted); op->getLoc(), rhsTensor.getType(), rhsTensor);
result = tosa::createMulOpAndCast(rewriter, op, outType, lhs, auto divResult = tosa::createMulOpAndCast(rewriter, op, outType, lhs,
rcpOp.getResult(), /*shift=*/0) rhsRcp, /*shift=*/0);
.getResult();
// Round result based on rounding mode
if (roundMode.compare("floor") == 0) {
// "floor": rounds the results of the division down. Equivalent to
// floor division in Python (the // operator).
auto floorOp =
rewriter.create<tosa::FloorOp>(op->getLoc(), outType, divResult);
result = floorOp.getResult();
} else if (roundMode.compare("trunc") == 0) {
// "trunc": rounds the results of the division towards zero. Equivalent
// to C-style integer division.
result = truncFloatDivWithDivResult(rewriter, op, outType, divResult);
} else { } else {
// The output type can be different than the input types (e.g. dividing an // None: No rounding mode
// int tensor results in a floating point tensor). result = divResult.getResult();
result = tosa::createBinaryOpAndCast<tosa::IntDivOp>( }
rewriter, op, outType, lhs, rhsTensor) } else {
.getResult(); if (roundMode.compare("floor") == 0) {
// "floor": rounds the results of the division down. Equivalent to floor
// division in Python (the // operator).
result = floorIntDiv(rewriter, op, outType, lhs, rhsTensor);
} else {
// "trunc": rounds the results of the division towards zero. Equivalent
// to C-style integer division.
// None: no rounding mode.
// TOSA IntDiv requires inputs to be i32
auto i32Type = RankedTensorType::get(outType.getShape(),
rewriter.getIntegerType(32));
lhs = tosa::promoteType(rewriter, lhs, i32Type);
rhsTensor = tosa::promoteType(rewriter, rhsTensor, i32Type);
auto intDivOp = rewriter.create<tosa::IntDivOp>(op->getLoc(), i32Type,
lhs, rhsTensor);
result = tosa::promoteType(rewriter, intDivOp, outType);
}
} }
rewriter.replaceOp(op, {result}); rewriter.replaceOp(op, {result});
@ -4524,20 +4676,24 @@ LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
return success(); return success();
} }
template <> template <typename AtenOpT>
LogicalResult ConvertAtenOp<AtenRemainderScalarOp>::matchAndRewrite( class ConvertAtenRemainderFmodOp : public OpConversionPattern<AtenOpT> {
AtenRemainderScalarOp op, OpAdaptor adaptor, public:
ConversionPatternRewriter &rewriter) const { using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto selfTy = cast<RankedTensorType>(self.getType()); auto selfTy = cast<RankedTensorType>(self.getType());
if (!selfTy) if (!selfTy)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Remainder"); op, "Only ranked tensor types supported in TOSA Remainder/Fmod");
auto outType = auto outType =
cast<TensorType>(getTypeConverter()->convertType(op.getType())); cast<TensorType>(this->getTypeConverter()->convertType(op.getType()));
Type outElemTy = outType.getElementType(); Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) if (!outElemTy.isIntOrFloat())
@ -4545,35 +4701,69 @@ LogicalResult ConvertAtenOp<AtenRemainderScalarOp>::matchAndRewrite(
op, "Only floating-point or integer datatype legalization supported"); op, "Only floating-point or integer datatype legalization supported");
Value otherTensor; Value otherTensor;
if constexpr (std::is_same<AtenOpT, AtenRemainderScalarOp>()) {
Value other = op.getOther(); Value other = op.getOther();
if (failed(torchScalarToTosaTensor(rewriter, op, other, otherTensor, if (failed(torchScalarToTosaTensor(rewriter, op, other, otherTensor,
outElemTy, {}))) outElemTy, {})))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Currently only scalar constants are supported for " op, "Currently only scalar constants are supported for "
"conversion in TOSA Remainder operation"); "conversion in TOSA Remainder/Fmod operation");
} else {
otherTensor = adaptor.getOther();
auto otherTy = cast<RankedTensorType>(otherTensor.getType());
if (!otherTy)
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Remainder/Fmod");
}
constexpr bool isRemainderOp =
std::is_same<AtenOpT, AtenRemainderScalarOp>() ||
std::is_same<AtenOpT, AtenRemainderTensorOp>() ||
std::is_same<AtenOpT, AtenRemainderIntOp>();
if (selfTy.getElementType() != outElemTy) if (selfTy.getElementType() != outElemTy)
self = rewriter.create<tosa::CastOp>(op.getLoc(), outType, self); self = rewriter.create<tosa::CastOp>(op.getLoc(), outType, self);
auto divTensor = self; Value divTensor;
if (isRemainderOp) {
// torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b
if (isa<mlir::FloatType>(outElemTy)) { if (isa<mlir::FloatType>(outElemTy)) {
auto otherTensorReciprocal = rewriter.create<tosa::ReciprocalOp>( auto otherTensorReciprocal = rewriter.create<tosa::ReciprocalOp>(
op.getLoc(), otherTensor.getType(), otherTensor); op.getLoc(), otherTensor.getType(), otherTensor);
divTensor = rewriter.create<tosa::MulOp>( divTensor = rewriter.create<tosa::MulOp>(
op.getLoc(), outType, self, otherTensorReciprocal, /*shift=*/0); op.getLoc(), outType, self, otherTensorReciprocal, /*shift=*/0);
divTensor = rewriter.create<tosa::FloorOp>(op.getLoc(), outType, divTensor); divTensor =
rewriter.create<tosa::FloorOp>(op.getLoc(), outType, divTensor);
} else { } else {
divTensor = rewriter.create<tosa::IntDivOp>(op.getLoc(), outType, self, divTensor = floorIntDiv(rewriter, op, outType, self, otherTensor);
otherTensor); }
} else {
// torch.fmod(a, b) == a - a.div(b, rounding_mode="trunc") * b
if (isa<mlir::FloatType>(outElemTy)) {
divTensor = truncFloatDiv(rewriter, op, outType, self, otherTensor);
} else {
// TOSA IntDiv requires inputs to be i32
auto i32Type = RankedTensorType::get(outType.getShape(),
rewriter.getIntegerType(32));
self = tosa::promoteType(rewriter, self, i32Type);
otherTensor = tosa::promoteType(rewriter, otherTensor, i32Type);
auto intDivTensor = rewriter.create<tosa::IntDivOp>(
op->getLoc(), i32Type, self, otherTensor);
divTensor = tosa::promoteType(rewriter, intDivTensor, outType);
}
} }
auto mulTensor = auto mulTensor = rewriter.create<tosa::MulOp>(op.getLoc(), outType,
rewriter.create<tosa::MulOp>(op.getLoc(), outType, otherTensor, divTensor, otherTensor, divTensor,
/*shift=*/0); /*shift=*/0);
rewriter.replaceOpWithNewOp<tosa::SubOp>(op, outType, self, mulTensor); rewriter.replaceOpWithNewOp<tosa::SubOp>(op, outType, self, mulTensor);
return success(); return success();
} }
};
template <typename AtenOpT, typename TosaOpT> template <typename AtenOpT, typename TosaOpT>
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> { class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
@ -5649,6 +5839,7 @@ public:
patterns.add<ConvertAtenCompareOp<AtenOp, TosaOp>>(typeConverter, context); patterns.add<ConvertAtenCompareOp<AtenOp, TosaOp>>(typeConverter, context);
INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp) INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp, tosa::GreaterEqualOp) INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp, tosa::GreaterEqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenGeTensorOp, tosa::GreaterEqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp) INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp) INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp) INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp)
@ -5673,8 +5864,19 @@ public:
patterns.add<ConvertAtenDivOp<AtenOp>>(typeConverter, context); patterns.add<ConvertAtenDivOp<AtenOp>>(typeConverter, context);
INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp); INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp);
INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp); INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp);
INSERT_BINARY_DIV_PATTERN(AtenDivTensorModeOp);
INSERT_BINARY_DIV_PATTERN(AtenDivScalarModeOp);
#undef INSERT_BINARY_DIV_PATTERN #undef INSERT_BINARY_DIV_PATTERN
#define INSERT_REMAINDER_FMOD_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenRemainderFmodOp<AtenOp>>(typeConverter, context);
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderScalarOp);
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderTensorOp);
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodScalarOp);
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodTensorOp);
#undef INSERT_REMAINDER_FMOD_OP_PATTERN
#define INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ #define INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenMultipleDimsReductionOp<AtenOp, ConversionFunc>>( \ patterns.add<ConvertAtenMultipleDimsReductionOp<AtenOp, ConversionFunc>>( \
@ -5828,7 +6030,6 @@ public:
INSERT_ATENOP_PATTERN(AtenCopyOp); INSERT_ATENOP_PATTERN(AtenCopyOp);
INSERT_ATENOP_PATTERN(AtenToDtypeOp); INSERT_ATENOP_PATTERN(AtenToDtypeOp);
INSERT_ATENOP_PATTERN(AtenConstantPadNdOp); INSERT_ATENOP_PATTERN(AtenConstantPadNdOp);
INSERT_ATENOP_PATTERN(AtenRemainderScalarOp);
INSERT_ATENOP_PATTERN(AtenCatOp); INSERT_ATENOP_PATTERN(AtenCatOp);
INSERT_ATENOP_PATTERN(AtenSqrtOp); INSERT_ATENOP_PATTERN(AtenSqrtOp);
INSERT_ATENOP_PATTERN(AtenIscloseOp); INSERT_ATENOP_PATTERN(AtenIscloseOp);

View File

@ -1668,6 +1668,40 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
# Write the TOSA set as a "passing" set as it is very early in development # Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet. # and very few tests work yet.
TOSA_PASS_SET = { TOSA_PASS_SET = {
"ElementwiseAtenFloorDivideBroadcastModule_basic",
"ElementwiseAtenFloorDivideScalarModule_basic",
"ElementwiseAtenFloorDivideScalarNegativeModule_basic",
"ElementwiseAtenFloorDivideTensorNegativeModule_basic",
"ElementwiseAtenFloorDivideTensorPositiveModule_basic",
"ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic",
"ElementwiseDivScalarRoundingModeFloorModule_basic",
"ElementwiseDivScalarRoundingModeFloorStaticModule_basic",
"ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic",
"ElementwiseDivScalarRoundingModeTruncModule_basic",
"ElementwiseDivScalarRoundingModeTruncStaticModule_basic",
"ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic",
"ElementwiseDivTensorRoundingModeFloorModule_basic",
"ElementwiseDivTensorRoundingModeFloorStaticModule_basic",
"ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic",
"ElementwiseDivTensorRoundingModeTruncModule_basic",
"ElementwiseDivTensorRoundingModeTruncStaticModule_basic",
"ElementwiseGeFloatTensorModule_basic",
"ElementwiseGeIntTensorModule_basic",
"ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic",
"ElementwiseRemainderScalarModule_Bool_basic",
"ElementwiseRemainderScalarModule_Int_NegativeDividend_basic",
"ElementwiseRemainderScalarModule_Int_NegativeDivisor_basic",
"ElementwiseRemainderTensorModule_Float_NegativeDividend_basic",
"ElementwiseRemainderTensorModule_Float_NegativeDivisor_basic",
"ElementwiseRemainderTensorModule_Float_basic",
"ElementwiseRemainderTensorModule_Int_Float_NegativeDividend_basic",
"ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic",
"ElementwiseRemainderTensorModule_Int_Float_basic",
"ElementwiseRemainderTensorModule_Int_NegativeDividend_basic",
"ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic",
"ElementwiseRemainderTensorModule_Int_basic",
"TriuBroadcastModule_basic",
"TriuModule_basic",
"ArgminIntModule_basic", "ArgminIntModule_basic",
"ArgminIntModule_multiple_mins", "ArgminIntModule_multiple_mins",
"ArgminModule_basic", "ArgminModule_basic",
@ -2210,6 +2244,7 @@ MAKE_FX_TOSA_PASS_SET = (
TOSA_PASS_SET TOSA_PASS_SET
| { | {
### Tests additionally passing in make_fx_tosa ### Tests additionally passing in make_fx_tosa
"AdaptiveAvgPool1dStaticLargerOutput_basic",
"ArgminIntModule_basic", "ArgminIntModule_basic",
"ArgminIntModule_multiple_mins", "ArgminIntModule_multiple_mins",
"ArgminModule_basic", "ArgminModule_basic",
@ -2318,7 +2353,6 @@ MAKE_FX_TOSA_PASS_SET = (
"ViewNoChange1dModule_basic", "ViewNoChange1dModule_basic",
"ViewNoChange2dModule_basic", "ViewNoChange2dModule_basic",
"ViewNoChange3dModule_basic", "ViewNoChange3dModule_basic",
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
} }
if torch_version_for_comparison() < version.parse("2.5.0.dev"): if torch_version_for_comparison() < version.parse("2.5.0.dev"):
@ -3137,7 +3171,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"Rot90MultipleRotationsModule_basic", "Rot90MultipleRotationsModule_basic",
"Rot90NegativeEvenRotationsModule_basic", "Rot90NegativeEvenRotationsModule_basic",
"Rot90NegativeOddRotationsModule_basic", "Rot90NegativeOddRotationsModule_basic",
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
"AtenIntMM_basic", "AtenIntMM_basic",
"AtenKthvalueDynamicDimsModule_basic", "AtenKthvalueDynamicDimsModule_basic",
"AtenKthvalueFloat64DynamicDimsModule_basic", "AtenKthvalueFloat64DynamicDimsModule_basic",
@ -3153,15 +3186,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"EinsumStaticDiagonalDimensionModule_basic", "EinsumStaticDiagonalDimensionModule_basic",
"ElementwiseFloatTensorGtIntTensorModule_basic", "ElementwiseFloatTensorGtIntTensorModule_basic",
"ElementwiseIntTensorLtFloatTensorModule_basic", "ElementwiseIntTensorLtFloatTensorModule_basic",
"ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic",
"ElementwiseRemainderScalarModule_Int_NegativeDividend_basic",
"ElementwiseRemainderScalarModule_Int_NegativeDivisor_basic",
"ElementwiseRemainderTensorModule_Float_NegativeDividend_basic",
"ElementwiseRemainderTensorModule_Float_NegativeDivisor_basic",
"ElementwiseRemainderTensorModule_Int_Float_NegativeDividend_basic",
"ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic",
"ElementwiseRemainderTensorModule_Int_NegativeDividend_basic",
"ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic",
"ElementwiseRreluEvalModule_basic", "ElementwiseRreluEvalModule_basic",
"ElementwiseRreluEvalStaticModule_basic", "ElementwiseRreluEvalStaticModule_basic",
"ElementwiseRreluTrainModule_basic", "ElementwiseRreluTrainModule_basic",
@ -3194,11 +3218,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"TriuIndicesNegativeOffsetModule_basic", "TriuIndicesNegativeOffsetModule_basic",
"TypeConversionUint8ToF32Module_basic", "TypeConversionUint8ToF32Module_basic",
"WeightNormInterfaceModule_basic", "WeightNormInterfaceModule_basic",
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
"AdaptiveAvgPool1dGeneralDynamic_basic",
"AdaptiveAvgPool1dStaticLargerOutput_basic",
"AdaptiveAvgPool2dDynamicNoBatch_basic",
"AdaptiveAvgPool2dDynamic_basic",
"AdaptiveAvgPool3dDynamicNoBatch_basic", "AdaptiveAvgPool3dDynamicNoBatch_basic",
"AdaptiveAvgPool3dDynamic_basic", "AdaptiveAvgPool3dDynamic_basic",
"AdaptiveMaxPool1dDynamicNoBatch_basic", "AdaptiveMaxPool1dDynamicNoBatch_basic",
@ -3370,11 +3389,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ElementwiseAtanTensorIntModule_basic", "ElementwiseAtanTensorIntModule_basic",
"ElementwiseAtanhIntModule_basic", "ElementwiseAtanhIntModule_basic",
"ElementwiseAtanhModule_basic", "ElementwiseAtanhModule_basic",
"ElementwiseAtenFloorDivideBroadcastModule_basic",
"ElementwiseAtenFloorDivideScalarModule_basic",
"ElementwiseAtenFloorDivideScalarNegativeModule_basic",
"ElementwiseAtenFloorDivideTensorNegativeModule_basic",
"ElementwiseAtenFloorDivideTensorPositiveModule_basic",
"ElementwiseAtenLogicalAndOpModule_basic", "ElementwiseAtenLogicalAndOpModule_basic",
"ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic", "ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic",
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic", "ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",
@ -3402,25 +3416,11 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ElementwiseCoshModule_basic", "ElementwiseCoshModule_basic",
"ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerChannelModule_basic",
"ElementwiseDequantizePerTensorModule_basic", "ElementwiseDequantizePerTensorModule_basic",
"ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic",
"ElementwiseDivScalarRoundingModeFloorModule_basic",
"ElementwiseDivScalarRoundingModeFloorStaticModule_basic",
"ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic",
"ElementwiseDivScalarRoundingModeTruncModule_basic",
"ElementwiseDivScalarRoundingModeTruncStaticModule_basic",
"ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic",
"ElementwiseDivTensorRoundingModeFloorModule_basic",
"ElementwiseDivTensorRoundingModeFloorStaticModule_basic",
"ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic",
"ElementwiseDivTensorRoundingModeTruncModule_basic",
"ElementwiseDivTensorRoundingModeTruncStaticModule_basic",
"ElementwiseErfIntModule_basic", "ElementwiseErfIntModule_basic",
"ElementwiseErfModule_basic", "ElementwiseErfModule_basic",
"ElementwiseExpIntModule_basic", "ElementwiseExpIntModule_basic",
"ElementwiseExpm1IntModule_basic", "ElementwiseExpm1IntModule_basic",
"ElementwiseExpm1Module_basic", "ElementwiseExpm1Module_basic",
"ElementwiseGeFloatTensorModule_basic",
"ElementwiseGeIntTensorModule_basic",
"ElementwiseGeluApproximateTanhModule_basic", "ElementwiseGeluApproximateTanhModule_basic",
"ElementwiseHardshrinkModule_basic", "ElementwiseHardshrinkModule_basic",
"ElementwiseHardshrinkStaticModule_basic", "ElementwiseHardshrinkStaticModule_basic",
@ -3448,10 +3448,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorModule_basic",
"ElementwiseQuantizePerTensorUIntModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic",
"ElementwiseReciprocalIntModule_basic", "ElementwiseReciprocalIntModule_basic",
"ElementwiseRemainderScalarModule_Bool_basic",
"ElementwiseRemainderTensorModule_Float_basic",
"ElementwiseRemainderTensorModule_Int_Float_basic",
"ElementwiseRemainderTensorModule_Int_basic",
"ElementwiseRsqrtIntModule_basic", "ElementwiseRsqrtIntModule_basic",
"ElementwiseSigmoidIntModule_basic", "ElementwiseSigmoidIntModule_basic",
"ElementwiseSinIntModule_basic", "ElementwiseSinIntModule_basic",
@ -3850,6 +3846,7 @@ ONNX_TOSA_CRASHING_SET = {
} }
ONNX_TOSA_XFAIL_SET = { ONNX_TOSA_XFAIL_SET = {
"ScaledDotProductAttentionDifferentCausalModule_basic",
"HstackBasicComplexModule_basic", "HstackBasicComplexModule_basic",
"HstackBasicFloatModule_basic", "HstackBasicFloatModule_basic",
"HstackBasicIntFloatModule_basic", "HstackBasicIntFloatModule_basic",
@ -3890,8 +3887,6 @@ ONNX_TOSA_XFAIL_SET = {
"ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic", "ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic",
"ElementwiseRemainderScalarModule_Int_Float_NegativeDividend_basic", "ElementwiseRemainderScalarModule_Int_Float_NegativeDividend_basic",
"ElementwiseRemainderScalarModule_Int_Float_NegativeDivisor_basic", "ElementwiseRemainderScalarModule_Int_Float_NegativeDivisor_basic",
"ElementwiseRemainderScalarModule_Int_NegativeDividend_basic",
"ElementwiseRemainderScalarModule_Int_NegativeDivisor_basic",
"ElementwiseRemainderTensorModule_Int_Float_NegativeDividend_basic", "ElementwiseRemainderTensorModule_Int_Float_NegativeDividend_basic",
"ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic", "ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic",
"ElementwiseRemainderTensorModule_Int_NegativeDividend_basic", "ElementwiseRemainderTensorModule_Int_NegativeDividend_basic",
@ -4223,11 +4218,6 @@ ONNX_TOSA_XFAIL_SET = {
"ElementwiseFloatTensorGtIntTensorModule_basic", "ElementwiseFloatTensorGtIntTensorModule_basic",
"ElementwiseFmodTensor_Int_Float_basic", "ElementwiseFmodTensor_Int_Float_basic",
"ElementwiseFmodTensor_Int_basic", "ElementwiseFmodTensor_Int_basic",
"ElementwiseGeFloatIntScalarModule_basic",
"ElementwiseGeFloatScalarModule_basic",
"ElementwiseGeFloatTensorModule_basic",
"ElementwiseGeIntScalarModule_basic",
"ElementwiseGeIntTensorModule_basic",
"ElementwiseGeMixedIntScalarModule_basic", "ElementwiseGeMixedIntScalarModule_basic",
"ElementwiseGtMixed2ScalarModule_basic", "ElementwiseGtMixed2ScalarModule_basic",
"ElementwiseIntTensorLtFloatScalarModule_basic", "ElementwiseIntTensorLtFloatScalarModule_basic",
@ -4259,7 +4249,6 @@ ONNX_TOSA_XFAIL_SET = {
"ElementwiseRelu6Module_basic", "ElementwiseRelu6Module_basic",
"ElementwiseRemainderScalarModule_Bool_basic", "ElementwiseRemainderScalarModule_Bool_basic",
"ElementwiseRemainderScalarModule_Int_Float_basic", "ElementwiseRemainderScalarModule_Int_Float_basic",
"ElementwiseRemainderScalarModule_Int_basic",
"ElementwiseRemainderTensorModule_Int_Float_basic", "ElementwiseRemainderTensorModule_Int_Float_basic",
"ElementwiseRemainderTensorModule_Int_basic", "ElementwiseRemainderTensorModule_Int_basic",
"ElementwiseRsqrtIntModule_basic", "ElementwiseRsqrtIntModule_basic",
@ -4682,7 +4671,6 @@ ONNX_TOSA_XFAIL_SET = {
"ScalarImplicitIntModule_basic", "ScalarImplicitIntModule_basic",
# REMOVE WHEN ENABLE_GQA IS ADDED # REMOVE WHEN ENABLE_GQA IS ADDED
"ScaledDotProductAttentionBoolMaskModule_basic", "ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic", "ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic", "ScaledDotProductAttentionSameDynamicModule_basic",
"ScatterReduceFloatMaxModule", "ScatterReduceFloatMaxModule",
@ -4819,8 +4807,6 @@ ONNX_TOSA_XFAIL_SET = {
"TraceSignedIntModule_basic", "TraceSignedIntModule_basic",
"TraceUnsignedIntModule_basic", "TraceUnsignedIntModule_basic",
"TraceUnsignedIntModule_empty", "TraceUnsignedIntModule_empty",
"TriuBroadcastModule_basic",
"TriuModule_basic",
"TupleModule_basic", "TupleModule_basic",
"TypeAsDifferentModule_basic", "TypeAsDifferentModule_basic",
"TypeConversionF32ToF64Module_basic", "TypeConversionF32ToF64Module_basic",

View File

@ -213,10 +213,10 @@ func.func @torch.aten.mul$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch
// CHECK-LABEL: func.func @torch.aten.div$basic( // CHECK-LABEL: func.func @torch.aten.div$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK-DAG: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_3]] : (tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_2]], %[[VAL_4]] {shift = 0 : i8} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_4]] {shift = 0 : i8} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32>
// CHECK: } // CHECK: }
@ -1470,3 +1470,205 @@ func.func @torch.aten.all.dim$basic(%arg0: !torch.vtensor<[3,2,3],i1>) -> !torch
%0 = torch.aten.all.dim %arg0, %dim, %keepdims: !torch.vtensor<[3,2,3],i1> , !torch.int, !torch.bool -> !torch.vtensor<[3,2,1],i1> %0 = torch.aten.all.dim %arg0, %dim, %keepdims: !torch.vtensor<[3,2,3],i1> , !torch.int, !torch.bool -> !torch.vtensor<[3,2,1],i1>
return %0 : !torch.vtensor<[3,2,1],i1> return %0 : !torch.vtensor<[3,2,1],i1>
} }
// -----
// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$float_trunc(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_4:.*]] = torch.constant.str "trunc"
// CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<-1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_10:.*]] = tosa.greater_equal %[[VAL_6]], %[[VAL_7]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
// CHECK: %[[VAL_11:.*]] = tosa.select %[[VAL_10]], %[[VAL_8]], %[[VAL_9]] : (tensor<?x?xi1>, tensor<f32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_12:.*]] = tosa.abs %[[VAL_6]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_13:.*]] = tosa.floor %[[VAL_12]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_13]], %[[VAL_11]] {shift = 0 : i8} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_15:.*]] = torch_c.from_builtin_tensor %[[VAL_14]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_15]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func.func @torch.aten.div.Tensor_mode$float_trunc(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
%str = torch.constant.str "trunc"
%0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32>, !torch.str -> !torch.vtensor<[?, ?],f32>
return %0 : !torch.vtensor<[?, ?],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$int_trunc(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si64>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
// CHECK: %[[VAL_4:.*]] = torch.constant.str "trunc"
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_3]] : (tensor<?x?xi64>) -> tensor<?x?xi32>
// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_2]] : (tensor<?x?xi64>) -> tensor<?x?xi32>
// CHECK: %[[VAL_7:.*]] = tosa.int_div %[[VAL_5]], %[[VAL_6]] : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
// CHECK: %[[VAL_8:.*]] = tosa.cast %[[VAL_7]] : (tensor<?x?xi32>) -> tensor<?x?xi64>
// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64>
// CHECK: return %[[VAL_9]] : !torch.vtensor<[?,?],si64>
// CHECK: }
func.func @torch.aten.div.Tensor_mode$int_trunc(%arg0: !torch.vtensor<[?, ?],si64>, %arg1: !torch.vtensor<[?, ?],si64>) -> !torch.vtensor<[?, ?],si64> {
%str = torch.constant.str "trunc"
%0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?, ?],si64>, !torch.vtensor<[?, ?],si64>, !torch.str -> !torch.vtensor<[?, ?],si64>
return %0 : !torch.vtensor<[?, ?],si64>
}
// -----
// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$float_floor(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_4:.*]] = torch.constant.str "floor"
// CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_7:.*]] = tosa.floor %[[VAL_6]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func.func @torch.aten.div.Tensor_mode$float_floor(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
%str = torch.constant.str "floor"
%0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32>, !torch.str -> !torch.vtensor<[?, ?],f32>
return %0 : !torch.vtensor<[?, ?],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$int_floor(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si64>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
// CHECK: %[[VAL_4:.*]] = torch.constant.str "floor"
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_3]] : (tensor<?x?xi64>) -> tensor<?x?xi32>
// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_2]] : (tensor<?x?xi64>) -> tensor<?x?xi32>
// CHECK: %[[VAL_7:.*]] = tosa.int_div %[[VAL_5]], %[[VAL_6]] : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
// CHECK: %[[VAL_10:.*]] = tosa.mul %[[VAL_5]], %[[VAL_6]] {shift = 0 : i8} : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
// CHECK: %[[VAL_11:.*]] = tosa.greater %[[VAL_8]], %[[VAL_10]] : (tensor<i32>, tensor<?x?xi32>) -> tensor<?x?xi1>
// CHECK: %[[VAL_12:.*]] = tosa.mul %[[VAL_7]], %[[VAL_6]] {shift = 0 : i8} : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
// CHECK: %[[VAL_13:.*]] = tosa.equal %[[VAL_12]], %[[VAL_5]] : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi1>
// CHECK: %[[VAL_14:.*]] = tosa.logical_not %[[VAL_13]] : (tensor<?x?xi1>) -> tensor<?x?xi1>
// CHECK: %[[VAL_15:.*]] = tosa.sub %[[VAL_7]], %[[VAL_9]] : (tensor<?x?xi32>, tensor<i32>) -> tensor<?x?xi32>
// CHECK: %[[VAL_16:.*]] = tosa.logical_and %[[VAL_11]], %[[VAL_14]] : (tensor<?x?xi1>, tensor<?x?xi1>) -> tensor<?x?xi1>
// CHECK: %[[VAL_17:.*]] = tosa.select %[[VAL_16]], %[[VAL_15]], %[[VAL_7]] : (tensor<?x?xi1>, tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
// CHECK: %[[VAL_18:.*]] = tosa.cast %[[VAL_17]] : (tensor<?x?xi32>) -> tensor<?x?xi64>
// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64>
// CHECK: return %[[VAL_19]] : !torch.vtensor<[?,?],si64>
// CHECK: }
func.func @torch.aten.div.Tensor_mode$int_floor(%arg0: !torch.vtensor<[?, ?],si64>, %arg1: !torch.vtensor<[?, ?],si64>) -> !torch.vtensor<[?, ?],si64> {
%str = torch.constant.str "floor"
%0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?, ?],si64>, !torch.vtensor<[?, ?],si64>, !torch.str -> !torch.vtensor<[?, ?],si64>
return %0 : !torch.vtensor<[?, ?],si64>
}
// -----
// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$float_basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_4:.*]] = torch.constant.str ""
// CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func.func @torch.aten.div.Tensor_mode$float_basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
%str = torch.constant.str ""
%0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32>, !torch.str -> !torch.vtensor<[?, ?],f32>
return %0 : !torch.vtensor<[?, ?],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$int_basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si64>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
// CHECK: %[[VAL_4:.*]] = torch.constant.str ""
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_3]] : (tensor<?x?xi64>) -> tensor<?x?xi32>
// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_2]] : (tensor<?x?xi64>) -> tensor<?x?xi32>
// CHECK: %[[VAL_7:.*]] = tosa.int_div %[[VAL_5]], %[[VAL_6]] : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
// CHECK: %[[VAL_8:.*]] = tosa.cast %[[VAL_7]] : (tensor<?x?xi32>) -> tensor<?x?xi64>
// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64>
// CHECK: return %[[VAL_9]] : !torch.vtensor<[?,?],si64>
// CHECK: }
func.func @torch.aten.div.Tensor_mode$int_basic(%arg0: !torch.vtensor<[?, ?],si64>, %arg1: !torch.vtensor<[?, ?],si64>) -> !torch.vtensor<[?, ?],si64> {
%str = torch.constant.str ""
%0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?, ?],si64>, !torch.vtensor<[?, ?],si64>, !torch.str -> !torch.vtensor<[?, ?],si64>
return %0 : !torch.vtensor<[?, ?],si64>
}
// -----
// CHECK-LABEL: func.func @torch.aten.ge.Tensor$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_4:.*]] = tosa.greater_equal %[[VAL_3]], %[[VAL_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
// CHECK: }
func.func @torch.aten.ge.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
%0 = torch.aten.ge.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1>
return %0 : !torch.vtensor<[?,?],i1>
}
// -----
// CHECK-LABEL: func.func @torch.aten.remainder.Tensor(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32>
// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_4]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK: %[[VAL_6:.*]] = tosa.floor %[[VAL_5]] : (tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_2]], %[[VAL_6]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK: %[[VAL_8:.*]] = tosa.sub %[[VAL_3]], %[[VAL_7]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32>
// CHECK: return %[[VAL_9]] : !torch.vtensor<[2,4],f32>
// CHECK: }
func.func @torch.aten.remainder.Tensor(%arg0: !torch.vtensor<[2, 4],f32>, %arg1: !torch.vtensor<[2, 4],f32>) -> !torch.vtensor<[2, 4],f32> {
%0 = torch.aten.remainder.Tensor %arg0, %arg1 : !torch.vtensor<[2, 4],f32>, !torch.vtensor<[2, 4],f32> -> !torch.vtensor<[2, 4],f32>
return %0 : !torch.vtensor<[2, 4],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.fmod.Tensor(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32>
// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_4]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<-1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_9:.*]] = tosa.greater_equal %[[VAL_5]], %[[VAL_6]] : (tensor<2x4xf32>, tensor<f32>) -> tensor<2x4xi1>
// CHECK: %[[VAL_10:.*]] = tosa.select %[[VAL_9]], %[[VAL_7]], %[[VAL_8]] : (tensor<2x4xi1>, tensor<f32>, tensor<f32>) -> tensor<2x4xf32>
// CHECK: %[[VAL_11:.*]] = tosa.abs %[[VAL_5]] : (tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK: %[[VAL_12:.*]] = tosa.floor %[[VAL_11]] : (tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK: %[[VAL_13:.*]] = tosa.mul %[[VAL_12]], %[[VAL_10]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_2]], %[[VAL_13]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK: %[[VAL_15:.*]] = tosa.sub %[[VAL_3]], %[[VAL_14]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32>
// CHECK: return %[[VAL_16]] : !torch.vtensor<[2,4],f32>
// CHECK: }
func.func @torch.aten.fmod.Tensor(%arg0: !torch.vtensor<[2, 4],f32>, %arg1: !torch.vtensor<[2, 4],f32>) -> !torch.vtensor<[2, 4],f32> {
%0 = torch.aten.fmod.Tensor %arg0, %arg1 : !torch.vtensor<[2, 4],f32>, !torch.vtensor<[2, 4],f32> -> !torch.vtensor<[2, 4],f32>
return %0 : !torch.vtensor<[2, 4],f32>
}