From cadea678e5e2cfc62a1625af6e50cac34144e710 Mon Sep 17 00:00:00 2001 From: Suraj Sudhir <16977902+sjarus@users.noreply.github.com> Date: Tue, 25 Jan 2022 08:48:58 -0800 Subject: [PATCH] [tosa] Implement torch.linear support. (#535) Refactor matmul into separate class and derive variants: - matmul - mm, bmm - linear Signed-off-by: Suraj Sudhir --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 281 ++++++++++++++++++--- 1 file changed, 242 insertions(+), 39 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 17ce945e9..4c1f35cc2 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -817,40 +817,37 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -// Perform torch matmul, mm and bmm +// Perform the basic n-dim matmul operation encompassing the handling of +// broadcasting and dynamic shape propagation. +// All PyTorch ops that leverage matrix multiplication will derive this and +// implement their specialized input processing (e.g transpose), and output +// processing, e.g. GEMM or fully connected bias handling. template -class ConvertAtenMatMulOp : public OpConversionPattern { +class ConvertAtenMatmulBaseOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; - LogicalResult - matchAndRewrite(AtenOpT op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value lhs = adaptor.self(); + // Each variant must implement corresponding parameter parsing options. + // Maintain separate input read functions for each variant because it is not + // necessarily true with all variants that the first two operands are the lhs + // and rhs. + virtual LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Value &lhs, Value &rhs) const { + return rewriter.notifyMatchFailure( + op, + "Unimplemented matrix multiplication variant input parsing function"); + } + LogicalResult performMatmul(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Value &lhs, + Value &rhs, Value &output) const { + auto lhsTy = lhs.getType().cast(); - - // Aten matmul, mm and bmm call operand2 by different names. - Value rhs = adaptor.getOperands()[1]; auto rhsTy = rhs.getType().cast(); - if (!lhsTy || !rhsTy) - return op.emitError("Only ranked tensor types supported in TOSA matmul"); - auto lhsRank = lhsTy.getRank(); auto rhsRank = rhsTy.getRank(); - // Mm takes two 2D tensors - if (isa(op)) { - assert(lhsRank == 2 && rhsRank == 2 && - "aten.mm called but matrix rank != 2"); - } - - // Bmm takes two 2D tensors - if (isa(op)) { - assert(lhsRank == 3 && rhsRank == 3 && - "aten.bmm called but matrix rank != 2"); - } - auto lhsShape = lhsTy.getShape(); auto rhsShape = rhsTy.getShape(); @@ -1248,11 +1245,8 @@ public: // Perform the reshape to output shape. This is always required unless both // inputs are rank=3, in which case the tosa.matmul output itself is // correctly shaped. - bool performOpReshape = !(lhsRank == 3 && rhsRank == 3); - - auto outputTy = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + bool performOpReshape = + !(lhsRank == 3 && rhsRank == 3 && lhsShape[0] == rhsShape[0]); if (performOpReshape) { // Since the output shape may be unknown, we construct it @@ -1358,20 +1352,218 @@ public: auto transposedOpType = RankedTensorType::get(transposedOpShape, outputElemTy); - auto transposedOp = rewriter.create( - op->getLoc(), - OpConversionPattern::getTypeConverter()->convertType( - transposedOpType), - reshapedOp.getResult(), transposedOpShapeConst.getValue()); + output = + rewriter + .create( + op->getLoc(), + OpConversionPattern::getTypeConverter() + ->convertType(transposedOpType), + reshapedOp.getResult(), transposedOpShapeConst.getValue()) + .getResult(); - rewriter.replaceOpWithNewOp(op, outputTy, transposedOp); } else { - rewriter.replaceOpWithNewOp(op, outputTy, reshapedOp); + output = reshapedOp.getResult(); } } else { - rewriter.replaceOpWithNewOp(op, outputTy, mmOpResult); + output = mmOpResult; } + return success(); + } + // The default version just reads two inputs, computes output and returns it. + // Other versions may add a bias, apply GEMM-style alpha/beta scaling etc. + virtual LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + Value lhs, rhs; + + if (failed(readMatMulInputs(op, adaptor, rewriter, lhs, rhs))) + return op.emitError("Failed to read matmul inputs"); + + Value output; + + if (failed(performMatmul(op, adaptor, rewriter, lhs, rhs, output))) + return op.emitError("Failed to perform matmul operation"); + + rewriter.replaceOpWithNewOp( + op, + OpConversionPattern::getTypeConverter() + ->convertType(op.getType()) + .template cast(), + output); + + return success(); + } +}; + +// Legalizes the torch.matmul op for general n-dim matmul. +template +class ConvertAtenMatMulOp : public ConvertAtenMatmulBaseOp { +public: + using ConvertAtenMatmulBaseOp::ConvertAtenMatmulBaseOp; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Value &lhs, Value &rhs) const override { + lhs = adaptor.self(); + auto lhsTy = lhs.getType().cast(); + + rhs = adaptor.other(); + auto rhsTy = rhs.getType().cast(); + + if (!lhsTy || !rhsTy) + return op.emitError("Only ranked tensor types supported in TOSA matmul"); + + return success(); + } +}; + +// Implements handling of aten.mm and aten.bmm ops. +template +class ConvertAtenMmOp : public ConvertAtenMatmulBaseOp { +public: + using ConvertAtenMatmulBaseOp::ConvertAtenMatmulBaseOp; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Value &lhs, Value &rhs) const override { + + lhs = adaptor.self(); + auto lhsTy = lhs.getType().cast(); + + rhs = adaptor.mat2(); + auto rhsTy = rhs.getType().cast(); + + if (!lhsTy || !rhsTy) + return op.emitError("Only ranked tensor types supported in TOSA matmul"); + + auto lhsRank = lhsTy.getRank(); + auto rhsRank = rhsTy.getRank(); + + if (isa(op)) { + // Mm takes two 2D tensors. + if (lhsRank != 2 || rhsRank != 2) + return op.emitError("aten.mm called but matrix rank != 2"); + } else if (isa(op)) { + // Bmm takes two 3D tensors. + if (lhsRank != 3 || rhsRank != 3) + return op.emitError("aten.bmm called but matrix rank != 3"); + } + + return success(); + } +}; + +// Implements handling of aten.linear op. +template +class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { +public: + using ConvertAtenMatmulBaseOp::ConvertAtenMatmulBaseOp; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Value &lhs, Value &rhs) const override { + + lhs = adaptor.input(); + auto lhsTy = lhs.getType().cast(); + + rhs = adaptor.weight(); + auto rhsTy = rhs.getType().cast(); + + if (!lhsTy || !rhsTy) + return op.emitError("Only ranked tensor types supported in TOSA matmul"); + + auto lhsRank = lhsTy.getRank(); + auto rhsRank = rhsTy.getRank(); + + if (lhsRank != 2 && lhsRank != 3) + return op.emitError("aten.Linear called but input rank not 2 or 3"); + if (rhsRank != 2 && rhsRank != 3) + return op.emitError("aten.Linear called but weight rank not 2 or 3"); + + // Protection against crash due to unguarded code in TOSA->LinAlg. + if (!lhsTy.hasStaticShape() || !rhsTy.hasStaticShape()) + return op.emitError("aten.Linear needs statically shaped input"); + + return success(); + } + // Override the default rewriter to perform RHS transpose and bias addition as + // well. + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Value lhs, rhs; + + if (failed(readMatMulInputs(op, adaptor, rewriter, lhs, rhs))) + return op.emitError("Failed to read matmul inputs"); + + // The aten.Linear op has a bias tensor that is added to the matmul output. + auto bias = adaptor.bias(); + auto biasTy = bias.getType(); + + // TOSA does not mandate that elementwise op tensors need to be ranked. + if (!biasTy.template isa() && + !biasTy.template isa()) + return op.emitError("Only tensor types supported in GEMM to " + "TOSA for bias tensor"); + + // RHS must have its last two dims transposed prior to matrix + // multiplication. + auto rhsTy = rhs.getType().cast(); + auto rhsRank = rhsTy.getRank(); + auto rhsShape = rhsTy.getShape(); + auto rhsElemTy = rhsTy.getElementType(); + + // Create a non-const shape array to transpose dims. + SmallVector transposedRhsShape; + for (auto &shape : rhsShape) + transposedRhsShape.push_back(shape); + SmallVector transposedRhsDims; + for (int32_t i = 0; i < rhsRank; i++) + transposedRhsDims.push_back(i); + + // Swap the last two dims. + std::swap(transposedRhsShape[rhsRank - 1], transposedRhsShape[rhsRank - 2]); + std::swap(transposedRhsDims[rhsRank - 1], transposedRhsDims[rhsRank - 2]); + + llvm::Optional transposedRhsShapeConst = + tosa::getConstTensor( + rewriter, op, + /*vec=*/transposedRhsDims, + /*shape=*/{static_cast(transposedRhsDims.size())}); + + auto transposedRhsType = + RankedTensorType::get(transposedRhsShape, rhsElemTy); + rhs = rewriter.create( + op->getLoc(), + OpConversionPattern::getTypeConverter()->convertType( + transposedRhsType), + rhs, transposedRhsShapeConst.getValue()); + + Value matmulOutput; + if (failed( + this->performMatmul(op, adaptor, rewriter, lhs, rhs, matmulOutput))) + return op.emitError("Failed to perform matmul operation"); + + Value matmulPlusBias = matmulOutput; + if (!biasTy.template isa()) { + // Bias addition broadcasts to the matmul output shape. + matmulPlusBias = + rewriter + .create(op->getLoc(), matmulOutput.getType(), + matmulOutput, bias) + .getResult(); + } + + rewriter.replaceOpWithNewOp( + op, + OpConversionPattern::getTypeConverter() + ->convertType(op.getType()) + .template cast(), + matmulPlusBias); + return success(); } }; @@ -1544,10 +1736,21 @@ public: target.addIllegalOp(); \ patterns.add>(typeConverter, context); INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp); - INSERT_MATMUL_ATENOP_PATTERN(AtenMmOp); - INSERT_MATMUL_ATENOP_PATTERN(AtenBmmOp); #undef INSERT_MATMUL_ATEMOP_PATTERN +#define INSERT_MM_ATENOP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_MM_ATENOP_PATTERN(AtenMmOp); + INSERT_MM_ATENOP_PATTERN(AtenBmmOp); +#undef INSERT_MM_ATEMOP_PATTERN + +#define INSERT_LINEAR_ATENOP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp); +#undef INSERT_LINEAR_ATEMOP_PATTERN + #define INSERT_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context);