From 4fb58002abcdddd3c360eeb88da9d2a21b9c03ea Mon Sep 17 00:00:00 2001 From: Frederik Harwath Date: Tue, 9 Jan 2024 07:16:49 -0800 Subject: [PATCH] Add LTC shape inference for aten.reflection_pad2d --- .../base_lazy_backend/shape_inference.cpp | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp b/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp index 244ee7b88..3971fdd32 100644 --- a/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp +++ b/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp @@ -227,6 +227,34 @@ std::vector compute_shape_remainder( return {Shape(self.scalar_type(), self.sizes().vec())}; } +std::vector +compute_shape_reflection_pad2d(const at::Tensor &self, + at::IntArrayRef padding) { + std::vector paddings = padding.vec(); + std::vector in_sizes = self.sizes().vec(); + auto num_dims = in_sizes.size(); + + TORCH_CHECK(padding.size() == 4); + TORCH_CHECK(num_dims >= 2); + + auto vdim = num_dims - 2; + auto hdim = num_dims - 1; + auto padding_left = padding[0]; + auto padding_right = padding[1]; + auto padding_top = padding[2]; + auto padding_bottom = padding[3]; + TORCH_CHECK(padding_left < in_sizes[hdim]); + TORCH_CHECK(padding_right < in_sizes[hdim]); + TORCH_CHECK(padding_top < in_sizes[vdim]); + TORCH_CHECK(padding_bottom < in_sizes[vdim]); + + std::vector out_sizes(in_sizes); + out_sizes[hdim] += padding_left + padding_right; + out_sizes[vdim] += padding_top + padding_bottom; + + return {Shape(self.scalar_type(), out_sizes)}; +} + std::vector compute_shape_uniform( const at::Tensor& self, double from, double to, c10::optional generator) {