[MLIR][TORCH] Add decomposition of aten.adaptive_avg_pool2d op

This commit adds the decomposition of `aten.adaptive_avg_pool2d` op into
`aten.avg_pool2d` op. The current decomposition only supports cases where
input size is equal to the output size.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/882/merge
Vivek Khandelwal 2022-05-13 17:36:24 +05:30
parent b76c8c82dc
commit 6f548fc3ad
7 changed files with 269 additions and 134 deletions

View File

@ -45,6 +45,10 @@ Value castIntToIndex(OpBuilder &b, Location loc, Value v);
Value castIndexToInt64(OpBuilder &b, Location loc, Value idx);
SmallVector<Value>
castIntVectorToIndexVector(OpBuilder &b, Location loc,
SmallVectorImpl<Value> &intValues);
Value getDimOp(OpBuilder &b, Location loc, Value v, int dim);
SmallVector<Value> getTensorSizesUntilDim(OpBuilder &b, Location loc,

View File

@ -32,15 +32,21 @@ using namespace mlir::torch::Torch;
template <typename OpTy>
static LogicalResult
checkAndGetPoolingParameters(OpTy op, ConversionPatternRewriter &rewriter,
bool &ceilMode,
SmallVectorImpl<int64_t> &kernelSizeInts,
TypeConverter *typeConverter, bool &ceilMode,
SmallVectorImpl<Value> &kernelSizeIntValues,
SmallVectorImpl<int64_t> &strideInts,
SmallVectorImpl<int64_t> &paddingInts) {
// Pattern match against the op's original operands, because otherwise we
// will get the lowered version of the operands which is harder to pattern
// match.
if (!matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSizeInts)))
return rewriter.notifyMatchFailure(op, "only support kernel size ints");
SmallVector<Value, 2> kernelSizeTorchInt;
if (!getListConstructElements(op.kernel_size(), kernelSizeTorchInt)) {
return rewriter.notifyMatchFailure(op,
"unimplemented: the kernel size is "
"not constructed from ListConstruct");
}
kernelSizeIntValues = getTypeConvertedValues(
rewriter, op.getLoc(), typeConverter, kernelSizeTorchInt);
if (!matchPattern(op.stride(), m_TorchConstantIntList(strideInts)))
return rewriter.notifyMatchFailure(op, "only support constant int strides");
if (!matchPattern(op.padding(), m_TorchConstantIntList(paddingInts)))
@ -58,7 +64,7 @@ template <typename OpTy>
static LogicalResult createPoolingOp(
Operation *op, ConversionPatternRewriter &rewriter, Value self,
bool supportFPInputOnly, bool ceilMode,
SmallVectorImpl<int64_t> &kernelSizeInts,
SmallVectorImpl<Value> &kernelSizeIntValues,
SmallVectorImpl<int64_t> &strideInts, SmallVectorImpl<int64_t> &paddingInts,
SmallVectorImpl<int64_t> &dilationInts, Attribute initValueAttr,
SmallVectorImpl<Value> &outTensorShape, Value &paddedInput, Value &result) {
@ -88,8 +94,6 @@ static LogicalResult createPoolingOp(
getAsConstantIntValues(rewriter, loc, paddingInts);
SmallVector<Value> dilationIntValues =
getAsConstantIntValues(rewriter, loc, dilationInts);
SmallVector<Value> kernelSizeIntValues =
getAsConstantIntValues(rewriter, loc, kernelSizeInts);
SmallVector<Value> strideIntValues =
getAsConstantIntValues(rewriter, loc, strideInts);
@ -108,7 +112,7 @@ static LogicalResult createPoolingOp(
auto stridesAttr = rewriter.getI64VectorAttr(strideInts);
auto dilationAttr = rewriter.getI64VectorAttr(dilationInts);
Value windowTensor = rewriter.create<linalg::InitTensorOp>(
loc, getAsConstantIndexValues(rewriter, loc, kernelSizeInts),
loc, castIntVectorToIndexVector(rewriter, loc, kernelSizeIntValues),
elementType);
result = rewriter
@ -130,6 +134,7 @@ public:
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
TypeConverter *typeConverter = getTypeConverter();
Value self = adaptor.self();
int64_t selfRank = self.getType().cast<RankedTensorType>().getRank();
// TODO: Add support for 3D inputs.
@ -137,14 +142,15 @@ public:
return rewriter.notifyMatchFailure(
op, "unimplemented: only support 4D input");
SmallVector<int64_t, 2> kernelSizeInts, strideInts, paddingInts,
dilationInts;
bool ceilMode;
SmallVector<Value, 2> kernelSizeIntValues;
SmallVector<int64_t, 2> strideInts, paddingInts, dilationInts;
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts)))
return rewriter.notifyMatchFailure(op,
"only support constant int dilations");
bool ceilMode;
if (failed(checkAndGetPoolingParameters<AtenMaxPool2dOp>(
op, rewriter, ceilMode, kernelSizeInts, strideInts, paddingInts)))
op, rewriter, typeConverter, ceilMode, kernelSizeIntValues,
strideInts, paddingInts)))
return rewriter.notifyMatchFailure(op, "invalid pooling parameters");
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
@ -158,7 +164,7 @@ public:
Value maxPool2d, paddedInput;
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
op, rewriter, self, /*supportFPInput=*/false, ceilMode,
kernelSizeInts, strideInts, paddingInts, dilationInts,
kernelSizeIntValues, strideInts, paddingInts, dilationInts,
smallestFPValueAttr, outTensorShape, paddedInput, maxPool2d)))
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
Type newResultType = getTypeConverter()->convertType(op.getType());
@ -200,6 +206,7 @@ public:
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op->getLoc();
TypeConverter *typeConverter = getTypeConverter();
Value self = adaptor.self();
RankedTensorType selfType = self.getType().cast<RankedTensorType>();
Type elementType = selfType.getElementType();
@ -213,14 +220,15 @@ public:
return rewriter.notifyMatchFailure(
op, "unimplemented: only support 4D input");
SmallVector<int64_t, 2> kernelSizeInts, strideInts, paddingInts,
dilationInts;
bool ceilMode;
SmallVector<Value, 2> kernelSizeIntValues;
SmallVector<int64_t, 2> strideInts, paddingInts, dilationInts;
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts)))
return rewriter.notifyMatchFailure(op,
"only support constant int dilations");
bool ceilMode;
if (failed(checkAndGetPoolingParameters<AtenMaxPool2dWithIndicesOp>(
op, rewriter, ceilMode, kernelSizeInts, strideInts, paddingInts)))
op, rewriter, typeConverter, ceilMode, kernelSizeIntValues,
strideInts, paddingInts)))
return rewriter.notifyMatchFailure(op, "invalid pooling parameters");
// `maxpool2d` contains the result of maxpool2d operation over the input.
@ -233,7 +241,7 @@ public:
SmallVector<Value, 4> outTensorShape;
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
op, rewriter, self, /*supportFPInput=*/false, ceilMode,
kernelSizeInts, strideInts, paddingInts, dilationInts,
kernelSizeIntValues, strideInts, paddingInts, dilationInts,
smallestFPValueAttr, outTensorShape, paddedInput, maxPool2d)))
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
@ -244,7 +252,7 @@ public:
indicesRankedTensorType.getElementType(), cstMinusOne);
SmallVector<Value> kernelSize =
getAsConstantIndexValues(rewriter, loc, kernelSizeInts);
castIntVectorToIndexVector(rewriter, loc, kernelSizeIntValues);
SmallVector<Value> padding =
getAsConstantIndexValues(rewriter, loc, paddingInts);
SmallVector<Value> dilation =
@ -344,105 +352,6 @@ public:
};
} // namespace
namespace {
class ConvertAtenAdaptiveAvgPool2dOp
: public OpConversionPattern<AtenAdaptiveAvgPool2dOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenAdaptiveAvgPool2dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
MLIRContext *context = op->getContext();
Value input = adaptor.self(); /* in form of N*C*H*W */
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
Type elementType = inputType.getElementType();
if (!elementType.isa<mlir::FloatType>())
return op.emitError("unimplemented: non-floating point type");
auto inputRank = inputType.getRank();
if (inputRank != 4)
return rewriter.notifyMatchFailure(op, "input should be rank 4");
SmallVector<int64_t, 2> expects{1, 1};
// Pattern match against the op's original operands, because otherwise we
// will get the lowered version of the operands which is harder to pattern
// match.
if (!isConstantIntListMatching(op.output_size(), expects))
return rewriter.notifyMatchFailure(
op, "only support output_size with H and W both equal to constant 1");
Value N = getDimOp(rewriter, loc, input, 0);
Value C = getDimOp(rewriter, loc, input, 1);
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{N, C}, elementType);
Value c0 = rewriter.create<arith::ConstantOp>(
loc, FloatAttr::get(elementType, 0.0));
Value initTensor0 =
rewriter.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
SmallVector<AffineExpr, 2> ncExprs;
ncExprs.push_back(mlir::getAffineDimExpr(0, context));
ncExprs.push_back(mlir::getAffineDimExpr(1, context));
auto ncIndexingMap = AffineMap::get(
/*dimCount=*/4,
/*symbolCount=*/0, ncExprs, context);
SmallVector<AffineMap, 2> indexingMaps = {
rewriter.getMultiDimIdentityMap(4), // input
ncIndexingMap, // output
};
SmallVector<StringRef, 4> iteratorTypesSum{"parallel", "parallel",
"reduction", "reduction"};
Value sumPool2d = rewriter
.create<linalg::GenericOp>(
loc, initTensor0.getType(), input, initTensor0,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypesSum,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value input = args[0], sum = args[1];
Value result = rewriter.create<arith::AddFOp>(
loc, sum, input);
b.create<linalg::YieldOp>(loc, result);
})
.getResult(0);
// Calculate H*W so that avg can be got from sum / (H*W)
Value H = getDimOp(rewriter, loc, input, 2);
Value W = getDimOp(rewriter, loc, input, 3);
auto castIndexToInt = [&](Value v) {
return rewriter.create<arith::IndexCastOp>(
loc, IntegerType::get(context, 64), v);
};
Value HtimesW = rewriter.create<arith::MulIOp>(loc, castIndexToInt(H),
castIndexToInt(W));
Value HtimesWf =
rewriter.create<arith::SIToFPOp>(loc, elementType, HtimesW);
Value c1Index = rewriter.create<arith::ConstantIndexOp>(loc, /*value=*/1);
Value outputTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{N, C, c1Index, c1Index}, elementType);
SmallVector<AffineMap, 2> indexingMapsAvg{
ncIndexingMap, rewriter.getMultiDimIdentityMap(4)};
SmallVector<StringRef, 4> iteratorTypesAvg(4, "parallel");
Value avgPool2d =
rewriter
.create<linalg::GenericOp>(
loc, outputTensor.getType(), sumPool2d, outputTensor,
/*indexingMaps=*/indexingMapsAvg,
/*iteratorTypes=*/iteratorTypesAvg,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value avg = b.create<arith::DivFOp>(loc, args[0], HtimesWf);
b.create<linalg::YieldOp>(loc, avg);
})
.getResult(0);
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, avgPool2d);
return success();
}
};
} // namespace
namespace {
class ConvertAtenAvgPool2dOp : public OpConversionPattern<AtenAvgPool2dOp> {
public:
@ -453,6 +362,7 @@ public:
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op->getLoc();
TypeConverter *typeConverter = getTypeConverter();
Value self = adaptor.self();
Type inputElementType =
@ -461,11 +371,12 @@ public:
Type resultElementType =
resultType.cast<RankedTensorType>().getElementType();
SmallVector<int64_t, 2> dilationInts{1, 1};
SmallVector<int64_t, 2> kernelSizeInts, strideInts, paddingInts;
bool ceilMode;
SmallVector<Value, 2> kernelSizeIntValues;
SmallVector<int64_t, 2> strideInts, paddingInts, dilationInts{1, 1};
if (failed(checkAndGetPoolingParameters<AtenAvgPool2dOp>(
op, rewriter, ceilMode, kernelSizeInts, strideInts, paddingInts)))
op, rewriter, typeConverter, ceilMode, kernelSizeIntValues,
strideInts, paddingInts)))
return rewriter.notifyMatchFailure(op, "invalid pooling parameters");
// TODO: Add support for count_include_pad equal to `False`.
@ -484,13 +395,11 @@ public:
SmallVector<Value, 4> outTensorShape;
if (failed(createPoolingOp<linalg::PoolingNchwSumOp>(
op, rewriter, self, /*supportFPInput=*/true, ceilMode,
kernelSizeInts, strideInts, paddingInts, dilationInts,
kernelSizeIntValues, strideInts, paddingInts, dilationInts,
rewriter.getZeroAttr(inputElementType), outTensorShape, paddedInput,
sumPool2d)))
return rewriter.notifyMatchFailure(op, "unable to compute sumpool2d");
SmallVector<Value> kernelSizeIntValues =
getAsConstantIntValues(rewriter, loc, kernelSizeInts);
Value kHtimeskW = rewriter.create<arith::MulIOp>(
loc, kernelSizeIntValues[0], kernelSizeIntValues[1]);
Value divisor = op.divisor_override().getType().isa<Torch::NoneType>()
@ -534,8 +443,6 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context);
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
patterns.add<ConvertAtenMaxPool2dWithIndicesOp>(typeConverter, context);
target.addIllegalOp<AtenAdaptiveAvgPool2dOp>();
patterns.add<ConvertAtenAdaptiveAvgPool2dOp>(typeConverter, context);
target.addIllegalOp<AtenAvgPool2dOp>();
patterns.add<ConvertAtenAvgPool2dOp>(typeConverter, context);
}

