diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 357f95fd2..bad41668e 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4656,6 +4656,38 @@ def Torch_AtenBatchNormOp : Torch_Op<"aten.batch_norm", [ }]; } +def Torch_AtenNativeGroupNormOp : Torch_Op<"aten.native_group_norm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchOptionalTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + Torch_IntType:$N, + Torch_IntType:$C, + Torch_IntType:$HxW, + Torch_IntType:$group, + Torch_FloatType:$eps + ); + let results = (outs + AnyTorchTensorType:$result0, + AnyTorchTensorType:$result1, + AnyTorchTensorType:$result2 + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNativeGroupNormOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 8, 3); + } + void AtenNativeGroupNormOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 8, 3); + } + }]; +} + def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [ AllowsTypeRefinement, HasValueSemantics, @@ -11066,6 +11098,40 @@ def Torch_AtenNativeBatchNormBackwardOp : Torch_Op<"aten.native_batch_norm_backw }]; } +def Torch_AtenNativeGroupNormBackwardOp : Torch_Op<"aten.native_group_norm_backward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::native_group_norm_backward : (Tensor, Tensor, Tensor, Tensor, Tensor?, int, int, int, int, bool[]) -> (Tensor, Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$grad_out, + AnyTorchTensorType:$input, + AnyTorchTensorType:$mean, + AnyTorchTensorType:$rstd, + AnyTorchOptionalTensorType:$weight, + Torch_IntType:$N, + Torch_IntType:$C, + Torch_IntType:$HxW, + Torch_IntType:$group, + AnyTorchListOfTorchBoolType:$output_mask + ); + let results = (outs + AnyTorchTensorType:$result0, + AnyTorchTensorType:$result1, + AnyTorchTensorType:$result2 + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNativeGroupNormBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 10, 3); + } + void AtenNativeGroupNormBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 10, 3); + } + }]; +} + def Torch_AtenNativeDropoutBackwardOp : Torch_Op<"aten.native_dropout_backward", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp b/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp index 3e52c20c2..760194a22 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp @@ -63,5 +63,59 @@ std::vector compute_shape_copy( return {Shape(self.scalar_type(), self.sizes().vec())}; } +std::vector compute_shape_native_group_norm( + const at::Tensor& input, + const c10::optional& weight, + const c10::optional& bias, + int64_t N, int64_t C, int64_t HxW, + int64_t group, double eps) { + + TORCH_CHECK( + input.sizes().size() >= 2, + "Input tensor must have at least batch and channel dimensions!"); + std::vector shapes; + shapes.reserve(3); + shapes.emplace_back(input.scalar_type(), input.sizes().vec()); + + // A separate mean and var needs to be kept for each group per N. + shapes.emplace_back( + at::get_default_dtype_as_scalartype(), + std::vector{N, group}); + + shapes.emplace_back( + at::get_default_dtype_as_scalartype(), + std::vector{N, group}); + return shapes; +} + +std::vector compute_shape_native_group_norm_backward( + const at::Tensor& grad_out, + const at::Tensor& input, + const at::Tensor& mean, + const at::Tensor& rstd, + const c10::optional& weight, + int64_t N, int64_t C, int64_t HxW, + int64_t group, ::std::array output_mask) { + + TORCH_CHECK( + input.sizes().size() >= 2, + "Input tensor must have at least batch and channel dimensions!"); + std::vector shapes; + shapes.reserve(3); + shapes.emplace_back(input.scalar_type(), input.sizes().vec()); + + int64_t num_features = input.size(1); + + // `weight` and `bias` are vectors of length C (number of channels)` + shapes.emplace_back( + at::get_default_dtype_as_scalartype(), + std::vector{num_features}); + shapes.emplace_back( + at::get_default_dtype_as_scalartype(), + std::vector{num_features}); + + return shapes; +} + } // namespace lazy } // namespace torch diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 8d36f5645..f3bd937cc 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -373,6 +373,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit( "aten::batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)" ) + emit( + "aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)" + ) emit( "aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)" ) @@ -662,6 +665,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::native_layer_norm_backward : (Tensor, Tensor, int[], Tensor, Tensor, Tensor?, Tensor?, bool[]) -> (Tensor, Tensor, Tensor)") emit("aten::embedding_dense_backward : (Tensor, Tensor, int, int, bool) -> (Tensor)") emit("aten::native_batch_norm_backward : (Tensor, Tensor, Tensor?, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, bool[]) -> (Tensor, Tensor, Tensor)") + emit("aten::native_group_norm_backward : (Tensor, Tensor, Tensor, Tensor, Tensor?, int, int, int, int, bool[]) -> (Tensor, Tensor, Tensor)") emit("aten::native_dropout_backward : (Tensor, Tensor, float) -> (Tensor)") emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)")