//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" #include "../PassDetail.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Traits.h" #include "mlir/IR/Matchers.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; namespace { // These legalizations are for unary ops with only for floating point datatypes. // There is no supported quantized integer mode for these. template class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value self = adaptor.self(); auto selfTy = self.getType().cast(); if (!selfTy) return op.emitError("Only Tensor types supported in TOSA"); if (selfTy.getElementType().isa()) { rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( op.getType()), self); return success(); } else { return op.emitError( "Only floating-point datatype legalization supported"); } } }; // These unary op legalizations are identical for floating-point // or quantized types template class ConvertAtenUnaryOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( op.getType()), adaptor.self()); return success(); } }; // These binary op legalizations are specific to add/sub which have an // alpha multiplier. template class ConvertAtenAddSubOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value lhs = adaptor.self(); auto lhsTy = lhs.getType().cast(); Value rhs = adaptor.other(); auto rhsTy = rhs.getType().cast(); if (!lhsTy || !rhsTy) return op.emitError("Only Tensor types supported in TOSA"); auto lhsElemTy = lhsTy.getElementType(); auto rhsElemTy = rhsTy.getElementType(); if (lhsElemTy != rhsElemTy) return op.emitError("Add: input datatypes mismatched"); // FIXME: Handle alpha. // Needs extraction of floating point constant. if (lhsElemTy.isa()) { rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( op.getType()), lhs, rhs); return success(); } else { return op.emitError( "Only floating-point datatype legalization supported"); } } }; // This defines a template to construct ops whose legalizations are // specialized. template class ConvertAtenOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenTanhOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.self(); auto selfTy = self.getType().cast(); if (selfTy && selfTy.getElementType().isa()) { rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self); return success(); } else { // Sigmoid legalization in TOSA for quantized element-type uses // specialized tosa.table construct. return op.emitError( "Only floating-point datatype legalization currently supported"); } } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenSigmoidOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.self(); auto selfTy = self.getType().cast(); if (selfTy && selfTy.getElementType().isa()) { rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self); return success(); } else { // Sigmoid legalization in TOSA for quantized element-type uses // specialized tosa.table construct. return op.emitError( "Only floating-point datatype legalization currently supported"); } } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenReluOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.self(); auto selfTy = self.getType().cast(); // Maps to tosa.clamp which has both int and fp limits. int64_t clampMin = 0; Value clampIn = self; if (selfTy) { // Rescale the clampIn for quantized types. TBD if (!selfTy.getElementType().isa()) { return op.emitError( "Only floating-point datatype legalization currently supported"); } rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), clampIn, rewriter.getI64IntegerAttr(clampMin), rewriter.getI64IntegerAttr(std::numeric_limits::max()), rewriter.getF32FloatAttr(0.0f), rewriter.getF32FloatAttr(std::numeric_limits::max())); return success(); } else { return op.emitError("Only Tensor types supported in TOSA"); } } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenMulTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value lhs = adaptor.self(); auto lhsTy = lhs.getType().cast(); Value rhs = adaptor.other(); auto rhsTy = rhs.getType().cast(); if (!lhsTy || !rhsTy) return op.emitError("Only Tensor types supported in TOSA"); auto lhsElemTy = lhsTy.getElementType(); auto rhsElemTy = rhsTy.getElementType(); if (lhsElemTy != rhsElemTy) return op.emitError("Add: input datatypes mismatched"); if (lhsElemTy.isa()) { rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), lhs, rhs, /*shift=*/0); return success(); } else { // Quantized multiplication may need to rescale inputs. return op.emitError( "Only floating-point datatype legalization currently supported"); } } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenDivTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value lhs = adaptor.self(); auto lhsTy = lhs.getType().cast(); Value rhs = adaptor.other(); auto rhsTy = rhs.getType().cast(); if (!lhsTy || !rhsTy) return op.emitError("Only Tensor types supported in TOSA"); auto lhsElemTy = lhsTy.getElementType(); auto rhsElemTy = rhsTy.getElementType(); if (lhsElemTy != rhsElemTy) return op.emitError("Add: input datatypes mismatched"); if (lhsElemTy.isa()) { auto rcpOp = rewriter.create( op->getLoc(), getTypeConverter()->convertType(op.getType()), rhs); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), lhs, rcpOp.getResult(), /*shift=*/0); } else { rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), lhs, rhs); } return success(); } using ReductionConvFunc = llvm::Optional (*)(PatternRewriter &, Operation *, RankedTensorType, Value, ElementsAttr, bool); // They all constitute a common form invoking the appropriate // converion function in TosaLegalizeCommon.cpp template class ConvertAtenReductionOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; // Each variant must implement corresponding parameter parsing options virtual LogicalResult readReduceDimsAndKeepDims( AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, ElementsAttr &reduceDimsAttr, bool &keepDims) const { return rewriter.notifyMatchFailure( op, "Unimplemented reduce_dims and keep_dims parsing function"); } // Common rewriter for all reduction ops, calls the specific implementation of // readReduceDimsAndKeepDims() needed for the op variant. LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.self(); auto selfTy = self.getType().cast(); if (!selfTy) return op.emitError("Only Tensor types supported in TOSA"); auto outputTy = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template cast(); if (!outputTy) return op.emitError( "Only ranked tensor type outputs permitted for reduce_mean"); ElementsAttr reduceDimsAttr; bool keepDims; if (failed(readReduceDimsAndKeepDims(op, adaptor, rewriter, reduceDimsAttr, keepDims))) return failure(); llvm::Optional result = ConversionFuncT(rewriter, op, outputTy, self, reduceDimsAttr, keepDims); if (!result) return failure(); // TBD - support dtype casting. rewriter.replaceOp(op, {result.getValue()}); return success(); } }; // This reduction op legalization template handles op variants that have // explicit reduce_dims dimensions (provided as a list) and keep_dims // parameters. template class ConvertAtenMultipleDimsReductionOp : public ConvertAtenReductionOp { using ConvertAtenReductionOp::ConvertAtenReductionOp; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult readReduceDimsAndKeepDims(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, ElementsAttr &reduceDimsAttr, bool &keepDims) const { SmallVector reduceDims; if (!matchPattern(op.dim(), m_TorchConstantIntList(reduceDims))) return rewriter.notifyMatchFailure(op, "non-const dim parameter unsupported"); int64_t N = reduceDims.size(); auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type()); reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType, llvm::makeArrayRef(reduceDims)); keepDims = false; if (!matchPattern(op.keepdim(), m_TorchConstantBool(&keepDims))) return rewriter.notifyMatchFailure( op, "non-const keepdim parameter unsupported"); return success(); } }; // This reduction op legalization template handles op variants that reduce in // only one explicit dim which is provided as a number (rather than a list), and // a keep_dims parameter. template class ConvertAtenOneDimReductionOp : public ConvertAtenReductionOp { using ConvertAtenReductionOp::ConvertAtenReductionOp; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult readReduceDimsAndKeepDims(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, ElementsAttr &reduceDimsAttr, bool &keepDims) const { int64_t reduceDim; if (!matchPattern(op.dim(), m_TorchConstantInt(&reduceDim))) return rewriter.notifyMatchFailure(op, "non-const dim parameter unsupported"); auto reduceDimsType = RankedTensorType::get({1}, rewriter.getI64Type()); reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType, llvm::makeArrayRef({reduceDim})); keepDims = false; if (!matchPattern(op.keepdim(), m_TorchConstantBool(&keepDims))) return rewriter.notifyMatchFailure( op, "non-const keepdim parameter unsupported"); return success(); } }; // This reduction op legalization template handles op variants that reduce all // dims does not keep dims. template class ConvertAtenAllDimsReductionOp : public ConvertAtenReductionOp { public: using ConvertAtenReductionOp::ConvertAtenReductionOp; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult readReduceDimsAndKeepDims(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, ElementsAttr &reduceDimsAttr, bool &keepDims) const { auto self = adaptor.self(); auto selfTy = self.getType().template cast(); // Select all dims to reduce SmallVector reduceDims; for (int64_t i = 0; i < selfTy.getRank(); i++) reduceDims.push_back(i); int64_t N = selfTy.getRank(); auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type()); reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType, llvm::makeArrayRef(reduceDims)); keepDims = false; return success(); } }; } // namespace // ----------------------------------------------------------------------------- // TorchToTosa Pass // ----------------------------------------------------------------------------- namespace { class ConvertTorchToTosa : public ConvertTorchToTosaBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); TorchConversion::getBackendTypeConversionDependentDialects(registry); } void runOnOperation() override { MLIRContext *context = &getContext(); ConversionTarget target(*context); target.addLegalDialect(); TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); TorchConversion::setupBackendTypeConversion(target, typeConverter); RewritePatternSet patterns(context); #define INSERT_UNARY_FPONLY_PATTERN(AtenOp, TosaOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, \ context); INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, tosa::LogOp) INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, tosa::ExpOp) #undef INSERT_UNARY_FPONLY_PATTERN #define INSERT_UNARY_PATTERN(AtenOp, TosaOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context); INSERT_UNARY_PATTERN(AtenNegOp, tosa::NegateOp) INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp) INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp) #undef INSERT_UNARY_PATTERN #define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context); INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, tosa::AddOp) INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, tosa::SubOp) #undef INSERT_BINARY_ADDSUB_PATTERN #define INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ target.addIllegalOp(); \ patterns.add>( \ typeConverter, context); INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenMeanDimOp, mlir::tosa::convertReduceMeanOp) INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenSumDimIntListOp, mlir::tosa::convertReduceSumOp) #undef INSERT_NDIMS_REDUCTION_OP_PATTERN #define INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ target.addIllegalOp(); \ patterns.add>( \ typeConverter, context); INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAnyDimOp, mlir::tosa::convertReduceAnyOp) #undef INSERT_ONEDIM_REDUCTION_OP_PATTERN #define INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ target.addIllegalOp(); \ patterns.add>( \ typeConverter, context); INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAllOp, mlir::tosa::convertReduceAllOp) INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAnyOp, mlir::tosa::convertReduceAnyOp) INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp, mlir::tosa::convertReduceSumOp) #undef INSERT_ALLDIMS_REDUCTION_OP_PATTERN #define INSERT_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context); INSERT_ATENOP_PATTERN(AtenTanhOp); INSERT_ATENOP_PATTERN(AtenSigmoidOp); INSERT_ATENOP_PATTERN(AtenReluOp); INSERT_ATENOP_PATTERN(AtenMulTensorOp); INSERT_ATENOP_PATTERN(AtenDivTensorOp); #undef INSERT_ATENOP_PATTERN if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); } }; } // namespace std::unique_ptr> mlir::torch::createConvertTorchToTosaPass() { return std::make_unique(); }