[TOSA] Add legalization for aten.as_strided (#3848)

- Add Torch to TOSA legalization for aten.as_strided op
- Update xfail_sets with the following:
  + New aten.as_strided results
+ Changes from this commit:
7f9f99c6f8
  + Failed tests from new PyTorch version update
- Add new LIT test to basic.mlir


Change-Id: I6f471ea116ca47f2bf9537b62950fce75a2c624f

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
pull/3853/head
Justin Ngo 2024-11-04 09:57:59 -08:00 committed by GitHub
parent 6aa46967b6
commit 4c1518d365
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 206 additions and 31 deletions

View File

@ -6778,6 +6778,108 @@ LogicalResult ConvertAtenOp<AtenThresholdBackwardOp>::matchAndRewrite(
return success();
}
// Legalization for aten.as_strided
template <>
LogicalResult ConvertAtenOp<AtenAsStridedOp>::matchAndRewrite(
AtenAsStridedOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// To lower aten.as_strided to TOSA, we will first reshape the input tensor to
// an 1-D tensor, then calculate the indices of result elements based on the
// output size, stride and storage offset. With the reshaped 1-D tensor and
// the indices, we can apply Gather to extract the required elements into a
// new tensor and then reshape it back to the desired output shape.
auto self = adaptor.getSelf();
// Not a tensor type
auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
auto selfElemTy = selfType.getElementType();
auto selfShape = selfType.getShape();
auto resultType =
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
auto resultElemTy = resultType.getElementType();
// Get output size
SmallVector<int64_t> outputSize;
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(outputSize)))
return rewriter.notifyMatchFailure(
op, "Only a constant list form of output size is supported");
// Get stride
SmallVector<int64_t> stride;
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))
return rewriter.notifyMatchFailure(
op, "Only a constant list form of stride is supported");
// Get storage offset
int64_t offset;
if (!matchPattern(op.getStorageOffset(), m_TorchConstantInt(&offset)))
offset = 0;
// Reshape input tensor into an 1-D tensor
int64_t selfNumElems = std::accumulate(selfShape.begin(), selfShape.end(), 1,
std::multiplies<int64_t>());
auto self1D = rewriter.create<tosa::ReshapeOp>(
op->getLoc(), RankedTensorType::get({selfNumElems}, selfElemTy), self,
rewriter.getDenseI64ArrayAttr({selfNumElems}));
// Calculate the target elements indices
SmallVector<int32_t> targetIndicesVec;
int64_t outputRank = outputSize.size();
int64_t outputNumElems = std::accumulate(outputSize.begin(), outputSize.end(),
1, std::multiplies<int64_t>());
for (int64_t i = 0; i < outputNumElems; i++) {
// Index formula:
// index[i] = coord_i_0 * stride[0] + coord_i_1 * stride[1] + ... +
// coord_i_n * stride[n]
int32_t index = offset;
int64_t coordFinder = i;
for (int64_t dim = 0; dim < outputRank; dim++) {
int64_t indexCoord = coordFinder % outputSize[outputRank - dim - 1];
index += indexCoord * stride[outputRank - dim - 1];
coordFinder /= outputSize[outputRank - dim - 1];
}
targetIndicesVec.push_back(index);
}
auto targetIndices =
tosa::getConstTensor<int32_t>(rewriter, op, targetIndicesVec,
makeShapeTorchCompatible({outputNumElems}))
.value();
// Convert PyTorch-style indices and dim into TensorFlow-style indices
auto targetIndicesTf = tosa::convertTorchIndexToTfIndices(
rewriter, op, self1D.getResult(), targetIndices, 0);
if (!targetIndicesTf)
return rewriter.notifyMatchFailure(op,
"Convert PyTorch-style indices and dim "
"to TensorFlow-style indices failed");
// Gather the target elements from 1-D input tensor
// Apply TensorFlow GatherNdOp with TensorFlow-style indices to retrieve the
// target elements
auto gatherOp = tosa::convertGatherNdOp(
rewriter, op,
RankedTensorType::get(makeShapeTorchCompatible({outputNumElems}),
resultElemTy),
self1D.getResult(), targetIndicesTf.value());
if (!gatherOp)
return rewriter.notifyMatchFailure(op, "Convert GatherNdOp failed");
auto result = rewriter.create<tosa::ReshapeOp>(
op->getLoc(), resultType, gatherOp.value(),
rewriter.getDenseI64ArrayAttr(outputSize));
rewriter.replaceOp(op, {result.getResult()});
return success();
}
} // namespace
// -----------------------------------------------------------------------------
@ -7096,6 +7198,7 @@ public:
INSERT_ATENOP_PATTERN(AtenDiagEmbedOp);
INSERT_ATENOP_PATTERN(AtenUniformOp);
INSERT_ATENOP_PATTERN(AtenThresholdBackwardOp);
INSERT_ATENOP_PATTERN(AtenAsStridedOp);
#undef INSERT_ATENOP_PATTERN
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \

