From 38d8498b21d0892c2c7196aeaab64aeec6852a62 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Tue, 2 Aug 2022 11:39:41 -0400 Subject: [PATCH] add e2e support for aten.atan2 (#1117) - Includes math-to-libm pass in refbackend for math::atan2 support --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 47 ++++++++++++ .../TorchToLinalg/Uncategorized.cpp | 30 +++++--- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 17 +++++ lib/Dialect/Torch/Transforms/ShapeLibrary.cpp | 4 ++ .../jit_ir/build_tools/shape_lib_gen.py | 3 + .../jit_ir/build_tools/torch_ods_gen.py | 1 + .../linalg_on_tensors_backends/refbackend.py | 2 + .../test_suite/elementwise.py | 71 +++++++++++++++++++ 8 files changed, 166 insertions(+), 9 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index eeff0fa9e..a13d1ab5c 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -655,6 +655,53 @@ def Torch_AtenCos_Op : Torch_Op<"aten.cos_", [ }]; } +def Torch_AtenAtan2Op : Torch_Op<"aten.atan2", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::atan2 : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAtan2Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenAtan2Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenAtan2_Op : Torch_Op<"aten.atan2_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::atan2_ : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAtan2_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenAtan2_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenNegOp : Torch_Op<"aten.neg", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index b982ca54b..f11f8a52e 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -391,6 +391,18 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, lhs, rhs); } } + if (auto atan2 = dyn_cast(op)) { + Type dtype = converter->convertType(atan2.getType()) + .cast() + .getElementType(); + if (!dtype.isa()) { + atan2.emitError("Atan2 requires floating point result type"); + return nullptr; + } + Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); + Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); + return b.create(loc, lhs, rhs); + } if (auto gtTensor = dyn_cast(op)) { AtenGtTensorOp::Adaptor adaptor(operands); Type lhsDtype = payloadArgs[0].getType(); @@ -926,7 +938,7 @@ public: ConversionPatternRewriter &rewriter) const override { if (!isa(); + AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, + AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, + AtenPowTensorScalarOp, AtenLog2Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, + AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenGeScalarOp, + AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, + AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, + AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, + AtenNeScalarOp, AtenMaskedFillScalarOp, AtenLogicalOrOp, AtenTriuOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 657880213..60fa8692f 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -733,6 +733,23 @@ void TypeAnalysis::visitOperation(Operation *op, return; } + // Dtype is always float32, except for bfloat16, float64 and nullptr after + // promotion and assuming possible-zero rank. + if (isa(op)) { + ValueKnowledge knowledge = + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); + Type promotedDtype = getPromotedResultType( + op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()}, + getRankIsNonZeroArray(op->getOperands())); + if (promotedDtype) { + knowledge.dtype = Float32Type::get(op->getContext()); + if (promotedDtype.isa()) + knowledge.dtype = promotedDtype; + } + incorporateKnowledge(op->getResult(0), knowledge); + return; + } + // Promote three dtypes. if (isa(op)) { auto knowledge = diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index 70d23c812..d52c5db17 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -6248,6 +6248,10 @@ module { %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list return %0 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten.atan2"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { + %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list + return %0 : !torch.list + } func.func @"__torch_mlir_shape_fn.aten.__and__.Tensor"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list return %0 : !torch.list diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index 0f47ddfd0..e31456249 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -807,6 +807,9 @@ def aten〇div〇Tensor_mode(self: List[int], other: List[int], rounding_mode: O def aten〇floor_divide(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) +def aten〇atan2(self: List[int], other: List[int]) -> List[int]: + return upstream_shape_functions.broadcast(self, other) + def aten〇__and__〇Tensor(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index e4d8a049d..0d8cf9da6 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -252,6 +252,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): "aten::exp : (Tensor) -> (Tensor)", "aten::expm1 : (Tensor) -> (Tensor)", "aten::cos : (Tensor) -> (Tensor)", + "aten::atan2 : (Tensor, Tensor) -> (Tensor)", "aten::neg : (Tensor) -> (Tensor)", "aten::floor : (Tensor) -> (Tensor)", "aten::ceil : (Tensor) -> (Tensor)", diff --git a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index 8791f3b32..09ab07c4d 100644 --- a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -142,6 +142,8 @@ LOWERING_PIPELINE = ",".join([ "func.func(refback-expand-ops-for-llvm)", "func.func(arith-expand)", "func.func(convert-math-to-llvm)", + # Handle some complex mlir::math ops (e.g. atan2) + "convert-math-to-libm", "convert-linalg-to-llvm", "convert-memref-to-llvm", "func.func(convert-arith-to-llvm)", diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index f13a0e172..536982c06 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -804,6 +804,77 @@ def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseAtan2TensorFloatModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ]) + def forward(self, a, b): + return torch.atan2(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseAtan2TensorFloatModule()) +def ElementwiseAtan2TensorFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 4), tu.rand(4, 4)) + + +# ============================================================================== + + +class ElementwiseAtan2TensorIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.int32, True), + ([-1], torch.int64, True), + ]) + def forward(self, a, b): + return torch.atan2(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseAtan2TensorIntModule()) +def ElementwiseAtan2TensorIntModule_basic(module, tu: TestUtils): + module.forward( + torch.randint(1, 10, [4]).type(torch.int32), torch.randint(1, 10, [4])) + + +# ============================================================================== + + +class ElementwiseAtan2FloatIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ([-1, -1], torch.float64, True), + ]) + def forward(self, a, b): + return torch.atan2(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseAtan2FloatIntModule()) +def ElementwiseAtan2FloatIntModule_basic(module, tu: TestUtils): + module.forward(torch.randint(1, 10, [4, 4], dtype=torch.int32), + tu.rand(4, 4).double()) + + +# ============================================================================== + + class ElementwiseLogModule(torch.nn.Module): def __init__(self):