diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 776721814..e5f4fea4f 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -4360,6 +4360,221 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// Legalization for aten.scatter.src +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenScatterSrcOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // Not a tensor type. + auto input = adaptor.getSelf(); + auto inputType = dyn_cast(input.getType()); + if (!inputType) + return rewriter.notifyMatchFailure( + op, "Only RankedTensorType inputs are currently supported"); + + auto inputShape = inputType.getShape(); + auto paramsRank = inputType.getRank(); + + auto index = adaptor.getIndex(); + auto indexType = dyn_cast(index.getType()); + if (!indexType) + return rewriter.notifyMatchFailure( + op, "Only RankedTensorType indices are currently supported"); + + // Check `index` and `input` param should have the same rank + if (indexType.getRank() != paramsRank) + return rewriter.notifyMatchFailure( + op, "Params index and input should have the same rank"); + + auto indexShape = indexType.getShape(); + + auto src = adaptor.getSrc(); + auto srcType = dyn_cast(src.getType()); + if (!srcType) + return rewriter.notifyMatchFailure( + op, "Only RankedTensorType sources are currently supported"); + + // Check `src` and `input` param should have the same rank + if (srcType.getRank() != paramsRank) + return rewriter.notifyMatchFailure( + op, "Src and input should have the same rank"); + + auto srcShape = srcType.getShape(); + + // Dynamic shape check + if (!inputType.hasStaticShape() || !indexType.hasStaticShape() || + !srcType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "Support for dynamic shape not implemented"); + + // index i64 to i32 for tosa compatitable + if (indexType.getElementType() != rewriter.getIntegerType(32)) { + index = rewriter.create( + op->getLoc(), + RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index); + } + + // Get positive dim + int64_t dim{0}; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure(op, + "Dim value should be a constant int"); + + dim = toPositiveDim(dim, paramsRank); + if (!isValidDim(dim, paramsRank)) + return rewriter.notifyMatchFailure(op, "Dim is invalid"); + + // It is also required that index.size(d) <= src.size(d) for all dimensions d, + // and that index.size(d) <= self.size(d) for all dimensions d != dim + for (int64_t d = 0; d < paramsRank; d++) { + if (d != dim) { + if (indexShape[d] > srcShape[d] || indexShape[d] > inputShape[d]) + return rewriter.notifyMatchFailure( + op, "Index size should be smaller or equal to src or input size " + "for all dimensions d != dim"); + } + } + + // Get the output type + auto outType = getTypeConverter()->convertType(op.getType()); + + // convert PyTorch style index and dim into TensorFlows tyle indices + // tensor<[1,4,2],si64> -> tensor<[1,4,2,3],si64> + auto indicesTf = + tosa::convertTorchIndexToTfIndices(rewriter, op, input, index, dim); + if (!indicesTf) + return rewriter.notifyMatchFailure( + op, "Convert PyTorch index and dim to TensorFlow indices failed"); + + // Perform the TensorFlow ScatterNd algorithm with TensorFlow style indices as + // input. + auto result = tosa::convertScatterNdOp(rewriter, op, outType, input, + indicesTf.value(), src); + + if (!result) + return rewriter.notifyMatchFailure(op, "Convert ScatterNdOp failed"); + + rewriter.replaceOp(op, {result.value()}); + return success(); +} + +// Legalization for aten.slice_scatter +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenSliceScatterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // Not a tensor type. + auto input = adaptor.getSelf(); + auto inputType = dyn_cast(input.getType()); + if (!inputType) + return rewriter.notifyMatchFailure( + op, "Only RankedTensorType inputs are currently supported"); + + auto inputShape = inputType.getShape(); + auto paramsRank = inputType.getRank(); + + auto src = adaptor.getSrc(); + auto srcType = dyn_cast(src.getType()); + if (!srcType) + return rewriter.notifyMatchFailure( + op, "Only RankedTensorType sources are currently supported"); + + // Check `src` and `input` param should have the same rank + if (srcType.getRank() != paramsRank) + return rewriter.notifyMatchFailure( + op, "Src and input should have the same rank"); + + auto srcShape = srcType.getShape(); + + // Dynamic shape check + if (!inputType.hasStaticShape() || !srcType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "Support for dynamic shape not implemented"); + + // Get positive dim + int64_t dim{0}; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure(op, + "Dim value should be a constant int"); + + dim = toPositiveDim(dim, paramsRank); + if (!isValidDim(dim, paramsRank)) + return rewriter.notifyMatchFailure(op, "Dim is invalid"); + + // Get start, end, and step params + // If start and end params are not specified, assign them to 0 and + // inputShape[dim], respectively. + int64_t start{0}; + if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) + return rewriter.notifyMatchFailure(op, + "Start value should be a constant int"); + if (start < 0) + start += inputShape[dim]; + + int64_t end{inputShape[dim]}; + if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) + return rewriter.notifyMatchFailure(op, + "End value should be a constant int"); + if (end < 0) + end += inputShape[dim]; + + if (end > inputShape[dim]) + end = inputShape[dim]; + + if (start >= end) + return rewriter.notifyMatchFailure( + op, "Start value greater than end value not supported"); + + int64_t step{1}; + if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) + return rewriter.notifyMatchFailure(op, + "Step value should be a constant int"); + + // Create PyTorch style scatter index based on start, end, and step values + int64_t outerRepeat{1}, innerRepeat{1}; + for (int64_t i = 0; i < dim; i++) + outerRepeat *= srcShape[i]; + + for (int64_t i = dim + 1; i < paramsRank; i++) + innerRepeat *= srcShape[i]; + + SmallVector indexVec; + for (int64_t i = 0; i < outerRepeat; i++) { + for (int32_t indexVal = start; indexVal < end; indexVal += step) { + for (int64_t j = 0; j < innerRepeat; j++) { + indexVec.push_back(indexVal); + } + } + } + + Value index = + tosa::getConstTensor(rewriter, op, indexVec, srcShape).value(); + + // Get the output type + auto outType = getTypeConverter()->convertType(op.getType()); + + // convert PyTorch style index and dim into TensorFlows tyle indices + // tensor<[1,4,2],si64> -> tensor<[1,4,2,3],si64> + auto indicesTf = + tosa::convertTorchIndexToTfIndices(rewriter, op, input, index, dim); + if (!indicesTf) + return rewriter.notifyMatchFailure( + op, "Convert PyTorch index and dim to TensorFlow indices failed"); + + // Perform the TensorFlow ScatterNd algorithm with TensorFlow style indices as + // input. + auto result = tosa::convertScatterNdOp(rewriter, op, outType, input, + indicesTf.value(), src); + + if (!result) + return rewriter.notifyMatchFailure(op, "Convert ScatterNdOp failed"); + + rewriter.replaceOp(op, {result.value()}); + return success(); +} + template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenAbsOp op, OpAdaptor adaptor, @@ -6099,6 +6314,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( dim2 = toPositiveDim(dim2, selfRank); } + if (dim1 == dim2) + return rewriter.notifyMatchFailure(op, + "Values dim1 and dim2 cannot be equal"); + auto selfShape = makeShapeTorchCompatible(selfType.getShape()); int64_t h = selfShape[dim1]; int64_t w = selfShape[dim2]; @@ -6122,13 +6341,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( SmallVector transposedDims; transposedInputShape.clear(); - for (int64_t i = 0; i < selfRank; ++i) { + for (int32_t i = 0; i < selfRank; ++i) { if (i == dim1 || i == dim2) continue; transposedDims.push_back(i); } - transposedDims.push_back(dim1); - transposedDims.push_back(dim2); + transposedDims.push_back(static_cast(dim1)); + transposedDims.push_back(static_cast(dim2)); auto transposedDimsConst = tosa::getConstTensor( rewriter, op, @@ -6213,6 +6432,193 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// Legalization for aten.diag_embed +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenDiagEmbedOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // To perform diag_embed, we will apply scatter with a newly created diagonal + // index tensor over a constant zero tensor. + // To make it simpler, we will only scatter using the diagonal with respect + // to the two innermost dimensions, then permute the output tensor to the + // correct order of dimensions. + auto self = adaptor.getSelf(); + + // Not a ranked tensor type + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types are supported"); + + auto selfRank = selfType.getRank(); + int64_t outRank = selfRank + 1; + + auto selfShape = makeShapeTorchCompatible(selfType.getShape()); + int64_t diagSize = selfShape[selfRank - 1]; + + if (!selfType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "Currently only static shapes are supported"); + + const TypeConverter *typeConverter = this->getTypeConverter(); + RankedTensorType resultType = cast( + typeConverter->convertType(op->getResult(0).getType())); + if (!resultType) + return rewriter.notifyMatchFailure(op, "Result type cannot be empty"); + + auto selfElemTy = selfType.getElementType(); + auto resultElemTy = resultType.getElementType(); + + int64_t offset{0}; + if (!matchPattern(op.getOffset(), m_TorchConstantInt(&offset))) + return rewriter.notifyMatchFailure(op, + "Offset value should be a constant int"); + + // dim1 default is -2 + int64_t dim1{outRank - 2}; + if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1))) + return rewriter.notifyMatchFailure(op, + "Dim1 value should be a constant int"); + dim1 = toPositiveDim(dim1, outRank); + + // dim2 default is -1 + int64_t dim2{outRank - 1}; + if (!matchPattern(op.getDim2(), m_TorchConstantInt(&dim2))) + return rewriter.notifyMatchFailure(op, + "Dim2 value should be a constant int"); + dim2 = toPositiveDim(dim2, outRank); + + if (dim1 == dim2) + return rewriter.notifyMatchFailure(op, "Dim1 and dim2 cannot be equal"); + + // If offset is smaller than 0, we will swap dim1 and dim2 and convert offset + // to a positive value + if (offset < 0) { + std::swap(dim1, dim2); + offset = std::abs(offset); + } + + // Create the diagonal index tensor + int64_t repeat = 1; + for (int64_t i = 0; i < selfRank - 1; i++) + repeat *= selfShape[i]; + + SmallVector indexVec; + for (int32_t i = 0; i < repeat; i++) { + for (int32_t j = offset; j < diagSize + offset; j++) + indexVec.push_back(j); + } + + SmallVector indexShape = llvm::to_vector(selfShape); + indexShape.push_back(1); + + auto index = tosa::getConstTensor(rewriter, op, + /*vec=*/indexVec, + /*shape=*/indexShape) + .value(); + + // Reshape the input tensor to be the same shape as the new index tensor to + // act as the src for scattering + auto scatterSrc = rewriter.create( + op->getLoc(), + RankedTensorType::get(makeShapeTorchCompatible(indexShape), selfElemTy), + self, rewriter.getDenseI64ArrayAttr(indexShape)); + + // Create a const zero tensor to scatter the input onto + SmallVector zeroShape; + for (int64_t i = 0; i < selfRank - 1; i++) + zeroShape.push_back(selfShape[i]); + zeroShape.push_back(diagSize + offset); + zeroShape.push_back(diagSize + offset); + + int64_t numElemOfZeroTensor = 1; + for (int64_t &d : zeroShape) + numElemOfZeroTensor *= d; + + Value zero = + TypeSwitch(selfElemTy) + .Case([&](auto) { + return tosa::getConstTensor( + rewriter, op, SmallVector(numElemOfZeroTensor, 0), + zeroShape) + .value(); + }) + .Case([&](auto intType) { + switch (intType.getWidth()) { + case 1: + return tosa::getConstTensor( + rewriter, op, + SmallVector(numElemOfZeroTensor, 0), zeroShape) + .value(); + case 32: + return tosa::getConstTensor( + rewriter, op, + SmallVector(numElemOfZeroTensor, 0), + zeroShape) + .value(); + case 64: + return tosa::getConstTensor( + rewriter, op, + SmallVector(numElemOfZeroTensor, 0), + zeroShape) + .value(); + } + llvm_unreachable("Invalid integer width"); + }); + + // Convert PyTorch index and dim to TensorFlow-style indices + auto indicesTf = tosa::convertTorchIndexToTfIndices(rewriter, op, zero, index, + outRank - 1); + if (!indicesTf) + return rewriter.notifyMatchFailure( + op, "Convert PyTorch index and dim to TensorFlow indices failed"); + + // Perform the TensorFlow ScatterNd algorithm with TensorFlow-style indices as + // input + auto diagonalTensor = tosa::convertScatterNdOp( + rewriter, op, + RankedTensorType::get(makeShapeTorchCompatible(zeroShape), resultElemTy), + zero, indicesTf.value(), scatterSrc.getResult()); + if (!diagonalTensor) + return rewriter.notifyMatchFailure(op, "Convert ScatterNdOp failed"); + + // Create the final dims order to permute the scattered tensor + SmallVector permutedDims(outRank, 0); + int32_t currentDim = 0; + int32_t i = 0; + + while (i < outRank) { + if (i == dim1) { + permutedDims[i] = outRank - 2; + i++; + continue; + } + + if (i == dim2) { + permutedDims[i] = outRank - 1; + i++; + continue; + } + + permutedDims[i] = currentDim; + currentDim++; + i++; + } + + auto permutedDimsConst = + tosa::getConstTensor(rewriter, op, + /*vec=*/permutedDims, + /*shape=*/{static_cast(outRank)}); + + auto result = rewriter.create(op->getLoc(), resultType, + diagonalTensor.value(), + permutedDimsConst.value()); + + rewriter.replaceOp(op, result.getResult()); + + return success(); +} + } // namespace // ----------------------------------------------------------------------------- @@ -6442,6 +6848,7 @@ public: context); INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1); INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0); + INSERT_CONSTANT_FILL_PATTERN(AtenEmptyMemoryFormatOp, 0); #undef INSERT_CONSTANT_FILL_PATTERN #define INSERT_FILL_PATTERN(AtenOp) \ @@ -6524,6 +6931,9 @@ public: INSERT_ATENOP_PATTERN(AtenIndexSelectOp); INSERT_ATENOP_PATTERN(AtenFlipOp); INSERT_ATENOP_PATTERN(AtenRoundOp); + INSERT_ATENOP_PATTERN(AtenScatterSrcOp); + INSERT_ATENOP_PATTERN(AtenSliceScatterOp); + INSERT_ATENOP_PATTERN(AtenDiagEmbedOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 326a7afe8..e7512fc89 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1650,12 +1650,18 @@ STABLEHLO_CRASHING_SET = { } TOSA_CRASHING_SET = { + "ArangeStartOutDtypeModule_basic", + "ArangeStartOutModule_basic", + "ScatterSrcStaticModule_basic", # Runtime op verification: Out of bounds access "IndexTensorNegativeIndexModule_basic", "ReduceAllDimEmpty_basic", } FX_IMPORTER_TOSA_CRASHING_SET = { + "ScatterSrcModule_basic", + "ScatterSrcStaticModule_basic", + "HBC_basic", "IndexTensorNegativeIndexModule_basic", "InterpolateDynamicModule_scales_recompute_bilinear", "InterpolateDynamicModule_sizes_bilinear", @@ -1671,6 +1677,26 @@ FX_IMPORTER_TOSA_CRASHING_SET = { # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "EmptyModule_contiguous", + "EmptyModule_defaultDtype", + "EmptyModule_falsePinMemory", + "EmptyModule_float", + "EmptyModule_int", + "EmptyModule_uint8", + "EmptyStridedModule_basic", + "NewEmptyModuleDefaultDtype_basic", + "NewEmptyModuleFalsePinMemory_basic", + "NewEmptyModuleFloat2D_basic", + "NewEmptyModuleFloat3D_basic", + "NewEmptyModuleInt2D_basic", + "NewEmptyModuleInt3D_basic", + "NewEmptyModuleLayoutIntDtype_basic", + "NewEmptyModuleNonDefaultFloatDtype_basic", + "NewEmptyModuleNonDefaultIntDtype_basic", + "NewEmptyStridedModuleDefaultDtype_basic", + "SelectScattertStaticModule_basic", + "SliceScatterStaticModule_basic", + "TensorAlloc1dStaticModule_basic", "AtenRoundFloatHalfToEvenModule_basic", "AtenRoundFloatModule_basic", "FakeQuantizePerTensorAffineCachemaskModule_basic", @@ -3248,6 +3274,12 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | { } FX_IMPORTER_TOSA_XFAIL_SET = { + "ViewDtypeStaticModule_basic", + "Unfold_Module_Dynamic_basic", + "Unfold_Module_Rank_4", + "Unfold_Module_Rank_Zero_Size_Zero_basic", + "Unfold_Module_Rank_Zero_basic", + "Unfold_Module_basic", "ArangeZeroElementOutputModule_basic", "NumpyTRank0Module_basic", "Permute0RankModule_basic", @@ -3338,12 +3370,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "AtenComplexImagModule_basic", "AtenComplexRealModule_basic", "AtenComplexViewModule_basic", - "AtenDiagEmbedDefaultDiag_basic", - "AtenDiagEmbedDimDiag_basic", - "AtenDiagEmbedNegOffsetDiag_basic", - "AtenDiagEmbedNonDefault4DDiag_basic", - "AtenDiagEmbedOffsetDiag_basic", - "AtenDiagEmbedRevDimDiag_basic", "AtenEmbeddingBagStaticModule_basic", "AtenEmbeddingBagSumExample_basic", "AtenEyeMModuleInt2D_basic", @@ -3513,31 +3539,13 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "ElementwiseToDtypeI64ToUI8Module_basic", "ElementwiseUnaryIntModule_basic", "ElementwiseWhereScalarOtherStaticModule_basic", - "EmptyLikeMemoryFormatModule_basic", - "EmptyLikeModule_defaultDtype", - "EmptyLikeModule_falsePinMemory", - "EmptyLikeModule_float", - "EmptyLikeModule_int", - "EmptyModule_contiguous", - "EmptyModule_defaultDtype", - "EmptyModule_falsePinMemory", - "EmptyModule_float", - "EmptyModule_int", - "EmptyModule_uint8", - "EmptyStridedModule_basic", - "EmptyStridedSizeIntStrideModule_basic", "EqIntModule_basic", "ExpandModule_basic", "ExponentialModule_basic", "FloatImplicitModule_basic", "FullLikeModuleInt2D_basic", "FullLikeModuleInt3D_basic", - "FullModuleDefaultDtype_basic", - "FullModuleFalsePinMemory_basic", - "FullModuleFloat2D_basic", - "FullModuleFloat3D_basic", "FullModuleInt2D_basic", - "FullModuleInt3D_basic", "GeFloatIntModule_basic", "GeFloatModule_basic", "GeIntModule_basic", @@ -3547,7 +3555,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "GridSamplerBasic4_basic", "GtFloatIntModule_basic", "GtIntModule_basic", - "HBC_basic", "IndexPut1DFloatAccumulateModule_basic", "IndexPut1DFloatNonAccumulateModule_basic", "IndexPut1DIntAccumulateModule_basic", @@ -3599,7 +3606,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "LinalgVectorNormComplexModule_basic", "LinspaceDtypeModule_basic", "LinspaceEmptyModule_basic", - "LinspaceOneSizeModule_basic", "MaskedFillTensorFloatValueModule_basic", "MatmulBroadcastBatchDim_basic", "MatmulStaticBroadcast_basic", @@ -3653,16 +3659,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "NativeGroupNormBackwardModule_basic", "NeFloatIntModule_basic", "NeIntModule_basic", - "NewEmptyModuleDefaultDtype_basic", - "NewEmptyModuleFalsePinMemory_basic", - "NewEmptyModuleFloat2D_basic", - "NewEmptyModuleFloat3D_basic", - "NewEmptyModuleInt2D_basic", - "NewEmptyModuleInt3D_basic", - "NewEmptyModuleLayoutIntDtype_basic", - "NewEmptyModuleNonDefaultFloatDtype_basic", - "NewEmptyModuleNonDefaultIntDtype_basic", - "NewEmptyStridedModuleDefaultDtype_basic", "NewFullModuleInt2D_basic", "NewFullModuleInt3D_basic", "NllLossModuleBackward1DMeanWeight_basic", @@ -3671,13 +3667,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "NllLossModuleBackward1DSum_basic", "NllLossModuleBackward1DWeight_basic", "NllLossModuleBackward1D_basic", - "NllLossModuleBackwardMeanWeight_basic", - "NllLossModuleBackwardMean_basic", - "NllLossModuleBackwardSumWeight_basic", - "NllLossModuleBackwardSum_basic", - "NllLossModuleBackwardWeight_basic", - "NllLossModuleBackward_basic", - "NllLossModuleBackward_ignore_index", "NormScalarComplexModule_basic", "NormScalarModule_basic", "NormScalarOptDimKeepDimComplexModule_basic", @@ -3777,26 +3766,14 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "ScatterSrcStaticModule_basic", "ScatterValueFloatModule_basic", "ScatterValueIntModule_basic", - "SelectScattertModule_basic", - "SelectScattertStaticModule_basic", "SignAndLogarithmOfDeterminantModule_F32", "SignAndLogarithmOfDeterminantBatchedModule_F32", "SignAndLogarithmOfDeterminantDynamicModule_F32", "SliceStaticComplexInputModule_basic", - "SliceCopyEndGreaterThanDimSize_Module_basic", - "SliceCopyNegative_Module_basic", - "SliceCopyNonZeroDim_Module_basic", "SliceCopyStartGreaterThanDimSize_Module_basic", - "SliceCopy_Module_basic", "SliceEndSleStartModule_basic", "SliceOutOfLowerBoundEndIndexModule_basic", "SliceOutOfLowerBoundStartIndexModule_basic", - "SliceScatterModule_basic", - "SliceScatterNegativeDimModule_basic", - "SliceScatterNegativeEndModule_basic", - "SliceScatterStaticModule_basic", - "SliceScatterStepVariationModule_basic", - "SliceScatterZeroDimModule_basic", "SliceSizeTwoStepModule_basic", "SoftplusModule_basic", "SortIntListReverse_basic", @@ -3864,6 +3841,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = { } ONNX_TOSA_CRASHING_SET = { + "ScatterSrcStaticModule_basic", "StdCorrectionEmptyDimModule_basic", "StdDimEmptyDimModule_basic", "VarCorrectionEmptyDimModule_basic", @@ -3872,6 +3850,11 @@ ONNX_TOSA_CRASHING_SET = { } ONNX_TOSA_XFAIL_SET = { + "Unfold_Module_Dynamic_basic", + "Unfold_Module_Rank_4", + "Unfold_Module_Rank_Zero_Size_Zero_basic", + "Unfold_Module_Rank_Zero_basic", + "ViewDtypeStaticModule_basic", "ArangeZeroElementOutputModule_basic", "LinspaceEmptyModule_basic", "RepeatInterleaveSelfIntNoDimModule_basic", diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index e569fed7f..e412bb390 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1998,3 +1998,136 @@ func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso %0 = torch.aten.round %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.empty.memory_format$basic() -> !torch.vtensor<[3,4],si64> { +// CHECK: %[[VAL_0:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_1:.*]] = torch.constant.bool false +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 4 +// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_6:.*]] = torch.constant.device "cpu" +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0> : tensor<3x4xi32>}> : () -> tensor<3x4xi32> +// CHECK: %[[VAL_8:.*]] = tosa.cast %[[VAL_7]] : (tensor<3x4xi32>) -> tensor<3x4xi64> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<0> : tensor<3x4xi64>}> : () -> tensor<3x4xi64> +// CHECK: %[[VAL_10:.*]] = tosa.cast %[[VAL_9]] : (tensor<3x4xi64>) -> tensor<3x4xi64> +// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<3x4xi64> -> !torch.vtensor<[3,4],si64> +// CHECK: return %[[VAL_11]] : !torch.vtensor<[3,4],si64> +// CHECK: } +func.func @torch.aten.empty.memory_format$basic() -> !torch.vtensor<[3,4],si64> { + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %none = torch.constant.none + %int3 = torch.constant.int 3 + %int4 = torch.constant.int 4 + %0 = torch.prim.ListConstruct %int3, %int4 : (!torch.int, !torch.int) -> !torch.list + %cpu = torch.constant.device "cpu" + %1 = torch.aten.empty.memory_format %0, %int4, %none, %cpu, %false, %none : !torch.list, !torch.int, !torch.none, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[3,4],si64> + %2 = torch.aten.fill.Scalar %1, %int0 : !torch.vtensor<[3,4],si64>, !torch.int -> !torch.vtensor<[3,4],si64> + return %2 : !torch.vtensor<[3,4],si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.scatter.src$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,8,6],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,4,3],si64>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[3,4,3],f32>) -> !torch.vtensor<[10,8,6],f32> { +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[3,4,3],f32> -> tensor<3x4x3xf32> +// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,4,3],si64> -> tensor<2x4x3xi64> +// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,8,6],f32> -> tensor<10x8x6xf32> +// CHECK: %[[VAL_6:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_7:.*]] = tosa.cast %[[VAL_4]] : (tensor<2x4x3xi64>) -> tensor<2x4x3xi32> +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<2x4x3xi32>) -> tensor<2x4x3x1xi32> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0], [0]], {{\[\[}}0], [0], [0]], {{\[\[}}0], [0], [0]], {{\[\[}}0], [0], [0]]], {{\[\[}}[1], [1], [1]], {{\[\[}}1], [1], [1]], {{\[\[}}1], [1], [1]], {{\[\[}}1], [1], [1]]]]> : tensor<2x4x3x1xi32>}> : () -> tensor<2x4x3x1xi32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]]], {{\[\[}}[0], [1], [2]], {{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]]]]> : tensor<2x4x3x1xi32>}> : () -> tensor<2x4x3x1xi32> +// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_9]], %[[VAL_8]], %[[VAL_10]] {axis = 3 : i32} : (tensor<2x4x3x1xi32>, tensor<2x4x3x1xi32>, tensor<2x4x3x1xi32>) -> tensor<2x4x3x3xi32> +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<3x4x3xf32>) -> tensor<1x36x1xf32> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor<10x8x6xf32>) -> tensor<1x480x1xf32> +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<2x4x3x3xi32>) -> tensor<24x3xi32> +// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<[48, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_14]], %[[VAL_15]] {shift = 0 : i8} : (tensor<24x3xi32>, tensor<3xi32>) -> tensor<24x3xi32> +// CHECK: %[[VAL_17:.*]] = tosa.reduce_sum %[[VAL_16]] {axis = 1 : i32} : (tensor<24x3xi32>) -> tensor<24x1xi32> +// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array} : (tensor<24x1xi32>) -> tensor<1x24xi32> +// CHECK: %[[VAL_19:.*]] = tosa.scatter %[[VAL_13]], %[[VAL_18]], %[[VAL_12]] : (tensor<1x480x1xf32>, tensor<1x24xi32>, tensor<1x36x1xf32>) -> tensor<1x480x1xf32> +// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array} : (tensor<1x480x1xf32>) -> tensor<10x8x6xf32> +// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<10x8x6xf32> -> !torch.vtensor<[10,8,6],f32> +// CHECK: return %[[VAL_21]] : !torch.vtensor<[10,8,6],f32> +// CHECK: } +func.func @torch.aten.scatter.src$basic(%arg0: !torch.vtensor<[10,8,6],f32>, %arg1: !torch.vtensor<[2,4,3],si64>, %arg2: !torch.vtensor<[3,4,3],f32>) -> !torch.vtensor<[10,8,6],f32> { + %int1 = torch.constant.int 1 + %0 = torch.aten.scatter.src %arg0, %int1, %arg1, %arg2 : !torch.vtensor<[10,8,6],f32>, !torch.int, !torch.vtensor<[2,4,3],si64>, !torch.vtensor<[3,4,3],f32> -> !torch.vtensor<[10,8,6],f32> + return %0 : !torch.vtensor<[10,8,6],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.slice_scatter$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[6,8],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[6,1],f32>) -> !torch.vtensor<[6,8],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[6,1],f32> -> tensor<6x1xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[6,8],f32> -> tensor<6x8xf32> +// CHECK: %[[VAL_4:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_5:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0> : tensor<6x1xi32>}> : () -> tensor<6x1xi32> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor<6x1xi32>) -> tensor<6x1x1xi32> +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<{{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]], {{\[\[}}4]], {{\[\[}}5]]]> : tensor<6x1x1xi32>}> : () -> tensor<6x1x1xi32> +// CHECK: %[[VAL_9:.*]] = tosa.concat %[[VAL_8]], %[[VAL_7]] {axis = 2 : i32} : (tensor<6x1x1xi32>, tensor<6x1x1xi32>) -> tensor<6x1x2xi32> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<6x1xf32>) -> tensor<1x6x1xf32> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<6x8xf32>) -> tensor<1x48x1xf32> +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor<6x1x2xi32>) -> tensor<6x2xi32> +// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[8, 1]> : tensor<2xi32>}> : () -> tensor<2xi32> +// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_12]], %[[VAL_13]] {shift = 0 : i8} : (tensor<6x2xi32>, tensor<2xi32>) -> tensor<6x2xi32> +// CHECK: %[[VAL_15:.*]] = tosa.reduce_sum %[[VAL_14]] {axis = 1 : i32} : (tensor<6x2xi32>) -> tensor<6x1xi32> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<6x1xi32>) -> tensor<1x6xi32> +// CHECK: %[[VAL_17:.*]] = tosa.scatter %[[VAL_11]], %[[VAL_16]], %[[VAL_10]] : (tensor<1x48x1xf32>, tensor<1x6xi32>, tensor<1x6x1xf32>) -> tensor<1x48x1xf32> +// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array} : (tensor<1x48x1xf32>) -> tensor<6x8xf32> +// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<6x8xf32> -> !torch.vtensor<[6,8],f32> +// CHECK: return %[[VAL_19]] : !torch.vtensor<[6,8],f32> +// CHECK: } +func.func @torch.aten.slice_scatter$basic(%arg0: !torch.vtensor<[6,8],f32>, %arg1: !torch.vtensor<[6,1],f32>) -> !torch.vtensor<[6,8],f32> { + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %0 = torch.aten.slice_scatter %arg0, %arg1, %int1, %int0, %int1, %int1 : !torch.vtensor<[6,8],f32>, !torch.vtensor<[6,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[6,8],f32> + return %0 : !torch.vtensor<[6,8],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.diag_embed$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.int -2 +// CHECK: %[[VAL_4:.*]] = torch.constant.int -1 +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]]], {{\[\[}}[0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]]]]> : tensor<2x3x4x1xi32>}> : () -> tensor<2x3x4x1xi32> +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<2x3x4xf32>) -> tensor<2x3x4x1xf32> +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x3x4x4xf32>}> : () -> tensor<2x3x4x4xf32> +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor<2x3x4x1xi32>) -> tensor<2x3x4x1x1xi32> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}[0]], {{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]]], {{\[\[}}[0]], {{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]]], {{\[\[}}[0]], {{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]]]], {{\[\[}}{{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]]], {{\[\[}}[1]], {{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]]], {{\[\[}}[1]], {{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]]]]]> : tensor<2x3x4x1x1xi32>}> : () -> tensor<2x3x4x1x1xi32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}[0]], {{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]]], {{\[\[}}[1]], {{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]]], {{\[\[}}[2]], {{\[\[}}2]], {{\[\[}}2]], {{\[\[}}2]]]], {{\[\[}}{{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]]], {{\[\[}}[1]], {{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]]], {{\[\[}}[2]], {{\[\[}}2]], {{\[\[}}2]], {{\[\[}}2]]]]]> : tensor<2x3x4x1x1xi32>}> : () -> tensor<2x3x4x1x1xi32> +// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]], {{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]], {{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]]], {{\[\[}}{{\[\[}}0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]], {{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]], {{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]]]]> : tensor<2x3x4x1x1xi32>}> : () -> tensor<2x3x4x1x1xi32> +// CHECK: %[[VAL_12:.*]] = tosa.concat %[[VAL_9]], %[[VAL_10]], %[[VAL_11]], %[[VAL_8]] {axis = 4 : i32} : (tensor<2x3x4x1x1xi32>, tensor<2x3x4x1x1xi32>, tensor<2x3x4x1x1xi32>, tensor<2x3x4x1x1xi32>) -> tensor<2x3x4x1x4xi32> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor<2x3x4x1xf32>) -> tensor<1x24x1xf32> +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<2x3x4x4xf32>) -> tensor<1x96x1xf32> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor<2x3x4x1x4xi32>) -> tensor<24x4xi32> +// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{value = dense<[48, 16, 4, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_15]], %[[VAL_16]] {shift = 0 : i8} : (tensor<24x4xi32>, tensor<4xi32>) -> tensor<24x4xi32> +// CHECK: %[[VAL_18:.*]] = tosa.reduce_sum %[[VAL_17]] {axis = 1 : i32} : (tensor<24x4xi32>) -> tensor<24x1xi32> +// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<24x1xi32>) -> tensor<1x24xi32> +// CHECK: %[[VAL_20:.*]] = tosa.scatter %[[VAL_14]], %[[VAL_19]], %[[VAL_13]] : (tensor<1x96x1xf32>, tensor<1x24xi32>, tensor<1x24x1xf32>) -> tensor<1x96x1xf32> +// CHECK: %[[VAL_21:.*]] = tosa.reshape %[[VAL_20]] {new_shape = array} : (tensor<1x96x1xf32>) -> tensor<2x3x4x4xf32> +// CHECK: %[[VAL_22:.*]] = "tosa.const"() <{value = dense<[0, 1, 2, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_23:.*]] = tosa.transpose %[[VAL_21]], %[[VAL_22]] : (tensor<2x3x4x4xf32>, tensor<4xi32>) -> tensor<2x3x4x4xf32> +// CHECK: %[[VAL_24:.*]] = torch_c.from_builtin_tensor %[[VAL_23]] : tensor<2x3x4x4xf32> -> !torch.vtensor<[2,3,4,4],f32> +// CHECK: return %[[VAL_24]] : !torch.vtensor<[2,3,4,4],f32> +// CHECK: } +func.func @torch.aten.diag_embed$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,4],f32> { + %int0 = torch.constant.int 0 + %int-2 = torch.constant.int -2 + %int-1 = torch.constant.int -1 + %0 = torch.aten.diag_embed %arg0, %int0, %int-2, %int-1 : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3,4,4],f32> + return %0 : !torch.vtensor<[2,3,4,4],f32> +}