From 3f79a2982ad2f3b847b73999d1e415de964fba89 Mon Sep 17 00:00:00 2001 From: Justin Ngo Date: Fri, 20 Sep 2024 14:33:55 -0700 Subject: [PATCH] [TOSA] Extend Torch to TOSA legalization coverage (#3718) - Add Torch to TOSA legalization for the following ops: + aten.logical_not + aten.logical_xor + aten.cos + aten.sin + aten.pow.Scalar + aten.pow.Tensor_Tensor + aten.erf + aten.bitwise_and.Scalar + aten.bitwise_left_shift.Tensor + aten.bitwise_right_shift.Tensor + aten.le.Tensor + aten.le.Scalar - Update e2e tests in xfail_sets - Update basic.mlir with newly legalized ops Signed-off-by: Justin Ngo Change-Id: I4aa5790073ef2e5ec0e9b374da42887242f8dabc Signed-off-by: Justin Ngo --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 205 ++++++++++++--------- projects/pt1/e2e_testing/xfail_sets.py | 82 ++++----- test/Conversion/TorchToTosa/basic.mlir | 187 +++++++++++++++++++ 3 files changed, 333 insertions(+), 141 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 5ecfd62a0..2a6b1612c 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -105,9 +105,18 @@ public: OpConversionPattern::getTypeConverter()->convertType( op.getType())); - auto binaryOp = - tosa::createBinaryOpAndCast(rewriter, op, outTy, lhs, rhs); - rewriter.replaceOp(op, binaryOp.getResult()); + Value binaryOp; + + // TOSA ArithmeticRightShiftOp has a round parameter. + if constexpr (std::is_same()) { + binaryOp = rewriter.create(op->getLoc(), outTy, lhs, rhs, + /*round=*/false); + } else { + binaryOp = + tosa::createBinaryOpAndCast(rewriter, op, outTy, lhs, rhs); + } + + rewriter.replaceOp(op, binaryOp); return success(); } }; @@ -353,6 +362,7 @@ public: // For bitwise operators, only integer datatype legalization is supported constexpr bool isBitwiseOp = std::is_same() || + std::is_same() || std::is_same() || std::is_same(); if (isa(lhsElemTy) && isBitwiseOp) { @@ -372,7 +382,9 @@ public: auto rhsTensor = rhsTy ? rhs : rhsAsTensor; // There is no Lesser operator in TOSA. constexpr auto swapLhsRhs = (std::is_same() || - std::is_same()); + std::is_same() || + std::is_same() || + std::is_same()); // Promote lhs and rhs dtypes for bitwise operators. TensorType resultTy = cast( @@ -688,39 +700,30 @@ public: ConversionPatternRewriter &rewriter) const override; }; -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenTanhOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value self = adaptor.getSelf(); - auto selfTy = cast(self.getType()); - if (selfTy && isa(selfTy.getElementType())) { - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), self); - return success(); - } - // Sigmoid legalization in TOSA for quantized element-type uses specialized - // tosa.table construct. - return rewriter.notifyMatchFailure( - op, "Only floating-point datatype legalization currently supported"); -} +template +class ConvertAtenActivationFunctionOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value self = adaptor.getSelf(); + auto selfTy = cast(self.getType()); + + if (!selfTy) + return rewriter.notifyMatchFailure(op, "Only Tensor types supported"); + + if (!isa(selfTy.getElementType())) + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype legalization currently supported"); + + rewriter.replaceOpWithNewOp( + op, this->getTypeConverter()->convertType(op.getType()), self); -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenSigmoidOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value self = adaptor.getSelf(); - auto selfTy = cast(self.getType()); - if (selfTy && isa(selfTy.getElementType())) { - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), self); return success(); } - // Sigmoid legalization in TOSA for quantized element-type uses - // specialized tosa.table construct. - return rewriter.notifyMatchFailure( - op, "Only floating-point datatype legalization currently supported"); -} +}; template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -1205,39 +1208,63 @@ class ConvertAtenSqueezeAllDimsOp : public ConvertAtenSqueezeOp { } }; -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenPowTensorScalarOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { +template +class ConvertAtenPowOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { - Value self = adaptor.getSelf(); - auto selfTy = cast(self.getType()); + auto outType = + cast(this->getTypeConverter()->convertType(op.getType())); - if (!selfTy) - return rewriter.notifyMatchFailure( - op, "Only ranked tensor types supported in TOSA Pow"); + Value selfTensor; + if constexpr (std::is_same()) { + Value selfScalar = op.getSelf(); + if (failed(torchScalarToTosaTensor(rewriter, op, selfScalar, selfTensor, + outType.getElementType(), {}))) + return rewriter.notifyMatchFailure( + op, "Currently only scalar constants are supported for " + "conversion in TOSA PowScalar operation"); + } else { + selfTensor = adaptor.getSelf(); + auto selfTy = cast(selfTensor.getType()); - if (!isa(selfTy.getElementType())) - return rewriter.notifyMatchFailure( - op, "Only floating-point datatype legalization supported"); + if (!selfTy) + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types supported in TOSA Pow"); - auto outType = - cast(getTypeConverter()->convertType(op.getType())); + if (!isa(selfTy.getElementType())) + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype legalization supported"); + } - Value expTensor; - Value expScalar = op.getExponent(); - if (failed(torchScalarToTosaTensor(rewriter, op, expScalar, expTensor, - outType.getElementType(), {}))) - return rewriter.notifyMatchFailure( - op, "Currently only scalar constants are supported for " - "conversion in TOSA Pow operation"); + Value expTensor; + if constexpr (std::is_same()) { + Value expScalar = op.getExponent(); + if (failed(torchScalarToTosaTensor(rewriter, op, expScalar, expTensor, + outType.getElementType(), {}))) + return rewriter.notifyMatchFailure( + op, "Currently only scalar constants are supported for " + "conversion in TOSA Pow operation"); + } else { + expTensor = adaptor.getExponent(); + auto expTy = cast(expTensor.getType()); - auto powOp = tosa::createBinaryOpAndCast(rewriter, op, outType, - self, expTensor); - rewriter.replaceOp(op, powOp.getResult()); + if (!expTy) + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types supported in TOSA Pow"); + } - return success(); -} + auto powOp = tosa::createBinaryOpAndCast( + rewriter, op, outType, selfTensor, expTensor); + rewriter.replaceOp(op, powOp.getResult()); + + return success(); + } +}; // Perform the basic n-dim matmul operation encompassing the handling of // broadcasting and dynamic shape propagation. @@ -4243,32 +4270,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenLeTensorOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - - // Not a tensor type. - auto selfType = dyn_cast(adaptor.getSelf().getType()); - if (!selfType) - return rewriter.notifyMatchFailure( - op, "Only tensor types input are currently supported"); - auto otherType = dyn_cast(adaptor.getOther().getType()); - if (!otherType) - return rewriter.notifyMatchFailure( - op, "Only tensor types condition are currently supported"); - - auto outType = getTypeConverter()->convertType(op.getType()); - - auto greaterOp = rewriter.create( - op.getLoc(), outType, adaptor.getSelf(), adaptor.getOther()); - - rewriter.replaceOpWithNewOp(op, outType, - greaterOp.getOutput()); - - return success(); -} - template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenIscloseOp op, OpAdaptor adaptor, @@ -5815,6 +5816,9 @@ public: INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp) INSERT_UNARY_PATTERN(AtenCeilOp, tosa::CeilOp) INSERT_UNARY_PATTERN(AtenReciprocalOp, tosa::ReciprocalOp) + INSERT_UNARY_PATTERN(AtenCosOp, tosa::CosOp) + INSERT_UNARY_PATTERN(AtenSinOp, tosa::SinOp) + INSERT_UNARY_PATTERN(AtenLogicalNotOp, tosa::LogicalNotOp) #undef INSERT_UNARY_PATTERN #define INSERT_BINARY_PATTERN(AtenOp, TosaOp) \ @@ -5823,6 +5827,11 @@ public: INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp) INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp) INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp) + INSERT_BINARY_PATTERN(AtenLogicalXorOp, tosa::LogicalXorOp) + INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp, + tosa::LogicalLeftShiftOp) + INSERT_BINARY_PATTERN(AtenBitwiseRightShiftTensorOp, + tosa::ArithmeticRightShiftOp) #undef INSERT_BINARY_PATTERN #define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \ @@ -5843,11 +5852,14 @@ public: INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp) INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp) INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp) + INSERT_BINARY_COMPARE_PATTERN(AtenLeTensorOp, tosa::GreaterEqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenLeScalarOp, tosa::GreaterEqualOp) INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp, tosa::EqualOp) INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp, tosa::EqualOp) INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp, tosa::EqualOp) INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp, tosa::EqualOp) INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndTensorOp, tosa::BitwiseAndOp) + INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndScalarOp, tosa::BitwiseAndOp) INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseOrTensorOp, tosa::BitwiseOrOp) INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseXorTensorOp, tosa::BitwiseXorOp) #undef INSERT_BINARY_COMPARE_PATTERN @@ -5987,16 +5999,30 @@ public: INSERT_MASKED_FILL_PATTERN(AtenMaskedFillTensorOp); #undef INSERT_MASKED_FILL_PATTERN +#define INSERT_POW_OP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_POW_OP_PATTERN(AtenPowTensorScalarOp); + INSERT_POW_OP_PATTERN(AtenPowTensorTensorOp); + INSERT_POW_OP_PATTERN(AtenPowScalarOp); +#undef INSERT_POW_OP_PATTERN + +#define INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenOp, TosaOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, \ + context); + INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenTanhOp, tosa::TanhOp); + INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenSigmoidOp, tosa::SigmoidOp); + INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenErfOp, tosa::ErfOp); +#undef INSERT_ACTIVATION_FUNCITON_OP_PATTERN + #define INSERT_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context); - INSERT_ATENOP_PATTERN(AtenTanhOp); INSERT_ATENOP_PATTERN(AtenHardtanhBackwardOp); - INSERT_ATENOP_PATTERN(AtenSigmoidOp); INSERT_ATENOP_PATTERN(AtenReluOp); INSERT_ATENOP_PATTERN(AtenLeakyReluOp); INSERT_ATENOP_PATTERN(AtenArgmaxOp); - INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp); INSERT_ATENOP_PATTERN(AtenRsubScalarOp); INSERT_ATENOP_PATTERN(AtenConvolutionOp); INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); @@ -6023,7 +6049,6 @@ public: INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp); INSERT_ATENOP_PATTERN(AtenAbsOp); INSERT_ATENOP_PATTERN(AtenWhereSelfOp); - INSERT_ATENOP_PATTERN(AtenLeTensorOp); INSERT_ATENOP_PATTERN(AtenClampOp); INSERT_ATENOP_PATTERN(AtenArangeStartStepOp); INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index b45dbda05..bdb4d7f47 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1702,6 +1702,35 @@ TOSA_PASS_SET = { "ElementwiseRemainderTensorModule_Int_basic", "TriuBroadcastModule_basic", "TriuModule_basic", + "AtenHannWindowPeriodicFalseModule_basic", + "AtenHannWindowPeriodicTrueModule_basic", + "ElementwiseAndScalarModule_basic", + "ElementwiseAndScalarStaticShapeModule_basic", + "ElementwiseAtenLogicalNotOpModule_basic", + "ElementwiseAtenLogicalXorOpModule_basic", + "ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic", + "ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic", + "ElementwiseBitwiseAndScalarInt32Module_basic", + "ElementwiseBitwiseAndScalarInt64Module_basic", + "ElementwiseBitwiseLeftShiftInt32Module_basic", + "ElementwiseBitwiseLeftShiftInt64Module_basic", + "ElementwiseBitwiseLeftShiftInt8Module_basic", + "ElementwiseBitwiseRightShiftInt32Module_basic", + "ElementwiseBitwiseRightShiftInt64Module_basic", + "ElementwiseBitwiseRightShiftInt8Module_basic", + "ElementwiseCosModule_basic", + "ElementwiseErfModule_basic", + "ElementwiseLeFloatIntScalarModule_basic", + "ElementwiseLeFloatScalarModule_basic", + "ElementwiseLeFloatTensorNanModule_basic", + "ElementwiseLeIntScalarModule_basic", + "ElementwiseLeMixedIntScalarModule_basic", + "ElementwisePowScalarModule_basic", + "ElementwisePowTensorBroadcastModule_basic", + "ElementwisePowTensorBroadcastStaticModule_basic", + "ElementwisePowTensorModule_basic", + "ElementwisePowTensorStaticModule_basic", + "ElementwiseSinModule_basic", "ArgminIntModule_basic", "ArgminIntModule_multiple_mins", "ArgminModule_basic", @@ -2245,6 +2274,8 @@ MAKE_FX_TOSA_PASS_SET = ( | { ### Tests additionally passing in make_fx_tosa "AdaptiveAvgPool1dStaticLargerOutput_basic", + "ScaledDotProductAttentionBoolMaskModule_basic", + "ScaledDotProductAttentionDifferentDynamicCausalModule_basic", "ArgminIntModule_basic", "ArgminIntModule_multiple_mins", "ArgminModule_basic", @@ -3200,10 +3231,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "MultinomialModule_basic", "RenormModuleFloat16_basic", # REMOVE WHEN ENABLE_GQA IS ADDED - "ScaledDotProductAttentionBoolMaskModule_basic", - "ScaledDotProductAttentionDifferentCausalModule_basic", - "ScaledDotProductAttentionDifferentDynamicCausalModule_basic", - "ScaledDotProductAttentionSameCausalModule_basic", "ScatterAddStaticModule_basic", "TensorsConcatComplex128FloatModule_basic", "TensorsConcatComplex128IntModule_basic", @@ -3254,8 +3281,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "AtenEyeMModuleInt2D_basic", "AtenEyeModuleInt2D_basic", "AtenFloatScalarModule_basic", - "AtenHannWindowPeriodicTrueModule_basic", - "AtenHannWindowPeriodicFalseModule_basic", "AtenIntBoolOpConstFalseModule_basic", "AtenIntBoolOpConstTrueModule_basic", "AtenIntBoolOpModule_basic", @@ -3373,8 +3398,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "ElementwiseAcoshIntModule_basic", "ElementwiseAcoshModule_basic", "ElementwiseAddScalar_NumToTensorFloat_Module_basic", - "ElementwiseAndScalarModule_basic", - "ElementwiseAndScalarStaticShapeModule_basic", "ElementwiseAsinIntModule_basic", "ElementwiseAsinModule_basic", "ElementwiseAsinhIntModule_basic", @@ -3392,44 +3415,23 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "ElementwiseAtenLogicalAndOpModule_basic", "ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic", "ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic", - "ElementwiseAtenLogicalNotOpModule_basic", "ElementwiseAtenLogicalNotOpPromoteModule_basic", - "ElementwiseAtenLogicalXorOpModule_basic", - "ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic", - "ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic", - "ElementwiseBitwiseAndScalarInt32Module_basic", - "ElementwiseBitwiseAndScalarInt64Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic", - "ElementwiseBitwiseLeftShiftInt32Module_basic", - "ElementwiseBitwiseLeftShiftInt64Module_basic", - "ElementwiseBitwiseLeftShiftInt8Module_basic", - "ElementwiseBitwiseRightShiftInt32Module_basic", - "ElementwiseBitwiseRightShiftInt64Module_basic", - "ElementwiseBitwiseRightShiftInt8Module_basic", "ElementwiseClampMinTensorFloatModule_basic", "ElementwiseClampMinTensorIntModule_basic", "ElementwiseClampTensorFloatModule_basic", "ElementwiseClampTensorIntModule_basic", "ElementwiseCosIntModule_basic", - "ElementwiseCosModule_basic", "ElementwiseCoshIntModule_basic", "ElementwiseCoshModule_basic", "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", "ElementwiseErfIntModule_basic", - "ElementwiseErfModule_basic", "ElementwiseExpIntModule_basic", "ElementwiseExpm1IntModule_basic", "ElementwiseExpm1Module_basic", "ElementwiseGeluApproximateTanhModule_basic", - "ElementwiseHardshrinkModule_basic", - "ElementwiseHardshrinkStaticModule_basic", "ElementwiseIntTensorLtFloatScalarModule_basic", - "ElementwiseLeFloatIntScalarModule_basic", - "ElementwiseLeFloatScalarModule_basic", - "ElementwiseLeFloatTensorNanModule_basic", - "ElementwiseLeIntScalarModule_basic", - "ElementwiseLeMixedIntScalarModule_basic", "ElementwiseLog10IntModule_basic", "ElementwiseLog10Module_basic", "ElementwiseLog1pModule_basic", @@ -3440,18 +3442,12 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "ElementwiseMishModule_basic", "ElementwiseMulTensorComplexDiffModule_basic", "ElementwiseMulTensorComplexModule_basic", - "ElementwisePowScalarModule_basic", - "ElementwisePowTensorBroadcastModule_basic", - "ElementwisePowTensorBroadcastStaticModule_basic", - "ElementwisePowTensorModule_basic", - "ElementwisePowTensorStaticModule_basic", "ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic", "ElementwiseReciprocalIntModule_basic", "ElementwiseRsqrtIntModule_basic", "ElementwiseSigmoidIntModule_basic", "ElementwiseSinIntModule_basic", - "ElementwiseSinModule_basic", "ElementwiseSinhIntModule_basic", "ElementwiseSinhModule_basic", "ElementwiseTanIntModule_basic", @@ -4169,9 +4165,7 @@ ONNX_TOSA_XFAIL_SET = { "ElementwiseAtenLogicalOrOpNegativeModule_basic", "ElementwiseAtenLogicalOrOpRandomFloatModule_basic", "ElementwiseAtenLogicalOrOpRandomModule_basic", - "ElementwiseAtenLogicalXorOpModule_basic", "ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic", - "ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic", "ElementwiseBitwiseAndModule_basic", "ElementwiseBitwiseLeftShiftInt32Module_basic", "ElementwiseBitwiseLeftShiftInt64Module_basic", @@ -4190,7 +4184,6 @@ ONNX_TOSA_XFAIL_SET = { "ElementwiseClampModule_basic", "ElementwiseClampTensorInt8Module_basic", "ElementwiseCosIntModule_basic", - "ElementwiseCosModule_basic", "ElementwiseCoshIntModule_basic", "ElementwiseCoshModule_basic", "ElementwiseDequantizePerChannelModule_basic", @@ -4210,7 +4203,6 @@ ONNX_TOSA_XFAIL_SET = { "ElementwiseEqBoolScalarModule_basic", "ElementwiseEqDiffWidthScalarModule_basic", "ElementwiseErfIntModule_basic", - "ElementwiseErfModule_basic", "ElementwiseExpIntModule_basic", "ElementwiseExpm1IntModule_basic", "ElementwiseExpm1Module_basic", @@ -4222,7 +4214,6 @@ ONNX_TOSA_XFAIL_SET = { "ElementwiseGtMixed2ScalarModule_basic", "ElementwiseIntTensorLtFloatScalarModule_basic", "ElementwiseIsinfModule_basic", - "ElementwiseLeFloatTensorNanModule_basic", "ElementwiseLeMixedIntScalarModule_basic", "ElementwiseLog10IntModule_basic", "ElementwiseLog2IntModule_basic", @@ -4237,12 +4228,6 @@ ONNX_TOSA_XFAIL_SET = { "ElementwiseNanToNumModule_Basic", "ElementwiseOrTensorModule_basic", "ElementwiseOrTensorStaticShapeModule_basic", - "ElementwisePowModule_basic", - "ElementwisePowScalarModule_basic", - "ElementwisePowTensorBroadcastModule_basic", - "ElementwisePowTensorBroadcastStaticModule_basic", - "ElementwisePowTensorModule_basic", - "ElementwisePowTensorStaticModule_basic", "ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic", "ElementwiseReciprocalIntModule_basic", @@ -4255,7 +4240,6 @@ ONNX_TOSA_XFAIL_SET = { "ElementwiseSgnModule_basic", "ElementwiseSigmoidIntModule_basic", "ElementwiseSinIntModule_basic", - "ElementwiseSinModule_basic", "ElementwiseSinhIntModule_basic", "ElementwiseSinhModule_basic", "ElementwiseSqrtIntModule_basic", @@ -4414,8 +4398,6 @@ ONNX_TOSA_XFAIL_SET = { "LinalgNormKeepDimComplexModule_basic", "LinalgNormModule_basic", "LinalgVectorNormComplexModule_basic", - "LinalgVectorNormKeepDimModule_basic", - "LinalgVectorNormModule_basic", "LogSoftmaxBackwardModule_basic", "LogSoftmaxIntModule_basic", "MaskedFillTensorFloatValueModule_basic", @@ -4503,8 +4485,6 @@ ONNX_TOSA_XFAIL_SET = { "NativeGroupNormBackwardModule_basic", "NativeGroupNormModule_basic", "NativeLayerNormDynamicModule_basic", - "NativeLayerNormModule4D_basic", - "NativeLayerNormModule_basic", "NeFloatIntModule_basic", "NeIntModule_basic", "NewEmptyStridedModuleDefaultDtype_basic", diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 53128a669..4e2920708 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1672,3 +1672,190 @@ func.func @torch.aten.fmod.Tensor(%arg0: !torch.vtensor<[2, 4],f32>, %arg1: !tor %0 = torch.aten.fmod.Tensor %arg0, %arg1 : !torch.vtensor<[2, 4],f32>, !torch.vtensor<[2, 4],f32> -> !torch.vtensor<[2, 4],f32> return %0 : !torch.vtensor<[2, 4],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.logical_not( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,5],i1>) -> !torch.vtensor<[4,5],i1> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5],i1> -> tensor<4x5xi1> +// CHECK: %[[VAL_2:.*]] = tosa.logical_not %[[VAL_1]] : (tensor<4x5xi1>) -> tensor<4x5xi1> +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<4x5xi1> -> !torch.vtensor<[4,5],i1> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[4,5],i1> +// CHECK: } +func.func @torch.aten.logical_not(%arg0: !torch.vtensor<[4,5],i1>) -> !torch.vtensor<[4,5],i1> { + %0 = torch.aten.logical_not %arg0 : !torch.vtensor<[4,5],i1> -> !torch.vtensor<[4,5],i1> + return %0 : !torch.vtensor<[4,5],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.cos( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32> +// CHECK: %[[VAL_2:.*]] = tosa.cos %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.cos(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { + %0 = torch.aten.cos %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.sin( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32> +// CHECK: %[[VAL_2:.*]] = tosa.sin %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.sin(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { + %0 = torch.aten.sin %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.pow.Scalar( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.float 2.000000e+00 +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.pow %[[VAL_3]], %[[VAL_1]] : (tensor, tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.pow.Scalar(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { + %float2.000000e00 = torch.constant.float 2.000000e+00 + %0 = torch.aten.pow.Scalar %float2.000000e00, %arg0 : !torch.float, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.pow.Tensor_Tensor$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.pow %[[VAL_3]], %[[VAL_2]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> +// CHECK: } +func.func @torch.aten.pow.Tensor_Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.erf$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = tosa.erf %[[VAL_1]] : (tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> +// CHECK: } +func.func @torch.aten.erf$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.erf %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.bitwise_and.Scalar$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2> : tensor}> : () -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.bitwise_and %[[VAL_1]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],si32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],si32> +// CHECK: } +func.func @torch.aten.bitwise_and.Scalar$basic(%arg0: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { + %int2 = torch.constant.int 2 + %0 = torch.aten.bitwise_and.Scalar %arg0, %int2 : !torch.vtensor<[?,?],si32>, !torch.int -> !torch.vtensor<[?,?],si32> + return %0 : !torch.vtensor<[?,?],si32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.le.Tensor$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.greater_equal %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> +// CHECK: } +func.func @torch.aten.le.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { + %0 = torch.aten.le.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.le.Scalar$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.greater_equal %[[VAL_3]], %[[VAL_1]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> +// CHECK: } +func.func @torch.aten.le.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { + %int2 = torch.constant.int 2 + %0 = torch.aten.le.Scalar %arg0, %int2 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.logical_xor$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],i1>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],i1> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],i1> -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.logical_xor %[[VAL_3]], %[[VAL_2]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> +// CHECK: } +func.func @torch.aten.logical_xor$basic(%arg0: !torch.vtensor<[?,?],i1>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> { + %0 = torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[?,?],i1>, !torch.vtensor<[?,?],i1> -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.bitwise_left_shift.Tensor$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.logical_left_shift %[[VAL_3]], %[[VAL_2]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],si32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],si32> +// CHECK: } +func.func @torch.aten.bitwise_left_shift.Tensor$basic(%arg0: !torch.vtensor<[?,?],si32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { + %0 = torch.aten.bitwise_left_shift.Tensor %arg0, %arg1: !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32> + return %0: !torch.vtensor<[?,?],si32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.bitwise_right_shift.Tensor$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.arithmetic_right_shift %[[VAL_3]], %[[VAL_2]] {round = false} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],si32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],si32> +// CHECK: } +func.func @torch.aten.bitwise_right_shift.Tensor$basic(%arg0: !torch.vtensor<[?,?],si32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { + %0 = torch.aten.bitwise_right_shift.Tensor %arg0, %arg1: !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32> + return %0: !torch.vtensor<[?,?],si32> +}