From 9ab0db5789d3980f3055c613c9847de1755afb1f Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 3 Oct 2024 11:09:52 -0700 Subject: [PATCH] [torch] `torch.aten.complex` operation with lowering (#3738) Add the operation with lowering to linalg. Includes a test for end-to-end correctness. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 24 ++++++++++ .../TorchToLinalg/Uncategorized.cpp | 44 ++++++++++++------- projects/pt1/e2e_testing/xfail_sets.py | 9 ++++ .../build_tools/torch_ods_gen.py | 1 + .../test_suite/elementwise.py | 27 ++++++++++++ 5 files changed, 88 insertions(+), 17 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 6f02a9476..2f329e782 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -5122,6 +5122,30 @@ def Torch_AtenRad2degOp : Torch_Op<"aten.rad2deg", [ }]; } +def Torch_AtenComplexOp : Torch_Op<"aten.complex", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::complex : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$real, + AnyTorchTensorType:$imag + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenComplexOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenComplexOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenRealOp : Torch_Op<"aten.real", [ AllowsTypeRefinement, ReadOnly diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 4688ffc78..0f6f92bd7 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -575,6 +575,16 @@ static Value createLinalgPayloadCalculationForElementwiseOp( b.create(loc, b.getFloatAttr(floatDtype, 0)); return createEqual(b, loc, floatDtype, self, zero); } + if (auto complex = dyn_cast(op)) { + auto ctype = cast( + cast(converter->convertType(complex.getType())) + .getElementType()); + Type stype = ctype.getElementType(); + + Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], stype); + Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], stype); + return b.create(loc, ctype, lhs, rhs); + } if (isa(op)) { if (isa(payloadArgs[0].getType())) return b.create(loc, payloadArgs[0]); @@ -1590,22 +1600,22 @@ public: AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp, AtenRemainderScalarOp, AtenRemainderTensorOp, AtenAbsOp, - AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, - AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, - AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp, - Aten__Lshift__ScalarOp, Aten__Rshift__ScalarOp, AtenGtScalarOp, - AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, - AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, - AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, - AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp, - AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp, - AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp, - AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, - AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, - AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, - AtenFillTensorOp, AtenAtanOp, AtenAcosOp, AtenAtanhOp, AtenAcoshOp, - AtenAsinOp, AtenAsinhOp, AtenRealOp, AtenImagOp, - AtenDequantizeSelfOp, AtenDequantizeTensorOp, + AtenComplexOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, + AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp, + AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp, + AtenBitwiseRightShiftTensorOp, Aten__Lshift__ScalarOp, + Aten__Rshift__ScalarOp, AtenGtScalarOp, AtenGeScalarOp, + AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, + AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, + AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp, + AtenAddScalarOp, AtenThresholdOp, AtenThresholdBackwardOp, + AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, + AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, + AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, + AtenTriuOp, AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, + AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp, + AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenRealOp, + AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp, AtenQuantizePerTensorOp, AtenIscloseOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); @@ -3351,7 +3361,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenClampTensorOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPreluOp, AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, - AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, + AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenComplexOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp, Aten__Lshift__ScalarOp, diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 2852611fe..33cc23993 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2751,6 +2751,7 @@ ONNX_XFAIL_SET = { "ElementwiseExpm1IntModule_basic", "ElementwiseExpm1Module_basic", "ElementwiseFmodTensor_Int_basic", + "ElementwiseCreateComplexModule_basic", "ElementwiseMulTensorComplexModule_basic", "ElementwiseMulTensorComplexDiffModule_basic", "ElementwiseOrTensorModule_basic", @@ -3165,6 +3166,14 @@ if torch_version_for_comparison() < version.parse("2.4.0.dev"): "AtenIntMM_basic", } +if torch_version_for_comparison() > version.parse("2.4.0.dev"): + STABLEHLO_PASS_SET = STABLEHLO_PASS_SET - { + "ElementwiseCreateComplexModule_basic", + } + FX_IMPORTER_STABLEHLO_XFAIL_SET = FX_IMPORTER_STABLEHLO_XFAIL_SET | { + "ElementwiseCreateComplexModule_basic", + } + ONNX_CRASHING_SET = LINALG_CRASHING_SET | { "FakeQuantizePerTensorAffineModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index ea5c50428..7d6680fe9 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -492,6 +492,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)") emit("aten::prelu : (Tensor, Tensor) -> (Tensor)") emit("aten::rad2deg : (Tensor) -> (Tensor)") + emit("aten::complex : (Tensor, Tensor) -> (Tensor)") emit("aten::real : (Tensor) -> (Tensor)") emit("aten::imag : (Tensor) -> (Tensor)") emit("aten::view_as_complex : (Tensor) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 9b4dbe659..ed5254353 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -2012,6 +2012,33 @@ def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseCreateComplexModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) + def forward(self, a, b): + return torch.complex(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseCreateComplexModule()) +def ElementwiseCreateComplexModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(4, high=10).type(torch.float32), + tu.randint(4, high=10).type(torch.float32), + ) + + +# ============================================================================== + + class ElementwiseMulTensorComplexModule(torch.nn.Module): def __init__(self): super().__init__()