mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add canonicalization pattern for prim.ListUnpack op
This commit adds the canonicalization pattern for the `prim.ListUnpack` op. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/1069/head
parent
247dd64a66
commit
4c25878e64
|
@ -310,8 +310,10 @@ def Torch_GlobalSlotSetOp : Torch_Op<"global_slot.set", []> {
|
||||||
// See `torch/csrc/jit/runtime/instruction.h`.
|
// See `torch/csrc/jit/runtime/instruction.h`.
|
||||||
|
|
||||||
|
|
||||||
def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack",
|
def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack", [
|
||||||
[AllowsTypeRefinement]> {
|
AllowsTypeRefinement,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
let summary = "TorchScript prim::ListUnpack op";
|
let summary = "TorchScript prim::ListUnpack op";
|
||||||
let arguments = (ins AnyTorchType:$operand);
|
let arguments = (ins AnyTorchType:$operand);
|
||||||
let results = (outs Variadic<AnyTorchType>:$results);
|
let results = (outs Variadic<AnyTorchType>:$results);
|
||||||
|
@ -319,6 +321,7 @@ def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack",
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$operand attr-dict `:` qualified(type($operand)) `->` qualified(type($results))
|
$operand attr-dict `:` qualified(type($operand)) `->` qualified(type($results))
|
||||||
}];
|
}];
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [
|
def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [
|
||||||
|
|
|
@ -1574,6 +1574,27 @@ void PrimTupleUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// PrimListUnpackOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void PrimListUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
|
MLIRContext *context) {
|
||||||
|
patterns.add(+[](PrimListUnpackOp op, PatternRewriter &rewriter) {
|
||||||
|
auto torchList = op.operand();
|
||||||
|
if (isListPotentiallyMutated(torchList)) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto listConstruct = torchList.getDefiningOp<Torch::PrimListConstructOp>();
|
||||||
|
if (!listConstruct)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, listConstruct.elements());
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
static PrimDictConstructOp getDictConstructIfNotModified(Value torchDict) {
|
static PrimDictConstructOp getDictConstructIfNotModified(Value torchDict) {
|
||||||
if (!llvm::all_of(torchDict.getUsers(), [](Operation *op) {
|
if (!llvm::all_of(torchDict.getUsers(), [](Operation *op) {
|
||||||
return isa<Aten__Getitem__DictStrOp, Aten__Contains__StrOp,
|
return isa<Aten__Getitem__DictStrOp, Aten__Contains__StrOp,
|
||||||
|
|
|
@ -1357,3 +1357,13 @@ func.func @torch.aten.size.int$copy(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.
|
||||||
%size = torch.aten.size.int %value_tensor, %zero : !torch.vtensor, !torch.int -> !torch.int
|
%size = torch.aten.size.int %value_tensor, %zero : !torch.vtensor, !torch.int -> !torch.int
|
||||||
return %size : !torch.int
|
return %size : !torch.int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @prim.ListUnpack$fold_list(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,3],f32>,
|
||||||
|
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[2,3],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) {
|
||||||
|
// CHECK: return %[[ARG0]], %[[ARG1]] : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>
|
||||||
|
func.func @prim.ListUnpack$fold_list(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !torch.vtensor<[2,3],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) {
|
||||||
|
%0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) -> !torch.list<vtensor>
|
||||||
|
%1:2 = torch.prim.ListUnpack %0 : !torch.list<vtensor> -> !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>
|
||||||
|
return %1#0, %1#1 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue