[Stablehlo] simplify promoteType (#3525)

only provide `outElementType` when promoteType
pull/3530/head
Yuanqiang Liu 2024-07-10 10:52:19 +08:00 committed by GitHub
parent dcb48dd46c
commit 5bee9aac63
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 91 additions and 78 deletions

View File

@ -50,7 +50,7 @@ Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
Operation *op, Value scalarValue, Type dtype); Operation *op, Value scalarValue, Type dtype);
Value promoteType(PatternRewriter &rewriter, Location loc, Value input, Value promoteType(PatternRewriter &rewriter, Location loc, Value input,
TensorType outType); Type outElementType);
Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
TensorType outType); TensorType outType);

View File

@ -148,7 +148,8 @@ public:
auto outType = cast<TensorType>( auto outType = cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType( OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType())); op.getType()));
self = hlo::promoteType(rewriter, op.getLoc(), self, outType); self =
hlo::promoteType(rewriter, op.getLoc(), self, outType.getElementType());
rewriter.replaceOpWithNewOp<StablehloOpT>(op, outType, self); rewriter.replaceOpWithNewOp<StablehloOpT>(op, outType, self);
return success(); return success();
} }
@ -207,7 +208,8 @@ public:
op.getType())); op.getType()));
if (isa<mlir::FloatType>(resultTy.getElementType())) { if (isa<mlir::FloatType>(resultTy.getElementType())) {
Value src = hlo::promoteType(rewriter, op.getLoc(), self, resultTy); Value src = hlo::promoteType(rewriter, op.getLoc(), self,
resultTy.getElementType());
rewriter.replaceOpWithNewOp<StablehloOpT>(op, resultTy, src); rewriter.replaceOpWithNewOp<StablehloOpT>(op, resultTy, src);
return success(); return success();
} else { } else {
@ -334,8 +336,8 @@ public:
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType( OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType())); op.getType()));
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy); lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy.getElementType());
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy.getElementType());
rewriter.replaceOpWithNewOp<ChloOpT>(op, outTy, lhs, rhs, rewriter.replaceOpWithNewOp<ChloOpT>(op, outTy, lhs, rhs,
/*broadcast_attr*/ nullptr); /*broadcast_attr*/ nullptr);
@ -381,8 +383,8 @@ public:
} }
} }
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy);
if (!skipMultiplyAlpha(op.getAlpha())) { if (!skipMultiplyAlpha(op.getAlpha())) {
Value alpha = hlo::scalarToStablehloTensor(rewriter, op, Value alpha = hlo::scalarToStablehloTensor(rewriter, op,
@ -437,8 +439,8 @@ public:
} }
} }
DenseI64ArrayAttr bcastDimensions; DenseI64ArrayAttr bcastDimensions;
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy);
auto loc = op.getLoc(); auto loc = op.getLoc();
Value result = Value result =
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions); rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
@ -476,16 +478,17 @@ public:
if (isa<mlir::FloatType>(outElemTy)) if (isa<mlir::FloatType>(outElemTy))
result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult(); result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult();
else if (!outElemTy.isUnsignedInteger()) { else if (!outElemTy.isUnsignedInteger()) {
TensorType defaultIntToFloatType = Type defaultIntToFloatType = rewriter.getF64Type();
outType.cloneWith(outType.getShape(), rewriter.getF64Type());
lhs = lhs =
hlo::promoteType(rewriter, op.getLoc(), lhs, defaultIntToFloatType); hlo::promoteType(rewriter, op.getLoc(), lhs, defaultIntToFloatType);
rhs = rhs =
hlo::promoteType(rewriter, op.getLoc(), rhs, defaultIntToFloatType); hlo::promoteType(rewriter, op.getLoc(), rhs, defaultIntToFloatType);
result = rewriter.create<ChloOpT>(loc, defaultIntToFloatType, lhs, rhs, result = rewriter.create<ChloOpT>(
bcastDimensions); loc, outType.cloneWith(outType.getShape(), defaultIntToFloatType),
lhs, rhs, bcastDimensions);
result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult(); result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult();
result = hlo::promoteType(rewriter, op.getLoc(), result, outType); result = hlo::promoteType(rewriter, op.getLoc(), result,
outType.getElementType());
} }
} }
rewriter.replaceOp(op, result); rewriter.replaceOp(op, result);
@ -517,7 +520,8 @@ public:
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(), rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
rhs.getType()); rhs.getType());
// use lhs's element type as compute type // use lhs's element type as compute type
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy); rhs =
hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy.getElementType());
rhsTy = dyn_cast<RankedTensorType>(rhs.getType()); rhsTy = dyn_cast<RankedTensorType>(rhs.getType());
} }
@ -533,16 +537,16 @@ public:
} }
if (isa<mlir::IntegerType>(lhsElemTy) && isa<mlir::FloatType>(rhsElemTy)) { if (isa<mlir::IntegerType>(lhsElemTy) && isa<mlir::FloatType>(rhsElemTy)) {
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsTy); lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy);
} else if (isa<mlir::FloatType>(lhsElemTy) && } else if (isa<mlir::FloatType>(lhsElemTy) &&
isa<mlir::IntegerType>(rhsElemTy)) { isa<mlir::IntegerType>(rhsElemTy)) {
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy);
} else { } else {
if (lhsElemTy.getIntOrFloatBitWidth() > if (lhsElemTy.getIntOrFloatBitWidth() >
rhsElemTy.getIntOrFloatBitWidth()) { rhsElemTy.getIntOrFloatBitWidth()) {
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy);
} else { } else {
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsTy); lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy);
} }
} }
lhsElemTy = dyn_cast<RankedTensorType>(lhs.getType()).getElementType(); lhsElemTy = dyn_cast<RankedTensorType>(lhs.getType()).getElementType();
@ -622,11 +626,11 @@ public:
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType( OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType())); op.getType()));
Type outElemTy = outType.getElementType(); Type outElemTy = outType.getElementType();
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy);
if (!rhsTy) { if (!rhsTy) {
rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy); rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy);
} }
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy);
DenseI64ArrayAttr bcastDimensions; DenseI64ArrayAttr bcastDimensions;
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs, rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
@ -736,8 +740,10 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
auto outType = auto outType =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType())); cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
// promote self and other types // promote self and other types
self = hlo::promoteType(rewriter, op.getLoc(), self, outType); self =
other = hlo::promoteType(rewriter, op.getLoc(), other, outType); hlo::promoteType(rewriter, op.getLoc(), self, outType.getElementType());
other =
hlo::promoteType(rewriter, op.getLoc(), other, outType.getElementType());
if (failed(broadcastRanks(rewriter, op, self, cond))) if (failed(broadcastRanks(rewriter, op, self, cond)))
return op.emitError("failed broadcast self and condition ranks"); return op.emitError("failed broadcast self and condition ranks");
@ -940,8 +946,8 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy); rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy);
} }
DenseI64ArrayAttr bcastDimensions; DenseI64ArrayAttr bcastDimensions;
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy);
auto loc = op.getLoc(); auto loc = op.getLoc();
Value result = rewriter.create<chlo::BroadcastPowOp>(loc, outType, lhs, rhs, Value result = rewriter.create<chlo::BroadcastPowOp>(loc, outType, lhs, rhs,
bcastDimensions); bcastDimensions);
@ -977,8 +983,8 @@ LogicalResult ConvertAtenOp<AtenPowScalarOp>::matchAndRewrite(
lhs = hlo::scalarToStablehloTensor(rewriter, op, lhs, outElemTy); lhs = hlo::scalarToStablehloTensor(rewriter, op, lhs, outElemTy);
} }
DenseI64ArrayAttr bcastDimensions; DenseI64ArrayAttr bcastDimensions;
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy);
auto loc = op.getLoc(); auto loc = op.getLoc();
Value result = rewriter.create<chlo::BroadcastPowOp>(loc, outType, lhs, rhs, Value result = rewriter.create<chlo::BroadcastPowOp>(loc, outType, lhs, rhs,
bcastDimensions); bcastDimensions);
@ -1121,7 +1127,8 @@ LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
return op.emitError("only ranked tensor type is supported."); return op.emitError("only ranked tensor type is supported.");
} }
auto outTy = cast<TensorType>(getTypeConverter()->convertType(op.getType())); auto outTy = cast<TensorType>(getTypeConverter()->convertType(op.getType()));
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy); input =
hlo::promoteType(rewriter, op.getLoc(), input, outTy.getElementType());
auto two = hlo::getConstantLike(rewriter, op.getLoc(), 2.0, input); auto two = hlo::getConstantLike(rewriter, op.getLoc(), 2.0, input);
auto log2Op = rewriter.create<stablehlo::LogOp>(op.getLoc(), two); auto log2Op = rewriter.create<stablehlo::LogOp>(op.getLoc(), two);
@ -1143,7 +1150,8 @@ LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
} }
auto outTy = cast<TensorType>(getTypeConverter()->convertType(op.getType())); auto outTy = cast<TensorType>(getTypeConverter()->convertType(op.getType()));
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy); input =
hlo::promoteType(rewriter, op.getLoc(), input, outTy.getElementType());
auto ten = hlo::getConstantLike(rewriter, op.getLoc(), 10.0, input); auto ten = hlo::getConstantLike(rewriter, op.getLoc(), 10.0, input);
auto log10Op = rewriter.create<stablehlo::LogOp>(op.getLoc(), ten); auto log10Op = rewriter.create<stablehlo::LogOp>(op.getLoc(), ten);
@ -1266,42 +1274,44 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
"non-bool cudnn_enabled unsupported"); "non-bool cudnn_enabled unsupported");
} }
if (training) { if (training) {
Type outputTy = getTypeConverter()->convertType(op.getType()); TensorType outputTy =
Type batchMeanOrVarTy = cast<TensorType>(getTypeConverter()->convertType(op.getType()));
RankedTensorType::get(weightTy.getShape(), inputTy.getElementType());
Value output; Value output;
// supported mixed types, like input type is fp16 and weight type is fp32. // supported mixed types, like input type is fp16 and weight type is fp32.
if (inputTy.getElementType() != weightTy.getElementType()) { if (inputTy.getElementType() != weightTy.getElementType()) {
RankedTensorType convertedType = inputTy; Type computeType = inputTy.getElementType();
if (cast<FloatType>(weightTy.getElementType()).getWidth() > if (weightTy.getElementType().getIntOrFloatBitWidth() >
cast<FloatType>(inputTy.getElementType()).getWidth()) { inputTy.getElementType().getIntOrFloatBitWidth()) {
convertedType = RankedTensorType::get(inputTy.getShape(), computeType = weightTy.getElementType();
weightTy.getElementType());
} }
input = hlo::promoteType(rewriter, op.getLoc(), input, convertedType); input = hlo::promoteType(rewriter, op.getLoc(), input, computeType);
weight = hlo::promoteType(rewriter, op.getLoc(), weight, convertedType); weight = hlo::promoteType(rewriter, op.getLoc(), weight, computeType);
bias = hlo::promoteType(rewriter, op.getLoc(), bias, convertedType); bias = hlo::promoteType(rewriter, op.getLoc(), bias, computeType);
auto batchNormTrainingResult = auto batchNormTrainingResult =
rewriter.create<stablehlo::BatchNormTrainingOp>( rewriter.create<stablehlo::BatchNormTrainingOp>(
op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input, op.getLoc(),
RankedTensorType::get(inputTy.getShape(), computeType),
RankedTensorType::get(weightTy.getShape(), computeType),
RankedTensorType::get(weightTy.getShape(), computeType), input,
weight, bias, rewriter.getF32FloatAttr(eps), weight, bias, rewriter.getF32FloatAttr(eps),
rewriter.getI64IntegerAttr(feature_index)); rewriter.getI64IntegerAttr(feature_index));
output = hlo::promoteType(rewriter, op.getLoc(), output = hlo::promoteType(rewriter, op.getLoc(),
batchNormTrainingResult.getResult(0), batchNormTrainingResult.getResult(0),
cast<TensorType>(outputTy)); outputTy.getElementType());
} else { } else {
auto batchNormTrainingResult = auto batchNormTrainingResult =
rewriter.create<stablehlo::BatchNormTrainingOp>( rewriter.create<stablehlo::BatchNormTrainingOp>(
op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input, op.getLoc(), outputTy, weightTy, weightTy, input, weight, bias,
weight, bias, rewriter.getF32FloatAttr(eps), rewriter.getF32FloatAttr(eps),
rewriter.getI64IntegerAttr(feature_index)); rewriter.getI64IntegerAttr(feature_index));
output = batchNormTrainingResult.getResult(0); output = batchNormTrainingResult.getResult(0);
} }
rewriter.replaceOp(op, output); rewriter.replaceOp(op, output);
return success(); return success();
} else { } else {
Type outputTy = getTypeConverter()->convertType(op.getType()); TensorType outputTy =
cast<TensorType>(getTypeConverter()->convertType(op.getType()));
SmallVector<int64_t, 4> castShape{inputTy.getShape().begin(), SmallVector<int64_t, 4> castShape{inputTy.getShape().begin(),
inputTy.getShape().end()}; inputTy.getShape().end()};
castShape[1] = weightTy.getShape()[0]; castShape[1] = weightTy.getShape()[0];
@ -1314,26 +1324,25 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
Value output; Value output;
// supported mixed types, like input type is fp16 and weight type is fp32. // supported mixed types, like input type is fp16 and weight type is fp32.
if (inputTy.getElementType() != weightTy.getElementType()) { if (inputTy.getElementType() != weightTy.getElementType()) {
RankedTensorType convertedType = inputTy; Type computeType = inputTy.getElementType();
if (cast<FloatType>(weightTy.getElementType()).getWidth() > if (weightTy.getElementType().getIntOrFloatBitWidth() >
cast<FloatType>(inputTy.getElementType()).getWidth()) { inputTy.getElementType().getIntOrFloatBitWidth()) {
convertedType = RankedTensorType::get(inputTy.getShape(), computeType = weightTy.getElementType();
weightTy.getElementType());
} }
input = input = hlo::promoteType(rewriter, op.getLoc(), inputCasted, computeType);
hlo::promoteType(rewriter, op.getLoc(), inputCasted, convertedType); weight = hlo::promoteType(rewriter, op.getLoc(), weight, computeType);
weight = hlo::promoteType(rewriter, op.getLoc(), weight, convertedType); bias = hlo::promoteType(rewriter, op.getLoc(), bias, computeType);
bias = hlo::promoteType(rewriter, op.getLoc(), bias, convertedType);
runningMean = runningMean =
hlo::promoteType(rewriter, op.getLoc(), runningMean, convertedType); hlo::promoteType(rewriter, op.getLoc(), runningMean, computeType);
runningVar = runningVar =
hlo::promoteType(rewriter, op.getLoc(), runningVar, convertedType); hlo::promoteType(rewriter, op.getLoc(), runningVar, computeType);
Value bnResult = rewriter.create<stablehlo::BatchNormInferenceOp>( Value bnResult = rewriter.create<stablehlo::BatchNormInferenceOp>(
op.getLoc(), convertedType, input, weight, bias, runningMean, op.getLoc(), RankedTensorType::get(inputTy.getShape(), computeType),
runningVar, rewriter.getF32FloatAttr(eps), input, weight, bias, runningMean, runningVar,
rewriter.getF32FloatAttr(eps),
rewriter.getI64IntegerAttr(feature_index)); rewriter.getI64IntegerAttr(feature_index));
output = hlo::promoteType(rewriter, op.getLoc(), bnResult, output = hlo::promoteType(rewriter, op.getLoc(), bnResult,
cast<TensorType>(outputTy)); outputTy.getElementType());
} else { } else {
output = rewriter.create<stablehlo::BatchNormInferenceOp>( output = rewriter.create<stablehlo::BatchNormInferenceOp>(
op.getLoc(), inputCasted.getType(), inputCasted, weight, bias, op.getLoc(), inputCasted.getType(), inputCasted, weight, bias,
@ -1515,7 +1524,7 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
// Promote type // Promote type
for (auto &v : builtinTensors) { for (auto &v : builtinTensors) {
v = hlo::promoteType(rewriter, op->getLoc(), v, outType); v = hlo::promoteType(rewriter, op->getLoc(), v, outType.getElementType());
} }
rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>( rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(
@ -1787,8 +1796,8 @@ LogicalResult ConvertAtenOp<AtenPowTensorTensorOp>::matchAndRewrite(
auto outTy = auto outTy =
cast<TensorType>(this->getTypeConverter()->convertType(op.getType())); cast<TensorType>(this->getTypeConverter()->convertType(op.getType()));
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy); lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy.getElementType());
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy.getElementType());
rewriter.replaceOpWithNewOp<chlo::BroadcastPowOp>(op, outTy, lhs, rhs, rewriter.replaceOpWithNewOp<chlo::BroadcastPowOp>(op, outTy, lhs, rhs,
/*broadcast_attr*/ nullptr); /*broadcast_attr*/ nullptr);
@ -1961,8 +1970,10 @@ LogicalResult ConvertAtenOp<AtenRemainderTensorOp>::matchAndRewrite(
auto resultType = auto resultType =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType())); cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
lhs = hlo::promoteType(rewriter, op->getLoc(), lhs, resultType); lhs = hlo::promoteType(rewriter, op->getLoc(), lhs,
rhs = hlo::promoteType(rewriter, op->getLoc(), rhs, resultType); resultType.getElementType());
rhs = hlo::promoteType(rewriter, op->getLoc(), rhs,
resultType.getElementType());
rewriter.replaceOpWithNewOp<stablehlo::RemOp>(op, lhs, rhs); rewriter.replaceOpWithNewOp<stablehlo::RemOp>(op, lhs, rhs);
return success(); return success();
} }
@ -1979,8 +1990,10 @@ LogicalResult ConvertAtenOp<AtenFmodTensorOp>::matchAndRewrite(
auto resultType = auto resultType =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType())); cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
lhs = hlo::promoteType(rewriter, op->getLoc(), lhs, resultType); lhs = hlo::promoteType(rewriter, op->getLoc(), lhs,
rhs = hlo::promoteType(rewriter, op->getLoc(), rhs, resultType); resultType.getElementType());
rhs = hlo::promoteType(rewriter, op->getLoc(), rhs,
resultType.getElementType());
stablehlo::MulOp mul; stablehlo::MulOp mul;
auto div = rewriter.create<stablehlo::DivOp>(loc, lhs, rhs); auto div = rewriter.create<stablehlo::DivOp>(loc, lhs, rhs);

View File

@ -835,7 +835,8 @@ public:
llvm::to_vector<4>(llvm::seq<int64_t>(-nSpatialDims, 0)); llvm::to_vector<4>(llvm::seq<int64_t>(-nSpatialDims, 0));
bias = *hlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims); bias = *hlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims);
bias = hlo::promoteType(rewriter, op.getLoc(), bias, outTy); bias =
hlo::promoteType(rewriter, op.getLoc(), bias, outTy.getElementType());
DenseI64ArrayAttr bcastDimensions; DenseI64ArrayAttr bcastDimensions;
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>( rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(

View File

@ -522,7 +522,8 @@ public:
} else { } else {
assert(false && "Unsupported pooling dimension"); assert(false && "Unsupported pooling dimension");
} }
divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy); divisor = hlo::promoteType(rewriter, op.getLoc(), divisor,
outTy.getElementType());
DenseI64ArrayAttr bcastDimensions; DenseI64ArrayAttr bcastDimensions;
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>( rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions); op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions);
@ -532,8 +533,8 @@ public:
// Use another mhlo.ReduceWindowOp to get the divisor // Use another mhlo.ReduceWindowOp to get the divisor
Value windowSizeConst = Value windowSizeConst =
hlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value(); hlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
windowSizeConst = windowSizeConst = hlo::promoteType(rewriter, op.getLoc(), windowSizeConst,
hlo::promoteType(rewriter, op.getLoc(), windowSizeConst, outTy); outTy.getElementType());
auto inputShapeVec = *hlo::getDimIndexOfTensor(rewriter, op, input); auto inputShapeVec = *hlo::getDimIndexOfTensor(rewriter, op, input);
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>( auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), inputShapeVec); op->getLoc(), inputShapeVec);
@ -583,7 +584,8 @@ LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
auto inputTy = cast<RankedTensorType>(input.getType()); auto inputTy = cast<RankedTensorType>(input.getType());
auto outTy = auto outTy =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType())); cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy); input =
hlo::promoteType(rewriter, op.getLoc(), input, outTy.getElementType());
inputTy = cast<RankedTensorType>(input.getType()); inputTy = cast<RankedTensorType>(input.getType());
auto inputElemTy = inputTy.getElementType(); auto inputElemTy = inputTy.getElementType();
auto inputRank = inputTy.getRank(); auto inputRank = inputTy.getRank();

View File

@ -170,13 +170,10 @@ Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
} }
Value promoteType(PatternRewriter &rewriter, Location loc, Value input, Value promoteType(PatternRewriter &rewriter, Location loc, Value input,
TensorType outType) { Type outElementType) {
TensorType in_type = cast<TensorType>(input.getType()); TensorType inType = cast<TensorType>(input.getType());
if (inType.getElementType() != outElementType) {
if (in_type.getElementType() != outType.getElementType()) { return rewriter.create<stablehlo::ConvertOp>(loc, input, outElementType);
TensorType promotedType =
in_type.cloneWith(in_type.getShape(), outType.getElementType());
return rewriter.create<stablehlo::ConvertOp>(loc, promotedType, input);
} }
return input; return input;
} }