//===- LazyShapeInference.cpp ---------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include #include #include #include #include "generated/shape_inference.h" #include "utils/exception.h" namespace torch { namespace lazy { // TODO(henrytu): Upstream these shape inference functions to PyTorch in the // future. std::vector compute_shape_add(const at::Tensor &self, const at::Scalar &other, const at::Scalar &alpha) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_sub(const at::Tensor &self, const at::Scalar &other, const at::Scalar &alpha) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_div(const at::Tensor &self, const at::Scalar &other) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape__make_per_channel_quantized_tensor(const at::Tensor &self, const at::Tensor &scale, const at::Tensor &zero_point, int64_t axis) { if (self.scalar_type() == at::kChar) return {Shape(at::kQInt8, self.sizes().vec())}; if (self.scalar_type() == at::kByte) return {Shape(at::kQUInt8, self.sizes().vec())}; if (self.scalar_type() == at::kInt) return {Shape(at::kQInt32, self.sizes().vec())}; assert(false); } std::vector compute_shape__make_per_tensor_quantized_tensor( const at::Tensor &self, double scale, int64_t zero_point) { if (self.scalar_type() == at::kChar) return {Shape(at::kQInt8, self.sizes().vec())}; if (self.scalar_type() == at::kByte) return {Shape(at::kQUInt8, self.sizes().vec())}; if (self.scalar_type() == at::kInt) return {Shape(at::kQInt32, self.sizes().vec())}; assert(false); } std::vector compute_shape_int_repr(const at::Tensor &self) { if (self.scalar_type() == at::kQInt8) return {Shape(at::kChar, self.sizes().vec())}; if (self.scalar_type() == at::kQUInt8) return {Shape(at::kByte, self.sizes().vec())}; if (self.scalar_type() == at::kQInt32) return {Shape(at::kInt, self.sizes().vec())}; assert(false); } std::vector compute_shape_dequantize(const at::Tensor &self) { return {Shape(at::kFloat, self.sizes().vec())}; } std::vector compute_shape_quantize_per_tensor(const at::Tensor &self, double scale, int64_t zero_point, at::ScalarType dtype) { return {Shape(dtype, self.sizes().vec())}; } std::vector compute_shape_isinf(const at::Tensor &self) { return {Shape(at::kBool, self.sizes().vec())}; } std::vector compute_shape_quantize_per_channel( const at::Tensor &self, const at::Tensor &scales, const at::Tensor &zero_points, int64_t axis, at::ScalarType dtype) { return {Shape(dtype, self.sizes().vec())}; } std::vector compute_shape_max_pool3d_with_indices( const at::Tensor &self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { auto in_sizes = self.sizes().vec(); std::vector dhw(3, 0); std::vector paddings = padding.vec(); std::vector ksizes = kernel_size.vec(); std::vector dilations = dilation.vec(); std::vector strides = stride.vec(); TORCH_CHECK(in_sizes.size() == 5, "max_pool3d requires 5D inputs, but got ", in_sizes); TORCH_CHECK(kernel_size.size() == 3 && stride.size() == 3 && padding.size() == 3 && dilation.size() == 3, "max_pool3d requires 3D operands, but got ", kernel_size, stride, padding, dilation); int64_t batch = in_sizes[0]; int64_t channel = in_sizes[1]; // NCDHW // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool3d.html for (auto i = 0UL; i < 3; ++i) { double out_size = (in_sizes[2 + i] + 2 * paddings[i] - dilations[i] * (ksizes[i] - 1) - 1) / (double)strides[i] + 1; if (ceil_mode) dhw[i] = (int64_t)std::ceil(out_size); else dhw[i] = (int64_t)std::floor(out_size); } auto out_sizes = {batch, channel, dhw[0], dhw[1], dhw[2]}; // `with_indices` returns output and index Tensor return {Shape(self.scalar_type(), out_sizes), Shape(at::kLong, out_sizes)}; } std::vector compute_shape_max_pool3d_with_indices_backward( const at::Tensor &grad_output, const at::Tensor &self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor &indices) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_mse_loss_backward(const at::Tensor &grad_output, const at::Tensor &self, const at::Tensor &target, int64_t reduction) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_mul(const at::Tensor &self, const at::Scalar &other) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_var(const at::Tensor &self, at::OptionalIntArrayRef dim, const ::std::optional &correction, bool keepdim) { // Result of variance is scalar tensor. return {Shape(self.scalar_type(), {})}; } std::vector compute_shape_var(const at::Tensor &self, at::OptionalIntArrayRef dim, ::std::optional &correction, bool keepdim) { // Result of variance is scalar tensor. return {Shape(self.scalar_type(), {})}; } std::vector compute_shape_nan_to_num(const at::Tensor &self, ::std::optional nan, ::std::optional posinf, ::std::optional neginf) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_hardtanh(const at::Tensor &self, const at::Scalar &min_val, const at::Scalar &max_val) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_hardtanh_backward( const at::Tensor &grad_output, const at::Tensor &self, const at::Scalar &min_val, const at::Scalar &max_val) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_where(const at::Tensor &condition, const at::Tensor &self, const at::Tensor &other) { // There are cases like - // torch.aten.where.self %42, %arg17, %37 : !torch.vtensor<[15,10],i1>, // !torch.vtensor<[],f32>, !torch.vtensor<[15,10],f32>. // So the result tensor would the biggest of all the three operands. auto condition_meta = at::native::empty_strided_meta_symint( condition.sym_sizes(), condition.sym_strides(), /*dtype=*/c10::make_optional(condition.scalar_type()), /*layout=*/c10::make_optional(condition.layout()), /*device=*/c10::make_optional(c10::Device(c10::kMeta)), /*pin_memory=*/c10::nullopt); auto self_meta = at::native::empty_strided_meta_symint( self.sym_sizes(), self.sym_strides(), /*dtype=*/c10::make_optional(self.scalar_type()), /*layout=*/c10::make_optional(self.layout()), /*device=*/c10::make_optional(c10::Device(c10::kMeta)), /*pin_memory=*/c10::nullopt); auto other_meta = at::native::empty_strided_meta_symint( other.sym_sizes(), other.sym_strides(), /*dtype=*/c10::make_optional(other.scalar_type()), /*layout=*/c10::make_optional(other.layout()), /*device=*/c10::make_optional(c10::Device(c10::kMeta)), /*pin_memory=*/c10::nullopt); auto out_meta = at::where(condition_meta, self_meta, other_meta); return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; } std::vector compute_shape_bucketize(const at::Tensor &self, const at::Tensor &boundaries, bool out_int32, bool right) { auto dtype = out_int32 ? at::kInt : at::kLong; return {Shape(dtype, self.sizes().vec())}; } std::vector compute_shape_copy(const at::Tensor &self, const at::Tensor &src, bool non_blocking) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_floor_divide(const at::Tensor &self, const at::Tensor &other) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_fmod(const at::Tensor &self, const at::Scalar &other) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_native_group_norm( const at::Tensor &input, const ::std::optional &weight, const ::std::optional &bias, int64_t N, int64_t C, int64_t HxW, int64_t group, double eps) { TORCH_CHECK(input.sizes().size() >= 2, "Input tensor must have at least batch and channel dimensions!"); std::vector shapes; shapes.reserve(3); shapes.emplace_back(input.scalar_type(), input.sizes().vec()); // A separate mean and var needs to be kept for each group per N. shapes.emplace_back(at::get_default_dtype_as_scalartype(), std::vector{N, group}); shapes.emplace_back(at::get_default_dtype_as_scalartype(), std::vector{N, group}); return shapes; } std::vector compute_shape_im2col(const at::Tensor &self, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) { auto self_meta = at::native::empty_strided_meta_symint( self.sym_sizes(), self.sym_strides(), /*dtype=*/c10::make_optional(self.scalar_type()), /*layout=*/c10::make_optional(self.layout()), /*device=*/c10::make_optional(c10::Device(c10::kMeta)), /*pin_memory=*/c10::nullopt); auto out_meta = at::im2col(self_meta, kernel_size, dilation, padding, stride); return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; } std::vector compute_shape_native_group_norm_backward( const at::Tensor &grad_out, const at::Tensor &input, const at::Tensor &mean, const at::Tensor &rstd, const ::std::optional &weight, int64_t N, int64_t C, int64_t HxW, int64_t group, ::std::array output_mask) { TORCH_CHECK(input.sizes().size() >= 2, "Input tensor must have at least batch and channel dimensions!"); std::vector shapes; shapes.reserve(3); shapes.emplace_back(input.scalar_type(), input.sizes().vec()); int64_t num_features = input.size(1); // `weight` and `bias` are vectors of length C (number of channels)` shapes.emplace_back(at::get_default_dtype_as_scalartype(), std::vector{num_features}); shapes.emplace_back(at::get_default_dtype_as_scalartype(), std::vector{num_features}); return shapes; } std::vector compute_shape_remainder(const at::Tensor &self, const at::Scalar &other) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_reflection_pad2d(const at::Tensor &self, at::IntArrayRef padding) { std::vector paddings = padding.vec(); std::vector in_sizes = self.sizes().vec(); auto num_dims = in_sizes.size(); TORCH_CHECK(padding.size() == 4); TORCH_CHECK(num_dims >= 2); auto vdim = num_dims - 2; auto hdim = num_dims - 1; auto padding_left = padding[0]; auto padding_right = padding[1]; auto padding_top = padding[2]; auto padding_bottom = padding[3]; TORCH_CHECK(padding_left < in_sizes[hdim]); TORCH_CHECK(padding_right < in_sizes[hdim]); TORCH_CHECK(padding_top < in_sizes[vdim]); TORCH_CHECK(padding_bottom < in_sizes[vdim]); std::vector out_sizes(in_sizes); out_sizes[hdim] += padding_left + padding_right; out_sizes[vdim] += padding_top + padding_bottom; return {Shape(self.scalar_type(), out_sizes)}; } std::vector compute_shape_uniform(const at::Tensor &self, double from, double to, ::std::optional generator) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_normal_functional(const at::Tensor &self, double mean, double std, ::std::optional generator) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_multinomial(const at::Tensor &self, int64_t num_samples, bool replacement, ::std::optional generator) { // Input tensor can be either 1D or 2D. The last dim of output // should be 'num_samples'. So the output shape can be either // [num_samples] or [m, num_samples]. // Output type can only be long tensor. auto ishape = self.sizes().vec(); ishape.back() = num_samples; return {Shape(at::kLong, ishape)}; } std::vector compute_shape_eye(int64_t n, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { auto out_meta = at::eye(n, dtype, layout, c10::Device(c10::kMeta), pin_memory); return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; } std::vector compute_shape_eye(int64_t n, int64_t m, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { auto out_meta = at::eye(n, m, dtype, layout, c10::Device(c10::kMeta), pin_memory); return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; } std::vector compute_shape_arange( const at::Scalar &end, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { auto out_meta = at::arange(end, dtype, layout, c10::Device(c10::kMeta), pin_memory); return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; } std::vector compute_shape_arange( const at::Scalar &start, const at::Scalar &end, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { auto out_meta = at::arange(start, end, dtype, layout, c10::Device(c10::kMeta), pin_memory); return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; } std::vector compute_shape_arange( const at::Scalar &start, const at::Scalar &end, const at::Scalar &step, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { auto out_meta = at::arange(start, end, step, dtype, layout, c10::Device(c10::kMeta), pin_memory); return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; } std::vector compute_shape_full( at::IntArrayRef size, const at::Scalar &fill_value, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { return { Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } std::vector compute_shape_ones(at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { return { Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } std::vector compute_shape_zeros(at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { return { Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } std::vector compute_shape_empty(at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { return { Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } std::vector compute_shape_empty_strided( at::IntArrayRef size, at::IntArrayRef stride, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { return { Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } std::vector compute_shape_fill(const at::Tensor &self, const at::Scalar &value) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_fill(const at::Tensor &self, const at::Tensor &value) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_randn(at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { return { Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } std::vector compute_shape_randint( int64_t high, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { return { Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } std::vector compute_shape_randint( int64_t low, int64_t high, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { return { Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } std::vector compute_shape_resize(const at::Tensor &self, at::IntArrayRef size, ::std::optional memory_format) { return {Shape(self.scalar_type(), size.vec())}; } std::vector compute_shape_bernoulli(const at::Tensor &self, const at::Tensor &p, ::std::optional generator) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_scalar_tensor( const at::Scalar &s, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { return {Shape(dtype.value_or(s.type()), c10::ArrayRef{})}; } std::vector compute_shape_roll(const at::Tensor &self, at::IntArrayRef shifts, at::IntArrayRef dims) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_linspace( const at::Scalar &start, const at::Scalar &end, int64_t steps, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { auto out_meta = at::linspace(start, end, steps, dtype, layout, c10::Device(c10::kMeta), pin_memory); return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; } } // namespace lazy } // namespace torch