View File

@ -143,6 +143,15 @@ Value castIndexToInt64(OpBuilder &b, Location loc, Value idx) {
return b.create<arith::IndexCastOp>(loc, b.getI64Type(), idx);
}
SmallVector<Value>
castIntVectorToIndexVector(OpBuilder &b, Location loc,
SmallVectorImpl<Value> &intValues) {
SmallVector<Value> indexValues;
for (Value v : intValues)
indexValues.push_back(castIntToIndex(b, loc, v));
return indexValues;
}
Value getDimOp(OpBuilder &b, Location loc, Value v, int dim) {
return b.createOrFold<tensor::DimOp>(loc, v, dim);
}

View File

@ -1551,7 +1551,8 @@ class DecomposeAten_UnsafeViewOp : public OpRewritePattern<Aten_UnsafeViewOp> {
// Note that this is the same decomposition as in AOTAutograd
// https://github.com/pytorch/functorch/blob/a3042d94e616d4143813668b1372d9d4545be14e/functorch/_src/aot_autograd.py#L104
namespace {
class DecomposeAten_ReshapeAliasOp : public OpRewritePattern<Aten_ReshapeAliasOp> {
class DecomposeAten_ReshapeAliasOp
: public OpRewritePattern<Aten_ReshapeAliasOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_ReshapeAliasOp op,
PatternRewriter &rewriter) const override {
@ -1775,6 +1776,116 @@ public:
};
} // namespace
namespace {
// Decompose `aten.adaptive_avg_pool2d` op into `aten.avg_pool2d` op.
//
// For AdaptiveAvgPool2d op, when the input size is an integer multiple of
// output size the kernel_size, stride and padding is calculated as follows:
// strideH = inH // outH
// strideW = inH // outH
// kernelH = inH - [(outH - 1) * strideH]
// kernelW = inW - [(outW - 1) * strideW]
// paddingH = 0, paddingW = 0
//
// For the special case, when the output size is one for all dimensions,
// the kernel size is same as the input size.
class DecomposeAtenAdaptiveAvgPool2dOp
: public OpRewritePattern<AtenAdaptiveAvgPool2dOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenAdaptiveAvgPool2dOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
MLIRContext *context = op.getContext();
Value input = op.self();
int64_t rank = getTensorRank(input);
SmallVector<Value, 2> inputHW;
Value dimH = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(rank - 2));
inputHW.push_back(
/*inH=*/rewriter.create<AtenSizeIntOp>(loc, input, dimH));
Value dimW = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(rank - 1));
inputHW.push_back(
/*inW=*/rewriter.create<AtenSizeIntOp>(loc, input, dimW));
Value outputShape = op.output_size();
SmallVector<Value> outputShapeSizesTorchInt;
getListConstructElements(outputShape, outputShapeSizesTorchInt);
// TODO: Add support for cases other than:
// 1.) inH == outH and inW == outW.
// 2.) outH == outW == 1
bool unitOutputSize = true;
for (Value outShape : outputShapeSizesTorchInt) {
int64_t outShapeInt;
if (!matchPattern(outShape, m_TorchConstantInt(&outShapeInt))) {
return rewriter.notifyMatchFailure(
op, "output size is expected to be a constant");
}
if (outShapeInt != 1) {
unitOutputSize = false;
break;
}
}
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value constantFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
Value constantTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
Value constantNone = rewriter.create<Torch::ConstantNoneOp>(loc);
SmallVector<Value, 2> kernelSize;
for (unsigned i = 0; i < inputHW.size(); i++) {
if (unitOutputSize) {
BaseTensorType inputTensorType = input.getType().cast<BaseTensorType>();
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
kernelSize.push_back(inputShape[rank - 2 + i] == kUnknownSize
? inputHW[i]
: rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(
inputShape[rank - 2 + i])));
} else {
Value cond = rewriter.create<AtenEqIntOp>(loc, inputHW[i],
outputShapeSizesTorchInt[i]);
rewriter.create<RuntimeAssertOp>(
loc, cond,
"unimplemented: only support cases where input and output size are "
"equal for non-unit output size");
Value outMinusOne = rewriter.create<AtenSubIntOp>(
loc, outputShapeSizesTorchInt[i], constantOne);
kernelSize.push_back(
rewriter.create<AtenSubIntOp>(loc, inputHW[i], outMinusOne));
}
}
Value kernelSizeList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize);
// Currently we only support cases where input size is equal to the output
// size or unit output size. For the former case, stride is always equal to
// one and for the latter the stride value doesn't matter, since the kernel
// size is same as the input size. Therfore, keeping the stride as one for
// the latter case as well for the ease of implementation.
Value strideList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
ValueRange{constantOne, constantOne});
Value paddingSizeList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
ValueRange{constantZero, constantZero});
rewriter.replaceOpWithNewOp<AtenAvgPool2dOp>(
op, op.getType(), input, kernelSizeList, strideList, paddingSizeList,
/*ceil_mode=*/constantFalse, /*count_include_pad=*/constantTrue,
/*divisor_override=*/constantNone);
return success();
}
};
} // namespace
namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@ -1910,6 +2021,8 @@ class DecomposeComplexOpsPass
patterns.add<DecomposeAtenPadOp>(context);
patterns.add<DecomposeAtenToDtypeLayoutOp>(context);
target.addIllegalOp<AtenToDtypeLayoutOp>();
patterns.add<DecomposeAtenAdaptiveAvgPool2dOp>(context);
target.addIllegalOp<AtenAdaptiveAvgPool2dOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {

View File

@ -12,7 +12,72 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ==============================================================================
class AdaptiveAvgPool2dModule(torch.nn.Module):
class AdaptiveAvgPool2dNonUnitOutputSizeStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.aap2d = torch.nn.AdaptiveAvgPool2d((7, 7))
@export
@annotate_args([
None,
([1, 512, 7, 7], torch.float32, True),
])
def forward(self, x):
return self.aap2d(x)
@register_test_case(
module_factory=lambda: AdaptiveAvgPool2dNonUnitOutputSizeStaticModule())
def AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic(
module, tu: TestUtils):
module.forward(tu.rand(1, 512, 7, 7))
class AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.aap2d = torch.nn.AdaptiveAvgPool2d((7, 7))
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, x):
return self.aap2d(x)
@register_test_case(
module_factory=lambda: AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule())
def AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic(
module, tu: TestUtils):
module.forward(tu.rand(1, 512, 7, 7))
class AdaptiveAvgPool2dUnitOutputSizeStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.aap2d = torch.nn.AdaptiveAvgPool2d((1, 1))
@export
@annotate_args([
None,
([1, 512, 7, 7], torch.float32, True),
])
def forward(self, x):
return self.aap2d(x)
@register_test_case(
module_factory=lambda: AdaptiveAvgPool2dUnitOutputSizeStaticModule())
def AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 512, 7, 7))
class AdaptiveAvgPool2dUnitOutputSizeDynamicModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -27,9 +92,10 @@ class AdaptiveAvgPool2dModule(torch.nn.Module):
return self.aap2d(x)
@register_test_case(module_factory=lambda: AdaptiveAvgPool2dModule())
def AdaptiveAvgPool2dModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 3, 8, 9))
@register_test_case(
module_factory=lambda: AdaptiveAvgPool2dUnitOutputSizeDynamicModule())
def AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 512, 7, 7))
# ==============================================================================

