[Stablehlo Dialect] fix lowering batch_norm with mixed types (#2383)

* [Stablehlo Dialect] fix lowering bn inference with mixed types

* update
pull/2408/head
Yuanqiang Liu 2023-08-21 17:36:56 +08:00 committed by GitHub
parent 8ffe5d17da
commit b636e0c40c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 64 additions and 14 deletions

View File

@ -982,7 +982,6 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
AtenBatchNormOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getInput();
// shape = [N, C, H, W]
auto inputTy = input.getType().cast<RankedTensorType>();
Value weight = adaptor.getWeight();
Value bias = adaptor.getBias();
@ -1001,7 +1000,8 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
}
auto inputElemTy = inputTy.getElementType().cast<mlir::FloatType>();
Value channelDim = rewriter.create<tensor::DimOp>(op->getLoc(), input, 1);
Value channelDim =
rewriter.create<tensor::DimOp>(op->getLoc(), input, feature_index);
if (options.dimSizeIndexBits == 32) {
auto channelDimI64 = rewriter.create<mlir::arith::IndexCastOp>(
@ -1077,12 +1077,36 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
Type outputTy = getTypeConverter()->convertType(op.getType());
Type batchMeanOrVarTy =
RankedTensorType::get(weightTy.getShape(), inputTy.getElementType());
auto batchNormTrainingResult =
rewriter.create<stablehlo::BatchNormTrainingOp>(
op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input,
weight, bias, rewriter.getF32FloatAttr(eps),
rewriter.getI64IntegerAttr(feature_index));
rewriter.replaceOp(op, batchNormTrainingResult.getResult(0));
Value output;
// supported mixed types, like input type is fp16 and weight type is fp32.
if (inputTy.getElementType() != weightTy.getElementType()) {
RankedTensorType convertedType = inputTy;
if (weightTy.getElementType().cast<FloatType>().getWidth() >
inputTy.getElementType().cast<FloatType>().getWidth()) {
convertedType = RankedTensorType::get(inputTy.getShape(),
weightTy.getElementType());
}
input = hlo::promoteType(rewriter, op.getLoc(), input, convertedType);
weight = hlo::promoteType(rewriter, op.getLoc(), weight, convertedType);
bias = hlo::promoteType(rewriter, op.getLoc(), bias, convertedType);
auto batchNormTrainingResult =
rewriter.create<stablehlo::BatchNormTrainingOp>(
op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input,
weight, bias, rewriter.getF32FloatAttr(eps),
rewriter.getI64IntegerAttr(feature_index));
output = hlo::promoteType(rewriter, op.getLoc(),
batchNormTrainingResult.getResult(0),
outputTy.cast<TensorType>());
} else {
auto batchNormTrainingResult =
rewriter.create<stablehlo::BatchNormTrainingOp>(
op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input,
weight, bias, rewriter.getF32FloatAttr(eps),
rewriter.getI64IntegerAttr(feature_index));
output = batchNormTrainingResult.getResult(0);
}
rewriter.replaceOp(op, output);
return success();
} else {
Type outputTy = getTypeConverter()->convertType(op.getType());
@ -1094,12 +1118,38 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
// stablehlo::BatchNormInferenceOp.
Value inputCasted =
rewriter.create<tensor::CastOp>(op.getLoc(), castTy, input);
Value output = rewriter.create<stablehlo::BatchNormInferenceOp>(
op.getLoc(), inputCasted.getType(), inputCasted, weight, bias,
runningMean, runningVar,
// 'epsilon' must satisfy constraint: 32-bit float attribute.
rewriter.getF32FloatAttr(eps),
rewriter.getI64IntegerAttr(feature_index));
Value output;
// supported mixed types, like input type is fp16 and weight type is fp32.
if (inputTy.getElementType() != weightTy.getElementType()) {
RankedTensorType convertedType = inputTy;
if (weightTy.getElementType().cast<FloatType>().getWidth() >
inputTy.getElementType().cast<FloatType>().getWidth()) {
convertedType = RankedTensorType::get(inputTy.getShape(),
weightTy.getElementType());
}
input =
hlo::promoteType(rewriter, op.getLoc(), inputCasted, convertedType);
weight = hlo::promoteType(rewriter, op.getLoc(), weight, convertedType);
bias = hlo::promoteType(rewriter, op.getLoc(), bias, convertedType);
runningMean =
hlo::promoteType(rewriter, op.getLoc(), runningMean, convertedType);
runningVar =
hlo::promoteType(rewriter, op.getLoc(), runningVar, convertedType);
Value bnResult = rewriter.create<stablehlo::BatchNormInferenceOp>(
op.getLoc(), convertedType, input, weight, bias, runningMean,
runningVar, rewriter.getF32FloatAttr(eps),
rewriter.getI64IntegerAttr(feature_index));
output = hlo::promoteType(rewriter, op.getLoc(), bnResult,
outputTy.cast<TensorType>());
} else {
output = rewriter.create<stablehlo::BatchNormInferenceOp>(
op.getLoc(), inputCasted.getType(), inputCasted, weight, bias,
runningMean, runningVar,
// 'epsilon' must satisfy constraint: 32-bit float attribute.
rewriter.getF32FloatAttr(eps),
rewriter.getI64IntegerAttr(feature_index));
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outputTy, output);
return success();
}