[LINALG] Decompose `aten.batch_norm` into `aten.native_batch_norm`

- This commit decomposes the `aten.batch_norm` op into the
  `aten.native_batch_norm` op, instead of lowering it to the
  `linalg.generic` op.
- It also adds run-time asserts in the `aten.native_batch_norm` lowering
  to make sure that the shape of the weight, bias, running_mean, and
  running_var must match the num of features.
- Since the `aten.native_batch_norm` op is not supported at TOSA backend,
  all the modules that are dependent on the `aten.native_batch_norm` op
  will fail and therefore they should be removed from the TOSA `passing`
  set.
- It also moves `checkNotNone` to utility.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
pull/590/head
Gaurav Shukla 2022-02-15 00:09:36 +05:30
parent c60468f141
commit 442ff4605c
8 changed files with 179 additions and 132 deletions

View File

@ -89,6 +89,32 @@ def BatchNorm3DModule_basic(module, tu: TestUtils):
# ==============================================================================
class BatchNormModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([-1], torch.float32, True),
([-1], torch.float32, True),
([-1], torch.float32, True),
([-1], torch.float32, True),
])
def forward(self, x, weight, bias, running_mean, running_var):
return torch.ops.aten.batch_norm(
x, weight, bias, running_mean, running_var, training=False,
momentum=0.1, eps=0.00001, cudnn_enabled=False)
@register_test_case(module_factory=lambda: BatchNormModule())
def BatchNormModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(2, 5, 3, 2), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5))
# ==============================================================================
class NativeBatchNorm1DModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -167,7 +193,7 @@ def NativeBatchNorm3DModule_basic(module, tu: TestUtils):
# ==============================================================================
class NativeBatchNormNoneWeightModule(torch.nn.Module):
class NativeBatchNormWeightNoneModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -181,16 +207,39 @@ class NativeBatchNormNoneWeightModule(torch.nn.Module):
])
def forward(self, x, bias, running_mean, running_var):
return torch.ops.aten.native_batch_norm(
x, None, bias, running_mean, running_var, training=False,
momentum=0.1, eps=0.00001)
x, weight=None, bias=bias, running_mean=running_mean,
running_var=running_var, training=False, momentum=0.1, eps=0.00001)
@register_test_case(module_factory=lambda: NativeBatchNormNoneWeightModule())
def NativeBatchNormNoneWeightModule_basic(module, tu: TestUtils):
@register_test_case(module_factory=lambda: NativeBatchNormWeightNoneModule())
def NativeBatchNormWeightNoneModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5))
# ==============================================================================
class NativeBatchNormWeightNoneBiasNoneModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1], torch.float32, True),
([-1], torch.float32, True),
])
def forward(self, x, running_mean, running_var):
return torch.ops.aten.native_batch_norm(
x, weight=None, bias=None, running_mean=running_mean,
running_var=running_var, training=False, momentum=0.1, eps=0.00001)
@register_test_case(module_factory=lambda: NativeBatchNormWeightNoneBiasNoneModule())
def NativeBatchNormWeightNoneBiasNoneModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5), tu.rand(5), tu.rand(5))
# ==============================================================================
class NativeLayerNormModule(torch.nn.Module):
def __init__(self):
super().__init__()

View File

@ -85,15 +85,11 @@ TOSA_PASS_SET = {
"ElementwiseReciprocalModule_basic",
"TypePromotionAlphaWiderModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_basic",
"BatchNorm1DModule_basic",
"BatchNorm2DModule_basic",
"BatchNorm3DModule_basic",
"FlattenStaticModule_basic",
"FlattenRank0Module_basic",
"ElementwiseFlattenBroadcastModule_basic",
"SquareModule_basic",
"MaxPool2dStaticModule_basic",
"ResNet18StaticModule_basic",
"NativeLayerNormModule4D_basic",
"LayerNormNormalizeOverAllDimsModule_basic",
"PermuteModule_basic",

View File

@ -9,6 +9,7 @@
#ifndef TORCHMLIR_DIALECT_TORCH_UTILS_H
#define TORCHMLIR_DIALECT_TORCH_UTILS_H
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
@ -22,6 +23,7 @@ int64_t toPositiveDim(int64_t dim, int64_t inputRank);
bool isValidDim(int64_t dim, int64_t inputRank);
bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems);
torch_upstream::ScalarType getScalarTypeForType(Type type);
LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op, Value v);
} // namespace Torch
} // namespace torch

View File