View File

@ -11,12 +11,14 @@ func.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,
%int7 = torch.constant.int 7
%int8 = torch.constant.int 8
%false = torch.constant.bool false
// CHECK: %[[C1:.*]] = torch_c.to_i64 %int1
// CHECK: %[[C2:.*]] = torch_c.to_i64 %int2
// CHECK: %[[NEUTRAL:.*]] = arith.constant -3.40282347E+38 : f32
// CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 5, 6] high[0, 0, 5, 6]
// CHECK: %[[OUT:.*]] = linalg.fill ins(%[[NEUTRAL]] : f32) outs(%{{.*}} : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[C1]], %[[C2]]] : tensor<?x?xf32>
// CHECK: %[[T1:.*]] = arith.index_cast %[[C1]] : i64 to index
// CHECK: %[[T2:.*]] = arith.index_cast %[[C2]] : i64 to index
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[T1]], %[[T2]]] : tensor<?x?xf32>
// CHECK: linalg.pooling_nchw_max {dilations = dense<[7, 8]> : vector<2xi64>, strides = dense<[3, 4]> : vector<2xi64>} ins(%[[PADDED]], %[[INIT]] : tensor<?x?x?x?xf32>, tensor<?x?xf32>) outs(%[[OUT]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
%kernel_size = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%stride = torch.prim.ListConstruct %int3, %int4 : (!torch.int, !torch.int) -> !torch.list<int>

