mirror of https://github.com/llvm/torch-mlir
Dynamic size support for flatten (#3005)
Added support for dynamic shapes in `flattenusingints` op in tosa dialect. Due to this some Argmax tests pass This PR fixes this issue https://github.com/llvm/torch-mlir/issues/3004 The following tests pass after this PR ``` 1. "ArgmaxIntModule_basic" 2. "ArgmaxIntModule_multiple_maxs" 3. "ArgmaxModule_basic" ```pull/3042/head
parent
7a9608bb69
commit
df02692726
|
@ -2485,10 +2485,9 @@ LogicalResult ConvertAtenOp<AtenFlattenUsingIntsOp>::matchAndRewrite(
|
||||||
|
|
||||||
// Not a ranked tensor type
|
// Not a ranked tensor type
|
||||||
auto selfType = adaptor.getSelf().getType().dyn_cast<RankedTensorType>();
|
auto selfType = adaptor.getSelf().getType().dyn_cast<RankedTensorType>();
|
||||||
if (!selfType || !selfType.hasStaticShape())
|
if (!selfType)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(op,
|
||||||
op,
|
"Only ranked tensor types supported");
|
||||||
"Only ranked tensor types with static shapes are currently supported");
|
|
||||||
|
|
||||||
int64_t selfRank = selfType.getRank();
|
int64_t selfRank = selfType.getRank();
|
||||||
|
|
||||||
|
@ -2520,8 +2519,11 @@ LogicalResult ConvertAtenOp<AtenFlattenUsingIntsOp>::matchAndRewrite(
|
||||||
} else {
|
} else {
|
||||||
if (idx == start_dim)
|
if (idx == start_dim)
|
||||||
newShape.push_back(s.value());
|
newShape.push_back(s.value());
|
||||||
else
|
// Only updating when the shapes are static
|
||||||
|
else if (s.value() != kUnknownSize && newShape.back() != kUnknownSize)
|
||||||
newShape.back() *= s.value();
|
newShape.back() *= s.value();
|
||||||
|
else
|
||||||
|
newShape.back() = kUnknownSize;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -885,6 +885,9 @@ TOSA_PASS_SET = {
|
||||||
"ArangeStartNegativeStepFloatModule_basic",
|
"ArangeStartNegativeStepFloatModule_basic",
|
||||||
"ArangeStartOutDtypeModule_basic",
|
"ArangeStartOutDtypeModule_basic",
|
||||||
"ArangeStartStepFloatModule_basic",
|
"ArangeStartStepFloatModule_basic",
|
||||||
|
"ArgmaxIntModule_basic",
|
||||||
|
"ArgmaxIntModule_multiple_maxs",
|
||||||
|
"ArgmaxModule_basic",
|
||||||
"ArgmaxModule_keepDim",
|
"ArgmaxModule_keepDim",
|
||||||
"ArgmaxModule_with_dim",
|
"ArgmaxModule_with_dim",
|
||||||
"AtenComplex64Module_basic",
|
"AtenComplex64Module_basic",
|
||||||
|
@ -1077,6 +1080,7 @@ TOSA_PASS_SET = {
|
||||||
"EmbeddingModuleI32Static_basic",
|
"EmbeddingModuleI32Static_basic",
|
||||||
"FlattenRank0Module_basic",
|
"FlattenRank0Module_basic",
|
||||||
"FlattenStaticModule_basic",
|
"FlattenStaticModule_basic",
|
||||||
|
"FlattenDynamicModuleCollapseAll_basic",
|
||||||
"FullLikeModuleFloat3DStatic_basic",
|
"FullLikeModuleFloat3DStatic_basic",
|
||||||
"FullLikeModuleInt2DStatic_basic",
|
"FullLikeModuleInt2DStatic_basic",
|
||||||
"FullModuleDefaultDtype_basic",
|
"FullModuleDefaultDtype_basic",
|
||||||
|
@ -1292,6 +1296,7 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
|
||||||
}) - {
|
}) - {
|
||||||
### Test failing in make_fx_tosa but not in tosa
|
### Test failing in make_fx_tosa but not in tosa
|
||||||
|
|
||||||
|
"FlattenDynamicModuleCollapseAll_basic",
|
||||||
# Dynamic shape, has extra unsupported broadcast ops
|
# Dynamic shape, has extra unsupported broadcast ops
|
||||||
"Matmul_3d",
|
"Matmul_3d",
|
||||||
"MatmulStaticBroadcast_basic",
|
"MatmulStaticBroadcast_basic",
|
||||||
|
|
|
@ -391,6 +391,25 @@ class FlattenDynamicModule(torch.nn.Module):
|
||||||
def FlattenDynamicModule_basic(module, tu: TestUtils):
|
def FlattenDynamicModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(10, 3, 8, 9, 3, 4))
|
module.forward(tu.rand(10, 3, 8, 9, 3, 4))
|
||||||
|
|
||||||
|
class FlattenDynamicModuleCollapseAll(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.flat = torch.nn.Flatten(0)
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, 9, 3, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return self.flat(x)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: FlattenDynamicModuleCollapseAll())
|
||||||
|
def FlattenDynamicModuleCollapseAll_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(10, 3, 8, 9, 3, 4))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue