mirror of https://github.com/llvm/torch-mlir
[Stablehlo] Improve the lowering of pool op in stablehlo (#3259)
1. Handle case stride == None 2. add avgpool3d maxpool1d maxpool3d loweringpull/3269/head
parent
fb8aed0907
commit
f32ada993d
|
@ -6637,6 +6637,34 @@ def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenMaxPool1dOp : Torch_Op<"aten.max_pool1d", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchListOfTorchIntType:$kernel_size,
|
||||||
|
AnyTorchListOfTorchIntType:$stride,
|
||||||
|
AnyTorchListOfTorchIntType:$padding,
|
||||||
|
AnyTorchListOfTorchIntType:$dilation,
|
||||||
|
Torch_BoolType:$ceil_mode
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchOptionalTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenMaxPool1dOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 6, 1);
|
||||||
|
}
|
||||||
|
void AtenMaxPool1dOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 6, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
|
def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -36,7 +36,7 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
||||||
auto constType = RankedTensorType::get({}, elementTy);
|
auto constType = RankedTensorType::get({}, elementTy);
|
||||||
// Avg pooling
|
// Avg pooling
|
||||||
if (isa<AtenAvgPool1dOp, AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp,
|
if (isa<AtenAvgPool1dOp, AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp,
|
||||||
AtenCumsumOp>(op)) {
|
AtenAvgPool3dOp, AtenCumsumOp>(op)) {
|
||||||
if (isa<mlir::FloatType>(elementTy)) {
|
if (isa<mlir::FloatType>(elementTy)) {
|
||||||
auto constAttr = DenseElementsAttr::get(
|
auto constAttr = DenseElementsAttr::get(
|
||||||
constType, {APFloat::getZero(
|
constType, {APFloat::getZero(
|
||||||
|
@ -54,7 +54,8 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Max pooling
|
// Max pooling
|
||||||
if (isa<AtenMaxPool2dOp, AtenMaxPool2dWithIndicesOp>(op)) {
|
if (isa<AtenMaxPool1dOp, AtenMaxPool2dOp, AtenMaxPool3dOp,
|
||||||
|
AtenMaxPool2dWithIndicesOp>(op)) {
|
||||||
if (isa<mlir::FloatType>(elementTy)) {
|
if (isa<mlir::FloatType>(elementTy)) {
|
||||||
auto constAttr = DenseElementsAttr::get(
|
auto constAttr = DenseElementsAttr::get(
|
||||||
constType,
|
constType,
|
||||||
|
@ -75,101 +76,6 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// AtenMaxPool2dOp
|
|
||||||
template <>
|
|
||||||
LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
|
|
||||||
AtenMaxPool2dOp op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const {
|
|
||||||
Value input = adaptor.getSelf();
|
|
||||||
auto inputTy = cast<RankedTensorType>(input.getType());
|
|
||||||
auto inputElemTy = inputTy.getElementType();
|
|
||||||
|
|
||||||
auto inputRank = inputTy.getRank();
|
|
||||||
auto outTy =
|
|
||||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
|
||||||
|
|
||||||
if (inputRank <= 2) {
|
|
||||||
return op.emitError(
|
|
||||||
"max_pooling2d only supports inputs with rank higher than 2");
|
|
||||||
}
|
|
||||||
SmallVector<int64_t, 2> padding, kernelSize, stride, dilation;
|
|
||||||
bool ceilMode = false;
|
|
||||||
|
|
||||||
if (!(matchPattern(op.getKernelSize(),
|
|
||||||
m_TorchListOfConstantInts(kernelSize)))) {
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "non-const int kernel size unsupported!");
|
|
||||||
}
|
|
||||||
if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) {
|
|
||||||
return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!");
|
|
||||||
}
|
|
||||||
if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) {
|
|
||||||
return rewriter.notifyMatchFailure(op,
|
|
||||||
"non-const int padding unsupported!");
|
|
||||||
}
|
|
||||||
if (!(matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilation)))) {
|
|
||||||
return rewriter.notifyMatchFailure(op,
|
|
||||||
"non-const int dilation unsupported!");
|
|
||||||
}
|
|
||||||
if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) {
|
|
||||||
return rewriter.notifyMatchFailure(op,
|
|
||||||
"non-const bool ceil_mode unsupported!");
|
|
||||||
}
|
|
||||||
|
|
||||||
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
|
|
||||||
// input
|
|
||||||
SmallVector<int64_t> stablehloStride(inputRank, 1);
|
|
||||||
SmallVector<int64_t> stablehloDilation(inputRank, 1);
|
|
||||||
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
|
|
||||||
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
|
||||||
std::copy(dilation.begin(), dilation.end(),
|
|
||||||
stablehloDilation.begin() + inputRank - 2);
|
|
||||||
std::copy(stride.begin(), stride.end(),
|
|
||||||
stablehloStride.begin() + inputRank - 2);
|
|
||||||
std::copy(kernelSize.begin(), kernelSize.end(),
|
|
||||||
stablehloKernelSize.begin() + inputRank - 2);
|
|
||||||
|
|
||||||
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
|
||||||
|
|
||||||
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
|
|
||||||
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
|
|
||||||
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
|
||||||
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
|
||||||
|
|
||||||
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
|
|
||||||
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
|
|
||||||
DenseI64ArrayAttr baseDilations;
|
|
||||||
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
|
|
||||||
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
|
||||||
RankedTensorType::get(
|
|
||||||
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
|
||||||
rewriter.getI64Type()),
|
|
||||||
stablehloPadding);
|
|
||||||
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
|
|
||||||
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
|
|
||||||
baseDilations, windowDilations, pad);
|
|
||||||
|
|
||||||
Block &block = reduceWindowOp.getBody().emplaceBlock();
|
|
||||||
|
|
||||||
auto blockArgumentTy = RankedTensorType::get({}, inputElemTy);
|
|
||||||
block.addArgument(blockArgumentTy, op->getLoc());
|
|
||||||
block.addArgument(blockArgumentTy, op->getLoc());
|
|
||||||
|
|
||||||
auto *firstArg = block.args_begin();
|
|
||||||
auto secondArg = block.args_rbegin();
|
|
||||||
|
|
||||||
{
|
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
|
||||||
rewriter.setInsertionPointToStart(&block);
|
|
||||||
Value result =
|
|
||||||
rewriter.create<stablehlo::MaxOp>(op->getLoc(), *firstArg, *secondArg);
|
|
||||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), result);
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.replaceOp(op, reduceWindowOp.getResults());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
// AtenMaxPool2dWithIndicesOp
|
// AtenMaxPool2dWithIndicesOp
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
||||||
|
@ -356,6 +262,129 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <typename AtenOpT, int Dim>
|
||||||
|
class ConvertAtenMaxPoolOp : public ConvertAtenOp<AtenOpT> {
|
||||||
|
public:
|
||||||
|
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
|
||||||
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
Value input = adaptor.getSelf();
|
||||||
|
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||||
|
auto inputElemTy = inputTy.getElementType();
|
||||||
|
auto inputRank = inputTy.getRank();
|
||||||
|
auto outTy = cast<RankedTensorType>(
|
||||||
|
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType()));
|
||||||
|
|
||||||
|
if (inputRank <= Dim) {
|
||||||
|
return op.emitError(
|
||||||
|
"max_pooling1d/2d only supports inputs with rank higher than 1/2");
|
||||||
|
}
|
||||||
|
SmallVector<int64_t, Dim> padding, kernelSize, stride, dilation;
|
||||||
|
bool ceilMode = false;
|
||||||
|
|
||||||
|
if (!(matchPattern(op.getKernelSize(),
|
||||||
|
m_TorchListOfConstantInts(kernelSize)))) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "non-const int kernel size unsupported!");
|
||||||
|
}
|
||||||
|
if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) {
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"non-const int stride unsupported!");
|
||||||
|
}
|
||||||
|
if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) {
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"non-const int padding unsupported!");
|
||||||
|
}
|
||||||
|
if (!(matchPattern(op.getDilation(),
|
||||||
|
m_TorchListOfConstantInts(dilation)))) {
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"non-const int dilation unsupported!");
|
||||||
|
}
|
||||||
|
if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "non-const bool ceil_mode unsupported!");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (stride.empty()) {
|
||||||
|
stride = kernelSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepend 1 to kernelSize, stride, dilation until they are of same rank
|
||||||
|
// as input
|
||||||
|
SmallVector<int64_t> stablehloStride(inputRank, 1);
|
||||||
|
SmallVector<int64_t> stablehloDilation(inputRank, 1);
|
||||||
|
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
|
||||||
|
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
||||||
|
std::copy(dilation.begin(), dilation.end(),
|
||||||
|
stablehloDilation.begin() + inputRank - Dim);
|
||||||
|
std::copy(stride.begin(), stride.end(),
|
||||||
|
stablehloStride.begin() + inputRank - Dim);
|
||||||
|
std::copy(kernelSize.begin(), kernelSize.end(),
|
||||||
|
stablehloKernelSize.begin() + inputRank - Dim);
|
||||||
|
|
||||||
|
Value initVal =
|
||||||
|
createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||||
|
|
||||||
|
if (Dim == 1) {
|
||||||
|
stablehloPadding[stablehloPadding.size() - 2] = padding[0];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 1] = padding[0];
|
||||||
|
} else if (Dim == 2) {
|
||||||
|
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
||||||
|
} else if (Dim == 3) {
|
||||||
|
stablehloPadding[stablehloPadding.size() - 6] = padding[0];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 5] = padding[0];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 4] = padding[1];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 3] = padding[1];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 2] = padding[2];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 1] = padding[2];
|
||||||
|
} else {
|
||||||
|
assert(false && "Unsupported pooling dimension");
|
||||||
|
}
|
||||||
|
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
|
||||||
|
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
|
||||||
|
DenseI64ArrayAttr baseDilations;
|
||||||
|
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
|
||||||
|
|
||||||
|
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
||||||
|
RankedTensorType::get(
|
||||||
|
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
||||||
|
rewriter.getI64Type()),
|
||||||
|
stablehloPadding);
|
||||||
|
|
||||||
|
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
|
||||||
|
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
|
||||||
|
baseDilations, windowDilations, pad);
|
||||||
|
|
||||||
|
Block &block = reduceWindowOp.getBody().emplaceBlock();
|
||||||
|
|
||||||
|
// Add bb argument
|
||||||
|
auto blockArgumentType = RankedTensorType::get({}, inputElemTy);
|
||||||
|
block.addArgument(blockArgumentType, op->getLoc());
|
||||||
|
block.addArgument(blockArgumentType, op->getLoc());
|
||||||
|
auto *firstArg = block.args_begin();
|
||||||
|
auto secondArg = block.args_rbegin();
|
||||||
|
|
||||||
|
{
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
rewriter.setInsertionPointToStart(&block);
|
||||||
|
|
||||||
|
Value result = rewriter.create<stablehlo::MaxOp>(op->getLoc(), *firstArg,
|
||||||
|
*secondArg);
|
||||||
|
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), result);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, reduceWindowOp.getResults());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
template <typename AtenOpT, int Dim>
|
template <typename AtenOpT, int Dim>
|
||||||
class ConvertAtenAvgPoolOp : public ConvertAtenOp<AtenOpT> {
|
class ConvertAtenAvgPoolOp : public ConvertAtenOp<AtenOpT> {
|
||||||
|
@ -375,8 +404,8 @@ public:
|
||||||
auto outShape = outTy.getShape();
|
auto outShape = outTy.getShape();
|
||||||
|
|
||||||
if (inputRank <= Dim) {
|
if (inputRank <= Dim) {
|
||||||
return op.emitError(
|
return op.emitError("avg_pooling1d/2d/3d only supports inputs with rank "
|
||||||
"avg_pooling1d/2d only supports inputs with rank higher than 1/2");
|
"higher than 1/2/3");
|
||||||
}
|
}
|
||||||
SmallVector<int64_t, Dim> padding, kernelSize, stride;
|
SmallVector<int64_t, Dim> padding, kernelSize, stride;
|
||||||
bool ceilMode = false;
|
bool ceilMode = false;
|
||||||
|
@ -405,6 +434,10 @@ public:
|
||||||
op, "non-const bool count_include_pad unsupported!");
|
op, "non-const bool count_include_pad unsupported!");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (stride.empty()) {
|
||||||
|
stride = kernelSize;
|
||||||
|
}
|
||||||
|
|
||||||
if constexpr (std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
|
if constexpr (std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
|
||||||
if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride())))
|
if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride())))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -425,11 +458,20 @@ public:
|
||||||
if (Dim == 1) {
|
if (Dim == 1) {
|
||||||
stablehloPadding[stablehloPadding.size() - 2] = padding[0];
|
stablehloPadding[stablehloPadding.size() - 2] = padding[0];
|
||||||
stablehloPadding[stablehloPadding.size() - 1] = padding[0];
|
stablehloPadding[stablehloPadding.size() - 1] = padding[0];
|
||||||
} else {
|
} else if (Dim == 2) {
|
||||||
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
|
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
|
||||||
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
|
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
|
||||||
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
||||||
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
||||||
|
} else if (Dim == 3) {
|
||||||
|
stablehloPadding[stablehloPadding.size() - 6] = padding[0];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 5] = padding[0];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 4] = padding[1];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 3] = padding[1];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 2] = padding[2];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 1] = padding[2];
|
||||||
|
} else {
|
||||||
|
assert(false && "Unsupported pooling dimension");
|
||||||
}
|
}
|
||||||
|
|
||||||
Value initVal =
|
Value initVal =
|
||||||
|
@ -474,10 +516,17 @@ public:
|
||||||
divisor =
|
divisor =
|
||||||
hlo::getConstTensor<int64_t>(rewriter, op, {kernelSize[0]}, {})
|
hlo::getConstTensor<int64_t>(rewriter, op, {kernelSize[0]}, {})
|
||||||
.value();
|
.value();
|
||||||
} else {
|
} else if (Dim == 2) {
|
||||||
divisor = hlo::getConstTensor<int64_t>(
|
divisor = hlo::getConstTensor<int64_t>(
|
||||||
rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
|
rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
|
||||||
.value();
|
.value();
|
||||||
|
} else if (Dim == 3) {
|
||||||
|
divisor = hlo::getConstTensor<int64_t>(
|
||||||
|
rewriter, op,
|
||||||
|
{kernelSize[0] * kernelSize[1] * kernelSize[2]}, {})
|
||||||
|
.value();
|
||||||
|
} else {
|
||||||
|
assert(false && "Unsupported pooling dimension");
|
||||||
}
|
}
|
||||||
divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy);
|
divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy);
|
||||||
DenseI64ArrayAttr bcastDimensions;
|
DenseI64ArrayAttr bcastDimensions;
|
||||||
|
@ -611,22 +660,28 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
|
||||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||||
MLIRContext *context = patterns.getContext();
|
MLIRContext *context = patterns.getContext();
|
||||||
target.addIllegalOp<AtenAvgPool1dOp>();
|
#define INSERT_ATEN_POOLING_PATTERN(AtenOp) \
|
||||||
patterns.add<ConvertAtenOp<AtenAvgPool1dOp>>(typeConverter, context, options);
|
target.addIllegalOp<AtenOp>(); \
|
||||||
target.addIllegalOp<AtenMaxPool2dOp>();
|
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
|
||||||
patterns.add<ConvertAtenOp<AtenMaxPool2dOp>>(typeConverter, context, options);
|
INSERT_ATEN_POOLING_PATTERN(AtenMaxPool2dWithIndicesOp);
|
||||||
target.addIllegalOp<AtenAvgPool2dOp>();
|
INSERT_ATEN_POOLING_PATTERN(AtenCumsumOp);
|
||||||
patterns.add<ConvertAtenOp<AtenAvgPool2dOp>>(typeConverter, context, options);
|
#undef INSERT_ATEN_POOLING_PATTERN
|
||||||
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
|
|
||||||
patterns.add<ConvertAtenOp<AtenMaxPool2dWithIndicesOp>>(typeConverter,
|
#define INSERT_ATEN_MAXPOOL_PATTERN(AtenOp, Dim) \
|
||||||
context, options);
|
target.addIllegalOp<AtenOp>(); \
|
||||||
target.addIllegalOp<AtenCumsumOp>();
|
patterns.add<ConvertAtenMaxPoolOp<AtenOp, Dim>>(typeConverter, context, \
|
||||||
patterns.add<ConvertAtenOp<AtenCumsumOp>>(typeConverter, context, options);
|
options)
|
||||||
|
INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool1dOp, 1);
|
||||||
|
INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool2dOp, 2);
|
||||||
|
INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool3dOp, 3);
|
||||||
|
#undef INSERT_ATEN_MAXPOOL_PATTERN
|
||||||
|
|
||||||
#define INSERT_ATEN_AVGPOOL_PATTERN(AtenOp, Dim) \
|
#define INSERT_ATEN_AVGPOOL_PATTERN(AtenOp, Dim) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenAvgPoolOp<AtenOp, Dim>>(typeConverter, context, \
|
patterns.add<ConvertAtenAvgPoolOp<AtenOp, Dim>>(typeConverter, context, \
|
||||||
options)
|
options)
|
||||||
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool1dOp, 1);
|
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool1dOp, 1);
|
||||||
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool2dOp, 2);
|
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool2dOp, 2);
|
||||||
|
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool3dOp, 3);
|
||||||
#undef INSERT_ATEN_AVGPOOL_PATTERN
|
#undef INSERT_ATEN_AVGPOOL_PATTERN
|
||||||
}
|
}
|
||||||
|
|
|
@ -7845,19 +7845,19 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" return %arg2 : !torch.list<int>\n"
|
" return %arg2 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.avg_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.avg_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list<int> {\n"
|
||||||
" %0 = call @__torch__.avg_pool1d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool) -> !torch.list<int>\n"
|
" %0 = call @__torch__.pool1d(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
" func.func @__torch__.avg_pool1d(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list<int> {\n"
|
" func.func @__torch__.pool1d(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool) -> !torch.list<int> {\n"
|
||||||
" %int-1 = torch.constant.int -1\n"
|
" %int-1 = torch.constant.int -1\n"
|
||||||
" %int-2 = torch.constant.int -2\n"
|
" %int-2 = torch.constant.int -2\n"
|
||||||
" %int-3 = torch.constant.int -3\n"
|
" %int-3 = torch.constant.int -3\n"
|
||||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||||
" %str_0 = torch.constant.str \"AssertionError: avg_pool1d: padding must be a single int\"\n"
|
" %str_0 = torch.constant.str \"AssertionError: pool1d: padding must be a single int\"\n"
|
||||||
" %str_1 = torch.constant.str \"AssertionError: avg_pool1d: stride must either be omitted, or a single int\"\n"
|
" %str_1 = torch.constant.str \"AssertionError: pool1d: stride must either be omitted, or a single int\"\n"
|
||||||
" %true = torch.constant.bool true\n"
|
" %true = torch.constant.bool true\n"
|
||||||
" %none = torch.constant.none\n"
|
" %none = torch.constant.none\n"
|
||||||
" %str_2 = torch.constant.str \"AssertionError: avg_pool1d: kernel_size must be a single int\"\n"
|
" %str_2 = torch.constant.str \"AssertionError: pool1d: kernel_size must be a single int\"\n"
|
||||||
" %int1 = torch.constant.int 1\n"
|
" %int1 = torch.constant.int 1\n"
|
||||||
" %int0 = torch.constant.int 0\n"
|
" %int0 = torch.constant.int 0\n"
|
||||||
" %int2 = torch.constant.int 2\n"
|
" %int2 = torch.constant.int 2\n"
|
||||||
|
@ -7940,6 +7940,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" }\n"
|
" }\n"
|
||||||
" return %23 : !torch.list<int>\n"
|
" return %23 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.max_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.list<int> {\n"
|
||||||
|
" %0 = call @__torch__.pool1d(%arg0, %arg1, %arg2, %arg3, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
|
||||||
|
" return %0 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.adaptive_avg_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.adaptive_avg_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||||
" %0 = call @__torch__.adaptive_avg_pool1d(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.adaptive_avg_pool1d(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
|
|
|
@ -1075,6 +1075,9 @@ STABLEHLO_PASS_SET = {
|
||||||
"Matmul_vecmat",
|
"Matmul_vecmat",
|
||||||
"MatmulStaticBroadcast_basic",
|
"MatmulStaticBroadcast_basic",
|
||||||
"MaxPool2dStaticModule_basic",
|
"MaxPool2dStaticModule_basic",
|
||||||
|
"MaxPool2dEmptyStrideStaticModule_basic",
|
||||||
|
"MaxPool3dStaticModule_basic",
|
||||||
|
"MaxPool3dEmptyStrideStaticModule_basic",
|
||||||
"MeanDimAllReduceModule_basic",
|
"MeanDimAllReduceModule_basic",
|
||||||
"MeanDimEmptyDimModule_basic",
|
"MeanDimEmptyDimModule_basic",
|
||||||
"MeanDimNoneDimModule_basic",
|
"MeanDimNoneDimModule_basic",
|
||||||
|
|
|
@ -961,14 +961,14 @@ def avg_pool2d(input: List[int], kernel_size: List[int], stride: List[int], padd
|
||||||
|
|
||||||
# TODO: This should be upstreamed.
|
# TODO: This should be upstreamed.
|
||||||
# See https://github.com/pytorch/pytorch/pull/76889 for an example.
|
# See https://github.com/pytorch/pytorch/pull/76889 for an example.
|
||||||
def avg_pool1d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool, count_include_pad: bool):
|
def pool1d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool):
|
||||||
assert len(kernel_size) == 1, "avg_pool1d: kernel_size must be a single int"
|
assert len(kernel_size) == 1, "pool1d: kernel_size must be a single int"
|
||||||
kL = kernel_size[0]
|
kL = kernel_size[0]
|
||||||
|
|
||||||
assert len(stride) == 0 or len(stride) == 1, "avg_pool1d: stride must either be omitted, or a single int"
|
assert len(stride) == 0 or len(stride) == 1, "pool1d: stride must either be omitted, or a single int"
|
||||||
dL = kL if len(stride) == 0 else stride[0]
|
dL = kL if len(stride) == 0 else stride[0]
|
||||||
|
|
||||||
assert len(padding) == 1, "avg_pool1d: padding must be a single int"
|
assert len(padding) == 1, "pool1d: padding must be a single int"
|
||||||
padL = padding[0]
|
padL = padding[0]
|
||||||
|
|
||||||
dilationL = 1
|
dilationL = 1
|
||||||
|
@ -1004,7 +1004,10 @@ def adaptive_avg_pool1d(self: List[int], out: List[int]):
|
||||||
return shape
|
return shape
|
||||||
|
|
||||||
def aten〇avg_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> List[int]:
|
def aten〇avg_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> List[int]:
|
||||||
return avg_pool1d(self, kernel_size, stride, padding, ceil_mode, count_include_pad)
|
return pool1d(self, kernel_size, stride, padding, ceil_mode)
|
||||||
|
|
||||||
|
def aten〇max_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), dilation: List[int] = (1,), ceil_mode: bool = False) -> List[int]:
|
||||||
|
return pool1d(self, kernel_size, stride, padding, ceil_mode)
|
||||||
|
|
||||||
def aten〇adaptive_avg_pool1d〡shape(self: List[int], output_size: List[int]) -> List[int]:
|
def aten〇adaptive_avg_pool1d〡shape(self: List[int], output_size: List[int]) -> List[int]:
|
||||||
return adaptive_avg_pool1d(self, output_size)
|
return adaptive_avg_pool1d(self, output_size)
|
||||||
|
|
|
@ -591,6 +591,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit(
|
emit(
|
||||||
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
|
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
|
||||||
)
|
)
|
||||||
|
emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
||||||
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
||||||
emit(
|
emit(
|
||||||
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
|
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
|
||||||
|
|
Loading…
Reference in New Issue