mirror of https://github.com/llvm/torch-mlir
Refactor conversion to Linalg
parent
740de4f260
commit
535d9c1712
|
@ -1376,57 +1376,38 @@ public:
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
/// From
|
|
||||||
/// iree/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp
|
|
||||||
///
|
|
||||||
/// Creates coefficients based on DFT definition, see
|
/// Creates coefficients based on DFT definition, see
|
||||||
/// https://en.wikipedia.org/wiki/Discrete_Fourier_transform.
|
/// https://en.wikipedia.org/wiki/Discrete_Fourier_transform.
|
||||||
Value getDFTMatmulCoeff(OpBuilder b, Location loc, RankedTensorType matrixType,
|
Value getDFTMatmulCoeff(OpBuilder b, Location loc,
|
||||||
bool isRealPart) {
|
RankedTensorType matrixType) {
|
||||||
|
|
||||||
|
ComplexType complexTy = llvm::cast<ComplexType>(matrixType.getElementType());
|
||||||
|
mlir::FloatType floatType =
|
||||||
|
llvm::cast<mlir::FloatType>(complexTy.getElementType());
|
||||||
|
|
||||||
// scale = 2 * pi / N
|
// scale = 2 * pi / N
|
||||||
double scale = 2 * M_PI / matrixType.getDimSize(0);
|
double scale = 2 * M_PI / matrixType.getDimSize(0);
|
||||||
|
|
||||||
SmallVector<Attribute> values;
|
SmallVector<std::complex<APFloat>> values;
|
||||||
assert(matrixType.getRank() == 2 && "expected 2D matrix");
|
|
||||||
for (auto i : llvm::seq<unsigned>(0, matrixType.getDimSize(0))) {
|
for (auto i : llvm::seq<unsigned>(0, matrixType.getDimSize(0))) {
|
||||||
for (auto j : llvm::seq<unsigned>(0, matrixType.getDimSize(1))) {
|
for (auto j : llvm::seq<unsigned>(0, matrixType.getDimSize(1))) {
|
||||||
double v = scale * i * j;
|
double v = scale * i * j;
|
||||||
v = isRealPart ? cos(v) : -sin(v);
|
double realV = cos(v);
|
||||||
values.push_back(b.getF32FloatAttr(v));
|
double imagV = -sin(v);
|
||||||
|
|
||||||
|
bool unused;
|
||||||
|
APFloat real(realV);
|
||||||
|
real.convert(floatType.getFloatSemantics(), APFloat::rmNearestTiesToEven,
|
||||||
|
&unused);
|
||||||
|
APFloat imag(imagV);
|
||||||
|
imag.convert(floatType.getFloatSemantics(), APFloat::rmNearestTiesToEven,
|
||||||
|
&unused);
|
||||||
|
|
||||||
|
values.push_back(std::complex<APFloat>(real, imag));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return b.create<arith::ConstantOp>(
|
return b.create<arith::ConstantOp>(
|
||||||
loc, matrixType, DenseFPElementsAttr::get(matrixType, values));
|
loc, matrixType, DenseElementsAttr::get(matrixType, values));
|
||||||
}
|
|
||||||
|
|
||||||
/// From
|
|
||||||
/// iree/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp
|
|
||||||
Value createLinalgMatmulOnTensors(OpBuilder b, Location loc,
|
|
||||||
RankedTensorType resultType, Value lhs,
|
|
||||||
Value rhs) {
|
|
||||||
Value zero = b.create<arith::ConstantOp>(
|
|
||||||
loc, b.getZeroAttr(resultType.getElementType()));
|
|
||||||
Value emptyTensor = b.create<mlir::tensor::EmptyOp>(
|
|
||||||
loc, resultType.getShape(), resultType.getElementType(),
|
|
||||||
/*dyn_size=*/ValueRange{});
|
|
||||||
Value zeroTensor =
|
|
||||||
b.create<linalg::FillOp>(loc, zero, emptyTensor).getResult(0);
|
|
||||||
|
|
||||||
switch (llvm::cast<RankedTensorType>(lhs.getType()).getRank()) {
|
|
||||||
case 1:
|
|
||||||
return b
|
|
||||||
.create<linalg::VecmatOp>(loc, TypeRange{resultType},
|
|
||||||
ValueRange{lhs, rhs}, ValueRange{zeroTensor})
|
|
||||||
.getResult(0);
|
|
||||||
case 2:
|
|
||||||
return b
|
|
||||||
.create<linalg::MatmulOp>(loc, TypeRange{resultType},
|
|
||||||
ValueRange{lhs, rhs}, ValueRange{zeroTensor})
|
|
||||||
.getResult(0);
|
|
||||||
default:
|
|
||||||
assert(false && "unhandled matmul type");
|
|
||||||
return Value();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ConvertAtenFftRfftOp final : OpConversionPattern<AtenFftRfftOp> {
|
struct ConvertAtenFftRfftOp final : OpConversionPattern<AtenFftRfftOp> {
|
||||||
|
@ -1461,69 +1442,120 @@ struct ConvertAtenFftRfftOp final : OpConversionPattern<AtenFftRfftOp> {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "unsupported: only ranked tensors are supported");
|
op, "unsupported: only ranked tensors are supported");
|
||||||
}
|
}
|
||||||
if (!inputType.hasStaticShape() || inputType.getRank() > 2) {
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "unsupported: only static 1D or 2D FFT is supported");
|
|
||||||
}
|
|
||||||
|
|
||||||
const ArrayRef<int64_t> inputShape = inputType.getShape();
|
const ArrayRef<int64_t> inputShape = inputType.getShape();
|
||||||
dim += dim < 0 ? inputShape.size() : 0;
|
dim += dim < 0 ? inputShape.size() : 0;
|
||||||
|
|
||||||
const int64_t fftLength = inputShape[dim];
|
const int64_t fftLength = inputShape[dim];
|
||||||
|
if (fftLength == ShapedType::kDynamic) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unsupported: FFT signal length must be static");
|
||||||
|
}
|
||||||
const int64_t rank = inputType.getRank();
|
const int64_t rank = inputType.getRank();
|
||||||
const int64_t lastDim = rank - 1;
|
const int64_t lastDim = rank - 1;
|
||||||
const int64_t outputFftDim = fftLength / 2 + 1;
|
const int64_t outputFftDim = fftLength / 2 + 1;
|
||||||
const bool needTranspose = dim != lastDim;
|
const bool needTranspose = dim != lastDim;
|
||||||
|
|
||||||
RankedTensorType newResultType = llvm::cast<RankedTensorType>(
|
|
||||||
getTypeConverter()->convertType(op.getType()));
|
|
||||||
llvm::SmallVector<int64_t> componentShape(newResultType.getShape());
|
|
||||||
|
|
||||||
// Transpose if FFT dimension is not the last one
|
// Transpose if FFT dimension is not the last one
|
||||||
llvm::SmallVector<int64_t> perms = llvm::to_vector(llvm::seq(rank));
|
llvm::SmallVector<int64_t> perms = llvm::to_vector(llvm::seq(rank));
|
||||||
std::swap(perms[dim], perms[lastDim]);
|
std::swap(perms[dim], perms[lastDim]);
|
||||||
if (needTranspose) {
|
if (needTranspose) {
|
||||||
self = transposeValue(loc, self, perms, rewriter);
|
self = transposeValue(loc, self, perms, rewriter);
|
||||||
for (size_t i = 0; i < componentShape.size(); i++) {
|
|
||||||
componentShape[i] = newResultType.getShape()[perms[i]];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
RankedTensorType matrixType = RankedTensorType::get(
|
RankedTensorType newResultType = llvm::cast<RankedTensorType>(
|
||||||
{fftLength, outputFftDim}, inputType.getElementType());
|
getTypeConverter()->convertType(op.getType()));
|
||||||
|
ComplexType complexElemType =
|
||||||
|
llvm::cast<ComplexType>(newResultType.getElementType());
|
||||||
|
Type elemType = complexElemType.getElementType();
|
||||||
|
|
||||||
RankedTensorType componentsType =
|
// coeffMatrix : tensor<fftLength x outputFftDim x complex<f32>>
|
||||||
RankedTensorType::get(componentShape, inputType.getElementType());
|
RankedTensorType coeffType =
|
||||||
|
RankedTensorType::get({fftLength, outputFftDim}, complexElemType);
|
||||||
|
// coeffMatrix(n,m) = cos(2 pi n m / N) - j sin(2 pi n m / N)
|
||||||
|
Value coeffMatrix = getDFTMatmulCoeff(rewriter, loc, coeffType);
|
||||||
|
|
||||||
Value realMatrix =
|
// #matmul_trait = {
|
||||||
getDFTMatmulCoeff(rewriter, loc, matrixType, /*isRealPart=*/true);
|
// indexing_maps = [
|
||||||
Value real = createLinalgMatmulOnTensors(rewriter, loc, componentsType,
|
// affine_map<(d_0, ... d_m, f, o) -> (d_0, ... d_m, f)>,
|
||||||
self, realMatrix);
|
// affine_map<(d_0, ... d_m, f, o) -> (f, o)>,
|
||||||
|
// affine_map<(d_0, ... d_m, f, o) -> (d_0, ... d_m, o)>
|
||||||
|
// ],
|
||||||
|
// iterator_types = ["parallel", ..., "parallel", "reduction", "parallel"]
|
||||||
|
// }
|
||||||
|
// linalg.generic #matmul_trait
|
||||||
|
// ins(%A, %B : tensor<D_0 x ... x D_m x fftLength x f32>,
|
||||||
|
// tensor<fftLength x outputFftDim x complex<f32>>)
|
||||||
|
// outs(%C : tensor<D_0 x ... x D_m x outputFftDim x complex<f32>>) {
|
||||||
|
// ^bb0(%a: f32, %b: complex<f32>, %c: complex<f32>) :
|
||||||
|
// %re = complex.re %b : f32
|
||||||
|
// %im = complex.im %b : f32
|
||||||
|
// %mulre = arith.mulf %a, %re: f32
|
||||||
|
// %mulim = arith.mulf %a, %im: f32
|
||||||
|
// %mulcplx = complex.create %mulre, %mulim : complex<f32>
|
||||||
|
// %add = complex.add %c, %mulcplx: complex<f32>
|
||||||
|
// linalg.yield %add : complex<f32>
|
||||||
|
// } -> (tensor<D_0 x ... x D_m x outputFftDim x complex<f32>>)
|
||||||
|
|
||||||
Value imagMatrix =
|
Value lhs = self;
|
||||||
getDFTMatmulCoeff(rewriter, loc, matrixType, /*isRealPart=*/false);
|
Value rhs = coeffMatrix;
|
||||||
Value imag = createLinalgMatmulOnTensors(rewriter, loc, componentsType,
|
RankedTensorType lhsType = llvm::cast<RankedTensorType>(lhs.getType());
|
||||||
self, imagMatrix);
|
ArrayRef<int64_t> lhsShape(lhsType.getShape());
|
||||||
|
ArrayRef<int64_t> rhsShape(coeffType.getShape());
|
||||||
|
|
||||||
// Pack components into a complex tensor
|
unsigned batchRank = lhsShape.size() - 1;
|
||||||
Type elementType = newResultType.getElementType();
|
|
||||||
auto toComplexBody = [&](OpBuilder &b, Location loc,
|
SmallVector<AffineExpr> lhsExpr;
|
||||||
ValueRange payloadArgs) {
|
SmallVector<AffineExpr> rhsExpr;
|
||||||
Value realElem = payloadArgs[0];
|
SmallVector<AffineExpr> outExpr;
|
||||||
Value imagElem = payloadArgs[1];
|
SmallVector<utils::IteratorType> iteratorTypes(
|
||||||
Value complexElem =
|
batchRank, utils::IteratorType::parallel);
|
||||||
b.create<complex::CreateOp>(loc, elementType, realElem, imagElem);
|
SmallVector<Value> resultShape;
|
||||||
b.create<linalg::YieldOp>(loc, complexElem);
|
for (unsigned i = 0; i < batchRank; i++) {
|
||||||
};
|
lhsExpr.push_back(rewriter.getAffineDimExpr(i));
|
||||||
Value complexRes = torch_to_linalg::createElementwiseLinalgGeneric(
|
outExpr.push_back(rewriter.getAffineDimExpr(i));
|
||||||
rewriter, loc, {real, imag}, elementType, toComplexBody);
|
resultShape.push_back(getDimOp(rewriter, loc, lhs, i));
|
||||||
|
}
|
||||||
|
unsigned fIdx = batchRank, oIdx = batchRank + 1;
|
||||||
|
lhsExpr.insert(lhsExpr.end(), {rewriter.getAffineDimExpr(fIdx)});
|
||||||
|
rhsExpr.insert(rhsExpr.end(), {rewriter.getAffineDimExpr(fIdx),
|
||||||
|
rewriter.getAffineDimExpr(oIdx)});
|
||||||
|
outExpr.insert(outExpr.end(), {rewriter.getAffineDimExpr(oIdx)});
|
||||||
|
resultShape.insert(resultShape.end(),
|
||||||
|
{getDimOp(rewriter, loc, rhs, rhsShape.size() - 1)});
|
||||||
|
|
||||||
|
Value zeroTensor =
|
||||||
|
createZeroInitTensor(rewriter, loc, resultShape, complexElemType);
|
||||||
|
auto indexingMaps = AffineMap::inferFromExprList(
|
||||||
|
{lhsExpr, rhsExpr, outExpr}, rewriter.getContext());
|
||||||
|
iteratorTypes.insert(iteratorTypes.end(), {utils::IteratorType::reduction,
|
||||||
|
utils::IteratorType::parallel});
|
||||||
|
|
||||||
|
Value complexRes =
|
||||||
|
rewriter
|
||||||
|
.create<linalg::GenericOp>(
|
||||||
|
loc, zeroTensor.getType(),
|
||||||
|
/*inputs=*/ValueRange{lhs, rhs},
|
||||||
|
/*outputs=*/zeroTensor, indexingMaps, iteratorTypes,
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
Value l = args[0], r = args[1], res = args[2];
|
||||||
|
Value re = b.create<complex::ReOp>(loc, elemType, r);
|
||||||
|
Value im = b.create<complex::ImOp>(loc, elemType, r);
|
||||||
|
Value mulRe = b.create<arith::MulFOp>(loc, l, re);
|
||||||
|
Value mulIm = b.create<arith::MulFOp>(loc, l, im);
|
||||||
|
Value mulCplx = b.create<complex::CreateOp>(
|
||||||
|
loc, complexElemType, mulRe, mulIm);
|
||||||
|
Value add = b.create<complex::AddOp>(loc, mulCplx, res);
|
||||||
|
b.create<linalg::YieldOp>(loc, add);
|
||||||
|
})
|
||||||
|
.getResult(0);
|
||||||
|
|
||||||
// Transpose back
|
// Transpose back
|
||||||
if (needTranspose) {
|
if (needTranspose) {
|
||||||
complexRes = transposeValue(loc, complexRes, perms, rewriter);
|
complexRes = transposeValue(loc, complexRes, perms, rewriter);
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, complexRes);
|
rewriter.replaceOp(op, complexRes);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -1,25 +1,28 @@
|
||||||
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -canonicalize -split-input-file -verify-diagnostics | FileCheck %s
|
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -canonicalize -split-input-file -verify-diagnostics | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||||
|
// CHECK: #map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||||
|
// CHECK: #map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_last_dim(
|
// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_last_dim(
|
||||||
// CHECK-SAME: %[[INPUT_VTENSOR:.*]]: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex<f32>> {
|
// CHECK-SAME: %arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex<f32>> {
|
||||||
// CHECK-DAG: %[[IMAG_COEFF:.*]] = arith.constant dense<{{.*}}> : tensor<9x5xf32>
|
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
|
||||||
// CHECK-DAG: %[[C0:.*]] = arith.constant 0.000000e+00 : f32
|
// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<{{.*}}> : tensor<9x5xcomplex<f32>>
|
||||||
// CHECK-DAG: %[[REAL_COEFF:.*]] = arith.constant dense<{{.*}}> : tensor<9x5xf32>
|
// CHECK: %[[VAR0:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[16,9],f32> -> tensor<16x9xf32>
|
||||||
// CHECK-DAG: %[[INPUT:.*]] = torch_c.to_builtin_tensor %[[INPUT_VTENSOR:.*]] : !torch.vtensor<[16,9],f32> -> tensor<16x9xf32>
|
// CHECK-DAG: %[[VAR1:.*]] = tensor.empty() : tensor<16x5xcomplex<f32>>
|
||||||
// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<16x5xf32>
|
// CHECK: %[[VAR2:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAR1]] : tensor<16x5xcomplex<f32>>) -> tensor<16x5xcomplex<f32>>
|
||||||
// CHECK: %[[ZEROES_0:.*]] = linalg.fill ins(%[[C0:.*]] : f32) outs(%[[EMPTY_0:.*]] : tensor<16x5xf32>) -> tensor<16x5xf32>
|
// CHECK: %[[VAR3:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel"]} ins(%[[VAR0]], %[[CST_0]] : tensor<16x9xf32>, tensor<9x5xcomplex<f32>>) outs(%[[VAR2]] : tensor<16x5xcomplex<f32>>) {
|
||||||
// CHECK: %[[REAL_COMP:.*]] = linalg.matmul ins(%[[INPUT:.*]], %[[REAL_COEFF:.*]] : tensor<16x9xf32>, tensor<9x5xf32>) outs(%[[ZEROES_0:.*]] : tensor<16x5xf32>) -> tensor<16x5xf32>
|
// CHECK: ^bb0(%in: f32, %in_1: complex<f32>, %out: complex<f32>):
|
||||||
// CHECK: %[[EMPTY_1:.*]] = tensor.empty() : tensor<16x5xf32>
|
// CHECK: %[[VAR5:.*]] = complex.re %in_1 : complex<f32>
|
||||||
// CHECK: %[[ZEROES_1:.*]] = linalg.fill ins(%[[C0:.*]] : f32) outs(%[[EMPTY_1:.*]] : tensor<16x5xf32>) -> tensor<16x5xf32>
|
// CHECK: %[[VAR6:.*]] = complex.im %in_1 : complex<f32>
|
||||||
// CHECK: %[[IMAG_COMP:.*]] = linalg.matmul ins(%[[INPUT:.*]], %[[IMAG_COEFF:.*]] : tensor<16x9xf32>, tensor<9x5xf32>) outs(%[[ZEROES_1:.*]] : tensor<16x5xf32>) -> tensor<16x5xf32>
|
// CHECK: %[[VAR7:.*]] = arith.mulf %in, %[[VAR5]] : f32
|
||||||
// CHECK: %[[EMPTY_2:.*]] = tensor.empty() : tensor<16x5xcomplex<f32>>
|
// CHECK: %[[VAR8:.*]] = arith.mulf %in, %[[VAR6]] : f32
|
||||||
// CHECK: %[[COMPLEX:.*]] = linalg.generic {{.*}} ins(%[[REAL_COMP:.*]], %[[IMAG_COMP:.*]] : tensor<16x5xf32>, tensor<16x5xf32>) outs(%[[EMPTY_2:.*]] : tensor<16x5xcomplex<f32>>) {
|
// CHECK: %[[VAR9:.*]] = complex.create %[[VAR7]], %[[VAR8]] : complex<f32>
|
||||||
// CHECK: ^bb0(%[[IN:.*]]: f32, %[[IN_2:.*]]: f32, %[[OUT:.*]]: complex<f32>):
|
// CHECK: %[[VAR10:.*]] = complex.add %[[VAR9]], %out : complex<f32>
|
||||||
// CHECK: %[[ELEM_COMPLEX:.*]] = complex.create %[[IN:.*]], %[[IN_2:.*]] : complex<f32>
|
// CHECK: linalg.yield %[[VAR10]] : complex<f32>
|
||||||
// CHECK: linalg.yield %[[ELEM_COMPLEX:.*]] : complex<f32>
|
|
||||||
// CHECK: } -> tensor<16x5xcomplex<f32>>
|
// CHECK: } -> tensor<16x5xcomplex<f32>>
|
||||||
// CHECK: %[[OUTPUT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[COMPLEX:.*]] : tensor<16x5xcomplex<f32>> -> !torch.vtensor<[16,5],complex<f32>>
|
// CHECK: %[[VAR4:.*]] = torch_c.from_builtin_tensor %[[VAR3]] : tensor<16x5xcomplex<f32>> -> !torch.vtensor<[16,5],complex<f32>>
|
||||||
// CHECK: return %[[OUTPUT_VTENSOR:.*]] : !torch.vtensor<[16,5],complex<f32>>
|
// CHECK: return %[[VAR4]] : !torch.vtensor<[16,5],complex<f32>>
|
||||||
|
|
||||||
func.func @torch.aten.fft_rfft$2d_last_dim(%arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex<f32>> {
|
func.func @torch.aten.fft_rfft$2d_last_dim(%arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex<f32>> {
|
||||||
%int-1 = torch.constant.int -1
|
%int-1 = torch.constant.int -1
|
||||||
|
@ -31,29 +34,28 @@ func.func @torch.aten.fft_rfft$2d_last_dim(%arg0: !torch.vtensor<[16,9],f32>) ->
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_first_dim(
|
// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_first_dim(
|
||||||
// CHECK-SAME: %[[INPUT_VTENSOR:.*]]: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex<f32>> {
|
// CHECK-SAME: %arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex<f32>> {
|
||||||
// CHECK-DAG: %[[IMAG_COEFF:.*]] = arith.constant dense<{{.*}}> : tensor<36x19xf32>
|
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
|
||||||
// CHECK-DAG: %[[C0:.*]] = arith.constant 0.000000e+00 : f32
|
// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<{{.*}}> : tensor<36x19xcomplex<f32>>
|
||||||
// CHECK-DAG: %[[REAL_COEFF:.*]] = arith.constant dense<{{.*}}> : tensor<36x19xf32>
|
// CHECK: %[[VAR0:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[36,23],f32> -> tensor<36x23xf32>
|
||||||
// CHECK-DAG: %[[INPUT:.*]] = torch_c.to_builtin_tensor %[[INPUT_VTENSOR:.*]] : !torch.vtensor<[36,23],f32> -> tensor<36x23xf32>
|
// CHECK-DAG: %[[VAR1:.*]] = tensor.empty() : tensor<23x36xf32>
|
||||||
// CHECK-DAG: %[[EMPTY_0:.*]] = tensor.empty() : tensor<23x36xf32>
|
// CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[VAR0]] : tensor<36x23xf32>) outs(%[[VAR1]] : tensor<23x36xf32>) permutation = [1, 0]
|
||||||
// CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[INPUT:.*]] : tensor<36x23xf32>) outs(%[[EMPTY_0:.*]] : tensor<23x36xf32>) permutation = [1, 0]
|
// CHECK-DAG: %[[VAR2:.*]] = tensor.empty() : tensor<23x19xcomplex<f32>>
|
||||||
// CHECK: %[[EMPTY_1:.*]] = tensor.empty() : tensor<23x19xf32>
|
// CHECK: %[[VAR3:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAR2]] : tensor<23x19xcomplex<f32>>) -> tensor<23x19xcomplex<f32>>
|
||||||
// CHECK: %[[ZEROES_0:.*]] = linalg.fill ins(%[[C0:.*]] : f32) outs(%[[EMPTY_1:.*]] : tensor<23x19xf32>) -> tensor<23x19xf32>
|
// CHECK: %[[VAR4:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel"]} ins(%[[TRANSPOSED]], %[[CST_0]] : tensor<23x36xf32>, tensor<36x19xcomplex<f32>>) outs(%[[VAR3]] : tensor<23x19xcomplex<f32>>) {
|
||||||
// CHECK: %[[REAL_COMP:.*]] = linalg.matmul ins(%[[TRANSPOSED:.*]], %[[REAL_COEFF:.*]] : tensor<23x36xf32>, tensor<36x19xf32>) outs(%[[ZEROES_0:.*]] : tensor<23x19xf32>) -> tensor<23x19xf32>
|
// CHECK: ^bb0(%in: f32, %in_2: complex<f32>, %out: complex<f32>):
|
||||||
// CHECK: %[[EMPTY_2:.*]] = tensor.empty() : tensor<23x19xf32>
|
// CHECK: %[[VAR7:.*]] = complex.re %in_2 : complex<f32>
|
||||||
// CHECK: %[[ZEROES_1:.*]] = linalg.fill ins(%[[C0:.*]] : f32) outs(%[[EMPTY_2:.*]] : tensor<23x19xf32>) -> tensor<23x19xf32>
|
// CHECK: %[[VAR8:.*]] = complex.im %in_2 : complex<f32>
|
||||||
// CHECK: %[[IMAG_COMP:.*]] = linalg.matmul ins(%[[TRANSPOSED:.*]], %[[IMAG_COEFF:.*]] : tensor<23x36xf32>, tensor<36x19xf32>) outs(%[[ZEROES_1:.*]] : tensor<23x19xf32>) -> tensor<23x19xf32>
|
// CHECK: %[[VAR9:.*]] = arith.mulf %in, %[[VAR7]] : f32
|
||||||
// CHECK: %[[EMPTY_3:.*]] = tensor.empty() : tensor<23x19xcomplex<f32>>
|
// CHECK: %[[VAR10:.*]] = arith.mulf %in, %[[VAR8]] : f32
|
||||||
// CHECK: %[[COMPLEX:.*]] = linalg.generic {{.*}} ins(%[[REAL_COMP:.*]], %[[IMAG_COMP:.*]] : tensor<23x19xf32>, tensor<23x19xf32>) outs(%[[EMPTY_3:.*]] : tensor<23x19xcomplex<f32>>) {
|
// CHECK: %[[VAR11:.*]] = complex.create %[[VAR9]], %[[VAR10]] : complex<f32>
|
||||||
// CHECK: ^bb0(%[[IN:.*]]: f32, %[[IN_3:.*]]: f32, %[[OUT:.*]]: complex<f32>):
|
// CHECK: %[[VAR12:.*]] = complex.add %[[VAR11]], %out : complex<f32>
|
||||||
// CHECK: %[[EMPTY_02:.*]] = complex.create %[[IN:.*]], %[[IN_3:.*]] : complex<f32>
|
// CHECK: linalg.yield %[[VAR12]] : complex<f32>
|
||||||
// CHECK: linalg.yield %[[EMPTY_02:.*]] : complex<f32>
|
|
||||||
// CHECK: } -> tensor<23x19xcomplex<f32>>
|
// CHECK: } -> tensor<23x19xcomplex<f32>>
|
||||||
// CHECK: %[[EMPTY_4:.*]] = tensor.empty() : tensor<19x23xcomplex<f32>>
|
// CHECK-DAG: %[[VAR5:.*]] = tensor.empty() : tensor<19x23xcomplex<f32>>
|
||||||
// CHECK: %[[TRANSPOSED_2:.*]] = linalg.transpose ins(%[[COMPLEX:.*]] : tensor<23x19xcomplex<f32>>) outs(%[[EMPTY_4:.*]] : tensor<19x23xcomplex<f32>>) permutation = [1, 0]
|
// CHECK: %[[TRANSPOSED_1:.*]] = linalg.transpose ins(%[[VAR4]] : tensor<23x19xcomplex<f32>>) outs(%[[VAR5]] : tensor<19x23xcomplex<f32>>) permutation = [1, 0]
|
||||||
// CHECK: %[[OUTPUT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[TRANSPOSED_2:.*]] : tensor<19x23xcomplex<f32>> -> !torch.vtensor<[19,23],complex<f32>>
|
// CHECK: %[[VAR6:.*]] = torch_c.from_builtin_tensor %[[TRANSPOSED_1]] : tensor<19x23xcomplex<f32>> -> !torch.vtensor<[19,23],complex<f32>>
|
||||||
// CHECK: return %[[OUTPUT_VTENSOR:.*]] : !torch.vtensor<[19,23],complex<f32>>
|
// CHECK: return %[[VAR6]] : !torch.vtensor<[19,23],complex<f32>>
|
||||||
func.func @torch.aten.fft_rfft$2d_first_dim(%arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex<f32>> {
|
func.func @torch.aten.fft_rfft$2d_first_dim(%arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex<f32>> {
|
||||||
%int0 = torch.constant.int 0
|
%int0 = torch.constant.int 0
|
||||||
%none = torch.constant.none
|
%none = torch.constant.none
|
||||||
|
|
Loading…
Reference in New Issue