mirror of https://github.com/llvm/torch-mlir
add max_pool3d (#2386)
parent
bc6bba9077
commit
ca34b9c4fc
|
@ -5454,6 +5454,93 @@ def Torch_AtenMaxPool2dWithIndicesBackwardOp : Torch_Op<"aten.max_pool2d_with_in
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenMaxPool3dOp : Torch_Op<"aten.max_pool3d", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchListOfTorchIntType:$kernel_size,
|
||||
AnyTorchListOfTorchIntType:$stride,
|
||||
AnyTorchListOfTorchIntType:$padding,
|
||||
AnyTorchListOfTorchIntType:$dilation,
|
||||
Torch_BoolType:$ceil_mode
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenMaxPool3dOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 6, 1);
|
||||
}
|
||||
void AtenMaxPool3dOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 6, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenMaxPool3dWithIndicesOp : Torch_Op<"aten.max_pool3d_with_indices", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchListOfTorchIntType:$kernel_size,
|
||||
AnyTorchListOfTorchIntType:$stride,
|
||||
AnyTorchListOfTorchIntType:$padding,
|
||||
AnyTorchListOfTorchIntType:$dilation,
|
||||
Torch_BoolType:$ceil_mode
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result0,
|
||||
AnyTorchTensorType:$result1
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenMaxPool3dWithIndicesOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 6, 2);
|
||||
}
|
||||
void AtenMaxPool3dWithIndicesOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 6, 2);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenMaxPool3dWithIndicesBackwardOp : Torch_Op<"aten.max_pool3d_with_indices_backward", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::max_pool3d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$grad_output,
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchListOfTorchIntType:$kernel_size,
|
||||
AnyTorchListOfTorchIntType:$stride,
|
||||
AnyTorchListOfTorchIntType:$padding,
|
||||
AnyTorchListOfTorchIntType:$dilation,
|
||||
Torch_BoolType:$ceil_mode,
|
||||
AnyTorchTensorType:$indices
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenMaxPool3dWithIndicesBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 8, 1);
|
||||
}
|
||||
void AtenMaxPool3dWithIndicesBackwardOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 8, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenAvgPool2dOp : Torch_Op<"aten.avg_pool2d", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -27,6 +27,7 @@ std::vector<torch::lazy::Shape> compute_shape_add(const at::Tensor& self,
|
|||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||
}
|
||||
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_sub(const at::Tensor& self,
|
||||
const at::Scalar& other,
|
||||
const at::Scalar& alpha) {
|
||||
|
@ -38,6 +39,47 @@ std::vector<torch::lazy::Shape> compute_shape_div(const at::Tensor& self,
|
|||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_max_pool3d_with_indices(
|
||||
const at::Tensor& self, at::IntArrayRef kernel_size,
|
||||
at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation,
|
||||
bool ceil_mode) {
|
||||
auto in_sizes = self.sizes().vec();
|
||||
std::vector<int64_t> dhw(3, 0);
|
||||
std::vector<int64_t> paddings = padding.vec();
|
||||
std::vector<int64_t> ksizes = kernel_size.vec();
|
||||
std::vector<int64_t> dilations = dilation.vec();
|
||||
std::vector<int64_t> strides = stride.vec();
|
||||
TORCH_CHECK(in_sizes.size() == 5, "max_pool3d requires 5D inputs, but got ",
|
||||
in_sizes);
|
||||
TORCH_CHECK(kernel_size.size() == 3 &&
|
||||
stride.size() == 3 &&
|
||||
padding.size() == 3 &&
|
||||
dilation.size() == 3, "max_pool3d requires 3D operands, but got ",
|
||||
kernel_size, stride, padding, dilation);
|
||||
int64_t batch = in_sizes[0];
|
||||
int64_t channel = in_sizes[1]; // NCDHW
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool3d.html
|
||||
for (auto i = 0UL; i<3; ++i) {
|
||||
double out_size = (in_sizes[2+i] + 2 * paddings[i] - dilations[i] *
|
||||
(ksizes[i] - 1) - 1) / (double)strides[i] + 1;
|
||||
if (ceil_mode)
|
||||
dhw[i] = (int64_t)std::ceil(out_size);
|
||||
else
|
||||
dhw[i] = (int64_t)std::floor(out_size);
|
||||
}
|
||||
auto out_sizes = {batch, channel, dhw[0], dhw[1], dhw[2]};
|
||||
// `with_indices` returns output and index Tensor
|
||||
return {Shape(self.scalar_type(), out_sizes), Shape(at::kLong, out_sizes)};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_max_pool3d_with_indices_backward(
|
||||
const at::Tensor & grad_output, const at::Tensor & self,
|
||||
at::IntArrayRef kernel_size, at::IntArrayRef stride,
|
||||
at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode,
|
||||
const at::Tensor & indices) {
|
||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_mse_loss_backward(
|
||||
const at::Tensor& grad_output, const at::Tensor& self,
|
||||
const at::Tensor& target, int64_t reduction) {
|
||||
|
|
|
@ -416,6 +416,15 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit(
|
||||
"aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::max_pool3d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::avg_pool2d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)"
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue