mirror of https://github.com/llvm/torch-mlir
add e2e support for aten.atan2 (#1117)
- Includes math-to-libm pass in refbackend for math::atan2 supportpull/1136/head
parent
704efdc259
commit
38d8498b21
|
@ -655,6 +655,53 @@ def Torch_AtenCos_Op : Torch_Op<"aten.cos_", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenAtan2Op : Torch_Op<"aten.atan2", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::atan2 : (Tensor, Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchTensorType:$other
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenAtan2Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||||
|
}
|
||||||
|
void AtenAtan2Op::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_AtenAtan2_Op : Torch_Op<"aten.atan2_", [
|
||||||
|
IsTrailingUnderscoreInplaceVariant,
|
||||||
|
AllowsTypeRefinement
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::atan2_ : (Tensor, Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchTensorType:$other
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenAtan2_Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||||
|
}
|
||||||
|
void AtenAtan2_Op::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenNegOp : Torch_Op<"aten.neg", [
|
def Torch_AtenNegOp : Torch_Op<"aten.neg", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -391,6 +391,18 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return b.create<arith::MulIOp>(loc, lhs, rhs);
|
return b.create<arith::MulIOp>(loc, lhs, rhs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (auto atan2 = dyn_cast<AtenAtan2Op>(op)) {
|
||||||
|
Type dtype = converter->convertType(atan2.getType())
|
||||||
|
.cast<RankedTensorType>()
|
||||||
|
.getElementType();
|
||||||
|
if (!dtype.isa<mlir::FloatType>()) {
|
||||||
|
atan2.emitError("Atan2 requires floating point result type");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||||
|
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||||
|
return b.create<math::Atan2Op>(loc, lhs, rhs);
|
||||||
|
}
|
||||||
if (auto gtTensor = dyn_cast<AtenGtTensorOp>(op)) {
|
if (auto gtTensor = dyn_cast<AtenGtTensorOp>(op)) {
|
||||||
AtenGtTensorOp::Adaptor adaptor(operands);
|
AtenGtTensorOp::Adaptor adaptor(operands);
|
||||||
Type lhsDtype = payloadArgs[0].getType();
|
Type lhsDtype = payloadArgs[0].getType();
|
||||||
|
@ -926,7 +938,7 @@ public:
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
if (!isa<AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp,
|
if (!isa<AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp,
|
||||||
AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp,
|
AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp,
|
||||||
AtenDivTensorOp, AtenDivTensorModeOp, AtenSubTensorOp,
|
AtenDivTensorOp, AtenDivTensorModeOp, AtenSubTensorOp, AtenAtan2Op,
|
||||||
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenExpm1Op,
|
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenExpm1Op,
|
||||||
AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
|
AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
|
||||||
AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenErfOp,
|
AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenErfOp,
|
||||||
|
@ -1669,14 +1681,14 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
||||||
AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp,
|
AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp,
|
||||||
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp,
|
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp,
|
||||||
AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp,
|
AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp,
|
||||||
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp,
|
AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp,
|
||||||
AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp,
|
AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp,
|
||||||
AtenLog2Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
|
AtenPowTensorScalarOp, AtenLog2Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp,
|
||||||
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp,
|
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenGeScalarOp,
|
||||||
AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp,
|
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
|
||||||
AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
|
AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp,
|
||||||
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp,
|
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
|
||||||
AtenLogicalOrOp, AtenTriuOp>();
|
AtenNeScalarOp, AtenMaskedFillScalarOp, AtenLogicalOrOp, AtenTriuOp>();
|
||||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenNllLossForwardOp>();
|
target.addIllegalOp<AtenNllLossForwardOp>();
|
||||||
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
||||||
|
|
|
@ -733,6 +733,23 @@ void TypeAnalysis::visitOperation(Operation *op,
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Dtype is always float32, except for bfloat16, float64 and nullptr after
|
||||||
|
// promotion and assuming possible-zero rank.
|
||||||
|
if (isa<AtenAtan2Op>(op)) {
|
||||||
|
ValueKnowledge knowledge =
|
||||||
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
|
Type promotedDtype = getPromotedResultType(
|
||||||
|
op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()},
|
||||||
|
getRankIsNonZeroArray(op->getOperands()));
|
||||||
|
if (promotedDtype) {
|
||||||
|
knowledge.dtype = Float32Type::get(op->getContext());
|
||||||
|
if (promotedDtype.isa<BFloat16Type, Float64Type>())
|
||||||
|
knowledge.dtype = promotedDtype;
|
||||||
|
}
|
||||||
|
incorporateKnowledge(op->getResult(0), knowledge);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Promote three dtypes.
|
// Promote three dtypes.
|
||||||
if (isa<AtenAddmmOp, AtenLerpTensorOp, AtenAddcmulOp, AtenAddcdivOp>(op)) {
|
if (isa<AtenAddmmOp, AtenLerpTensorOp, AtenAddcmulOp, AtenAddcdivOp>(op)) {
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
|
|
|
@ -6248,6 +6248,10 @@ module {
|
||||||
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
||||||
return %0 : !torch.list<int>
|
return %0 : !torch.list<int>
|
||||||
}
|
}
|
||||||
|
func.func @"__torch_mlir_shape_fn.aten.atan2"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
|
||||||
|
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
||||||
|
return %0 : !torch.list<int>
|
||||||
|
}
|
||||||
func.func @"__torch_mlir_shape_fn.aten.__and__.Tensor"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
|
func.func @"__torch_mlir_shape_fn.aten.__and__.Tensor"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
|
||||||
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
||||||
return %0 : !torch.list<int>
|
return %0 : !torch.list<int>
|
||||||
|
|
|
@ -807,6 +807,9 @@ def aten〇div〇Tensor_mode(self: List[int], other: List[int], rounding_mode: O
|
||||||
def aten〇floor_divide(self: List[int], other: List[int]) -> List[int]:
|
def aten〇floor_divide(self: List[int], other: List[int]) -> List[int]:
|
||||||
return upstream_shape_functions.broadcast(self, other)
|
return upstream_shape_functions.broadcast(self, other)
|
||||||
|
|
||||||
|
def aten〇atan2(self: List[int], other: List[int]) -> List[int]:
|
||||||
|
return upstream_shape_functions.broadcast(self, other)
|
||||||
|
|
||||||
def aten〇__and__〇Tensor(self: List[int], other: List[int]) -> List[int]:
|
def aten〇__and__〇Tensor(self: List[int], other: List[int]) -> List[int]:
|
||||||
return upstream_shape_functions.broadcast(self, other)
|
return upstream_shape_functions.broadcast(self, other)
|
||||||
|
|
||||||
|
|
|
@ -252,6 +252,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
"aten::exp : (Tensor) -> (Tensor)",
|
"aten::exp : (Tensor) -> (Tensor)",
|
||||||
"aten::expm1 : (Tensor) -> (Tensor)",
|
"aten::expm1 : (Tensor) -> (Tensor)",
|
||||||
"aten::cos : (Tensor) -> (Tensor)",
|
"aten::cos : (Tensor) -> (Tensor)",
|
||||||
|
"aten::atan2 : (Tensor, Tensor) -> (Tensor)",
|
||||||
"aten::neg : (Tensor) -> (Tensor)",
|
"aten::neg : (Tensor) -> (Tensor)",
|
||||||
"aten::floor : (Tensor) -> (Tensor)",
|
"aten::floor : (Tensor) -> (Tensor)",
|
||||||
"aten::ceil : (Tensor) -> (Tensor)",
|
"aten::ceil : (Tensor) -> (Tensor)",
|
||||||
|
|
|
@ -142,6 +142,8 @@ LOWERING_PIPELINE = ",".join([
|
||||||
"func.func(refback-expand-ops-for-llvm)",
|
"func.func(refback-expand-ops-for-llvm)",
|
||||||
"func.func(arith-expand)",
|
"func.func(arith-expand)",
|
||||||
"func.func(convert-math-to-llvm)",
|
"func.func(convert-math-to-llvm)",
|
||||||
|
# Handle some complex mlir::math ops (e.g. atan2)
|
||||||
|
"convert-math-to-libm",
|
||||||
"convert-linalg-to-llvm",
|
"convert-linalg-to-llvm",
|
||||||
"convert-memref-to-llvm",
|
"convert-memref-to-llvm",
|
||||||
"func.func(convert-arith-to-llvm)",
|
"func.func(convert-arith-to-llvm)",
|
||||||
|
|
|
@ -804,6 +804,77 @@ def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseAtan2TensorFloatModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.atan2(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseAtan2TensorFloatModule())
|
||||||
|
def ElementwiseAtan2TensorFloatModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseAtan2TensorIntModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.int32, True),
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.atan2(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseAtan2TensorIntModule())
|
||||||
|
def ElementwiseAtan2TensorIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(
|
||||||
|
torch.randint(1, 10, [4]).type(torch.int32), torch.randint(1, 10, [4]))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseAtan2FloatIntModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int32, True),
|
||||||
|
([-1, -1], torch.float64, True),
|
||||||
|
])
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.atan2(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseAtan2FloatIntModule())
|
||||||
|
def ElementwiseAtan2FloatIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(1, 10, [4, 4], dtype=torch.int32),
|
||||||
|
tu.rand(4, 4).double())
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseLogModule(torch.nn.Module):
|
class ElementwiseLogModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
Loading…
Reference in New Issue