mirror of https://github.com/llvm/torch-mlir
parent
e98f2ba04a
commit
6f7d9e83df
|
@ -213,6 +213,7 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
"BatchNorm1DWith2DInputModule_basic",
|
||||
"BatchNorm2DModule_basic",
|
||||
"BatchNorm3DModule_basic",
|
||||
"BatchNorm1DStaticShapeModule_basic",
|
||||
"ElementwiseAddScalarFloatModule_basic",
|
||||
"ElementwiseAddScalarInt64Module_basic",
|
||||
"ElementwiseAddScalarIntModule_basic",
|
||||
|
@ -317,6 +318,12 @@ STABLEHLO_PASS_SET = {
|
|||
"ArangeStartStepIntModule_basic",
|
||||
"ArangeZeroElementOutputModule_basic",
|
||||
"BatchMlpLayerModule_basic",
|
||||
"BatchNorm1DModule_basic",
|
||||
"BatchNorm1DWith2DInputModule_basic",
|
||||
"BatchNorm2DModule_basic",
|
||||
"BatchNorm3DModule_basic",
|
||||
"BatchNorm1DStaticShapeModule_basic",
|
||||
"ResNet18StaticModule_basic",
|
||||
"BmmModule_basic",
|
||||
"BroadcastToModule_basic",
|
||||
"BroadcastToSameRankStaticModule_basic",
|
||||
|
@ -805,6 +812,7 @@ TOSA_PASS_SET = {
|
|||
"BatchNorm1DWith2DInputModule_basic",
|
||||
"BatchNorm2DModule_basic",
|
||||
"BatchNorm3DModule_basic",
|
||||
"BatchNorm1DStaticShapeModule_basic",
|
||||
"FlattenStaticModule_basic",
|
||||
"FlattenRank0Module_basic",
|
||||
"ElementwiseFlattenBroadcastModule_basic",
|
||||
|
|
|
@ -931,10 +931,10 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
Value momentum = adaptor.getMomentum();
|
||||
(void)momentum;
|
||||
|
||||
if (inputTy.getRank() <= 2) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"input should have rank larger than 2");
|
||||
}
|
||||
// handle feature index, see torch's BatchNorm1d, BatchNorm2d, BatchNorm3d,
|
||||
// all of NC, NCL, NCHW, NCDHW's feature index is 1.
|
||||
int64_t feature_index = 1;
|
||||
|
||||
if (!inputTy.getElementType().template isa<mlir::FloatType>()) {
|
||||
return op.emitError("only input tensor of float type is supported");
|
||||
}
|
||||
|
@ -1020,7 +1020,7 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
rewriter.create<stablehlo::BatchNormTrainingOp>(
|
||||
op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input,
|
||||
weight, bias, rewriter.getF32FloatAttr(eps),
|
||||
rewriter.getI64IntegerAttr(1));
|
||||
rewriter.getI64IntegerAttr(feature_index));
|
||||
rewriter.replaceOp(op, batchNormTrainingResult.getResult(0));
|
||||
return success();
|
||||
} else {
|
||||
|
@ -1037,7 +1037,8 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
op.getLoc(), inputCasted.getType(), inputCasted, weight, bias,
|
||||
runningMean, runningVar,
|
||||
// 'epsilon' must satisfy constraint: 32-bit float attribute.
|
||||
rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(1));
|
||||
rewriter.getF32FloatAttr(eps),
|
||||
rewriter.getI64IntegerAttr(feature_index));
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outputTy, output);
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -44,5 +44,6 @@ void mlir::torch::registerAllPasses() {
|
|||
mlir::mhlo::registerStablehloLegalizeToHloPass();
|
||||
mlir::mhlo::registerChloLegalizeToHloPass();
|
||||
mlir::mhlo::registerHloLegalizeToLinalgPass();
|
||||
mlir::mhlo::registerTestUnfuseBatchNormPass();
|
||||
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
||||
}
|
||||
|
|
|
@ -131,6 +131,7 @@ LOWERING_PIPELINE = "builtin.module(" + ",".join([
|
|||
# emit things in that form from the high level (e.g. single linalg-generic).
|
||||
# Other backends are likely to benefit more.
|
||||
"func.func(linalg-fuse-elementwise-ops)",
|
||||
"convert-shape-to-std",
|
||||
# Bufferize.
|
||||
"func.func(scf-bufferize)",
|
||||
"func.func(tm-tensor-bufferize)",
|
||||
|
|
|
@ -40,7 +40,7 @@ class LinalgOnTensorsStablehloBackend(StablehloBackend):
|
|||
"""
|
||||
run_pipeline_with_repro_report(
|
||||
imported_module,
|
||||
"builtin.module(func.func(chlo-legalize-to-hlo),stablehlo-legalize-to-hlo,func.func(canonicalize,cse,symbolic-shape-optimization,hlo-legalize-to-linalg,canonicalize))",
|
||||
"builtin.module(func.func(chlo-legalize-to-hlo),stablehlo-legalize-to-hlo,func.func(canonicalize,cse,symbolic-shape-optimization,mhlo-test-unfuse-batch-norm,canonicalize,hlo-legalize-to-linalg,canonicalize))",
|
||||
"Lowering StableHLO to Linalg-on-Tensors",
|
||||
)
|
||||
return self.refbackend.compile(imported_module)
|
||||
|
|
|
@ -115,6 +115,32 @@ def BatchNorm3DModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class BatchNorm1DStaticShapeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([2, 5], torch.float32, True),
|
||||
([5], torch.float32, True),
|
||||
([5], torch.float32, True),
|
||||
([5], torch.float32, True),
|
||||
([5], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, weight, bias, running_mean, running_var):
|
||||
return torch.ops.aten.batch_norm(
|
||||
x, weight, bias, running_mean, running_var, training=False,
|
||||
momentum=0.1, eps=0.00001, cudnn_enabled=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: BatchNorm1DStaticShapeModule())
|
||||
def BatchNorm1DStaticShapeModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(2, 5), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class NativeBatchNorm1DModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
#include "torch-mlir/InitAll.h"
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
#include "mhlo/IR/hlo_ops.h"
|
||||
#include "mhlo/transforms/passes.h"
|
||||
#include "stablehlo/dialect/Register.h"
|
||||
#endif
|
||||
|
||||
|
@ -28,6 +30,12 @@ int main(int argc, char **argv) {
|
|||
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
mlir::stablehlo::registerAllDialects(registry);
|
||||
registry.insert<mlir::mhlo::MhloDialect>();
|
||||
mlir::mhlo::registerSymbolicShapeOptimizationPass();
|
||||
mlir::mhlo::registerStablehloLegalizeToHloPass();
|
||||
mlir::mhlo::registerChloLegalizeToHloPass();
|
||||
mlir::mhlo::registerHloLegalizeToLinalgPass();
|
||||
mlir::mhlo::registerTestUnfuseBatchNormPass();
|
||||
#endif
|
||||
return mlir::asMainReturnCode(mlir::MlirOptMain(
|
||||
argc, argv, "MLIR modular optimizer driver\n", registry));
|
||||
|
|
Loading…
Reference in New Issue