From b316918947f4c206fd6c3abcecb5abcb27efc73b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=AD=A6=E5=AE=B6=E4=BC=9F?= <73166454+Vremold@users.noreply.github.com> Date: Fri, 16 Sep 2022 15:09:21 +0800 Subject: [PATCH] Add AtenClampOp conversion pattern to MHLO (#1356) Add AtenClampOp conversion pattern to MHLO --- e2e_testing/xfail_sets.py | 3 + lib/Conversion/TorchToMhlo/Basic.cpp | 111 ++++++++++++++++++++++++--- 2 files changed, 104 insertions(+), 10 deletions(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 67a23a1b5..026727af8 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -22,6 +22,9 @@ EAGER_MODE_XFAIL_SET = { } MHLO_PASS_SET = { + "ElementwiseClampModule_basic", + "ElementwiseClampMinModule_basic", + "ElementwiseClampMaxModule_basic", "BmmModule_basic", "BroadcastToModule_basic", "ElementwiseExpModule_basic", diff --git a/lib/Conversion/TorchToMhlo/Basic.cpp b/lib/Conversion/TorchToMhlo/Basic.cpp index 5d4d95c26..37ce61f42 100644 --- a/lib/Conversion/TorchToMhlo/Basic.cpp +++ b/lib/Conversion/TorchToMhlo/Basic.cpp @@ -41,6 +41,60 @@ bool skipMultiplyAlpha(Value alphaValue) { return ((isFloat && doubleValue == 1.0) || (isInt && intValue == 1.0)); } +static FailureOr getMaxValueOfDtype(Operation *op, Type elementType, + PatternRewriter &rewriter) { + auto constType = RankedTensorType::get({}, elementType); + if (elementType.isa()) { + auto constAttr = SplatElementsAttr::get( + constType, + APFloat::getInf(elementType.cast().getFloatSemantics(), + /*negative=*/false)); + return rewriter.create(op->getLoc(), constType, constAttr) + .getResult(); + } + if (elementType.isa()) { + auto integerType = elementType.cast(); + DenseElementsAttr constAttr; + if (integerType.isUnsigned()) { + constAttr = SplatElementsAttr::get( + constType, APInt::getMaxValue(integerType.getWidth())); + } else { + constAttr = SplatElementsAttr::get( + constType, APInt::getSignedMaxValue(integerType.getWidth())); + } + return rewriter.create(op->getLoc(), constType, constAttr) + .getResult(); + } + return failure(); +} + +static FailureOr getMinValueOfDtype(Operation *op, Type elementType, + PatternRewriter &rewriter) { + auto constType = RankedTensorType::get({}, elementType); + if (elementType.isa()) { + auto constAttr = SplatElementsAttr::get( + constType, + APFloat::getInf(elementType.cast().getFloatSemantics(), + /*negative=*/true)); + return rewriter.create(op->getLoc(), constType, constAttr) + .getResult(); + } + if (elementType.isa()) { + auto integerType = elementType.cast(); + DenseElementsAttr constAttr; + if (integerType.isUnsigned()) { + constAttr = SplatElementsAttr::get( + constType, APInt::getMinValue(integerType.getWidth())); + } else { + constAttr = SplatElementsAttr::get( + constType, APInt::getSignedMinValue(integerType.getWidth())); + } + return rewriter.create(op->getLoc(), constType, constAttr) + .getResult(); + } + return failure(); +} + // These legalizations are for unary ops with only for floating point datatypes. // There is no supported quantized integer mode for these. namespace { @@ -942,33 +996,69 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // AtenNumelOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( - AtenNumelOp op, - OpAdaptor adaptor, - ConversionPatternRewriter& rewriter) const { + AtenNumelOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { auto self = adaptor.self(); auto selfTy = self.getType().dyn_cast(); size_t rank = selfTy.getRank(); Type intType = rewriter.getIntegerType(options.dimSizeIndexBits); auto loc = op->getLoc(); - Value numel = - rewriter.create(loc, rewriter.getIntegerAttr(intType, 1)); - for (size_t d = 0 ; d < rank; ++ d) { - Value dimSize = rewriter.create( + Value numel = rewriter.create( + loc, rewriter.getIntegerAttr(intType, 1)); + for (size_t d = 0; d < rank; ++d) { + Value dimSize = rewriter.create( loc, intType, rewriter.create(loc, self, d)); - numel = rewriter.create(loc, numel, dimSize); + numel = rewriter.create(loc, numel, dimSize); } auto outTy = getTypeConverter()->convertType(op.getType()); if (outTy != numel.getType()) { - rewriter.replaceOpWithNewOp( - op, outTy, numel); + rewriter.replaceOpWithNewOp(op, outTy, numel); } else { rewriter.replaceOp(op, numel); } return success(); } +// AtenClampOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenClampOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.self(); + auto inputType = input.getType().cast(); + auto inputElemType = inputType.getElementType(); + Value minValue = adaptor.min(); + Value maxValue = adaptor.max(); + if (failed(checkNotNone(rewriter, op, minValue)) && + failed(checkNotNone(rewriter, op, maxValue))) { + return rewriter.notifyMatchFailure( + op, "this op should be folded as its `min` and `max` both are none"); + } else if (failed(checkNotNone(rewriter, op, minValue))) { + maxValue = mhlo::scalarToMhloTensor(rewriter, op, maxValue, inputElemType); + auto minInfo = getMinValueOfDtype(op, inputElemType, rewriter); + if (failed(minInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to generate min value of dtype"); + } + minValue = *minInfo; + } else if (failed(checkNotNone(rewriter, op, maxValue))) { + minValue = mhlo::scalarToMhloTensor(rewriter, op, minValue, inputElemType); + auto maxInfo = getMaxValueOfDtype(op, inputElemType, rewriter); + if (failed(maxInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to generate max value of dtype"); + } + maxValue = *maxInfo; + } else { + minValue = mhlo::scalarToMhloTensor(rewriter, op, minValue, inputElemType); + maxValue = mhlo::scalarToMhloTensor(rewriter, op, maxValue, inputElemType); + } + rewriter.replaceOpWithNewOp(op, minValue, input, maxValue); + return success(); +} + void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToMhloOptions &options) { @@ -1047,6 +1137,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenErfOp); INSERT_ATENOP_PATTERN(AtenCatOp); + INSERT_ATENOP_PATTERN(AtenClampOp); INSERT_ATENOP_PATTERN(AtenBatchNormOp); INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);