diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index e0a23df9a..c86244f5f 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -413,7 +413,6 @@ def Torch_PrimListConstructOp: Torch_Op<"prim.ListConstruct", [ ); let hasVerifier = 1; - let hasCanonicalizer = 1; let assemblyFormat = [{ $elements attr-dict `:` functional-type(operands, results) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index d271243f7..d346d9db4 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2099,32 +2099,6 @@ void PrimTupleUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } -//===----------------------------------------------------------------------===// -// PrimListConstructOp -//===----------------------------------------------------------------------===// - -void PrimListConstructOp::getCanonicalizationPatterns( - RewritePatternSet &patterns, MLIRContext *context) { - patterns.add(+[](PrimListConstructOp op, PatternRewriter &rewriter) { - if (isListPotentiallyMutated(op.getResult())) - return failure(); - SmallVector elements = llvm::to_vector<4>(op.getElements()); - if (elements.size() == 0) - return failure(); - - auto listUnpackOp = elements[0].getDefiningOp(); - if (!listUnpackOp) - return failure(); - if (listUnpackOp.getResults() != elements) - return failure(); - if (isListPotentiallyMutated(listUnpackOp.getOperand())) - return failure(); - - rewriter.replaceOp(op, listUnpackOp.getOperand()); - return success(); - }); -} - //===----------------------------------------------------------------------===// // PrimListUnpackOp //===----------------------------------------------------------------------===// diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index d98de4631..a056aa2f0 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -965,15 +965,6 @@ func.func @torch.prim.If$fold_same_result$subset_of_results(%arg0: !torch.bool, return %0, %1: !torch.int, !torch.int } -// CHECK-LABEL: func.func @prim.ListConstruct$fold_list( -// CHECK-SAME: %[[ARG0:.*]]: !torch.list) -> !torch.list { -// CHECK: return %[[ARG0]] : !torch.list -func.func @prim.ListConstruct$fold_list(%arg0: !torch.list) -> !torch.list { - %0:2 = torch.prim.ListUnpack %arg0 : !torch.list -> !torch.tensor, !torch.tensor - %1 = torch.prim.ListConstruct %0#0, %0#1 : (!torch.tensor, !torch.tensor) -> !torch.list - return %1 : !torch.list -} - // CHECK-LABEL: func.func @torch.prim.TupleUnpack( // CHECK-SAME: %[[ARG0:.*]]: !torch.tensor, // CHECK-SAME: %[[ARG1:.*]]: !torch.tensor) -> !torch.tensor {