View File

@ -949,3 +949,37 @@ func.func @torch.aten.to.dtype_layout(%arg0: !torch.vtensor<[?,?],f32>) -> !torc
%0 = torch.aten.to.dtype_layout %arg0, %int7, %int0, %none, %none, %false, %false, %none : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f64>
return %0 : !torch.vtensor<[?,?],f64>
}
// -----
// CHECK-LABEL: func @torch.aten.adaptive_avg_pool2d(
// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
// CHECK: %[[CST7:.*]] = torch.constant.int 7
// CHECK: %[[OUTPUT_SIZE:.*]] = torch.prim.ListConstruct %[[CST7]], %[[CST7]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[CST2:.*]] = torch.constant.int 2
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[SELF]], %[[CST2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[CST3:.*]] = torch.constant.int 3
// CHECK: %[[DIM3:.*]] = torch.aten.size.int %[[SELF]], %[[CST3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[COND1:.*]] = torch.aten.eq.int %[[DIM2]], %[[CST7]] : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[COND1]], "unimplemented: only support cases where input and output size are equal for non-unit output size"
// CHECK: %[[T1:.*]] = torch.aten.sub.int %[[CST7]], %[[CST1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[T2:.*]] = torch.aten.sub.int %[[DIM2]], %[[T1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[COND2:.*]] = torch.aten.eq.int %[[DIM3]], %[[CST7]] : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[COND2]], "unimplemented: only support cases where input and output size are equal for non-unit output size"
// CHECK: %[[T3:.*]] = torch.aten.sub.int %[[CST7]], %[[CST1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[T4:.*]] = torch.aten.sub.int %[[DIM3]], %[[T3]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[T5:.*]] = torch.prim.ListConstruct %[[T2]], %[[T4]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T6:.*]] = torch.prim.ListConstruct %[[CST1]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T7:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[OUT:.*]] = torch.aten.avg_pool2d %[[SELF]], %[[T5]], %[[T6]], %[[T7]], %[[FALSE]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?,?,?],f32>
func.func @torch.aten.adaptive_avg_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
%int7 = torch.constant.int 7
%output_size = torch.prim.ListConstruct %int7, %int7 : (!torch.int, !torch.int) -> !torch.list<int>
%0 = torch.aten.adaptive_avg_pool2d %arg0, %output_size : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32>
}