From a635fd228775e676c8cc113ec9fad6412b77bfb9 Mon Sep 17 00:00:00 2001 From: Henry Tu Date: Fri, 3 Jun 2022 13:49:02 -0400 Subject: [PATCH] Added support for native_batch_norm_backward (#890) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 35 ++++++++++++++++++- .../jit_ir/build_tools/torch_ods_gen.py | 1 + 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 9ecb9e324..d2027cbd0 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7514,6 +7514,40 @@ def Torch_AtenEmbeddingDenseBackwardOp : Torch_Op<"aten.embedding_dense_backward }]; } +def Torch_AtenNativeBatchNormBackwardOp : Torch_Op<"aten.native_batch_norm_backward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::native_batch_norm_backward : (Tensor, Tensor, Tensor?, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, bool[]) -> (Tensor, Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$grad_out, + AnyTorchTensorType:$input, + AnyTorchOptionalTensorType:$weight, + AnyTorchOptionalTensorType:$running_mean, + AnyTorchOptionalTensorType:$running_var, + AnyTorchOptionalTensorType:$save_mean, + AnyTorchOptionalTensorType:$save_invstd, + Torch_BoolType:$train, + Torch_FloatType:$eps, + AnyTorchListOfTorchBoolType:$output_mask + ); + let results = (outs + AnyTorchTensorType:$result0, + AnyTorchTensorType:$result1, + AnyTorchTensorType:$result2 + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNativeBatchNormBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 10, 3); + } + void AtenNativeBatchNormBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 10, 3); + } + }]; +} + def Torch_PrimLayoutOp : Torch_Op<"prim.layout", [ AllowsTypeRefinement, HasValueSemantics, @@ -7890,4 +7924,3 @@ def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [ } }]; } - diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 730ca37fa..5dea74bea 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -538,6 +538,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::_log_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)") emit("aten::native_layer_norm_backward : (Tensor, Tensor, int[], Tensor, Tensor, Tensor?, Tensor?, bool[]) -> (Tensor, Tensor, Tensor)") emit("aten::embedding_dense_backward : (Tensor, Tensor, int, int, bool) -> (Tensor)") + emit("aten::native_batch_norm_backward : (Tensor, Tensor, Tensor?, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, bool[]) -> (Tensor, Tensor, Tensor)") # ========================================================================== # `prim::` namespace.