mirror of https://github.com/llvm/torch-mlir
Implement e2e support for aten.acos op
This depends on a change in the LLVM core repository which adds acos support to the MLIR Math dialect.pull/2633/head
parent
7acabafd84
commit
b656c674ee
|
@ -886,6 +886,51 @@ def Torch_AtenCos_Op : Torch_Op<"aten.cos_", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenAcosOp : Torch_Op<"aten.acos", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::acos : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenAcosOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenAcosOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenAcos_Op : Torch_Op<"aten.acos_", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::acos_ : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
Torch_NonValueTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
Torch_NonValueTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenAcos_Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenAcos_Op::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenAtanOp : Torch_Op<"aten.atan", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -39,7 +39,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
binder.op, resultType, operand);
|
||||
return success();
|
||||
});
|
||||
// TODO: Acos unimplemented in torch-mlir
|
||||
// TODO: Acosh unimplemented in torch-mlir
|
||||
// Add became forward compatible with Torch in version 7.
|
||||
patterns.onOp("Add", 7,
|
||||
|
@ -154,6 +153,17 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
binder.op, resultType, operand);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp("Acos", 7,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value operand;
|
||||
if (binder.tensorOperand(operand) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenAcosOp>(
|
||||
binder.op, resultType, operand);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"BitShift", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
|
|
|
@ -272,6 +272,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return createCalculationForMathOpWithDtypeConversion<math::AtanOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenAcosOp>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::AcosOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (auto clone = dyn_cast<AtenCloneOp>(op)) {
|
||||
int64_t memoryFormat;
|
||||
if (!clone.getMemoryFormat().getType().isa<Torch::NoneType>() &&
|
||||
|
@ -1329,7 +1333,7 @@ public:
|
|||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
|
||||
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp,
|
||||
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
|
||||
AtenAtanOp, AtenRealOp, AtenImagOp>(op))
|
||||
AtenAtanOp, AtenAcosOp, AtenRealOp, AtenImagOp>(op))
|
||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
|
@ -1975,10 +1979,10 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
|||
AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp,
|
||||
AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp,
|
||||
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp,
|
||||
AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenLogicalXorOp,
|
||||
AtenLogicalNotOp, AtenTriuOp, AtenTrilOp, AtenRemainderScalarOp,
|
||||
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
|
||||
AtenRealOp, AtenImagOp>();
|
||||
AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenAcosOp,
|
||||
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp,
|
||||
AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
|
||||
AtenFillTensorOp, AtenRealOp, AtenImagOp>();
|
||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenNllLossForwardOp>();
|
||||
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
||||
|
|
|
@ -6290,6 +6290,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.acos\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.cosine_similarity\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int, %arg3: !torch.float) -> !torch.list<int> {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
|
@ -8566,6 +8570,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
|
||||
" return %1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.acos\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
|
||||
" return %1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.sigmoid\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
|
||||
|
|
|
@ -98,6 +98,9 @@ def aten〇sin〡shape(self: List[int]) -> List[int]:
|
|||
def aten〇cos〡shape(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇acos〡shape(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇cosine_similarity〡shape(x1: List[int], x2: List[int], dim: int = 1, eps: float = 1e-08) -> List[int]:
|
||||
broadcast = upstream_shape_functions.broadcast(x1, x2)
|
||||
return broadcast[:dim] + broadcast[dim + 1:]
|
||||
|
@ -1552,6 +1555,11 @@ def aten〇cos〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
|||
self_rank, self_dtype = self_rank_dtype
|
||||
return _get_dtype_of_floating_point_op(self_dtype)
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||
def aten〇acos〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return _get_dtype_of_floating_point_op(self_dtype)
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||
def aten〇sigmoid〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
|
|
|
@ -274,6 +274,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
"aten::exp : (Tensor) -> (Tensor)",
|
||||
"aten::expm1 : (Tensor) -> (Tensor)",
|
||||
"aten::cos : (Tensor) -> (Tensor)",
|
||||
"aten::acos : (Tensor) -> (Tensor)",
|
||||
"aten::atan : (Tensor) -> (Tensor)",
|
||||
"aten::atan2 : (Tensor, Tensor) -> (Tensor)",
|
||||
"aten::neg : (Tensor) -> (Tensor)",
|
||||
|
|
|
@ -2900,10 +2900,50 @@ class ElementwiseCosIntModule(torch.nn.Module):
|
|||
def ElementwiseCosIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAcosModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.acos(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAcosModule())
|
||||
def ElementwiseAcosModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAcosIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.acos(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAcosIntModule())
|
||||
def ElementwiseAcosIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseNegModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
|
@ -103,6 +103,13 @@ func.func @test_atan(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,
|
|||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_acos
|
||||
func.func @test_acos(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.acos %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
|
||||
%0 = torch.operator "onnx.Acos"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
|
||||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_bitshift_left_uint8
|
||||
func.func @test_bitshift_left_uint8(%arg0: !torch.vtensor<[3],ui8>, %arg1: !torch.vtensor<[3],ui8>) -> !torch.vtensor<[3],ui8> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.bitwise_left_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui8>, !torch.vtensor<[3],ui8> -> !torch.vtensor<[3],ui8>
|
||||
|
|
Loading…
Reference in New Issue