[Torch Dialect] add avg_pool 2d and 3d op variants (#2473)

Adds ODS for `avg_pool2d` and `avg_pool3d`, including their backward and
`adaptive_` variants.
pull/2480/head
David Gens 2023-09-20 10:47:08 -07:00 committed by GitHub
parent 20ea1c9e91
commit 023fc90072
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 236 additions and 17 deletions

View File

@ -1317,10 +1317,6 @@ LTC_XFAIL_SET = {
"_ConvolutionDeprecated2DBenchmarkModule_basic", "_ConvolutionDeprecated2DBenchmarkModule_basic",
"_ConvolutionDeprecated2DCudnnModule_basic", "_ConvolutionDeprecated2DCudnnModule_basic",
"_ConvolutionDeprecated2DDeterministicModule_basic", "_ConvolutionDeprecated2DDeterministicModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AddIntModule_basic", "AddIntModule_basic",
"AtenIntBoolOpModule_basic", "AtenIntBoolOpModule_basic",
"BernoulliTensorModule_basic", "BernoulliTensorModule_basic",

View File

@ -5563,6 +5563,34 @@ def Torch_AtenMaxPool3dWithIndicesBackwardOp : Torch_Op<"aten.max_pool3d_with_in
}]; }];
} }
def Torch_AtenAvgPool1dOp : Torch_Op<"aten.avg_pool1d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::avg_pool1d : (Tensor, int[], int[], int[], bool, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$kernel_size,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding,
Torch_BoolType:$ceil_mode,
Torch_BoolType:$count_include_pad
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAvgPool1dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenAvgPool1dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}
def Torch_AtenAvgPool2dOp : Torch_Op<"aten.avg_pool2d", [ def Torch_AtenAvgPool2dOp : Torch_Op<"aten.avg_pool2d", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,
@ -5592,30 +5620,91 @@ def Torch_AtenAvgPool2dOp : Torch_Op<"aten.avg_pool2d", [
}]; }];
} }
def Torch_AtenAvgPool1dOp : Torch_Op<"aten.avg_pool1d", [ def Torch_AtenAvgPool2dBackwardOp : Torch_Op<"aten.avg_pool2d_backward", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,
ReadOnly ReadOnly
]> { ]> {
let summary = "Generated op for `aten::avg_pool1d : (Tensor, int[], int[], int[], bool, bool) -> (Tensor)`"; let summary = "Generated op for `aten::avg_pool2d_backward : (Tensor, Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)`";
let arguments = (ins let arguments = (ins
AnyTorchTensorType:$grad_output,
AnyTorchTensorType:$self, AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$kernel_size, AnyTorchListOfTorchIntType:$kernel_size,
AnyTorchListOfTorchIntType:$stride, AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding, AnyTorchListOfTorchIntType:$padding,
Torch_BoolType:$ceil_mode, Torch_BoolType:$ceil_mode,
Torch_BoolType:$count_include_pad Torch_BoolType:$count_include_pad,
AnyTorchOptionalIntType:$divisor_override
); );
let results = (outs let results = (outs
AnyTorchTensorType:$result AnyTorchTensorType:$result
); );
let hasCustomAssemblyFormat = 1; let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{ let extraClassDefinition = [{
ParseResult AtenAvgPool1dOp::parse(OpAsmParser &parser, OperationState &result) { ParseResult AtenAvgPool2dBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1); return parseDefaultTorchOp(parser, result, 8, 1);
} }
void AtenAvgPool1dOp::print(OpAsmPrinter &printer) { void AtenAvgPool2dBackwardOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1); printDefaultTorchOp(printer, *this, 8, 1);
}
}];
}
def Torch_AtenAvgPool3dOp : Torch_Op<"aten.avg_pool3d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::avg_pool3d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$kernel_size,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding,
Torch_BoolType:$ceil_mode,
Torch_BoolType:$count_include_pad,
AnyTorchOptionalIntType:$divisor_override
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAvgPool3dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 7, 1);
}
void AtenAvgPool3dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 7, 1);
}
}];
}
def Torch_AtenAvgPool3dBackwardOp : Torch_Op<"aten.avg_pool3d_backward", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::avg_pool3d_backward : (Tensor, Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$grad_output,
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$kernel_size,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding,
Torch_BoolType:$ceil_mode,
Torch_BoolType:$count_include_pad,
AnyTorchOptionalIntType:$divisor_override
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAvgPool3dBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 8, 1);
}
void AtenAvgPool3dBackwardOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 8, 1);
} }
}]; }];
} }
@ -5846,6 +5935,30 @@ def Torch_AtenMaskedScatter_Op : Torch_Op<"aten.masked_scatter_", [
}]; }];
} }
def Torch_AtenAdaptiveAvgPool1dOp : Torch_Op<"aten.adaptive_avg_pool1d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::adaptive_avg_pool1d : (Tensor, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$output_size
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAdaptiveAvgPool1dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenAdaptiveAvgPool1dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [ def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,
@ -5870,12 +5983,12 @@ def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [
}]; }];
} }
def Torch_AtenAdaptiveAvgPool1dOp : Torch_Op<"aten.adaptive_avg_pool1d", [ def Torch_Aten_AdaptiveAvgPool2dOp : Torch_Op<"aten._adaptive_avg_pool2d", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,
ReadOnly ReadOnly
]> { ]> {
let summary = "Generated op for `aten::adaptive_avg_pool1d : (Tensor, int[]) -> (Tensor)`"; let summary = "Generated op for `aten::_adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)`";
let arguments = (ins let arguments = (ins
AnyTorchTensorType:$self, AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$output_size AnyTorchListOfTorchIntType:$output_size
@ -5885,10 +5998,106 @@ def Torch_AtenAdaptiveAvgPool1dOp : Torch_Op<"aten.adaptive_avg_pool1d", [
); );
let hasCustomAssemblyFormat = 1; let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{ let extraClassDefinition = [{
ParseResult AtenAdaptiveAvgPool1dOp::parse(OpAsmParser &parser, OperationState &result) { ParseResult Aten_AdaptiveAvgPool2dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1); return parseDefaultTorchOp(parser, result, 2, 1);
} }
void AtenAdaptiveAvgPool1dOp::print(OpAsmPrinter &printer) { void Aten_AdaptiveAvgPool2dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_Aten_AdaptiveAvgPool2dBackwardOp : Torch_Op<"aten._adaptive_avg_pool2d_backward", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_adaptive_avg_pool2d_backward : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$grad_output,
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten_AdaptiveAvgPool2dBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void Aten_AdaptiveAvgPool2dBackwardOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenAdaptiveAvgPool3dOp : Torch_Op<"aten.adaptive_avg_pool3d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$output_size
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAdaptiveAvgPool3dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenAdaptiveAvgPool3dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_Aten_AdaptiveAvgPool3dOp : Torch_Op<"aten._adaptive_avg_pool3d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$output_size
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten_AdaptiveAvgPool3dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void Aten_AdaptiveAvgPool3dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_Aten_AdaptiveAvgPool3dBackwardOp : Torch_Op<"aten._adaptive_avg_pool3d_backward", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_adaptive_avg_pool3d_backward : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$grad_output,
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten_AdaptiveAvgPool3dBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void Aten_AdaptiveAvgPool3dBackwardOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1); printDefaultTorchOp(printer, *this, 2, 1);
} }
}]; }];

View File

@ -426,11 +426,20 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit( emit(
"aten::max_pool3d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)" "aten::max_pool3d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
) )
emit(
"aten::avg_pool1d : (Tensor, int[], int[], int[], bool, bool) -> (Tensor)"
)
emit( emit(
"aten::avg_pool2d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)" "aten::avg_pool2d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)"
) )
emit( emit(
"aten::avg_pool1d : (Tensor, int[], int[], int[], bool, bool) -> (Tensor)" "aten::avg_pool2d_backward : (Tensor, Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)"
)
emit(
"aten::avg_pool3d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)"
)
emit(
"aten::avg_pool3d_backward : (Tensor, Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)"
) )
emit( emit(
"aten::softmax.int : (Tensor, int, int?) -> (Tensor)" "aten::softmax.int : (Tensor, int, int?) -> (Tensor)"
@ -444,8 +453,13 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit_with_mutating_variants("aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)") emit_with_mutating_variants("aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)")
emit_with_mutating_variants("aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)") emit_with_mutating_variants("aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)")
emit_with_mutating_variants("aten::masked_scatter : (Tensor, Tensor, Tensor) -> (Tensor)") emit_with_mutating_variants("aten::masked_scatter : (Tensor, Tensor, Tensor) -> (Tensor)")
emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")
emit("aten::adaptive_avg_pool1d : (Tensor, int[]) -> (Tensor)") emit("aten::adaptive_avg_pool1d : (Tensor, int[]) -> (Tensor)")
emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")
emit("aten::_adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")
emit("aten::_adaptive_avg_pool2d_backward : (Tensor, Tensor) -> (Tensor)")
emit("aten::adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)")
emit("aten::_adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)")
emit("aten::_adaptive_avg_pool3d_backward : (Tensor, Tensor) -> (Tensor)")
emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)") emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)")
emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)") emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)")
emit("aten::permute : (Tensor, int[]) -> (Tensor)") emit("aten::permute : (Tensor, int[]) -> (Tensor)")