Implement e2e support for aten.acos op

This depends on a change in the LLVM core repository which adds acos
support to the MLIR Math dialect.
pull/2633/head
Frederik Harwath 2023-12-06 02:26:13 -08:00 committed by Frederik Harwath
parent 7acabafd84
commit b656c674ee
8 changed files with 130 additions and 6 deletions

View File

@ -886,6 +886,51 @@ def Torch_AtenCos_Op : Torch_Op<"aten.cos_", [
}];
}
def Torch_AtenAcosOp : Torch_Op<"aten.acos", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::acos : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAcosOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenAcosOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenAcos_Op : Torch_Op<"aten.acos_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::acos_ : (Tensor) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self
);
let results = (outs
Torch_NonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAcos_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenAcos_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenAtanOp : Torch_Op<"aten.atan", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -39,7 +39,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.op, resultType, operand);
return success();
});
// TODO: Acos unimplemented in torch-mlir
// TODO: Acosh unimplemented in torch-mlir
// Add became forward compatible with Torch in version 7.
patterns.onOp("Add", 7,
@ -154,6 +153,17 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.op, resultType, operand);
return success();
});
patterns.onOp("Acos", 7,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType))
return failure();
rewriter.replaceOpWithNewOp<Torch::AtenAcosOp>(
binder.op, resultType, operand);
return success();
});
patterns.onOp(
"BitShift", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;

View File

@ -272,6 +272,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return createCalculationForMathOpWithDtypeConversion<math::AtanOp>(
b, converter, payloadArgs[0], op);
}
if (isa<AtenAcosOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::AcosOp>(
b, converter, payloadArgs[0], op);
}
if (auto clone = dyn_cast<AtenCloneOp>(op)) {
int64_t memoryFormat;
if (!clone.getMemoryFormat().getType().isa<Torch::NoneType>() &&
@ -1329,7 +1333,7 @@ public:
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp,
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
AtenAtanOp, AtenRealOp, AtenImagOp>(op))
AtenAtanOp, AtenAcosOp, AtenRealOp, AtenImagOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -1975,10 +1979,10 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp,
AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp,
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp,
AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenLogicalXorOp,
AtenLogicalNotOp, AtenTriuOp, AtenTrilOp, AtenRemainderScalarOp,
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
AtenRealOp, AtenImagOp>();
AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenAcosOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp,
AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
AtenFillTensorOp, AtenRealOp, AtenImagOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenDetachOp>(typeConverter, context);

View File

@ -6290,6 +6290,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.acos\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.cosine_similarity\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int, %arg3: !torch.float) -> !torch.list<int> {\n"
" %none = torch.constant.none\n"
" %int1 = torch.constant.int 1\n"
@ -8566,6 +8570,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.acos\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.sigmoid\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"

View File

@ -98,6 +98,9 @@ def atensin〡shape(self: List[int]) -> List[int]:
def atencos〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
def atenacos〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
def atencosine_similarity〡shape(x1: List[int], x2: List[int], dim: int = 1, eps: float = 1e-08) -> List[int]:
broadcast = upstream_shape_functions.broadcast(x1, x2)
return broadcast[:dim] + broadcast[dim + 1:]
@ -1552,6 +1555,11 @@ def atencos〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atenacos〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atensigmoid〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype

View File

@ -274,6 +274,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::exp : (Tensor) -> (Tensor)",
"aten::expm1 : (Tensor) -> (Tensor)",
"aten::cos : (Tensor) -> (Tensor)",
"aten::acos : (Tensor) -> (Tensor)",
"aten::atan : (Tensor) -> (Tensor)",
"aten::atan2 : (Tensor, Tensor) -> (Tensor)",
"aten::neg : (Tensor) -> (Tensor)",

View File

@ -2900,10 +2900,50 @@ class ElementwiseCosIntModule(torch.nn.Module):
def ElementwiseCosIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32))
# ==============================================================================
class ElementwiseAcosModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.acos(a)
@register_test_case(module_factory=lambda: ElementwiseAcosModule())
def ElementwiseAcosModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class ElementwiseAcosIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, a):
return torch.acos(a)
@register_test_case(module_factory=lambda: ElementwiseAcosIntModule())
def ElementwiseAcosIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32))
# ==============================================================================
class ElementwiseNegModule(torch.nn.Module):
def __init__(self):

View File

@ -103,6 +103,13 @@ func.func @test_atan(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,
return %0 : !torch.vtensor<[3,4,5],f32>
}
// CHECK-LABEL: @test_acos
func.func @test_acos(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.acos %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
%0 = torch.operator "onnx.Acos"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32>
}
// CHECK-LABEL: @test_bitshift_left_uint8
func.func @test_bitshift_left_uint8(%arg0: !torch.vtensor<[3],ui8>, %arg1: !torch.vtensor<[3],ui8>) -> !torch.vtensor<[3],ui8> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.bitwise_left_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui8>, !torch.vtensor<[3],ui8> -> !torch.vtensor<[3],ui8>