[Stablehlo] fix promoteType() when input doesn't have DefiningOp (#2262)

pull/2265/head snapshot-20230626.881
Yuanqiang Liu 2023-06-26 00:04:17 +08:00 committed by GitHub
parent f4e7344276
commit 0548e2ef3b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 29 additions and 26 deletions

View File

@ -45,7 +45,8 @@ Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
Operation *op, Value scalarValue, Type dtype);
Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType);
Value promoteType(PatternRewriter &rewriter, Location loc, Value input,
TensorType outType);
Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
TensorType outType);

View File

@ -148,7 +148,7 @@ public:
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
self = hlo::promoteType(rewriter, self, outType);
self = hlo::promoteType(rewriter, op.getLoc(), self, outType);
rewriter.replaceOpWithNewOp<StablehloOpT>(op, outType, self);
return success();
}
@ -253,8 +253,8 @@ public:
->convertType(op.getType())
.template cast<TensorType>();
lhs = hlo::promoteType(rewriter, lhs, outTy);
rhs = hlo::promoteType(rewriter, rhs, outTy);
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy);
rewriter.replaceOpWithNewOp<ChloOpT>(op, outTy, lhs, rhs,
/*broadcast_attr*/ nullptr);
@ -300,8 +300,8 @@ public:
}
}
lhs = hlo::promoteType(rewriter, lhs, outType);
rhs = hlo::promoteType(rewriter, rhs, outType);
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
if (!skipMultiplyAlpha(op.getAlpha())) {
Value alpha = hlo::scalarToStablehloTensor(rewriter, op,
@ -354,8 +354,8 @@ public:
outElemTy);
}
DenseIntElementsAttr bcastDimensions;
lhs = hlo::promoteType(rewriter, lhs, outType);
rhs = hlo::promoteType(rewriter, rhs, outType);
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
auto loc = op.getLoc();
Value result =
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
@ -427,7 +427,7 @@ public:
}
// TODO: what is the PyTorch default type promotion?
rhs = hlo::promoteType(rewriter, rhs, lhsTy);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy);
chlo::ComparisonTypeAttr compareTypeAttr;
chlo::ComparisonDirectionAttr compareDirectionAttr;
@ -494,8 +494,10 @@ public:
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
Value lhs = hlo::promoteType(rewriter, adaptor.getSelf(), outType);
Value rhs = hlo::promoteType(rewriter, adaptor.getOther(), outType);
Value lhs =
hlo::promoteType(rewriter, op.getLoc(), adaptor.getSelf(), outType);
Value rhs =
hlo::promoteType(rewriter, op.getLoc(), adaptor.getOther(), outType);
DenseIntElementsAttr bcastDimensions;
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
@ -610,8 +612,8 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
auto outType =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
// promote self and other types
self = hlo::promoteType(rewriter, self, outType);
other = hlo::promoteType(rewriter, other, outType);
self = hlo::promoteType(rewriter, op.getLoc(), self, outType);
other = hlo::promoteType(rewriter, op.getLoc(), other, outType);
if (failed(
broadcastRanks(rewriter, op, self, cond, options.dimSizeIndexBits)))
@ -807,8 +809,8 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy);
}
DenseIntElementsAttr bcastDimensions;
lhs = hlo::promoteType(rewriter, lhs, outType);
rhs = hlo::promoteType(rewriter, rhs, outType);
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
auto loc = op.getLoc();
Value result = rewriter.create<chlo::BroadcastPowOp>(loc, outType, lhs, rhs,
bcastDimensions);
@ -1212,7 +1214,7 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
// Promote type
for (auto &v : builtinTensors) {
v = hlo::promoteType(rewriter, v, outType);
v = hlo::promoteType(rewriter, op->getLoc(), v, outType);
}
rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(
@ -1404,8 +1406,8 @@ LogicalResult ConvertAtenOp<AtenPowTensorTensorOp>::matchAndRewrite(
auto outTy =
this->getTypeConverter()->convertType(op.getType()).cast<TensorType>();
lhs = hlo::promoteType(rewriter, lhs, outTy);
rhs = hlo::promoteType(rewriter, rhs, outTy);
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy);
rewriter.replaceOpWithNewOp<chlo::BroadcastPowOp>(op, outTy, lhs, rhs,
/*broadcast_attr*/ nullptr);

View File

@ -785,7 +785,7 @@ public:
const auto &options = getOptions();
bias = *hlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims,
options.dimSizeIndexBits);
bias = hlo::promoteType(rewriter, bias, outTy);
bias = hlo::promoteType(rewriter, op.getLoc(), bias, outTy);
DenseIntElementsAttr bcastDimensions;
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(

View File

@ -484,7 +484,7 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
Value divisor = hlo::getConstTensor<int64_t>(
rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
.value();
divisor = hlo::promoteType(rewriter, divisor, outTy);
divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy);
DenseIntElementsAttr bcastDimensions;
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions);
@ -494,7 +494,8 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
// Use another stablehlo.ReduceWindowOp to get the divisor
Value windowSizeConst =
hlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
windowSizeConst = hlo::promoteType(rewriter, windowSizeConst, outTy);
windowSizeConst =
hlo::promoteType(rewriter, op.getLoc(), windowSizeConst, outTy);
const auto &options = getOptions();
auto inputShapeVec =
*hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);

View File

@ -185,15 +185,14 @@ Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
dtype_tensor);
}
Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) {
Operation *op = input.getDefiningOp();
TensorType in_type = input.getType().dyn_cast<TensorType>();
Value promoteType(PatternRewriter &rewriter, Location loc, Value input,
TensorType outType) {
TensorType in_type = input.getType().cast<TensorType>();
if (in_type.getElementType() != outType.getElementType()) {
TensorType promotedType =
in_type.cloneWith(in_type.getShape(), outType.getElementType());
return rewriter.create<stablehlo::ConvertOp>(op->getLoc(), promotedType,
input);
return rewriter.create<stablehlo::ConvertOp>(loc, promotedType, input);
}
return input;
}