mirror of https://github.com/llvm/torch-mlir
Add `hasDtype` checks everywhere dtypes are used in decompositions (#1750)
There are several decompositions that assume the operands of the op have dtypes available; however, the only time dtypes are guaranteed to be present is when the graph has reached the backend contract. In general, every pass that happens before reaching the backend contract should not assume dtypes are available and should use `hasDtype` to check first. This commit adds `hasDtype` checks to every decomposition that uses dtypes.pull/1766/head
parent
273664ded6
commit
d44bdd2728
|
@ -71,7 +71,7 @@ static Type computeReductionType(PatternRewriter &rewriter, Operation *op,
|
|||
Type resultType = tensorType.getWithSizesAndDtype(
|
||||
sizes.size() == 0 ? std::optional<ArrayRef<int64_t>>()
|
||||
: llvm::makeArrayRef(sizes),
|
||||
tensorType.getDtype());
|
||||
tensorType.getOptionalDtype());
|
||||
return resultType;
|
||||
}
|
||||
|
||||
|
@ -407,6 +407,11 @@ public:
|
|||
op, "Expected a boolean value for half_to_float");
|
||||
|
||||
BaseTensorType resultTensorType = op.getType().cast<BaseTensorType>();
|
||||
if (!resultTensorType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected result type to have a dtype");
|
||||
}
|
||||
Type resultTensorDtype = resultTensorType.getDtype();
|
||||
// `torch.ops.aten._softmax`'s softmax with half to float conversion is not
|
||||
// supported on CPU, but we go ahead with the decomposing.
|
||||
// TODO: Add an e2e test once upstream support is added.
|
||||
|
@ -418,7 +423,7 @@ public:
|
|||
Value cstFalse = rewriter.create<ConstantBoolOp>(loc, false);
|
||||
self = rewriter.create<AtenToDtypeOp>(
|
||||
loc, resultTensorType, self,
|
||||
getDtypeIntValueForType(rewriter, loc, resultTensorType.getDtype()),
|
||||
getDtypeIntValueForType(rewriter, loc, resultTensorDtype),
|
||||
/*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none);
|
||||
}
|
||||
Value result = getSoftmaxResult(op, self, resultTensorType, rewriter);
|
||||
|
@ -558,8 +563,8 @@ public:
|
|||
return failure();
|
||||
BaseTensorType valueTensorType =
|
||||
inputType
|
||||
.getWithSizesAndDtype(indicesTensorType.getSizes(),
|
||||
inputType.getDtype())
|
||||
.getWithSizesAndDtype(indicesTensorType.getOptionalSizes(),
|
||||
inputType.getOptionalDtype())
|
||||
.cast<BaseTensorType>();
|
||||
|
||||
// If the dim type is `NoneType` i.e. reduce along all the dimensions.
|
||||
|
@ -568,7 +573,9 @@ public:
|
|||
// 0th dimension.
|
||||
if (dim.getType().isa<Torch::NoneType>()) {
|
||||
BaseTensorType flattenType =
|
||||
inputType.getWithSizesAndDtype({kUnknownSize}, inputType.getDtype())
|
||||
inputType
|
||||
.getWithSizesAndDtype({kUnknownSize},
|
||||
inputType.getOptionalDtype())
|
||||
.cast<BaseTensorType>();
|
||||
dim = rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
Value end = rewriter.create<ConstantIntOp>(
|
||||
|
@ -923,7 +930,7 @@ public:
|
|||
sizes.append(inputShape.begin(), inputShape.end());
|
||||
sizes[cstDim] = kUnknownSize;
|
||||
Type sliceTy = selfTy.getWithSizesAndDtype(llvm::makeArrayRef(sizes),
|
||||
selfTy.getDtype());
|
||||
selfTy.getOptionalDtype());
|
||||
Value slice0 = rewriter.create<AtenSliceTensorOp>(
|
||||
loc, sliceTy, input, dim, negShift, constNone, constOne);
|
||||
Value slice1 = rewriter.create<AtenSliceTensorOp>(
|
||||
|
@ -1057,7 +1064,7 @@ public:
|
|||
reshapedSizes.push_back(scaledSize);
|
||||
}
|
||||
|
||||
Type dtype = self.getType().cast<ValueTensorType>().getDtype();
|
||||
Type dtype = self.getType().cast<ValueTensorType>().getOptionalDtype();
|
||||
Type unsqueezedType = ValueTensorType::get(
|
||||
context, llvm::makeArrayRef(unsqueezedIntSizes), dtype);
|
||||
Type expandedType = ValueTensorType::get(
|
||||
|
@ -1493,10 +1500,8 @@ public:
|
|||
}
|
||||
|
||||
// TODO: Handle integer type operands.
|
||||
if (!input.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getDtype()
|
||||
.isa<mlir::FloatType>()) {
|
||||
auto inputType = input.getType().cast<BaseTensorType>();
|
||||
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: non-floating point dtype");
|
||||
}
|
||||
|
@ -2067,7 +2072,7 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
|
|||
for (int i = 0; i < axis; i++)
|
||||
meanVarSizes[i] = input.getSizes()[i];
|
||||
auto meanVarType = input.getWithSizesAndDtype(
|
||||
llvm::makeArrayRef(meanVarSizes), input.getDtype());
|
||||
llvm::makeArrayRef(meanVarSizes), input.getOptionalDtype());
|
||||
auto nativeLayerNorm = rewriter.create<AtenNativeLayerNormOp>(
|
||||
loc, op.getType(), meanVarType, meanVarType, op.getInput(),
|
||||
op.getNormalizedShape(), op.getWeight(), op.getBias(), op.getEps());
|
||||
|
@ -2302,7 +2307,7 @@ class DecomposeAtenNativeBatchNormOp
|
|||
|
||||
SmallVector<int64_t> runningStatsShapeInt(inputRank, 1);
|
||||
runningStatsShapeInt[1] = kUnknownSize;
|
||||
Type dtype = input.getType().cast<ValueTensorType>().getDtype();
|
||||
Type dtype = input.getType().cast<ValueTensorType>().getOptionalDtype();
|
||||
Type reshapeType = ValueTensorType::get(
|
||||
context, llvm::makeArrayRef(runningStatsShapeInt), dtype);
|
||||
|
||||
|
@ -2419,6 +2424,10 @@ class DecomposeConstantTensorNewLikeOp : public OpRewritePattern<OpTy> {
|
|||
if (dtype.getType().isa<Torch::NoneType>()) {
|
||||
BaseTensorType tensorType =
|
||||
op.getSelf().getType().template cast<BaseTensorType>();
|
||||
if (!tensorType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected input tensor to have a dtype");
|
||||
}
|
||||
dtype =
|
||||
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
|
||||
}
|
||||
|
@ -2439,6 +2448,10 @@ public:
|
|||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
BaseTensorType outTy = op.getType().template cast<BaseTensorType>();
|
||||
if (!outTy.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected result type to have a dtype");
|
||||
}
|
||||
SmallVector<int64_t> empty;
|
||||
auto dtype =
|
||||
getTypeForTorchType(op.getContext(), op.getFillValue().getType());
|
||||
|
@ -2479,7 +2492,7 @@ public:
|
|||
SmallVector<int64_t> transposeShape =
|
||||
llvm::to_vector(llvm::reverse(weightType.getSizes()));
|
||||
Type transposeType = weightType.getWithSizesAndDtype(
|
||||
llvm::makeArrayRef(transposeShape), weightType.getDtype());
|
||||
llvm::makeArrayRef(transposeShape), weightType.getOptionalDtype());
|
||||
Value transposeWeight =
|
||||
rewriter.create<AtenTOp>(loc, transposeType, weight);
|
||||
|
||||
|
@ -2542,6 +2555,10 @@ public:
|
|||
LogicalResult matchAndRewrite(AtenFullLikeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
BaseTensorType outTy = op.getType().template cast<BaseTensorType>();
|
||||
if (!outTy.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected result type to have a dtype");
|
||||
}
|
||||
SmallVector<int64_t> empty;
|
||||
auto dtype =
|
||||
getTypeForTorchType(op.getContext(), op.getFillValue().getType());
|
||||
|
@ -2598,7 +2615,12 @@ public:
|
|||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(Aten_ToCopyOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Type resultDtype = op.getType().cast<BaseTensorType>().getDtype();
|
||||
auto resultType = op.getType().cast<BaseTensorType>();
|
||||
if (!resultType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected result type to have a dtype");
|
||||
}
|
||||
Type resultDtype = resultType.getDtype();
|
||||
Value zero = getConstantWithGivenDtypeAndValue(rewriter, op.getLoc(), 0.0,
|
||||
resultDtype);
|
||||
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
|
||||
|
@ -2618,7 +2640,12 @@ public:
|
|||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenCopyOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Type resultDtype = op.getType().cast<BaseTensorType>().getDtype();
|
||||
auto resultType = op.getType().cast<BaseTensorType>();
|
||||
if (!resultType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected result type to have a dtype");
|
||||
}
|
||||
Type resultDtype = resultType.getDtype();
|
||||
Value srcToDtype =
|
||||
convertTensorToDtype(rewriter, op.getLoc(), op.getSrc(), resultDtype);
|
||||
rewriter.replaceOpWithNewOp<AtenExpandAsOp>(op, op.getType(), srcToDtype,
|
||||
|
@ -2638,6 +2665,10 @@ class DecomposeAtenNewEmptyOp : public OpRewritePattern<AtenNewEmptyOp> {
|
|||
Value dtype = op.getDtype();
|
||||
if (dtype.getType().isa<Torch::NoneType>()) {
|
||||
BaseTensorType tensorType = op.getSelf().getType().cast<BaseTensorType>();
|
||||
if (!tensorType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected input tensor to have a dtype");
|
||||
}
|
||||
dtype =
|
||||
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
|
||||
}
|
||||
|
@ -2980,6 +3011,10 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
|
|||
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
|
||||
Type outputType = op.getType();
|
||||
BaseTensorType outputTensorType = outputType.cast<BaseTensorType>();
|
||||
if (!outputTensorType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"expected result type to have a dtype");
|
||||
}
|
||||
Type newOutputType = outputTensorType.getWithSizesAndDtype(
|
||||
outputTensorType.getSizes(), rewriter.getF64Type());
|
||||
if (!inputTensorTy.hasDtype() ||
|
||||
|
@ -3169,8 +3204,8 @@ public:
|
|||
} else {
|
||||
sizes.resize(srcShape.size() + 1, kUnknownSize);
|
||||
}
|
||||
Type srcType = srcTensorType.getWithSizesAndDtype(llvm::makeArrayRef(sizes),
|
||||
srcTensorType.getDtype());
|
||||
Type srcType = srcTensorType.getWithSizesAndDtype(
|
||||
llvm::makeArrayRef(sizes), srcTensorType.getOptionalDtype());
|
||||
src = rewriter.create<AtenUnsqueezeOp>(loc, srcType, src, dim);
|
||||
rewriter.replaceOpWithNewOp<AtenSliceScatterOp>(
|
||||
op, op.getSelf().getType(), self, src, dim, start, startPlusOne,
|
||||
|
@ -3269,7 +3304,7 @@ public:
|
|||
BaseTensorType subType =
|
||||
inputType
|
||||
.getWithSizesAndDtype(llvm::makeArrayRef(inputType.getSizes()),
|
||||
resultType.getDtype())
|
||||
resultType.getOptionalDtype())
|
||||
.cast<BaseTensorType>();
|
||||
|
||||
Value sub = createTensorSub(rewriter, loc, subType, op.getSelf(), op.getTarget());
|
||||
|
@ -3305,6 +3340,10 @@ public:
|
|||
Location loc = op.getLoc();
|
||||
Type resultType = op.getType();
|
||||
BaseTensorType resultTensorType = resultType.cast<BaseTensorType>();
|
||||
if (!resultTensorType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected result type to have a dtype");
|
||||
}
|
||||
|
||||
int64_t cstLow, cstHigh;
|
||||
if (!matchPattern(op.getLow(), m_TorchConstantInt(&cstLow)))
|
||||
|
|
Loading…
Reference in New Issue