diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 7ccd0449c..c8089f6f3 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -11632,7 +11632,6 @@ def Torch_AtenStackOp : Torch_Op<"aten.stack", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; - let hasFolder = 1; } def Torch_AtenAppendTOp : Torch_Op<"aten.append.t", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 111f1a7b6..32a550ce8 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2444,17 +2444,6 @@ OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) { return list.getElements()[0]; } -//===----------------------------------------------------------------------===// -// AtenStackOp -//===----------------------------------------------------------------------===// - -OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) { - auto list = getOperand(0).getDefiningOp(); - if (!list || !list->hasOneUse() || list.getElements().size() != 1) - return nullptr; - return list.getElements()[0]; -} - //===----------------------------------------------------------------------===// // AtenBroadcastToOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index f50eb461b..6da354120 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -691,7 +691,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): # List ops. emit("aten::cat : (Tensor[], int) -> (Tensor)", has_folder=True) - emit("aten::stack : (Tensor[], int) -> (Tensor)", has_folder=True) + emit("aten::stack : (Tensor[], int) -> (Tensor)") emit("aten::append.t : (t[], t) -> (t[])") emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True) emit("aten::eq.int_list : (int[], int[]) -> (bool)", has_folder=True) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index a9d9e2e9c..971aa1efc 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -884,6 +884,28 @@ def TensorsStackModule_basic(module, tu: TestUtils): # ============================================================================== +class TensorsStackSingleElementListModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.stack([x], dim=1) + + +@register_test_case(module_factory=lambda: TensorsStackSingleElementListModule()) +def TensorsStackSingleElementListModule_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 32)) + + +# ============================================================================== + + class TensorsStackNegativeDimModule(torch.nn.Module): def __init__(self):