mirror of https://github.com/llvm/torch-mlir
[TOSA] Fix Tensor.hacked_twin to support diff size indexes (#3547)
- Broadcasts index list tensors - Adds torch.nn.Unfold test Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>pull/3576/head
parent
8bd1b9751f
commit
d3efab984b
|
@ -3797,13 +3797,126 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
|
|||
indicesTfConcatTensors.push_back(indicesTfOneDim.getResult());
|
||||
}
|
||||
|
||||
// Right now only support multiple indexes with same shape
|
||||
// TODO for different shape multiple indexes, add broadcast_to for small
|
||||
// shape
|
||||
auto getRankExtendedShape =
|
||||
[](SmallVector<int64_t> inputShape,
|
||||
SmallVector<int64_t> maxRank1DimShape) -> SmallVector<int64_t> {
|
||||
SmallVector<int64_t> rankExtendedShape(maxRank1DimShape);
|
||||
auto inputRank = inputShape.size();
|
||||
auto maxRank = maxRank1DimShape.size();
|
||||
auto startIdx = maxRank - inputRank;
|
||||
for (size_t i = startIdx; i < maxRank; i++) {
|
||||
rankExtendedShape[i] = inputShape[i - startIdx];
|
||||
}
|
||||
return rankExtendedShape;
|
||||
};
|
||||
|
||||
bool hasDiffShapedIndexes = false;
|
||||
for (auto indexShapeOneDim : indexesShape) {
|
||||
if (!llvm::equal(indexesShape[0], indexShapeOneDim)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: Only support multi indexes with same shape");
|
||||
hasDiffShapedIndexes = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (hasDiffShapedIndexes) {
|
||||
int64_t maxRank = 1;
|
||||
for (auto idxRank : indexesRank) {
|
||||
if (idxRank > maxRank)
|
||||
maxRank = idxRank;
|
||||
}
|
||||
// Tensor shape of max rank, each dim being 1
|
||||
SmallVector<int64_t> maxRank1DimShape;
|
||||
for (int i = 0; i < maxRank; i++)
|
||||
maxRank1DimShape.push_back(1);
|
||||
// Tensor shape of max rank, each dim being the max dim.
|
||||
SmallVector<int64_t> maxRankMaxDimShape(maxRank1DimShape);
|
||||
|
||||
auto updateMaxRankMaxDimShape =
|
||||
[&](SmallVector<int64_t> broadcastedShape) -> LogicalResult {
|
||||
for (size_t i = 0; i < maxRankMaxDimShape.size(); i++) {
|
||||
// check for malformed index tensors
|
||||
if (broadcastedShape[i] != 1 && maxRankMaxDimShape[i] != 1 &&
|
||||
maxRankMaxDimShape[i] != broadcastedShape[i]) {
|
||||
return failure();
|
||||
}
|
||||
if (broadcastedShape[i] > maxRankMaxDimShape[i])
|
||||
maxRankMaxDimShape[i] = broadcastedShape[i];
|
||||
}
|
||||
return success();
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < indexesRank.size(); i++) {
|
||||
// Reshape all index tensors to same maxRank
|
||||
auto idxRank = indexesRank[i];
|
||||
auto unreshapedIdxTensor = indicesTfConcatTensors[i];
|
||||
SmallVector<int64_t> broadcastedShape =
|
||||
getRankExtendedShape(indexesShape[i], maxRank1DimShape);
|
||||
|
||||
if (idxRank < maxRank) {
|
||||
auto idxType =
|
||||
dyn_cast<RankedTensorType>(indicesTfConcatTensors[i].getType());
|
||||
// indicesTfConcatTensors has a trailing [1] dim for the final concat.
|
||||
auto broadcastedShapeTf(broadcastedShape);
|
||||
broadcastedShapeTf.push_back(1);
|
||||
auto reshapeOutputTy = RankedTensorType::get(
|
||||
broadcastedShapeTf, idxType.getElementType());
|
||||
// Update the tensor array with the max rank-extended form
|
||||
indicesTfConcatTensors[i] = rewriter.create<tosa::ReshapeOp>(
|
||||
op->getLoc(), reshapeOutputTy, unreshapedIdxTensor,
|
||||
rewriter.getDenseI64ArrayAttr(broadcastedShapeTf));
|
||||
}
|
||||
|
||||
// Construct the max rank broadcasted form of all index tensors with
|
||||
// each index tensor.
|
||||
if (updateMaxRankMaxDimShape(broadcastedShape).failed()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Malformed index tensors that have mismatched dim shapes");
|
||||
}
|
||||
|
||||
// Every index now has the same rank but not yet same shape until
|
||||
// tosa.tile below.
|
||||
indexesShape[i] = broadcastedShape;
|
||||
indexesRank[i] = maxRank;
|
||||
}
|
||||
|
||||
auto getTileOpShape = [&](SmallVector<int64_t> indexShape,
|
||||
SmallVector<int64_t> &tileOpShape) -> bool {
|
||||
bool needsTiling = false;
|
||||
for (size_t i = 0; i < indexShape.size(); i++) {
|
||||
if (1 == indexShape[i]) {
|
||||
tileOpShape.push_back(maxRankMaxDimShape[i]);
|
||||
needsTiling = true;
|
||||
} else {
|
||||
tileOpShape.push_back(1);
|
||||
}
|
||||
}
|
||||
return needsTiling;
|
||||
};
|
||||
|
||||
// Use tosa.tile to broadcast in multiple dims so all index tensors have
|
||||
// the same shape. This materializes new tensors.
|
||||
for (size_t i = 0; i < indexesRank.size(); i++) {
|
||||
SmallVector<int64_t> tileOpShape;
|
||||
bool needsTiling = getTileOpShape(indexesShape[i], tileOpShape);
|
||||
|
||||
if (needsTiling) {
|
||||
auto idxType =
|
||||
dyn_cast<RankedTensorType>(indicesTfConcatTensors[i].getType());
|
||||
// indicesTfConcatTensors has a trailing [1] dim for the final concat.
|
||||
auto maxRankMaxDimShapeTf(maxRankMaxDimShape);
|
||||
maxRankMaxDimShapeTf.push_back(1);
|
||||
auto tileOpShapeTf(tileOpShape);
|
||||
tileOpShapeTf.push_back(1);
|
||||
auto tileOutputTy = RankedTensorType::get(maxRankMaxDimShapeTf,
|
||||
idxType.getElementType());
|
||||
auto reshapedIdxTensor = indicesTfConcatTensors[i];
|
||||
indicesTfConcatTensors[i] = rewriter.create<tosa::TileOp>(
|
||||
op->getLoc(), tileOutputTy, reshapedIdxTensor,
|
||||
rewriter.getDenseI64ArrayAttr(tileOpShapeTf));
|
||||
}
|
||||
|
||||
// Every index tensor now has the same rank and shape
|
||||
indexesShape[i] = maxRankMaxDimShape;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -30,6 +30,7 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
|
|||
# this is added to check the torch.onnx.export -> import_onnx -> torch path
|
||||
"DeformConv2D_basic",
|
||||
"ReduceAnyDimFloatModule_basic",
|
||||
"UnfoldModule_basic",
|
||||
}
|
||||
|
||||
LINALG_CRASHING_SET = {
|
||||
|
@ -1983,6 +1984,8 @@ TOSA_PASS_SET = {
|
|||
"TorchPrimLoopForLikeTensorArgModule_basic",
|
||||
"RenormModuleFloat32NegativeDim_basic",
|
||||
"RenormModuleFloat32_basic",
|
||||
"IndexTensorStaticContiguousWithNoneModule_basic",
|
||||
"IndexTensorStaticNonContiguousWithNoneModule_basic",
|
||||
}
|
||||
|
||||
MAKE_FX_TOSA_PASS_SET = (
|
||||
|
@ -2750,6 +2753,7 @@ ONNX_XFAIL_SET = {
|
|||
"ReduceAnyFloatModule_basic",
|
||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||
"ReduceMinAlongDimUnsignedInt_basic",
|
||||
"UnfoldModule_basic",
|
||||
}
|
||||
|
||||
if torch_version_for_comparison() < version.parse("2.3.0.dev"):
|
||||
|
@ -3189,7 +3193,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"IndexSelectWholeTensorModule_basic",
|
||||
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
|
||||
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
|
||||
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
|
||||
"IndexTensorMultiInputContiguousCenter_basic",
|
||||
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
|
||||
"IndexTensorMultiInputNonContiguousDynamic_basic",
|
||||
|
|
|
@ -5646,3 +5646,27 @@ def AtenKthvalueFloat64DynamicDimsModule_basic(module, tu: TestUtils):
|
|||
module.forward(
|
||||
torch.randperm(4 * 2 * 8 * 3, dtype=torch.float64).reshape(4, 2, 8, 3)
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class UnfoldModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.unfold = torch.nn.Unfold(kernel_size=(2, 3))
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, input):
|
||||
return self.unfold(input)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: UnfoldModule())
|
||||
def UnfoldModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 5, 3, 4))
|
||||
|
|
Loading…
Reference in New Issue