mirror of https://github.com/llvm/torch-mlir
Remove folder from `AtenStackOp` for single element list inputs (#2626)
`AtenStackOp` defines this folder for list operand containing single element: ``` OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) { auto list = getOperand(0).getDefiningOp<PrimListConstructOp>(); if (!list || !list->hasOneUse() || list.getElements().size() != 1) return nullptr; return list.getElements()[0]; } ``` However, unlike `AtenCatOp`, `AtenStackOp` cannot be folded away for single element list operand because the result from a stack operation contains an additional dimension (of size 1, like expand_shape). This PR removes the `AtenStackOp::fold` method, and adds an e2e test for single element list input case, which fails on current `main` as follows: ``` Unexpected outcome summary: (linalg) ****** Failed tests - 1 tests FAIL - "TensorsStackSingleElementListModule_basic" @ trace item #0 - call to "forward" @ output of call to "forward" ERROR: shape (torch.Size([10, 32])) is not equal to golden shape (torch.Size([10, 1, 32])) ``` Thanks Chris Lalau Keraly for the bug report.pull/2628/merge
parent
0b4422a253
commit
7acabafd84
|
@ -11632,7 +11632,6 @@ def Torch_AtenStackOp : Torch_Op<"aten.stack", [
|
|||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenAppendTOp : Torch_Op<"aten.append.t", [
|
||||
|
|
|
@ -2444,17 +2444,6 @@ OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) {
|
|||
return list.getElements()[0];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenStackOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) {
|
||||
auto list = getOperand(0).getDefiningOp<PrimListConstructOp>();
|
||||
if (!list || !list->hasOneUse() || list.getElements().size() != 1)
|
||||
return nullptr;
|
||||
return list.getElements()[0];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenBroadcastToOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -691,7 +691,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
|
||||
# List ops.
|
||||
emit("aten::cat : (Tensor[], int) -> (Tensor)", has_folder=True)
|
||||
emit("aten::stack : (Tensor[], int) -> (Tensor)", has_folder=True)
|
||||
emit("aten::stack : (Tensor[], int) -> (Tensor)")
|
||||
emit("aten::append.t : (t[], t) -> (t[])")
|
||||
emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True)
|
||||
emit("aten::eq.int_list : (int[], int[]) -> (bool)", has_folder=True)
|
||||
|
|
|
@ -884,6 +884,28 @@ def TensorsStackModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class TensorsStackSingleElementListModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.stack([x], dim=1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TensorsStackSingleElementListModule())
|
||||
def TensorsStackSingleElementListModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TensorsStackNegativeDimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue