mirror of https://github.com/llvm/torch-mlir
Remove folder from `AtenStackOp` for single element list inputs (#2626)
`AtenStackOp` defines this folder for list operand containing single element: ``` OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) { auto list = getOperand(0).getDefiningOp<PrimListConstructOp>(); if (!list || !list->hasOneUse() || list.getElements().size() != 1) return nullptr; return list.getElements()[0]; } ``` However, unlike `AtenCatOp`, `AtenStackOp` cannot be folded away for single element list operand because the result from a stack operation contains an additional dimension (of size 1, like expand_shape). This PR removes the `AtenStackOp::fold` method, and adds an e2e test for single element list input case, which fails on current `main` as follows: ``` Unexpected outcome summary: (linalg) ****** Failed tests - 1 tests FAIL - "TensorsStackSingleElementListModule_basic" @ trace item #0 - call to "forward" @ output of call to "forward" ERROR: shape (torch.Size([10, 32])) is not equal to golden shape (torch.Size([10, 1, 32])) ``` Thanks Chris Lalau Keraly for the bug report.pull/2628/merge
parent
0b4422a253
commit
7acabafd84
|
@ -11632,7 +11632,6 @@ def Torch_AtenStackOp : Torch_Op<"aten.stack", [
|
||||||
printDefaultTorchOp(printer, *this, 2, 1);
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
let hasFolder = 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AtenAppendTOp : Torch_Op<"aten.append.t", [
|
def Torch_AtenAppendTOp : Torch_Op<"aten.append.t", [
|
||||||
|
|
|
@ -2444,17 +2444,6 @@ OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) {
|
||||||
return list.getElements()[0];
|
return list.getElements()[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// AtenStackOp
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) {
|
|
||||||
auto list = getOperand(0).getDefiningOp<PrimListConstructOp>();
|
|
||||||
if (!list || !list->hasOneUse() || list.getElements().size() != 1)
|
|
||||||
return nullptr;
|
|
||||||
return list.getElements()[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AtenBroadcastToOp
|
// AtenBroadcastToOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -691,7 +691,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
|
|
||||||
# List ops.
|
# List ops.
|
||||||
emit("aten::cat : (Tensor[], int) -> (Tensor)", has_folder=True)
|
emit("aten::cat : (Tensor[], int) -> (Tensor)", has_folder=True)
|
||||||
emit("aten::stack : (Tensor[], int) -> (Tensor)", has_folder=True)
|
emit("aten::stack : (Tensor[], int) -> (Tensor)")
|
||||||
emit("aten::append.t : (t[], t) -> (t[])")
|
emit("aten::append.t : (t[], t) -> (t[])")
|
||||||
emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True)
|
emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True)
|
||||||
emit("aten::eq.int_list : (int[], int[]) -> (bool)", has_folder=True)
|
emit("aten::eq.int_list : (int[], int[]) -> (bool)", has_folder=True)
|
||||||
|
|
|
@ -884,6 +884,28 @@ def TensorsStackModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TensorsStackSingleElementListModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.stack([x], dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: TensorsStackSingleElementListModule())
|
||||||
|
def TensorsStackSingleElementListModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(10, 32))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class TensorsStackNegativeDimModule(torch.nn.Module):
|
class TensorsStackNegativeDimModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
Loading…
Reference in New Issue