mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Patch up Ops and their lowerings to deal with +ve `dim`
-- In Python we have the concept of negative dimension indexing. -- We would want to normalize such dimensions to be +ve and within the expected range instead. -- This commit takes care of a few remaining set of Ops and their lowerings by applying `toPositiveDim` and `isValidDim` to the extracted integer `dim` value. Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>pull/2034/head snapshot-20230414.808
parent
1bd5747ca3
commit
318fe13468
|
@ -231,6 +231,7 @@ STABLEHLO_PASS_SET = {
|
|||
"GatherModule_basic",
|
||||
"Gather2DInputModdule_basic",
|
||||
"GatherRandomIndexModule_basic",
|
||||
"GatherNegativeDimModule_basic",
|
||||
"GeluBackwardModule_basic",
|
||||
"HardswishModule_basic",
|
||||
"HardswishRandomModule_basic",
|
||||
|
@ -243,6 +244,7 @@ STABLEHLO_PASS_SET = {
|
|||
"IndexSelectTwoIdxModule_basic",
|
||||
"IndexSelectWholeDimensionModule_basic",
|
||||
"IndexSelectWholeTensorModule_basic",
|
||||
"IndexSelectNegativeDimModule_basic",
|
||||
"LayerNormLastDimModule_basic",
|
||||
"LayerNormModule_basic",
|
||||
"LayerNormNormalizeOverAllDimsModule_basic",
|
||||
|
|
|
@ -856,10 +856,9 @@ public:
|
|||
return rewriter.notifyMatchFailure(op, "dim must be constant");
|
||||
auto inputRank =
|
||||
adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
|
||||
if (dim < 0)
|
||||
dim += inputRank + 1;
|
||||
if (!(0 <= dim && dim <= inputRank))
|
||||
return rewriter.notifyMatchFailure(op, "statically invalid");
|
||||
dim = toPositiveDim(dim, inputRank + 1);
|
||||
if (!isValidDim(dim, inputRank + 1))
|
||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||
|
||||
SmallVector<ReassociationIndices> reassociationMap(inputRank);
|
||||
// From the perspective of the reassociation map, the situation of
|
||||
|
@ -1083,11 +1082,6 @@ public:
|
|||
Location loc = op.getLoc();
|
||||
TypeConverter *typeConverter = getTypeConverter();
|
||||
|
||||
Value dimValue = op.getDim();
|
||||
int64_t dim;
|
||||
if (!matchPattern(dimValue, m_TorchConstantInt(&dim)))
|
||||
return op.emitError("unimplemented: dim is not constant");
|
||||
|
||||
// Collect all the tensors to be concatenated.
|
||||
auto tensorList = op.getTensors();
|
||||
SmallVector<Value> tensorsTorchType;
|
||||
|
@ -1112,6 +1106,14 @@ public:
|
|||
}
|
||||
|
||||
int rank = newResultType.getRank();
|
||||
Value dimValue = op.getDim();
|
||||
int64_t dim;
|
||||
if (!matchPattern(dimValue, m_TorchConstantInt(&dim)))
|
||||
return op.emitError("unimplemented: dim is not constant");
|
||||
dim = toPositiveDim(dim, rank);
|
||||
if (!isValidDim(dim, rank))
|
||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||
|
||||
SmallVector<Value> offsets, sizes, strides;
|
||||
sizes.reserve(rank);
|
||||
strides.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 1));
|
||||
|
@ -1120,10 +1122,6 @@ public:
|
|||
for (int i = 0; i < rank; ++i)
|
||||
sizes.push_back(rewriter.createOrFold<tensor::DimOp>(loc, tensors[0], i));
|
||||
|
||||
dim = toPositiveDim(dim, rank);
|
||||
if (!isValidDim(dim, rank))
|
||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||
|
||||
// Calculate the size of the `dim` result dimension by adding the dim size
|
||||
// of each tensor together.
|
||||
Value resultDimSize = sizes[dim];
|
||||
|
|
|
@ -79,6 +79,10 @@ public:
|
|||
int64_t dim;
|
||||
if (!matchPattern(dimValue, m_TorchConstantInt(&dim)))
|
||||
return op.emitError("unimplemented: dim is not constant");
|
||||
int64_t inputRank = adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
|
||||
dim = toPositiveDim(dim, inputRank);
|
||||
if (!isValidDim(dim, inputRank))
|
||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||
|
||||
Value indices = adaptor.getIndex();
|
||||
Value self = adaptor.getSelf();
|
||||
|
@ -476,6 +480,9 @@ public:
|
|||
int64_t dimInt;
|
||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dimInt)))
|
||||
return op->emitError("unimplemented: dim is not constant");
|
||||
dimInt = toPositiveDim(dimInt, inputRank);
|
||||
if (!isValidDim(dimInt, inputRank))
|
||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||
|
||||
SmallVector<Value> resultShape = getTensorSizes(rewriter, loc, input);
|
||||
resultShape[dimInt] = getTensorSizes(rewriter, loc, indices)[0];
|
||||
|
|
|
@ -577,6 +577,8 @@ LogicalResult ConvertAtenOp<AtenSizeIntOp>::matchAndRewrite(
|
|||
int64_t dimInt;
|
||||
if (matchPattern(op.getDim(), m_TorchConstantInt(&dimInt))) {
|
||||
dimInt = toPositiveDim(dimInt, selfType.getRank());
|
||||
if (!isValidDim(dimInt, selfType.getRank()))
|
||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||
dim = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), dimInt);
|
||||
} else {
|
||||
Value inputRank = rewriter.create<arith::ConstantOp>(
|
||||
|
@ -1189,6 +1191,9 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(op,
|
||||
"only constant dim param is supported");
|
||||
}
|
||||
dim = toPositiveDim(dim, outType.getRank());
|
||||
if (!isValidDim(dim, outType.getRank()))
|
||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||
|
||||
SmallVector<Value> torchTensors;
|
||||
if (!getListConstructElements(op.getTensors(), torchTensors)) {
|
||||
|
@ -1203,9 +1208,8 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
|
|||
v = hlo::promoteType(rewriter, v, outType);
|
||||
}
|
||||
|
||||
size_t posDim = toPositiveDim(dim, outType.getRank());
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(
|
||||
op, outType, ValueRange(builtinTensors), posDim);
|
||||
op, outType, ValueRange(builtinTensors), dim);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -228,6 +228,10 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
|
|||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only constant dim is currently supported");
|
||||
int64_t inputRank = selfTy.getRank();
|
||||
dim = toPositiveDim(dim, inputRank);
|
||||
if (!isValidDim(dim, inputRank))
|
||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||
|
||||
Value output = gatherTensorAlongSingleAxis(
|
||||
rewriter, op, self, adaptor.getIndex(), dim, options.dimSizeIndexBits);
|
||||
|
|
|
@ -268,6 +268,10 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
|
|||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only constant dim is currently supported");
|
||||
int64_t inputRank = selfTy.getRank();
|
||||
dim = toPositiveDim(dim, inputRank);
|
||||
if (!isValidDim(dim, inputRank))
|
||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||
|
||||
auto getOptionalVal = [&](Value val) -> std::optional<Value> {
|
||||
if (val.getType().isa<Torch::NoneType>()) {
|
||||
|
@ -343,17 +347,20 @@ LogicalResult ConvertAtenOp<AtenSqueezeDimOp>::matchAndRewrite(
|
|||
auto selfTy = self.getType().cast<RankedTensorType>();
|
||||
if (!selfTy)
|
||||
return op.emitError("only ranked tensor types are supported");
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only constant dim is currently supported");
|
||||
|
||||
auto rank = selfTy.getRank();
|
||||
if (rank == 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "the rank of tensor must be greater than 0");
|
||||
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only constant dim is currently supported");
|
||||
dim = toPositiveDim(dim, rank);
|
||||
if (!isValidDim(dim, rank))
|
||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||
|
||||
if (selfTy.getShape()[dim] != 1) {
|
||||
if (selfTy.getShape()[dim] == ShapedType::kDynamic)
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -396,6 +403,10 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
|
|||
int64_t dim;
|
||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||
return op->emitError("dim must be a Scalar constant");
|
||||
int64_t inputRank = adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
|
||||
dim = toPositiveDim(dim, inputRank + 1);
|
||||
if (!isValidDim(dim, inputRank + 1))
|
||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||
|
||||
auto unsqzTensorInfo = hlo::unsqueezeTensor(rewriter, op, adaptor.getSelf(),
|
||||
{dim}, options.dimSizeIndexBits);
|
||||
|
|
|
@ -717,6 +717,12 @@ class ConvertAtenMultipleDimsReductionOp
|
|||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const dim parameter unsupported");
|
||||
int64_t N = reduceDims.size();
|
||||
int64_t inputRank = adaptor.getSelf().getType().template cast<RankedTensorType>().getRank();
|
||||
for (unsigned i=0; i<N; i++) {
|
||||
reduceDims[i] = toPositiveDim(reduceDims[i], inputRank);
|
||||
if (!isValidDim(reduceDims[i], inputRank))
|
||||
return rewriter.notifyMatchFailure(op, "reduce dim is statically invalid");
|
||||
}
|
||||
auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type());
|
||||
reduceDimsAttr =
|
||||
DenseIntElementsAttr::get(reduceDimsType, llvm::ArrayRef(reduceDims));
|
||||
|
@ -747,6 +753,10 @@ class ConvertAtenOneDimReductionOp
|
|||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&reduceDim)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const dim parameter unsupported");
|
||||
int64_t inputRank = adaptor.getSelf().getType().template cast<RankedTensorType>().getRank();
|
||||
reduceDim = toPositiveDim(reduceDim, inputRank);
|
||||
if (!isValidDim(reduceDim, inputRank))
|
||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||
auto reduceDimsType = RankedTensorType::get({1}, rewriter.getI64Type());
|
||||
reduceDimsAttr =
|
||||
DenseIntElementsAttr::get(reduceDimsType, llvm::ArrayRef({reduceDim}));
|
||||
|
@ -806,6 +816,11 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewrite(
|
|||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&reduceDim))) {
|
||||
// NoneType indicates reduce on all dims
|
||||
reduceDim = -1;
|
||||
} else {
|
||||
int64_t inputRank = selfTy.getRank();
|
||||
reduceDim = toPositiveDim(reduceDim, inputRank);
|
||||
if (!isValidDim(reduceDim, inputRank))
|
||||
return rewriter.notifyMatchFailure(op, "reduce dim is statically invalid");
|
||||
}
|
||||
|
||||
bool keepDim = false;
|
||||
|
|
|
@ -1146,6 +1146,8 @@ OpFoldResult AtenSizeIntOp::fold(FoldAdaptor adaptor) {
|
|||
return nullptr;
|
||||
ArrayRef<int64_t> sizes = type->getSizes();
|
||||
dim = toPositiveDim(dim, sizes.size());
|
||||
if (!isValidDim(dim, sizes.size()))
|
||||
return nullptr;
|
||||
return IntegerAttr::get(IntegerType::get(getContext(), 64), sizes[dim]);
|
||||
}
|
||||
|
||||
|
|
|
@ -213,12 +213,16 @@ public:
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected a constant boolean value for keepDim");
|
||||
|
||||
Value input = op.getSelf();
|
||||
Value input = op.getSelf();
|
||||
// For every dimension included in `dim` of the op, iterated over in
|
||||
// reverse order, we create a call to aten.max.dim.
|
||||
std::sort(dims.begin(), dims.end());
|
||||
std::reverse(dims.begin(), dims.end());
|
||||
for (int64_t dimInt : dims) {
|
||||
int64_t inputRank = input.getType().cast<Torch::ValueTensorType>().getSizes().size();
|
||||
dimInt = toPositiveDim(dimInt, inputRank);
|
||||
if (!isValidDim(dimInt, inputRank))
|
||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||
Value dim = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(dimInt));
|
||||
// The input to the next invocation of aten.max.dim is the output of the
|
||||
|
|
|
@ -744,6 +744,29 @@ def GatherModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class GatherNegativeDimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, tensor, indices):
|
||||
return torch.gather(tensor, -1, indices)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: GatherNegativeDimModule())
|
||||
def GatherNegativeDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4), torch.tensor([[[1, 2, 3], [1, 2, 3]]]))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class GatherRandomIndexModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -31,6 +31,25 @@ def IndexSelectSingleIdxModule_basic(module, tu: TestUtils):
|
|||
module.forward(tu.rand(4, 5, 6), torch.tensor([2]))
|
||||
|
||||
|
||||
class IndexSelectNegativeDimModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([4, 5, 6], torch.float32, True),
|
||||
([1], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, input, indices):
|
||||
return torch.index_select(input, -1, indices)
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexSelectNegativeDimModule())
|
||||
def IndexSelectNegativeDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5, 6), torch.tensor([2]))
|
||||
|
||||
|
||||
class IndexSelectTwoIdxModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue