mirror of https://github.com/llvm/torch-mlir
[TOSA] Add legalization for aten.diagonal (#3740)
- Add lowering from Torch to TOSA for aten.diagonal - Clean up some code - Update xfail_sets.py with the new e2e results - Update basic_mlir with the new op mlir test Signed-off-by: Justin Ngo <justin.ngo@arm.com> Change-Id: I99bed685455752d09ed96edd837c4dfbee152701 Signed-off-by: Justin Ngo <justin.ngo@arm.com>pull/3646/merge
parent
5f74de5ba0
commit
5eab669c4a
|
@ -891,8 +891,6 @@ public:
|
||||||
if (!result)
|
if (!result)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// TBD - support dtype casting.
|
|
||||||
|
|
||||||
rewriter.replaceOp(op, {result.value()});
|
rewriter.replaceOp(op, {result.value()});
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
|
@ -5647,8 +5645,7 @@ ConvertAtenOp<Aten__InterpolateSizeListScaleListOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Template to create support tril mask tensor for aten.tril
|
// Template to create supporting tril mask tensor for aten.tril
|
||||||
// legalization
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Value createTrilMask(PatternRewriter &rewriter, Operation *op,
|
Value createTrilMask(PatternRewriter &rewriter, Operation *op,
|
||||||
ArrayRef<int64_t> shape, int64_t h, int64_t w,
|
ArrayRef<int64_t> shape, int64_t h, int64_t w,
|
||||||
|
@ -5671,28 +5668,6 @@ Value createTrilMask(PatternRewriter &rewriter, Operation *op,
|
||||||
return tosa::getConstTensor<T>(rewriter, op, vec, shape).value();
|
return tosa::getConstTensor<T>(rewriter, op, vec, shape).value();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Function to get tril mask tensor based on input type
|
|
||||||
// for aten.tril legalization
|
|
||||||
Value getTrilMask(PatternRewriter &rewriter, Operation *op,
|
|
||||||
ArrayRef<int64_t> shape, int64_t h, int64_t w,
|
|
||||||
int64_t diagonal, Type type) {
|
|
||||||
return TypeSwitch<Type, Value>(type)
|
|
||||||
.Case<mlir::FloatType>([&](auto) {
|
|
||||||
return createTrilMask<float>(rewriter, op, shape, h, w, diagonal);
|
|
||||||
})
|
|
||||||
.Case<mlir::IntegerType>([&](auto intType) {
|
|
||||||
switch (intType.getWidth()) {
|
|
||||||
case 1:
|
|
||||||
return createTrilMask<bool>(rewriter, op, shape, h, w, diagonal);
|
|
||||||
case 32:
|
|
||||||
return createTrilMask<int32_t>(rewriter, op, shape, h, w, diagonal);
|
|
||||||
case 64:
|
|
||||||
return createTrilMask<int64_t>(rewriter, op, shape, h, w, diagonal);
|
|
||||||
}
|
|
||||||
llvm_unreachable("Invalid integer width");
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Legalization for aten.tril
|
// Legalization for aten.tril
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
|
||||||
|
@ -5740,14 +5715,31 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
|
||||||
return rewriter.notifyMatchFailure(op, "Diagonal value is not an integer");
|
return rewriter.notifyMatchFailure(op, "Diagonal value is not an integer");
|
||||||
|
|
||||||
// Define shape for mask tensor based on rank
|
// Define shape for mask tensor based on rank
|
||||||
SmallVector<int64_t> constShape;
|
SmallVector<int64_t> maskShape;
|
||||||
for (auto i = 0; i < selfRank - 2; i++)
|
for (auto i = 0; i < selfRank - 2; i++)
|
||||||
constShape.push_back(1);
|
maskShape.push_back(1);
|
||||||
constShape.push_back(h);
|
maskShape.push_back(h);
|
||||||
constShape.push_back(w);
|
maskShape.push_back(w);
|
||||||
|
|
||||||
Value trilMask = getTrilMask(rewriter, op, constShape, h, w, diagonal,
|
Value trilMask = TypeSwitch<Type, Value>(resultType.getElementType())
|
||||||
resultType.getElementType());
|
.Case<mlir::FloatType>([&](auto) {
|
||||||
|
return createTrilMask<float>(rewriter, op, maskShape,
|
||||||
|
h, w, diagonal);
|
||||||
|
})
|
||||||
|
.Case<mlir::IntegerType>([&](auto intType) {
|
||||||
|
switch (intType.getWidth()) {
|
||||||
|
case 1:
|
||||||
|
return createTrilMask<bool>(rewriter, op, maskShape,
|
||||||
|
h, w, diagonal);
|
||||||
|
case 32:
|
||||||
|
return createTrilMask<int32_t>(
|
||||||
|
rewriter, op, maskShape, h, w, diagonal);
|
||||||
|
case 64:
|
||||||
|
return createTrilMask<int64_t>(
|
||||||
|
rewriter, op, maskShape, h, w, diagonal);
|
||||||
|
}
|
||||||
|
llvm_unreachable("Invalid integer width");
|
||||||
|
});
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tosa::MulOp>(op, resultType, self, trilMask,
|
rewriter.replaceOpWithNewOp<tosa::MulOp>(op, resultType, self, trilMask,
|
||||||
/*shift=*/0);
|
/*shift=*/0);
|
||||||
|
@ -5755,6 +5747,189 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Template to create supporting diagonal mask tensor for aten.diagonal
|
||||||
|
template <typename T>
|
||||||
|
Value createDiagonalMask(PatternRewriter &rewriter, Operation *op,
|
||||||
|
ArrayRef<int64_t> shape, int64_t h, int64_t w,
|
||||||
|
int64_t offset) {
|
||||||
|
SmallVector<T> vec;
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < h; i++) {
|
||||||
|
for (int64_t j = 0; j < w; j++) {
|
||||||
|
// Positive offset value moves above the main diagonal, while negative
|
||||||
|
// diagonal value moves below the main diagonal.
|
||||||
|
if (i + offset == j) {
|
||||||
|
vec.push_back(static_cast<T>(1));
|
||||||
|
} else {
|
||||||
|
vec.push_back(static_cast<T>(0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tosa::getConstTensor<T>(rewriter, op, vec, shape).value();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Legalization for aten.diagonal
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenDiagonalOp>::matchAndRewrite(
|
||||||
|
AtenDiagonalOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
auto self = adaptor.getSelf();
|
||||||
|
|
||||||
|
// Not a ranked tensor type
|
||||||
|
auto selfType = dyn_cast<RankedTensorType>(self.getType());
|
||||||
|
if (!selfType)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only ranked tensor types are supported");
|
||||||
|
|
||||||
|
// Rank below 2 not accepted
|
||||||
|
auto selfRank = selfType.getRank();
|
||||||
|
if (selfRank <= 1)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Rank 0 and 1 are not accepted as they cause underflow");
|
||||||
|
|
||||||
|
if (!selfType.hasStaticShape())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Currently only static shapes are supported");
|
||||||
|
|
||||||
|
const TypeConverter *typeConverter = this->getTypeConverter();
|
||||||
|
RankedTensorType resultType = cast<RankedTensorType>(
|
||||||
|
typeConverter->convertType(op->getResult(0).getType()));
|
||||||
|
if (!resultType)
|
||||||
|
return rewriter.notifyMatchFailure(op, "Result type cannot be empty");
|
||||||
|
|
||||||
|
auto selfElemTy = selfType.getElementType();
|
||||||
|
auto resultElemTy = resultType.getElementType();
|
||||||
|
|
||||||
|
int64_t offset, dim1, dim2;
|
||||||
|
if (!matchPattern(op.getOffset(), m_TorchConstantInt(&offset)))
|
||||||
|
offset = 0;
|
||||||
|
|
||||||
|
if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1))) {
|
||||||
|
dim1 = 0;
|
||||||
|
} else {
|
||||||
|
dim1 = toPositiveDim(dim1, selfRank);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!matchPattern(op.getDim2(), m_TorchConstantInt(&dim2))) {
|
||||||
|
dim2 = 1;
|
||||||
|
} else {
|
||||||
|
dim2 = toPositiveDim(dim2, selfRank);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto selfShape = makeShapeTorchCompatible(selfType.getShape());
|
||||||
|
int64_t h = selfShape[dim1];
|
||||||
|
int64_t w = selfShape[dim2];
|
||||||
|
|
||||||
|
// Overflowing offset not supported
|
||||||
|
if ((offset < 0 && std::abs(offset) >= h) || (offset >= 0 && offset >= w))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Offset greater or equal than shape not supported");
|
||||||
|
|
||||||
|
int64_t targetDim1 = selfRank - 2;
|
||||||
|
int64_t targetDim2 = selfRank - 1;
|
||||||
|
|
||||||
|
Value selfTransposed = self;
|
||||||
|
SmallVector<int64_t> transposedInputShape = selfShape;
|
||||||
|
RankedTensorType transposedInputType = selfType;
|
||||||
|
|
||||||
|
// If (dim1, dim2) != (rank - 2, rank - 1), transpose the input tensor
|
||||||
|
// so that dim1 and dim2 become rank - 2 and rank - 1. We do this so that
|
||||||
|
// we can consistently create the diagonal mask tensor.
|
||||||
|
if (!(dim1 == targetDim1 && dim2 == targetDim2)) {
|
||||||
|
SmallVector<int32_t> transposedDims;
|
||||||
|
transposedInputShape.clear();
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < selfRank; ++i) {
|
||||||
|
if (i == dim1 || i == dim2)
|
||||||
|
continue;
|
||||||
|
transposedDims.push_back(i);
|
||||||
|
}
|
||||||
|
transposedDims.push_back(dim1);
|
||||||
|
transposedDims.push_back(dim2);
|
||||||
|
|
||||||
|
auto transposedDimsConst = tosa::getConstTensor<int32_t>(
|
||||||
|
rewriter, op,
|
||||||
|
/*vec=*/transposedDims,
|
||||||
|
/*shape=*/{static_cast<int32_t>(selfRank)});
|
||||||
|
|
||||||
|
for (auto &dim : transposedDims)
|
||||||
|
transposedInputShape.push_back(selfShape[dim]);
|
||||||
|
|
||||||
|
transposedInputType = RankedTensorType::get(
|
||||||
|
makeShapeLLVMCompatible(transposedInputShape), selfElemTy);
|
||||||
|
|
||||||
|
selfTransposed = rewriter.create<tosa::TransposeOp>(
|
||||||
|
op->getLoc(), transposedInputType, self, transposedDimsConst.value());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define shape for mask tensor based on rank
|
||||||
|
SmallVector<int64_t> maskShape;
|
||||||
|
for (auto i = 0; i < selfRank - 2; i++)
|
||||||
|
maskShape.push_back(1);
|
||||||
|
maskShape.push_back(h);
|
||||||
|
maskShape.push_back(w);
|
||||||
|
|
||||||
|
Value diagonalMask =
|
||||||
|
TypeSwitch<Type, Value>(resultElemTy)
|
||||||
|
.Case<mlir::FloatType>([&](auto) {
|
||||||
|
return createDiagonalMask<float>(rewriter, op, maskShape, h, w,
|
||||||
|
offset);
|
||||||
|
})
|
||||||
|
.Case<mlir::IntegerType>([&](auto intType) {
|
||||||
|
switch (intType.getWidth()) {
|
||||||
|
case 1:
|
||||||
|
return createDiagonalMask<bool>(rewriter, op, maskShape, h, w,
|
||||||
|
offset);
|
||||||
|
case 32:
|
||||||
|
return createDiagonalMask<int32_t>(rewriter, op, maskShape, h, w,
|
||||||
|
offset);
|
||||||
|
case 64:
|
||||||
|
return createDiagonalMask<int64_t>(rewriter, op, maskShape, h, w,
|
||||||
|
offset);
|
||||||
|
}
|
||||||
|
llvm_unreachable("Invalid integer width");
|
||||||
|
});
|
||||||
|
|
||||||
|
Value diagonalTensor = rewriter.create<tosa::MulOp>(
|
||||||
|
op->getLoc(), transposedInputType, selfTransposed, diagonalMask,
|
||||||
|
/*shift=*/0);
|
||||||
|
|
||||||
|
auto resultShape = makeShapeTorchCompatible(resultType.getShape());
|
||||||
|
auto targetReduceDim = resultShape[resultType.getRank() - 1];
|
||||||
|
|
||||||
|
// If transposedInputShape[targetDim1] (or h) is greater than the innermost
|
||||||
|
// dim of the result, we won't get the correct shape when we reduce sum along
|
||||||
|
// the innermost dim to get the result. Therefore, we have to slice the
|
||||||
|
// transposed tensor so that transposedInputShape[targetDim1] ==
|
||||||
|
// targetReduceDim.
|
||||||
|
if (h > targetReduceDim) {
|
||||||
|
transposedInputShape[targetDim1] = targetReduceDim;
|
||||||
|
transposedInputType = RankedTensorType::get(
|
||||||
|
makeShapeLLVMCompatible(transposedInputShape), selfElemTy);
|
||||||
|
SmallVector<int64_t> startSlice(selfRank, 0);
|
||||||
|
SmallVector<int64_t> sizeSlice =
|
||||||
|
llvm::to_vector(makeShapeTorchCompatible(transposedInputShape));
|
||||||
|
if (offset < 0)
|
||||||
|
startSlice[targetDim1] = std::abs(offset);
|
||||||
|
diagonalTensor = rewriter.create<tosa::SliceOp>(
|
||||||
|
op->getLoc(), transposedInputType, diagonalTensor,
|
||||||
|
rewriter.getDenseI64ArrayAttr(startSlice),
|
||||||
|
rewriter.getDenseI64ArrayAttr(sizeSlice));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply Reduce Sum to get the result
|
||||||
|
auto reduceDimType = RankedTensorType::get({1}, rewriter.getI64Type());
|
||||||
|
auto reduceDimAttr =
|
||||||
|
DenseIntElementsAttr::get(reduceDimType, llvm::ArrayRef({targetDim2}));
|
||||||
|
auto result =
|
||||||
|
mlir::tosa::convertReduceSumOp(rewriter, op, resultType, diagonalTensor,
|
||||||
|
reduceDimAttr, /*keep_dims=*/false);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, result.value());
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
@ -6060,6 +6235,7 @@ public:
|
||||||
INSERT_ATENOP_PATTERN(AtenIscloseOp);
|
INSERT_ATENOP_PATTERN(AtenIscloseOp);
|
||||||
INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp);
|
INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenTrilOp);
|
INSERT_ATENOP_PATTERN(AtenTrilOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenDiagonalOp);
|
||||||
#undef INSERT_ATENOP_PATTERN
|
#undef INSERT_ATENOP_PATTERN
|
||||||
|
|
||||||
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
||||||
|
|
|
@ -1663,6 +1663,8 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
|
||||||
# Write the TOSA set as a "passing" set as it is very early in development
|
# Write the TOSA set as a "passing" set as it is very early in development
|
||||||
# and very few tests work yet.
|
# and very few tests work yet.
|
||||||
TOSA_PASS_SET = {
|
TOSA_PASS_SET = {
|
||||||
|
"DiagonalWithStaticShapeModule_basic",
|
||||||
|
"EinsumStaticDiagonalDimensionModule_basic",
|
||||||
"ElementwiseAtenFloorDivideBroadcastModule_basic",
|
"ElementwiseAtenFloorDivideBroadcastModule_basic",
|
||||||
"ElementwiseAtenFloorDivideScalarModule_basic",
|
"ElementwiseAtenFloorDivideScalarModule_basic",
|
||||||
"ElementwiseAtenFloorDivideScalarNegativeModule_basic",
|
"ElementwiseAtenFloorDivideScalarNegativeModule_basic",
|
||||||
|
@ -3190,6 +3192,7 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
|
||||||
}
|
}
|
||||||
|
|
||||||
FX_IMPORTER_TOSA_XFAIL_SET = {
|
FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
|
"AdaptiveMaxPool1dDimOneStatic_basic",
|
||||||
"AtenPolarDoubleModule_basic",
|
"AtenPolarDoubleModule_basic",
|
||||||
"AtenPolarFloatModule_basic",
|
"AtenPolarFloatModule_basic",
|
||||||
"HstackBasicComplexModule_basic",
|
"HstackBasicComplexModule_basic",
|
||||||
|
@ -3213,7 +3216,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"Conv_Transpose2dStaticModule_basic",
|
"Conv_Transpose2dStaticModule_basic",
|
||||||
"Conv_Transpose3dModule_basic",
|
"Conv_Transpose3dModule_basic",
|
||||||
"Conv_Transpose3dStaticModule_basic",
|
"Conv_Transpose3dStaticModule_basic",
|
||||||
"EinsumStaticDiagonalDimensionModule_basic",
|
|
||||||
"ElementwiseFloatTensorGtIntTensorModule_basic",
|
"ElementwiseFloatTensorGtIntTensorModule_basic",
|
||||||
"ElementwiseIntTensorLtFloatTensorModule_basic",
|
"ElementwiseIntTensorLtFloatTensorModule_basic",
|
||||||
"ElementwiseRreluEvalModule_basic",
|
"ElementwiseRreluEvalModule_basic",
|
||||||
|
@ -3384,14 +3386,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"DeterminantBatchedModule_F32",
|
"DeterminantBatchedModule_F32",
|
||||||
"DeterminantDynamicModule_F32",
|
"DeterminantDynamicModule_F32",
|
||||||
"DeterminantModule_F32",
|
"DeterminantModule_F32",
|
||||||
"DiagonalModule_basic",
|
|
||||||
"DiagonalModule_nonsquare",
|
|
||||||
"DiagonalModule_transposed",
|
|
||||||
"DiagonalModule_with_dims",
|
|
||||||
"DiagonalModule_with_dims_and_offset",
|
|
||||||
"DiagonalModule_with_negative_dims",
|
|
||||||
"DiagonalModule_with_offset",
|
|
||||||
"DiagonalWithStaticShapeModule_basic",
|
|
||||||
"DivFloatModule_basic",
|
"DivFloatModule_basic",
|
||||||
"DivIntModule_basic",
|
"DivIntModule_basic",
|
||||||
"DropoutTrainModule_basic",
|
"DropoutTrainModule_basic",
|
||||||
|
@ -3805,11 +3799,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ToCopyWithDTypeModule_basic",
|
"ToCopyWithDTypeModule_basic",
|
||||||
"TorchPrimLoopForLikeModule_basic",
|
"TorchPrimLoopForLikeModule_basic",
|
||||||
"TorchPrimLoopWhileLikeModule_basic",
|
"TorchPrimLoopWhileLikeModule_basic",
|
||||||
"TraceModule_basic",
|
|
||||||
"TraceModule_empty",
|
"TraceModule_empty",
|
||||||
"TraceModule_nonsquare",
|
|
||||||
"TraceSignedIntModule_basic",
|
|
||||||
"TraceUnsignedIntModule_basic",
|
|
||||||
"TraceUnsignedIntModule_empty",
|
"TraceUnsignedIntModule_empty",
|
||||||
"TypeConversionI1ToF64Module_basic",
|
"TypeConversionI1ToF64Module_basic",
|
||||||
"TypeConversionI1ToI32Module_basic",
|
"TypeConversionI1ToI32Module_basic",
|
||||||
|
@ -3845,6 +3835,7 @@ ONNX_TOSA_CRASHING_SET = {
|
||||||
}
|
}
|
||||||
|
|
||||||
ONNX_TOSA_XFAIL_SET = {
|
ONNX_TOSA_XFAIL_SET = {
|
||||||
|
"AdaptiveMaxPool1dDimOneStatic_basic",
|
||||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||||
"HstackBasicComplexModule_basic",
|
"HstackBasicComplexModule_basic",
|
||||||
"HstackBasicFloatModule_basic",
|
"HstackBasicFloatModule_basic",
|
||||||
|
@ -3874,7 +3865,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"Conv_Transpose2dStaticModule_basic",
|
"Conv_Transpose2dStaticModule_basic",
|
||||||
"Conv_Transpose3dModule_basic",
|
"Conv_Transpose3dModule_basic",
|
||||||
"Conv_Transpose3dStaticModule_basic",
|
"Conv_Transpose3dStaticModule_basic",
|
||||||
"EinsumStaticDiagonalDimensionModule_basic",
|
|
||||||
"EinsumStaticModule_basic",
|
"EinsumStaticModule_basic",
|
||||||
"ElementwiseFmaxModule_basic",
|
"ElementwiseFmaxModule_basic",
|
||||||
"ElementwiseFminModule_basic",
|
"ElementwiseFminModule_basic",
|
||||||
|
|
|
@ -1859,3 +1859,29 @@ func.func @torch.aten.bitwise_right_shift.Tensor$basic(%arg0: !torch.vtensor<[?,
|
||||||
%0 = torch.aten.bitwise_right_shift.Tensor %arg0, %arg1: !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32>
|
%0 = torch.aten.bitwise_right_shift.Tensor %arg0, %arg1: !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32>
|
||||||
return %0: !torch.vtensor<[?,?],si32>
|
return %0: !torch.vtensor<[?,?],si32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.diagonal$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5,6],si32>) -> !torch.vtensor<[5,6,2],si32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5,6],si32> -> tensor<3x4x5x6xi32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.int -2
|
||||||
|
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<[2, 3, 1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_5]] : (tensor<3x4x5x6xi32>, tensor<4xi32>) -> tensor<5x6x4x3xi32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0, 0, 0], [0, 0, 0], [1, 0, 0], [0, 1, 0]]]]> : tensor<1x1x4x3xi32>}> : () -> tensor<1x1x4x3xi32>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_6]], %[[VAL_7]] {shift = 0 : i8} : (tensor<5x6x4x3xi32>, tensor<1x1x4x3xi32>) -> tensor<5x6x4x3xi32>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = tosa.slice %[[VAL_8]] {size = array<i64: 5, 6, 2, 3>, start = array<i64: 0, 0, 2, 0>} : (tensor<5x6x4x3xi32>) -> tensor<5x6x2x3xi32>
|
||||||
|
// CHECK: %[[VAL_10:.*]] = tosa.reduce_sum %[[VAL_9]] {axis = 3 : i32} : (tensor<5x6x2x3xi32>) -> tensor<5x6x2x1xi32>
|
||||||
|
// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_10]] {new_shape = array<i64: 5, 6, 2>} : (tensor<5x6x2x1xi32>) -> tensor<5x6x2xi32>
|
||||||
|
// CHECK: %[[VAL_12:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<5x6x2xi32> -> !torch.vtensor<[5,6,2],si32>
|
||||||
|
// CHECK: return %[[VAL_12]] : !torch.vtensor<[5,6,2],si32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.diagonal$basic(%arg0: !torch.vtensor<[3,4,5,6], si32>) -> !torch.vtensor<[5,6,2], si32> {
|
||||||
|
%dim1 = torch.constant.int 1
|
||||||
|
%dim2 = torch.constant.int 0
|
||||||
|
%offset = torch.constant.int -2
|
||||||
|
%0 = torch.aten.diagonal %arg0, %offset, %dim1, %dim2 : !torch.vtensor<[3,4,5,6],si32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[5,6,2],si32>
|
||||||
|
return %0 : !torch.vtensor<[5,6,2],si32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue