From 740de4f2608370dc935ab9082b779ec48d1345ca Mon Sep 17 00:00:00 2001 From: giacs-epic <179146510+giacs-epic@users.noreply.github.com> Date: Tue, 19 Nov 2024 13:46:25 +0000 Subject: [PATCH] Address review feedback --- .../torch-mlir/Dialect/Torch/Utils/Utils.h | 2 + lib/Conversion/TorchToLinalg/Linear.cpp | 6 +- .../Torch/Transforms/DecomposeComplexOps.cpp | 115 ++++++------------ lib/Dialect/Torch/Utils/Utils.cpp | 12 ++ test/Dialect/Torch/decompose-complex-ops.mlir | 67 ++++------ 5 files changed, 77 insertions(+), 125 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index cf31c8f97..dd13c40fa 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -20,6 +20,8 @@ namespace Torch { int64_t toPositiveDim(int64_t dim, int64_t inputRank); bool isValidDim(int64_t dim, int64_t inputRank); +Value toIntListConstruct(PatternRewriter &rewriter, Location loc, + ArrayRef cstInput, Torch::IntType intType); bool getListConstructElements(Value v, SmallVectorImpl &elems); /// Returns the index indicated by `v` for a list of given `length`. /// If the index is negative, it is adjusted to `length` + `v`. diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index f6dd45df3..850e292fb 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -1391,11 +1391,7 @@ Value getDFTMatmulCoeff(OpBuilder b, Location loc, RankedTensorType matrixType, for (auto i : llvm::seq(0, matrixType.getDimSize(0))) { for (auto j : llvm::seq(0, matrixType.getDimSize(1))) { double v = scale * i * j; - if (isRealPart) { - v = cos(v); - } else { - v = -sin(v); - } + v = isRealPart ? cos(v) : -sin(v); values.push_back(b.getF32FloatAttr(v)); } } diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 888b13f1c..64d9c769f 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -9024,8 +9024,10 @@ namespace { /// Creates coefficients based on DFT definition, see /// https://en.wikipedia.org/wiki/Discrete_Fourier_transform. +/// Even indices of the second dimension are for the real components of the +/// output. Odd indices for the imaginary components. Value getDFTMatmulCoeff(PatternRewriter &rewriter, Location loc, - ValueTensorType matrixType, bool isRealPart) { + ValueTensorType matrixType) { // scale = 2 * pi / N double scale = 2 * M_PI / matrixType.getSizes()[0]; @@ -9033,12 +9035,9 @@ Value getDFTMatmulCoeff(PatternRewriter &rewriter, Location loc, assert(matrixType.getSizes().size() == 2 && "expected 2D matrix"); for (auto i : llvm::seq(0, matrixType.getSizes()[0])) { for (auto j : llvm::seq(0, matrixType.getSizes()[1])) { - double v = scale * i * j; - if (isRealPart) { - v = cos(v); - } else { - v = -sin(v); - } + const bool isImagPart = j % 2; + double v = scale * i * (j / 2); + v = isImagPart ? -sin(v) : cos(v); values.push_back(rewriter.getF32FloatAttr(v)); } } @@ -9049,29 +9048,6 @@ Value getDFTMatmulCoeff(PatternRewriter &rewriter, Location loc, ArrayRef(values))); } -Value createBatchMatmul(PatternRewriter &rewriter, Location loc, Value lhs, - Value rhs) { - - BaseTensorType lhsType = cast(lhs.getType()); - assert(lhsType && lhsType.hasSizes()); - const ArrayRef lhsShape = lhsType.getSizes(); - assert(lhsShape.size() >= 2); - BaseTensorType rhsType = cast(rhs.getType()); - assert(rhsType && rhsType.hasSizes()); - const ArrayRef rhsShape = rhsType.getSizes(); - assert(rhsShape.size() >= 2); - assert(rhsShape[rhsShape.size() - 2] == lhsShape[lhsShape.size() - 1]); - - SmallVector resShape(lhsShape); - resShape[resShape.size() - 1] = rhsShape[rhsShape.size() - 1]; - - Type dtype = lhsType.getOptionalDtype(); - - ValueTensorType resType = - ValueTensorType::get(rewriter.getContext(), resShape, dtype); - return rewriter.create(loc, resType, lhs, rhs); -} - class DecomposeAtenFftRfftOp final : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -9133,66 +9109,51 @@ class DecomposeAtenFftRfftOp final : public OpRewritePattern { return success(); }; + SmallVector lhsShape(inputShape); // Transpose if FFT dimension is not the last one if (needTranspose) { if (failed(transposeValue(rewriter, loc, self, dim, lastDim, self))) { return failure(); } + std::swap(lhsShape[dim], lhsShape[lastDim]); } + // self : (D_0 x ... x D_m x fftLength) - // lhs = unsqueeze(self, -2) : (D x 1 x fftLength), D = [D_1, D_2, ...] - Value unsqueezeDim = - rewriter.create(loc, rewriter.getI64IntegerAttr(-2)); - auto unsqueezed = unsqueezeTensor(rewriter, op, self, unsqueezeDim); - if (failed(unsqueezed)) - return rewriter.notifyMatchFailure(op, - "cannot generate unsqueezed tensor"); - Value lhs = *unsqueezed; Type dtype = inputType.getOptionalDtype(); - Value real, complex; + // coeff : (fftLength x outputFftDim*2) + ValueTensorType matrixType = ValueTensorType::get( + op.getContext(), SmallVector{fftLength, outputFftDim * 2}, + dtype); + Value coeffMatrix = getDFTMatmulCoeff(rewriter, loc, matrixType); - for (const bool isRealPart : {true, false}) { + // X = matmul(self, coeff) : (D_0 x ... x D_m x outputFftDim*2) + SmallVector matmulShape(lhsShape.begin(), lhsShape.end() - 1); + matmulShape.push_back(outputFftDim * 2); + ValueTensorType matmulType = + ValueTensorType::get(op.getContext(), matmulShape, dtype); + Value flatRes = + rewriter.create(loc, matmulType, self, coeffMatrix); - // coeff : (fftLength x outputFftDim) - ValueTensorType matrixType = ValueTensorType::get( - op.getContext(), SmallVector{fftLength, outputFftDim}, - dtype); - Value coeffMatrix = getDFTMatmulCoeff(rewriter, loc, matrixType, - /*isRealPart=*/isRealPart); - - // X = matmul(lhs, coeff) : (D x 1 x outputFftDim) - Value matmulRes = createBatchMatmul(rewriter, loc, lhs, coeffMatrix); - - // Y = squeeze(X, -2) : (D x outputFftDim) - auto squeezed = squeezeTensor(rewriter, op, loc, -2, matmulRes); - if (failed(squeezed)) - return rewriter.notifyMatchFailure(op, - "cannot generate squeezed tensor"); - - if (isRealPart) { - real = *squeezed; - } else { - complex = *squeezed; - } - } - - // Pack components into a complex tensor - BaseTensorType realType = cast(real.getType()); - SmallVector stackSizes(realType.getSizes()); - stackSizes.push_back(2); - Value sequence = rewriter.create( - loc, ListType::get(op.getContext(), realType), - ValueRange{real, complex}); - Type stackType = realType.getWithSizesAndDtype(stackSizes, dtype); + // Y = unflatten(X, -1, [outputFftDim, 2]) + // : (D_0 x ... x D_m x outputFftDim x 2) + // Z = view_as_complex(Y) : complex(D_0 x ... x D_m x outputFftDim) + SmallVector complexResShape(matmulShape); + complexResShape.back() = outputFftDim; + SmallVector unflattenedResShape(complexResShape); + unflattenedResShape.push_back(2); + Type unflattenedResType = + ValueTensorType::get(op.getContext(), unflattenedResShape, dtype); Value cstMinusOne = rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); - Value stack = - rewriter.create(loc, stackType, sequence, cstMinusOne); - Type complexResType = ValueTensorType::get( - op.getContext(), realType.getSizes(), ComplexType::get(dtype)); - Value complexRes = - rewriter.create(loc, complexResType, stack); + Value unflattenSizes = toIntListConstruct( + rewriter, loc, {outputFftDim, 2}, IntType::get(rewriter.getContext())); + Value unflattenedRes = rewriter.create( + loc, unflattenedResType, flatRes, /*dim=*/cstMinusOne, unflattenSizes); + Type complexResType = ValueTensorType::get(op.getContext(), complexResShape, + ComplexType::get(dtype)); + Value complexRes = rewriter.create(loc, complexResType, + unflattenedRes); // Transpose back if (needTranspose) { diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 664bbb2d5..77840d206 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -36,6 +36,18 @@ Torch::matchLegalConstantIndexIntoListOfSize(Value v, int64_t length) { return dim; } +Value Torch::toIntListConstruct(PatternRewriter &rewriter, Location loc, + ArrayRef cstInput, + Torch::IntType intType) { + SmallVector cstValues; + for (int64_t i : cstInput) { + cstValues.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(i))); + } + return rewriter.create( + loc, Torch::ListType::get(intType), cstValues); +} + bool Torch::getListConstructElements(Value v, SmallVectorImpl &elems) { auto listConstruct = v.getDefiningOp(); if (!listConstruct) diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index ac6ddc758..bf37a4847 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -175,26 +175,16 @@ func.func @torch.aten.fmod_float(%arg0: !torch.vtensor<[?],f16>, %arg1: !torch.v // ----- // CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_last_dim( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex> { -// CHECK: %[[INTM1:.*]] = torch.constant.int -1 -// CHECK: %[[VAR0:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<9x5xf32>) : !torch.vtensor<[9,5],f32> -// CHECK: %[[TRUE:.*]] = torch.constant.bool true -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[VAR1:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<9x5xf32>) : !torch.vtensor<[9,5],f32> -// CHECK: %[[INTM2:.*]] = torch.constant.int -2 -// CHECK: %[[VAR2:.*]] = torch.aten.unsqueeze %[[ARG0:.*]], %[[INTM2:.*]] : !torch.vtensor<[16,9],f32>, !torch.int -> !torch.vtensor<[16,1,9],f32> -// CHECK: %[[VAR3:.*]] = torch.aten.matmul %[[VAR2:.*]], %[[VAR1:.*]] : !torch.vtensor<[16,1,9],f32>, !torch.vtensor<[9,5],f32> -> !torch.vtensor<[16,1,5],f32> -// CHECK: torch.runtime.assert %[[TRUE:.*]], "squeeze operation possible for dim only when input_shape[dim] == 1." -// CHECK: %[[VAR4:.*]] = torch.aten.squeeze.dim %[[VAR3:.*]], %[[INT1:.*]] : !torch.vtensor<[16,1,5],f32>, !torch.int -> !torch.vtensor<[16,5],f32> -// CHECK: %[[VAR5:.*]] = torch.aten.matmul %[[VAR2:.*]], %[[VAR0:.*]] : !torch.vtensor<[16,1,9],f32>, !torch.vtensor<[9,5],f32> -> !torch.vtensor<[16,1,5],f32> -// CHECK: torch.runtime.assert %[[TRUE:.*]], "squeeze operation possible for dim only when input_shape[dim] == 1." -// CHECK: %[[VAR6:.*]] = torch.aten.squeeze.dim %[[VAR5:.*]], %[[INT1:.*]] : !torch.vtensor<[16,1,5],f32>, !torch.int -> !torch.vtensor<[16,5],f32> -// CHECK: %[[VAR7:.*]] = torch.aten.unsqueeze %[[VAR4:.*]], %[[INTM1:.*]] : !torch.vtensor<[16,5],f32>, !torch.int -> !torch.vtensor<[16,5,1],f32> -// CHECK: %[[VAR8:.*]] = torch.aten.unsqueeze %[[VAR6:.*]], %[[INTM1:.*]] : !torch.vtensor<[16,5],f32>, !torch.int -> !torch.vtensor<[16,5,1],f32> -// CHECK: %[[VAR9:.*]] = torch.prim.ListConstruct %[[VAR7:.*]], %[[VAR8:.*]] : (!torch.vtensor<[16,5,1],f32>, !torch.vtensor<[16,5,1],f32>) -> !torch.list -// CHECK: %[[VAR10:.*]] = torch.aten.cat %[[VAR9:.*]], %[[INTM1:.*]] : !torch.list, !torch.int -> !torch.vtensor<[16,5,2],f32> -// CHECK: %[[VAR11:.*]] = torch.aten.view_as_complex %[[VAR10:.*]] : !torch.vtensor<[16,5,2],f32> -> !torch.vtensor<[16,5],complex> -// CHECK: return %[[VAR11:.*]] : !torch.vtensor<[16,5],complex> +// CHECK-SAME: %arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex> { +// CHECK-DAG: %[[INT2:.*]] = torch.constant.int 2 +// CHECK-DAG: %[[INT5:.*]] = torch.constant.int 5 +// CHECK-DAG: %[[INT16:.*]] = torch.constant.int 16 +// CHECK: %[[VAR0:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<9x10xf32>) : !torch.vtensor<[9,10],f32> +// CHECK: %[[VAR1:.*]] = torch.aten.mm %arg0, %[[VAR0]] : !torch.vtensor<[16,9],f32>, !torch.vtensor<[9,10],f32> -> !torch.vtensor<[16,10],f32> +// CHECK: %[[VAR2:.*]] = torch.prim.ListConstruct %[[INT16]], %[[INT5]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAR3:.*]] = torch.aten.view %[[VAR1]], %[[VAR2]] : !torch.vtensor<[16,10],f32>, !torch.list -> !torch.vtensor<[16,5,2],f32> +// CHECK: %[[VAR4:.*]] = torch.aten.view_as_complex %[[VAR3]] : !torch.vtensor<[16,5,2],f32> -> !torch.vtensor<[16,5],complex> +// CHECK: return %[[VAR4]] : !torch.vtensor<[16,5],complex> func.func @torch.aten.fft_rfft$2d_last_dim(%arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex> { %int-1 = torch.constant.int -1 %none = torch.constant.none @@ -205,29 +195,20 @@ func.func @torch.aten.fft_rfft$2d_last_dim(%arg0: !torch.vtensor<[16,9],f32>) -> // ----- // CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_first_dim( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex> { -// CHECK: %[[INTM1:.*]] = torch.constant.int -1 -// CHECK: %[[VAR0:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<36x19xf32>) : !torch.vtensor<[36,19],f32> -// CHECK: %[[TRUE:.*]] = torch.constant.bool true -// CHECK: %[[VAR1:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<36x19xf32>) : !torch.vtensor<[36,19],f32> -// CHECK: %[[INTM2:.*]] = torch.constant.int -2 -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[VAR2:.*]] = torch.aten.transpose.int %[[ARG0:.*]], %[[INT0:.*]], %[[INT1:.*]] : !torch.vtensor<[36,23],f32>, !torch.int, !torch.int -> !torch.vtensor<[23,36],f32> -// CHECK: %[[VAR3:.*]] = torch.aten.unsqueeze %[[VAR2:.*]], %[[INTM2:.*]] : !torch.vtensor<[23,36],f32>, !torch.int -> !torch.vtensor<[23,1,36],f32> -// CHECK: %[[VAR4:.*]] = torch.aten.matmul %[[VAR3:.*]], %[[VAR1:.*]] : !torch.vtensor<[23,1,36],f32>, !torch.vtensor<[36,19],f32> -> !torch.vtensor<[23,1,19],f32> -// CHECK: torch.runtime.assert %[[TRUE:.*]], "squeeze operation possible for dim only when input_shape[dim] == 1." -// CHECK: %[[VAR5:.*]] = torch.aten.squeeze.dim %[[VAR4:.*]], %[[INT1:.*]] : !torch.vtensor<[23,1,19],f32>, !torch.int -> !torch.vtensor<[23,19],f32> -// CHECK: %[[VAR6:.*]] = torch.aten.matmul %[[VAR3:.*]], %[[VAR0:.*]] : !torch.vtensor<[23,1,36],f32>, !torch.vtensor<[36,19],f32> -> !torch.vtensor<[23,1,19],f32> -// CHECK: torch.runtime.assert %[[TRUE:.*]], "squeeze operation possible for dim only when input_shape[dim] == 1." -// CHECK: %[[VAR7:.*]] = torch.aten.squeeze.dim %[[VAR6:.*]], %[[INT1:.*]] : !torch.vtensor<[23,1,19],f32>, !torch.int -> !torch.vtensor<[23,19],f32> -// CHECK: %[[VAR8:.*]] = torch.aten.unsqueeze %[[VAR5:.*]], %[[INTM1:.*]] : !torch.vtensor<[23,19],f32>, !torch.int -> !torch.vtensor<[23,19,1],f32> -// CHECK: %[[VAR9:.*]] = torch.aten.unsqueeze %[[VAR7:.*]], %[[INTM1:.*]] : !torch.vtensor<[23,19],f32>, !torch.int -> !torch.vtensor<[23,19,1],f32> -// CHECK: %[[VAR10:.*]] = torch.prim.ListConstruct %[[VAR8:.*]], %[[VAR9:.*]] : (!torch.vtensor<[23,19,1],f32>, !torch.vtensor<[23,19,1],f32>) -> !torch.list -// CHECK: %[[VAR11:.*]] = torch.aten.cat %[[VAR10:.*]], %[[INTM1:.*]] : !torch.list, !torch.int -> !torch.vtensor<[23,19,2],f32> -// CHECK: %[[VAR12:.*]] = torch.aten.view_as_complex %[[VAR11:.*]] : !torch.vtensor<[23,19,2],f32> -> !torch.vtensor<[23,19],complex> -// CHECK: %[[VAR13:.*]] = torch.aten.transpose.int %[[VAR12:.*]], %[[INT0:.*]], %[[INT1:.*]] : !torch.vtensor<[23,19],complex>, !torch.int, !torch.int -> !torch.vtensor<[19,23],complex> -// CHECK: return %[[VAR13:.*]] : !torch.vtensor<[19,23],complex> +// CHECK-SAME: %arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex> { +// CHECK-DAG: %[[INT2:.*]] = torch.constant.int 2 +// CHECK-DAG: %[[INT19:.*]] = torch.constant.int 19 +// CHECK-DAG: %[[INT23:.*]] = torch.constant.int 23 +// CHECK-DAG: %[[VAR0:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<36x38xf32>) : !torch.vtensor<[36,38],f32> +// CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0 +// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[VAR1:.*]] = torch.aten.transpose.int %arg0, %[[INT0]], %[[INT1]] : !torch.vtensor<[36,23],f32>, !torch.int, !torch.int -> !torch.vtensor<[23,36],f32> +// CHECK: %[[VAR2:.*]] = torch.aten.mm %[[VAR1]], %[[VAR0]] : !torch.vtensor<[23,36],f32>, !torch.vtensor<[36,38],f32> -> !torch.vtensor<[23,38],f32> +// CHECK: %[[VAR3:.*]] = torch.prim.ListConstruct %[[INT23]], %[[INT19]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAR4:.*]] = torch.aten.view %[[VAR2]], %[[VAR3]] : !torch.vtensor<[23,38],f32>, !torch.list -> !torch.vtensor<[23,19,2],f32> +// CHECK: %[[VAR5:.*]] = torch.aten.view_as_complex %[[VAR4]] : !torch.vtensor<[23,19,2],f32> -> !torch.vtensor<[23,19],complex> +// CHECK: %[[VAR6:.*]] = torch.aten.transpose.int %[[VAR5]], %[[INT0]], %[[INT1]] : !torch.vtensor<[23,19],complex>, !torch.int, !torch.int -> !torch.vtensor<[19,23],complex> +// CHECK: return %[[VAR6]] : !torch.vtensor<[19,23],complex> func.func @torch.aten.fft_rfft$2d_first_dim(%arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex> { %int0 = torch.constant.int 0 %none = torch.constant.none