[TOSA] Add legalization for empty, scatter, slice_scatter, diag_embed (#3792)

- Add Torch to TOSA legalization for the following ops:
  + aten.empty.memory_format
  + aten.scatter.src
  + aten.slice_scatter
  + aten.diag_embed
- Update xfail_sets.py with new e2e results
- Update basic.mlir with new LIT tests


Change-Id: I817ecf207bcfcf97ca54f30c10c76c4f0f4145ae

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
pull/3798/head
Justin Ngo 2024-10-15 08:38:02 -07:00 committed by GitHub
parent 895f490cf5
commit 45bb17ebfe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 584 additions and 58 deletions

View File

@ -4360,6 +4360,221 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
return success(); return success();
} }
// Legalization for aten.scatter.src
template <>
LogicalResult ConvertAtenOp<AtenScatterSrcOp>::matchAndRewrite(
AtenScatterSrcOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Not a tensor type.
auto input = adaptor.getSelf();
auto inputType = dyn_cast<RankedTensorType>(input.getType());
if (!inputType)
return rewriter.notifyMatchFailure(
op, "Only RankedTensorType inputs are currently supported");
auto inputShape = inputType.getShape();
auto paramsRank = inputType.getRank();
auto index = adaptor.getIndex();
auto indexType = dyn_cast<RankedTensorType>(index.getType());
if (!indexType)
return rewriter.notifyMatchFailure(
op, "Only RankedTensorType indices are currently supported");
// Check `index` and `input` param should have the same rank
if (indexType.getRank() != paramsRank)
return rewriter.notifyMatchFailure(
op, "Params index and input should have the same rank");
auto indexShape = indexType.getShape();
auto src = adaptor.getSrc();
auto srcType = dyn_cast<RankedTensorType>(src.getType());
if (!srcType)
return rewriter.notifyMatchFailure(
op, "Only RankedTensorType sources are currently supported");
// Check `src` and `input` param should have the same rank
if (srcType.getRank() != paramsRank)
return rewriter.notifyMatchFailure(
op, "Src and input should have the same rank");
auto srcShape = srcType.getShape();
// Dynamic shape check
if (!inputType.hasStaticShape() || !indexType.hasStaticShape() ||
!srcType.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "Support for dynamic shape not implemented");
// index i64 to i32 for tosa compatitable
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
index = rewriter.create<tosa::CastOp>(
op->getLoc(),
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index);
}
// Get positive dim
int64_t dim{0};
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op,
"Dim value should be a constant int");
dim = toPositiveDim(dim, paramsRank);
if (!isValidDim(dim, paramsRank))
return rewriter.notifyMatchFailure(op, "Dim is invalid");
// It is also required that index.size(d) <= src.size(d) for all dimensions d,
// and that index.size(d) <= self.size(d) for all dimensions d != dim
for (int64_t d = 0; d < paramsRank; d++) {
if (d != dim) {
if (indexShape[d] > srcShape[d] || indexShape[d] > inputShape[d])
return rewriter.notifyMatchFailure(
op, "Index size should be smaller or equal to src or input size "
"for all dimensions d != dim");
}
}
// Get the output type
auto outType = getTypeConverter()->convertType(op.getType());
// convert PyTorch style index and dim into TensorFlows tyle indices
// tensor<[1,4,2],si64> -> tensor<[1,4,2,3],si64>
auto indicesTf =
tosa::convertTorchIndexToTfIndices(rewriter, op, input, index, dim);
if (!indicesTf)
return rewriter.notifyMatchFailure(
op, "Convert PyTorch index and dim to TensorFlow indices failed");
// Perform the TensorFlow ScatterNd algorithm with TensorFlow style indices as
// input.
auto result = tosa::convertScatterNdOp(rewriter, op, outType, input,
indicesTf.value(), src);
if (!result)
return rewriter.notifyMatchFailure(op, "Convert ScatterNdOp failed");
rewriter.replaceOp(op, {result.value()});
return success();
}
// Legalization for aten.slice_scatter
template <>
LogicalResult ConvertAtenOp<AtenSliceScatterOp>::matchAndRewrite(
AtenSliceScatterOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Not a tensor type.
auto input = adaptor.getSelf();
auto inputType = dyn_cast<RankedTensorType>(input.getType());
if (!inputType)
return rewriter.notifyMatchFailure(
op, "Only RankedTensorType inputs are currently supported");
auto inputShape = inputType.getShape();
auto paramsRank = inputType.getRank();
auto src = adaptor.getSrc();
auto srcType = dyn_cast<RankedTensorType>(src.getType());
if (!srcType)
return rewriter.notifyMatchFailure(
op, "Only RankedTensorType sources are currently supported");
// Check `src` and `input` param should have the same rank
if (srcType.getRank() != paramsRank)
return rewriter.notifyMatchFailure(
op, "Src and input should have the same rank");
auto srcShape = srcType.getShape();
// Dynamic shape check
if (!inputType.hasStaticShape() || !srcType.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "Support for dynamic shape not implemented");
// Get positive dim
int64_t dim{0};
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op,
"Dim value should be a constant int");
dim = toPositiveDim(dim, paramsRank);
if (!isValidDim(dim, paramsRank))
return rewriter.notifyMatchFailure(op, "Dim is invalid");
// Get start, end, and step params
// If start and end params are not specified, assign them to 0 and
// inputShape[dim], respectively.
int64_t start{0};
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start)))
return rewriter.notifyMatchFailure(op,
"Start value should be a constant int");
if (start < 0)
start += inputShape[dim];
int64_t end{inputShape[dim]};
if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end)))
return rewriter.notifyMatchFailure(op,
"End value should be a constant int");
if (end < 0)
end += inputShape[dim];
if (end > inputShape[dim])
end = inputShape[dim];
if (start >= end)
return rewriter.notifyMatchFailure(
op, "Start value greater than end value not supported");
int64_t step{1};
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step)))
return rewriter.notifyMatchFailure(op,
"Step value should be a constant int");
// Create PyTorch style scatter index based on start, end, and step values
int64_t outerRepeat{1}, innerRepeat{1};
for (int64_t i = 0; i < dim; i++)
outerRepeat *= srcShape[i];
for (int64_t i = dim + 1; i < paramsRank; i++)
innerRepeat *= srcShape[i];
SmallVector<int32_t> indexVec;
for (int64_t i = 0; i < outerRepeat; i++) {
for (int32_t indexVal = start; indexVal < end; indexVal += step) {
for (int64_t j = 0; j < innerRepeat; j++) {
indexVec.push_back(indexVal);
}
}
}
Value index =
tosa::getConstTensor<int32_t>(rewriter, op, indexVec, srcShape).value();
// Get the output type
auto outType = getTypeConverter()->convertType(op.getType());
// convert PyTorch style index and dim into TensorFlows tyle indices
// tensor<[1,4,2],si64> -> tensor<[1,4,2,3],si64>
auto indicesTf =
tosa::convertTorchIndexToTfIndices(rewriter, op, input, index, dim);
if (!indicesTf)
return rewriter.notifyMatchFailure(
op, "Convert PyTorch index and dim to TensorFlow indices failed");
// Perform the TensorFlow ScatterNd algorithm with TensorFlow style indices as
// input.
auto result = tosa::convertScatterNdOp(rewriter, op, outType, input,
indicesTf.value(), src);
if (!result)
return rewriter.notifyMatchFailure(op, "Convert ScatterNdOp failed");
rewriter.replaceOp(op, {result.value()});
return success();
}
template <> template <>
LogicalResult ConvertAtenOp<AtenAbsOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenAbsOp>::matchAndRewrite(
AtenAbsOp op, OpAdaptor adaptor, AtenAbsOp op, OpAdaptor adaptor,
@ -6099,6 +6314,10 @@ LogicalResult ConvertAtenOp<AtenDiagonalOp>::matchAndRewrite(
dim2 = toPositiveDim(dim2, selfRank); dim2 = toPositiveDim(dim2, selfRank);
} }
if (dim1 == dim2)
return rewriter.notifyMatchFailure(op,
"Values dim1 and dim2 cannot be equal");
auto selfShape = makeShapeTorchCompatible(selfType.getShape()); auto selfShape = makeShapeTorchCompatible(selfType.getShape());
int64_t h = selfShape[dim1]; int64_t h = selfShape[dim1];
int64_t w = selfShape[dim2]; int64_t w = selfShape[dim2];
@ -6122,13 +6341,13 @@ LogicalResult ConvertAtenOp<AtenDiagonalOp>::matchAndRewrite(
SmallVector<int32_t> transposedDims; SmallVector<int32_t> transposedDims;
transposedInputShape.clear(); transposedInputShape.clear();
for (int64_t i = 0; i < selfRank; ++i) { for (int32_t i = 0; i < selfRank; ++i) {
if (i == dim1 || i == dim2) if (i == dim1 || i == dim2)
continue; continue;
transposedDims.push_back(i); transposedDims.push_back(i);
} }
transposedDims.push_back(dim1); transposedDims.push_back(static_cast<int32_t>(dim1));
transposedDims.push_back(dim2); transposedDims.push_back(static_cast<int32_t>(dim2));
auto transposedDimsConst = tosa::getConstTensor<int32_t>( auto transposedDimsConst = tosa::getConstTensor<int32_t>(
rewriter, op, rewriter, op,
@ -6213,6 +6432,193 @@ LogicalResult ConvertAtenOp<AtenDiagonalOp>::matchAndRewrite(
return success(); return success();
} }
// Legalization for aten.diag_embed
template <>
LogicalResult ConvertAtenOp<AtenDiagEmbedOp>::matchAndRewrite(
AtenDiagEmbedOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// To perform diag_embed, we will apply scatter with a newly created diagonal
// index tensor over a constant zero tensor.
// To make it simpler, we will only scatter using the diagonal with respect
// to the two innermost dimensions, then permute the output tensor to the
// correct order of dimensions.
auto self = adaptor.getSelf();
// Not a ranked tensor type
auto selfType = dyn_cast<RankedTensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types are supported");
auto selfRank = selfType.getRank();
int64_t outRank = selfRank + 1;
auto selfShape = makeShapeTorchCompatible(selfType.getShape());
int64_t diagSize = selfShape[selfRank - 1];
if (!selfType.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "Currently only static shapes are supported");
const TypeConverter *typeConverter = this->getTypeConverter();
RankedTensorType resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType()));
if (!resultType)
return rewriter.notifyMatchFailure(op, "Result type cannot be empty");
auto selfElemTy = selfType.getElementType();
auto resultElemTy = resultType.getElementType();
int64_t offset{0};
if (!matchPattern(op.getOffset(), m_TorchConstantInt(&offset)))
return rewriter.notifyMatchFailure(op,
"Offset value should be a constant int");
// dim1 default is -2
int64_t dim1{outRank - 2};
if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1)))
return rewriter.notifyMatchFailure(op,
"Dim1 value should be a constant int");
dim1 = toPositiveDim(dim1, outRank);
// dim2 default is -1
int64_t dim2{outRank - 1};
if (!matchPattern(op.getDim2(), m_TorchConstantInt(&dim2)))
return rewriter.notifyMatchFailure(op,
"Dim2 value should be a constant int");
dim2 = toPositiveDim(dim2, outRank);
if (dim1 == dim2)
return rewriter.notifyMatchFailure(op, "Dim1 and dim2 cannot be equal");
// If offset is smaller than 0, we will swap dim1 and dim2 and convert offset
// to a positive value
if (offset < 0) {
std::swap(dim1, dim2);
offset = std::abs(offset);
}
// Create the diagonal index tensor
int64_t repeat = 1;
for (int64_t i = 0; i < selfRank - 1; i++)
repeat *= selfShape[i];
SmallVector<int32_t> indexVec;
for (int32_t i = 0; i < repeat; i++) {
for (int32_t j = offset; j < diagSize + offset; j++)
indexVec.push_back(j);
}
SmallVector<int64_t> indexShape = llvm::to_vector(selfShape);
indexShape.push_back(1);
auto index = tosa::getConstTensor<int32_t>(rewriter, op,
/*vec=*/indexVec,
/*shape=*/indexShape)
.value();
// Reshape the input tensor to be the same shape as the new index tensor to
// act as the src for scattering
auto scatterSrc = rewriter.create<tosa::ReshapeOp>(
op->getLoc(),
RankedTensorType::get(makeShapeTorchCompatible(indexShape), selfElemTy),
self, rewriter.getDenseI64ArrayAttr(indexShape));
// Create a const zero tensor to scatter the input onto
SmallVector<int64_t> zeroShape;
for (int64_t i = 0; i < selfRank - 1; i++)
zeroShape.push_back(selfShape[i]);
zeroShape.push_back(diagSize + offset);
zeroShape.push_back(diagSize + offset);
int64_t numElemOfZeroTensor = 1;
for (int64_t &d : zeroShape)
numElemOfZeroTensor *= d;
Value zero =
TypeSwitch<Type, Value>(selfElemTy)
.Case<mlir::FloatType>([&](auto) {
return tosa::getConstTensor<float>(
rewriter, op, SmallVector<float>(numElemOfZeroTensor, 0),
zeroShape)
.value();
})
.Case<mlir::IntegerType>([&](auto intType) {
switch (intType.getWidth()) {
case 1:
return tosa::getConstTensor<bool>(
rewriter, op,
SmallVector<bool>(numElemOfZeroTensor, 0), zeroShape)
.value();
case 32:
return tosa::getConstTensor<int32_t>(
rewriter, op,
SmallVector<int32_t>(numElemOfZeroTensor, 0),
zeroShape)
.value();
case 64:
return tosa::getConstTensor<int64_t>(
rewriter, op,
SmallVector<int64_t>(numElemOfZeroTensor, 0),
zeroShape)
.value();
}
llvm_unreachable("Invalid integer width");
});
// Convert PyTorch index and dim to TensorFlow-style indices
auto indicesTf = tosa::convertTorchIndexToTfIndices(rewriter, op, zero, index,
outRank - 1);
if (!indicesTf)
return rewriter.notifyMatchFailure(
op, "Convert PyTorch index and dim to TensorFlow indices failed");
// Perform the TensorFlow ScatterNd algorithm with TensorFlow-style indices as
// input
auto diagonalTensor = tosa::convertScatterNdOp(
rewriter, op,
RankedTensorType::get(makeShapeTorchCompatible(zeroShape), resultElemTy),
zero, indicesTf.value(), scatterSrc.getResult());
if (!diagonalTensor)
return rewriter.notifyMatchFailure(op, "Convert ScatterNdOp failed");
// Create the final dims order to permute the scattered tensor
SmallVector<int32_t> permutedDims(outRank, 0);
int32_t currentDim = 0;
int32_t i = 0;
while (i < outRank) {
if (i == dim1) {
permutedDims[i] = outRank - 2;
i++;
continue;
}
if (i == dim2) {
permutedDims[i] = outRank - 1;
i++;
continue;
}
permutedDims[i] = currentDim;
currentDim++;
i++;
}
auto permutedDimsConst =
tosa::getConstTensor<int32_t>(rewriter, op,
/*vec=*/permutedDims,
/*shape=*/{static_cast<int32_t>(outRank)});
auto result = rewriter.create<tosa::TransposeOp>(op->getLoc(), resultType,
diagonalTensor.value(),
permutedDimsConst.value());
rewriter.replaceOp(op, result.getResult());
return success();
}
} // namespace } // namespace
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
@ -6442,6 +6848,7 @@ public:
context); context);
INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1); INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1);
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0); INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
INSERT_CONSTANT_FILL_PATTERN(AtenEmptyMemoryFormatOp, 0);
#undef INSERT_CONSTANT_FILL_PATTERN #undef INSERT_CONSTANT_FILL_PATTERN
#define INSERT_FILL_PATTERN(AtenOp) \ #define INSERT_FILL_PATTERN(AtenOp) \
@ -6524,6 +6931,9 @@ public:
INSERT_ATENOP_PATTERN(AtenIndexSelectOp); INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
INSERT_ATENOP_PATTERN(AtenFlipOp); INSERT_ATENOP_PATTERN(AtenFlipOp);
INSERT_ATENOP_PATTERN(AtenRoundOp); INSERT_ATENOP_PATTERN(AtenRoundOp);
INSERT_ATENOP_PATTERN(AtenScatterSrcOp);
INSERT_ATENOP_PATTERN(AtenSliceScatterOp);
INSERT_ATENOP_PATTERN(AtenDiagEmbedOp);
#undef INSERT_ATENOP_PATTERN #undef INSERT_ATENOP_PATTERN
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \

