From 1510eae75d51cdf2a770875ce03a0cd1a217fd07 Mon Sep 17 00:00:00 2001 From: Henry Tu Date: Fri, 24 Jun 2022 19:30:45 -0400 Subject: [PATCH] Upstream native_batch_norm and native_batch_norm_backward shape inference functions (#978) * Removed compute_shape_native_batch_norm * Removed compute_shape_native_batch_norm_backward --- .../base_lazy_backend/shape_inference.cpp | 65 ------------------- 1 file changed, 65 deletions(-) diff --git a/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp b/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp index 4ef9029aa..96cd7b3a9 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp @@ -29,71 +29,6 @@ compute_shape_mul(const at::Tensor& self, const at::Scalar& other) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_native_batch_norm( - const at::Tensor& input, const c10::optional& weight, - const c10::optional& bias, - const c10::optional& running_mean, - const c10::optional& running_var, bool training, - double momentum, double eps) { - std::vector shapes; - shapes.reserve(3); - shapes.emplace_back(input.scalar_type(), input.sizes().vec()); - - // A separate mean and var needs to be kept for each channel. - TORCH_CHECK( - input.sizes().size() >= 2, - "Input tensor must have at least batch and channel dimensions!"); - int64_t num_features = input.sizes().vec()[1]; - - if (running_mean.has_value()) { - shapes.emplace_back( - running_mean.value().scalar_type(), running_mean.value().sizes().vec()); - } else { - shapes.emplace_back( - at::get_default_dtype_as_scalartype(), - std::vector{num_features}); - } - - if (running_var.has_value()) { - shapes.emplace_back( - running_var.value().scalar_type(), running_var.value().sizes().vec()); - } else { - shapes.emplace_back( - at::get_default_dtype_as_scalartype(), - std::vector{num_features}); - } - return shapes; -} - -std::vector compute_shape_native_batch_norm_backward( - const at::Tensor& grad_out, const at::Tensor& input, - const c10::optional& weight, - const c10::optional& running_mean, - const c10::optional& running_var, - const c10::optional& save_mean, - const c10::optional& save_invstd, bool train, double eps, - ::std::array output_mask) { - std::vector shapes; - shapes.reserve(3); - shapes.emplace_back(input.scalar_type(), input.sizes().vec()); - - // A separate mean and var needs to be kept for each channel. - TORCH_CHECK( - input.sizes().size() >= 2, - "Input tensor must have at least batch and channel dimensions!"); - int64_t num_features = input.sizes().vec()[1]; - - // `weight` and `bias` are vectors of length C (number of channels)` - shapes.emplace_back( - at::get_default_dtype_as_scalartype(), - std::vector{num_features}); - shapes.emplace_back( - at::get_default_dtype_as_scalartype(), - std::vector{num_features}); - - return shapes; -} - std::vector compute_shape_new_empty(const at::Tensor & self, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) { if (dtype.has_value()) { return {Shape(*dtype, size)};