mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add E2E support for aten.avg_pool2d op
This commit adds lowering of `aten.avg_pool2d` op. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/819/head snapshot-20220502.427
parent
81ee5bb58c
commit
4b11284440
|
@ -3048,6 +3048,35 @@ def Torch_AtenMaxPool2dWithIndicesBackwardOp : Torch_Op<"aten.max_pool2d_with_in
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenAvgPool2dOp : Torch_Op<"aten.avg_pool2d", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::avg_pool2d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchListOfTorchIntType:$kernel_size,
|
||||
AnyTorchListOfTorchIntType:$stride,
|
||||
AnyTorchListOfTorchIntType:$padding,
|
||||
Torch_BoolType:$ceil_mode,
|
||||
Torch_BoolType:$count_include_pad,
|
||||
AnyTorchOptionalIntType:$divisor_override
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenAvgPool2dOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 7, 1);
|
||||
}
|
||||
void AtenAvgPool2dOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 7, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenSoftmaxIntOp : Torch_Op<"aten.softmax.int", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -34,8 +34,7 @@ static LogicalResult
|
|||
checkAndGetPoolingParameters(OpTy op, ConversionPatternRewriter &rewriter,
|
||||
SmallVectorImpl<int64_t> &kernelSizeInts,
|
||||
SmallVectorImpl<int64_t> &strideInts,
|
||||
SmallVectorImpl<int64_t> &paddingInts,
|
||||
SmallVectorImpl<int64_t> &dilationInts) {
|
||||
SmallVectorImpl<int64_t> &paddingInts) {
|
||||
// Pattern match against the op's original operands, because otherwise we
|
||||
// will get the lowered version of the operands which is harder to pattern
|
||||
// match.
|
||||
|
@ -46,9 +45,6 @@ checkAndGetPoolingParameters(OpTy op, ConversionPatternRewriter &rewriter,
|
|||
if (!matchPattern(op.padding(), m_TorchConstantIntList(paddingInts)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only support constant int paddings");
|
||||
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only support constant int dilations");
|
||||
bool ceilMode;
|
||||
if (!matchPattern(op.ceil_mode(), m_TorchConstantBool(&ceilMode)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
|
@ -60,30 +56,26 @@ checkAndGetPoolingParameters(OpTy op, ConversionPatternRewriter &rewriter,
|
|||
return success();
|
||||
}
|
||||
|
||||
// Computes maxpool2d for AtenMaxPool2dOp and AtenMaxPool2dWithIndicesOp.
|
||||
static LogicalResult
|
||||
computeMaxPool2d(Operation *op, ConversionPatternRewriter &rewriter, Value self,
|
||||
SmallVectorImpl<int64_t> &kernelSizeInts,
|
||||
SmallVectorImpl<int64_t> &strideInts,
|
||||
SmallVectorImpl<int64_t> &paddingInts,
|
||||
SmallVectorImpl<int64_t> &dilationInts, Value &result) {
|
||||
// Creates a pooling operation based on the type specified by `OpTy` and
|
||||
// arguments passed.
|
||||
template <typename OpTy>
|
||||
static LogicalResult createPoolingOp(
|
||||
Operation *op, ConversionPatternRewriter &rewriter, Value self,
|
||||
bool supportFPInputOnly, SmallVectorImpl<int64_t> &kernelSizeInts,
|
||||
SmallVectorImpl<int64_t> &strideInts, SmallVectorImpl<int64_t> &paddingInts,
|
||||
SmallVectorImpl<int64_t> &dilationInts, Attribute intiValueAttr,
|
||||
SmallVectorImpl<Value> &outTensorShape, Value &paddedInput, Value &result) {
|
||||
Location loc = op->getLoc();
|
||||
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
|
||||
if (!elementType.isa<mlir::FloatType>())
|
||||
if (!elementType.isa<mlir::FloatType>() && !supportFPInputOnly)
|
||||
return op->emitError("unimplemented: non-floating point type");
|
||||
|
||||
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
|
||||
paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(),
|
||||
paddingInts.end());
|
||||
auto smallestFPValueAttr = rewriter.getFloatAttr(
|
||||
elementType, APFloat::getLargest(
|
||||
elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
/*Negative=*/true));
|
||||
Value smallestFPValue =
|
||||
rewriter.create<arith::ConstantOp>(loc, smallestFPValueAttr);
|
||||
Value paddedInput =
|
||||
torch_to_linalg::getPaddedTensor(op, rewriter, self, paddingIncludingNC,
|
||||
paddingIncludingNC, smallestFPValue);
|
||||
Value initValue = rewriter.create<arith::ConstantOp>(loc, intiValueAttr);
|
||||
paddedInput = torch_to_linalg::getPaddedTensor(
|
||||
op, rewriter, self, paddingIncludingNC, paddingIncludingNC, initValue);
|
||||
|
||||
Value N = getDimOp(rewriter, loc, self, 0);
|
||||
Value C = getDimOp(rewriter, loc, self, 1);
|
||||
|
@ -107,9 +99,9 @@ computeMaxPool2d(Operation *op, ConversionPatternRewriter &rewriter, Value self,
|
|||
kernelSizeIntValues[1], strideIntValues[1]);
|
||||
|
||||
// Create output tensor initialized with smallest floating point value.
|
||||
outTensorShape.insert(outTensorShape.begin(), {N, C, hOut, wOut});
|
||||
Value outTensorInitialized =
|
||||
createInitTensor(rewriter, loc, ValueRange{N, C, hOut, wOut}, elementType,
|
||||
smallestFPValue);
|
||||
createInitTensor(rewriter, loc, outTensorShape, elementType, initValue);
|
||||
|
||||
auto stridesAttr = rewriter.getI64VectorAttr(strideInts);
|
||||
auto dilationAttr = rewriter.getI64VectorAttr(dilationInts);
|
||||
|
@ -118,11 +110,11 @@ computeMaxPool2d(Operation *op, ConversionPatternRewriter &rewriter, Value self,
|
|||
elementType);
|
||||
|
||||
result = rewriter
|
||||
.create<linalg::PoolingNchwMaxOp>(
|
||||
loc, outTensorInitialized.getType(),
|
||||
ValueRange{paddedInput, windowTensor}, outTensorInitialized,
|
||||
stridesAttr, dilationAttr)
|
||||
.create<OpTy>(loc, outTensorInitialized.getType(),
|
||||
ValueRange{paddedInput, windowTensor},
|
||||
outTensorInitialized, stridesAttr, dilationAttr)
|
||||
.getResult(0);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -143,16 +135,28 @@ public:
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only support 4D input");
|
||||
|
||||
Value maxPool2d;
|
||||
SmallVector<int64_t, 2> kernelSizeInts, strideInts, paddingInts,
|
||||
dilationInts;
|
||||
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only support constant int dilations");
|
||||
if (failed(checkAndGetPoolingParameters<AtenMaxPool2dOp>(
|
||||
op, rewriter, kernelSizeInts, strideInts, paddingInts,
|
||||
dilationInts)))
|
||||
op, rewriter, kernelSizeInts, strideInts, paddingInts)))
|
||||
return rewriter.notifyMatchFailure(op, "invalid pooling parameters");
|
||||
|
||||
if (failed(computeMaxPool2d(op, rewriter, self, kernelSizeInts, strideInts,
|
||||
paddingInts, dilationInts, maxPool2d)))
|
||||
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
|
||||
auto smallestFPValueAttr = rewriter.getFloatAttr(
|
||||
elementType,
|
||||
APFloat::getLargest(
|
||||
elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
/*Negative=*/true));
|
||||
SmallVector<Value, 4> outTensorShape;
|
||||
// `maxpool2d` contains the result of maxpool2d operation over the input.
|
||||
Value maxPool2d, paddedInput;
|
||||
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
|
||||
op, rewriter, self, /*supportFPInput=*/false, kernelSizeInts,
|
||||
strideInts, paddingInts, dilationInts, smallestFPValueAttr,
|
||||
outTensorShape, paddedInput, maxPool2d)))
|
||||
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, maxPool2d);
|
||||
|
@ -208,36 +212,31 @@ public:
|
|||
|
||||
SmallVector<int64_t, 2> kernelSizeInts, strideInts, paddingInts,
|
||||
dilationInts;
|
||||
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only support constant int dilations");
|
||||
if (failed(checkAndGetPoolingParameters<AtenMaxPool2dWithIndicesOp>(
|
||||
op, rewriter, kernelSizeInts, strideInts, paddingInts,
|
||||
dilationInts)))
|
||||
op, rewriter, kernelSizeInts, strideInts, paddingInts)))
|
||||
return rewriter.notifyMatchFailure(op, "invalid pooling parameters");
|
||||
|
||||
// Contains the result of maxpool2d operation over the input.
|
||||
Value maxPool2d;
|
||||
if (failed(computeMaxPool2d(op, rewriter, self, kernelSizeInts, strideInts,
|
||||
paddingInts, dilationInts, maxPool2d)))
|
||||
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
|
||||
|
||||
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
|
||||
paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(),
|
||||
paddingInts.end());
|
||||
// `maxpool2d` contains the result of maxpool2d operation over the input.
|
||||
auto smallestFPValueAttr = rewriter.getFloatAttr(
|
||||
elementType,
|
||||
APFloat::getLargest(
|
||||
elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
/*Negative=*/true));
|
||||
Value smallestFPValue =
|
||||
rewriter.create<arith::ConstantOp>(loc, smallestFPValueAttr);
|
||||
Value paddedInput =
|
||||
torch_to_linalg::getPaddedTensor(op, rewriter, self, paddingIncludingNC,
|
||||
paddingIncludingNC, smallestFPValue);
|
||||
Value maxPool2d, paddedInput;
|
||||
SmallVector<Value, 4> outTensorShape;
|
||||
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
|
||||
op, rewriter, self, /*supportFPInput=*/false, kernelSizeInts,
|
||||
strideInts, paddingInts, dilationInts, smallestFPValueAttr,
|
||||
outTensorShape, paddedInput, maxPool2d)))
|
||||
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
|
||||
|
||||
SmallVector<Value> resultShape(getTensorSizes(rewriter, loc, maxPool2d));
|
||||
Value cstMinusOne =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(-1));
|
||||
Value indicesTensor =
|
||||
createInitTensor(rewriter, loc, resultShape,
|
||||
createInitTensor(rewriter, loc, outTensorShape,
|
||||
indicesRankedTensorType.getElementType(), cstMinusOne);
|
||||
|
||||
SmallVector<Value> kernelSize =
|
||||
|
@ -440,6 +439,88 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenAvgPool2dOp : public OpConversionPattern<AtenAvgPool2dOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenAvgPool2dOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
Location loc = op->getLoc();
|
||||
Value self = adaptor.self();
|
||||
|
||||
Type inputElementType =
|
||||
self.getType().cast<RankedTensorType>().getElementType();
|
||||
Type resultType = getTypeConverter()->convertType(op.getType());
|
||||
Type resultElementType =
|
||||
resultType.cast<RankedTensorType>().getElementType();
|
||||
|
||||
SmallVector<int64_t, 2> dilationInts{1, 1};
|
||||
SmallVector<int64_t, 2> kernelSizeInts, strideInts, paddingInts;
|
||||
if (failed(checkAndGetPoolingParameters<AtenAvgPool2dOp>(
|
||||
op, rewriter, kernelSizeInts, strideInts, paddingInts)))
|
||||
return rewriter.notifyMatchFailure(op, "invalid pooling parameters");
|
||||
|
||||
// TODO: Add support for count_include_pad equal to `False`.
|
||||
bool countIncludePad;
|
||||
if (!matchPattern(op.count_include_pad(),
|
||||
m_TorchConstantBool(&countIncludePad)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "count_include_pad must be a constant");
|
||||
if (!countIncludePad) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: count_include_pad is expected to be true");
|
||||
}
|
||||
|
||||
// `sumPool2d` contains the result of sumpool2d operation over the input.
|
||||
Value sumPool2d, paddedInput;
|
||||
SmallVector<Value, 4> outTensorShape;
|
||||
if (failed(createPoolingOp<linalg::PoolingNchwSumOp>(
|
||||
op, rewriter, self, /*supportFPInput=*/true, kernelSizeInts,
|
||||
strideInts, paddingInts, dilationInts,
|
||||
rewriter.getZeroAttr(inputElementType), outTensorShape, paddedInput,
|
||||
sumPool2d)))
|
||||
return rewriter.notifyMatchFailure(op, "unable to compute sumpool2d");
|
||||
|
||||
SmallVector<Value> kernelSizeIntValues =
|
||||
getAsConstantIntValues(rewriter, loc, kernelSizeInts);
|
||||
Value kHtimeskW = rewriter.create<arith::MulIOp>(
|
||||
loc, kernelSizeIntValues[0], kernelSizeIntValues[1]);
|
||||
Value divisor = op.divisor_override().getType().isa<Torch::NoneType>()
|
||||
? kHtimeskW
|
||||
: adaptor.divisor_override();
|
||||
divisor = convertScalarToDtype(rewriter, loc, divisor, resultElementType);
|
||||
|
||||
Value outputTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, outTensorShape, resultElementType);
|
||||
SmallVector<AffineMap> indexingMapsAvg(2,
|
||||
rewriter.getMultiDimIdentityMap(4));
|
||||
SmallVector<StringRef> iteratorTypesAvg(4, "parallel");
|
||||
|
||||
Value avgPool2d =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, outputTensor.getType(), sumPool2d, outputTensor,
|
||||
/*indexingMaps=*/indexingMapsAvg,
|
||||
/*iteratorTypes=*/iteratorTypesAvg,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value avg;
|
||||
if (resultElementType.isa<mlir::IntegerType>())
|
||||
avg = b.create<arith::DivSIOp>(loc, args[0], divisor);
|
||||
else if (resultElementType.isa<mlir::FloatType>())
|
||||
avg = b.create<arith::DivFOp>(loc, args[0], divisor);
|
||||
b.create<linalg::YieldOp>(loc, avg);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, avgPool2d);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target) {
|
||||
|
@ -450,4 +531,6 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
|
|||
patterns.add<ConvertAtenMaxPool2dWithIndicesOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenAdaptiveAvgPool2dOp>();
|
||||
patterns.add<ConvertAtenAdaptiveAvgPool2dOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenAvgPool2dOp>();
|
||||
patterns.add<ConvertAtenAvgPool2dOp>(typeConverter, context);
|
||||
}
|
||||
|
|
|
@ -490,14 +490,14 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
ValsemVariantAtenBernoulliFloatOp, ValsemVariantAtenBernoulliTensorOp,
|
||||
ValsemVariantAtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp,
|
||||
AtenHardswishOp, AtenSiluOp, AtenHardtanhOp, AtenMaskedSelectOp,
|
||||
AtenMaxPool2dOp, AtenAdaptiveAvgPool2dOp, AtenFlattenUsingIntsOp,
|
||||
AtenSqueezeOp, AtenSqueezeDimOp, AtenUnsqueezeOp, AtenViewOp,
|
||||
Aten_UnsafeViewOp, AtenReshapeOp, Aten_ReshapeAliasOp, AtenResize_Op,
|
||||
AtenTransposeIntOp, AtenTOp, AtenPermuteOp, AtenIndexSelectOp,
|
||||
AtenSelectIntOp, AtenSliceTensorOp, AtenGatherOp, AtenExpandOp,
|
||||
AtenExpandAsOp, AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp,
|
||||
AtenPadOp, AtenZero_Op, AtenIndexTensorOp,
|
||||
ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp,
|
||||
AtenMaxPool2dOp, AtenAvgPool2dOp, AtenAdaptiveAvgPool2dOp,
|
||||
AtenFlattenUsingIntsOp, AtenSqueezeOp, AtenSqueezeDimOp,
|
||||
AtenUnsqueezeOp, AtenViewOp, Aten_UnsafeViewOp, AtenReshapeOp,
|
||||
Aten_ReshapeAliasOp, AtenResize_Op, AtenTransposeIntOp, AtenTOp,
|
||||
AtenPermuteOp, AtenIndexSelectOp, AtenSelectIntOp, AtenSliceTensorOp,
|
||||
AtenGatherOp, AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp,
|
||||
AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp, AtenZero_Op,
|
||||
AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp,
|
||||
ValsemVariantAtenCopyOp, ValsemVariantAtenZeroOp,
|
||||
AtenIndexPutHackedTwinOp>(op)) {
|
||||
ValueKnowledge knowledge =
|
||||
|
|
|
@ -1584,6 +1584,159 @@ module {
|
|||
func @"__torch_mlir_shape_fn.aten.max_pool2d_with_indices_backward"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>) -> !torch.list<int> {
|
||||
return %arg1 : !torch.list<int>
|
||||
}
|
||||
func @"__torch_mlir_shape_fn.aten.avg_pool2d"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.avg_pool2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.optional<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.avg_pool2d(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {
|
||||
%int4 = torch.constant.int 4
|
||||
%int3 = torch.constant.int 3
|
||||
%int0 = torch.constant.int 0
|
||||
%int2 = torch.constant.int 2
|
||||
%int1 = torch.constant.int 1
|
||||
%true = torch.constant.bool true
|
||||
%str = torch.constant.str "AssertionError: avg_pool2d: kernel_size must either be a single int, or a tuple of two ints"
|
||||
%none = torch.constant.none
|
||||
%str_0 = torch.constant.str "AssertionError: avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints"
|
||||
%str_1 = torch.constant.str "AssertionError: avg_pool2d: padding must be either be a single int, or a tuple of two ints"
|
||||
%str_2 = torch.constant.str "AssertionError: "
|
||||
%int-4 = torch.constant.int -4
|
||||
%int-3 = torch.constant.int -3
|
||||
%int-2 = torch.constant.int -2
|
||||
%int-1 = torch.constant.int -1
|
||||
%0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
|
||||
%1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool
|
||||
%2 = torch.prim.If %1 -> (!torch.bool) {
|
||||
torch.prim.If.yield %true : !torch.bool
|
||||
} else {
|
||||
%39 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
|
||||
%40 = torch.aten.eq.int %39, %int2 : !torch.int, !torch.int -> !torch.bool
|
||||
torch.prim.If.yield %40 : !torch.bool
|
||||
}
|
||||
torch.prim.If %2 -> () {
|
||||
torch.prim.If.yield
|
||||
} else {
|
||||
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
|
||||
torch.prim.If.yield
|
||||
}
|
||||
%3 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%4 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
|
||||
%5 = torch.aten.eq.int %4, %int1 : !torch.int, !torch.int -> !torch.bool
|
||||
%6 = torch.prim.If %5 -> (!torch.int) {
|
||||
torch.prim.If.yield %3 : !torch.int
|
||||
} else {
|
||||
%39 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list<int>, !torch.int -> !torch.int
|
||||
torch.prim.If.yield %39 : !torch.int
|
||||
}
|
||||
%7 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int
|
||||
%8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
%9 = torch.prim.If %8 -> (!torch.bool) {
|
||||
torch.prim.If.yield %true : !torch.bool
|
||||
} else {
|
||||
%39 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int
|
||||
%40 = torch.aten.eq.int %39, %int1 : !torch.int, !torch.int -> !torch.bool
|
||||
torch.prim.If.yield %40 : !torch.bool
|
||||
}
|
||||
%10 = torch.prim.If %9 -> (!torch.bool) {
|
||||
torch.prim.If.yield %true : !torch.bool
|
||||
} else {
|
||||
%39 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int
|
||||
%40 = torch.aten.eq.int %39, %int2 : !torch.int, !torch.int -> !torch.bool
|
||||
torch.prim.If.yield %40 : !torch.bool
|
||||
}
|
||||
torch.prim.If %10 -> () {
|
||||
torch.prim.If.yield
|
||||
} else {
|
||||
torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none
|
||||
torch.prim.If.yield
|
||||
}
|
||||
%11 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int
|
||||
%12 = torch.aten.eq.int %11, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
%13 = torch.prim.If %12 -> (!torch.int) {
|
||||
torch.prim.If.yield %3 : !torch.int
|
||||
} else {
|
||||
%39 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int
|
||||
torch.prim.If.yield %39 : !torch.int
|
||||
}
|
||||
%14 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int
|
||||
%15 = torch.aten.eq.int %14, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
%16 = torch.prim.If %15 -> (!torch.int) {
|
||||
torch.prim.If.yield %6 : !torch.int
|
||||
} else {
|
||||
%39 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int
|
||||
%40 = torch.aten.eq.int %39, %int1 : !torch.int, !torch.int -> !torch.bool
|
||||
%41 = torch.prim.If %40 -> (!torch.int) {
|
||||
torch.prim.If.yield %13 : !torch.int
|
||||
} else {
|
||||
%42 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list<int>, !torch.int -> !torch.int
|
||||
torch.prim.If.yield %42 : !torch.int
|
||||
}
|
||||
torch.prim.If.yield %41 : !torch.int
|
||||
}
|
||||
%17 = torch.aten.len.t %arg3 : !torch.list<int> -> !torch.int
|
||||
%18 = torch.aten.eq.int %17, %int1 : !torch.int, !torch.int -> !torch.bool
|
||||
%19 = torch.prim.If %18 -> (!torch.bool) {
|
||||
torch.prim.If.yield %true : !torch.bool
|
||||
} else {
|
||||
%39 = torch.aten.len.t %arg3 : !torch.list<int> -> !torch.int
|
||||
%40 = torch.aten.eq.int %39, %int2 : !torch.int, !torch.int -> !torch.bool
|
||||
torch.prim.If.yield %40 : !torch.bool
|
||||
}
|
||||
torch.prim.If %19 -> () {
|
||||
torch.prim.If.yield
|
||||
} else {
|
||||
torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none
|
||||
torch.prim.If.yield
|
||||
}
|
||||
%20 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%21 = torch.aten.len.t %arg3 : !torch.list<int> -> !torch.int
|
||||
%22 = torch.aten.eq.int %21, %int1 : !torch.int, !torch.int -> !torch.bool
|
||||
%23 = torch.prim.If %22 -> (!torch.int) {
|
||||
torch.prim.If.yield %20 : !torch.int
|
||||
} else {
|
||||
%39 = torch.aten.__getitem__.t %arg3, %int1 : !torch.list<int>, !torch.int -> !torch.int
|
||||
torch.prim.If.yield %39 : !torch.int
|
||||
}
|
||||
%24 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||
%25 = torch.aten.eq.int %24, %int3 : !torch.int, !torch.int -> !torch.bool
|
||||
%26 = torch.prim.If %25 -> (!torch.bool) {
|
||||
torch.prim.If.yield %true : !torch.bool
|
||||
} else {
|
||||
%39 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||
%40 = torch.aten.eq.int %39, %int4 : !torch.int, !torch.int -> !torch.bool
|
||||
torch.prim.If.yield %40 : !torch.bool
|
||||
}
|
||||
torch.prim.If %26 -> () {
|
||||
torch.prim.If.yield
|
||||
} else {
|
||||
torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none
|
||||
torch.prim.If.yield
|
||||
}
|
||||
%27 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||
%28 = torch.aten.eq.int %27, %int4 : !torch.int, !torch.int -> !torch.bool
|
||||
%29 = torch.prim.If %28 -> (!torch.int) {
|
||||
%39 = torch.aten.__getitem__.t %arg0, %int-4 : !torch.list<int>, !torch.int -> !torch.int
|
||||
torch.prim.If.yield %39 : !torch.int
|
||||
} else {
|
||||
torch.prim.If.yield %int1 : !torch.int
|
||||
}
|
||||
%30 = torch.aten.__getitem__.t %arg0, %int-3 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%31 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%32 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%33 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.pooling_output_shape(%31, %3, %20, %13, %int1, %arg4) : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.bool) -> !torch.int
|
||||
%34 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.pooling_output_shape(%32, %6, %23, %16, %int1, %arg4) : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.bool) -> !torch.int
|
||||
%35 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.pool2d_shape_check(%arg0, %3, %6, %13, %16, %20, %23, %int1, %int1, %30, %31, %32, %33, %34) : (!torch.list<int>, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.none
|
||||
%36 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||
%37 = torch.aten.eq.int %36, %int3 : !torch.int, !torch.int -> !torch.bool
|
||||
%38 = torch.prim.If %37 -> (!torch.list<int>) {
|
||||
%39 = torch.prim.ListConstruct %30, %33, %34 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
torch.prim.If.yield %39 : !torch.list<int>
|
||||
} else {
|
||||
%39 = torch.prim.ListConstruct %29, %30, %33, %34 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
torch.prim.If.yield %39 : !torch.list<int>
|
||||
}
|
||||
return %38 : !torch.list<int>
|
||||
}
|
||||
func @"__torch_mlir_shape_fn.aten.adaptive_avg_pool2d"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.adaptive_avg_pool2d(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
|
|
|
@ -574,6 +574,9 @@ def aten〇max_pool2d_with_indices(self: List[int], kernel_size: List[int], stri
|
|||
def aten〇max_pool2d_with_indices_backward(grad_output: List[int], self: List[int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, indices: List[int]) -> List[int]:
|
||||
return self
|
||||
|
||||
def aten〇avg_pool2d(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> List[int]:
|
||||
return upstream_shape_helpers.avg_pool2d(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
|
||||
|
||||
def aten〇adaptive_avg_pool2d(self: List[int], output_size: List[int]) -> List[int]:
|
||||
return upstream_shape_helpers.adaptive_avg_pool2d(self, output_size)
|
||||
|
||||
|
|
|
@ -334,6 +334,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit(
|
||||
"aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::avg_pool2d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::softmax.int : (Tensor, int, int?) -> (Tensor)"
|
||||
)
|
||||
|
|
|
@ -213,6 +213,45 @@ def max_pool2d(input: List[int], kernel_size: List[int], stride: List[int], padd
|
|||
else:
|
||||
return [nbatch, nInputPlane, outputHeight, outputWidth]
|
||||
|
||||
def avg_pool2d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool, count_include_pad: bool, divisor_override: Optional[int]):
|
||||
assert len(kernel_size) == 1 or len(kernel_size) == 2, "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints"
|
||||
kH = kernel_size[0]
|
||||
kW = kH if len(kernel_size) == 1 else kernel_size[1]
|
||||
|
||||
assert len(stride) == 0 or len(stride) == 1 or len(stride) == 2, "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints"
|
||||
dH = kH if len(stride) == 0 else stride[0]
|
||||
if len(stride) == 0:
|
||||
dW = kW
|
||||
elif len(stride) == 1:
|
||||
dW = dH
|
||||
else:
|
||||
dW = stride[1]
|
||||
|
||||
assert len(padding) == 1 or len(padding) == 2, "avg_pool2d: padding must be either be a single int, or a tuple of two ints"
|
||||
padH = padding[0]
|
||||
padW = padH if len(padding) == 1 else padding[1]
|
||||
|
||||
dilationH = 1
|
||||
dilationW = 1
|
||||
|
||||
assert len(input) == 3 or len(input) == 4
|
||||
|
||||
nbatch = input[-4] if len(input) == 4 else 1
|
||||
nInputPlane = input[-3]
|
||||
inputHeight = input[-2]
|
||||
inputWidth = input[-1]
|
||||
|
||||
outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode)
|
||||
outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode)
|
||||
|
||||
pool2d_shape_check(input, kH, kW, dH, dW, padH, padW, dilationH, dilationW, nInputPlane,
|
||||
inputHeight, inputWidth, outputHeight, outputWidth)
|
||||
|
||||
if len(input) == 3:
|
||||
return [nInputPlane, outputHeight, outputWidth]
|
||||
else:
|
||||
return [nbatch, nInputPlane, outputHeight, outputWidth]
|
||||
|
||||
def max_pool2d_with_indices(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool):
|
||||
out = max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
|
||||
return (out, out)
|
||||
|
|
|
@ -465,3 +465,106 @@ class MaxPool2dWithIndicesBackwardDynamic3DModule(torch.nn.Module):
|
|||
def MaxPool2dWithIndicesBackwardDynamic3DModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 7, 6), tu.rand(2, 6, 5),
|
||||
torch.randint(16, (2, 7, 6)))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AvgPool2dFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ap2d = torch.nn.AvgPool2d(kernel_size=[6, 8],
|
||||
stride=[2, 2],
|
||||
padding=[3, 4],
|
||||
ceil_mode=False,
|
||||
count_include_pad=True,
|
||||
divisor_override=None)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.ap2d(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AvgPool2dFloatModule())
|
||||
def AvgPool2dFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4, 20, 20) - 0.5)
|
||||
|
||||
|
||||
class AvgPool2dIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ap2d = torch.nn.AvgPool2d(kernel_size=[6, 8],
|
||||
stride=[2, 2],
|
||||
padding=[3, 4],
|
||||
ceil_mode=False,
|
||||
count_include_pad=True,
|
||||
divisor_override=None)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.ap2d(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AvgPool2dIntModule())
|
||||
def AvgPool2dIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (2, 4, 20, 20)))
|
||||
|
||||
|
||||
class AvgPool2dStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ap2d = torch.nn.AvgPool2d(kernel_size=[6, 8],
|
||||
stride=[2, 2],
|
||||
padding=[3, 4],
|
||||
ceil_mode=False,
|
||||
count_include_pad=True,
|
||||
divisor_override=None)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([2, 2, 10, 20], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.ap2d(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AvgPool2dStaticModule())
|
||||
def AvgPool2dStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 2, 10, 20) - 0.5)
|
||||
|
||||
|
||||
class AvgPool2dDivisorOverrideModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ap2d = torch.nn.AvgPool2d(kernel_size=[4, 8],
|
||||
stride=[2, 3],
|
||||
padding=[2, 4],
|
||||
ceil_mode=False,
|
||||
count_include_pad=True,
|
||||
divisor_override=22)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([4, 4, 20, 20], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.ap2d(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AvgPool2dDivisorOverrideModule())
|
||||
def AvgPool2dDivisorOverrideModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 4, 20, 20) - 0.5)
|
||||
|
|
Loading…
Reference in New Issue