mirror of https://github.com/llvm/torch-mlir
[tosa] Implement torch.linear support. (#535)
Refactor matmul into separate class and derive variants: - matmul - mm, bmm - linear Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>pull/539/head snapshot-20220125.228
parent
ad4b9e0369
commit
cadea678e5
|
@ -817,40 +817,37 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
// Perform torch matmul, mm and bmm
|
||||
// Perform the basic n-dim matmul operation encompassing the handling of
|
||||
// broadcasting and dynamic shape propagation.
|
||||
// All PyTorch ops that leverage matrix multiplication will derive this and
|
||||
// implement their specialized input processing (e.g transpose), and output
|
||||
// processing, e.g. GEMM or fully connected bias handling.
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenMatMulOp : public OpConversionPattern<AtenOpT> {
|
||||
class ConvertAtenMatmulBaseOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value lhs = adaptor.self();
|
||||
// Each variant must implement corresponding parameter parsing options.
|
||||
// Maintain separate input read functions for each variant because it is not
|
||||
// necessarily true with all variants that the first two operands are the lhs
|
||||
// and rhs.
|
||||
virtual LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Value &lhs, Value &rhs) const {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"Unimplemented matrix multiplication variant input parsing function");
|
||||
}
|
||||
LogicalResult performMatmul(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Value &lhs,
|
||||
Value &rhs, Value &output) const {
|
||||
|
||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
||||
|
||||
// Aten matmul, mm and bmm call operand2 by different names.
|
||||
Value rhs = adaptor.getOperands()[1];
|
||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
||||
|
||||
if (!lhsTy || !rhsTy)
|
||||
return op.emitError("Only ranked tensor types supported in TOSA matmul");
|
||||
|
||||
auto lhsRank = lhsTy.getRank();
|
||||
auto rhsRank = rhsTy.getRank();
|
||||
|
||||
// Mm takes two 2D tensors
|
||||
if (isa<AtenMmOp>(op)) {
|
||||
assert(lhsRank == 2 && rhsRank == 2 &&
|
||||
"aten.mm called but matrix rank != 2");
|
||||
}
|
||||
|
||||
// Bmm takes two 2D tensors
|
||||
if (isa<AtenBmmOp>(op)) {
|
||||
assert(lhsRank == 3 && rhsRank == 3 &&
|
||||
"aten.bmm called but matrix rank != 2");
|
||||
}
|
||||
|
||||
auto lhsShape = lhsTy.getShape();
|
||||
auto rhsShape = rhsTy.getShape();
|
||||
|
||||
|
@ -1248,11 +1245,8 @@ public:
|
|||
// Perform the reshape to output shape. This is always required unless both
|
||||
// inputs are rank=3, in which case the tosa.matmul output itself is
|
||||
// correctly shaped.
|
||||
bool performOpReshape = !(lhsRank == 3 && rhsRank == 3);
|
||||
|
||||
auto outputTy = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<RankedTensorType>();
|
||||
bool performOpReshape =
|
||||
!(lhsRank == 3 && rhsRank == 3 && lhsShape[0] == rhsShape[0]);
|
||||
|
||||
if (performOpReshape) {
|
||||
// Since the output shape may be unknown, we construct it
|
||||
|
@ -1358,20 +1352,218 @@ public:
|
|||
|
||||
auto transposedOpType =
|
||||
RankedTensorType::get(transposedOpShape, outputElemTy);
|
||||
auto transposedOp = rewriter.create<tosa::TransposeOp>(
|
||||
output =
|
||||
rewriter
|
||||
.create<tosa::TransposeOp>(
|
||||
op->getLoc(),
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(transposedOpType),
|
||||
reshapedOp.getResult(), transposedOpShapeConst.getValue())
|
||||
.getResult();
|
||||
|
||||
} else {
|
||||
output = reshapedOp.getResult();
|
||||
}
|
||||
} else {
|
||||
output = mmOpResult;
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
// The default version just reads two inputs, computes output and returns it.
|
||||
// Other versions may add a bias, apply GEMM-style alpha/beta scaling etc.
|
||||
virtual LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
Value lhs, rhs;
|
||||
|
||||
if (failed(readMatMulInputs(op, adaptor, rewriter, lhs, rhs)))
|
||||
return op.emitError("Failed to read matmul inputs");
|
||||
|
||||
Value output;
|
||||
|
||||
if (failed(performMatmul(op, adaptor, rewriter, lhs, rhs, output)))
|
||||
return op.emitError("Failed to perform matmul operation");
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<RankedTensorType>(),
|
||||
output);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Legalizes the torch.matmul op for general n-dim matmul.
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenMatMulOp : public ConvertAtenMatmulBaseOp<AtenOpT> {
|
||||
public:
|
||||
using ConvertAtenMatmulBaseOp<AtenOpT>::ConvertAtenMatmulBaseOp;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Value &lhs, Value &rhs) const override {
|
||||
lhs = adaptor.self();
|
||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
||||
|
||||
rhs = adaptor.other();
|
||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
||||
|
||||
if (!lhsTy || !rhsTy)
|
||||
return op.emitError("Only ranked tensor types supported in TOSA matmul");
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Implements handling of aten.mm and aten.bmm ops.
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenMmOp : public ConvertAtenMatmulBaseOp<AtenOpT> {
|
||||
public:
|
||||
using ConvertAtenMatmulBaseOp<AtenOpT>::ConvertAtenMatmulBaseOp;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Value &lhs, Value &rhs) const override {
|
||||
|
||||
lhs = adaptor.self();
|
||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
||||
|
||||
rhs = adaptor.mat2();
|
||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
||||
|
||||
if (!lhsTy || !rhsTy)
|
||||
return op.emitError("Only ranked tensor types supported in TOSA matmul");
|
||||
|
||||
auto lhsRank = lhsTy.getRank();
|
||||
auto rhsRank = rhsTy.getRank();
|
||||
|
||||
if (isa<AtenMmOp>(op)) {
|
||||
// Mm takes two 2D tensors.
|
||||
if (lhsRank != 2 || rhsRank != 2)
|
||||
return op.emitError("aten.mm called but matrix rank != 2");
|
||||
} else if (isa<AtenBmmOp>(op)) {
|
||||
// Bmm takes two 3D tensors.
|
||||
if (lhsRank != 3 || rhsRank != 3)
|
||||
return op.emitError("aten.bmm called but matrix rank != 3");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Implements handling of aten.linear op.
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp<AtenOpT> {
|
||||
public:
|
||||
using ConvertAtenMatmulBaseOp<AtenOpT>::ConvertAtenMatmulBaseOp;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Value &lhs, Value &rhs) const override {
|
||||
|
||||
lhs = adaptor.input();
|
||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
||||
|
||||
rhs = adaptor.weight();
|
||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
||||
|
||||
if (!lhsTy || !rhsTy)
|
||||
return op.emitError("Only ranked tensor types supported in TOSA matmul");
|
||||
|
||||
auto lhsRank = lhsTy.getRank();
|
||||
auto rhsRank = rhsTy.getRank();
|
||||
|
||||
if (lhsRank != 2 && lhsRank != 3)
|
||||
return op.emitError("aten.Linear called but input rank not 2 or 3");
|
||||
if (rhsRank != 2 && rhsRank != 3)
|
||||
return op.emitError("aten.Linear called but weight rank not 2 or 3");
|
||||
|
||||
// Protection against crash due to unguarded code in TOSA->LinAlg.
|
||||
if (!lhsTy.hasStaticShape() || !rhsTy.hasStaticShape())
|
||||
return op.emitError("aten.Linear needs statically shaped input");
|
||||
|
||||
return success();
|
||||
}
|
||||
// Override the default rewriter to perform RHS transpose and bias addition as
|
||||
// well.
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
Value lhs, rhs;
|
||||
|
||||
if (failed(readMatMulInputs(op, adaptor, rewriter, lhs, rhs)))
|
||||
return op.emitError("Failed to read matmul inputs");
|
||||
|
||||
// The aten.Linear op has a bias tensor that is added to the matmul output.
|
||||
auto bias = adaptor.bias();
|
||||
auto biasTy = bias.getType();
|
||||
|
||||
// TOSA does not mandate that elementwise op tensors need to be ranked.
|
||||
if (!biasTy.template isa<Torch::NoneType>() &&
|
||||
!biasTy.template isa<TensorType>())
|
||||
return op.emitError("Only tensor types supported in GEMM to "
|
||||
"TOSA for bias tensor");
|
||||
|
||||
// RHS must have its last two dims transposed prior to matrix
|
||||
// multiplication.
|
||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
||||
auto rhsRank = rhsTy.getRank();
|
||||
auto rhsShape = rhsTy.getShape();
|
||||
auto rhsElemTy = rhsTy.getElementType();
|
||||
|
||||
// Create a non-const shape array to transpose dims.
|
||||
SmallVector<int64_t> transposedRhsShape;
|
||||
for (auto &shape : rhsShape)
|
||||
transposedRhsShape.push_back(shape);
|
||||
SmallVector<int32_t> transposedRhsDims;
|
||||
for (int32_t i = 0; i < rhsRank; i++)
|
||||
transposedRhsDims.push_back(i);
|
||||
|
||||
// Swap the last two dims.
|
||||
std::swap(transposedRhsShape[rhsRank - 1], transposedRhsShape[rhsRank - 2]);
|
||||
std::swap(transposedRhsDims[rhsRank - 1], transposedRhsDims[rhsRank - 2]);
|
||||
|
||||
llvm::Optional<Value> transposedRhsShapeConst =
|
||||
tosa::getConstTensor<int32_t>(
|
||||
rewriter, op,
|
||||
/*vec=*/transposedRhsDims,
|
||||
/*shape=*/{static_cast<int32_t>(transposedRhsDims.size())});
|
||||
|
||||
auto transposedRhsType =
|
||||
RankedTensorType::get(transposedRhsShape, rhsElemTy);
|
||||
rhs = rewriter.create<tosa::TransposeOp>(
|
||||
op->getLoc(),
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
transposedOpType),
|
||||
reshapedOp.getResult(), transposedOpShapeConst.getValue());
|
||||
transposedRhsType),
|
||||
rhs, transposedRhsShapeConst.getValue());
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outputTy, transposedOp);
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outputTy, reshapedOp);
|
||||
}
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outputTy, mmOpResult);
|
||||
Value matmulOutput;
|
||||
if (failed(
|
||||
this->performMatmul(op, adaptor, rewriter, lhs, rhs, matmulOutput)))
|
||||
return op.emitError("Failed to perform matmul operation");
|
||||
|
||||
Value matmulPlusBias = matmulOutput;
|
||||
if (!biasTy.template isa<Torch::NoneType>()) {
|
||||
// Bias addition broadcasts to the matmul output shape.
|
||||
matmulPlusBias =
|
||||
rewriter
|
||||
.create<tosa::AddOp>(op->getLoc(), matmulOutput.getType(),
|
||||
matmulOutput, bias)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<RankedTensorType>(),
|
||||
matmulPlusBias);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -1544,10 +1736,21 @@ public:
|
|||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenMatMulOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp);
|
||||
INSERT_MATMUL_ATENOP_PATTERN(AtenMmOp);
|
||||
INSERT_MATMUL_ATENOP_PATTERN(AtenBmmOp);
|
||||
#undef INSERT_MATMUL_ATEMOP_PATTERN
|
||||
|
||||
#define INSERT_MM_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenMmOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_MM_ATENOP_PATTERN(AtenMmOp);
|
||||
INSERT_MM_ATENOP_PATTERN(AtenBmmOp);
|
||||
#undef INSERT_MM_ATEMOP_PATTERN
|
||||
|
||||
#define INSERT_LINEAR_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenLinearOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp);
|
||||
#undef INSERT_LINEAR_ATEMOP_PATTERN
|
||||
|
||||
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
|
||||
|
|
Loading…
Reference in New Issue