[MLIR][TORCH] Add E2E support for aten.avg_pool2d op

This commit adds lowering of `aten.avg_pool2d` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/819/head snapshot-20220502.427
Vivek Khandelwal 2022-03-30 20:44:23 +05:30
parent 81ee5bb58c
commit 4b11284440
8 changed files with 472 additions and 59 deletions

View File

@ -3048,6 +3048,35 @@ def Torch_AtenMaxPool2dWithIndicesBackwardOp : Torch_Op<"aten.max_pool2d_with_in
}]; }];
} }
def Torch_AtenAvgPool2dOp : Torch_Op<"aten.avg_pool2d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::avg_pool2d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$kernel_size,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding,
Torch_BoolType:$ceil_mode,
Torch_BoolType:$count_include_pad,
AnyTorchOptionalIntType:$divisor_override
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAvgPool2dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 7, 1);
}
void AtenAvgPool2dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 7, 1);
}
}];
}
def Torch_AtenSoftmaxIntOp : Torch_Op<"aten.softmax.int", [ def Torch_AtenSoftmaxIntOp : Torch_Op<"aten.softmax.int", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -34,8 +34,7 @@ static LogicalResult
checkAndGetPoolingParameters(OpTy op, ConversionPatternRewriter &rewriter, checkAndGetPoolingParameters(OpTy op, ConversionPatternRewriter &rewriter,
SmallVectorImpl<int64_t> &kernelSizeInts, SmallVectorImpl<int64_t> &kernelSizeInts,
SmallVectorImpl<int64_t> &strideInts, SmallVectorImpl<int64_t> &strideInts,
SmallVectorImpl<int64_t> &paddingInts, SmallVectorImpl<int64_t> &paddingInts) {
SmallVectorImpl<int64_t> &dilationInts) {
// Pattern match against the op's original operands, because otherwise we // Pattern match against the op's original operands, because otherwise we
// will get the lowered version of the operands which is harder to pattern // will get the lowered version of the operands which is harder to pattern
// match. // match.
@ -46,9 +45,6 @@ checkAndGetPoolingParameters(OpTy op, ConversionPatternRewriter &rewriter,
if (!matchPattern(op.padding(), m_TorchConstantIntList(paddingInts))) if (!matchPattern(op.padding(), m_TorchConstantIntList(paddingInts)))
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"only support constant int paddings"); "only support constant int paddings");
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts)))
return rewriter.notifyMatchFailure(op,
"only support constant int dilations");
bool ceilMode; bool ceilMode;
if (!matchPattern(op.ceil_mode(), m_TorchConstantBool(&ceilMode))) if (!matchPattern(op.ceil_mode(), m_TorchConstantBool(&ceilMode)))
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
@ -60,30 +56,26 @@ checkAndGetPoolingParameters(OpTy op, ConversionPatternRewriter &rewriter,
return success(); return success();
} }
// Computes maxpool2d for AtenMaxPool2dOp and AtenMaxPool2dWithIndicesOp. // Creates a pooling operation based on the type specified by `OpTy` and
static LogicalResult // arguments passed.
computeMaxPool2d(Operation *op, ConversionPatternRewriter &rewriter, Value self, template <typename OpTy>
SmallVectorImpl<int64_t> &kernelSizeInts, static LogicalResult createPoolingOp(
SmallVectorImpl<int64_t> &strideInts, Operation *op, ConversionPatternRewriter &rewriter, Value self,
SmallVectorImpl<int64_t> &paddingInts, bool supportFPInputOnly, SmallVectorImpl<int64_t> &kernelSizeInts,
SmallVectorImpl<int64_t> &dilationInts, Value &result) { SmallVectorImpl<int64_t> &strideInts, SmallVectorImpl<int64_t> &paddingInts,
SmallVectorImpl<int64_t> &dilationInts, Attribute intiValueAttr,
SmallVectorImpl<Value> &outTensorShape, Value &paddedInput, Value &result) {
Location loc = op->getLoc(); Location loc = op->getLoc();
Type elementType = self.getType().cast<RankedTensorType>().getElementType(); Type elementType = self.getType().cast<RankedTensorType>().getElementType();
if (!elementType.isa<mlir::FloatType>()) if (!elementType.isa<mlir::FloatType>() && !supportFPInputOnly)
return op->emitError("unimplemented: non-floating point type"); return op->emitError("unimplemented: non-floating point type");
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0}; SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(), paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(),
paddingInts.end()); paddingInts.end());
auto smallestFPValueAttr = rewriter.getFloatAttr( Value initValue = rewriter.create<arith::ConstantOp>(loc, intiValueAttr);
elementType, APFloat::getLargest( paddedInput = torch_to_linalg::getPaddedTensor(
elementType.cast<mlir::FloatType>().getFloatSemantics(), op, rewriter, self, paddingIncludingNC, paddingIncludingNC, initValue);
/*Negative=*/true));
Value smallestFPValue =
rewriter.create<arith::ConstantOp>(loc, smallestFPValueAttr);
Value paddedInput =
torch_to_linalg::getPaddedTensor(op, rewriter, self, paddingIncludingNC,
paddingIncludingNC, smallestFPValue);
Value N = getDimOp(rewriter, loc, self, 0); Value N = getDimOp(rewriter, loc, self, 0);
Value C = getDimOp(rewriter, loc, self, 1); Value C = getDimOp(rewriter, loc, self, 1);
@ -107,9 +99,9 @@ computeMaxPool2d(Operation *op, ConversionPatternRewriter &rewriter, Value self,
kernelSizeIntValues[1], strideIntValues[1]); kernelSizeIntValues[1], strideIntValues[1]);
// Create output tensor initialized with smallest floating point value. // Create output tensor initialized with smallest floating point value.
outTensorShape.insert(outTensorShape.begin(), {N, C, hOut, wOut});
Value outTensorInitialized = Value outTensorInitialized =
createInitTensor(rewriter, loc, ValueRange{N, C, hOut, wOut}, elementType, createInitTensor(rewriter, loc, outTensorShape, elementType, initValue);
smallestFPValue);
auto stridesAttr = rewriter.getI64VectorAttr(strideInts); auto stridesAttr = rewriter.getI64VectorAttr(strideInts);
auto dilationAttr = rewriter.getI64VectorAttr(dilationInts); auto dilationAttr = rewriter.getI64VectorAttr(dilationInts);
@ -118,11 +110,11 @@ computeMaxPool2d(Operation *op, ConversionPatternRewriter &rewriter, Value self,
elementType); elementType);
result = rewriter result = rewriter
.create<linalg::PoolingNchwMaxOp>( .create<OpTy>(loc, outTensorInitialized.getType(),
loc, outTensorInitialized.getType(), ValueRange{paddedInput, windowTensor},
ValueRange{paddedInput, windowTensor}, outTensorInitialized, outTensorInitialized, stridesAttr, dilationAttr)
stridesAttr, dilationAttr)
.getResult(0); .getResult(0);
return success(); return success();
} }
@ -143,16 +135,28 @@ public:
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "unimplemented: only support 4D input"); op, "unimplemented: only support 4D input");
Value maxPool2d;
SmallVector<int64_t, 2> kernelSizeInts, strideInts, paddingInts, SmallVector<int64_t, 2> kernelSizeInts, strideInts, paddingInts,
dilationInts; dilationInts;
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts)))
return rewriter.notifyMatchFailure(op,
"only support constant int dilations");
if (failed(checkAndGetPoolingParameters<AtenMaxPool2dOp>( if (failed(checkAndGetPoolingParameters<AtenMaxPool2dOp>(
op, rewriter, kernelSizeInts, strideInts, paddingInts, op, rewriter, kernelSizeInts, strideInts, paddingInts)))
dilationInts)))
return rewriter.notifyMatchFailure(op, "invalid pooling parameters"); return rewriter.notifyMatchFailure(op, "invalid pooling parameters");
if (failed(computeMaxPool2d(op, rewriter, self, kernelSizeInts, strideInts, Type elementType = self.getType().cast<RankedTensorType>().getElementType();
paddingInts, dilationInts, maxPool2d))) auto smallestFPValueAttr = rewriter.getFloatAttr(
elementType,
APFloat::getLargest(
elementType.cast<mlir::FloatType>().getFloatSemantics(),
/*Negative=*/true));
SmallVector<Value, 4> outTensorShape;
// `maxpool2d` contains the result of maxpool2d operation over the input.
Value maxPool2d, paddedInput;
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
op, rewriter, self, /*supportFPInput=*/false, kernelSizeInts,
strideInts, paddingInts, dilationInts, smallestFPValueAttr,
outTensorShape, paddedInput, maxPool2d)))
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d"); return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
Type newResultType = getTypeConverter()->convertType(op.getType()); Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, maxPool2d); rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, maxPool2d);
@ -208,36 +212,31 @@ public:
SmallVector<int64_t, 2> kernelSizeInts, strideInts, paddingInts, SmallVector<int64_t, 2> kernelSizeInts, strideInts, paddingInts,
dilationInts; dilationInts;
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts)))
return rewriter.notifyMatchFailure(op,
"only support constant int dilations");
if (failed(checkAndGetPoolingParameters<AtenMaxPool2dWithIndicesOp>( if (failed(checkAndGetPoolingParameters<AtenMaxPool2dWithIndicesOp>(
op, rewriter, kernelSizeInts, strideInts, paddingInts, op, rewriter, kernelSizeInts, strideInts, paddingInts)))
dilationInts)))
return rewriter.notifyMatchFailure(op, "invalid pooling parameters"); return rewriter.notifyMatchFailure(op, "invalid pooling parameters");
// Contains the result of maxpool2d operation over the input. // `maxpool2d` contains the result of maxpool2d operation over the input.
Value maxPool2d;
if (failed(computeMaxPool2d(op, rewriter, self, kernelSizeInts, strideInts,
paddingInts, dilationInts, maxPool2d)))
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(),
paddingInts.end());
auto smallestFPValueAttr = rewriter.getFloatAttr( auto smallestFPValueAttr = rewriter.getFloatAttr(
elementType, elementType,
APFloat::getLargest( APFloat::getLargest(
elementType.cast<mlir::FloatType>().getFloatSemantics(), elementType.cast<mlir::FloatType>().getFloatSemantics(),
/*Negative=*/true)); /*Negative=*/true));
Value smallestFPValue = Value maxPool2d, paddedInput;
rewriter.create<arith::ConstantOp>(loc, smallestFPValueAttr); SmallVector<Value, 4> outTensorShape;
Value paddedInput = if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
torch_to_linalg::getPaddedTensor(op, rewriter, self, paddingIncludingNC, op, rewriter, self, /*supportFPInput=*/false, kernelSizeInts,
paddingIncludingNC, smallestFPValue); strideInts, paddingInts, dilationInts, smallestFPValueAttr,
outTensorShape, paddedInput, maxPool2d)))
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
SmallVector<Value> resultShape(getTensorSizes(rewriter, loc, maxPool2d));
Value cstMinusOne = Value cstMinusOne =
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(-1)); rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(-1));
Value indicesTensor = Value indicesTensor =
createInitTensor(rewriter, loc, resultShape, createInitTensor(rewriter, loc, outTensorShape,
indicesRankedTensorType.getElementType(), cstMinusOne); indicesRankedTensorType.getElementType(), cstMinusOne);
SmallVector<Value> kernelSize = SmallVector<Value> kernelSize =
@ -440,6 +439,88 @@ public:
}; };
} // namespace } // namespace
namespace {
class ConvertAtenAvgPool2dOp : public OpConversionPattern<AtenAvgPool2dOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenAvgPool2dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op->getLoc();
Value self = adaptor.self();
Type inputElementType =
self.getType().cast<RankedTensorType>().getElementType();
Type resultType = getTypeConverter()->convertType(op.getType());
Type resultElementType =
resultType.cast<RankedTensorType>().getElementType();
SmallVector<int64_t, 2> dilationInts{1, 1};
SmallVector<int64_t, 2> kernelSizeInts, strideInts, paddingInts;
if (failed(checkAndGetPoolingParameters<AtenAvgPool2dOp>(
op, rewriter, kernelSizeInts, strideInts, paddingInts)))
return rewriter.notifyMatchFailure(op, "invalid pooling parameters");
// TODO: Add support for count_include_pad equal to `False`.
bool countIncludePad;
if (!matchPattern(op.count_include_pad(),
m_TorchConstantBool(&countIncludePad)))
return rewriter.notifyMatchFailure(
op, "count_include_pad must be a constant");
if (!countIncludePad) {
return rewriter.notifyMatchFailure(
op, "unimplemented: count_include_pad is expected to be true");
}
// `sumPool2d` contains the result of sumpool2d operation over the input.
Value sumPool2d, paddedInput;
SmallVector<Value, 4> outTensorShape;
if (failed(createPoolingOp<linalg::PoolingNchwSumOp>(
op, rewriter, self, /*supportFPInput=*/true, kernelSizeInts,
strideInts, paddingInts, dilationInts,
rewriter.getZeroAttr(inputElementType), outTensorShape, paddedInput,
sumPool2d)))
return rewriter.notifyMatchFailure(op, "unable to compute sumpool2d");
SmallVector<Value> kernelSizeIntValues =
getAsConstantIntValues(rewriter, loc, kernelSizeInts);
Value kHtimeskW = rewriter.create<arith::MulIOp>(
loc, kernelSizeIntValues[0], kernelSizeIntValues[1]);
Value divisor = op.divisor_override().getType().isa<Torch::NoneType>()
? kHtimeskW
: adaptor.divisor_override();
divisor = convertScalarToDtype(rewriter, loc, divisor, resultElementType);
Value outputTensor = rewriter.create<linalg::InitTensorOp>(
loc, outTensorShape, resultElementType);
SmallVector<AffineMap> indexingMapsAvg(2,
rewriter.getMultiDimIdentityMap(4));
SmallVector<StringRef> iteratorTypesAvg(4, "parallel");
Value avgPool2d =
rewriter
.create<linalg::GenericOp>(
loc, outputTensor.getType(), sumPool2d, outputTensor,
/*indexingMaps=*/indexingMapsAvg,
/*iteratorTypes=*/iteratorTypesAvg,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value avg;
if (resultElementType.isa<mlir::IntegerType>())
avg = b.create<arith::DivSIOp>(loc, args[0], divisor);
else if (resultElementType.isa<mlir::FloatType>())
avg = b.create<arith::DivFOp>(loc, args[0], divisor);
b.create<linalg::YieldOp>(loc, avg);
})
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, avgPool2d);
return success();
}
};
} // namespace
void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality( void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns, TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) { ConversionTarget &target) {
@ -450,4 +531,6 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
patterns.add<ConvertAtenMaxPool2dWithIndicesOp>(typeConverter, context); patterns.add<ConvertAtenMaxPool2dWithIndicesOp>(typeConverter, context);
target.addIllegalOp<AtenAdaptiveAvgPool2dOp>(); target.addIllegalOp<AtenAdaptiveAvgPool2dOp>();
patterns.add<ConvertAtenAdaptiveAvgPool2dOp>(typeConverter, context); patterns.add<ConvertAtenAdaptiveAvgPool2dOp>(typeConverter, context);
target.addIllegalOp<AtenAvgPool2dOp>();
patterns.add<ConvertAtenAvgPool2dOp>(typeConverter, context);
} }

