[MHLO] Init MHLO pooling-like op conversion (#1141)

* [MHLO] Init MHLO pooling-like op conversion and remove 'op' suffix in filenames

Co-authored-by: Bairen Yi <yibairen.byron@bytedance.com>
Co-authored-by: Jiawei Wu <xremold@gmail.com>
Co-authored-by: Tianyou Guo tianyou.gty@alibaba-inc.com
Co-authored-by: Xu Yan <yancey.yx@alibaba-inc.com>
Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com>

See RFC #999
pull/1156/head
武家伟 2022-08-04 12:34:22 +08:00 committed by GitHub
parent f0a24f59f6
commit d030591df9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 798 additions and 19 deletions

View File

@ -22,7 +22,6 @@
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include <iostream>
#include <numeric>
@ -618,9 +617,8 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
namespace {
template <>
LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
AtenGeluOp op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const {
AtenGeluOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value input = adaptor.self();
auto inputTy = input.getType().template dyn_cast<RankedTensorType>();
@ -641,7 +639,6 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
}
} // namespace
// AtenErfOp
namespace {
template <>

View File

@ -1,11 +1,12 @@
add_mlir_conversion_library(TorchMLIRTorchToMhlo
TorchToMhlo.cpp
MhloLegalizeUtils.cpp
BasicOp.cpp
GatherOp.cpp
Basic.cpp
Gather.cpp
Linear.cpp
ViewLikeOps.cpp
ReductionOp.cpp
ViewLike.cpp
Reduction.cpp
Pooling.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToMhlo

View File

@ -0,0 +1,557 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "../PassDetail.h"
#include "./MhloLegalizeUtils.h"
#include "./PopulatePatterns.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include <iostream>
#include <numeric>
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
PatternRewriter &rewriter) {
auto constType = RankedTensorType::get({}, elementTy);
// Avg pooling
if (isa<AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp>(op)) {
if (elementTy.isa<mlir::FloatType>()) {
auto constAttr = DenseElementsAttr::get(
constType, {APFloat::getZero(
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/false)});
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
constAttr);
} else if (elementTy.isa<mlir::IntegerType>() &&
elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get(
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
constAttr);
}
}
// Max pooling
if (isa<AtenMaxPool2dOp, AtenMaxPool2dWithIndicesOp>(op)) {
if (elementTy.isa<mlir::FloatType>()) {
auto constAttr = DenseElementsAttr::get(
constType, {APFloat::getLargest(
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/true)});
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
constAttr);
} else if (elementTy.isa<mlir::IntegerType>() &&
elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get(
constType,
{APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())});
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
constAttr);
}
}
op->emitError("unimplemented lowering in AtenPoolingOp");
return nullptr;
}
namespace {
template <typename AtenOpT>
class ConvertAtenPoolingOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
// AtenMaxPool2dOp
namespace {
template <>
LogicalResult ConvertAtenPoolingOp<AtenMaxPool2dOp>::matchAndRewrite(
AtenMaxPool2dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.self();
auto inputTy = input.getType().cast<RankedTensorType>();
auto inputElemTy = inputTy.getElementType();
auto inputRank = inputTy.getRank();
auto outTy =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
if (inputRank <= 2) {
return op.emitError(
"max_pooling2d only supports inputs with rank higher than 2");
}
SmallVector<int64_t, 2> padding, kernelSize, stride, dilation;
bool ceilMode = false;
if (!(matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSize)))) {
return rewriter.notifyMatchFailure(
op, "non-const int kernel size unsupported!");
}
if (!(matchPattern(op.stride(), m_TorchConstantIntList(stride)))) {
return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!");
}
if (!(matchPattern(op.padding(), m_TorchConstantIntList(padding)))) {
return rewriter.notifyMatchFailure(op,
"non-const int padding unsupported!");
}
if (!(matchPattern(op.dilation(), m_TorchConstantIntList(dilation)))) {
return rewriter.notifyMatchFailure(op,
"non-const int dilation unsupported!");
}
if (!(matchPattern(op.ceil_mode(), m_TorchConstantBool(&ceilMode)))) {
return rewriter.notifyMatchFailure(op,
"non-const bool ceil_mode unsupported!");
}
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
// input
SmallVector<int64_t> mhloStride(inputRank, 1);
SmallVector<int64_t> mhloDilation(inputRank, 1);
SmallVector<int64_t> mhloKernelSize(inputRank, 1);
SmallVector<int64_t> mhloPadding(inputRank * 2, 0);
std::copy(dilation.begin(), dilation.end(),
mhloDilation.begin() + inputRank - 2);
std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2);
std::copy(kernelSize.begin(), kernelSize.end(),
mhloKernelSize.begin() + inputRank - 2);
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
mhloPadding[mhloPadding.size() - 4] = padding[0];
mhloPadding[mhloPadding.size() - 3] = padding[0];
mhloPadding[mhloPadding.size() - 2] = padding[1];
mhloPadding[mhloPadding.size() - 1] = padding[1];
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())},
rewriter.getI64Type()),
mhloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloStride.size())},
rewriter.getI64Type()),
mhloStride);
DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())},
rewriter.getI64Type()),
mhloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
rewriter.getI64Type()),
mhloPadding);
auto reduceWindowOp = rewriter.create<mhlo::ReduceWindowOp>(
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
baseDilations, windowDilations, pad);
Block &block = reduceWindowOp.body().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputElemTy);
block.addArgument(blockArgumentTy, op->getLoc());
block.addArgument(blockArgumentTy, op->getLoc());
auto *firstArg = block.args_begin();
auto secondArg = block.args_rbegin();
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value result =
rewriter.create<mhlo::MaxOp>(op->getLoc(), *firstArg, *secondArg);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), result);
}
rewriter.replaceOp(op, reduceWindowOp.getResults());
return success();
}
} // namespace
// AtenMaxPool2dWithIndicesOp
namespace {
template <>
LogicalResult ConvertAtenPoolingOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
AtenMaxPool2dWithIndicesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.self();
auto inputTy = input.getType().cast<RankedTensorType>();
auto inputElemTy = inputTy.getElementType();
auto inputShape = inputTy.getShape();
auto inputRank = inputTy.getRank();
auto outValTy =
getTypeConverter()->convertType(op.getType(0)).cast<RankedTensorType>();
auto outIdxTy =
getTypeConverter()->convertType(op.getType(1)).cast<RankedTensorType>();
if (inputRank <= 2) {
return op.emitError(
"max_pooling2d only supports inputs with rank higher than 2");
}
SmallVector<int64_t, 2> padding, kernelSize, stride, dilation;
bool ceilMode = false;
if (!(matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSize)))) {
return rewriter.notifyMatchFailure(
op, "non-const int kernel size unsupported!");
}
if (!(matchPattern(op.stride(), m_TorchConstantIntList(stride)))) {
return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!");
}
if (!(matchPattern(op.padding(), m_TorchConstantIntList(padding)))) {
return rewriter.notifyMatchFailure(op,
"non-const int padding unsupported!");
}
if (!(matchPattern(op.dilation(), m_TorchConstantIntList(dilation)))) {
return rewriter.notifyMatchFailure(op,
"non-const int dilation unsupported!");
}
if (!(matchPattern(op.ceil_mode(), m_TorchConstantBool(&ceilMode)))) {
return rewriter.notifyMatchFailure(op,
"non-const bool ceil_mode unsupported!");
}
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
// input
SmallVector<int64_t> mhloStride(inputRank, 1);
SmallVector<int64_t> mhloDilation(inputRank, 1);
SmallVector<int64_t> mhloKernelSize(inputRank, 1);
SmallVector<int64_t> mhloPadding(inputRank * 2, 0);
std::copy(dilation.begin(), dilation.end(),
mhloDilation.begin() + inputRank - 2);
std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2);
std::copy(kernelSize.begin(), kernelSize.end(),
mhloKernelSize.begin() + inputRank - 2);
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
mhloPadding[mhloPadding.size() - 4] = padding[0];
mhloPadding[mhloPadding.size() - 3] = padding[0];
mhloPadding[mhloPadding.size() - 2] = padding[1];
mhloPadding[mhloPadding.size() - 1] = padding[1];
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())},
rewriter.getI64Type()),
mhloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloStride.size())},
rewriter.getI64Type()),
mhloStride);
DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())},
rewriter.getI64Type()),
mhloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
rewriter.getI64Type()),
mhloPadding);
auto inputShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input);
if (failed(inputShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
auto inputShapeVec = *inputShapeInfo;
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), inputShapeVec);
SmallVector<Value> initIndexShapeVec;
for (int64_t i = 0; i < inputRank - 2; i++)
initIndexShapeVec.push_back(inputShapeVec[i]);
initIndexShapeVec.push_back(rewriter.create<mlir::arith::MulIOp>(
op->getLoc(), inputShapeVec[inputRank - 1],
inputShapeVec[inputRank - 2]));
auto initIndexShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), initIndexShapeVec);
SmallVector<int64_t> initIndexShapeForType(inputShape.begin(),
inputShape.end() - 2);
if (inputShape[inputRank - 1] == ShapedType::kDynamicSize ||
inputShape[inputRank - 2] == ShapedType::kDynamicSize) {
initIndexShapeForType.push_back(ShapedType::kDynamicSize);
} else {
initIndexShapeForType.push_back(inputShape[inputRank - 1] *
inputShape[inputRank - 2]);
}
auto initIndexTensor =
rewriter
.create<mhlo::DynamicIotaOp>(
op->getLoc(),
RankedTensorType::get(initIndexShapeForType,
rewriter.getI64Type()),
initIndexShapeTensor, static_cast<uint64_t>(inputRank - 2))
.getResult();
auto indexTensor =
rewriter
.create<mhlo::DynamicReshapeOp>(
op->getLoc(),
RankedTensorType::get(inputShape, rewriter.getI64Type()),
initIndexTensor, inputShapeTensor)
.getResult();
Value initIdx =
mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).getValue();
auto reduceWindowOp = rewriter.create<mhlo::ReduceWindowOp>(
op->getLoc(), mlir::TypeRange{outValTy, outIdxTy},
mlir::ValueRange{input, indexTensor}, mlir::ValueRange{initVal, initIdx},
windowDimensions, windowStrides, baseDilations, windowDilations, pad);
Block &block = reduceWindowOp.body().emplaceBlock();
// Add bb argument
auto blockValArgumentType = RankedTensorType::get({}, inputElemTy);
auto blockIdxArgumentType = RankedTensorType::get({}, rewriter.getI64Type());
auto compareResultType = RankedTensorType::get({}, rewriter.getI1Type());
block.addArgument(blockValArgumentType, op->getLoc());
block.addArgument(blockIdxArgumentType, op->getLoc());
block.addArgument(blockValArgumentType, op->getLoc());
block.addArgument(blockIdxArgumentType, op->getLoc());
auto *firstValArg = block.args_begin();
auto *firstIdxArg = std::next(firstValArg);
auto *secondValArg = std::next(firstIdxArg);
auto *secondIdxArg = std::next(secondValArg);
mhlo::ComparisonTypeAttr compareTypeAttr;
if (inputTy.getElementType().isa<mlir::FloatType>()) {
compareTypeAttr = mhlo::ComparisonTypeAttr::get(
rewriter.getContext(), mhlo::ComparisonType::FLOAT);
} else if (inputTy.getElementType().isa<mlir::IntegerType>()) {
compareTypeAttr = mhlo::ComparisonTypeAttr::get(
rewriter.getContext(), mhlo::ComparisonType::SIGNED);
}
mhlo::ComparisonDirectionAttr compareGeDirectionAttr =
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
mhlo::ComparisonDirection::GE);
mhlo::ComparisonDirectionAttr compareEqDirectionAttr =
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
mhlo::ComparisonDirection::EQ);
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value compareGeResult = rewriter.create<mhlo::CompareOp>(
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
compareGeDirectionAttr, compareTypeAttr);
Value retValResult = rewriter.create<mhlo::SelectOp>(
op->getLoc(), compareGeResult, *firstValArg, *secondValArg);
// Get smaller index if compared values are equal.
Value compareEqResult = rewriter.create<mhlo::CompareOp>(
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
compareEqDirectionAttr, compareTypeAttr);
Value minIdx =
rewriter.create<mhlo::MinOp>(op->getLoc(), *firstIdxArg, *secondIdxArg);
Value idxWithGeVal = rewriter.create<mhlo::SelectOp>(
op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg);
Value retIdxResult = rewriter.create<mhlo::SelectOp>(
op->getLoc(), compareEqResult, minIdx, idxWithGeVal);
rewriter.create<mhlo::ReturnOp>(
op->getLoc(), mlir::ValueRange{retValResult, retIdxResult});
}
rewriter.replaceOp(op, reduceWindowOp.getResults());
return success();
}
} // namespace
// AtenAvgPool2dOp
namespace {
template <>
LogicalResult ConvertAtenPoolingOp<AtenAvgPool2dOp>::matchAndRewrite(
AtenAvgPool2dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.self();
auto inputTy = input.getType().cast<RankedTensorType>();
auto inputElemTy = inputTy.getElementType();
auto inputRank = inputTy.getRank();
auto outTy =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
auto outShape = outTy.getShape();
if (inputRank <= 2) {
return op.emitError(
"avg_pooling2d only supports inputs with rank higher than 2");
}
SmallVector<int64_t, 2> padding, kernelSize, stride;
bool ceilMode = false;
bool countIncludePad = true;
if (!(matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSize)))) {
return rewriter.notifyMatchFailure(
op, "non-const int kernel size unsupported!");
}
if (!(matchPattern(op.stride(), m_TorchConstantIntList(stride)))) {
return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!");
}
if (!(matchPattern(op.padding(), m_TorchConstantIntList(padding)))) {
return rewriter.notifyMatchFailure(op,
"non-const int padding unsupported!");
}
if (!(matchPattern(op.ceil_mode(), m_TorchConstantBool(&ceilMode)))) {
return rewriter.notifyMatchFailure(op,
"non-const bool ceil_mode unsupported!");
}
if (!(matchPattern(op.count_include_pad(),
m_TorchConstantBool(&countIncludePad)))) {
return rewriter.notifyMatchFailure(
op, "non-const bool count_include_pad unsupported!");
}
if (succeeded(checkNotNone(rewriter, op, op.divisor_override()))) {
return rewriter.notifyMatchFailure(
op, "only None divisor_override supported for now!");
}
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
// input
SmallVector<int64_t> mhloStride(inputRank, 1);
SmallVector<int64_t> mhloDilation(inputRank, 1);
SmallVector<int64_t> mhloKernelSize(inputRank, 1);
SmallVector<int64_t> mhloPadding(inputRank * 2, 0);
std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2);
std::copy(kernelSize.begin(), kernelSize.end(),
mhloKernelSize.begin() + inputRank - 2);
mhloPadding[mhloPadding.size() - 4] = padding[0];
mhloPadding[mhloPadding.size() - 3] = padding[0];
mhloPadding[mhloPadding.size() - 2] = padding[1];
mhloPadding[mhloPadding.size() - 1] = padding[1];
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())},
rewriter.getI64Type()),
mhloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloStride.size())},
rewriter.getI64Type()),
mhloStride);
DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())},
rewriter.getI64Type()),
mhloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
rewriter.getI64Type()),
mhloPadding);
auto reduceWindowSum = rewriter.create<mhlo::ReduceWindowOp>(
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
baseDilations, windowDilations, pad);
Block &sumBlock = reduceWindowSum.body().emplaceBlock();
// Add bb argument
auto blockArgumentType = RankedTensorType::get({}, inputElemTy);
sumBlock.addArgument(blockArgumentType, op->getLoc());
sumBlock.addArgument(blockArgumentType, op->getLoc());
auto *firstArg = sumBlock.args_begin();
auto secondArg = sumBlock.args_rbegin();
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&sumBlock);
Value sumResult =
rewriter.create<mhlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), sumResult);
}
// Use kernel size as the divisor
if (countIncludePad) {
Value divisor = mhlo::getConstTensor<int64_t>(
rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
.getValue();
divisor = mhlo::promoteType(rewriter, divisor, outTy);
DenseIntElementsAttr bcastDimensions;
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions);
return success();
}
// Use another mhlo.ReduceWindowOp to get the divisor
Value windowSizeConst =
mhlo::getConstTensor<float>(rewriter, op, {1.0}, {}).getValue();
windowSizeConst = mhlo::promoteType(rewriter, windowSizeConst, outTy);
auto inputShapeVec = *mhlo::getDimSizesOfTensor(rewriter, op, input);
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), inputShapeVec);
windowSizeConst = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
op->getLoc(),
RankedTensorType::get(inputTy.getShape(), outTy.getElementType()),
windowSizeConst, inputShapeTensor, rewriter.getI64TensorAttr({}));
Value zero = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
auto reduceWindowSize = rewriter.create<mhlo::ReduceWindowOp>(
op->getLoc(), RankedTensorType::get(outShape, inputElemTy),
windowSizeConst, zero, windowDimensions, windowStrides, baseDilations,
windowDilations, pad);
Block &sizeBlock = reduceWindowSize.body().emplaceBlock();
// Add bb argument
blockArgumentType = RankedTensorType::get({}, inputElemTy);
sizeBlock.addArgument(blockArgumentType, op->getLoc());
sizeBlock.addArgument(blockArgumentType, op->getLoc());
firstArg = sizeBlock.args_begin();
secondArg = sizeBlock.args_rbegin();
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&sizeBlock);
Value sumResult =
rewriter.create<mhlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), sumResult);
}
rewriter.replaceOpWithNewOp<mhlo::DivOp>(
op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0));
return success();
}
} // namespace
void mlir::torch::torch_to_mhlo::populatePoolingOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenMaxPool2dOp>();
patterns.add<ConvertAtenPoolingOp<AtenMaxPool2dOp>>(typeConverter, context);
target.addIllegalOp<AtenAvgPool2dOp>();
patterns.add<ConvertAtenPoolingOp<AtenAvgPool2dOp>>(typeConverter, context);
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
patterns.add<ConvertAtenPoolingOp<AtenMaxPool2dWithIndicesOp>>(typeConverter,
context);
}

