mirror of https://github.com/llvm/torch-mlir
[TOSA] Fix Tensor.hacked_twin to support diff size indexes (#3547)
- Broadcasts index list tensors - Adds torch.nn.Unfold test Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>pull/3576/head
parent
8bd1b9751f
commit
d3efab984b
|
@ -3797,13 +3797,126 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
|
||||||
indicesTfConcatTensors.push_back(indicesTfOneDim.getResult());
|
indicesTfConcatTensors.push_back(indicesTfOneDim.getResult());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Right now only support multiple indexes with same shape
|
auto getRankExtendedShape =
|
||||||
// TODO for different shape multiple indexes, add broadcast_to for small
|
[](SmallVector<int64_t> inputShape,
|
||||||
// shape
|
SmallVector<int64_t> maxRank1DimShape) -> SmallVector<int64_t> {
|
||||||
|
SmallVector<int64_t> rankExtendedShape(maxRank1DimShape);
|
||||||
|
auto inputRank = inputShape.size();
|
||||||
|
auto maxRank = maxRank1DimShape.size();
|
||||||
|
auto startIdx = maxRank - inputRank;
|
||||||
|
for (size_t i = startIdx; i < maxRank; i++) {
|
||||||
|
rankExtendedShape[i] = inputShape[i - startIdx];
|
||||||
|
}
|
||||||
|
return rankExtendedShape;
|
||||||
|
};
|
||||||
|
|
||||||
|
bool hasDiffShapedIndexes = false;
|
||||||
for (auto indexShapeOneDim : indexesShape) {
|
for (auto indexShapeOneDim : indexesShape) {
|
||||||
if (!llvm::equal(indexesShape[0], indexShapeOneDim)) {
|
if (!llvm::equal(indexesShape[0], indexShapeOneDim)) {
|
||||||
return rewriter.notifyMatchFailure(
|
hasDiffShapedIndexes = true;
|
||||||
op, "unimplemented: Only support multi indexes with same shape");
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (hasDiffShapedIndexes) {
|
||||||
|
int64_t maxRank = 1;
|
||||||
|
for (auto idxRank : indexesRank) {
|
||||||
|
if (idxRank > maxRank)
|
||||||
|
maxRank = idxRank;
|
||||||
|
}
|
||||||
|
// Tensor shape of max rank, each dim being 1
|
||||||
|
SmallVector<int64_t> maxRank1DimShape;
|
||||||
|
for (int i = 0; i < maxRank; i++)
|
||||||
|
maxRank1DimShape.push_back(1);
|
||||||
|
// Tensor shape of max rank, each dim being the max dim.
|
||||||
|
SmallVector<int64_t> maxRankMaxDimShape(maxRank1DimShape);
|
||||||
|
|
||||||
|
auto updateMaxRankMaxDimShape =
|
||||||
|
[&](SmallVector<int64_t> broadcastedShape) -> LogicalResult {
|
||||||
|
for (size_t i = 0; i < maxRankMaxDimShape.size(); i++) {
|
||||||
|
// check for malformed index tensors
|
||||||
|
if (broadcastedShape[i] != 1 && maxRankMaxDimShape[i] != 1 &&
|
||||||
|
maxRankMaxDimShape[i] != broadcastedShape[i]) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (broadcastedShape[i] > maxRankMaxDimShape[i])
|
||||||
|
maxRankMaxDimShape[i] = broadcastedShape[i];
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
};
|
||||||
|
|
||||||
|
for (size_t i = 0; i < indexesRank.size(); i++) {
|
||||||
|
// Reshape all index tensors to same maxRank
|
||||||
|
auto idxRank = indexesRank[i];
|
||||||
|
auto unreshapedIdxTensor = indicesTfConcatTensors[i];
|
||||||
|
SmallVector<int64_t> broadcastedShape =
|
||||||
|
getRankExtendedShape(indexesShape[i], maxRank1DimShape);
|
||||||
|
|
||||||
|
if (idxRank < maxRank) {
|
||||||
|
auto idxType =
|
||||||
|
dyn_cast<RankedTensorType>(indicesTfConcatTensors[i].getType());
|
||||||
|
// indicesTfConcatTensors has a trailing [1] dim for the final concat.
|
||||||
|
auto broadcastedShapeTf(broadcastedShape);
|
||||||
|
broadcastedShapeTf.push_back(1);
|
||||||
|
auto reshapeOutputTy = RankedTensorType::get(
|
||||||
|
broadcastedShapeTf, idxType.getElementType());
|
||||||
|
// Update the tensor array with the max rank-extended form
|
||||||
|
indicesTfConcatTensors[i] = rewriter.create<tosa::ReshapeOp>(
|
||||||
|
op->getLoc(), reshapeOutputTy, unreshapedIdxTensor,
|
||||||
|
rewriter.getDenseI64ArrayAttr(broadcastedShapeTf));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct the max rank broadcasted form of all index tensors with
|
||||||
|
// each index tensor.
|
||||||
|
if (updateMaxRankMaxDimShape(broadcastedShape).failed()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Malformed index tensors that have mismatched dim shapes");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Every index now has the same rank but not yet same shape until
|
||||||
|
// tosa.tile below.
|
||||||
|
indexesShape[i] = broadcastedShape;
|
||||||
|
indexesRank[i] = maxRank;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto getTileOpShape = [&](SmallVector<int64_t> indexShape,
|
||||||
|
SmallVector<int64_t> &tileOpShape) -> bool {
|
||||||
|
bool needsTiling = false;
|
||||||
|
for (size_t i = 0; i < indexShape.size(); i++) {
|
||||||
|
if (1 == indexShape[i]) {
|
||||||
|
tileOpShape.push_back(maxRankMaxDimShape[i]);
|
||||||
|
needsTiling = true;
|
||||||
|
} else {
|
||||||
|
tileOpShape.push_back(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return needsTiling;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Use tosa.tile to broadcast in multiple dims so all index tensors have
|
||||||
|
// the same shape. This materializes new tensors.
|
||||||
|
for (size_t i = 0; i < indexesRank.size(); i++) {
|
||||||
|
SmallVector<int64_t> tileOpShape;
|
||||||
|
bool needsTiling = getTileOpShape(indexesShape[i], tileOpShape);
|
||||||
|
|
||||||
|
if (needsTiling) {
|
||||||
|
auto idxType =
|
||||||
|
dyn_cast<RankedTensorType>(indicesTfConcatTensors[i].getType());
|
||||||
|
// indicesTfConcatTensors has a trailing [1] dim for the final concat.
|
||||||
|
auto maxRankMaxDimShapeTf(maxRankMaxDimShape);
|
||||||
|
maxRankMaxDimShapeTf.push_back(1);
|
||||||
|
auto tileOpShapeTf(tileOpShape);
|
||||||
|
tileOpShapeTf.push_back(1);
|
||||||
|
auto tileOutputTy = RankedTensorType::get(maxRankMaxDimShapeTf,
|
||||||
|
idxType.getElementType());
|
||||||
|
auto reshapedIdxTensor = indicesTfConcatTensors[i];
|
||||||
|
indicesTfConcatTensors[i] = rewriter.create<tosa::TileOp>(
|
||||||
|
op->getLoc(), tileOutputTy, reshapedIdxTensor,
|
||||||
|
rewriter.getDenseI64ArrayAttr(tileOpShapeTf));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Every index tensor now has the same rank and shape
|
||||||
|
indexesShape[i] = maxRankMaxDimShape;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,7 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
|
||||||
# this is added to check the torch.onnx.export -> import_onnx -> torch path
|
# this is added to check the torch.onnx.export -> import_onnx -> torch path
|
||||||
"DeformConv2D_basic",
|
"DeformConv2D_basic",
|
||||||
"ReduceAnyDimFloatModule_basic",
|
"ReduceAnyDimFloatModule_basic",
|
||||||
|
"UnfoldModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
LINALG_CRASHING_SET = {
|
LINALG_CRASHING_SET = {
|
||||||
|
@ -1983,6 +1984,8 @@ TOSA_PASS_SET = {
|
||||||
"TorchPrimLoopForLikeTensorArgModule_basic",
|
"TorchPrimLoopForLikeTensorArgModule_basic",
|
||||||
"RenormModuleFloat32NegativeDim_basic",
|
"RenormModuleFloat32NegativeDim_basic",
|
||||||
"RenormModuleFloat32_basic",
|
"RenormModuleFloat32_basic",
|
||||||
|
"IndexTensorStaticContiguousWithNoneModule_basic",
|
||||||
|
"IndexTensorStaticNonContiguousWithNoneModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
MAKE_FX_TOSA_PASS_SET = (
|
MAKE_FX_TOSA_PASS_SET = (
|
||||||
|
@ -2750,6 +2753,7 @@ ONNX_XFAIL_SET = {
|
||||||
"ReduceAnyFloatModule_basic",
|
"ReduceAnyFloatModule_basic",
|
||||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||||
"ReduceMinAlongDimUnsignedInt_basic",
|
"ReduceMinAlongDimUnsignedInt_basic",
|
||||||
|
"UnfoldModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
if torch_version_for_comparison() < version.parse("2.3.0.dev"):
|
if torch_version_for_comparison() < version.parse("2.3.0.dev"):
|
||||||
|
@ -3189,7 +3193,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"IndexSelectWholeTensorModule_basic",
|
"IndexSelectWholeTensorModule_basic",
|
||||||
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
|
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
|
||||||
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
|
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
|
||||||
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
|
|
||||||
"IndexTensorMultiInputContiguousCenter_basic",
|
"IndexTensorMultiInputContiguousCenter_basic",
|
||||||
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
|
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
|
||||||
"IndexTensorMultiInputNonContiguousDynamic_basic",
|
"IndexTensorMultiInputNonContiguousDynamic_basic",
|
||||||
|
|
|
@ -5646,3 +5646,27 @@ def AtenKthvalueFloat64DynamicDimsModule_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(
|
||||||
torch.randperm(4 * 2 * 8 * 3, dtype=torch.float64).reshape(4, 2, 8, 3)
|
torch.randperm(4 * 2 * 8 * 3, dtype=torch.float64).reshape(4, 2, 8, 3)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class UnfoldModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.unfold = torch.nn.Unfold(kernel_size=(2, 3))
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, input):
|
||||||
|
return self.unfold(input)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: UnfoldModule())
|
||||||
|
def UnfoldModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 5, 3, 4))
|
||||||
|
|
Loading…
Reference in New Issue