Address review feedback

pull/3857/head
giacs-epic 2024-11-19 13:46:25 +00:00
parent 048dc5518e
commit 740de4f260
5 changed files with 77 additions and 125 deletions

View File

@ -20,6 +20,8 @@ namespace Torch {
int64_t toPositiveDim(int64_t dim, int64_t inputRank);
bool isValidDim(int64_t dim, int64_t inputRank);
Value toIntListConstruct(PatternRewriter &rewriter, Location loc,
ArrayRef<int64_t> cstInput, Torch::IntType intType);
bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems);
/// Returns the index indicated by `v` for a list of given `length`.
/// If the index is negative, it is adjusted to `length` + `v`.

View File

@ -1391,11 +1391,7 @@ Value getDFTMatmulCoeff(OpBuilder b, Location loc, RankedTensorType matrixType,
for (auto i : llvm::seq<unsigned>(0, matrixType.getDimSize(0))) {
for (auto j : llvm::seq<unsigned>(0, matrixType.getDimSize(1))) {
double v = scale * i * j;
if (isRealPart) {
v = cos(v);
} else {
v = -sin(v);
}
v = isRealPart ? cos(v) : -sin(v);
values.push_back(b.getF32FloatAttr(v));
}
}

View File

@ -9024,8 +9024,10 @@ namespace {
/// Creates coefficients based on DFT definition, see
/// https://en.wikipedia.org/wiki/Discrete_Fourier_transform.
/// Even indices of the second dimension are for the real components of the
/// output. Odd indices for the imaginary components.
Value getDFTMatmulCoeff(PatternRewriter &rewriter, Location loc,
ValueTensorType matrixType, bool isRealPart) {
ValueTensorType matrixType) {
// scale = 2 * pi / N
double scale = 2 * M_PI / matrixType.getSizes()[0];
@ -9033,12 +9035,9 @@ Value getDFTMatmulCoeff(PatternRewriter &rewriter, Location loc,
assert(matrixType.getSizes().size() == 2 && "expected 2D matrix");
for (auto i : llvm::seq<unsigned>(0, matrixType.getSizes()[0])) {
for (auto j : llvm::seq<unsigned>(0, matrixType.getSizes()[1])) {
double v = scale * i * j;
if (isRealPart) {
v = cos(v);
} else {
v = -sin(v);
}
const bool isImagPart = j % 2;
double v = scale * i * (j / 2);
v = isImagPart ? -sin(v) : cos(v);
values.push_back(rewriter.getF32FloatAttr(v));
}
}
@ -9049,29 +9048,6 @@ Value getDFTMatmulCoeff(PatternRewriter &rewriter, Location loc,
ArrayRef<Attribute>(values)));
}
Value createBatchMatmul(PatternRewriter &rewriter, Location loc, Value lhs,
Value rhs) {
BaseTensorType lhsType = cast<BaseTensorType>(lhs.getType());
assert(lhsType && lhsType.hasSizes());
const ArrayRef<int64_t> lhsShape = lhsType.getSizes();
assert(lhsShape.size() >= 2);
BaseTensorType rhsType = cast<BaseTensorType>(rhs.getType());
assert(rhsType && rhsType.hasSizes());
const ArrayRef<int64_t> rhsShape = rhsType.getSizes();
assert(rhsShape.size() >= 2);
assert(rhsShape[rhsShape.size() - 2] == lhsShape[lhsShape.size() - 1]);
SmallVector<int64_t> resShape(lhsShape);
resShape[resShape.size() - 1] = rhsShape[rhsShape.size() - 1];
Type dtype = lhsType.getOptionalDtype();
ValueTensorType resType =
ValueTensorType::get(rewriter.getContext(), resShape, dtype);
return rewriter.create<AtenMatmulOp>(loc, resType, lhs, rhs);
}
class DecomposeAtenFftRfftOp final : public OpRewritePattern<AtenFftRfftOp> {
using OpRewritePattern<AtenFftRfftOp>::OpRewritePattern;
@ -9133,66 +9109,51 @@ class DecomposeAtenFftRfftOp final : public OpRewritePattern<AtenFftRfftOp> {
return success();
};
SmallVector<int64_t> lhsShape(inputShape);
// Transpose if FFT dimension is not the last one
if (needTranspose) {
if (failed(transposeValue(rewriter, loc, self, dim, lastDim, self))) {
return failure();
}
std::swap(lhsShape[dim], lhsShape[lastDim]);
}
// self : (D_0 x ... x D_m x fftLength)
// lhs = unsqueeze(self, -2) : (D x 1 x fftLength), D = [D_1, D_2, ...]
Value unsqueezeDim =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-2));
auto unsqueezed = unsqueezeTensor(rewriter, op, self, unsqueezeDim);
if (failed(unsqueezed))
return rewriter.notifyMatchFailure(op,
"cannot generate unsqueezed tensor");
Value lhs = *unsqueezed;
Type dtype = inputType.getOptionalDtype();
Value real, complex;
// coeff : (fftLength x outputFftDim*2)
ValueTensorType matrixType = ValueTensorType::get(
op.getContext(), SmallVector<int64_t>{fftLength, outputFftDim * 2},
dtype);
Value coeffMatrix = getDFTMatmulCoeff(rewriter, loc, matrixType);
for (const bool isRealPart : {true, false}) {
// X = matmul(self, coeff) : (D_0 x ... x D_m x outputFftDim*2)
SmallVector<int64_t> matmulShape(lhsShape.begin(), lhsShape.end() - 1);
matmulShape.push_back(outputFftDim * 2);
ValueTensorType matmulType =
ValueTensorType::get(op.getContext(), matmulShape, dtype);
Value flatRes =
rewriter.create<AtenMatmulOp>(loc, matmulType, self, coeffMatrix);
// coeff : (fftLength x outputFftDim)
ValueTensorType matrixType = ValueTensorType::get(
op.getContext(), SmallVector<int64_t>{fftLength, outputFftDim},
dtype);
Value coeffMatrix = getDFTMatmulCoeff(rewriter, loc, matrixType,
/*isRealPart=*/isRealPart);
// X = matmul(lhs, coeff) : (D x 1 x outputFftDim)
Value matmulRes = createBatchMatmul(rewriter, loc, lhs, coeffMatrix);
// Y = squeeze(X, -2) : (D x outputFftDim)
auto squeezed = squeezeTensor(rewriter, op, loc, -2, matmulRes);
if (failed(squeezed))
return rewriter.notifyMatchFailure(op,
"cannot generate squeezed tensor");
if (isRealPart) {
real = *squeezed;
} else {
complex = *squeezed;
}
}
// Pack components into a complex tensor
BaseTensorType realType = cast<BaseTensorType>(real.getType());
SmallVector<int64_t> stackSizes(realType.getSizes());
stackSizes.push_back(2);
Value sequence = rewriter.create<PrimListConstructOp>(
loc, ListType::get(op.getContext(), realType),
ValueRange{real, complex});
Type stackType = realType.getWithSizesAndDtype(stackSizes, dtype);
// Y = unflatten(X, -1, [outputFftDim, 2])
// : (D_0 x ... x D_m x outputFftDim x 2)
// Z = view_as_complex(Y) : complex(D_0 x ... x D_m x outputFftDim)
SmallVector<int64_t> complexResShape(matmulShape);
complexResShape.back() = outputFftDim;
SmallVector<int64_t> unflattenedResShape(complexResShape);
unflattenedResShape.push_back(2);
Type unflattenedResType =
ValueTensorType::get(op.getContext(), unflattenedResShape, dtype);
Value cstMinusOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
Value stack =
rewriter.create<AtenStackOp>(loc, stackType, sequence, cstMinusOne);
Type complexResType = ValueTensorType::get(
op.getContext(), realType.getSizes(), ComplexType::get(dtype));
Value complexRes =
rewriter.create<AtenViewAsComplexOp>(loc, complexResType, stack);
Value unflattenSizes = toIntListConstruct(
rewriter, loc, {outputFftDim, 2}, IntType::get(rewriter.getContext()));
Value unflattenedRes = rewriter.create<AtenUnflattenIntOp>(
loc, unflattenedResType, flatRes, /*dim=*/cstMinusOne, unflattenSizes);
Type complexResType = ValueTensorType::get(op.getContext(), complexResShape,
ComplexType::get(dtype));
Value complexRes = rewriter.create<AtenViewAsComplexOp>(loc, complexResType,
unflattenedRes);
// Transpose back
if (needTranspose) {

View File

@ -36,6 +36,18 @@ Torch::matchLegalConstantIndexIntoListOfSize(Value v, int64_t length) {
return dim;
}
Value Torch::toIntListConstruct(PatternRewriter &rewriter, Location loc,
ArrayRef<int64_t> cstInput,
Torch::IntType intType) {
SmallVector<Value> cstValues;
for (int64_t i : cstInput) {
cstValues.push_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i)));
}
return rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(intType), cstValues);
}
bool Torch::getListConstructElements(Value v, SmallVectorImpl<Value> &elems) {
auto listConstruct = v.getDefiningOp<PrimListConstructOp>();
if (!listConstruct)

View File

@ -175,26 +175,16 @@ func.func @torch.aten.fmod_float(%arg0: !torch.vtensor<[?],f16>, %arg1: !torch.v
// -----
// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_last_dim(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex<f32>> {
// CHECK: %[[INTM1:.*]] = torch.constant.int -1
// CHECK: %[[VAR0:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<9x5xf32>) : !torch.vtensor<[9,5],f32>
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[VAR1:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<9x5xf32>) : !torch.vtensor<[9,5],f32>
// CHECK: %[[INTM2:.*]] = torch.constant.int -2
// CHECK: %[[VAR2:.*]] = torch.aten.unsqueeze %[[ARG0:.*]], %[[INTM2:.*]] : !torch.vtensor<[16,9],f32>, !torch.int -> !torch.vtensor<[16,1,9],f32>
// CHECK: %[[VAR3:.*]] = torch.aten.matmul %[[VAR2:.*]], %[[VAR1:.*]] : !torch.vtensor<[16,1,9],f32>, !torch.vtensor<[9,5],f32> -> !torch.vtensor<[16,1,5],f32>
// CHECK: torch.runtime.assert %[[TRUE:.*]], "squeeze operation possible for dim only when input_shape[dim] == 1."
// CHECK: %[[VAR4:.*]] = torch.aten.squeeze.dim %[[VAR3:.*]], %[[INT1:.*]] : !torch.vtensor<[16,1,5],f32>, !torch.int -> !torch.vtensor<[16,5],f32>
// CHECK: %[[VAR5:.*]] = torch.aten.matmul %[[VAR2:.*]], %[[VAR0:.*]] : !torch.vtensor<[16,1,9],f32>, !torch.vtensor<[9,5],f32> -> !torch.vtensor<[16,1,5],f32>
// CHECK: torch.runtime.assert %[[TRUE:.*]], "squeeze operation possible for dim only when input_shape[dim] == 1."
// CHECK: %[[VAR6:.*]] = torch.aten.squeeze.dim %[[VAR5:.*]], %[[INT1:.*]] : !torch.vtensor<[16,1,5],f32>, !torch.int -> !torch.vtensor<[16,5],f32>
// CHECK: %[[VAR7:.*]] = torch.aten.unsqueeze %[[VAR4:.*]], %[[INTM1:.*]] : !torch.vtensor<[16,5],f32>, !torch.int -> !torch.vtensor<[16,5,1],f32>
// CHECK: %[[VAR8:.*]] = torch.aten.unsqueeze %[[VAR6:.*]], %[[INTM1:.*]] : !torch.vtensor<[16,5],f32>, !torch.int -> !torch.vtensor<[16,5,1],f32>
// CHECK: %[[VAR9:.*]] = torch.prim.ListConstruct %[[VAR7:.*]], %[[VAR8:.*]] : (!torch.vtensor<[16,5,1],f32>, !torch.vtensor<[16,5,1],f32>) -> !torch.list<vtensor>
// CHECK: %[[VAR10:.*]] = torch.aten.cat %[[VAR9:.*]], %[[INTM1:.*]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[16,5,2],f32>
// CHECK: %[[VAR11:.*]] = torch.aten.view_as_complex %[[VAR10:.*]] : !torch.vtensor<[16,5,2],f32> -> !torch.vtensor<[16,5],complex<f32>>
// CHECK: return %[[VAR11:.*]] : !torch.vtensor<[16,5],complex<f32>>
// CHECK-SAME: %arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex<f32>> {
// CHECK-DAG: %[[INT2:.*]] = torch.constant.int 2
// CHECK-DAG: %[[INT5:.*]] = torch.constant.int 5
// CHECK-DAG: %[[INT16:.*]] = torch.constant.int 16
// CHECK: %[[VAR0:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<9x10xf32>) : !torch.vtensor<[9,10],f32>
// CHECK: %[[VAR1:.*]] = torch.aten.mm %arg0, %[[VAR0]] : !torch.vtensor<[16,9],f32>, !torch.vtensor<[9,10],f32> -> !torch.vtensor<[16,10],f32>
// CHECK: %[[VAR2:.*]] = torch.prim.ListConstruct %[[INT16]], %[[INT5]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR3:.*]] = torch.aten.view %[[VAR1]], %[[VAR2]] : !torch.vtensor<[16,10],f32>, !torch.list<int> -> !torch.vtensor<[16,5,2],f32>
// CHECK: %[[VAR4:.*]] = torch.aten.view_as_complex %[[VAR3]] : !torch.vtensor<[16,5,2],f32> -> !torch.vtensor<[16,5],complex<f32>>
// CHECK: return %[[VAR4]] : !torch.vtensor<[16,5],complex<f32>>
func.func @torch.aten.fft_rfft$2d_last_dim(%arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex<f32>> {
%int-1 = torch.constant.int -1
%none = torch.constant.none
@ -205,29 +195,20 @@ func.func @torch.aten.fft_rfft$2d_last_dim(%arg0: !torch.vtensor<[16,9],f32>) ->
// -----
// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_first_dim(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex<f32>> {
// CHECK: %[[INTM1:.*]] = torch.constant.int -1
// CHECK: %[[VAR0:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<36x19xf32>) : !torch.vtensor<[36,19],f32>
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[VAR1:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<36x19xf32>) : !torch.vtensor<[36,19],f32>
// CHECK: %[[INTM2:.*]] = torch.constant.int -2
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[VAR2:.*]] = torch.aten.transpose.int %[[ARG0:.*]], %[[INT0:.*]], %[[INT1:.*]] : !torch.vtensor<[36,23],f32>, !torch.int, !torch.int -> !torch.vtensor<[23,36],f32>
// CHECK: %[[VAR3:.*]] = torch.aten.unsqueeze %[[VAR2:.*]], %[[INTM2:.*]] : !torch.vtensor<[23,36],f32>, !torch.int -> !torch.vtensor<[23,1,36],f32>
// CHECK: %[[VAR4:.*]] = torch.aten.matmul %[[VAR3:.*]], %[[VAR1:.*]] : !torch.vtensor<[23,1,36],f32>, !torch.vtensor<[36,19],f32> -> !torch.vtensor<[23,1,19],f32>
// CHECK: torch.runtime.assert %[[TRUE:.*]], "squeeze operation possible for dim only when input_shape[dim] == 1."
// CHECK: %[[VAR5:.*]] = torch.aten.squeeze.dim %[[VAR4:.*]], %[[INT1:.*]] : !torch.vtensor<[23,1,19],f32>, !torch.int -> !torch.vtensor<[23,19],f32>
// CHECK: %[[VAR6:.*]] = torch.aten.matmul %[[VAR3:.*]], %[[VAR0:.*]] : !torch.vtensor<[23,1,36],f32>, !torch.vtensor<[36,19],f32> -> !torch.vtensor<[23,1,19],f32>
// CHECK: torch.runtime.assert %[[TRUE:.*]], "squeeze operation possible for dim only when input_shape[dim] == 1."
// CHECK: %[[VAR7:.*]] = torch.aten.squeeze.dim %[[VAR6:.*]], %[[INT1:.*]] : !torch.vtensor<[23,1,19],f32>, !torch.int -> !torch.vtensor<[23,19],f32>
// CHECK: %[[VAR8:.*]] = torch.aten.unsqueeze %[[VAR5:.*]], %[[INTM1:.*]] : !torch.vtensor<[23,19],f32>, !torch.int -> !torch.vtensor<[23,19,1],f32>
// CHECK: %[[VAR9:.*]] = torch.aten.unsqueeze %[[VAR7:.*]], %[[INTM1:.*]] : !torch.vtensor<[23,19],f32>, !torch.int -> !torch.vtensor<[23,19,1],f32>
// CHECK: %[[VAR10:.*]] = torch.prim.ListConstruct %[[VAR8:.*]], %[[VAR9:.*]] : (!torch.vtensor<[23,19,1],f32>, !torch.vtensor<[23,19,1],f32>) -> !torch.list<vtensor>
// CHECK: %[[VAR11:.*]] = torch.aten.cat %[[VAR10:.*]], %[[INTM1:.*]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[23,19,2],f32>
// CHECK: %[[VAR12:.*]] = torch.aten.view_as_complex %[[VAR11:.*]] : !torch.vtensor<[23,19,2],f32> -> !torch.vtensor<[23,19],complex<f32>>
// CHECK: %[[VAR13:.*]] = torch.aten.transpose.int %[[VAR12:.*]], %[[INT0:.*]], %[[INT1:.*]] : !torch.vtensor<[23,19],complex<f32>>, !torch.int, !torch.int -> !torch.vtensor<[19,23],complex<f32>>
// CHECK: return %[[VAR13:.*]] : !torch.vtensor<[19,23],complex<f32>>
// CHECK-SAME: %arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex<f32>> {
// CHECK-DAG: %[[INT2:.*]] = torch.constant.int 2
// CHECK-DAG: %[[INT19:.*]] = torch.constant.int 19
// CHECK-DAG: %[[INT23:.*]] = torch.constant.int 23
// CHECK-DAG: %[[VAR0:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<36x38xf32>) : !torch.vtensor<[36,38],f32>
// CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0
// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[VAR1:.*]] = torch.aten.transpose.int %arg0, %[[INT0]], %[[INT1]] : !torch.vtensor<[36,23],f32>, !torch.int, !torch.int -> !torch.vtensor<[23,36],f32>
// CHECK: %[[VAR2:.*]] = torch.aten.mm %[[VAR1]], %[[VAR0]] : !torch.vtensor<[23,36],f32>, !torch.vtensor<[36,38],f32> -> !torch.vtensor<[23,38],f32>
// CHECK: %[[VAR3:.*]] = torch.prim.ListConstruct %[[INT23]], %[[INT19]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAR4:.*]] = torch.aten.view %[[VAR2]], %[[VAR3]] : !torch.vtensor<[23,38],f32>, !torch.list<int> -> !torch.vtensor<[23,19,2],f32>
// CHECK: %[[VAR5:.*]] = torch.aten.view_as_complex %[[VAR4]] : !torch.vtensor<[23,19,2],f32> -> !torch.vtensor<[23,19],complex<f32>>
// CHECK: %[[VAR6:.*]] = torch.aten.transpose.int %[[VAR5]], %[[INT0]], %[[INT1]] : !torch.vtensor<[23,19],complex<f32>>, !torch.int, !torch.int -> !torch.vtensor<[19,23],complex<f32>>
// CHECK: return %[[VAR6]] : !torch.vtensor<[19,23],complex<f32>>
func.func @torch.aten.fft_rfft$2d_first_dim(%arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex<f32>> {
%int0 = torch.constant.int 0
%none = torch.constant.none