mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add canonicalization pattern for prim.ListUnpack op
This commit adds the canonicalization pattern for the `prim.ListUnpack` op. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/1069/head
parent
247dd64a66
commit
4c25878e64
|
@ -310,8 +310,10 @@ def Torch_GlobalSlotSetOp : Torch_Op<"global_slot.set", []> {
|
|||
// See `torch/csrc/jit/runtime/instruction.h`.
|
||||
|
||||
|
||||
def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack",
|
||||
[AllowsTypeRefinement]> {
|
||||
def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack", [
|
||||
AllowsTypeRefinement,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "TorchScript prim::ListUnpack op";
|
||||
let arguments = (ins AnyTorchType:$operand);
|
||||
let results = (outs Variadic<AnyTorchType>:$results);
|
||||
|
@ -319,6 +321,7 @@ def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack",
|
|||
let assemblyFormat = [{
|
||||
$operand attr-dict `:` qualified(type($operand)) `->` qualified(type($results))
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [
|
||||
|
|
|
@ -1574,6 +1574,27 @@ void PrimTupleUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PrimListUnpackOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void PrimListUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add(+[](PrimListUnpackOp op, PatternRewriter &rewriter) {
|
||||
auto torchList = op.operand();
|
||||
if (isListPotentiallyMutated(torchList)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto listConstruct = torchList.getDefiningOp<Torch::PrimListConstructOp>();
|
||||
if (!listConstruct)
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(op, listConstruct.elements());
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
static PrimDictConstructOp getDictConstructIfNotModified(Value torchDict) {
|
||||
if (!llvm::all_of(torchDict.getUsers(), [](Operation *op) {
|
||||
return isa<Aten__Getitem__DictStrOp, Aten__Contains__StrOp,
|
||||
|
|
|
@ -1357,3 +1357,13 @@ func.func @torch.aten.size.int$copy(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.
|
|||
%size = torch.aten.size.int %value_tensor, %zero : !torch.vtensor, !torch.int -> !torch.int
|
||||
return %size : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @prim.ListUnpack$fold_list(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,3],f32>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[2,3],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) {
|
||||
// CHECK: return %[[ARG0]], %[[ARG1]] : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>
|
||||
func.func @prim.ListUnpack$fold_list(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !torch.vtensor<[2,3],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) {
|
||||
%0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) -> !torch.list<vtensor>
|
||||
%1:2 = torch.prim.ListUnpack %0 : !torch.list<vtensor> -> !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>
|
||||
return %1#0, %1#1 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue