diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 3507bafb1..8919df43a 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -339,43 +339,122 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); - patterns.onOp("BatchNormalization", 15, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value input, weight, bias, runningMean, runningVar; - bool training; - float momentum, eps; - if (binder.s64BoolAttr(training, "training_mode", 0)) - return failure(); - if (training) { - // TODO: Add support for training = true - return rewriter.notifyMatchFailure( - binder.op, "unsupported conversion: training = true"); - } + patterns.onOp( + "BatchNormalization", 15, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input, weight, bias, inputMean, inputVar; + bool training; + float momentum, eps; + if (binder.tensorOperandAtIndex(input, 0) || + binder.tensorOperandAtIndex(weight, 1) || + binder.tensorOperandAtIndex(bias, 2) || + binder.tensorOperandAtIndex(inputMean, 3) || + binder.tensorOperandAtIndex(inputVar, 4) || + binder.f32FloatAttr(momentum, "momentum", 0.9f) || + binder.f32FloatAttr(eps, "epsilon", 1e-05f) || + binder.s64BoolAttr(training, "training_mode", 0) || + binder.tensorResultTypeAtIndex(resultType, 0)) + return failure(); - if (binder.tensorOperandAtIndex(input, 0) || - binder.tensorOperandAtIndex(weight, 1) || - binder.tensorOperandAtIndex(bias, 2) || - binder.tensorOperandAtIndex(runningMean, 3) || - binder.tensorOperandAtIndex(runningVar, 4) || - binder.f32FloatAttr(momentum, "momentum", 0.9f) || - binder.f32FloatAttr(eps, "epsilon", 1e-05f) || - binder.tensorResultType(resultType)) - return failure(); + Location loc = binder.getLoc(); + Value cstFalse = rewriter.create(loc, false); + Value cstMomentum = rewriter.create( + loc, rewriter.getF64FloatAttr(momentum)); + Value cstEps = rewriter.create( + loc, rewriter.getF64FloatAttr(eps)); - Value cstFalse = rewriter.create( - binder.getLoc(), false); - Value cstMomentum = rewriter.create( - binder.getLoc(), rewriter.getF64FloatAttr(momentum)); - Value cstEps = rewriter.create( - binder.getLoc(), rewriter.getF64FloatAttr(eps)); + // When training_mode=False, the op outputs only Y, where + // Y = (X - input_mean) / sqrt(input_var + epsilon) * scale + + // B + if (!training) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, weight, bias, inputMean, inputVar, + /*training=*/cstFalse, cstMomentum, cstEps, + /*cudnn_enabled=*/cstFalse); + return success(); + } - rewriter.replaceOpWithNewOp( - binder.op, resultType, input, weight, bias, runningMean, - runningVar, /*training=*/cstFalse, cstMomentum, cstEps, - /*cudnn_enabled=*/cstFalse); - return success(); - }); + Torch::ValueTensorType meanResultType, varResultType; + if (binder.tensorResultTypeAtIndex(meanResultType, 1) || + binder.tensorResultTypeAtIndex(varResultType, 2)) + return failure(); + + // When training_mode=True, the outputs are as follows: + // Y, running_mean, running_var. + // Y = (X - current_mean) / sqrt(current_var + epsilon) * + // scale + B + // running_mean = input_mean * momentum + current_mean * (1 - + // momentum) + // running_var = input_var * momentum + current_var * (1 - + // momentum) + // and + // current_mean = ReduceMean(X, axis=all_except_channel_index) + // current_var = ReduceVar(X, axis=all_except_channel_index) + + Torch::ValueTensorType inputType = + cast(input.getType()); + if (!inputType.hasSizes()) + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: expected input to have sizes"); + + // Computing current_mean and current_var. + int64_t inputRank = inputType.getSizes().size(); + // Reduce all dimensions except channel dim. + SmallVector dimsToReduce; + for (int64_t i = 0; i < inputRank; i++) { + if (i != 1) + dimsToReduce.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + Value reduceDimsList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + dimsToReduce); + Value noneVal = rewriter.create(binder.getLoc()); + Value currentMean = rewriter.create( + loc, meanResultType, input, reduceDimsList, + /*keepdim=*/cstFalse, + /*dtype=*/noneVal); + Value currentVar = rewriter.create( + loc, varResultType, input, reduceDimsList, + /*unbiased=*/cstFalse, + /*keepdim=*/cstFalse); + + // Computing running_mean. + Value inputMeanMulMomentum = rewriter.create( + loc, meanResultType, inputMean, cstMomentum); + Value currentMeanMulMomentum = rewriter.create( + loc, varResultType, currentMean, cstMomentum); + Value constantOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value inpMeanMMSubCurMeanMM = rewriter.create( + loc, meanResultType, inputMeanMulMomentum, currentMeanMulMomentum, + constantOne); + Value runningMean = rewriter.create( + loc, meanResultType, inpMeanMMSubCurMeanMM, currentMean, + constantOne); + + // Computing running_var. + Value inputVarMulMomentum = rewriter.create( + loc, varResultType, inputVar, cstMomentum); + Value currentVarMulMomentum = rewriter.create( + loc, varResultType, currentVar, cstMomentum); + Value inpVarMMSubCurVarMM = rewriter.create( + loc, varResultType, inputVarMulMomentum, currentVarMulMomentum, + constantOne); + Value runningVar = rewriter.create( + loc, varResultType, inpVarMMSubCurVarMM, currentVar, constantOne); + + // Computing Y. + Value y = rewriter.create( + loc, resultType, input, weight, bias, currentMean, currentVar, + /*training=*/cstFalse, cstMomentum, cstEps, + /*cudnn_enabled=*/cstFalse); + + rewriter.replaceOp(binder.op, {y, runningMean, runningVar}); + return success(); + }); patterns.onOp( "AveragePool", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 2c70d6730..10cca7f80 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -1266,6 +1266,34 @@ func.func @test_batchnorm_example(%arg0: !torch.vtensor<[2,3,4,5],f32>, %arg1: ! // ----- +// CHECK-LABEL: func.func @test_batchnorm_training +func.func @test_batchnorm_training(%arg0: !torch.vtensor<[1,16,27],f32>, %arg1: !torch.vtensor<[16],f32>, %arg2: !torch.vtensor<[16],f32>, %arg3: !torch.vtensor<[16],f32>, %arg4: !torch.vtensor<[16],f32>) -> (!torch.vtensor<[1,16,27],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>) attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[MOMENTUM:.*]] = torch.constant.float 1.000000e+00 +// CHECK: %[[EPSILON:.*]] = torch.constant.float 9.9999997473787516E-6 +// CHECK: %[[CST0:.*]] = torch.constant.int 0 +// CHECK: %[[CST2:.*]] = torch.constant.int 2 +// CHECK: %[[REDUCE_DIMS:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST2]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[CURRENT_MEAN:.*]] = torch.aten.mean.dim %arg0, %[[REDUCE_DIMS]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[1,16,27],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[16],f32> +// CHECK: %[[CURRENT_VAR:.*]] = torch.aten.var.dim %arg0, %[[REDUCE_DIMS]], %[[FALSE]], %[[FALSE]] : !torch.vtensor<[1,16,27],f32>, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[16],f32> +// CHECK: %[[MEAN_MUL_MOMENTUM:.*]] = torch.aten.mul.Scalar %arg3, %[[MOMENTUM]] : !torch.vtensor<[16],f32>, !torch.float -> !torch.vtensor<[16],f32> +// CHECK: %[[CURR_MEAN_MUL_MOMENTUM:.*]] = torch.aten.mul.Scalar %[[CURRENT_MEAN]], %[[MOMENTUM]] : !torch.vtensor<[16],f32>, !torch.float -> !torch.vtensor<[16],f32> +// CHECK: %[[CST1:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_0:.*]] = torch.aten.sub.Tensor %[[MEAN_MUL_MOMENTUM]], %[[CURR_MEAN_MUL_MOMENTUM]], %[[CST1]] : !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.int -> !torch.vtensor<[16],f32> +// CHECK: %[[RUNNING_MEAN:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[CURRENT_MEAN]], %[[CST1]] : !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.int -> !torch.vtensor<[16],f32> +// CHECK: %[[VAR_MUL_MOMENTUM:.*]] = torch.aten.mul.Scalar %arg4, %[[MOMENTUM]] : !torch.vtensor<[16],f32>, !torch.float -> !torch.vtensor<[16],f32> +// CHECK: %[[CURR_VAR_MUL_MOMENTUM:.*]] = torch.aten.mul.Scalar %[[CURRENT_VAR]], %[[MOMENTUM]] : !torch.vtensor<[16],f32>, !torch.float -> !torch.vtensor<[16],f32> +// CHECK: %[[VAL_1:.*]] = torch.aten.sub.Tensor %[[VAR_MUL_MOMENTUM]], %[[CURR_VAR_MUL_MOMENTUM]], %[[CST1]] : !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.int -> !torch.vtensor<[16],f32> +// CHECK: %[[RUNNING_VAR:.*]] = torch.aten.add.Tensor %[[VAL_1]], %[[CURRENT_VAR]], %[[CST1]] : !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.int -> !torch.vtensor<[16],f32> +// CHECK: %[[Y:.*]] = torch.aten.batch_norm %arg0, %arg1, %arg2, %[[CURRENT_MEAN]], %[[CURRENT_VAR]], %[[FALSE]], %[[MOMENTUM]], %[[EPSILON]], %[[FALSE]] : !torch.vtensor<[1,16,27],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[1,16,27],f32> +// CHECK: return %[[Y]], %[[RUNNING_MEAN]], %[[RUNNING_VAR]] : !torch.vtensor<[1,16,27],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32> + %0:3 = torch.operator "onnx.BatchNormalization"(%arg0, %arg1, %arg2, %arg3, %arg4) {torch.onnx.epsilon = 9.99999974E-6 : f32, torch.onnx.momentum = 1.000000e+00 : f32, torch.onnx.training_mode = 1 : si64} : (!torch.vtensor<[1,16,27],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>) -> (!torch.vtensor<[1,16,27],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>) + return %0#0, %0#1, %0#2 : !torch.vtensor<[1,16,27],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32> +} + +// ----- + // CHECK-LABEL: @test_concat_1d_axis_0 func.func @test_concat_1d_axis_0(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtensor<[2],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.list