Add AtenClampOp conversion pattern to MHLO (#1356)

Add AtenClampOp conversion pattern to MHLO
pull/1378/head snapshot-20220916.598
武家伟 2022-09-16 15:09:21 +08:00 committed by GitHub
parent e749831434
commit b316918947
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 104 additions and 10 deletions

View File

@ -22,6 +22,9 @@ EAGER_MODE_XFAIL_SET = {
}
MHLO_PASS_SET = {
"ElementwiseClampModule_basic",
"ElementwiseClampMinModule_basic",
"ElementwiseClampMaxModule_basic",
"BmmModule_basic",
"BroadcastToModule_basic",
"ElementwiseExpModule_basic",

View File

@ -41,6 +41,60 @@ bool skipMultiplyAlpha(Value alphaValue) {
return ((isFloat && doubleValue == 1.0) || (isInt && intValue == 1.0));
}
static FailureOr<Value> getMaxValueOfDtype(Operation *op, Type elementType,
PatternRewriter &rewriter) {
auto constType = RankedTensorType::get({}, elementType);
if (elementType.isa<mlir::FloatType>()) {
auto constAttr = SplatElementsAttr::get(
constType,
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/false));
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
.getResult();
}
if (elementType.isa<mlir::IntegerType>()) {
auto integerType = elementType.cast<mlir::IntegerType>();
DenseElementsAttr constAttr;
if (integerType.isUnsigned()) {
constAttr = SplatElementsAttr::get(
constType, APInt::getMaxValue(integerType.getWidth()));
} else {
constAttr = SplatElementsAttr::get(
constType, APInt::getSignedMaxValue(integerType.getWidth()));
}
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
.getResult();
}
return failure();
}
static FailureOr<Value> getMinValueOfDtype(Operation *op, Type elementType,
PatternRewriter &rewriter) {
auto constType = RankedTensorType::get({}, elementType);
if (elementType.isa<mlir::FloatType>()) {
auto constAttr = SplatElementsAttr::get(
constType,
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/true));
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
.getResult();
}
if (elementType.isa<mlir::IntegerType>()) {
auto integerType = elementType.cast<mlir::IntegerType>();
DenseElementsAttr constAttr;
if (integerType.isUnsigned()) {
constAttr = SplatElementsAttr::get(
constType, APInt::getMinValue(integerType.getWidth()));
} else {
constAttr = SplatElementsAttr::get(
constType, APInt::getSignedMinValue(integerType.getWidth()));
}
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
.getResult();
}
return failure();
}
// These legalizations are for unary ops with only for floating point datatypes.
// There is no supported quantized integer mode for these.
namespace {
@ -942,33 +996,69 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
// AtenNumelOp
template <>
LogicalResult ConvertAtenOp<AtenNumelOp>::matchAndRewrite(
AtenNumelOp op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const {
AtenNumelOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto self = adaptor.self();
auto selfTy = self.getType().dyn_cast<RankedTensorType>();
size_t rank = selfTy.getRank();
Type intType = rewriter.getIntegerType(options.dimSizeIndexBits);
auto loc = op->getLoc();
Value numel =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(intType, 1));
for (size_t d = 0 ; d < rank; ++ d) {
Value dimSize = rewriter.create<arith::IndexCastOp>(
Value numel = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, 1));
for (size_t d = 0; d < rank; ++d) {
Value dimSize = rewriter.create<arith::IndexCastOp>(
loc, intType, rewriter.create<tensor::DimOp>(loc, self, d));
numel = rewriter.create<arith::MulIOp>(loc, numel, dimSize);
numel = rewriter.create<arith::MulIOp>(loc, numel, dimSize);
}
auto outTy = getTypeConverter()->convertType(op.getType());
if (outTy != numel.getType()) {
rewriter.replaceOpWithNewOp<arith::ExtSIOp>(
op, outTy, numel);
rewriter.replaceOpWithNewOp<arith::ExtSIOp>(op, outTy, numel);
} else {
rewriter.replaceOp(op, numel);
}
return success();
}
// AtenClampOp
template <>
LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
AtenClampOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.self();
auto inputType = input.getType().cast<RankedTensorType>();
auto inputElemType = inputType.getElementType();
Value minValue = adaptor.min();
Value maxValue = adaptor.max();
if (failed(checkNotNone(rewriter, op, minValue)) &&
failed(checkNotNone(rewriter, op, maxValue))) {
return rewriter.notifyMatchFailure(
op, "this op should be folded as its `min` and `max` both are none");
} else if (failed(checkNotNone(rewriter, op, minValue))) {
maxValue = mhlo::scalarToMhloTensor(rewriter, op, maxValue, inputElemType);
auto minInfo = getMinValueOfDtype(op, inputElemType, rewriter);
if (failed(minInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to generate min value of dtype");
}
minValue = *minInfo;
} else if (failed(checkNotNone(rewriter, op, maxValue))) {
minValue = mhlo::scalarToMhloTensor(rewriter, op, minValue, inputElemType);
auto maxInfo = getMaxValueOfDtype(op, inputElemType, rewriter);
if (failed(maxInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to generate max value of dtype");
}
maxValue = *maxInfo;
} else {
minValue = mhlo::scalarToMhloTensor(rewriter, op, minValue, inputElemType);
maxValue = mhlo::scalarToMhloTensor(rewriter, op, maxValue, inputElemType);
}
rewriter.replaceOpWithNewOp<mhlo::ClampOp>(op, minValue, input, maxValue);
return success();
}
void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToMhloOptions &options) {
@ -1047,6 +1137,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenErfOp);
INSERT_ATENOP_PATTERN(AtenCatOp);
INSERT_ATENOP_PATTERN(AtenClampOp);
INSERT_ATENOP_PATTERN(AtenBatchNormOp);
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);