[Stablehlo] add e2e test for aten.batch_norm (#2129)

pull/2133/head snapshot-20230518.842
Yuanqiang Liu 2023-05-18 00:04:40 +08:00 committed by GitHub
parent e98f2ba04a
commit 6f7d9e83df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 52 additions and 7 deletions

View File

@ -213,6 +213,7 @@ TORCHDYNAMO_XFAIL_SET = {
"BatchNorm1DWith2DInputModule_basic",
"BatchNorm2DModule_basic",
"BatchNorm3DModule_basic",
"BatchNorm1DStaticShapeModule_basic",
"ElementwiseAddScalarFloatModule_basic",
"ElementwiseAddScalarInt64Module_basic",
"ElementwiseAddScalarIntModule_basic",
@ -317,6 +318,12 @@ STABLEHLO_PASS_SET = {
"ArangeStartStepIntModule_basic",
"ArangeZeroElementOutputModule_basic",
"BatchMlpLayerModule_basic",
"BatchNorm1DModule_basic",
"BatchNorm1DWith2DInputModule_basic",
"BatchNorm2DModule_basic",
"BatchNorm3DModule_basic",
"BatchNorm1DStaticShapeModule_basic",
"ResNet18StaticModule_basic",
"BmmModule_basic",
"BroadcastToModule_basic",
"BroadcastToSameRankStaticModule_basic",
@ -805,6 +812,7 @@ TOSA_PASS_SET = {
"BatchNorm1DWith2DInputModule_basic",
"BatchNorm2DModule_basic",
"BatchNorm3DModule_basic",
"BatchNorm1DStaticShapeModule_basic",
"FlattenStaticModule_basic",
"FlattenRank0Module_basic",
"ElementwiseFlattenBroadcastModule_basic",

View File

@ -931,10 +931,10 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
Value momentum = adaptor.getMomentum();
(void)momentum;
if (inputTy.getRank() <= 2) {
return rewriter.notifyMatchFailure(op,
"input should have rank larger than 2");
}
// handle feature index, see torch's BatchNorm1d, BatchNorm2d, BatchNorm3d,
// all of NC, NCL, NCHW, NCDHW's feature index is 1.
int64_t feature_index = 1;
if (!inputTy.getElementType().template isa<mlir::FloatType>()) {
return op.emitError("only input tensor of float type is supported");
}
@ -1020,7 +1020,7 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
rewriter.create<stablehlo::BatchNormTrainingOp>(
op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input,
weight, bias, rewriter.getF32FloatAttr(eps),
rewriter.getI64IntegerAttr(1));
rewriter.getI64IntegerAttr(feature_index));
rewriter.replaceOp(op, batchNormTrainingResult.getResult(0));
return success();
} else {
@ -1037,7 +1037,8 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
op.getLoc(), inputCasted.getType(), inputCasted, weight, bias,
runningMean, runningVar,
// 'epsilon' must satisfy constraint: 32-bit float attribute.
rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(1));
rewriter.getF32FloatAttr(eps),
rewriter.getI64IntegerAttr(feature_index));
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outputTy, output);
return success();
}

View File

@ -44,5 +44,6 @@ void mlir::torch::registerAllPasses() {
mlir::mhlo::registerStablehloLegalizeToHloPass();
mlir::mhlo::registerChloLegalizeToHloPass();
mlir::mhlo::registerHloLegalizeToLinalgPass();
mlir::mhlo::registerTestUnfuseBatchNormPass();
#endif // TORCH_MLIR_ENABLE_STABLEHLO
}

View File

@ -131,6 +131,7 @@ LOWERING_PIPELINE = "builtin.module(" + ",".join([
# emit things in that form from the high level (e.g. single linalg-generic).
# Other backends are likely to benefit more.
"func.func(linalg-fuse-elementwise-ops)",
"convert-shape-to-std",
# Bufferize.
"func.func(scf-bufferize)",
"func.func(tm-tensor-bufferize)",

View File

@ -40,7 +40,7 @@ class LinalgOnTensorsStablehloBackend(StablehloBackend):
"""
run_pipeline_with_repro_report(
imported_module,
"builtin.module(func.func(chlo-legalize-to-hlo),stablehlo-legalize-to-hlo,func.func(canonicalize,cse,symbolic-shape-optimization,hlo-legalize-to-linalg,canonicalize))",
"builtin.module(func.func(chlo-legalize-to-hlo),stablehlo-legalize-to-hlo,func.func(canonicalize,cse,symbolic-shape-optimization,mhlo-test-unfuse-batch-norm,canonicalize,hlo-legalize-to-linalg,canonicalize))",
"Lowering StableHLO to Linalg-on-Tensors",
)
return self.refbackend.compile(imported_module)

View File

@ -115,6 +115,32 @@ def BatchNorm3DModule_basic(module, tu: TestUtils):
# ==============================================================================
class BatchNorm1DStaticShapeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([2, 5], torch.float32, True),
([5], torch.float32, True),
([5], torch.float32, True),
([5], torch.float32, True),
([5], torch.float32, True),
])
def forward(self, x, weight, bias, running_mean, running_var):
return torch.ops.aten.batch_norm(
x, weight, bias, running_mean, running_var, training=False,
momentum=0.1, eps=0.00001, cudnn_enabled=False)
@register_test_case(module_factory=lambda: BatchNorm1DStaticShapeModule())
def BatchNorm1DStaticShapeModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(2, 5), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5))
# ==============================================================================
class NativeBatchNorm1DModule(torch.nn.Module):
def __init__(self):
super().__init__()

View File

@ -13,6 +13,8 @@
#include "torch-mlir/InitAll.h"
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
#include "mhlo/IR/hlo_ops.h"
#include "mhlo/transforms/passes.h"
#include "stablehlo/dialect/Register.h"
#endif
@ -28,6 +30,12 @@ int main(int argc, char **argv) {
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
mlir::stablehlo::registerAllDialects(registry);
registry.insert<mlir::mhlo::MhloDialect>();
mlir::mhlo::registerSymbolicShapeOptimizationPass();
mlir::mhlo::registerStablehloLegalizeToHloPass();
mlir::mhlo::registerChloLegalizeToHloPass();
mlir::mhlo::registerHloLegalizeToLinalgPass();
mlir::mhlo::registerTestUnfuseBatchNormPass();
#endif
return mlir::asMainReturnCode(mlir::MlirOptMain(
argc, argv, "MLIR modular optimizer driver\n", registry));