From 9b89f8eb3f9825a6c13d20dd960b8c7499699f5e Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Wed, 9 Feb 2022 19:31:03 -0800 Subject: [PATCH] [TORCH][MLIR] Add E2E support for aten.clone (#571) This commit adds support for the aten.clone op. --- e2e_testing/torchscript/elementwise.py | 18 ++++++++++++++++++ .../Dialect/Torch/IR/GeneratedAtenOps.td | 15 +++++++++++++++ lib/Conversion/TorchToLinalg/TorchToLinalg.cpp | 11 +++++++++-- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 4 ++-- .../jit_ir/build_tools/torch_ods_gen.py | 1 + 5 files changed, 45 insertions(+), 4 deletions(-) diff --git a/e2e_testing/torchscript/elementwise.py b/e2e_testing/torchscript/elementwise.py index a571f4799..bc6bbe6f5 100644 --- a/e2e_testing/torchscript/elementwise.py +++ b/e2e_testing/torchscript/elementwise.py @@ -1106,3 +1106,21 @@ class ElementwiseAddScalarFloatModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseAddScalarFloatModule()) def ElementwiseAddScalarFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + + +class ElementwiseCloneModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.clone(x) + + +@register_test_case(module_factory=lambda: ElementwiseCloneModule()) +def ElementwiseCloneModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4)) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td index c5c947bfe..87897ed66 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td @@ -2262,6 +2262,21 @@ def Torch_AtenBucketizeTensorOp : Torch_Op<"aten.bucketize.Tensor", [ let assemblyFormat = "$self `,` $boundaries `,` $out_int32 `,` $right attr-dict `:` qualified(type($self)) `,` qualified(type($boundaries)) `,` qualified(type($out_int32)) `,` qualified(type($right)) `->` qualified(type($result))"; } +def Torch_AtenCloneOp : Torch_Op<"aten.clone", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::clone : (Tensor, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + TorchOptionalIntType:$memory_format + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $memory_format attr-dict `:` qualified(type($self)) `,` qualified(type($memory_format)) `->` qualified(type($result))"; +} + def Torch_AtenContiguousOp : Torch_Op<"aten.contiguous", [ AllowsTypeRefinement ]> { diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 93c2d8ad2..205f67603 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -1660,6 +1660,13 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, payloadArgs[0]); if (isa(op)) return b.create(loc, payloadArgs[0]); + if (auto clone = dyn_cast(op)) { + if (!clone.memory_format().getType().isa()) { + clone.emitError("unimplemented: only default memory format is supported"); + return nullptr; + } + return payloadArgs[0]; + } if (auto bitwiseAndTensor = dyn_cast(op)) { if (bitwiseAndTensor.getType() .cast() @@ -2450,7 +2457,7 @@ struct ConvertElementwiseOp : ConversionPattern { AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, - AtenThresholdOp, AtenThresholdBackwardOp>(op)) + AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -4571,7 +4578,7 @@ public: AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, - AtenThresholdBackwardOp>(); + AtenThresholdBackwardOp, AtenCloneOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 9b5609e9a..3308e5ed0 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -242,8 +242,8 @@ public: AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenFloorOp, AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp, AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp, - AtenAbsOp, AtenThresholdOp, AtenSquareOp, PseudoAtenUniformOp>( - op)) { + AtenAbsOp, AtenThresholdOp, AtenSquareOp, PseudoAtenUniformOp, + AtenCloneOp>(op)) { return getLatticeElement(op->getResult(0)).join(*operands[0]); } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 39ef5acd9..41da0c9ad 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -577,6 +577,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): emit("aten::arange.start_step : (Scalar, Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)") emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)") + emit("aten::clone : (Tensor, int?) -> (Tensor)") emit("aten::contiguous : (Tensor, int) -> (Tensor)") emit("aten::copy_ : (Tensor, Tensor, bool) -> (Tensor)") emit("aten::detach : (Tensor) -> (Tensor)")