mirror of https://github.com/llvm/torch-mlir
Address review feedback
parent
048dc5518e
commit
740de4f260
|
@ -20,6 +20,8 @@ namespace Torch {
|
|||
|
||||
int64_t toPositiveDim(int64_t dim, int64_t inputRank);
|
||||
bool isValidDim(int64_t dim, int64_t inputRank);
|
||||
Value toIntListConstruct(PatternRewriter &rewriter, Location loc,
|
||||
ArrayRef<int64_t> cstInput, Torch::IntType intType);
|
||||
bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems);
|
||||
/// Returns the index indicated by `v` for a list of given `length`.
|
||||
/// If the index is negative, it is adjusted to `length` + `v`.
|
||||
|
|
|
@ -1391,11 +1391,7 @@ Value getDFTMatmulCoeff(OpBuilder b, Location loc, RankedTensorType matrixType,
|
|||
for (auto i : llvm::seq<unsigned>(0, matrixType.getDimSize(0))) {
|
||||
for (auto j : llvm::seq<unsigned>(0, matrixType.getDimSize(1))) {
|
||||
double v = scale * i * j;
|
||||
if (isRealPart) {
|
||||
v = cos(v);
|
||||
} else {
|
||||
v = -sin(v);
|
||||
}
|
||||
v = isRealPart ? cos(v) : -sin(v);
|
||||
values.push_back(b.getF32FloatAttr(v));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9024,8 +9024,10 @@ namespace {
|
|||
|
||||
/// Creates coefficients based on DFT definition, see
|
||||
/// https://en.wikipedia.org/wiki/Discrete_Fourier_transform.
|
||||
/// Even indices of the second dimension are for the real components of the
|
||||
/// output. Odd indices for the imaginary components.
|
||||
Value getDFTMatmulCoeff(PatternRewriter &rewriter, Location loc,
|
||||
ValueTensorType matrixType, bool isRealPart) {
|
||||
ValueTensorType matrixType) {
|
||||
// scale = 2 * pi / N
|
||||
double scale = 2 * M_PI / matrixType.getSizes()[0];
|
||||
|
||||
|
@ -9033,12 +9035,9 @@ Value getDFTMatmulCoeff(PatternRewriter &rewriter, Location loc,
|
|||
assert(matrixType.getSizes().size() == 2 && "expected 2D matrix");
|
||||
for (auto i : llvm::seq<unsigned>(0, matrixType.getSizes()[0])) {
|
||||
for (auto j : llvm::seq<unsigned>(0, matrixType.getSizes()[1])) {
|
||||
double v = scale * i * j;
|
||||
if (isRealPart) {
|
||||
v = cos(v);
|
||||
} else {
|
||||
v = -sin(v);
|
||||
}
|
||||
const bool isImagPart = j % 2;
|
||||
double v = scale * i * (j / 2);
|
||||
v = isImagPart ? -sin(v) : cos(v);
|
||||
values.push_back(rewriter.getF32FloatAttr(v));
|
||||
}
|
||||
}
|
||||
|
@ -9049,29 +9048,6 @@ Value getDFTMatmulCoeff(PatternRewriter &rewriter, Location loc,
|
|||
ArrayRef<Attribute>(values)));
|
||||
}
|
||||
|
||||
Value createBatchMatmul(PatternRewriter &rewriter, Location loc, Value lhs,
|
||||
Value rhs) {
|
||||
|
||||
BaseTensorType lhsType = cast<BaseTensorType>(lhs.getType());
|
||||
assert(lhsType && lhsType.hasSizes());
|
||||
const ArrayRef<int64_t> lhsShape = lhsType.getSizes();
|
||||
assert(lhsShape.size() >= 2);
|
||||
BaseTensorType rhsType = cast<BaseTensorType>(rhs.getType());
|
||||
assert(rhsType && rhsType.hasSizes());
|
||||
const ArrayRef<int64_t> rhsShape = rhsType.getSizes();
|
||||
assert(rhsShape.size() >= 2);
|
||||
assert(rhsShape[rhsShape.size() - 2] == lhsShape[lhsShape.size() - 1]);
|
||||
|
||||
SmallVector<int64_t> resShape(lhsShape);
|
||||
resShape[resShape.size() - 1] = rhsShape[rhsShape.size() - 1];
|
||||
|
||||
Type dtype = lhsType.getOptionalDtype();
|
||||
|
||||
ValueTensorType resType =
|
||||
ValueTensorType::get(rewriter.getContext(), resShape, dtype);
|
||||
return rewriter.create<AtenMatmulOp>(loc, resType, lhs, rhs);
|
||||
}
|
||||
|
||||
class DecomposeAtenFftRfftOp final : public OpRewritePattern<AtenFftRfftOp> {
|
||||
|
||||
using OpRewritePattern<AtenFftRfftOp>::OpRewritePattern;
|
||||
|
@ -9133,66 +9109,51 @@ class DecomposeAtenFftRfftOp final : public OpRewritePattern<AtenFftRfftOp> {
|
|||
return success();
|
||||
};
|
||||
|
||||
SmallVector<int64_t> lhsShape(inputShape);
|
||||
// Transpose if FFT dimension is not the last one
|
||||
if (needTranspose) {
|
||||
if (failed(transposeValue(rewriter, loc, self, dim, lastDim, self))) {
|
||||
return failure();
|
||||
}
|
||||
std::swap(lhsShape[dim], lhsShape[lastDim]);
|
||||
}
|
||||
// self : (D_0 x ... x D_m x fftLength)
|
||||
|
||||
// lhs = unsqueeze(self, -2) : (D x 1 x fftLength), D = [D_1, D_2, ...]
|
||||
Value unsqueezeDim =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-2));
|
||||
auto unsqueezed = unsqueezeTensor(rewriter, op, self, unsqueezeDim);
|
||||
if (failed(unsqueezed))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"cannot generate unsqueezed tensor");
|
||||
Value lhs = *unsqueezed;
|
||||
Type dtype = inputType.getOptionalDtype();
|
||||
|
||||
Value real, complex;
|
||||
// coeff : (fftLength x outputFftDim*2)
|
||||
ValueTensorType matrixType = ValueTensorType::get(
|
||||
op.getContext(), SmallVector<int64_t>{fftLength, outputFftDim * 2},
|
||||
dtype);
|
||||
Value coeffMatrix = getDFTMatmulCoeff(rewriter, loc, matrixType);
|
||||
|
||||
for (const bool isRealPart : {true, false}) {
|
||||
// X = matmul(self, coeff) : (D_0 x ... x D_m x outputFftDim*2)
|
||||
SmallVector<int64_t> matmulShape(lhsShape.begin(), lhsShape.end() - 1);
|
||||
matmulShape.push_back(outputFftDim * 2);
|
||||
ValueTensorType matmulType =
|
||||
ValueTensorType::get(op.getContext(), matmulShape, dtype);
|
||||
Value flatRes =
|
||||
rewriter.create<AtenMatmulOp>(loc, matmulType, self, coeffMatrix);
|
||||
|
||||
// coeff : (fftLength x outputFftDim)
|
||||
ValueTensorType matrixType = ValueTensorType::get(
|
||||
op.getContext(), SmallVector<int64_t>{fftLength, outputFftDim},
|
||||
dtype);
|
||||
Value coeffMatrix = getDFTMatmulCoeff(rewriter, loc, matrixType,
|
||||
/*isRealPart=*/isRealPart);
|
||||
|
||||
// X = matmul(lhs, coeff) : (D x 1 x outputFftDim)
|
||||
Value matmulRes = createBatchMatmul(rewriter, loc, lhs, coeffMatrix);
|
||||
|
||||
// Y = squeeze(X, -2) : (D x outputFftDim)
|
||||
auto squeezed = squeezeTensor(rewriter, op, loc, -2, matmulRes);
|
||||
if (failed(squeezed))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"cannot generate squeezed tensor");
|
||||
|
||||
if (isRealPart) {
|
||||
real = *squeezed;
|
||||
} else {
|
||||
complex = *squeezed;
|
||||
}
|
||||
}
|
||||
|
||||
// Pack components into a complex tensor
|
||||
BaseTensorType realType = cast<BaseTensorType>(real.getType());
|
||||
SmallVector<int64_t> stackSizes(realType.getSizes());
|
||||
stackSizes.push_back(2);
|
||||
Value sequence = rewriter.create<PrimListConstructOp>(
|
||||
loc, ListType::get(op.getContext(), realType),
|
||||
ValueRange{real, complex});
|
||||
Type stackType = realType.getWithSizesAndDtype(stackSizes, dtype);
|
||||
// Y = unflatten(X, -1, [outputFftDim, 2])
|
||||
// : (D_0 x ... x D_m x outputFftDim x 2)
|
||||
// Z = view_as_complex(Y) : complex(D_0 x ... x D_m x outputFftDim)
|
||||
SmallVector<int64_t> complexResShape(matmulShape);
|
||||
complexResShape.back() = outputFftDim;
|
||||
SmallVector<int64_t> unflattenedResShape(complexResShape);
|
||||
unflattenedResShape.push_back(2);
|
||||
Type unflattenedResType =
|
||||
ValueTensorType::get(op.getContext(), unflattenedResShape, dtype);
|
||||
Value cstMinusOne =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
|
||||
Value stack =
|
||||
rewriter.create<AtenStackOp>(loc, stackType, sequence, cstMinusOne);
|
||||
Type complexResType = ValueTensorType::get(
|
||||
op.getContext(), realType.getSizes(), ComplexType::get(dtype));
|
||||
Value complexRes =
|
||||
rewriter.create<AtenViewAsComplexOp>(loc, complexResType, stack);
|
||||
Value unflattenSizes = toIntListConstruct(
|
||||
rewriter, loc, {outputFftDim, 2}, IntType::get(rewriter.getContext()));
|
||||
Value unflattenedRes = rewriter.create<AtenUnflattenIntOp>(
|
||||
loc, unflattenedResType, flatRes, /*dim=*/cstMinusOne, unflattenSizes);
|
||||
Type complexResType = ValueTensorType::get(op.getContext(), complexResShape,
|
||||
ComplexType::get(dtype));
|
||||
Value complexRes = rewriter.create<AtenViewAsComplexOp>(loc, complexResType,
|
||||
unflattenedRes);
|
||||
|
||||
// Transpose back
|
||||
if (needTranspose) {
|
||||
|
|
|
@ -36,6 +36,18 @@ Torch::matchLegalConstantIndexIntoListOfSize(Value v, int64_t length) {
|
|||
return dim;
|
||||
}
|
||||
|
||||
Value Torch::toIntListConstruct(PatternRewriter &rewriter, Location loc,
|
||||
ArrayRef<int64_t> cstInput,
|
||||
Torch::IntType intType) {
|
||||
SmallVector<Value> cstValues;
|
||||
for (int64_t i : cstInput) {
|
||||
cstValues.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(i)));
|
||||
}
|
||||
return rewriter.create<Torch::PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(intType), cstValues);
|
||||
}
|
||||
|
||||
bool Torch::getListConstructElements(Value v, SmallVectorImpl<Value> &elems) {
|
||||
auto listConstruct = v.getDefiningOp<PrimListConstructOp>();
|
||||
if (!listConstruct)
|
||||
|
|
|
@ -175,26 +175,16 @@ func.func @torch.aten.fmod_float(%arg0: !torch.vtensor<[?],f16>, %arg1: !torch.v
|
|||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_last_dim(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex<f32>> {
|
||||
// CHECK: %[[INTM1:.*]] = torch.constant.int -1
|
||||
// CHECK: %[[VAR0:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<9x5xf32>) : !torch.vtensor<[9,5],f32>
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAR1:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<9x5xf32>) : !torch.vtensor<[9,5],f32>
|
||||
// CHECK: %[[INTM2:.*]] = torch.constant.int -2
|
||||
// CHECK: %[[VAR2:.*]] = torch.aten.unsqueeze %[[ARG0:.*]], %[[INTM2:.*]] : !torch.vtensor<[16,9],f32>, !torch.int -> !torch.vtensor<[16,1,9],f32>
|
||||
// CHECK: %[[VAR3:.*]] = torch.aten.matmul %[[VAR2:.*]], %[[VAR1:.*]] : !torch.vtensor<[16,1,9],f32>, !torch.vtensor<[9,5],f32> -> !torch.vtensor<[16,1,5],f32>
|
||||
// CHECK: torch.runtime.assert %[[TRUE:.*]], "squeeze operation possible for dim only when input_shape[dim] == 1."
|
||||
// CHECK: %[[VAR4:.*]] = torch.aten.squeeze.dim %[[VAR3:.*]], %[[INT1:.*]] : !torch.vtensor<[16,1,5],f32>, !torch.int -> !torch.vtensor<[16,5],f32>
|
||||
// CHECK: %[[VAR5:.*]] = torch.aten.matmul %[[VAR2:.*]], %[[VAR0:.*]] : !torch.vtensor<[16,1,9],f32>, !torch.vtensor<[9,5],f32> -> !torch.vtensor<[16,1,5],f32>
|
||||
// CHECK: torch.runtime.assert %[[TRUE:.*]], "squeeze operation possible for dim only when input_shape[dim] == 1."
|
||||
// CHECK: %[[VAR6:.*]] = torch.aten.squeeze.dim %[[VAR5:.*]], %[[INT1:.*]] : !torch.vtensor<[16,1,5],f32>, !torch.int -> !torch.vtensor<[16,5],f32>
|
||||
// CHECK: %[[VAR7:.*]] = torch.aten.unsqueeze %[[VAR4:.*]], %[[INTM1:.*]] : !torch.vtensor<[16,5],f32>, !torch.int -> !torch.vtensor<[16,5,1],f32>
|
||||
// CHECK: %[[VAR8:.*]] = torch.aten.unsqueeze %[[VAR6:.*]], %[[INTM1:.*]] : !torch.vtensor<[16,5],f32>, !torch.int -> !torch.vtensor<[16,5,1],f32>
|
||||
// CHECK: %[[VAR9:.*]] = torch.prim.ListConstruct %[[VAR7:.*]], %[[VAR8:.*]] : (!torch.vtensor<[16,5,1],f32>, !torch.vtensor<[16,5,1],f32>) -> !torch.list<vtensor>
|
||||
// CHECK: %[[VAR10:.*]] = torch.aten.cat %[[VAR9:.*]], %[[INTM1:.*]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[16,5,2],f32>
|
||||
// CHECK: %[[VAR11:.*]] = torch.aten.view_as_complex %[[VAR10:.*]] : !torch.vtensor<[16,5,2],f32> -> !torch.vtensor<[16,5],complex<f32>>
|
||||
// CHECK: return %[[VAR11:.*]] : !torch.vtensor<[16,5],complex<f32>>
|
||||
// CHECK-SAME: %arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex<f32>> {
|
||||
// CHECK-DAG: %[[INT2:.*]] = torch.constant.int 2
|
||||
// CHECK-DAG: %[[INT5:.*]] = torch.constant.int 5
|
||||
// CHECK-DAG: %[[INT16:.*]] = torch.constant.int 16
|
||||
// CHECK: %[[VAR0:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<9x10xf32>) : !torch.vtensor<[9,10],f32>
|
||||
// CHECK: %[[VAR1:.*]] = torch.aten.mm %arg0, %[[VAR0]] : !torch.vtensor<[16,9],f32>, !torch.vtensor<[9,10],f32> -> !torch.vtensor<[16,10],f32>
|
||||
// CHECK: %[[VAR2:.*]] = torch.prim.ListConstruct %[[INT16]], %[[INT5]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAR3:.*]] = torch.aten.view %[[VAR1]], %[[VAR2]] : !torch.vtensor<[16,10],f32>, !torch.list<int> -> !torch.vtensor<[16,5,2],f32>
|
||||
// CHECK: %[[VAR4:.*]] = torch.aten.view_as_complex %[[VAR3]] : !torch.vtensor<[16,5,2],f32> -> !torch.vtensor<[16,5],complex<f32>>
|
||||
// CHECK: return %[[VAR4]] : !torch.vtensor<[16,5],complex<f32>>
|
||||
func.func @torch.aten.fft_rfft$2d_last_dim(%arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex<f32>> {
|
||||
%int-1 = torch.constant.int -1
|
||||
%none = torch.constant.none
|
||||
|
@ -205,29 +195,20 @@ func.func @torch.aten.fft_rfft$2d_last_dim(%arg0: !torch.vtensor<[16,9],f32>) ->
|
|||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_first_dim(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex<f32>> {
|
||||
// CHECK: %[[INTM1:.*]] = torch.constant.int -1
|
||||
// CHECK: %[[VAR0:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<36x19xf32>) : !torch.vtensor<[36,19],f32>
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[VAR1:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<36x19xf32>) : !torch.vtensor<[36,19],f32>
|
||||
// CHECK: %[[INTM2:.*]] = torch.constant.int -2
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAR2:.*]] = torch.aten.transpose.int %[[ARG0:.*]], %[[INT0:.*]], %[[INT1:.*]] : !torch.vtensor<[36,23],f32>, !torch.int, !torch.int -> !torch.vtensor<[23,36],f32>
|
||||
// CHECK: %[[VAR3:.*]] = torch.aten.unsqueeze %[[VAR2:.*]], %[[INTM2:.*]] : !torch.vtensor<[23,36],f32>, !torch.int -> !torch.vtensor<[23,1,36],f32>
|
||||
// CHECK: %[[VAR4:.*]] = torch.aten.matmul %[[VAR3:.*]], %[[VAR1:.*]] : !torch.vtensor<[23,1,36],f32>, !torch.vtensor<[36,19],f32> -> !torch.vtensor<[23,1,19],f32>
|
||||
// CHECK: torch.runtime.assert %[[TRUE:.*]], "squeeze operation possible for dim only when input_shape[dim] == 1."
|
||||
// CHECK: %[[VAR5:.*]] = torch.aten.squeeze.dim %[[VAR4:.*]], %[[INT1:.*]] : !torch.vtensor<[23,1,19],f32>, !torch.int -> !torch.vtensor<[23,19],f32>
|
||||
// CHECK: %[[VAR6:.*]] = torch.aten.matmul %[[VAR3:.*]], %[[VAR0:.*]] : !torch.vtensor<[23,1,36],f32>, !torch.vtensor<[36,19],f32> -> !torch.vtensor<[23,1,19],f32>
|
||||
// CHECK: torch.runtime.assert %[[TRUE:.*]], "squeeze operation possible for dim only when input_shape[dim] == 1."
|
||||
// CHECK: %[[VAR7:.*]] = torch.aten.squeeze.dim %[[VAR6:.*]], %[[INT1:.*]] : !torch.vtensor<[23,1,19],f32>, !torch.int -> !torch.vtensor<[23,19],f32>
|
||||
// CHECK: %[[VAR8:.*]] = torch.aten.unsqueeze %[[VAR5:.*]], %[[INTM1:.*]] : !torch.vtensor<[23,19],f32>, !torch.int -> !torch.vtensor<[23,19,1],f32>
|
||||
// CHECK: %[[VAR9:.*]] = torch.aten.unsqueeze %[[VAR7:.*]], %[[INTM1:.*]] : !torch.vtensor<[23,19],f32>, !torch.int -> !torch.vtensor<[23,19,1],f32>
|
||||
// CHECK: %[[VAR10:.*]] = torch.prim.ListConstruct %[[VAR8:.*]], %[[VAR9:.*]] : (!torch.vtensor<[23,19,1],f32>, !torch.vtensor<[23,19,1],f32>) -> !torch.list<vtensor>
|
||||
// CHECK: %[[VAR11:.*]] = torch.aten.cat %[[VAR10:.*]], %[[INTM1:.*]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[23,19,2],f32>
|
||||
// CHECK: %[[VAR12:.*]] = torch.aten.view_as_complex %[[VAR11:.*]] : !torch.vtensor<[23,19,2],f32> -> !torch.vtensor<[23,19],complex<f32>>
|
||||
// CHECK: %[[VAR13:.*]] = torch.aten.transpose.int %[[VAR12:.*]], %[[INT0:.*]], %[[INT1:.*]] : !torch.vtensor<[23,19],complex<f32>>, !torch.int, !torch.int -> !torch.vtensor<[19,23],complex<f32>>
|
||||
// CHECK: return %[[VAR13:.*]] : !torch.vtensor<[19,23],complex<f32>>
|
||||
// CHECK-SAME: %arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex<f32>> {
|
||||
// CHECK-DAG: %[[INT2:.*]] = torch.constant.int 2
|
||||
// CHECK-DAG: %[[INT19:.*]] = torch.constant.int 19
|
||||
// CHECK-DAG: %[[INT23:.*]] = torch.constant.int 23
|
||||
// CHECK-DAG: %[[VAR0:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<36x38xf32>) : !torch.vtensor<[36,38],f32>
|
||||
// CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAR1:.*]] = torch.aten.transpose.int %arg0, %[[INT0]], %[[INT1]] : !torch.vtensor<[36,23],f32>, !torch.int, !torch.int -> !torch.vtensor<[23,36],f32>
|
||||
// CHECK: %[[VAR2:.*]] = torch.aten.mm %[[VAR1]], %[[VAR0]] : !torch.vtensor<[23,36],f32>, !torch.vtensor<[36,38],f32> -> !torch.vtensor<[23,38],f32>
|
||||
// CHECK: %[[VAR3:.*]] = torch.prim.ListConstruct %[[INT23]], %[[INT19]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAR4:.*]] = torch.aten.view %[[VAR2]], %[[VAR3]] : !torch.vtensor<[23,38],f32>, !torch.list<int> -> !torch.vtensor<[23,19,2],f32>
|
||||
// CHECK: %[[VAR5:.*]] = torch.aten.view_as_complex %[[VAR4]] : !torch.vtensor<[23,19,2],f32> -> !torch.vtensor<[23,19],complex<f32>>
|
||||
// CHECK: %[[VAR6:.*]] = torch.aten.transpose.int %[[VAR5]], %[[INT0]], %[[INT1]] : !torch.vtensor<[23,19],complex<f32>>, !torch.int, !torch.int -> !torch.vtensor<[19,23],complex<f32>>
|
||||
// CHECK: return %[[VAR6]] : !torch.vtensor<[19,23],complex<f32>>
|
||||
func.func @torch.aten.fft_rfft$2d_first_dim(%arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex<f32>> {
|
||||
%int0 = torch.constant.int 0
|
||||
%none = torch.constant.none
|
||||
|
|
Loading…
Reference in New Issue