[torch] `torch.aten.complex` operation with lowering (#3738)

Add the operation with lowering to linalg. Includes a test for
end-to-end correctness.
pull/3760/head
Rob Suderman 2024-10-03 11:09:52 -07:00 committed by GitHub
parent f0b7ca72f5
commit 9ab0db5789
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 88 additions and 17 deletions

View File

@ -5122,6 +5122,30 @@ def Torch_AtenRad2degOp : Torch_Op<"aten.rad2deg", [
}];
}
def Torch_AtenComplexOp : Torch_Op<"aten.complex", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::complex : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$real,
AnyTorchTensorType:$imag
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenComplexOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenComplexOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenRealOp : Torch_Op<"aten.real", [
AllowsTypeRefinement,
ReadOnly

View File

@ -575,6 +575,16 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
b.create<arith::ConstantOp>(loc, b.getFloatAttr(floatDtype, 0));
return createEqual(b, loc, floatDtype, self, zero);
}
if (auto complex = dyn_cast<AtenComplexOp>(op)) {
auto ctype = cast<ComplexType>(
cast<RankedTensorType>(converter->convertType(complex.getType()))
.getElementType());
Type stype = ctype.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], stype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], stype);
return b.create<complex::CreateOp>(loc, ctype, lhs, rhs);
}
if (isa<AtenAbsOp>(op)) {
if (isa<IntegerType>(payloadArgs[0].getType()))
return b.create<math::AbsIOp>(loc, payloadArgs[0]);
@ -1590,22 +1600,22 @@ public:
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op,
AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp,
AtenRemainderScalarOp, AtenRemainderTensorOp, AtenAbsOp,
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp,
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp,
Aten__Lshift__ScalarOp, Aten__Rshift__ScalarOp, AtenGtScalarOp,
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp,
AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp,
AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp,
AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp,
AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
AtenFillTensorOp, AtenAtanOp, AtenAcosOp, AtenAtanhOp, AtenAcoshOp,
AtenAsinOp, AtenAsinhOp, AtenRealOp, AtenImagOp,
AtenDequantizeSelfOp, AtenDequantizeTensorOp,
AtenComplexOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
AtenBitwiseRightShiftTensorOp, Aten__Lshift__ScalarOp,
Aten__Rshift__ScalarOp, AtenGtScalarOp, AtenGeScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp,
AtenAddScalarOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp,
AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp,
AtenTriuOp, AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp,
AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp,
AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenRealOp,
AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
AtenQuantizePerTensorOp, AtenIscloseOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
@ -3351,7 +3361,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenClampTensorOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp,
AtenFloorOp, AtenCeilOp, AtenPreluOp, AtenPowScalarOp,
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op,
AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenComplexOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
AtenBitwiseRightShiftTensorOp, Aten__Lshift__ScalarOp,

View File

@ -2751,6 +2751,7 @@ ONNX_XFAIL_SET = {
"ElementwiseExpm1IntModule_basic",
"ElementwiseExpm1Module_basic",
"ElementwiseFmodTensor_Int_basic",
"ElementwiseCreateComplexModule_basic",
"ElementwiseMulTensorComplexModule_basic",
"ElementwiseMulTensorComplexDiffModule_basic",
"ElementwiseOrTensorModule_basic",
@ -3165,6 +3166,14 @@ if torch_version_for_comparison() < version.parse("2.4.0.dev"):
"AtenIntMM_basic",
}
if torch_version_for_comparison() > version.parse("2.4.0.dev"):
STABLEHLO_PASS_SET = STABLEHLO_PASS_SET - {
"ElementwiseCreateComplexModule_basic",
}
FX_IMPORTER_STABLEHLO_XFAIL_SET = FX_IMPORTER_STABLEHLO_XFAIL_SET | {
"ElementwiseCreateComplexModule_basic",
}
ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
"FakeQuantizePerTensorAffineModule_basic",

View File

@ -492,6 +492,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)")
emit("aten::prelu : (Tensor, Tensor) -> (Tensor)")
emit("aten::rad2deg : (Tensor) -> (Tensor)")
emit("aten::complex : (Tensor, Tensor) -> (Tensor)")
emit("aten::real : (Tensor) -> (Tensor)")
emit("aten::imag : (Tensor) -> (Tensor)")
emit("aten::view_as_complex : (Tensor) -> (Tensor)")

View File

@ -2012,6 +2012,33 @@ def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseCreateComplexModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1], torch.float32, True),
([-1], torch.float32, True),
]
)
def forward(self, a, b):
return torch.complex(a, b)
@register_test_case(module_factory=lambda: ElementwiseCreateComplexModule())
def ElementwiseCreateComplexModule_basic(module, tu: TestUtils):
module.forward(
tu.randint(4, high=10).type(torch.float32),
tu.randint(4, high=10).type(torch.float32),
)
# ==============================================================================
class ElementwiseMulTensorComplexModule(torch.nn.Module):
def __init__(self):
super().__init__()