mirror of https://github.com/llvm/torch-mlir
[RFC] general support for Adaptive Pooling Ops (#2661)
Adaptive pooling ops can only be decomposed into their non-adaptive counterparts in trivial cases. For example, the current decomposition for AtenAdaptiveAvgPool1dOp in DecomposeComplexOps.cpp supports outSize = inSize (i.e., do literally nothing), and outSize = 1 (i.e., do a batched average). The reason adaptive pooling ops are difficult to lower to linalg is that they are not constantly strided. They are computed by taking an input tensor of shape (N, C, Hin), and an output size Hout, and computing the output tensor at position (n,c, h) in the following way: 1. compute st(h) = (h*Hin)//Hout 2. compute en(h) = 1 + ((h+1)*Hin -1)//Hout 3. apply a computation (max or avg) to the slice: INPUT[n, c, st(h):en(h)] The provided sample implementation (for ConvertAtenAdaptiveAvgPool1dOp) uses tensor.extract to access the input tensor inside the payload of a linalg generic op. This is likely an unattractive use of linalg generic ops, which is why I am asking for some more targeted feedback on the validity of this approach before attempting to support the many other adaptive pooling ops. Specifically: - Is the performance of this implementation bad enough to warrant targeting different dialects entirely? e.g. TMtensor/linalg ext/ etc. - If the provided implementation is of acceptable performance to the community, then is it permissable to remove the Adaptive pooling decompositions from DecomposeComplexOps.cpp? Based on the current structure of the -torch-decompose-complex-ops pass, it does not seem possible to only decompose the adaptive ops in special cases (it seems to get stuck in an infinite loop on a match failure). I would be happy to instead incorporate the case logic into the conversion directly, and remove the decompositions once they are rendered completely obsolete. As long as this approach is acceptable, I can clean up the implementation with some helper functions, and quickly add support for each of the remaining Adaptive pooling ops.pull/2720/head
parent
4dd17f0b71
commit
07d0645f64
|
@ -97,7 +97,8 @@ static LogicalResult createPoolingOp(
|
|||
}
|
||||
}
|
||||
|
||||
Value initValue = rewriter.create<arith::ConstantOp>(loc, cast<TypedAttr>(initValueAttr));
|
||||
Value initValue =
|
||||
rewriter.create<arith::ConstantOp>(loc, cast<TypedAttr>(initValueAttr));
|
||||
paddedInput = torch_to_linalg::getPaddedTensor(
|
||||
op, rewriter, self, lowPaddingIncludingNC, highPaddingIncludingNC,
|
||||
initValue);
|
||||
|
@ -141,7 +142,6 @@ static LogicalResult createPoolingOp(
|
|||
return success();
|
||||
}
|
||||
|
||||
|
||||
namespace {
|
||||
class ConvertAtenMaxPool2dOp : public OpConversionPattern<AtenMaxPool2dOp> {
|
||||
public:
|
||||
|
@ -163,7 +163,8 @@ public:
|
|||
bool ceilMode;
|
||||
SmallVector<Value, 2> kernelSizeIntValues;
|
||||
SmallVector<int64_t, 2> strideInts, paddingInts, dilationInts;
|
||||
if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationInts)))
|
||||
if (!matchPattern(op.getDilation(),
|
||||
m_TorchListOfConstantInts(dilationInts)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only support constant int dilations");
|
||||
if (failed(checkAndGetPoolingParameters<AtenMaxPool2dOp>(
|
||||
|
@ -241,7 +242,8 @@ public:
|
|||
bool ceilMode;
|
||||
SmallVector<Value, 2> kernelSizeIntValues;
|
||||
SmallVector<int64_t, 2> strideInts, paddingInts, dilationInts;
|
||||
if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationInts)))
|
||||
if (!matchPattern(op.getDilation(),
|
||||
m_TorchListOfConstantInts(dilationInts)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only support constant int dilations");
|
||||
if (failed(checkAndGetPoolingParameters<AtenMaxPool2dWithIndicesOp>(
|
||||
|
@ -372,7 +374,6 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
|
||||
namespace {
|
||||
template <typename OpTy, typename PoolingOpTy, int Dim>
|
||||
class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
|
||||
|
@ -397,8 +398,8 @@ public:
|
|||
bool ceilMode;
|
||||
SmallVector<Value, Dim> kernelSizeIntValues;
|
||||
SmallVector<int64_t, Dim> strideInts, paddingInts, dilationInts(Dim, 1);
|
||||
if (failed(checkAndGetPoolingParameters<OpTy>(
|
||||
op, rewriter, typeConverter, ceilMode, kernelSizeIntValues,
|
||||
if (failed(checkAndGetPoolingParameters<OpTy>(op, rewriter, typeConverter,
|
||||
ceilMode, kernelSizeIntValues,
|
||||
strideInts, paddingInts)))
|
||||
return rewriter.notifyMatchFailure(op, "invalid pooling parameters");
|
||||
|
||||
|
@ -418,15 +419,16 @@ public:
|
|||
SmallVector<Value, Dim + 2> outTensorShape;
|
||||
if (failed(createPoolingOp<PoolingOpTy>(
|
||||
op, rewriter, self, /*supportNonFPInput=*/true, ceilMode,
|
||||
/*dimensionality=*/Dim, kernelSizeIntValues, strideInts, paddingInts,
|
||||
dilationInts, rewriter.getZeroAttr(inputElementType), outTensorShape,
|
||||
paddedInput, sumPool)))
|
||||
/*dimensionality=*/Dim, kernelSizeIntValues, strideInts,
|
||||
paddingInts, dilationInts, rewriter.getZeroAttr(inputElementType),
|
||||
outTensorShape, paddedInput, sumPool)))
|
||||
return rewriter.notifyMatchFailure(op, "unable to compute sumpool");
|
||||
Value divisor;
|
||||
if constexpr (std::is_same<OpTy, AtenAvgPool2dOp>()) {
|
||||
Value kHtimeskW = rewriter.create<arith::MulIOp>(
|
||||
loc, kernelSizeIntValues[0], kernelSizeIntValues[1]);
|
||||
divisor = op.getDivisorOverride().getType().template isa<Torch::NoneType>()
|
||||
divisor =
|
||||
op.getDivisorOverride().getType().template isa<Torch::NoneType>()
|
||||
? kHtimeskW
|
||||
: adaptor.getDivisorOverride();
|
||||
} else {
|
||||
|
@ -436,7 +438,8 @@ public:
|
|||
|
||||
Value outputTensor = rewriter.create<tensor::EmptyOp>(
|
||||
loc, getAsOpFoldResult(outTensorShape), resultElementType);
|
||||
SmallVector<AffineMap> indexingMapsAvg(2, rewriter.getMultiDimIdentityMap(Dim+2));
|
||||
SmallVector<AffineMap> indexingMapsAvg(
|
||||
2, rewriter.getMultiDimIdentityMap(Dim + 2));
|
||||
SmallVector<utils::IteratorType> iteratorTypesAvg(
|
||||
Dim + 2, utils::IteratorType::parallel);
|
||||
Value avgPool =
|
||||
|
@ -459,8 +462,188 @@ public:
|
|||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
/*
|
||||
This section is for lowering adaptive pooling ops, which cannot generally be
|
||||
decomposed into typical pooling ops. Given an input tensor of rank (N,C,Hin) and
|
||||
an output spatial size Hout, an element of the output tensor at position (n, c,
|
||||
h) is computed as follows.
|
||||
1. compute st(h) = (h*Hin)//Hout
|
||||
2. compute en(h) = 1 + ((h+1)*Hin - 1)//Hout
|
||||
3. apply the operation (max or avg) over input[n, c, st(h):en(h)]
|
||||
This is problematic for linalg ops for a few reasons:
|
||||
1. The access to the input tensor is not constantly strided
|
||||
2. The size of the window itself is not contant: en(h) - st(h) can vary with
|
||||
h! Although it is a bit like using a hammer to paint, our workaround is to use
|
||||
tensor.extract to access the elements of the input tensor inside our linalg
|
||||
generic op's payload.
|
||||
|
||||
Current TODO's:
|
||||
1. gather most of the boilerplate out of this op and make it into an
|
||||
adaptive pooling helper function.
|
||||
2. figure out what to do with the conflicting decompositions in
|
||||
DecomposeComplexOps.cpp
|
||||
3. Implement more efficient passes for when the kernel-size, input spatial
|
||||
dims, and output spatial dims are constant.
|
||||
*/
|
||||
|
||||
namespace {
|
||||
class ConvertAtenAdaptiveAvgPool1dOp
|
||||
: public OpConversionPattern<AtenAdaptiveAvgPool1dOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenAdaptiveAvgPool1dOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
Location loc = op->getLoc();
|
||||
const TypeConverter *typeConverter = getTypeConverter();
|
||||
|
||||
// get rank of input (same as rank of output)
|
||||
int64_t rank =
|
||||
adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
|
||||
// input operand should be NCH (i.e. rank 3)
|
||||
if (rank != 3) {
|
||||
return rewriter.notifyMatchFailure(op, "only supports input type NCH");
|
||||
}
|
||||
|
||||
// input tensor and output shape
|
||||
Value input = adaptor.getSelf();
|
||||
Value outputShape = op.getOutputSize();
|
||||
SmallVector<Value> outShapeVector;
|
||||
getListConstructElements(outputShape, outShapeVector);
|
||||
outShapeVector =
|
||||
getTypeConvertedValues(rewriter, loc, typeConverter, outShapeVector);
|
||||
Value hIn = getDimOp(rewriter, loc, input, 2);
|
||||
Value hOut = outShapeVector[0];
|
||||
Value hOutIndex = castIntToIndex(rewriter, loc, hOut);
|
||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
||||
RankedTensorType outputType =
|
||||
typeConverter->convertType(op.getResult().getType())
|
||||
.cast<RankedTensorType>();
|
||||
|
||||
// get elementType of input tensor
|
||||
Type elementType = inputType.getElementType();
|
||||
|
||||
// make an iteration space of size kMax = 1 + ceildiv (hIn - 1) , hOut
|
||||
Type boolType = rewriter.getI1Type();
|
||||
Value kIter;
|
||||
Value constantOne =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
||||
Value hInPlusOne = rewriter.create<arith::SubIOp>(loc, hIn, constantOne);
|
||||
Value kMaxMinusOne =
|
||||
rewriter.create<arith::CeilDivSIOp>(loc, hInPlusOne, hOutIndex);
|
||||
Value kMax = rewriter.create<arith::AddIOp>(loc, constantOne, kMaxMinusOne);
|
||||
kIter = rewriter.create<tensor::EmptyOp>(
|
||||
loc, getAsOpFoldResult(ValueRange({kMax})), boolType);
|
||||
|
||||
// need to buffer input, else there will possibly be an out of bounds access
|
||||
// later buffVal = 0 for avg pooling and -inf for max pooling
|
||||
Value buffVal = rewriter.create<arith::ConstantOp>(
|
||||
loc, elementType, rewriter.getFloatAttr(elementType, 0));
|
||||
SmallVector<int64_t> lowPadding = {0, 0, 0};
|
||||
SmallVector<int64_t> highPadding = {0, 0, 1};
|
||||
Value buffInput = torch_to_linalg::getPaddedTensor(
|
||||
op, rewriter, input, lowPadding, highPadding, buffVal);
|
||||
|
||||
// make a list of outputSizes
|
||||
SmallVector<Value> outputSizes;
|
||||
for (unsigned i = 0; i < rank - 1; i++) {
|
||||
outputSizes.push_back(getDimOp(rewriter, loc, input, i));
|
||||
}
|
||||
outputSizes.push_back(hOutIndex);
|
||||
|
||||
// initialize a kernel size tensor (only for avg pooling)
|
||||
Value kSizeTensor = rewriter.create<tensor::EmptyOp>(
|
||||
loc, getAsOpFoldResult(ValueRange({hOutIndex})), elementType);
|
||||
|
||||
// initialize an output tensor
|
||||
Value initOutput =
|
||||
createInitTensor(rewriter, loc, outputSizes, elementType, buffVal);
|
||||
|
||||
// setup indexing maps and iterator types for linalg generic op
|
||||
// for kIter (d0,d1,d2,d3) -> (d3)
|
||||
// for output (d0,d1,d2,d3) -> (d0,d1,d2)
|
||||
// for kSizeTensor (d0,d1,d2,d3) -> (d2)
|
||||
SmallVector<AffineExpr> kIterExprs, outputExprs, kSizeTensorExprs;
|
||||
for (unsigned i = 0; i < 3; i++) {
|
||||
outputExprs.push_back(rewriter.getAffineDimExpr(i));
|
||||
}
|
||||
kSizeTensorExprs.push_back(rewriter.getAffineDimExpr(2));
|
||||
kIterExprs.push_back(rewriter.getAffineDimExpr(3));
|
||||
SmallVector<AffineMap> indexingMaps = AffineMap::inferFromExprList(
|
||||
{kIterExprs, outputExprs, kSizeTensorExprs});
|
||||
SmallVector<utils::IteratorType> iteratorTypes(
|
||||
3, utils::IteratorType::parallel);
|
||||
iteratorTypes.push_back(utils::IteratorType::reduction);
|
||||
|
||||
Value indexOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||||
auto sumPool = rewriter.create<linalg::GenericOp>(
|
||||
loc, /*resultTensorTypes=*/
|
||||
TypeRange({initOutput.getType(), kSizeTensor.getType()}),
|
||||
/*inputs=*/ValueRange({kIter}),
|
||||
/*outputs=*/ValueRange({initOutput, kSizeTensor}),
|
||||
/*indexingMaps=*/indexingMaps,
|
||||
/*iteratorTypes=*/iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value res = args[1];
|
||||
Value ind0 = b.create<linalg::IndexOp>(loc, 0);
|
||||
Value ind1 = b.create<linalg::IndexOp>(loc, 1);
|
||||
Value ind2 = b.create<linalg::IndexOp>(loc, 2);
|
||||
Value ind3 = b.create<linalg::IndexOp>(loc, 3);
|
||||
// compute start and end indices
|
||||
// st = s1( s0(ind2 * Hin) // Hout )
|
||||
Value s0 = b.create<arith::MulIOp>(loc, ind2, hIn);
|
||||
Value s1 = b.create<arith::FloorDivSIOp>(loc, s0, hOutIndex);
|
||||
// en = e4( 1 + e3( e2( e1( e0(ind2 + 1) * hIn ) - 1 ) // hOut ) )
|
||||
Value e0 = b.create<arith::AddIOp>(loc, ind2, indexOne);
|
||||
Value e1 = b.create<arith::MulIOp>(loc, e0, hIn);
|
||||
Value e2 = b.create<arith::SubIOp>(loc, e1, indexOne);
|
||||
Value e3 = b.create<arith::FloorDivSIOp>(loc, e2, hOutIndex);
|
||||
Value e4 = b.create<arith::AddIOp>(loc, indexOne, e3);
|
||||
// get input element @ st + ind3:
|
||||
Value wIndex = b.create<arith::AddIOp>(loc, s1, ind3);
|
||||
Value inElt = b.create<tensor::ExtractOp>(
|
||||
loc, elementType, buffInput, ValueRange({ind0, ind1, wIndex}));
|
||||
// check if we extracted at windex < end index
|
||||
Value cond =
|
||||
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate(6), wIndex, e4);
|
||||
// if inElt is in bounds, include it in the computation
|
||||
// else, use buffVal = 0 (for max pool use -infinity)
|
||||
Value out1 = b.create<arith::SelectOp>(loc, cond, inElt, buffVal);
|
||||
// compute Kernel size: we store this to kwTensor
|
||||
Value kSize = b.create<arith::SubIOp>(loc, e4, s1);
|
||||
Value kSizeInt = castIndexToInt64(b, loc, kSize);
|
||||
Value kSizeF = b.create<arith::SIToFPOp>(loc, elementType, kSizeInt);
|
||||
// accumulate out2 to res = args[1]
|
||||
Value out2 = b.create<arith::AddFOp>(loc, res, out1);
|
||||
b.create<linalg::YieldOp>(loc, ValueRange({out2, kSizeF}));
|
||||
});
|
||||
|
||||
// make a linalg generic to divide each element by the corresponding
|
||||
// Kernel Width. This step is only necessary for avg pooling.
|
||||
SmallVector<AffineMap> indexingMaps1 =
|
||||
AffineMap::inferFromExprList({kSizeTensorExprs, outputExprs});
|
||||
SmallVector<utils::IteratorType> iteratorTypes1(
|
||||
3, utils::IteratorType::parallel);
|
||||
auto output = rewriter.create<linalg::GenericOp>(
|
||||
loc, /*resultTensorTypes=*/initOutput.getType(),
|
||||
/*inputs=*/sumPool.getResultTensors()[1],
|
||||
/*outputs=*/sumPool.getResultTensors()[0],
|
||||
/*indexingMaps=*/indexingMaps1,
|
||||
/*iteratorTypes=*/iteratorTypes1,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value q = b.create<arith::DivFOp>(loc, args[1], args[0]);
|
||||
b.create<linalg::YieldOp>(loc, q);
|
||||
});
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outputType,
|
||||
output.getResultTensors());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
|
@ -471,8 +654,12 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
|
|||
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
|
||||
patterns.add<ConvertAtenMaxPool2dWithIndicesOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenAvgPool1dOp, AtenAvgPool2dOp>();
|
||||
patterns.add<ConvertAtenAvgPoolOp<AtenAvgPool1dOp, linalg::PoolingNcwSumOp, 1>>(
|
||||
patterns
|
||||
.add<ConvertAtenAvgPoolOp<AtenAvgPool1dOp, linalg::PoolingNcwSumOp, 1>>(
|
||||
typeConverter, context);
|
||||
patterns.add<ConvertAtenAvgPoolOp<AtenAvgPool2dOp, linalg::PoolingNchwSumOp, 2>>(
|
||||
patterns
|
||||
.add<ConvertAtenAvgPoolOp<AtenAvgPool2dOp, linalg::PoolingNchwSumOp, 2>>(
|
||||
typeConverter, context);
|
||||
target.addIllegalOp<AtenAdaptiveAvgPool1dOp>();
|
||||
patterns.add<ConvertAtenAdaptiveAvgPool1dOp>(typeConverter, context);
|
||||
}
|
||||
|
|
|
@ -257,6 +257,8 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
# ERROR: Exception: Unsupported: missing default value for argument 0 in schema for aten.div.Tensor_mode
|
||||
"ElementwiseDivRoundingModeFloorModule_basic",
|
||||
"ElementwiseDivRoundingModeTruncModule_basic",
|
||||
"AdaptiveAvgPool1dStaticLargerOutput_basic",
|
||||
"AdaptiveAvgPool1dGeneralDynamic_basic",
|
||||
|
||||
# ERROR: Exception: Unsupported op: get_attr
|
||||
"NumToTensorFloatModule_basic",
|
||||
|
@ -1324,6 +1326,7 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
|
|||
### Tests additionally passing in make_fx_tosa
|
||||
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
|
||||
"NativeGroupNormBackwardModule_basic",
|
||||
"SliceWholeTensorModule_basic",
|
||||
"TensorFloatModule_basic",
|
||||
|
|
|
@ -248,7 +248,7 @@ class ExampleArgs:
|
|||
# compiler where each backend can "own" its set of legal ops.
|
||||
BACKEND_LEGAL_OPS = {
|
||||
OutputType.TOSA: ['aten.flatten.using_ints', 'aten.native_layer_norm', 'aten.linear'],
|
||||
OutputType.LINALG_ON_TENSORS: ['aten.flatten.using_ints', ],
|
||||
OutputType.LINALG_ON_TENSORS: ['aten.flatten.using_ints','aten.adaptive_avg_pool1d'],
|
||||
OutputType.STABLEHLO: [],
|
||||
}
|
||||
|
||||
|
|
|
@ -11,7 +11,6 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AdaptiveAvgPool2dNonUnitOutputSizeStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -55,7 +54,6 @@ def AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic(
|
|||
module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 512, 7, 7))
|
||||
|
||||
|
||||
class AdaptiveAvgPool2dUnitOutputSizeStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -776,12 +774,71 @@ def AvgPool1dStaticModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class AdaptiveAvgPool1dStaticLargerOutput(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=13)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([5, 512, 7], torch.float32, True)
|
||||
])
|
||||
def forward(self,x):
|
||||
return self.aap1d(x)
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: AdaptiveAvgPool1dStaticLargerOutput())
|
||||
def AdaptiveAvgPool1dStaticLargerOutput_basic(
|
||||
module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 512, 7))
|
||||
|
||||
class AdaptiveAvgPool1dStaticEvenMultiple(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=7)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([5, 512, 147], torch.float32, True)
|
||||
])
|
||||
def forward(self,x):
|
||||
return self.aap1d(x)
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: AdaptiveAvgPool1dStaticEvenMultiple())
|
||||
def AdaptiveAvgPool1dStaticEvenMultiple_basic(
|
||||
module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 512, 147))
|
||||
|
||||
class AdaptiveAvgPool1dGeneralDynamic(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=7)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1,-1,-1], torch.float32, True)
|
||||
])
|
||||
def forward(self,x):
|
||||
return self.aap1d(x)
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: AdaptiveAvgPool1dGeneralDynamic())
|
||||
def AdaptiveAvgPool1dGeneralDynamic_basic(
|
||||
module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 512, 10))
|
||||
|
||||
class AdaptiveAvgPool1dNonUnitOutputSizeStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.aap1d = torch.nn.AdaptiveAvgPool1d(7)
|
||||
self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=7)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
|
@ -801,7 +858,7 @@ class AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule(torch.nn.Module):
|
|||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.aap1d = torch.nn.AdaptiveAvgPool1d(7)
|
||||
self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=7)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
|
@ -821,7 +878,7 @@ class AdaptiveAvgPool1dUnitOutputSizeStaticModule(torch.nn.Module):
|
|||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.aap1d = torch.nn.AdaptiveAvgPool1d(1)
|
||||
self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=1)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
|
@ -841,7 +898,7 @@ class AdaptiveAvgPool1dUnitOutputSizeDynamicModule(torch.nn.Module):
|
|||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.aap1d = torch.nn.AdaptiveAvgPool1d(1)
|
||||
self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=1)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
|
|
Loading…
Reference in New Issue