View File

@ -32,6 +32,10 @@ void populateLinearOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);
void populatePoolingOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);
} // namespace torch_to_mhlo
} // namespace torch
} // namespace mlir

View File

@ -13,7 +13,6 @@
#include "./PopulatePatterns.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h"
@ -24,7 +23,6 @@
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
using namespace mlir;
using namespace mlir::torch;
@ -44,8 +42,9 @@ public:
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<chlo::ChloDialect, mhlo::MhloDialect, tensor::TensorDialect,
arith::ArithmeticDialect, Torch::TorchDialect>();
target.addLegalDialect<chlo::ChloDialect, mhlo::MhloDialect,
tensor::TensorDialect, arith::ArithmeticDialect,
Torch::TorchDialect>();
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
@ -55,14 +54,16 @@ public:
torch_to_mhlo::populateBasicOpPatternsAndLegality(typeConverter, patterns,
target);
torch_to_mhlo::populateViewLikeOpPatternsAndLegality(typeConverter, patterns,
target);
torch_to_mhlo::populateViewLikeOpPatternsAndLegality(typeConverter,
patterns, target);
torch_to_mhlo::populateGatherOpPatternsAndLegality(typeConverter, patterns,
target);
torch_to_mhlo::populateReductionOpPatternsAndLegality(typeConverter,
patterns, target);
torch_to_mhlo::populateLinearOpPatternsAndLegality(typeConverter, patterns,
target);
torch_to_mhlo::populatePoolingOpPatternsAndLegality(typeConverter, patterns,
target);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {

View File

@ -0,0 +1,218 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s
// -----
// CHECK-LABEL: func.func @torch.aten.max_pool2d(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
// CHECK: %int2 = torch.constant.int 2
// CHECK: %int1 = torch.constant.int 1
// CHECK: %int0 = torch.constant.int 0
// CHECK: %false = torch.constant.bool false
// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_6:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor<f32>
// CHECK: %[[VAL_7:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_6]]) ({
// CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
// CHECK: %[[VAL_10:.*]] = mhlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
// CHECK: "mhlo.return"(%[[VAL_10]]) : (tensor<f32>) -> ()
// CHECK: }) {padding = dense<0> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[VAL_11]] : !torch.vtensor<[?,?,?,?],f32>
func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%false = torch.constant.bool false
%0 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%4 = torch.aten.max_pool2d %arg0, %0, %1, %2, %3, %false : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
return %4 : !torch.vtensor<[?,?,?,?],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.max_pool2d$padding(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
// CHECK: %int2 = torch.constant.int 2
// CHECK: %int1 = torch.constant.int 1
// CHECK: %false = torch.constant.bool false
// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor<f32>
// CHECK: %[[VAL_6:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({
// CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
// CHECK: %[[VAL_10:.*]] = mhlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
// CHECK: "mhlo.return"(%[[VAL_10]]) : (tensor<f32>) -> ()
// CHECK: })
// CHECK-SAME{LITERAL}: {padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?,?,?],f32>
func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%false = torch.constant.bool false
%0 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.max_pool2d %arg0, %0, %1, %2, %2, %false : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
return %3 : !torch.vtensor<[?,?,?,?],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.max_pool2d_with_indices(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?],f32>) -> (!torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64>) {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
// CHECK: %int3 = torch.constant.int 3
// CHECK: %int2 = torch.constant.int 2
// CHECK: %false = torch.constant.bool false
// CHECK: %int0 = torch.constant.int 0
// CHECK: %int1 = torch.constant.int 1
// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_6:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor<f32>
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor<?x?x?xf32>
// CHECK: %[[VAL_8:.*]] = arith.index_cast %[[VAL_7]] : index to i64
// CHECK: %[[IDX_1:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_9:.*]] = tensor.dim %[[VAL_1]], %[[IDX_1]] : tensor<?x?x?xf32>
// CHECK: %[[VAL_10:.*]] = arith.index_cast %[[VAL_9]] : index to i64
// CHECK: %[[IDX_2:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_11:.*]] = tensor.dim %[[VAL_1]], %[[IDX_2]] : tensor<?x?x?xf32>
// CHECK: %[[VAL_12:.*]] = arith.index_cast %[[VAL_11]] : index to i64
// CHECK: %[[VAL_13:.*]] = tensor.from_elements %[[VAL_8]], %[[VAL_10]], %[[VAL_12]] : tensor<3xi64>
// CHECK: %[[VAL_14:.*]] = arith.muli %[[VAL_12]], %[[VAL_10]] : i64
// CHECK: %[[VAL_15:.*]] = tensor.from_elements %[[VAL_8]], %[[VAL_14]] : tensor<2xi64>
// CHECK: %[[VAL_16:.*]] = "mhlo.dynamic_iota"(%[[VAL_15]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64>
// CHECK: %[[VAL_17:.*]] = "mhlo.dynamic_reshape"(%[[VAL_16]], %[[VAL_13]]) : (tensor<?x?xi64>, tensor<3xi64>) -> tensor<?x?x?xi64>
// CHECK: %[[VAL_18:.*]] = mhlo.constant dense<0> : tensor<i64>
// CHECK: %[[VAL_19:.*]]:2 = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_17]], %[[VAL_6]], %[[VAL_18]]) ({
// CHECK: ^bb0(%[[IVAL_0:.*]]: tensor<f32>, %[[IVAL_1:.*]]: tensor<i64>, %[[IVAL_2:.*]]: tensor<f32>, %[[IVAL_3:.*]]: tensor<i64>):
// CHECK: %[[IVAL_4:.*]] = "mhlo.compare"(%[[IVAL_0]], %[[IVAL_2]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction GE>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[IVAL_5:.*]] = "mhlo.select"(%[[IVAL_4]], %[[IVAL_0]], %[[IVAL_2]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: %[[IVAL_6:.*]] = "mhlo.compare"(%[[IVAL_0]], %[[IVAL_2]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[IVAL_7:.*]] = mhlo.minimum %[[IVAL_1]], %[[IVAL_3]] : tensor<i64>
// CHECK: %[[IVAL_8:.*]] = "mhlo.select"(%[[IVAL_4]], %[[IVAL_1]], %[[IVAL_3]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
// CHECK: %[[IVAL_9:.*]] = "mhlo.select"(%[[IVAL_6]], %[[IVAL_7]], %[[IVAL_8]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
// CHECK: "mhlo.return"(%[[IVAL_5]], %[[IVAL_9]]) : (tensor<f32>, tensor<i64>) -> ()
// CHECK{LITERAL}: }) {padding = dense<0> : tensor<3x2xi64>, window_dilations = dense<1> : tensor<3xi64>, window_dimensions = dense<[1, 3, 3]> : tensor<3xi64>, window_strides = dense<[1, 2, 2]> : tensor<3xi64>} : (tensor<?x?x?xf32>, tensor<?x?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?x?x?xf32>, tensor<?x?x?xi64>)
// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]]#0 : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_19]]#1 : tensor<?x?x?xi64> -> !torch.vtensor<[?,?,?],si64>
// CHECK: return %[[VAL_20]], %[[VAL_21]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64>
func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32>) -> (!torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64>) {
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%false = torch.constant.bool false
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%result0, %result1 = torch.aten.max_pool2d_with_indices %arg0, %0, %1, %2, %3, %false : !torch.vtensor<[?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64>
return %result0, %result1 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64>
}
// -----
// CHECK-LABEL: func.func @torch.aten.avg_pool2d(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
// CHECK: %int3 = torch.constant.int 3
// CHECK: %int2 = torch.constant.int 2
// CHECK: %int1 = torch.constant.int 1
// CHECK: %false = torch.constant.bool false
// CHECK: %none = torch.constant.none
// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[VAL_6:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({
// CHECK: ^bb0(%[[IVAL_0:.*]]: tensor<f32>, %[[IVAL_1:.*]]: tensor<f32>):
// CHECK: %[[IVAL_2:.*]] = mhlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor<f32>
// CHECK: "mhlo.return"(%[[IVAL_2]]) : (tensor<f32>) -> ()
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor<?x?x?x?xf32>
// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : index to i64
// CHECK: %[[IDX_1:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_10:.*]] = tensor.dim %[[VAL_1]], %[[IDX_1]] : tensor<?x?x?x?xf32>
// CHECK: %[[VAL_11:.*]] = arith.index_cast %[[VAL_10]] : index to i64
// CHECK: %[[IDX_2:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_12:.*]] = tensor.dim %[[VAL_1]], %[[IDX_2]] : tensor<?x?x?x?xf32>
// CHECK: %[[VAL_13:.*]] = arith.index_cast %[[VAL_12]] : index to i64
// CHECK: %[[IDX_3:.*]] = arith.constant 3 : index
// CHECK: %[[VAL_14:.*]] = tensor.dim %[[VAL_1]], %[[IDX_3]] : tensor<?x?x?x?xf32>
// CHECK: %[[VAL_15:.*]] = arith.index_cast %[[VAL_14]] : index to i64
// CHECK: %[[VAL_16:.*]] = tensor.from_elements %[[VAL_9]], %[[VAL_11]], %[[VAL_13]], %[[VAL_15]] : tensor<4xi64>
// CHECK: %[[VAL_17:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_7]], %[[VAL_16]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_18:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[VAL_19:.*]] = "mhlo.reduce_window"(%[[VAL_17]], %[[VAL_18]]) ({
// CHECK: ^bb0(%[[IVAL_3:.*]]: tensor<f32>, %[[IVAL_4:.*]]: tensor<f32>):
// CHECK: %[[IVAL_5:.*]] = mhlo.add %[[IVAL_3]], %[[IVAL_4]] : tensor<f32>
// CHECK: "mhlo.return"(%[[IVAL_5]]) : (tensor<f32>) -> ()
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_20:.*]] = mhlo.divide %[[VAL_6]], %[[VAL_19]] : tensor<?x?x?x?xf32>
// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[VAL_21]] : !torch.vtensor<[?,?,?,?],f32>
func.func @torch.aten.avg_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%false = torch.constant.bool false
%none = torch.constant.none
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %false, %none : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32>
return %3 : !torch.vtensor<[?,?,?,?],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.avg_pool2d$count_include_pad(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
// CHECK: %int3 = torch.constant.int 3
// CHECK: %int2 = torch.constant.int 2
// CHECK: %int1 = torch.constant.int 1
// CHECK: %false = torch.constant.bool false
// CHECK: %none = torch.constant.none
// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[VAL_6:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({
// CHECK: ^bb0(%[[IVAL_0:.*]]: tensor<f32>, %[[IVAL_1:.*]]: tensor<f32>):
// CHECK: %[[IVAL_2:.*]] = mhlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor<f32>
// CHECK: "mhlo.return"(%[[IVAL_2]]) : (tensor<f32>) -> ()
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<9> : tensor<i64>
// CHECK: %[[VAL_8:.*]] = mhlo.convert(%[[VAL_7]]) : (tensor<i64>) -> tensor<f32>
// CHECK: %[[VAL_9:.*]] = chlo.broadcast_divide %[[VAL_6]], %[[VAL_8]] : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[VAL_10]] : !torch.vtensor<[?,?,?,?],f32>
func.func @torch.aten.avg_pool2d$count_include_pad(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%false = torch.constant.bool false
%true = torch.constant.bool true
%none = torch.constant.none
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %true, %none : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32>
return %3 : !torch.vtensor<[?,?,?,?],f32>
}

View File

@ -448,11 +448,12 @@ cc_library(
srcs = [
"lib/Conversion/TorchToMhlo/TorchToMhlo.cpp",
"lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp",
"lib/Conversion/TorchToMhlo/BasicOp.cpp",
"lib/Conversion/TorchToMhlo/GatherOp.cpp",
"lib/Conversion/TorchToMhlo/Basic.cpp",
"lib/Conversion/TorchToMhlo/Gather.cpp",
"lib/Conversion/TorchToMhlo/Linear.cpp",
"lib/Conversion/TorchToMhlo/ViewLikeOps.cpp",
"lib/Conversion/TorchToMhlo/ReductionOp.cpp",
"lib/Conversion/TorchToMhlo/ViewLike.cpp",
"lib/Conversion/TorchToMhlo/Reduction.cpp",
"lib/Conversion/TorchToMhlo/Pooling.cpp",
"lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h",
"lib/Conversion/TorchToMhlo/PopulatePatterns.h",
"lib/Conversion/PassDetail.h",