mirror of https://github.com/llvm/torch-mlir
Upstream native_batch_norm and native_batch_norm_backward shape inference functions (#978)
* Removed compute_shape_native_batch_norm * Removed compute_shape_native_batch_norm_backwardpull/1125/head
parent
0cee0dc978
commit
1510eae75d
|
@ -29,71 +29,6 @@ compute_shape_mul(const at::Tensor& self, const at::Scalar& other) {
|
|||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_native_batch_norm(
|
||||
const at::Tensor& input, const c10::optional<at::Tensor>& weight,
|
||||
const c10::optional<at::Tensor>& bias,
|
||||
const c10::optional<at::Tensor>& running_mean,
|
||||
const c10::optional<at::Tensor>& running_var, bool training,
|
||||
double momentum, double eps) {
|
||||
std::vector<torch::lazy::Shape> shapes;
|
||||
shapes.reserve(3);
|
||||
shapes.emplace_back(input.scalar_type(), input.sizes().vec());
|
||||
|
||||
// A separate mean and var needs to be kept for each channel.
|
||||
TORCH_CHECK(
|
||||
input.sizes().size() >= 2,
|
||||
"Input tensor must have at least batch and channel dimensions!");
|
||||
int64_t num_features = input.sizes().vec()[1];
|
||||
|
||||
if (running_mean.has_value()) {
|
||||
shapes.emplace_back(
|
||||
running_mean.value().scalar_type(), running_mean.value().sizes().vec());
|
||||
} else {
|
||||
shapes.emplace_back(
|
||||
at::get_default_dtype_as_scalartype(),
|
||||
std::vector<int64_t>{num_features});
|
||||
}
|
||||
|
||||
if (running_var.has_value()) {
|
||||
shapes.emplace_back(
|
||||
running_var.value().scalar_type(), running_var.value().sizes().vec());
|
||||
} else {
|
||||
shapes.emplace_back(
|
||||
at::get_default_dtype_as_scalartype(),
|
||||
std::vector<int64_t>{num_features});
|
||||
}
|
||||
return shapes;
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_native_batch_norm_backward(
|
||||
const at::Tensor& grad_out, const at::Tensor& input,
|
||||
const c10::optional<at::Tensor>& weight,
|
||||
const c10::optional<at::Tensor>& running_mean,
|
||||
const c10::optional<at::Tensor>& running_var,
|
||||
const c10::optional<at::Tensor>& save_mean,
|
||||
const c10::optional<at::Tensor>& save_invstd, bool train, double eps,
|
||||
::std::array<bool, 3> output_mask) {
|
||||
std::vector<torch::lazy::Shape> shapes;
|
||||
shapes.reserve(3);
|
||||
shapes.emplace_back(input.scalar_type(), input.sizes().vec());
|
||||
|
||||
// A separate mean and var needs to be kept for each channel.
|
||||
TORCH_CHECK(
|
||||
input.sizes().size() >= 2,
|
||||
"Input tensor must have at least batch and channel dimensions!");
|
||||
int64_t num_features = input.sizes().vec()[1];
|
||||
|
||||
// `weight` and `bias` are vectors of length C (number of channels)`
|
||||
shapes.emplace_back(
|
||||
at::get_default_dtype_as_scalartype(),
|
||||
std::vector<int64_t>{num_features});
|
||||
shapes.emplace_back(
|
||||
at::get_default_dtype_as_scalartype(),
|
||||
std::vector<int64_t>{num_features});
|
||||
|
||||
return shapes;
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_new_empty(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
|
||||
if (dtype.has_value()) {
|
||||
return {Shape(*dtype, size)};
|
||||
|
|
Loading…
Reference in New Issue