Refactor conversion to Linalg

pull/3857/head
giacs-epic 2024-11-20 14:40:34 +00:00
parent 740de4f260
commit 535d9c1712
2 changed files with 150 additions and 116 deletions

View File

@ -1376,57 +1376,38 @@ public:
namespace {
/// From
/// iree/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp
///
/// Creates coefficients based on DFT definition, see
/// https://en.wikipedia.org/wiki/Discrete_Fourier_transform.
Value getDFTMatmulCoeff(OpBuilder b, Location loc, RankedTensorType matrixType,
bool isRealPart) {
Value getDFTMatmulCoeff(OpBuilder b, Location loc,
RankedTensorType matrixType) {
ComplexType complexTy = llvm::cast<ComplexType>(matrixType.getElementType());
mlir::FloatType floatType =
llvm::cast<mlir::FloatType>(complexTy.getElementType());
// scale = 2 * pi / N
double scale = 2 * M_PI / matrixType.getDimSize(0);
SmallVector<Attribute> values;
assert(matrixType.getRank() == 2 && "expected 2D matrix");
SmallVector<std::complex<APFloat>> values;
for (auto i : llvm::seq<unsigned>(0, matrixType.getDimSize(0))) {
for (auto j : llvm::seq<unsigned>(0, matrixType.getDimSize(1))) {
double v = scale * i * j;
v = isRealPart ? cos(v) : -sin(v);
values.push_back(b.getF32FloatAttr(v));
double realV = cos(v);
double imagV = -sin(v);
bool unused;
APFloat real(realV);
real.convert(floatType.getFloatSemantics(), APFloat::rmNearestTiesToEven,
&unused);
APFloat imag(imagV);
imag.convert(floatType.getFloatSemantics(), APFloat::rmNearestTiesToEven,
&unused);
values.push_back(std::complex<APFloat>(real, imag));
}
}
return b.create<arith::ConstantOp>(
loc, matrixType, DenseFPElementsAttr::get(matrixType, values));
}
/// From
/// iree/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp
Value createLinalgMatmulOnTensors(OpBuilder b, Location loc,
RankedTensorType resultType, Value lhs,
Value rhs) {
Value zero = b.create<arith::ConstantOp>(
loc, b.getZeroAttr(resultType.getElementType()));
Value emptyTensor = b.create<mlir::tensor::EmptyOp>(
loc, resultType.getShape(), resultType.getElementType(),
/*dyn_size=*/ValueRange{});
Value zeroTensor =
b.create<linalg::FillOp>(loc, zero, emptyTensor).getResult(0);
switch (llvm::cast<RankedTensorType>(lhs.getType()).getRank()) {
case 1:
return b
.create<linalg::VecmatOp>(loc, TypeRange{resultType},
ValueRange{lhs, rhs}, ValueRange{zeroTensor})
.getResult(0);
case 2:
return b
.create<linalg::MatmulOp>(loc, TypeRange{resultType},
ValueRange{lhs, rhs}, ValueRange{zeroTensor})
.getResult(0);
default:
assert(false && "unhandled matmul type");
return Value();
}
loc, matrixType, DenseElementsAttr::get(matrixType, values));
}
struct ConvertAtenFftRfftOp final : OpConversionPattern<AtenFftRfftOp> {
@ -1461,69 +1442,120 @@ struct ConvertAtenFftRfftOp final : OpConversionPattern<AtenFftRfftOp> {
return rewriter.notifyMatchFailure(
op, "unsupported: only ranked tensors are supported");
}
if (!inputType.hasStaticShape() || inputType.getRank() > 2) {
return rewriter.notifyMatchFailure(
op, "unsupported: only static 1D or 2D FFT is supported");
}
const ArrayRef<int64_t> inputShape = inputType.getShape();
dim += dim < 0 ? inputShape.size() : 0;
const int64_t fftLength = inputShape[dim];
if (fftLength == ShapedType::kDynamic) {
return rewriter.notifyMatchFailure(
op, "unsupported: FFT signal length must be static");
}
const int64_t rank = inputType.getRank();
const int64_t lastDim = rank - 1;
const int64_t outputFftDim = fftLength / 2 + 1;
const bool needTranspose = dim != lastDim;
RankedTensorType newResultType = llvm::cast<RankedTensorType>(
getTypeConverter()->convertType(op.getType()));
llvm::SmallVector<int64_t> componentShape(newResultType.getShape());
// Transpose if FFT dimension is not the last one
llvm::SmallVector<int64_t> perms = llvm::to_vector(llvm::seq(rank));
std::swap(perms[dim], perms[lastDim]);
if (needTranspose) {
self = transposeValue(loc, self, perms, rewriter);
for (size_t i = 0; i < componentShape.size(); i++) {
componentShape[i] = newResultType.getShape()[perms[i]];
}
}
RankedTensorType matrixType = RankedTensorType::get(
{fftLength, outputFftDim}, inputType.getElementType());
RankedTensorType newResultType = llvm::cast<RankedTensorType>(
getTypeConverter()->convertType(op.getType()));
ComplexType complexElemType =
llvm::cast<ComplexType>(newResultType.getElementType());
Type elemType = complexElemType.getElementType();
RankedTensorType componentsType =
RankedTensorType::get(componentShape, inputType.getElementType());
// coeffMatrix : tensor<fftLength x outputFftDim x complex<f32>>
RankedTensorType coeffType =
RankedTensorType::get({fftLength, outputFftDim}, complexElemType);
// coeffMatrix(n,m) = cos(2 pi n m / N) - j sin(2 pi n m / N)
Value coeffMatrix = getDFTMatmulCoeff(rewriter, loc, coeffType);
Value realMatrix =
getDFTMatmulCoeff(rewriter, loc, matrixType, /*isRealPart=*/true);
Value real = createLinalgMatmulOnTensors(rewriter, loc, componentsType,
self, realMatrix);
// #matmul_trait = {
// indexing_maps = [
// affine_map<(d_0, ... d_m, f, o) -> (d_0, ... d_m, f)>,
// affine_map<(d_0, ... d_m, f, o) -> (f, o)>,
// affine_map<(d_0, ... d_m, f, o) -> (d_0, ... d_m, o)>
// ],
// iterator_types = ["parallel", ..., "parallel", "reduction", "parallel"]
// }
// linalg.generic #matmul_trait
// ins(%A, %B : tensor<D_0 x ... x D_m x fftLength x f32>,
// tensor<fftLength x outputFftDim x complex<f32>>)
// outs(%C : tensor<D_0 x ... x D_m x outputFftDim x complex<f32>>) {
// ^bb0(%a: f32, %b: complex<f32>, %c: complex<f32>) :
// %re = complex.re %b : f32
// %im = complex.im %b : f32
// %mulre = arith.mulf %a, %re: f32
// %mulim = arith.mulf %a, %im: f32
// %mulcplx = complex.create %mulre, %mulim : complex<f32>
// %add = complex.add %c, %mulcplx: complex<f32>
// linalg.yield %add : complex<f32>
// } -> (tensor<D_0 x ... x D_m x outputFftDim x complex<f32>>)
Value imagMatrix =
getDFTMatmulCoeff(rewriter, loc, matrixType, /*isRealPart=*/false);
Value imag = createLinalgMatmulOnTensors(rewriter, loc, componentsType,
self, imagMatrix);
Value lhs = self;
Value rhs = coeffMatrix;
RankedTensorType lhsType = llvm::cast<RankedTensorType>(lhs.getType());
ArrayRef<int64_t> lhsShape(lhsType.getShape());
ArrayRef<int64_t> rhsShape(coeffType.getShape());
// Pack components into a complex tensor
Type elementType = newResultType.getElementType();
auto toComplexBody = [&](OpBuilder &b, Location loc,
ValueRange payloadArgs) {
Value realElem = payloadArgs[0];
Value imagElem = payloadArgs[1];
Value complexElem =
b.create<complex::CreateOp>(loc, elementType, realElem, imagElem);
b.create<linalg::YieldOp>(loc, complexElem);
};
Value complexRes = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, {real, imag}, elementType, toComplexBody);
unsigned batchRank = lhsShape.size() - 1;
SmallVector<AffineExpr> lhsExpr;
SmallVector<AffineExpr> rhsExpr;
SmallVector<AffineExpr> outExpr;
SmallVector<utils::IteratorType> iteratorTypes(
batchRank, utils::IteratorType::parallel);
SmallVector<Value> resultShape;
for (unsigned i = 0; i < batchRank; i++) {
lhsExpr.push_back(rewriter.getAffineDimExpr(i));
outExpr.push_back(rewriter.getAffineDimExpr(i));
resultShape.push_back(getDimOp(rewriter, loc, lhs, i));
}
unsigned fIdx = batchRank, oIdx = batchRank + 1;
lhsExpr.insert(lhsExpr.end(), {rewriter.getAffineDimExpr(fIdx)});
rhsExpr.insert(rhsExpr.end(), {rewriter.getAffineDimExpr(fIdx),
rewriter.getAffineDimExpr(oIdx)});
outExpr.insert(outExpr.end(), {rewriter.getAffineDimExpr(oIdx)});
resultShape.insert(resultShape.end(),
{getDimOp(rewriter, loc, rhs, rhsShape.size() - 1)});
Value zeroTensor =
createZeroInitTensor(rewriter, loc, resultShape, complexElemType);
auto indexingMaps = AffineMap::inferFromExprList(
{lhsExpr, rhsExpr, outExpr}, rewriter.getContext());
iteratorTypes.insert(iteratorTypes.end(), {utils::IteratorType::reduction,
utils::IteratorType::parallel});
Value complexRes =
rewriter
.create<linalg::GenericOp>(
loc, zeroTensor.getType(),
/*inputs=*/ValueRange{lhs, rhs},
/*outputs=*/zeroTensor, indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value l = args[0], r = args[1], res = args[2];
Value re = b.create<complex::ReOp>(loc, elemType, r);
Value im = b.create<complex::ImOp>(loc, elemType, r);
Value mulRe = b.create<arith::MulFOp>(loc, l, re);
Value mulIm = b.create<arith::MulFOp>(loc, l, im);
Value mulCplx = b.create<complex::CreateOp>(
loc, complexElemType, mulRe, mulIm);
Value add = b.create<complex::AddOp>(loc, mulCplx, res);
b.create<linalg::YieldOp>(loc, add);
})
.getResult(0);
// Transpose back
if (needTranspose) {
complexRes = transposeValue(loc, complexRes, perms, rewriter);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, complexRes);
rewriter.replaceOp(op, complexRes);
return success();
}
};

