//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "../PassDetail.h" #include "./MhloLegalizeUtils.h" #include "./PopulatePatterns.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; static Value createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter) { auto constType = RankedTensorType::get({}, elementTy); if (isa(op)) { if (elementTy.isa()) { auto constAttr = DenseElementsAttr::get( constType, {APFloat::getZero( elementTy.cast().getFloatSemantics(), /*negative=*/false)}); return rewriter.create(op->getLoc(), constType, constAttr); } else if (elementTy.isa() && elementTy.getIntOrFloatBitWidth() != 8) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())}); return rewriter.create(op->getLoc(), constType, constAttr); } } if (isa(op)) { if (elementTy.isa()) { auto constAttr = DenseElementsAttr::get( constType, {APFloat::getLargest( elementTy.cast().getFloatSemantics(), /*negative=*/true)}); return rewriter.create(op->getLoc(), constType, constAttr); } else if (elementTy.isa() && elementTy.getIntOrFloatBitWidth() != 8) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())}); return rewriter.create(op->getLoc(), constType, constAttr); } } op->emitError("unimplemented lowering in " "createInitialValueForReduceOp"); return nullptr; } // Util for converting AtenArgmaxOp and AtenMaxDimOp static llvm::Optional getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, ArrayRef inputShapeVec, int64_t dim) { auto inputTy = input.getType().template cast(); if (!inputTy) { return llvm::None; } if (!inputTy.getElementType().isIntOrFloat()) { return llvm::None; } auto inputShape = inputTy.getShape(); auto inputElemTy = inputTy.getElementType(); Value initValue = createInitialValueForReduceOp(op, inputElemTy, rewriter); if (!initValue) return llvm::None; Value initIndex = mhlo::getConstTensor(rewriter, op, {0}, {}).getValue(); DenseIntElementsAttr dimensions = DenseIntElementsAttr::get( RankedTensorType::get({}, rewriter.getI64Type()), dim); auto inputShapeTensor = rewriter.create( op->getLoc(), inputShapeVec); auto indexTensor = rewriter.create( op->getLoc(), RankedTensorType::get(inputShape, rewriter.getI64Type()), inputShapeTensor, static_cast(dim)); auto mhloReduceOp = rewriter.create( op->getLoc(), ValueRange{input, indexTensor}, ValueRange{ initValue, initIndex, }, dimensions); Block &block = mhloReduceOp.body().emplaceBlock(); // Add block arguments auto blockValArgumentType = RankedTensorType::get({}, inputTy.getElementType()); auto blockIdxArgumentType = RankedTensorType::get({}, rewriter.getI64Type()); auto compareResultType = RankedTensorType::get({}, rewriter.getI1Type()); block.addArgument(blockValArgumentType, op->getLoc()); block.addArgument(blockIdxArgumentType, op->getLoc()); block.addArgument(blockValArgumentType, op->getLoc()); block.addArgument(blockIdxArgumentType, op->getLoc()); auto *firstValArg = block.args_begin(); auto *firstIdxArg = std::next(firstValArg); auto *secondValArg = std::next(firstIdxArg); auto *secondIdxArg = std::next(secondValArg); mhlo::ComparisonTypeAttr compareTypeAttr; if (inputTy.getElementType().isa()) { compareTypeAttr = mhlo::ComparisonTypeAttr::get( rewriter.getContext(), mhlo::ComparisonType::FLOAT); } else if (inputTy.getElementType().isa()) { compareTypeAttr = mhlo::ComparisonTypeAttr::get( rewriter.getContext(), mhlo::ComparisonType::SIGNED); } mhlo::ComparisonDirectionAttr compareGeDirectionAttr = mhlo::ComparisonDirectionAttr::get(rewriter.getContext(), mhlo::ComparisonDirection::GE); mhlo::ComparisonDirectionAttr compareEqDirectionAttr = mhlo::ComparisonDirectionAttr::get(rewriter.getContext(), mhlo::ComparisonDirection::EQ); { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); Value compareGeResult = rewriter.create( op->getLoc(), compareResultType, *firstValArg, *secondValArg, compareGeDirectionAttr, compareTypeAttr); Value retValResult = rewriter.create( op->getLoc(), compareGeResult, *firstValArg, *secondValArg); // get smaller index value if compared nums are equal. Value compareEqResult = rewriter.create( op->getLoc(), compareResultType, *firstValArg, *secondValArg, compareEqDirectionAttr, compareTypeAttr); Value minIdx = rewriter.create(op->getLoc(), *firstIdxArg, *secondIdxArg); Value idxWithGeVal = rewriter.create( op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg); Value retIdxResult = rewriter.create( op->getLoc(), compareEqResult, minIdx, idxWithGeVal); rewriter.create( op->getLoc(), mlir::ValueRange{retValResult, retIdxResult}); } return mhloReduceOp.getResults(); } namespace { template class ConvertAtenReductionOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace // AtenArgmaxOp namespace { template <> LogicalResult ConvertAtenReductionOp::matchAndRewrite( AtenArgmaxOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.self(); auto inputTy = input.getType().template cast(); if (!inputTy) { return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO"); } auto inputElemTy = inputTy.getElementType(); if (!inputElemTy.isIntOrFloat()) { return op.emitError( "only floating-point or integer datatype legalization supported"); } // Currently, (u)int8 dtype is not supported! if (inputElemTy.isa() && inputElemTy.getIntOrFloatBitWidth() == 8) { return rewriter.notifyMatchFailure( op, "IntegerType with bitwidth 8 unsupported in convertion from " "AtenArgmaxOp to MHLO"); } int64_t dim; if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) { return rewriter.notifyMatchFailure(op, "non-int dim unsupported"); } dim = toPositiveDim(dim, inputTy.getRank()); if (!isValidDim(dim, inputTy.getRank())) { return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); } bool keepDim = false; if (!matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) { return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); } auto inputShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input); if (failed(inputShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } auto inputShapeVec = *inputShapeInfo; auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim).getValue(); if (keepDim) { auto outShapeVec = inputShapeVec; outShapeVec[dim] = rewriter.create( op->getLoc(), rewriter.getIntegerAttr( rewriter.getIntegerType(mhlo::kMhloDimSizeBits), 1)); auto outShapeTensor = rewriter.create( op->getLoc(), outShapeVec); rewriter.replaceOpWithNewOp( op, typeConverter->convertType(op.getType()), mhloReduceResults[1], outShapeTensor); return success(); } rewriter.replaceOp(op, mhloReduceResults[1]); return success(); } } // namespace // AtenMaxDimOp namespace { template <> LogicalResult ConvertAtenReductionOp::matchAndRewrite( AtenMaxDimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.self(); auto inputTy = input.getType().template dyn_cast(); if (!inputTy) { return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO"); } auto inputElemTy = inputTy.getElementType(); if (!inputElemTy.isIntOrFloat()) { return op.emitError( "Only floating-point or integer datatype legalization supported"); } // Currently, (u)int8 dtype is not supported if (inputElemTy.isa() && inputElemTy.getIntOrFloatBitWidth() == 8) { return rewriter.notifyMatchFailure( op, "IntegerType with bitwidth 8 unsupported in convertion from " "AtenMaxDimOp to MHLO"); } RankedTensorType valResultType = getTypeConverter() ->convertType(op.getResult(0).getType()) .template cast(); RankedTensorType idxResultType = getTypeConverter() ->convertType(op.getResult(1).getType()) .template cast(); Type idxElementType = idxResultType.getElementType(); if (!idxElementType.isa()) { return op.emitError("Aten.max.dim needs integer-like result"); } int64_t dim; if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) { return rewriter.notifyMatchFailure(op, "non-int dim unsupported"); } dim = toPositiveDim(dim, inputTy.getRank()); if (!isValidDim(dim, inputTy.getRank())) { return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); } bool keepDim = false; if (!matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) { return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); } auto inputShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input); if (failed(inputShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } auto inputShapeVec = *inputShapeInfo; auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim).getValue(); if (keepDim) { auto outShapeVec = inputShapeVec; outShapeVec[dim] = rewriter.create( op->getLoc(), rewriter.getIntegerAttr( rewriter.getIntegerType(mhlo::kMhloDimSizeBits), 1)); auto outShapeTensor = rewriter.create( op->getLoc(), outShapeVec); auto mhloReduceValueResult = rewriter.create( op->getLoc(), valResultType, mhloReduceResults[0], outShapeTensor); auto mhloReduceIndexResult = rewriter.create( op->getLoc(), idxResultType, mhloReduceResults[1], outShapeTensor); rewriter.replaceOp(op, {mhloReduceValueResult, mhloReduceIndexResult}); return success(); } rewriter.replaceOp(op, {mhloReduceResults[0], mhloReduceResults[1]}); return success(); } } // namespace // AtenSumOp namespace { template <> LogicalResult ConvertAtenReductionOp::matchAndRewrite( AtenSumOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.self(); auto inputTy = input.getType().dyn_cast(); if (!inputTy) { return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO"); } auto dtype = adaptor.dtype(); if (!dtype.getType().isa()) { auto dstElemTy = getTypeConverter() ->convertType(op.getType()) .template dyn_cast() .getElementType(); input = rewriter.create(op->getLoc(), input, dstElemTy); inputTy = input.getType().dyn_cast(); } auto inputElemTy = inputTy.getElementType(); if (!inputElemTy.isIntOrFloat()) { return op.emitError( "only floating-point or integer datatype legalization supported"); } // Currently, (u)int8 dtype is not supported if (inputElemTy.isa() && inputElemTy.getIntOrFloatBitWidth() == 8) { return rewriter.notifyMatchFailure( op, "IntegerType with bitwidth 8 unsupported in convertion from " "AtenSumOp to MHLO"); } SmallVector dims; for (int64_t i = 0; i < inputTy.getRank(); i++) { dims.push_back(i); } Value initValue = createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); if (!initValue) return failure(); auto mhloReduceOp = rewriter.create( op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); Block &block = mhloReduceOp.body().emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); block.addArgument(blockArgumentTy, op->getLoc()); block.addArgument(blockArgumentTy, op->getLoc()); auto *firstArgument = block.args_begin(); auto secondArgument = block.args_rbegin(); { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); Value addResult = rewriter.create( op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); rewriter.create(op->getLoc(), addResult); } rewriter.replaceOp(op, mhloReduceOp.getResults()); return success(); } } // namespace // AtenMaxOp namespace { template <> LogicalResult ConvertAtenReductionOp::matchAndRewrite( AtenMaxOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.self(); auto inputTy = input.getType().dyn_cast(); if (!inputTy) { return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO"); } auto inputElemTy = inputTy.getElementType(); if (!inputElemTy.isIntOrFloat()) { return op.emitError( "only floating-point or integer datatype legalization supported"); } // Currently, (u)int8 dtype is not supported if (inputElemTy.isa() && inputElemTy.getIntOrFloatBitWidth() == 8) { return rewriter.notifyMatchFailure( op, "IntegerType with bitwidth 8 unsupported in convertion from " "AtenMaxOp to MHLO"); } SmallVector dims; for (int64_t i = 0; i < inputTy.getRank(); i++) { dims.push_back(i); } Value initValue = createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); if (!initValue) return failure(); auto mhloReduceOp = rewriter.create( op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); Block &block = mhloReduceOp.body().emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); block.addArgument(blockArgumentTy, op->getLoc()); block.addArgument(blockArgumentTy, op->getLoc()); auto *firstArgument = block.args_begin(); auto secondArgument = block.args_rbegin(); { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); Value maxResult = rewriter.create( op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); rewriter.create(op->getLoc(), maxResult); } rewriter.replaceOp(op, mhloReduceOp.getResults()); return success(); } } // namespace // AtenSumDimIntListOp namespace { template <> LogicalResult ConvertAtenReductionOp::matchAndRewrite( AtenSumDimIntListOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.self(); auto inputTy = input.getType().dyn_cast(); if (!inputTy) { return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO"); } auto dtype = adaptor.dtype(); if (!dtype.getType().isa()) { auto dstElemTy = getTypeConverter() ->convertType(op.getType()) .template dyn_cast() .getElementType(); input = rewriter.create(op->getLoc(), input, dstElemTy); inputTy = input.getType().dyn_cast(); } auto inputElemTy = inputTy.getElementType(); if (!inputElemTy.isIntOrFloat()) { return op.emitError( "Only floating-point or integer datatype legalization supported"); } // Currently, (u)int8 dtype is not supported if (inputElemTy.isa() && inputElemTy.getIntOrFloatBitWidth() == 8) { return rewriter.notifyMatchFailure( op, "IntegerType with bitwidth 8 unsupported in convertion from " "AtenSumDimIntListOp to MHLO"); } SmallVector inputDims; SmallVector dims; if (!matchPattern(op.dim(), m_TorchConstantIntList(inputDims))) { return rewriter.notifyMatchFailure(op, "non-int dim list unsupported"); } for (auto d : inputDims) { d = toPositiveDim(d, inputTy.getRank()); // Drop invalid dims if (isValidDim(d, inputTy.getRank())) { dims.push_back(d); } } bool keepDim = false; if (!matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) { return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); } Value initValue = createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); if (!initValue) return failure(); auto mhloReduceOp = rewriter.create( op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); Region ®ion = mhloReduceOp.body(); Block &block = region.emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); block.addArgument(blockArgumentTy, op->getLoc()); block.addArgument(blockArgumentTy, op->getLoc()); auto *firstArgument = block.args_begin(); auto secondArgument = block.args_rbegin(); { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); Value addResult = rewriter.create( op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); rewriter.create(op->getLoc(), addResult); } if (keepDim) { auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input); if (failed(outShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } auto outShapeVec = *outShapeInfo; auto one = rewriter.create( op->getLoc(), rewriter.getIntegerAttr( rewriter.getIntegerType(mhlo::kMhloDimSizeBits), 1)); for (int64_t i : dims) { outShapeVec[i] = one; } auto outShapeTensor = rewriter.create( op->getLoc(), outShapeVec); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), mhloReduceOp.getResult(0), outShapeTensor); return success(); } rewriter.replaceOp(op, mhloReduceOp.getResults()); return success(); } } // namespace void mlir::torch::torch_to_mhlo::populateReductionOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); #define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenArgmaxOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxDimOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxOp); #undef INSERT_ATEN_REDUCTION_OP_PATTERN }