From 0cf9ee340bc3264922e03b0df002e41c69184702 Mon Sep 17 00:00:00 2001 From: Zhekun Zhang <32320144+zhekunz2@users.noreply.github.com> Date: Tue, 2 May 2023 20:06:02 -0700 Subject: [PATCH] [Torch Dialect] Add to.dtype_layout canonicalize patterns (#2062) * add to.dtype_layout canonicalize patterns * update comment --------- Co-authored-by: zhekun.zhang --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 1 + lib/Dialect/Torch/IR/TorchOps.cpp | 42 +++++++++++++++++++ .../jit_ir/build_tools/torch_ods_gen.py | 2 +- test/Dialect/Torch/canonicalize.mlir | 32 ++++++++++++++ 4 files changed, 76 insertions(+), 1 deletion(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 9c366eb01..357f95fd2 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7479,6 +7479,7 @@ def Torch_AtenToDtypeLayoutOp : Torch_Op<"aten.to.dtype_layout", [ } }]; let hasFolder = 1; + let hasCanonicalizer = 1; } def Torch_AtenToOtherOp : Torch_Op<"aten.to.other", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 382423c98..28506d6ea 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -797,6 +797,48 @@ OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) { return getOperand(0); } +void AtenToDtypeLayoutOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + // `to.dtype_layout` -> `to.device/to.dtype` if layout is none and pin memory + // is false + patterns.add(+[](AtenToDtypeLayoutOp op, PatternRewriter &rewriter) { + // The pin_memory arg should be either constant `False` or `none`. + if (!op.getPinMemory().getType().isa()) { + bool pinMemory; + if (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory))) + return failure(); + else if (pinMemory) + return failure(); + } + + // The layout arg should be either `none` or `0` i.e. strided. + if (!op.getLayout().getType().isa()) { + int64_t tensorLayout; + if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout))) + return failure(); + else if (tensorLayout != torch_upstream::Layout::Strided) + return failure(); + } + + if (op.getDevice().getType().isa()) { + // The device arg is `none`. Rewrite to to.dtype. + AtenToDtypeOp toDtype = rewriter.create( + op.getLoc(), op.getType(), op.getSelf(), op.getDtype(), + op.getNonBlocking(), op.getCopy(), op.getMemoryFormat()); + rewriter.replaceOp(op, toDtype->getResults()); + } else { + // The device arg is not `none`. Rewrite to to.device. + AtenToDeviceOp toDevice = rewriter.create( + op.getLoc(), op.getType(), op.getSelf(), op.getDevice(), + op.getDtype(), op.getNonBlocking(), op.getCopy(), + op.getMemoryFormat()); + rewriter.replaceOp(op, toDevice->getResults()); + } + + return success(); + }); +} + //===----------------------------------------------------------------------===// // AtenViewOp //===----------------------------------------------------------------------===// diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index f280fd6e0..8d36f5645 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -502,7 +502,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::max.dim : (Tensor, int, bool) -> (Tensor, Tensor)") emit("aten::amax : (Tensor, int[], bool) -> (Tensor)") emit("aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True) - emit("aten::to.dtype_layout : (Tensor, int?, int?, Device?, bool?, bool, bool, int?) -> (Tensor)", has_folder=True) + emit("aten::to.dtype_layout : (Tensor, int?, int?, Device?, bool?, bool, bool, int?) -> (Tensor)", has_folder=True, has_canonicalizer = True) emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)") emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)") emit("aten::to.device : (Tensor, Device, int, bool, bool, int?) -> (Tensor)") diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 1657527ab..b4f9db5df 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1451,6 +1451,38 @@ func.func @torch.aten.to.dtype_layout$same_dtype(%arg0: !torch.tensor<[?,?],f32> return %0 : !torch.tensor<[?,?],f32> } +// CHECK-LABEL: func.func @torch.aten.to.dtype_layout$to_device( +// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> { +// CHECK-NEXT: %[[INT6:.*]] = torch.constant.int 6 +// CHECK-NEXT: %[[FALSE:.*]] = torch.constant.bool false +// CHECK-NEXT: %[[NONE:.*]] = torch.constant.none +// CHECK-NEXT: %[[CPU:.*]] = torch.constant.device "cpu" +// CHECK-NEXT: %[[RESULT:.*]] = torch.aten.to.device %[[ARG]], %[[CPU]], %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.tensor<[?,?],f32>, !torch.Device, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor<[?,?],f32> +// CHECK-NEXT: return %[[RESULT]] : !torch.tensor<[?,?],f32> +func.func @torch.aten.to.dtype_layout$to_device(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> { + %none = torch.constant.none + %device = torch.constant.device "cpu" + %false = torch.constant.bool false + %int6 = torch.constant.int 6 + %0 = torch.aten.to.dtype_layout %arg0, %int6, %none, %device, %none, %false, %false, %none : !torch.tensor<[?,?],f32>, !torch.int, !torch.none, !torch.Device, !torch.none, !torch.bool, !torch.bool, !torch.none -> !torch.tensor<[?,?],f32> + return %0 : !torch.tensor<[?,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.to.dtype_layout$to_dtype( +// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f16> { +// CHECK-NEXT: %[[NONE:.*]] = torch.constant.none +// CHECK-NEXT: %[[FALSE:.*]] = torch.constant.bool false +// CHECK-NEXT: %[[INT5:.*]] = torch.constant.int 5 +// CHECK-NEXT: %[[RESULT:.*]] = torch.aten.to.dtype %[[ARG]], %[[INT5]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor<[?,?],f16> +// CHECK-NEXT: return %[[RESULT]] : !torch.tensor<[?,?],f16> +func.func @torch.aten.to.dtype_layout$to_dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f16> { + %none = torch.constant.none + %false = torch.constant.bool false + %int5 = torch.constant.int 5 + %0 = torch.aten.to.dtype_layout %arg0, %int5, %none, %none, %none, %false, %false, %none : !torch.tensor<[?,?],f32>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.bool, !torch.none -> !torch.tensor<[?,?],f16> + return %0 : !torch.tensor<[?,?],f16> +} + // CHECK-LABEL: func.func @torch.aten.ge.float$same_operand( // CHECK-SAME: %{{.*}}: !torch.float) -> !torch.bool { // CHECK: %[[TRUE:.*]] = torch.constant.bool true