View File

@ -1650,12 +1650,18 @@ STABLEHLO_CRASHING_SET = {
} }
TOSA_CRASHING_SET = { TOSA_CRASHING_SET = {
"ArangeStartOutDtypeModule_basic",
"ArangeStartOutModule_basic",
"ScatterSrcStaticModule_basic",
# Runtime op verification: Out of bounds access # Runtime op verification: Out of bounds access
"IndexTensorNegativeIndexModule_basic", "IndexTensorNegativeIndexModule_basic",
"ReduceAllDimEmpty_basic", "ReduceAllDimEmpty_basic",
} }
FX_IMPORTER_TOSA_CRASHING_SET = { FX_IMPORTER_TOSA_CRASHING_SET = {
"ScatterSrcModule_basic",
"ScatterSrcStaticModule_basic",
"HBC_basic",
"IndexTensorNegativeIndexModule_basic", "IndexTensorNegativeIndexModule_basic",
"InterpolateDynamicModule_scales_recompute_bilinear", "InterpolateDynamicModule_scales_recompute_bilinear",
"InterpolateDynamicModule_sizes_bilinear", "InterpolateDynamicModule_sizes_bilinear",
@ -1671,6 +1677,26 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
# Write the TOSA set as a "passing" set as it is very early in development # Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet. # and very few tests work yet.
TOSA_PASS_SET = { TOSA_PASS_SET = {
"EmptyModule_contiguous",
"EmptyModule_defaultDtype",
"EmptyModule_falsePinMemory",
"EmptyModule_float",
"EmptyModule_int",
"EmptyModule_uint8",
"EmptyStridedModule_basic",
"NewEmptyModuleDefaultDtype_basic",
"NewEmptyModuleFalsePinMemory_basic",
"NewEmptyModuleFloat2D_basic",
"NewEmptyModuleFloat3D_basic",
"NewEmptyModuleInt2D_basic",
"NewEmptyModuleInt3D_basic",
"NewEmptyModuleLayoutIntDtype_basic",
"NewEmptyModuleNonDefaultFloatDtype_basic",
"NewEmptyModuleNonDefaultIntDtype_basic",
"NewEmptyStridedModuleDefaultDtype_basic",
"SelectScattertStaticModule_basic",
"SliceScatterStaticModule_basic",
"TensorAlloc1dStaticModule_basic",
"AtenRoundFloatHalfToEvenModule_basic", "AtenRoundFloatHalfToEvenModule_basic",
"AtenRoundFloatModule_basic", "AtenRoundFloatModule_basic",
"FakeQuantizePerTensorAffineCachemaskModule_basic", "FakeQuantizePerTensorAffineCachemaskModule_basic",
@ -3248,6 +3274,12 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
} }
FX_IMPORTER_TOSA_XFAIL_SET = { FX_IMPORTER_TOSA_XFAIL_SET = {
"ViewDtypeStaticModule_basic",
"Unfold_Module_Dynamic_basic",
"Unfold_Module_Rank_4",
"Unfold_Module_Rank_Zero_Size_Zero_basic",
"Unfold_Module_Rank_Zero_basic",
"Unfold_Module_basic",
"ArangeZeroElementOutputModule_basic", "ArangeZeroElementOutputModule_basic",
"NumpyTRank0Module_basic", "NumpyTRank0Module_basic",
"Permute0RankModule_basic", "Permute0RankModule_basic",
@ -3338,12 +3370,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"AtenComplexImagModule_basic", "AtenComplexImagModule_basic",
"AtenComplexRealModule_basic", "AtenComplexRealModule_basic",
"AtenComplexViewModule_basic", "AtenComplexViewModule_basic",
"AtenDiagEmbedDefaultDiag_basic",
"AtenDiagEmbedDimDiag_basic",
"AtenDiagEmbedNegOffsetDiag_basic",
"AtenDiagEmbedNonDefault4DDiag_basic",
"AtenDiagEmbedOffsetDiag_basic",
"AtenDiagEmbedRevDimDiag_basic",
"AtenEmbeddingBagStaticModule_basic", "AtenEmbeddingBagStaticModule_basic",
"AtenEmbeddingBagSumExample_basic", "AtenEmbeddingBagSumExample_basic",
"AtenEyeMModuleInt2D_basic", "AtenEyeMModuleInt2D_basic",
@ -3513,31 +3539,13 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ElementwiseToDtypeI64ToUI8Module_basic", "ElementwiseToDtypeI64ToUI8Module_basic",
"ElementwiseUnaryIntModule_basic", "ElementwiseUnaryIntModule_basic",
"ElementwiseWhereScalarOtherStaticModule_basic", "ElementwiseWhereScalarOtherStaticModule_basic",
"EmptyLikeMemoryFormatModule_basic",
"EmptyLikeModule_defaultDtype",
"EmptyLikeModule_falsePinMemory",
"EmptyLikeModule_float",
"EmptyLikeModule_int",
"EmptyModule_contiguous",
"EmptyModule_defaultDtype",
"EmptyModule_falsePinMemory",
"EmptyModule_float",
"EmptyModule_int",
"EmptyModule_uint8",
"EmptyStridedModule_basic",
"EmptyStridedSizeIntStrideModule_basic",
"EqIntModule_basic", "EqIntModule_basic",
"ExpandModule_basic", "ExpandModule_basic",
"ExponentialModule_basic", "ExponentialModule_basic",
"FloatImplicitModule_basic", "FloatImplicitModule_basic",
"FullLikeModuleInt2D_basic", "FullLikeModuleInt2D_basic",
"FullLikeModuleInt3D_basic", "FullLikeModuleInt3D_basic",
"FullModuleDefaultDtype_basic",
"FullModuleFalsePinMemory_basic",
"FullModuleFloat2D_basic",
"FullModuleFloat3D_basic",
"FullModuleInt2D_basic", "FullModuleInt2D_basic",
"FullModuleInt3D_basic",
"GeFloatIntModule_basic", "GeFloatIntModule_basic",
"GeFloatModule_basic", "GeFloatModule_basic",
"GeIntModule_basic", "GeIntModule_basic",
@ -3547,7 +3555,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"GridSamplerBasic4_basic", "GridSamplerBasic4_basic",
"GtFloatIntModule_basic", "GtFloatIntModule_basic",
"GtIntModule_basic", "GtIntModule_basic",
"HBC_basic",
"IndexPut1DFloatAccumulateModule_basic", "IndexPut1DFloatAccumulateModule_basic",
"IndexPut1DFloatNonAccumulateModule_basic", "IndexPut1DFloatNonAccumulateModule_basic",
"IndexPut1DIntAccumulateModule_basic", "IndexPut1DIntAccumulateModule_basic",
@ -3599,7 +3606,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"LinalgVectorNormComplexModule_basic", "LinalgVectorNormComplexModule_basic",
"LinspaceDtypeModule_basic", "LinspaceDtypeModule_basic",
"LinspaceEmptyModule_basic", "LinspaceEmptyModule_basic",
"LinspaceOneSizeModule_basic",
"MaskedFillTensorFloatValueModule_basic", "MaskedFillTensorFloatValueModule_basic",
"MatmulBroadcastBatchDim_basic", "MatmulBroadcastBatchDim_basic",
"MatmulStaticBroadcast_basic", "MatmulStaticBroadcast_basic",
@ -3653,16 +3659,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"NativeGroupNormBackwardModule_basic", "NativeGroupNormBackwardModule_basic",
"NeFloatIntModule_basic", "NeFloatIntModule_basic",
"NeIntModule_basic", "NeIntModule_basic",
"NewEmptyModuleDefaultDtype_basic",
"NewEmptyModuleFalsePinMemory_basic",
"NewEmptyModuleFloat2D_basic",
"NewEmptyModuleFloat3D_basic",
"NewEmptyModuleInt2D_basic",
"NewEmptyModuleInt3D_basic",
"NewEmptyModuleLayoutIntDtype_basic",
"NewEmptyModuleNonDefaultFloatDtype_basic",
"NewEmptyModuleNonDefaultIntDtype_basic",
"NewEmptyStridedModuleDefaultDtype_basic",
"NewFullModuleInt2D_basic", "NewFullModuleInt2D_basic",
"NewFullModuleInt3D_basic", "NewFullModuleInt3D_basic",
"NllLossModuleBackward1DMeanWeight_basic", "NllLossModuleBackward1DMeanWeight_basic",
@ -3671,13 +3667,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"NllLossModuleBackward1DSum_basic", "NllLossModuleBackward1DSum_basic",
"NllLossModuleBackward1DWeight_basic", "NllLossModuleBackward1DWeight_basic",
"NllLossModuleBackward1D_basic", "NllLossModuleBackward1D_basic",
"NllLossModuleBackwardMeanWeight_basic",
"NllLossModuleBackwardMean_basic",
"NllLossModuleBackwardSumWeight_basic",
"NllLossModuleBackwardSum_basic",
"NllLossModuleBackwardWeight_basic",
"NllLossModuleBackward_basic",
"NllLossModuleBackward_ignore_index",
"NormScalarComplexModule_basic", "NormScalarComplexModule_basic",
"NormScalarModule_basic", "NormScalarModule_basic",
"NormScalarOptDimKeepDimComplexModule_basic", "NormScalarOptDimKeepDimComplexModule_basic",
@ -3777,26 +3766,14 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ScatterSrcStaticModule_basic", "ScatterSrcStaticModule_basic",
"ScatterValueFloatModule_basic", "ScatterValueFloatModule_basic",
"ScatterValueIntModule_basic", "ScatterValueIntModule_basic",
"SelectScattertModule_basic",
"SelectScattertStaticModule_basic",
"SignAndLogarithmOfDeterminantModule_F32", "SignAndLogarithmOfDeterminantModule_F32",
"SignAndLogarithmOfDeterminantBatchedModule_F32", "SignAndLogarithmOfDeterminantBatchedModule_F32",
"SignAndLogarithmOfDeterminantDynamicModule_F32", "SignAndLogarithmOfDeterminantDynamicModule_F32",
"SliceStaticComplexInputModule_basic", "SliceStaticComplexInputModule_basic",
"SliceCopyEndGreaterThanDimSize_Module_basic",
"SliceCopyNegative_Module_basic",
"SliceCopyNonZeroDim_Module_basic",
"SliceCopyStartGreaterThanDimSize_Module_basic", "SliceCopyStartGreaterThanDimSize_Module_basic",
"SliceCopy_Module_basic",
"SliceEndSleStartModule_basic", "SliceEndSleStartModule_basic",
"SliceOutOfLowerBoundEndIndexModule_basic", "SliceOutOfLowerBoundEndIndexModule_basic",
"SliceOutOfLowerBoundStartIndexModule_basic", "SliceOutOfLowerBoundStartIndexModule_basic",
"SliceScatterModule_basic",
"SliceScatterNegativeDimModule_basic",
"SliceScatterNegativeEndModule_basic",
"SliceScatterStaticModule_basic",
"SliceScatterStepVariationModule_basic",
"SliceScatterZeroDimModule_basic",
"SliceSizeTwoStepModule_basic", "SliceSizeTwoStepModule_basic",
"SoftplusModule_basic", "SoftplusModule_basic",
"SortIntListReverse_basic", "SortIntListReverse_basic",
@ -3864,6 +3841,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
} }
ONNX_TOSA_CRASHING_SET = { ONNX_TOSA_CRASHING_SET = {
"ScatterSrcStaticModule_basic",
"StdCorrectionEmptyDimModule_basic", "StdCorrectionEmptyDimModule_basic",
"StdDimEmptyDimModule_basic", "StdDimEmptyDimModule_basic",
"VarCorrectionEmptyDimModule_basic", "VarCorrectionEmptyDimModule_basic",
@ -3872,6 +3850,11 @@ ONNX_TOSA_CRASHING_SET = {
} }
ONNX_TOSA_XFAIL_SET = { ONNX_TOSA_XFAIL_SET = {
"Unfold_Module_Dynamic_basic",
"Unfold_Module_Rank_4",
"Unfold_Module_Rank_Zero_Size_Zero_basic",
"Unfold_Module_Rank_Zero_basic",
"ViewDtypeStaticModule_basic",
"ArangeZeroElementOutputModule_basic", "ArangeZeroElementOutputModule_basic",
"LinspaceEmptyModule_basic", "LinspaceEmptyModule_basic",
"RepeatInterleaveSelfIntNoDimModule_basic", "RepeatInterleaveSelfIntNoDimModule_basic",

View File

@ -1998,3 +1998,136 @@ func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso
%0 = torch.aten.round %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> %0 = torch.aten.round %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32>
} }
// -----
// CHECK-LABEL: func.func @torch.aten.empty.memory_format$basic() -> !torch.vtensor<[3,4],si64> {
// CHECK: %[[VAL_0:.*]] = torch.constant.int 0
// CHECK: %[[VAL_1:.*]] = torch.constant.bool false
// CHECK: %[[VAL_2:.*]] = torch.constant.none
// CHECK: %[[VAL_3:.*]] = torch.constant.int 3
// CHECK: %[[VAL_4:.*]] = torch.constant.int 4
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_6:.*]] = torch.constant.device "cpu"
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0> : tensor<3x4xi32>}> : () -> tensor<3x4xi32>
// CHECK: %[[VAL_8:.*]] = tosa.cast %[[VAL_7]] : (tensor<3x4xi32>) -> tensor<3x4xi64>
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<0> : tensor<3x4xi64>}> : () -> tensor<3x4xi64>
// CHECK: %[[VAL_10:.*]] = tosa.cast %[[VAL_9]] : (tensor<3x4xi64>) -> tensor<3x4xi64>
// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<3x4xi64> -> !torch.vtensor<[3,4],si64>
// CHECK: return %[[VAL_11]] : !torch.vtensor<[3,4],si64>
// CHECK: }
func.func @torch.aten.empty.memory_format$basic() -> !torch.vtensor<[3,4],si64> {
%int0 = torch.constant.int 0
%false = torch.constant.bool false
%none = torch.constant.none
%int3 = torch.constant.int 3
%int4 = torch.constant.int 4
%0 = torch.prim.ListConstruct %int3, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
%cpu = torch.constant.device "cpu"
%1 = torch.aten.empty.memory_format %0, %int4, %none, %cpu, %false, %none : !torch.list<int>, !torch.int, !torch.none, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[3,4],si64>
%2 = torch.aten.fill.Scalar %1, %int0 : !torch.vtensor<[3,4],si64>, !torch.int -> !torch.vtensor<[3,4],si64>
return %2 : !torch.vtensor<[3,4],si64>
}
// -----
// CHECK-LABEL: func.func @torch.aten.scatter.src$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,8,6],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,4,3],si64>,
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[3,4,3],f32>) -> !torch.vtensor<[10,8,6],f32> {
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[3,4,3],f32> -> tensor<3x4x3xf32>
// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,4,3],si64> -> tensor<2x4x3xi64>
// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,8,6],f32> -> tensor<10x8x6xf32>
// CHECK: %[[VAL_6:.*]] = torch.constant.int 1
// CHECK: %[[VAL_7:.*]] = tosa.cast %[[VAL_4]] : (tensor<2x4x3xi64>) -> tensor<2x4x3xi32>
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 2, 4, 3, 1>} : (tensor<2x4x3xi32>) -> tensor<2x4x3x1xi32>
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0], [0]], {{\[\[}}0], [0], [0]], {{\[\[}}0], [0], [0]], {{\[\[}}0], [0], [0]]], {{\[\[}}[1], [1], [1]], {{\[\[}}1], [1], [1]], {{\[\[}}1], [1], [1]], {{\[\[}}1], [1], [1]]]]> : tensor<2x4x3x1xi32>}> : () -> tensor<2x4x3x1xi32>
// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]]], {{\[\[}}[0], [1], [2]], {{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]]]]> : tensor<2x4x3x1xi32>}> : () -> tensor<2x4x3x1xi32>
// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_9]], %[[VAL_8]], %[[VAL_10]] {axis = 3 : i32} : (tensor<2x4x3x1xi32>, tensor<2x4x3x1xi32>, tensor<2x4x3x1xi32>) -> tensor<2x4x3x3xi32>
// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 1, 36, 1>} : (tensor<3x4x3xf32>) -> tensor<1x36x1xf32>
// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array<i64: 1, 480, 1>} : (tensor<10x8x6xf32>) -> tensor<1x480x1xf32>
// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array<i64: 24, 3>} : (tensor<2x4x3x3xi32>) -> tensor<24x3xi32>
// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<[48, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_14]], %[[VAL_15]] {shift = 0 : i8} : (tensor<24x3xi32>, tensor<3xi32>) -> tensor<24x3xi32>
// CHECK: %[[VAL_17:.*]] = tosa.reduce_sum %[[VAL_16]] {axis = 1 : i32} : (tensor<24x3xi32>) -> tensor<24x1xi32>
// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array<i64: 1, 24>} : (tensor<24x1xi32>) -> tensor<1x24xi32>
// CHECK: %[[VAL_19:.*]] = tosa.scatter %[[VAL_13]], %[[VAL_18]], %[[VAL_12]] : (tensor<1x480x1xf32>, tensor<1x24xi32>, tensor<1x36x1xf32>) -> tensor<1x480x1xf32>
// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array<i64: 10, 8, 6>} : (tensor<1x480x1xf32>) -> tensor<10x8x6xf32>
// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<10x8x6xf32> -> !torch.vtensor<[10,8,6],f32>
// CHECK: return %[[VAL_21]] : !torch.vtensor<[10,8,6],f32>
// CHECK: }
func.func @torch.aten.scatter.src$basic(%arg0: !torch.vtensor<[10,8,6],f32>, %arg1: !torch.vtensor<[2,4,3],si64>, %arg2: !torch.vtensor<[3,4,3],f32>) -> !torch.vtensor<[10,8,6],f32> {
%int1 = torch.constant.int 1
%0 = torch.aten.scatter.src %arg0, %int1, %arg1, %arg2 : !torch.vtensor<[10,8,6],f32>, !torch.int, !torch.vtensor<[2,4,3],si64>, !torch.vtensor<[3,4,3],f32> -> !torch.vtensor<[10,8,6],f32>
return %0 : !torch.vtensor<[10,8,6],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.slice_scatter$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[6,8],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[6,1],f32>) -> !torch.vtensor<[6,8],f32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[6,1],f32> -> tensor<6x1xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[6,8],f32> -> tensor<6x8xf32>
// CHECK: %[[VAL_4:.*]] = torch.constant.int 1
// CHECK: %[[VAL_5:.*]] = torch.constant.int 0
// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0> : tensor<6x1xi32>}> : () -> tensor<6x1xi32>
// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array<i64: 6, 1, 1>} : (tensor<6x1xi32>) -> tensor<6x1x1xi32>
// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<{{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]], {{\[\[}}4]], {{\[\[}}5]]]> : tensor<6x1x1xi32>}> : () -> tensor<6x1x1xi32>
// CHECK: %[[VAL_9:.*]] = tosa.concat %[[VAL_8]], %[[VAL_7]] {axis = 2 : i32} : (tensor<6x1x1xi32>, tensor<6x1x1xi32>) -> tensor<6x1x2xi32>
// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array<i64: 1, 6, 1>} : (tensor<6x1xf32>) -> tensor<1x6x1xf32>
// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 1, 48, 1>} : (tensor<6x8xf32>) -> tensor<1x48x1xf32>
// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array<i64: 6, 2>} : (tensor<6x1x2xi32>) -> tensor<6x2xi32>
// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[8, 1]> : tensor<2xi32>}> : () -> tensor<2xi32>
// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_12]], %[[VAL_13]] {shift = 0 : i8} : (tensor<6x2xi32>, tensor<2xi32>) -> tensor<6x2xi32>
// CHECK: %[[VAL_15:.*]] = tosa.reduce_sum %[[VAL_14]] {axis = 1 : i32} : (tensor<6x2xi32>) -> tensor<6x1xi32>
// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array<i64: 1, 6>} : (tensor<6x1xi32>) -> tensor<1x6xi32>
// CHECK: %[[VAL_17:.*]] = tosa.scatter %[[VAL_11]], %[[VAL_16]], %[[VAL_10]] : (tensor<1x48x1xf32>, tensor<1x6xi32>, tensor<1x6x1xf32>) -> tensor<1x48x1xf32>
// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array<i64: 6, 8>} : (tensor<1x48x1xf32>) -> tensor<6x8xf32>
// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<6x8xf32> -> !torch.vtensor<[6,8],f32>
// CHECK: return %[[VAL_19]] : !torch.vtensor<[6,8],f32>
// CHECK: }
func.func @torch.aten.slice_scatter$basic(%arg0: !torch.vtensor<[6,8],f32>, %arg1: !torch.vtensor<[6,1],f32>) -> !torch.vtensor<[6,8],f32> {
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%0 = torch.aten.slice_scatter %arg0, %arg1, %int1, %int0, %int1, %int1 : !torch.vtensor<[6,8],f32>, !torch.vtensor<[6,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[6,8],f32>
return %0 : !torch.vtensor<[6,8],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.diag_embed$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,4],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
// CHECK: %[[VAL_3:.*]] = torch.constant.int -2
// CHECK: %[[VAL_4:.*]] = torch.constant.int -1
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]]], {{\[\[}}[0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]]]]> : tensor<2x3x4x1xi32>}> : () -> tensor<2x3x4x1xi32>
// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 2, 3, 4, 1>} : (tensor<2x3x4xf32>) -> tensor<2x3x4x1xf32>
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x3x4x4xf32>}> : () -> tensor<2x3x4x4xf32>
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array<i64: 2, 3, 4, 1, 1>} : (tensor<2x3x4x1xi32>) -> tensor<2x3x4x1x1xi32>
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}[0]], {{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]]], {{\[\[}}[0]], {{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]]], {{\[\[}}[0]], {{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]]]], {{\[\[}}{{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]]], {{\[\[}}[1]], {{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]]], {{\[\[}}[1]], {{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]]]]]> : tensor<2x3x4x1x1xi32>}> : () -> tensor<2x3x4x1x1xi32>
// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}[0]], {{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]]], {{\[\[}}[1]], {{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]]], {{\[\[}}[2]], {{\[\[}}2]], {{\[\[}}2]], {{\[\[}}2]]]], {{\[\[}}{{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]]], {{\[\[}}[1]], {{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]]], {{\[\[}}[2]], {{\[\[}}2]], {{\[\[}}2]], {{\[\[}}2]]]]]> : tensor<2x3x4x1x1xi32>}> : () -> tensor<2x3x4x1x1xi32>
// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]], {{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]], {{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]]], {{\[\[}}{{\[\[}}0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]], {{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]], {{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]]]]> : tensor<2x3x4x1x1xi32>}> : () -> tensor<2x3x4x1x1xi32>
// CHECK: %[[VAL_12:.*]] = tosa.concat %[[VAL_9]], %[[VAL_10]], %[[VAL_11]], %[[VAL_8]] {axis = 4 : i32} : (tensor<2x3x4x1x1xi32>, tensor<2x3x4x1x1xi32>, tensor<2x3x4x1x1xi32>, tensor<2x3x4x1x1xi32>) -> tensor<2x3x4x1x4xi32>
// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array<i64: 1, 24, 1>} : (tensor<2x3x4x1xf32>) -> tensor<1x24x1xf32>
// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 1, 96, 1>} : (tensor<2x3x4x4xf32>) -> tensor<1x96x1xf32>
// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array<i64: 24, 4>} : (tensor<2x3x4x1x4xi32>) -> tensor<24x4xi32>
// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{value = dense<[48, 16, 4, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_15]], %[[VAL_16]] {shift = 0 : i8} : (tensor<24x4xi32>, tensor<4xi32>) -> tensor<24x4xi32>
// CHECK: %[[VAL_18:.*]] = tosa.reduce_sum %[[VAL_17]] {axis = 1 : i32} : (tensor<24x4xi32>) -> tensor<24x1xi32>
// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array<i64: 1, 24>} : (tensor<24x1xi32>) -> tensor<1x24xi32>
// CHECK: %[[VAL_20:.*]] = tosa.scatter %[[VAL_14]], %[[VAL_19]], %[[VAL_13]] : (tensor<1x96x1xf32>, tensor<1x24xi32>, tensor<1x24x1xf32>) -> tensor<1x96x1xf32>
// CHECK: %[[VAL_21:.*]] = tosa.reshape %[[VAL_20]] {new_shape = array<i64: 2, 3, 4, 4>} : (tensor<1x96x1xf32>) -> tensor<2x3x4x4xf32>
// CHECK: %[[VAL_22:.*]] = "tosa.const"() <{value = dense<[0, 1, 2, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK: %[[VAL_23:.*]] = tosa.transpose %[[VAL_21]], %[[VAL_22]] : (tensor<2x3x4x4xf32>, tensor<4xi32>) -> tensor<2x3x4x4xf32>
// CHECK: %[[VAL_24:.*]] = torch_c.from_builtin_tensor %[[VAL_23]] : tensor<2x3x4x4xf32> -> !torch.vtensor<[2,3,4,4],f32>
// CHECK: return %[[VAL_24]] : !torch.vtensor<[2,3,4,4],f32>
// CHECK: }
func.func @torch.aten.diag_embed$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,4],f32> {
%int0 = torch.constant.int 0
%int-2 = torch.constant.int -2
%int-1 = torch.constant.int -1
%0 = torch.aten.diag_embed %arg0, %int0, %int-2, %int-1 : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3,4,4],f32>
return %0 : !torch.vtensor<[2,3,4,4],f32>
}