Make `getTensorRank` safer by changing return to `Optional<unsigned>` (#1707)

Currently `getTensorRank` returns -1 if it was unable to get the rank
of the tensor. However, not every use in the codebase was checking the
return value, and in some cases, the return value was casted to
unsigned leading to some infinte loops when an unranked tensor reached
a decomposition.

This commit changes the return of `getTensorRank` to
`Optional<unsigned>` to make it clear to the user that the function
can fail.

This commit also changes a couple of for loops that iterate a vector
in reverse order that can potentially become infinite loops into
range-based for loops.
pull/1715/head
Ramiro Leal-Cavazos 2022-12-12 08:56:28 -08:00 committed by GitHub
parent 430737b820
commit 73bd32d06c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 111 additions and 56 deletions

View File

@ -47,8 +47,8 @@ Value convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value input,
bool isBuiltInType(Type type);
// Helper funtion to get rank of `Base tensor type`.
// -1 is returned if the tensorRank can't be determined.
int getTensorRank(Value tensor);
// llvm::None is returned if the tensorRank can't be determined.
Optional<unsigned> getTensorRank(Value tensor);
bool isViewLikeOp(Operation *op);

View File

@ -279,7 +279,8 @@ public:
// to i32 as required for the scatter op.
// 2.) `values` is mapped to `updates` in scatter op.
// 3.) `input` is mapped to `original` in scatter op.
if (getTensorRank(indexTensor) != 1)
Optional<unsigned> indexTensorRank = getTensorRank(indexTensor);
if (!indexTensorRank || *indexTensorRank != 1)
return rewriter.notifyMatchFailure(
op, "unimplemented: index tensor with rank != 1 is not supported");
auto indexTensorType = indexTensor.getType().cast<BaseTensorType>();

View File

@ -136,8 +136,9 @@ static Value getScalarValue(Value input, Location loc,
}
Value scalar = nullptr;
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
if (valueTensorLiteralOp &&
getTensorRank(valueTensorLiteralOp.getResult()) == 0) {
Optional<unsigned> tensorRank =
getTensorRank(valueTensorLiteralOp.getResult());
if (valueTensorLiteralOp && tensorRank && *tensorRank == 0) {
auto tensorType =
valueTensorLiteralOp.getValue().getType().cast<RankedTensorType>();
if (tensorType.getElementType().isa<mlir::IntegerType>()) {

View File

@ -202,12 +202,13 @@ public:
op, "Expected a constant boolean value for keepDim");
Value input = op.getSelf();
std::sort(dims.begin(), dims.end());
// For every dimension included in `dim` of the op, iterated over in
// reverse order, we create a call to aten.max.dim.
for (int64_t i = dims.size() - 1; i >= 0; i--) {
std::sort(dims.begin(), dims.end());
std::reverse(dims.begin(), dims.end());
for (int64_t dimInt : dims) {
Value dim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(dims[i]));
loc, rewriter.getI64IntegerAttr(dimInt));
// The input to the next invocation of aten.max.dim is the output of the
// previous aten.max.dim op.
input = createMaxAlongDimension(rewriter, loc, op, input, dim, keepDim);
@ -227,11 +228,12 @@ public:
Location loc = op.getLoc();
Value self = op.getSelf();
MLIRContext *context = op.getContext();
int64_t rank = getTensorRank(self);
if (rank < 0)
Optional<unsigned> maybeRank = getTensorRank(self);
if (!maybeRank)
return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor");
unsigned rank = *maybeRank;
SmallVector<Value> sizes;
for (int i = 0; i < rank; i++) {
for (unsigned i = 0; i < rank; i++) {
Value dim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));
sizes.push_back(rewriter.create<AtenSizeIntOp>(loc, self, dim));
@ -546,7 +548,12 @@ public:
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
BaseTensorType indicesTensorType = result.getType().cast<BaseTensorType>();
Optional<unsigned> maybeInputRank = getTensorRank(input);
if (!maybeInputRank) {
return rewriter.notifyMatchFailure(
op, "expected input tensor to have a rank");
}
unsigned inputRank = *maybeInputRank;
if (!indicesTensorType.hasSizes())
return failure();
BaseTensorType valueTensorType =
@ -565,7 +572,7 @@ public:
.cast<BaseTensorType>();
dim = rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value end = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(getTensorRank(input) - 1));
loc, rewriter.getI64IntegerAttr(inputRank - 1));
input = rewriter.create<AtenFlattenUsingIntsOp>(loc, flattenType, input,
dim, end);
}
@ -674,8 +681,14 @@ public:
Value lhs = op.getSelf();
Value rhs = op.getOther();
int lhsRank = getTensorRank(lhs);
int rhsRank = getTensorRank(rhs);
Optional<unsigned> maybeLhsRank = getTensorRank(lhs);
Optional<unsigned> maybeRhsRank = getTensorRank(rhs);
if (!maybeLhsRank || !maybeRhsRank) {
return rewriter.notifyMatchFailure(
op, "expected input tensors to have a rank");
}
unsigned lhsRank = *maybeLhsRank;
unsigned rhsRank = *maybeRhsRank;
if (lhsRank == 2 && rhsRank == 2) {
// If both lhs and rhs ranks are 2 then map it to `aten.mm` op.
@ -773,15 +786,17 @@ public:
LogicalResult matchAndRewrite(AtenTOp op,
PatternRewriter &rewriter) const override {
Value lhs = op.getSelf();
int lhsRank = getTensorRank(lhs);
Optional<unsigned> lhsRank = getTensorRank(lhs);
auto loc = op.getLoc();
if (lhsRank > 2 || lhsRank < 0) {
if (!lhsRank) {
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
} else if (*lhsRank > 2) {
std::string errorMessage =
"t() expects a tensor with <=2 dimensions, but self is " +
std::to_string(lhsRank) + "D";
std::to_string(*lhsRank) + "D";
return rewriter.notifyMatchFailure(op, errorMessage.c_str());
} else if (lhsRank < 2)
} else if (*lhsRank < 2)
rewriter.replaceOp(op, lhs);
else {
Value zero =
@ -846,9 +861,10 @@ public:
loc, listType, llvm::ArrayRef<Value>{slice0, slice1});
return rewriter.create<AtenCatOp>(loc, self.getType(), slices, dim);
};
int rank = getTensorRank(self);
if (rank < 0)
Optional<unsigned> maybeRank = getTensorRank(self);
if (!maybeRank)
return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor");
unsigned rank = *maybeRank;
Value output = self;
auto nShifts = shifts.size();
for (size_t k = 0; k < nShifts; ++k) {
@ -901,16 +917,17 @@ public:
Location loc = op.getLoc();
Value self = op.getSelf();
MLIRContext *context = op.getContext();
int rank = getTensorRank(self);
if (rank < 0)
Optional<unsigned> maybeRank = getTensorRank(self);
if (!maybeRank)
return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor");
unsigned rank = *maybeRank;
SmallVector<Value> repeats;
if (!getListConstructElements(op.getRepeats(), repeats))
return rewriter.notifyMatchFailure(
op, "Unimplemented: repeats not list of Scalar");
if (rank > (int)repeats.size()) {
if (rank > repeats.size()) {
return rewriter.notifyMatchFailure(
op, "repeats are not matched with self's rank");
}
@ -946,7 +963,7 @@ public:
auto selfType = self.getType().dyn_cast<BaseTensorType>();
auto selfShape = selfType.getSizes();
for (int i = 0; i < rank; i++) {
for (unsigned i = 0; i < rank; i++) {
auto scale = repeats[i + leadingRank];
Value dimSize;
if (selfShape[i] == kUnknownSize) {
@ -1003,9 +1020,10 @@ public:
Location loc = op.getLoc();
Value self = op.getSelf();
MLIRContext *context = op.getContext();
int64_t rank = getTensorRank(self);
if (rank < 0)
Optional<unsigned> maybeRank = getTensorRank(self);
if (!maybeRank)
return rewriter.notifyMatchFailure(op, "unimplemented: unranked tensor");
unsigned rank = *maybeRank;
int64_t start, end;
if (!matchPattern(op.getStartDim(), m_TorchConstantInt(&start)) ||
@ -1239,6 +1257,13 @@ public:
Location loc = op.getLoc();
MLIRContext *context = op.getContext();
Value gradOutput = op.getGradOutput();
Optional<unsigned> maybeGradRank = getTensorRank(gradOutput);
if (!maybeGradRank) {
return rewriter.notifyMatchFailure(op,
"expected grad output to have a rank");
}
unsigned gradRank = *maybeGradRank;
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
@ -1249,10 +1274,8 @@ public:
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
loc, rewriter.getBoolAttr(false));
Value gradOutput = op.getGradOutput();
Value input = op.getInput();
Value weight = op.getWeight();
auto gradRank = getTensorRank(gradOutput);
if (gradRank != 4)
return rewriter.notifyMatchFailure(
@ -1299,7 +1322,7 @@ public:
// Rotate weight.
SmallVector<Value> axes;
for (auto i = 2; i < gradRank; i++) {
for (unsigned i = 2; i < gradRank; i++) {
axes.push_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i)));
}
@ -1309,7 +1332,7 @@ public:
axesList);
// Calculate padding for first convolution.
SmallVector<Value> gradInputPaddingValues;
for (auto i = 2; i < gradRank; i++) {
for (unsigned i = 2; i < gradRank; i++) {
Value dim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));
Value outDim = rewriter.create<Torch::AtenSizeIntOp>(loc, input, dim);
@ -1359,7 +1382,7 @@ public:
loc, gradWeight.getType(), gradWeight, cstZero, cstOne);
SmallVector<Value> dimIntList{cstZero};
for (auto i = 2; i < gradRank; i++)
for (unsigned i = 2; i < gradRank; i++)
dimIntList.push_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i)));
Value gradIntList = rewriter.create<Torch::PrimListConstructOp>(
@ -1387,9 +1410,11 @@ public:
Value input = op.getSelf();
Value mat1 = op.getMat1();
Value mat2 = op.getMat2();
Optional<unsigned> mat1Rank = getTensorRank(mat1);
Optional<unsigned> mat2Rank = getTensorRank(mat2);
// The operands `mat1`, `mat2` to aten.addmm must be of rank 2.
if (getTensorRank(mat1) != 2 || getTensorRank(mat2) != 2) {
if (!mat1Rank || !mat2Rank || *mat1Rank != 2 || *mat2Rank != 2) {
return rewriter.notifyMatchFailure(
op, "expected mat1, mat2 operands to aten.addmm to be rank 2");
}
@ -1447,7 +1472,12 @@ public:
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.getSelf();
unsigned inputRank = getTensorRank(input);
Optional<unsigned> maybeInputRank = getTensorRank(input);
if (!maybeInputRank) {
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
}
unsigned inputRank = *maybeInputRank;
Value dimList = op.getDim();
Value keepDim = op.getKeepdim();
Value dtype = op.getDtype();
@ -1572,7 +1602,11 @@ public:
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.getSelf();
unsigned inputRank = getTensorRank(self);
Optional<unsigned> maybeInputRank = getTensorRank(self);
if (!maybeInputRank) {
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
}
unsigned inputRank = *maybeInputRank;
BaseTensorType rank0FloatTensorTy = op.getType().cast<BaseTensorType>();
if (!rank0FloatTensorTy.hasSizes() ||
rank0FloatTensorTy.getSizes().size() != 0) {
@ -2129,10 +2163,11 @@ class DecomposeAtenNativeBatchNormOp
// Rank of the input tensor must be greater than or equal to 2. The shape of
// the `input` is supposed to be (N, C, D?, H?, W?).
int64_t inputRank = getTensorRank(input);
if (inputRank < 2)
Optional<unsigned> maybeInputRank = getTensorRank(input);
if (!maybeInputRank || *maybeInputRank < 2)
return rewriter.notifyMatchFailure(
op, "input must have rank greater than or equal to 2");
unsigned inputRank = *maybeInputRank;
// In the inference mode, the `runningMean` and `runningVar` must not be
// None.
@ -2142,7 +2177,10 @@ class DecomposeAtenNativeBatchNormOp
op, "running stats must not be None in inference mode");
// Rank of `runningMean` and `runningVar` must be exactly 1.
if (getTensorRank(runningMean) != 1 || getTensorRank(runningVar) != 1)
Optional<unsigned> runningMeanRank = getTensorRank(runningMean);
Optional<unsigned> runningVarRank = getTensorRank(runningVar);
if (!runningMeanRank || !runningVarRank || *runningMeanRank != 1 ||
*runningVarRank != 1)
return rewriter.notifyMatchFailure(
op, "expected runningMean and runningVar to be rank 1");
@ -2191,7 +2229,8 @@ class DecomposeAtenNativeBatchNormOp
Value batchNormOutput = normalizedInput;
if (!weight.getType().isa<Torch::NoneType>()) {
// Rank of `weight` must be exactly 1.
if (getTensorRank(weight) != 1)
Optional<unsigned> weightRank = getTensorRank(weight);
if (!weightRank || *weightRank != 1)
return rewriter.notifyMatchFailure(op, "expected weight to be rank 1");
weight = rewriter.create<AtenViewOp>(loc, reshapeType, weight,
runningStatsSizeList);
@ -2200,7 +2239,8 @@ class DecomposeAtenNativeBatchNormOp
}
if (!bias.getType().isa<Torch::NoneType>()) {
// Rank of `bias` must be exactly 1.
if (getTensorRank(bias) != 1)
Optional<unsigned> biasRank = getTensorRank(bias);
if (!biasRank || *biasRank != 1)
return rewriter.notifyMatchFailure(op, "expected bias to be rank 1");
bias = rewriter.create<AtenViewOp>(loc, reshapeType, bias,
runningStatsSizeList);
@ -2619,7 +2659,11 @@ class DecomposeAtenAdaptiveAvgPool2dOp
MLIRContext *context = op.getContext();
Value input = op.getSelf();
int64_t rank = getTensorRank(input);
Optional<unsigned> maybeRank = getTensorRank(input);
if (!maybeRank) {
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
}
unsigned rank = *maybeRank;
SmallVector<Value, 2> inputHW;
Value dimH = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(rank - 2));
@ -2788,12 +2832,19 @@ class DecomposeAtenNumpyTOp : public OpRewritePattern<AtenNumpyTOp> {
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.getSelf();
int64_t inputRank = getTensorRank(self);
Optional<unsigned> maybeInputRank = getTensorRank(self);
if (!maybeInputRank) {
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
}
unsigned inputRank = *maybeInputRank;
SmallVector<Value> dimListElements;
for (int64_t i = inputRank - 1; i >= 0; i--)
SmallVector<int> dimListInts(llvm::reverse(
llvm::iota_range<int>(0, inputRank, /*inclusive=*/false)));
for (int dimListInt : dimListInts) {
dimListElements.push_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i)));
loc, rewriter.getI64IntegerAttr(dimListInt)));
}
Value dimList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(op->getContext())),
dimListElements);
@ -2828,7 +2879,11 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
inputTensorTy = self.getType().cast<BaseTensorType>();
}
unsigned inputRank = getTensorRank(self);
Optional<unsigned> maybeInputRank = getTensorRank(self);
if (!maybeInputRank) {
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
}
unsigned inputRank = *maybeInputRank;
SmallVector<Value> dimListElements;
bool isNoneOrEmpty = true;
if (!dimList.getType().template isa<Torch::NoneType>()) {

View File

@ -348,10 +348,12 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenTOp>();
target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>();
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {
int lhsRank = getTensorRank(op.getSelf());
int rhsRank = getTensorRank(op.getOther());
Optional<unsigned> lhsRank = getTensorRank(op.getSelf());
Optional<unsigned> rhsRank = getTensorRank(op.getOther());
if (!lhsRank || !rhsRank)
return false;
// Make aten.matmul legal if the following condition is satisfied.
return (lhsRank != 2 || rhsRank != 2) && (lhsRank != 3 || rhsRank != 3);
return (*lhsRank != 2 || *rhsRank != 2) && (*lhsRank != 3 || *rhsRank != 3);
});
target.addIllegalOp<AtenAddcmulOp>();
target.addIllegalOp<AtenAddcdivOp>();

View File

@ -139,15 +139,11 @@ bool Torch::isBuiltInType(Type type) {
return isa<BuiltinDialect>(type.getDialect());
}
int Torch::getTensorRank(Value tensor) {
int tensorRank = -1;
Optional<unsigned> Torch::getTensorRank(Value tensor) {
BaseTensorType tensorType = tensor.getType().cast<BaseTensorType>();
if (tensorType.hasSizes()) {
ArrayRef<int64_t> tensorShape = tensorType.getSizes();
tensorRank = tensorShape.size();
}
return tensorRank;
if (!tensorType.hasSizes())
return llvm::None;
return tensorType.getSizes().size();
}
bool Torch::isViewLikeOp(Operation *op) {