mirror of https://github.com/llvm/torch-mlir
[Stablehlo] simplify promoteType (#3525)
only provide `outElementType` when promoteTypepull/3530/head
parent
dcb48dd46c
commit
5bee9aac63
|
@ -50,7 +50,7 @@ Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
|
||||||
Operation *op, Value scalarValue, Type dtype);
|
Operation *op, Value scalarValue, Type dtype);
|
||||||
|
|
||||||
Value promoteType(PatternRewriter &rewriter, Location loc, Value input,
|
Value promoteType(PatternRewriter &rewriter, Location loc, Value input,
|
||||||
TensorType outType);
|
Type outElementType);
|
||||||
|
|
||||||
Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
|
Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
|
||||||
TensorType outType);
|
TensorType outType);
|
||||||
|
|
|
@ -148,7 +148,8 @@ public:
|
||||||
auto outType = cast<TensorType>(
|
auto outType = cast<TensorType>(
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
op.getType()));
|
op.getType()));
|
||||||
self = hlo::promoteType(rewriter, op.getLoc(), self, outType);
|
self =
|
||||||
|
hlo::promoteType(rewriter, op.getLoc(), self, outType.getElementType());
|
||||||
rewriter.replaceOpWithNewOp<StablehloOpT>(op, outType, self);
|
rewriter.replaceOpWithNewOp<StablehloOpT>(op, outType, self);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -207,7 +208,8 @@ public:
|
||||||
op.getType()));
|
op.getType()));
|
||||||
|
|
||||||
if (isa<mlir::FloatType>(resultTy.getElementType())) {
|
if (isa<mlir::FloatType>(resultTy.getElementType())) {
|
||||||
Value src = hlo::promoteType(rewriter, op.getLoc(), self, resultTy);
|
Value src = hlo::promoteType(rewriter, op.getLoc(), self,
|
||||||
|
resultTy.getElementType());
|
||||||
rewriter.replaceOpWithNewOp<StablehloOpT>(op, resultTy, src);
|
rewriter.replaceOpWithNewOp<StablehloOpT>(op, resultTy, src);
|
||||||
return success();
|
return success();
|
||||||
} else {
|
} else {
|
||||||
|
@ -334,8 +336,8 @@ public:
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
op.getType()));
|
op.getType()));
|
||||||
|
|
||||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy);
|
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy.getElementType());
|
||||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy);
|
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy.getElementType());
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<ChloOpT>(op, outTy, lhs, rhs,
|
rewriter.replaceOpWithNewOp<ChloOpT>(op, outTy, lhs, rhs,
|
||||||
/*broadcast_attr*/ nullptr);
|
/*broadcast_attr*/ nullptr);
|
||||||
|
@ -381,8 +383,8 @@ public:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
|
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy);
|
||||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
|
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy);
|
||||||
|
|
||||||
if (!skipMultiplyAlpha(op.getAlpha())) {
|
if (!skipMultiplyAlpha(op.getAlpha())) {
|
||||||
Value alpha = hlo::scalarToStablehloTensor(rewriter, op,
|
Value alpha = hlo::scalarToStablehloTensor(rewriter, op,
|
||||||
|
@ -437,8 +439,8 @@ public:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
DenseI64ArrayAttr bcastDimensions;
|
DenseI64ArrayAttr bcastDimensions;
|
||||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
|
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy);
|
||||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
|
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy);
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
Value result =
|
Value result =
|
||||||
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
|
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
|
||||||
|
@ -476,16 +478,17 @@ public:
|
||||||
if (isa<mlir::FloatType>(outElemTy))
|
if (isa<mlir::FloatType>(outElemTy))
|
||||||
result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult();
|
result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult();
|
||||||
else if (!outElemTy.isUnsignedInteger()) {
|
else if (!outElemTy.isUnsignedInteger()) {
|
||||||
TensorType defaultIntToFloatType =
|
Type defaultIntToFloatType = rewriter.getF64Type();
|
||||||
outType.cloneWith(outType.getShape(), rewriter.getF64Type());
|
|
||||||
lhs =
|
lhs =
|
||||||
hlo::promoteType(rewriter, op.getLoc(), lhs, defaultIntToFloatType);
|
hlo::promoteType(rewriter, op.getLoc(), lhs, defaultIntToFloatType);
|
||||||
rhs =
|
rhs =
|
||||||
hlo::promoteType(rewriter, op.getLoc(), rhs, defaultIntToFloatType);
|
hlo::promoteType(rewriter, op.getLoc(), rhs, defaultIntToFloatType);
|
||||||
result = rewriter.create<ChloOpT>(loc, defaultIntToFloatType, lhs, rhs,
|
result = rewriter.create<ChloOpT>(
|
||||||
bcastDimensions);
|
loc, outType.cloneWith(outType.getShape(), defaultIntToFloatType),
|
||||||
|
lhs, rhs, bcastDimensions);
|
||||||
result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult();
|
result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult();
|
||||||
result = hlo::promoteType(rewriter, op.getLoc(), result, outType);
|
result = hlo::promoteType(rewriter, op.getLoc(), result,
|
||||||
|
outType.getElementType());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
rewriter.replaceOp(op, result);
|
rewriter.replaceOp(op, result);
|
||||||
|
@ -517,7 +520,8 @@ public:
|
||||||
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
|
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
|
||||||
rhs.getType());
|
rhs.getType());
|
||||||
// use lhs's element type as compute type
|
// use lhs's element type as compute type
|
||||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy);
|
rhs =
|
||||||
|
hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy.getElementType());
|
||||||
rhsTy = dyn_cast<RankedTensorType>(rhs.getType());
|
rhsTy = dyn_cast<RankedTensorType>(rhs.getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -533,16 +537,16 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isa<mlir::IntegerType>(lhsElemTy) && isa<mlir::FloatType>(rhsElemTy)) {
|
if (isa<mlir::IntegerType>(lhsElemTy) && isa<mlir::FloatType>(rhsElemTy)) {
|
||||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsTy);
|
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy);
|
||||||
} else if (isa<mlir::FloatType>(lhsElemTy) &&
|
} else if (isa<mlir::FloatType>(lhsElemTy) &&
|
||||||
isa<mlir::IntegerType>(rhsElemTy)) {
|
isa<mlir::IntegerType>(rhsElemTy)) {
|
||||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy);
|
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy);
|
||||||
} else {
|
} else {
|
||||||
if (lhsElemTy.getIntOrFloatBitWidth() >
|
if (lhsElemTy.getIntOrFloatBitWidth() >
|
||||||
rhsElemTy.getIntOrFloatBitWidth()) {
|
rhsElemTy.getIntOrFloatBitWidth()) {
|
||||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy);
|
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy);
|
||||||
} else {
|
} else {
|
||||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsTy);
|
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
lhsElemTy = dyn_cast<RankedTensorType>(lhs.getType()).getElementType();
|
lhsElemTy = dyn_cast<RankedTensorType>(lhs.getType()).getElementType();
|
||||||
|
@ -622,11 +626,11 @@ public:
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
op.getType()));
|
op.getType()));
|
||||||
Type outElemTy = outType.getElementType();
|
Type outElemTy = outType.getElementType();
|
||||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
|
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy);
|
||||||
if (!rhsTy) {
|
if (!rhsTy) {
|
||||||
rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy);
|
rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy);
|
||||||
}
|
}
|
||||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
|
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy);
|
||||||
|
|
||||||
DenseI64ArrayAttr bcastDimensions;
|
DenseI64ArrayAttr bcastDimensions;
|
||||||
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
|
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
|
||||||
|
@ -736,8 +740,10 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
|
||||||
auto outType =
|
auto outType =
|
||||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
// promote self and other types
|
// promote self and other types
|
||||||
self = hlo::promoteType(rewriter, op.getLoc(), self, outType);
|
self =
|
||||||
other = hlo::promoteType(rewriter, op.getLoc(), other, outType);
|
hlo::promoteType(rewriter, op.getLoc(), self, outType.getElementType());
|
||||||
|
other =
|
||||||
|
hlo::promoteType(rewriter, op.getLoc(), other, outType.getElementType());
|
||||||
|
|
||||||
if (failed(broadcastRanks(rewriter, op, self, cond)))
|
if (failed(broadcastRanks(rewriter, op, self, cond)))
|
||||||
return op.emitError("failed broadcast self and condition ranks");
|
return op.emitError("failed broadcast self and condition ranks");
|
||||||
|
@ -940,8 +946,8 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
||||||
rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy);
|
rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy);
|
||||||
}
|
}
|
||||||
DenseI64ArrayAttr bcastDimensions;
|
DenseI64ArrayAttr bcastDimensions;
|
||||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
|
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy);
|
||||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
|
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy);
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
Value result = rewriter.create<chlo::BroadcastPowOp>(loc, outType, lhs, rhs,
|
Value result = rewriter.create<chlo::BroadcastPowOp>(loc, outType, lhs, rhs,
|
||||||
bcastDimensions);
|
bcastDimensions);
|
||||||
|
@ -977,8 +983,8 @@ LogicalResult ConvertAtenOp<AtenPowScalarOp>::matchAndRewrite(
|
||||||
lhs = hlo::scalarToStablehloTensor(rewriter, op, lhs, outElemTy);
|
lhs = hlo::scalarToStablehloTensor(rewriter, op, lhs, outElemTy);
|
||||||
}
|
}
|
||||||
DenseI64ArrayAttr bcastDimensions;
|
DenseI64ArrayAttr bcastDimensions;
|
||||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
|
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy);
|
||||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
|
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy);
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
Value result = rewriter.create<chlo::BroadcastPowOp>(loc, outType, lhs, rhs,
|
Value result = rewriter.create<chlo::BroadcastPowOp>(loc, outType, lhs, rhs,
|
||||||
bcastDimensions);
|
bcastDimensions);
|
||||||
|
@ -1121,7 +1127,8 @@ LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
|
||||||
return op.emitError("only ranked tensor type is supported.");
|
return op.emitError("only ranked tensor type is supported.");
|
||||||
}
|
}
|
||||||
auto outTy = cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
auto outTy = cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
|
input =
|
||||||
|
hlo::promoteType(rewriter, op.getLoc(), input, outTy.getElementType());
|
||||||
|
|
||||||
auto two = hlo::getConstantLike(rewriter, op.getLoc(), 2.0, input);
|
auto two = hlo::getConstantLike(rewriter, op.getLoc(), 2.0, input);
|
||||||
auto log2Op = rewriter.create<stablehlo::LogOp>(op.getLoc(), two);
|
auto log2Op = rewriter.create<stablehlo::LogOp>(op.getLoc(), two);
|
||||||
|
@ -1143,7 +1150,8 @@ LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
|
||||||
}
|
}
|
||||||
|
|
||||||
auto outTy = cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
auto outTy = cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
|
input =
|
||||||
|
hlo::promoteType(rewriter, op.getLoc(), input, outTy.getElementType());
|
||||||
|
|
||||||
auto ten = hlo::getConstantLike(rewriter, op.getLoc(), 10.0, input);
|
auto ten = hlo::getConstantLike(rewriter, op.getLoc(), 10.0, input);
|
||||||
auto log10Op = rewriter.create<stablehlo::LogOp>(op.getLoc(), ten);
|
auto log10Op = rewriter.create<stablehlo::LogOp>(op.getLoc(), ten);
|
||||||
|
@ -1266,42 +1274,44 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
||||||
"non-bool cudnn_enabled unsupported");
|
"non-bool cudnn_enabled unsupported");
|
||||||
}
|
}
|
||||||
if (training) {
|
if (training) {
|
||||||
Type outputTy = getTypeConverter()->convertType(op.getType());
|
TensorType outputTy =
|
||||||
Type batchMeanOrVarTy =
|
cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
RankedTensorType::get(weightTy.getShape(), inputTy.getElementType());
|
|
||||||
|
|
||||||
Value output;
|
Value output;
|
||||||
// supported mixed types, like input type is fp16 and weight type is fp32.
|
// supported mixed types, like input type is fp16 and weight type is fp32.
|
||||||
if (inputTy.getElementType() != weightTy.getElementType()) {
|
if (inputTy.getElementType() != weightTy.getElementType()) {
|
||||||
RankedTensorType convertedType = inputTy;
|
Type computeType = inputTy.getElementType();
|
||||||
if (cast<FloatType>(weightTy.getElementType()).getWidth() >
|
if (weightTy.getElementType().getIntOrFloatBitWidth() >
|
||||||
cast<FloatType>(inputTy.getElementType()).getWidth()) {
|
inputTy.getElementType().getIntOrFloatBitWidth()) {
|
||||||
convertedType = RankedTensorType::get(inputTy.getShape(),
|
computeType = weightTy.getElementType();
|
||||||
weightTy.getElementType());
|
|
||||||
}
|
}
|
||||||
input = hlo::promoteType(rewriter, op.getLoc(), input, convertedType);
|
input = hlo::promoteType(rewriter, op.getLoc(), input, computeType);
|
||||||
weight = hlo::promoteType(rewriter, op.getLoc(), weight, convertedType);
|
weight = hlo::promoteType(rewriter, op.getLoc(), weight, computeType);
|
||||||
bias = hlo::promoteType(rewriter, op.getLoc(), bias, convertedType);
|
bias = hlo::promoteType(rewriter, op.getLoc(), bias, computeType);
|
||||||
auto batchNormTrainingResult =
|
auto batchNormTrainingResult =
|
||||||
rewriter.create<stablehlo::BatchNormTrainingOp>(
|
rewriter.create<stablehlo::BatchNormTrainingOp>(
|
||||||
op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input,
|
op.getLoc(),
|
||||||
|
RankedTensorType::get(inputTy.getShape(), computeType),
|
||||||
|
RankedTensorType::get(weightTy.getShape(), computeType),
|
||||||
|
RankedTensorType::get(weightTy.getShape(), computeType), input,
|
||||||
weight, bias, rewriter.getF32FloatAttr(eps),
|
weight, bias, rewriter.getF32FloatAttr(eps),
|
||||||
rewriter.getI64IntegerAttr(feature_index));
|
rewriter.getI64IntegerAttr(feature_index));
|
||||||
output = hlo::promoteType(rewriter, op.getLoc(),
|
output = hlo::promoteType(rewriter, op.getLoc(),
|
||||||
batchNormTrainingResult.getResult(0),
|
batchNormTrainingResult.getResult(0),
|
||||||
cast<TensorType>(outputTy));
|
outputTy.getElementType());
|
||||||
} else {
|
} else {
|
||||||
auto batchNormTrainingResult =
|
auto batchNormTrainingResult =
|
||||||
rewriter.create<stablehlo::BatchNormTrainingOp>(
|
rewriter.create<stablehlo::BatchNormTrainingOp>(
|
||||||
op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input,
|
op.getLoc(), outputTy, weightTy, weightTy, input, weight, bias,
|
||||||
weight, bias, rewriter.getF32FloatAttr(eps),
|
rewriter.getF32FloatAttr(eps),
|
||||||
rewriter.getI64IntegerAttr(feature_index));
|
rewriter.getI64IntegerAttr(feature_index));
|
||||||
output = batchNormTrainingResult.getResult(0);
|
output = batchNormTrainingResult.getResult(0);
|
||||||
}
|
}
|
||||||
rewriter.replaceOp(op, output);
|
rewriter.replaceOp(op, output);
|
||||||
return success();
|
return success();
|
||||||
} else {
|
} else {
|
||||||
Type outputTy = getTypeConverter()->convertType(op.getType());
|
TensorType outputTy =
|
||||||
|
cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
SmallVector<int64_t, 4> castShape{inputTy.getShape().begin(),
|
SmallVector<int64_t, 4> castShape{inputTy.getShape().begin(),
|
||||||
inputTy.getShape().end()};
|
inputTy.getShape().end()};
|
||||||
castShape[1] = weightTy.getShape()[0];
|
castShape[1] = weightTy.getShape()[0];
|
||||||
|
@ -1314,26 +1324,25 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
||||||
Value output;
|
Value output;
|
||||||
// supported mixed types, like input type is fp16 and weight type is fp32.
|
// supported mixed types, like input type is fp16 and weight type is fp32.
|
||||||
if (inputTy.getElementType() != weightTy.getElementType()) {
|
if (inputTy.getElementType() != weightTy.getElementType()) {
|
||||||
RankedTensorType convertedType = inputTy;
|
Type computeType = inputTy.getElementType();
|
||||||
if (cast<FloatType>(weightTy.getElementType()).getWidth() >
|
if (weightTy.getElementType().getIntOrFloatBitWidth() >
|
||||||
cast<FloatType>(inputTy.getElementType()).getWidth()) {
|
inputTy.getElementType().getIntOrFloatBitWidth()) {
|
||||||
convertedType = RankedTensorType::get(inputTy.getShape(),
|
computeType = weightTy.getElementType();
|
||||||
weightTy.getElementType());
|
|
||||||
}
|
}
|
||||||
input =
|
input = hlo::promoteType(rewriter, op.getLoc(), inputCasted, computeType);
|
||||||
hlo::promoteType(rewriter, op.getLoc(), inputCasted, convertedType);
|
weight = hlo::promoteType(rewriter, op.getLoc(), weight, computeType);
|
||||||
weight = hlo::promoteType(rewriter, op.getLoc(), weight, convertedType);
|
bias = hlo::promoteType(rewriter, op.getLoc(), bias, computeType);
|
||||||
bias = hlo::promoteType(rewriter, op.getLoc(), bias, convertedType);
|
|
||||||
runningMean =
|
runningMean =
|
||||||
hlo::promoteType(rewriter, op.getLoc(), runningMean, convertedType);
|
hlo::promoteType(rewriter, op.getLoc(), runningMean, computeType);
|
||||||
runningVar =
|
runningVar =
|
||||||
hlo::promoteType(rewriter, op.getLoc(), runningVar, convertedType);
|
hlo::promoteType(rewriter, op.getLoc(), runningVar, computeType);
|
||||||
Value bnResult = rewriter.create<stablehlo::BatchNormInferenceOp>(
|
Value bnResult = rewriter.create<stablehlo::BatchNormInferenceOp>(
|
||||||
op.getLoc(), convertedType, input, weight, bias, runningMean,
|
op.getLoc(), RankedTensorType::get(inputTy.getShape(), computeType),
|
||||||
runningVar, rewriter.getF32FloatAttr(eps),
|
input, weight, bias, runningMean, runningVar,
|
||||||
|
rewriter.getF32FloatAttr(eps),
|
||||||
rewriter.getI64IntegerAttr(feature_index));
|
rewriter.getI64IntegerAttr(feature_index));
|
||||||
output = hlo::promoteType(rewriter, op.getLoc(), bnResult,
|
output = hlo::promoteType(rewriter, op.getLoc(), bnResult,
|
||||||
cast<TensorType>(outputTy));
|
outputTy.getElementType());
|
||||||
} else {
|
} else {
|
||||||
output = rewriter.create<stablehlo::BatchNormInferenceOp>(
|
output = rewriter.create<stablehlo::BatchNormInferenceOp>(
|
||||||
op.getLoc(), inputCasted.getType(), inputCasted, weight, bias,
|
op.getLoc(), inputCasted.getType(), inputCasted, weight, bias,
|
||||||
|
@ -1515,7 +1524,7 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
|
||||||
|
|
||||||
// Promote type
|
// Promote type
|
||||||
for (auto &v : builtinTensors) {
|
for (auto &v : builtinTensors) {
|
||||||
v = hlo::promoteType(rewriter, op->getLoc(), v, outType);
|
v = hlo::promoteType(rewriter, op->getLoc(), v, outType.getElementType());
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(
|
rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(
|
||||||
|
@ -1787,8 +1796,8 @@ LogicalResult ConvertAtenOp<AtenPowTensorTensorOp>::matchAndRewrite(
|
||||||
auto outTy =
|
auto outTy =
|
||||||
cast<TensorType>(this->getTypeConverter()->convertType(op.getType()));
|
cast<TensorType>(this->getTypeConverter()->convertType(op.getType()));
|
||||||
|
|
||||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy);
|
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy.getElementType());
|
||||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy);
|
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy.getElementType());
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<chlo::BroadcastPowOp>(op, outTy, lhs, rhs,
|
rewriter.replaceOpWithNewOp<chlo::BroadcastPowOp>(op, outTy, lhs, rhs,
|
||||||
/*broadcast_attr*/ nullptr);
|
/*broadcast_attr*/ nullptr);
|
||||||
|
@ -1961,8 +1970,10 @@ LogicalResult ConvertAtenOp<AtenRemainderTensorOp>::matchAndRewrite(
|
||||||
|
|
||||||
auto resultType =
|
auto resultType =
|
||||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
lhs = hlo::promoteType(rewriter, op->getLoc(), lhs, resultType);
|
lhs = hlo::promoteType(rewriter, op->getLoc(), lhs,
|
||||||
rhs = hlo::promoteType(rewriter, op->getLoc(), rhs, resultType);
|
resultType.getElementType());
|
||||||
|
rhs = hlo::promoteType(rewriter, op->getLoc(), rhs,
|
||||||
|
resultType.getElementType());
|
||||||
rewriter.replaceOpWithNewOp<stablehlo::RemOp>(op, lhs, rhs);
|
rewriter.replaceOpWithNewOp<stablehlo::RemOp>(op, lhs, rhs);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -1979,8 +1990,10 @@ LogicalResult ConvertAtenOp<AtenFmodTensorOp>::matchAndRewrite(
|
||||||
|
|
||||||
auto resultType =
|
auto resultType =
|
||||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
lhs = hlo::promoteType(rewriter, op->getLoc(), lhs, resultType);
|
lhs = hlo::promoteType(rewriter, op->getLoc(), lhs,
|
||||||
rhs = hlo::promoteType(rewriter, op->getLoc(), rhs, resultType);
|
resultType.getElementType());
|
||||||
|
rhs = hlo::promoteType(rewriter, op->getLoc(), rhs,
|
||||||
|
resultType.getElementType());
|
||||||
|
|
||||||
stablehlo::MulOp mul;
|
stablehlo::MulOp mul;
|
||||||
auto div = rewriter.create<stablehlo::DivOp>(loc, lhs, rhs);
|
auto div = rewriter.create<stablehlo::DivOp>(loc, lhs, rhs);
|
||||||
|
|
|
@ -835,7 +835,8 @@ public:
|
||||||
llvm::to_vector<4>(llvm::seq<int64_t>(-nSpatialDims, 0));
|
llvm::to_vector<4>(llvm::seq<int64_t>(-nSpatialDims, 0));
|
||||||
|
|
||||||
bias = *hlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims);
|
bias = *hlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims);
|
||||||
bias = hlo::promoteType(rewriter, op.getLoc(), bias, outTy);
|
bias =
|
||||||
|
hlo::promoteType(rewriter, op.getLoc(), bias, outTy.getElementType());
|
||||||
|
|
||||||
DenseI64ArrayAttr bcastDimensions;
|
DenseI64ArrayAttr bcastDimensions;
|
||||||
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
|
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
|
||||||
|
|
|
@ -522,7 +522,8 @@ public:
|
||||||
} else {
|
} else {
|
||||||
assert(false && "Unsupported pooling dimension");
|
assert(false && "Unsupported pooling dimension");
|
||||||
}
|
}
|
||||||
divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy);
|
divisor = hlo::promoteType(rewriter, op.getLoc(), divisor,
|
||||||
|
outTy.getElementType());
|
||||||
DenseI64ArrayAttr bcastDimensions;
|
DenseI64ArrayAttr bcastDimensions;
|
||||||
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
|
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
|
||||||
op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions);
|
op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions);
|
||||||
|
@ -532,8 +533,8 @@ public:
|
||||||
// Use another mhlo.ReduceWindowOp to get the divisor
|
// Use another mhlo.ReduceWindowOp to get the divisor
|
||||||
Value windowSizeConst =
|
Value windowSizeConst =
|
||||||
hlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
|
hlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
|
||||||
windowSizeConst =
|
windowSizeConst = hlo::promoteType(rewriter, op.getLoc(), windowSizeConst,
|
||||||
hlo::promoteType(rewriter, op.getLoc(), windowSizeConst, outTy);
|
outTy.getElementType());
|
||||||
auto inputShapeVec = *hlo::getDimIndexOfTensor(rewriter, op, input);
|
auto inputShapeVec = *hlo::getDimIndexOfTensor(rewriter, op, input);
|
||||||
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||||
op->getLoc(), inputShapeVec);
|
op->getLoc(), inputShapeVec);
|
||||||
|
@ -583,7 +584,8 @@ LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
|
||||||
auto inputTy = cast<RankedTensorType>(input.getType());
|
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||||
auto outTy =
|
auto outTy =
|
||||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
|
input =
|
||||||
|
hlo::promoteType(rewriter, op.getLoc(), input, outTy.getElementType());
|
||||||
inputTy = cast<RankedTensorType>(input.getType());
|
inputTy = cast<RankedTensorType>(input.getType());
|
||||||
auto inputElemTy = inputTy.getElementType();
|
auto inputElemTy = inputTy.getElementType();
|
||||||
auto inputRank = inputTy.getRank();
|
auto inputRank = inputTy.getRank();
|
||||||
|
|
|
@ -170,13 +170,10 @@ Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
|
||||||
}
|
}
|
||||||
|
|
||||||
Value promoteType(PatternRewriter &rewriter, Location loc, Value input,
|
Value promoteType(PatternRewriter &rewriter, Location loc, Value input,
|
||||||
TensorType outType) {
|
Type outElementType) {
|
||||||
TensorType in_type = cast<TensorType>(input.getType());
|
TensorType inType = cast<TensorType>(input.getType());
|
||||||
|
if (inType.getElementType() != outElementType) {
|
||||||
if (in_type.getElementType() != outType.getElementType()) {
|
return rewriter.create<stablehlo::ConvertOp>(loc, input, outElementType);
|
||||||
TensorType promotedType =
|
|
||||||
in_type.cloneWith(in_type.getShape(), outType.getElementType());
|
|
||||||
return rewriter.create<stablehlo::ConvertOp>(loc, promotedType, input);
|
|
||||||
}
|
}
|
||||||
return input;
|
return input;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue