mirror of https://github.com/llvm/torch-mlir
Add AtenClampOp conversion pattern to MHLO (#1356)
Add AtenClampOp conversion pattern to MHLOpull/1378/head snapshot-20220916.598
parent
e749831434
commit
b316918947
|
@ -22,6 +22,9 @@ EAGER_MODE_XFAIL_SET = {
|
|||
}
|
||||
|
||||
MHLO_PASS_SET = {
|
||||
"ElementwiseClampModule_basic",
|
||||
"ElementwiseClampMinModule_basic",
|
||||
"ElementwiseClampMaxModule_basic",
|
||||
"BmmModule_basic",
|
||||
"BroadcastToModule_basic",
|
||||
"ElementwiseExpModule_basic",
|
||||
|
|
|
@ -41,6 +41,60 @@ bool skipMultiplyAlpha(Value alphaValue) {
|
|||
return ((isFloat && doubleValue == 1.0) || (isInt && intValue == 1.0));
|
||||
}
|
||||
|
||||
static FailureOr<Value> getMaxValueOfDtype(Operation *op, Type elementType,
|
||||
PatternRewriter &rewriter) {
|
||||
auto constType = RankedTensorType::get({}, elementType);
|
||||
if (elementType.isa<mlir::FloatType>()) {
|
||||
auto constAttr = SplatElementsAttr::get(
|
||||
constType,
|
||||
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
/*negative=*/false));
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
||||
.getResult();
|
||||
}
|
||||
if (elementType.isa<mlir::IntegerType>()) {
|
||||
auto integerType = elementType.cast<mlir::IntegerType>();
|
||||
DenseElementsAttr constAttr;
|
||||
if (integerType.isUnsigned()) {
|
||||
constAttr = SplatElementsAttr::get(
|
||||
constType, APInt::getMaxValue(integerType.getWidth()));
|
||||
} else {
|
||||
constAttr = SplatElementsAttr::get(
|
||||
constType, APInt::getSignedMaxValue(integerType.getWidth()));
|
||||
}
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
||||
.getResult();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
static FailureOr<Value> getMinValueOfDtype(Operation *op, Type elementType,
|
||||
PatternRewriter &rewriter) {
|
||||
auto constType = RankedTensorType::get({}, elementType);
|
||||
if (elementType.isa<mlir::FloatType>()) {
|
||||
auto constAttr = SplatElementsAttr::get(
|
||||
constType,
|
||||
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
/*negative=*/true));
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
||||
.getResult();
|
||||
}
|
||||
if (elementType.isa<mlir::IntegerType>()) {
|
||||
auto integerType = elementType.cast<mlir::IntegerType>();
|
||||
DenseElementsAttr constAttr;
|
||||
if (integerType.isUnsigned()) {
|
||||
constAttr = SplatElementsAttr::get(
|
||||
constType, APInt::getMinValue(integerType.getWidth()));
|
||||
} else {
|
||||
constAttr = SplatElementsAttr::get(
|
||||
constType, APInt::getSignedMinValue(integerType.getWidth()));
|
||||
}
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
||||
.getResult();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
// These legalizations are for unary ops with only for floating point datatypes.
|
||||
// There is no supported quantized integer mode for these.
|
||||
namespace {
|
||||
|
@ -942,33 +996,69 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
|
|||
// AtenNumelOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenNumelOp>::matchAndRewrite(
|
||||
AtenNumelOp op,
|
||||
OpAdaptor adaptor,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
AtenNumelOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto self = adaptor.self();
|
||||
auto selfTy = self.getType().dyn_cast<RankedTensorType>();
|
||||
size_t rank = selfTy.getRank();
|
||||
|
||||
Type intType = rewriter.getIntegerType(options.dimSizeIndexBits);
|
||||
auto loc = op->getLoc();
|
||||
Value numel =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(intType, 1));
|
||||
for (size_t d = 0 ; d < rank; ++ d) {
|
||||
Value dimSize = rewriter.create<arith::IndexCastOp>(
|
||||
Value numel = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(intType, 1));
|
||||
for (size_t d = 0; d < rank; ++d) {
|
||||
Value dimSize = rewriter.create<arith::IndexCastOp>(
|
||||
loc, intType, rewriter.create<tensor::DimOp>(loc, self, d));
|
||||
numel = rewriter.create<arith::MulIOp>(loc, numel, dimSize);
|
||||
numel = rewriter.create<arith::MulIOp>(loc, numel, dimSize);
|
||||
}
|
||||
|
||||
auto outTy = getTypeConverter()->convertType(op.getType());
|
||||
if (outTy != numel.getType()) {
|
||||
rewriter.replaceOpWithNewOp<arith::ExtSIOp>(
|
||||
op, outTy, numel);
|
||||
rewriter.replaceOpWithNewOp<arith::ExtSIOp>(op, outTy, numel);
|
||||
} else {
|
||||
rewriter.replaceOp(op, numel);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
// AtenClampOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
|
||||
AtenClampOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.self();
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto inputElemType = inputType.getElementType();
|
||||
Value minValue = adaptor.min();
|
||||
Value maxValue = adaptor.max();
|
||||
if (failed(checkNotNone(rewriter, op, minValue)) &&
|
||||
failed(checkNotNone(rewriter, op, maxValue))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "this op should be folded as its `min` and `max` both are none");
|
||||
} else if (failed(checkNotNone(rewriter, op, minValue))) {
|
||||
maxValue = mhlo::scalarToMhloTensor(rewriter, op, maxValue, inputElemType);
|
||||
auto minInfo = getMinValueOfDtype(op, inputElemType, rewriter);
|
||||
if (failed(minInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to generate min value of dtype");
|
||||
}
|
||||
minValue = *minInfo;
|
||||
} else if (failed(checkNotNone(rewriter, op, maxValue))) {
|
||||
minValue = mhlo::scalarToMhloTensor(rewriter, op, minValue, inputElemType);
|
||||
auto maxInfo = getMaxValueOfDtype(op, inputElemType, rewriter);
|
||||
if (failed(maxInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to generate max value of dtype");
|
||||
}
|
||||
maxValue = *maxInfo;
|
||||
} else {
|
||||
minValue = mhlo::scalarToMhloTensor(rewriter, op, minValue, inputElemType);
|
||||
maxValue = mhlo::scalarToMhloTensor(rewriter, op, maxValue, inputElemType);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<mhlo::ClampOp>(op, minValue, input, maxValue);
|
||||
return success();
|
||||
}
|
||||
|
||||
void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToMhloOptions &options) {
|
||||
|
@ -1047,6 +1137,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
|||
INSERT_ATENOP_PATTERN(AtenErfOp);
|
||||
|
||||
INSERT_ATENOP_PATTERN(AtenCatOp);
|
||||
INSERT_ATENOP_PATTERN(AtenClampOp);
|
||||
|
||||
INSERT_ATENOP_PATTERN(AtenBatchNormOp);
|
||||
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
|
||||
|
|
Loading…
Reference in New Issue