[MLIR][TORCH] Add canonicalization pattern for prim.ListUnpack op

This commit adds the canonicalization pattern for the `prim.ListUnpack` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/1069/head
Vivek Khandelwal 2022-07-15 17:40:23 +05:30
parent 247dd64a66
commit 4c25878e64
3 changed files with 36 additions and 2 deletions

View File

@ -310,8 +310,10 @@ def Torch_GlobalSlotSetOp : Torch_Op<"global_slot.set", []> {
// See `torch/csrc/jit/runtime/instruction.h`.
def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack",
[AllowsTypeRefinement]> {
def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "TorchScript prim::ListUnpack op";
let arguments = (ins AnyTorchType:$operand);
let results = (outs Variadic<AnyTorchType>:$results);
@ -319,6 +321,7 @@ def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack",
let assemblyFormat = [{
$operand attr-dict `:` qualified(type($operand)) `->` qualified(type($results))
}];
let hasCanonicalizer = 1;
}
def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [

View File

@ -1574,6 +1574,27 @@ void PrimTupleUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
});
}
//===----------------------------------------------------------------------===//
// PrimListUnpackOp
//===----------------------------------------------------------------------===//
void PrimListUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](PrimListUnpackOp op, PatternRewriter &rewriter) {
auto torchList = op.operand();
if (isListPotentiallyMutated(torchList)) {
return failure();
}
auto listConstruct = torchList.getDefiningOp<Torch::PrimListConstructOp>();
if (!listConstruct)
return failure();
rewriter.replaceOp(op, listConstruct.elements());
return success();
});
}
static PrimDictConstructOp getDictConstructIfNotModified(Value torchDict) {
if (!llvm::all_of(torchDict.getUsers(), [](Operation *op) {
return isa<Aten__Getitem__DictStrOp, Aten__Contains__StrOp,

View File

@ -1357,3 +1357,13 @@ func.func @torch.aten.size.int$copy(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.
%size = torch.aten.size.int %value_tensor, %zero : !torch.vtensor, !torch.int -> !torch.int
return %size : !torch.int
}
// CHECK-LABEL: func.func @prim.ListUnpack$fold_list(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,3],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[2,3],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) {
// CHECK: return %[[ARG0]], %[[ARG1]] : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>
func.func @prim.ListUnpack$fold_list(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !torch.vtensor<[2,3],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) {
%0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) -> !torch.list<vtensor>
%1:2 = torch.prim.ListUnpack %0 : !torch.list<vtensor> -> !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>
return %1#0, %1#1 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>
}