mirror of https://github.com/llvm/torch-mlir
[Stablehlo] fix promoteType() when input doesn't have DefiningOp (#2262)
parent
f4e7344276
commit
0548e2ef3b
|
@ -45,7 +45,8 @@ Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
|
|||
Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
|
||||
Operation *op, Value scalarValue, Type dtype);
|
||||
|
||||
Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType);
|
||||
Value promoteType(PatternRewriter &rewriter, Location loc, Value input,
|
||||
TensorType outType);
|
||||
|
||||
Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
|
||||
TensorType outType);
|
||||
|
|
|
@ -148,7 +148,7 @@ public:
|
|||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
self = hlo::promoteType(rewriter, self, outType);
|
||||
self = hlo::promoteType(rewriter, op.getLoc(), self, outType);
|
||||
rewriter.replaceOpWithNewOp<StablehloOpT>(op, outType, self);
|
||||
return success();
|
||||
}
|
||||
|
@ -253,8 +253,8 @@ public:
|
|||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
|
||||
lhs = hlo::promoteType(rewriter, lhs, outTy);
|
||||
rhs = hlo::promoteType(rewriter, rhs, outTy);
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy);
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy);
|
||||
|
||||
rewriter.replaceOpWithNewOp<ChloOpT>(op, outTy, lhs, rhs,
|
||||
/*broadcast_attr*/ nullptr);
|
||||
|
@ -300,8 +300,8 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
lhs = hlo::promoteType(rewriter, lhs, outType);
|
||||
rhs = hlo::promoteType(rewriter, rhs, outType);
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
|
||||
|
||||
if (!skipMultiplyAlpha(op.getAlpha())) {
|
||||
Value alpha = hlo::scalarToStablehloTensor(rewriter, op,
|
||||
|
@ -354,8 +354,8 @@ public:
|
|||
outElemTy);
|
||||
}
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
lhs = hlo::promoteType(rewriter, lhs, outType);
|
||||
rhs = hlo::promoteType(rewriter, rhs, outType);
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
|
||||
auto loc = op.getLoc();
|
||||
Value result =
|
||||
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
|
||||
|
@ -427,7 +427,7 @@ public:
|
|||
}
|
||||
|
||||
// TODO: what is the PyTorch default type promotion?
|
||||
rhs = hlo::promoteType(rewriter, rhs, lhsTy);
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy);
|
||||
|
||||
chlo::ComparisonTypeAttr compareTypeAttr;
|
||||
chlo::ComparisonDirectionAttr compareDirectionAttr;
|
||||
|
@ -494,8 +494,10 @@ public:
|
|||
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
Value lhs = hlo::promoteType(rewriter, adaptor.getSelf(), outType);
|
||||
Value rhs = hlo::promoteType(rewriter, adaptor.getOther(), outType);
|
||||
Value lhs =
|
||||
hlo::promoteType(rewriter, op.getLoc(), adaptor.getSelf(), outType);
|
||||
Value rhs =
|
||||
hlo::promoteType(rewriter, op.getLoc(), adaptor.getOther(), outType);
|
||||
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
|
||||
|
@ -610,8 +612,8 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
|
|||
auto outType =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
// promote self and other types
|
||||
self = hlo::promoteType(rewriter, self, outType);
|
||||
other = hlo::promoteType(rewriter, other, outType);
|
||||
self = hlo::promoteType(rewriter, op.getLoc(), self, outType);
|
||||
other = hlo::promoteType(rewriter, op.getLoc(), other, outType);
|
||||
|
||||
if (failed(
|
||||
broadcastRanks(rewriter, op, self, cond, options.dimSizeIndexBits)))
|
||||
|
@ -807,8 +809,8 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
|||
rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy);
|
||||
}
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
lhs = hlo::promoteType(rewriter, lhs, outType);
|
||||
rhs = hlo::promoteType(rewriter, rhs, outType);
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
|
||||
auto loc = op.getLoc();
|
||||
Value result = rewriter.create<chlo::BroadcastPowOp>(loc, outType, lhs, rhs,
|
||||
bcastDimensions);
|
||||
|
@ -1212,7 +1214,7 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
|
|||
|
||||
// Promote type
|
||||
for (auto &v : builtinTensors) {
|
||||
v = hlo::promoteType(rewriter, v, outType);
|
||||
v = hlo::promoteType(rewriter, op->getLoc(), v, outType);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(
|
||||
|
@ -1404,8 +1406,8 @@ LogicalResult ConvertAtenOp<AtenPowTensorTensorOp>::matchAndRewrite(
|
|||
auto outTy =
|
||||
this->getTypeConverter()->convertType(op.getType()).cast<TensorType>();
|
||||
|
||||
lhs = hlo::promoteType(rewriter, lhs, outTy);
|
||||
rhs = hlo::promoteType(rewriter, rhs, outTy);
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy);
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy);
|
||||
|
||||
rewriter.replaceOpWithNewOp<chlo::BroadcastPowOp>(op, outTy, lhs, rhs,
|
||||
/*broadcast_attr*/ nullptr);
|
||||
|
|
|
@ -785,7 +785,7 @@ public:
|
|||
const auto &options = getOptions();
|
||||
bias = *hlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims,
|
||||
options.dimSizeIndexBits);
|
||||
bias = hlo::promoteType(rewriter, bias, outTy);
|
||||
bias = hlo::promoteType(rewriter, op.getLoc(), bias, outTy);
|
||||
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
|
||||
|
|
|
@ -484,7 +484,7 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
|
|||
Value divisor = hlo::getConstTensor<int64_t>(
|
||||
rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
|
||||
.value();
|
||||
divisor = hlo::promoteType(rewriter, divisor, outTy);
|
||||
divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy);
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
|
||||
op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions);
|
||||
|
@ -494,7 +494,8 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
|
|||
// Use another stablehlo.ReduceWindowOp to get the divisor
|
||||
Value windowSizeConst =
|
||||
hlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
|
||||
windowSizeConst = hlo::promoteType(rewriter, windowSizeConst, outTy);
|
||||
windowSizeConst =
|
||||
hlo::promoteType(rewriter, op.getLoc(), windowSizeConst, outTy);
|
||||
const auto &options = getOptions();
|
||||
auto inputShapeVec =
|
||||
*hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
|
|
|
@ -185,15 +185,14 @@ Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
|
|||
dtype_tensor);
|
||||
}
|
||||
|
||||
Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) {
|
||||
Operation *op = input.getDefiningOp();
|
||||
TensorType in_type = input.getType().dyn_cast<TensorType>();
|
||||
Value promoteType(PatternRewriter &rewriter, Location loc, Value input,
|
||||
TensorType outType) {
|
||||
TensorType in_type = input.getType().cast<TensorType>();
|
||||
|
||||
if (in_type.getElementType() != outType.getElementType()) {
|
||||
TensorType promotedType =
|
||||
in_type.cloneWith(in_type.getShape(), outType.getElementType());
|
||||
return rewriter.create<stablehlo::ConvertOp>(op->getLoc(), promotedType,
|
||||
input);
|
||||
return rewriter.create<stablehlo::ConvertOp>(loc, promotedType, input);
|
||||
}
|
||||
return input;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue