[Stablehlo] simplify promoteType (#3525)

only provide `outElementType` when promoteType
pull/3530/head
Yuanqiang Liu 2024-07-10 10:52:19 +08:00 committed by GitHub
parent dcb48dd46c
commit 5bee9aac63
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 91 additions and 78 deletions

View File

@ -50,7 +50,7 @@ Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
Operation *op, Value scalarValue, Type dtype);
Value promoteType(PatternRewriter &rewriter, Location loc, Value input,
TensorType outType);
Type outElementType);
Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
TensorType outType);

View File

@ -148,7 +148,8 @@ public:
auto outType = cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
self = hlo::promoteType(rewriter, op.getLoc(), self, outType);
self =
hlo::promoteType(rewriter, op.getLoc(), self, outType.getElementType());
rewriter.replaceOpWithNewOp<StablehloOpT>(op, outType, self);
return success();
}
@ -207,7 +208,8 @@ public:
op.getType()));
if (isa<mlir::FloatType>(resultTy.getElementType())) {
Value src = hlo::promoteType(rewriter, op.getLoc(), self, resultTy);
Value src = hlo::promoteType(rewriter, op.getLoc(), self,
resultTy.getElementType());
rewriter.replaceOpWithNewOp<StablehloOpT>(op, resultTy, src);
return success();
} else {
@ -334,8 +336,8 @@ public:
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy);
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy.getElementType());
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy.getElementType());
rewriter.replaceOpWithNewOp<ChloOpT>(op, outTy, lhs, rhs,
/*broadcast_attr*/ nullptr);
@ -381,8 +383,8 @@ public:
}
}
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy);
if (!skipMultiplyAlpha(op.getAlpha())) {
Value alpha = hlo::scalarToStablehloTensor(rewriter, op,
@ -437,8 +439,8 @@ public:
}
}
DenseI64ArrayAttr bcastDimensions;
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy);
auto loc = op.getLoc();
Value result =
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
@ -476,16 +478,17 @@ public:
if (isa<mlir::FloatType>(outElemTy))
result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult();
else if (!outElemTy.isUnsignedInteger()) {
TensorType defaultIntToFloatType =
outType.cloneWith(outType.getShape(), rewriter.getF64Type());
Type defaultIntToFloatType = rewriter.getF64Type();
lhs =
hlo::promoteType(rewriter, op.getLoc(), lhs, defaultIntToFloatType);
rhs =
hlo::promoteType(rewriter, op.getLoc(), rhs, defaultIntToFloatType);
result = rewriter.create<ChloOpT>(loc, defaultIntToFloatType, lhs, rhs,
bcastDimensions);
result = rewriter.create<ChloOpT>(
loc, outType.cloneWith(outType.getShape(), defaultIntToFloatType),
lhs, rhs, bcastDimensions);
result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult();
result = hlo::promoteType(rewriter, op.getLoc(), result, outType);
result = hlo::promoteType(rewriter, op.getLoc(), result,
outType.getElementType());
}
}
rewriter.replaceOp(op, result);
@ -517,7 +520,8 @@ public:
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
rhs.getType());
// use lhs's element type as compute type
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy);
rhs =
hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy.getElementType());
rhsTy = dyn_cast<RankedTensorType>(rhs.getType());
}
@ -533,16 +537,16 @@ public:
}
if (isa<mlir::IntegerType>(lhsElemTy) && isa<mlir::FloatType>(rhsElemTy)) {
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsTy);
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy);
} else if (isa<mlir::FloatType>(lhsElemTy) &&
isa<mlir::IntegerType>(rhsElemTy)) {
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy);
} else {
if (lhsElemTy.getIntOrFloatBitWidth() >
rhsElemTy.getIntOrFloatBitWidth()) {
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy);
} else {
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsTy);
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy);
}
}
lhsElemTy = dyn_cast<RankedTensorType>(lhs.getType()).getElementType();
@ -622,11 +626,11 @@ public:
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
Type outElemTy = outType.getElementType();
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy);
if (!rhsTy) {
rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy);
}
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy);
DenseI64ArrayAttr bcastDimensions;
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
@ -736,8 +740,10 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
auto outType =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
// promote self and other types
self = hlo::promoteType(rewriter, op.getLoc(), self, outType);
other = hlo::promoteType(rewriter, op.getLoc(), other, outType);
self =
hlo::promoteType(rewriter, op.getLoc(), self, outType.getElementType());
other =
hlo::promoteType(rewriter, op.getLoc(), other, outType.getElementType());
if (failed(broadcastRanks(rewriter, op, self, cond)))
return op.emitError("failed broadcast self and condition ranks");
@ -940,8 +946,8 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy);
}
DenseI64ArrayAttr bcastDimensions;
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy);
auto loc = op.getLoc();
Value result = rewriter.create<chlo::BroadcastPowOp>(loc, outType, lhs, rhs,
bcastDimensions);
@ -977,8 +983,8 @@ LogicalResult ConvertAtenOp<AtenPowScalarOp>::matchAndRewrite(
lhs = hlo::scalarToStablehloTensor(rewriter, op, lhs, outElemTy);
}
DenseI64ArrayAttr bcastDimensions;
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy);
auto loc = op.getLoc();
Value result = rewriter.create<chlo::BroadcastPowOp>(loc, outType, lhs, rhs,
bcastDimensions);
@ -1121,7 +1127,8 @@ LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
return op.emitError("only ranked tensor type is supported.");
}
auto outTy = cast<TensorType>(getTypeConverter()->convertType(op.getType()));
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
input =
hlo::promoteType(rewriter, op.getLoc(), input, outTy.getElementType());
auto two = hlo::getConstantLike(rewriter, op.getLoc(), 2.0, input);
auto log2Op = rewriter.create<stablehlo::LogOp>(op.getLoc(), two);
@ -1143,7 +1150,8 @@ LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
}
auto outTy = cast<TensorType>(getTypeConverter()->convertType(op.getType()));
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
input =
hlo::promoteType(rewriter, op.getLoc(), input, outTy.getElementType());
auto ten = hlo::getConstantLike(rewriter, op.getLoc(), 10.0, input);
auto log10Op = rewriter.create<stablehlo::LogOp>(op.getLoc(), ten);
@ -1266,42 +1274,44 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
"non-bool cudnn_enabled unsupported");
}
if (training) {
Type outputTy = getTypeConverter()->convertType(op.getType());
Type batchMeanOrVarTy =
RankedTensorType::get(weightTy.getShape(), inputTy.getElementType());
TensorType outputTy =
cast<TensorType>(getTypeConverter()->convertType(op.getType()));
Value output;
// supported mixed types, like input type is fp16 and weight type is fp32.
if (inputTy.getElementType() != weightTy.getElementType()) {
RankedTensorType convertedType = inputTy;
if (cast<FloatType>(weightTy.getElementType()).getWidth() >
cast<FloatType>(inputTy.getElementType()).getWidth()) {
convertedType = RankedTensorType::get(inputTy.getShape(),
weightTy.getElementType());
Type computeType = inputTy.getElementType();
if (weightTy.getElementType().getIntOrFloatBitWidth() >
inputTy.getElementType().getIntOrFloatBitWidth()) {
computeType = weightTy.getElementType();
}
input = hlo::promoteType(rewriter, op.getLoc(), input, convertedType);
weight = hlo::promoteType(rewriter, op.getLoc(), weight, convertedType);
bias = hlo::promoteType(rewriter, op.getLoc(), bias, convertedType);
input = hlo::promoteType(rewriter, op.getLoc(), input, computeType);
weight = hlo::promoteType(rewriter, op.getLoc(), weight, computeType);
bias = hlo::promoteType(rewriter, op.getLoc(), bias, computeType);
auto batchNormTrainingResult =
rewriter.create<stablehlo::BatchNormTrainingOp>(
op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input,
op.getLoc(),
RankedTensorType::get(inputTy.getShape(), computeType),
RankedTensorType::get(weightTy.getShape(), computeType),
RankedTensorType::get(weightTy.getShape(), computeType), input,
weight, bias, rewriter.getF32FloatAttr(eps),
rewriter.getI64IntegerAttr(feature_index));
output = hlo::promoteType(rewriter, op.getLoc(),
batchNormTrainingResult.getResult(0),
cast<TensorType>(outputTy));
outputTy.getElementType());
} else {
auto batchNormTrainingResult =
rewriter.create<stablehlo::BatchNormTrainingOp>(
op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input,
weight, bias, rewriter.getF32FloatAttr(eps),
op.getLoc(), outputTy, weightTy, weightTy, input, weight, bias,
rewriter.getF32FloatAttr(eps),
rewriter.getI64IntegerAttr(feature_index));
output = batchNormTrainingResult.getResult(0);
}
rewriter.replaceOp(op, output);
return success();
} else {
Type outputTy = getTypeConverter()->convertType(op.getType());
TensorType outputTy =
cast<TensorType>(getTypeConverter()->convertType(op.getType()));
SmallVector<int64_t, 4> castShape{inputTy.getShape().begin(),
inputTy.getShape().end()};
castShape[1] = weightTy.getShape()[0];
@ -1314,26 +1324,25 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
Value output;
// supported mixed types, like input type is fp16 and weight type is fp32.
if (inputTy.getElementType() != weightTy.getElementType()) {
RankedTensorType convertedType = inputTy;
if (cast<FloatType>(weightTy.getElementType()).getWidth() >
cast<FloatType>(inputTy.getElementType()).getWidth()) {
convertedType = RankedTensorType::get(inputTy.getShape(),
weightTy.getElementType());
Type computeType = inputTy.getElementType();
if (weightTy.getElementType().getIntOrFloatBitWidth() >
inputTy.getElementType().getIntOrFloatBitWidth()) {
computeType = weightTy.getElementType();
}
input =
hlo::promoteType(rewriter, op.getLoc(), inputCasted, convertedType);
weight = hlo::promoteType(rewriter, op.getLoc(), weight, convertedType);
bias = hlo::promoteType(rewriter, op.getLoc(), bias, convertedType);
input = hlo::promoteType(rewriter, op.getLoc(), inputCasted, computeType);
weight = hlo::promoteType(rewriter, op.getLoc(), weight, computeType);
bias = hlo::promoteType(rewriter, op.getLoc(), bias, computeType);
runningMean =
hlo::promoteType(rewriter, op.getLoc(), runningMean, convertedType);
hlo::promoteType(rewriter, op.getLoc(), runningMean, computeType);
runningVar =
hlo::promoteType(rewriter, op.getLoc(), runningVar, convertedType);
hlo::promoteType(rewriter, op.getLoc(), runningVar, computeType);
Value bnResult = rewriter.create<stablehlo::BatchNormInferenceOp>(
op.getLoc(), convertedType, input, weight, bias, runningMean,
runningVar, rewriter.getF32FloatAttr(eps),
op.getLoc(), RankedTensorType::get(inputTy.getShape(), computeType),
input, weight, bias, runningMean, runningVar,
rewriter.getF32FloatAttr(eps),
rewriter.getI64IntegerAttr(feature_index));
output = hlo::promoteType(rewriter, op.getLoc(), bnResult,
cast<TensorType>(outputTy));
outputTy.getElementType());
} else {
output = rewriter.create<stablehlo::BatchNormInferenceOp>(
op.getLoc(), inputCasted.getType(), inputCasted, weight, bias,
@ -1515,7 +1524,7 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
// Promote type
for (auto &v : builtinTensors) {
v = hlo::promoteType(rewriter, op->getLoc(), v, outType);
v = hlo::promoteType(rewriter, op->getLoc(), v, outType.getElementType());
}
rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(
@ -1787,8 +1796,8 @@ LogicalResult ConvertAtenOp<AtenPowTensorTensorOp>::matchAndRewrite(
auto outTy =
cast<TensorType>(this->getTypeConverter()->convertType(op.getType()));
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy);
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy.getElementType());
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy.getElementType());
rewriter.replaceOpWithNewOp<chlo::BroadcastPowOp>(op, outTy, lhs, rhs,
/*broadcast_attr*/ nullptr);
@ -1961,8 +1970,10 @@ LogicalResult ConvertAtenOp<AtenRemainderTensorOp>::matchAndRewrite(
auto resultType =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
lhs = hlo::promoteType(rewriter, op->getLoc(), lhs, resultType);
rhs = hlo::promoteType(rewriter, op->getLoc(), rhs, resultType);
lhs = hlo::promoteType(rewriter, op->getLoc(), lhs,
resultType.getElementType());
rhs = hlo::promoteType(rewriter, op->getLoc(), rhs,
resultType.getElementType());
rewriter.replaceOpWithNewOp<stablehlo::RemOp>(op, lhs, rhs);
return success();
}
@ -1979,8 +1990,10 @@ LogicalResult ConvertAtenOp<AtenFmodTensorOp>::matchAndRewrite(
auto resultType =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
lhs = hlo::promoteType(rewriter, op->getLoc(), lhs, resultType);
rhs = hlo::promoteType(rewriter, op->getLoc(), rhs, resultType);
lhs = hlo::promoteType(rewriter, op->getLoc(), lhs,
resultType.getElementType());
rhs = hlo::promoteType(rewriter, op->getLoc(), rhs,
resultType.getElementType());
stablehlo::MulOp mul;
auto div = rewriter.create<stablehlo::DivOp>(loc, lhs, rhs);

View File

@ -835,7 +835,8 @@ public:
llvm::to_vector<4>(llvm::seq<int64_t>(-nSpatialDims, 0));
bias = *hlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims);
bias = hlo::promoteType(rewriter, op.getLoc(), bias, outTy);
bias =
hlo::promoteType(rewriter, op.getLoc(), bias, outTy.getElementType());
DenseI64ArrayAttr bcastDimensions;
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(

View File

@ -522,7 +522,8 @@ public:
} else {
assert(false && "Unsupported pooling dimension");
}
divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy);
divisor = hlo::promoteType(rewriter, op.getLoc(), divisor,
outTy.getElementType());
DenseI64ArrayAttr bcastDimensions;
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions);
@ -532,8 +533,8 @@ public:
// Use another mhlo.ReduceWindowOp to get the divisor
Value windowSizeConst =
hlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
windowSizeConst =
hlo::promoteType(rewriter, op.getLoc(), windowSizeConst, outTy);
windowSizeConst = hlo::promoteType(rewriter, op.getLoc(), windowSizeConst,
outTy.getElementType());
auto inputShapeVec = *hlo::getDimIndexOfTensor(rewriter, op, input);
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), inputShapeVec);
@ -583,7 +584,8 @@ LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
auto inputTy = cast<RankedTensorType>(input.getType());
auto outTy =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
input =
hlo::promoteType(rewriter, op.getLoc(), input, outTy.getElementType());
inputTy = cast<RankedTensorType>(input.getType());
auto inputElemTy = inputTy.getElementType();
auto inputRank = inputTy.getRank();

View File

@ -170,13 +170,10 @@ Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
}
Value promoteType(PatternRewriter &rewriter, Location loc, Value input,
TensorType outType) {
TensorType in_type = cast<TensorType>(input.getType());
if (in_type.getElementType() != outType.getElementType()) {
TensorType promotedType =
in_type.cloneWith(in_type.getShape(), outType.getElementType());
return rewriter.create<stablehlo::ConvertOp>(loc, promotedType, input);
Type outElementType) {
TensorType inType = cast<TensorType>(input.getType());
if (inType.getElementType() != outElementType) {
return rewriter.create<stablehlo::ConvertOp>(loc, input, outElementType);
}
return input;
}