2022-03-11 01:54:13 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
|
|
|
|
|
|
|
#include "PopulatePatterns.h"
|
2022-10-05 21:28:06 +08:00
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
2022-03-11 01:54:13 +08:00
|
|
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
|
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
|
|
#include "mlir/IR/Matchers.h"
|
2023-12-02 08:38:21 +08:00
|
|
|
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
|
2022-03-11 01:54:13 +08:00
|
|
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
2022-08-25 00:19:35 +08:00
|
|
|
#include <algorithm>
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::torch;
|
|
|
|
using namespace mlir::torch::Torch;
|
|
|
|
|
|
|
|
namespace {
|
2024-01-25 06:02:50 +08:00
|
|
|
|
|
|
|
static void getZeroPoint(Value value, Value &zeropoint) {
|
|
|
|
if (auto make = value.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
|
|
|
|
zeropoint = make.getZeroPoint();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-04-16 07:06:47 +08:00
|
|
|
// for uint8 types, we shift down by 128 so that we can faithfully
|
|
|
|
// represent the quantization with signed i8 types.
|
|
|
|
static void signShift(PatternRewriter &rewriter, Location loc, Value &arg,
|
|
|
|
Value &zp, bool isUnsignedType, int64_t numBits) {
|
|
|
|
if (!isUnsignedType)
|
|
|
|
return;
|
|
|
|
int64_t minSI = -(1 << (numBits - 1));
|
2024-05-01 00:23:09 +08:00
|
|
|
Value minSIValue = rewriter.create<arith::ConstantIntOp>(
|
2024-05-31 14:45:13 +08:00
|
|
|
loc, minSI, cast<mlir::IntegerType>(zp.getType()).getWidth());
|
2024-04-16 07:06:47 +08:00
|
|
|
zp = rewriter.create<arith::AddIOp>(loc, zp, minSIValue);
|
|
|
|
minSIValue = rewriter.create<arith::ConstantIntOp>(loc, minSI, numBits);
|
|
|
|
arg = torch_to_linalg::createElementwiseLinalgGeneric(
|
|
|
|
rewriter, loc, ValueRange{arg},
|
2024-04-28 05:00:56 +08:00
|
|
|
cast<TensorType>(arg.getType()).getElementType(),
|
2024-04-16 07:06:47 +08:00
|
|
|
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
|
|
|
|
Value result =
|
|
|
|
rewriter.create<arith::AddIOp>(loc, payloadArgs[0], minSIValue);
|
|
|
|
b.create<linalg::YieldOp>(loc, result);
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2024-01-31 05:46:47 +08:00
|
|
|
static Value transposeValue(Location loc, Value value, ArrayRef<int64_t> perms,
|
|
|
|
PatternRewriter &rewriter) {
|
2024-04-28 05:00:56 +08:00
|
|
|
auto valueTy = cast<RankedTensorType>(value.getType());
|
2024-01-31 05:46:47 +08:00
|
|
|
auto inShape = valueTy.getShape();
|
|
|
|
llvm::SmallVector<int64_t> outShape;
|
|
|
|
llvm::SmallVector<Value> dynDims;
|
|
|
|
for (size_t i = 0; i < perms.size(); ++i) {
|
|
|
|
outShape.push_back(inShape[perms[i]]);
|
|
|
|
if (ShapedType::isDynamic(inShape[perms[i]])) {
|
|
|
|
dynDims.push_back(rewriter.create<tensor::DimOp>(loc, value, perms[i]));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
auto outTy = RankedTensorType::get(outShape, valueTy.getElementType());
|
|
|
|
Value empty = rewriter.create<tensor::EmptyOp>(loc, outTy, dynDims);
|
|
|
|
Value transpose =
|
|
|
|
rewriter.create<linalg::TransposeOp>(loc, value, empty, perms)
|
|
|
|
->getResult(0);
|
|
|
|
return transpose;
|
|
|
|
}
|
|
|
|
|
2022-03-11 01:54:13 +08:00
|
|
|
class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(AtenMmOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op->getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value lhs = adaptor.getSelf();
|
|
|
|
Value rhs = adaptor.getMat2();
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
// A user can write an errorneous program where `aten.mm` is in fact called
|
|
|
|
// with operands of invalid rank or dtype. We cannot convert to linalg in
|
|
|
|
// this case or we will get a verifier error, which corresponds to breaking
|
|
|
|
// of *internal* compiler invariants, and for a user manifests as a compiler
|
|
|
|
// crash in the worst case (such as we try to canonicalize/fold/print the
|
|
|
|
// invalid op before the verifier gets to see it -- also release builds of a
|
|
|
|
// mature compiler usually have the verifier turned off for compile time
|
|
|
|
// reasons).
|
|
|
|
//
|
|
|
|
// The compiler cannot crash even if the user wrote an erroneous program!
|
|
|
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
|
|
|
return failure();
|
2023-12-07 13:13:53 +08:00
|
|
|
|
2024-04-28 05:00:56 +08:00
|
|
|
RankedTensorType lhsType = cast<RankedTensorType>(lhs.getType());
|
|
|
|
RankedTensorType rhsType = cast<RankedTensorType>(rhs.getType());
|
2023-12-07 13:13:53 +08:00
|
|
|
|
|
|
|
if (lhsType.getRank() != 2 || rhsType.getRank() != 2) {
|
2022-03-11 01:54:13 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "expected both operands to aten.mm to be rank 2");
|
|
|
|
}
|
|
|
|
|
2023-12-07 13:13:53 +08:00
|
|
|
ValueTensorType lhsTorchType =
|
2024-04-28 05:00:56 +08:00
|
|
|
cast<ValueTensorType>(op.getSelf().getType());
|
2023-12-07 13:13:53 +08:00
|
|
|
ValueTensorType rhsTorchType =
|
2024-04-28 05:00:56 +08:00
|
|
|
cast<ValueTensorType>(op.getMat2().getType());
|
2024-01-25 06:02:50 +08:00
|
|
|
|
|
|
|
Value lhsZeroPoint, rhsZeroPoint;
|
|
|
|
getZeroPoint(op.getSelf(), lhsZeroPoint);
|
|
|
|
getZeroPoint(op.getMat2(), rhsZeroPoint);
|
|
|
|
|
2024-04-02 07:21:05 +08:00
|
|
|
if (static_cast<bool>(lhsZeroPoint) != static_cast<bool>(rhsZeroPoint)) {
|
2024-01-25 06:02:50 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unsupported: aten.mm with mixed quantization");
|
|
|
|
}
|
|
|
|
|
2023-12-07 13:13:53 +08:00
|
|
|
if (lhsTorchType.getDtype() != rhsTorchType.getDtype()) {
|
2024-04-11 03:36:58 +08:00
|
|
|
if (!lhsZeroPoint) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unsupported: aten.mm with different input element types");
|
|
|
|
}
|
|
|
|
// Allows quantized types to mismatch since they will be cast to the same
|
|
|
|
// type.
|
2023-12-07 13:13:53 +08:00
|
|
|
}
|
|
|
|
|
2024-01-25 06:02:50 +08:00
|
|
|
bool isUnsigned = torch_to_linalg::isUnsignedTorchType(lhsTorchType);
|
2024-04-11 03:36:58 +08:00
|
|
|
bool isUnsignedR = torch_to_linalg::isUnsignedTorchType(rhsTorchType);
|
2024-01-25 06:02:50 +08:00
|
|
|
|
2022-03-11 01:54:13 +08:00
|
|
|
Value lhsDim0 = rewriter.create<tensor::DimOp>(loc, lhs, 0);
|
|
|
|
Value rhsDim1 = rewriter.create<tensor::DimOp>(loc, rhs, 1);
|
2023-09-30 07:45:48 +08:00
|
|
|
|
|
|
|
if (!isAssumingStrictSymbolicShapes(rewriter)) {
|
|
|
|
Value lhsDim1 = rewriter.create<tensor::DimOp>(loc, lhs, 1);
|
|
|
|
Value rhsDim0 = rewriter.create<tensor::DimOp>(loc, rhs, 0);
|
|
|
|
Value contractingDimEqual = rewriter.create<arith::CmpIOp>(
|
|
|
|
loc, arith::CmpIPredicate::eq, lhsDim1, rhsDim0);
|
|
|
|
rewriter.create<cf::AssertOp>(
|
|
|
|
loc, contractingDimEqual,
|
|
|
|
rewriter.getStringAttr(
|
|
|
|
"mismatching contracting dimension for torch.aten.mm"));
|
|
|
|
}
|
2022-03-11 01:54:13 +08:00
|
|
|
|
2024-06-08 09:36:32 +08:00
|
|
|
TensorType resultType =
|
|
|
|
cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
|
|
|
Type elementType = resultType.getElementType();
|
2024-06-21 04:54:20 +08:00
|
|
|
auto accumulatorDType =
|
|
|
|
getDefaultAccType(rewriter, lhsType.getElementType());
|
2024-06-08 09:36:32 +08:00
|
|
|
if (accumulatorDType != resultType.getElementType()) {
|
2024-03-15 07:40:40 +08:00
|
|
|
elementType = accumulatorDType;
|
|
|
|
}
|
2023-12-07 13:13:53 +08:00
|
|
|
Value zeroFill = createZeroInitTensor(
|
|
|
|
rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType);
|
|
|
|
|
|
|
|
Value matmul;
|
2024-04-11 03:36:58 +08:00
|
|
|
if (lhsZeroPoint) {
|
2024-01-25 06:02:50 +08:00
|
|
|
lhsZeroPoint = typeConverter->materializeTargetConversion(
|
|
|
|
rewriter, loc,
|
|
|
|
getTypeConverter()->convertType(lhsZeroPoint.getType()),
|
|
|
|
lhsZeroPoint);
|
|
|
|
rhsZeroPoint = typeConverter->materializeTargetConversion(
|
|
|
|
rewriter, loc,
|
|
|
|
getTypeConverter()->convertType(rhsZeroPoint.getType()),
|
|
|
|
rhsZeroPoint);
|
|
|
|
lhsZeroPoint = rewriter.create<arith::TruncIOp>(
|
|
|
|
loc, rewriter.getI32Type(), lhsZeroPoint);
|
|
|
|
rhsZeroPoint = rewriter.create<arith::TruncIOp>(
|
|
|
|
loc, rewriter.getI32Type(), rhsZeroPoint);
|
2024-04-11 03:36:58 +08:00
|
|
|
|
2024-04-16 07:06:47 +08:00
|
|
|
// change uint8 quantization -> int8 quantization
|
2024-04-11 03:36:58 +08:00
|
|
|
int64_t numBits =
|
2024-04-28 05:00:56 +08:00
|
|
|
cast<mlir::IntegerType>(lhsType.getElementType()).getWidth();
|
2024-04-16 07:06:47 +08:00
|
|
|
signShift(rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits);
|
2024-04-28 05:00:56 +08:00
|
|
|
numBits = cast<mlir::IntegerType>(rhsType.getElementType()).getWidth();
|
2024-04-16 07:06:47 +08:00
|
|
|
signShift(rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits);
|
2024-04-11 03:36:58 +08:00
|
|
|
|
2024-01-25 06:02:50 +08:00
|
|
|
matmul =
|
|
|
|
rewriter
|
|
|
|
.create<linalg::QuantizedMatmulOp>(
|
|
|
|
loc, zeroFill.getType(),
|
|
|
|
ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint}, zeroFill)
|
|
|
|
.getResult(0);
|
|
|
|
} else if (isUnsigned) {
|
2024-08-03 02:32:24 +08:00
|
|
|
auto matmulOp = rewriter.create<linalg::MatmulOp>(
|
|
|
|
loc, zeroFill.getType(), ValueRange{lhs, rhs}, zeroFill);
|
|
|
|
matmulOp.setCast(linalg::TypeFn::cast_unsigned);
|
|
|
|
matmul = matmulOp->getResult(0);
|
2023-12-07 13:13:53 +08:00
|
|
|
} else {
|
|
|
|
matmul = rewriter
|
|
|
|
.create<linalg::MatmulOp>(loc, zeroFill.getType(),
|
|
|
|
ValueRange{lhs, rhs}, zeroFill)
|
|
|
|
.getResult(0);
|
|
|
|
}
|
2024-03-15 07:40:40 +08:00
|
|
|
|
2024-06-08 09:36:32 +08:00
|
|
|
if (accumulatorDType != resultType.getElementType()) {
|
2024-03-15 07:40:40 +08:00
|
|
|
matmul = torch_to_linalg::convertTensorToElementType(
|
2024-06-08 09:36:32 +08:00
|
|
|
rewriter, loc, matmul, resultType.getElementType());
|
2024-03-15 07:40:40 +08:00
|
|
|
}
|
2022-10-18 12:22:53 +08:00
|
|
|
// When constructed with just dynamic sizes, EmptyOp will have a result
|
2022-03-11 01:54:13 +08:00
|
|
|
// type which has all `?`'s for dimensions, which might not be the result
|
|
|
|
// type of `op`. The constraints on later linalg ops means that the result
|
|
|
|
// of the MatmulOp will have this type too. So cast it to the desired type
|
|
|
|
// so that in the end we have the original result type.
|
2024-06-08 09:36:32 +08:00
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, matmul);
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-05-03 04:01:15 +08:00
|
|
|
namespace {
|
|
|
|
class ConvertAtenFlipOp : public OpConversionPattern<AtenFlipOp> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(AtenFlipOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
Location loc = op->getLoc();
|
|
|
|
MLIRContext *context = op.getContext();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value self = adaptor.getSelf();
|
2024-01-30 01:59:33 +08:00
|
|
|
auto selfRank =
|
2024-04-28 05:00:56 +08:00
|
|
|
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
|
2022-05-03 04:01:15 +08:00
|
|
|
Type elementType =
|
2024-04-28 05:00:56 +08:00
|
|
|
cast<RankedTensorType>(adaptor.getSelf().getType()).getElementType();
|
2022-05-03 04:01:15 +08:00
|
|
|
Value c1 =
|
|
|
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
|
|
|
|
|
|
|
SmallVector<int64_t> axis;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(axis)))
|
2022-05-03 04:01:15 +08:00
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"only constant dim lists supported");
|
2023-06-15 10:27:34 +08:00
|
|
|
for (unsigned i = 0, e = axis.size(); i < e; i++) {
|
|
|
|
axis[i] = toPositiveDim(axis[i], selfRank);
|
|
|
|
if (!isValidDim(axis[i], selfRank)) {
|
|
|
|
return rewriter.notifyMatchFailure(op, "axis is statically invalid");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-05-03 04:01:15 +08:00
|
|
|
// Only used to calculate flipped values, i.e. those on the flip axes. Other
|
|
|
|
// dims won't be used.
|
|
|
|
SmallVector<Value> dims = getTensorSizes(rewriter, loc, self);
|
|
|
|
for (auto flipDim : axis)
|
|
|
|
dims[flipDim] = rewriter.create<arith::SubIOp>(loc, dims[flipDim], c1);
|
|
|
|
|
|
|
|
Value initTensor = createZeroInitTensor(
|
|
|
|
rewriter, loc, getTensorSizes(rewriter, loc, self), elementType);
|
|
|
|
|
2022-11-17 06:40:36 +08:00
|
|
|
SmallVector<utils::IteratorType> iteratorTypes(
|
|
|
|
selfRank, utils::IteratorType::parallel);
|
2022-05-03 04:01:15 +08:00
|
|
|
SmallVector<AffineMap> indexingMaps(
|
|
|
|
2, AffineMap::getMultiDimIdentityMap(selfRank, context));
|
|
|
|
Value flipped =
|
|
|
|
rewriter
|
|
|
|
.create<linalg::GenericOp>(
|
|
|
|
loc, self.getType(), self, initTensor, indexingMaps,
|
|
|
|
iteratorTypes,
|
|
|
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
|
|
|
SmallVector<Value> indices;
|
|
|
|
for (auto i = 0; i < selfRank; i++)
|
|
|
|
indices.push_back(b.create<linalg::IndexOp>(loc, i));
|
|
|
|
for (auto flipDim : axis) {
|
|
|
|
indices[flipDim] = b.create<arith::SubIOp>(
|
|
|
|
loc, dims[flipDim], indices[flipDim]);
|
|
|
|
}
|
|
|
|
Value res = b.create<tensor::ExtractOp>(loc, self, indices)
|
|
|
|
.getResult();
|
|
|
|
b.create<linalg::YieldOp>(loc, res);
|
|
|
|
})
|
|
|
|
.getResult(0);
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, self.getType(), flipped);
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-03-11 01:54:13 +08:00
|
|
|
namespace {
|
|
|
|
class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(AtenMatmulOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op->getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value lhs = adaptor.getSelf();
|
|
|
|
Value rhs = adaptor.getOther();
|
2022-03-11 01:54:13 +08:00
|
|
|
|
2023-12-21 02:09:10 +08:00
|
|
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) {
|
2022-03-11 01:54:13 +08:00
|
|
|
return failure();
|
2023-12-21 02:09:10 +08:00
|
|
|
}
|
2024-04-28 05:00:56 +08:00
|
|
|
auto lhsType = cast<RankedTensorType>(lhs.getType());
|
|
|
|
auto rhsType = cast<RankedTensorType>(rhs.getType());
|
2022-03-11 01:54:13 +08:00
|
|
|
|
2024-04-16 07:06:47 +08:00
|
|
|
auto lhsTorchType = cast<ValueTensorType>(op.getSelf().getType());
|
|
|
|
auto rhsTorchType = cast<ValueTensorType>(op.getOther().getType());
|
|
|
|
|
2022-06-16 23:45:10 +08:00
|
|
|
// Get the rank of both matrix.
|
|
|
|
unsigned lhsRank = lhsType.getRank();
|
|
|
|
unsigned rhsRank = rhsType.getRank();
|
2022-03-11 01:54:13 +08:00
|
|
|
|
2024-04-16 07:06:47 +08:00
|
|
|
Value lhsZeroPoint, rhsZeroPoint;
|
|
|
|
getZeroPoint(op.getSelf(), lhsZeroPoint);
|
|
|
|
getZeroPoint(op.getOther(), rhsZeroPoint);
|
|
|
|
|
|
|
|
if (static_cast<bool>(lhsZeroPoint) != static_cast<bool>(rhsZeroPoint)) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unsupported: aten.matmul with mixed quantization");
|
|
|
|
}
|
|
|
|
|
|
|
|
bool isUnsigned = torch_to_linalg::isUnsignedTorchType(lhsTorchType);
|
|
|
|
bool isUnsignedR = torch_to_linalg::isUnsignedTorchType(rhsTorchType);
|
|
|
|
|
|
|
|
if (!lhsZeroPoint && lhsTorchType.getDtype() != rhsTorchType.getDtype()) {
|
|
|
|
// Allows quantized types to mismatch
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unsupported: aten.matmul with different input element types");
|
|
|
|
}
|
|
|
|
|
2024-04-17 00:28:28 +08:00
|
|
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
|
|
|
auto resultType = cast<RankedTensorType>(newResultType);
|
|
|
|
Type elementType = resultType.getElementType();
|
|
|
|
|
2024-04-16 07:06:47 +08:00
|
|
|
if (lhsZeroPoint) {
|
2024-04-17 00:28:28 +08:00
|
|
|
// get each zero point ready to pass to a quantized_matmul
|
2024-04-16 07:06:47 +08:00
|
|
|
lhsZeroPoint = typeConverter->materializeTargetConversion(
|
|
|
|
rewriter, loc,
|
|
|
|
getTypeConverter()->convertType(lhsZeroPoint.getType()),
|
|
|
|
lhsZeroPoint);
|
|
|
|
rhsZeroPoint = typeConverter->materializeTargetConversion(
|
|
|
|
rewriter, loc,
|
|
|
|
getTypeConverter()->convertType(rhsZeroPoint.getType()),
|
|
|
|
rhsZeroPoint);
|
|
|
|
lhsZeroPoint = rewriter.create<arith::TruncIOp>(
|
|
|
|
loc, rewriter.getI32Type(), lhsZeroPoint);
|
|
|
|
rhsZeroPoint = rewriter.create<arith::TruncIOp>(
|
|
|
|
loc, rewriter.getI32Type(), rhsZeroPoint);
|
|
|
|
|
|
|
|
// change uint8 quantization -> int8 quantization
|
|
|
|
int64_t numBits =
|
2024-04-28 05:00:56 +08:00
|
|
|
cast<mlir::IntegerType>(lhsType.getElementType()).getWidth();
|
2024-04-16 07:06:47 +08:00
|
|
|
signShift(rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits);
|
2024-04-28 05:00:56 +08:00
|
|
|
numBits = cast<mlir::IntegerType>(rhsType.getElementType()).getWidth();
|
2024-04-16 07:06:47 +08:00
|
|
|
signShift(rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits);
|
|
|
|
|
2024-04-17 00:28:28 +08:00
|
|
|
// for quantized vec-vec, vec-mat, and mat-vec cases, lower to
|
|
|
|
// expand/collapse + quantized_matmul
|
|
|
|
bool lhsVec = (lhsRank == 1 && rhsRank <= 2);
|
|
|
|
bool rhsVec = (lhsRank <= 2 && rhsRank == 1);
|
|
|
|
|
|
|
|
if (lhsVec || rhsVec) {
|
|
|
|
SmallVector<ReassociationIndices> reassociation(1);
|
|
|
|
reassociation[0].push_back(0);
|
|
|
|
reassociation[0].push_back(1);
|
|
|
|
|
|
|
|
if (lhsVec) {
|
|
|
|
// unsqueeze lhs to a matrix
|
|
|
|
int64_t lhsDim = lhsType.getShape()[0];
|
|
|
|
auto lhsUnsqueezeType = RankedTensorType::get(
|
|
|
|
ArrayRef<int64_t>{1, lhsDim}, lhsType.getElementType());
|
|
|
|
lhs = rewriter.create<tensor::ExpandShapeOp>(loc, lhsUnsqueezeType,
|
|
|
|
lhs, reassociation);
|
|
|
|
}
|
|
|
|
if (rhsVec) {
|
|
|
|
// unsqueeze rhs to a matrix
|
|
|
|
int64_t rhsDim = rhsType.getShape()[0];
|
|
|
|
auto rhsUnsqueezeType = RankedTensorType::get(
|
|
|
|
ArrayRef<int64_t>{rhsDim, 1}, rhsType.getElementType());
|
|
|
|
rhs = rewriter.create<tensor::ExpandShapeOp>(loc, rhsUnsqueezeType,
|
|
|
|
rhs, reassociation);
|
|
|
|
}
|
|
|
|
// get quantized_matmul and squeeze result
|
|
|
|
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
|
|
|
|
Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1);
|
|
|
|
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
|
|
|
|
Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1);
|
|
|
|
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
|
|
|
|
|
|
|
|
Value zeroTensor = createZeroInitTensor(
|
|
|
|
rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType);
|
|
|
|
Value matmul = rewriter
|
|
|
|
.create<linalg::QuantizedMatmulOp>(
|
|
|
|
loc, zeroTensor.getType(),
|
|
|
|
ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint},
|
|
|
|
zeroTensor)
|
|
|
|
.getResult(0);
|
|
|
|
int64_t resultRank = resultType.getRank();
|
|
|
|
if (resultRank == 0) {
|
|
|
|
// in vec-vec case, need to collapse result to a scalar
|
|
|
|
reassociation.clear();
|
|
|
|
}
|
|
|
|
matmul = rewriter.create<tensor::CollapseShapeOp>(
|
|
|
|
loc, resultType, matmul, reassociation);
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
// the remaining quantized cases (Mat-Mat and broadcast -> BMM) are
|
|
|
|
// covered in the relevant section below
|
|
|
|
}
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
// The different cases of torch_matmul op is mentioned here:
|
|
|
|
// https://pytorch.org/docs/stable/generated/torch.matmul.html
|
|
|
|
|
|
|
|
// First Case: Dot Product.
|
|
|
|
if (lhsRank == 1 && rhsRank == 1) {
|
|
|
|
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
|
|
|
|
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
|
|
|
|
|
|
|
|
checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0);
|
|
|
|
|
|
|
|
Value zeroTensor = createZeroInitTensor(rewriter, loc, {}, elementType);
|
|
|
|
Value dotProd =
|
|
|
|
rewriter
|
|
|
|
.create<linalg::DotOp>(loc, zeroTensor.getType(),
|
|
|
|
ValueRange{lhs, rhs}, zeroTensor)
|
|
|
|
.getResult(0);
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, dotProd);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Second Case: Vec-Mat Multiplication.
|
|
|
|
if (lhsRank == 1 && rhsRank == 2) {
|
|
|
|
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
|
|
|
|
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
|
|
|
|
Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1);
|
|
|
|
checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0);
|
|
|
|
|
|
|
|
Value zeroTensor =
|
|
|
|
createZeroInitTensor(rewriter, loc, ValueRange{rhsDim1}, elementType);
|
|
|
|
Value matmul =
|
|
|
|
rewriter
|
|
|
|
.create<linalg::VecmatOp>(loc, zeroTensor.getType(),
|
|
|
|
ValueRange{lhs, rhs}, zeroTensor)
|
|
|
|
.getResult(0);
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Third Case: Matrix-Vec Multiplication.
|
|
|
|
if (lhsRank == 2 && rhsRank == 1) {
|
|
|
|
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
|
|
|
|
Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1);
|
|
|
|
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
|
|
|
|
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
|
|
|
|
|
|
|
|
Value zeroTensor =
|
|
|
|
createZeroInitTensor(rewriter, loc, ValueRange{lhsDim0}, elementType);
|
|
|
|
Value matmul =
|
|
|
|
rewriter
|
|
|
|
.create<linalg::MatvecOp>(loc, zeroTensor.getType(),
|
|
|
|
ValueRange{lhs, rhs}, zeroTensor)
|
|
|
|
.getResult(0);
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2024-04-16 07:06:47 +08:00
|
|
|
// Fourth Case: Mat-Mat Multiplication.
|
2023-12-21 02:09:10 +08:00
|
|
|
if (lhsRank == 2 && rhsRank == 2) {
|
|
|
|
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
|
|
|
|
Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1);
|
|
|
|
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
|
|
|
|
Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1);
|
|
|
|
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
|
|
|
|
|
|
|
|
Value zeroTensor = createZeroInitTensor(
|
|
|
|
rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType);
|
2024-04-16 07:06:47 +08:00
|
|
|
Value matmul;
|
|
|
|
if (lhsZeroPoint) {
|
|
|
|
matmul = rewriter
|
|
|
|
.create<linalg::QuantizedMatmulOp>(
|
|
|
|
loc, zeroTensor.getType(),
|
|
|
|
ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint},
|
|
|
|
zeroTensor)
|
|
|
|
.getResult(0);
|
|
|
|
} else {
|
|
|
|
matmul = rewriter
|
|
|
|
.create<linalg::MatmulOp>(loc, zeroTensor.getType(),
|
|
|
|
ValueRange{lhs, rhs}, zeroTensor)
|
|
|
|
.getResult(0);
|
|
|
|
}
|
2023-12-21 02:09:10 +08:00
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Fifth Case: Batch-Matrix Multiplication.
|
2022-06-16 23:45:10 +08:00
|
|
|
// TODO: Handle batch matrix multiplication when one of the matrix is unity
|
|
|
|
// rank and the other has batch dimension.
|
|
|
|
if (lhsRank > 1 && rhsRank > 1) {
|
|
|
|
unsigned maxRank = std::max(lhsRank, rhsRank);
|
|
|
|
unsigned minRank = std::min(lhsRank, rhsRank);
|
|
|
|
unsigned batchRank = maxRank - 2;
|
|
|
|
|
|
|
|
// At least one of the matrix must have rank greater than 2.
|
|
|
|
if (batchRank <= 0) {
|
|
|
|
return rewriter.notifyMatchFailure(op, "expected batch dimensions");
|
|
|
|
}
|
|
|
|
|
|
|
|
// The `broadcastedBatchShape` contains batch dimensions of the resultant
|
|
|
|
// matrix.
|
|
|
|
SmallVector<Value> broadcastedBatchShape(batchRank);
|
|
|
|
Value maxRankMatrix = (lhsRank > rhsRank) ? lhs : rhs;
|
|
|
|
Value maxDim;
|
|
|
|
// Compute broadcasted batch dimensions if the batch dimensions of
|
|
|
|
// the matrices are broadcastable.
|
|
|
|
for (unsigned i = 1; i <= batchRank; i++) {
|
|
|
|
if (i <= minRank - 2) {
|
|
|
|
Value lhsDim = getDimOp(rewriter, loc, lhs, lhsRank - 2 - i);
|
|
|
|
Value rhsDim = getDimOp(rewriter, loc, rhs, rhsRank - 2 - i);
|
|
|
|
maxDim = rewriter.createOrFold<arith::MaxUIOp>(loc, lhsDim, rhsDim);
|
|
|
|
} else {
|
|
|
|
maxDim = getDimOp(rewriter, loc, maxRankMatrix, maxRank - 2 - i);
|
|
|
|
}
|
|
|
|
broadcastedBatchShape[batchRank - i] = maxDim;
|
|
|
|
}
|
2022-03-11 01:54:13 +08:00
|
|
|
|
2022-06-16 23:45:10 +08:00
|
|
|
Value lhsDim0 = getDimOp(rewriter, loc, lhs, lhsRank - 2);
|
|
|
|
Value lhsDim1 = getDimOp(rewriter, loc, lhs, lhsRank - 1);
|
|
|
|
Value rhsDim0 = getDimOp(rewriter, loc, rhs, rhsRank - 2);
|
|
|
|
Value rhsDim1 = getDimOp(rewriter, loc, rhs, rhsRank - 1);
|
|
|
|
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
|
|
|
|
|
|
|
|
// Compute broadcasted shape of both the matrices in integer format.
|
|
|
|
SmallVector<Value> lhsBroadcastToShape(broadcastedBatchShape);
|
|
|
|
lhsBroadcastToShape.push_back(lhsDim0);
|
|
|
|
lhsBroadcastToShape.push_back(lhsDim1);
|
|
|
|
SmallVector<Value> rhsBroadcastToShape(broadcastedBatchShape);
|
|
|
|
rhsBroadcastToShape.push_back(rhsDim0);
|
|
|
|
rhsBroadcastToShape.push_back(rhsDim1);
|
|
|
|
for (unsigned i = 0; i < maxRank; i++) {
|
|
|
|
lhsBroadcastToShape[i] =
|
|
|
|
castIndexToInt64(rewriter, loc, lhsBroadcastToShape[i]);
|
|
|
|
rhsBroadcastToShape[i] =
|
|
|
|
castIndexToInt64(rewriter, loc, rhsBroadcastToShape[i]);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Broadcast the batch dimensions of both the matrices.
|
|
|
|
Value broadcastedLhs, broadcastedRhs;
|
2023-10-06 03:15:26 +08:00
|
|
|
// TODO: Improve usage of static shape information.
|
|
|
|
SmallVector<int64_t> lhsTargetShape(lhsBroadcastToShape.size(),
|
|
|
|
ShapedType::kDynamic);
|
2024-01-27 02:54:59 +08:00
|
|
|
auto lhsBroadcastType = RankedTensorType::get(
|
|
|
|
lhsTargetShape, lhsType.getElementType(), lhsType.getEncoding());
|
2022-06-16 23:45:10 +08:00
|
|
|
if (failed(torch_to_linalg::broadcastToGivenShape(
|
2023-10-06 03:15:26 +08:00
|
|
|
op, rewriter, lhs, lhsBroadcastToShape, lhsBroadcastType,
|
|
|
|
broadcastedLhs))) {
|
2022-06-16 23:45:10 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unable to perform broadcast operation");
|
|
|
|
}
|
2023-10-06 03:15:26 +08:00
|
|
|
SmallVector<int64_t> rhsTargetShape(rhsBroadcastToShape.size(),
|
|
|
|
ShapedType::kDynamic);
|
2024-01-27 02:54:59 +08:00
|
|
|
auto rhsBroadcastType = RankedTensorType::get(
|
|
|
|
rhsTargetShape, rhsType.getElementType(), rhsType.getEncoding());
|
2022-06-16 23:45:10 +08:00
|
|
|
if (failed(torch_to_linalg::broadcastToGivenShape(
|
2023-10-06 03:15:26 +08:00
|
|
|
op, rewriter, rhs, rhsBroadcastToShape, rhsBroadcastType,
|
|
|
|
broadcastedRhs))) {
|
2022-06-16 23:45:10 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unable to perform broadcast operation");
|
|
|
|
}
|
|
|
|
|
2022-05-19 00:29:04 +08:00
|
|
|
if (maxRank == 3) {
|
|
|
|
Value zeroTensor = createZeroInitTensor(
|
|
|
|
rewriter, loc,
|
|
|
|
ValueRange{broadcastedBatchShape[0], lhsDim0, rhsDim1},
|
|
|
|
elementType);
|
2024-04-16 07:06:47 +08:00
|
|
|
Value matmul;
|
|
|
|
if (lhsZeroPoint) {
|
|
|
|
matmul = rewriter
|
|
|
|
.create<linalg::QuantizedBatchMatmulOp>(
|
|
|
|
loc, zeroTensor.getType(),
|
|
|
|
ValueRange{broadcastedLhs, broadcastedRhs,
|
|
|
|
lhsZeroPoint, rhsZeroPoint},
|
|
|
|
zeroTensor)
|
|
|
|
.getResult(0);
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType,
|
|
|
|
matmul);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
matmul = rewriter
|
|
|
|
.create<linalg::BatchMatmulOp>(
|
|
|
|
loc, zeroTensor.getType(),
|
|
|
|
ValueRange{broadcastedLhs, broadcastedRhs}, zeroTensor)
|
|
|
|
.getResult(0);
|
2022-05-19 00:29:04 +08:00
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-06-16 23:45:10 +08:00
|
|
|
// Check if the result of the matrix multiplication has more than one
|
|
|
|
// dynamic batch dimensions.
|
2022-11-29 20:33:31 +08:00
|
|
|
SmallVector<int64_t> batchDimsInt =
|
|
|
|
makeShapeTorchCompatible(resultType.getShape());
|
|
|
|
batchDimsInt.pop_back();
|
|
|
|
batchDimsInt.pop_back();
|
2022-06-16 23:45:10 +08:00
|
|
|
bool multipleDynamicBatchDims =
|
|
|
|
llvm::count(batchDimsInt, kUnknownSize) > 1;
|
|
|
|
|
|
|
|
// TODO: Lowering to `linalg.BatchMatmul` is only possible when there is
|
|
|
|
// at most one dynamic batch dimension due to limited support of the
|
|
|
|
// `tensor.ExpandShape` op.
|
|
|
|
if (!multipleDynamicBatchDims) {
|
|
|
|
// Collapse the batch dimensions into one dimension. The resultant rank
|
|
|
|
// will always be 3.
|
|
|
|
SmallVector<ReassociationIndices> reassociation(3);
|
|
|
|
for (unsigned i = 0, j = 0; i < maxRank; i++) {
|
|
|
|
if (i >= batchRank)
|
|
|
|
j++;
|
|
|
|
reassociation[j].push_back(i);
|
|
|
|
}
|
|
|
|
Value collapsedLhs = rewriter.create<tensor::CollapseShapeOp>(
|
|
|
|
op->getLoc(), broadcastedLhs, reassociation);
|
|
|
|
Value collapsedRhs = rewriter.create<tensor::CollapseShapeOp>(
|
|
|
|
op->getLoc(), broadcastedRhs, reassociation);
|
|
|
|
|
|
|
|
// Compute the result shape after collapsing the batch dimensions.
|
|
|
|
SmallVector<Value> collapsedResultShape;
|
|
|
|
collapsedResultShape.push_back(broadcastedBatchShape[0]);
|
|
|
|
for (unsigned i = 1; i < batchRank; i++) {
|
|
|
|
collapsedResultShape[0] = rewriter.createOrFold<arith::MulIOp>(
|
|
|
|
loc, collapsedResultShape[0], broadcastedBatchShape[i]);
|
|
|
|
}
|
|
|
|
collapsedResultShape.push_back(lhsDim0);
|
|
|
|
collapsedResultShape.push_back(rhsDim1);
|
|
|
|
SmallVector<OpFoldResult> updatedCollapseResultShape =
|
|
|
|
getAsOpFoldResult(collapsedResultShape);
|
|
|
|
|
2022-10-18 12:22:53 +08:00
|
|
|
Value initTensor = rewriter.create<tensor::EmptyOp>(
|
2022-06-16 23:45:10 +08:00
|
|
|
loc, updatedCollapseResultShape, elementType);
|
|
|
|
Value c0 = rewriter.create<arith::ConstantOp>(
|
|
|
|
loc, rewriter.getZeroAttr(elementType));
|
|
|
|
Value zeroTensor =
|
|
|
|
rewriter.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
|
2024-04-16 07:06:47 +08:00
|
|
|
Value batchMatMul;
|
|
|
|
|
|
|
|
if (lhsZeroPoint) {
|
|
|
|
batchMatMul = rewriter
|
|
|
|
.create<linalg::QuantizedBatchMatmulOp>(
|
|
|
|
loc, zeroTensor.getType(),
|
|
|
|
ValueRange{collapsedLhs, collapsedRhs,
|
|
|
|
lhsZeroPoint, rhsZeroPoint},
|
|
|
|
zeroTensor)
|
|
|
|
.getResult(0);
|
|
|
|
} else {
|
|
|
|
batchMatMul =
|
|
|
|
rewriter
|
|
|
|
.create<linalg::BatchMatmulOp>(
|
|
|
|
loc, zeroTensor.getType(),
|
|
|
|
ValueRange{collapsedLhs, collapsedRhs}, zeroTensor)
|
|
|
|
.getResult(0);
|
|
|
|
}
|
2022-06-16 23:45:10 +08:00
|
|
|
Value expandResult = rewriter.create<tensor::ExpandShapeOp>(
|
|
|
|
loc, resultType, batchMatMul, reassociation);
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType,
|
|
|
|
expandResult);
|
|
|
|
return success();
|
|
|
|
}
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
SmallVector<AffineExpr> lhsExpr;
|
|
|
|
SmallVector<AffineExpr> rhsExpr;
|
|
|
|
SmallVector<AffineExpr> outExpr;
|
2022-11-17 06:40:36 +08:00
|
|
|
SmallVector<utils::IteratorType> iteratorTypes(
|
|
|
|
batchRank, utils::IteratorType::parallel);
|
2022-03-11 01:54:13 +08:00
|
|
|
for (unsigned i = 0; i < batchRank; i++) {
|
|
|
|
lhsExpr.push_back(rewriter.getAffineDimExpr(i));
|
|
|
|
rhsExpr.push_back(rewriter.getAffineDimExpr(i));
|
|
|
|
outExpr.push_back(rewriter.getAffineDimExpr(i));
|
|
|
|
}
|
|
|
|
lhsExpr.insert(lhsExpr.end(), {rewriter.getAffineDimExpr(batchRank),
|
|
|
|
rewriter.getAffineDimExpr(batchRank + 1)});
|
|
|
|
rhsExpr.insert(rhsExpr.end(), {rewriter.getAffineDimExpr(batchRank + 1),
|
|
|
|
rewriter.getAffineDimExpr(batchRank + 2)});
|
|
|
|
outExpr.insert(outExpr.end(), {rewriter.getAffineDimExpr(batchRank),
|
|
|
|
rewriter.getAffineDimExpr(batchRank + 2)});
|
|
|
|
|
2022-06-16 23:45:10 +08:00
|
|
|
SmallVector<Value> resultShape(broadcastedBatchShape);
|
|
|
|
resultShape.insert(resultShape.end(), {lhsDim0, rhsDim1});
|
|
|
|
Value zeroTensor =
|
2022-03-11 01:54:13 +08:00
|
|
|
createZeroInitTensor(rewriter, loc, resultShape, elementType);
|
2024-02-10 06:07:49 +08:00
|
|
|
auto indexingMaps = AffineMap::inferFromExprList(
|
|
|
|
{lhsExpr, rhsExpr, outExpr}, rewriter.getContext());
|
2022-03-11 01:54:13 +08:00
|
|
|
iteratorTypes.insert(iteratorTypes.end(),
|
2022-11-17 06:40:36 +08:00
|
|
|
{utils::IteratorType::parallel,
|
|
|
|
utils::IteratorType::reduction,
|
|
|
|
utils::IteratorType::parallel});
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
Value finalRes =
|
|
|
|
rewriter
|
|
|
|
.create<linalg::GenericOp>(
|
2022-06-16 23:45:10 +08:00
|
|
|
loc, zeroTensor.getType(),
|
|
|
|
ValueRange{broadcastedLhs, broadcastedRhs}, zeroTensor,
|
2022-03-11 01:54:13 +08:00
|
|
|
/*indexingMaps=*/indexingMaps,
|
|
|
|
/*iteratorTypes=*/iteratorTypes,
|
|
|
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
|
|
|
Value l = args[0], r = args[1], res = args[2];
|
|
|
|
Value mul = b.create<arith::MulFOp>(loc, l, r);
|
|
|
|
Value add = b.create<arith::AddFOp>(loc, mul, res);
|
|
|
|
b.create<linalg::YieldOp>(loc, add);
|
|
|
|
})
|
|
|
|
.getResult(0);
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, finalRes);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
class ConvertAtenBmmOp : public OpConversionPattern<AtenBmmOp> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(AtenBmmOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
|
|
|
return failure();
|
|
|
|
Location loc = op->getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value lhs = adaptor.getSelf();
|
|
|
|
Value rhs = adaptor.getMat2();
|
2024-04-28 05:00:56 +08:00
|
|
|
RankedTensorType lhsType = cast<RankedTensorType>(lhs.getType());
|
|
|
|
RankedTensorType rhsType = cast<RankedTensorType>(rhs.getType());
|
2023-09-11 20:58:59 +08:00
|
|
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
2024-01-30 01:59:33 +08:00
|
|
|
Type resultElementType =
|
2024-04-11 21:47:35 +08:00
|
|
|
cast<RankedTensorType>(newResultType).getElementType();
|
|
|
|
Type lhsElementType = cast<RankedTensorType>(lhsType).getElementType();
|
|
|
|
Type rhsElementType = cast<RankedTensorType>(rhsType).getElementType();
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
if (lhsType.getRank() != 3 || rhsType.getRank() != 3) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "expected both operands to aten.bmm to be rank 3");
|
|
|
|
}
|
2023-09-11 20:58:59 +08:00
|
|
|
|
|
|
|
// Convert the inputs element type equivalent to the result' element type.
|
|
|
|
if (lhsElementType != rhsElementType) {
|
|
|
|
if (lhsElementType != resultElementType) {
|
2024-01-30 01:59:33 +08:00
|
|
|
// True if the lhs element type is not equal to the result' element
|
|
|
|
// type.
|
|
|
|
lhs = torch_to_linalg::convertTensorToElementType(rewriter, loc, lhs,
|
|
|
|
resultElementType);
|
2023-09-11 20:58:59 +08:00
|
|
|
} else {
|
2024-01-30 01:59:33 +08:00
|
|
|
// True if the rhs element type is not equal to the result' element
|
|
|
|
// type.
|
|
|
|
rhs = torch_to_linalg::convertTensorToElementType(rewriter, loc, rhs,
|
|
|
|
resultElementType);
|
2023-09-11 20:58:59 +08:00
|
|
|
}
|
|
|
|
}
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
|
|
|
|
Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1);
|
|
|
|
Value lhsDim2 = getDimOp(rewriter, loc, lhs, 2);
|
|
|
|
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
|
|
|
|
Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1);
|
|
|
|
Value rhsDim2 = getDimOp(rewriter, loc, rhs, 2);
|
|
|
|
|
|
|
|
// Check the batch numbers are equal.
|
|
|
|
checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0);
|
|
|
|
|
|
|
|
// Check the matrixs shapes are valid for mulplication.
|
|
|
|
checkDimEqualHelper(rewriter, loc, lhsDim2, rhsDim1);
|
|
|
|
|
|
|
|
Value initTensor0 = createZeroInitTensor(
|
2024-01-30 01:59:33 +08:00
|
|
|
rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2},
|
|
|
|
resultElementType);
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
Value bmm =
|
|
|
|
rewriter
|
|
|
|
.create<linalg::BatchMatmulOp>(loc, initTensor0.getType(),
|
|
|
|
ValueRange{lhs, rhs}, initTensor0)
|
|
|
|
.getResult(0);
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, bmm);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
namespace {
|
2022-04-08 12:47:57 +08:00
|
|
|
class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
|
2022-03-11 01:54:13 +08:00
|
|
|
public:
|
|
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
2022-04-08 12:47:57 +08:00
|
|
|
matchAndRewrite(AtenConvolutionOp op, OpAdaptor adaptor,
|
2022-03-11 01:54:13 +08:00
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op->getLoc();
|
|
|
|
MLIRContext *context = op->getContext();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = adaptor.getInput(); /* in form of N*C*H*W */
|
2024-07-30 03:25:07 +08:00
|
|
|
Value weight = adaptor.getWeight(); /* in form of F*C/G*H*W */
|
2024-01-31 05:46:47 +08:00
|
|
|
Value bias = adaptor.getBias();
|
2024-04-28 05:00:56 +08:00
|
|
|
auto resultTy = cast<ValueTensorType>(op.getType());
|
2024-01-31 05:46:47 +08:00
|
|
|
|
|
|
|
Value inputZp, weightZp;
|
2024-05-01 00:23:09 +08:00
|
|
|
bool inputUnsigned = false;
|
|
|
|
bool weightUnsigned = false;
|
2024-01-31 05:46:47 +08:00
|
|
|
if (auto make = op.getInput()
|
|
|
|
.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
|
|
|
|
input = make.getSelf();
|
|
|
|
inputZp = make.getZeroPoint();
|
|
|
|
input = typeConverter->materializeTargetConversion(
|
|
|
|
rewriter, loc, typeConverter->convertType(input.getType()), input);
|
|
|
|
inputZp = typeConverter->materializeTargetConversion(
|
|
|
|
rewriter, loc, typeConverter->convertType(inputZp.getType()),
|
|
|
|
inputZp);
|
2024-06-21 04:54:20 +08:00
|
|
|
inputZp =
|
|
|
|
rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), inputZp);
|
2024-05-01 00:23:09 +08:00
|
|
|
auto torchDtype = cast<ValueTensorType>(make.getType()).getDtype();
|
|
|
|
inputUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype);
|
2024-01-31 05:46:47 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
if (auto make = op.getWeight()
|
|
|
|
.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
|
|
|
|
weight = make.getSelf();
|
|
|
|
weightZp = make.getZeroPoint();
|
|
|
|
|
|
|
|
weight = typeConverter->materializeTargetConversion(
|
|
|
|
rewriter, loc, typeConverter->convertType(weight.getType()), weight);
|
|
|
|
weightZp = typeConverter->materializeTargetConversion(
|
|
|
|
rewriter, loc, typeConverter->convertType(weightZp.getType()),
|
|
|
|
weightZp);
|
2024-06-21 04:54:20 +08:00
|
|
|
weightZp = rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(),
|
|
|
|
weightZp);
|
2024-05-01 00:23:09 +08:00
|
|
|
auto torchDtype = cast<ValueTensorType>(make.getType()).getDtype();
|
|
|
|
weightUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype);
|
2024-01-31 05:46:47 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
if (static_cast<bool>(inputZp) != static_cast<bool>(weightZp)) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "lhs and rhs of convolution must either be both int or fp");
|
|
|
|
}
|
|
|
|
|
2024-06-04 00:27:44 +08:00
|
|
|
if (inputZp && !isa<Torch::NoneType>(bias.getType())) {
|
2024-04-28 05:00:56 +08:00
|
|
|
auto biasDTy = cast<RankedTensorType>(bias.getType()).getElementType();
|
2024-01-31 05:46:47 +08:00
|
|
|
if (!biasDTy.isInteger(32)) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "quantized result ty should be i32 accumulator");
|
|
|
|
}
|
|
|
|
}
|
2022-03-11 01:54:13 +08:00
|
|
|
|
2022-08-25 00:19:35 +08:00
|
|
|
bool transposed = true;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed)))
|
2022-08-25 00:19:35 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: only constant transposed supported");
|
|
|
|
|
2024-04-28 05:00:56 +08:00
|
|
|
auto inputDTy = cast<RankedTensorType>(input.getType()).getElementType();
|
|
|
|
auto weightDTy = cast<RankedTensorType>(weight.getType()).getElementType();
|
2024-01-31 05:46:47 +08:00
|
|
|
auto resultDTy = resultTy.toBuiltinTensor().getElementType();
|
|
|
|
|
2024-04-11 21:47:35 +08:00
|
|
|
if (!isa<mlir::FloatType, mlir::IntegerType>(inputDTy) ||
|
|
|
|
!isa<mlir::FloatType, mlir::IntegerType>(weightDTy) ||
|
|
|
|
!isa<mlir::FloatType, mlir::IntegerType>(resultDTy))
|
2024-01-31 05:46:47 +08:00
|
|
|
return op.emitError("unimplemented: non-fp not-int type");
|
2024-04-28 05:00:56 +08:00
|
|
|
size_t inRank = cast<RankedTensorType>(input.getType()).getRank();
|
2024-01-31 05:46:47 +08:00
|
|
|
size_t numSpatialDims = inRank - 2;
|
|
|
|
if (numSpatialDims < 1 || numSpatialDims > 3)
|
2022-04-08 12:47:57 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
2024-01-24 13:30:03 +08:00
|
|
|
op, "unimplemented: only 1d-3d convolution currently supported");
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
Type intType = IntegerType::get(context, 64);
|
|
|
|
auto castIndexToInt = [&](Value v) {
|
TorchToLinalg: Try folding shape computations to keep static shapes when possible (#3475)
Before this PR, a statically shaped aten.convolution would generate
dynamically shaped linalg IR, and even `-canonicalize` would not be able
to fold it back into static shapes. This PR ensure that shape
calculations are folded on construction to directly generate statically
shaped linalg IR.
We achieve that by ensuring that `arith` ops involved in computing
shapes are created via `createOrFold`, so that later uses of
`getAsOpFoldResult` see constants instead of those ops.
For example
```
module {
func.func @forward(%arg0: !torch.vtensor<[32,336,112,112],f32>,
%arg1: !torch.vtensor<[336,168,3,3],f32>,
%arg2: !torch.vtensor<[336],f32>)
-> !torch.vtensor<[32,336,56,56],f32> {
%false = torch.constant.bool false
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%0 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct : () -> !torch.list<int>
%3 = torch.aten.convolution %arg0, %arg1, %arg2, %1, %0, %0, %false, %2, %int2
: !torch.vtensor<[32,336,112,112],f32>, !torch.vtensor<[336,168,3,3],f32>, !torch.vtensor<[336],f32>, !torch.list<int>,
!torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int
-> !torch.vtensor<[32,336,56,56],f32>
return %3 : !torch.vtensor<[32,336,56,56],f32>
}
}
```
would result in
```
[...]
%padded = tensor.pad %2 low[%14, %15, %16, %17] high[%14, %15, %16, %17] {
^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
tensor.yield %cst : f32
} : tensor<32x336x112x112xf32> to tensor<?x?x?x?xf32>
[...]
%45 = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
ins(%expanded, %expanded_37 : tensor<?x2x?x?x?xf32>, tensor<2x168x168x3x3xf32>)
outs(%expanded_44 : tensor<32x2x168x?x?xf32>) -> tensor<32x2x168x?x?xf32>
[...]
```
and with this PR all shapes are static.
2024-06-27 14:43:10 +08:00
|
|
|
return rewriter.createOrFold<arith::IndexCastOp>(loc, intType, v);
|
2022-03-11 01:54:13 +08:00
|
|
|
};
|
|
|
|
|
2022-11-04 15:57:29 +08:00
|
|
|
SmallVector<Value> paddingIntValues;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!getListConstructElements(op.getPadding(), paddingIntValues))
|
2022-03-11 01:54:13 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
2022-11-04 15:57:29 +08:00
|
|
|
op, "only support padding from a list construct");
|
|
|
|
paddingIntValues = getTypeConvertedValues(rewriter, loc, getTypeConverter(),
|
|
|
|
paddingIntValues);
|
2022-12-08 22:15:31 +08:00
|
|
|
SmallVector<Value> outputPaddingIntValues;
|
|
|
|
if (!getListConstructElements(op.getOutputPadding(),
|
|
|
|
outputPaddingIntValues))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "only support output_padding from a list construct");
|
|
|
|
outputPaddingIntValues = getTypeConvertedValues(
|
|
|
|
rewriter, loc, getTypeConverter(), outputPaddingIntValues);
|
2022-04-08 12:47:57 +08:00
|
|
|
SmallVector<int64_t> strideInts;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts)))
|
2022-03-11 01:54:13 +08:00
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"only support constant int strides");
|
2022-04-08 12:47:57 +08:00
|
|
|
SmallVector<int64_t> dilationInts;
|
2024-01-30 01:59:33 +08:00
|
|
|
if (!matchPattern(op.getDilation(),
|
|
|
|
m_TorchListOfConstantInts(dilationInts)))
|
2022-03-11 01:54:13 +08:00
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"only support constant int dilations");
|
2022-04-08 12:47:57 +08:00
|
|
|
|
2022-08-04 14:18:38 +08:00
|
|
|
Value inBatch = getDimOp(rewriter, loc, input, 0);
|
|
|
|
Value inChannels = getDimOp(rewriter, loc, input, 1);
|
2022-04-08 12:47:57 +08:00
|
|
|
SmallVector<Value> inDims;
|
|
|
|
for (size_t i = 2; i < inRank; i++)
|
|
|
|
inDims.push_back(getDimOp(rewriter, loc, input, i));
|
2022-08-04 14:18:38 +08:00
|
|
|
Value weightBatch = getDimOp(rewriter, loc, weight, 0);
|
|
|
|
Value weightChannels = getDimOp(rewriter, loc, weight, 1);
|
2022-04-08 12:47:57 +08:00
|
|
|
SmallVector<Value> weightDims;
|
|
|
|
for (size_t i = 2; i < inRank; i++)
|
|
|
|
weightDims.push_back(getDimOp(rewriter, loc, weight, i));
|
|
|
|
|
2022-08-04 14:18:38 +08:00
|
|
|
// Checks for valid group size
|
2024-07-30 03:25:07 +08:00
|
|
|
int64_t numGroups;
|
|
|
|
if (!matchPattern(op.getGroups(), m_TorchConstantInt(&numGroups)))
|
2022-08-04 14:18:38 +08:00
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"only constant group size supported.");
|
2022-12-08 04:20:41 +08:00
|
|
|
Value groups = castIntToIndex(rewriter, loc, adaptor.getGroups());
|
2022-08-04 14:18:38 +08:00
|
|
|
|
|
|
|
auto validate = [&](Value toValidate, std::string err) {
|
|
|
|
Value c0 =
|
|
|
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
|
|
|
|
Value inputValid = rewriter.create<arith::CmpIOp>(
|
|
|
|
loc, arith::CmpIPredicate::eq, c0,
|
|
|
|
rewriter.create<arith::RemSIOp>(loc, toValidate, groups));
|
|
|
|
rewriter.create<cf::AssertOp>(loc, inputValid,
|
|
|
|
rewriter.getStringAttr(err));
|
|
|
|
};
|
|
|
|
validate(inChannels,
|
|
|
|
"invalid: groups must divide input channel size evenly.");
|
|
|
|
validate(weightBatch,
|
|
|
|
"invalid: groups must divide weight batch size evenly.");
|
2022-03-11 01:54:13 +08:00
|
|
|
SmallVector<Value> dilationIntValues =
|
|
|
|
getAsConstantIntValues(rewriter, loc, dilationInts);
|
|
|
|
SmallVector<Value> strideIntValues =
|
|
|
|
getAsConstantIntValues(rewriter, loc, strideInts);
|
|
|
|
|
2024-05-01 00:23:09 +08:00
|
|
|
// convert any uint8 quantization to int8 quantization
|
|
|
|
if (auto integerType = dyn_cast<mlir::IntegerType>(inputDTy)) {
|
|
|
|
int64_t width = integerType.getWidth();
|
|
|
|
signShift(rewriter, loc, input, inputZp, inputUnsigned, width);
|
|
|
|
}
|
|
|
|
if (auto integerType = dyn_cast<mlir::IntegerType>(weightDTy)) {
|
|
|
|
int64_t width = integerType.getWidth();
|
|
|
|
signShift(rewriter, loc, weight, weightZp, weightUnsigned, width);
|
|
|
|
}
|
2022-08-25 00:19:35 +08:00
|
|
|
// Pad the input tensor according to padding.
|
2022-08-04 14:18:38 +08:00
|
|
|
SmallVector<Value> outDims{inBatch, weightBatch};
|
2022-08-25 00:19:35 +08:00
|
|
|
Value paddedInput;
|
2024-05-01 00:23:09 +08:00
|
|
|
Value pad = inputZp;
|
|
|
|
if (!pad) {
|
|
|
|
if (isa<mlir::FloatType>(inputDTy))
|
|
|
|
pad = rewriter.create<arith::ConstantOp>(
|
|
|
|
op.getLoc(), rewriter.getFloatAttr(inputDTy, 0.0));
|
|
|
|
if (isa<mlir::IntegerType>(inputDTy))
|
|
|
|
pad = rewriter.create<arith::ConstantOp>(
|
|
|
|
op.getLoc(), rewriter.getIntegerAttr(inputDTy, 0));
|
|
|
|
}
|
|
|
|
if (pad.getType() != inputDTy) {
|
|
|
|
if (isa<mlir::FloatType>(inputDTy))
|
|
|
|
pad = rewriter.create<arith::TruncFOp>(op.getLoc(), inputDTy, pad);
|
2024-01-31 05:46:47 +08:00
|
|
|
|
2024-05-01 00:23:09 +08:00
|
|
|
if (isa<mlir::IntegerType>(inputDTy))
|
|
|
|
pad = rewriter.create<arith::TruncIOp>(op.getLoc(), inputDTy, pad);
|
|
|
|
}
|
|
|
|
if (transposed) {
|
2022-08-25 00:19:35 +08:00
|
|
|
Value c0 =
|
|
|
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
|
|
|
|
Value c1 =
|
|
|
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
|
|
|
Value c2 =
|
|
|
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(2));
|
|
|
|
|
|
|
|
// Transpose and flip weight
|
|
|
|
SmallVector<Value> weightInitDims = getTensorSizes(rewriter, loc, weight);
|
|
|
|
std::iter_swap(weightInitDims.begin(), weightInitDims.begin() + 1);
|
|
|
|
outDims[1] = weightInitDims[0];
|
|
|
|
Value weightInitTensor =
|
2024-01-31 05:46:47 +08:00
|
|
|
createZeroInitTensor(rewriter, loc, weightInitDims, weightDTy);
|
2022-11-17 06:40:36 +08:00
|
|
|
SmallVector<utils::IteratorType> iteratorTypes(
|
|
|
|
inRank, utils::IteratorType::parallel);
|
2022-10-13 05:01:24 +08:00
|
|
|
SmallVector<AffineMap> indexingMaps{
|
|
|
|
AffineMap::getMultiDimIdentityMap(inRank, context)};
|
2022-08-25 00:19:35 +08:00
|
|
|
weight = rewriter
|
|
|
|
.create<linalg::GenericOp>(
|
2022-10-13 05:01:24 +08:00
|
|
|
loc, weightInitTensor.getType(), ValueRange{},
|
2022-08-25 00:19:35 +08:00
|
|
|
weightInitTensor, indexingMaps, iteratorTypes,
|
|
|
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
|
|
|
SmallVector<Value> indices;
|
|
|
|
for (size_t i = 0; i < inRank; i++)
|
|
|
|
indices.push_back(b.create<linalg::IndexOp>(loc, i));
|
|
|
|
std::iter_swap(indices.begin(), indices.begin() + 1);
|
|
|
|
// Flip only the spatial dimensions (from 2 to inRank)
|
|
|
|
for (size_t flipDim = 2; flipDim < inRank; flipDim++) {
|
|
|
|
indices[flipDim] = b.create<arith::SubIOp>(
|
|
|
|
loc,
|
|
|
|
b.create<arith::SubIOp>(
|
|
|
|
loc, weightInitDims[flipDim], c1),
|
|
|
|
indices[flipDim]);
|
|
|
|
}
|
|
|
|
Value res =
|
|
|
|
b.create<tensor::ExtractOp>(loc, weight, indices)
|
|
|
|
.getResult();
|
|
|
|
b.create<linalg::YieldOp>(loc, res);
|
|
|
|
})
|
|
|
|
.getResult(0);
|
|
|
|
|
|
|
|
// Calculate padded input size, allocate tensor
|
|
|
|
SmallVector<Value> outerSizes{inBatch, inChannels};
|
|
|
|
SmallVector<Value> innerSizes{inBatch, inChannels};
|
|
|
|
SmallVector<Value> offsets{c0, c0};
|
2024-01-31 05:46:47 +08:00
|
|
|
for (size_t i = 0; i < numSpatialDims; i++) {
|
2022-08-25 00:19:35 +08:00
|
|
|
Value innerSize = rewriter.create<arith::SubIOp>(loc, inDims[i], c1);
|
|
|
|
innerSize = rewriter.create<arith::MulIOp>(
|
|
|
|
loc, innerSize, castIntToIndex(rewriter, loc, strideIntValues[i]));
|
|
|
|
innerSize = rewriter.create<arith::AddIOp>(loc, innerSize, c1);
|
|
|
|
|
|
|
|
Value offset = rewriter.create<arith::SubIOp>(loc, weightDims[i], c1);
|
|
|
|
offset = rewriter.create<arith::MulIOp>(
|
|
|
|
loc, offset, castIntToIndex(rewriter, loc, dilationIntValues[i]));
|
|
|
|
offset = rewriter.create<arith::SubIOp>(
|
|
|
|
loc, offset, castIntToIndex(rewriter, loc, paddingIntValues[i]));
|
|
|
|
|
|
|
|
Value outerSize = rewriter.create<arith::MulIOp>(loc, offset, c2);
|
|
|
|
outerSize = rewriter.create<arith::AddIOp>(loc, outerSize, innerSize);
|
2022-12-08 22:15:31 +08:00
|
|
|
outerSize = rewriter.create<arith::AddIOp>(
|
|
|
|
loc, outerSize,
|
|
|
|
castIntToIndex(rewriter, loc, outputPaddingIntValues[i]));
|
2022-08-25 00:19:35 +08:00
|
|
|
|
|
|
|
outerSizes.push_back(outerSize);
|
|
|
|
offsets.push_back(offset);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Allocate padded input tensor
|
|
|
|
Value initTensor =
|
2024-05-01 00:23:09 +08:00
|
|
|
createInitTensor(rewriter, loc, outerSizes, inputDTy, pad);
|
2022-08-25 00:19:35 +08:00
|
|
|
|
|
|
|
// Insert input into allocated tensor
|
|
|
|
SmallVector<Value> strideIndexValues{c1, c1};
|
|
|
|
for (auto stride : strideIntValues)
|
|
|
|
strideIndexValues.push_back(castIntToIndex(rewriter, loc, stride));
|
|
|
|
SmallVector<Value> insertSizes = getTensorSizes(rewriter, loc, input);
|
|
|
|
|
|
|
|
paddedInput = rewriter.create<tensor::InsertSliceOp>(
|
|
|
|
loc, torch_to_linalg::removeSizeInformation(rewriter, loc, input),
|
|
|
|
initTensor, offsets, insertSizes, strideIndexValues);
|
|
|
|
|
|
|
|
// Calculate output dims
|
2024-01-31 05:46:47 +08:00
|
|
|
for (size_t i = 0; i < numSpatialDims; i++)
|
2022-08-25 00:19:35 +08:00
|
|
|
outDims.push_back(torch_to_linalg::getOutputDimForConvTransposeOps(
|
|
|
|
rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i],
|
2022-12-08 22:15:31 +08:00
|
|
|
castIndexToInt(weightDims[i]), strideIntValues[i],
|
|
|
|
outputPaddingIntValues[i]));
|
2022-08-25 00:19:35 +08:00
|
|
|
|
|
|
|
// Set stride to 1
|
|
|
|
strideInts.clear();
|
2024-01-31 05:46:47 +08:00
|
|
|
strideInts.append(numSpatialDims, 1);
|
2022-08-25 00:19:35 +08:00
|
|
|
} else {
|
|
|
|
// Pad input
|
2022-11-04 15:57:29 +08:00
|
|
|
paddedInput = torch_to_linalg::getDynamicZeroPaddedTensor(
|
2024-01-31 05:46:47 +08:00
|
|
|
op, rewriter, input, paddingIntValues, /*unpaddedDims=*/2, pad);
|
2022-08-25 00:19:35 +08:00
|
|
|
|
|
|
|
// Calculate output dims
|
2024-01-31 05:46:47 +08:00
|
|
|
for (size_t i = 0; i < numSpatialDims; i++)
|
2022-08-25 00:19:35 +08:00
|
|
|
outDims.push_back(torch_to_linalg::getOutputDimForConvOps(
|
|
|
|
rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i],
|
|
|
|
castIndexToInt(weightDims[i]), strideIntValues[i]));
|
|
|
|
}
|
2022-03-11 01:54:13 +08:00
|
|
|
|
2024-06-21 04:54:20 +08:00
|
|
|
Type accumulatorDType = getDefaultAccType(rewriter, inputDTy);
|
2022-10-18 12:22:53 +08:00
|
|
|
Value initTensor = rewriter.create<tensor::EmptyOp>(
|
2024-03-15 07:40:40 +08:00
|
|
|
loc, getAsOpFoldResult(outDims), accumulatorDType);
|
2022-03-11 01:54:13 +08:00
|
|
|
|
2022-08-04 14:18:38 +08:00
|
|
|
Value outputTensor;
|
2024-05-31 14:45:13 +08:00
|
|
|
if (accumulatorDType != resultDTy && !isa<Torch::NoneType>(bias.getType()))
|
2024-03-15 07:40:40 +08:00
|
|
|
bias = torch_to_linalg::convertTensorToElementType(rewriter, loc, bias,
|
|
|
|
accumulatorDType);
|
2024-05-31 14:45:13 +08:00
|
|
|
if (isa<Torch::NoneType>(bias.getType())) {
|
2024-01-31 05:46:47 +08:00
|
|
|
Value c0;
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<mlir::FloatType>(accumulatorDType)) {
|
2024-03-15 07:40:40 +08:00
|
|
|
c0 = rewriter.create<arith::ConstantOp>(
|
|
|
|
loc, FloatAttr::get(accumulatorDType, 0.0));
|
2024-04-11 21:47:35 +08:00
|
|
|
} else if (isa<mlir::IntegerType>(accumulatorDType)) {
|
2024-03-15 07:40:40 +08:00
|
|
|
c0 = rewriter.create<arith::ConstantOp>(
|
|
|
|
loc, IntegerAttr::get(accumulatorDType, 0));
|
2024-01-31 05:46:47 +08:00
|
|
|
}
|
2024-02-29 13:52:03 +08:00
|
|
|
outputTensor =
|
|
|
|
rewriter.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
|
2024-01-31 05:46:47 +08:00
|
|
|
|
2022-03-11 01:54:13 +08:00
|
|
|
} else {
|
2024-04-28 05:00:56 +08:00
|
|
|
auto biasType = cast<RankedTensorType>(bias.getType());
|
2022-03-11 01:54:13 +08:00
|
|
|
if (biasType.getRank() != 1)
|
|
|
|
return rewriter.notifyMatchFailure(op, "expect bias to be rank 1");
|
|
|
|
|
2024-04-28 05:00:56 +08:00
|
|
|
auto resultRank = cast<RankedTensorType>(initTensor.getType()).getRank();
|
2022-03-11 01:54:13 +08:00
|
|
|
SmallVector<AffineMap> indexingMaps = {
|
|
|
|
// bias is used to initialize the channels - dimension 1 of output
|
|
|
|
AffineMap::get(/*dimCount=*/resultRank, /*symbolCount=*/0,
|
|
|
|
rewriter.getAffineDimExpr(1), context),
|
|
|
|
rewriter.getMultiDimIdentityMap(resultRank)};
|
2022-11-17 06:40:36 +08:00
|
|
|
SmallVector<utils::IteratorType> iteratorTypes(
|
|
|
|
resultRank, utils::IteratorType::parallel);
|
2022-08-04 14:18:38 +08:00
|
|
|
outputTensor = rewriter
|
|
|
|
.create<linalg::GenericOp>(
|
|
|
|
loc, initTensor.getType(), bias, initTensor,
|
|
|
|
indexingMaps, iteratorTypes,
|
|
|
|
[](OpBuilder &b, Location loc, ValueRange args) {
|
|
|
|
b.create<linalg::YieldOp>(loc, args[0]);
|
|
|
|
})
|
|
|
|
.getResult(0);
|
2022-03-11 01:54:13 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
auto stridesAttr = rewriter.getI64VectorAttr(strideInts);
|
|
|
|
auto dilationAttr = rewriter.getI64VectorAttr(dilationInts);
|
2022-04-08 12:47:57 +08:00
|
|
|
|
2022-08-04 14:18:38 +08:00
|
|
|
Value inputStride =
|
|
|
|
rewriter.create<arith::FloorDivSIOp>(loc, inChannels, groups);
|
|
|
|
Value weightStride =
|
|
|
|
rewriter.create<arith::FloorDivSIOp>(loc, weightBatch, groups);
|
|
|
|
|
|
|
|
SmallVector<Value> zeroOffsets(inRank, rewriter.create<arith::ConstantOp>(
|
|
|
|
loc, rewriter.getIndexAttr(0)));
|
|
|
|
SmallVector<Value> unitStrides(inRank, rewriter.create<arith::ConstantOp>(
|
|
|
|
loc, rewriter.getIndexAttr(1)));
|
|
|
|
SmallVector<Value> outDimSlice(outDims);
|
|
|
|
outDimSlice[1] = weightStride;
|
|
|
|
SmallVector<Value> inputSliceSizes{inBatch, inputStride};
|
|
|
|
inputSliceSizes.append(inDims);
|
|
|
|
SmallVector<Value> weightSliceSizes{weightStride, weightChannels};
|
|
|
|
weightSliceSizes.append(weightDims);
|
|
|
|
|
|
|
|
Value conv;
|
2024-01-31 05:46:47 +08:00
|
|
|
// the code so far is able to respect all numSpatialDims
|
2024-07-30 03:25:07 +08:00
|
|
|
// the code below this point is numSpatialDims specific and numGroups
|
2024-01-30 01:59:33 +08:00
|
|
|
// specific
|
|
|
|
// TODO: factor out the above code into a helper function, and then separate
|
|
|
|
// convolution into:
|
2024-01-24 13:30:03 +08:00
|
|
|
// - grouped 1d-3d
|
2024-01-31 05:46:47 +08:00
|
|
|
// - grouped 1d-3d (quantized)
|
2024-01-24 13:30:03 +08:00
|
|
|
// - ungrouped 1d-3d
|
2024-07-30 03:25:07 +08:00
|
|
|
if (numGroups == 1 && !inputZp) {
|
2024-01-31 05:46:47 +08:00
|
|
|
switch (numSpatialDims) {
|
2024-01-24 13:30:03 +08:00
|
|
|
case 1:
|
|
|
|
conv = rewriter
|
|
|
|
.create<linalg::Conv1DNcwFcwOp>(
|
|
|
|
loc, outputTensor.getType(),
|
|
|
|
ValueRange{paddedInput, weight}, outputTensor,
|
|
|
|
stridesAttr, dilationAttr)
|
|
|
|
.getResult(0);
|
|
|
|
break;
|
|
|
|
case 2:
|
2024-01-30 01:59:33 +08:00
|
|
|
conv = rewriter
|
|
|
|
.create<linalg::Conv2DNchwFchwOp>(
|
|
|
|
loc, outputTensor.getType(),
|
|
|
|
ValueRange{paddedInput, weight}, outputTensor,
|
|
|
|
stridesAttr, dilationAttr)
|
|
|
|
.getResult(0);
|
2024-01-24 13:30:03 +08:00
|
|
|
break;
|
|
|
|
case 3:
|
2024-01-30 01:59:33 +08:00
|
|
|
conv = rewriter
|
|
|
|
.create<linalg::Conv3DNcdhwFcdhwOp>(
|
|
|
|
loc, outputTensor.getType(),
|
|
|
|
ValueRange{paddedInput, weight}, outputTensor,
|
|
|
|
stridesAttr, dilationAttr)
|
|
|
|
.getResult(0);
|
2024-01-24 13:30:03 +08:00
|
|
|
break;
|
|
|
|
default:
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: only 1D, 2D, and 3D convolution supported");
|
|
|
|
};
|
|
|
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
2024-03-15 07:40:40 +08:00
|
|
|
if (accumulatorDType != resultDTy) {
|
|
|
|
Type resultElementType =
|
2024-04-11 21:47:35 +08:00
|
|
|
cast<RankedTensorType>(newResultType).getElementType();
|
2024-03-15 07:40:40 +08:00
|
|
|
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
|
|
|
|
resultElementType);
|
|
|
|
}
|
2024-01-24 13:30:03 +08:00
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
|
|
|
|
return success();
|
2024-01-31 05:46:47 +08:00
|
|
|
}
|
|
|
|
|
2024-07-30 03:25:07 +08:00
|
|
|
if (numGroups == 1 && inputZp) {
|
2024-01-31 05:46:47 +08:00
|
|
|
// The quantized version uses a different channel ordering so we need to
|
|
|
|
// permute the tensors in order to use the existing path. We should
|
|
|
|
// eventually directly support this channel ordering.
|
|
|
|
llvm::SmallVector<int64_t> inPerms, weightPerms;
|
|
|
|
inPerms.push_back(0); // N stays at the front for input.
|
|
|
|
// Then we expect the spatial dimensions
|
|
|
|
for (size_t i = 0; i < numSpatialDims; ++i) {
|
|
|
|
inPerms.push_back(i + 2);
|
|
|
|
weightPerms.push_back(i + 2);
|
|
|
|
}
|
|
|
|
inPerms.push_back(1);
|
|
|
|
weightPerms.append({1, 0});
|
2022-08-04 14:18:38 +08:00
|
|
|
|
2024-01-31 05:46:47 +08:00
|
|
|
paddedInput = transposeValue(op.getLoc(), paddedInput, inPerms, rewriter);
|
|
|
|
weight = transposeValue(op.getLoc(), weight, weightPerms, rewriter);
|
|
|
|
outputTensor =
|
|
|
|
transposeValue(op.getLoc(), outputTensor, inPerms, rewriter);
|
|
|
|
|
|
|
|
switch (numSpatialDims) {
|
|
|
|
case 2:
|
2022-08-04 14:18:38 +08:00
|
|
|
conv = rewriter
|
2024-01-31 05:46:47 +08:00
|
|
|
.create<linalg::Conv2DNhwcHwcfQOp>(
|
2024-01-30 01:59:33 +08:00
|
|
|
loc, outputTensor.getType(),
|
2024-01-31 05:46:47 +08:00
|
|
|
ValueRange{paddedInput, weight, inputZp, weightZp},
|
|
|
|
outputTensor, stridesAttr, dilationAttr)
|
2024-01-30 01:59:33 +08:00
|
|
|
.getResult(0);
|
2024-01-31 05:46:47 +08:00
|
|
|
break;
|
|
|
|
case 3:
|
|
|
|
conv = rewriter
|
|
|
|
.create<linalg::Conv3DNdhwcDhwcfQOp>(
|
|
|
|
loc, outputTensor.getType(),
|
|
|
|
ValueRange{paddedInput, weight, inputZp, weightZp},
|
|
|
|
outputTensor, stridesAttr, dilationAttr)
|
|
|
|
.getResult(0);
|
|
|
|
break;
|
|
|
|
default:
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: only 1D, 2D, and 3D convolution supported");
|
|
|
|
};
|
2022-08-04 14:18:38 +08:00
|
|
|
|
2024-01-31 05:46:47 +08:00
|
|
|
llvm::SmallVector<int64_t> outPerms;
|
|
|
|
outPerms.push_back(0);
|
|
|
|
outPerms.push_back(inPerms.size() - 1);
|
|
|
|
for (size_t i = 0; i < numSpatialDims; ++i) {
|
|
|
|
outPerms.push_back(i + 1);
|
2022-08-04 14:18:38 +08:00
|
|
|
}
|
2024-01-31 05:46:47 +08:00
|
|
|
conv = transposeValue(op.getLoc(), conv, outPerms, rewriter);
|
2022-08-04 14:18:38 +08:00
|
|
|
|
2024-01-31 05:46:47 +08:00
|
|
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
2024-03-15 07:40:40 +08:00
|
|
|
if (accumulatorDType != resultDTy) {
|
|
|
|
Type resultElementType =
|
2024-04-11 21:47:35 +08:00
|
|
|
cast<RankedTensorType>(newResultType).getElementType();
|
2024-03-15 07:40:40 +08:00
|
|
|
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
|
|
|
|
resultElementType);
|
|
|
|
}
|
2024-01-31 05:46:47 +08:00
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
if (numSpatialDims != 2)
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: only 2D grouped convolution supported");
|
|
|
|
|
2024-07-30 03:25:07 +08:00
|
|
|
// Special depthwise case: Cin = Cout = groups.
|
|
|
|
// Note: pytorch considers Cin == groups (Cout possibly a non-zero multiple
|
|
|
|
// of groups) to be depthwise in their documentation, but the linalg ops
|
|
|
|
// apparently disagree.
|
2024-01-31 05:46:47 +08:00
|
|
|
auto inShape = makeShapeTorchCompatible(
|
2024-04-28 05:00:56 +08:00
|
|
|
cast<RankedTensorType>(input.getType()).getShape());
|
2024-01-31 05:46:47 +08:00
|
|
|
auto weightShape = makeShapeTorchCompatible(
|
2024-04-28 05:00:56 +08:00
|
|
|
cast<RankedTensorType>(weight.getType()).getShape());
|
2024-07-30 03:25:07 +08:00
|
|
|
if (inShape[1] == numGroups && weightShape[0] == numGroups &&
|
|
|
|
weightShape[1] == 1) {
|
|
|
|
// Collapse weight shape (C/G == 1)
|
2024-01-31 05:46:47 +08:00
|
|
|
SmallVector<ReassociationIndices, 4> collapsedDims = {{0, 1}, {2}, {3}};
|
2024-07-30 03:25:07 +08:00
|
|
|
SmallVector<int64_t> collapsedShape{weightShape[0] * weightShape[1],
|
|
|
|
weightShape[2], weightShape[3]};
|
2024-01-31 05:46:47 +08:00
|
|
|
Type collapsedType = RankedTensorType::get(
|
|
|
|
makeShapeLLVMCompatible(collapsedShape), weightDTy);
|
|
|
|
Value collapsedWeight = rewriter.create<tensor::CollapseShapeOp>(
|
|
|
|
loc, collapsedType, weight, collapsedDims);
|
2024-07-30 03:25:07 +08:00
|
|
|
if (!inputZp) {
|
|
|
|
conv = rewriter
|
|
|
|
.create<linalg::DepthwiseConv2DNchwChwOp>(
|
|
|
|
loc, outputTensor.getType(),
|
|
|
|
ValueRange{paddedInput, collapsedWeight}, outputTensor,
|
|
|
|
stridesAttr, dilationAttr)
|
|
|
|
.getResult(0);
|
|
|
|
} else {
|
|
|
|
// currently, the only named depthwise qconv op is nhwc_hwc
|
|
|
|
// input: nchw -> nhwc; weight (collapsed): chw -> hwc
|
|
|
|
// linalg conv result nhwc -> nchw
|
|
|
|
// inPerms = [0, 2, 3, 1]
|
|
|
|
// weightPerms = [1, 2, 0]
|
|
|
|
// resultPerms = [0, 3, 1, 2]
|
|
|
|
llvm::SmallVector<int64_t> inPerms, weightPerms, resultPerms;
|
|
|
|
inPerms.push_back(0);
|
|
|
|
resultPerms.append({0, static_cast<int64_t>(numSpatialDims + 1)});
|
|
|
|
for (size_t i = 0; i < numSpatialDims; ++i) {
|
|
|
|
inPerms.push_back(i + 2);
|
|
|
|
weightPerms.push_back(i + 1);
|
|
|
|
resultPerms.push_back(i + 1);
|
|
|
|
}
|
|
|
|
inPerms.push_back(1);
|
|
|
|
weightPerms.push_back(0);
|
|
|
|
|
|
|
|
paddedInput =
|
|
|
|
transposeValue(op.getLoc(), paddedInput, inPerms, rewriter);
|
|
|
|
collapsedWeight =
|
|
|
|
transposeValue(op.getLoc(), collapsedWeight, weightPerms, rewriter);
|
|
|
|
outputTensor =
|
|
|
|
transposeValue(op.getLoc(), outputTensor, inPerms, rewriter);
|
|
|
|
|
|
|
|
conv =
|
|
|
|
rewriter
|
|
|
|
.create<linalg::DepthwiseConv2DNhwcHwcQOp>(
|
|
|
|
loc, outputTensor.getType(),
|
|
|
|
ValueRange{paddedInput, collapsedWeight, inputZp, weightZp},
|
|
|
|
outputTensor, stridesAttr, dilationAttr)
|
|
|
|
.getResult(0);
|
|
|
|
// convert output nhwc -> nchw
|
|
|
|
conv = transposeValue(op.getLoc(), conv, resultPerms, rewriter);
|
|
|
|
}
|
2024-01-31 05:46:47 +08:00
|
|
|
|
|
|
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
2024-03-15 07:40:40 +08:00
|
|
|
if (accumulatorDType != resultDTy) {
|
|
|
|
Type resultElementType =
|
2024-04-11 21:47:35 +08:00
|
|
|
cast<RankedTensorType>(newResultType).getElementType();
|
2024-03-15 07:40:40 +08:00
|
|
|
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
|
|
|
|
resultElementType);
|
|
|
|
}
|
2024-01-31 05:46:47 +08:00
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Grouped case, use the grouped conv linalg op
|
|
|
|
auto expandGroups = [&](Value tensor, size_t dim) {
|
2024-04-28 05:00:56 +08:00
|
|
|
auto inType = cast<RankedTensorType>(tensor.getType());
|
2024-01-31 05:46:47 +08:00
|
|
|
auto inShape = makeShapeTorchCompatible(inType.getShape());
|
|
|
|
|
|
|
|
SmallVector<int64_t> outShape;
|
|
|
|
for (auto i = 0; i < (long)inShape.size(); i++) {
|
|
|
|
if (i == 1) {
|
2024-07-30 03:25:07 +08:00
|
|
|
outShape.push_back(numGroups);
|
2024-01-31 05:46:47 +08:00
|
|
|
}
|
|
|
|
if (i == (long)dim) {
|
|
|
|
outShape.push_back(inShape[i] == kUnknownSize
|
|
|
|
? kUnknownSize
|
2024-07-30 03:25:07 +08:00
|
|
|
: inShape[i] / numGroups);
|
2024-01-31 05:46:47 +08:00
|
|
|
} else {
|
|
|
|
outShape.push_back(inShape[i]);
|
2022-08-04 14:18:38 +08:00
|
|
|
}
|
2024-01-31 05:46:47 +08:00
|
|
|
}
|
2022-08-04 14:18:38 +08:00
|
|
|
|
2024-01-31 05:46:47 +08:00
|
|
|
SmallVector<ReassociationIndices> indices;
|
|
|
|
for (auto i = 0; i <= (long)inShape.size(); i++) {
|
|
|
|
if (i == (long)dim) {
|
|
|
|
indices.push_back({i, ++i});
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
indices.push_back({i});
|
|
|
|
}
|
2022-08-04 14:18:38 +08:00
|
|
|
|
2024-01-31 05:46:47 +08:00
|
|
|
auto retType = inType.clone(makeShapeLLVMCompatible(outShape));
|
|
|
|
return rewriter.create<tensor::ExpandShapeOp>(loc, retType, tensor,
|
|
|
|
indices);
|
|
|
|
};
|
2022-08-04 14:18:38 +08:00
|
|
|
|
2024-01-31 05:46:47 +08:00
|
|
|
// expand F,C,H,W -> G,F/G,C,H,W
|
|
|
|
auto expandWeight = [&](Value tensor) {
|
2024-04-28 05:00:56 +08:00
|
|
|
auto inType = cast<RankedTensorType>(tensor.getType());
|
2024-01-31 05:46:47 +08:00
|
|
|
auto inShape = makeShapeTorchCompatible(inType.getShape());
|
2022-08-04 14:18:38 +08:00
|
|
|
|
2024-01-31 05:46:47 +08:00
|
|
|
SmallVector<int64_t> outShape{
|
2024-07-30 03:25:07 +08:00
|
|
|
numGroups,
|
|
|
|
(inShape[0] == kUnknownSize ? kUnknownSize : inShape[0] / numGroups)};
|
2024-01-31 05:46:47 +08:00
|
|
|
outShape.append(inShape.begin() + 1, inShape.end());
|
2022-08-04 14:18:38 +08:00
|
|
|
|
2024-01-31 05:46:47 +08:00
|
|
|
SmallVector<ReassociationIndices> indices{{0, 1}};
|
|
|
|
for (auto i = 2; i <= (long)inShape.size(); i++)
|
|
|
|
indices.push_back({i});
|
2022-08-04 14:18:38 +08:00
|
|
|
|
2024-01-31 05:46:47 +08:00
|
|
|
auto retType = inType.clone(makeShapeLLVMCompatible(outShape));
|
|
|
|
return rewriter.create<tensor::ExpandShapeOp>(loc, retType, tensor,
|
|
|
|
indices);
|
|
|
|
};
|
2022-08-04 14:18:38 +08:00
|
|
|
|
2024-01-31 05:46:47 +08:00
|
|
|
Value paddedInputExpanded = expandGroups(paddedInput, 1);
|
|
|
|
Value weightExpanded = expandWeight(weight);
|
|
|
|
auto expandOutputTensor = expandGroups(outputTensor, 1);
|
|
|
|
|
|
|
|
// TODO: add 1D and 3D case
|
2024-06-04 00:27:44 +08:00
|
|
|
if (!inputZp) {
|
|
|
|
conv = rewriter
|
|
|
|
.create<linalg::Conv2DNgchwGfchwOp>(
|
|
|
|
loc, expandOutputTensor.getResultType(),
|
|
|
|
ValueRange{paddedInputExpanded, weightExpanded},
|
|
|
|
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
|
|
|
|
.getResult(0);
|
|
|
|
} else {
|
|
|
|
conv = rewriter
|
|
|
|
.create<linalg::Conv2DNgchwGfchwQOp>(
|
|
|
|
loc, expandOutputTensor.getResultType(),
|
|
|
|
ValueRange{paddedInputExpanded, weightExpanded, inputZp,
|
|
|
|
weightZp},
|
|
|
|
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
|
|
|
|
.getResult(0);
|
|
|
|
}
|
2024-01-31 05:46:47 +08:00
|
|
|
conv = rewriter.create<tensor::CollapseShapeOp>(
|
|
|
|
loc, outputTensor.getType(), conv,
|
|
|
|
expandOutputTensor.getReassociationIndices());
|
|
|
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
2024-03-15 07:40:40 +08:00
|
|
|
if (accumulatorDType != resultDTy) {
|
|
|
|
Type resultElementType =
|
2024-04-11 21:47:35 +08:00
|
|
|
cast<RankedTensorType>(newResultType).getElementType();
|
2024-03-15 07:40:40 +08:00
|
|
|
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
|
|
|
|
resultElementType);
|
|
|
|
}
|
2024-01-31 05:46:47 +08:00
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
|
|
|
|
return success();
|
2022-03-11 01:54:13 +08:00
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
|
|
|
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
|
|
|
ConversionTarget &target) {
|
|
|
|
MLIRContext *context = patterns.getContext();
|
|
|
|
target.addIllegalOp<AtenMmOp>();
|
|
|
|
patterns.add<ConvertAtenMmOp>(typeConverter, context);
|
2022-05-03 04:01:15 +08:00
|
|
|
target.addIllegalOp<AtenFlipOp>();
|
|
|
|
patterns.add<ConvertAtenFlipOp>(typeConverter, context);
|
2022-03-11 01:54:13 +08:00
|
|
|
target.addIllegalOp<AtenMatmulOp>();
|
|
|
|
patterns.add<ConvertAtenMatmulOp>(typeConverter, context);
|
|
|
|
target.addIllegalOp<AtenBmmOp>();
|
|
|
|
patterns.add<ConvertAtenBmmOp>(typeConverter, context);
|
2022-04-08 12:47:57 +08:00
|
|
|
target.addIllegalOp<AtenConvolutionOp>();
|
|
|
|
patterns.add<ConvertAtenConvolutionOp>(typeConverter, context);
|
2022-03-11 01:54:13 +08:00
|
|
|
}
|