Remove folder from `AtenStackOp` for single element list inputs (#2626)

`AtenStackOp` defines this folder for list operand containing single
element:
```
OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) {
  auto list = getOperand(0).getDefiningOp<PrimListConstructOp>();
  if (!list || !list->hasOneUse() || list.getElements().size() != 1)
    return nullptr;
  return list.getElements()[0];
}
```
However, unlike `AtenCatOp`, `AtenStackOp` cannot be folded away for
single element list operand because the result from a stack operation
contains an additional dimension (of size 1, like expand_shape).

This PR removes the `AtenStackOp::fold` method, and adds an e2e test for
single element list input case, which fails on current `main` as
follows:
```
Unexpected outcome summary: (linalg)                                                                                                                                                                   
                                                                                                                                                                                                       
****** Failed tests - 1 tests                                                                                                                                                                          
    FAIL - "TensorsStackSingleElementListModule_basic"                                                                                                                                                 
        @ trace item #0 - call to "forward"                                                                                                                                                            
        @ output of call to "forward"                                                                                                                                                                  
        ERROR: shape (torch.Size([10, 32])) is not equal to golden shape (torch.Size([10, 1, 32]))     
```
Thanks Chris Lalau Keraly for the bug report.
pull/2628/merge
Sambhav Jain 2023-12-11 10:52:50 -08:00 committed by GitHub
parent 0b4422a253
commit 7acabafd84
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 23 additions and 13 deletions

View File

@ -11632,7 +11632,6 @@ def Torch_AtenStackOp : Torch_Op<"aten.stack", [
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}
def Torch_AtenAppendTOp : Torch_Op<"aten.append.t", [

View File

@ -2444,17 +2444,6 @@ OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) {
return list.getElements()[0];
}
//===----------------------------------------------------------------------===//
// AtenStackOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) {
auto list = getOperand(0).getDefiningOp<PrimListConstructOp>();
if (!list || !list->hasOneUse() || list.getElements().size() != 1)
return nullptr;
return list.getElements()[0];
}
//===----------------------------------------------------------------------===//
// AtenBroadcastToOp
//===----------------------------------------------------------------------===//

View File

@ -691,7 +691,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
# List ops.
emit("aten::cat : (Tensor[], int) -> (Tensor)", has_folder=True)
emit("aten::stack : (Tensor[], int) -> (Tensor)", has_folder=True)
emit("aten::stack : (Tensor[], int) -> (Tensor)")
emit("aten::append.t : (t[], t) -> (t[])")
emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True)
emit("aten::eq.int_list : (int[], int[]) -> (bool)", has_folder=True)

View File

@ -884,6 +884,28 @@ def TensorsStackModule_basic(module, tu: TestUtils):
# ==============================================================================
class TensorsStackSingleElementListModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.stack([x], dim=1)
@register_test_case(module_factory=lambda: TensorsStackSingleElementListModule())
def TensorsStackSingleElementListModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 32))
# ==============================================================================
class TensorsStackNegativeDimModule(torch.nn.Module):
def __init__(self):