2022-03-11 01:54:13 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
|
|
|
|
|
|
|
#include "../PassDetail.h"
|
|
|
|
#include "PopulatePatterns.h"
|
2022-10-05 21:28:06 +08:00
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
2022-03-11 01:54:13 +08:00
|
|
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
|
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
|
|
#include "mlir/IR/Matchers.h"
|
2023-12-02 08:38:21 +08:00
|
|
|
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
|
2022-03-11 01:54:13 +08:00
|
|
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
2022-08-25 00:19:35 +08:00
|
|
|
#include <algorithm>
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::torch;
|
|
|
|
using namespace mlir::torch::Torch;
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(AtenMmOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op->getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value lhs = adaptor.getSelf();
|
|
|
|
Value rhs = adaptor.getMat2();
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
// A user can write an errorneous program where `aten.mm` is in fact called
|
|
|
|
// with operands of invalid rank or dtype. We cannot convert to linalg in
|
|
|
|
// this case or we will get a verifier error, which corresponds to breaking
|
|
|
|
// of *internal* compiler invariants, and for a user manifests as a compiler
|
|
|
|
// crash in the worst case (such as we try to canonicalize/fold/print the
|
|
|
|
// invalid op before the verifier gets to see it -- also release builds of a
|
|
|
|
// mature compiler usually have the verifier turned off for compile time
|
|
|
|
// reasons).
|
|
|
|
//
|
|
|
|
// The compiler cannot crash even if the user wrote an erroneous program!
|
|
|
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
|
|
|
return failure();
|
2023-12-07 13:13:53 +08:00
|
|
|
|
|
|
|
RankedTensorType lhsType = lhs.getType().cast<RankedTensorType>();
|
|
|
|
RankedTensorType rhsType = rhs.getType().cast<RankedTensorType>();
|
|
|
|
|
|
|
|
if (lhsType.getRank() != 2 || rhsType.getRank() != 2) {
|
2022-03-11 01:54:13 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "expected both operands to aten.mm to be rank 2");
|
|
|
|
}
|
|
|
|
|
2023-12-07 13:13:53 +08:00
|
|
|
ValueTensorType lhsTorchType =
|
|
|
|
op.getSelf().getType().cast<ValueTensorType>();
|
|
|
|
ValueTensorType rhsTorchType =
|
|
|
|
op.getMat2().getType().cast<ValueTensorType>();
|
|
|
|
if (lhsTorchType.getDtype() != rhsTorchType.getDtype()) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unsupported: aten.mm with different input element types");
|
|
|
|
}
|
|
|
|
|
2022-03-11 01:54:13 +08:00
|
|
|
Value lhsDim0 = rewriter.create<tensor::DimOp>(loc, lhs, 0);
|
|
|
|
Value rhsDim1 = rewriter.create<tensor::DimOp>(loc, rhs, 1);
|
2023-09-30 07:45:48 +08:00
|
|
|
|
|
|
|
if (!isAssumingStrictSymbolicShapes(rewriter)) {
|
|
|
|
Value lhsDim1 = rewriter.create<tensor::DimOp>(loc, lhs, 1);
|
|
|
|
Value rhsDim0 = rewriter.create<tensor::DimOp>(loc, rhs, 0);
|
|
|
|
Value contractingDimEqual = rewriter.create<arith::CmpIOp>(
|
|
|
|
loc, arith::CmpIPredicate::eq, lhsDim1, rhsDim0);
|
|
|
|
rewriter.create<cf::AssertOp>(
|
|
|
|
loc, contractingDimEqual,
|
|
|
|
rewriter.getStringAttr(
|
|
|
|
"mismatching contracting dimension for torch.aten.mm"));
|
|
|
|
}
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
|
|
|
Type elementType = newResultType.cast<TensorType>().getElementType();
|
2023-12-07 13:13:53 +08:00
|
|
|
Value zeroFill = createZeroInitTensor(
|
|
|
|
rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType);
|
|
|
|
|
|
|
|
Value matmul;
|
|
|
|
auto intType = dyn_cast<mlir::IntegerType>(lhsTorchType.getDtype());
|
|
|
|
if (intType && intType.isUnsigned()) {
|
|
|
|
matmul = rewriter
|
|
|
|
.create<linalg::MatmulUnsignedOp>(
|
|
|
|
loc, zeroFill.getType(), ValueRange{lhs, rhs}, zeroFill)
|
|
|
|
.getResult(0);
|
|
|
|
} else {
|
|
|
|
matmul = rewriter
|
|
|
|
.create<linalg::MatmulOp>(loc, zeroFill.getType(),
|
|
|
|
ValueRange{lhs, rhs}, zeroFill)
|
|
|
|
.getResult(0);
|
|
|
|
}
|
2022-10-18 12:22:53 +08:00
|
|
|
// When constructed with just dynamic sizes, EmptyOp will have a result
|
2022-03-11 01:54:13 +08:00
|
|
|
// type which has all `?`'s for dimensions, which might not be the result
|
|
|
|
// type of `op`. The constraints on later linalg ops means that the result
|
|
|
|
// of the MatmulOp will have this type too. So cast it to the desired type
|
|
|
|
// so that in the end we have the original result type.
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-05-03 04:01:15 +08:00
|
|
|
namespace {
|
|
|
|
class ConvertAtenFlipOp : public OpConversionPattern<AtenFlipOp> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(AtenFlipOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
Location loc = op->getLoc();
|
|
|
|
MLIRContext *context = op.getContext();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value self = adaptor.getSelf();
|
|
|
|
auto selfRank = adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
|
2022-05-03 04:01:15 +08:00
|
|
|
Type elementType =
|
2022-12-08 04:20:41 +08:00
|
|
|
adaptor.getSelf().getType().cast<RankedTensorType>().getElementType();
|
2022-05-03 04:01:15 +08:00
|
|
|
Value c1 =
|
|
|
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
|
|
|
|
|
|
|
SmallVector<int64_t> axis;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(axis)))
|
2022-05-03 04:01:15 +08:00
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"only constant dim lists supported");
|
2023-06-15 10:27:34 +08:00
|
|
|
for (unsigned i = 0, e = axis.size(); i < e; i++) {
|
|
|
|
axis[i] = toPositiveDim(axis[i], selfRank);
|
|
|
|
if (!isValidDim(axis[i], selfRank)) {
|
|
|
|
return rewriter.notifyMatchFailure(op, "axis is statically invalid");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-05-03 04:01:15 +08:00
|
|
|
// Only used to calculate flipped values, i.e. those on the flip axes. Other
|
|
|
|
// dims won't be used.
|
|
|
|
SmallVector<Value> dims = getTensorSizes(rewriter, loc, self);
|
|
|
|
for (auto flipDim : axis)
|
|
|
|
dims[flipDim] = rewriter.create<arith::SubIOp>(loc, dims[flipDim], c1);
|
|
|
|
|
|
|
|
Value initTensor = createZeroInitTensor(
|
|
|
|
rewriter, loc, getTensorSizes(rewriter, loc, self), elementType);
|
|
|
|
|
2022-11-17 06:40:36 +08:00
|
|
|
SmallVector<utils::IteratorType> iteratorTypes(
|
|
|
|
selfRank, utils::IteratorType::parallel);
|
2022-05-03 04:01:15 +08:00
|
|
|
SmallVector<AffineMap> indexingMaps(
|
|
|
|
2, AffineMap::getMultiDimIdentityMap(selfRank, context));
|
|
|
|
Value flipped =
|
|
|
|
rewriter
|
|
|
|
.create<linalg::GenericOp>(
|
|
|
|
loc, self.getType(), self, initTensor, indexingMaps,
|
|
|
|
iteratorTypes,
|
|
|
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
|
|
|
SmallVector<Value> indices;
|
|
|
|
for (auto i = 0; i < selfRank; i++)
|
|
|
|
indices.push_back(b.create<linalg::IndexOp>(loc, i));
|
|
|
|
for (auto flipDim : axis) {
|
|
|
|
indices[flipDim] = b.create<arith::SubIOp>(
|
|
|
|
loc, dims[flipDim], indices[flipDim]);
|
|
|
|
}
|
|
|
|
Value res = b.create<tensor::ExtractOp>(loc, self, indices)
|
|
|
|
.getResult();
|
|
|
|
b.create<linalg::YieldOp>(loc, res);
|
|
|
|
})
|
|
|
|
.getResult(0);
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, self.getType(), flipped);
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-03-11 01:54:13 +08:00
|
|
|
namespace {
|
|
|
|
class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(AtenMatmulOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op->getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value lhs = adaptor.getSelf();
|
|
|
|
Value rhs = adaptor.getOther();
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
|
|
|
return failure();
|
2022-06-16 23:45:10 +08:00
|
|
|
auto lhsType = lhs.getType().cast<RankedTensorType>();
|
|
|
|
auto rhsType = rhs.getType().cast<RankedTensorType>();
|
2022-03-11 01:54:13 +08:00
|
|
|
|
2022-06-16 23:45:10 +08:00
|
|
|
// Get the rank of both matrix.
|
|
|
|
unsigned lhsRank = lhsType.getRank();
|
|
|
|
unsigned rhsRank = rhsType.getRank();
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
2022-06-16 23:45:10 +08:00
|
|
|
auto resultType = newResultType.cast<RankedTensorType>();
|
|
|
|
Type elementType = resultType.getElementType();
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
// The different cases of torch_matmul op is mentioned here:
|
|
|
|
// https://pytorch.org/docs/stable/generated/torch.matmul.html
|
|
|
|
|
|
|
|
// First Case: Dot Product.
|
|
|
|
if (lhsRank == 1 && rhsRank == 1) {
|
|
|
|
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
|
|
|
|
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
|
|
|
|
|
|
|
|
checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0);
|
|
|
|
|
|
|
|
Value zeroTensor = createZeroInitTensor(rewriter, loc, {}, elementType);
|
|
|
|
Value dotProd =
|
|
|
|
rewriter
|
|
|
|
.create<linalg::DotOp>(loc, zeroTensor.getType(),
|
|
|
|
ValueRange{lhs, rhs}, zeroTensor)
|
|
|
|
.getResult(0);
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, dotProd);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Second Case: Vec-Mat Multiplication.
|
|
|
|
if (lhsRank == 1 && rhsRank == 2) {
|
|
|
|
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
|
|
|
|
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
|
|
|
|
Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1);
|
|
|
|
checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0);
|
|
|
|
|
|
|
|
Value zeroTensor =
|
|
|
|
createZeroInitTensor(rewriter, loc, ValueRange{rhsDim1}, elementType);
|
|
|
|
Value matmul =
|
|
|
|
rewriter
|
|
|
|
.create<linalg::VecmatOp>(loc, zeroTensor.getType(),
|
|
|
|
ValueRange{lhs, rhs}, zeroTensor)
|
|
|
|
.getResult(0);
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Third Case: Matrix-Vec Multiplication.
|
|
|
|
if (lhsRank == 2 && rhsRank == 1) {
|
|
|
|
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
|
|
|
|
Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1);
|
|
|
|
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
|
|
|
|
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
|
|
|
|
|
|
|
|
Value zeroTensor =
|
|
|
|
createZeroInitTensor(rewriter, loc, ValueRange{lhsDim0}, elementType);
|
|
|
|
Value matmul =
|
|
|
|
rewriter
|
|
|
|
.create<linalg::MatvecOp>(loc, zeroTensor.getType(),
|
|
|
|
ValueRange{lhs, rhs}, zeroTensor)
|
|
|
|
.getResult(0);
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Fourth Case: Batch-Matrix Multiplication.
|
2022-06-16 23:45:10 +08:00
|
|
|
// TODO: Handle batch matrix multiplication when one of the matrix is unity
|
|
|
|
// rank and the other has batch dimension.
|
|
|
|
if (lhsRank > 1 && rhsRank > 1) {
|
|
|
|
unsigned maxRank = std::max(lhsRank, rhsRank);
|
|
|
|
unsigned minRank = std::min(lhsRank, rhsRank);
|
|
|
|
unsigned batchRank = maxRank - 2;
|
|
|
|
|
|
|
|
// At least one of the matrix must have rank greater than 2.
|
|
|
|
if (batchRank <= 0) {
|
|
|
|
return rewriter.notifyMatchFailure(op, "expected batch dimensions");
|
|
|
|
}
|
|
|
|
|
|
|
|
// The `broadcastedBatchShape` contains batch dimensions of the resultant
|
|
|
|
// matrix.
|
|
|
|
SmallVector<Value> broadcastedBatchShape(batchRank);
|
|
|
|
Value maxRankMatrix = (lhsRank > rhsRank) ? lhs : rhs;
|
|
|
|
Value maxDim;
|
|
|
|
// Compute broadcasted batch dimensions if the batch dimensions of
|
|
|
|
// the matrices are broadcastable.
|
|
|
|
for (unsigned i = 1; i <= batchRank; i++) {
|
|
|
|
if (i <= minRank - 2) {
|
|
|
|
Value lhsDim = getDimOp(rewriter, loc, lhs, lhsRank - 2 - i);
|
|
|
|
Value rhsDim = getDimOp(rewriter, loc, rhs, rhsRank - 2 - i);
|
|
|
|
maxDim = rewriter.createOrFold<arith::MaxUIOp>(loc, lhsDim, rhsDim);
|
|
|
|
} else {
|
|
|
|
maxDim = getDimOp(rewriter, loc, maxRankMatrix, maxRank - 2 - i);
|
|
|
|
}
|
|
|
|
broadcastedBatchShape[batchRank - i] = maxDim;
|
|
|
|
}
|
2022-03-11 01:54:13 +08:00
|
|
|
|
2022-06-16 23:45:10 +08:00
|
|
|
Value lhsDim0 = getDimOp(rewriter, loc, lhs, lhsRank - 2);
|
|
|
|
Value lhsDim1 = getDimOp(rewriter, loc, lhs, lhsRank - 1);
|
|
|
|
Value rhsDim0 = getDimOp(rewriter, loc, rhs, rhsRank - 2);
|
|
|
|
Value rhsDim1 = getDimOp(rewriter, loc, rhs, rhsRank - 1);
|
|
|
|
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
|
|
|
|
|
|
|
|
// Compute broadcasted shape of both the matrices in integer format.
|
|
|
|
SmallVector<Value> lhsBroadcastToShape(broadcastedBatchShape);
|
|
|
|
lhsBroadcastToShape.push_back(lhsDim0);
|
|
|
|
lhsBroadcastToShape.push_back(lhsDim1);
|
|
|
|
SmallVector<Value> rhsBroadcastToShape(broadcastedBatchShape);
|
|
|
|
rhsBroadcastToShape.push_back(rhsDim0);
|
|
|
|
rhsBroadcastToShape.push_back(rhsDim1);
|
|
|
|
for (unsigned i = 0; i < maxRank; i++) {
|
|
|
|
lhsBroadcastToShape[i] =
|
|
|
|
castIndexToInt64(rewriter, loc, lhsBroadcastToShape[i]);
|
|
|
|
rhsBroadcastToShape[i] =
|
|
|
|
castIndexToInt64(rewriter, loc, rhsBroadcastToShape[i]);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Broadcast the batch dimensions of both the matrices.
|
|
|
|
Value broadcastedLhs, broadcastedRhs;
|
2023-10-06 03:15:26 +08:00
|
|
|
// TODO: Improve usage of static shape information.
|
|
|
|
SmallVector<int64_t> lhsTargetShape(lhsBroadcastToShape.size(),
|
|
|
|
ShapedType::kDynamic);
|
|
|
|
auto lhsBroadcastType =
|
|
|
|
RankedTensorType::get(lhsTargetShape, lhsType.getElementType());
|
2022-06-16 23:45:10 +08:00
|
|
|
if (failed(torch_to_linalg::broadcastToGivenShape(
|
2023-10-06 03:15:26 +08:00
|
|
|
op, rewriter, lhs, lhsBroadcastToShape, lhsBroadcastType,
|
|
|
|
broadcastedLhs))) {
|
2022-06-16 23:45:10 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unable to perform broadcast operation");
|
|
|
|
}
|
2023-10-06 03:15:26 +08:00
|
|
|
SmallVector<int64_t> rhsTargetShape(rhsBroadcastToShape.size(),
|
|
|
|
ShapedType::kDynamic);
|
|
|
|
auto rhsBroadcastType =
|
|
|
|
RankedTensorType::get(rhsTargetShape, rhsType.getElementType());
|
2022-06-16 23:45:10 +08:00
|
|
|
if (failed(torch_to_linalg::broadcastToGivenShape(
|
2023-10-06 03:15:26 +08:00
|
|
|
op, rewriter, rhs, rhsBroadcastToShape, rhsBroadcastType,
|
|
|
|
broadcastedRhs))) {
|
2022-06-16 23:45:10 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unable to perform broadcast operation");
|
|
|
|
}
|
|
|
|
|
2022-05-19 00:29:04 +08:00
|
|
|
if (maxRank == 3) {
|
|
|
|
Value zeroTensor = createZeroInitTensor(
|
|
|
|
rewriter, loc,
|
|
|
|
ValueRange{broadcastedBatchShape[0], lhsDim0, rhsDim1},
|
|
|
|
elementType);
|
|
|
|
Value matmul =
|
|
|
|
rewriter
|
|
|
|
.create<linalg::BatchMatmulOp>(
|
|
|
|
loc, zeroTensor.getType(),
|
|
|
|
ValueRange{broadcastedLhs, broadcastedRhs}, zeroTensor)
|
|
|
|
.getResult(0);
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-06-16 23:45:10 +08:00
|
|
|
// Check if the result of the matrix multiplication has more than one
|
|
|
|
// dynamic batch dimensions.
|
2022-11-29 20:33:31 +08:00
|
|
|
SmallVector<int64_t> batchDimsInt =
|
|
|
|
makeShapeTorchCompatible(resultType.getShape());
|
|
|
|
batchDimsInt.pop_back();
|
|
|
|
batchDimsInt.pop_back();
|
2022-06-16 23:45:10 +08:00
|
|
|
bool multipleDynamicBatchDims =
|
|
|
|
llvm::count(batchDimsInt, kUnknownSize) > 1;
|
|
|
|
|
|
|
|
// TODO: Lowering to `linalg.BatchMatmul` is only possible when there is
|
|
|
|
// at most one dynamic batch dimension due to limited support of the
|
|
|
|
// `tensor.ExpandShape` op.
|
|
|
|
if (!multipleDynamicBatchDims) {
|
|
|
|
// Collapse the batch dimensions into one dimension. The resultant rank
|
|
|
|
// will always be 3.
|
|
|
|
SmallVector<ReassociationIndices> reassociation(3);
|
|
|
|
for (unsigned i = 0, j = 0; i < maxRank; i++) {
|
|
|
|
if (i >= batchRank)
|
|
|
|
j++;
|
|
|
|
reassociation[j].push_back(i);
|
|
|
|
}
|
|
|
|
Value collapsedLhs = rewriter.create<tensor::CollapseShapeOp>(
|
|
|
|
op->getLoc(), broadcastedLhs, reassociation);
|
|
|
|
Value collapsedRhs = rewriter.create<tensor::CollapseShapeOp>(
|
|
|
|
op->getLoc(), broadcastedRhs, reassociation);
|
|
|
|
|
|
|
|
// Compute the result shape after collapsing the batch dimensions.
|
|
|
|
SmallVector<Value> collapsedResultShape;
|
|
|
|
collapsedResultShape.push_back(broadcastedBatchShape[0]);
|
|
|
|
for (unsigned i = 1; i < batchRank; i++) {
|
|
|
|
collapsedResultShape[0] = rewriter.createOrFold<arith::MulIOp>(
|
|
|
|
loc, collapsedResultShape[0], broadcastedBatchShape[i]);
|
|
|
|
}
|
|
|
|
collapsedResultShape.push_back(lhsDim0);
|
|
|
|
collapsedResultShape.push_back(rhsDim1);
|
|
|
|
SmallVector<OpFoldResult> updatedCollapseResultShape =
|
|
|
|
getAsOpFoldResult(collapsedResultShape);
|
|
|
|
|
2022-10-18 12:22:53 +08:00
|
|
|
Value initTensor = rewriter.create<tensor::EmptyOp>(
|
2022-06-16 23:45:10 +08:00
|
|
|
loc, updatedCollapseResultShape, elementType);
|
|
|
|
Value c0 = rewriter.create<arith::ConstantOp>(
|
|
|
|
loc, rewriter.getZeroAttr(elementType));
|
|
|
|
Value zeroTensor =
|
|
|
|
rewriter.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
|
|
|
|
|
|
|
|
Value batchMatMul =
|
|
|
|
rewriter
|
|
|
|
.create<linalg::BatchMatmulOp>(
|
|
|
|
loc, zeroTensor.getType(),
|
|
|
|
ValueRange{collapsedLhs, collapsedRhs}, zeroTensor)
|
|
|
|
.getResult(0);
|
|
|
|
Value expandResult = rewriter.create<tensor::ExpandShapeOp>(
|
|
|
|
loc, resultType, batchMatMul, reassociation);
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType,
|
|
|
|
expandResult);
|
|
|
|
return success();
|
|
|
|
}
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
SmallVector<AffineExpr> lhsExpr;
|
|
|
|
SmallVector<AffineExpr> rhsExpr;
|
|
|
|
SmallVector<AffineExpr> outExpr;
|
2022-11-17 06:40:36 +08:00
|
|
|
SmallVector<utils::IteratorType> iteratorTypes(
|
|
|
|
batchRank, utils::IteratorType::parallel);
|
2022-03-11 01:54:13 +08:00
|
|
|
for (unsigned i = 0; i < batchRank; i++) {
|
|
|
|
lhsExpr.push_back(rewriter.getAffineDimExpr(i));
|
|
|
|
rhsExpr.push_back(rewriter.getAffineDimExpr(i));
|
|
|
|
outExpr.push_back(rewriter.getAffineDimExpr(i));
|
|
|
|
}
|
|
|
|
lhsExpr.insert(lhsExpr.end(), {rewriter.getAffineDimExpr(batchRank),
|
|
|
|
rewriter.getAffineDimExpr(batchRank + 1)});
|
|
|
|
rhsExpr.insert(rhsExpr.end(), {rewriter.getAffineDimExpr(batchRank + 1),
|
|
|
|
rewriter.getAffineDimExpr(batchRank + 2)});
|
|
|
|
outExpr.insert(outExpr.end(), {rewriter.getAffineDimExpr(batchRank),
|
|
|
|
rewriter.getAffineDimExpr(batchRank + 2)});
|
|
|
|
|
2022-06-16 23:45:10 +08:00
|
|
|
SmallVector<Value> resultShape(broadcastedBatchShape);
|
|
|
|
resultShape.insert(resultShape.end(), {lhsDim0, rhsDim1});
|
|
|
|
Value zeroTensor =
|
2022-03-11 01:54:13 +08:00
|
|
|
createZeroInitTensor(rewriter, loc, resultShape, elementType);
|
|
|
|
auto indexingMaps =
|
|
|
|
AffineMap::inferFromExprList({lhsExpr, rhsExpr, outExpr});
|
|
|
|
iteratorTypes.insert(iteratorTypes.end(),
|
2022-11-17 06:40:36 +08:00
|
|
|
{utils::IteratorType::parallel,
|
|
|
|
utils::IteratorType::reduction,
|
|
|
|
utils::IteratorType::parallel});
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
Value finalRes =
|
|
|
|
rewriter
|
|
|
|
.create<linalg::GenericOp>(
|
2022-06-16 23:45:10 +08:00
|
|
|
loc, zeroTensor.getType(),
|
|
|
|
ValueRange{broadcastedLhs, broadcastedRhs}, zeroTensor,
|
2022-03-11 01:54:13 +08:00
|
|
|
/*indexingMaps=*/indexingMaps,
|
|
|
|
/*iteratorTypes=*/iteratorTypes,
|
|
|
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
|
|
|
Value l = args[0], r = args[1], res = args[2];
|
|
|
|
Value mul = b.create<arith::MulFOp>(loc, l, r);
|
|
|
|
Value add = b.create<arith::AddFOp>(loc, mul, res);
|
|
|
|
b.create<linalg::YieldOp>(loc, add);
|
|
|
|
})
|
|
|
|
.getResult(0);
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, finalRes);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
class ConvertAtenBmmOp : public OpConversionPattern<AtenBmmOp> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(AtenBmmOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
|
|
|
return failure();
|
|
|
|
Location loc = op->getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value lhs = adaptor.getSelf();
|
|
|
|
Value rhs = adaptor.getMat2();
|
2022-03-11 01:54:13 +08:00
|
|
|
RankedTensorType lhsType = lhs.getType().cast<RankedTensorType>();
|
|
|
|
RankedTensorType rhsType = rhs.getType().cast<RankedTensorType>();
|
2023-09-11 20:58:59 +08:00
|
|
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
|
|
|
Type resultElementType = newResultType.cast<RankedTensorType>().getElementType();
|
|
|
|
Type lhsElementType = lhsType.cast<RankedTensorType>().getElementType();
|
|
|
|
Type rhsElementType = rhsType.cast<RankedTensorType>().getElementType();
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
if (lhsType.getRank() != 3 || rhsType.getRank() != 3) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "expected both operands to aten.bmm to be rank 3");
|
|
|
|
}
|
2023-09-11 20:58:59 +08:00
|
|
|
|
|
|
|
// Convert the inputs element type equivalent to the result' element type.
|
|
|
|
if (lhsElementType != rhsElementType) {
|
|
|
|
if (lhsElementType != resultElementType) {
|
|
|
|
// True if the lhs element type is not equal to the result' element type.
|
|
|
|
lhs = torch_to_linalg::convertTensorToElementType(
|
|
|
|
rewriter, loc, lhs, resultElementType);
|
|
|
|
} else {
|
|
|
|
// True if the rhs element type is not equal to the result' element type.
|
|
|
|
rhs = torch_to_linalg::convertTensorToElementType(
|
|
|
|
rewriter, loc, rhs, resultElementType);
|
|
|
|
}
|
|
|
|
}
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
|
|
|
|
Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1);
|
|
|
|
Value lhsDim2 = getDimOp(rewriter, loc, lhs, 2);
|
|
|
|
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
|
|
|
|
Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1);
|
|
|
|
Value rhsDim2 = getDimOp(rewriter, loc, rhs, 2);
|
|
|
|
|
|
|
|
// Check the batch numbers are equal.
|
|
|
|
checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0);
|
|
|
|
|
|
|
|
// Check the matrixs shapes are valid for mulplication.
|
|
|
|
checkDimEqualHelper(rewriter, loc, lhsDim2, rhsDim1);
|
|
|
|
|
|
|
|
Value initTensor0 = createZeroInitTensor(
|
2023-09-11 20:58:59 +08:00
|
|
|
rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, resultElementType);
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
Value bmm =
|
|
|
|
rewriter
|
|
|
|
.create<linalg::BatchMatmulOp>(loc, initTensor0.getType(),
|
|
|
|
ValueRange{lhs, rhs}, initTensor0)
|
|
|
|
.getResult(0);
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, bmm);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
namespace {
|
2022-04-08 12:47:57 +08:00
|
|
|
class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
|
2022-03-11 01:54:13 +08:00
|
|
|
public:
|
|
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
2022-04-08 12:47:57 +08:00
|
|
|
matchAndRewrite(AtenConvolutionOp op, OpAdaptor adaptor,
|
2022-03-11 01:54:13 +08:00
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op->getLoc();
|
|
|
|
MLIRContext *context = op->getContext();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = adaptor.getInput(); /* in form of N*C*H*W */
|
|
|
|
Value weight = adaptor.getWeight(); /* in form of F*C*H*W */
|
2022-03-11 01:54:13 +08:00
|
|
|
|
2022-08-25 00:19:35 +08:00
|
|
|
bool transposed = true;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed)))
|
2022-08-25 00:19:35 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: only constant transposed supported");
|
|
|
|
|
2022-03-11 01:54:13 +08:00
|
|
|
Type elementType =
|
|
|
|
input.getType().cast<RankedTensorType>().getElementType();
|
|
|
|
if (!elementType.isa<mlir::FloatType>())
|
|
|
|
return op.emitError("unimplemented: non-floating point type");
|
2022-04-08 12:47:57 +08:00
|
|
|
size_t inRank = input.getType().cast<RankedTensorType>().getRank();
|
2022-08-25 00:19:35 +08:00
|
|
|
size_t numSpacialDims = inRank - 2;
|
|
|
|
if (numSpacialDims != 2)
|
2022-04-08 12:47:57 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: only 2D convolution currently supported");
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
Type intType = IntegerType::get(context, 64);
|
|
|
|
auto castIndexToInt = [&](Value v) {
|
|
|
|
return rewriter.create<arith::IndexCastOp>(loc, intType, v);
|
|
|
|
};
|
|
|
|
|
2022-11-04 15:57:29 +08:00
|
|
|
SmallVector<Value> paddingIntValues;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!getListConstructElements(op.getPadding(), paddingIntValues))
|
2022-03-11 01:54:13 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
2022-11-04 15:57:29 +08:00
|
|
|
op, "only support padding from a list construct");
|
|
|
|
paddingIntValues = getTypeConvertedValues(rewriter, loc, getTypeConverter(),
|
|
|
|
paddingIntValues);
|
2022-12-08 22:15:31 +08:00
|
|
|
SmallVector<Value> outputPaddingIntValues;
|
|
|
|
if (!getListConstructElements(op.getOutputPadding(),
|
|
|
|
outputPaddingIntValues))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "only support output_padding from a list construct");
|
|
|
|
outputPaddingIntValues = getTypeConvertedValues(
|
|
|
|
rewriter, loc, getTypeConverter(), outputPaddingIntValues);
|
2022-04-08 12:47:57 +08:00
|
|
|
SmallVector<int64_t> strideInts;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts)))
|
2022-03-11 01:54:13 +08:00
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"only support constant int strides");
|
2022-04-08 12:47:57 +08:00
|
|
|
SmallVector<int64_t> dilationInts;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationInts)))
|
2022-03-11 01:54:13 +08:00
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"only support constant int dilations");
|
2022-04-08 12:47:57 +08:00
|
|
|
|
2022-08-04 14:18:38 +08:00
|
|
|
Value inBatch = getDimOp(rewriter, loc, input, 0);
|
|
|
|
Value inChannels = getDimOp(rewriter, loc, input, 1);
|
2022-04-08 12:47:57 +08:00
|
|
|
SmallVector<Value> inDims;
|
|
|
|
for (size_t i = 2; i < inRank; i++)
|
|
|
|
inDims.push_back(getDimOp(rewriter, loc, input, i));
|
2022-08-04 14:18:38 +08:00
|
|
|
Value weightBatch = getDimOp(rewriter, loc, weight, 0);
|
|
|
|
Value weightChannels = getDimOp(rewriter, loc, weight, 1);
|
2022-04-08 12:47:57 +08:00
|
|
|
SmallVector<Value> weightDims;
|
|
|
|
for (size_t i = 2; i < inRank; i++)
|
|
|
|
weightDims.push_back(getDimOp(rewriter, loc, weight, i));
|
|
|
|
|
2022-08-04 14:18:38 +08:00
|
|
|
// Checks for valid group size
|
|
|
|
int64_t groupSize;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getGroups(), m_TorchConstantInt(&groupSize)))
|
2022-08-04 14:18:38 +08:00
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"only constant group size supported.");
|
2022-12-08 04:20:41 +08:00
|
|
|
Value groups = castIntToIndex(rewriter, loc, adaptor.getGroups());
|
2022-08-04 14:18:38 +08:00
|
|
|
|
|
|
|
auto validate = [&](Value toValidate, std::string err) {
|
|
|
|
Value c0 =
|
|
|
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
|
|
|
|
Value inputValid = rewriter.create<arith::CmpIOp>(
|
|
|
|
loc, arith::CmpIPredicate::eq, c0,
|
|
|
|
rewriter.create<arith::RemSIOp>(loc, toValidate, groups));
|
|
|
|
rewriter.create<cf::AssertOp>(loc, inputValid,
|
|
|
|
rewriter.getStringAttr(err));
|
|
|
|
};
|
|
|
|
validate(inChannels,
|
|
|
|
"invalid: groups must divide input channel size evenly.");
|
|
|
|
validate(weightBatch,
|
|
|
|
"invalid: groups must divide weight batch size evenly.");
|
2022-03-11 01:54:13 +08:00
|
|
|
SmallVector<Value> dilationIntValues =
|
|
|
|
getAsConstantIntValues(rewriter, loc, dilationInts);
|
|
|
|
SmallVector<Value> strideIntValues =
|
|
|
|
getAsConstantIntValues(rewriter, loc, strideInts);
|
|
|
|
|
2022-08-25 00:19:35 +08:00
|
|
|
// Pad the input tensor according to padding.
|
2022-08-04 14:18:38 +08:00
|
|
|
SmallVector<Value> outDims{inBatch, weightBatch};
|
2022-08-25 00:19:35 +08:00
|
|
|
Value paddedInput;
|
|
|
|
if (transposed) {
|
|
|
|
Value c0 =
|
|
|
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
|
|
|
|
Value c1 =
|
|
|
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
|
|
|
Value c2 =
|
|
|
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(2));
|
|
|
|
|
|
|
|
// Transpose and flip weight
|
|
|
|
SmallVector<Value> weightInitDims = getTensorSizes(rewriter, loc, weight);
|
|
|
|
std::iter_swap(weightInitDims.begin(), weightInitDims.begin() + 1);
|
|
|
|
outDims[1] = weightInitDims[0];
|
|
|
|
Value weightInitTensor =
|
|
|
|
createZeroInitTensor(rewriter, loc, weightInitDims, elementType);
|
2022-11-17 06:40:36 +08:00
|
|
|
SmallVector<utils::IteratorType> iteratorTypes(
|
|
|
|
inRank, utils::IteratorType::parallel);
|
2022-10-13 05:01:24 +08:00
|
|
|
SmallVector<AffineMap> indexingMaps{
|
|
|
|
AffineMap::getMultiDimIdentityMap(inRank, context)};
|
2022-08-25 00:19:35 +08:00
|
|
|
weight = rewriter
|
|
|
|
.create<linalg::GenericOp>(
|
2022-10-13 05:01:24 +08:00
|
|
|
loc, weightInitTensor.getType(), ValueRange{},
|
2022-08-25 00:19:35 +08:00
|
|
|
weightInitTensor, indexingMaps, iteratorTypes,
|
|
|
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
|
|
|
SmallVector<Value> indices;
|
|
|
|
for (size_t i = 0; i < inRank; i++)
|
|
|
|
indices.push_back(b.create<linalg::IndexOp>(loc, i));
|
|
|
|
std::iter_swap(indices.begin(), indices.begin() + 1);
|
|
|
|
// Flip only the spatial dimensions (from 2 to inRank)
|
|
|
|
for (size_t flipDim = 2; flipDim < inRank; flipDim++) {
|
|
|
|
indices[flipDim] = b.create<arith::SubIOp>(
|
|
|
|
loc,
|
|
|
|
b.create<arith::SubIOp>(
|
|
|
|
loc, weightInitDims[flipDim], c1),
|
|
|
|
indices[flipDim]);
|
|
|
|
}
|
|
|
|
Value res =
|
|
|
|
b.create<tensor::ExtractOp>(loc, weight, indices)
|
|
|
|
.getResult();
|
|
|
|
b.create<linalg::YieldOp>(loc, res);
|
|
|
|
})
|
|
|
|
.getResult(0);
|
|
|
|
|
|
|
|
// Calculate padded input size, allocate tensor
|
|
|
|
SmallVector<Value> outerSizes{inBatch, inChannels};
|
|
|
|
SmallVector<Value> innerSizes{inBatch, inChannels};
|
|
|
|
SmallVector<Value> offsets{c0, c0};
|
|
|
|
for (size_t i = 0; i < numSpacialDims; i++) {
|
|
|
|
Value innerSize = rewriter.create<arith::SubIOp>(loc, inDims[i], c1);
|
|
|
|
innerSize = rewriter.create<arith::MulIOp>(
|
|
|
|
loc, innerSize, castIntToIndex(rewriter, loc, strideIntValues[i]));
|
|
|
|
innerSize = rewriter.create<arith::AddIOp>(loc, innerSize, c1);
|
|
|
|
|
|
|
|
Value offset = rewriter.create<arith::SubIOp>(loc, weightDims[i], c1);
|
|
|
|
offset = rewriter.create<arith::MulIOp>(
|
|
|
|
loc, offset, castIntToIndex(rewriter, loc, dilationIntValues[i]));
|
|
|
|
offset = rewriter.create<arith::SubIOp>(
|
|
|
|
loc, offset, castIntToIndex(rewriter, loc, paddingIntValues[i]));
|
|
|
|
|
|
|
|
Value outerSize = rewriter.create<arith::MulIOp>(loc, offset, c2);
|
|
|
|
outerSize = rewriter.create<arith::AddIOp>(loc, outerSize, innerSize);
|
2022-12-08 22:15:31 +08:00
|
|
|
outerSize = rewriter.create<arith::AddIOp>(
|
|
|
|
loc, outerSize,
|
|
|
|
castIntToIndex(rewriter, loc, outputPaddingIntValues[i]));
|
2022-08-25 00:19:35 +08:00
|
|
|
|
|
|
|
outerSizes.push_back(outerSize);
|
|
|
|
offsets.push_back(offset);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Allocate padded input tensor
|
|
|
|
Value initTensor =
|
|
|
|
createZeroInitTensor(rewriter, loc, outerSizes, elementType);
|
|
|
|
|
|
|
|
// Insert input into allocated tensor
|
|
|
|
SmallVector<Value> strideIndexValues{c1, c1};
|
|
|
|
for (auto stride : strideIntValues)
|
|
|
|
strideIndexValues.push_back(castIntToIndex(rewriter, loc, stride));
|
|
|
|
SmallVector<Value> insertSizes = getTensorSizes(rewriter, loc, input);
|
|
|
|
|
|
|
|
paddedInput = rewriter.create<tensor::InsertSliceOp>(
|
|
|
|
loc, torch_to_linalg::removeSizeInformation(rewriter, loc, input),
|
|
|
|
initTensor, offsets, insertSizes, strideIndexValues);
|
|
|
|
|
|
|
|
// Calculate output dims
|
|
|
|
for (size_t i = 0; i < numSpacialDims; i++)
|
|
|
|
outDims.push_back(torch_to_linalg::getOutputDimForConvTransposeOps(
|
|
|
|
rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i],
|
2022-12-08 22:15:31 +08:00
|
|
|
castIndexToInt(weightDims[i]), strideIntValues[i],
|
|
|
|
outputPaddingIntValues[i]));
|
2022-08-25 00:19:35 +08:00
|
|
|
|
|
|
|
// Set stride to 1
|
|
|
|
strideInts.clear();
|
|
|
|
strideInts.append(numSpacialDims, 1);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
// Pad input
|
2022-11-04 15:57:29 +08:00
|
|
|
paddedInput = torch_to_linalg::getDynamicZeroPaddedTensor(
|
|
|
|
op, rewriter, input, paddingIntValues, /*unpaddedDims=*/2);
|
2022-08-25 00:19:35 +08:00
|
|
|
|
|
|
|
// Calculate output dims
|
|
|
|
for (size_t i = 0; i < numSpacialDims; i++)
|
|
|
|
outDims.push_back(torch_to_linalg::getOutputDimForConvOps(
|
|
|
|
rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i],
|
|
|
|
castIndexToInt(weightDims[i]), strideIntValues[i]));
|
|
|
|
}
|
2022-03-11 01:54:13 +08:00
|
|
|
|
2022-10-18 12:22:53 +08:00
|
|
|
Value initTensor = rewriter.create<tensor::EmptyOp>(
|
|
|
|
loc, getAsOpFoldResult(outDims), elementType);
|
2022-03-11 01:54:13 +08:00
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
Value bias = adaptor.getBias();
|
2022-08-04 14:18:38 +08:00
|
|
|
Value outputTensor;
|
2022-03-11 01:54:13 +08:00
|
|
|
if (bias.getType().isa<Torch::NoneType>()) {
|
|
|
|
Value c0float = rewriter.create<arith::ConstantOp>(
|
|
|
|
loc, FloatAttr::get(elementType, 0.0));
|
2022-08-04 14:18:38 +08:00
|
|
|
outputTensor = rewriter.create<linalg::FillOp>(loc, c0float, initTensor)
|
|
|
|
.getResult(0);
|
2022-03-11 01:54:13 +08:00
|
|
|
} else {
|
|
|
|
auto biasType = bias.getType().cast<RankedTensorType>();
|
|
|
|
if (biasType.getRank() != 1)
|
|
|
|
return rewriter.notifyMatchFailure(op, "expect bias to be rank 1");
|
|
|
|
if (elementType != biasType.getElementType())
|
|
|
|
return rewriter.notifyMatchFailure(op, "unimplemented: type promotion");
|
|
|
|
|
|
|
|
auto resultRank = initTensor.getType().cast<RankedTensorType>().getRank();
|
|
|
|
SmallVector<AffineMap> indexingMaps = {
|
|
|
|
// bias is used to initialize the channels - dimension 1 of output
|
|
|
|
AffineMap::get(/*dimCount=*/resultRank, /*symbolCount=*/0,
|
|
|
|
rewriter.getAffineDimExpr(1), context),
|
|
|
|
rewriter.getMultiDimIdentityMap(resultRank)};
|
2022-11-17 06:40:36 +08:00
|
|
|
SmallVector<utils::IteratorType> iteratorTypes(
|
|
|
|
resultRank, utils::IteratorType::parallel);
|
2022-08-04 14:18:38 +08:00
|
|
|
outputTensor = rewriter
|
|
|
|
.create<linalg::GenericOp>(
|
|
|
|
loc, initTensor.getType(), bias, initTensor,
|
|
|
|
indexingMaps, iteratorTypes,
|
|
|
|
[](OpBuilder &b, Location loc, ValueRange args) {
|
|
|
|
b.create<linalg::YieldOp>(loc, args[0]);
|
|
|
|
})
|
|
|
|
.getResult(0);
|
2022-03-11 01:54:13 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
auto stridesAttr = rewriter.getI64VectorAttr(strideInts);
|
|
|
|
auto dilationAttr = rewriter.getI64VectorAttr(dilationInts);
|
2022-04-08 12:47:57 +08:00
|
|
|
|
2022-08-04 14:18:38 +08:00
|
|
|
Value inputStride =
|
|
|
|
rewriter.create<arith::FloorDivSIOp>(loc, inChannels, groups);
|
|
|
|
Value weightStride =
|
|
|
|
rewriter.create<arith::FloorDivSIOp>(loc, weightBatch, groups);
|
|
|
|
|
|
|
|
SmallVector<Value> zeroOffsets(inRank, rewriter.create<arith::ConstantOp>(
|
|
|
|
loc, rewriter.getIndexAttr(0)));
|
|
|
|
SmallVector<Value> unitStrides(inRank, rewriter.create<arith::ConstantOp>(
|
|
|
|
loc, rewriter.getIndexAttr(1)));
|
|
|
|
SmallVector<Value> outDimSlice(outDims);
|
|
|
|
outDimSlice[1] = weightStride;
|
|
|
|
SmallVector<Value> inputSliceSizes{inBatch, inputStride};
|
|
|
|
inputSliceSizes.append(inDims);
|
|
|
|
SmallVector<Value> weightSliceSizes{weightStride, weightChannels};
|
|
|
|
weightSliceSizes.append(weightDims);
|
|
|
|
|
|
|
|
Value conv;
|
|
|
|
if (groupSize == 1) {
|
|
|
|
// TODO: add 1D and 3D case
|
|
|
|
conv =
|
|
|
|
rewriter
|
|
|
|
.create<linalg::Conv2DNchwFchwOp>(
|
|
|
|
loc, outputTensor.getType(), ValueRange{paddedInput, weight},
|
|
|
|
outputTensor, stridesAttr, dilationAttr)
|
|
|
|
.getResult(0);
|
|
|
|
} else {
|
|
|
|
// Special depthwise case
|
2022-11-29 20:33:31 +08:00
|
|
|
auto inShape = makeShapeTorchCompatible(
|
|
|
|
input.getType().cast<RankedTensorType>().getShape());
|
|
|
|
auto weightShape = makeShapeTorchCompatible(
|
|
|
|
weight.getType().cast<RankedTensorType>().getShape());
|
2022-08-04 14:18:38 +08:00
|
|
|
if (weightShape[0] != kUnknownSize && inShape[1] == groupSize &&
|
|
|
|
weightShape[0] % inShape[1] == 0 && weightShape[1] == 1) {
|
|
|
|
// Collapse weight shape
|
|
|
|
SmallVector<ReassociationIndices, 4> collapsedDims = {{0, 1}, {2}, {3}};
|
|
|
|
SmallVector<int64_t> collapsedShape{
|
|
|
|
(weightShape[0] == kUnknownSize ? kUnknownSize
|
|
|
|
: weightShape[0] * weightShape[1]),
|
|
|
|
weightShape[2], weightShape[3]};
|
2022-11-29 20:33:31 +08:00
|
|
|
Type collapsedType = RankedTensorType::get(
|
|
|
|
makeShapeLLVMCompatible(collapsedShape), elementType);
|
2022-08-04 14:18:38 +08:00
|
|
|
Value collapsedWeight = rewriter.create<tensor::CollapseShapeOp>(
|
|
|
|
loc, collapsedType, weight, collapsedDims);
|
|
|
|
|
|
|
|
conv = rewriter
|
|
|
|
.create<linalg::DepthwiseConv2DNchwChwOp>(
|
|
|
|
loc, outputTensor.getType(),
|
|
|
|
ValueRange{paddedInput, collapsedWeight}, outputTensor,
|
|
|
|
stridesAttr, dilationAttr)
|
|
|
|
.getResult(0);
|
|
|
|
|
|
|
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Grouped case, use the grouped conv linalg op
|
|
|
|
auto expandGroups = [&](Value tensor, size_t dim) {
|
|
|
|
auto inType = tensor.getType().cast<RankedTensorType>();
|
2022-11-29 20:33:31 +08:00
|
|
|
auto inShape = makeShapeTorchCompatible(inType.getShape());
|
2022-08-04 14:18:38 +08:00
|
|
|
|
|
|
|
SmallVector<int64_t> outShape;
|
|
|
|
for (auto i = 0; i < (long)inShape.size(); i++) {
|
|
|
|
if (i == 1) {
|
|
|
|
outShape.push_back(groupSize);
|
|
|
|
}
|
|
|
|
if (i == (long)dim) {
|
|
|
|
outShape.push_back(inShape[i] == kUnknownSize
|
|
|
|
? kUnknownSize
|
|
|
|
: inShape[i] / groupSize);
|
|
|
|
} else {
|
|
|
|
outShape.push_back(inShape[i]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<ReassociationIndices> indices;
|
|
|
|
for (auto i = 0; i <= (long)inShape.size(); i++) {
|
|
|
|
if (i == (long)dim) {
|
|
|
|
indices.push_back({i, ++i});
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
indices.push_back({i});
|
|
|
|
}
|
|
|
|
|
2022-11-29 20:33:31 +08:00
|
|
|
auto retType = inType.clone(makeShapeLLVMCompatible(outShape));
|
2022-08-04 14:18:38 +08:00
|
|
|
return rewriter.create<tensor::ExpandShapeOp>(loc, retType, tensor,
|
|
|
|
indices);
|
|
|
|
};
|
|
|
|
|
[TorchToLinalg] Lower grouped conv2d to linalg Op with correct dimension ordering (#2623)
The linalg Op `linalg.conv_2d_ngchw_fgchw` had a bug where
1. Weights were accessed as G,F,C,H,W instead of as F,G,C,H,W
2. Output was accessed as N,F,G,H,W instead of as N,G,F,H,W
Now this has been fixed in
https://github.com/llvm/llvm-project/pull/73855 which broke the
torch-mlir lowering to that Op.
This patch switches lowering in torch-mlir to the newly introduced
`linalg.conv_2d_ngchw_gfchw` op which accesses weights in an order that
is compatible with PyTorch's memory layout.
Fix https://github.com/llvm/torch-mlir/issues/2622
2023-12-08 21:18:23 +08:00
|
|
|
// expand F,C,H,W -> G,F/G,C,H,W
|
2022-08-04 14:18:38 +08:00
|
|
|
auto expandWeight = [&](Value tensor) {
|
|
|
|
auto inType = tensor.getType().cast<RankedTensorType>();
|
2022-11-29 20:33:31 +08:00
|
|
|
auto inShape = makeShapeTorchCompatible(inType.getShape());
|
2022-08-04 14:18:38 +08:00
|
|
|
|
|
|
|
SmallVector<int64_t> outShape{
|
|
|
|
groupSize, (inShape[0] == kUnknownSize ? kUnknownSize
|
|
|
|
: inShape[0] / groupSize)};
|
|
|
|
outShape.append(inShape.begin() + 1, inShape.end());
|
|
|
|
|
|
|
|
SmallVector<ReassociationIndices> indices{{0, 1}};
|
|
|
|
for (auto i = 2; i <= (long)inShape.size(); i++)
|
|
|
|
indices.push_back({i});
|
|
|
|
|
2022-11-29 20:33:31 +08:00
|
|
|
auto retType = inType.clone(makeShapeLLVMCompatible(outShape));
|
2022-08-04 14:18:38 +08:00
|
|
|
return rewriter.create<tensor::ExpandShapeOp>(loc, retType, tensor,
|
|
|
|
indices);
|
|
|
|
};
|
|
|
|
|
|
|
|
Value paddedInputExpanded = expandGroups(paddedInput, 1);
|
|
|
|
Value weightExpanded = expandWeight(weight);
|
[TorchToLinalg] Lower grouped conv2d to linalg Op with correct dimension ordering (#2623)
The linalg Op `linalg.conv_2d_ngchw_fgchw` had a bug where
1. Weights were accessed as G,F,C,H,W instead of as F,G,C,H,W
2. Output was accessed as N,F,G,H,W instead of as N,G,F,H,W
Now this has been fixed in
https://github.com/llvm/llvm-project/pull/73855 which broke the
torch-mlir lowering to that Op.
This patch switches lowering in torch-mlir to the newly introduced
`linalg.conv_2d_ngchw_gfchw` op which accesses weights in an order that
is compatible with PyTorch's memory layout.
Fix https://github.com/llvm/torch-mlir/issues/2622
2023-12-08 21:18:23 +08:00
|
|
|
auto expandOutputTensor = expandGroups(outputTensor, 1);
|
2022-08-04 14:18:38 +08:00
|
|
|
|
|
|
|
// TODO: add 1D and 3D case
|
|
|
|
conv = rewriter
|
[TorchToLinalg] Lower grouped conv2d to linalg Op with correct dimension ordering (#2623)
The linalg Op `linalg.conv_2d_ngchw_fgchw` had a bug where
1. Weights were accessed as G,F,C,H,W instead of as F,G,C,H,W
2. Output was accessed as N,F,G,H,W instead of as N,G,F,H,W
Now this has been fixed in
https://github.com/llvm/llvm-project/pull/73855 which broke the
torch-mlir lowering to that Op.
This patch switches lowering in torch-mlir to the newly introduced
`linalg.conv_2d_ngchw_gfchw` op which accesses weights in an order that
is compatible with PyTorch's memory layout.
Fix https://github.com/llvm/torch-mlir/issues/2622
2023-12-08 21:18:23 +08:00
|
|
|
.create<linalg::Conv2DNgchwGfchwOp>(
|
|
|
|
loc, expandOutputTensor.getResultType(),
|
2022-08-04 14:18:38 +08:00
|
|
|
ValueRange{paddedInputExpanded, weightExpanded},
|
[TorchToLinalg] Lower grouped conv2d to linalg Op with correct dimension ordering (#2623)
The linalg Op `linalg.conv_2d_ngchw_fgchw` had a bug where
1. Weights were accessed as G,F,C,H,W instead of as F,G,C,H,W
2. Output was accessed as N,F,G,H,W instead of as N,G,F,H,W
Now this has been fixed in
https://github.com/llvm/llvm-project/pull/73855 which broke the
torch-mlir lowering to that Op.
This patch switches lowering in torch-mlir to the newly introduced
`linalg.conv_2d_ngchw_gfchw` op which accesses weights in an order that
is compatible with PyTorch's memory layout.
Fix https://github.com/llvm/torch-mlir/issues/2622
2023-12-08 21:18:23 +08:00
|
|
|
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
|
2022-08-04 14:18:38 +08:00
|
|
|
.getResult(0);
|
|
|
|
|
|
|
|
conv = rewriter.create<tensor::CollapseShapeOp>(
|
[TorchToLinalg] Lower grouped conv2d to linalg Op with correct dimension ordering (#2623)
The linalg Op `linalg.conv_2d_ngchw_fgchw` had a bug where
1. Weights were accessed as G,F,C,H,W instead of as F,G,C,H,W
2. Output was accessed as N,F,G,H,W instead of as N,G,F,H,W
Now this has been fixed in
https://github.com/llvm/llvm-project/pull/73855 which broke the
torch-mlir lowering to that Op.
This patch switches lowering in torch-mlir to the newly introduced
`linalg.conv_2d_ngchw_gfchw` op which accesses weights in an order that
is compatible with PyTorch's memory layout.
Fix https://github.com/llvm/torch-mlir/issues/2622
2023-12-08 21:18:23 +08:00
|
|
|
loc, outputTensor.getType(), conv,
|
|
|
|
expandOutputTensor.getReassociationIndices());
|
2022-08-04 14:18:38 +08:00
|
|
|
}
|
2022-04-08 12:47:57 +08:00
|
|
|
|
2022-03-11 01:54:13 +08:00
|
|
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
2022-04-08 12:47:57 +08:00
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
|
2022-03-11 01:54:13 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
|
|
|
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
|
|
|
ConversionTarget &target) {
|
|
|
|
MLIRContext *context = patterns.getContext();
|
|
|
|
target.addIllegalOp<AtenMmOp>();
|
|
|
|
patterns.add<ConvertAtenMmOp>(typeConverter, context);
|
2022-05-03 04:01:15 +08:00
|
|
|
target.addIllegalOp<AtenFlipOp>();
|
|
|
|
patterns.add<ConvertAtenFlipOp>(typeConverter, context);
|
2022-03-11 01:54:13 +08:00
|
|
|
target.addIllegalOp<AtenMatmulOp>();
|
|
|
|
patterns.add<ConvertAtenMatmulOp>(typeConverter, context);
|
|
|
|
target.addIllegalOp<AtenBmmOp>();
|
|
|
|
patterns.add<ConvertAtenBmmOp>(typeConverter, context);
|
2022-04-08 12:47:57 +08:00
|
|
|
target.addIllegalOp<AtenConvolutionOp>();
|
|
|
|
patterns.add<ConvertAtenConvolutionOp>(typeConverter, context);
|
2022-03-11 01:54:13 +08:00
|
|
|
}
|