View File

@ -1,25 +1,28 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -canonicalize -split-input-file -verify-diagnostics | FileCheck %s
// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK: #map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
// CHECK: #map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_last_dim(
// CHECK-SAME: %[[INPUT_VTENSOR:.*]]: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex<f32>> {
// CHECK-DAG: %[[IMAG_COEFF:.*]] = arith.constant dense<{{.*}}> : tensor<9x5xf32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[REAL_COEFF:.*]] = arith.constant dense<{{.*}}> : tensor<9x5xf32>
// CHECK-DAG: %[[INPUT:.*]] = torch_c.to_builtin_tensor %[[INPUT_VTENSOR:.*]] : !torch.vtensor<[16,9],f32> -> tensor<16x9xf32>
// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<16x5xf32>
// CHECK: %[[ZEROES_0:.*]] = linalg.fill ins(%[[C0:.*]] : f32) outs(%[[EMPTY_0:.*]] : tensor<16x5xf32>) -> tensor<16x5xf32>
// CHECK: %[[REAL_COMP:.*]] = linalg.matmul ins(%[[INPUT:.*]], %[[REAL_COEFF:.*]] : tensor<16x9xf32>, tensor<9x5xf32>) outs(%[[ZEROES_0:.*]] : tensor<16x5xf32>) -> tensor<16x5xf32>
// CHECK: %[[EMPTY_1:.*]] = tensor.empty() : tensor<16x5xf32>
// CHECK: %[[ZEROES_1:.*]] = linalg.fill ins(%[[C0:.*]] : f32) outs(%[[EMPTY_1:.*]] : tensor<16x5xf32>) -> tensor<16x5xf32>
// CHECK: %[[IMAG_COMP:.*]] = linalg.matmul ins(%[[INPUT:.*]], %[[IMAG_COEFF:.*]] : tensor<16x9xf32>, tensor<9x5xf32>) outs(%[[ZEROES_1:.*]] : tensor<16x5xf32>) -> tensor<16x5xf32>
// CHECK: %[[EMPTY_2:.*]] = tensor.empty() : tensor<16x5xcomplex<f32>>
// CHECK: %[[COMPLEX:.*]] = linalg.generic {{.*}} ins(%[[REAL_COMP:.*]], %[[IMAG_COMP:.*]] : tensor<16x5xf32>, tensor<16x5xf32>) outs(%[[EMPTY_2:.*]] : tensor<16x5xcomplex<f32>>) {
// CHECK: ^bb0(%[[IN:.*]]: f32, %[[IN_2:.*]]: f32, %[[OUT:.*]]: complex<f32>):
// CHECK: %[[ELEM_COMPLEX:.*]] = complex.create %[[IN:.*]], %[[IN_2:.*]] : complex<f32>
// CHECK: linalg.yield %[[ELEM_COMPLEX:.*]] : complex<f32>
// CHECK-SAME: %arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex<f32>> {
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<{{.*}}> : tensor<9x5xcomplex<f32>>
// CHECK: %[[VAR0:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[16,9],f32> -> tensor<16x9xf32>
// CHECK-DAG: %[[VAR1:.*]] = tensor.empty() : tensor<16x5xcomplex<f32>>
// CHECK: %[[VAR2:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAR1]] : tensor<16x5xcomplex<f32>>) -> tensor<16x5xcomplex<f32>>
// CHECK: %[[VAR3:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel"]} ins(%[[VAR0]], %[[CST_0]] : tensor<16x9xf32>, tensor<9x5xcomplex<f32>>) outs(%[[VAR2]] : tensor<16x5xcomplex<f32>>) {
// CHECK: ^bb0(%in: f32, %in_1: complex<f32>, %out: complex<f32>):
// CHECK: %[[VAR5:.*]] = complex.re %in_1 : complex<f32>
// CHECK: %[[VAR6:.*]] = complex.im %in_1 : complex<f32>
// CHECK: %[[VAR7:.*]] = arith.mulf %in, %[[VAR5]] : f32
// CHECK: %[[VAR8:.*]] = arith.mulf %in, %[[VAR6]] : f32
// CHECK: %[[VAR9:.*]] = complex.create %[[VAR7]], %[[VAR8]] : complex<f32>
// CHECK: %[[VAR10:.*]] = complex.add %[[VAR9]], %out : complex<f32>
// CHECK: linalg.yield %[[VAR10]] : complex<f32>
// CHECK: } -> tensor<16x5xcomplex<f32>>
// CHECK: %[[OUTPUT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[COMPLEX:.*]] : tensor<16x5xcomplex<f32>> -> !torch.vtensor<[16,5],complex<f32>>
// CHECK: return %[[OUTPUT_VTENSOR:.*]] : !torch.vtensor<[16,5],complex<f32>>
// CHECK: %[[VAR4:.*]] = torch_c.from_builtin_tensor %[[VAR3]] : tensor<16x5xcomplex<f32>> -> !torch.vtensor<[16,5],complex<f32>>
// CHECK: return %[[VAR4]] : !torch.vtensor<[16,5],complex<f32>>
func.func @torch.aten.fft_rfft$2d_last_dim(%arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex<f32>> {
%int-1 = torch.constant.int -1
@ -31,29 +34,28 @@ func.func @torch.aten.fft_rfft$2d_last_dim(%arg0: !torch.vtensor<[16,9],f32>) ->
// -----
// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_first_dim(
// CHECK-SAME: %[[INPUT_VTENSOR:.*]]: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex<f32>> {
// CHECK-DAG: %[[IMAG_COEFF:.*]] = arith.constant dense<{{.*}}> : tensor<36x19xf32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[REAL_COEFF:.*]] = arith.constant dense<{{.*}}> : tensor<36x19xf32>
// CHECK-DAG: %[[INPUT:.*]] = torch_c.to_builtin_tensor %[[INPUT_VTENSOR:.*]] : !torch.vtensor<[36,23],f32> -> tensor<36x23xf32>
// CHECK-DAG: %[[EMPTY_0:.*]] = tensor.empty() : tensor<23x36xf32>
// CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[INPUT:.*]] : tensor<36x23xf32>) outs(%[[EMPTY_0:.*]] : tensor<23x36xf32>) permutation = [1, 0]
// CHECK: %[[EMPTY_1:.*]] = tensor.empty() : tensor<23x19xf32>
// CHECK: %[[ZEROES_0:.*]] = linalg.fill ins(%[[C0:.*]] : f32) outs(%[[EMPTY_1:.*]] : tensor<23x19xf32>) -> tensor<23x19xf32>
// CHECK: %[[REAL_COMP:.*]] = linalg.matmul ins(%[[TRANSPOSED:.*]], %[[REAL_COEFF:.*]] : tensor<23x36xf32>, tensor<36x19xf32>) outs(%[[ZEROES_0:.*]] : tensor<23x19xf32>) -> tensor<23x19xf32>
// CHECK: %[[EMPTY_2:.*]] = tensor.empty() : tensor<23x19xf32>
// CHECK: %[[ZEROES_1:.*]] = linalg.fill ins(%[[C0:.*]] : f32) outs(%[[EMPTY_2:.*]] : tensor<23x19xf32>) -> tensor<23x19xf32>
// CHECK: %[[IMAG_COMP:.*]] = linalg.matmul ins(%[[TRANSPOSED:.*]], %[[IMAG_COEFF:.*]] : tensor<23x36xf32>, tensor<36x19xf32>) outs(%[[ZEROES_1:.*]] : tensor<23x19xf32>) -> tensor<23x19xf32>
// CHECK: %[[EMPTY_3:.*]] = tensor.empty() : tensor<23x19xcomplex<f32>>
// CHECK: %[[COMPLEX:.*]] = linalg.generic {{.*}} ins(%[[REAL_COMP:.*]], %[[IMAG_COMP:.*]] : tensor<23x19xf32>, tensor<23x19xf32>) outs(%[[EMPTY_3:.*]] : tensor<23x19xcomplex<f32>>) {
// CHECK: ^bb0(%[[IN:.*]]: f32, %[[IN_3:.*]]: f32, %[[OUT:.*]]: complex<f32>):
// CHECK: %[[EMPTY_02:.*]] = complex.create %[[IN:.*]], %[[IN_3:.*]] : complex<f32>
// CHECK: linalg.yield %[[EMPTY_02:.*]] : complex<f32>
// CHECK-SAME: %arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex<f32>> {
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<{{.*}}> : tensor<36x19xcomplex<f32>>
// CHECK: %[[VAR0:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[36,23],f32> -> tensor<36x23xf32>
// CHECK-DAG: %[[VAR1:.*]] = tensor.empty() : tensor<23x36xf32>
// CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[VAR0]] : tensor<36x23xf32>) outs(%[[VAR1]] : tensor<23x36xf32>) permutation = [1, 0]
// CHECK-DAG: %[[VAR2:.*]] = tensor.empty() : tensor<23x19xcomplex<f32>>
// CHECK: %[[VAR3:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAR2]] : tensor<23x19xcomplex<f32>>) -> tensor<23x19xcomplex<f32>>
// CHECK: %[[VAR4:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel"]} ins(%[[TRANSPOSED]], %[[CST_0]] : tensor<23x36xf32>, tensor<36x19xcomplex<f32>>) outs(%[[VAR3]] : tensor<23x19xcomplex<f32>>) {
// CHECK: ^bb0(%in: f32, %in_2: complex<f32>, %out: complex<f32>):
// CHECK: %[[VAR7:.*]] = complex.re %in_2 : complex<f32>
// CHECK: %[[VAR8:.*]] = complex.im %in_2 : complex<f32>
// CHECK: %[[VAR9:.*]] = arith.mulf %in, %[[VAR7]] : f32
// CHECK: %[[VAR10:.*]] = arith.mulf %in, %[[VAR8]] : f32
// CHECK: %[[VAR11:.*]] = complex.create %[[VAR9]], %[[VAR10]] : complex<f32>
// CHECK: %[[VAR12:.*]] = complex.add %[[VAR11]], %out : complex<f32>
// CHECK: linalg.yield %[[VAR12]] : complex<f32>
// CHECK: } -> tensor<23x19xcomplex<f32>>
// CHECK: %[[EMPTY_4:.*]] = tensor.empty() : tensor<19x23xcomplex<f32>>
// CHECK: %[[TRANSPOSED_2:.*]] = linalg.transpose ins(%[[COMPLEX:.*]] : tensor<23x19xcomplex<f32>>) outs(%[[EMPTY_4:.*]] : tensor<19x23xcomplex<f32>>) permutation = [1, 0]
// CHECK: %[[OUTPUT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[TRANSPOSED_2:.*]] : tensor<19x23xcomplex<f32>> -> !torch.vtensor<[19,23],complex<f32>>
// CHECK: return %[[OUTPUT_VTENSOR:.*]] : !torch.vtensor<[19,23],complex<f32>>
// CHECK-DAG: %[[VAR5:.*]] = tensor.empty() : tensor<19x23xcomplex<f32>>
// CHECK: %[[TRANSPOSED_1:.*]] = linalg.transpose ins(%[[VAR4]] : tensor<23x19xcomplex<f32>>) outs(%[[VAR5]] : tensor<19x23xcomplex<f32>>) permutation = [1, 0]
// CHECK: %[[VAR6:.*]] = torch_c.from_builtin_tensor %[[TRANSPOSED_1]] : tensor<19x23xcomplex<f32>> -> !torch.vtensor<[19,23],complex<f32>>
// CHECK: return %[[VAR6]] : !torch.vtensor<[19,23],complex<f32>>
func.func @torch.aten.fft_rfft$2d_first_dim(%arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex<f32>> {
%int0 = torch.constant.int 0
%none = torch.constant.none