mirror of https://github.com/llvm/torch-mlir
[TOSA] Add legalization for aten.as_strided (#3848)
- Add Torch to TOSA legalization for aten.as_strided op
- Update xfail_sets with the following:
+ New aten.as_strided results
+ Changes from this commit:
7f9f99c6f8
+ Failed tests from new PyTorch version update
- Add new LIT test to basic.mlir
Change-Id: I6f471ea116ca47f2bf9537b62950fce75a2c624f
Signed-off-by: Justin Ngo <justin.ngo@arm.com>
pull/3853/head
parent
6aa46967b6
commit
4c1518d365
|
@ -6778,6 +6778,108 @@ LogicalResult ConvertAtenOp<AtenThresholdBackwardOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
// Legalization for aten.as_strided
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenAsStridedOp>::matchAndRewrite(
|
||||
AtenAsStridedOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// To lower aten.as_strided to TOSA, we will first reshape the input tensor to
|
||||
// an 1-D tensor, then calculate the indices of result elements based on the
|
||||
// output size, stride and storage offset. With the reshaped 1-D tensor and
|
||||
// the indices, we can apply Gather to extract the required elements into a
|
||||
// new tensor and then reshape it back to the desired output shape.
|
||||
auto self = adaptor.getSelf();
|
||||
|
||||
// Not a tensor type
|
||||
auto selfType = dyn_cast<TensorType>(self.getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
||||
auto selfElemTy = selfType.getElementType();
|
||||
auto selfShape = selfType.getShape();
|
||||
|
||||
auto resultType =
|
||||
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
|
||||
auto resultElemTy = resultType.getElementType();
|
||||
|
||||
// Get output size
|
||||
SmallVector<int64_t> outputSize;
|
||||
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(outputSize)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only a constant list form of output size is supported");
|
||||
|
||||
// Get stride
|
||||
SmallVector<int64_t> stride;
|
||||
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only a constant list form of stride is supported");
|
||||
|
||||
// Get storage offset
|
||||
int64_t offset;
|
||||
if (!matchPattern(op.getStorageOffset(), m_TorchConstantInt(&offset)))
|
||||
offset = 0;
|
||||
|
||||
// Reshape input tensor into an 1-D tensor
|
||||
int64_t selfNumElems = std::accumulate(selfShape.begin(), selfShape.end(), 1,
|
||||
std::multiplies<int64_t>());
|
||||
|
||||
auto self1D = rewriter.create<tosa::ReshapeOp>(
|
||||
op->getLoc(), RankedTensorType::get({selfNumElems}, selfElemTy), self,
|
||||
rewriter.getDenseI64ArrayAttr({selfNumElems}));
|
||||
|
||||
// Calculate the target elements indices
|
||||
SmallVector<int32_t> targetIndicesVec;
|
||||
int64_t outputRank = outputSize.size();
|
||||
int64_t outputNumElems = std::accumulate(outputSize.begin(), outputSize.end(),
|
||||
1, std::multiplies<int64_t>());
|
||||
|
||||
for (int64_t i = 0; i < outputNumElems; i++) {
|
||||
// Index formula:
|
||||
// index[i] = coord_i_0 * stride[0] + coord_i_1 * stride[1] + ... +
|
||||
// coord_i_n * stride[n]
|
||||
int32_t index = offset;
|
||||
int64_t coordFinder = i;
|
||||
for (int64_t dim = 0; dim < outputRank; dim++) {
|
||||
int64_t indexCoord = coordFinder % outputSize[outputRank - dim - 1];
|
||||
index += indexCoord * stride[outputRank - dim - 1];
|
||||
coordFinder /= outputSize[outputRank - dim - 1];
|
||||
}
|
||||
targetIndicesVec.push_back(index);
|
||||
}
|
||||
|
||||
auto targetIndices =
|
||||
tosa::getConstTensor<int32_t>(rewriter, op, targetIndicesVec,
|
||||
makeShapeTorchCompatible({outputNumElems}))
|
||||
.value();
|
||||
|
||||
// Convert PyTorch-style indices and dim into TensorFlow-style indices
|
||||
auto targetIndicesTf = tosa::convertTorchIndexToTfIndices(
|
||||
rewriter, op, self1D.getResult(), targetIndices, 0);
|
||||
if (!targetIndicesTf)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Convert PyTorch-style indices and dim "
|
||||
"to TensorFlow-style indices failed");
|
||||
|
||||
// Gather the target elements from 1-D input tensor
|
||||
// Apply TensorFlow GatherNdOp with TensorFlow-style indices to retrieve the
|
||||
// target elements
|
||||
auto gatherOp = tosa::convertGatherNdOp(
|
||||
rewriter, op,
|
||||
RankedTensorType::get(makeShapeTorchCompatible({outputNumElems}),
|
||||
resultElemTy),
|
||||
self1D.getResult(), targetIndicesTf.value());
|
||||
|
||||
if (!gatherOp)
|
||||
return rewriter.notifyMatchFailure(op, "Convert GatherNdOp failed");
|
||||
|
||||
auto result = rewriter.create<tosa::ReshapeOp>(
|
||||
op->getLoc(), resultType, gatherOp.value(),
|
||||
rewriter.getDenseI64ArrayAttr(outputSize));
|
||||
|
||||
rewriter.replaceOp(op, {result.getResult()});
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -7096,6 +7198,7 @@ public:
|
|||
INSERT_ATENOP_PATTERN(AtenDiagEmbedOp);
|
||||
INSERT_ATENOP_PATTERN(AtenUniformOp);
|
||||
INSERT_ATENOP_PATTERN(AtenThresholdBackwardOp);
|
||||
INSERT_ATENOP_PATTERN(AtenAsStridedOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
||||
|
|
|
@ -1723,6 +1723,9 @@ TOSA_CRASHING_SET = {
|
|||
}
|
||||
|
||||
FX_IMPORTER_TOSA_CRASHING_SET = {
|
||||
"Aten_TrilinearModuleSumAllDims_basic",
|
||||
"Aten_TrilinearModuleSumdims_basic",
|
||||
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
|
||||
"GridSamplerBasic1_basic",
|
||||
"GridSamplerBasic2_basic",
|
||||
"GridSamplerBasic3_basic",
|
||||
|
@ -1744,6 +1747,13 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
|
|||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
# and very few tests work yet.
|
||||
TOSA_PASS_SET = {
|
||||
"Aten_TrilinearModuleSumAllDims_basic",
|
||||
"Aten_TrilinearModuleSumdims_basic",
|
||||
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
|
||||
"Aten_TrilinearModuleVaryingRanks_basic",
|
||||
"Aten_TrilinearModule_basic",
|
||||
"ElementwiseAddBoolModule_basic",
|
||||
"Exp2StaticModule_basic",
|
||||
"CosineSimilarityStaticBroadcastModule_basic",
|
||||
"DropoutTrainStaticShapeModule_basic",
|
||||
"ElementwiseAtenLogicalAndOpModule_basic",
|
||||
|
@ -1937,10 +1947,6 @@ TOSA_PASS_SET = {
|
|||
"ElementwiseTruncIntModule_basic",
|
||||
"ElementwiseSgnModule_basic",
|
||||
"ElementwiseSignIntModule_basic",
|
||||
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
|
||||
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
||||
"AddCDivModule_basic",
|
||||
"AddCDiv_Module_basic",
|
||||
"AddCMulModule_basic",
|
||||
|
@ -2292,7 +2298,6 @@ TOSA_PASS_SET = {
|
|||
"RreluWithNoiseBackwardTrainStaticModule_basic",
|
||||
"RepeatModule_basic",
|
||||
"RepeatInterleaveSelfIntNoDimModule_basic",
|
||||
"ResNet18StaticModule_basic",
|
||||
"ReshapeAliasCollapseModule_basic",
|
||||
"ReshapeAliasExpandModule_basic",
|
||||
"ReshapeAsModule_basic",
|
||||
|
@ -2416,6 +2421,9 @@ MAKE_FX_TOSA_PASS_SET = (
|
|||
TOSA_PASS_SET
|
||||
| {
|
||||
### Tests additionally passing in make_fx_tosa
|
||||
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
|
||||
"ResNet18StaticModule_basic",
|
||||
"AdaptiveAvgPool1dStaticLargerOutput_basic",
|
||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
|
||||
|
@ -2464,9 +2472,7 @@ MAKE_FX_TOSA_PASS_SET = (
|
|||
"MaxPool1dEmptyStrideStaticModule_basic",
|
||||
"MaxPool1dStaticCeilModeTrueModule_basic",
|
||||
"MaxPool1dStaticModule_basic",
|
||||
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
|
||||
"CosineSimilarityModule_basic",
|
||||
"NativeGroupNormBackwardModule_basic",
|
||||
"ReduceFrobeniusNormKeepDimModule_basic",
|
||||
|
@ -2474,8 +2480,6 @@ MAKE_FX_TOSA_PASS_SET = (
|
|||
"SliceWholeTensorModule_basic",
|
||||
"TensorFloatModule_basic",
|
||||
"TensorIntModule_basic",
|
||||
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
|
||||
"RepeatInterleaveSelfIntModule_basic",
|
||||
"TorchPrimLoopForLikeTensorArgModule_basic",
|
||||
"ViewSizeDimFollowedByCollapsedOnesModule_basic",
|
||||
|
@ -2492,13 +2496,6 @@ MAKE_FX_TOSA_PASS_SET = (
|
|||
}
|
||||
) - {
|
||||
### Test failing in make_fx_tosa but not in tosa
|
||||
"ChunkListUnpackUneven_Module_basic",
|
||||
"ChunkListUnpack_Module_basic",
|
||||
"SplitTensorGetItem_Module_basic",
|
||||
"SplitTensorLastSmallerModule_basic",
|
||||
"SplitTensorListUnpackModule_basic",
|
||||
"SplitTensorNegativeDimModule_basic",
|
||||
"SplitWithSizesListUnpackModule_basic",
|
||||
# Dynamic shape, has extra unsupported broadcast ops
|
||||
"Matmul_3d",
|
||||
# Unimplemented operator 'aten._index_put_impl_.hacked_twin'
|
||||
|
@ -3367,6 +3364,8 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
|
|||
}
|
||||
|
||||
FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||
"Aten_TrilinearModuleVaryingRanks_basic",
|
||||
"Aten_TrilinearModuleZerodDimBug_basic",
|
||||
"AdaptiveMaxPool1dDimOneStatic_basic",
|
||||
"ElementwiseRreluWithNoiseTrainModule_basic",
|
||||
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
|
||||
|
@ -3387,16 +3386,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"SliceOutOfUpperBoundIndexModule_basic",
|
||||
"SliceOutOfUpperBoundIndexStaticModule_basic",
|
||||
"SliceStartEqEndModule_basic",
|
||||
"ChunkListUnpackDynamic_Module_basic",
|
||||
"ChunkListUnpackUnevenDynamic_Module_basic",
|
||||
"ChunkListUnpackUneven_Module_basic",
|
||||
"ChunkListUnpack_Module_basic",
|
||||
"SplitTensorGetItem_Module_basic",
|
||||
"SplitTensorLastSmallerModule_basic",
|
||||
"SplitTensorListUnpackModule_basic",
|
||||
"SplitTensorNegativeDimModule_basic",
|
||||
"SplitWithSizesListUnpackModule_basic",
|
||||
"SplitWithSizes_Module_basic",
|
||||
"ElementwiseCreateComplexModule_basic",
|
||||
"AtenPolarDoubleModule_basic",
|
||||
"AtenPolarFloatModule_basic",
|
||||
|
@ -3572,7 +3561,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ElementwiseAcosModule_basic",
|
||||
"ElementwiseAcoshIntModule_basic",
|
||||
"ElementwiseAcoshModule_basic",
|
||||
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
|
||||
"ElementwiseAsinIntModule_basic",
|
||||
"ElementwiseAsinModule_basic",
|
||||
"ElementwiseAsinhIntModule_basic",
|
||||
|
@ -3608,7 +3596,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ElementwiseLog1pModule_basic",
|
||||
"ElementwiseLog2IntModule_basic",
|
||||
"ElementwiseLogIntModule_basic",
|
||||
"ElementwiseLogSigmoidModule_basic",
|
||||
"ElementwiseLogitModule_basic",
|
||||
"ElementwiseMishModule_basic",
|
||||
"ElementwiseMulTensorComplexDiffModule_basic",
|
||||
|
@ -3731,8 +3718,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"NormScalarModule_basic",
|
||||
"NormScalarOptDimKeepDimComplexModule_basic",
|
||||
"NormalFunctionalModule_basic",
|
||||
"NumToTensorFloatModule_basic",
|
||||
"NumToTensorIntModule_basic",
|
||||
"NumelModule_basic",
|
||||
"NumelZeroRankModule_basic",
|
||||
"OnesLikeModule_falsePinMemory",
|
||||
|
@ -3790,7 +3775,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ReplicationPad2dModule_right0",
|
||||
"ReplicationPad2dModule_top0",
|
||||
"RollModule_basic",
|
||||
"RsubInt0d_NumToTensor_Module_basic",
|
||||
"RsubIntModule_noalpha_basic",
|
||||
"ScalarConstantTupleModule_basic",
|
||||
"ScalarImplicitFloatModule_basic",
|
||||
|
@ -3873,6 +3857,55 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ViewSizeFromOtherTensor_basic",
|
||||
"VisionTransformerModule_basic",
|
||||
"ZerosLikeModule_falsePinMemory",
|
||||
# count_include_pad and divisor_override check in TOSA AvgPool2d
|
||||
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
|
||||
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule_basic",
|
||||
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
|
||||
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic",
|
||||
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
|
||||
"ResNet18Module_basic",
|
||||
"ResNet18StaticModule_basic",
|
||||
"MobilenetV3Module_basic",
|
||||
# Unexpected failures due to new PyTorch version update
|
||||
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
||||
"AdaptiveAvgPool1dGeneralDynamic_basic",
|
||||
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
|
||||
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
|
||||
"AdaptiveAvgPool1dStaticLargerOutput_basic",
|
||||
"AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic",
|
||||
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool2dDynamicNoBatch_basic",
|
||||
"AdaptiveAvgPool2dDynamic_basic",
|
||||
"CrossEntropyLossModule_basic",
|
||||
"CrossEntropyLossNoReductionModule_basic",
|
||||
"ElementwiseRreluTrainModule_basic",
|
||||
"ElementwiseRreluTrainStaticModule_basic",
|
||||
"IndexPutImpl1DFloatNonAccumulateModule_basic",
|
||||
"IndexPutImpl1DIntNonAccumulateModule_basic",
|
||||
"IndexPutImpl2DFloatNonAccumulateModule_basic",
|
||||
"IndexPutImpl3DFloatNonAccumulateModule_basic",
|
||||
"IouOfModule_basic",
|
||||
"MaxPool1dEmptyStrideStaticModule_basic",
|
||||
"MaxPool1dStaticCeilModeTrueModule_basic",
|
||||
"MaxPool1dStaticModule_basic",
|
||||
"MeshgridIndexingIJ_basic",
|
||||
"MeshgridIndexingXY_basic",
|
||||
"Meshgrid_basic",
|
||||
"OneHotModule_basic",
|
||||
"ReduceFrobeniusNormKeepDimModule_basic",
|
||||
"ReduceFrobeniusNormModule_basic",
|
||||
"RepeatInterleaveSelfIntModule_basic",
|
||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
|
||||
"ScaledDotProductAttentionDifferentModule_basic",
|
||||
"ScaledDotProductAttentionMaskModule_basic",
|
||||
"ScaledDotProductAttentionSameCausalModule_basic",
|
||||
"ScaledDotProductAttentionSameDynamicModule_basic",
|
||||
"ScaledDotProductAttentionSameModule_basic",
|
||||
}
|
||||
|
||||
ONNX_TOSA_CRASHING_SET = {
|
||||
|
@ -3885,6 +3918,7 @@ ONNX_TOSA_CRASHING_SET = {
|
|||
}
|
||||
|
||||
ONNX_TOSA_XFAIL_SET = {
|
||||
"Exp2StaticModule_basic",
|
||||
"ElementwiseRreluWithNoiseEvalModule_basic",
|
||||
"ElementwiseRreluWithNoiseEvalStaticModule_basic",
|
||||
"ElementwiseRreluWithNoiseTrainModule_basic",
|
||||
|
|
|
@ -2275,3 +2275,41 @@ func.func @torch.aten.uniform$basic(%arg0: !torch.vtensor<[3,4],f64>) -> (!torch
|
|||
%0 = torch.aten.uniform %arg0, %float1.000000e00, %float1.000000e01, %none : !torch.vtensor<[3,4],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[3,4],f64>
|
||||
return %0, %0 : !torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.as_strided$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[3,3],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_5:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_5]], %[[VAL_5]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 25>} : (tensor<5x5xf32>) -> tensor<25xf32>
|
||||
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<[0, 1, 2, 2, 3, 4, 4, 5, 6]> : tensor<9xi32>}> : () -> tensor<9xi32>
|
||||
// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array<i64: 9, 1>} : (tensor<9xi32>) -> tensor<9x1xi32>
|
||||
// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_10]] {axis = 1 : i32} : (tensor<9x1xi32>) -> tensor<9x1xi32>
|
||||
// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array<i64: 1, 25, 1>} : (tensor<25xf32>) -> tensor<1x25x1xf32>
|
||||
// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array<i64: 9, 1>} : (tensor<9x1xi32>) -> tensor<9x1xi32>
|
||||
// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
|
||||
// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_14]] {shift = 0 : i8} : (tensor<9x1xi32>, tensor<1xi32>) -> tensor<9x1xi32>
|
||||
// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<9x1xi32>) -> tensor<9x1xi32>
|
||||
// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array<i64: 1, 9>} : (tensor<9x1xi32>) -> tensor<1x9xi32>
|
||||
// CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_12]], %[[VAL_17]] : (tensor<1x25x1xf32>, tensor<1x9xi32>) -> tensor<1x9x1xf32>
|
||||
// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array<i64: 9>} : (tensor<1x9x1xf32>) -> tensor<9xf32>
|
||||
// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array<i64: 3, 3>} : (tensor<9xf32>) -> tensor<3x3xf32>
|
||||
// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<3x3xf32> -> !torch.vtensor<[3,3],f32>
|
||||
// CHECK: return %[[VAL_21]] : !torch.vtensor<[3,3],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.as_strided$basic(%arg0: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[3,3],f32> {
|
||||
%none = torch.constant.none
|
||||
%int1 = torch.constant.int 1
|
||||
%int2 = torch.constant.int 2
|
||||
%int3 = torch.constant.int 3
|
||||
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%1 = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%2 = torch.aten.as_strided %arg0, %0, %1, %none : !torch.vtensor<[5,5],f32>, !torch.list<int>, !torch.list<int>, !torch.none -> !torch.vtensor<[3,3],f32>
|
||||
return %2 : !torch.vtensor<[3,3],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue