mirror of https://github.com/llvm/torch-mlir
[TOSA] Add legalization for aten.diagonal (#3740)
- Add lowering from Torch to TOSA for aten.diagonal - Clean up some code - Update xfail_sets.py with the new e2e results - Update basic_mlir with the new op mlir test Signed-off-by: Justin Ngo <justin.ngo@arm.com> Change-Id: I99bed685455752d09ed96edd837c4dfbee152701 Signed-off-by: Justin Ngo <justin.ngo@arm.com>pull/3646/merge
parent
5f74de5ba0
commit
5eab669c4a
|
@ -891,8 +891,6 @@ public:
|
|||
if (!result)
|
||||
return failure();
|
||||
|
||||
// TBD - support dtype casting.
|
||||
|
||||
rewriter.replaceOp(op, {result.value()});
|
||||
|
||||
return success();
|
||||
|
@ -5647,8 +5645,7 @@ ConvertAtenOp<Aten__InterpolateSizeListScaleListOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
// Template to create support tril mask tensor for aten.tril
|
||||
// legalization
|
||||
// Template to create supporting tril mask tensor for aten.tril
|
||||
template <typename T>
|
||||
Value createTrilMask(PatternRewriter &rewriter, Operation *op,
|
||||
ArrayRef<int64_t> shape, int64_t h, int64_t w,
|
||||
|
@ -5671,28 +5668,6 @@ Value createTrilMask(PatternRewriter &rewriter, Operation *op,
|
|||
return tosa::getConstTensor<T>(rewriter, op, vec, shape).value();
|
||||
}
|
||||
|
||||
// Function to get tril mask tensor based on input type
|
||||
// for aten.tril legalization
|
||||
Value getTrilMask(PatternRewriter &rewriter, Operation *op,
|
||||
ArrayRef<int64_t> shape, int64_t h, int64_t w,
|
||||
int64_t diagonal, Type type) {
|
||||
return TypeSwitch<Type, Value>(type)
|
||||
.Case<mlir::FloatType>([&](auto) {
|
||||
return createTrilMask<float>(rewriter, op, shape, h, w, diagonal);
|
||||
})
|
||||
.Case<mlir::IntegerType>([&](auto intType) {
|
||||
switch (intType.getWidth()) {
|
||||
case 1:
|
||||
return createTrilMask<bool>(rewriter, op, shape, h, w, diagonal);
|
||||
case 32:
|
||||
return createTrilMask<int32_t>(rewriter, op, shape, h, w, diagonal);
|
||||
case 64:
|
||||
return createTrilMask<int64_t>(rewriter, op, shape, h, w, diagonal);
|
||||
}
|
||||
llvm_unreachable("Invalid integer width");
|
||||
});
|
||||
}
|
||||
|
||||
// Legalization for aten.tril
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
|
||||
|
@ -5740,14 +5715,31 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(op, "Diagonal value is not an integer");
|
||||
|
||||
// Define shape for mask tensor based on rank
|
||||
SmallVector<int64_t> constShape;
|
||||
SmallVector<int64_t> maskShape;
|
||||
for (auto i = 0; i < selfRank - 2; i++)
|
||||
constShape.push_back(1);
|
||||
constShape.push_back(h);
|
||||
constShape.push_back(w);
|
||||
maskShape.push_back(1);
|
||||
maskShape.push_back(h);
|
||||
maskShape.push_back(w);
|
||||
|
||||
Value trilMask = getTrilMask(rewriter, op, constShape, h, w, diagonal,
|
||||
resultType.getElementType());
|
||||
Value trilMask = TypeSwitch<Type, Value>(resultType.getElementType())
|
||||
.Case<mlir::FloatType>([&](auto) {
|
||||
return createTrilMask<float>(rewriter, op, maskShape,
|
||||
h, w, diagonal);
|
||||
})
|
||||
.Case<mlir::IntegerType>([&](auto intType) {
|
||||
switch (intType.getWidth()) {
|
||||
case 1:
|
||||
return createTrilMask<bool>(rewriter, op, maskShape,
|
||||
h, w, diagonal);
|
||||
case 32:
|
||||
return createTrilMask<int32_t>(
|
||||
rewriter, op, maskShape, h, w, diagonal);
|
||||
case 64:
|
||||
return createTrilMask<int64_t>(
|
||||
rewriter, op, maskShape, h, w, diagonal);
|
||||
}
|
||||
llvm_unreachable("Invalid integer width");
|
||||
});
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::MulOp>(op, resultType, self, trilMask,
|
||||
/*shift=*/0);
|
||||
|
@ -5755,6 +5747,189 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
// Template to create supporting diagonal mask tensor for aten.diagonal
|
||||
template <typename T>
|
||||
Value createDiagonalMask(PatternRewriter &rewriter, Operation *op,
|
||||
ArrayRef<int64_t> shape, int64_t h, int64_t w,
|
||||
int64_t offset) {
|
||||
SmallVector<T> vec;
|
||||
|
||||
for (int64_t i = 0; i < h; i++) {
|
||||
for (int64_t j = 0; j < w; j++) {
|
||||
// Positive offset value moves above the main diagonal, while negative
|
||||
// diagonal value moves below the main diagonal.
|
||||
if (i + offset == j) {
|
||||
vec.push_back(static_cast<T>(1));
|
||||
} else {
|
||||
vec.push_back(static_cast<T>(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tosa::getConstTensor<T>(rewriter, op, vec, shape).value();
|
||||
}
|
||||
|
||||
// Legalization for aten.diagonal
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenDiagonalOp>::matchAndRewrite(
|
||||
AtenDiagonalOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto self = adaptor.getSelf();
|
||||
|
||||
// Not a ranked tensor type
|
||||
auto selfType = dyn_cast<RankedTensorType>(self.getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only ranked tensor types are supported");
|
||||
|
||||
// Rank below 2 not accepted
|
||||
auto selfRank = selfType.getRank();
|
||||
if (selfRank <= 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Rank 0 and 1 are not accepted as they cause underflow");
|
||||
|
||||
if (!selfType.hasStaticShape())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Currently only static shapes are supported");
|
||||
|
||||
const TypeConverter *typeConverter = this->getTypeConverter();
|
||||
RankedTensorType resultType = cast<RankedTensorType>(
|
||||
typeConverter->convertType(op->getResult(0).getType()));
|
||||
if (!resultType)
|
||||
return rewriter.notifyMatchFailure(op, "Result type cannot be empty");
|
||||
|
||||
auto selfElemTy = selfType.getElementType();
|
||||
auto resultElemTy = resultType.getElementType();
|
||||
|
||||
int64_t offset, dim1, dim2;
|
||||
if (!matchPattern(op.getOffset(), m_TorchConstantInt(&offset)))
|
||||
offset = 0;
|
||||
|
||||
if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1))) {
|
||||
dim1 = 0;
|
||||
} else {
|
||||
dim1 = toPositiveDim(dim1, selfRank);
|
||||
}
|
||||
|
||||
if (!matchPattern(op.getDim2(), m_TorchConstantInt(&dim2))) {
|
||||
dim2 = 1;
|
||||
} else {
|
||||
dim2 = toPositiveDim(dim2, selfRank);
|
||||
}
|
||||
|
||||
auto selfShape = makeShapeTorchCompatible(selfType.getShape());
|
||||
int64_t h = selfShape[dim1];
|
||||
int64_t w = selfShape[dim2];
|
||||
|
||||
// Overflowing offset not supported
|
||||
if ((offset < 0 && std::abs(offset) >= h) || (offset >= 0 && offset >= w))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Offset greater or equal than shape not supported");
|
||||
|
||||
int64_t targetDim1 = selfRank - 2;
|
||||
int64_t targetDim2 = selfRank - 1;
|
||||
|
||||
Value selfTransposed = self;
|
||||
SmallVector<int64_t> transposedInputShape = selfShape;
|
||||
RankedTensorType transposedInputType = selfType;
|
||||
|
||||
// If (dim1, dim2) != (rank - 2, rank - 1), transpose the input tensor
|
||||
// so that dim1 and dim2 become rank - 2 and rank - 1. We do this so that
|
||||
// we can consistently create the diagonal mask tensor.
|
||||
if (!(dim1 == targetDim1 && dim2 == targetDim2)) {
|
||||
SmallVector<int32_t> transposedDims;
|
||||
transposedInputShape.clear();
|
||||
|
||||
for (int64_t i = 0; i < selfRank; ++i) {
|
||||
if (i == dim1 || i == dim2)
|
||||
continue;
|
||||
transposedDims.push_back(i);
|
||||
}
|
||||
transposedDims.push_back(dim1);
|
||||
transposedDims.push_back(dim2);
|
||||
|
||||
auto transposedDimsConst = tosa::getConstTensor<int32_t>(
|
||||
rewriter, op,
|
||||
/*vec=*/transposedDims,
|
||||
/*shape=*/{static_cast<int32_t>(selfRank)});
|
||||
|
||||
for (auto &dim : transposedDims)
|
||||
transposedInputShape.push_back(selfShape[dim]);
|
||||
|
||||
transposedInputType = RankedTensorType::get(
|
||||
makeShapeLLVMCompatible(transposedInputShape), selfElemTy);
|
||||
|
||||
selfTransposed = rewriter.create<tosa::TransposeOp>(
|
||||
op->getLoc(), transposedInputType, self, transposedDimsConst.value());
|
||||
}
|
||||
|
||||
// Define shape for mask tensor based on rank
|
||||
SmallVector<int64_t> maskShape;
|
||||
for (auto i = 0; i < selfRank - 2; i++)
|
||||
maskShape.push_back(1);
|
||||
maskShape.push_back(h);
|
||||
maskShape.push_back(w);
|
||||
|
||||
Value diagonalMask =
|
||||
TypeSwitch<Type, Value>(resultElemTy)
|
||||
.Case<mlir::FloatType>([&](auto) {
|
||||
return createDiagonalMask<float>(rewriter, op, maskShape, h, w,
|
||||
offset);
|
||||
})
|
||||
.Case<mlir::IntegerType>([&](auto intType) {
|
||||
switch (intType.getWidth()) {
|
||||
case 1:
|
||||
return createDiagonalMask<bool>(rewriter, op, maskShape, h, w,
|
||||
offset);
|
||||
case 32:
|
||||
return createDiagonalMask<int32_t>(rewriter, op, maskShape, h, w,
|
||||
offset);
|
||||
case 64:
|
||||
return createDiagonalMask<int64_t>(rewriter, op, maskShape, h, w,
|
||||
offset);
|
||||
}
|
||||
llvm_unreachable("Invalid integer width");
|
||||
});
|
||||
|
||||
Value diagonalTensor = rewriter.create<tosa::MulOp>(
|
||||
op->getLoc(), transposedInputType, selfTransposed, diagonalMask,
|
||||
/*shift=*/0);
|
||||
|
||||
auto resultShape = makeShapeTorchCompatible(resultType.getShape());
|
||||
auto targetReduceDim = resultShape[resultType.getRank() - 1];
|
||||
|
||||
// If transposedInputShape[targetDim1] (or h) is greater than the innermost
|
||||
// dim of the result, we won't get the correct shape when we reduce sum along
|
||||
// the innermost dim to get the result. Therefore, we have to slice the
|
||||
// transposed tensor so that transposedInputShape[targetDim1] ==
|
||||
// targetReduceDim.
|
||||
if (h > targetReduceDim) {
|
||||
transposedInputShape[targetDim1] = targetReduceDim;
|
||||
transposedInputType = RankedTensorType::get(
|
||||
makeShapeLLVMCompatible(transposedInputShape), selfElemTy);
|
||||
SmallVector<int64_t> startSlice(selfRank, 0);
|
||||
SmallVector<int64_t> sizeSlice =
|
||||
llvm::to_vector(makeShapeTorchCompatible(transposedInputShape));
|
||||
if (offset < 0)
|
||||
startSlice[targetDim1] = std::abs(offset);
|
||||
diagonalTensor = rewriter.create<tosa::SliceOp>(
|
||||
op->getLoc(), transposedInputType, diagonalTensor,
|
||||
rewriter.getDenseI64ArrayAttr(startSlice),
|
||||
rewriter.getDenseI64ArrayAttr(sizeSlice));
|
||||
}
|
||||
|
||||
// Apply Reduce Sum to get the result
|
||||
auto reduceDimType = RankedTensorType::get({1}, rewriter.getI64Type());
|
||||
auto reduceDimAttr =
|
||||
DenseIntElementsAttr::get(reduceDimType, llvm::ArrayRef({targetDim2}));
|
||||
auto result =
|
||||
mlir::tosa::convertReduceSumOp(rewriter, op, resultType, diagonalTensor,
|
||||
reduceDimAttr, /*keep_dims=*/false);
|
||||
|
||||
rewriter.replaceOp(op, result.value());
|
||||
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -6060,6 +6235,7 @@ public:
|
|||
INSERT_ATENOP_PATTERN(AtenIscloseOp);
|
||||
INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp);
|
||||
INSERT_ATENOP_PATTERN(AtenTrilOp);
|
||||
INSERT_ATENOP_PATTERN(AtenDiagonalOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
||||
|
|
|
@ -1663,6 +1663,8 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
|
|||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
# and very few tests work yet.
|
||||
TOSA_PASS_SET = {
|
||||
"DiagonalWithStaticShapeModule_basic",
|
||||
"EinsumStaticDiagonalDimensionModule_basic",
|
||||
"ElementwiseAtenFloorDivideBroadcastModule_basic",
|
||||
"ElementwiseAtenFloorDivideScalarModule_basic",
|
||||
"ElementwiseAtenFloorDivideScalarNegativeModule_basic",
|
||||
|
@ -3190,6 +3192,7 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
|
|||
}
|
||||
|
||||
FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||
"AdaptiveMaxPool1dDimOneStatic_basic",
|
||||
"AtenPolarDoubleModule_basic",
|
||||
"AtenPolarFloatModule_basic",
|
||||
"HstackBasicComplexModule_basic",
|
||||
|
@ -3213,7 +3216,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"Conv_Transpose2dStaticModule_basic",
|
||||
"Conv_Transpose3dModule_basic",
|
||||
"Conv_Transpose3dStaticModule_basic",
|
||||
"EinsumStaticDiagonalDimensionModule_basic",
|
||||
"ElementwiseFloatTensorGtIntTensorModule_basic",
|
||||
"ElementwiseIntTensorLtFloatTensorModule_basic",
|
||||
"ElementwiseRreluEvalModule_basic",
|
||||
|
@ -3384,14 +3386,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"DeterminantBatchedModule_F32",
|
||||
"DeterminantDynamicModule_F32",
|
||||
"DeterminantModule_F32",
|
||||
"DiagonalModule_basic",
|
||||
"DiagonalModule_nonsquare",
|
||||
"DiagonalModule_transposed",
|
||||
"DiagonalModule_with_dims",
|
||||
"DiagonalModule_with_dims_and_offset",
|
||||
"DiagonalModule_with_negative_dims",
|
||||
"DiagonalModule_with_offset",
|
||||
"DiagonalWithStaticShapeModule_basic",
|
||||
"DivFloatModule_basic",
|
||||
"DivIntModule_basic",
|
||||
"DropoutTrainModule_basic",
|
||||
|
@ -3805,11 +3799,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ToCopyWithDTypeModule_basic",
|
||||
"TorchPrimLoopForLikeModule_basic",
|
||||
"TorchPrimLoopWhileLikeModule_basic",
|
||||
"TraceModule_basic",
|
||||
"TraceModule_empty",
|
||||
"TraceModule_nonsquare",
|
||||
"TraceSignedIntModule_basic",
|
||||
"TraceUnsignedIntModule_basic",
|
||||
"TraceUnsignedIntModule_empty",
|
||||
"TypeConversionI1ToF64Module_basic",
|
||||
"TypeConversionI1ToI32Module_basic",
|
||||
|
@ -3845,6 +3835,7 @@ ONNX_TOSA_CRASHING_SET = {
|
|||
}
|
||||
|
||||
ONNX_TOSA_XFAIL_SET = {
|
||||
"AdaptiveMaxPool1dDimOneStatic_basic",
|
||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||
"HstackBasicComplexModule_basic",
|
||||
"HstackBasicFloatModule_basic",
|
||||
|
@ -3874,7 +3865,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"Conv_Transpose2dStaticModule_basic",
|
||||
"Conv_Transpose3dModule_basic",
|
||||
"Conv_Transpose3dStaticModule_basic",
|
||||
"EinsumStaticDiagonalDimensionModule_basic",
|
||||
"EinsumStaticModule_basic",
|
||||
"ElementwiseFmaxModule_basic",
|
||||
"ElementwiseFminModule_basic",
|
||||
|
|
|
@ -1859,3 +1859,29 @@ func.func @torch.aten.bitwise_right_shift.Tensor$basic(%arg0: !torch.vtensor<[?,
|
|||
%0 = torch.aten.bitwise_right_shift.Tensor %arg0, %arg1: !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32>
|
||||
return %0: !torch.vtensor<[?,?],si32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.diagonal$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5,6],si32>) -> !torch.vtensor<[5,6,2],si32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5,6],si32> -> tensor<3x4x5x6xi32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.int -2
|
||||
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<[2, 3, 1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
|
||||
// CHECK: %[[VAL_6:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_5]] : (tensor<3x4x5x6xi32>, tensor<4xi32>) -> tensor<5x6x4x3xi32>
|
||||
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0, 0, 0], [0, 0, 0], [1, 0, 0], [0, 1, 0]]]]> : tensor<1x1x4x3xi32>}> : () -> tensor<1x1x4x3xi32>
|
||||
// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_6]], %[[VAL_7]] {shift = 0 : i8} : (tensor<5x6x4x3xi32>, tensor<1x1x4x3xi32>) -> tensor<5x6x4x3xi32>
|
||||
// CHECK: %[[VAL_9:.*]] = tosa.slice %[[VAL_8]] {size = array<i64: 5, 6, 2, 3>, start = array<i64: 0, 0, 2, 0>} : (tensor<5x6x4x3xi32>) -> tensor<5x6x2x3xi32>
|
||||
// CHECK: %[[VAL_10:.*]] = tosa.reduce_sum %[[VAL_9]] {axis = 3 : i32} : (tensor<5x6x2x3xi32>) -> tensor<5x6x2x1xi32>
|
||||
// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_10]] {new_shape = array<i64: 5, 6, 2>} : (tensor<5x6x2x1xi32>) -> tensor<5x6x2xi32>
|
||||
// CHECK: %[[VAL_12:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<5x6x2xi32> -> !torch.vtensor<[5,6,2],si32>
|
||||
// CHECK: return %[[VAL_12]] : !torch.vtensor<[5,6,2],si32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.diagonal$basic(%arg0: !torch.vtensor<[3,4,5,6], si32>) -> !torch.vtensor<[5,6,2], si32> {
|
||||
%dim1 = torch.constant.int 1
|
||||
%dim2 = torch.constant.int 0
|
||||
%offset = torch.constant.int -2
|
||||
%0 = torch.aten.diagonal %arg0, %offset, %dim1, %dim2 : !torch.vtensor<[3,4,5,6],si32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[5,6,2],si32>
|
||||
return %0 : !torch.vtensor<[5,6,2],si32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue