diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 48c38b077..10f6ecb35 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -6778,6 +6778,108 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// Legalization for aten.as_strided +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenAsStridedOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // To lower aten.as_strided to TOSA, we will first reshape the input tensor to + // an 1-D tensor, then calculate the indices of result elements based on the + // output size, stride and storage offset. With the reshaped 1-D tensor and + // the indices, we can apply Gather to extract the required elements into a + // new tensor and then reshape it back to the desired output shape. + auto self = adaptor.getSelf(); + + // Not a tensor type + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + auto selfElemTy = selfType.getElementType(); + auto selfShape = selfType.getShape(); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + auto resultElemTy = resultType.getElementType(); + + // Get output size + SmallVector outputSize; + if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(outputSize))) + return rewriter.notifyMatchFailure( + op, "Only a constant list form of output size is supported"); + + // Get stride + SmallVector stride; + if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(stride))) + return rewriter.notifyMatchFailure( + op, "Only a constant list form of stride is supported"); + + // Get storage offset + int64_t offset; + if (!matchPattern(op.getStorageOffset(), m_TorchConstantInt(&offset))) + offset = 0; + + // Reshape input tensor into an 1-D tensor + int64_t selfNumElems = std::accumulate(selfShape.begin(), selfShape.end(), 1, + std::multiplies()); + + auto self1D = rewriter.create( + op->getLoc(), RankedTensorType::get({selfNumElems}, selfElemTy), self, + rewriter.getDenseI64ArrayAttr({selfNumElems})); + + // Calculate the target elements indices + SmallVector targetIndicesVec; + int64_t outputRank = outputSize.size(); + int64_t outputNumElems = std::accumulate(outputSize.begin(), outputSize.end(), + 1, std::multiplies()); + + for (int64_t i = 0; i < outputNumElems; i++) { + // Index formula: + // index[i] = coord_i_0 * stride[0] + coord_i_1 * stride[1] + ... + + // coord_i_n * stride[n] + int32_t index = offset; + int64_t coordFinder = i; + for (int64_t dim = 0; dim < outputRank; dim++) { + int64_t indexCoord = coordFinder % outputSize[outputRank - dim - 1]; + index += indexCoord * stride[outputRank - dim - 1]; + coordFinder /= outputSize[outputRank - dim - 1]; + } + targetIndicesVec.push_back(index); + } + + auto targetIndices = + tosa::getConstTensor(rewriter, op, targetIndicesVec, + makeShapeTorchCompatible({outputNumElems})) + .value(); + + // Convert PyTorch-style indices and dim into TensorFlow-style indices + auto targetIndicesTf = tosa::convertTorchIndexToTfIndices( + rewriter, op, self1D.getResult(), targetIndices, 0); + if (!targetIndicesTf) + return rewriter.notifyMatchFailure(op, + "Convert PyTorch-style indices and dim " + "to TensorFlow-style indices failed"); + + // Gather the target elements from 1-D input tensor + // Apply TensorFlow GatherNdOp with TensorFlow-style indices to retrieve the + // target elements + auto gatherOp = tosa::convertGatherNdOp( + rewriter, op, + RankedTensorType::get(makeShapeTorchCompatible({outputNumElems}), + resultElemTy), + self1D.getResult(), targetIndicesTf.value()); + + if (!gatherOp) + return rewriter.notifyMatchFailure(op, "Convert GatherNdOp failed"); + + auto result = rewriter.create( + op->getLoc(), resultType, gatherOp.value(), + rewriter.getDenseI64ArrayAttr(outputSize)); + + rewriter.replaceOp(op, {result.getResult()}); + + return success(); +} + } // namespace // ----------------------------------------------------------------------------- @@ -7096,6 +7198,7 @@ public: INSERT_ATENOP_PATTERN(AtenDiagEmbedOp); INSERT_ATENOP_PATTERN(AtenUniformOp); INSERT_ATENOP_PATTERN(AtenThresholdBackwardOp); + INSERT_ATENOP_PATTERN(AtenAsStridedOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 90479cf7f..377154586 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1723,6 +1723,9 @@ TOSA_CRASHING_SET = { } FX_IMPORTER_TOSA_CRASHING_SET = { + "Aten_TrilinearModuleSumAllDims_basic", + "Aten_TrilinearModuleSumdims_basic", + "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", "GridSamplerBasic1_basic", "GridSamplerBasic2_basic", "GridSamplerBasic3_basic", @@ -1744,6 +1747,13 @@ FX_IMPORTER_TOSA_CRASHING_SET = { # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "Aten_TrilinearModuleSumAllDims_basic", + "Aten_TrilinearModuleSumdims_basic", + "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", + "Aten_TrilinearModuleVaryingRanks_basic", + "Aten_TrilinearModule_basic", + "ElementwiseAddBoolModule_basic", + "Exp2StaticModule_basic", "CosineSimilarityStaticBroadcastModule_basic", "DropoutTrainStaticShapeModule_basic", "ElementwiseAtenLogicalAndOpModule_basic", @@ -1937,10 +1947,6 @@ TOSA_PASS_SET = { "ElementwiseTruncIntModule_basic", "ElementwiseSgnModule_basic", "ElementwiseSignIntModule_basic", - "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic", - "AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic", "AddCDivModule_basic", "AddCDiv_Module_basic", "AddCMulModule_basic", @@ -2292,7 +2298,6 @@ TOSA_PASS_SET = { "RreluWithNoiseBackwardTrainStaticModule_basic", "RepeatModule_basic", "RepeatInterleaveSelfIntNoDimModule_basic", - "ResNet18StaticModule_basic", "ReshapeAliasCollapseModule_basic", "ReshapeAliasExpandModule_basic", "ReshapeAsModule_basic", @@ -2416,6 +2421,9 @@ MAKE_FX_TOSA_PASS_SET = ( TOSA_PASS_SET | { ### Tests additionally passing in make_fx_tosa + "AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic", + "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", + "ResNet18StaticModule_basic", "AdaptiveAvgPool1dStaticLargerOutput_basic", "ScaledDotProductAttentionBoolMaskModule_basic", "ScaledDotProductAttentionDifferentDynamicCausalModule_basic", @@ -2464,9 +2472,7 @@ MAKE_FX_TOSA_PASS_SET = ( "MaxPool1dEmptyStrideStaticModule_basic", "MaxPool1dStaticCeilModeTrueModule_basic", "MaxPool1dStaticModule_basic", - "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool1dStaticEvenMultiple_basic", "CosineSimilarityModule_basic", "NativeGroupNormBackwardModule_basic", "ReduceFrobeniusNormKeepDimModule_basic", @@ -2474,8 +2480,6 @@ MAKE_FX_TOSA_PASS_SET = ( "SliceWholeTensorModule_basic", "TensorFloatModule_basic", "TensorIntModule_basic", - "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", "RepeatInterleaveSelfIntModule_basic", "TorchPrimLoopForLikeTensorArgModule_basic", "ViewSizeDimFollowedByCollapsedOnesModule_basic", @@ -2492,13 +2496,6 @@ MAKE_FX_TOSA_PASS_SET = ( } ) - { ### Test failing in make_fx_tosa but not in tosa - "ChunkListUnpackUneven_Module_basic", - "ChunkListUnpack_Module_basic", - "SplitTensorGetItem_Module_basic", - "SplitTensorLastSmallerModule_basic", - "SplitTensorListUnpackModule_basic", - "SplitTensorNegativeDimModule_basic", - "SplitWithSizesListUnpackModule_basic", # Dynamic shape, has extra unsupported broadcast ops "Matmul_3d", # Unimplemented operator 'aten._index_put_impl_.hacked_twin' @@ -3367,6 +3364,8 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | { } FX_IMPORTER_TOSA_XFAIL_SET = { + "Aten_TrilinearModuleVaryingRanks_basic", + "Aten_TrilinearModuleZerodDimBug_basic", "AdaptiveMaxPool1dDimOneStatic_basic", "ElementwiseRreluWithNoiseTrainModule_basic", "ElementwiseRreluWithNoiseTrainStaticModule_basic", @@ -3387,16 +3386,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "SliceOutOfUpperBoundIndexModule_basic", "SliceOutOfUpperBoundIndexStaticModule_basic", "SliceStartEqEndModule_basic", - "ChunkListUnpackDynamic_Module_basic", - "ChunkListUnpackUnevenDynamic_Module_basic", - "ChunkListUnpackUneven_Module_basic", - "ChunkListUnpack_Module_basic", - "SplitTensorGetItem_Module_basic", - "SplitTensorLastSmallerModule_basic", - "SplitTensorListUnpackModule_basic", - "SplitTensorNegativeDimModule_basic", - "SplitWithSizesListUnpackModule_basic", - "SplitWithSizes_Module_basic", "ElementwiseCreateComplexModule_basic", "AtenPolarDoubleModule_basic", "AtenPolarFloatModule_basic", @@ -3572,7 +3561,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "ElementwiseAcosModule_basic", "ElementwiseAcoshIntModule_basic", "ElementwiseAcoshModule_basic", - "ElementwiseAddScalar_NumToTensorFloat_Module_basic", "ElementwiseAsinIntModule_basic", "ElementwiseAsinModule_basic", "ElementwiseAsinhIntModule_basic", @@ -3608,7 +3596,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "ElementwiseLog1pModule_basic", "ElementwiseLog2IntModule_basic", "ElementwiseLogIntModule_basic", - "ElementwiseLogSigmoidModule_basic", "ElementwiseLogitModule_basic", "ElementwiseMishModule_basic", "ElementwiseMulTensorComplexDiffModule_basic", @@ -3731,8 +3718,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "NormScalarModule_basic", "NormScalarOptDimKeepDimComplexModule_basic", "NormalFunctionalModule_basic", - "NumToTensorFloatModule_basic", - "NumToTensorIntModule_basic", "NumelModule_basic", "NumelZeroRankModule_basic", "OnesLikeModule_falsePinMemory", @@ -3790,7 +3775,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "ReplicationPad2dModule_right0", "ReplicationPad2dModule_top0", "RollModule_basic", - "RsubInt0d_NumToTensor_Module_basic", "RsubIntModule_noalpha_basic", "ScalarConstantTupleModule_basic", "ScalarImplicitFloatModule_basic", @@ -3873,6 +3857,55 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "ViewSizeFromOtherTensor_basic", "VisionTransformerModule_basic", "ZerosLikeModule_falsePinMemory", + # count_include_pad and divisor_override check in TOSA AvgPool2d + "AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule_basic", + "AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic", + "AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic", + "AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", + "ResNet18Module_basic", + "ResNet18StaticModule_basic", + "MobilenetV3Module_basic", + # Unexpected failures due to new PyTorch version update + "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", + "AdaptiveAvgPool1dGeneralDynamic_basic", + "AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool1dStaticEvenMultiple_basic", + "AdaptiveAvgPool1dStaticLargerOutput_basic", + "AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool2dDynamicNoBatch_basic", + "AdaptiveAvgPool2dDynamic_basic", + "CrossEntropyLossModule_basic", + "CrossEntropyLossNoReductionModule_basic", + "ElementwiseRreluTrainModule_basic", + "ElementwiseRreluTrainStaticModule_basic", + "IndexPutImpl1DFloatNonAccumulateModule_basic", + "IndexPutImpl1DIntNonAccumulateModule_basic", + "IndexPutImpl2DFloatNonAccumulateModule_basic", + "IndexPutImpl3DFloatNonAccumulateModule_basic", + "IouOfModule_basic", + "MaxPool1dEmptyStrideStaticModule_basic", + "MaxPool1dStaticCeilModeTrueModule_basic", + "MaxPool1dStaticModule_basic", + "MeshgridIndexingIJ_basic", + "MeshgridIndexingXY_basic", + "Meshgrid_basic", + "OneHotModule_basic", + "ReduceFrobeniusNormKeepDimModule_basic", + "ReduceFrobeniusNormModule_basic", + "RepeatInterleaveSelfIntModule_basic", + "ScaledDotProductAttentionBoolMaskModule_basic", + "ScaledDotProductAttentionDifferentCausalModule_basic", + "ScaledDotProductAttentionDifferentDynamicCausalModule_basic", + "ScaledDotProductAttentionDifferentModule_basic", + "ScaledDotProductAttentionMaskModule_basic", + "ScaledDotProductAttentionSameCausalModule_basic", + "ScaledDotProductAttentionSameDynamicModule_basic", + "ScaledDotProductAttentionSameModule_basic", } ONNX_TOSA_CRASHING_SET = { @@ -3885,6 +3918,7 @@ ONNX_TOSA_CRASHING_SET = { } ONNX_TOSA_XFAIL_SET = { + "Exp2StaticModule_basic", "ElementwiseRreluWithNoiseEvalModule_basic", "ElementwiseRreluWithNoiseEvalStaticModule_basic", "ElementwiseRreluWithNoiseTrainModule_basic", diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 2cf2486e7..ed679e852 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -2275,3 +2275,41 @@ func.func @torch.aten.uniform$basic(%arg0: !torch.vtensor<[3,4],f64>) -> (!torch %0 = torch.aten.uniform %arg0, %float1.000000e00, %float1.000000e01, %none : !torch.vtensor<[3,4],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[3,4],f64> return %0, %0 : !torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.as_strided$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[3,3],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_5:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_5]], %[[VAL_5]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<5x5xf32>) -> tensor<25xf32> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<[0, 1, 2, 2, 3, 4, 4, 5, 6]> : tensor<9xi32>}> : () -> tensor<9xi32> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor<9xi32>) -> tensor<9x1xi32> +// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_10]] {axis = 1 : i32} : (tensor<9x1xi32>) -> tensor<9x1xi32> +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor<25xf32>) -> tensor<1x25x1xf32> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<9x1xi32>) -> tensor<9x1xi32> +// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_14]] {shift = 0 : i8} : (tensor<9x1xi32>, tensor<1xi32>) -> tensor<9x1xi32> +// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<9x1xi32>) -> tensor<9x1xi32> +// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<9x1xi32>) -> tensor<1x9xi32> +// CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_12]], %[[VAL_17]] : (tensor<1x25x1xf32>, tensor<1x9xi32>) -> tensor<1x9x1xf32> +// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<1x9x1xf32>) -> tensor<9xf32> +// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array} : (tensor<9xf32>) -> tensor<3x3xf32> +// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<3x3xf32> -> !torch.vtensor<[3,3],f32> +// CHECK: return %[[VAL_21]] : !torch.vtensor<[3,3],f32> +// CHECK: } +func.func @torch.aten.as_strided$basic(%arg0: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[3,3],f32> { + %none = torch.constant.none + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.aten.as_strided %arg0, %0, %1, %none : !torch.vtensor<[5,5],f32>, !torch.list, !torch.list, !torch.none -> !torch.vtensor<[3,3],f32> + return %2 : !torch.vtensor<[3,3],f32> +}