[tosa] Implement torch.linear support. (#535)

Refactor matmul into separate class and derive variants:
- matmul
- mm, bmm
- linear

Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>
pull/539/head snapshot-20220125.228
Suraj Sudhir 2022-01-25 08:48:58 -08:00 committed by GitHub
parent ad4b9e0369
commit cadea678e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 242 additions and 39 deletions

View File

@ -817,40 +817,37 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
return success();
}
// Perform torch matmul, mm and bmm
// Perform the basic n-dim matmul operation encompassing the handling of
// broadcasting and dynamic shape propagation.
// All PyTorch ops that leverage matrix multiplication will derive this and
// implement their specialized input processing (e.g transpose), and output
// processing, e.g. GEMM or fully connected bias handling.
template <typename AtenOpT>
class ConvertAtenMatMulOp : public OpConversionPattern<AtenOpT> {
class ConvertAtenMatmulBaseOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value lhs = adaptor.self();
// Each variant must implement corresponding parameter parsing options.
// Maintain separate input read functions for each variant because it is not
// necessarily true with all variants that the first two operands are the lhs
// and rhs.
virtual LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Value &lhs, Value &rhs) const {
return rewriter.notifyMatchFailure(
op,
"Unimplemented matrix multiplication variant input parsing function");
}
LogicalResult performMatmul(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Value &lhs,
Value &rhs, Value &output) const {
auto lhsTy = lhs.getType().cast<RankedTensorType>();
// Aten matmul, mm and bmm call operand2 by different names.
Value rhs = adaptor.getOperands()[1];
auto rhsTy = rhs.getType().cast<RankedTensorType>();
if (!lhsTy || !rhsTy)
return op.emitError("Only ranked tensor types supported in TOSA matmul");
auto lhsRank = lhsTy.getRank();
auto rhsRank = rhsTy.getRank();
// Mm takes two 2D tensors
if (isa<AtenMmOp>(op)) {
assert(lhsRank == 2 && rhsRank == 2 &&
"aten.mm called but matrix rank != 2");
}
// Bmm takes two 2D tensors
if (isa<AtenBmmOp>(op)) {
assert(lhsRank == 3 && rhsRank == 3 &&
"aten.bmm called but matrix rank != 2");
}
auto lhsShape = lhsTy.getShape();
auto rhsShape = rhsTy.getShape();
@ -1248,11 +1245,8 @@ public:
// Perform the reshape to output shape. This is always required unless both
// inputs are rank=3, in which case the tosa.matmul output itself is
// correctly shaped.
bool performOpReshape = !(lhsRank == 3 && rhsRank == 3);
auto outputTy = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<RankedTensorType>();
bool performOpReshape =
!(lhsRank == 3 && rhsRank == 3 && lhsShape[0] == rhsShape[0]);
if (performOpReshape) {
// Since the output shape may be unknown, we construct it
@ -1358,20 +1352,218 @@ public:
auto transposedOpType =
RankedTensorType::get(transposedOpShape, outputElemTy);
auto transposedOp = rewriter.create<tosa::TransposeOp>(
output =
rewriter
.create<tosa::TransposeOp>(
op->getLoc(),
OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(transposedOpType),
reshapedOp.getResult(), transposedOpShapeConst.getValue())
.getResult();
} else {
output = reshapedOp.getResult();
}
} else {
output = mmOpResult;
}
return success();
}
// The default version just reads two inputs, computes output and returns it.
// Other versions may add a bias, apply GEMM-style alpha/beta scaling etc.
virtual LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value lhs, rhs;
if (failed(readMatMulInputs(op, adaptor, rewriter, lhs, rhs)))
return op.emitError("Failed to read matmul inputs");
Value output;
if (failed(performMatmul(op, adaptor, rewriter, lhs, rhs, output)))
return op.emitError("Failed to perform matmul operation");
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<RankedTensorType>(),
output);
return success();
}
};
// Legalizes the torch.matmul op for general n-dim matmul.
template <typename AtenOpT>
class ConvertAtenMatMulOp : public ConvertAtenMatmulBaseOp<AtenOpT> {
public:
using ConvertAtenMatmulBaseOp<AtenOpT>::ConvertAtenMatmulBaseOp;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Value &lhs, Value &rhs) const override {
lhs = adaptor.self();
auto lhsTy = lhs.getType().cast<RankedTensorType>();
rhs = adaptor.other();
auto rhsTy = rhs.getType().cast<RankedTensorType>();
if (!lhsTy || !rhsTy)
return op.emitError("Only ranked tensor types supported in TOSA matmul");
return success();
}
};
// Implements handling of aten.mm and aten.bmm ops.
template <typename AtenOpT>
class ConvertAtenMmOp : public ConvertAtenMatmulBaseOp<AtenOpT> {
public:
using ConvertAtenMatmulBaseOp<AtenOpT>::ConvertAtenMatmulBaseOp;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Value &lhs, Value &rhs) const override {
lhs = adaptor.self();
auto lhsTy = lhs.getType().cast<RankedTensorType>();
rhs = adaptor.mat2();
auto rhsTy = rhs.getType().cast<RankedTensorType>();
if (!lhsTy || !rhsTy)
return op.emitError("Only ranked tensor types supported in TOSA matmul");
auto lhsRank = lhsTy.getRank();
auto rhsRank = rhsTy.getRank();
if (isa<AtenMmOp>(op)) {
// Mm takes two 2D tensors.
if (lhsRank != 2 || rhsRank != 2)
return op.emitError("aten.mm called but matrix rank != 2");
} else if (isa<AtenBmmOp>(op)) {
// Bmm takes two 3D tensors.
if (lhsRank != 3 || rhsRank != 3)
return op.emitError("aten.bmm called but matrix rank != 3");
}
return success();
}
};
// Implements handling of aten.linear op.
template <typename AtenOpT>
class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp<AtenOpT> {
public:
using ConvertAtenMatmulBaseOp<AtenOpT>::ConvertAtenMatmulBaseOp;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Value &lhs, Value &rhs) const override {
lhs = adaptor.input();
auto lhsTy = lhs.getType().cast<RankedTensorType>();
rhs = adaptor.weight();
auto rhsTy = rhs.getType().cast<RankedTensorType>();
if (!lhsTy || !rhsTy)
return op.emitError("Only ranked tensor types supported in TOSA matmul");
auto lhsRank = lhsTy.getRank();
auto rhsRank = rhsTy.getRank();
if (lhsRank != 2 && lhsRank != 3)
return op.emitError("aten.Linear called but input rank not 2 or 3");
if (rhsRank != 2 && rhsRank != 3)
return op.emitError("aten.Linear called but weight rank not 2 or 3");
// Protection against crash due to unguarded code in TOSA->LinAlg.
if (!lhsTy.hasStaticShape() || !rhsTy.hasStaticShape())
return op.emitError("aten.Linear needs statically shaped input");
return success();
}
// Override the default rewriter to perform RHS transpose and bias addition as
// well.
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value lhs, rhs;
if (failed(readMatMulInputs(op, adaptor, rewriter, lhs, rhs)))
return op.emitError("Failed to read matmul inputs");
// The aten.Linear op has a bias tensor that is added to the matmul output.
auto bias = adaptor.bias();
auto biasTy = bias.getType();
// TOSA does not mandate that elementwise op tensors need to be ranked.
if (!biasTy.template isa<Torch::NoneType>() &&
!biasTy.template isa<TensorType>())
return op.emitError("Only tensor types supported in GEMM to "
"TOSA for bias tensor");
// RHS must have its last two dims transposed prior to matrix
// multiplication.
auto rhsTy = rhs.getType().cast<RankedTensorType>();
auto rhsRank = rhsTy.getRank();
auto rhsShape = rhsTy.getShape();
auto rhsElemTy = rhsTy.getElementType();
// Create a non-const shape array to transpose dims.
SmallVector<int64_t> transposedRhsShape;
for (auto &shape : rhsShape)
transposedRhsShape.push_back(shape);
SmallVector<int32_t> transposedRhsDims;
for (int32_t i = 0; i < rhsRank; i++)
transposedRhsDims.push_back(i);
// Swap the last two dims.
std::swap(transposedRhsShape[rhsRank - 1], transposedRhsShape[rhsRank - 2]);
std::swap(transposedRhsDims[rhsRank - 1], transposedRhsDims[rhsRank - 2]);
llvm::Optional<Value> transposedRhsShapeConst =
tosa::getConstTensor<int32_t>(
rewriter, op,
/*vec=*/transposedRhsDims,
/*shape=*/{static_cast<int32_t>(transposedRhsDims.size())});
auto transposedRhsType =
RankedTensorType::get(transposedRhsShape, rhsElemTy);
rhs = rewriter.create<tosa::TransposeOp>(
op->getLoc(),
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
transposedOpType),
reshapedOp.getResult(), transposedOpShapeConst.getValue());
transposedRhsType),
rhs, transposedRhsShapeConst.getValue());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outputTy, transposedOp);
} else {
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outputTy, reshapedOp);
}
} else {
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outputTy, mmOpResult);
Value matmulOutput;
if (failed(
this->performMatmul(op, adaptor, rewriter, lhs, rhs, matmulOutput)))
return op.emitError("Failed to perform matmul operation");
Value matmulPlusBias = matmulOutput;
if (!biasTy.template isa<Torch::NoneType>()) {
// Bias addition broadcasts to the matmul output shape.
matmulPlusBias =
rewriter
.create<tosa::AddOp>(op->getLoc(), matmulOutput.getType(),
matmulOutput, bias)
.getResult();
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<RankedTensorType>(),
matmulPlusBias);
return success();
}
};
@ -1544,10 +1736,21 @@ public:
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenMatMulOp<AtenOp>>(typeConverter, context);
INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp);
INSERT_MATMUL_ATENOP_PATTERN(AtenMmOp);
INSERT_MATMUL_ATENOP_PATTERN(AtenBmmOp);
#undef INSERT_MATMUL_ATEMOP_PATTERN
#define INSERT_MM_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenMmOp<AtenOp>>(typeConverter, context);
INSERT_MM_ATENOP_PATTERN(AtenMmOp);
INSERT_MM_ATENOP_PATTERN(AtenBmmOp);
#undef INSERT_MM_ATEMOP_PATTERN
#define INSERT_LINEAR_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenLinearOp<AtenOp>>(typeConverter, context);
INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp);
#undef INSERT_LINEAR_ATEMOP_PATTERN
#define INSERT_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);