//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" #include "../PassDetail.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Traits.h" #include "mlir/IR/Matchers.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; namespace { // These legalizations are for unary ops with only for floating point datatypes. // There is no supported quantized integer mode for these. template class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value self = adaptor.self(); auto selfTy = self.getType().cast(); if (!selfTy) return op.emitError("Only Tensor types supported in TOSA"); if (selfTy.getElementType().isa()) { rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( op.getType()), self); return success(); } else { return op.emitError( "Only floating-point datatype legalization supported"); } } }; // These unary op legalizations are identical for floating-point // or quantized types template class ConvertAtenUnaryOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( op.getType()), adaptor.self()); return success(); } }; // These binary op legalizations are identical for floating-point // or quantized types template class ConvertAtenBinaryOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.self(); auto lhsTy = lhs.getType().cast(); Value rhs = adaptor.other(); auto rhsTy = rhs.getType().cast(); if (!lhsTy || !rhsTy) return op.emitError("Only Tensor types supported in TOSA"); auto lhsElemTy = lhsTy.getElementType(); auto rhsElemTy = rhsTy.getElementType(); if (lhsElemTy != rhsElemTy) return op.emitError("Add: input datatypes mismatched"); rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( op.getType()), lhs, rhs); return success(); } }; // These binary op legalizations are specific to add/sub which have an // alpha multiplier. template class ConvertAtenAddSubOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.self(); auto lhsTy = lhs.getType().cast(); Value rhs = adaptor.other(); auto rhsTy = rhs.getType().cast(); if (!lhsTy || !rhsTy) return op.emitError("Only Tensor types supported in TOSA"); auto lhsElemTy = lhsTy.getElementType(); auto rhsElemTy = rhsTy.getElementType(); if (lhsElemTy != rhsElemTy) return op.emitError("Add: input datatypes mismatched"); // FIXME: Handle alpha. // Needs extraction of floating point constant. if (lhsElemTy.isa()) { rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( op.getType()), lhs, rhs); return success(); } else { return op.emitError( "Only floating-point datatype legalization supported"); } } }; // This defines a template to construct ops whose legalizations are // specialized. template class ConvertAtenOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenTanhOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.self(); auto selfTy = self.getType().cast(); if (selfTy && selfTy.getElementType().isa()) { rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self); return success(); } else { // Sigmoid legalization in TOSA for quantized element-type uses // specialized tosa.table construct. return op.emitError( "Only floating-point datatype legalization currently supported"); } } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenSigmoidOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.self(); auto selfTy = self.getType().cast(); if (selfTy && selfTy.getElementType().isa()) { rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self); return success(); } else { // Sigmoid legalization in TOSA for quantized element-type uses // specialized tosa.table construct. return op.emitError( "Only floating-point datatype legalization currently supported"); } } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenReluOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.self(); auto selfTy = self.getType().cast(); // Maps to tosa.clamp which has both int and fp limits. int64_t clampMin = 0; Value clampIn = self; if (selfTy) { // Rescale the clampIn for quantized types. TBD if (!selfTy.getElementType().isa()) { return op.emitError( "Only floating-point datatype legalization currently supported"); } rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), clampIn, rewriter.getI64IntegerAttr(clampMin), rewriter.getI64IntegerAttr(std::numeric_limits::max()), rewriter.getF32FloatAttr(0.0f), rewriter.getF32FloatAttr(std::numeric_limits::max())); return success(); } else { return op.emitError("Only Tensor types supported in TOSA"); } } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenMulTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value lhs = adaptor.self(); auto lhsTy = lhs.getType().cast(); Value rhs = adaptor.other(); auto rhsTy = rhs.getType().cast(); if (!lhsTy || !rhsTy) return op.emitError("Only Tensor types supported in TOSA"); auto lhsElemTy = lhsTy.getElementType(); auto rhsElemTy = rhsTy.getElementType(); if (lhsElemTy != rhsElemTy) return op.emitError("Add: input datatypes mismatched"); if (lhsElemTy.isa()) { rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), lhs, rhs, /*shift=*/0); return success(); } else { // Quantized multiplication may need to rescale inputs. return op.emitError( "Only floating-point datatype legalization currently supported"); } } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenDivTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value lhs = adaptor.self(); auto lhsTy = lhs.getType().cast(); Value rhs = adaptor.other(); auto rhsTy = rhs.getType().cast(); if (!lhsTy || !rhsTy) return op.emitError("Only Tensor types supported in TOSA"); auto lhsElemTy = lhsTy.getElementType(); auto rhsElemTy = rhsTy.getElementType(); if (lhsElemTy != rhsElemTy) return op.emitError("Add: input datatypes mismatched"); if (lhsElemTy.isa()) { auto rcpOp = rewriter.create( op->getLoc(), getTypeConverter()->convertType(op.getType()), rhs); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), lhs, rcpOp.getResult(), /*shift=*/0); } else { rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), lhs, rhs); } return success(); } using ReductionConvFunc = llvm::Optional (*)(PatternRewriter &, Operation *, RankedTensorType, Value, ElementsAttr, bool); // They all constitute a common form invoking the appropriate // converion function in TosaLegalizeCommon.cpp template class ConvertAtenReductionOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; // Each variant must implement corresponding parameter parsing options virtual LogicalResult readReduceDimsAndKeepDims( AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, ElementsAttr &reduceDimsAttr, bool &keepDims) const { return rewriter.notifyMatchFailure( op, "Unimplemented reduce_dims and keep_dims parsing function"); } // Common rewriter for all reduction ops, calls the specific implementation of // readReduceDimsAndKeepDims() needed for the op variant. LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value self = adaptor.self(); auto selfTy = self.getType().cast(); if (!selfTy) return op.emitError("Only Tensor types supported in TOSA"); auto outputTy = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template cast(); if (!outputTy) return op.emitError( "Only ranked tensor type outputs permitted for reduce_mean"); ElementsAttr reduceDimsAttr; bool keepDims; if (failed(readReduceDimsAndKeepDims(op, adaptor, rewriter, reduceDimsAttr, keepDims))) return failure(); llvm::Optional result = ConversionFuncT(rewriter, op, outputTy, self, reduceDimsAttr, keepDims); if (!result) return failure(); // TBD - support dtype casting. rewriter.replaceOp(op, {result.getValue()}); return success(); } }; // This reduction op legalization template handles op variants that have // explicit reduce_dims dimensions (provided as a list) and keep_dims // parameters. template class ConvertAtenMultipleDimsReductionOp : public ConvertAtenReductionOp { using ConvertAtenReductionOp::ConvertAtenReductionOp; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult readReduceDimsAndKeepDims(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, ElementsAttr &reduceDimsAttr, bool &keepDims) const override { SmallVector reduceDims; if (!matchPattern(op.dim(), m_TorchConstantIntList(reduceDims))) return rewriter.notifyMatchFailure(op, "non-const dim parameter unsupported"); int64_t N = reduceDims.size(); auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type()); reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType, llvm::makeArrayRef(reduceDims)); keepDims = false; if (!matchPattern(op.keepdim(), m_TorchConstantBool(&keepDims))) return rewriter.notifyMatchFailure( op, "non-const keepdim parameter unsupported"); return success(); } }; // This reduction op legalization template handles op variants that reduce in // only one explicit dim which is provided as a number (rather than a list), and // a keep_dims parameter. template class ConvertAtenOneDimReductionOp : public ConvertAtenReductionOp { using ConvertAtenReductionOp::ConvertAtenReductionOp; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult readReduceDimsAndKeepDims(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, ElementsAttr &reduceDimsAttr, bool &keepDims) const override { int64_t reduceDim; if (!matchPattern(op.dim(), m_TorchConstantInt(&reduceDim))) return rewriter.notifyMatchFailure(op, "non-const dim parameter unsupported"); auto reduceDimsType = RankedTensorType::get({1}, rewriter.getI64Type()); reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType, llvm::makeArrayRef({reduceDim})); keepDims = false; if (!matchPattern(op.keepdim(), m_TorchConstantBool(&keepDims))) return rewriter.notifyMatchFailure( op, "non-const keepdim parameter unsupported"); return success(); } }; // This reduction op legalization template handles op variants that reduce all // dims does not keep dims. template class ConvertAtenAllDimsReductionOp : public ConvertAtenReductionOp { public: using ConvertAtenReductionOp::ConvertAtenReductionOp; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult readReduceDimsAndKeepDims(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, ElementsAttr &reduceDimsAttr, bool &keepDims) const override { auto self = adaptor.self(); auto selfTy = self.getType().template cast(); // Select all dims to reduce SmallVector reduceDims; for (int64_t i = 0; i < selfTy.getRank(); i++) reduceDims.push_back(i); int64_t N = selfTy.getRank(); auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type()); reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType, llvm::makeArrayRef(reduceDims)); keepDims = false; return success(); } }; template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenArgmaxOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.self(); auto selfTy = self.getType().template cast(); if (!selfTy) return op.emitError("Only ranked tensor types supported in TOSA argmax"); int64_t reduceDim; if (!matchPattern(op.dim(), m_TorchConstantInt(&reduceDim))) { // NoneType indicates reduce on all dims reduceDim = -1; } bool keepDim = false; if (!matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) return rewriter.notifyMatchFailure( op, "non-const keepdim parameter unsupported"); auto resultTy = getTypeConverter() ->convertType(op.getResult().getType()) .cast(); auto outputETy = resultTy.getElementType(); // Create a single instance of tosa.argmax. // Multiple dims require chained construct. auto buildArgmax = [&](int64_t reduceDim, Value input) -> Value { auto inputTy = input.getType().cast(); auto inputShape = inputTy.getShape(); SmallVector outputShapeArr = {}; int32_t i = 0; for (auto &dim : inputShape) { if (i++ != reduceDim) { outputShapeArr.push_back(dim); } else { if (keepDim) outputShapeArr.push_back(1); } } // Tosa argmax output is i32, while Torch backend mandates i64. auto outputReduceTy = RankedTensorType::get( ArrayRef(outputShapeArr), rewriter.getI32Type()); auto reduceDimAttr = rewriter.getIntegerAttr(rewriter.getI64Type(), reduceDim); return rewriter .create(op->getLoc(), getTypeConverter()->convertType(outputReduceTy), input, reduceDimAttr) .getResult(); }; // Convert the final index to i64 for backend finalization, However, i64 // is not a defined type for tosa.cast, so using arith.extsi instead. auto castToInt64 = [&](Value result) -> LogicalResult { auto resTy = result.getType().cast(); if (!resTy) return op.emitError("Argmax: Result is not a shaped type"); auto resShape = resTy.getShape(); auto outTy = RankedTensorType::get(resShape, outputETy); // rewriter.getI64Type()); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(outTy), result); return success(); }; if (reduceDim == -1) { // reducing on all dims Value input = self; for (int dim = 0; dim < selfTy.getRank(); dim++) { // progressively reduce each 0-th dim input = buildArgmax(0, input); } return castToInt64(input); } else { return castToInt64(buildArgmax(reduceDim, self)); } return success(); } template class ConvertAtenSqueezeOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; // Each variant must implement corresponding parameter parsing options virtual LogicalResult generateSqueezedShape(AtenOpT op, RankedTensorType selfTy, ConversionPatternRewriter &rewriter, SmallVector &squeezedShape) const { return rewriter.notifyMatchFailure( op, "Unimplemented dim/dim-list parsing function"); } // Common rewriter for all squeeze ops, calls the specific implementation of // generateSqueezedShape() needed for the op variant. LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value self = adaptor.self(); auto selfTy = self.getType().template cast(); if (!selfTy) return op.emitError("Only ranked tensor types supported in TOSA argmax"); SmallVector newOutputShape; if (failed(generateSqueezedShape(op, selfTy, rewriter, newOutputShape))) return op.emitError("Squeeze could not compute new shape"); auto resultTy = OpConversionPattern::getTypeConverter() ->convertType(op.getResult().getType()) .template cast(); auto resultElemTy = resultTy.getElementType(); auto newOutputTy = RankedTensorType::get(newOutputShape, resultElemTy); auto reshapeOp = rewriter.create( op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( newOutputTy), self, rewriter.getI64ArrayAttr(newOutputShape)); rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( newOutputTy), reshapeOp); return success(); } }; template class ConvertAtenSqueezeOneDimOp : public ConvertAtenSqueezeOp { using ConvertAtenSqueezeOp::ConvertAtenSqueezeOp; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult generateSqueezedShape(AtenOpT op, RankedTensorType selfTy, ConversionPatternRewriter &rewriter, SmallVector &squeezedShape) const override { int64_t squeezeDim; if (!matchPattern(op.dim(), m_TorchConstantInt(&squeezeDim))) return rewriter.notifyMatchFailure(op, "non-const dim parameter unsupported"); // Handle negative dim if (squeezeDim < 0) squeezeDim = squeezeDim + selfTy.getRank(); auto selfShape = selfTy.getShape(); // Only dims statically known to have size=1 are reduced. // Dynamic dims are treated as unknowns and will not be squeezed // even if dim parameter says it should be. uint32_t dimNum = 0; for (auto &dim : selfShape) { if (dim != 1 || squeezeDim != dimNum) squeezedShape.push_back(dim); dimNum++; } return success(); } }; template class ConvertAtenSqueezeAllDimsOp : public ConvertAtenSqueezeOp { using ConvertAtenSqueezeOp::ConvertAtenSqueezeOp; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult generateSqueezedShape(AtenOpT op, RankedTensorType selfTy, ConversionPatternRewriter &rewriter, SmallVector &squeezedShape) const override { auto selfShape = selfTy.getShape(); // Dims that may dynamically resolve to 1 are not reduced here. Only // compile-time resolvable dims are handled here. for (auto &dim : selfShape) { if (dim != 1) squeezedShape.push_back(dim); } return success(); } }; // FIXME(AG): This will eventually go into a Tosa*Utils file // Convert an fp32 scalar into tosa fp32 tensor. static LogicalResult tosaF32TensorFromTorchFloat(ConversionPatternRewriter &rewriter, Operation *op, Value torchScalarValue, Value &tosaTensor) { double scalarValue; if (!matchPattern(torchScalarValue, m_TorchConstantFloat(&scalarValue))) return failure(); // Construct a tosa.const tosaTensor = mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, scalarValue); return success(); } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenPowTensorScalarOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.self(); auto selfTy = self.getType().template cast(); if (!selfTy) return op.emitError("Only ranked tensor types supported in TOSA Pow"); if (!selfTy.getElementType().isa()) return op.emitError("Only floating-point datatype legalization supported"); Value expTensor; Value expScalar = op.exponent(); if (failed(tosaF32TensorFromTorchFloat(rewriter, op.getOperation(), expScalar, expTensor))) return op.emitError("Currently only scalar constants are supported for " "conversion in TOSA Pow operation"); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self, expTensor); return success(); } // Perform torch matmul, mm and bmm template class ConvertAtenMatMulOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.self(); auto lhsTy = lhs.getType().cast(); // Aten matmul, mm and bmm call operand2 by different names. Value rhs = adaptor.getOperands()[1]; auto rhsTy = rhs.getType().cast(); if (!lhsTy || !rhsTy) return op.emitError("Only ranked tensor types supported in TOSA matmul"); auto lhsRank = lhsTy.getRank(); auto rhsRank = rhsTy.getRank(); // Mm takes two 2D tensors if (isa(op)) { assert(lhsRank == 2 && rhsRank == 2 && "aten.mm called but matrix rank != 2"); } // Bmm takes two 2D tensors if (isa(op)) { assert(lhsRank == 3 && rhsRank == 3 && "aten.bmm called but matrix rank != 2"); } auto lhsShape = lhsTy.getShape(); auto rhsShape = rhsTy.getShape(); auto lhsElemTy = lhsTy.getElementType(); auto rhsElemTy = rhsTy.getElementType(); if (lhsElemTy != rhsElemTy) return op.emitError("Matmul: input datatypes mismatched"); // Legalization constructs may offer input shapes but expect output shapes // to be inferred, e.g. // func @forward(%arg0: !torch.vtensor<[14,19],f32>, // %arg1: !torch.vtensor<[19,28],f32>) -> // !torch.vtensor<[?,?],f32> // This is tricky with matmul, since TOSA matmul is on 3D inputs. // This means the need to reshape potentially both inputs and outputs, // and reshape to unknown shape is undefined. auto maxInputRank = lhsRank > rhsRank ? lhsRank : rhsRank; // If performing dot product on vectors, the RHS is synthetically transposed if (maxInputRank == 1) maxInputRank++; // Obtaining the rank broadcasted shapes of tensors makes it easier to // construct the input and output reshaping logic. auto getRankBroadcastedShape = [&](Value tensor, bool isRHS) -> SmallVector { auto tensorTy = tensor.getType().cast(); auto tensorShape = tensorTy.getShape(); auto tensorRank = tensorTy.getRank(); SmallVector bcastedShape; auto bcastDims = maxInputRank - tensorRank; if (isRHS && (tensorRank == 1) && bcastDims) { // RHS with rank1 is special. It be synthetically transposed to dim[:-2] for (int32_t i = 0; i < bcastDims - 1; i++) bcastedShape.push_back(1); bcastedShape.push_back(tensorShape[0]); bcastedShape.push_back(1); } else { if (bcastDims > 0) { // rank broadcast for (uint32_t i = 0; i < bcastDims; i++) bcastedShape.push_back(1); } for (auto &dim : tensorShape) bcastedShape.push_back(dim); } return bcastedShape; }; // Step: Rank broadcast the two inputs. auto lhsBroadcastedShape = getRankBroadcastedShape(lhs, false); auto lhsBroadcastedTy = RankedTensorType::get(lhsBroadcastedShape, lhsElemTy); auto rhsBroadcastedShape = getRankBroadcastedShape(rhs, true); auto rhsBroadcastedTy = RankedTensorType::get(rhsBroadcastedShape, rhsElemTy); auto rankBroadcastedLhs = lhsRank == maxInputRank ? lhs : rewriter.create( op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( lhsBroadcastedTy), lhs, rewriter.getI64ArrayAttr(lhsBroadcastedShape)); auto rankBroadcastedRhs = rhsRank == maxInputRank ? rhs : rewriter.create( op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( rhsBroadcastedTy), rhs, rewriter.getI64ArrayAttr(rhsBroadcastedShape)); // TOSA matmul is performed on two 3D inputs and generates a 3D output. // Lower ranked tensors are dim-1 reshaped up to 3D auto reshapeUpTo3DTensor = [&](Value tensor) -> Value { auto tensorTy = tensor.getType().cast(); auto rank = tensorTy.getRank(); assert(rank <= 3 && "reshapeUpTo3D tensor must receive rank <= 3"); if (rank == 3) return tensor; auto shape = tensorTy.getShape(); SmallVector newShape({1, 1, 1}); if (rank == 2) { // batchsize = 1 newShape[1] = shape[0]; newShape[2] = shape[1]; } else { // rank 1 newShape[2] = shape[0]; } auto newType = RankedTensorType::get(newShape, tensorTy.getElementType()); return rewriter.create( op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( newType), tensor, rewriter.getI64ArrayAttr(newShape)); }; // Where broadcasting is required in one or more batch dims, the following // is done. // Where all batch dims are involved in broadcasting: // Given A: 3x1x5x6 and B: 1x4x6x7 // 1. Reshape A to 1x15x6 (squeeze all batchdims into dim1) // 2. Transpose B to 6x1x4x7, Reshape to 1x6x28 // 3. tosa.Matmul 1x15x6 1x6x28 = 1x15x28 // 4. Reshape out to 3x5x4x7, Transpose to 3x4x5x7 // Where there are batch dimensions that are broadcast and not, the // treatment is to have dim0 correspond to product of all non-broadcast // dimsizes: // Given A: 4x8x16x32 B: 8x32x17 // 1. Reshape A to 8x64x32 (squeeze all unbroadcasted dims into dim0, // broadcasted dims into dim1) // 2. No transpose or reshape of B as its batchdims are not broadcast to. // 3. tosa.Matmul 8x64x32 8x32x17 = 8x64x17 // 4. Reshape to 8x4x16x17, Transpose to 4x8x16x17 // Check if we need to perform the broadcast on batch dim // Not needed if max rank < 3, or if maxrank == 3 and dim[0] matches auto needsBatchDimBroadcast = [&]() -> bool { if (maxInputRank < 3) { return false; } else { if (maxInputRank == 3 && lhsBroadcastedShape[0] == rhsBroadcastedShape[0]) { return false; } return true; } }; auto performBatchDimBroadcast = needsBatchDimBroadcast(); // Inputs to the tosa.matmul Value matmulLhs, matmulRhs; using TensorShape_t = struct { int64_t dim; int64_t shape; }; // Transpose needs to done if transposeDims are not non-monotonically // increasing. E.g. [0, 1, 2, 3]: No transpose [1, 0, 2, 3]: Transpose dim0 // and dim1 The order need not be sequential, since one or more dims may // have been removed due to broadcasting. auto isTransposeRequired = [](SmallVector transposedDims) -> bool { int32_t lastDim = -1; for (auto &dim : transposedDims) { if (lastDim > dim) return true; lastDim = dim; } return false; }; SmallVector commonElems, lhsSqueezedElems, rhsSqueezedElems; if (!performBatchDimBroadcast) { // Simple with no broadcasting artifacts. Just reshape up to 3D matmulLhs = reshapeUpTo3DTensor(rankBroadcastedLhs); matmulRhs = reshapeUpTo3DTensor(rankBroadcastedRhs); } else { // In this case, either or both input matrices involve broadcasting on // their batch dimensions. For example: // 4x5x6, 1x6x7 -> 4x5x7 // 4x1x5x6, 1x3x6x7 -> 4x3x5x7 // Though maxInputRank is necessarily >=3 here, individual matrices may be // lower rank. // E.g. 3x4x5x6, 6 -> 3x4x5 // These are the accumulated products of the shape of each dim: // 1. common dimensions: upper dimensions (dims other than two rightmost) // whose shapes are the same for both LHS and RHS. // 2. LHS squeezed dimensions: all dimensions of LHS that involve // broadcasting in either direction, plus the LHS[-2] shape // 3. RHS squeezed dimensions: all dimensions of RHS that involve // broadcasting in either direction, plus the RHS[-1] shape int64_t commonValue = 1, lhsSqueezedValue = 1, rhsSqueezedValue = 1; // For both LHS and RHS, the dimensions are separated into the common, // squeezed and remaining dim. E.g. given // LHS = 3x4x5x6 // RHS = 1x4x6x7 // common = {{dim=1, shape=4}} // lhs squeezed = {{dim=0, shape=3}, // {dim=2, shape=5}} // rhs squeezed = {{dim=0, shape=1}, // {dim=2, shape=7}} // The matmul dim is LHS[-1] and RHS[-2], i.e. 6. // Once this is obtained, LHS and RHS are expressed as: // LHS = {common, lhs_squeezed, matmul_dim} // RHS = {common, matmul_dim, rhs_squeezed} // The matmul is then performed to obtain output: // matmul_out = {common, lhs_squeezed, rhs_squeezed} // Finally, we reshape to 'unsqueeze' the LHS and RHS parts and transpose // them back to their correct positions. SmallVector transposedLhsShape; SmallVector transposedLhsDims; // Step: generate the common dim/shape information for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) { bool isDynamicDim = lhsBroadcastedTy.isDynamic(lhsBroadcastedShape[dim]); if (isDynamicDim || lhsBroadcastedShape[dim] == rhsBroadcastedShape[dim]) { commonValue *= lhsBroadcastedShape[dim]; commonElems.push_back({dim, lhsBroadcastedShape[dim]}); } } // Step: generate the LHS squeezed dim/shape information. bool hasDynamicDims = false; for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) { bool isDynamicDim = lhsBroadcastedTy.isDynamic(lhsBroadcastedShape[dim]); hasDynamicDims |= isDynamicDim; if (!isDynamicDim && lhsBroadcastedShape[dim] != rhsBroadcastedShape[dim]) { lhsSqueezedValue *= lhsBroadcastedShape[dim]; lhsSqueezedElems.push_back({dim, lhsBroadcastedShape[dim]}); } } // including LHS[-2] lhsSqueezedElems.push_back( {maxInputRank - 2, lhsBroadcastedShape[maxInputRank - 2]}); lhsSqueezedValue *= lhsBroadcastedShape[maxInputRank - 2]; // Step: Create the tosa.transpose array. If this array has a // non-monotonic series of dims, perform transpose. // First the common_elems for (uint32_t i = 0; i < commonElems.size(); i++) { transposedLhsShape.push_back(commonElems[i].shape); transposedLhsDims.push_back(commonElems[i].dim); } // then the lhs_squeezed elems for (uint32_t i = 0; i < lhsSqueezedElems.size(); i++) { transposedLhsShape.push_back(lhsSqueezedElems[i].shape); transposedLhsDims.push_back(lhsSqueezedElems[i].dim); } // then the final dim transposedLhsDims.push_back(maxInputRank - 1); transposedLhsShape.push_back(lhsBroadcastedShape[maxInputRank - 1]); bool lhsNeedsTranspose = isTransposeRequired(transposedLhsDims); auto lhsReshapeInput = rankBroadcastedLhs; if (lhsNeedsTranspose) { auto transposedLhsType = RankedTensorType::get(transposedLhsShape, rhsElemTy); llvm::Optional transposedLhsDimsConst = tosa::getConstTensor( rewriter, op, /*vec=*/transposedLhsDims, /*shape=*/{static_cast(transposedLhsDims.size())}); lhsReshapeInput = rewriter .create( op->getLoc(), OpConversionPattern::getTypeConverter() ->convertType(transposedLhsType), rankBroadcastedLhs, transposedLhsDimsConst.getValue()) .getResult(); } // LHS = {common, lhs_squeezed, matmul_dim} SmallVector newLhsShape( {1, 1, lhsBroadcastedShape[maxInputRank - 1]}); newLhsShape[0] = commonValue; newLhsShape[1] = hasDynamicDims ? ShapedType::kDynamicSize : lhsSqueezedValue; auto newLhsType = RankedTensorType::get(newLhsShape, lhsElemTy); matmulLhs = rewriter.create( op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( newLhsType), lhsReshapeInput, rewriter.getI64ArrayAttr(newLhsShape)); SmallVector transposedRhsShape; SmallVector transposedRhsDims; // Step: Create the RHS transpose sequence // RHS = {common, matmul_dim, rhs_squeezed} // first the common_dims for (uint32_t i = 0; i < commonElems.size(); i++) { transposedRhsShape.push_back(commonElems[i].shape); transposedRhsDims.push_back(commonElems[i].dim); } // The matmul_dim of RHS transposedRhsDims.push_back(maxInputRank - 2); transposedRhsShape.push_back(rhsBroadcastedShape[maxInputRank - 2]); // finally all the rhs_squeeze dims hasDynamicDims = false; for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) { bool isDynamicDim = rhsBroadcastedTy.isDynamic(rhsBroadcastedShape[dim]); hasDynamicDims |= isDynamicDim; if (!isDynamicDim && rhsBroadcastedShape[dim] != lhsBroadcastedShape[dim]) { rhsSqueezedElems.push_back({dim, rhsBroadcastedShape[dim]}); rhsSqueezedValue *= rhsBroadcastedShape[dim]; } } rhsSqueezedElems.push_back( {maxInputRank - 1, rhsBroadcastedShape[maxInputRank - 1]}); rhsSqueezedValue *= rhsBroadcastedShape[maxInputRank - 1]; for (uint32_t i = 0; i < rhsSqueezedElems.size(); i++) { transposedRhsShape.push_back(rhsSqueezedElems[i].shape); transposedRhsDims.push_back(rhsSqueezedElems[i].dim); } auto transposedRhsType = RankedTensorType::get(transposedRhsShape, rhsElemTy); if (hasDynamicDims) rhsSqueezedValue = ShapedType::kDynamicSize; SmallVector newRhsShape({commonValue, rhsBroadcastedShape[maxInputRank - 2], rhsSqueezedValue}); auto newRhsType = RankedTensorType::get(newRhsShape, rhsElemTy); bool rhsNeedsTranspose = isTransposeRequired(transposedRhsDims); auto transposedRhsValue = rankBroadcastedRhs; if (rhsNeedsTranspose) { llvm::Optional transposedRhsDimsConst = tosa::getConstTensor( rewriter, op, /*vec=*/transposedRhsDims, /*shape=*/{static_cast(transposedRhsDims.size())}); transposedRhsValue = rewriter .create( op->getLoc(), OpConversionPattern::getTypeConverter() ->convertType(transposedRhsType), rankBroadcastedRhs, transposedRhsDimsConst.getValue()) .getResult(); } // reshape matmulRhs = rewriter.create( op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( newRhsType), transposedRhsValue, rewriter.getI64ArrayAttr(newRhsShape)); } auto matmulLhsShape = matmulLhs.getType().template cast().getShape(); auto matmulRhsShape = matmulRhs.getType().template cast().getShape(); // The reshape/transpose should ensure the tosa.matmul always has same // batch size for either matrix. If if shapes are dynamic, they'll be // appropriately handled. assert(matmulLhsShape[0] == matmulRhsShape[0] && "tosa.matmul needs same batchsize on LHS and RHS"); SmallVector matmulOutputShape( {matmulLhsShape[0], matmulLhsShape[1], matmulRhsShape[2]}); Type outputElemTy; if (lhsElemTy.isa()) { outputElemTy = lhsElemTy; } else { // qint8 emits i32 matmul output outputElemTy = rewriter.getIntegerType(32); } auto mmOutputTy = RankedTensorType::get(matmulOutputShape, outputElemTy); auto mmOpResult = rewriter .create( op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( mmOutputTy), matmulLhs, matmulRhs) .getResult(); // Perform the reshape to output shape. This is always required unless both // inputs are rank=3, in which case the tosa.matmul output itself is // correctly shaped. bool performOpReshape = !(lhsRank == 3 && rhsRank == 3); auto outputTy = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template cast(); if (performOpReshape) { // Since the output shape may be unknown, we construct it // independently and reshape. Otherwise reshape may be expressed for // an unknown to-be-inferred output shape. The final tensor.cast // reshapes the known shape to the desired output shape. auto computeOpShape = [&](SmallVector &reshapedOpShape, SmallVector &transposedOpDims, SmallVector &transposedOpShapes) { if (maxInputRank == 1) return; if (maxInputRank == 2) { if (lhsRank == 2) reshapedOpShape.push_back(lhsShape[0]); if (rhsRank == 2) reshapedOpShape.push_back(rhsShape[1]); return; } // Step: Construct the output transpose/reshape information // First the common_dims for (uint32_t i = 0; i < commonElems.size(); i++) { reshapedOpShape.push_back(commonElems[i].shape); transposedOpDims.push_back(commonElems[i].dim); } // Then the LHS squeezed dims for (uint32_t i = 0; i < lhsSqueezedElems.size() - 1; i++) { // Only dims that don't broadcast - broadcasting ones come from the // other input. if (lhsSqueezedElems[i].shape != 1) { reshapedOpShape.push_back(lhsSqueezedElems[i].shape); transposedOpDims.push_back(lhsSqueezedElems[i].dim); } } // The last squeezed dim is lhs[-2] which needs to be // checked separately for broadcasting if (lhsRank > 1) { reshapedOpShape.push_back(lhsBroadcastedShape[maxInputRank - 2]); transposedOpDims.push_back(maxInputRank - 2); } // then the RHS squeezed dims except rhs[-1] which is handled like // lhs[-2] for (uint32_t i = 0; i < rhsSqueezedElems.size() - 1; i++) { if (rhsSqueezedElems[i].shape != 1) { reshapedOpShape.push_back(rhsSqueezedElems[i].shape); transposedOpDims.push_back(rhsSqueezedElems[i].dim); } } // rhs[-1] if (rhsRank > 1) { reshapedOpShape.push_back(rhsBroadcastedShape[maxInputRank - 1]); transposedOpDims.push_back(maxInputRank - 1); } // Final transposed output shape construction for (uint32_t i = 0; i < maxInputRank - 2; i++) { if (lhsBroadcastedTy.isDynamicDim(i)) { transposedOpShapes.push_back(ShapedType::kDynamicSize); } else { if (lhsBroadcastedShape[i] == rhsBroadcastedShape[i]) { transposedOpShapes.push_back(lhsBroadcastedShape[i]); } else { transposedOpShapes.push_back(lhsBroadcastedShape[i] == 1 ? rhsBroadcastedShape[i] : lhsBroadcastedShape[i]); } } } if (lhsRank > 1) transposedOpShapes.push_back(lhsBroadcastedShape[maxInputRank - 2]); if (rhsRank > 1) transposedOpShapes.push_back(rhsBroadcastedShape[maxInputRank - 1]); return; }; SmallVector reshapedOpShape, transposedOpShape; SmallVector transposedOpDims; computeOpShape(reshapedOpShape, transposedOpDims, transposedOpShape); bool opNeedsTranspose = isTransposeRequired(transposedOpDims); // Perform reshape auto reshapedOpType = RankedTensorType::get(reshapedOpShape, outputElemTy); auto reshapedOp = rewriter.create( op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( reshapedOpType), mmOpResult, rewriter.getI64ArrayAttr(reshapedOpShape)); if (opNeedsTranspose) { llvm::Optional transposedOpShapeConst = tosa::getConstTensor( rewriter, op, /*vec=*/transposedOpDims, /*shape=*/{static_cast(transposedOpDims.size())}); auto transposedOpType = RankedTensorType::get(transposedOpShape, outputElemTy); auto transposedOp = rewriter.create( op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( transposedOpType), reshapedOp.getResult(), transposedOpShapeConst.getValue()); rewriter.replaceOpWithNewOp(op, outputTy, transposedOp); } else { rewriter.replaceOpWithNewOp(op, outputTy, reshapedOp); } } else { rewriter.replaceOpWithNewOp(op, outputTy, mmOpResult); } return success(); } }; } // namespace // ----------------------------------------------------------------------------- // TorchToTosa Pass // ----------------------------------------------------------------------------- namespace { class ConvertTorchToTosa : public ConvertTorchToTosaBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); registry.insert(); registry.insert(); TorchConversion::getBackendTypeConversionDependentDialects(registry); } void runOnOperation() override { MLIRContext *context = &getContext(); ConversionTarget target(*context); target.addLegalDialect(); TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); TorchConversion::setupBackendTypeConversion(target, typeConverter); RewritePatternSet patterns(context); #define INSERT_UNARY_FPONLY_PATTERN(AtenOp, TosaOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, \ context); INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, tosa::LogOp) INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, tosa::ExpOp) #undef INSERT_UNARY_FPONLY_PATTERN #define INSERT_UNARY_PATTERN(AtenOp, TosaOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context); INSERT_UNARY_PATTERN(AtenNegOp, tosa::NegateOp) INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp) INSERT_UNARY_PATTERN(AtenRsqrtOp, tosa::RsqrtOp) INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp) #undef INSERT_UNARY_PATTERN #define INSERT_BINARY_PATTERN(AtenOp, TosaOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context); INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp) INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp) #undef INSERT_BINARY_PATTERN #define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context); INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, tosa::AddOp) INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, tosa::SubOp) #undef INSERT_BINARY_ADDSUB_PATTERN #define INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ target.addIllegalOp(); \ patterns.add>( \ typeConverter, context); INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenMeanDimOp, mlir::tosa::convertReduceMeanOp) INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenSumDimIntListOp, mlir::tosa::convertReduceSumOp) #undef INSERT_NDIMS_REDUCTION_OP_PATTERN #define INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ target.addIllegalOp(); \ patterns.add>( \ typeConverter, context); INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAnyDimOp, mlir::tosa::convertReduceAnyOp) #undef INSERT_ONEDIM_REDUCTION_OP_PATTERN #define INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ target.addIllegalOp(); \ patterns.add>( \ typeConverter, context); INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAllOp, mlir::tosa::convertReduceAllOp) INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAnyOp, mlir::tosa::convertReduceAnyOp) INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp, mlir::tosa::convertReduceSumOp) #undef INSERT_ALLDIMS_REDUCTION_OP_PATTERN #define INSERT_SQUEEZE_OP_PATTERN(AtenOp, TemplateForm) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context); INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeOp, ConvertAtenSqueezeAllDimsOp) INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeDimOp, ConvertAtenSqueezeOneDimOp) #undef INSERT_SQUEEZE_OP_PATTERN #define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context); INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp); INSERT_MATMUL_ATENOP_PATTERN(AtenMmOp); INSERT_MATMUL_ATENOP_PATTERN(AtenBmmOp); #undef INSERT_MATMUL_ATEMOP_PATTERN #define INSERT_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context); INSERT_ATENOP_PATTERN(AtenTanhOp); INSERT_ATENOP_PATTERN(AtenSigmoidOp); INSERT_ATENOP_PATTERN(AtenReluOp); INSERT_ATENOP_PATTERN(AtenMulTensorOp); INSERT_ATENOP_PATTERN(AtenDivTensorOp); INSERT_ATENOP_PATTERN(AtenArgmaxOp); INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp); #undef INSERT_ATENOP_PATTERN if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); } }; } // namespace std::unique_ptr> mlir::torch::createConvertTorchToTosaPass() { return std::make_unique(); }