mirror of https://github.com/llvm/torch-mlir
[Stablehlo Dialect] fix lowering batch_norm with mixed types (#2383)
* [Stablehlo Dialect] fix lowering bn inference with mixed types * updatepull/2408/head
parent
8ffe5d17da
commit
b636e0c40c
|
@ -982,7 +982,6 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
AtenBatchNormOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.getInput();
|
||||
// shape = [N, C, H, W]
|
||||
auto inputTy = input.getType().cast<RankedTensorType>();
|
||||
Value weight = adaptor.getWeight();
|
||||
Value bias = adaptor.getBias();
|
||||
|
@ -1001,7 +1000,8 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
}
|
||||
auto inputElemTy = inputTy.getElementType().cast<mlir::FloatType>();
|
||||
|
||||
Value channelDim = rewriter.create<tensor::DimOp>(op->getLoc(), input, 1);
|
||||
Value channelDim =
|
||||
rewriter.create<tensor::DimOp>(op->getLoc(), input, feature_index);
|
||||
|
||||
if (options.dimSizeIndexBits == 32) {
|
||||
auto channelDimI64 = rewriter.create<mlir::arith::IndexCastOp>(
|
||||
|
@ -1077,12 +1077,36 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
Type outputTy = getTypeConverter()->convertType(op.getType());
|
||||
Type batchMeanOrVarTy =
|
||||
RankedTensorType::get(weightTy.getShape(), inputTy.getElementType());
|
||||
auto batchNormTrainingResult =
|
||||
rewriter.create<stablehlo::BatchNormTrainingOp>(
|
||||
op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input,
|
||||
weight, bias, rewriter.getF32FloatAttr(eps),
|
||||
rewriter.getI64IntegerAttr(feature_index));
|
||||
rewriter.replaceOp(op, batchNormTrainingResult.getResult(0));
|
||||
|
||||
Value output;
|
||||
// supported mixed types, like input type is fp16 and weight type is fp32.
|
||||
if (inputTy.getElementType() != weightTy.getElementType()) {
|
||||
RankedTensorType convertedType = inputTy;
|
||||
if (weightTy.getElementType().cast<FloatType>().getWidth() >
|
||||
inputTy.getElementType().cast<FloatType>().getWidth()) {
|
||||
convertedType = RankedTensorType::get(inputTy.getShape(),
|
||||
weightTy.getElementType());
|
||||
}
|
||||
input = hlo::promoteType(rewriter, op.getLoc(), input, convertedType);
|
||||
weight = hlo::promoteType(rewriter, op.getLoc(), weight, convertedType);
|
||||
bias = hlo::promoteType(rewriter, op.getLoc(), bias, convertedType);
|
||||
auto batchNormTrainingResult =
|
||||
rewriter.create<stablehlo::BatchNormTrainingOp>(
|
||||
op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input,
|
||||
weight, bias, rewriter.getF32FloatAttr(eps),
|
||||
rewriter.getI64IntegerAttr(feature_index));
|
||||
output = hlo::promoteType(rewriter, op.getLoc(),
|
||||
batchNormTrainingResult.getResult(0),
|
||||
outputTy.cast<TensorType>());
|
||||
} else {
|
||||
auto batchNormTrainingResult =
|
||||
rewriter.create<stablehlo::BatchNormTrainingOp>(
|
||||
op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input,
|
||||
weight, bias, rewriter.getF32FloatAttr(eps),
|
||||
rewriter.getI64IntegerAttr(feature_index));
|
||||
output = batchNormTrainingResult.getResult(0);
|
||||
}
|
||||
rewriter.replaceOp(op, output);
|
||||
return success();
|
||||
} else {
|
||||
Type outputTy = getTypeConverter()->convertType(op.getType());
|
||||
|
@ -1094,12 +1118,38 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
// stablehlo::BatchNormInferenceOp.
|
||||
Value inputCasted =
|
||||
rewriter.create<tensor::CastOp>(op.getLoc(), castTy, input);
|
||||
Value output = rewriter.create<stablehlo::BatchNormInferenceOp>(
|
||||
op.getLoc(), inputCasted.getType(), inputCasted, weight, bias,
|
||||
runningMean, runningVar,
|
||||
// 'epsilon' must satisfy constraint: 32-bit float attribute.
|
||||
rewriter.getF32FloatAttr(eps),
|
||||
rewriter.getI64IntegerAttr(feature_index));
|
||||
|
||||
Value output;
|
||||
// supported mixed types, like input type is fp16 and weight type is fp32.
|
||||
if (inputTy.getElementType() != weightTy.getElementType()) {
|
||||
RankedTensorType convertedType = inputTy;
|
||||
if (weightTy.getElementType().cast<FloatType>().getWidth() >
|
||||
inputTy.getElementType().cast<FloatType>().getWidth()) {
|
||||
convertedType = RankedTensorType::get(inputTy.getShape(),
|
||||
weightTy.getElementType());
|
||||
}
|
||||
input =
|
||||
hlo::promoteType(rewriter, op.getLoc(), inputCasted, convertedType);
|
||||
weight = hlo::promoteType(rewriter, op.getLoc(), weight, convertedType);
|
||||
bias = hlo::promoteType(rewriter, op.getLoc(), bias, convertedType);
|
||||
runningMean =
|
||||
hlo::promoteType(rewriter, op.getLoc(), runningMean, convertedType);
|
||||
runningVar =
|
||||
hlo::promoteType(rewriter, op.getLoc(), runningVar, convertedType);
|
||||
Value bnResult = rewriter.create<stablehlo::BatchNormInferenceOp>(
|
||||
op.getLoc(), convertedType, input, weight, bias, runningMean,
|
||||
runningVar, rewriter.getF32FloatAttr(eps),
|
||||
rewriter.getI64IntegerAttr(feature_index));
|
||||
output = hlo::promoteType(rewriter, op.getLoc(), bnResult,
|
||||
outputTy.cast<TensorType>());
|
||||
} else {
|
||||
output = rewriter.create<stablehlo::BatchNormInferenceOp>(
|
||||
op.getLoc(), inputCasted.getType(), inputCasted, weight, bias,
|
||||
runningMean, runningVar,
|
||||
// 'epsilon' must satisfy constraint: 32-bit float attribute.
|
||||
rewriter.getF32FloatAttr(eps),
|
||||
rewriter.getI64IntegerAttr(feature_index));
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outputTy, output);
|
||||
return success();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue