[MLIR][TORCH] Patch up Ops and their lowerings to deal with +ve `dim`

-- In Python we have the concept of negative dimension indexing.
-- We would want to normalize such dimensions to be +ve and within the
   expected range instead.
-- This commit takes care of a few remaining set of Ops and their
   lowerings by applying `toPositiveDim` and `isValidDim` to the
   extracted integer `dim` value.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
pull/2034/head snapshot-20230414.808
Abhishek Varma 2023-04-07 11:49:35 +00:00 committed by Abhishek Varma
parent 1bd5747ca3
commit 318fe13468
11 changed files with 109 additions and 20 deletions

View File

@ -231,6 +231,7 @@ STABLEHLO_PASS_SET = {
"GatherModule_basic",
"Gather2DInputModdule_basic",
"GatherRandomIndexModule_basic",
"GatherNegativeDimModule_basic",
"GeluBackwardModule_basic",
"HardswishModule_basic",
"HardswishRandomModule_basic",
@ -243,6 +244,7 @@ STABLEHLO_PASS_SET = {
"IndexSelectTwoIdxModule_basic",
"IndexSelectWholeDimensionModule_basic",
"IndexSelectWholeTensorModule_basic",
"IndexSelectNegativeDimModule_basic",
"LayerNormLastDimModule_basic",
"LayerNormModule_basic",
"LayerNormNormalizeOverAllDimsModule_basic",

View File

@ -856,10 +856,9 @@ public:
return rewriter.notifyMatchFailure(op, "dim must be constant");
auto inputRank =
adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
if (dim < 0)
dim += inputRank + 1;
if (!(0 <= dim && dim <= inputRank))
return rewriter.notifyMatchFailure(op, "statically invalid");
dim = toPositiveDim(dim, inputRank + 1);
if (!isValidDim(dim, inputRank + 1))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
SmallVector<ReassociationIndices> reassociationMap(inputRank);
// From the perspective of the reassociation map, the situation of
@ -1083,11 +1082,6 @@ public:
Location loc = op.getLoc();
TypeConverter *typeConverter = getTypeConverter();
Value dimValue = op.getDim();
int64_t dim;
if (!matchPattern(dimValue, m_TorchConstantInt(&dim)))
return op.emitError("unimplemented: dim is not constant");
// Collect all the tensors to be concatenated.
auto tensorList = op.getTensors();
SmallVector<Value> tensorsTorchType;
@ -1112,6 +1106,14 @@ public:
}
int rank = newResultType.getRank();
Value dimValue = op.getDim();
int64_t dim;
if (!matchPattern(dimValue, m_TorchConstantInt(&dim)))
return op.emitError("unimplemented: dim is not constant");
dim = toPositiveDim(dim, rank);
if (!isValidDim(dim, rank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
SmallVector<Value> offsets, sizes, strides;
sizes.reserve(rank);
strides.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 1));
@ -1120,10 +1122,6 @@ public:
for (int i = 0; i < rank; ++i)
sizes.push_back(rewriter.createOrFold<tensor::DimOp>(loc, tensors[0], i));
dim = toPositiveDim(dim, rank);
if (!isValidDim(dim, rank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
// Calculate the size of the `dim` result dimension by adding the dim size
// of each tensor together.
Value resultDimSize = sizes[dim];

View File

@ -79,6 +79,10 @@ public:
int64_t dim;
if (!matchPattern(dimValue, m_TorchConstantInt(&dim)))
return op.emitError("unimplemented: dim is not constant");
int64_t inputRank = adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
Value indices = adaptor.getIndex();
Value self = adaptor.getSelf();
@ -476,6 +480,9 @@ public:
int64_t dimInt;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dimInt)))
return op->emitError("unimplemented: dim is not constant");
dimInt = toPositiveDim(dimInt, inputRank);
if (!isValidDim(dimInt, inputRank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
SmallVector<Value> resultShape = getTensorSizes(rewriter, loc, input);
resultShape[dimInt] = getTensorSizes(rewriter, loc, indices)[0];

View File

@ -577,6 +577,8 @@ LogicalResult ConvertAtenOp<AtenSizeIntOp>::matchAndRewrite(
int64_t dimInt;
if (matchPattern(op.getDim(), m_TorchConstantInt(&dimInt))) {
dimInt = toPositiveDim(dimInt, selfType.getRank());
if (!isValidDim(dimInt, selfType.getRank()))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
dim = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), dimInt);
} else {
Value inputRank = rewriter.create<arith::ConstantOp>(
@ -1189,6 +1191,9 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(op,
"only constant dim param is supported");
}
dim = toPositiveDim(dim, outType.getRank());
if (!isValidDim(dim, outType.getRank()))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
SmallVector<Value> torchTensors;
if (!getListConstructElements(op.getTensors(), torchTensors)) {
@ -1203,9 +1208,8 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
v = hlo::promoteType(rewriter, v, outType);
}
size_t posDim = toPositiveDim(dim, outType.getRank());
rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(
op, outType, ValueRange(builtinTensors), posDim);
op, outType, ValueRange(builtinTensors), dim);
return success();
}

View File

@ -228,6 +228,10 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
op, "only constant dim is currently supported");
int64_t inputRank = selfTy.getRank();
dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
Value output = gatherTensorAlongSingleAxis(
rewriter, op, self, adaptor.getIndex(), dim, options.dimSizeIndexBits);

View File

