mirror of https://github.com/llvm/torch-mlir
[TOSA] Add legalization for empty, scatter, slice_scatter, diag_embed (#3792)
- Add Torch to TOSA legalization for the following ops: + aten.empty.memory_format + aten.scatter.src + aten.slice_scatter + aten.diag_embed - Update xfail_sets.py with new e2e results - Update basic.mlir with new LIT tests Change-Id: I817ecf207bcfcf97ca54f30c10c76c4f0f4145ae Signed-off-by: Justin Ngo <justin.ngo@arm.com>pull/3798/head
parent
895f490cf5
commit
45bb17ebfe
|
@ -4360,6 +4360,221 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Legalization for aten.scatter.src
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenScatterSrcOp>::matchAndRewrite(
|
||||||
|
AtenScatterSrcOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
|
// Not a tensor type.
|
||||||
|
auto input = adaptor.getSelf();
|
||||||
|
auto inputType = dyn_cast<RankedTensorType>(input.getType());
|
||||||
|
if (!inputType)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only RankedTensorType inputs are currently supported");
|
||||||
|
|
||||||
|
auto inputShape = inputType.getShape();
|
||||||
|
auto paramsRank = inputType.getRank();
|
||||||
|
|
||||||
|
auto index = adaptor.getIndex();
|
||||||
|
auto indexType = dyn_cast<RankedTensorType>(index.getType());
|
||||||
|
if (!indexType)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only RankedTensorType indices are currently supported");
|
||||||
|
|
||||||
|
// Check `index` and `input` param should have the same rank
|
||||||
|
if (indexType.getRank() != paramsRank)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Params index and input should have the same rank");
|
||||||
|
|
||||||
|
auto indexShape = indexType.getShape();
|
||||||
|
|
||||||
|
auto src = adaptor.getSrc();
|
||||||
|
auto srcType = dyn_cast<RankedTensorType>(src.getType());
|
||||||
|
if (!srcType)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only RankedTensorType sources are currently supported");
|
||||||
|
|
||||||
|
// Check `src` and `input` param should have the same rank
|
||||||
|
if (srcType.getRank() != paramsRank)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Src and input should have the same rank");
|
||||||
|
|
||||||
|
auto srcShape = srcType.getShape();
|
||||||
|
|
||||||
|
// Dynamic shape check
|
||||||
|
if (!inputType.hasStaticShape() || !indexType.hasStaticShape() ||
|
||||||
|
!srcType.hasStaticShape())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Support for dynamic shape not implemented");
|
||||||
|
|
||||||
|
// index i64 to i32 for tosa compatitable
|
||||||
|
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
|
||||||
|
index = rewriter.create<tosa::CastOp>(
|
||||||
|
op->getLoc(),
|
||||||
|
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get positive dim
|
||||||
|
int64_t dim{0};
|
||||||
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"Dim value should be a constant int");
|
||||||
|
|
||||||
|
dim = toPositiveDim(dim, paramsRank);
|
||||||
|
if (!isValidDim(dim, paramsRank))
|
||||||
|
return rewriter.notifyMatchFailure(op, "Dim is invalid");
|
||||||
|
|
||||||
|
// It is also required that index.size(d) <= src.size(d) for all dimensions d,
|
||||||
|
// and that index.size(d) <= self.size(d) for all dimensions d != dim
|
||||||
|
for (int64_t d = 0; d < paramsRank; d++) {
|
||||||
|
if (d != dim) {
|
||||||
|
if (indexShape[d] > srcShape[d] || indexShape[d] > inputShape[d])
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Index size should be smaller or equal to src or input size "
|
||||||
|
"for all dimensions d != dim");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the output type
|
||||||
|
auto outType = getTypeConverter()->convertType(op.getType());
|
||||||
|
|
||||||
|
// convert PyTorch style index and dim into TensorFlows tyle indices
|
||||||
|
// tensor<[1,4,2],si64> -> tensor<[1,4,2,3],si64>
|
||||||
|
auto indicesTf =
|
||||||
|
tosa::convertTorchIndexToTfIndices(rewriter, op, input, index, dim);
|
||||||
|
if (!indicesTf)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Convert PyTorch index and dim to TensorFlow indices failed");
|
||||||
|
|
||||||
|
// Perform the TensorFlow ScatterNd algorithm with TensorFlow style indices as
|
||||||
|
// input.
|
||||||
|
auto result = tosa::convertScatterNdOp(rewriter, op, outType, input,
|
||||||
|
indicesTf.value(), src);
|
||||||
|
|
||||||
|
if (!result)
|
||||||
|
return rewriter.notifyMatchFailure(op, "Convert ScatterNdOp failed");
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, {result.value()});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Legalization for aten.slice_scatter
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenSliceScatterOp>::matchAndRewrite(
|
||||||
|
AtenSliceScatterOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
|
// Not a tensor type.
|
||||||
|
auto input = adaptor.getSelf();
|
||||||
|
auto inputType = dyn_cast<RankedTensorType>(input.getType());
|
||||||
|
if (!inputType)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only RankedTensorType inputs are currently supported");
|
||||||
|
|
||||||
|
auto inputShape = inputType.getShape();
|
||||||
|
auto paramsRank = inputType.getRank();
|
||||||
|
|
||||||
|
auto src = adaptor.getSrc();
|
||||||
|
auto srcType = dyn_cast<RankedTensorType>(src.getType());
|
||||||
|
if (!srcType)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only RankedTensorType sources are currently supported");
|
||||||
|
|
||||||
|
// Check `src` and `input` param should have the same rank
|
||||||
|
if (srcType.getRank() != paramsRank)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Src and input should have the same rank");
|
||||||
|
|
||||||
|
auto srcShape = srcType.getShape();
|
||||||
|
|
||||||
|
// Dynamic shape check
|
||||||
|
if (!inputType.hasStaticShape() || !srcType.hasStaticShape())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Support for dynamic shape not implemented");
|
||||||
|
|
||||||
|
// Get positive dim
|
||||||
|
int64_t dim{0};
|
||||||
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"Dim value should be a constant int");
|
||||||
|
|
||||||
|
dim = toPositiveDim(dim, paramsRank);
|
||||||
|
if (!isValidDim(dim, paramsRank))
|
||||||
|
return rewriter.notifyMatchFailure(op, "Dim is invalid");
|
||||||
|
|
||||||
|
// Get start, end, and step params
|
||||||
|
// If start and end params are not specified, assign them to 0 and
|
||||||
|
// inputShape[dim], respectively.
|
||||||
|
int64_t start{0};
|
||||||
|
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start)))
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"Start value should be a constant int");
|
||||||
|
if (start < 0)
|
||||||
|
start += inputShape[dim];
|
||||||
|
|
||||||
|
int64_t end{inputShape[dim]};
|
||||||
|
if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end)))
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"End value should be a constant int");
|
||||||
|
if (end < 0)
|
||||||
|
end += inputShape[dim];
|
||||||
|
|
||||||
|
if (end > inputShape[dim])
|
||||||
|
end = inputShape[dim];
|
||||||
|
|
||||||
|
if (start >= end)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Start value greater than end value not supported");
|
||||||
|
|
||||||
|
int64_t step{1};
|
||||||
|
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step)))
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"Step value should be a constant int");
|
||||||
|
|
||||||
|
// Create PyTorch style scatter index based on start, end, and step values
|
||||||
|
int64_t outerRepeat{1}, innerRepeat{1};
|
||||||
|
for (int64_t i = 0; i < dim; i++)
|
||||||
|
outerRepeat *= srcShape[i];
|
||||||
|
|
||||||
|
for (int64_t i = dim + 1; i < paramsRank; i++)
|
||||||
|
innerRepeat *= srcShape[i];
|
||||||
|
|
||||||
|
SmallVector<int32_t> indexVec;
|
||||||
|
for (int64_t i = 0; i < outerRepeat; i++) {
|
||||||
|
for (int32_t indexVal = start; indexVal < end; indexVal += step) {
|
||||||
|
for (int64_t j = 0; j < innerRepeat; j++) {
|
||||||
|
indexVec.push_back(indexVal);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Value index =
|
||||||
|
tosa::getConstTensor<int32_t>(rewriter, op, indexVec, srcShape).value();
|
||||||
|
|
||||||
|
// Get the output type
|
||||||
|
auto outType = getTypeConverter()->convertType(op.getType());
|
||||||
|
|
||||||
|
// convert PyTorch style index and dim into TensorFlows tyle indices
|
||||||
|
// tensor<[1,4,2],si64> -> tensor<[1,4,2,3],si64>
|
||||||
|
auto indicesTf =
|
||||||
|
tosa::convertTorchIndexToTfIndices(rewriter, op, input, index, dim);
|
||||||
|
if (!indicesTf)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Convert PyTorch index and dim to TensorFlow indices failed");
|
||||||
|
|
||||||
|
// Perform the TensorFlow ScatterNd algorithm with TensorFlow style indices as
|
||||||
|
// input.
|
||||||
|
auto result = tosa::convertScatterNdOp(rewriter, op, outType, input,
|
||||||
|
indicesTf.value(), src);
|
||||||
|
|
||||||
|
if (!result)
|
||||||
|
return rewriter.notifyMatchFailure(op, "Convert ScatterNdOp failed");
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, {result.value()});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<AtenAbsOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenAbsOp>::matchAndRewrite(
|
||||||
AtenAbsOp op, OpAdaptor adaptor,
|
AtenAbsOp op, OpAdaptor adaptor,
|
||||||
|
@ -6099,6 +6314,10 @@ LogicalResult ConvertAtenOp<AtenDiagonalOp>::matchAndRewrite(
|
||||||
dim2 = toPositiveDim(dim2, selfRank);
|
dim2 = toPositiveDim(dim2, selfRank);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (dim1 == dim2)
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"Values dim1 and dim2 cannot be equal");
|
||||||
|
|
||||||
auto selfShape = makeShapeTorchCompatible(selfType.getShape());
|
auto selfShape = makeShapeTorchCompatible(selfType.getShape());
|
||||||
int64_t h = selfShape[dim1];
|
int64_t h = selfShape[dim1];
|
||||||
int64_t w = selfShape[dim2];
|
int64_t w = selfShape[dim2];
|
||||||
|
@ -6122,13 +6341,13 @@ LogicalResult ConvertAtenOp<AtenDiagonalOp>::matchAndRewrite(
|
||||||
SmallVector<int32_t> transposedDims;
|
SmallVector<int32_t> transposedDims;
|
||||||
transposedInputShape.clear();
|
transposedInputShape.clear();
|
||||||
|
|
||||||
for (int64_t i = 0; i < selfRank; ++i) {
|
for (int32_t i = 0; i < selfRank; ++i) {
|
||||||
if (i == dim1 || i == dim2)
|
if (i == dim1 || i == dim2)
|
||||||
continue;
|
continue;
|
||||||
transposedDims.push_back(i);
|
transposedDims.push_back(i);
|
||||||
}
|
}
|
||||||
transposedDims.push_back(dim1);
|
transposedDims.push_back(static_cast<int32_t>(dim1));
|
||||||
transposedDims.push_back(dim2);
|
transposedDims.push_back(static_cast<int32_t>(dim2));
|
||||||
|
|
||||||
auto transposedDimsConst = tosa::getConstTensor<int32_t>(
|
auto transposedDimsConst = tosa::getConstTensor<int32_t>(
|
||||||
rewriter, op,
|
rewriter, op,
|
||||||
|
@ -6213,6 +6432,193 @@ LogicalResult ConvertAtenOp<AtenDiagonalOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Legalization for aten.diag_embed
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenDiagEmbedOp>::matchAndRewrite(
|
||||||
|
AtenDiagEmbedOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
// To perform diag_embed, we will apply scatter with a newly created diagonal
|
||||||
|
// index tensor over a constant zero tensor.
|
||||||
|
// To make it simpler, we will only scatter using the diagonal with respect
|
||||||
|
// to the two innermost dimensions, then permute the output tensor to the
|
||||||
|
// correct order of dimensions.
|
||||||
|
auto self = adaptor.getSelf();
|
||||||
|
|
||||||
|
// Not a ranked tensor type
|
||||||
|
auto selfType = dyn_cast<RankedTensorType>(self.getType());
|
||||||
|
if (!selfType)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only ranked tensor types are supported");
|
||||||
|
|
||||||
|
auto selfRank = selfType.getRank();
|
||||||
|
int64_t outRank = selfRank + 1;
|
||||||
|
|
||||||
|
auto selfShape = makeShapeTorchCompatible(selfType.getShape());
|
||||||
|
int64_t diagSize = selfShape[selfRank - 1];
|
||||||
|
|
||||||
|
if (!selfType.hasStaticShape())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Currently only static shapes are supported");
|
||||||
|
|
||||||
|
const TypeConverter *typeConverter = this->getTypeConverter();
|
||||||
|
RankedTensorType resultType = cast<RankedTensorType>(
|
||||||
|
typeConverter->convertType(op->getResult(0).getType()));
|
||||||
|
if (!resultType)
|
||||||
|
return rewriter.notifyMatchFailure(op, "Result type cannot be empty");
|
||||||
|
|
||||||
|
auto selfElemTy = selfType.getElementType();
|
||||||
|
auto resultElemTy = resultType.getElementType();
|
||||||
|
|
||||||
|
int64_t offset{0};
|
||||||
|
if (!matchPattern(op.getOffset(), m_TorchConstantInt(&offset)))
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"Offset value should be a constant int");
|
||||||
|
|
||||||
|
// dim1 default is -2
|
||||||
|
int64_t dim1{outRank - 2};
|
||||||
|
if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1)))
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"Dim1 value should be a constant int");
|
||||||
|
dim1 = toPositiveDim(dim1, outRank);
|
||||||
|
|
||||||
|
// dim2 default is -1
|
||||||
|
int64_t dim2{outRank - 1};
|
||||||
|
if (!matchPattern(op.getDim2(), m_TorchConstantInt(&dim2)))
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"Dim2 value should be a constant int");
|
||||||
|
dim2 = toPositiveDim(dim2, outRank);
|
||||||
|
|
||||||
|
if (dim1 == dim2)
|
||||||
|
return rewriter.notifyMatchFailure(op, "Dim1 and dim2 cannot be equal");
|
||||||
|
|
||||||
|
// If offset is smaller than 0, we will swap dim1 and dim2 and convert offset
|
||||||
|
// to a positive value
|
||||||
|
if (offset < 0) {
|
||||||
|
std::swap(dim1, dim2);
|
||||||
|
offset = std::abs(offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the diagonal index tensor
|
||||||
|
int64_t repeat = 1;
|
||||||
|
for (int64_t i = 0; i < selfRank - 1; i++)
|
||||||
|
repeat *= selfShape[i];
|
||||||
|
|
||||||
|
SmallVector<int32_t> indexVec;
|
||||||
|
for (int32_t i = 0; i < repeat; i++) {
|
||||||
|
for (int32_t j = offset; j < diagSize + offset; j++)
|
||||||
|
indexVec.push_back(j);
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> indexShape = llvm::to_vector(selfShape);
|
||||||
|
indexShape.push_back(1);
|
||||||
|
|
||||||
|
auto index = tosa::getConstTensor<int32_t>(rewriter, op,
|
||||||
|
/*vec=*/indexVec,
|
||||||
|
/*shape=*/indexShape)
|
||||||
|
.value();
|
||||||
|
|
||||||
|
// Reshape the input tensor to be the same shape as the new index tensor to
|
||||||
|
// act as the src for scattering
|
||||||
|
auto scatterSrc = rewriter.create<tosa::ReshapeOp>(
|
||||||
|
op->getLoc(),
|
||||||
|
RankedTensorType::get(makeShapeTorchCompatible(indexShape), selfElemTy),
|
||||||
|
self, rewriter.getDenseI64ArrayAttr(indexShape));
|
||||||
|
|
||||||
|
// Create a const zero tensor to scatter the input onto
|
||||||
|
SmallVector<int64_t> zeroShape;
|
||||||
|
for (int64_t i = 0; i < selfRank - 1; i++)
|
||||||
|
zeroShape.push_back(selfShape[i]);
|
||||||
|
zeroShape.push_back(diagSize + offset);
|
||||||
|
zeroShape.push_back(diagSize + offset);
|
||||||
|
|
||||||
|
int64_t numElemOfZeroTensor = 1;
|
||||||
|
for (int64_t &d : zeroShape)
|
||||||
|
numElemOfZeroTensor *= d;
|
||||||
|
|
||||||
|
Value zero =
|
||||||
|
TypeSwitch<Type, Value>(selfElemTy)
|
||||||
|
.Case<mlir::FloatType>([&](auto) {
|
||||||
|
return tosa::getConstTensor<float>(
|
||||||
|
rewriter, op, SmallVector<float>(numElemOfZeroTensor, 0),
|
||||||
|
zeroShape)
|
||||||
|
.value();
|
||||||
|
})
|
||||||
|
.Case<mlir::IntegerType>([&](auto intType) {
|
||||||
|
switch (intType.getWidth()) {
|
||||||
|
case 1:
|
||||||
|
return tosa::getConstTensor<bool>(
|
||||||
|
rewriter, op,
|
||||||
|
SmallVector<bool>(numElemOfZeroTensor, 0), zeroShape)
|
||||||
|
.value();
|
||||||
|
case 32:
|
||||||
|
return tosa::getConstTensor<int32_t>(
|
||||||
|
rewriter, op,
|
||||||
|
SmallVector<int32_t>(numElemOfZeroTensor, 0),
|
||||||
|
zeroShape)
|
||||||
|
.value();
|
||||||
|
case 64:
|
||||||
|
return tosa::getConstTensor<int64_t>(
|
||||||
|
rewriter, op,
|
||||||
|
SmallVector<int64_t>(numElemOfZeroTensor, 0),
|
||||||
|
zeroShape)
|
||||||
|
.value();
|
||||||
|
}
|
||||||
|
llvm_unreachable("Invalid integer width");
|
||||||
|
});
|
||||||
|
|
||||||
|
// Convert PyTorch index and dim to TensorFlow-style indices
|
||||||
|
auto indicesTf = tosa::convertTorchIndexToTfIndices(rewriter, op, zero, index,
|
||||||
|
outRank - 1);
|
||||||
|
if (!indicesTf)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Convert PyTorch index and dim to TensorFlow indices failed");
|
||||||
|
|
||||||
|
// Perform the TensorFlow ScatterNd algorithm with TensorFlow-style indices as
|
||||||
|
// input
|
||||||
|
auto diagonalTensor = tosa::convertScatterNdOp(
|
||||||
|
rewriter, op,
|
||||||
|
RankedTensorType::get(makeShapeTorchCompatible(zeroShape), resultElemTy),
|
||||||
|
zero, indicesTf.value(), scatterSrc.getResult());
|
||||||
|
if (!diagonalTensor)
|
||||||
|
return rewriter.notifyMatchFailure(op, "Convert ScatterNdOp failed");
|
||||||
|
|
||||||
|
// Create the final dims order to permute the scattered tensor
|
||||||
|
SmallVector<int32_t> permutedDims(outRank, 0);
|
||||||
|
int32_t currentDim = 0;
|
||||||
|
int32_t i = 0;
|
||||||
|
|
||||||
|
while (i < outRank) {
|
||||||
|
if (i == dim1) {
|
||||||
|
permutedDims[i] = outRank - 2;
|
||||||
|
i++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (i == dim2) {
|
||||||
|
permutedDims[i] = outRank - 1;
|
||||||
|
i++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
permutedDims[i] = currentDim;
|
||||||
|
currentDim++;
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto permutedDimsConst =
|
||||||
|
tosa::getConstTensor<int32_t>(rewriter, op,
|
||||||
|
/*vec=*/permutedDims,
|
||||||
|
/*shape=*/{static_cast<int32_t>(outRank)});
|
||||||
|
|
||||||
|
auto result = rewriter.create<tosa::TransposeOp>(op->getLoc(), resultType,
|
||||||
|
diagonalTensor.value(),
|
||||||
|
permutedDimsConst.value());
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, result.getResult());
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
@ -6442,6 +6848,7 @@ public:
|
||||||
context);
|
context);
|
||||||
INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1);
|
INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1);
|
||||||
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
|
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
|
||||||
|
INSERT_CONSTANT_FILL_PATTERN(AtenEmptyMemoryFormatOp, 0);
|
||||||
#undef INSERT_CONSTANT_FILL_PATTERN
|
#undef INSERT_CONSTANT_FILL_PATTERN
|
||||||
|
|
||||||
#define INSERT_FILL_PATTERN(AtenOp) \
|
#define INSERT_FILL_PATTERN(AtenOp) \
|
||||||
|
@ -6524,6 +6931,9 @@ public:
|
||||||
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
|
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenFlipOp);
|
INSERT_ATENOP_PATTERN(AtenFlipOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenRoundOp);
|
INSERT_ATENOP_PATTERN(AtenRoundOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenScatterSrcOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenSliceScatterOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenDiagEmbedOp);
|
||||||
#undef INSERT_ATENOP_PATTERN
|
#undef INSERT_ATENOP_PATTERN
|
||||||
|
|
||||||
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
||||||
|
|
|
@ -1650,12 +1650,18 @@ STABLEHLO_CRASHING_SET = {
|
||||||
}
|
}
|
||||||
|
|
||||||
TOSA_CRASHING_SET = {
|
TOSA_CRASHING_SET = {
|
||||||
|
"ArangeStartOutDtypeModule_basic",
|
||||||
|
"ArangeStartOutModule_basic",
|
||||||
|
"ScatterSrcStaticModule_basic",
|
||||||
# Runtime op verification: Out of bounds access
|
# Runtime op verification: Out of bounds access
|
||||||
"IndexTensorNegativeIndexModule_basic",
|
"IndexTensorNegativeIndexModule_basic",
|
||||||
"ReduceAllDimEmpty_basic",
|
"ReduceAllDimEmpty_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
FX_IMPORTER_TOSA_CRASHING_SET = {
|
FX_IMPORTER_TOSA_CRASHING_SET = {
|
||||||
|
"ScatterSrcModule_basic",
|
||||||
|
"ScatterSrcStaticModule_basic",
|
||||||
|
"HBC_basic",
|
||||||
"IndexTensorNegativeIndexModule_basic",
|
"IndexTensorNegativeIndexModule_basic",
|
||||||
"InterpolateDynamicModule_scales_recompute_bilinear",
|
"InterpolateDynamicModule_scales_recompute_bilinear",
|
||||||
"InterpolateDynamicModule_sizes_bilinear",
|
"InterpolateDynamicModule_sizes_bilinear",
|
||||||
|
@ -1671,6 +1677,26 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
|
||||||
# Write the TOSA set as a "passing" set as it is very early in development
|
# Write the TOSA set as a "passing" set as it is very early in development
|
||||||
# and very few tests work yet.
|
# and very few tests work yet.
|
||||||
TOSA_PASS_SET = {
|
TOSA_PASS_SET = {
|
||||||
|
"EmptyModule_contiguous",
|
||||||
|
"EmptyModule_defaultDtype",
|
||||||
|
"EmptyModule_falsePinMemory",
|
||||||
|
"EmptyModule_float",
|
||||||
|
"EmptyModule_int",
|
||||||
|
"EmptyModule_uint8",
|
||||||
|
"EmptyStridedModule_basic",
|
||||||
|
"NewEmptyModuleDefaultDtype_basic",
|
||||||
|
"NewEmptyModuleFalsePinMemory_basic",
|
||||||
|
"NewEmptyModuleFloat2D_basic",
|
||||||
|
"NewEmptyModuleFloat3D_basic",
|
||||||
|
"NewEmptyModuleInt2D_basic",
|
||||||
|
"NewEmptyModuleInt3D_basic",
|
||||||
|
"NewEmptyModuleLayoutIntDtype_basic",
|
||||||
|
"NewEmptyModuleNonDefaultFloatDtype_basic",
|
||||||
|
"NewEmptyModuleNonDefaultIntDtype_basic",
|
||||||
|
"NewEmptyStridedModuleDefaultDtype_basic",
|
||||||
|
"SelectScattertStaticModule_basic",
|
||||||
|
"SliceScatterStaticModule_basic",
|
||||||
|
"TensorAlloc1dStaticModule_basic",
|
||||||
"AtenRoundFloatHalfToEvenModule_basic",
|
"AtenRoundFloatHalfToEvenModule_basic",
|
||||||
"AtenRoundFloatModule_basic",
|
"AtenRoundFloatModule_basic",
|
||||||
"FakeQuantizePerTensorAffineCachemaskModule_basic",
|
"FakeQuantizePerTensorAffineCachemaskModule_basic",
|
||||||
|
@ -3248,6 +3274,12 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
|
||||||
}
|
}
|
||||||
|
|
||||||
FX_IMPORTER_TOSA_XFAIL_SET = {
|
FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
|
"ViewDtypeStaticModule_basic",
|
||||||
|
"Unfold_Module_Dynamic_basic",
|
||||||
|
"Unfold_Module_Rank_4",
|
||||||
|
"Unfold_Module_Rank_Zero_Size_Zero_basic",
|
||||||
|
"Unfold_Module_Rank_Zero_basic",
|
||||||
|
"Unfold_Module_basic",
|
||||||
"ArangeZeroElementOutputModule_basic",
|
"ArangeZeroElementOutputModule_basic",
|
||||||
"NumpyTRank0Module_basic",
|
"NumpyTRank0Module_basic",
|
||||||
"Permute0RankModule_basic",
|
"Permute0RankModule_basic",
|
||||||
|
@ -3338,12 +3370,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"AtenComplexImagModule_basic",
|
"AtenComplexImagModule_basic",
|
||||||
"AtenComplexRealModule_basic",
|
"AtenComplexRealModule_basic",
|
||||||
"AtenComplexViewModule_basic",
|
"AtenComplexViewModule_basic",
|
||||||
"AtenDiagEmbedDefaultDiag_basic",
|
|
||||||
"AtenDiagEmbedDimDiag_basic",
|
|
||||||
"AtenDiagEmbedNegOffsetDiag_basic",
|
|
||||||
"AtenDiagEmbedNonDefault4DDiag_basic",
|
|
||||||
"AtenDiagEmbedOffsetDiag_basic",
|
|
||||||
"AtenDiagEmbedRevDimDiag_basic",
|
|
||||||
"AtenEmbeddingBagStaticModule_basic",
|
"AtenEmbeddingBagStaticModule_basic",
|
||||||
"AtenEmbeddingBagSumExample_basic",
|
"AtenEmbeddingBagSumExample_basic",
|
||||||
"AtenEyeMModuleInt2D_basic",
|
"AtenEyeMModuleInt2D_basic",
|
||||||
|
@ -3513,31 +3539,13 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseToDtypeI64ToUI8Module_basic",
|
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||||
"ElementwiseUnaryIntModule_basic",
|
"ElementwiseUnaryIntModule_basic",
|
||||||
"ElementwiseWhereScalarOtherStaticModule_basic",
|
"ElementwiseWhereScalarOtherStaticModule_basic",
|
||||||
"EmptyLikeMemoryFormatModule_basic",
|
|
||||||
"EmptyLikeModule_defaultDtype",
|
|
||||||
"EmptyLikeModule_falsePinMemory",
|
|
||||||
"EmptyLikeModule_float",
|
|
||||||
"EmptyLikeModule_int",
|
|
||||||
"EmptyModule_contiguous",
|
|
||||||
"EmptyModule_defaultDtype",
|
|
||||||
"EmptyModule_falsePinMemory",
|
|
||||||
"EmptyModule_float",
|
|
||||||
"EmptyModule_int",
|
|
||||||
"EmptyModule_uint8",
|
|
||||||
"EmptyStridedModule_basic",
|
|
||||||
"EmptyStridedSizeIntStrideModule_basic",
|
|
||||||
"EqIntModule_basic",
|
"EqIntModule_basic",
|
||||||
"ExpandModule_basic",
|
"ExpandModule_basic",
|
||||||
"ExponentialModule_basic",
|
"ExponentialModule_basic",
|
||||||
"FloatImplicitModule_basic",
|
"FloatImplicitModule_basic",
|
||||||
"FullLikeModuleInt2D_basic",
|
"FullLikeModuleInt2D_basic",
|
||||||
"FullLikeModuleInt3D_basic",
|
"FullLikeModuleInt3D_basic",
|
||||||
"FullModuleDefaultDtype_basic",
|
|
||||||
"FullModuleFalsePinMemory_basic",
|
|
||||||
"FullModuleFloat2D_basic",
|
|
||||||
"FullModuleFloat3D_basic",
|
|
||||||
"FullModuleInt2D_basic",
|
"FullModuleInt2D_basic",
|
||||||
"FullModuleInt3D_basic",
|
|
||||||
"GeFloatIntModule_basic",
|
"GeFloatIntModule_basic",
|
||||||
"GeFloatModule_basic",
|
"GeFloatModule_basic",
|
||||||
"GeIntModule_basic",
|
"GeIntModule_basic",
|
||||||
|
@ -3547,7 +3555,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"GridSamplerBasic4_basic",
|
"GridSamplerBasic4_basic",
|
||||||
"GtFloatIntModule_basic",
|
"GtFloatIntModule_basic",
|
||||||
"GtIntModule_basic",
|
"GtIntModule_basic",
|
||||||
"HBC_basic",
|
|
||||||
"IndexPut1DFloatAccumulateModule_basic",
|
"IndexPut1DFloatAccumulateModule_basic",
|
||||||
"IndexPut1DFloatNonAccumulateModule_basic",
|
"IndexPut1DFloatNonAccumulateModule_basic",
|
||||||
"IndexPut1DIntAccumulateModule_basic",
|
"IndexPut1DIntAccumulateModule_basic",
|
||||||
|
@ -3599,7 +3606,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"LinalgVectorNormComplexModule_basic",
|
"LinalgVectorNormComplexModule_basic",
|
||||||
"LinspaceDtypeModule_basic",
|
"LinspaceDtypeModule_basic",
|
||||||
"LinspaceEmptyModule_basic",
|
"LinspaceEmptyModule_basic",
|
||||||
"LinspaceOneSizeModule_basic",
|
|
||||||
"MaskedFillTensorFloatValueModule_basic",
|
"MaskedFillTensorFloatValueModule_basic",
|
||||||
"MatmulBroadcastBatchDim_basic",
|
"MatmulBroadcastBatchDim_basic",
|
||||||
"MatmulStaticBroadcast_basic",
|
"MatmulStaticBroadcast_basic",
|
||||||
|
@ -3653,16 +3659,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"NativeGroupNormBackwardModule_basic",
|
"NativeGroupNormBackwardModule_basic",
|
||||||
"NeFloatIntModule_basic",
|
"NeFloatIntModule_basic",
|
||||||
"NeIntModule_basic",
|
"NeIntModule_basic",
|
||||||
"NewEmptyModuleDefaultDtype_basic",
|
|
||||||
"NewEmptyModuleFalsePinMemory_basic",
|
|
||||||
"NewEmptyModuleFloat2D_basic",
|
|
||||||
"NewEmptyModuleFloat3D_basic",
|
|
||||||
"NewEmptyModuleInt2D_basic",
|
|
||||||
"NewEmptyModuleInt3D_basic",
|
|
||||||
"NewEmptyModuleLayoutIntDtype_basic",
|
|
||||||
"NewEmptyModuleNonDefaultFloatDtype_basic",
|
|
||||||
"NewEmptyModuleNonDefaultIntDtype_basic",
|
|
||||||
"NewEmptyStridedModuleDefaultDtype_basic",
|
|
||||||
"NewFullModuleInt2D_basic",
|
"NewFullModuleInt2D_basic",
|
||||||
"NewFullModuleInt3D_basic",
|
"NewFullModuleInt3D_basic",
|
||||||
"NllLossModuleBackward1DMeanWeight_basic",
|
"NllLossModuleBackward1DMeanWeight_basic",
|
||||||
|
@ -3671,13 +3667,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"NllLossModuleBackward1DSum_basic",
|
"NllLossModuleBackward1DSum_basic",
|
||||||
"NllLossModuleBackward1DWeight_basic",
|
"NllLossModuleBackward1DWeight_basic",
|
||||||
"NllLossModuleBackward1D_basic",
|
"NllLossModuleBackward1D_basic",
|
||||||
"NllLossModuleBackwardMeanWeight_basic",
|
|
||||||
"NllLossModuleBackwardMean_basic",
|
|
||||||
"NllLossModuleBackwardSumWeight_basic",
|
|
||||||
"NllLossModuleBackwardSum_basic",
|
|
||||||
"NllLossModuleBackwardWeight_basic",
|
|
||||||
"NllLossModuleBackward_basic",
|
|
||||||
"NllLossModuleBackward_ignore_index",
|
|
||||||
"NormScalarComplexModule_basic",
|
"NormScalarComplexModule_basic",
|
||||||
"NormScalarModule_basic",
|
"NormScalarModule_basic",
|
||||||
"NormScalarOptDimKeepDimComplexModule_basic",
|
"NormScalarOptDimKeepDimComplexModule_basic",
|
||||||
|
@ -3777,26 +3766,14 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ScatterSrcStaticModule_basic",
|
"ScatterSrcStaticModule_basic",
|
||||||
"ScatterValueFloatModule_basic",
|
"ScatterValueFloatModule_basic",
|
||||||
"ScatterValueIntModule_basic",
|
"ScatterValueIntModule_basic",
|
||||||
"SelectScattertModule_basic",
|
|
||||||
"SelectScattertStaticModule_basic",
|
|
||||||
"SignAndLogarithmOfDeterminantModule_F32",
|
"SignAndLogarithmOfDeterminantModule_F32",
|
||||||
"SignAndLogarithmOfDeterminantBatchedModule_F32",
|
"SignAndLogarithmOfDeterminantBatchedModule_F32",
|
||||||
"SignAndLogarithmOfDeterminantDynamicModule_F32",
|
"SignAndLogarithmOfDeterminantDynamicModule_F32",
|
||||||
"SliceStaticComplexInputModule_basic",
|
"SliceStaticComplexInputModule_basic",
|
||||||
"SliceCopyEndGreaterThanDimSize_Module_basic",
|
|
||||||
"SliceCopyNegative_Module_basic",
|
|
||||||
"SliceCopyNonZeroDim_Module_basic",
|
|
||||||
"SliceCopyStartGreaterThanDimSize_Module_basic",
|
"SliceCopyStartGreaterThanDimSize_Module_basic",
|
||||||
"SliceCopy_Module_basic",
|
|
||||||
"SliceEndSleStartModule_basic",
|
"SliceEndSleStartModule_basic",
|
||||||
"SliceOutOfLowerBoundEndIndexModule_basic",
|
"SliceOutOfLowerBoundEndIndexModule_basic",
|
||||||
"SliceOutOfLowerBoundStartIndexModule_basic",
|
"SliceOutOfLowerBoundStartIndexModule_basic",
|
||||||
"SliceScatterModule_basic",
|
|
||||||
"SliceScatterNegativeDimModule_basic",
|
|
||||||
"SliceScatterNegativeEndModule_basic",
|
|
||||||
"SliceScatterStaticModule_basic",
|
|
||||||
"SliceScatterStepVariationModule_basic",
|
|
||||||
"SliceScatterZeroDimModule_basic",
|
|
||||||
"SliceSizeTwoStepModule_basic",
|
"SliceSizeTwoStepModule_basic",
|
||||||
"SoftplusModule_basic",
|
"SoftplusModule_basic",
|
||||||
"SortIntListReverse_basic",
|
"SortIntListReverse_basic",
|
||||||
|
@ -3864,6 +3841,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
}
|
}
|
||||||
|
|
||||||
ONNX_TOSA_CRASHING_SET = {
|
ONNX_TOSA_CRASHING_SET = {
|
||||||
|
"ScatterSrcStaticModule_basic",
|
||||||
"StdCorrectionEmptyDimModule_basic",
|
"StdCorrectionEmptyDimModule_basic",
|
||||||
"StdDimEmptyDimModule_basic",
|
"StdDimEmptyDimModule_basic",
|
||||||
"VarCorrectionEmptyDimModule_basic",
|
"VarCorrectionEmptyDimModule_basic",
|
||||||
|
@ -3872,6 +3850,11 @@ ONNX_TOSA_CRASHING_SET = {
|
||||||
}
|
}
|
||||||
|
|
||||||
ONNX_TOSA_XFAIL_SET = {
|
ONNX_TOSA_XFAIL_SET = {
|
||||||
|
"Unfold_Module_Dynamic_basic",
|
||||||
|
"Unfold_Module_Rank_4",
|
||||||
|
"Unfold_Module_Rank_Zero_Size_Zero_basic",
|
||||||
|
"Unfold_Module_Rank_Zero_basic",
|
||||||
|
"ViewDtypeStaticModule_basic",
|
||||||
"ArangeZeroElementOutputModule_basic",
|
"ArangeZeroElementOutputModule_basic",
|
||||||
"LinspaceEmptyModule_basic",
|
"LinspaceEmptyModule_basic",
|
||||||
"RepeatInterleaveSelfIntNoDimModule_basic",
|
"RepeatInterleaveSelfIntNoDimModule_basic",
|
||||||
|
|
|
@ -1998,3 +1998,136 @@ func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso
|
||||||
%0 = torch.aten.round %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
|
%0 = torch.aten.round %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
|
||||||
return %0 : !torch.vtensor<[3,4,5],f32>
|
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.empty.memory_format$basic() -> !torch.vtensor<[3,4],si64> {
|
||||||
|
// CHECK: %[[VAL_0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.int 4
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = torch.constant.device "cpu"
|
||||||
|
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0> : tensor<3x4xi32>}> : () -> tensor<3x4xi32>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = tosa.cast %[[VAL_7]] : (tensor<3x4xi32>) -> tensor<3x4xi64>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<0> : tensor<3x4xi64>}> : () -> tensor<3x4xi64>
|
||||||
|
// CHECK: %[[VAL_10:.*]] = tosa.cast %[[VAL_9]] : (tensor<3x4xi64>) -> tensor<3x4xi64>
|
||||||
|
// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<3x4xi64> -> !torch.vtensor<[3,4],si64>
|
||||||
|
// CHECK: return %[[VAL_11]] : !torch.vtensor<[3,4],si64>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.empty.memory_format$basic() -> !torch.vtensor<[3,4],si64> {
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%false = torch.constant.bool false
|
||||||
|
%none = torch.constant.none
|
||||||
|
%int3 = torch.constant.int 3
|
||||||
|
%int4 = torch.constant.int 4
|
||||||
|
%0 = torch.prim.ListConstruct %int3, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%cpu = torch.constant.device "cpu"
|
||||||
|
%1 = torch.aten.empty.memory_format %0, %int4, %none, %cpu, %false, %none : !torch.list<int>, !torch.int, !torch.none, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[3,4],si64>
|
||||||
|
%2 = torch.aten.fill.Scalar %1, %int0 : !torch.vtensor<[3,4],si64>, !torch.int -> !torch.vtensor<[3,4],si64>
|
||||||
|
return %2 : !torch.vtensor<[3,4],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.scatter.src$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,8,6],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,4,3],si64>,
|
||||||
|
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[3,4,3],f32>) -> !torch.vtensor<[10,8,6],f32> {
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[3,4,3],f32> -> tensor<3x4x3xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,4,3],si64> -> tensor<2x4x3xi64>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,8,6],f32> -> tensor<10x8x6xf32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[VAL_7:.*]] = tosa.cast %[[VAL_4]] : (tensor<2x4x3xi64>) -> tensor<2x4x3xi32>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 2, 4, 3, 1>} : (tensor<2x4x3xi32>) -> tensor<2x4x3x1xi32>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0], [0]], {{\[\[}}0], [0], [0]], {{\[\[}}0], [0], [0]], {{\[\[}}0], [0], [0]]], {{\[\[}}[1], [1], [1]], {{\[\[}}1], [1], [1]], {{\[\[}}1], [1], [1]], {{\[\[}}1], [1], [1]]]]> : tensor<2x4x3x1xi32>}> : () -> tensor<2x4x3x1xi32>
|
||||||
|
// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]]], {{\[\[}}[0], [1], [2]], {{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]]]]> : tensor<2x4x3x1xi32>}> : () -> tensor<2x4x3x1xi32>
|
||||||
|
// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_9]], %[[VAL_8]], %[[VAL_10]] {axis = 3 : i32} : (tensor<2x4x3x1xi32>, tensor<2x4x3x1xi32>, tensor<2x4x3x1xi32>) -> tensor<2x4x3x3xi32>
|
||||||
|
// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 1, 36, 1>} : (tensor<3x4x3xf32>) -> tensor<1x36x1xf32>
|
||||||
|
// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array<i64: 1, 480, 1>} : (tensor<10x8x6xf32>) -> tensor<1x480x1xf32>
|
||||||
|
// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array<i64: 24, 3>} : (tensor<2x4x3x3xi32>) -> tensor<24x3xi32>
|
||||||
|
// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<[48, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
|
||||||
|
// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_14]], %[[VAL_15]] {shift = 0 : i8} : (tensor<24x3xi32>, tensor<3xi32>) -> tensor<24x3xi32>
|
||||||
|
// CHECK: %[[VAL_17:.*]] = tosa.reduce_sum %[[VAL_16]] {axis = 1 : i32} : (tensor<24x3xi32>) -> tensor<24x1xi32>
|
||||||
|
// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array<i64: 1, 24>} : (tensor<24x1xi32>) -> tensor<1x24xi32>
|
||||||
|
// CHECK: %[[VAL_19:.*]] = tosa.scatter %[[VAL_13]], %[[VAL_18]], %[[VAL_12]] : (tensor<1x480x1xf32>, tensor<1x24xi32>, tensor<1x36x1xf32>) -> tensor<1x480x1xf32>
|
||||||
|
// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array<i64: 10, 8, 6>} : (tensor<1x480x1xf32>) -> tensor<10x8x6xf32>
|
||||||
|
// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<10x8x6xf32> -> !torch.vtensor<[10,8,6],f32>
|
||||||
|
// CHECK: return %[[VAL_21]] : !torch.vtensor<[10,8,6],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.scatter.src$basic(%arg0: !torch.vtensor<[10,8,6],f32>, %arg1: !torch.vtensor<[2,4,3],si64>, %arg2: !torch.vtensor<[3,4,3],f32>) -> !torch.vtensor<[10,8,6],f32> {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%0 = torch.aten.scatter.src %arg0, %int1, %arg1, %arg2 : !torch.vtensor<[10,8,6],f32>, !torch.int, !torch.vtensor<[2,4,3],si64>, !torch.vtensor<[3,4,3],f32> -> !torch.vtensor<[10,8,6],f32>
|
||||||
|
return %0 : !torch.vtensor<[10,8,6],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.slice_scatter$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[6,8],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[6,1],f32>) -> !torch.vtensor<[6,8],f32> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[6,1],f32> -> tensor<6x1xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[6,8],f32> -> tensor<6x8xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0> : tensor<6x1xi32>}> : () -> tensor<6x1xi32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array<i64: 6, 1, 1>} : (tensor<6x1xi32>) -> tensor<6x1x1xi32>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<{{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]], {{\[\[}}4]], {{\[\[}}5]]]> : tensor<6x1x1xi32>}> : () -> tensor<6x1x1xi32>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = tosa.concat %[[VAL_8]], %[[VAL_7]] {axis = 2 : i32} : (tensor<6x1x1xi32>, tensor<6x1x1xi32>) -> tensor<6x1x2xi32>
|
||||||
|
// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array<i64: 1, 6, 1>} : (tensor<6x1xf32>) -> tensor<1x6x1xf32>
|
||||||
|
// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 1, 48, 1>} : (tensor<6x8xf32>) -> tensor<1x48x1xf32>
|
||||||
|
// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array<i64: 6, 2>} : (tensor<6x1x2xi32>) -> tensor<6x2xi32>
|
||||||
|
// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[8, 1]> : tensor<2xi32>}> : () -> tensor<2xi32>
|
||||||
|
// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_12]], %[[VAL_13]] {shift = 0 : i8} : (tensor<6x2xi32>, tensor<2xi32>) -> tensor<6x2xi32>
|
||||||
|
// CHECK: %[[VAL_15:.*]] = tosa.reduce_sum %[[VAL_14]] {axis = 1 : i32} : (tensor<6x2xi32>) -> tensor<6x1xi32>
|
||||||
|
// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array<i64: 1, 6>} : (tensor<6x1xi32>) -> tensor<1x6xi32>
|
||||||
|
// CHECK: %[[VAL_17:.*]] = tosa.scatter %[[VAL_11]], %[[VAL_16]], %[[VAL_10]] : (tensor<1x48x1xf32>, tensor<1x6xi32>, tensor<1x6x1xf32>) -> tensor<1x48x1xf32>
|
||||||
|
// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array<i64: 6, 8>} : (tensor<1x48x1xf32>) -> tensor<6x8xf32>
|
||||||
|
// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<6x8xf32> -> !torch.vtensor<[6,8],f32>
|
||||||
|
// CHECK: return %[[VAL_19]] : !torch.vtensor<[6,8],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.slice_scatter$basic(%arg0: !torch.vtensor<[6,8],f32>, %arg1: !torch.vtensor<[6,1],f32>) -> !torch.vtensor<[6,8],f32> {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%0 = torch.aten.slice_scatter %arg0, %arg1, %int1, %int0, %int1, %int1 : !torch.vtensor<[6,8],f32>, !torch.vtensor<[6,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[6,8],f32>
|
||||||
|
return %0 : !torch.vtensor<[6,8],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.diag_embed$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,4],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch.constant.int -2
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.int -1
|
||||||
|
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]]], {{\[\[}}[0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]]]]> : tensor<2x3x4x1xi32>}> : () -> tensor<2x3x4x1xi32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 2, 3, 4, 1>} : (tensor<2x3x4xf32>) -> tensor<2x3x4x1xf32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x3x4x4xf32>}> : () -> tensor<2x3x4x4xf32>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array<i64: 2, 3, 4, 1, 1>} : (tensor<2x3x4x1xi32>) -> tensor<2x3x4x1x1xi32>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}[0]], {{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]]], {{\[\[}}[0]], {{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]]], {{\[\[}}[0]], {{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]]]], {{\[\[}}{{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]]], {{\[\[}}[1]], {{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]]], {{\[\[}}[1]], {{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]]]]]> : tensor<2x3x4x1x1xi32>}> : () -> tensor<2x3x4x1x1xi32>
|
||||||
|
// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}[0]], {{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]]], {{\[\[}}[1]], {{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]]], {{\[\[}}[2]], {{\[\[}}2]], {{\[\[}}2]], {{\[\[}}2]]]], {{\[\[}}{{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]]], {{\[\[}}[1]], {{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]]], {{\[\[}}[2]], {{\[\[}}2]], {{\[\[}}2]], {{\[\[}}2]]]]]> : tensor<2x3x4x1x1xi32>}> : () -> tensor<2x3x4x1x1xi32>
|
||||||
|
// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]], {{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]], {{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]]], {{\[\[}}{{\[\[}}0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]], {{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]], {{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]]]]> : tensor<2x3x4x1x1xi32>}> : () -> tensor<2x3x4x1x1xi32>
|
||||||
|
// CHECK: %[[VAL_12:.*]] = tosa.concat %[[VAL_9]], %[[VAL_10]], %[[VAL_11]], %[[VAL_8]] {axis = 4 : i32} : (tensor<2x3x4x1x1xi32>, tensor<2x3x4x1x1xi32>, tensor<2x3x4x1x1xi32>, tensor<2x3x4x1x1xi32>) -> tensor<2x3x4x1x4xi32>
|
||||||
|
// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array<i64: 1, 24, 1>} : (tensor<2x3x4x1xf32>) -> tensor<1x24x1xf32>
|
||||||
|
// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 1, 96, 1>} : (tensor<2x3x4x4xf32>) -> tensor<1x96x1xf32>
|
||||||
|
// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array<i64: 24, 4>} : (tensor<2x3x4x1x4xi32>) -> tensor<24x4xi32>
|
||||||
|
// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{value = dense<[48, 16, 4, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
|
||||||
|
// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_15]], %[[VAL_16]] {shift = 0 : i8} : (tensor<24x4xi32>, tensor<4xi32>) -> tensor<24x4xi32>
|
||||||
|
// CHECK: %[[VAL_18:.*]] = tosa.reduce_sum %[[VAL_17]] {axis = 1 : i32} : (tensor<24x4xi32>) -> tensor<24x1xi32>
|
||||||
|
// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array<i64: 1, 24>} : (tensor<24x1xi32>) -> tensor<1x24xi32>
|
||||||
|
// CHECK: %[[VAL_20:.*]] = tosa.scatter %[[VAL_14]], %[[VAL_19]], %[[VAL_13]] : (tensor<1x96x1xf32>, tensor<1x24xi32>, tensor<1x24x1xf32>) -> tensor<1x96x1xf32>
|
||||||
|
// CHECK: %[[VAL_21:.*]] = tosa.reshape %[[VAL_20]] {new_shape = array<i64: 2, 3, 4, 4>} : (tensor<1x96x1xf32>) -> tensor<2x3x4x4xf32>
|
||||||
|
// CHECK: %[[VAL_22:.*]] = "tosa.const"() <{value = dense<[0, 1, 2, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
|
||||||
|
// CHECK: %[[VAL_23:.*]] = tosa.transpose %[[VAL_21]], %[[VAL_22]] : (tensor<2x3x4x4xf32>, tensor<4xi32>) -> tensor<2x3x4x4xf32>
|
||||||
|
// CHECK: %[[VAL_24:.*]] = torch_c.from_builtin_tensor %[[VAL_23]] : tensor<2x3x4x4xf32> -> !torch.vtensor<[2,3,4,4],f32>
|
||||||
|
// CHECK: return %[[VAL_24]] : !torch.vtensor<[2,3,4,4],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.diag_embed$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,4],f32> {
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%int-2 = torch.constant.int -2
|
||||||
|
%int-1 = torch.constant.int -1
|
||||||
|
%0 = torch.aten.diag_embed %arg0, %int0, %int-2, %int-1 : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3,4,4],f32>
|
||||||
|
return %0 : !torch.vtensor<[2,3,4,4],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue