mirror of https://github.com/llvm/torch-mlir
[torch] `torch.aten.complex` operation with lowering (#3738)
Add the operation with lowering to linalg. Includes a test for end-to-end correctness.pull/3760/head
parent
f0b7ca72f5
commit
9ab0db5789
|
@ -5122,6 +5122,30 @@ def Torch_AtenRad2degOp : Torch_Op<"aten.rad2deg", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenComplexOp : Torch_Op<"aten.complex", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::complex : (Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$real,
|
||||
AnyTorchTensorType:$imag
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenComplexOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenComplexOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenRealOp : Torch_Op<"aten.real", [
|
||||
AllowsTypeRefinement,
|
||||
ReadOnly
|
||||
|
|
|
@ -575,6 +575,16 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
b.create<arith::ConstantOp>(loc, b.getFloatAttr(floatDtype, 0));
|
||||
return createEqual(b, loc, floatDtype, self, zero);
|
||||
}
|
||||
if (auto complex = dyn_cast<AtenComplexOp>(op)) {
|
||||
auto ctype = cast<ComplexType>(
|
||||
cast<RankedTensorType>(converter->convertType(complex.getType()))
|
||||
.getElementType());
|
||||
Type stype = ctype.getElementType();
|
||||
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], stype);
|
||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], stype);
|
||||
return b.create<complex::CreateOp>(loc, ctype, lhs, rhs);
|
||||
}
|
||||
if (isa<AtenAbsOp>(op)) {
|
||||
if (isa<IntegerType>(payloadArgs[0].getType()))
|
||||
return b.create<math::AbsIOp>(loc, payloadArgs[0]);
|
||||
|
@ -1590,22 +1600,22 @@ public:
|
|||
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op,
|
||||
AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp,
|
||||
AtenRemainderScalarOp, AtenRemainderTensorOp, AtenAbsOp,
|
||||
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp,
|
||||
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
|
||||
AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp,
|
||||
Aten__Lshift__ScalarOp, Aten__Rshift__ScalarOp, AtenGtScalarOp,
|
||||
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
|
||||
AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp,
|
||||
AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp,
|
||||
AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
|
||||
AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp,
|
||||
AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp,
|
||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
|
||||
AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp,
|
||||
AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
|
||||
AtenFillTensorOp, AtenAtanOp, AtenAcosOp, AtenAtanhOp, AtenAcoshOp,
|
||||
AtenAsinOp, AtenAsinhOp, AtenRealOp, AtenImagOp,
|
||||
AtenDequantizeSelfOp, AtenDequantizeTensorOp,
|
||||
AtenComplexOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
|
||||
AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
|
||||
AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
|
||||
AtenBitwiseRightShiftTensorOp, Aten__Lshift__ScalarOp,
|
||||
Aten__Rshift__ScalarOp, AtenGtScalarOp, AtenGeScalarOp,
|
||||
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
|
||||
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
|
||||
AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp,
|
||||
AtenAddScalarOp, AtenThresholdOp, AtenThresholdBackwardOp,
|
||||
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
|
||||
AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp,
|
||||
AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp,
|
||||
AtenTriuOp, AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp,
|
||||
AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp,
|
||||
AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenRealOp,
|
||||
AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
|
||||
AtenQuantizePerTensorOp, AtenIscloseOp>(op))
|
||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||
|
||||
|
@ -3351,7 +3361,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
|||
AtenClampTensorOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp,
|
||||
AtenFloorOp, AtenCeilOp, AtenPreluOp, AtenPowScalarOp,
|
||||
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op,
|
||||
AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
|
||||
AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenComplexOp, AtenReciprocalOp,
|
||||
AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
|
||||
AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
|
||||
AtenBitwiseRightShiftTensorOp, Aten__Lshift__ScalarOp,
|
||||
|
|
|
@ -2751,6 +2751,7 @@ ONNX_XFAIL_SET = {
|
|||
"ElementwiseExpm1IntModule_basic",
|
||||
"ElementwiseExpm1Module_basic",
|
||||
"ElementwiseFmodTensor_Int_basic",
|
||||
"ElementwiseCreateComplexModule_basic",
|
||||
"ElementwiseMulTensorComplexModule_basic",
|
||||
"ElementwiseMulTensorComplexDiffModule_basic",
|
||||
"ElementwiseOrTensorModule_basic",
|
||||
|
@ -3165,6 +3166,14 @@ if torch_version_for_comparison() < version.parse("2.4.0.dev"):
|
|||
"AtenIntMM_basic",
|
||||
}
|
||||
|
||||
if torch_version_for_comparison() > version.parse("2.4.0.dev"):
|
||||
STABLEHLO_PASS_SET = STABLEHLO_PASS_SET - {
|
||||
"ElementwiseCreateComplexModule_basic",
|
||||
}
|
||||
FX_IMPORTER_STABLEHLO_XFAIL_SET = FX_IMPORTER_STABLEHLO_XFAIL_SET | {
|
||||
"ElementwiseCreateComplexModule_basic",
|
||||
}
|
||||
|
||||
|
||||
ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
|
||||
"FakeQuantizePerTensorAffineModule_basic",
|
||||
|
|
|
@ -492,6 +492,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)")
|
||||
emit("aten::prelu : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::rad2deg : (Tensor) -> (Tensor)")
|
||||
emit("aten::complex : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::real : (Tensor) -> (Tensor)")
|
||||
emit("aten::imag : (Tensor) -> (Tensor)")
|
||||
emit("aten::view_as_complex : (Tensor) -> (Tensor)")
|
||||
|
|
|
@ -2012,6 +2012,33 @@ def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseCreateComplexModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, a, b):
|
||||
return torch.complex(a, b)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseCreateComplexModule())
|
||||
def ElementwiseCreateComplexModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.randint(4, high=10).type(torch.float32),
|
||||
tu.randint(4, high=10).type(torch.float32),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseMulTensorComplexModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue