diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 0dbea2b5c..5ecfd62a0 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -463,6 +463,119 @@ public: } }; +// Function to perform division with trunc rounding mode (rounding result +// towards zero) for float type inputs. +// This function takes in the division result between lhs and rhs rather +// than takes in the original lhs and rhs tensors as parameters. +Value truncFloatDivWithDivResult(PatternRewriter &rewriter, Operation *op, + TensorType outType, Value divResult) { + // To implement trunc mode for float inputs, multiply the floored abs + // of the tensor with the elementwise signedness of the tensor. + // div_result = lhs / rhs + // trunc_val = floor(abs(div_result)) * sign(div_result) + auto zero = + tosa::getConstTensor(rewriter, op, 0, {}, outType.getElementType()) + .value(); + + auto one = + tosa::getConstTensor(rewriter, op, 1, {}, outType.getElementType()) + .value(); + + auto minusOne = tosa::getConstTensor(rewriter, op, -1, {}, + outType.getElementType()) + .value(); + + auto cond = rewriter.create( + op->getLoc(), + RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(1)), + divResult, zero); + + auto selectOp = rewriter.create(op->getLoc(), outType, cond, + one, minusOne); + + auto absDivResult = + rewriter.create(op->getLoc(), outType, divResult); + + auto flooredAbsDivResult = + rewriter.create(op->getLoc(), outType, absDivResult); + + Value result = + tosa::createMulOpAndCast(rewriter, op, outType, flooredAbsDivResult, + selectOp, /*shift=*/0) + .getResult(); + + return result; +} + +// Function to perform division with trunc rounding mode (rounding result +// towards zero) for float type inputs +Value truncFloatDiv(PatternRewriter &rewriter, Operation *op, + TensorType outType, Value lhs, Value rhs) { + rhs = tosa::promoteType(rewriter, rhs, outType); + + auto rhsRcp = + rewriter.create(op->getLoc(), rhs.getType(), rhs); + + auto divResult = tosa::createMulOpAndCast(rewriter, op, outType, lhs, rhsRcp, + /*shift=*/0); + + return truncFloatDivWithDivResult(rewriter, op, outType, divResult); +} + +// Function to perform division with floor rounding mode (rounding result +// down) for integer type inputs. +Value floorIntDiv(PatternRewriter &rewriter, Operation *op, TensorType outType, + Value lhs, Value rhs) { + // To implement floor mode int input, utilize tosa::IntDivOp (trunc div + // result) with the following formula elementwise: + // floor_val = trunc_val - ((trunc_val * rhs != lhs) + // && (sign(lhs) != sign(rhs))) + + // TOSA IntDiv requires inputs to be i32 + auto i32Type = + RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(32)); + lhs = tosa::promoteType(rewriter, lhs, i32Type); + rhs = tosa::promoteType(rewriter, rhs, i32Type); + + auto intDivOp = + rewriter.create(op->getLoc(), i32Type, lhs, rhs); + + auto zero = tosa::getConstTensor(rewriter, op, 0, {}).value(); + + auto one = tosa::getConstTensor(rewriter, op, 1, {}).value(); + + auto boolType = + RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(1)); + + auto lhsMulRhs = rewriter.create(op->getLoc(), i32Type, lhs, rhs, + /*shift=*/0); + + auto lhsRhsDifferentSign = + rewriter.create(op->getLoc(), boolType, zero, lhsMulRhs); + + auto truncMulRhs = rewriter.create(op->getLoc(), i32Type, + intDivOp, rhs, /*shift=*/0); + + auto truncMulRhsEqualLhs = + rewriter.create(op->getLoc(), boolType, truncMulRhs, lhs); + + auto truncMulRhsNotEqualLhs = rewriter.create( + op->getLoc(), boolType, truncMulRhsEqualLhs); + + auto truncMinusOne = + rewriter.create(op->getLoc(), i32Type, intDivOp, one); + + auto cond = rewriter.create( + op->getLoc(), boolType, lhsRhsDifferentSign, truncMulRhsNotEqualLhs); + + auto selectOp = rewriter.create(op->getLoc(), i32Type, cond, + truncMinusOne, intDivOp); + + Value result = tosa::promoteType(rewriter, selectOp, outType); + + return result; +} + template class ConvertAtenDivOp : public OpConversionPattern { public: @@ -498,25 +611,64 @@ public: OpConversionPattern::getTypeConverter()->convertType( op.getType())); - // auto result; + // Get rounding mode for aten.div.Tensor_mode + std::string roundMode; + if constexpr (std::is_same() || + std::is_same()) { + if (!matchPattern(op.getRoundingMode(), m_TorchConstantStr(roundMode))) + return rewriter.notifyMatchFailure( + op, "Non-const rounding mode parameter unsupported"); + } + Value result; if (isa(outType.getElementType())) { - // The input to the reciprocal is an integer sometimes, and we may need to - // promote it to a floating point. Per TOSA specification, the input types - // can only be floating point for tosa::ReciprocalOp. - Value rhsCasted = tosa::promoteType(rewriter, rhsTensor, outType); - auto rcpOp = rewriter.create( - op->getLoc(), rhsCasted.getType(), rhsCasted); + // The input to the reciprocal is an integer sometimes, and we may need + // to promote it to a floating point. Per TOSA specification, the input + // types can only be floating point for tosa::ReciprocalOp. + rhsTensor = tosa::promoteType(rewriter, rhsTensor, outType); + auto rhsRcp = rewriter.create( + op->getLoc(), rhsTensor.getType(), rhsTensor); - result = tosa::createMulOpAndCast(rewriter, op, outType, lhs, - rcpOp.getResult(), /*shift=*/0) - .getResult(); + auto divResult = tosa::createMulOpAndCast(rewriter, op, outType, lhs, + rhsRcp, /*shift=*/0); + + // Round result based on rounding mode + if (roundMode.compare("floor") == 0) { + // "floor": rounds the results of the division down. Equivalent to + // floor division in Python (the // operator). + auto floorOp = + rewriter.create(op->getLoc(), outType, divResult); + + result = floorOp.getResult(); + } else if (roundMode.compare("trunc") == 0) { + // "trunc": rounds the results of the division towards zero. Equivalent + // to C-style integer division. + result = truncFloatDivWithDivResult(rewriter, op, outType, divResult); + } else { + // None: No rounding mode + result = divResult.getResult(); + } } else { - // The output type can be different than the input types (e.g. dividing an - // int tensor results in a floating point tensor). - result = tosa::createBinaryOpAndCast( - rewriter, op, outType, lhs, rhsTensor) - .getResult(); + if (roundMode.compare("floor") == 0) { + // "floor": rounds the results of the division down. Equivalent to floor + // division in Python (the // operator). + result = floorIntDiv(rewriter, op, outType, lhs, rhsTensor); + } else { + // "trunc": rounds the results of the division towards zero. Equivalent + // to C-style integer division. + // None: no rounding mode. + + // TOSA IntDiv requires inputs to be i32 + auto i32Type = RankedTensorType::get(outType.getShape(), + rewriter.getIntegerType(32)); + lhs = tosa::promoteType(rewriter, lhs, i32Type); + rhsTensor = tosa::promoteType(rewriter, rhsTensor, i32Type); + + auto intDivOp = rewriter.create(op->getLoc(), i32Type, + lhs, rhsTensor); + + result = tosa::promoteType(rewriter, intDivOp, outType); + } } rewriter.replaceOp(op, {result}); @@ -4524,56 +4676,94 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenRemainderScalarOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { +template +class ConvertAtenRemainderFmodOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { - Value self = adaptor.getSelf(); - auto selfTy = cast(self.getType()); + Value self = adaptor.getSelf(); + auto selfTy = cast(self.getType()); - if (!selfTy) - return rewriter.notifyMatchFailure( - op, "Only ranked tensor types supported in TOSA Remainder"); + if (!selfTy) + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types supported in TOSA Remainder/Fmod"); - auto outType = - cast(getTypeConverter()->convertType(op.getType())); + auto outType = + cast(this->getTypeConverter()->convertType(op.getType())); - Type outElemTy = outType.getElementType(); - if (!outElemTy.isIntOrFloat()) - return rewriter.notifyMatchFailure( - op, "Only floating-point or integer datatype legalization supported"); + Type outElemTy = outType.getElementType(); + if (!outElemTy.isIntOrFloat()) + return rewriter.notifyMatchFailure( + op, "Only floating-point or integer datatype legalization supported"); - Value otherTensor; - Value other = op.getOther(); - if (failed(torchScalarToTosaTensor(rewriter, op, other, otherTensor, - outElemTy, {}))) - return rewriter.notifyMatchFailure( - op, "Currently only scalar constants are supported for " - "conversion in TOSA Remainder operation"); + Value otherTensor; + if constexpr (std::is_same()) { + Value other = op.getOther(); + if (failed(torchScalarToTosaTensor(rewriter, op, other, otherTensor, + outElemTy, {}))) + return rewriter.notifyMatchFailure( + op, "Currently only scalar constants are supported for " + "conversion in TOSA Remainder/Fmod operation"); + } else { + otherTensor = adaptor.getOther(); + auto otherTy = cast(otherTensor.getType()); - if (selfTy.getElementType() != outElemTy) - self = rewriter.create(op.getLoc(), outType, self); + if (!otherTy) + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types supported in TOSA Remainder/Fmod"); + } - auto divTensor = self; - if (isa(outElemTy)) { - auto otherTensorReciprocal = rewriter.create( - op.getLoc(), otherTensor.getType(), otherTensor); - divTensor = rewriter.create( - op.getLoc(), outType, self, otherTensorReciprocal, /*shift=*/0); - divTensor = rewriter.create(op.getLoc(), outType, divTensor); - } else { - divTensor = rewriter.create(op.getLoc(), outType, self, - otherTensor); + constexpr bool isRemainderOp = + std::is_same() || + std::is_same() || + std::is_same(); + + if (selfTy.getElementType() != outElemTy) + self = rewriter.create(op.getLoc(), outType, self); + + Value divTensor; + if (isRemainderOp) { + // torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b + if (isa(outElemTy)) { + auto otherTensorReciprocal = rewriter.create( + op.getLoc(), otherTensor.getType(), otherTensor); + divTensor = rewriter.create( + op.getLoc(), outType, self, otherTensorReciprocal, /*shift=*/0); + divTensor = + rewriter.create(op.getLoc(), outType, divTensor); + } else { + divTensor = floorIntDiv(rewriter, op, outType, self, otherTensor); + } + } else { + // torch.fmod(a, b) == a - a.div(b, rounding_mode="trunc") * b + if (isa(outElemTy)) { + divTensor = truncFloatDiv(rewriter, op, outType, self, otherTensor); + } else { + // TOSA IntDiv requires inputs to be i32 + auto i32Type = RankedTensorType::get(outType.getShape(), + rewriter.getIntegerType(32)); + self = tosa::promoteType(rewriter, self, i32Type); + otherTensor = tosa::promoteType(rewriter, otherTensor, i32Type); + + auto intDivTensor = rewriter.create( + op->getLoc(), i32Type, self, otherTensor); + + divTensor = tosa::promoteType(rewriter, intDivTensor, outType); + } + } + + auto mulTensor = rewriter.create(op.getLoc(), outType, + otherTensor, divTensor, + /*shift=*/0); + rewriter.replaceOpWithNewOp(op, outType, self, mulTensor); + + return success(); } - - auto mulTensor = - rewriter.create(op.getLoc(), outType, otherTensor, divTensor, - /*shift=*/0); - rewriter.replaceOpWithNewOp(op, outType, self, mulTensor); - - return success(); -} +}; template class ConvertAtenPoolingBaseOp : public OpConversionPattern { @@ -5649,6 +5839,7 @@ public: patterns.add>(typeConverter, context); INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp) INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp, tosa::GreaterEqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenGeTensorOp, tosa::GreaterEqualOp) INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp) INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp) INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp) @@ -5673,8 +5864,19 @@ public: patterns.add>(typeConverter, context); INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp); INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp); + INSERT_BINARY_DIV_PATTERN(AtenDivTensorModeOp); + INSERT_BINARY_DIV_PATTERN(AtenDivScalarModeOp); #undef INSERT_BINARY_DIV_PATTERN +#define INSERT_REMAINDER_FMOD_OP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderScalarOp); + INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderTensorOp); + INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodScalarOp); + INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodTensorOp); +#undef INSERT_REMAINDER_FMOD_OP_PATTERN + #define INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ target.addIllegalOp(); \ patterns.add>( \ @@ -5828,7 +6030,6 @@ public: INSERT_ATENOP_PATTERN(AtenCopyOp); INSERT_ATENOP_PATTERN(AtenToDtypeOp); INSERT_ATENOP_PATTERN(AtenConstantPadNdOp); - INSERT_ATENOP_PATTERN(AtenRemainderScalarOp); INSERT_ATENOP_PATTERN(AtenCatOp); INSERT_ATENOP_PATTERN(AtenSqrtOp); INSERT_ATENOP_PATTERN(AtenIscloseOp); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 8230f5e5a..b45dbda05 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1668,6 +1668,40 @@ FX_IMPORTER_TOSA_CRASHING_SET = { # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "ElementwiseAtenFloorDivideBroadcastModule_basic", + "ElementwiseAtenFloorDivideScalarModule_basic", + "ElementwiseAtenFloorDivideScalarNegativeModule_basic", + "ElementwiseAtenFloorDivideTensorNegativeModule_basic", + "ElementwiseAtenFloorDivideTensorPositiveModule_basic", + "ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic", + "ElementwiseDivScalarRoundingModeFloorModule_basic", + "ElementwiseDivScalarRoundingModeFloorStaticModule_basic", + "ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic", + "ElementwiseDivScalarRoundingModeTruncModule_basic", + "ElementwiseDivScalarRoundingModeTruncStaticModule_basic", + "ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic", + "ElementwiseDivTensorRoundingModeFloorModule_basic", + "ElementwiseDivTensorRoundingModeFloorStaticModule_basic", + "ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic", + "ElementwiseDivTensorRoundingModeTruncModule_basic", + "ElementwiseDivTensorRoundingModeTruncStaticModule_basic", + "ElementwiseGeFloatTensorModule_basic", + "ElementwiseGeIntTensorModule_basic", + "ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic", + "ElementwiseRemainderScalarModule_Bool_basic", + "ElementwiseRemainderScalarModule_Int_NegativeDividend_basic", + "ElementwiseRemainderScalarModule_Int_NegativeDivisor_basic", + "ElementwiseRemainderTensorModule_Float_NegativeDividend_basic", + "ElementwiseRemainderTensorModule_Float_NegativeDivisor_basic", + "ElementwiseRemainderTensorModule_Float_basic", + "ElementwiseRemainderTensorModule_Int_Float_NegativeDividend_basic", + "ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic", + "ElementwiseRemainderTensorModule_Int_Float_basic", + "ElementwiseRemainderTensorModule_Int_NegativeDividend_basic", + "ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic", + "ElementwiseRemainderTensorModule_Int_basic", + "TriuBroadcastModule_basic", + "TriuModule_basic", "ArgminIntModule_basic", "ArgminIntModule_multiple_mins", "ArgminModule_basic", @@ -2210,6 +2244,7 @@ MAKE_FX_TOSA_PASS_SET = ( TOSA_PASS_SET | { ### Tests additionally passing in make_fx_tosa + "AdaptiveAvgPool1dStaticLargerOutput_basic", "ArgminIntModule_basic", "ArgminIntModule_multiple_mins", "ArgminModule_basic", @@ -2318,7 +2353,6 @@ MAKE_FX_TOSA_PASS_SET = ( "ViewNoChange1dModule_basic", "ViewNoChange2dModule_basic", "ViewNoChange3dModule_basic", - "AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic", } if torch_version_for_comparison() < version.parse("2.5.0.dev"): @@ -3137,7 +3171,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "Rot90MultipleRotationsModule_basic", "Rot90NegativeEvenRotationsModule_basic", "Rot90NegativeOddRotationsModule_basic", - "AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic", "AtenIntMM_basic", "AtenKthvalueDynamicDimsModule_basic", "AtenKthvalueFloat64DynamicDimsModule_basic", @@ -3153,15 +3186,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "EinsumStaticDiagonalDimensionModule_basic", "ElementwiseFloatTensorGtIntTensorModule_basic", "ElementwiseIntTensorLtFloatTensorModule_basic", - "ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic", - "ElementwiseRemainderScalarModule_Int_NegativeDividend_basic", - "ElementwiseRemainderScalarModule_Int_NegativeDivisor_basic", - "ElementwiseRemainderTensorModule_Float_NegativeDividend_basic", - "ElementwiseRemainderTensorModule_Float_NegativeDivisor_basic", - "ElementwiseRemainderTensorModule_Int_Float_NegativeDividend_basic", - "ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic", - "ElementwiseRemainderTensorModule_Int_NegativeDividend_basic", - "ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic", "ElementwiseRreluEvalModule_basic", "ElementwiseRreluEvalStaticModule_basic", "ElementwiseRreluTrainModule_basic", @@ -3194,11 +3218,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "TriuIndicesNegativeOffsetModule_basic", "TypeConversionUint8ToF32Module_basic", "WeightNormInterfaceModule_basic", - "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", - "AdaptiveAvgPool1dGeneralDynamic_basic", - "AdaptiveAvgPool1dStaticLargerOutput_basic", - "AdaptiveAvgPool2dDynamicNoBatch_basic", - "AdaptiveAvgPool2dDynamic_basic", "AdaptiveAvgPool3dDynamicNoBatch_basic", "AdaptiveAvgPool3dDynamic_basic", "AdaptiveMaxPool1dDynamicNoBatch_basic", @@ -3370,11 +3389,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "ElementwiseAtanTensorIntModule_basic", "ElementwiseAtanhIntModule_basic", "ElementwiseAtanhModule_basic", - "ElementwiseAtenFloorDivideBroadcastModule_basic", - "ElementwiseAtenFloorDivideScalarModule_basic", - "ElementwiseAtenFloorDivideScalarNegativeModule_basic", - "ElementwiseAtenFloorDivideTensorNegativeModule_basic", - "ElementwiseAtenFloorDivideTensorPositiveModule_basic", "ElementwiseAtenLogicalAndOpModule_basic", "ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic", "ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic", @@ -3402,25 +3416,11 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "ElementwiseCoshModule_basic", "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", - "ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic", - "ElementwiseDivScalarRoundingModeFloorModule_basic", - "ElementwiseDivScalarRoundingModeFloorStaticModule_basic", - "ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic", - "ElementwiseDivScalarRoundingModeTruncModule_basic", - "ElementwiseDivScalarRoundingModeTruncStaticModule_basic", - "ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic", - "ElementwiseDivTensorRoundingModeFloorModule_basic", - "ElementwiseDivTensorRoundingModeFloorStaticModule_basic", - "ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic", - "ElementwiseDivTensorRoundingModeTruncModule_basic", - "ElementwiseDivTensorRoundingModeTruncStaticModule_basic", "ElementwiseErfIntModule_basic", "ElementwiseErfModule_basic", "ElementwiseExpIntModule_basic", "ElementwiseExpm1IntModule_basic", "ElementwiseExpm1Module_basic", - "ElementwiseGeFloatTensorModule_basic", - "ElementwiseGeIntTensorModule_basic", "ElementwiseGeluApproximateTanhModule_basic", "ElementwiseHardshrinkModule_basic", "ElementwiseHardshrinkStaticModule_basic", @@ -3448,10 +3448,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic", "ElementwiseReciprocalIntModule_basic", - "ElementwiseRemainderScalarModule_Bool_basic", - "ElementwiseRemainderTensorModule_Float_basic", - "ElementwiseRemainderTensorModule_Int_Float_basic", - "ElementwiseRemainderTensorModule_Int_basic", "ElementwiseRsqrtIntModule_basic", "ElementwiseSigmoidIntModule_basic", "ElementwiseSinIntModule_basic", @@ -3850,6 +3846,7 @@ ONNX_TOSA_CRASHING_SET = { } ONNX_TOSA_XFAIL_SET = { + "ScaledDotProductAttentionDifferentCausalModule_basic", "HstackBasicComplexModule_basic", "HstackBasicFloatModule_basic", "HstackBasicIntFloatModule_basic", @@ -3890,8 +3887,6 @@ ONNX_TOSA_XFAIL_SET = { "ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic", "ElementwiseRemainderScalarModule_Int_Float_NegativeDividend_basic", "ElementwiseRemainderScalarModule_Int_Float_NegativeDivisor_basic", - "ElementwiseRemainderScalarModule_Int_NegativeDividend_basic", - "ElementwiseRemainderScalarModule_Int_NegativeDivisor_basic", "ElementwiseRemainderTensorModule_Int_Float_NegativeDividend_basic", "ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic", "ElementwiseRemainderTensorModule_Int_NegativeDividend_basic", @@ -4223,11 +4218,6 @@ ONNX_TOSA_XFAIL_SET = { "ElementwiseFloatTensorGtIntTensorModule_basic", "ElementwiseFmodTensor_Int_Float_basic", "ElementwiseFmodTensor_Int_basic", - "ElementwiseGeFloatIntScalarModule_basic", - "ElementwiseGeFloatScalarModule_basic", - "ElementwiseGeFloatTensorModule_basic", - "ElementwiseGeIntScalarModule_basic", - "ElementwiseGeIntTensorModule_basic", "ElementwiseGeMixedIntScalarModule_basic", "ElementwiseGtMixed2ScalarModule_basic", "ElementwiseIntTensorLtFloatScalarModule_basic", @@ -4259,7 +4249,6 @@ ONNX_TOSA_XFAIL_SET = { "ElementwiseRelu6Module_basic", "ElementwiseRemainderScalarModule_Bool_basic", "ElementwiseRemainderScalarModule_Int_Float_basic", - "ElementwiseRemainderScalarModule_Int_basic", "ElementwiseRemainderTensorModule_Int_Float_basic", "ElementwiseRemainderTensorModule_Int_basic", "ElementwiseRsqrtIntModule_basic", @@ -4682,7 +4671,6 @@ ONNX_TOSA_XFAIL_SET = { "ScalarImplicitIntModule_basic", # REMOVE WHEN ENABLE_GQA IS ADDED "ScaledDotProductAttentionBoolMaskModule_basic", - "ScaledDotProductAttentionDifferentDynamicCausalModule_basic", "ScaledDotProductAttentionSameCausalModule_basic", "ScaledDotProductAttentionSameDynamicModule_basic", "ScatterReduceFloatMaxModule", @@ -4819,8 +4807,6 @@ ONNX_TOSA_XFAIL_SET = { "TraceSignedIntModule_basic", "TraceUnsignedIntModule_basic", "TraceUnsignedIntModule_empty", - "TriuBroadcastModule_basic", - "TriuModule_basic", "TupleModule_basic", "TypeAsDifferentModule_basic", "TypeConversionF32ToF64Module_basic", diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index c8a3d371f..53128a669 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -213,10 +213,10 @@ func.func @torch.aten.mul$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch // CHECK-LABEL: func.func @torch.aten.div$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK-DAG: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_3]] : (tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_2]], %[[VAL_4]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_4]] {shift = 0 : i8} : (tensor, tensor) -> tensor // CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -1470,3 +1470,205 @@ func.func @torch.aten.all.dim$basic(%arg0: !torch.vtensor<[3,2,3],i1>) -> !torch %0 = torch.aten.all.dim %arg0, %dim, %keepdims: !torch.vtensor<[3,2,3],i1> , !torch.int, !torch.bool -> !torch.vtensor<[3,2,1],i1> return %0 : !torch.vtensor<[3,2,1],i1> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$float_trunc( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_4:.*]] = torch.constant.str "trunc" +// CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<-1.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_10:.*]] = tosa.greater_equal %[[VAL_6]], %[[VAL_7]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_11:.*]] = tosa.select %[[VAL_10]], %[[VAL_8]], %[[VAL_9]] : (tensor, tensor, tensor) -> tensor +// CHECK: %[[VAL_12:.*]] = tosa.abs %[[VAL_6]] : (tensor) -> tensor +// CHECK: %[[VAL_13:.*]] = tosa.floor %[[VAL_12]] : (tensor) -> tensor +// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_13]], %[[VAL_11]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_15:.*]] = torch_c.from_builtin_tensor %[[VAL_14]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_15]] : !torch.vtensor<[?,?],f32> +// CHECK: } +func.func @torch.aten.div.Tensor_mode$float_trunc(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { + %str = torch.constant.str "trunc" + %0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32>, !torch.str -> !torch.vtensor<[?, ?],f32> + return %0 : !torch.vtensor<[?, ?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$int_trunc( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si64>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si64> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si64> -> tensor +// CHECK: %[[VAL_4:.*]] = torch.constant.str "trunc" +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_3]] : (tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_2]] : (tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.int_div %[[VAL_5]], %[[VAL_6]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.cast %[[VAL_7]] : (tensor) -> tensor +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor -> !torch.vtensor<[?,?],si64> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[?,?],si64> +// CHECK: } +func.func @torch.aten.div.Tensor_mode$int_trunc(%arg0: !torch.vtensor<[?, ?],si64>, %arg1: !torch.vtensor<[?, ?],si64>) -> !torch.vtensor<[?, ?],si64> { + %str = torch.constant.str "trunc" + %0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?, ?],si64>, !torch.vtensor<[?, ?],si64>, !torch.str -> !torch.vtensor<[?, ?],si64> + return %0 : !torch.vtensor<[?, ?],si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$float_floor( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_4:.*]] = torch.constant.str "floor" +// CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.floor %[[VAL_6]] : (tensor) -> tensor +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> +// CHECK: } +func.func @torch.aten.div.Tensor_mode$float_floor(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { + %str = torch.constant.str "floor" + %0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32>, !torch.str -> !torch.vtensor<[?, ?],f32> + return %0 : !torch.vtensor<[?, ?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$int_floor( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si64>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si64> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si64> -> tensor +// CHECK: %[[VAL_4:.*]] = torch.constant.str "floor" +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_3]] : (tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_2]] : (tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.int_div %[[VAL_5]], %[[VAL_6]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<1> : tensor}> : () -> tensor +// CHECK: %[[VAL_10:.*]] = tosa.mul %[[VAL_5]], %[[VAL_6]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_11:.*]] = tosa.greater %[[VAL_8]], %[[VAL_10]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_12:.*]] = tosa.mul %[[VAL_7]], %[[VAL_6]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_13:.*]] = tosa.equal %[[VAL_12]], %[[VAL_5]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_14:.*]] = tosa.logical_not %[[VAL_13]] : (tensor) -> tensor +// CHECK: %[[VAL_15:.*]] = tosa.sub %[[VAL_7]], %[[VAL_9]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_16:.*]] = tosa.logical_and %[[VAL_11]], %[[VAL_14]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_17:.*]] = tosa.select %[[VAL_16]], %[[VAL_15]], %[[VAL_7]] : (tensor, tensor, tensor) -> tensor +// CHECK: %[[VAL_18:.*]] = tosa.cast %[[VAL_17]] : (tensor) -> tensor +// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor -> !torch.vtensor<[?,?],si64> +// CHECK: return %[[VAL_19]] : !torch.vtensor<[?,?],si64> +// CHECK: } +func.func @torch.aten.div.Tensor_mode$int_floor(%arg0: !torch.vtensor<[?, ?],si64>, %arg1: !torch.vtensor<[?, ?],si64>) -> !torch.vtensor<[?, ?],si64> { + %str = torch.constant.str "floor" + %0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?, ?],si64>, !torch.vtensor<[?, ?],si64>, !torch.str -> !torch.vtensor<[?, ?],si64> + return %0 : !torch.vtensor<[?, ?],si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$float_basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_4:.*]] = torch.constant.str "" +// CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],f32> +// CHECK: } +func.func @torch.aten.div.Tensor_mode$float_basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { + %str = torch.constant.str "" + %0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32>, !torch.str -> !torch.vtensor<[?, ?],f32> + return %0 : !torch.vtensor<[?, ?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$int_basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si64>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si64> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si64> -> tensor +// CHECK: %[[VAL_4:.*]] = torch.constant.str "" +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_3]] : (tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_2]] : (tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.int_div %[[VAL_5]], %[[VAL_6]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.cast %[[VAL_7]] : (tensor) -> tensor +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor -> !torch.vtensor<[?,?],si64> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[?,?],si64> +// CHECK: } +func.func @torch.aten.div.Tensor_mode$int_basic(%arg0: !torch.vtensor<[?, ?],si64>, %arg1: !torch.vtensor<[?, ?],si64>) -> !torch.vtensor<[?, ?],si64> { + %str = torch.constant.str "" + %0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?, ?],si64>, !torch.vtensor<[?, ?],si64>, !torch.str -> !torch.vtensor<[?, ?],si64> + return %0 : !torch.vtensor<[?, ?],si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.ge.Tensor$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.greater_equal %[[VAL_3]], %[[VAL_2]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> +// CHECK: } +func.func @torch.aten.ge.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { + %0 = torch.aten.ge.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.remainder.Tensor( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32> +// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_4]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_6:.*]] = tosa.floor %[[VAL_5]] : (tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_2]], %[[VAL_6]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_8:.*]] = tosa.sub %[[VAL_3]], %[[VAL_7]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[2,4],f32> +// CHECK: } +func.func @torch.aten.remainder.Tensor(%arg0: !torch.vtensor<[2, 4],f32>, %arg1: !torch.vtensor<[2, 4],f32>) -> !torch.vtensor<[2, 4],f32> { + %0 = torch.aten.remainder.Tensor %arg0, %arg1 : !torch.vtensor<[2, 4],f32>, !torch.vtensor<[2, 4],f32> -> !torch.vtensor<[2, 4],f32> + return %0 : !torch.vtensor<[2, 4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.fmod.Tensor( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32> +// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_4]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<-1.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_9:.*]] = tosa.greater_equal %[[VAL_5]], %[[VAL_6]] : (tensor<2x4xf32>, tensor) -> tensor<2x4xi1> +// CHECK: %[[VAL_10:.*]] = tosa.select %[[VAL_9]], %[[VAL_7]], %[[VAL_8]] : (tensor<2x4xi1>, tensor, tensor) -> tensor<2x4xf32> +// CHECK: %[[VAL_11:.*]] = tosa.abs %[[VAL_5]] : (tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_12:.*]] = tosa.floor %[[VAL_11]] : (tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_13:.*]] = tosa.mul %[[VAL_12]], %[[VAL_10]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_2]], %[[VAL_13]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_15:.*]] = tosa.sub %[[VAL_3]], %[[VAL_14]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> +// CHECK: return %[[VAL_16]] : !torch.vtensor<[2,4],f32> +// CHECK: } +func.func @torch.aten.fmod.Tensor(%arg0: !torch.vtensor<[2, 4],f32>, %arg1: !torch.vtensor<[2, 4],f32>) -> !torch.vtensor<[2, 4],f32> { + %0 = torch.aten.fmod.Tensor %arg0, %arg1 : !torch.vtensor<[2, 4],f32>, !torch.vtensor<[2, 4],f32> -> !torch.vtensor<[2, 4],f32> + return %0 : !torch.vtensor<[2, 4],f32> +}