@ -68,15 +68,6 @@ static LogicalResult verifyLinalgCompatibleTypes(Operation *op,
return success();
}
static LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op,
Value v) {
Type type = v.getType();
if (type.isa<OptionalType>() || type.isa<Torch::NoneType>() ||
type.isa<mlir::NoneType>())
return rewriter.notifyMatchFailure(op, "unimplemented None type arg");
return success();
}
// Generate IR: dim = dim >= 0 ? dim : dim + inputRank
static Value toPositiveDimDynamic(OpBuilder &b, Location loc, Value dim,
Value inputRank) {
@ -604,111 +595,6 @@ static void createLinalgPayloadCalculationForGatherOps(
b.create<linalg::YieldOp>(loc, extract);
}
namespace {
class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenBatchNormOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *context = op->getContext();
Location loc = op->getLoc();
Value input = adaptor.input();
Value weight = adaptor.weight();
Value bias = adaptor.bias();
Value runningMean = adaptor.running_mean();
Value runningVar = adaptor.running_var();
Value training = adaptor.training();
Value eps = adaptor.eps();
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
// TODO: Handle the None cases for the optional parameters:
// weight, bias.
if (failed(checkNotNone(rewriter, op, weight)) ||
failed(checkNotNone(rewriter, op, bias)) ||
failed(checkNotNone(rewriter, op, runningMean)) ||
failed(checkNotNone(rewriter, op, runningVar)))
return failure();
auto inputType = input.getType().cast<RankedTensorType>();
auto weightType = weight.getType().cast<RankedTensorType>();
auto biasType = bias.getType().cast<RankedTensorType>();
auto runningMeanType = runningMean.getType().cast<RankedTensorType>();
auto runningVarType = runningVar.getType().cast<RankedTensorType>();
auto inputRank = inputType.getRank();
if (inputRank <= 2)
return rewriter.notifyMatchFailure(
op, "input should have rank larger than 2");
if (weightType.getRank() != 1 || biasType.getRank() != 1 ||
runningMeanType.getRank() != 1 || runningVarType.getRank() != 1) {
return rewriter.notifyMatchFailure(
op, "expect weight, bias, running_mean and running_var to be rank 1");
}
// TODO: Add support for training.
auto constFalse = rewriter.create<arith::ConstantOp>(
loc, IntegerAttr::get(IntegerType::get(context, 1), 0));
auto trainingFalse = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, training, constFalse);
rewriter.create<AssertOp>(
loc, trainingFalse,
rewriter.getStringAttr("training is not supported for now"));
// num_features C from an expected input of size (N,C,D,H,W ...)
Value numFeatures = rewriter.create<tensor::DimOp>(loc, input, 1);
auto contractingDim0EqualsNumFeatures = [&](Value v) {
auto dim0 = rewriter.create<tensor::DimOp>(loc, v, 0);
auto dim0Equal = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, numFeatures, dim0);
rewriter.create<AssertOp>(
loc, dim0Equal,
rewriter.getStringAttr(
"expect the size of dim 0 equal to the number of features"));
};
contractingDim0EqualsNumFeatures(weight);
contractingDim0EqualsNumFeatures(bias);
contractingDim0EqualsNumFeatures(runningMean);
contractingDim0EqualsNumFeatures(runningVar);
auto indexingMap = AffineMap::get(
/*dimCount=*/inputRank,
/*symbolCount=*/0, rewriter.getAffineDimExpr(1), context);
SmallVector<AffineMap> indexingMaps = {
rewriter.getMultiDimIdentityMap(inputRank), // input
indexingMap, // weight
indexingMap, // bias
indexingMap, // runningMean
indexingMap, // runningVar
rewriter.getMultiDimIdentityMap(inputRank), // output
};
SmallVector<StringRef> iteratorTypes(inputRank, "parallel");
Value batchNorm =
rewriter
.create<linalg::GenericOp>(
loc, input.getType(),
ValueRange{input, weight, bias, runningMean, runningVar}, input,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value input = args[0], weight = args[1], bias = args[2],
mean = args[3], var = args[4];
Value result = createLinalgPayloadCalculationForNormOps(
b, loc, var.getType(), input, mean, var, eps, weight,
bias);
b.create<linalg::YieldOp>(loc, result);
})
.getResult(0);
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, batchNorm);
return success();
}
};
} // namespace
// For layernorm, the mean and standard-deviation are calculated separately over
// the last certain number dimensions which have to be of the shape specified by
// normalized_shape.
@ -4628,8 +4514,6 @@ public:
patterns.add<ConvertAtenBmmOp>(typeConverter, context);
target.addIllegalOp<AtenLinearOp>();
patterns.add<ConvertAtenLinearOp>(typeConverter, context);
target.addIllegalOp<AtenBatchNormOp>();
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
target.addIllegalOp<
AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp,
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,

