mirror of https://github.com/llvm/torch-mlir
Add ODS for group_norm
- Add ODS for native_group_norm/backward. - Add shape-inference for native_group_norm/backward . Signed-off-by: rahul shrivastava <rahul.shrivastava@cerebras.net>pull/2120/head
parent
86718cb203
commit
40a2c501a1
|
@ -4656,6 +4656,38 @@ def Torch_AtenBatchNormOp : Torch_Op<"aten.batch_norm", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenNativeGroupNormOp : Torch_Op<"aten.native_group_norm", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$input,
|
||||
AnyTorchOptionalTensorType:$weight,
|
||||
AnyTorchOptionalTensorType:$bias,
|
||||
Torch_IntType:$N,
|
||||
Torch_IntType:$C,
|
||||
Torch_IntType:$HxW,
|
||||
Torch_IntType:$group,
|
||||
Torch_FloatType:$eps
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result0,
|
||||
AnyTorchTensorType:$result1,
|
||||
AnyTorchTensorType:$result2
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenNativeGroupNormOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 8, 3);
|
||||
}
|
||||
void AtenNativeGroupNormOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 8, 3);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -11066,6 +11098,40 @@ def Torch_AtenNativeBatchNormBackwardOp : Torch_Op<"aten.native_batch_norm_backw
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenNativeGroupNormBackwardOp : Torch_Op<"aten.native_group_norm_backward", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::native_group_norm_backward : (Tensor, Tensor, Tensor, Tensor, Tensor?, int, int, int, int, bool[]) -> (Tensor, Tensor, Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$grad_out,
|
||||
AnyTorchTensorType:$input,
|
||||
AnyTorchTensorType:$mean,
|
||||
AnyTorchTensorType:$rstd,
|
||||
AnyTorchOptionalTensorType:$weight,
|
||||
Torch_IntType:$N,
|
||||
Torch_IntType:$C,
|
||||
Torch_IntType:$HxW,
|
||||
Torch_IntType:$group,
|
||||
AnyTorchListOfTorchBoolType:$output_mask
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result0,
|
||||
AnyTorchTensorType:$result1,
|
||||
AnyTorchTensorType:$result2
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenNativeGroupNormBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 10, 3);
|
||||
}
|
||||
void AtenNativeGroupNormBackwardOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 10, 3);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenNativeDropoutBackwardOp : Torch_Op<"aten.native_dropout_backward", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -63,5 +63,59 @@ std::vector<torch::lazy::Shape> compute_shape_copy(
|
|||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_native_group_norm(
|
||||
const at::Tensor& input,
|
||||
const c10::optional<at::Tensor>& weight,
|
||||
const c10::optional<at::Tensor>& bias,
|
||||
int64_t N, int64_t C, int64_t HxW,
|
||||
int64_t group, double eps) {
|
||||
|
||||
TORCH_CHECK(
|
||||
input.sizes().size() >= 2,
|
||||
"Input tensor must have at least batch and channel dimensions!");
|
||||
std::vector<torch::lazy::Shape> shapes;
|
||||
shapes.reserve(3);
|
||||
shapes.emplace_back(input.scalar_type(), input.sizes().vec());
|
||||
|
||||
// A separate mean and var needs to be kept for each group per N.
|
||||
shapes.emplace_back(
|
||||
at::get_default_dtype_as_scalartype(),
|
||||
std::vector<int64_t>{N, group});
|
||||
|
||||
shapes.emplace_back(
|
||||
at::get_default_dtype_as_scalartype(),
|
||||
std::vector<int64_t>{N, group});
|
||||
return shapes;
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_native_group_norm_backward(
|
||||
const at::Tensor& grad_out,
|
||||
const at::Tensor& input,
|
||||
const at::Tensor& mean,
|
||||
const at::Tensor& rstd,
|
||||
const c10::optional<at::Tensor>& weight,
|
||||
int64_t N, int64_t C, int64_t HxW,
|
||||
int64_t group, ::std::array<bool, 3> output_mask) {
|
||||
|
||||
TORCH_CHECK(
|
||||
input.sizes().size() >= 2,
|
||||
"Input tensor must have at least batch and channel dimensions!");
|
||||
std::vector<torch::lazy::Shape> shapes;
|
||||
shapes.reserve(3);
|
||||
shapes.emplace_back(input.scalar_type(), input.sizes().vec());
|
||||
|
||||
int64_t num_features = input.size(1);
|
||||
|
||||
// `weight` and `bias` are vectors of length C (number of channels)`
|
||||
shapes.emplace_back(
|
||||
at::get_default_dtype_as_scalartype(),
|
||||
std::vector<int64_t>{num_features});
|
||||
shapes.emplace_back(
|
||||
at::get_default_dtype_as_scalartype(),
|
||||
std::vector<int64_t>{num_features});
|
||||
|
||||
return shapes;
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
|
|
|
@ -373,6 +373,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit(
|
||||
"aten::batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
|
||||
)
|
||||
|
@ -662,6 +665,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::native_layer_norm_backward : (Tensor, Tensor, int[], Tensor, Tensor, Tensor?, Tensor?, bool[]) -> (Tensor, Tensor, Tensor)")
|
||||
emit("aten::embedding_dense_backward : (Tensor, Tensor, int, int, bool) -> (Tensor)")
|
||||
emit("aten::native_batch_norm_backward : (Tensor, Tensor, Tensor?, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, bool[]) -> (Tensor, Tensor, Tensor)")
|
||||
emit("aten::native_group_norm_backward : (Tensor, Tensor, Tensor, Tensor, Tensor?, int, int, int, int, bool[]) -> (Tensor, Tensor, Tensor)")
|
||||
emit("aten::native_dropout_backward : (Tensor, Tensor, float) -> (Tensor)")
|
||||
emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)")
|
||||
|
||||
|
|
Loading…
Reference in New Issue