mirror of https://github.com/llvm/torch-mlir
[MHLO] Init MHLO pooling-like op conversion (#1141)
* [MHLO] Init MHLO pooling-like op conversion and remove 'op' suffix in filenames Co-authored-by: Bairen Yi <yibairen.byron@bytedance.com> Co-authored-by: Jiawei Wu <xremold@gmail.com> Co-authored-by: Tianyou Guo tianyou.gty@alibaba-inc.com Co-authored-by: Xu Yan <yancey.yx@alibaba-inc.com> Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com> See RFC #999pull/1156/head
parent
f0a24f59f6
commit
d030591df9
|
@ -22,7 +22,6 @@
|
|||
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
|
||||
|
@ -618,9 +617,8 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
|
|||
namespace {
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
|
||||
AtenGeluOp op,
|
||||
OpAdaptor adaptor,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
AtenGeluOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Location loc = op.getLoc();
|
||||
Value input = adaptor.self();
|
||||
auto inputTy = input.getType().template dyn_cast<RankedTensorType>();
|
||||
|
@ -641,7 +639,6 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
|
|||
}
|
||||
} // namespace
|
||||
|
||||
|
||||
// AtenErfOp
|
||||
namespace {
|
||||
template <>
|
|
@ -1,11 +1,12 @@
|
|||
add_mlir_conversion_library(TorchMLIRTorchToMhlo
|
||||
TorchToMhlo.cpp
|
||||
MhloLegalizeUtils.cpp
|
||||
BasicOp.cpp
|
||||
GatherOp.cpp
|
||||
Basic.cpp
|
||||
Gather.cpp
|
||||
Linear.cpp
|
||||
ViewLikeOps.cpp
|
||||
ReductionOp.cpp
|
||||
ViewLike.cpp
|
||||
Reduction.cpp
|
||||
Pooling.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToMhlo
|
||||
|
|
|
@ -0,0 +1,557 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "./MhloLegalizeUtils.h"
|
||||
#include "./PopulatePatterns.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
||||
PatternRewriter &rewriter) {
|
||||
auto constType = RankedTensorType::get({}, elementTy);
|
||||
// Avg pooling
|
||||
if (isa<AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp>(op)) {
|
||||
if (elementTy.isa<mlir::FloatType>()) {
|
||||
auto constAttr = DenseElementsAttr::get(
|
||||
constType, {APFloat::getZero(
|
||||
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
/*negative=*/false)});
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
} else if (elementTy.isa<mlir::IntegerType>() &&
|
||||
elementTy.getIntOrFloatBitWidth() != 8) {
|
||||
auto constAttr = DenseElementsAttr::get(
|
||||
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
}
|
||||
}
|
||||
|
||||
// Max pooling
|
||||
if (isa<AtenMaxPool2dOp, AtenMaxPool2dWithIndicesOp>(op)) {
|
||||
if (elementTy.isa<mlir::FloatType>()) {
|
||||
auto constAttr = DenseElementsAttr::get(
|
||||
constType, {APFloat::getLargest(
|
||||
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
/*negative=*/true)});
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
} else if (elementTy.isa<mlir::IntegerType>() &&
|
||||
elementTy.getIntOrFloatBitWidth() != 8) {
|
||||
auto constAttr = DenseElementsAttr::get(
|
||||
constType,
|
||||
{APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())});
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
}
|
||||
}
|
||||
op->emitError("unimplemented lowering in AtenPoolingOp");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenPoolingOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// AtenMaxPool2dOp
|
||||
namespace {
|
||||
template <>
|
||||
LogicalResult ConvertAtenPoolingOp<AtenMaxPool2dOp>::matchAndRewrite(
|
||||
AtenMaxPool2dOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.self();
|
||||
auto inputTy = input.getType().cast<RankedTensorType>();
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
|
||||
auto inputRank = inputTy.getRank();
|
||||
auto outTy =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
|
||||
if (inputRank <= 2) {
|
||||
return op.emitError(
|
||||
"max_pooling2d only supports inputs with rank higher than 2");
|
||||
}
|
||||
SmallVector<int64_t, 2> padding, kernelSize, stride, dilation;
|
||||
bool ceilMode = false;
|
||||
|
||||
if (!(matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSize)))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "non-const int kernel size unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.stride(), m_TorchConstantIntList(stride)))) {
|
||||
return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.padding(), m_TorchConstantIntList(padding)))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const int padding unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.dilation(), m_TorchConstantIntList(dilation)))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const int dilation unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.ceil_mode(), m_TorchConstantBool(&ceilMode)))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const bool ceil_mode unsupported!");
|
||||
}
|
||||
|
||||
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
|
||||
// input
|
||||
SmallVector<int64_t> mhloStride(inputRank, 1);
|
||||
SmallVector<int64_t> mhloDilation(inputRank, 1);
|
||||
SmallVector<int64_t> mhloKernelSize(inputRank, 1);
|
||||
SmallVector<int64_t> mhloPadding(inputRank * 2, 0);
|
||||
std::copy(dilation.begin(), dilation.end(),
|
||||
mhloDilation.begin() + inputRank - 2);
|
||||
std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2);
|
||||
std::copy(kernelSize.begin(), kernelSize.end(),
|
||||
mhloKernelSize.begin() + inputRank - 2);
|
||||
|
||||
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||
|
||||
mhloPadding[mhloPadding.size() - 4] = padding[0];
|
||||
mhloPadding[mhloPadding.size() - 3] = padding[0];
|
||||
mhloPadding[mhloPadding.size() - 2] = padding[1];
|
||||
mhloPadding[mhloPadding.size() - 1] = padding[1];
|
||||
|
||||
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloKernelSize);
|
||||
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloStride.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloStride);
|
||||
DenseIntElementsAttr baseDilations;
|
||||
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloDilation);
|
||||
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(
|
||||
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
||||
rewriter.getI64Type()),
|
||||
mhloPadding);
|
||||
auto reduceWindowOp = rewriter.create<mhlo::ReduceWindowOp>(
|
||||
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
|
||||
baseDilations, windowDilations, pad);
|
||||
|
||||
Block &block = reduceWindowOp.body().emplaceBlock();
|
||||
|
||||
auto blockArgumentTy = RankedTensorType::get({}, inputElemTy);
|
||||
block.addArgument(blockArgumentTy, op->getLoc());
|
||||
block.addArgument(blockArgumentTy, op->getLoc());
|
||||
|
||||
auto *firstArg = block.args_begin();
|
||||
auto secondArg = block.args_rbegin();
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
Value result =
|
||||
rewriter.create<mhlo::MaxOp>(op->getLoc(), *firstArg, *secondArg);
|
||||
rewriter.create<mhlo::ReturnOp>(op->getLoc(), result);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, reduceWindowOp.getResults());
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// AtenMaxPool2dWithIndicesOp
|
||||
namespace {
|
||||
template <>
|
||||
LogicalResult ConvertAtenPoolingOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
||||
AtenMaxPool2dWithIndicesOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.self();
|
||||
auto inputTy = input.getType().cast<RankedTensorType>();
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
auto inputShape = inputTy.getShape();
|
||||
auto inputRank = inputTy.getRank();
|
||||
auto outValTy =
|
||||
getTypeConverter()->convertType(op.getType(0)).cast<RankedTensorType>();
|
||||
auto outIdxTy =
|
||||
getTypeConverter()->convertType(op.getType(1)).cast<RankedTensorType>();
|
||||
|
||||
if (inputRank <= 2) {
|
||||
return op.emitError(
|
||||
"max_pooling2d only supports inputs with rank higher than 2");
|
||||
}
|
||||
SmallVector<int64_t, 2> padding, kernelSize, stride, dilation;
|
||||
bool ceilMode = false;
|
||||
|
||||
if (!(matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSize)))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "non-const int kernel size unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.stride(), m_TorchConstantIntList(stride)))) {
|
||||
return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.padding(), m_TorchConstantIntList(padding)))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const int padding unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.dilation(), m_TorchConstantIntList(dilation)))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const int dilation unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.ceil_mode(), m_TorchConstantBool(&ceilMode)))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const bool ceil_mode unsupported!");
|
||||
}
|
||||
|
||||
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
|
||||
// input
|
||||
SmallVector<int64_t> mhloStride(inputRank, 1);
|
||||
SmallVector<int64_t> mhloDilation(inputRank, 1);
|
||||
SmallVector<int64_t> mhloKernelSize(inputRank, 1);
|
||||
SmallVector<int64_t> mhloPadding(inputRank * 2, 0);
|
||||
std::copy(dilation.begin(), dilation.end(),
|
||||
mhloDilation.begin() + inputRank - 2);
|
||||
std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2);
|
||||
std::copy(kernelSize.begin(), kernelSize.end(),
|
||||
mhloKernelSize.begin() + inputRank - 2);
|
||||
|
||||
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||
|
||||
mhloPadding[mhloPadding.size() - 4] = padding[0];
|
||||
mhloPadding[mhloPadding.size() - 3] = padding[0];
|
||||
mhloPadding[mhloPadding.size() - 2] = padding[1];
|
||||
mhloPadding[mhloPadding.size() - 1] = padding[1];
|
||||
|
||||
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloKernelSize);
|
||||
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloStride.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloStride);
|
||||
DenseIntElementsAttr baseDilations;
|
||||
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloDilation);
|
||||
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(
|
||||
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
||||
rewriter.getI64Type()),
|
||||
mhloPadding);
|
||||
|
||||
auto inputShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input);
|
||||
if (failed(inputShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
}
|
||||
auto inputShapeVec = *inputShapeInfo;
|
||||
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), inputShapeVec);
|
||||
|
||||
SmallVector<Value> initIndexShapeVec;
|
||||
for (int64_t i = 0; i < inputRank - 2; i++)
|
||||
initIndexShapeVec.push_back(inputShapeVec[i]);
|
||||
initIndexShapeVec.push_back(rewriter.create<mlir::arith::MulIOp>(
|
||||
op->getLoc(), inputShapeVec[inputRank - 1],
|
||||
inputShapeVec[inputRank - 2]));
|
||||
auto initIndexShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), initIndexShapeVec);
|
||||
|
||||
SmallVector<int64_t> initIndexShapeForType(inputShape.begin(),
|
||||
inputShape.end() - 2);
|
||||
if (inputShape[inputRank - 1] == ShapedType::kDynamicSize ||
|
||||
inputShape[inputRank - 2] == ShapedType::kDynamicSize) {
|
||||
initIndexShapeForType.push_back(ShapedType::kDynamicSize);
|
||||
} else {
|
||||
initIndexShapeForType.push_back(inputShape[inputRank - 1] *
|
||||
inputShape[inputRank - 2]);
|
||||
}
|
||||
|
||||
auto initIndexTensor =
|
||||
rewriter
|
||||
.create<mhlo::DynamicIotaOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(initIndexShapeForType,
|
||||
rewriter.getI64Type()),
|
||||
initIndexShapeTensor, static_cast<uint64_t>(inputRank - 2))
|
||||
.getResult();
|
||||
|
||||
auto indexTensor =
|
||||
rewriter
|
||||
.create<mhlo::DynamicReshapeOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(inputShape, rewriter.getI64Type()),
|
||||
initIndexTensor, inputShapeTensor)
|
||||
.getResult();
|
||||
|
||||
Value initIdx =
|
||||
mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).getValue();
|
||||
|
||||
auto reduceWindowOp = rewriter.create<mhlo::ReduceWindowOp>(
|
||||
op->getLoc(), mlir::TypeRange{outValTy, outIdxTy},
|
||||
mlir::ValueRange{input, indexTensor}, mlir::ValueRange{initVal, initIdx},
|
||||
windowDimensions, windowStrides, baseDilations, windowDilations, pad);
|
||||
|
||||
Block &block = reduceWindowOp.body().emplaceBlock();
|
||||
|
||||
// Add bb argument
|
||||
auto blockValArgumentType = RankedTensorType::get({}, inputElemTy);
|
||||
auto blockIdxArgumentType = RankedTensorType::get({}, rewriter.getI64Type());
|
||||
auto compareResultType = RankedTensorType::get({}, rewriter.getI1Type());
|
||||
block.addArgument(blockValArgumentType, op->getLoc());
|
||||
block.addArgument(blockIdxArgumentType, op->getLoc());
|
||||
block.addArgument(blockValArgumentType, op->getLoc());
|
||||
block.addArgument(blockIdxArgumentType, op->getLoc());
|
||||
auto *firstValArg = block.args_begin();
|
||||
auto *firstIdxArg = std::next(firstValArg);
|
||||
auto *secondValArg = std::next(firstIdxArg);
|
||||
auto *secondIdxArg = std::next(secondValArg);
|
||||
|
||||
mhlo::ComparisonTypeAttr compareTypeAttr;
|
||||
if (inputTy.getElementType().isa<mlir::FloatType>()) {
|
||||
compareTypeAttr = mhlo::ComparisonTypeAttr::get(
|
||||
rewriter.getContext(), mhlo::ComparisonType::FLOAT);
|
||||
} else if (inputTy.getElementType().isa<mlir::IntegerType>()) {
|
||||
compareTypeAttr = mhlo::ComparisonTypeAttr::get(
|
||||
rewriter.getContext(), mhlo::ComparisonType::SIGNED);
|
||||
}
|
||||
mhlo::ComparisonDirectionAttr compareGeDirectionAttr =
|
||||
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
|
||||
mhlo::ComparisonDirection::GE);
|
||||
mhlo::ComparisonDirectionAttr compareEqDirectionAttr =
|
||||
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
|
||||
mhlo::ComparisonDirection::EQ);
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
|
||||
Value compareGeResult = rewriter.create<mhlo::CompareOp>(
|
||||
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
||||
compareGeDirectionAttr, compareTypeAttr);
|
||||
Value retValResult = rewriter.create<mhlo::SelectOp>(
|
||||
op->getLoc(), compareGeResult, *firstValArg, *secondValArg);
|
||||
|
||||
// Get smaller index if compared values are equal.
|
||||
Value compareEqResult = rewriter.create<mhlo::CompareOp>(
|
||||
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
||||
compareEqDirectionAttr, compareTypeAttr);
|
||||
Value minIdx =
|
||||
rewriter.create<mhlo::MinOp>(op->getLoc(), *firstIdxArg, *secondIdxArg);
|
||||
Value idxWithGeVal = rewriter.create<mhlo::SelectOp>(
|
||||
op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg);
|
||||
Value retIdxResult = rewriter.create<mhlo::SelectOp>(
|
||||
op->getLoc(), compareEqResult, minIdx, idxWithGeVal);
|
||||
|
||||
rewriter.create<mhlo::ReturnOp>(
|
||||
op->getLoc(), mlir::ValueRange{retValResult, retIdxResult});
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, reduceWindowOp.getResults());
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// AtenAvgPool2dOp
|
||||
namespace {
|
||||
template <>
|
||||
LogicalResult ConvertAtenPoolingOp<AtenAvgPool2dOp>::matchAndRewrite(
|
||||
AtenAvgPool2dOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.self();
|
||||
auto inputTy = input.getType().cast<RankedTensorType>();
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
auto inputRank = inputTy.getRank();
|
||||
auto outTy =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
auto outShape = outTy.getShape();
|
||||
|
||||
if (inputRank <= 2) {
|
||||
return op.emitError(
|
||||
"avg_pooling2d only supports inputs with rank higher than 2");
|
||||
}
|
||||
SmallVector<int64_t, 2> padding, kernelSize, stride;
|
||||
bool ceilMode = false;
|
||||
bool countIncludePad = true;
|
||||
|
||||
if (!(matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSize)))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "non-const int kernel size unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.stride(), m_TorchConstantIntList(stride)))) {
|
||||
return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.padding(), m_TorchConstantIntList(padding)))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const int padding unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.ceil_mode(), m_TorchConstantBool(&ceilMode)))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const bool ceil_mode unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.count_include_pad(),
|
||||
m_TorchConstantBool(&countIncludePad)))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "non-const bool count_include_pad unsupported!");
|
||||
}
|
||||
if (succeeded(checkNotNone(rewriter, op, op.divisor_override()))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only None divisor_override supported for now!");
|
||||
}
|
||||
|
||||
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
|
||||
// input
|
||||
SmallVector<int64_t> mhloStride(inputRank, 1);
|
||||
SmallVector<int64_t> mhloDilation(inputRank, 1);
|
||||
SmallVector<int64_t> mhloKernelSize(inputRank, 1);
|
||||
SmallVector<int64_t> mhloPadding(inputRank * 2, 0);
|
||||
|
||||
std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2);
|
||||
std::copy(kernelSize.begin(), kernelSize.end(),
|
||||
mhloKernelSize.begin() + inputRank - 2);
|
||||
mhloPadding[mhloPadding.size() - 4] = padding[0];
|
||||
mhloPadding[mhloPadding.size() - 3] = padding[0];
|
||||
mhloPadding[mhloPadding.size() - 2] = padding[1];
|
||||
mhloPadding[mhloPadding.size() - 1] = padding[1];
|
||||
|
||||
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||
|
||||
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloKernelSize);
|
||||
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloStride.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloStride);
|
||||
DenseIntElementsAttr baseDilations;
|
||||
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloDilation);
|
||||
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(
|
||||
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
||||
rewriter.getI64Type()),
|
||||
mhloPadding);
|
||||
|
||||
auto reduceWindowSum = rewriter.create<mhlo::ReduceWindowOp>(
|
||||
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
|
||||
baseDilations, windowDilations, pad);
|
||||
|
||||
Block &sumBlock = reduceWindowSum.body().emplaceBlock();
|
||||
|
||||
// Add bb argument
|
||||
auto blockArgumentType = RankedTensorType::get({}, inputElemTy);
|
||||
sumBlock.addArgument(blockArgumentType, op->getLoc());
|
||||
sumBlock.addArgument(blockArgumentType, op->getLoc());
|
||||
auto *firstArg = sumBlock.args_begin();
|
||||
auto secondArg = sumBlock.args_rbegin();
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&sumBlock);
|
||||
|
||||
Value sumResult =
|
||||
rewriter.create<mhlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
|
||||
rewriter.create<mhlo::ReturnOp>(op->getLoc(), sumResult);
|
||||
}
|
||||
|
||||
// Use kernel size as the divisor
|
||||
if (countIncludePad) {
|
||||
Value divisor = mhlo::getConstTensor<int64_t>(
|
||||
rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
|
||||
.getValue();
|
||||
divisor = mhlo::promoteType(rewriter, divisor, outTy);
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
|
||||
op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Use another mhlo.ReduceWindowOp to get the divisor
|
||||
Value windowSizeConst =
|
||||
mhlo::getConstTensor<float>(rewriter, op, {1.0}, {}).getValue();
|
||||
windowSizeConst = mhlo::promoteType(rewriter, windowSizeConst, outTy);
|
||||
auto inputShapeVec = *mhlo::getDimSizesOfTensor(rewriter, op, input);
|
||||
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), inputShapeVec);
|
||||
|
||||
windowSizeConst = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(inputTy.getShape(), outTy.getElementType()),
|
||||
windowSizeConst, inputShapeTensor, rewriter.getI64TensorAttr({}));
|
||||
|
||||
Value zero = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||
auto reduceWindowSize = rewriter.create<mhlo::ReduceWindowOp>(
|
||||
op->getLoc(), RankedTensorType::get(outShape, inputElemTy),
|
||||
windowSizeConst, zero, windowDimensions, windowStrides, baseDilations,
|
||||
windowDilations, pad);
|
||||
|
||||
Block &sizeBlock = reduceWindowSize.body().emplaceBlock();
|
||||
|
||||
// Add bb argument
|
||||
blockArgumentType = RankedTensorType::get({}, inputElemTy);
|
||||
sizeBlock.addArgument(blockArgumentType, op->getLoc());
|
||||
sizeBlock.addArgument(blockArgumentType, op->getLoc());
|
||||
firstArg = sizeBlock.args_begin();
|
||||
secondArg = sizeBlock.args_rbegin();
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&sizeBlock);
|
||||
|
||||
Value sumResult =
|
||||
rewriter.create<mhlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
|
||||
rewriter.create<mhlo::ReturnOp>(op->getLoc(), sumResult);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<mhlo::DivOp>(
|
||||
op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0));
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::torch::torch_to_mhlo::populatePoolingOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
target.addIllegalOp<AtenMaxPool2dOp>();
|
||||
patterns.add<ConvertAtenPoolingOp<AtenMaxPool2dOp>>(typeConverter, context);
|
||||
target.addIllegalOp<AtenAvgPool2dOp>();
|
||||
patterns.add<ConvertAtenPoolingOp<AtenAvgPool2dOp>>(typeConverter, context);
|
||||
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
|
||||
patterns.add<ConvertAtenPoolingOp<AtenMaxPool2dWithIndicesOp>>(typeConverter,
|
||||
context);
|
||||
}
|
|
@ -32,6 +32,10 @@ void populateLinearOpPatternsAndLegality(TypeConverter &typeConverter,
|
|||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target);
|
||||
|
||||
void populatePoolingOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target);
|
||||
|
||||
} // namespace torch_to_mhlo
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
#include "./PopulatePatterns.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
|
@ -24,7 +23,6 @@
|
|||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
|
@ -44,8 +42,9 @@ public:
|
|||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<chlo::ChloDialect, mhlo::MhloDialect, tensor::TensorDialect,
|
||||
arith::ArithmeticDialect, Torch::TorchDialect>();
|
||||
target.addLegalDialect<chlo::ChloDialect, mhlo::MhloDialect,
|
||||
tensor::TensorDialect, arith::ArithmeticDialect,
|
||||
Torch::TorchDialect>();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
typeConverter.addConversion([](Type type) { return type; });
|
||||
|
@ -55,14 +54,16 @@ public:
|
|||
|
||||
torch_to_mhlo::populateBasicOpPatternsAndLegality(typeConverter, patterns,
|
||||
target);
|
||||
torch_to_mhlo::populateViewLikeOpPatternsAndLegality(typeConverter, patterns,
|
||||
target);
|
||||
torch_to_mhlo::populateViewLikeOpPatternsAndLegality(typeConverter,
|
||||
patterns, target);
|
||||
torch_to_mhlo::populateGatherOpPatternsAndLegality(typeConverter, patterns,
|
||||
target);
|
||||
torch_to_mhlo::populateReductionOpPatternsAndLegality(typeConverter,
|
||||
patterns, target);
|
||||
torch_to_mhlo::populateLinearOpPatternsAndLegality(typeConverter, patterns,
|
||||
target);
|
||||
torch_to_mhlo::populatePoolingOpPatternsAndLegality(typeConverter, patterns,
|
||||
target);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns)))) {
|
||||
|
|
|
@ -0,0 +1,218 @@
|
|||
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.max_pool2d(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %int2 = torch.constant.int 2
|
||||
// CHECK: %int1 = torch.constant.int 1
|
||||
// CHECK: %int0 = torch.constant.int 0
|
||||
// CHECK: %false = torch.constant.bool false
|
||||
// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_6:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor<f32>
|
||||
// CHECK: %[[VAL_7:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_6]]) ({
|
||||
// CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
|
||||
// CHECK: %[[VAL_10:.*]] = mhlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
|
||||
// CHECK: "mhlo.return"(%[[VAL_10]]) : (tensor<f32>) -> ()
|
||||
// CHECK: }) {padding = dense<0> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||
// CHECK: return %[[VAL_11]] : !torch.vtensor<[?,?,?,?],f32>
|
||||
func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
%int2 = torch.constant.int 2
|
||||
%int1 = torch.constant.int 1
|
||||
%int0 = torch.constant.int 0
|
||||
%false = torch.constant.bool false
|
||||
%0 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%3 = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%4 = torch.aten.max_pool2d %arg0, %0, %1, %2, %3, %false : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
|
||||
return %4 : !torch.vtensor<[?,?,?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.max_pool2d$padding(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %int2 = torch.constant.int 2
|
||||
// CHECK: %int1 = torch.constant.int 1
|
||||
// CHECK: %false = torch.constant.bool false
|
||||
// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor<f32>
|
||||
// CHECK: %[[VAL_6:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({
|
||||
// CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
|
||||
// CHECK: %[[VAL_10:.*]] = mhlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
|
||||
// CHECK: "mhlo.return"(%[[VAL_10]]) : (tensor<f32>) -> ()
|
||||
// CHECK: })
|
||||
// CHECK-SAME{LITERAL}: {padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||
// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?,?,?],f32>
|
||||
func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
%int2 = torch.constant.int 2
|
||||
%int1 = torch.constant.int 1
|
||||
%false = torch.constant.bool false
|
||||
%0 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%2 = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%3 = torch.aten.max_pool2d %arg0, %0, %1, %2, %2, %false : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
|
||||
return %3 : !torch.vtensor<[?,?,?,?],f32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.max_pool2d_with_indices(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?],f32>) -> (!torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64>) {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
|
||||
// CHECK: %int3 = torch.constant.int 3
|
||||
// CHECK: %int2 = torch.constant.int 2
|
||||
// CHECK: %false = torch.constant.bool false
|
||||
// CHECK: %int0 = torch.constant.int 0
|
||||
// CHECK: %int1 = torch.constant.int 1
|
||||
// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_6:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor<f32>
|
||||
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[VAL_8:.*]] = arith.index_cast %[[VAL_7]] : index to i64
|
||||
// CHECK: %[[IDX_1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[VAL_9:.*]] = tensor.dim %[[VAL_1]], %[[IDX_1]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[VAL_10:.*]] = arith.index_cast %[[VAL_9]] : index to i64
|
||||
// CHECK: %[[IDX_2:.*]] = arith.constant 2 : index
|
||||
// CHECK: %[[VAL_11:.*]] = tensor.dim %[[VAL_1]], %[[IDX_2]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[VAL_12:.*]] = arith.index_cast %[[VAL_11]] : index to i64
|
||||
// CHECK: %[[VAL_13:.*]] = tensor.from_elements %[[VAL_8]], %[[VAL_10]], %[[VAL_12]] : tensor<3xi64>
|
||||
// CHECK: %[[VAL_14:.*]] = arith.muli %[[VAL_12]], %[[VAL_10]] : i64
|
||||
// CHECK: %[[VAL_15:.*]] = tensor.from_elements %[[VAL_8]], %[[VAL_14]] : tensor<2xi64>
|
||||
// CHECK: %[[VAL_16:.*]] = "mhlo.dynamic_iota"(%[[VAL_15]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[VAL_17:.*]] = "mhlo.dynamic_reshape"(%[[VAL_16]], %[[VAL_13]]) : (tensor<?x?xi64>, tensor<3xi64>) -> tensor<?x?x?xi64>
|
||||
// CHECK: %[[VAL_18:.*]] = mhlo.constant dense<0> : tensor<i64>
|
||||
// CHECK: %[[VAL_19:.*]]:2 = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_17]], %[[VAL_6]], %[[VAL_18]]) ({
|
||||
// CHECK: ^bb0(%[[IVAL_0:.*]]: tensor<f32>, %[[IVAL_1:.*]]: tensor<i64>, %[[IVAL_2:.*]]: tensor<f32>, %[[IVAL_3:.*]]: tensor<i64>):
|
||||
// CHECK: %[[IVAL_4:.*]] = "mhlo.compare"(%[[IVAL_0]], %[[IVAL_2]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction GE>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[IVAL_5:.*]] = "mhlo.select"(%[[IVAL_4]], %[[IVAL_0]], %[[IVAL_2]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %[[IVAL_6:.*]] = "mhlo.compare"(%[[IVAL_0]], %[[IVAL_2]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[IVAL_7:.*]] = mhlo.minimum %[[IVAL_1]], %[[IVAL_3]] : tensor<i64>
|
||||
// CHECK: %[[IVAL_8:.*]] = "mhlo.select"(%[[IVAL_4]], %[[IVAL_1]], %[[IVAL_3]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||
// CHECK: %[[IVAL_9:.*]] = "mhlo.select"(%[[IVAL_6]], %[[IVAL_7]], %[[IVAL_8]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||
// CHECK: "mhlo.return"(%[[IVAL_5]], %[[IVAL_9]]) : (tensor<f32>, tensor<i64>) -> ()
|
||||
// CHECK{LITERAL}: }) {padding = dense<0> : tensor<3x2xi64>, window_dilations = dense<1> : tensor<3xi64>, window_dimensions = dense<[1, 3, 3]> : tensor<3xi64>, window_strides = dense<[1, 2, 2]> : tensor<3xi64>} : (tensor<?x?x?xf32>, tensor<?x?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?x?x?xf32>, tensor<?x?x?xi64>)
|
||||
// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]]#0 : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_19]]#1 : tensor<?x?x?xi64> -> !torch.vtensor<[?,?,?],si64>
|
||||
// CHECK: return %[[VAL_20]], %[[VAL_21]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64>
|
||||
func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32>) -> (!torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64>) {
|
||||
%int3 = torch.constant.int 3
|
||||
%int2 = torch.constant.int 2
|
||||
%false = torch.constant.bool false
|
||||
%int0 = torch.constant.int 0
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%3 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%result0, %result1 = torch.aten.max_pool2d_with_indices %arg0, %0, %1, %2, %3, %false : !torch.vtensor<[?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64>
|
||||
return %result0, %result1 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.avg_pool2d(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %int3 = torch.constant.int 3
|
||||
// CHECK: %int2 = torch.constant.int 2
|
||||
// CHECK: %int1 = torch.constant.int 1
|
||||
// CHECK: %false = torch.constant.bool false
|
||||
// CHECK: %none = torch.constant.none
|
||||
// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VAL_6:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({
|
||||
// CHECK: ^bb0(%[[IVAL_0:.*]]: tensor<f32>, %[[IVAL_1:.*]]: tensor<f32>):
|
||||
// CHECK: %[[IVAL_2:.*]] = mhlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor<f32>
|
||||
// CHECK: "mhlo.return"(%[[IVAL_2]]) : (tensor<f32>) -> ()
|
||||
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : index to i64
|
||||
// CHECK: %[[IDX_1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[VAL_10:.*]] = tensor.dim %[[VAL_1]], %[[IDX_1]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_11:.*]] = arith.index_cast %[[VAL_10]] : index to i64
|
||||
// CHECK: %[[IDX_2:.*]] = arith.constant 2 : index
|
||||
// CHECK: %[[VAL_12:.*]] = tensor.dim %[[VAL_1]], %[[IDX_2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_13:.*]] = arith.index_cast %[[VAL_12]] : index to i64
|
||||
// CHECK: %[[IDX_3:.*]] = arith.constant 3 : index
|
||||
// CHECK: %[[VAL_14:.*]] = tensor.dim %[[VAL_1]], %[[IDX_3]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_15:.*]] = arith.index_cast %[[VAL_14]] : index to i64
|
||||
// CHECK: %[[VAL_16:.*]] = tensor.from_elements %[[VAL_9]], %[[VAL_11]], %[[VAL_13]], %[[VAL_15]] : tensor<4xi64>
|
||||
// CHECK: %[[VAL_17:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_7]], %[[VAL_16]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_18:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VAL_19:.*]] = "mhlo.reduce_window"(%[[VAL_17]], %[[VAL_18]]) ({
|
||||
// CHECK: ^bb0(%[[IVAL_3:.*]]: tensor<f32>, %[[IVAL_4:.*]]: tensor<f32>):
|
||||
// CHECK: %[[IVAL_5:.*]] = mhlo.add %[[IVAL_3]], %[[IVAL_4]] : tensor<f32>
|
||||
// CHECK: "mhlo.return"(%[[IVAL_5]]) : (tensor<f32>) -> ()
|
||||
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_20:.*]] = mhlo.divide %[[VAL_6]], %[[VAL_19]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||
// CHECK: return %[[VAL_21]] : !torch.vtensor<[?,?,?,?],f32>
|
||||
func.func @torch.aten.avg_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
%int3 = torch.constant.int 3
|
||||
%int2 = torch.constant.int 2
|
||||
%int1 = torch.constant.int 1
|
||||
%false = torch.constant.bool false
|
||||
%none = torch.constant.none
|
||||
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %false, %none : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32>
|
||||
return %3 : !torch.vtensor<[?,?,?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.avg_pool2d$count_include_pad(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %int3 = torch.constant.int 3
|
||||
// CHECK: %int2 = torch.constant.int 2
|
||||
// CHECK: %int1 = torch.constant.int 1
|
||||
// CHECK: %false = torch.constant.bool false
|
||||
// CHECK: %none = torch.constant.none
|
||||
// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VAL_6:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({
|
||||
// CHECK: ^bb0(%[[IVAL_0:.*]]: tensor<f32>, %[[IVAL_1:.*]]: tensor<f32>):
|
||||
// CHECK: %[[IVAL_2:.*]] = mhlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor<f32>
|
||||
// CHECK: "mhlo.return"(%[[IVAL_2]]) : (tensor<f32>) -> ()
|
||||
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<9> : tensor<i64>
|
||||
// CHECK: %[[VAL_8:.*]] = mhlo.convert(%[[VAL_7]]) : (tensor<i64>) -> tensor<f32>
|
||||
// CHECK: %[[VAL_9:.*]] = chlo.broadcast_divide %[[VAL_6]], %[[VAL_8]] : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||
// CHECK: return %[[VAL_10]] : !torch.vtensor<[?,?,?,?],f32>
|
||||
func.func @torch.aten.avg_pool2d$count_include_pad(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
%int3 = torch.constant.int 3
|
||||
%int2 = torch.constant.int 2
|
||||
%int1 = torch.constant.int 1
|
||||
%false = torch.constant.bool false
|
||||
%true = torch.constant.bool true
|
||||
%none = torch.constant.none
|
||||
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %true, %none : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32>
|
||||
return %3 : !torch.vtensor<[?,?,?,?],f32>
|
||||
}
|
|
@ -448,11 +448,12 @@ cc_library(
|
|||
srcs = [
|
||||
"lib/Conversion/TorchToMhlo/TorchToMhlo.cpp",
|
||||
"lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp",
|
||||
"lib/Conversion/TorchToMhlo/BasicOp.cpp",
|
||||
"lib/Conversion/TorchToMhlo/GatherOp.cpp",
|
||||
"lib/Conversion/TorchToMhlo/Basic.cpp",
|
||||
"lib/Conversion/TorchToMhlo/Gather.cpp",
|
||||
"lib/Conversion/TorchToMhlo/Linear.cpp",
|
||||
"lib/Conversion/TorchToMhlo/ViewLikeOps.cpp",
|
||||
"lib/Conversion/TorchToMhlo/ReductionOp.cpp",
|
||||
"lib/Conversion/TorchToMhlo/ViewLike.cpp",
|
||||
"lib/Conversion/TorchToMhlo/Reduction.cpp",
|
||||
"lib/Conversion/TorchToMhlo/Pooling.cpp",
|
||||
"lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h",
|
||||
"lib/Conversion/TorchToMhlo/PopulatePatterns.h",
|
||||
"lib/Conversion/PassDetail.h",
|
||||
|
|
Loading…
Reference in New Issue