View File

@ -1723,6 +1723,9 @@ TOSA_CRASHING_SET = {
}
FX_IMPORTER_TOSA_CRASHING_SET = {
"Aten_TrilinearModuleSumAllDims_basic",
"Aten_TrilinearModuleSumdims_basic",
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"GridSamplerBasic1_basic",
"GridSamplerBasic2_basic",
"GridSamplerBasic3_basic",
@ -1744,6 +1747,13 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"Aten_TrilinearModuleSumAllDims_basic",
"Aten_TrilinearModuleSumdims_basic",
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModule_basic",
"ElementwiseAddBoolModule_basic",
"Exp2StaticModule_basic",
"CosineSimilarityStaticBroadcastModule_basic",
"DropoutTrainStaticShapeModule_basic",
"ElementwiseAtenLogicalAndOpModule_basic",
@ -1937,10 +1947,6 @@ TOSA_PASS_SET = {
"ElementwiseTruncIntModule_basic",
"ElementwiseSgnModule_basic",
"ElementwiseSignIntModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
"AddCDivModule_basic",
"AddCDiv_Module_basic",
"AddCMulModule_basic",
@ -2292,7 +2298,6 @@ TOSA_PASS_SET = {
"RreluWithNoiseBackwardTrainStaticModule_basic",
"RepeatModule_basic",
"RepeatInterleaveSelfIntNoDimModule_basic",
"ResNet18StaticModule_basic",
"ReshapeAliasCollapseModule_basic",
"ReshapeAliasExpandModule_basic",
"ReshapeAsModule_basic",
@ -2416,6 +2421,9 @@ MAKE_FX_TOSA_PASS_SET = (
TOSA_PASS_SET
| {
### Tests additionally passing in make_fx_tosa
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
"ResNet18StaticModule_basic",
"AdaptiveAvgPool1dStaticLargerOutput_basic",
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
@ -2464,9 +2472,7 @@ MAKE_FX_TOSA_PASS_SET = (
"MaxPool1dEmptyStrideStaticModule_basic",
"MaxPool1dStaticCeilModeTrueModule_basic",
"MaxPool1dStaticModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
"CosineSimilarityModule_basic",
"NativeGroupNormBackwardModule_basic",
"ReduceFrobeniusNormKeepDimModule_basic",
@ -2474,8 +2480,6 @@ MAKE_FX_TOSA_PASS_SET = (
"SliceWholeTensorModule_basic",
"TensorFloatModule_basic",
"TensorIntModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"RepeatInterleaveSelfIntModule_basic",
"TorchPrimLoopForLikeTensorArgModule_basic",
"ViewSizeDimFollowedByCollapsedOnesModule_basic",
@ -2492,13 +2496,6 @@ MAKE_FX_TOSA_PASS_SET = (
}
) - {
### Test failing in make_fx_tosa but not in tosa
"ChunkListUnpackUneven_Module_basic",
"ChunkListUnpack_Module_basic",
"SplitTensorGetItem_Module_basic",
"SplitTensorLastSmallerModule_basic",
"SplitTensorListUnpackModule_basic",
"SplitTensorNegativeDimModule_basic",
"SplitWithSizesListUnpackModule_basic",
# Dynamic shape, has extra unsupported broadcast ops
"Matmul_3d",
# Unimplemented operator 'aten._index_put_impl_.hacked_twin'
@ -3367,6 +3364,8 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
}
FX_IMPORTER_TOSA_XFAIL_SET = {
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModuleZerodDimBug_basic",
"AdaptiveMaxPool1dDimOneStatic_basic",
"ElementwiseRreluWithNoiseTrainModule_basic",
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
@ -3387,16 +3386,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"SliceOutOfUpperBoundIndexModule_basic",
"SliceOutOfUpperBoundIndexStaticModule_basic",
"SliceStartEqEndModule_basic",
"ChunkListUnpackDynamic_Module_basic",
"ChunkListUnpackUnevenDynamic_Module_basic",
"ChunkListUnpackUneven_Module_basic",
"ChunkListUnpack_Module_basic",
"SplitTensorGetItem_Module_basic",
"SplitTensorLastSmallerModule_basic",
"SplitTensorListUnpackModule_basic",
"SplitTensorNegativeDimModule_basic",
"SplitWithSizesListUnpackModule_basic",
"SplitWithSizes_Module_basic",
"ElementwiseCreateComplexModule_basic",
"AtenPolarDoubleModule_basic",
"AtenPolarFloatModule_basic",
@ -3572,7 +3561,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ElementwiseAcosModule_basic",
"ElementwiseAcoshIntModule_basic",
"ElementwiseAcoshModule_basic",
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
"ElementwiseAsinIntModule_basic",
"ElementwiseAsinModule_basic",
"ElementwiseAsinhIntModule_basic",
@ -3608,7 +3596,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ElementwiseLog1pModule_basic",
"ElementwiseLog2IntModule_basic",
"ElementwiseLogIntModule_basic",
"ElementwiseLogSigmoidModule_basic",
"ElementwiseLogitModule_basic",
"ElementwiseMishModule_basic",
"ElementwiseMulTensorComplexDiffModule_basic",
@ -3731,8 +3718,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"NormScalarModule_basic",
"NormScalarOptDimKeepDimComplexModule_basic",
"NormalFunctionalModule_basic",
"NumToTensorFloatModule_basic",
"NumToTensorIntModule_basic",
"NumelModule_basic",
"NumelZeroRankModule_basic",
"OnesLikeModule_falsePinMemory",
@ -3790,7 +3775,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ReplicationPad2dModule_right0",
"ReplicationPad2dModule_top0",
"RollModule_basic",
"RsubInt0d_NumToTensor_Module_basic",
"RsubIntModule_noalpha_basic",
"ScalarConstantTupleModule_basic",
"ScalarImplicitFloatModule_basic",
@ -3873,6 +3857,55 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ViewSizeFromOtherTensor_basic",
"VisionTransformerModule_basic",
"ZerosLikeModule_falsePinMemory",
# count_include_pad and divisor_override check in TOSA AvgPool2d
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule_basic",
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
"ResNet18Module_basic",
"ResNet18StaticModule_basic",
"MobilenetV3Module_basic",
# Unexpected failures due to new PyTorch version update
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
"AdaptiveAvgPool1dGeneralDynamic_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
"AdaptiveAvgPool1dStaticLargerOutput_basic",
"AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dDynamicNoBatch_basic",
"AdaptiveAvgPool2dDynamic_basic",
"CrossEntropyLossModule_basic",
"CrossEntropyLossNoReductionModule_basic",
"ElementwiseRreluTrainModule_basic",
"ElementwiseRreluTrainStaticModule_basic",
"IndexPutImpl1DFloatNonAccumulateModule_basic",
"IndexPutImpl1DIntNonAccumulateModule_basic",
"IndexPutImpl2DFloatNonAccumulateModule_basic",
"IndexPutImpl3DFloatNonAccumulateModule_basic",
"IouOfModule_basic",
"MaxPool1dEmptyStrideStaticModule_basic",
"MaxPool1dStaticCeilModeTrueModule_basic",
"MaxPool1dStaticModule_basic",
"MeshgridIndexingIJ_basic",
"MeshgridIndexingXY_basic",
"Meshgrid_basic",
"OneHotModule_basic",
"ReduceFrobeniusNormKeepDimModule_basic",
"ReduceFrobeniusNormModule_basic",
"RepeatInterleaveSelfIntModule_basic",
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
"ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionMaskModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScaledDotProductAttentionSameModule_basic",
}
ONNX_TOSA_CRASHING_SET = {
@ -3885,6 +3918,7 @@ ONNX_TOSA_CRASHING_SET = {
}
ONNX_TOSA_XFAIL_SET = {
"Exp2StaticModule_basic",
"ElementwiseRreluWithNoiseEvalModule_basic",
"ElementwiseRreluWithNoiseEvalStaticModule_basic",
"ElementwiseRreluWithNoiseTrainModule_basic",

View File

@ -2275,3 +2275,41 @@ func.func @torch.aten.uniform$basic(%arg0: !torch.vtensor<[3,4],f64>) -> (!torch
%0 = torch.aten.uniform %arg0, %float1.000000e00, %float1.000000e01, %none : !torch.vtensor<[3,4],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[3,4],f64>
return %0, %0 : !torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>
}
// -----
// CHECK-LABEL: func.func @torch.aten.as_strided$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[3,3],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.none
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
// CHECK: %[[VAL_5:.*]] = torch.constant.int 3
// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_5]], %[[VAL_5]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 25>} : (tensor<5x5xf32>) -> tensor<25xf32>
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<[0, 1, 2, 2, 3, 4, 4, 5, 6]> : tensor<9xi32>}> : () -> tensor<9xi32>
// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array<i64: 9, 1>} : (tensor<9xi32>) -> tensor<9x1xi32>
// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_10]] {axis = 1 : i32} : (tensor<9x1xi32>) -> tensor<9x1xi32>
// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array<i64: 1, 25, 1>} : (tensor<25xf32>) -> tensor<1x25x1xf32>
// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array<i64: 9, 1>} : (tensor<9x1xi32>) -> tensor<9x1xi32>
// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_14]] {shift = 0 : i8} : (tensor<9x1xi32>, tensor<1xi32>) -> tensor<9x1xi32>
// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<9x1xi32>) -> tensor<9x1xi32>
// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array<i64: 1, 9>} : (tensor<9x1xi32>) -> tensor<1x9xi32>
// CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_12]], %[[VAL_17]] : (tensor<1x25x1xf32>, tensor<1x9xi32>) -> tensor<1x9x1xf32>
// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array<i64: 9>} : (tensor<1x9x1xf32>) -> tensor<9xf32>
// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array<i64: 3, 3>} : (tensor<9xf32>) -> tensor<3x3xf32>
// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<3x3xf32> -> !torch.vtensor<[3,3],f32>
// CHECK: return %[[VAL_21]] : !torch.vtensor<[3,3],f32>
// CHECK: }
func.func @torch.aten.as_strided$basic(%arg0: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[3,3],f32> {
%none = torch.constant.none
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.aten.as_strided %arg0, %0, %1, %none : !torch.vtensor<[5,5],f32>, !torch.list<int>, !torch.list<int>, !torch.none -> !torch.vtensor<[3,3],f32>
return %2 : !torch.vtensor<[3,3],f32>
}