Add `hasDtype` checks everywhere dtypes are used in decompositions (#1750)

There are several decompositions that assume the operands of the op
have dtypes available; however, the only time dtypes are guaranteed to
be present is when the graph has reached the backend contract. In
general, every pass that happens before reaching the backend contract
should not assume dtypes are available and should use `hasDtype` to
check first.

This commit adds `hasDtype` checks to every decomposition that uses
dtypes.
pull/1766/head
Ramiro Leal-Cavazos 2023-01-03 14:19:18 -08:00 committed by GitHub
parent 273664ded6
commit d44bdd2728
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 58 additions and 19 deletions

View File

@ -71,7 +71,7 @@ static Type computeReductionType(PatternRewriter &rewriter, Operation *op,
Type resultType = tensorType.getWithSizesAndDtype(
sizes.size() == 0 ? std::optional<ArrayRef<int64_t>>()
: llvm::makeArrayRef(sizes),
tensorType.getDtype());
tensorType.getOptionalDtype());
return resultType;
}
@ -407,6 +407,11 @@ public:
op, "Expected a boolean value for half_to_float");
BaseTensorType resultTensorType = op.getType().cast<BaseTensorType>();
if (!resultTensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected result type to have a dtype");
}
Type resultTensorDtype = resultTensorType.getDtype();
// `torch.ops.aten._softmax`'s softmax with half to float conversion is not
// supported on CPU, but we go ahead with the decomposing.
// TODO: Add an e2e test once upstream support is added.
@ -418,7 +423,7 @@ public:
Value cstFalse = rewriter.create<ConstantBoolOp>(loc, false);
self = rewriter.create<AtenToDtypeOp>(
loc, resultTensorType, self,
getDtypeIntValueForType(rewriter, loc, resultTensorType.getDtype()),
getDtypeIntValueForType(rewriter, loc, resultTensorDtype),
/*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none);
}
Value result = getSoftmaxResult(op, self, resultTensorType, rewriter);
@ -558,8 +563,8 @@ public:
return failure();
BaseTensorType valueTensorType =
inputType
.getWithSizesAndDtype(indicesTensorType.getSizes(),
inputType.getDtype())
.getWithSizesAndDtype(indicesTensorType.getOptionalSizes(),
inputType.getOptionalDtype())
.cast<BaseTensorType>();
// If the dim type is `NoneType` i.e. reduce along all the dimensions.
@ -568,7 +573,9 @@ public:
// 0th dimension.
if (dim.getType().isa<Torch::NoneType>()) {
BaseTensorType flattenType =
inputType.getWithSizesAndDtype({kUnknownSize}, inputType.getDtype())
inputType
.getWithSizesAndDtype({kUnknownSize},
inputType.getOptionalDtype())
.cast<BaseTensorType>();
dim = rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value end = rewriter.create<ConstantIntOp>(
@ -923,7 +930,7 @@ public:
sizes.append(inputShape.begin(), inputShape.end());
sizes[cstDim] = kUnknownSize;
Type sliceTy = selfTy.getWithSizesAndDtype(llvm::makeArrayRef(sizes),
selfTy.getDtype());
selfTy.getOptionalDtype());
Value slice0 = rewriter.create<AtenSliceTensorOp>(
loc, sliceTy, input, dim, negShift, constNone, constOne);
Value slice1 = rewriter.create<AtenSliceTensorOp>(
@ -1057,7 +1064,7 @@ public:
reshapedSizes.push_back(scaledSize);
}
Type dtype = self.getType().cast<ValueTensorType>().getDtype();
Type dtype = self.getType().cast<ValueTensorType>().getOptionalDtype();
Type unsqueezedType = ValueTensorType::get(
context, llvm::makeArrayRef(unsqueezedIntSizes), dtype);
Type expandedType = ValueTensorType::get(
@ -1493,10 +1500,8 @@ public:
}
// TODO: Handle integer type operands.
if (!input.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
auto inputType = input.getType().cast<BaseTensorType>();
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: non-floating point dtype");
}
@ -2067,7 +2072,7 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
for (int i = 0; i < axis; i++)
meanVarSizes[i] = input.getSizes()[i];
auto meanVarType = input.getWithSizesAndDtype(
llvm::makeArrayRef(meanVarSizes), input.getDtype());
llvm::makeArrayRef(meanVarSizes), input.getOptionalDtype());
auto nativeLayerNorm = rewriter.create<AtenNativeLayerNormOp>(
loc, op.getType(), meanVarType, meanVarType, op.getInput(),
op.getNormalizedShape(), op.getWeight(), op.getBias(), op.getEps());
@ -2302,7 +2307,7 @@ class DecomposeAtenNativeBatchNormOp
SmallVector<int64_t> runningStatsShapeInt(inputRank, 1);
runningStatsShapeInt[1] = kUnknownSize;
Type dtype = input.getType().cast<ValueTensorType>().getDtype();
Type dtype = input.getType().cast<ValueTensorType>().getOptionalDtype();
Type reshapeType = ValueTensorType::get(
context, llvm::makeArrayRef(runningStatsShapeInt), dtype);
@ -2419,6 +2424,10 @@ class DecomposeConstantTensorNewLikeOp : public OpRewritePattern<OpTy> {
if (dtype.getType().isa<Torch::NoneType>()) {
BaseTensorType tensorType =
op.getSelf().getType().template cast<BaseTensorType>();
if (!tensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected input tensor to have a dtype");
}
dtype =
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
}
@ -2439,6 +2448,10 @@ public:
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
BaseTensorType outTy = op.getType().template cast<BaseTensorType>();
if (!outTy.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected result type to have a dtype");
}
SmallVector<int64_t> empty;
auto dtype =
getTypeForTorchType(op.getContext(), op.getFillValue().getType());
@ -2479,7 +2492,7 @@ public:
SmallVector<int64_t> transposeShape =
llvm::to_vector(llvm::reverse(weightType.getSizes()));
Type transposeType = weightType.getWithSizesAndDtype(
llvm::makeArrayRef(transposeShape), weightType.getDtype());
llvm::makeArrayRef(transposeShape), weightType.getOptionalDtype());
Value transposeWeight =
rewriter.create<AtenTOp>(loc, transposeType, weight);
@ -2542,6 +2555,10 @@ public:
LogicalResult matchAndRewrite(AtenFullLikeOp op,
PatternRewriter &rewriter) const override {
BaseTensorType outTy = op.getType().template cast<BaseTensorType>();
if (!outTy.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected result type to have a dtype");
}
SmallVector<int64_t> empty;
auto dtype =
getTypeForTorchType(op.getContext(), op.getFillValue().getType());
@ -2598,7 +2615,12 @@ public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_ToCopyOp op,
PatternRewriter &rewriter) const override {
Type resultDtype = op.getType().cast<BaseTensorType>().getDtype();
auto resultType = op.getType().cast<BaseTensorType>();
if (!resultType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected result type to have a dtype");
}
Type resultDtype = resultType.getDtype();
Value zero = getConstantWithGivenDtypeAndValue(rewriter, op.getLoc(), 0.0,
resultDtype);
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
@ -2618,7 +2640,12 @@ public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenCopyOp op,
PatternRewriter &rewriter) const override {
Type resultDtype = op.getType().cast<BaseTensorType>().getDtype();
auto resultType = op.getType().cast<BaseTensorType>();
if (!resultType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected result type to have a dtype");
}
Type resultDtype = resultType.getDtype();
Value srcToDtype =
convertTensorToDtype(rewriter, op.getLoc(), op.getSrc(), resultDtype);
rewriter.replaceOpWithNewOp<AtenExpandAsOp>(op, op.getType(), srcToDtype,
@ -2638,6 +2665,10 @@ class DecomposeAtenNewEmptyOp : public OpRewritePattern<AtenNewEmptyOp> {
Value dtype = op.getDtype();
if (dtype.getType().isa<Torch::NoneType>()) {
BaseTensorType tensorType = op.getSelf().getType().cast<BaseTensorType>();
if (!tensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected input tensor to have a dtype");
}
dtype =
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
}
@ -2980,6 +3011,10 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
Type outputType = op.getType();
BaseTensorType outputTensorType = outputType.cast<BaseTensorType>();
if (!outputTensorType.hasDtype()) {
return rewriter.notifyMatchFailure(op,
"expected result type to have a dtype");
}
Type newOutputType = outputTensorType.getWithSizesAndDtype(
outputTensorType.getSizes(), rewriter.getF64Type());
if (!inputTensorTy.hasDtype() ||
@ -3169,8 +3204,8 @@ public:
} else {
sizes.resize(srcShape.size() + 1, kUnknownSize);
}
Type srcType = srcTensorType.getWithSizesAndDtype(llvm::makeArrayRef(sizes),
srcTensorType.getDtype());
Type srcType = srcTensorType.getWithSizesAndDtype(
llvm::makeArrayRef(sizes), srcTensorType.getOptionalDtype());
src = rewriter.create<AtenUnsqueezeOp>(loc, srcType, src, dim);
rewriter.replaceOpWithNewOp<AtenSliceScatterOp>(
op, op.getSelf().getType(), self, src, dim, start, startPlusOne,
@ -3269,7 +3304,7 @@ public:
BaseTensorType subType =
inputType
.getWithSizesAndDtype(llvm::makeArrayRef(inputType.getSizes()),
resultType.getDtype())
resultType.getOptionalDtype())
.cast<BaseTensorType>();
Value sub = createTensorSub(rewriter, loc, subType, op.getSelf(), op.getTarget());
@ -3305,6 +3340,10 @@ public:
Location loc = op.getLoc();
Type resultType = op.getType();
BaseTensorType resultTensorType = resultType.cast<BaseTensorType>();
if (!resultTensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected result type to have a dtype");
}
int64_t cstLow, cstHigh;
if (!matchPattern(op.getLow(), m_TorchConstantInt(&cstLow)))