From c8b867b8760db151727fe9f1b71793d1d1d8a022 Mon Sep 17 00:00:00 2001 From: Gleb Kazantaev Date: Tue, 10 Jan 2023 13:08:25 -0500 Subject: [PATCH] Added support for aten::norm.ScalarOpt_dim (#1774) * Added support for aten::norm.ScalarOpt_dim * Disable NormalizeModule_basic for linalg --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 26 +++++++++++++++++++ .../jit_ir/build_tools/torch_ods_gen.py | 3 +++ .../test_suite/__init__.py | 1 + .../test_suite/norm_like.py | 19 ++++++++++++++ 4 files changed, 49 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 6d6b12d08..1ced47045 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4502,6 +4502,32 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [ }]; } +def Torch_AtenNormScalarOptDimOp : Torch_Op<"aten.norm.ScalarOpt_dim", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalScalarType:$p, + AnyTorchListOfTorchIntType:$dim, + Torch_BoolType:$keepdim + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNormScalarOptDimOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenNormScalarOptDimOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 18a09f182..97d821194 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -371,6 +371,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit( "aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)" ) + emit( + "aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)" + ) emit( "aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)" ) diff --git a/python/torch_mlir_e2e_test/test_suite/__init__.py b/python/torch_mlir_e2e_test/test_suite/__init__.py index c107ceb37..7d133a87f 100644 --- a/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -8,6 +8,7 @@ # to the backend contract. COMMON_TORCH_MLIR_LOWERING_XFAILS = { "QuantizedMLP_basic", + "NormalizeModule_basic", } def register_all_tests(): diff --git a/python/torch_mlir_e2e_test/test_suite/norm_like.py b/python/torch_mlir_e2e_test/test_suite/norm_like.py index e8b006c30..6ff76d1a4 100644 --- a/python/torch_mlir_e2e_test/test_suite/norm_like.py +++ b/python/torch_mlir_e2e_test/test_suite/norm_like.py @@ -261,6 +261,25 @@ def NativeLayerNormDynamicModule_basic(module, tu: TestUtils): # ============================================================================== +class NormalizeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 3], torch.float32, True), + ]) + def forward(self, x): + return torch.nn.functional.normalize(x) + + +@register_test_case(module_factory=lambda: NormalizeModule()) +def NormalizeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 3)) + +# ============================================================================== + class NativeLayerNormModule4D(torch.nn.Module): def __init__(self): super().__init__()