mirror of https://github.com/llvm/torch-mlir
[TOSA] Add legalization for aten.as_strided (#3848)
- Add Torch to TOSA legalization for aten.as_strided op
- Update xfail_sets with the following:
+ New aten.as_strided results
+ Changes from this commit:
7f9f99c6f8
+ Failed tests from new PyTorch version update
- Add new LIT test to basic.mlir
Change-Id: I6f471ea116ca47f2bf9537b62950fce75a2c624f
Signed-off-by: Justin Ngo <justin.ngo@arm.com>
pull/3853/head
parent
6aa46967b6
commit
4c1518d365
|
@ -6778,6 +6778,108 @@ LogicalResult ConvertAtenOp<AtenThresholdBackwardOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Legalization for aten.as_strided
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenAsStridedOp>::matchAndRewrite(
|
||||||
|
AtenAsStridedOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
// To lower aten.as_strided to TOSA, we will first reshape the input tensor to
|
||||||
|
// an 1-D tensor, then calculate the indices of result elements based on the
|
||||||
|
// output size, stride and storage offset. With the reshaped 1-D tensor and
|
||||||
|
// the indices, we can apply Gather to extract the required elements into a
|
||||||
|
// new tensor and then reshape it back to the desired output shape.
|
||||||
|
auto self = adaptor.getSelf();
|
||||||
|
|
||||||
|
// Not a tensor type
|
||||||
|
auto selfType = dyn_cast<TensorType>(self.getType());
|
||||||
|
if (!selfType)
|
||||||
|
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
||||||
|
auto selfElemTy = selfType.getElementType();
|
||||||
|
auto selfShape = selfType.getShape();
|
||||||
|
|
||||||
|
auto resultType =
|
||||||
|
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
|
||||||
|
auto resultElemTy = resultType.getElementType();
|
||||||
|
|
||||||
|
// Get output size
|
||||||
|
SmallVector<int64_t> outputSize;
|
||||||
|
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(outputSize)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only a constant list form of output size is supported");
|
||||||
|
|
||||||
|
// Get stride
|
||||||
|
SmallVector<int64_t> stride;
|
||||||
|
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only a constant list form of stride is supported");
|
||||||
|
|
||||||
|
// Get storage offset
|
||||||
|
int64_t offset;
|
||||||
|
if (!matchPattern(op.getStorageOffset(), m_TorchConstantInt(&offset)))
|
||||||
|
offset = 0;
|
||||||
|
|
||||||
|
// Reshape input tensor into an 1-D tensor
|
||||||
|
int64_t selfNumElems = std::accumulate(selfShape.begin(), selfShape.end(), 1,
|
||||||
|
std::multiplies<int64_t>());
|
||||||
|
|
||||||
|
auto self1D = rewriter.create<tosa::ReshapeOp>(
|
||||||
|
op->getLoc(), RankedTensorType::get({selfNumElems}, selfElemTy), self,
|
||||||
|
rewriter.getDenseI64ArrayAttr({selfNumElems}));
|
||||||
|
|
||||||
|
// Calculate the target elements indices
|
||||||
|
SmallVector<int32_t> targetIndicesVec;
|
||||||
|
int64_t outputRank = outputSize.size();
|
||||||
|
int64_t outputNumElems = std::accumulate(outputSize.begin(), outputSize.end(),
|
||||||
|
1, std::multiplies<int64_t>());
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < outputNumElems; i++) {
|
||||||
|
// Index formula:
|
||||||
|
// index[i] = coord_i_0 * stride[0] + coord_i_1 * stride[1] + ... +
|
||||||
|
// coord_i_n * stride[n]
|
||||||
|
int32_t index = offset;
|
||||||
|
int64_t coordFinder = i;
|
||||||
|
for (int64_t dim = 0; dim < outputRank; dim++) {
|
||||||
|
int64_t indexCoord = coordFinder % outputSize[outputRank - dim - 1];
|
||||||
|
index += indexCoord * stride[outputRank - dim - 1];
|
||||||
|
coordFinder /= outputSize[outputRank - dim - 1];
|
||||||
|
}
|
||||||
|
targetIndicesVec.push_back(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto targetIndices =
|
||||||
|
tosa::getConstTensor<int32_t>(rewriter, op, targetIndicesVec,
|
||||||
|
makeShapeTorchCompatible({outputNumElems}))
|
||||||
|
.value();
|
||||||
|
|
||||||
|
// Convert PyTorch-style indices and dim into TensorFlow-style indices
|
||||||
|
auto targetIndicesTf = tosa::convertTorchIndexToTfIndices(
|
||||||
|
rewriter, op, self1D.getResult(), targetIndices, 0);
|
||||||
|
if (!targetIndicesTf)
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"Convert PyTorch-style indices and dim "
|
||||||
|
"to TensorFlow-style indices failed");
|
||||||
|
|
||||||
|
// Gather the target elements from 1-D input tensor
|
||||||
|
// Apply TensorFlow GatherNdOp with TensorFlow-style indices to retrieve the
|
||||||
|
// target elements
|
||||||
|
auto gatherOp = tosa::convertGatherNdOp(
|
||||||
|
rewriter, op,
|
||||||
|
RankedTensorType::get(makeShapeTorchCompatible({outputNumElems}),
|
||||||
|
resultElemTy),
|
||||||
|
self1D.getResult(), targetIndicesTf.value());
|
||||||
|
|
||||||
|
if (!gatherOp)
|
||||||
|
return rewriter.notifyMatchFailure(op, "Convert GatherNdOp failed");
|
||||||
|
|
||||||
|
auto result = rewriter.create<tosa::ReshapeOp>(
|
||||||
|
op->getLoc(), resultType, gatherOp.value(),
|
||||||
|
rewriter.getDenseI64ArrayAttr(outputSize));
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, {result.getResult()});
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
@ -7096,6 +7198,7 @@ public:
|
||||||
INSERT_ATENOP_PATTERN(AtenDiagEmbedOp);
|
INSERT_ATENOP_PATTERN(AtenDiagEmbedOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenUniformOp);
|
INSERT_ATENOP_PATTERN(AtenUniformOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenThresholdBackwardOp);
|
INSERT_ATENOP_PATTERN(AtenThresholdBackwardOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenAsStridedOp);
|
||||||
#undef INSERT_ATENOP_PATTERN
|
#undef INSERT_ATENOP_PATTERN
|
||||||
|
|
||||||
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
||||||
|
|
|
@ -1723,6 +1723,9 @@ TOSA_CRASHING_SET = {
|
||||||
}
|
}
|
||||||
|
|
||||||
FX_IMPORTER_TOSA_CRASHING_SET = {
|
FX_IMPORTER_TOSA_CRASHING_SET = {
|
||||||
|
"Aten_TrilinearModuleSumAllDims_basic",
|
||||||
|
"Aten_TrilinearModuleSumdims_basic",
|
||||||
|
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
|
||||||
"GridSamplerBasic1_basic",
|
"GridSamplerBasic1_basic",
|
||||||
"GridSamplerBasic2_basic",
|
"GridSamplerBasic2_basic",
|
||||||
"GridSamplerBasic3_basic",
|
"GridSamplerBasic3_basic",
|
||||||
|
@ -1744,6 +1747,13 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
|
||||||
# Write the TOSA set as a "passing" set as it is very early in development
|
# Write the TOSA set as a "passing" set as it is very early in development
|
||||||
# and very few tests work yet.
|
# and very few tests work yet.
|
||||||
TOSA_PASS_SET = {
|
TOSA_PASS_SET = {
|
||||||
|
"Aten_TrilinearModuleSumAllDims_basic",
|
||||||
|
"Aten_TrilinearModuleSumdims_basic",
|
||||||
|
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
|
||||||
|
"Aten_TrilinearModuleVaryingRanks_basic",
|
||||||
|
"Aten_TrilinearModule_basic",
|
||||||
|
"ElementwiseAddBoolModule_basic",
|
||||||
|
"Exp2StaticModule_basic",
|
||||||
"CosineSimilarityStaticBroadcastModule_basic",
|
"CosineSimilarityStaticBroadcastModule_basic",
|
||||||
"DropoutTrainStaticShapeModule_basic",
|
"DropoutTrainStaticShapeModule_basic",
|
||||||
"ElementwiseAtenLogicalAndOpModule_basic",
|
"ElementwiseAtenLogicalAndOpModule_basic",
|
||||||
|
@ -1937,10 +1947,6 @@ TOSA_PASS_SET = {
|
||||||
"ElementwiseTruncIntModule_basic",
|
"ElementwiseTruncIntModule_basic",
|
||||||
"ElementwiseSgnModule_basic",
|
"ElementwiseSgnModule_basic",
|
||||||
"ElementwiseSignIntModule_basic",
|
"ElementwiseSignIntModule_basic",
|
||||||
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
|
||||||
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
|
|
||||||
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
|
|
||||||
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
|
||||||
"AddCDivModule_basic",
|
"AddCDivModule_basic",
|
||||||
"AddCDiv_Module_basic",
|
"AddCDiv_Module_basic",
|
||||||
"AddCMulModule_basic",
|
"AddCMulModule_basic",
|
||||||
|
@ -2292,7 +2298,6 @@ TOSA_PASS_SET = {
|
||||||
"RreluWithNoiseBackwardTrainStaticModule_basic",
|
"RreluWithNoiseBackwardTrainStaticModule_basic",
|
||||||
"RepeatModule_basic",
|
"RepeatModule_basic",
|
||||||
"RepeatInterleaveSelfIntNoDimModule_basic",
|
"RepeatInterleaveSelfIntNoDimModule_basic",
|
||||||
"ResNet18StaticModule_basic",
|
|
||||||
"ReshapeAliasCollapseModule_basic",
|
"ReshapeAliasCollapseModule_basic",
|
||||||
"ReshapeAliasExpandModule_basic",
|
"ReshapeAliasExpandModule_basic",
|
||||||
"ReshapeAsModule_basic",
|
"ReshapeAsModule_basic",
|
||||||
|
@ -2416,6 +2421,9 @@ MAKE_FX_TOSA_PASS_SET = (
|
||||||
TOSA_PASS_SET
|
TOSA_PASS_SET
|
||||||
| {
|
| {
|
||||||
### Tests additionally passing in make_fx_tosa
|
### Tests additionally passing in make_fx_tosa
|
||||||
|
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
||||||
|
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
|
||||||
|
"ResNet18StaticModule_basic",
|
||||||
"AdaptiveAvgPool1dStaticLargerOutput_basic",
|
"AdaptiveAvgPool1dStaticLargerOutput_basic",
|
||||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||||
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
|
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
|
||||||
|
@ -2464,9 +2472,7 @@ MAKE_FX_TOSA_PASS_SET = (
|
||||||
"MaxPool1dEmptyStrideStaticModule_basic",
|
"MaxPool1dEmptyStrideStaticModule_basic",
|
||||||
"MaxPool1dStaticCeilModeTrueModule_basic",
|
"MaxPool1dStaticCeilModeTrueModule_basic",
|
||||||
"MaxPool1dStaticModule_basic",
|
"MaxPool1dStaticModule_basic",
|
||||||
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
|
|
||||||
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
|
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
|
||||||
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
|
|
||||||
"CosineSimilarityModule_basic",
|
"CosineSimilarityModule_basic",
|
||||||
"NativeGroupNormBackwardModule_basic",
|
"NativeGroupNormBackwardModule_basic",
|
||||||
"ReduceFrobeniusNormKeepDimModule_basic",
|
"ReduceFrobeniusNormKeepDimModule_basic",
|
||||||
|
@ -2474,8 +2480,6 @@ MAKE_FX_TOSA_PASS_SET = (
|
||||||
"SliceWholeTensorModule_basic",
|
"SliceWholeTensorModule_basic",
|
||||||
"TensorFloatModule_basic",
|
"TensorFloatModule_basic",
|
||||||
"TensorIntModule_basic",
|
"TensorIntModule_basic",
|
||||||
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
|
|
||||||
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
|
|
||||||
"RepeatInterleaveSelfIntModule_basic",
|
"RepeatInterleaveSelfIntModule_basic",
|
||||||
"TorchPrimLoopForLikeTensorArgModule_basic",
|
"TorchPrimLoopForLikeTensorArgModule_basic",
|
||||||
"ViewSizeDimFollowedByCollapsedOnesModule_basic",
|
"ViewSizeDimFollowedByCollapsedOnesModule_basic",
|
||||||
|
@ -2492,13 +2496,6 @@ MAKE_FX_TOSA_PASS_SET = (
|
||||||
}
|
}
|
||||||
) - {
|
) - {
|
||||||
### Test failing in make_fx_tosa but not in tosa
|
### Test failing in make_fx_tosa but not in tosa
|
||||||
"ChunkListUnpackUneven_Module_basic",
|
|
||||||
"ChunkListUnpack_Module_basic",
|
|
||||||
"SplitTensorGetItem_Module_basic",
|
|
||||||
"SplitTensorLastSmallerModule_basic",
|
|
||||||
"SplitTensorListUnpackModule_basic",
|
|
||||||
"SplitTensorNegativeDimModule_basic",
|
|
||||||
"SplitWithSizesListUnpackModule_basic",
|
|
||||||
# Dynamic shape, has extra unsupported broadcast ops
|
# Dynamic shape, has extra unsupported broadcast ops
|
||||||
"Matmul_3d",
|
"Matmul_3d",
|
||||||
# Unimplemented operator 'aten._index_put_impl_.hacked_twin'
|
# Unimplemented operator 'aten._index_put_impl_.hacked_twin'
|
||||||
|
@ -3367,6 +3364,8 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
|
||||||
}
|
}
|
||||||
|
|
||||||
FX_IMPORTER_TOSA_XFAIL_SET = {
|
FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
|
"Aten_TrilinearModuleVaryingRanks_basic",
|
||||||
|
"Aten_TrilinearModuleZerodDimBug_basic",
|
||||||
"AdaptiveMaxPool1dDimOneStatic_basic",
|
"AdaptiveMaxPool1dDimOneStatic_basic",
|
||||||
"ElementwiseRreluWithNoiseTrainModule_basic",
|
"ElementwiseRreluWithNoiseTrainModule_basic",
|
||||||
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
|
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
|
||||||
|
@ -3387,16 +3386,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"SliceOutOfUpperBoundIndexModule_basic",
|
"SliceOutOfUpperBoundIndexModule_basic",
|
||||||
"SliceOutOfUpperBoundIndexStaticModule_basic",
|
"SliceOutOfUpperBoundIndexStaticModule_basic",
|
||||||
"SliceStartEqEndModule_basic",
|
"SliceStartEqEndModule_basic",
|
||||||
"ChunkListUnpackDynamic_Module_basic",
|
|
||||||
"ChunkListUnpackUnevenDynamic_Module_basic",
|
|
||||||
"ChunkListUnpackUneven_Module_basic",
|
|
||||||
"ChunkListUnpack_Module_basic",
|
|
||||||
"SplitTensorGetItem_Module_basic",
|
|
||||||
"SplitTensorLastSmallerModule_basic",
|
|
||||||
"SplitTensorListUnpackModule_basic",
|
|
||||||
"SplitTensorNegativeDimModule_basic",
|
|
||||||
"SplitWithSizesListUnpackModule_basic",
|
|
||||||
"SplitWithSizes_Module_basic",
|
|
||||||
"ElementwiseCreateComplexModule_basic",
|
"ElementwiseCreateComplexModule_basic",
|
||||||
"AtenPolarDoubleModule_basic",
|
"AtenPolarDoubleModule_basic",
|
||||||
"AtenPolarFloatModule_basic",
|
"AtenPolarFloatModule_basic",
|
||||||
|
@ -3572,7 +3561,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseAcosModule_basic",
|
"ElementwiseAcosModule_basic",
|
||||||
"ElementwiseAcoshIntModule_basic",
|
"ElementwiseAcoshIntModule_basic",
|
||||||
"ElementwiseAcoshModule_basic",
|
"ElementwiseAcoshModule_basic",
|
||||||
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
|
|
||||||
"ElementwiseAsinIntModule_basic",
|
"ElementwiseAsinIntModule_basic",
|
||||||
"ElementwiseAsinModule_basic",
|
"ElementwiseAsinModule_basic",
|
||||||
"ElementwiseAsinhIntModule_basic",
|
"ElementwiseAsinhIntModule_basic",
|
||||||
|
@ -3608,7 +3596,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseLog1pModule_basic",
|
"ElementwiseLog1pModule_basic",
|
||||||
"ElementwiseLog2IntModule_basic",
|
"ElementwiseLog2IntModule_basic",
|
||||||
"ElementwiseLogIntModule_basic",
|
"ElementwiseLogIntModule_basic",
|
||||||
"ElementwiseLogSigmoidModule_basic",
|
|
||||||
"ElementwiseLogitModule_basic",
|
"ElementwiseLogitModule_basic",
|
||||||
"ElementwiseMishModule_basic",
|
"ElementwiseMishModule_basic",
|
||||||
"ElementwiseMulTensorComplexDiffModule_basic",
|
"ElementwiseMulTensorComplexDiffModule_basic",
|
||||||
|
@ -3731,8 +3718,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"NormScalarModule_basic",
|
"NormScalarModule_basic",
|
||||||
"NormScalarOptDimKeepDimComplexModule_basic",
|
"NormScalarOptDimKeepDimComplexModule_basic",
|
||||||
"NormalFunctionalModule_basic",
|
"NormalFunctionalModule_basic",
|
||||||
"NumToTensorFloatModule_basic",
|
|
||||||
"NumToTensorIntModule_basic",
|
|
||||||
"NumelModule_basic",
|
"NumelModule_basic",
|
||||||
"NumelZeroRankModule_basic",
|
"NumelZeroRankModule_basic",
|
||||||
"OnesLikeModule_falsePinMemory",
|
"OnesLikeModule_falsePinMemory",
|
||||||
|
@ -3790,7 +3775,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ReplicationPad2dModule_right0",
|
"ReplicationPad2dModule_right0",
|
||||||
"ReplicationPad2dModule_top0",
|
"ReplicationPad2dModule_top0",
|
||||||
"RollModule_basic",
|
"RollModule_basic",
|
||||||
"RsubInt0d_NumToTensor_Module_basic",
|
|
||||||
"RsubIntModule_noalpha_basic",
|
"RsubIntModule_noalpha_basic",
|
||||||
"ScalarConstantTupleModule_basic",
|
"ScalarConstantTupleModule_basic",
|
||||||
"ScalarImplicitFloatModule_basic",
|
"ScalarImplicitFloatModule_basic",
|
||||||
|
@ -3873,6 +3857,55 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ViewSizeFromOtherTensor_basic",
|
"ViewSizeFromOtherTensor_basic",
|
||||||
"VisionTransformerModule_basic",
|
"VisionTransformerModule_basic",
|
||||||
"ZerosLikeModule_falsePinMemory",
|
"ZerosLikeModule_falsePinMemory",
|
||||||
|
# count_include_pad and divisor_override check in TOSA AvgPool2d
|
||||||
|
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
|
||||||
|
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
||||||
|
"AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule_basic",
|
||||||
|
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
|
||||||
|
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
||||||
|
"AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic",
|
||||||
|
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
|
||||||
|
"ResNet18Module_basic",
|
||||||
|
"ResNet18StaticModule_basic",
|
||||||
|
"MobilenetV3Module_basic",
|
||||||
|
# Unexpected failures due to new PyTorch version update
|
||||||
|
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
||||||
|
"AdaptiveAvgPool1dGeneralDynamic_basic",
|
||||||
|
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
|
||||||
|
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
|
||||||
|
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
|
||||||
|
"AdaptiveAvgPool1dStaticLargerOutput_basic",
|
||||||
|
"AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic",
|
||||||
|
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
|
||||||
|
"AdaptiveAvgPool2dDynamicNoBatch_basic",
|
||||||
|
"AdaptiveAvgPool2dDynamic_basic",
|
||||||
|
"CrossEntropyLossModule_basic",
|
||||||
|
"CrossEntropyLossNoReductionModule_basic",
|
||||||
|
"ElementwiseRreluTrainModule_basic",
|
||||||
|
"ElementwiseRreluTrainStaticModule_basic",
|
||||||
|
"IndexPutImpl1DFloatNonAccumulateModule_basic",
|
||||||
|
"IndexPutImpl1DIntNonAccumulateModule_basic",
|
||||||
|
"IndexPutImpl2DFloatNonAccumulateModule_basic",
|
||||||
|
"IndexPutImpl3DFloatNonAccumulateModule_basic",
|
||||||
|
"IouOfModule_basic",
|
||||||
|
"MaxPool1dEmptyStrideStaticModule_basic",
|
||||||
|
"MaxPool1dStaticCeilModeTrueModule_basic",
|
||||||
|
"MaxPool1dStaticModule_basic",
|
||||||
|
"MeshgridIndexingIJ_basic",
|
||||||
|
"MeshgridIndexingXY_basic",
|
||||||
|
"Meshgrid_basic",
|
||||||
|
"OneHotModule_basic",
|
||||||
|
"ReduceFrobeniusNormKeepDimModule_basic",
|
||||||
|
"ReduceFrobeniusNormModule_basic",
|
||||||
|
"RepeatInterleaveSelfIntModule_basic",
|
||||||
|
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||||
|
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||||
|
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
|
||||||
|
"ScaledDotProductAttentionDifferentModule_basic",
|
||||||
|
"ScaledDotProductAttentionMaskModule_basic",
|
||||||
|
"ScaledDotProductAttentionSameCausalModule_basic",
|
||||||
|
"ScaledDotProductAttentionSameDynamicModule_basic",
|
||||||
|
"ScaledDotProductAttentionSameModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
ONNX_TOSA_CRASHING_SET = {
|
ONNX_TOSA_CRASHING_SET = {
|
||||||
|
@ -3885,6 +3918,7 @@ ONNX_TOSA_CRASHING_SET = {
|
||||||
}
|
}
|
||||||
|
|
||||||
ONNX_TOSA_XFAIL_SET = {
|
ONNX_TOSA_XFAIL_SET = {
|
||||||
|
"Exp2StaticModule_basic",
|
||||||
"ElementwiseRreluWithNoiseEvalModule_basic",
|
"ElementwiseRreluWithNoiseEvalModule_basic",
|
||||||
"ElementwiseRreluWithNoiseEvalStaticModule_basic",
|
"ElementwiseRreluWithNoiseEvalStaticModule_basic",
|
||||||
"ElementwiseRreluWithNoiseTrainModule_basic",
|
"ElementwiseRreluWithNoiseTrainModule_basic",
|
||||||
|
|
|
@ -2275,3 +2275,41 @@ func.func @torch.aten.uniform$basic(%arg0: !torch.vtensor<[3,4],f64>) -> (!torch
|
||||||
%0 = torch.aten.uniform %arg0, %float1.000000e00, %float1.000000e01, %none : !torch.vtensor<[3,4],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[3,4],f64>
|
%0 = torch.aten.uniform %arg0, %float1.000000e00, %float1.000000e01, %none : !torch.vtensor<[3,4],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[3,4],f64>
|
||||||
return %0, %0 : !torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>
|
return %0, %0 : !torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.as_strided$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[3,3],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_5]], %[[VAL_5]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 25>} : (tensor<5x5xf32>) -> tensor<25xf32>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<[0, 1, 2, 2, 3, 4, 4, 5, 6]> : tensor<9xi32>}> : () -> tensor<9xi32>
|
||||||
|
// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array<i64: 9, 1>} : (tensor<9xi32>) -> tensor<9x1xi32>
|
||||||
|
// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_10]] {axis = 1 : i32} : (tensor<9x1xi32>) -> tensor<9x1xi32>
|
||||||
|
// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array<i64: 1, 25, 1>} : (tensor<25xf32>) -> tensor<1x25x1xf32>
|
||||||
|
// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array<i64: 9, 1>} : (tensor<9x1xi32>) -> tensor<9x1xi32>
|
||||||
|
// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
|
||||||
|
// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_14]] {shift = 0 : i8} : (tensor<9x1xi32>, tensor<1xi32>) -> tensor<9x1xi32>
|
||||||
|
// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<9x1xi32>) -> tensor<9x1xi32>
|
||||||
|
// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array<i64: 1, 9>} : (tensor<9x1xi32>) -> tensor<1x9xi32>
|
||||||
|
// CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_12]], %[[VAL_17]] : (tensor<1x25x1xf32>, tensor<1x9xi32>) -> tensor<1x9x1xf32>
|
||||||
|
// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array<i64: 9>} : (tensor<1x9x1xf32>) -> tensor<9xf32>
|
||||||
|
// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array<i64: 3, 3>} : (tensor<9xf32>) -> tensor<3x3xf32>
|
||||||
|
// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<3x3xf32> -> !torch.vtensor<[3,3],f32>
|
||||||
|
// CHECK: return %[[VAL_21]] : !torch.vtensor<[3,3],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.as_strided$basic(%arg0: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[3,3],f32> {
|
||||||
|
%none = torch.constant.none
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%int3 = torch.constant.int 3
|
||||||
|
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%1 = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%2 = torch.aten.as_strided %arg0, %0, %1, %none : !torch.vtensor<[5,5],f32>, !torch.list<int>, !torch.list<int>, !torch.none -> !torch.vtensor<[3,3],f32>
|
||||||
|
return %2 : !torch.vtensor<[3,3],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue