add e2e support for aten.atan2 (#1117)

- Includes math-to-libm pass in refbackend for math::atan2 support
pull/1136/head
Quinn Dawkins 2022-08-02 11:39:41 -04:00 committed by GitHub
parent 704efdc259
commit 38d8498b21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 166 additions and 9 deletions

View File

@ -655,6 +655,53 @@ def Torch_AtenCos_Op : Torch_Op<"aten.cos_", [
}]; }];
} }
def Torch_AtenAtan2Op : Torch_Op<"aten.atan2", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::atan2 : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAtan2Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenAtan2Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenAtan2_Op : Torch_Op<"aten.atan2_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::atan2_ : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAtan2_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenAtan2_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenNegOp : Torch_Op<"aten.neg", [ def Torch_AtenNegOp : Torch_Op<"aten.neg", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -391,6 +391,18 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<arith::MulIOp>(loc, lhs, rhs); return b.create<arith::MulIOp>(loc, lhs, rhs);
} }
} }
if (auto atan2 = dyn_cast<AtenAtan2Op>(op)) {
Type dtype = converter->convertType(atan2.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::FloatType>()) {
atan2.emitError("Atan2 requires floating point result type");
return nullptr;
}
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<math::Atan2Op>(loc, lhs, rhs);
}
if (auto gtTensor = dyn_cast<AtenGtTensorOp>(op)) { if (auto gtTensor = dyn_cast<AtenGtTensorOp>(op)) {
AtenGtTensorOp::Adaptor adaptor(operands); AtenGtTensorOp::Adaptor adaptor(operands);
Type lhsDtype = payloadArgs[0].getType(); Type lhsDtype = payloadArgs[0].getType();
@ -926,7 +938,7 @@ public:
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
if (!isa<AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, if (!isa<AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp,
AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp, AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp,
AtenDivTensorOp, AtenDivTensorModeOp, AtenSubTensorOp, AtenDivTensorOp, AtenDivTensorModeOp, AtenSubTensorOp, AtenAtan2Op,
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenExpm1Op, AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenExpm1Op,
AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenErfOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenErfOp,
@ -1669,14 +1681,14 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp,
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp, AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp,
AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp,
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp,
AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp,
AtenLog2Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenPowTensorScalarOp, AtenLog2Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp,
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenGeScalarOp,
AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp,
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
AtenLogicalOrOp, AtenTriuOp>(); AtenNeScalarOp, AtenMaskedFillScalarOp, AtenLogicalOrOp, AtenTriuOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context); patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>(); target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenDetachOp>(typeConverter, context); patterns.add<ConvertAtenDetachOp>(typeConverter, context);

View File

@ -733,6 +733,23 @@ void TypeAnalysis::visitOperation(Operation *op,
return; return;
} }
// Dtype is always float32, except for bfloat16, float64 and nullptr after
// promotion and assuming possible-zero rank.
if (isa<AtenAtan2Op>(op)) {
ValueKnowledge knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
Type promotedDtype = getPromotedResultType(
op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()},
getRankIsNonZeroArray(op->getOperands()));
if (promotedDtype) {
knowledge.dtype = Float32Type::get(op->getContext());
if (promotedDtype.isa<BFloat16Type, Float64Type>())
knowledge.dtype = promotedDtype;
}
incorporateKnowledge(op->getResult(0), knowledge);
return;
}
// Promote three dtypes. // Promote three dtypes.
if (isa<AtenAddmmOp, AtenLerpTensorOp, AtenAddcmulOp, AtenAddcdivOp>(op)) { if (isa<AtenAddmmOp, AtenLerpTensorOp, AtenAddcmulOp, AtenAddcdivOp>(op)) {
auto knowledge = auto knowledge =

View File

@ -6248,6 +6248,10 @@ module {
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int> %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int> return %0 : !torch.list<int>
} }
func.func @"__torch_mlir_shape_fn.aten.atan2"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.__and__.Tensor"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> { func.func @"__torch_mlir_shape_fn.aten.__and__.Tensor"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int> %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int> return %0 : !torch.list<int>

View File

@ -807,6 +807,9 @@ def atendivTensor_mode(self: List[int], other: List[int], rounding_mode: O
def atenfloor_divide(self: List[int], other: List[int]) -> List[int]: def atenfloor_divide(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other) return upstream_shape_functions.broadcast(self, other)
def atenatan2(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)
def aten__and__Tensor(self: List[int], other: List[int]) -> List[int]: def aten__and__Tensor(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other) return upstream_shape_functions.broadcast(self, other)

View File

@ -252,6 +252,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::exp : (Tensor) -> (Tensor)", "aten::exp : (Tensor) -> (Tensor)",
"aten::expm1 : (Tensor) -> (Tensor)", "aten::expm1 : (Tensor) -> (Tensor)",
"aten::cos : (Tensor) -> (Tensor)", "aten::cos : (Tensor) -> (Tensor)",
"aten::atan2 : (Tensor, Tensor) -> (Tensor)",
"aten::neg : (Tensor) -> (Tensor)", "aten::neg : (Tensor) -> (Tensor)",
"aten::floor : (Tensor) -> (Tensor)", "aten::floor : (Tensor) -> (Tensor)",
"aten::ceil : (Tensor) -> (Tensor)", "aten::ceil : (Tensor) -> (Tensor)",

View File

@ -142,6 +142,8 @@ LOWERING_PIPELINE = ",".join([
"func.func(refback-expand-ops-for-llvm)", "func.func(refback-expand-ops-for-llvm)",
"func.func(arith-expand)", "func.func(arith-expand)",
"func.func(convert-math-to-llvm)", "func.func(convert-math-to-llvm)",
# Handle some complex mlir::math ops (e.g. atan2)
"convert-math-to-libm",
"convert-linalg-to-llvm", "convert-linalg-to-llvm",
"convert-memref-to-llvm", "convert-memref-to-llvm",
"func.func(convert-arith-to-llvm)", "func.func(convert-arith-to-llvm)",

View File

@ -804,6 +804,77 @@ def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class ElementwiseAtan2TensorFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, a, b):
return torch.atan2(a, b)
@register_test_case(module_factory=lambda: ElementwiseAtan2TensorFloatModule())
def ElementwiseAtan2TensorFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 4), tu.rand(4, 4))
# ==============================================================================
class ElementwiseAtan2TensorIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.int32, True),
([-1], torch.int64, True),
])
def forward(self, a, b):
return torch.atan2(a, b)
@register_test_case(module_factory=lambda: ElementwiseAtan2TensorIntModule())
def ElementwiseAtan2TensorIntModule_basic(module, tu: TestUtils):
module.forward(
torch.randint(1, 10, [4]).type(torch.int32), torch.randint(1, 10, [4]))
# ==============================================================================
class ElementwiseAtan2FloatIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
([-1, -1], torch.float64, True),
])
def forward(self, a, b):
return torch.atan2(a, b)
@register_test_case(module_factory=lambda: ElementwiseAtan2FloatIntModule())
def ElementwiseAtan2FloatIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(1, 10, [4, 4], dtype=torch.int32),
tu.rand(4, 4).double())
# ==============================================================================
class ElementwiseLogModule(torch.nn.Module): class ElementwiseLogModule(torch.nn.Module):
def __init__(self): def __init__(self):