From 5eab669c4ab0c3aab3dab5b95d0172ab0a8395b8 Mon Sep 17 00:00:00 2001 From: Justin Ngo Date: Mon, 30 Sep 2024 08:24:31 -0700 Subject: [PATCH] [TOSA] Add legalization for aten.diagonal (#3740) - Add lowering from Torch to TOSA for aten.diagonal - Clean up some code - Update xfail_sets.py with the new e2e results - Update basic_mlir with the new op mlir test Signed-off-by: Justin Ngo Change-Id: I99bed685455752d09ed96edd837c4dfbee152701 Signed-off-by: Justin Ngo --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 240 ++++++++++++++++++--- projects/pt1/e2e_testing/xfail_sets.py | 18 +- test/Conversion/TorchToTosa/basic.mlir | 26 +++ 3 files changed, 238 insertions(+), 46 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 2a6b1612c..302752465 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -891,8 +891,6 @@ public: if (!result) return failure(); - // TBD - support dtype casting. - rewriter.replaceOp(op, {result.value()}); return success(); @@ -5647,8 +5645,7 @@ ConvertAtenOp::matchAndRewrite( return success(); } -// Template to create support tril mask tensor for aten.tril -// legalization +// Template to create supporting tril mask tensor for aten.tril template Value createTrilMask(PatternRewriter &rewriter, Operation *op, ArrayRef shape, int64_t h, int64_t w, @@ -5671,28 +5668,6 @@ Value createTrilMask(PatternRewriter &rewriter, Operation *op, return tosa::getConstTensor(rewriter, op, vec, shape).value(); } -// Function to get tril mask tensor based on input type -// for aten.tril legalization -Value getTrilMask(PatternRewriter &rewriter, Operation *op, - ArrayRef shape, int64_t h, int64_t w, - int64_t diagonal, Type type) { - return TypeSwitch(type) - .Case([&](auto) { - return createTrilMask(rewriter, op, shape, h, w, diagonal); - }) - .Case([&](auto intType) { - switch (intType.getWidth()) { - case 1: - return createTrilMask(rewriter, op, shape, h, w, diagonal); - case 32: - return createTrilMask(rewriter, op, shape, h, w, diagonal); - case 64: - return createTrilMask(rewriter, op, shape, h, w, diagonal); - } - llvm_unreachable("Invalid integer width"); - }); -} - // Legalization for aten.tril template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -5740,14 +5715,31 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "Diagonal value is not an integer"); // Define shape for mask tensor based on rank - SmallVector constShape; + SmallVector maskShape; for (auto i = 0; i < selfRank - 2; i++) - constShape.push_back(1); - constShape.push_back(h); - constShape.push_back(w); + maskShape.push_back(1); + maskShape.push_back(h); + maskShape.push_back(w); - Value trilMask = getTrilMask(rewriter, op, constShape, h, w, diagonal, - resultType.getElementType()); + Value trilMask = TypeSwitch(resultType.getElementType()) + .Case([&](auto) { + return createTrilMask(rewriter, op, maskShape, + h, w, diagonal); + }) + .Case([&](auto intType) { + switch (intType.getWidth()) { + case 1: + return createTrilMask(rewriter, op, maskShape, + h, w, diagonal); + case 32: + return createTrilMask( + rewriter, op, maskShape, h, w, diagonal); + case 64: + return createTrilMask( + rewriter, op, maskShape, h, w, diagonal); + } + llvm_unreachable("Invalid integer width"); + }); rewriter.replaceOpWithNewOp(op, resultType, self, trilMask, /*shift=*/0); @@ -5755,6 +5747,189 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// Template to create supporting diagonal mask tensor for aten.diagonal +template +Value createDiagonalMask(PatternRewriter &rewriter, Operation *op, + ArrayRef shape, int64_t h, int64_t w, + int64_t offset) { + SmallVector vec; + + for (int64_t i = 0; i < h; i++) { + for (int64_t j = 0; j < w; j++) { + // Positive offset value moves above the main diagonal, while negative + // diagonal value moves below the main diagonal. + if (i + offset == j) { + vec.push_back(static_cast(1)); + } else { + vec.push_back(static_cast(0)); + } + } + } + + return tosa::getConstTensor(rewriter, op, vec, shape).value(); +} + +// Legalization for aten.diagonal +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenDiagonalOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.getSelf(); + + // Not a ranked tensor type + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types are supported"); + + // Rank below 2 not accepted + auto selfRank = selfType.getRank(); + if (selfRank <= 1) + return rewriter.notifyMatchFailure( + op, "Rank 0 and 1 are not accepted as they cause underflow"); + + if (!selfType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "Currently only static shapes are supported"); + + const TypeConverter *typeConverter = this->getTypeConverter(); + RankedTensorType resultType = cast( + typeConverter->convertType(op->getResult(0).getType())); + if (!resultType) + return rewriter.notifyMatchFailure(op, "Result type cannot be empty"); + + auto selfElemTy = selfType.getElementType(); + auto resultElemTy = resultType.getElementType(); + + int64_t offset, dim1, dim2; + if (!matchPattern(op.getOffset(), m_TorchConstantInt(&offset))) + offset = 0; + + if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1))) { + dim1 = 0; + } else { + dim1 = toPositiveDim(dim1, selfRank); + } + + if (!matchPattern(op.getDim2(), m_TorchConstantInt(&dim2))) { + dim2 = 1; + } else { + dim2 = toPositiveDim(dim2, selfRank); + } + + auto selfShape = makeShapeTorchCompatible(selfType.getShape()); + int64_t h = selfShape[dim1]; + int64_t w = selfShape[dim2]; + + // Overflowing offset not supported + if ((offset < 0 && std::abs(offset) >= h) || (offset >= 0 && offset >= w)) + return rewriter.notifyMatchFailure( + op, "Offset greater or equal than shape not supported"); + + int64_t targetDim1 = selfRank - 2; + int64_t targetDim2 = selfRank - 1; + + Value selfTransposed = self; + SmallVector transposedInputShape = selfShape; + RankedTensorType transposedInputType = selfType; + + // If (dim1, dim2) != (rank - 2, rank - 1), transpose the input tensor + // so that dim1 and dim2 become rank - 2 and rank - 1. We do this so that + // we can consistently create the diagonal mask tensor. + if (!(dim1 == targetDim1 && dim2 == targetDim2)) { + SmallVector transposedDims; + transposedInputShape.clear(); + + for (int64_t i = 0; i < selfRank; ++i) { + if (i == dim1 || i == dim2) + continue; + transposedDims.push_back(i); + } + transposedDims.push_back(dim1); + transposedDims.push_back(dim2); + + auto transposedDimsConst = tosa::getConstTensor( + rewriter, op, + /*vec=*/transposedDims, + /*shape=*/{static_cast(selfRank)}); + + for (auto &dim : transposedDims) + transposedInputShape.push_back(selfShape[dim]); + + transposedInputType = RankedTensorType::get( + makeShapeLLVMCompatible(transposedInputShape), selfElemTy); + + selfTransposed = rewriter.create( + op->getLoc(), transposedInputType, self, transposedDimsConst.value()); + } + + // Define shape for mask tensor based on rank + SmallVector maskShape; + for (auto i = 0; i < selfRank - 2; i++) + maskShape.push_back(1); + maskShape.push_back(h); + maskShape.push_back(w); + + Value diagonalMask = + TypeSwitch(resultElemTy) + .Case([&](auto) { + return createDiagonalMask(rewriter, op, maskShape, h, w, + offset); + }) + .Case([&](auto intType) { + switch (intType.getWidth()) { + case 1: + return createDiagonalMask(rewriter, op, maskShape, h, w, + offset); + case 32: + return createDiagonalMask(rewriter, op, maskShape, h, w, + offset); + case 64: + return createDiagonalMask(rewriter, op, maskShape, h, w, + offset); + } + llvm_unreachable("Invalid integer width"); + }); + + Value diagonalTensor = rewriter.create( + op->getLoc(), transposedInputType, selfTransposed, diagonalMask, + /*shift=*/0); + + auto resultShape = makeShapeTorchCompatible(resultType.getShape()); + auto targetReduceDim = resultShape[resultType.getRank() - 1]; + + // If transposedInputShape[targetDim1] (or h) is greater than the innermost + // dim of the result, we won't get the correct shape when we reduce sum along + // the innermost dim to get the result. Therefore, we have to slice the + // transposed tensor so that transposedInputShape[targetDim1] == + // targetReduceDim. + if (h > targetReduceDim) { + transposedInputShape[targetDim1] = targetReduceDim; + transposedInputType = RankedTensorType::get( + makeShapeLLVMCompatible(transposedInputShape), selfElemTy); + SmallVector startSlice(selfRank, 0); + SmallVector sizeSlice = + llvm::to_vector(makeShapeTorchCompatible(transposedInputShape)); + if (offset < 0) + startSlice[targetDim1] = std::abs(offset); + diagonalTensor = rewriter.create( + op->getLoc(), transposedInputType, diagonalTensor, + rewriter.getDenseI64ArrayAttr(startSlice), + rewriter.getDenseI64ArrayAttr(sizeSlice)); + } + + // Apply Reduce Sum to get the result + auto reduceDimType = RankedTensorType::get({1}, rewriter.getI64Type()); + auto reduceDimAttr = + DenseIntElementsAttr::get(reduceDimType, llvm::ArrayRef({targetDim2})); + auto result = + mlir::tosa::convertReduceSumOp(rewriter, op, resultType, diagonalTensor, + reduceDimAttr, /*keep_dims=*/false); + + rewriter.replaceOp(op, result.value()); + + return success(); +} } // namespace // ----------------------------------------------------------------------------- @@ -6060,6 +6235,7 @@ public: INSERT_ATENOP_PATTERN(AtenIscloseOp); INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp); INSERT_ATENOP_PATTERN(AtenTrilOp); + INSERT_ATENOP_PATTERN(AtenDiagonalOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 53f1b3647..2852611fe 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1663,6 +1663,8 @@ FX_IMPORTER_TOSA_CRASHING_SET = { # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "DiagonalWithStaticShapeModule_basic", + "EinsumStaticDiagonalDimensionModule_basic", "ElementwiseAtenFloorDivideBroadcastModule_basic", "ElementwiseAtenFloorDivideScalarModule_basic", "ElementwiseAtenFloorDivideScalarNegativeModule_basic", @@ -3190,6 +3192,7 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | { } FX_IMPORTER_TOSA_XFAIL_SET = { + "AdaptiveMaxPool1dDimOneStatic_basic", "AtenPolarDoubleModule_basic", "AtenPolarFloatModule_basic", "HstackBasicComplexModule_basic", @@ -3213,7 +3216,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "Conv_Transpose2dStaticModule_basic", "Conv_Transpose3dModule_basic", "Conv_Transpose3dStaticModule_basic", - "EinsumStaticDiagonalDimensionModule_basic", "ElementwiseFloatTensorGtIntTensorModule_basic", "ElementwiseIntTensorLtFloatTensorModule_basic", "ElementwiseRreluEvalModule_basic", @@ -3384,14 +3386,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "DeterminantBatchedModule_F32", "DeterminantDynamicModule_F32", "DeterminantModule_F32", - "DiagonalModule_basic", - "DiagonalModule_nonsquare", - "DiagonalModule_transposed", - "DiagonalModule_with_dims", - "DiagonalModule_with_dims_and_offset", - "DiagonalModule_with_negative_dims", - "DiagonalModule_with_offset", - "DiagonalWithStaticShapeModule_basic", "DivFloatModule_basic", "DivIntModule_basic", "DropoutTrainModule_basic", @@ -3805,11 +3799,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "ToCopyWithDTypeModule_basic", "TorchPrimLoopForLikeModule_basic", "TorchPrimLoopWhileLikeModule_basic", - "TraceModule_basic", "TraceModule_empty", - "TraceModule_nonsquare", - "TraceSignedIntModule_basic", - "TraceUnsignedIntModule_basic", "TraceUnsignedIntModule_empty", "TypeConversionI1ToF64Module_basic", "TypeConversionI1ToI32Module_basic", @@ -3845,6 +3835,7 @@ ONNX_TOSA_CRASHING_SET = { } ONNX_TOSA_XFAIL_SET = { + "AdaptiveMaxPool1dDimOneStatic_basic", "ScaledDotProductAttentionDifferentCausalModule_basic", "HstackBasicComplexModule_basic", "HstackBasicFloatModule_basic", @@ -3874,7 +3865,6 @@ ONNX_TOSA_XFAIL_SET = { "Conv_Transpose2dStaticModule_basic", "Conv_Transpose3dModule_basic", "Conv_Transpose3dStaticModule_basic", - "EinsumStaticDiagonalDimensionModule_basic", "EinsumStaticModule_basic", "ElementwiseFmaxModule_basic", "ElementwiseFminModule_basic", diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 4e2920708..9957f5207 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1859,3 +1859,29 @@ func.func @torch.aten.bitwise_right_shift.Tensor$basic(%arg0: !torch.vtensor<[?, %0 = torch.aten.bitwise_right_shift.Tensor %arg0, %arg1: !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32> return %0: !torch.vtensor<[?,?],si32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.diagonal$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5,6],si32>) -> !torch.vtensor<[5,6,2],si32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5,6],si32> -> tensor<3x4x5x6xi32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_4:.*]] = torch.constant.int -2 +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<[2, 3, 1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_6:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_5]] : (tensor<3x4x5x6xi32>, tensor<4xi32>) -> tensor<5x6x4x3xi32> +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0, 0, 0], [0, 0, 0], [1, 0, 0], [0, 1, 0]]]]> : tensor<1x1x4x3xi32>}> : () -> tensor<1x1x4x3xi32> +// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_6]], %[[VAL_7]] {shift = 0 : i8} : (tensor<5x6x4x3xi32>, tensor<1x1x4x3xi32>) -> tensor<5x6x4x3xi32> +// CHECK: %[[VAL_9:.*]] = tosa.slice %[[VAL_8]] {size = array, start = array} : (tensor<5x6x4x3xi32>) -> tensor<5x6x2x3xi32> +// CHECK: %[[VAL_10:.*]] = tosa.reduce_sum %[[VAL_9]] {axis = 3 : i32} : (tensor<5x6x2x3xi32>) -> tensor<5x6x2x1xi32> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_10]] {new_shape = array} : (tensor<5x6x2x1xi32>) -> tensor<5x6x2xi32> +// CHECK: %[[VAL_12:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<5x6x2xi32> -> !torch.vtensor<[5,6,2],si32> +// CHECK: return %[[VAL_12]] : !torch.vtensor<[5,6,2],si32> +// CHECK: } +func.func @torch.aten.diagonal$basic(%arg0: !torch.vtensor<[3,4,5,6], si32>) -> !torch.vtensor<[5,6,2], si32> { + %dim1 = torch.constant.int 1 + %dim2 = torch.constant.int 0 + %offset = torch.constant.int -2 + %0 = torch.aten.diagonal %arg0, %offset, %dim1, %dim2 : !torch.vtensor<[3,4,5,6],si32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[5,6,2],si32> + return %0 : !torch.vtensor<[5,6,2],si32> +}