[TOSA] Fix Tensor.hacked_twin to support diff size indexes (#3547)

- Broadcasts index list tensors
- Adds torch.nn.Unfold test

Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>
pull/3576/head
Suraj Sudhir 2024-07-30 14:32:05 -07:00 committed by GitHub
parent 8bd1b9751f
commit d3efab984b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 146 additions and 6 deletions

View File

@ -3797,13 +3797,126 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
indicesTfConcatTensors.push_back(indicesTfOneDim.getResult());
}
// Right now only support multiple indexes with same shape
// TODO for different shape multiple indexes, add broadcast_to for small
// shape
auto getRankExtendedShape =
[](SmallVector<int64_t> inputShape,
SmallVector<int64_t> maxRank1DimShape) -> SmallVector<int64_t> {
SmallVector<int64_t> rankExtendedShape(maxRank1DimShape);
auto inputRank = inputShape.size();
auto maxRank = maxRank1DimShape.size();
auto startIdx = maxRank - inputRank;
for (size_t i = startIdx; i < maxRank; i++) {
rankExtendedShape[i] = inputShape[i - startIdx];
}
return rankExtendedShape;
};
bool hasDiffShapedIndexes = false;
for (auto indexShapeOneDim : indexesShape) {
if (!llvm::equal(indexesShape[0], indexShapeOneDim)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: Only support multi indexes with same shape");
hasDiffShapedIndexes = true;
break;
}
}
if (hasDiffShapedIndexes) {
int64_t maxRank = 1;
for (auto idxRank : indexesRank) {
if (idxRank > maxRank)
maxRank = idxRank;
}
// Tensor shape of max rank, each dim being 1
SmallVector<int64_t> maxRank1DimShape;
for (int i = 0; i < maxRank; i++)
maxRank1DimShape.push_back(1);
// Tensor shape of max rank, each dim being the max dim.
SmallVector<int64_t> maxRankMaxDimShape(maxRank1DimShape);
auto updateMaxRankMaxDimShape =
[&](SmallVector<int64_t> broadcastedShape) -> LogicalResult {
for (size_t i = 0; i < maxRankMaxDimShape.size(); i++) {
// check for malformed index tensors
if (broadcastedShape[i] != 1 && maxRankMaxDimShape[i] != 1 &&
maxRankMaxDimShape[i] != broadcastedShape[i]) {
return failure();
}
if (broadcastedShape[i] > maxRankMaxDimShape[i])
maxRankMaxDimShape[i] = broadcastedShape[i];
}
return success();
};
for (size_t i = 0; i < indexesRank.size(); i++) {
// Reshape all index tensors to same maxRank
auto idxRank = indexesRank[i];
auto unreshapedIdxTensor = indicesTfConcatTensors[i];
SmallVector<int64_t> broadcastedShape =
getRankExtendedShape(indexesShape[i], maxRank1DimShape);
if (idxRank < maxRank) {
auto idxType =
dyn_cast<RankedTensorType>(indicesTfConcatTensors[i].getType());
// indicesTfConcatTensors has a trailing [1] dim for the final concat.
auto broadcastedShapeTf(broadcastedShape);
broadcastedShapeTf.push_back(1);
auto reshapeOutputTy = RankedTensorType::get(
broadcastedShapeTf, idxType.getElementType());
// Update the tensor array with the max rank-extended form
indicesTfConcatTensors[i] = rewriter.create<tosa::ReshapeOp>(
op->getLoc(), reshapeOutputTy, unreshapedIdxTensor,
rewriter.getDenseI64ArrayAttr(broadcastedShapeTf));
}
// Construct the max rank broadcasted form of all index tensors with
// each index tensor.
if (updateMaxRankMaxDimShape(broadcastedShape).failed()) {
return rewriter.notifyMatchFailure(
op, "Malformed index tensors that have mismatched dim shapes");
}
// Every index now has the same rank but not yet same shape until
// tosa.tile below.
indexesShape[i] = broadcastedShape;
indexesRank[i] = maxRank;
}
auto getTileOpShape = [&](SmallVector<int64_t> indexShape,
SmallVector<int64_t> &tileOpShape) -> bool {
bool needsTiling = false;
for (size_t i = 0; i < indexShape.size(); i++) {
if (1 == indexShape[i]) {
tileOpShape.push_back(maxRankMaxDimShape[i]);
needsTiling = true;
} else {
tileOpShape.push_back(1);
}
}
return needsTiling;
};
// Use tosa.tile to broadcast in multiple dims so all index tensors have
// the same shape. This materializes new tensors.
for (size_t i = 0; i < indexesRank.size(); i++) {
SmallVector<int64_t> tileOpShape;
bool needsTiling = getTileOpShape(indexesShape[i], tileOpShape);
if (needsTiling) {
auto idxType =
dyn_cast<RankedTensorType>(indicesTfConcatTensors[i].getType());
// indicesTfConcatTensors has a trailing [1] dim for the final concat.
auto maxRankMaxDimShapeTf(maxRankMaxDimShape);
maxRankMaxDimShapeTf.push_back(1);
auto tileOpShapeTf(tileOpShape);
tileOpShapeTf.push_back(1);
auto tileOutputTy = RankedTensorType::get(maxRankMaxDimShapeTf,
idxType.getElementType());
auto reshapedIdxTensor = indicesTfConcatTensors[i];
indicesTfConcatTensors[i] = rewriter.create<tosa::TileOp>(
op->getLoc(), tileOutputTy, reshapedIdxTensor,
rewriter.getDenseI64ArrayAttr(tileOpShapeTf));
}
// Every index tensor now has the same rank and shape
indexesShape[i] = maxRankMaxDimShape;
}
}

View File

@ -30,6 +30,7 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
# this is added to check the torch.onnx.export -> import_onnx -> torch path
"DeformConv2D_basic",
"ReduceAnyDimFloatModule_basic",
"UnfoldModule_basic",
}
LINALG_CRASHING_SET = {
@ -1983,6 +1984,8 @@ TOSA_PASS_SET = {
"TorchPrimLoopForLikeTensorArgModule_basic",
"RenormModuleFloat32NegativeDim_basic",
"RenormModuleFloat32_basic",
"IndexTensorStaticContiguousWithNoneModule_basic",
"IndexTensorStaticNonContiguousWithNoneModule_basic",
}
MAKE_FX_TOSA_PASS_SET = (
@ -2750,6 +2753,7 @@ ONNX_XFAIL_SET = {
"ReduceAnyFloatModule_basic",
"ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMinAlongDimUnsignedInt_basic",
"UnfoldModule_basic",
}
if torch_version_for_comparison() < version.parse("2.3.0.dev"):
@ -3189,7 +3193,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"IndexSelectWholeTensorModule_basic",
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
"IndexTensorMultiInputContiguousCenter_basic",
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
"IndexTensorMultiInputNonContiguousDynamic_basic",

View File

@ -5646,3 +5646,27 @@ def AtenKthvalueFloat64DynamicDimsModule_basic(module, tu: TestUtils):
module.forward(
torch.randperm(4 * 2 * 8 * 3, dtype=torch.float64).reshape(4, 2, 8, 3)
)
# ==============================================================================
class UnfoldModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.unfold = torch.nn.Unfold(kernel_size=(2, 3))
@export
@annotate_args(
[
None,
([-1, -1, -1, -1], torch.float32, True),
]
)
def forward(self, input):
return self.unfold(input)
@register_test_case(module_factory=lambda: UnfoldModule())
def UnfoldModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5, 3, 4))