mirror of https://github.com/llvm/torch-mlir
[ONNX] Add training mode support for BatchNormalization op (#3597)
This commit extends the OnnxToTorch lowering for BatchNormalization op for supporting the case when training=True. Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/3631/head
parent
2511cf46b4
commit
4a0bed0ce0
|
@ -339,42 +339,121 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
binder.op, resultType, operand);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp("BatchNormalization", 15,
|
||||
patterns.onOp(
|
||||
"BatchNormalization", 15,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value input, weight, bias, runningMean, runningVar;
|
||||
Value input, weight, bias, inputMean, inputVar;
|
||||
bool training;
|
||||
float momentum, eps;
|
||||
if (binder.s64BoolAttr(training, "training_mode", 0))
|
||||
return failure();
|
||||
if (training) {
|
||||
// TODO: Add support for training = true
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "unsupported conversion: training = true");
|
||||
}
|
||||
|
||||
if (binder.tensorOperandAtIndex(input, 0) ||
|
||||
binder.tensorOperandAtIndex(weight, 1) ||
|
||||
binder.tensorOperandAtIndex(bias, 2) ||
|
||||
binder.tensorOperandAtIndex(runningMean, 3) ||
|
||||
binder.tensorOperandAtIndex(runningVar, 4) ||
|
||||
binder.tensorOperandAtIndex(inputMean, 3) ||
|
||||
binder.tensorOperandAtIndex(inputVar, 4) ||
|
||||
binder.f32FloatAttr(momentum, "momentum", 0.9f) ||
|
||||
binder.f32FloatAttr(eps, "epsilon", 1e-05f) ||
|
||||
binder.tensorResultType(resultType))
|
||||
binder.s64BoolAttr(training, "training_mode", 0) ||
|
||||
binder.tensorResultTypeAtIndex(resultType, 0))
|
||||
return failure();
|
||||
|
||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
|
||||
binder.getLoc(), false);
|
||||
Location loc = binder.getLoc();
|
||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
||||
Value cstMomentum = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getF64FloatAttr(momentum));
|
||||
loc, rewriter.getF64FloatAttr(momentum));
|
||||
Value cstEps = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getF64FloatAttr(eps));
|
||||
loc, rewriter.getF64FloatAttr(eps));
|
||||
|
||||
// When training_mode=False, the op outputs only Y, where
|
||||
// Y = (X - input_mean) / sqrt(input_var + epsilon) * scale +
|
||||
// B
|
||||
if (!training) {
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenBatchNormOp>(
|
||||
binder.op, resultType, input, weight, bias, runningMean,
|
||||
runningVar, /*training=*/cstFalse, cstMomentum, cstEps,
|
||||
binder.op, resultType, input, weight, bias, inputMean, inputVar,
|
||||
/*training=*/cstFalse, cstMomentum, cstEps,
|
||||
/*cudnn_enabled=*/cstFalse);
|
||||
return success();
|
||||
}
|
||||
|
||||
Torch::ValueTensorType meanResultType, varResultType;
|
||||
if (binder.tensorResultTypeAtIndex(meanResultType, 1) ||
|
||||
binder.tensorResultTypeAtIndex(varResultType, 2))
|
||||
return failure();
|
||||
|
||||
// When training_mode=True, the outputs are as follows:
|
||||
// Y, running_mean, running_var.
|
||||
// Y = (X - current_mean) / sqrt(current_var + epsilon) *
|
||||
// scale + B
|
||||
// running_mean = input_mean * momentum + current_mean * (1 -
|
||||
// momentum)
|
||||
// running_var = input_var * momentum + current_var * (1 -
|
||||
// momentum)
|
||||
// and
|
||||
// current_mean = ReduceMean(X, axis=all_except_channel_index)
|
||||
// current_var = ReduceVar(X, axis=all_except_channel_index)
|
||||
|
||||
Torch::ValueTensorType inputType =
|
||||
cast<Torch::ValueTensorType>(input.getType());
|
||||
if (!inputType.hasSizes())
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "unimplemented: expected input to have sizes");
|
||||
|
||||
// Computing current_mean and current_var.
|
||||
int64_t inputRank = inputType.getSizes().size();
|
||||
// Reduce all dimensions except channel dim.
|
||||
SmallVector<Value> dimsToReduce;
|
||||
for (int64_t i = 0; i < inputRank; i++) {
|
||||
if (i != 1)
|
||||
dimsToReduce.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
||||
}
|
||||
Value reduceDimsList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||
dimsToReduce);
|
||||
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
Value currentMean = rewriter.create<Torch::AtenMeanDimOp>(
|
||||
loc, meanResultType, input, reduceDimsList,
|
||||
/*keepdim=*/cstFalse,
|
||||
/*dtype=*/noneVal);
|
||||
Value currentVar = rewriter.create<Torch::AtenVarDimOp>(
|
||||
loc, varResultType, input, reduceDimsList,
|
||||
/*unbiased=*/cstFalse,
|
||||
/*keepdim=*/cstFalse);
|
||||
|
||||
// Computing running_mean.
|
||||
Value inputMeanMulMomentum = rewriter.create<Torch::AtenMulScalarOp>(
|
||||
loc, meanResultType, inputMean, cstMomentum);
|
||||
Value currentMeanMulMomentum = rewriter.create<Torch::AtenMulScalarOp>(
|
||||
loc, varResultType, currentMean, cstMomentum);
|
||||
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
Value inpMeanMMSubCurMeanMM = rewriter.create<Torch::AtenSubTensorOp>(
|
||||
loc, meanResultType, inputMeanMulMomentum, currentMeanMulMomentum,
|
||||
constantOne);
|
||||
Value runningMean = rewriter.create<Torch::AtenAddTensorOp>(
|
||||
loc, meanResultType, inpMeanMMSubCurMeanMM, currentMean,
|
||||
constantOne);
|
||||
|
||||
// Computing running_var.
|
||||
Value inputVarMulMomentum = rewriter.create<Torch::AtenMulScalarOp>(
|
||||
loc, varResultType, inputVar, cstMomentum);
|
||||
Value currentVarMulMomentum = rewriter.create<Torch::AtenMulScalarOp>(
|
||||
loc, varResultType, currentVar, cstMomentum);
|
||||
Value inpVarMMSubCurVarMM = rewriter.create<Torch::AtenSubTensorOp>(
|
||||
loc, varResultType, inputVarMulMomentum, currentVarMulMomentum,
|
||||
constantOne);
|
||||
Value runningVar = rewriter.create<Torch::AtenAddTensorOp>(
|
||||
loc, varResultType, inpVarMMSubCurVarMM, currentVar, constantOne);
|
||||
|
||||
// Computing Y.
|
||||
Value y = rewriter.create<Torch::AtenBatchNormOp>(
|
||||
loc, resultType, input, weight, bias, currentMean, currentVar,
|
||||
/*training=*/cstFalse, cstMomentum, cstEps,
|
||||
/*cudnn_enabled=*/cstFalse);
|
||||
|
||||
rewriter.replaceOp(binder.op, {y, runningMean, runningVar});
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"AveragePool", 11,
|
||||
|
|
|
@ -1266,6 +1266,34 @@ func.func @test_batchnorm_example(%arg0: !torch.vtensor<[2,3,4,5],f32>, %arg1: !
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_batchnorm_training
|
||||
func.func @test_batchnorm_training(%arg0: !torch.vtensor<[1,16,27],f32>, %arg1: !torch.vtensor<[16],f32>, %arg2: !torch.vtensor<[16],f32>, %arg3: !torch.vtensor<[16],f32>, %arg4: !torch.vtensor<[16],f32>) -> (!torch.vtensor<[1,16,27],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>) attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[MOMENTUM:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[EPSILON:.*]] = torch.constant.float 9.9999997473787516E-6
|
||||
// CHECK: %[[CST0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[CST2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[REDUCE_DIMS:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[CURRENT_MEAN:.*]] = torch.aten.mean.dim %arg0, %[[REDUCE_DIMS]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[1,16,27],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[16],f32>
|
||||
// CHECK: %[[CURRENT_VAR:.*]] = torch.aten.var.dim %arg0, %[[REDUCE_DIMS]], %[[FALSE]], %[[FALSE]] : !torch.vtensor<[1,16,27],f32>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[16],f32>
|
||||
// CHECK: %[[MEAN_MUL_MOMENTUM:.*]] = torch.aten.mul.Scalar %arg3, %[[MOMENTUM]] : !torch.vtensor<[16],f32>, !torch.float -> !torch.vtensor<[16],f32>
|
||||
// CHECK: %[[CURR_MEAN_MUL_MOMENTUM:.*]] = torch.aten.mul.Scalar %[[CURRENT_MEAN]], %[[MOMENTUM]] : !torch.vtensor<[16],f32>, !torch.float -> !torch.vtensor<[16],f32>
|
||||
// CHECK: %[[CST1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_0:.*]] = torch.aten.sub.Tensor %[[MEAN_MUL_MOMENTUM]], %[[CURR_MEAN_MUL_MOMENTUM]], %[[CST1]] : !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.int -> !torch.vtensor<[16],f32>
|
||||
// CHECK: %[[RUNNING_MEAN:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[CURRENT_MEAN]], %[[CST1]] : !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.int -> !torch.vtensor<[16],f32>
|
||||
// CHECK: %[[VAR_MUL_MOMENTUM:.*]] = torch.aten.mul.Scalar %arg4, %[[MOMENTUM]] : !torch.vtensor<[16],f32>, !torch.float -> !torch.vtensor<[16],f32>
|
||||
// CHECK: %[[CURR_VAR_MUL_MOMENTUM:.*]] = torch.aten.mul.Scalar %[[CURRENT_VAR]], %[[MOMENTUM]] : !torch.vtensor<[16],f32>, !torch.float -> !torch.vtensor<[16],f32>
|
||||
// CHECK: %[[VAL_1:.*]] = torch.aten.sub.Tensor %[[VAR_MUL_MOMENTUM]], %[[CURR_VAR_MUL_MOMENTUM]], %[[CST1]] : !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.int -> !torch.vtensor<[16],f32>
|
||||
// CHECK: %[[RUNNING_VAR:.*]] = torch.aten.add.Tensor %[[VAL_1]], %[[CURRENT_VAR]], %[[CST1]] : !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.int -> !torch.vtensor<[16],f32>
|
||||
// CHECK: %[[Y:.*]] = torch.aten.batch_norm %arg0, %arg1, %arg2, %[[CURRENT_MEAN]], %[[CURRENT_VAR]], %[[FALSE]], %[[MOMENTUM]], %[[EPSILON]], %[[FALSE]] : !torch.vtensor<[1,16,27],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[1,16,27],f32>
|
||||
// CHECK: return %[[Y]], %[[RUNNING_MEAN]], %[[RUNNING_VAR]] : !torch.vtensor<[1,16,27],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>
|
||||
%0:3 = torch.operator "onnx.BatchNormalization"(%arg0, %arg1, %arg2, %arg3, %arg4) {torch.onnx.epsilon = 9.99999974E-6 : f32, torch.onnx.momentum = 1.000000e+00 : f32, torch.onnx.training_mode = 1 : si64} : (!torch.vtensor<[1,16,27],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>) -> (!torch.vtensor<[1,16,27],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>)
|
||||
return %0#0, %0#1, %0#2 : !torch.vtensor<[1,16,27],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_concat_1d_axis_0
|
||||
func.func @test_concat_1d_axis_0(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtensor<[2],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.list<vtensor>
|
||||
|
|
Loading…
Reference in New Issue