From 6f7d9e83dfefc61da9ca623441a3cd0aebc549a4 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Thu, 18 May 2023 00:04:40 +0800 Subject: [PATCH] [Stablehlo] add e2e test for aten.batch_norm (#2129) --- e2e_testing/xfail_sets.py | 8 ++++++ lib/Conversion/TorchToStablehlo/Basic.cpp | 13 +++++----- lib/InitAll.cpp | 1 + .../linalg_on_tensors_backends/refbackend.py | 1 + .../stablehlo_backends/linalg_on_tensors.py | 2 +- .../test_suite/norm_like.py | 26 +++++++++++++++++++ tools/torch-mlir-opt/torch-mlir-opt.cpp | 8 ++++++ 7 files changed, 52 insertions(+), 7 deletions(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 1fc89ca5f..3cb5b2e2f 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -213,6 +213,7 @@ TORCHDYNAMO_XFAIL_SET = { "BatchNorm1DWith2DInputModule_basic", "BatchNorm2DModule_basic", "BatchNorm3DModule_basic", + "BatchNorm1DStaticShapeModule_basic", "ElementwiseAddScalarFloatModule_basic", "ElementwiseAddScalarInt64Module_basic", "ElementwiseAddScalarIntModule_basic", @@ -317,6 +318,12 @@ STABLEHLO_PASS_SET = { "ArangeStartStepIntModule_basic", "ArangeZeroElementOutputModule_basic", "BatchMlpLayerModule_basic", + "BatchNorm1DModule_basic", + "BatchNorm1DWith2DInputModule_basic", + "BatchNorm2DModule_basic", + "BatchNorm3DModule_basic", + "BatchNorm1DStaticShapeModule_basic", + "ResNet18StaticModule_basic", "BmmModule_basic", "BroadcastToModule_basic", "BroadcastToSameRankStaticModule_basic", @@ -805,6 +812,7 @@ TOSA_PASS_SET = { "BatchNorm1DWith2DInputModule_basic", "BatchNorm2DModule_basic", "BatchNorm3DModule_basic", + "BatchNorm1DStaticShapeModule_basic", "FlattenStaticModule_basic", "FlattenRank0Module_basic", "ElementwiseFlattenBroadcastModule_basic", diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 8e240a745..0b5db344b 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -931,10 +931,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value momentum = adaptor.getMomentum(); (void)momentum; - if (inputTy.getRank() <= 2) { - return rewriter.notifyMatchFailure(op, - "input should have rank larger than 2"); - } + // handle feature index, see torch's BatchNorm1d, BatchNorm2d, BatchNorm3d, + // all of NC, NCL, NCHW, NCDHW's feature index is 1. + int64_t feature_index = 1; + if (!inputTy.getElementType().template isa()) { return op.emitError("only input tensor of float type is supported"); } @@ -1020,7 +1020,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.create( op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input, weight, bias, rewriter.getF32FloatAttr(eps), - rewriter.getI64IntegerAttr(1)); + rewriter.getI64IntegerAttr(feature_index)); rewriter.replaceOp(op, batchNormTrainingResult.getResult(0)); return success(); } else { @@ -1037,7 +1037,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op.getLoc(), inputCasted.getType(), inputCasted, weight, bias, runningMean, runningVar, // 'epsilon' must satisfy constraint: 32-bit float attribute. - rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(1)); + rewriter.getF32FloatAttr(eps), + rewriter.getI64IntegerAttr(feature_index)); rewriter.replaceOpWithNewOp(op, outputTy, output); return success(); } diff --git a/lib/InitAll.cpp b/lib/InitAll.cpp index 87a2b8f39..43b45d32e 100644 --- a/lib/InitAll.cpp +++ b/lib/InitAll.cpp @@ -44,5 +44,6 @@ void mlir::torch::registerAllPasses() { mlir::mhlo::registerStablehloLegalizeToHloPass(); mlir::mhlo::registerChloLegalizeToHloPass(); mlir::mhlo::registerHloLegalizeToLinalgPass(); + mlir::mhlo::registerTestUnfuseBatchNormPass(); #endif // TORCH_MLIR_ENABLE_STABLEHLO } diff --git a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index 66762fdf9..f4c4e5176 100644 --- a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -131,6 +131,7 @@ LOWERING_PIPELINE = "builtin.module(" + ",".join([ # emit things in that form from the high level (e.g. single linalg-generic). # Other backends are likely to benefit more. "func.func(linalg-fuse-elementwise-ops)", + "convert-shape-to-std", # Bufferize. "func.func(scf-bufferize)", "func.func(tm-tensor-bufferize)", diff --git a/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py b/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py index b285c46b8..6a36dd196 100644 --- a/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py +++ b/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py @@ -40,7 +40,7 @@ class LinalgOnTensorsStablehloBackend(StablehloBackend): """ run_pipeline_with_repro_report( imported_module, - "builtin.module(func.func(chlo-legalize-to-hlo),stablehlo-legalize-to-hlo,func.func(canonicalize,cse,symbolic-shape-optimization,hlo-legalize-to-linalg,canonicalize))", + "builtin.module(func.func(chlo-legalize-to-hlo),stablehlo-legalize-to-hlo,func.func(canonicalize,cse,symbolic-shape-optimization,mhlo-test-unfuse-batch-norm,canonicalize,hlo-legalize-to-linalg,canonicalize))", "Lowering StableHLO to Linalg-on-Tensors", ) return self.refbackend.compile(imported_module) diff --git a/python/torch_mlir_e2e_test/test_suite/norm_like.py b/python/torch_mlir_e2e_test/test_suite/norm_like.py index 38ad7bebd..f59695620 100644 --- a/python/torch_mlir_e2e_test/test_suite/norm_like.py +++ b/python/torch_mlir_e2e_test/test_suite/norm_like.py @@ -115,6 +115,32 @@ def BatchNorm3DModule_basic(module, tu: TestUtils): # ============================================================================== +class BatchNorm1DStaticShapeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 5], torch.float32, True), + ([5], torch.float32, True), + ([5], torch.float32, True), + ([5], torch.float32, True), + ([5], torch.float32, True), + ]) + def forward(self, x, weight, bias, running_mean, running_var): + return torch.ops.aten.batch_norm( + x, weight, bias, running_mean, running_var, training=False, + momentum=0.1, eps=0.00001, cudnn_enabled=False) + + +@register_test_case(module_factory=lambda: BatchNorm1DStaticShapeModule()) +def BatchNorm1DStaticShapeModule_basic(module, tu: TestUtils): + module.forward( + tu.rand(2, 5), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5)) + +# ============================================================================== + class NativeBatchNorm1DModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/tools/torch-mlir-opt/torch-mlir-opt.cpp b/tools/torch-mlir-opt/torch-mlir-opt.cpp index c2e975fed..af76cc56d 100644 --- a/tools/torch-mlir-opt/torch-mlir-opt.cpp +++ b/tools/torch-mlir-opt/torch-mlir-opt.cpp @@ -13,6 +13,8 @@ #include "torch-mlir/InitAll.h" #ifdef TORCH_MLIR_ENABLE_STABLEHLO +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/passes.h" #include "stablehlo/dialect/Register.h" #endif @@ -28,6 +30,12 @@ int main(int argc, char **argv) { #ifdef TORCH_MLIR_ENABLE_STABLEHLO mlir::stablehlo::registerAllDialects(registry); + registry.insert(); + mlir::mhlo::registerSymbolicShapeOptimizationPass(); + mlir::mhlo::registerStablehloLegalizeToHloPass(); + mlir::mhlo::registerChloLegalizeToHloPass(); + mlir::mhlo::registerHloLegalizeToLinalgPass(); + mlir::mhlo::registerTestUnfuseBatchNormPass(); #endif return mlir::asMainReturnCode(mlir::MlirOptMain( argc, argv, "MLIR modular optimizer driver\n", registry));