mirror of https://github.com/llvm/torch-mlir
Add LTC shape inference for aten.reflection_pad2d
parent
345dfd5903
commit
4fb58002ab
|
@ -227,6 +227,34 @@ std::vector<torch::lazy::Shape> compute_shape_remainder(
|
|||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape>
|
||||
compute_shape_reflection_pad2d(const at::Tensor &self,
|
||||
at::IntArrayRef padding) {
|
||||
std::vector<int64_t> paddings = padding.vec();
|
||||
std::vector<int64_t> in_sizes = self.sizes().vec();
|
||||
auto num_dims = in_sizes.size();
|
||||
|
||||
TORCH_CHECK(padding.size() == 4);
|
||||
TORCH_CHECK(num_dims >= 2);
|
||||
|
||||
auto vdim = num_dims - 2;
|
||||
auto hdim = num_dims - 1;
|
||||
auto padding_left = padding[0];
|
||||
auto padding_right = padding[1];
|
||||
auto padding_top = padding[2];
|
||||
auto padding_bottom = padding[3];
|
||||
TORCH_CHECK(padding_left < in_sizes[hdim]);
|
||||
TORCH_CHECK(padding_right < in_sizes[hdim]);
|
||||
TORCH_CHECK(padding_top < in_sizes[vdim]);
|
||||
TORCH_CHECK(padding_bottom < in_sizes[vdim]);
|
||||
|
||||
std::vector<int64_t> out_sizes(in_sizes);
|
||||
out_sizes[hdim] += padding_left + padding_right;
|
||||
out_sizes[vdim] += padding_top + padding_bottom;
|
||||
|
||||
return {Shape(self.scalar_type(), out_sizes)};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_uniform(
|
||||
const at::Tensor& self, double from, double to,
|
||||
c10::optional<at::Generator> generator) {
|
||||
|
|
Loading…
Reference in New Issue