add max_pool3d (#2386)

pull/2275/merge
David Gens 2023-08-28 16:01:55 -07:00 committed by GitHub
parent bc6bba9077
commit ca34b9c4fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 138 additions and 0 deletions

View File

@ -5454,6 +5454,93 @@ def Torch_AtenMaxPool2dWithIndicesBackwardOp : Torch_Op<"aten.max_pool2d_with_in
}];
}
def Torch_AtenMaxPool3dOp : Torch_Op<"aten.max_pool3d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$kernel_size,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding,
AnyTorchListOfTorchIntType:$dilation,
Torch_BoolType:$ceil_mode
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaxPool3dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenMaxPool3dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}
def Torch_AtenMaxPool3dWithIndicesOp : Torch_Op<"aten.max_pool3d_with_indices", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$kernel_size,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding,
AnyTorchListOfTorchIntType:$dilation,
Torch_BoolType:$ceil_mode
);
let results = (outs
AnyTorchTensorType:$result0,
AnyTorchTensorType:$result1
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaxPool3dWithIndicesOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 2);
}
void AtenMaxPool3dWithIndicesOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 2);
}
}];
}
def Torch_AtenMaxPool3dWithIndicesBackwardOp : Torch_Op<"aten.max_pool3d_with_indices_backward", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::max_pool3d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$grad_output,
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$kernel_size,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding,
AnyTorchListOfTorchIntType:$dilation,
Torch_BoolType:$ceil_mode,
AnyTorchTensorType:$indices
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaxPool3dWithIndicesBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 8, 1);
}
void AtenMaxPool3dWithIndicesBackwardOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 8, 1);
}
}];
}
def Torch_AtenAvgPool2dOp : Torch_Op<"aten.avg_pool2d", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -27,6 +27,7 @@ std::vector<torch::lazy::Shape> compute_shape_add(const at::Tensor& self,
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_sub(const at::Tensor& self,
const at::Scalar& other,
const at::Scalar& alpha) {
@ -38,6 +39,47 @@ std::vector<torch::lazy::Shape> compute_shape_div(const at::Tensor& self,
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_max_pool3d_with_indices(
const at::Tensor& self, at::IntArrayRef kernel_size,
at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation,
bool ceil_mode) {
auto in_sizes = self.sizes().vec();
std::vector<int64_t> dhw(3, 0);
std::vector<int64_t> paddings = padding.vec();
std::vector<int64_t> ksizes = kernel_size.vec();
std::vector<int64_t> dilations = dilation.vec();
std::vector<int64_t> strides = stride.vec();
TORCH_CHECK(in_sizes.size() == 5, "max_pool3d requires 5D inputs, but got ",
in_sizes);
TORCH_CHECK(kernel_size.size() == 3 &&
stride.size() == 3 &&
padding.size() == 3 &&
dilation.size() == 3, "max_pool3d requires 3D operands, but got ",
kernel_size, stride, padding, dilation);
int64_t batch = in_sizes[0];
int64_t channel = in_sizes[1]; // NCDHW
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool3d.html
for (auto i = 0UL; i<3; ++i) {
double out_size = (in_sizes[2+i] + 2 * paddings[i] - dilations[i] *
(ksizes[i] - 1) - 1) / (double)strides[i] + 1;
if (ceil_mode)
dhw[i] = (int64_t)std::ceil(out_size);
else
dhw[i] = (int64_t)std::floor(out_size);
}
auto out_sizes = {batch, channel, dhw[0], dhw[1], dhw[2]};
// `with_indices` returns output and index Tensor
return {Shape(self.scalar_type(), out_sizes), Shape(at::kLong, out_sizes)};
}
std::vector<torch::lazy::Shape> compute_shape_max_pool3d_with_indices_backward(
const at::Tensor & grad_output, const at::Tensor & self,
at::IntArrayRef kernel_size, at::IntArrayRef stride,
at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode,
const at::Tensor & indices) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_mse_loss_backward(
const at::Tensor& grad_output, const at::Tensor& self,
const at::Tensor& target, int64_t reduction) {

View File

@ -416,6 +416,15 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit(
"aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
)
emit(
"aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)"
)
emit(
"aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
)
emit(
"aten::max_pool3d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
)
emit(
"aten::avg_pool2d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)"
)