From 442ff4605c54f836afa17d2780bd9bddcc2a5ed8 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Tue, 15 Feb 2022 00:09:36 +0530 Subject: [PATCH] [LINALG] Decompose `aten.batch_norm` into `aten.native_batch_norm` - This commit decomposes the `aten.batch_norm` op into the `aten.native_batch_norm` op, instead of lowering it to the `linalg.generic` op. - It also adds run-time asserts in the `aten.native_batch_norm` lowering to make sure that the shape of the weight, bias, running_mean, and running_var must match the num of features. - Since the `aten.native_batch_norm` op is not supported at TOSA backend, all the modules that are dependent on the `aten.native_batch_norm` op will fail and therefore they should be removed from the TOSA `passing` set. - It also moves `checkNotNone` to utility. Signed-Off-by: Gaurav Shukla --- e2e_testing/torchscript/norm_like.py | 59 ++++++++- e2e_testing/torchscript/xfail_sets.py | 4 - .../torch-mlir/Dialect/Torch/Utils/Utils.h | 2 + .../TorchToLinalg/TorchToLinalg.cpp | 116 ------------------ .../Torch/Transforms/DecomposeComplexOps.cpp | 69 ++++++++++- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 7 +- lib/Dialect/Torch/Utils/Utils.cpp | 8 ++ test/Dialect/Torch/decompose-complex-ops.mlir | 46 +++++++ 8 files changed, 179 insertions(+), 132 deletions(-) diff --git a/e2e_testing/torchscript/norm_like.py b/e2e_testing/torchscript/norm_like.py index 16c1412b9..748137f8b 100644 --- a/e2e_testing/torchscript/norm_like.py +++ b/e2e_testing/torchscript/norm_like.py @@ -89,6 +89,32 @@ def BatchNorm3DModule_basic(module, tu: TestUtils): # ============================================================================== +class BatchNormModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ]) + def forward(self, x, weight, bias, running_mean, running_var): + return torch.ops.aten.batch_norm( + x, weight, bias, running_mean, running_var, training=False, + momentum=0.1, eps=0.00001, cudnn_enabled=False) + + +@register_test_case(module_factory=lambda: BatchNormModule()) +def BatchNormModule_basic(module, tu: TestUtils): + module.forward( + tu.rand(2, 5, 3, 2), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5)) + +# ============================================================================== + class NativeBatchNorm1DModule(torch.nn.Module): def __init__(self): super().__init__() @@ -167,7 +193,7 @@ def NativeBatchNorm3DModule_basic(module, tu: TestUtils): # ============================================================================== -class NativeBatchNormNoneWeightModule(torch.nn.Module): +class NativeBatchNormWeightNoneModule(torch.nn.Module): def __init__(self): super().__init__() @@ -181,16 +207,39 @@ class NativeBatchNormNoneWeightModule(torch.nn.Module): ]) def forward(self, x, bias, running_mean, running_var): return torch.ops.aten.native_batch_norm( - x, None, bias, running_mean, running_var, training=False, - momentum=0.1, eps=0.00001) + x, weight=None, bias=bias, running_mean=running_mean, + running_var=running_var, training=False, momentum=0.1, eps=0.00001) -@register_test_case(module_factory=lambda: NativeBatchNormNoneWeightModule()) -def NativeBatchNormNoneWeightModule_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: NativeBatchNormWeightNoneModule()) +def NativeBatchNormWeightNoneModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5)) # ============================================================================== +class NativeBatchNormWeightNoneBiasNoneModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ]) + def forward(self, x, running_mean, running_var): + return torch.ops.aten.native_batch_norm( + x, weight=None, bias=None, running_mean=running_mean, + running_var=running_var, training=False, momentum=0.1, eps=0.00001) + + +@register_test_case(module_factory=lambda: NativeBatchNormWeightNoneBiasNoneModule()) +def NativeBatchNormWeightNoneBiasNoneModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 5), tu.rand(5), tu.rand(5)) + +# ============================================================================== + class NativeLayerNormModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index 24965ab9e..99528fcb2 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -85,15 +85,11 @@ TOSA_PASS_SET = { "ElementwiseReciprocalModule_basic", "TypePromotionAlphaWiderModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_basic", - "BatchNorm1DModule_basic", - "BatchNorm2DModule_basic", - "BatchNorm3DModule_basic", "FlattenStaticModule_basic", "FlattenRank0Module_basic", "ElementwiseFlattenBroadcastModule_basic", "SquareModule_basic", "MaxPool2dStaticModule_basic", - "ResNet18StaticModule_basic", "NativeLayerNormModule4D_basic", "LayerNormNormalizeOverAllDimsModule_basic", "PermuteModule_basic", diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index 63a8fc257..2607a8c3b 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -9,6 +9,7 @@ #ifndef TORCHMLIR_DIALECT_TORCH_UTILS_H #define TORCHMLIR_DIALECT_TORCH_UTILS_H +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" @@ -22,6 +23,7 @@ int64_t toPositiveDim(int64_t dim, int64_t inputRank); bool isValidDim(int64_t dim, int64_t inputRank); bool getListConstructElements(Value v, SmallVectorImpl &elems); torch_upstream::ScalarType getScalarTypeForType(Type type); +LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op, Value v); } // namespace Torch } // namespace torch diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 83a519c00..fa8baa34c 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -68,15 +68,6 @@ static LogicalResult verifyLinalgCompatibleTypes(Operation *op, return success(); } -static LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op, - Value v) { - Type type = v.getType(); - if (type.isa() || type.isa() || - type.isa()) - return rewriter.notifyMatchFailure(op, "unimplemented None type arg"); - return success(); -} - // Generate IR: dim = dim >= 0 ? dim : dim + inputRank static Value toPositiveDimDynamic(OpBuilder &b, Location loc, Value dim, Value inputRank) { @@ -604,111 +595,6 @@ static void createLinalgPayloadCalculationForGatherOps( b.create(loc, extract); } -namespace { -class ConvertAtenBatchNormOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(AtenBatchNormOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - MLIRContext *context = op->getContext(); - Location loc = op->getLoc(); - Value input = adaptor.input(); - Value weight = adaptor.weight(); - Value bias = adaptor.bias(); - Value runningMean = adaptor.running_mean(); - Value runningVar = adaptor.running_var(); - Value training = adaptor.training(); - Value eps = adaptor.eps(); - - if (failed(verifyLinalgCompatibleTypes(op, rewriter))) - return failure(); - - // TODO: Handle the None cases for the optional parameters: - // weight, bias. - if (failed(checkNotNone(rewriter, op, weight)) || - failed(checkNotNone(rewriter, op, bias)) || - failed(checkNotNone(rewriter, op, runningMean)) || - failed(checkNotNone(rewriter, op, runningVar))) - return failure(); - - auto inputType = input.getType().cast(); - auto weightType = weight.getType().cast(); - auto biasType = bias.getType().cast(); - auto runningMeanType = runningMean.getType().cast(); - auto runningVarType = runningVar.getType().cast(); - - auto inputRank = inputType.getRank(); - if (inputRank <= 2) - return rewriter.notifyMatchFailure( - op, "input should have rank larger than 2"); - - if (weightType.getRank() != 1 || biasType.getRank() != 1 || - runningMeanType.getRank() != 1 || runningVarType.getRank() != 1) { - return rewriter.notifyMatchFailure( - op, "expect weight, bias, running_mean and running_var to be rank 1"); - } - - // TODO: Add support for training. - auto constFalse = rewriter.create( - loc, IntegerAttr::get(IntegerType::get(context, 1), 0)); - auto trainingFalse = rewriter.create( - loc, arith::CmpIPredicate::eq, training, constFalse); - rewriter.create( - loc, trainingFalse, - rewriter.getStringAttr("training is not supported for now")); - - // num_features – C from an expected input of size (N,C,D,H,W ...) - Value numFeatures = rewriter.create(loc, input, 1); - auto contractingDim0EqualsNumFeatures = [&](Value v) { - auto dim0 = rewriter.create(loc, v, 0); - auto dim0Equal = rewriter.create( - loc, arith::CmpIPredicate::eq, numFeatures, dim0); - rewriter.create( - loc, dim0Equal, - rewriter.getStringAttr( - "expect the size of dim 0 equal to the number of features")); - }; - contractingDim0EqualsNumFeatures(weight); - contractingDim0EqualsNumFeatures(bias); - contractingDim0EqualsNumFeatures(runningMean); - contractingDim0EqualsNumFeatures(runningVar); - - auto indexingMap = AffineMap::get( - /*dimCount=*/inputRank, - /*symbolCount=*/0, rewriter.getAffineDimExpr(1), context); - SmallVector indexingMaps = { - rewriter.getMultiDimIdentityMap(inputRank), // input - indexingMap, // weight - indexingMap, // bias - indexingMap, // runningMean - indexingMap, // runningVar - rewriter.getMultiDimIdentityMap(inputRank), // output - }; - SmallVector iteratorTypes(inputRank, "parallel"); - Value batchNorm = - rewriter - .create( - loc, input.getType(), - ValueRange{input, weight, bias, runningMean, runningVar}, input, - /*indexingMaps=*/indexingMaps, - /*iteratorTypes=*/iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value input = args[0], weight = args[1], bias = args[2], - mean = args[3], var = args[4]; - Value result = createLinalgPayloadCalculationForNormOps( - b, loc, var.getType(), input, mean, var, eps, weight, - bias); - b.create(loc, result); - }) - .getResult(0); - Type newResultType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, newResultType, batchNorm); - return success(); - } -}; -} // namespace - // For layernorm, the mean and standard-deviation are calculated separately over // the last certain number dimensions which have to be of the shape specified by // normalized_shape. @@ -4628,8 +4514,6 @@ public: patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); target.addIllegalOp< AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp, diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 1b2745690..fdbc52a39 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -900,7 +900,9 @@ class DecomposeAtenAddCLikeOp : public OpRewritePattern { return success(); } }; +} // namespace +namespace { class DecomposeAtenLayerNormOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenLayerNormOp op, @@ -929,6 +931,40 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenBatchNormOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenBatchNormOp op, + PatternRewriter &rewriter) const override { + // TODO: Add support for `training` mode. + bool training = false; + if (!matchPattern(op.training(), m_TorchConstantBool(&training)) || + training) + return rewriter.notifyMatchFailure( + op, "unimplemented: training mode is not supported"); + + // The `mean` and `invstd` outputs shape should be {0} in the inference + // mode. + BaseTensorType tensorType = op.getType().cast(); + if (!tensorType.hasDtype() || !tensorType.getDtype().isa()) + return rewriter.notifyMatchFailure( + op, "unimplemented: non-floating point type input"); + Type emptyType = + tensorType.getWithSizesAndDtype({0}, tensorType.getDtype()); + + // The first output tensor of the `AtenNativeBatchNormOp` is essentially + // `AtenBatchNormOp` result. + auto nativeBatchNorm = rewriter.create( + op.getLoc(), op.getType(), /*meanType=*/emptyType, + /*invStdType=*/emptyType, op.input(), op.weight(), op.bias(), + op.running_mean(), op.running_var(), op.training(), op.momentum(), + op.eps()); + rewriter.replaceOp(op, nativeBatchNorm.getResult(0)); + return success(); + } +}; +} // namespace + namespace { // Decompose `aten.empty_like` op into `aten.size` and `aten.empty` ops. class DecomposeAtenEmptyLikeOp : public OpRewritePattern { @@ -1027,6 +1063,14 @@ class DecomposeAtenNativeBatchNormOp Value runningVar = op.running_var(); Value eps = op.eps(); + // TODO: Add support for optional type parameters. + if (weight.getType().isa() || + bias.getType().isa() || + runningMean.getType().isa() || + runningVar.getType().isa()) + return rewriter.notifyMatchFailure( + op, "unimplemented: optional type arg is not supported"); + // TODO: Add support for `training` mode. bool training = false; if (!matchPattern(op.training(), m_TorchConstantBool(&training)) || @@ -1053,13 +1097,24 @@ class DecomposeAtenNativeBatchNormOp return rewriter.notifyMatchFailure( op, "expected running_mean and running_var to be rank 1"); + // The shape of `runningMean` and `runningVar` must be (numFeatures). Here, + // 'numFeatures' is C from an expected 'input' of size (N,C,D?,H?,W?). Value zero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value one = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); Value numFeatures = rewriter.create(loc, input, /*dim=*/one); - // TODO: Add Runtime Asserts to check the shape of weight, bias, - // running_mean and running_var to be (numFeatures). + auto dim0EqualsNumFeatures = [&](Value v) { + Value dim0 = rewriter.create(loc, v, /*dim=*/zero); + Value eqCmp = rewriter.create(loc, BoolType::get(context), + dim0, numFeatures); + rewriter.create( + loc, eqCmp, + rewriter.getStringAttr("size of the 0th dimension must be equal to " + "the number of features")); + }; + dim0EqualsNumFeatures(runningMean); + dim0EqualsNumFeatures(runningVar); // The `runningMean` and `runningVar` must be reshaped to (1, C, 1?, 1?, 1?) // to make it broadcast-compatible with (N, C, D?, H?, W?). @@ -1097,18 +1152,22 @@ class DecomposeAtenNativeBatchNormOp // 3. output = normalizedInput * weight + bias Value batchNormOutput = normalizedInput; if (!weight.getType().isa()) { - // Rank of `weight` must be exactly 1. + // The shape of the `weight` tensor must be (numFeatures). if (getTensorRank(weight) != 1) return rewriter.notifyMatchFailure(op, "expected weight to be rank 1"); + dim0EqualsNumFeatures(weight); + weight = rewriter.create(loc, reshapeType, weight, runningStatsSizeList); batchNormOutput = rewriter.create( loc, batchNormOutput.getType(), batchNormOutput, weight); } if (!bias.getType().isa()) { - // Rank of `bias` must be exactly 1. + // The shape of the `bias` tensor must be (numFeatures). if (getTensorRank(bias) != 1) return rewriter.notifyMatchFailure(op, "expected bias to be rank 1"); + dim0EqualsNumFeatures(bias); + bias = rewriter.create(loc, reshapeType, bias, runningStatsSizeList); batchNormOutput = rewriter.create( @@ -1219,6 +1278,8 @@ class DecomposeComplexOpsPass patterns.add(context); target.addIllegalOp(); patterns.add(context); + target.addIllegalOp(); + patterns.add(context); patterns.add(context); target.addIllegalOp(); patterns.add(context); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index fe4ceed93..566c64a9e 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -1878,9 +1878,10 @@ ChangeResult TypeAnalyzer::visitAtenNativeBatchNormOp( meanKnowledge.dtype = input.dtype; invStdKnowledge.dtype = input.dtype; - // Rank of the input tensor must be greater than or equal to 2. The size of - // the input tensor as well as the output tensor should be (N, C, D?, H?, W?). - // The running_mean, running_var, weight, and bias should be of size (C). + // Rank of the input tensor must be greater than or equal to 2. The shape + // of the input tensor as well as the batch norm output tensor should be + // (N, C, D?, H?, W?). In inference mode, the mean and inv-std outputs should + // be empty tensors, whereas they should be of shape (C) in the training mode. bool training = false; if (matchPattern(op.training(), m_TorchConstantBool(&training)) && input.hasSizes && input.sizes.size() >= 2) { diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 4198bb4b2..9ca6da38d 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -46,6 +46,14 @@ ScalarType getScalarTypeForType(Type type) { llvm::report_fatal_error("unhandled type for getScalarTypeForType"); } +LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op, Value v) { + Type type = v.getType(); + if (type.isa() || type.isa() || + type.isa()) + return rewriter.notifyMatchFailure(op, "unimplemented None type arg"); + return success(); +} + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 003ec6b1b..c5d9b952d 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -448,3 +448,49 @@ func @torch.aten.hardsigmoid(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor %0 = torch.aten.hardsigmoid %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32> } + +// ----- +// CHECK-LABEL: func @torch.aten.batch_norm( +// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?,?],f32>, +// CHECK-SAME: %[[WEIGHT:.*]]: !torch.vtensor<[?],f32>, %[[BIAS:.*]]: !torch.vtensor<[?],f32>, %[[RMEAN:.*]]: !torch.vtensor<[?],f32>, %[[RVAR:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { +// CHECK: %[[EPS:.*]] = torch.constant.float 1.000000e-05 +// CHECK: %[[MOM:.*]] = torch.constant.float 1.000000e-01 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[INPUT_DIM1:.*]] = torch.aten.size.int %[[INPUT]], %[[INT1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[RMEAN_DIM0:.*]] = torch.aten.size.int %[[RMEAN]], %[[INT0]] : !torch.vtensor<[?],f32>, !torch.int -> !torch.int +// CHECK: %[[PRED_MEAN:.*]] = torch.aten.eq.int %[[RMEAN_DIM0]], %[[INPUT_DIM1]] : !torch.int, !torch.int -> !torch.bool +// CHECK: torch.runtime.assert %[[PRED_MEAN]], "size of the 0th dimension must be equal to the number of features" +// CHECK: %[[RVAR_DIM0:.*]] = torch.aten.size.int %[[RVAR]], %[[INT0]] : !torch.vtensor<[?],f32>, !torch.int -> !torch.int +// CHECK: %[[PRED_VAR:.*]] = torch.aten.eq.int %[[RVAR_DIM0]], %[[INPUT_DIM1]] : !torch.int, !torch.int -> !torch.bool +// CHECK: torch.runtime.assert %[[PRED_VAR]], "size of the 0th dimension must be equal to the number of features" +// CHECK: %[[SIZE_LIST:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INPUT_DIM1]], %[[INT1]], %[[INT1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[RMEAN_VIEW:.*]] = torch.aten.view %[[RMEAN]], %[[SIZE_LIST]] : !torch.vtensor<[?],f32>, !torch.list -> !torch.vtensor<[1,?,1,1],f32> +// CHECK: %[[RVAR_VIEW:.*]] = torch.aten.view %[[RVAR]], %[[SIZE_LIST]] : !torch.vtensor<[?],f32>, !torch.list -> !torch.vtensor<[1,?,1,1],f32> +// CHECK: %[[X_SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[RMEAN_VIEW]], %[[INT1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[1,?,1,1],f32>, !torch.int -> !torch.vtensor<[?,?,?,?],f32> +// CHECK: %[[VAR_EPS:.*]] = torch.aten.add.Scalar %[[RVAR_VIEW]], %[[EPS]], %[[INT1]] : !torch.vtensor<[1,?,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,?,1,1],f32> +// CHECK: %[[SQRT_VAR_EPS:.*]] = torch.aten.rsqrt %[[VAR_EPS]] : !torch.vtensor<[1,?,1,1],f32> -> !torch.vtensor<[1,?,1,1],f32> +// CHECK: %[[NORM_INPUT:.*]] = torch.aten.mul.Tensor %[[X_SUB_MEAN]], %[[SQRT_VAR_EPS]] : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[1,?,1,1],f32> -> !torch.vtensor<[?,?,?,?],f32> +// CHECK: %[[WEIGHT_DIM0:.*]] = torch.aten.size.int %[[WEIGHT]], %[[INT0]] : !torch.vtensor<[?],f32>, !torch.int -> !torch.int +// CHECK: %[[PRED_WEIGHT:.*]] = torch.aten.eq.int %[[WEIGHT_DIM0]], %[[INPUT_DIM1]] : !torch.int, !torch.int -> !torch.bool +// CHECK: torch.runtime.assert %[[PRED_WEIGHT]], "size of the 0th dimension must be equal to the number of features" +// CHECK: %[[WEIGHT_VIEW:.*]] = torch.aten.view %[[WEIGHT]], %[[SIZE_LIST]] : !torch.vtensor<[?],f32>, !torch.list -> !torch.vtensor<[1,?,1,1],f32> +// CHECK: %[[SCALED_INPUT:.*]] = torch.aten.mul.Tensor %[[NORM_INPUT]], %[[WEIGHT_VIEW]] : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[1,?,1,1],f32> -> !torch.vtensor<[?,?,?,?],f32> +// CHECK: %[[BIAS_DIM0:.*]] = torch.aten.size.int %[[BIAS]], %[[INT0]] : !torch.vtensor<[?],f32>, !torch.int -> !torch.int +// CHECK: %[[PRED_BIAS:.*]] = torch.aten.eq.int %[[BIAS_DIM0]], %[[INPUT_DIM1]] : !torch.int, !torch.int -> !torch.bool +// CHECK: torch.runtime.assert %[[PRED_BIAS]], "size of the 0th dimension must be equal to the number of features" +// CHECK: %[[BIAS_VIEW:.*]] = torch.aten.view %[[BIAS]], %[[SIZE_LIST]] : !torch.vtensor<[?],f32>, !torch.list -> !torch.vtensor<[1,?,1,1],f32> +// CHECK: %[[OUTPUT:.*]] = torch.aten.add.Tensor %[[SCALED_INPUT]], %[[BIAS_VIEW]], %[[INT1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[1,?,1,1],f32>, !torch.int -> !torch.vtensor<[?,?,?,?],f32> +// CHECK: %[[ZERO_LIST:.*]] = torch.prim.ListConstruct %[[INT0]] : (!torch.int) -> !torch.list +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[MEAN_OUT:.*]] = torch.aten.empty.memory_format %[[ZERO_LIST]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[0],f32> +// CHECK: %[[INV_STD_OUT:.*]] = torch.aten.empty.memory_format %[[ZERO_LIST]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[0],f32> +// CHECK: return %[[OUTPUT]] : !torch.vtensor<[?,?,?,?],f32> +func @torch.aten.batch_norm(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?],f32>, %arg2: !torch.vtensor<[?],f32>, %arg3: !torch.vtensor<[?],f32>, %arg4: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %float1.000000e-05 = torch.constant.float 1.000000e-05 + %float1.000000e-01 = torch.constant.float 1.000000e-01 + %false = torch.constant.bool false + %0 = torch.aten.batch_norm %arg0, %arg1, %arg2, %arg3, %arg4, %false, %float1.000000e-01, %float1.000000e-05, %false : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> +}