@ -268,6 +268,10 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
op, "only constant dim is currently supported");
int64_t inputRank = selfTy.getRank();
dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
auto getOptionalVal = [&](Value val) -> std::optional<Value> {
if (val.getType().isa<Torch::NoneType>()) {
@ -343,17 +347,20 @@ LogicalResult ConvertAtenOp<AtenSqueezeDimOp>::matchAndRewrite(
auto selfTy = self.getType().cast<RankedTensorType>();
if (!selfTy)
return op.emitError("only ranked tensor types are supported");
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
op, "only constant dim is currently supported");
auto rank = selfTy.getRank();
if (rank == 0)
return rewriter.notifyMatchFailure(
op, "the rank of tensor must be greater than 0");
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
op, "only constant dim is currently supported");
dim = toPositiveDim(dim, rank);
if (!isValidDim(dim, rank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
if (selfTy.getShape()[dim] != 1) {
if (selfTy.getShape()[dim] == ShapedType::kDynamic)
return rewriter.notifyMatchFailure(
@ -396,6 +403,10 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return op->emitError("dim must be a Scalar constant");
int64_t inputRank = adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
dim = toPositiveDim(dim, inputRank + 1);
if (!isValidDim(dim, inputRank + 1))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
auto unsqzTensorInfo = hlo::unsqueezeTensor(rewriter, op, adaptor.getSelf(),
{dim}, options.dimSizeIndexBits);

View File

@ -717,6 +717,12 @@ class ConvertAtenMultipleDimsReductionOp
return rewriter.notifyMatchFailure(op,
"non-const dim parameter unsupported");
int64_t N = reduceDims.size();
int64_t inputRank = adaptor.getSelf().getType().template cast<RankedTensorType>().getRank();
for (unsigned i=0; i<N; i++) {
reduceDims[i] = toPositiveDim(reduceDims[i], inputRank);
if (!isValidDim(reduceDims[i], inputRank))
return rewriter.notifyMatchFailure(op, "reduce dim is statically invalid");
}
auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type());
reduceDimsAttr =
DenseIntElementsAttr::get(reduceDimsType, llvm::ArrayRef(reduceDims));
@ -747,6 +753,10 @@ class ConvertAtenOneDimReductionOp
if (!matchPattern(op.getDim(), m_TorchConstantInt(&reduceDim)))
return rewriter.notifyMatchFailure(op,
"non-const dim parameter unsupported");
int64_t inputRank = adaptor.getSelf().getType().template cast<RankedTensorType>().getRank();
reduceDim = toPositiveDim(reduceDim, inputRank);
if (!isValidDim(reduceDim, inputRank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
auto reduceDimsType = RankedTensorType::get({1}, rewriter.getI64Type());
reduceDimsAttr =
DenseIntElementsAttr::get(reduceDimsType, llvm::ArrayRef({reduceDim}));
@ -806,6 +816,11 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewrite(
if (!matchPattern(op.getDim(), m_TorchConstantInt(&reduceDim))) {
// NoneType indicates reduce on all dims
reduceDim = -1;
} else {
int64_t inputRank = selfTy.getRank();
reduceDim = toPositiveDim(reduceDim, inputRank);
if (!isValidDim(reduceDim, inputRank))
return rewriter.notifyMatchFailure(op, "reduce dim is statically invalid");
}
bool keepDim = false;

View File

@ -1146,6 +1146,8 @@ OpFoldResult AtenSizeIntOp::fold(FoldAdaptor adaptor) {
return nullptr;
ArrayRef<int64_t> sizes = type->getSizes();
dim = toPositiveDim(dim, sizes.size());
if (!isValidDim(dim, sizes.size()))
return nullptr;
return IntegerAttr::get(IntegerType::get(getContext(), 64), sizes[dim]);
}

View File

@ -213,12 +213,16 @@ public:
return rewriter.notifyMatchFailure(
op, "Expected a constant boolean value for keepDim");
Value input = op.getSelf();
Value input = op.getSelf();
// For every dimension included in `dim` of the op, iterated over in
// reverse order, we create a call to aten.max.dim.
std::sort(dims.begin(), dims.end());
std::reverse(dims.begin(), dims.end());
for (int64_t dimInt : dims) {
int64_t inputRank = input.getType().cast<Torch::ValueTensorType>().getSizes().size();
dimInt = toPositiveDim(dimInt, inputRank);
if (!isValidDim(dimInt, inputRank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
Value dim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(dimInt));
// The input to the next invocation of aten.max.dim is the output of the

View File

@ -744,6 +744,29 @@ def GatherModule_basic(module, tu: TestUtils):
# ==============================================================================
class GatherNegativeDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.int64, True),
])
def forward(self, tensor, indices):
return torch.gather(tensor, -1, indices)
@register_test_case(module_factory=lambda: GatherNegativeDimModule())
def GatherNegativeDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4), torch.tensor([[[1, 2, 3], [1, 2, 3]]]))
# ==============================================================================
class GatherRandomIndexModule(torch.nn.Module):
def __init__(self):
super().__init__()

View File

@ -31,6 +31,25 @@ def IndexSelectSingleIdxModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 5, 6), torch.tensor([2]))
class IndexSelectNegativeDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([4, 5, 6], torch.float32, True),
([1], torch.int64, True),
])
def forward(self, input, indices):
return torch.index_select(input, -1, indices)
@register_test_case(module_factory=lambda: IndexSelectNegativeDimModule())
def IndexSelectNegativeDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 5, 6), torch.tensor([2]))
class IndexSelectTwoIdxModule(torch.nn.Module):
def __init__(self):
super().__init__()