Upstream native_batch_norm and native_batch_norm_backward shape inference functions (#978)

* Removed compute_shape_native_batch_norm

* Removed compute_shape_native_batch_norm_backward
pull/1125/head
Henry Tu 2022-06-24 19:30:45 -04:00 committed by Henry Tu
parent 0cee0dc978
commit 1510eae75d
1 changed files with 0 additions and 65 deletions

View File

@ -29,71 +29,6 @@ compute_shape_mul(const at::Tensor& self, const at::Scalar& other) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_native_batch_norm(
const at::Tensor& input, const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias,
const c10::optional<at::Tensor>& running_mean,
const c10::optional<at::Tensor>& running_var, bool training,
double momentum, double eps) {
std::vector<torch::lazy::Shape> shapes;
shapes.reserve(3);
shapes.emplace_back(input.scalar_type(), input.sizes().vec());
// A separate mean and var needs to be kept for each channel.
TORCH_CHECK(
input.sizes().size() >= 2,
"Input tensor must have at least batch and channel dimensions!");
int64_t num_features = input.sizes().vec()[1];
if (running_mean.has_value()) {
shapes.emplace_back(
running_mean.value().scalar_type(), running_mean.value().sizes().vec());
} else {
shapes.emplace_back(
at::get_default_dtype_as_scalartype(),
std::vector<int64_t>{num_features});
}
if (running_var.has_value()) {
shapes.emplace_back(
running_var.value().scalar_type(), running_var.value().sizes().vec());
} else {
shapes.emplace_back(
at::get_default_dtype_as_scalartype(),
std::vector<int64_t>{num_features});
}
return shapes;
}
std::vector<torch::lazy::Shape> compute_shape_native_batch_norm_backward(
const at::Tensor& grad_out, const at::Tensor& input,
const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& running_mean,
const c10::optional<at::Tensor>& running_var,
const c10::optional<at::Tensor>& save_mean,
const c10::optional<at::Tensor>& save_invstd, bool train, double eps,
::std::array<bool, 3> output_mask) {
std::vector<torch::lazy::Shape> shapes;
shapes.reserve(3);
shapes.emplace_back(input.scalar_type(), input.sizes().vec());
// A separate mean and var needs to be kept for each channel.
TORCH_CHECK(
input.sizes().size() >= 2,
"Input tensor must have at least batch and channel dimensions!");
int64_t num_features = input.sizes().vec()[1];
// `weight` and `bias` are vectors of length C (number of channels)`
shapes.emplace_back(
at::get_default_dtype_as_scalartype(),
std::vector<int64_t>{num_features});
shapes.emplace_back(
at::get_default_dtype_as_scalartype(),
std::vector<int64_t>{num_features});
return shapes;
}
std::vector<torch::lazy::Shape> compute_shape_new_empty(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
if (dtype.has_value()) {
return {Shape(*dtype, size)};