View File

@ -490,14 +490,14 @@ ChangeResult TypeAnalyzer::visitOperation(
ValsemVariantAtenBernoulliFloatOp, ValsemVariantAtenBernoulliTensorOp, ValsemVariantAtenBernoulliFloatOp, ValsemVariantAtenBernoulliTensorOp,
ValsemVariantAtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp, ValsemVariantAtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp,
AtenHardswishOp, AtenSiluOp, AtenHardtanhOp, AtenMaskedSelectOp, AtenHardswishOp, AtenSiluOp, AtenHardtanhOp, AtenMaskedSelectOp,
AtenMaxPool2dOp, AtenAdaptiveAvgPool2dOp, AtenFlattenUsingIntsOp, AtenMaxPool2dOp, AtenAvgPool2dOp, AtenAdaptiveAvgPool2dOp,
AtenSqueezeOp, AtenSqueezeDimOp, AtenUnsqueezeOp, AtenViewOp, AtenFlattenUsingIntsOp, AtenSqueezeOp, AtenSqueezeDimOp,
Aten_UnsafeViewOp, AtenReshapeOp, Aten_ReshapeAliasOp, AtenResize_Op, AtenUnsqueezeOp, AtenViewOp, Aten_UnsafeViewOp, AtenReshapeOp,
AtenTransposeIntOp, AtenTOp, AtenPermuteOp, AtenIndexSelectOp, Aten_ReshapeAliasOp, AtenResize_Op, AtenTransposeIntOp, AtenTOp,
AtenSelectIntOp, AtenSliceTensorOp, AtenGatherOp, AtenExpandOp, AtenPermuteOp, AtenIndexSelectOp, AtenSelectIntOp, AtenSliceTensorOp,
AtenExpandAsOp, AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp, AtenGatherOp, AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp,
AtenPadOp, AtenZero_Op, AtenIndexTensorOp, AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp, AtenZero_Op,
ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp, AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp,
ValsemVariantAtenCopyOp, ValsemVariantAtenZeroOp, ValsemVariantAtenCopyOp, ValsemVariantAtenZeroOp,
AtenIndexPutHackedTwinOp>(op)) { AtenIndexPutHackedTwinOp>(op)) {
ValueKnowledge knowledge = ValueKnowledge knowledge =

View File

@ -1584,6 +1584,159 @@ module {
func @"__torch_mlir_shape_fn.aten.max_pool2d_with_indices_backward"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>) -> !torch.list<int> { func @"__torch_mlir_shape_fn.aten.max_pool2d_with_indices_backward"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>) -> !torch.list<int> {
return %arg1 : !torch.list<int> return %arg1 : !torch.list<int>
} }
func @"__torch_mlir_shape_fn.aten.avg_pool2d"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.avg_pool2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.optional<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.avg_pool2d(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {
%int4 = torch.constant.int 4
%int3 = torch.constant.int 3
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%true = torch.constant.bool true
%str = torch.constant.str "AssertionError: avg_pool2d: kernel_size must either be a single int, or a tuple of two ints"
%none = torch.constant.none
%str_0 = torch.constant.str "AssertionError: avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints"
%str_1 = torch.constant.str "AssertionError: avg_pool2d: padding must be either be a single int, or a tuple of two ints"
%str_2 = torch.constant.str "AssertionError: "
%int-4 = torch.constant.int -4
%int-3 = torch.constant.int -3
%int-2 = torch.constant.int -2
%int-1 = torch.constant.int -1
%0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool
%2 = torch.prim.If %1 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%39 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%40 = torch.aten.eq.int %39, %int2 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %40 : !torch.bool
}
torch.prim.If %2 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%3 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int
%4 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%5 = torch.aten.eq.int %4, %int1 : !torch.int, !torch.int -> !torch.bool
%6 = torch.prim.If %5 -> (!torch.int) {
torch.prim.If.yield %3 : !torch.int
} else {
%39 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %39 : !torch.int
}
%7 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int
%8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
%9 = torch.prim.If %8 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%39 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int
%40 = torch.aten.eq.int %39, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %40 : !torch.bool
}
%10 = torch.prim.If %9 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%39 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int
%40 = torch.aten.eq.int %39, %int2 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %40 : !torch.bool
}
torch.prim.If %10 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%11 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int
%12 = torch.aten.eq.int %11, %int0 : !torch.int, !torch.int -> !torch.bool
%13 = torch.prim.If %12 -> (!torch.int) {
torch.prim.If.yield %3 : !torch.int
} else {
%39 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %39 : !torch.int
}
%14 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int
%15 = torch.aten.eq.int %14, %int0 : !torch.int, !torch.int -> !torch.bool
%16 = torch.prim.If %15 -> (!torch.int) {
torch.prim.If.yield %6 : !torch.int
} else {
%39 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int
%40 = torch.aten.eq.int %39, %int1 : !torch.int, !torch.int -> !torch.bool
%41 = torch.prim.If %40 -> (!torch.int) {
torch.prim.If.yield %13 : !torch.int
} else {
%42 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %42 : !torch.int
}
torch.prim.If.yield %41 : !torch.int
}
%17 = torch.aten.len.t %arg3 : !torch.list<int> -> !torch.int
%18 = torch.aten.eq.int %17, %int1 : !torch.int, !torch.int -> !torch.bool
%19 = torch.prim.If %18 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%39 = torch.aten.len.t %arg3 : !torch.list<int> -> !torch.int
%40 = torch.aten.eq.int %39, %int2 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %40 : !torch.bool
}
torch.prim.If %19 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%20 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list<int>, !torch.int -> !torch.int
%21 = torch.aten.len.t %arg3 : !torch.list<int> -> !torch.int
%22 = torch.aten.eq.int %21, %int1 : !torch.int, !torch.int -> !torch.bool
%23 = torch.prim.If %22 -> (!torch.int) {
torch.prim.If.yield %20 : !torch.int
} else {
%39 = torch.aten.__getitem__.t %arg3, %int1 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %39 : !torch.int
}
%24 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%25 = torch.aten.eq.int %24, %int3 : !torch.int, !torch.int -> !torch.bool
%26 = torch.prim.If %25 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%39 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%40 = torch.aten.eq.int %39, %int4 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %40 : !torch.bool
}
torch.prim.If %26 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%27 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%28 = torch.aten.eq.int %27, %int4 : !torch.int, !torch.int -> !torch.bool
%29 = torch.prim.If %28 -> (!torch.int) {
%39 = torch.aten.__getitem__.t %arg0, %int-4 : !torch.list<int>, !torch.int -> !torch.int
torch.prim.If.yield %39 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
%30 = torch.aten.__getitem__.t %arg0, %int-3 : !torch.list<int>, !torch.int -> !torch.int
%31 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list<int>, !torch.int -> !torch.int
%32 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list<int>, !torch.int -> !torch.int
%33 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.pooling_output_shape(%31, %3, %20, %13, %int1, %arg4) : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.bool) -> !torch.int
%34 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.pooling_output_shape(%32, %6, %23, %16, %int1, %arg4) : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.bool) -> !torch.int
%35 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.pool2d_shape_check(%arg0, %3, %6, %13, %16, %20, %23, %int1, %int1, %30, %31, %32, %33, %34) : (!torch.list<int>, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.none
%36 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%37 = torch.aten.eq.int %36, %int3 : !torch.int, !torch.int -> !torch.bool
%38 = torch.prim.If %37 -> (!torch.list<int>) {
%39 = torch.prim.ListConstruct %30, %33, %34 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.prim.If.yield %39 : !torch.list<int>
} else {
%39 = torch.prim.ListConstruct %29, %30, %33, %34 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
torch.prim.If.yield %39 : !torch.list<int>
}
return %38 : !torch.list<int>
}
func @"__torch_mlir_shape_fn.aten.adaptive_avg_pool2d"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> { func @"__torch_mlir_shape_fn.aten.adaptive_avg_pool2d"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.adaptive_avg_pool2d(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int> %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.adaptive_avg_pool2d(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int> return %0 : !torch.list<int>

View File

@ -574,6 +574,9 @@ def atenmax_pool2d_with_indices(self: List[int], kernel_size: List[int], stri
def atenmax_pool2d_with_indices_backward(grad_output: List[int], self: List[int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, indices: List[int]) -> List[int]: def atenmax_pool2d_with_indices_backward(grad_output: List[int], self: List[int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, indices: List[int]) -> List[int]:
return self return self
def atenavg_pool2d(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> List[int]:
return upstream_shape_helpers.avg_pool2d(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
def atenadaptive_avg_pool2d(self: List[int], output_size: List[int]) -> List[int]: def atenadaptive_avg_pool2d(self: List[int], output_size: List[int]) -> List[int]:
return upstream_shape_helpers.adaptive_avg_pool2d(self, output_size) return upstream_shape_helpers.adaptive_avg_pool2d(self, output_size)

View File

@ -334,6 +334,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit( emit(
"aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)" "aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
) )
emit(
"aten::avg_pool2d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)"
)
emit( emit(
"aten::softmax.int : (Tensor, int, int?) -> (Tensor)" "aten::softmax.int : (Tensor, int, int?) -> (Tensor)"
) )

View File

@ -213,6 +213,45 @@ def max_pool2d(input: List[int], kernel_size: List[int], stride: List[int], padd
else: else:
return [nbatch, nInputPlane, outputHeight, outputWidth] return [nbatch, nInputPlane, outputHeight, outputWidth]
def avg_pool2d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool, count_include_pad: bool, divisor_override: Optional[int]):
assert len(kernel_size) == 1 or len(kernel_size) == 2, "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints"
kH = kernel_size[0]
kW = kH if len(kernel_size) == 1 else kernel_size[1]
assert len(stride) == 0 or len(stride) == 1 or len(stride) == 2, "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints"
dH = kH if len(stride) == 0 else stride[0]
if len(stride) == 0:
dW = kW
elif len(stride) == 1:
dW = dH
else:
dW = stride[1]
assert len(padding) == 1 or len(padding) == 2, "avg_pool2d: padding must be either be a single int, or a tuple of two ints"
padH = padding[0]
padW = padH if len(padding) == 1 else padding[1]
dilationH = 1
dilationW = 1
assert len(input) == 3 or len(input) == 4
nbatch = input[-4] if len(input) == 4 else 1
nInputPlane = input[-3]
inputHeight = input[-2]
inputWidth = input[-1]
outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode)
outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode)
pool2d_shape_check(input, kH, kW, dH, dW, padH, padW, dilationH, dilationW, nInputPlane,
inputHeight, inputWidth, outputHeight, outputWidth)
if len(input) == 3:
return [nInputPlane, outputHeight, outputWidth]
else:
return [nbatch, nInputPlane, outputHeight, outputWidth]
def max_pool2d_with_indices(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool): def max_pool2d_with_indices(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool):
out = max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode) out = max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
return (out, out) return (out, out)

View File

@ -465,3 +465,106 @@ class MaxPool2dWithIndicesBackwardDynamic3DModule(torch.nn.Module):
def MaxPool2dWithIndicesBackwardDynamic3DModule_basic(module, tu: TestUtils): def MaxPool2dWithIndicesBackwardDynamic3DModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 7, 6), tu.rand(2, 6, 5), module.forward(tu.rand(2, 7, 6), tu.rand(2, 6, 5),
torch.randint(16, (2, 7, 6))) torch.randint(16, (2, 7, 6)))
# ==============================================================================
class AvgPool2dFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.ap2d = torch.nn.AvgPool2d(kernel_size=[6, 8],
stride=[2, 2],
padding=[3, 4],
ceil_mode=False,
count_include_pad=True,
divisor_override=None)
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, x):
return self.ap2d(x)
@register_test_case(module_factory=lambda: AvgPool2dFloatModule())
def AvgPool2dFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 20, 20) - 0.5)
class AvgPool2dIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.ap2d = torch.nn.AvgPool2d(kernel_size=[6, 8],
stride=[2, 2],
padding=[3, 4],
ceil_mode=False,
count_include_pad=True,
divisor_override=None)
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.int64, True),
])
def forward(self, x):
return self.ap2d(x)
@register_test_case(module_factory=lambda: AvgPool2dIntModule())
def AvgPool2dIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(100, (2, 4, 20, 20)))
class AvgPool2dStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.ap2d = torch.nn.AvgPool2d(kernel_size=[6, 8],
stride=[2, 2],
padding=[3, 4],
ceil_mode=False,
count_include_pad=True,
divisor_override=None)
@export
@annotate_args([
None,
([2, 2, 10, 20], torch.float32, True),
])
def forward(self, x):
return self.ap2d(x)
@register_test_case(module_factory=lambda: AvgPool2dStaticModule())
def AvgPool2dStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 2, 10, 20) - 0.5)
class AvgPool2dDivisorOverrideModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.ap2d = torch.nn.AvgPool2d(kernel_size=[4, 8],
stride=[2, 3],
padding=[2, 4],
ceil_mode=False,
count_include_pad=True,
divisor_override=22)
@export
@annotate_args([
None,
([4, 4, 20, 20], torch.float32, True),
])
def forward(self, x):
return self.ap2d(x)
@register_test_case(module_factory=lambda: AvgPool2dDivisorOverrideModule())
def AvgPool2dDivisorOverrideModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 4, 20, 20) - 0.5)