View File

@ -900,7 +900,9 @@ class DecomposeAtenAddCLikeOp : public OpRewritePattern<OpTy> {
return success();
}
};
} // namespace
namespace {
class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
using OpRewritePattern<AtenLayerNormOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenLayerNormOp op,
@ -929,6 +931,40 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
};
} // namespace
namespace {
class DecomposeAtenBatchNormOp : public OpRewritePattern<AtenBatchNormOp> {
using OpRewritePattern<AtenBatchNormOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenBatchNormOp op,
PatternRewriter &rewriter) const override {
// TODO: Add support for `training` mode.
bool training = false;
if (!matchPattern(op.training(), m_TorchConstantBool(&training)) ||
training)
return rewriter.notifyMatchFailure(
op, "unimplemented: training mode is not supported");
// The `mean` and `invstd` outputs shape should be {0} in the inference
// mode.
BaseTensorType tensorType = op.getType().cast<BaseTensorType>();
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(
op, "unimplemented: non-floating point type input");
Type emptyType =
tensorType.getWithSizesAndDtype({0}, tensorType.getDtype());
// The first output tensor of the `AtenNativeBatchNormOp` is essentially
// `AtenBatchNormOp` result.
auto nativeBatchNorm = rewriter.create<AtenNativeBatchNormOp>(
op.getLoc(), op.getType(), /*meanType=*/emptyType,
/*invStdType=*/emptyType, op.input(), op.weight(), op.bias(),
op.running_mean(), op.running_var(), op.training(), op.momentum(),
op.eps());
rewriter.replaceOp(op, nativeBatchNorm.getResult(0));
return success();
}
};
} // namespace
namespace {
// Decompose `aten.empty_like` op into `aten.size` and `aten.empty` ops.
class DecomposeAtenEmptyLikeOp : public OpRewritePattern<AtenEmptyLikeOp> {
@ -1027,6 +1063,14 @@ class DecomposeAtenNativeBatchNormOp
Value runningVar = op.running_var();
Value eps = op.eps();
// TODO: Add support for optional type parameters.
if (weight.getType().isa<OptionalType>() ||
bias.getType().isa<OptionalType>() ||
runningMean.getType().isa<OptionalType>() ||
runningVar.getType().isa<OptionalType>())
return rewriter.notifyMatchFailure(
op, "unimplemented: optional type arg is not supported");
// TODO: Add support for `training` mode.
bool training = false;
if (!matchPattern(op.training(), m_TorchConstantBool(&training)) ||
@ -1053,13 +1097,24 @@ class DecomposeAtenNativeBatchNormOp
return rewriter.notifyMatchFailure(
op, "expected running_mean and running_var to be rank 1");
// The shape of `runningMean` and `runningVar` must be (numFeatures). Here,
// 'numFeatures' is C from an expected 'input' of size (N,C,D?,H?,W?).
Value zero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value numFeatures = rewriter.create<AtenSizeIntOp>(loc, input, /*dim=*/one);
// TODO: Add Runtime Asserts to check the shape of weight, bias,
// running_mean and running_var to be (numFeatures).
auto dim0EqualsNumFeatures = [&](Value v) {
Value dim0 = rewriter.create<AtenSizeIntOp>(loc, v, /*dim=*/zero);
Value eqCmp = rewriter.create<AtenEqIntOp>(loc, BoolType::get(context),
dim0, numFeatures);
rewriter.create<RuntimeAssertOp>(
loc, eqCmp,
rewriter.getStringAttr("size of the 0th dimension must be equal to "
"the number of features"));
};
dim0EqualsNumFeatures(runningMean);
dim0EqualsNumFeatures(runningVar);
// The `runningMean` and `runningVar` must be reshaped to (1, C, 1?, 1?, 1?)
// to make it broadcast-compatible with (N, C, D?, H?, W?).
@ -1097,18 +1152,22 @@ class DecomposeAtenNativeBatchNormOp
// 3. output = normalizedInput * weight + bias
Value batchNormOutput = normalizedInput;
if (!weight.getType().isa<Torch::NoneType>()) {
// Rank of `weight` must be exactly 1.
// The shape of the `weight` tensor must be (numFeatures).
if (getTensorRank(weight) != 1)
return rewriter.notifyMatchFailure(op, "expected weight to be rank 1");
dim0EqualsNumFeatures(weight);
weight = rewriter.create<AtenViewOp>(loc, reshapeType, weight,
runningStatsSizeList);
batchNormOutput = rewriter.create<AtenMulTensorOp>(
loc, batchNormOutput.getType(), batchNormOutput, weight);
}
if (!bias.getType().isa<Torch::NoneType>()) {
// Rank of `bias` must be exactly 1.
// The shape of the `bias` tensor must be (numFeatures).
if (getTensorRank(bias) != 1)
return rewriter.notifyMatchFailure(op, "expected bias to be rank 1");
dim0EqualsNumFeatures(bias);
bias = rewriter.create<AtenViewOp>(loc, reshapeType, bias,
runningStatsSizeList);
batchNormOutput = rewriter.create<AtenAddTensorOp>(
@ -1219,6 +1278,8 @@ class DecomposeComplexOpsPass
patterns.add<DecomposeAtenLayerNormOp>(context);
target.addIllegalOp<AtenNativeBatchNormOp>();
patterns.add<DecomposeAtenNativeBatchNormOp>(context);
target.addIllegalOp<AtenBatchNormOp>();
patterns.add<DecomposeAtenBatchNormOp>(context);
patterns.add<DecomposeAtenArangeOp>(context);
target.addIllegalOp<AtenArangeOp>();
patterns.add<DecomposeAtenArangeStartOp>(context);

View File

@ -1878,9 +1878,10 @@ ChangeResult TypeAnalyzer::visitAtenNativeBatchNormOp(
meanKnowledge.dtype = input.dtype;
invStdKnowledge.dtype = input.dtype;
// Rank of the input tensor must be greater than or equal to 2. The size of
// the input tensor as well as the output tensor should be (N, C, D?, H?, W?).
// The running_mean, running_var, weight, and bias should be of size (C).
// Rank of the input tensor must be greater than or equal to 2. The shape
// of the input tensor as well as the batch norm output tensor should be
// (N, C, D?, H?, W?). In inference mode, the mean and inv-std outputs should
// be empty tensors, whereas they should be of shape (C) in the training mode.
bool training = false;
if (matchPattern(op.training(), m_TorchConstantBool(&training)) &&
input.hasSizes && input.sizes.size() >= 2) {

View File

@ -46,6 +46,14 @@ ScalarType getScalarTypeForType(Type type) {
llvm::report_fatal_error("unhandled type for getScalarTypeForType");
}
LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op, Value v) {
Type type = v.getType();
if (type.isa<OptionalType>() || type.isa<Torch::NoneType>() ||
type.isa<mlir::NoneType>())
return rewriter.notifyMatchFailure(op, "unimplemented None type arg");
return success();
}
} // namespace Torch
} // namespace torch
} // namespace mlir

View File

@ -448,3 +448,49 @@ func @torch.aten.hardsigmoid(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor
%0 = torch.aten.hardsigmoid %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.batch_norm(
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?,?],f32>,
// CHECK-SAME: %[[WEIGHT:.*]]: !torch.vtensor<[?],f32>, %[[BIAS:.*]]: !torch.vtensor<[?],f32>, %[[RMEAN:.*]]: !torch.vtensor<[?],f32>, %[[RVAR:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
// CHECK: %[[EPS:.*]] = torch.constant.float 1.000000e-05
// CHECK: %[[MOM:.*]] = torch.constant.float 1.000000e-01
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INPUT_DIM1:.*]] = torch.aten.size.int %[[INPUT]], %[[INT1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[RMEAN_DIM0:.*]] = torch.aten.size.int %[[RMEAN]], %[[INT0]] : !torch.vtensor<[?],f32>, !torch.int -> !torch.int
// CHECK: %[[PRED_MEAN:.*]] = torch.aten.eq.int %[[RMEAN_DIM0]], %[[INPUT_DIM1]] : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[PRED_MEAN]], "size of the 0th dimension must be equal to the number of features"
// CHECK: %[[RVAR_DIM0:.*]] = torch.aten.size.int %[[RVAR]], %[[INT0]] : !torch.vtensor<[?],f32>, !torch.int -> !torch.int
// CHECK: %[[PRED_VAR:.*]] = torch.aten.eq.int %[[RVAR_DIM0]], %[[INPUT_DIM1]] : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[PRED_VAR]], "size of the 0th dimension must be equal to the number of features"
// CHECK: %[[SIZE_LIST:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INPUT_DIM1]], %[[INT1]], %[[INT1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<!torch.int>
// CHECK: %[[RMEAN_VIEW:.*]] = torch.aten.view %[[RMEAN]], %[[SIZE_LIST]] : !torch.vtensor<[?],f32>, !torch.list<!torch.int> -> !torch.vtensor<[1,?,1,1],f32>
// CHECK: %[[RVAR_VIEW:.*]] = torch.aten.view %[[RVAR]], %[[SIZE_LIST]] : !torch.vtensor<[?],f32>, !torch.list<!torch.int> -> !torch.vtensor<[1,?,1,1],f32>
// CHECK: %[[X_SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[RMEAN_VIEW]], %[[INT1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[1,?,1,1],f32>, !torch.int -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: %[[VAR_EPS:.*]] = torch.aten.add.Scalar %[[RVAR_VIEW]], %[[EPS]], %[[INT1]] : !torch.vtensor<[1,?,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,?,1,1],f32>
// CHECK: %[[SQRT_VAR_EPS:.*]] = torch.aten.rsqrt %[[VAR_EPS]] : !torch.vtensor<[1,?,1,1],f32> -> !torch.vtensor<[1,?,1,1],f32>
// CHECK: %[[NORM_INPUT:.*]] = torch.aten.mul.Tensor %[[X_SUB_MEAN]], %[[SQRT_VAR_EPS]] : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[1,?,1,1],f32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: %[[WEIGHT_DIM0:.*]] = torch.aten.size.int %[[WEIGHT]], %[[INT0]] : !torch.vtensor<[?],f32>, !torch.int -> !torch.int
// CHECK: %[[PRED_WEIGHT:.*]] = torch.aten.eq.int %[[WEIGHT_DIM0]], %[[INPUT_DIM1]] : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[PRED_WEIGHT]], "size of the 0th dimension must be equal to the number of features"
// CHECK: %[[WEIGHT_VIEW:.*]] = torch.aten.view %[[WEIGHT]], %[[SIZE_LIST]] : !torch.vtensor<[?],f32>, !torch.list<!torch.int> -> !torch.vtensor<[1,?,1,1],f32>
// CHECK: %[[SCALED_INPUT:.*]] = torch.aten.mul.Tensor %[[NORM_INPUT]], %[[WEIGHT_VIEW]] : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[1,?,1,1],f32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: %[[BIAS_DIM0:.*]] = torch.aten.size.int %[[BIAS]], %[[INT0]] : !torch.vtensor<[?],f32>, !torch.int -> !torch.int
// CHECK: %[[PRED_BIAS:.*]] = torch.aten.eq.int %[[BIAS_DIM0]], %[[INPUT_DIM1]] : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[PRED_BIAS]], "size of the 0th dimension must be equal to the number of features"
// CHECK: %[[BIAS_VIEW:.*]] = torch.aten.view %[[BIAS]], %[[SIZE_LIST]] : !torch.vtensor<[?],f32>, !torch.list<!torch.int> -> !torch.vtensor<[1,?,1,1],f32>
// CHECK: %[[OUTPUT:.*]] = torch.aten.add.Tensor %[[SCALED_INPUT]], %[[BIAS_VIEW]], %[[INT1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[1,?,1,1],f32>, !torch.int -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: %[[ZERO_LIST:.*]] = torch.prim.ListConstruct %[[INT0]] : (!torch.int) -> !torch.list<!torch.int>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[MEAN_OUT:.*]] = torch.aten.empty.memory_format %[[ZERO_LIST]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[0],f32>
// CHECK: %[[INV_STD_OUT:.*]] = torch.aten.empty.memory_format %[[ZERO_LIST]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[0],f32>
// CHECK: return %[[OUTPUT]] : !torch.vtensor<[?,?,?,?],f32>
func @torch.aten.batch_norm(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?],f32>, %arg2: !torch.vtensor<[?],f32>, %arg3: !torch.vtensor<[?],f32>, %arg4: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
%float1.000000e-05 = torch.constant.float 1.000000e-05
%float1.000000e-01 = torch.constant.float 1.000000e-01
%false = torch.constant.bool false
%0 = torch.aten.batch_norm %arg0, %arg1, %arg2, %arg3, %arg4, %false, %float1.000000e-01, %float1.000000e-05, %false : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32>
}