Add LTC shape inference for aten.reflection_pad2d

pre_fixup_20240110
Frederik Harwath 2024-01-09 07:16:49 -08:00 committed by Frederik Harwath
parent 345dfd5903
commit 4fb58002ab
1 changed files with 28 additions and 0 deletions

View File

@ -227,6 +227,34 @@ std::vector<torch::lazy::Shape> compute_shape_remainder(
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape>
compute_shape_reflection_pad2d(const at::Tensor &self,
at::IntArrayRef padding) {
std::vector<int64_t> paddings = padding.vec();
std::vector<int64_t> in_sizes = self.sizes().vec();
auto num_dims = in_sizes.size();
TORCH_CHECK(padding.size() == 4);
TORCH_CHECK(num_dims >= 2);
auto vdim = num_dims - 2;
auto hdim = num_dims - 1;
auto padding_left = padding[0];
auto padding_right = padding[1];
auto padding_top = padding[2];
auto padding_bottom = padding[3];
TORCH_CHECK(padding_left < in_sizes[hdim]);
TORCH_CHECK(padding_right < in_sizes[hdim]);
TORCH_CHECK(padding_top < in_sizes[vdim]);
TORCH_CHECK(padding_bottom < in_sizes[vdim]);
std::vector<int64_t> out_sizes(in_sizes);
out_sizes[hdim] += padding_left + padding_right;
out_sizes[vdim] += padding_top + padding_bottom;
return {Shape(self.scalar_type(), out_sizes)};
}
std::vector<torch::lazy::Shape> compute_shape_uniform(
const at::Tensor& self, double from, double to,
c10::optional<at::Generator> generator) {