mirror of https://github.com/llvm/torch-mlir
revert canonicalizer for PrimListConstructOp (#2408)
parent
3dd29f9d5d
commit
4c9d234b01
|
@ -413,7 +413,6 @@ def Torch_PrimListConstructOp: Torch_Op<"prim.ListConstruct", [
|
||||||
);
|
);
|
||||||
|
|
||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
let hasCanonicalizer = 1;
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$elements attr-dict `:` functional-type(operands, results)
|
$elements attr-dict `:` functional-type(operands, results)
|
||||||
|
|
|
@ -2099,32 +2099,6 @@ void PrimTupleUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// PrimListConstructOp
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
void PrimListConstructOp::getCanonicalizationPatterns(
|
|
||||||
RewritePatternSet &patterns, MLIRContext *context) {
|
|
||||||
patterns.add(+[](PrimListConstructOp op, PatternRewriter &rewriter) {
|
|
||||||
if (isListPotentiallyMutated(op.getResult()))
|
|
||||||
return failure();
|
|
||||||
SmallVector<Value> elements = llvm::to_vector<4>(op.getElements());
|
|
||||||
if (elements.size() == 0)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto listUnpackOp = elements[0].getDefiningOp<PrimListUnpackOp>();
|
|
||||||
if (!listUnpackOp)
|
|
||||||
return failure();
|
|
||||||
if (listUnpackOp.getResults() != elements)
|
|
||||||
return failure();
|
|
||||||
if (isListPotentiallyMutated(listUnpackOp.getOperand()))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
rewriter.replaceOp(op, listUnpackOp.getOperand());
|
|
||||||
return success();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// PrimListUnpackOp
|
// PrimListUnpackOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -965,15 +965,6 @@ func.func @torch.prim.If$fold_same_result$subset_of_results(%arg0: !torch.bool,
|
||||||
return %0, %1: !torch.int, !torch.int
|
return %0, %1: !torch.int, !torch.int
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @prim.ListConstruct$fold_list(
|
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.list<tensor>) -> !torch.list<tensor> {
|
|
||||||
// CHECK: return %[[ARG0]] : !torch.list<tensor>
|
|
||||||
func.func @prim.ListConstruct$fold_list(%arg0: !torch.list<tensor>) -> !torch.list<tensor> {
|
|
||||||
%0:2 = torch.prim.ListUnpack %arg0 : !torch.list<tensor> -> !torch.tensor, !torch.tensor
|
|
||||||
%1 = torch.prim.ListConstruct %0#0, %0#1 : (!torch.tensor, !torch.tensor) -> !torch.list<tensor>
|
|
||||||
return %1 : !torch.list<tensor>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.prim.TupleUnpack(
|
// CHECK-LABEL: func.func @torch.prim.TupleUnpack(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.tensor,
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.tensor,
|
||||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.tensor) -> !torch.tensor {
|
// CHECK-SAME: %[[ARG1:.*]]: !torch.tensor) -> !torch.tensor {
|
||||||
|
|
Loading…
Reference in New Issue