From 79b9cf9468e44c2a46c537ef14018f0554b788cc Mon Sep 17 00:00:00 2001 From: gpetters94 Date: Wed, 10 Aug 2022 19:24:02 -0400 Subject: [PATCH] Add lowering for aten.to.device (#1107) --- e2e_testing/torchscript/xfail_sets.py | 1 + .../Dialect/Torch/IR/GeneratedTorchOps.td | 27 +++++++++++++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 21 +++++++++++++++ .../Transforms/MaximizeValueSemantics.cpp | 2 +- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 5 ++++ lib/Dialect/Torch/Transforms/ShapeLibrary.cpp | 4 +++ .../jit_ir/build_tools/shape_lib_gen.py | 3 +++ .../jit_ir/build_tools/torch_ods_gen.py | 1 + .../torch_mlir_e2e_test/test_suite/basic.py | 19 +++++++++++++ 9 files changed, 82 insertions(+), 1 deletion(-) diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index 196a67e5b..93bb4d202 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -53,6 +53,7 @@ TOSA_PASS_SET = { "TModuleRank1_basic", "TModuleRank0_basic", "ElementwiseToDtypeIdentityModule_basic", + "AtenToDeviceModule_basic", "View1DFoldModule_basic", "UnsafeView1DFoldModule_basic", "SqueezeDimModule_static", diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 6d6459d5d..51378d800 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -5880,6 +5880,33 @@ def Torch_AtenToPrimDeviceOp : Torch_Op<"aten.to.prim_Device", [ }]; } +def Torch_AtenToDeviceOp : Torch_Op<"aten.to.device", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::to.device : (Tensor, Device, int, bool, bool, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_DeviceType:$device, + Torch_IntType:$dtype, + Torch_BoolType:$non_blocking, + Torch_BoolType:$copy, + AnyTorchOptionalIntType:$memory_format + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenToDeviceOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenToDeviceOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + def Torch_AtenTypeAsOp : Torch_Op<"aten.type_as", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 4e00a872a..f07ffb19d 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1996,6 +1996,25 @@ public: }; } // namespace +namespace { +// Decompose `aten.to.device` op into `aten.to.dtype` op. +class DecomposeAtenToDeviceOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenToDeviceOp op, + PatternRewriter &rewriter) const override { + + // Device information isn't relevant to torch-mlir, so we can drop that info + // here. + rewriter.replaceOpWithNewOp(op, op.getType(), op.self(), + op.dtype(), op.non_blocking(), + op.copy(), op.memory_format()); + + return success(); + } +}; +} // namespace + namespace { // Decompose `aten.adaptive_avg_pool2d` op into `aten.avg_pool2d` op. // @@ -2586,6 +2605,8 @@ class DecomposeComplexOpsPass patterns.add(context); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); patterns.add(context); diff --git a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp index 9a31c30ed..a4db59642 100644 --- a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp +++ b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp @@ -38,7 +38,7 @@ static bool isViewLikeOp(Operation *op) { AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp, AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp, TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp, - AtenNarrowOp>(op); + AtenNarrowOp, AtenToDeviceOp>(op); } namespace { diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 97b7a4326..74071d735 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -1024,6 +1024,11 @@ void TypeAnalysis::visitOperation(Operation *op, return; } + if (auto toDtype = dyn_cast(op)) { + visitAtenToDtypeLikeOp(toDtype, operands); + return; + } + if (auto toOther = dyn_cast(op)) { visitTypeConversionOp(toOther, operands); return; diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index 610fdfa6c..bc70edde4 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -5448,6 +5448,10 @@ module { func.func @"__torch_mlir_shape_fn.aten.to.dtype_layout"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.bool, %arg6: !torch.bool, %arg7: !torch.optional) -> !torch.list { return %arg0 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten.to.device"(%arg0: !torch.list, %arg1: !torch.Device, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool, %arg5: !torch.optional) -> !torch.list { + %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list + return %0 : !torch.list + } func.func @"__torch_mlir_shape_fn.aten.to.other"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.list { %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list return %0 : !torch.list diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index b4865c6c2..bd67f6844 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -427,6 +427,9 @@ def aten〇to〇dtype(self: List[int], dtype: int, non_blocking: bool = False, c def aten〇to〇dtype_layout(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> List[int]: return self +def aten〇to〇device(self: List[int], device: device, dtype: int, non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇to〇other(self: List[int], other: List[int], non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> List[int]: return upstream_shape_functions.unary(self) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 1e28b74b9..5dce6bc0b 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -456,6 +456,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::to.dtype_layout : (Tensor, int?, int?, Device?, bool?, bool, bool, int?) -> (Tensor)", has_folder=True) emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)") emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)") + emit("aten::to.device : (Tensor, Device, int, bool, bool, int?) -> (Tensor)") emit("aten::type_as : (Tensor, Tensor) -> (Tensor)") emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True) emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index b5d88e4f5..018977e2c 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -2627,3 +2627,22 @@ def Aten_EmbeddingBagExample_basic(module, tu: TestUtils): indices = torch.LongTensor([0, 1, 2, 2, 0, 2, 1, 3, 20, 50, 99, 2, 4, 5, 6, 7, 34, 54]) offsets = torch.LongTensor([0, 3, 5, 7, 9, 10, 15]) module.forward(weight, indices, offsets) + +# ============================================================================== + +class AtenToDeviceModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1 , -1], torch.float32, True), + ]) + + def forward(self, val): + return torch.ops.aten.to(val, device='cpu', dtype=torch.float, non_blocking=False) + +@register_test_case(module_factory=lambda: AtenToDeviceModule()) +def AtenToDeviceModule_basic(module, tu: TestUtils): + module.forward(torch.randn(2, 4)) \ No newline at end of file