From 96b14e952e7551a5e2ec05596b5c6cb52e1a3f8e Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Fri, 23 Jun 2023 01:07:14 +0800 Subject: [PATCH] [Torch Dialect] Support aten.device.with_index (#2254) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 25 +++++++++++++++++++ lib/Dialect/Torch/IR/TorchOps.cpp | 23 +++++++++++++++++ .../jit_ir/build_tools/torch_ods_gen.py | 1 + test/Dialect/Torch/canonicalize.mlir | 10 ++++++++ 4 files changed, 59 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 9f8784f72..a82721a5d 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6996,6 +6996,31 @@ def Torch_AtenDetachOp : Torch_Op<"aten.detach", [ let hasFolder = 1; } +def Torch_AtenDeviceWithIndexOp : Torch_Op<"aten.device.with_index", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::device.with_index : (str, int) -> (Device)`"; + let arguments = (ins + Torch_StringType:$type, + Torch_IntType:$index + ); + let results = (outs + Torch_DeviceType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenDeviceWithIndexOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenDeviceWithIndexOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasCanonicalizer = 1; +} + def Torch_AtenCudaOp : Torch_Op<"aten.cuda", [ AllowsTypeRefinement, ReadOnly diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 60dffbaf5..19006af2e 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2455,6 +2455,29 @@ void AtenCudaOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +//===----------------------------------------------------------------------===// +// AtenDeviceWithIndexOp +//===----------------------------------------------------------------------===// + +void AtenDeviceWithIndexOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(+[](AtenDeviceWithIndexOp op, PatternRewriter &rewriter) { + std::string type; + int64_t index; + if (!matchPattern(op.getType(), m_TorchConstantStr(type))) { + return rewriter.notifyMatchFailure( + op, "unimplemented: type must be a constant string"); + } + if (!matchPattern(op.getIndex(), m_TorchConstantInt(&index))) { + return rewriter.notifyMatchFailure( + op, "unimplemented: index must be a constant integer"); + } + rewriter.replaceOpWithNewOp( + op, type + ":" + std::to_string(index)); + return success(); + }); +} + //===----------------------------------------------------------------------===// // AtenIntTensorOp //===----------------------------------------------------------------------===// diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 38c3b7e8e..490f26274 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -484,6 +484,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit_with_mutating_variants("aten::copy : (Tensor, Tensor, bool) -> (Tensor)") emit("aten::_to_copy : (Tensor, int?, int?, Device?, bool?, bool, int?) -> (Tensor)") emit("aten::detach : (Tensor) -> (Tensor)", has_folder=True) + emit("aten::device.with_index : (str, int) -> (Device)", has_canonicalizer=True) emit("aten::cuda : (Tensor) -> (Tensor)", has_canonicalizer=True) emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)") emit("aten::embedding_bag.padding_idx : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int?) -> (Tensor, Tensor, Tensor, Tensor)") diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index ab19e6038..b7b21477f 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -2001,3 +2001,13 @@ func.func @torch.aten.cuda$canonicalize(%arg0: !torch.tensor) -> !torch.tensor { %0 = torch.aten.cuda %arg0 : !torch.tensor -> !torch.tensor return %0 : !torch.tensor } + +// CHECK-LABEL: func.func @torch.aten.device.with_index$canonicalize +// CHECK-NEXT: %[[VAL:.*]] = torch.constant.device "cuda:0" +// CHECK-NEXT: return %[[VAL]] : !torch.Device +func.func @torch.aten.device.with_index$canonicalize() -> !torch.Device { + %str = torch.constant.str "cuda" + %int0 = torch.constant.int 0 + %0 = torch.aten.device.with_index %str, %int0 : !torch.str, !torch.int -> !torch.Device + return %0 : !torch.Device +}