[ONNX] Add training mode support for BatchNormalization op (#3597)

This commit extends the OnnxToTorch lowering for BatchNormalization op
for supporting the case when training=True.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
pull/3631/head
Vivek Khandelwal 2024-08-14 10:46:38 +05:30 committed by GitHub
parent 2511cf46b4
commit 4a0bed0ce0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 141 additions and 34 deletions

View File

@ -339,43 +339,122 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.op, resultType, operand);
return success();
});
patterns.onOp("BatchNormalization", 15,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value input, weight, bias, runningMean, runningVar;
bool training;
float momentum, eps;
if (binder.s64BoolAttr(training, "training_mode", 0))
return failure();
if (training) {
// TODO: Add support for training = true
return rewriter.notifyMatchFailure(
binder.op, "unsupported conversion: training = true");
}
patterns.onOp(
"BatchNormalization", 15,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value input, weight, bias, inputMean, inputVar;
bool training;
float momentum, eps;
if (binder.tensorOperandAtIndex(input, 0) ||
binder.tensorOperandAtIndex(weight, 1) ||
binder.tensorOperandAtIndex(bias, 2) ||
binder.tensorOperandAtIndex(inputMean, 3) ||
binder.tensorOperandAtIndex(inputVar, 4) ||
binder.f32FloatAttr(momentum, "momentum", 0.9f) ||
binder.f32FloatAttr(eps, "epsilon", 1e-05f) ||
binder.s64BoolAttr(training, "training_mode", 0) ||
binder.tensorResultTypeAtIndex(resultType, 0))
return failure();
if (binder.tensorOperandAtIndex(input, 0) ||
binder.tensorOperandAtIndex(weight, 1) ||
binder.tensorOperandAtIndex(bias, 2) ||
binder.tensorOperandAtIndex(runningMean, 3) ||
binder.tensorOperandAtIndex(runningVar, 4) ||
binder.f32FloatAttr(momentum, "momentum", 0.9f) ||
binder.f32FloatAttr(eps, "epsilon", 1e-05f) ||
binder.tensorResultType(resultType))
return failure();
Location loc = binder.getLoc();
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
Value cstMomentum = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(momentum));
Value cstEps = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(eps));
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
binder.getLoc(), false);
Value cstMomentum = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(momentum));
Value cstEps = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(eps));
// When training_mode=False, the op outputs only Y, where
// Y = (X - input_mean) / sqrt(input_var + epsilon) * scale +
// B
if (!training) {
rewriter.replaceOpWithNewOp<Torch::AtenBatchNormOp>(
binder.op, resultType, input, weight, bias, inputMean, inputVar,
/*training=*/cstFalse, cstMomentum, cstEps,
/*cudnn_enabled=*/cstFalse);
return success();
}
rewriter.replaceOpWithNewOp<Torch::AtenBatchNormOp>(
binder.op, resultType, input, weight, bias, runningMean,
runningVar, /*training=*/cstFalse, cstMomentum, cstEps,
/*cudnn_enabled=*/cstFalse);
return success();
});
Torch::ValueTensorType meanResultType, varResultType;
if (binder.tensorResultTypeAtIndex(meanResultType, 1) ||
binder.tensorResultTypeAtIndex(varResultType, 2))
return failure();
// When training_mode=True, the outputs are as follows:
// Y, running_mean, running_var.
// Y = (X - current_mean) / sqrt(current_var + epsilon) *
// scale + B
// running_mean = input_mean * momentum + current_mean * (1 -
// momentum)
// running_var = input_var * momentum + current_var * (1 -
// momentum)
// and
// current_mean = ReduceMean(X, axis=all_except_channel_index)
// current_var = ReduceVar(X, axis=all_except_channel_index)
Torch::ValueTensorType inputType =
cast<Torch::ValueTensorType>(input.getType());
if (!inputType.hasSizes())
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: expected input to have sizes");
// Computing current_mean and current_var.
int64_t inputRank = inputType.getSizes().size();
// Reduce all dimensions except channel dim.
SmallVector<Value> dimsToReduce;
for (int64_t i = 0; i < inputRank; i++) {
if (i != 1)
dimsToReduce.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
}
Value reduceDimsList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
dimsToReduce);
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value currentMean = rewriter.create<Torch::AtenMeanDimOp>(
loc, meanResultType, input, reduceDimsList,
/*keepdim=*/cstFalse,
/*dtype=*/noneVal);
Value currentVar = rewriter.create<Torch::AtenVarDimOp>(
loc, varResultType, input, reduceDimsList,
/*unbiased=*/cstFalse,
/*keepdim=*/cstFalse);
// Computing running_mean.
Value inputMeanMulMomentum = rewriter.create<Torch::AtenMulScalarOp>(
loc, meanResultType, inputMean, cstMomentum);
Value currentMeanMulMomentum = rewriter.create<Torch::AtenMulScalarOp>(
loc, varResultType, currentMean, cstMomentum);
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value inpMeanMMSubCurMeanMM = rewriter.create<Torch::AtenSubTensorOp>(
loc, meanResultType, inputMeanMulMomentum, currentMeanMulMomentum,
constantOne);
Value runningMean = rewriter.create<Torch::AtenAddTensorOp>(
loc, meanResultType, inpMeanMMSubCurMeanMM, currentMean,
constantOne);
// Computing running_var.
Value inputVarMulMomentum = rewriter.create<Torch::AtenMulScalarOp>(
loc, varResultType, inputVar, cstMomentum);
Value currentVarMulMomentum = rewriter.create<Torch::AtenMulScalarOp>(
loc, varResultType, currentVar, cstMomentum);
Value inpVarMMSubCurVarMM = rewriter.create<Torch::AtenSubTensorOp>(
loc, varResultType, inputVarMulMomentum, currentVarMulMomentum,
constantOne);
Value runningVar = rewriter.create<Torch::AtenAddTensorOp>(
loc, varResultType, inpVarMMSubCurVarMM, currentVar, constantOne);
// Computing Y.
Value y = rewriter.create<Torch::AtenBatchNormOp>(
loc, resultType, input, weight, bias, currentMean, currentVar,
/*training=*/cstFalse, cstMomentum, cstEps,
/*cudnn_enabled=*/cstFalse);
rewriter.replaceOp(binder.op, {y, runningMean, runningVar});
return success();
});
patterns.onOp(
"AveragePool", 11,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {

View File

@ -1266,6 +1266,34 @@ func.func @test_batchnorm_example(%arg0: !torch.vtensor<[2,3,4,5],f32>, %arg1: !
// -----
// CHECK-LABEL: func.func @test_batchnorm_training
func.func @test_batchnorm_training(%arg0: !torch.vtensor<[1,16,27],f32>, %arg1: !torch.vtensor<[16],f32>, %arg2: !torch.vtensor<[16],f32>, %arg3: !torch.vtensor<[16],f32>, %arg4: !torch.vtensor<[16],f32>) -> (!torch.vtensor<[1,16,27],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>) attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[MOMENTUM:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[EPSILON:.*]] = torch.constant.float 9.9999997473787516E-6
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[CST2:.*]] = torch.constant.int 2
// CHECK: %[[REDUCE_DIMS:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[CURRENT_MEAN:.*]] = torch.aten.mean.dim %arg0, %[[REDUCE_DIMS]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[1,16,27],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[16],f32>
// CHECK: %[[CURRENT_VAR:.*]] = torch.aten.var.dim %arg0, %[[REDUCE_DIMS]], %[[FALSE]], %[[FALSE]] : !torch.vtensor<[1,16,27],f32>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[16],f32>
// CHECK: %[[MEAN_MUL_MOMENTUM:.*]] = torch.aten.mul.Scalar %arg3, %[[MOMENTUM]] : !torch.vtensor<[16],f32>, !torch.float -> !torch.vtensor<[16],f32>
// CHECK: %[[CURR_MEAN_MUL_MOMENTUM:.*]] = torch.aten.mul.Scalar %[[CURRENT_MEAN]], %[[MOMENTUM]] : !torch.vtensor<[16],f32>, !torch.float -> !torch.vtensor<[16],f32>
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[VAL_0:.*]] = torch.aten.sub.Tensor %[[MEAN_MUL_MOMENTUM]], %[[CURR_MEAN_MUL_MOMENTUM]], %[[CST1]] : !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.int -> !torch.vtensor<[16],f32>
// CHECK: %[[RUNNING_MEAN:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[CURRENT_MEAN]], %[[CST1]] : !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.int -> !torch.vtensor<[16],f32>
// CHECK: %[[VAR_MUL_MOMENTUM:.*]] = torch.aten.mul.Scalar %arg4, %[[MOMENTUM]] : !torch.vtensor<[16],f32>, !torch.float -> !torch.vtensor<[16],f32>
// CHECK: %[[CURR_VAR_MUL_MOMENTUM:.*]] = torch.aten.mul.Scalar %[[CURRENT_VAR]], %[[MOMENTUM]] : !torch.vtensor<[16],f32>, !torch.float -> !torch.vtensor<[16],f32>
// CHECK: %[[VAL_1:.*]] = torch.aten.sub.Tensor %[[VAR_MUL_MOMENTUM]], %[[CURR_VAR_MUL_MOMENTUM]], %[[CST1]] : !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.int -> !torch.vtensor<[16],f32>
// CHECK: %[[RUNNING_VAR:.*]] = torch.aten.add.Tensor %[[VAL_1]], %[[CURRENT_VAR]], %[[CST1]] : !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.int -> !torch.vtensor<[16],f32>
// CHECK: %[[Y:.*]] = torch.aten.batch_norm %arg0, %arg1, %arg2, %[[CURRENT_MEAN]], %[[CURRENT_VAR]], %[[FALSE]], %[[MOMENTUM]], %[[EPSILON]], %[[FALSE]] : !torch.vtensor<[1,16,27],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[1,16,27],f32>
// CHECK: return %[[Y]], %[[RUNNING_MEAN]], %[[RUNNING_VAR]] : !torch.vtensor<[1,16,27],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>
%0:3 = torch.operator "onnx.BatchNormalization"(%arg0, %arg1, %arg2, %arg3, %arg4) {torch.onnx.epsilon = 9.99999974E-6 : f32, torch.onnx.momentum = 1.000000e+00 : f32, torch.onnx.training_mode = 1 : si64} : (!torch.vtensor<[1,16,27],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>) -> (!torch.vtensor<[1,16,27],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>)
return %0#0, %0#1, %0#2 : !torch.vtensor<[1,16,27],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>
}
// -----
// CHECK-LABEL: @test_concat_1d_axis_0
func.func @test_concat_1d_axis_0(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtensor<[2],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.list<vtensor>