Add ODS for group_norm

- Add ODS for native_group_norm/backward.
- Add shape-inference for native_group_norm/backward .

Signed-off-by: rahul shrivastava <rahul.shrivastava@cerebras.net>
pull/2120/head
rahul shrivastava 2023-05-05 04:24:08 -07:00 committed by rahuls-cerebras
parent 86718cb203
commit 40a2c501a1
3 changed files with 124 additions and 0 deletions

View File

@ -4656,6 +4656,38 @@ def Torch_AtenBatchNormOp : Torch_Op<"aten.batch_norm", [
}];
}
def Torch_AtenNativeGroupNormOp : Torch_Op<"aten.native_group_norm", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchOptionalTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
Torch_IntType:$N,
Torch_IntType:$C,
Torch_IntType:$HxW,
Torch_IntType:$group,
Torch_FloatType:$eps
);
let results = (outs
AnyTorchTensorType:$result0,
AnyTorchTensorType:$result1,
AnyTorchTensorType:$result2
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenNativeGroupNormOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 8, 3);
}
void AtenNativeGroupNormOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 8, 3);
}
}];
}
def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
AllowsTypeRefinement,
HasValueSemantics,
@ -11066,6 +11098,40 @@ def Torch_AtenNativeBatchNormBackwardOp : Torch_Op<"aten.native_batch_norm_backw
}];
}
def Torch_AtenNativeGroupNormBackwardOp : Torch_Op<"aten.native_group_norm_backward", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::native_group_norm_backward : (Tensor, Tensor, Tensor, Tensor, Tensor?, int, int, int, int, bool[]) -> (Tensor, Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$grad_out,
AnyTorchTensorType:$input,
AnyTorchTensorType:$mean,
AnyTorchTensorType:$rstd,
AnyTorchOptionalTensorType:$weight,
Torch_IntType:$N,
Torch_IntType:$C,
Torch_IntType:$HxW,
Torch_IntType:$group,
AnyTorchListOfTorchBoolType:$output_mask
);
let results = (outs
AnyTorchTensorType:$result0,
AnyTorchTensorType:$result1,
AnyTorchTensorType:$result2
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenNativeGroupNormBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 10, 3);
}
void AtenNativeGroupNormBackwardOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 10, 3);
}
}];
}
def Torch_AtenNativeDropoutBackwardOp : Torch_Op<"aten.native_dropout_backward", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -63,5 +63,59 @@ std::vector<torch::lazy::Shape> compute_shape_copy(
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_native_group_norm(
const at::Tensor& input,
const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias,
int64_t N, int64_t C, int64_t HxW,
int64_t group, double eps) {
TORCH_CHECK(
input.sizes().size() >= 2,
"Input tensor must have at least batch and channel dimensions!");
std::vector<torch::lazy::Shape> shapes;
shapes.reserve(3);
shapes.emplace_back(input.scalar_type(), input.sizes().vec());
// A separate mean and var needs to be kept for each group per N.
shapes.emplace_back(
at::get_default_dtype_as_scalartype(),
std::vector<int64_t>{N, group});
shapes.emplace_back(
at::get_default_dtype_as_scalartype(),
std::vector<int64_t>{N, group});
return shapes;
}
std::vector<torch::lazy::Shape> compute_shape_native_group_norm_backward(
const at::Tensor& grad_out,
const at::Tensor& input,
const at::Tensor& mean,
const at::Tensor& rstd,
const c10::optional<at::Tensor>& weight,
int64_t N, int64_t C, int64_t HxW,
int64_t group, ::std::array<bool, 3> output_mask) {
TORCH_CHECK(
input.sizes().size() >= 2,
"Input tensor must have at least batch and channel dimensions!");
std::vector<torch::lazy::Shape> shapes;
shapes.reserve(3);
shapes.emplace_back(input.scalar_type(), input.sizes().vec());
int64_t num_features = input.size(1);
// `weight` and `bias` are vectors of length C (number of channels)`
shapes.emplace_back(
at::get_default_dtype_as_scalartype(),
std::vector<int64_t>{num_features});
shapes.emplace_back(
at::get_default_dtype_as_scalartype(),
std::vector<int64_t>{num_features});
return shapes;
}
} // namespace lazy
} // namespace torch

View File

@ -373,6 +373,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit(
"aten::batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)"
)
emit(
"aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)"
)
emit(
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
)
@ -662,6 +665,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::native_layer_norm_backward : (Tensor, Tensor, int[], Tensor, Tensor, Tensor?, Tensor?, bool[]) -> (Tensor, Tensor, Tensor)")
emit("aten::embedding_dense_backward : (Tensor, Tensor, int, int, bool) -> (Tensor)")
emit("aten::native_batch_norm_backward : (Tensor, Tensor, Tensor?, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, bool[]) -> (Tensor, Tensor, Tensor)")
emit("aten::native_group_norm_backward : (Tensor, Tensor, Tensor, Tensor, Tensor?, int, int, int, int, bool[]) -> (Tensor, Tensor, Tensor)")
emit("aten::native_dropout_backward : (Tensor, Tensor, float) -> (Tensor)")
emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)")