mirror of https://github.com/llvm/torch-mlir
[onnx] Lowerings from `onnx.tan` (#2642)
Started work on the `tan` lowerings for ONNX to Torch. Uses `sin` and `cos` to represent a `tan`.pull/2411/merge
parent
a24aadbfab
commit
11cc92d4ab
|
@ -1066,6 +1066,51 @@ def Torch_AtenAcos_Op : Torch_Op<"aten.acos_", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenTanOp : Torch_Op<"aten.tan", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::tan : (Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenTanOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||||
|
}
|
||||||
|
void AtenTanOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 1, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_AtenTan_Op : Torch_Op<"aten.tan_", [
|
||||||
|
IsTrailingUnderscoreInplaceVariant,
|
||||||
|
AllowsTypeRefinement
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::tan_ : (Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
Torch_NonValueTensorType:$self
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
Torch_NonValueTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenTan_Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||||
|
}
|
||||||
|
void AtenTan_Op::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 1, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenAtanOp : Torch_Op<"aten.atan", [
|
def Torch_AtenAtanOp : Torch_Op<"aten.atan", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -795,6 +795,18 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
patterns.onOp("Tan", 7,
|
||||||
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
Value operand;
|
||||||
|
if (binder.tensorOperand(operand) ||
|
||||||
|
binder.tensorResultType(resultType))
|
||||||
|
return failure();
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenTanOp>(
|
||||||
|
binder.op, resultType, operand);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
|
||||||
patterns.onOp(
|
patterns.onOp(
|
||||||
"Transpose", 13,
|
"Transpose", 13,
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
|
|
@ -216,6 +216,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return b.create<math::FloorOp>(loc, payloadArgs[0]);
|
return b.create<math::FloorOp>(loc, payloadArgs[0]);
|
||||||
if (isa<AtenCeilOp>(op))
|
if (isa<AtenCeilOp>(op))
|
||||||
return b.create<math::CeilOp>(loc, payloadArgs[0]);
|
return b.create<math::CeilOp>(loc, payloadArgs[0]);
|
||||||
|
if (isa<AtenTanOp>(op)) {
|
||||||
|
return createCalculationForMathOpWithDtypeConversion<math::TanOp>(
|
||||||
|
b, converter, payloadArgs[0], op);
|
||||||
|
}
|
||||||
if (isa<AtenTanhOp>(op)) {
|
if (isa<AtenTanhOp>(op)) {
|
||||||
return createCalculationForMathOpWithDtypeConversion<math::TanhOp>(
|
return createCalculationForMathOpWithDtypeConversion<math::TanhOp>(
|
||||||
b, converter, payloadArgs[0], op);
|
b, converter, payloadArgs[0], op);
|
||||||
|
@ -1319,15 +1323,15 @@ public:
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
if (!isa<AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenReluOp, AtenPreluOp,
|
if (!isa<AtenTanOp, AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenReluOp,
|
||||||
AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp,
|
AtenPreluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp,
|
||||||
AtenDivTensorOp, AtenDivTensorModeOp, AtenSubTensorOp, AtenAtan2Op,
|
AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp,
|
||||||
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenExpm1Op,
|
AtenSubTensorOp, AtenAtan2Op, AtenLerpTensorOp, AtenSigmoidOp,
|
||||||
AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
|
AtenExpOp, AtenExpm1Op, AtenMinimumOp, AtenMaximumOp,
|
||||||
AtenClampTensorOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp,
|
AtenToDtypeOp, AtenClampOp, AtenClampTensorOp, AtenRsubScalarOp,
|
||||||
AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenPowScalarOp,
|
AtenMulScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp,
|
||||||
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op,
|
AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp,
|
||||||
AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp,
|
AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp,
|
||||||
AtenRemainderScalarOp, AtenAbsOp, AtenReciprocalOp,
|
AtenRemainderScalarOp, AtenAbsOp, AtenReciprocalOp,
|
||||||
AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp,
|
AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp,
|
||||||
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
|
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
|
||||||
|
@ -1972,7 +1976,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
||||||
ConversionTarget &target) {
|
ConversionTarget &target) {
|
||||||
MLIRContext *context = patterns.getContext();
|
MLIRContext *context = patterns.getContext();
|
||||||
target.addIllegalOp<
|
target.addIllegalOp<
|
||||||
AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenReluOp, AtenGeluOp,
|
AtenTanOp, AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenReluOp, AtenGeluOp,
|
||||||
AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp,
|
AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp,
|
||||||
AtenDivTensorModeOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp,
|
AtenDivTensorModeOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp,
|
||||||
AtenMinimumOp, AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
|
AtenMinimumOp, AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
|
||||||
|
|
|
@ -6238,6 +6238,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.tan\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||||
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
|
" return %0 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.atan\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.atan\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
|
@ -11396,6 +11400,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" }\n"
|
" }\n"
|
||||||
" return %4 : !torch.tuple<int, int>\n"
|
" return %4 : !torch.tuple<int, int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.tan\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||||
|
" %int6 = torch.constant.int 6\n"
|
||||||
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||||
|
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
|
||||||
|
" torch.prim.If.yield %int6 : !torch.int\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.If.yield %0#1 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" return %2 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.atan2\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten.atan2\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||||
" %int6 = torch.constant.int 6\n"
|
" %int6 = torch.constant.int 6\n"
|
||||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
|
|
@ -59,6 +59,9 @@ def aten〇triu〡shape(self: List[int], diagonal: int = 0) -> List[int]:
|
||||||
def aten〇tril〡shape(self: List[int], diagonal: int = 0) -> List[int]:
|
def aten〇tril〡shape(self: List[int], diagonal: int = 0) -> List[int]:
|
||||||
return upstream_shape_functions.unary(self)
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
|
def aten〇tan〡shape(self: List[int]) -> List[int]:
|
||||||
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
def aten〇atan〡shape(self: List[int]) -> List[int]:
|
def aten〇atan〡shape(self: List[int]) -> List[int]:
|
||||||
return upstream_shape_functions.unary(self)
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
|
@ -3721,6 +3724,13 @@ def aten〇var_mean〡dtype(self_rank_dtype: Tuple[int, int], unbiased: bool = T
|
||||||
return torch.float64, self_dtype
|
return torch.float64, self_dtype
|
||||||
return self_dtype, self_dtype
|
return self_dtype, self_dtype
|
||||||
|
|
||||||
|
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||||
|
def aten〇tan〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||||
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
if is_integer_dtype(self_dtype):
|
||||||
|
return torch.float32
|
||||||
|
return self_dtype
|
||||||
|
|
||||||
@check_dtype_function(_check_two_tensor_op())
|
@check_dtype_function(_check_two_tensor_op())
|
||||||
def aten〇atan2〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
|
def aten〇atan2〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
|
||||||
self_rank, self_dtype = self_rank_dtype
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
|
|
@ -278,6 +278,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
"aten::expm1 : (Tensor) -> (Tensor)",
|
"aten::expm1 : (Tensor) -> (Tensor)",
|
||||||
"aten::cos : (Tensor) -> (Tensor)",
|
"aten::cos : (Tensor) -> (Tensor)",
|
||||||
"aten::acos : (Tensor) -> (Tensor)",
|
"aten::acos : (Tensor) -> (Tensor)",
|
||||||
|
"aten::tan : (Tensor) -> (Tensor)",
|
||||||
"aten::atan : (Tensor) -> (Tensor)",
|
"aten::atan : (Tensor) -> (Tensor)",
|
||||||
"aten::atan2 : (Tensor, Tensor) -> (Tensor)",
|
"aten::atan2 : (Tensor, Tensor) -> (Tensor)",
|
||||||
"aten::neg : (Tensor) -> (Tensor)",
|
"aten::neg : (Tensor) -> (Tensor)",
|
||||||
|
|
|
@ -3009,6 +3009,46 @@ def ElementwiseAcosIntModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ElementwiseTanModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.tan(a)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseTanModule())
|
||||||
|
def ElementwiseTanModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ElementwiseTanIntModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.tan(a)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseTanIntModule())
|
||||||
|
def ElementwiseTanIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class ElementwiseNegModule(torch.nn.Module):
|
class ElementwiseNegModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
@ -795,6 +795,15 @@ func.func @test_sinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_tan
|
||||||
|
func.func @test_tan(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[TAN:.+]] = torch.aten.tan %arg0
|
||||||
|
%0 = torch.operator "onnx.Tan"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_transpose_default
|
// CHECK-LABEL: func.func @test_transpose_default
|
||||||
func.func @test_transpose_default(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,3,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
|
func.func @test_transpose_default(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,3,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
|
||||||
// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
|
// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
|
||||||
|
|
Loading…
Reference in New Issue