Dynamic size support for flatten (#3005)

Added support for dynamic shapes in `flattenusingints` op in tosa
dialect. Due to this some Argmax tests pass
This PR fixes this issue https://github.com/llvm/torch-mlir/issues/3004

The following tests pass after this PR
 ```
1. "ArgmaxIntModule_basic"
2. "ArgmaxIntModule_multiple_maxs"
3. "ArgmaxModule_basic"
```
pull/3042/head
Abhishek-TyRnT 2024-03-20 03:49:29 +05:30 committed by GitHub
parent 7a9608bb69
commit df02692726
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 31 additions and 5 deletions

View File

@ -2485,10 +2485,9 @@ LogicalResult ConvertAtenOp<AtenFlattenUsingIntsOp>::matchAndRewrite(
// Not a ranked tensor type
auto selfType = adaptor.getSelf().getType().dyn_cast<RankedTensorType>();
if (!selfType || !selfType.hasStaticShape())
return rewriter.notifyMatchFailure(
op,
"Only ranked tensor types with static shapes are currently supported");
if (!selfType)
return rewriter.notifyMatchFailure(op,
"Only ranked tensor types supported");
int64_t selfRank = selfType.getRank();
@ -2520,8 +2519,11 @@ LogicalResult ConvertAtenOp<AtenFlattenUsingIntsOp>::matchAndRewrite(
} else {
if (idx == start_dim)
newShape.push_back(s.value());
else
// Only updating when the shapes are static
else if (s.value() != kUnknownSize && newShape.back() != kUnknownSize)
newShape.back() *= s.value();
else
newShape.back() = kUnknownSize;
}
}

View File

@ -885,6 +885,9 @@ TOSA_PASS_SET = {
"ArangeStartNegativeStepFloatModule_basic",
"ArangeStartOutDtypeModule_basic",
"ArangeStartStepFloatModule_basic",
"ArgmaxIntModule_basic",
"ArgmaxIntModule_multiple_maxs",
"ArgmaxModule_basic",
"ArgmaxModule_keepDim",
"ArgmaxModule_with_dim",
"AtenComplex64Module_basic",
@ -1077,6 +1080,7 @@ TOSA_PASS_SET = {
"EmbeddingModuleI32Static_basic",
"FlattenRank0Module_basic",
"FlattenStaticModule_basic",
"FlattenDynamicModuleCollapseAll_basic",
"FullLikeModuleFloat3DStatic_basic",
"FullLikeModuleInt2DStatic_basic",
"FullModuleDefaultDtype_basic",
@ -1292,6 +1296,7 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
}) - {
### Test failing in make_fx_tosa but not in tosa
"FlattenDynamicModuleCollapseAll_basic",
# Dynamic shape, has extra unsupported broadcast ops
"Matmul_3d",
"MatmulStaticBroadcast_basic",

View File

@ -391,6 +391,25 @@ class FlattenDynamicModule(torch.nn.Module):
def FlattenDynamicModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 3, 8, 9, 3, 4))
class FlattenDynamicModuleCollapseAll(torch.nn.Module):
def __init__(self):
super().__init__()
self.flat = torch.nn.Flatten(0)
@export
@annotate_args([
None,
([-1, -1, -1, 9, 3, -1], torch.float32, True),
])
def forward(self, x):
return self.flat(x)
@register_test_case(module_factory=lambda: FlattenDynamicModuleCollapseAll())
def FlattenDynamicModuleCollapseAll_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 3, 8, 9, 3, 4))
# ==============================================================================