mirror of https://github.com/llvm/torch-mlir
[tosa] Implement torch.linear support. (#535)
Refactor matmul into separate class and derive variants: - matmul - mm, bmm - linear Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>pull/539/head snapshot-20220125.228
parent
ad4b9e0369
commit
cadea678e5
|
@ -817,40 +817,37 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Perform torch matmul, mm and bmm
|
// Perform the basic n-dim matmul operation encompassing the handling of
|
||||||
|
// broadcasting and dynamic shape propagation.
|
||||||
|
// All PyTorch ops that leverage matrix multiplication will derive this and
|
||||||
|
// implement their specialized input processing (e.g transpose), and output
|
||||||
|
// processing, e.g. GEMM or fully connected bias handling.
|
||||||
template <typename AtenOpT>
|
template <typename AtenOpT>
|
||||||
class ConvertAtenMatMulOp : public OpConversionPattern<AtenOpT> {
|
class ConvertAtenMatmulBaseOp : public OpConversionPattern<AtenOpT> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||||
LogicalResult
|
// Each variant must implement corresponding parameter parsing options.
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
// Maintain separate input read functions for each variant because it is not
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
// necessarily true with all variants that the first two operands are the lhs
|
||||||
Value lhs = adaptor.self();
|
// and rhs.
|
||||||
|
virtual LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter,
|
||||||
|
Value &lhs, Value &rhs) const {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op,
|
||||||
|
"Unimplemented matrix multiplication variant input parsing function");
|
||||||
|
}
|
||||||
|
LogicalResult performMatmul(AtenOpT op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter, Value &lhs,
|
||||||
|
Value &rhs, Value &output) const {
|
||||||
|
|
||||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
// Aten matmul, mm and bmm call operand2 by different names.
|
|
||||||
Value rhs = adaptor.getOperands()[1];
|
|
||||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
if (!lhsTy || !rhsTy)
|
|
||||||
return op.emitError("Only ranked tensor types supported in TOSA matmul");
|
|
||||||
|
|
||||||
auto lhsRank = lhsTy.getRank();
|
auto lhsRank = lhsTy.getRank();
|
||||||
auto rhsRank = rhsTy.getRank();
|
auto rhsRank = rhsTy.getRank();
|
||||||
|
|
||||||
// Mm takes two 2D tensors
|
|
||||||
if (isa<AtenMmOp>(op)) {
|
|
||||||
assert(lhsRank == 2 && rhsRank == 2 &&
|
|
||||||
"aten.mm called but matrix rank != 2");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bmm takes two 2D tensors
|
|
||||||
if (isa<AtenBmmOp>(op)) {
|
|
||||||
assert(lhsRank == 3 && rhsRank == 3 &&
|
|
||||||
"aten.bmm called but matrix rank != 2");
|
|
||||||
}
|
|
||||||
|
|
||||||
auto lhsShape = lhsTy.getShape();
|
auto lhsShape = lhsTy.getShape();
|
||||||
auto rhsShape = rhsTy.getShape();
|
auto rhsShape = rhsTy.getShape();
|
||||||
|
|
||||||
|
@ -1248,11 +1245,8 @@ public:
|
||||||
// Perform the reshape to output shape. This is always required unless both
|
// Perform the reshape to output shape. This is always required unless both
|
||||||
// inputs are rank=3, in which case the tosa.matmul output itself is
|
// inputs are rank=3, in which case the tosa.matmul output itself is
|
||||||
// correctly shaped.
|
// correctly shaped.
|
||||||
bool performOpReshape = !(lhsRank == 3 && rhsRank == 3);
|
bool performOpReshape =
|
||||||
|
!(lhsRank == 3 && rhsRank == 3 && lhsShape[0] == rhsShape[0]);
|
||||||
auto outputTy = OpConversionPattern<AtenOpT>::getTypeConverter()
|
|
||||||
->convertType(op.getType())
|
|
||||||
.template cast<RankedTensorType>();
|
|
||||||
|
|
||||||
if (performOpReshape) {
|
if (performOpReshape) {
|
||||||
// Since the output shape may be unknown, we construct it
|
// Since the output shape may be unknown, we construct it
|
||||||
|
@ -1358,20 +1352,218 @@ public:
|
||||||
|
|
||||||
auto transposedOpType =
|
auto transposedOpType =
|
||||||
RankedTensorType::get(transposedOpShape, outputElemTy);
|
RankedTensorType::get(transposedOpShape, outputElemTy);
|
||||||
auto transposedOp = rewriter.create<tosa::TransposeOp>(
|
output =
|
||||||
|
rewriter
|
||||||
|
.create<tosa::TransposeOp>(
|
||||||
|
op->getLoc(),
|
||||||
|
OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||||
|
->convertType(transposedOpType),
|
||||||
|
reshapedOp.getResult(), transposedOpShapeConst.getValue())
|
||||||
|
.getResult();
|
||||||
|
|
||||||
|
} else {
|
||||||
|
output = reshapedOp.getResult();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
output = mmOpResult;
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
// The default version just reads two inputs, computes output and returns it.
|
||||||
|
// Other versions may add a bias, apply GEMM-style alpha/beta scaling etc.
|
||||||
|
virtual LogicalResult
|
||||||
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
|
Value lhs, rhs;
|
||||||
|
|
||||||
|
if (failed(readMatMulInputs(op, adaptor, rewriter, lhs, rhs)))
|
||||||
|
return op.emitError("Failed to read matmul inputs");
|
||||||
|
|
||||||
|
Value output;
|
||||||
|
|
||||||
|
if (failed(performMatmul(op, adaptor, rewriter, lhs, rhs, output)))
|
||||||
|
return op.emitError("Failed to perform matmul operation");
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||||
|
op,
|
||||||
|
OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||||
|
->convertType(op.getType())
|
||||||
|
.template cast<RankedTensorType>(),
|
||||||
|
output);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Legalizes the torch.matmul op for general n-dim matmul.
|
||||||
|
template <typename AtenOpT>
|
||||||
|
class ConvertAtenMatMulOp : public ConvertAtenMatmulBaseOp<AtenOpT> {
|
||||||
|
public:
|
||||||
|
using ConvertAtenMatmulBaseOp<AtenOpT>::ConvertAtenMatmulBaseOp;
|
||||||
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||||
|
LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter,
|
||||||
|
Value &lhs, Value &rhs) const override {
|
||||||
|
lhs = adaptor.self();
|
||||||
|
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
|
rhs = adaptor.other();
|
||||||
|
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
|
if (!lhsTy || !rhsTy)
|
||||||
|
return op.emitError("Only ranked tensor types supported in TOSA matmul");
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Implements handling of aten.mm and aten.bmm ops.
|
||||||
|
template <typename AtenOpT>
|
||||||
|
class ConvertAtenMmOp : public ConvertAtenMatmulBaseOp<AtenOpT> {
|
||||||
|
public:
|
||||||
|
using ConvertAtenMatmulBaseOp<AtenOpT>::ConvertAtenMatmulBaseOp;
|
||||||
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||||
|
LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter,
|
||||||
|
Value &lhs, Value &rhs) const override {
|
||||||
|
|
||||||
|
lhs = adaptor.self();
|
||||||
|
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
|
rhs = adaptor.mat2();
|
||||||
|
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
|
if (!lhsTy || !rhsTy)
|
||||||
|
return op.emitError("Only ranked tensor types supported in TOSA matmul");
|
||||||
|
|
||||||
|
auto lhsRank = lhsTy.getRank();
|
||||||
|
auto rhsRank = rhsTy.getRank();
|
||||||
|
|
||||||
|
if (isa<AtenMmOp>(op)) {
|
||||||
|
// Mm takes two 2D tensors.
|
||||||
|
if (lhsRank != 2 || rhsRank != 2)
|
||||||
|
return op.emitError("aten.mm called but matrix rank != 2");
|
||||||
|
} else if (isa<AtenBmmOp>(op)) {
|
||||||
|
// Bmm takes two 3D tensors.
|
||||||
|
if (lhsRank != 3 || rhsRank != 3)
|
||||||
|
return op.emitError("aten.bmm called but matrix rank != 3");
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Implements handling of aten.linear op.
|
||||||
|
template <typename AtenOpT>
|
||||||
|
class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp<AtenOpT> {
|
||||||
|
public:
|
||||||
|
using ConvertAtenMatmulBaseOp<AtenOpT>::ConvertAtenMatmulBaseOp;
|
||||||
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||||
|
LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter,
|
||||||
|
Value &lhs, Value &rhs) const override {
|
||||||
|
|
||||||
|
lhs = adaptor.input();
|
||||||
|
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
|
rhs = adaptor.weight();
|
||||||
|
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
|
if (!lhsTy || !rhsTy)
|
||||||
|
return op.emitError("Only ranked tensor types supported in TOSA matmul");
|
||||||
|
|
||||||
|
auto lhsRank = lhsTy.getRank();
|
||||||
|
auto rhsRank = rhsTy.getRank();
|
||||||
|
|
||||||
|
if (lhsRank != 2 && lhsRank != 3)
|
||||||
|
return op.emitError("aten.Linear called but input rank not 2 or 3");
|
||||||
|
if (rhsRank != 2 && rhsRank != 3)
|
||||||
|
return op.emitError("aten.Linear called but weight rank not 2 or 3");
|
||||||
|
|
||||||
|
// Protection against crash due to unguarded code in TOSA->LinAlg.
|
||||||
|
if (!lhsTy.hasStaticShape() || !rhsTy.hasStaticShape())
|
||||||
|
return op.emitError("aten.Linear needs statically shaped input");
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
// Override the default rewriter to perform RHS transpose and bias addition as
|
||||||
|
// well.
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
|
Value lhs, rhs;
|
||||||
|
|
||||||
|
if (failed(readMatMulInputs(op, adaptor, rewriter, lhs, rhs)))
|
||||||
|
return op.emitError("Failed to read matmul inputs");
|
||||||
|
|
||||||
|
// The aten.Linear op has a bias tensor that is added to the matmul output.
|
||||||
|
auto bias = adaptor.bias();
|
||||||
|
auto biasTy = bias.getType();
|
||||||
|
|
||||||
|
// TOSA does not mandate that elementwise op tensors need to be ranked.
|
||||||
|
if (!biasTy.template isa<Torch::NoneType>() &&
|
||||||
|
!biasTy.template isa<TensorType>())
|
||||||
|
return op.emitError("Only tensor types supported in GEMM to "
|
||||||
|
"TOSA for bias tensor");
|
||||||
|
|
||||||
|
// RHS must have its last two dims transposed prior to matrix
|
||||||
|
// multiplication.
|
||||||
|
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
||||||
|
auto rhsRank = rhsTy.getRank();
|
||||||
|
auto rhsShape = rhsTy.getShape();
|
||||||
|
auto rhsElemTy = rhsTy.getElementType();
|
||||||
|
|
||||||
|
// Create a non-const shape array to transpose dims.
|
||||||
|
SmallVector<int64_t> transposedRhsShape;
|
||||||
|
for (auto &shape : rhsShape)
|
||||||
|
transposedRhsShape.push_back(shape);
|
||||||
|
SmallVector<int32_t> transposedRhsDims;
|
||||||
|
for (int32_t i = 0; i < rhsRank; i++)
|
||||||
|
transposedRhsDims.push_back(i);
|
||||||
|
|
||||||
|
// Swap the last two dims.
|
||||||
|
std::swap(transposedRhsShape[rhsRank - 1], transposedRhsShape[rhsRank - 2]);
|
||||||
|
std::swap(transposedRhsDims[rhsRank - 1], transposedRhsDims[rhsRank - 2]);
|
||||||
|
|
||||||
|
llvm::Optional<Value> transposedRhsShapeConst =
|
||||||
|
tosa::getConstTensor<int32_t>(
|
||||||
|
rewriter, op,
|
||||||
|
/*vec=*/transposedRhsDims,
|
||||||
|
/*shape=*/{static_cast<int32_t>(transposedRhsDims.size())});
|
||||||
|
|
||||||
|
auto transposedRhsType =
|
||||||
|
RankedTensorType::get(transposedRhsShape, rhsElemTy);
|
||||||
|
rhs = rewriter.create<tosa::TransposeOp>(
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
transposedOpType),
|
transposedRhsType),
|
||||||
reshapedOp.getResult(), transposedOpShapeConst.getValue());
|
rhs, transposedRhsShapeConst.getValue());
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outputTy, transposedOp);
|
Value matmulOutput;
|
||||||
} else {
|
if (failed(
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outputTy, reshapedOp);
|
this->performMatmul(op, adaptor, rewriter, lhs, rhs, matmulOutput)))
|
||||||
}
|
return op.emitError("Failed to perform matmul operation");
|
||||||
} else {
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outputTy, mmOpResult);
|
Value matmulPlusBias = matmulOutput;
|
||||||
|
if (!biasTy.template isa<Torch::NoneType>()) {
|
||||||
|
// Bias addition broadcasts to the matmul output shape.
|
||||||
|
matmulPlusBias =
|
||||||
|
rewriter
|
||||||
|
.create<tosa::AddOp>(op->getLoc(), matmulOutput.getType(),
|
||||||
|
matmulOutput, bias)
|
||||||
|
.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||||
|
op,
|
||||||
|
OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||||
|
->convertType(op.getType())
|
||||||
|
.template cast<RankedTensorType>(),
|
||||||
|
matmulPlusBias);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1544,10 +1736,21 @@ public:
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenMatMulOp<AtenOp>>(typeConverter, context);
|
patterns.add<ConvertAtenMatMulOp<AtenOp>>(typeConverter, context);
|
||||||
INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp);
|
INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp);
|
||||||
INSERT_MATMUL_ATENOP_PATTERN(AtenMmOp);
|
|
||||||
INSERT_MATMUL_ATENOP_PATTERN(AtenBmmOp);
|
|
||||||
#undef INSERT_MATMUL_ATEMOP_PATTERN
|
#undef INSERT_MATMUL_ATEMOP_PATTERN
|
||||||
|
|
||||||
|
#define INSERT_MM_ATENOP_PATTERN(AtenOp) \
|
||||||
|
target.addIllegalOp<AtenOp>(); \
|
||||||
|
patterns.add<ConvertAtenMmOp<AtenOp>>(typeConverter, context);
|
||||||
|
INSERT_MM_ATENOP_PATTERN(AtenMmOp);
|
||||||
|
INSERT_MM_ATENOP_PATTERN(AtenBmmOp);
|
||||||
|
#undef INSERT_MM_ATEMOP_PATTERN
|
||||||
|
|
||||||
|
#define INSERT_LINEAR_ATENOP_PATTERN(AtenOp) \
|
||||||
|
target.addIllegalOp<AtenOp>(); \
|
||||||
|
patterns.add<ConvertAtenLinearOp<AtenOp>>(typeConverter, context);
|
||||||
|
INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp);
|
||||||
|
#undef INSERT_LINEAR_ATEMOP_PATTERN
|
||||||
|
|
||||||
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
|
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
|
||||||
|
|
Loading…
Reference in New Issue