diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index ae98dd5e3..aa0fe6f60 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -678,6 +678,566 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } + +// Perform torch matmul, mm and bmm +template +class ConvertAtenMatMulOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value lhs = adaptor.self(); + auto lhsTy = lhs.getType().cast(); + + // Aten matmul, mm and bmm call operand2 by different names. + Value rhs = adaptor.getOperands()[1]; + auto rhsTy = rhs.getType().cast(); + + if (!lhsTy || !rhsTy) + return op.emitError("Only ranked tensor types supported in TOSA matmul"); + + auto lhsRank = lhsTy.getRank(); + auto rhsRank = rhsTy.getRank(); + + // Mm takes two 2D tensors + if (isa(op)) { + assert(lhsRank == 2 && rhsRank == 2 && + "aten.mm called but matrix rank != 2"); + } + + // Bmm takes two 2D tensors + if (isa(op)) { + assert(lhsRank == 3 && rhsRank == 3 && + "aten.bmm called but matrix rank != 2"); + } + + auto lhsShape = lhsTy.getShape(); + auto rhsShape = rhsTy.getShape(); + + auto lhsElemTy = lhsTy.getElementType(); + auto rhsElemTy = rhsTy.getElementType(); + + if (lhsElemTy != rhsElemTy) + return op.emitError("Matmul: input datatypes mismatched"); + + // Legalization constructs may offer input shapes but expect output shapes + // to be inferred, e.g. + // func @forward(%arg0: !torch.vtensor<[14,19],f32>, + // %arg1: !torch.vtensor<[19,28],f32>) -> + // !torch.vtensor<[?,?],f32> + // This is tricky with matmul, since TOSA matmul is on 3D inputs. + // This means the need to reshape potentially both inputs and outputs, + // and reshape to unknown shape is undefined. + + auto maxInputRank = lhsRank > rhsRank ? lhsRank : rhsRank; + // If performing dot product on vectors, the RHS is synthetically transposed + if (maxInputRank == 1) + maxInputRank++; + + // Obtaining the rank broadcasted shapes of tensors makes it easier to + // construct the input and output reshaping logic. + auto getRankBroadcastedShape = [&](Value tensor, + bool isRHS) -> SmallVector { + auto tensorTy = tensor.getType().cast(); + auto tensorShape = tensorTy.getShape(); + auto tensorRank = tensorTy.getRank(); + + SmallVector bcastedShape; + + auto bcastDims = maxInputRank - tensorRank; + + if (isRHS && (tensorRank == 1) && bcastDims) { + // RHS with rank1 is special. It be synthetically transposed to dim[:-2] + for (int32_t i = 0; i < bcastDims - 1; i++) + bcastedShape.push_back(1); + bcastedShape.push_back(tensorShape[0]); + bcastedShape.push_back(1); + } else { + if (bcastDims > 0) { // rank broadcast + for (uint32_t i = 0; i < bcastDims; i++) + bcastedShape.push_back(1); + } + for (auto &dim : tensorShape) + bcastedShape.push_back(dim); + } + return bcastedShape; + }; + + // Step: Rank broadcast the two inputs. + auto lhsBroadcastedShape = getRankBroadcastedShape(lhs, false); + auto lhsBroadcastedTy = + RankedTensorType::get(lhsBroadcastedShape, lhsElemTy); + auto rhsBroadcastedShape = getRankBroadcastedShape(rhs, true); + auto rhsBroadcastedTy = + RankedTensorType::get(rhsBroadcastedShape, rhsElemTy); + + auto rankBroadcastedLhs = + lhsRank == maxInputRank + ? lhs + : rewriter.create( + op->getLoc(), + OpConversionPattern::getTypeConverter()->convertType( + lhsBroadcastedTy), + lhs, rewriter.getI64ArrayAttr(lhsBroadcastedShape)); + + auto rankBroadcastedRhs = + rhsRank == maxInputRank + ? rhs + : rewriter.create( + op->getLoc(), + OpConversionPattern::getTypeConverter()->convertType( + rhsBroadcastedTy), + rhs, rewriter.getI64ArrayAttr(rhsBroadcastedShape)); + + // TOSA matmul is performed on two 3D inputs and generates a 3D output. + // Lower ranked tensors are dim-1 reshaped up to 3D + auto reshapeUpTo3DTensor = [&](Value tensor) -> Value { + auto tensorTy = tensor.getType().cast(); + auto rank = tensorTy.getRank(); + + assert(rank <= 3 && "reshapeUpTo3D tensor must receive rank <= 3"); + if (rank == 3) + return tensor; + + auto shape = tensorTy.getShape(); + SmallVector newShape({1, 1, 1}); + + if (rank == 2) { // batchsize = 1 + newShape[1] = shape[0]; + newShape[2] = shape[1]; + } else { // rank 1 + newShape[2] = shape[0]; + } + auto newType = RankedTensorType::get(newShape, tensorTy.getElementType()); + + return rewriter.create( + op->getLoc(), + OpConversionPattern::getTypeConverter()->convertType( + newType), + tensor, rewriter.getI64ArrayAttr(newShape)); + }; + + // Where broadcasting is required in one or more batch dims, the following + // is done. + // Where all batch dims are involved in broadcasting: + // Given A: 3x1x5x6 and B: 1x4x6x7 + // 1. Reshape A to 1x15x6 (squeeze all batchdims into dim1) + // 2. Transpose B to 6x1x4x7, Reshape to 1x6x28 + // 3. tosa.Matmul 1x15x6 1x6x28 = 1x15x28 + // 4. Reshape out to 3x5x4x7, Transpose to 3x4x5x7 + // Where there are batch dimensions that are broadcast and not, the + // treatment is to have dim0 correspond to product of all non-broadcast + // dimsizes: + // Given A: 4x8x16x32 B: 8x32x17 + // 1. Reshape A to 8x64x32 (squeeze all unbroadcasted dims into dim0, + // broadcasted dims into dim1) + // 2. No transpose or reshape of B as its batchdims are not broadcast to. + // 3. tosa.Matmul 8x64x32 8x32x17 = 8x64x17 + // 4. Reshape to 8x4x16x17, Transpose to 4x8x16x17 + + // Check if we need to perform the broadcast on batch dim + // Not needed if max rank < 3, or if maxrank == 3 and dim[0] matches + auto needsBatchDimBroadcast = [&]() -> bool { + if (maxInputRank < 3) { + return false; + } else { + if (maxInputRank == 3 && + lhsBroadcastedShape[0] == rhsBroadcastedShape[0]) { + return false; + } + return true; + } + }; + + auto performBatchDimBroadcast = needsBatchDimBroadcast(); + + // Inputs to the tosa.matmul + Value matmulLhs, matmulRhs; + + using TensorShape_t = struct { + int64_t dim; + int64_t shape; + }; + + // Transpose needs to done if transposeDims are not non-monotonically + // increasing. E.g. [0, 1, 2, 3]: No transpose [1, 0, 2, 3]: Transpose dim0 + // and dim1 The order need not be sequential, since one or more dims may + // have been removed due to broadcasting. + auto isTransposeRequired = [](SmallVector transposedDims) -> bool { + int32_t lastDim = -1; + for (auto &dim : transposedDims) { + if (lastDim > dim) + return true; + lastDim = dim; + } + return false; + }; + + SmallVector commonElems, lhsSqueezedElems, rhsSqueezedElems; + + if (!performBatchDimBroadcast) { + // Simple with no broadcasting artifacts. Just reshape up to 3D + matmulLhs = reshapeUpTo3DTensor(rankBroadcastedLhs); + matmulRhs = reshapeUpTo3DTensor(rankBroadcastedRhs); + + } else { + // In this case, either or both input matrices involve broadcasting on + // their batch dimensions. For example: + // 4x5x6, 1x6x7 -> 4x5x7 + // 4x1x5x6, 1x3x6x7 -> 4x3x5x7 + // Though maxInputRank is necessarily >=3 here, individual matrices may be + // lower rank. + // E.g. 3x4x5x6, 6 -> 3x4x5 + + // These are the accumulated products of the shape of each dim: + // 1. common dimensions: upper dimensions (dims other than two rightmost) + // whose shapes are the same for both LHS and RHS. + // 2. LHS squeezed dimensions: all dimensions of LHS that involve + // broadcasting in either direction, plus the LHS[-2] shape + // 3. RHS squeezed dimensions: all dimensions of RHS that involve + // broadcasting in either direction, plus the RHS[-1] shape + int64_t commonValue = 1, lhsSqueezedValue = 1, rhsSqueezedValue = 1; + + // For both LHS and RHS, the dimensions are separated into the common, + // squeezed and remaining dim. E.g. given + // LHS = 3x4x5x6 + // RHS = 1x4x6x7 + // common = {{dim=1, shape=4}} + // lhs squeezed = {{dim=0, shape=3}, + // {dim=2, shape=5}} + // rhs squeezed = {{dim=0, shape=1}, + // {dim=2, shape=7}} + // The matmul dim is LHS[-1] and RHS[-2], i.e. 6. + // Once this is obtained, LHS and RHS are expressed as: + // LHS = {common, lhs_squeezed, matmul_dim} + // RHS = {common, matmul_dim, rhs_squeezed} + // The matmul is then performed to obtain output: + // matmul_out = {common, lhs_squeezed, rhs_squeezed} + // Finally, we reshape to 'unsqueeze' the LHS and RHS parts and transpose + // them back to their correct positions. + + SmallVector transposedLhsShape; + SmallVector transposedLhsDims; + + // Step: generate the common dim/shape information + for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) { + bool isDynamicDim = + lhsBroadcastedTy.isDynamic(lhsBroadcastedShape[dim]); + if (isDynamicDim || + lhsBroadcastedShape[dim] == rhsBroadcastedShape[dim]) { + commonValue *= lhsBroadcastedShape[dim]; + commonElems.push_back({dim, lhsBroadcastedShape[dim]}); + } + } + + // Step: generate the LHS squeezed dim/shape information. + bool hasDynamicDims = false; + for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) { + bool isDynamicDim = + lhsBroadcastedTy.isDynamic(lhsBroadcastedShape[dim]); + hasDynamicDims |= isDynamicDim; + if (!isDynamicDim && + lhsBroadcastedShape[dim] != rhsBroadcastedShape[dim]) { + lhsSqueezedValue *= lhsBroadcastedShape[dim]; + lhsSqueezedElems.push_back({dim, lhsBroadcastedShape[dim]}); + } + } + // including LHS[-2] + lhsSqueezedElems.push_back( + {maxInputRank - 2, lhsBroadcastedShape[maxInputRank - 2]}); + lhsSqueezedValue *= lhsBroadcastedShape[maxInputRank - 2]; + + // Step: Create the tosa.transpose array. If this array has a + // non-monotonic series of dims, perform transpose. + // First the common_elems + for (uint32_t i = 0; i < commonElems.size(); i++) { + transposedLhsShape.push_back(commonElems[i].shape); + transposedLhsDims.push_back(commonElems[i].dim); + } + // then the lhs_squeezed elems + for (uint32_t i = 0; i < lhsSqueezedElems.size(); i++) { + transposedLhsShape.push_back(lhsSqueezedElems[i].shape); + transposedLhsDims.push_back(lhsSqueezedElems[i].dim); + } + // then the final dim + transposedLhsDims.push_back(maxInputRank - 1); + transposedLhsShape.push_back(lhsBroadcastedShape[maxInputRank - 1]); + + bool lhsNeedsTranspose = isTransposeRequired(transposedLhsDims); + + auto lhsReshapeInput = rankBroadcastedLhs; + + if (lhsNeedsTranspose) { + auto transposedLhsType = + RankedTensorType::get(transposedLhsShape, rhsElemTy); + + llvm::Optional transposedLhsDimsConst = + tosa::getConstTensor( + rewriter, op, + /*vec=*/transposedLhsDims, + /*shape=*/{static_cast(transposedLhsDims.size())}); + + lhsReshapeInput = + rewriter + .create( + op->getLoc(), + OpConversionPattern::getTypeConverter() + ->convertType(transposedLhsType), + rankBroadcastedLhs, transposedLhsDimsConst.getValue()) + .getResult(); + } + + // LHS = {common, lhs_squeezed, matmul_dim} + SmallVector newLhsShape( + {1, 1, lhsBroadcastedShape[maxInputRank - 1]}); + newLhsShape[0] = commonValue; + newLhsShape[1] = + hasDynamicDims ? ShapedType::kDynamicSize : lhsSqueezedValue; + + auto newLhsType = RankedTensorType::get(newLhsShape, lhsElemTy); + + matmulLhs = rewriter.create( + op->getLoc(), + OpConversionPattern::getTypeConverter()->convertType( + newLhsType), + lhsReshapeInput, rewriter.getI64ArrayAttr(newLhsShape)); + + SmallVector transposedRhsShape; + SmallVector transposedRhsDims; + + // Step: Create the RHS transpose sequence + // RHS = {common, matmul_dim, rhs_squeezed} + // first the common_dims + for (uint32_t i = 0; i < commonElems.size(); i++) { + transposedRhsShape.push_back(commonElems[i].shape); + transposedRhsDims.push_back(commonElems[i].dim); + } + // The matmul_dim of RHS + transposedRhsDims.push_back(maxInputRank - 2); + transposedRhsShape.push_back(rhsBroadcastedShape[maxInputRank - 2]); + // finally all the rhs_squeeze dims + hasDynamicDims = false; + for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) { + bool isDynamicDim = + rhsBroadcastedTy.isDynamic(rhsBroadcastedShape[dim]); + hasDynamicDims |= isDynamicDim; + if (!isDynamicDim && + rhsBroadcastedShape[dim] != lhsBroadcastedShape[dim]) { + rhsSqueezedElems.push_back({dim, rhsBroadcastedShape[dim]}); + rhsSqueezedValue *= rhsBroadcastedShape[dim]; + } + } + rhsSqueezedElems.push_back( + {maxInputRank - 1, rhsBroadcastedShape[maxInputRank - 1]}); + rhsSqueezedValue *= rhsBroadcastedShape[maxInputRank - 1]; + for (uint32_t i = 0; i < rhsSqueezedElems.size(); i++) { + transposedRhsShape.push_back(rhsSqueezedElems[i].shape); + transposedRhsDims.push_back(rhsSqueezedElems[i].dim); + } + + auto transposedRhsType = + RankedTensorType::get(transposedRhsShape, rhsElemTy); + + if (hasDynamicDims) + rhsSqueezedValue = ShapedType::kDynamicSize; + + SmallVector newRhsShape({commonValue, + rhsBroadcastedShape[maxInputRank - 2], + rhsSqueezedValue}); + auto newRhsType = RankedTensorType::get(newRhsShape, rhsElemTy); + + bool rhsNeedsTranspose = isTransposeRequired(transposedRhsDims); + + auto transposedRhsValue = rankBroadcastedRhs; + + if (rhsNeedsTranspose) { + llvm::Optional transposedRhsDimsConst = + tosa::getConstTensor( + rewriter, op, + /*vec=*/transposedRhsDims, + /*shape=*/{static_cast(transposedRhsDims.size())}); + + transposedRhsValue = + rewriter + .create( + op->getLoc(), + OpConversionPattern::getTypeConverter() + ->convertType(transposedRhsType), + rankBroadcastedRhs, transposedRhsDimsConst.getValue()) + .getResult(); + } + + // reshape + matmulRhs = rewriter.create( + op->getLoc(), + OpConversionPattern::getTypeConverter()->convertType( + newRhsType), + transposedRhsValue, rewriter.getI64ArrayAttr(newRhsShape)); + } + + auto matmulLhsShape = + matmulLhs.getType().template cast().getShape(); + auto matmulRhsShape = + matmulRhs.getType().template cast().getShape(); + + // The reshape/transpose should ensure the tosa.matmul always has same + // batch size for either matrix. If if shapes are dynamic, they'll be + // appropriately handled. + assert(matmulLhsShape[0] == matmulRhsShape[0] && + "tosa.matmul needs same batchsize on LHS and RHS"); + + SmallVector matmulOutputShape( + {matmulLhsShape[0], matmulLhsShape[1], matmulRhsShape[2]}); + Type outputElemTy; + if (lhsElemTy.isa()) { + outputElemTy = lhsElemTy; + } else { // qint8 emits i32 matmul output + outputElemTy = rewriter.getIntegerType(32); + } + + auto mmOutputTy = RankedTensorType::get(matmulOutputShape, outputElemTy); + auto mmOpResult = + rewriter + .create( + op->getLoc(), + OpConversionPattern::getTypeConverter()->convertType( + mmOutputTy), + matmulLhs, matmulRhs) + .getResult(); + + // Perform the reshape to output shape. This is always required unless both + // inputs are rank=3, in which case the tosa.matmul output itself is + // correctly shaped. + bool performOpReshape = !(lhsRank == 3 && rhsRank == 3); + + auto outputTy = OpConversionPattern::getTypeConverter() + ->convertType(op.getType()) + .template cast(); + + if (performOpReshape) { + // Since the output shape may be unknown, we construct it + // independently and reshape. Otherwise reshape may be expressed for + // an unknown to-be-inferred output shape. The final tensor.cast + // reshapes the known shape to the desired output shape. + auto computeOpShape = [&](SmallVector &reshapedOpShape, + SmallVector &transposedOpDims, + SmallVector &transposedOpShapes) { + if (maxInputRank == 1) + return; + + if (maxInputRank == 2) { + if (lhsRank == 2) + reshapedOpShape.push_back(lhsShape[0]); + if (rhsRank == 2) + reshapedOpShape.push_back(rhsShape[1]); + return; + } + + // Step: Construct the output transpose/reshape information + // First the common_dims + for (uint32_t i = 0; i < commonElems.size(); i++) { + reshapedOpShape.push_back(commonElems[i].shape); + transposedOpDims.push_back(commonElems[i].dim); + } + + // Then the LHS squeezed dims + for (uint32_t i = 0; i < lhsSqueezedElems.size() - 1; i++) { + // Only dims that don't broadcast - broadcasting ones come from the + // other input. + if (lhsSqueezedElems[i].shape != 1) { + reshapedOpShape.push_back(lhsSqueezedElems[i].shape); + transposedOpDims.push_back(lhsSqueezedElems[i].dim); + } + } + // The last squeezed dim is lhs[-2] which needs to be + // checked separately for broadcasting + if (lhsRank > 1) { + reshapedOpShape.push_back(lhsBroadcastedShape[maxInputRank - 2]); + transposedOpDims.push_back(maxInputRank - 2); + } + + // then the RHS squeezed dims except rhs[-1] which is handled like + // lhs[-2] + for (uint32_t i = 0; i < rhsSqueezedElems.size() - 1; i++) { + if (rhsSqueezedElems[i].shape != 1) { + reshapedOpShape.push_back(rhsSqueezedElems[i].shape); + transposedOpDims.push_back(rhsSqueezedElems[i].dim); + } + } + // rhs[-1] + if (rhsRank > 1) { + reshapedOpShape.push_back(rhsBroadcastedShape[maxInputRank - 1]); + transposedOpDims.push_back(maxInputRank - 1); + } + + // Final transposed output shape construction + for (uint32_t i = 0; i < maxInputRank - 2; i++) { + if (lhsBroadcastedTy.isDynamicDim(i)) { + transposedOpShapes.push_back(ShapedType::kDynamicSize); + } else { + if (lhsBroadcastedShape[i] == rhsBroadcastedShape[i]) { + transposedOpShapes.push_back(lhsBroadcastedShape[i]); + } else { + transposedOpShapes.push_back(lhsBroadcastedShape[i] == 1 + ? rhsBroadcastedShape[i] + : lhsBroadcastedShape[i]); + } + } + } + if (lhsRank > 1) + transposedOpShapes.push_back(lhsBroadcastedShape[maxInputRank - 2]); + if (rhsRank > 1) + transposedOpShapes.push_back(rhsBroadcastedShape[maxInputRank - 1]); + + return; + }; + + SmallVector reshapedOpShape, transposedOpShape; + SmallVector transposedOpDims; + + computeOpShape(reshapedOpShape, transposedOpDims, transposedOpShape); + + bool opNeedsTranspose = isTransposeRequired(transposedOpDims); + + // Perform reshape + auto reshapedOpType = + RankedTensorType::get(reshapedOpShape, outputElemTy); + auto reshapedOp = rewriter.create( + op->getLoc(), + OpConversionPattern::getTypeConverter()->convertType( + reshapedOpType), + mmOpResult, rewriter.getI64ArrayAttr(reshapedOpShape)); + + if (opNeedsTranspose) { + + llvm::Optional transposedOpShapeConst = + tosa::getConstTensor( + rewriter, op, + /*vec=*/transposedOpDims, + /*shape=*/{static_cast(transposedOpDims.size())}); + + auto transposedOpType = + RankedTensorType::get(transposedOpShape, outputElemTy); + auto transposedOp = rewriter.create( + op->getLoc(), + OpConversionPattern::getTypeConverter()->convertType( + transposedOpType), + reshapedOp.getResult(), transposedOpShapeConst.getValue()); + + rewriter.replaceOpWithNewOp(op, outputTy, transposedOp); + } else { + rewriter.replaceOpWithNewOp(op, outputTy, reshapedOp); + } + } else { + rewriter.replaceOpWithNewOp(op, outputTy, mmOpResult); + } + + return success(); + } +}; + } // namespace // ----------------------------------------------------------------------------- @@ -774,6 +1334,14 @@ public: INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeDimOp, ConvertAtenSqueezeOneDimOp) #undef INSERT_SQUEEZE_OP_PATTERN +#define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp); + INSERT_MATMUL_ATENOP_PATTERN(AtenMmOp); + INSERT_MATMUL_ATENOP_PATTERN(AtenBmmOp); +#undef INSERT_MATMUL_ATEMOP_PATTERN + #define INSERT_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context);