mirror of https://github.com/llvm/torch-mlir
add e2e support for aten.atan2 (#1117)
- Includes math-to-libm pass in refbackend for math::atan2 supportpull/1136/head
parent
704efdc259
commit
38d8498b21
|
@ -655,6 +655,53 @@ def Torch_AtenCos_Op : Torch_Op<"aten.cos_", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenAtan2Op : Torch_Op<"aten.atan2", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::atan2 : (Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$other
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenAtan2Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenAtan2Op::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenAtan2_Op : Torch_Op<"aten.atan2_", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::atan2_ : (Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$other
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenAtan2_Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenAtan2_Op::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenNegOp : Torch_Op<"aten.neg", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -391,6 +391,18 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return b.create<arith::MulIOp>(loc, lhs, rhs);
|
||||
}
|
||||
}
|
||||
if (auto atan2 = dyn_cast<AtenAtan2Op>(op)) {
|
||||
Type dtype = converter->convertType(atan2.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
if (!dtype.isa<mlir::FloatType>()) {
|
||||
atan2.emitError("Atan2 requires floating point result type");
|
||||
return nullptr;
|
||||
}
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||
return b.create<math::Atan2Op>(loc, lhs, rhs);
|
||||
}
|
||||
if (auto gtTensor = dyn_cast<AtenGtTensorOp>(op)) {
|
||||
AtenGtTensorOp::Adaptor adaptor(operands);
|
||||
Type lhsDtype = payloadArgs[0].getType();
|
||||
|
@ -926,7 +938,7 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (!isa<AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp,
|
||||
AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp,
|
||||
AtenDivTensorOp, AtenDivTensorModeOp, AtenSubTensorOp,
|
||||
AtenDivTensorOp, AtenDivTensorModeOp, AtenSubTensorOp, AtenAtan2Op,
|
||||
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenExpm1Op,
|
||||
AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
|
||||
AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenErfOp,
|
||||
|
@ -1669,14 +1681,14 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
|||
AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp,
|
||||
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp,
|
||||
AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp,
|
||||
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp,
|
||||
AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp,
|
||||
AtenLog2Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
|
||||
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp,
|
||||
AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp,
|
||||
AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
|
||||
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp,
|
||||
AtenLogicalOrOp, AtenTriuOp>();
|
||||
AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp,
|
||||
AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp,
|
||||
AtenPowTensorScalarOp, AtenLog2Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp,
|
||||
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenGeScalarOp,
|
||||
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
|
||||
AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp,
|
||||
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
|
||||
AtenNeScalarOp, AtenMaskedFillScalarOp, AtenLogicalOrOp, AtenTriuOp>();
|
||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenNllLossForwardOp>();
|
||||
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
||||
|
|
|
@ -733,6 +733,23 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
return;
|
||||
}
|
||||
|
||||
// Dtype is always float32, except for bfloat16, float64 and nullptr after
|
||||
// promotion and assuming possible-zero rank.
|
||||
if (isa<AtenAtan2Op>(op)) {
|
||||
ValueKnowledge knowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
Type promotedDtype = getPromotedResultType(
|
||||
op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()},
|
||||
getRankIsNonZeroArray(op->getOperands()));
|
||||
if (promotedDtype) {
|
||||
knowledge.dtype = Float32Type::get(op->getContext());
|
||||
if (promotedDtype.isa<BFloat16Type, Float64Type>())
|
||||
knowledge.dtype = promotedDtype;
|
||||
}
|
||||
incorporateKnowledge(op->getResult(0), knowledge);
|
||||
return;
|
||||
}
|
||||
|
||||
// Promote three dtypes.
|
||||
if (isa<AtenAddmmOp, AtenLerpTensorOp, AtenAddcmulOp, AtenAddcdivOp>(op)) {
|
||||
auto knowledge =
|
||||
|
|
|
@ -6248,6 +6248,10 @@ module {
|
|||
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.atan2"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.__and__.Tensor"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
|
|
|
@ -807,6 +807,9 @@ def aten〇div〇Tensor_mode(self: List[int], other: List[int], rounding_mode: O
|
|||
def aten〇floor_divide(self: List[int], other: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.broadcast(self, other)
|
||||
|
||||
def aten〇atan2(self: List[int], other: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.broadcast(self, other)
|
||||
|
||||
def aten〇__and__〇Tensor(self: List[int], other: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.broadcast(self, other)
|
||||
|
||||
|
|
|
@ -252,6 +252,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
"aten::exp : (Tensor) -> (Tensor)",
|
||||
"aten::expm1 : (Tensor) -> (Tensor)",
|
||||
"aten::cos : (Tensor) -> (Tensor)",
|
||||
"aten::atan2 : (Tensor, Tensor) -> (Tensor)",
|
||||
"aten::neg : (Tensor) -> (Tensor)",
|
||||
"aten::floor : (Tensor) -> (Tensor)",
|
||||
"aten::ceil : (Tensor) -> (Tensor)",
|
||||
|
|
|
@ -142,6 +142,8 @@ LOWERING_PIPELINE = ",".join([
|
|||
"func.func(refback-expand-ops-for-llvm)",
|
||||
"func.func(arith-expand)",
|
||||
"func.func(convert-math-to-llvm)",
|
||||
# Handle some complex mlir::math ops (e.g. atan2)
|
||||
"convert-math-to-libm",
|
||||
"convert-linalg-to-llvm",
|
||||
"convert-memref-to-llvm",
|
||||
"func.func(convert-arith-to-llvm)",
|
||||
|
|
|
@ -804,6 +804,77 @@ def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAtan2TensorFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.atan2(a, b)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtan2TensorFloatModule())
|
||||
def ElementwiseAtan2TensorFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAtan2TensorIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.int32, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.atan2(a, b)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtan2TensorIntModule())
|
||||
def ElementwiseAtan2TensorIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
torch.randint(1, 10, [4]).type(torch.int32), torch.randint(1, 10, [4]))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAtan2FloatIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
([-1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.atan2(a, b)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtan2FloatIntModule())
|
||||
def ElementwiseAtan2FloatIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, [4, 4], dtype=torch.int32),
|
||||
tu.rand(4, 4).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseLogModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue