[MLIR][ONNX] Add OnnxToTorch support for atan and bitwise ops

This commit adds the OnnxToTorch support for Atan, Bitshift, BitwiseAnd,
and BitwiseNot op.
This commit also adds the TorchToLinalg support for AtenBitwiseLeftShiftTensorOp.

Signed-Off By: vivekkhandelwal@nod-labs.com
pull/2595/head
Vivek Khandelwal 2023-11-27 13:44:16 +00:00
parent 53fc995639
commit dc9ea08db5
9 changed files with 337 additions and 21 deletions

View File

@ -17,6 +17,7 @@
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallString.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include <string>
namespace mlir::torch::onnx_c { namespace mlir::torch::onnx_c {
@ -103,6 +104,22 @@ struct OpBinder {
return failure(); return failure();
} }
ParseResult customOpNameStringAttr(std::string &value, StringRef nameSuffix,
std::string defaultValue = "") {
SmallString<64> name("torch.onnx.");
name.append(nameSuffix);
auto attr = op->getAttr(name);
if (!attr) {
value = defaultValue;
return success();
}
if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
value = stringAttr.str();
return success();
}
return failure();
}
Torch::ValueTensorType toValidTensorType(Type t) { Torch::ValueTensorType toValidTensorType(Type t) {
auto tt = dyn_cast<Torch::ValueTensorType>(t); auto tt = dyn_cast<Torch::ValueTensorType>(t);
if (tt && tt.hasSizes()) if (tt && tt.hasSizes())

View File

@ -2891,6 +2891,53 @@ def Torch_AtenBitwiseXor_TensorOp : Torch_Op<"aten.bitwise_xor_.Tensor", [
}]; }];
} }
def Torch_AtenBitwiseLeftShiftTensorOp : Torch_Op<"aten.bitwise_left_shift.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::bitwise_left_shift.Tensor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenBitwiseLeftShiftTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenBitwiseLeftShiftTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenBitwiseLeftShift_TensorOp : Torch_Op<"aten.bitwise_left_shift_.Tensor", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::bitwise_left_shift_.Tensor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self,
Torch_NonValueTensorType:$other
);
let results = (outs
Torch_NonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenBitwiseLeftShift_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenBitwiseLeftShift_TensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenBitwiseRightShiftTensorOp : Torch_Op<"aten.bitwise_right_shift.Tensor", [ def Torch_AtenBitwiseRightShiftTensorOp : Torch_Op<"aten.bitwise_right_shift.Tensor", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -8,6 +8,7 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
@ -141,6 +142,57 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
}); });
// TODO: Asin unimplemented in torch-mlir // TODO: Asin unimplemented in torch-mlir
// TODO: Asinh unimplemented in torch-mlir // TODO: Asinh unimplemented in torch-mlir
// TODO: Atan unimplemented in torch-mlir
// TODO: Atanh unimplemented in torch-mlir // TODO: Atanh unimplemented in torch-mlir
patterns.onOp("Atan", 7,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType))
return failure();
rewriter.replaceOpWithNewOp<Torch::AtenAtanOp>(
binder.op, resultType, operand);
return success();
});
patterns.onOp(
"BitShift", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value lhs, rhs;
std::string direction;
if (binder.tensorOperands(lhs, rhs) ||
binder.tensorResultType(resultType) ||
binder.customOpNameStringAttr(direction, "direction", ""))
return failure();
if (direction == "LEFT") {
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseLeftShiftTensorOp>(
binder.op, resultType, lhs, rhs);
} else {
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseRightShiftTensorOp>(
binder.op, resultType, lhs, rhs);
}
return success();
});
patterns.onOp(
"BitwiseAnd", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value lhs, rhs;
std::string direction;
if (binder.tensorOperands(lhs, rhs) ||
binder.tensorResultType(resultType))
return failure();
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseAndTensorOp>(
binder.op, resultType, lhs, rhs);
return success();
});
patterns.onOp("BitwiseNot", 18,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType))
return failure();
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseNotOp>(
binder.op, resultType, operand);
return success();
});
} }

View File

@ -366,6 +366,20 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::ShRSIOp>(loc, lhs, rhs); return b.create<arith::ShRSIOp>(loc, lhs, rhs);
} }
if (auto bitwiseLeftShiftTensor =
dyn_cast<AtenBitwiseLeftShiftTensorOp>(op)) {
Type dtype = converter->convertType(bitwiseLeftShiftTensor.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::IntegerType>()) {
bitwiseLeftShiftTensor.emitError(
"Bitwise_Left_Shift op does not support non-integer input dtype.");
return nullptr;
}
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::ShLIOp>(loc, lhs, rhs);
}
if (isa<AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp>(op)) { if (isa<AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp>(op)) {
MLIRContext *context = op->getContext(); MLIRContext *context = op->getContext();
Type floatDtype = mlir::FloatType::getF64(context); Type floatDtype = mlir::FloatType::getF64(context);
@ -1252,16 +1266,17 @@ public:
AtenRsqrtOp, AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp, AtenRsqrtOp, AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp,
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp,
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp,
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenLeScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp,
AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp,
AtenAddScalarOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenLeTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp,
AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp,
AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp,
AtenFillTensorOp, AtenAtanOp, AtenRealOp, AtenImagOp>(op)) AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
AtenAtanOp, AtenRealOp, AtenImagOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -1900,16 +1915,16 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op,
AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp,
AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp,
AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp,
AtenLogicalAndOp, AtenAtanOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp,
AtenTriuOp, AtenTrilOp, AtenRemainderScalarOp, AtenBitwiseNotOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp,
AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, AtenRealOp, AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
AtenImagOp>(); AtenFillTensorOp, AtenRealOp, AtenImagOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context); patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>(); target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenDetachOp>(typeConverter, context); patterns.add<ConvertAtenDetachOp>(typeConverter, context);

View File

@ -7680,6 +7680,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.bitwise_left_shift.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.bitwise_not\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.bitwise_not\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
@ -9560,6 +9564,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n" " %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n" " return %4 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_left_shift.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.bmm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.bmm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"

View File

@ -908,6 +908,9 @@ def atenbitwise_xorTensor〡shape(self: List[int], other: List[int]) -> Li
def atenbitwise_right_shiftTensor〡shape(self: List[int], other: List[int]) -> List[int]: def atenbitwise_right_shiftTensor〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other) return upstream_shape_functions.broadcast(self, other)
def atenbitwise_left_shiftTensor〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)
def atenbitwise_not〡shape(self: List[int]) -> List[int]: def atenbitwise_not〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self) return upstream_shape_functions.unary(self)
@ -2454,6 +2457,14 @@ def atenbitwise_right_shiftTensor〡dtype(self_rank_dtype: Tuple[int, int]
dtypes = [self_dtype, other_dtype] dtypes = [self_dtype, other_dtype]
return promote_dtypes(ranks, dtypes) return promote_dtypes(ranks, dtypes)
@check_dtype_function(_check_two_tensor_op())
def atenbitwise_left_shiftTensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
other_rank, other_dtype = other_rank_dtype
self_rank, self_dtype = self_rank_dtype
ranks: List[Optional[int]] = [self_rank, other_rank]
dtypes = [self_dtype, other_dtype]
return promote_dtypes(ranks, dtypes)
@check_dtype_function( @check_dtype_function(
_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 4), (2, 4, 3)]) + _check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 4), (2, 4, 3)]) +
# Different width # Different width

View File

@ -317,6 +317,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::bitwise_and.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::bitwise_and.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::bitwise_or.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::bitwise_or.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::bitwise_xor.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::bitwise_xor.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::bitwise_left_shift.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::bitwise_right_shift.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::bitwise_right_shift.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)", "aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)",
"aten::square : (Tensor) -> (Tensor)", "aten::square : (Tensor) -> (Tensor)",

View File

@ -3649,6 +3649,69 @@ def ElementwiseBitwiseRightShiftInt8Module_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class ElementwiseBitwiseLeftShiftInt64Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
([-1, -1], torch.int64, True),
])
def forward(self, lhs, rhs):
return torch.bitwise_left_shift(lhs, rhs)
@register_test_case(module_factory=lambda: ElementwiseBitwiseLeftShiftInt64Module())
def ElementwiseBitwiseLeftShiftInt64Module_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-1000, high=1000), tu.randint(3, 4, low=0, high=64))
class ElementwiseBitwiseLeftShiftInt32Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, 4], torch.int32, True),
([-1, 1], torch.int32, True),
])
def forward(self, lhs, rhs):
return torch.bitwise_left_shift(lhs, rhs)
@register_test_case(module_factory=lambda: ElementwiseBitwiseLeftShiftInt32Module())
def ElementwiseBitwiseLeftShiftInt32Module_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-1000, high=1000).to(torch.int32), tu.randint(3, 1, low=0, high=32).to(torch.int32))
class ElementwiseBitwiseLeftShiftInt8Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int8, True),
([-1, -1], torch.int8, True),
])
def forward(self, lhs, rhs):
return torch.bitwise_left_shift(lhs, rhs)
@register_test_case(module_factory=lambda: ElementwiseBitwiseLeftShiftInt8Module())
def ElementwiseBitwiseLeftShiftInt8Module_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-100, high=100).to(torch.int8), tu.randint(3, 4, low=0, high=8).to(torch.int8))
# ==============================================================================
class ElementwiseBitwiseAndScalarInt64Module(torch.nn.Module): class ElementwiseBitwiseAndScalarInt64Module(torch.nn.Module):
def __init__(self): def __init__(self):

View File

@ -95,3 +95,101 @@ func.func @test_argmin_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) ->
%0 = torch.operator "onnx.ArgMin"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> %0 = torch.operator "onnx.ArgMin"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64>
return %0 : !torch.vtensor<[2],si64> return %0 : !torch.vtensor<[2],si64>
} }
// CHECK-LABEL: @test_atan
func.func @test_atan(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.atan %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
%0 = torch.operator "onnx.Atan"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32>
}
// CHECK-LABEL: @test_bitshift_left_uint8
func.func @test_bitshift_left_uint8(%arg0: !torch.vtensor<[3],ui8>, %arg1: !torch.vtensor<[3],ui8>) -> !torch.vtensor<[3],ui8> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.bitwise_left_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui8>, !torch.vtensor<[3],ui8> -> !torch.vtensor<[3],ui8>
%0 = torch.operator "onnx.BitShift"(%arg0, %arg1) {torch.onnx.direction = "LEFT"} : (!torch.vtensor<[3],ui8>, !torch.vtensor<[3],ui8>) -> !torch.vtensor<[3],ui8>
return %0 : !torch.vtensor<[3],ui8>
}
// CHECK-LABEL: @test_bitshift_left_uint16
func.func @test_bitshift_left_uint16(%arg0: !torch.vtensor<[3],ui16>, %arg1: !torch.vtensor<[3],ui16>) -> !torch.vtensor<[3],ui16> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.bitwise_left_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui16>, !torch.vtensor<[3],ui16> -> !torch.vtensor<[3],ui16>
%0 = torch.operator "onnx.BitShift"(%arg0, %arg1) {torch.onnx.direction = "LEFT"} : (!torch.vtensor<[3],ui16>, !torch.vtensor<[3],ui16>) -> !torch.vtensor<[3],ui16>
return %0 : !torch.vtensor<[3],ui16>
}
// CHECK-LABEL: @test_bitshift_left_uint32
func.func @test_bitshift_left_uint32(%arg0: !torch.vtensor<[3],ui32>, %arg1: !torch.vtensor<[3],ui32>) -> !torch.vtensor<[3],ui32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.bitwise_left_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui32>, !torch.vtensor<[3],ui32> -> !torch.vtensor<[3],ui32>
%0 = torch.operator "onnx.BitShift"(%arg0, %arg1) {torch.onnx.direction = "LEFT"} : (!torch.vtensor<[3],ui32>, !torch.vtensor<[3],ui32>) -> !torch.vtensor<[3],ui32>
return %0 : !torch.vtensor<[3],ui32>
}
// CHECK-LABEL: @test_bitshift_left_uint64
func.func @test_bitshift_left_uint64(%arg0: !torch.vtensor<[3],ui64>, %arg1: !torch.vtensor<[3],ui64>) -> !torch.vtensor<[3],ui64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.bitwise_left_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui64>, !torch.vtensor<[3],ui64> -> !torch.vtensor<[3],ui64>
%0 = torch.operator "onnx.BitShift"(%arg0, %arg1) {torch.onnx.direction = "LEFT"} : (!torch.vtensor<[3],ui64>, !torch.vtensor<[3],ui64>) -> !torch.vtensor<[3],ui64>
return %0 : !torch.vtensor<[3],ui64>
}
// CHECK-LABEL: @test_bitshift_right_uint8
func.func @test_bitshift_right_uint8(%arg0: !torch.vtensor<[3],ui8>, %arg1: !torch.vtensor<[3],ui8>) -> !torch.vtensor<[3],ui8> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui8>, !torch.vtensor<[3],ui8> -> !torch.vtensor<[3],ui8>
%0 = torch.operator "onnx.BitShift"(%arg0, %arg1) {torch.onnx.direction = "RIGHT"} : (!torch.vtensor<[3],ui8>, !torch.vtensor<[3],ui8>) -> !torch.vtensor<[3],ui8>
return %0 : !torch.vtensor<[3],ui8>
}
// CHECK-LABEL: @test_bitshift_right_uint16
func.func @test_bitshift_right_uint16(%arg0: !torch.vtensor<[3],ui16>, %arg1: !torch.vtensor<[3],ui16>) -> !torch.vtensor<[3],ui16> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui16>, !torch.vtensor<[3],ui16> -> !torch.vtensor<[3],ui16>
%0 = torch.operator "onnx.BitShift"(%arg0, %arg1) {torch.onnx.direction = "RIGHT"} : (!torch.vtensor<[3],ui16>, !torch.vtensor<[3],ui16>) -> !torch.vtensor<[3],ui16>
return %0 : !torch.vtensor<[3],ui16>
}
// CHECK-LABEL: @test_bitshift_right_uint32
func.func @test_bitshift_right_uint32(%arg0: !torch.vtensor<[3],ui32>, %arg1: !torch.vtensor<[3],ui32>) -> !torch.vtensor<[3],ui32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui32>, !torch.vtensor<[3],ui32> -> !torch.vtensor<[3],ui32>
%0 = torch.operator "onnx.BitShift"(%arg0, %arg1) {torch.onnx.direction = "RIGHT"} : (!torch.vtensor<[3],ui32>, !torch.vtensor<[3],ui32>) -> !torch.vtensor<[3],ui32>
return %0 : !torch.vtensor<[3],ui32>
}
// CHECK-LABEL: @test_bitshift_right_uint64
func.func @test_bitshift_right_uint64(%arg0: !torch.vtensor<[3],ui64>, %arg1: !torch.vtensor<[3],ui64>) -> !torch.vtensor<[3],ui64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui64>, !torch.vtensor<[3],ui64> -> !torch.vtensor<[3],ui64>
%0 = torch.operator "onnx.BitShift"(%arg0, %arg1) {torch.onnx.direction = "RIGHT"} : (!torch.vtensor<[3],ui64>, !torch.vtensor<[3],ui64>) -> !torch.vtensor<[3],ui64>
return %0 : !torch.vtensor<[3],ui64>
}
// CHECK-LABEL: @test_bitwise_and_i16_3d
func.func @test_bitwise_and_i16_3d(%arg0: !torch.vtensor<[3,4,5],si16>, %arg1: !torch.vtensor<[3,4,5],si16>) -> !torch.vtensor<[3,4,5],si16> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.bitwise_and.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si16>, !torch.vtensor<[3,4,5],si16> -> !torch.vtensor<[3,4,5],si16>
%0 = torch.operator "onnx.BitwiseAnd"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],si16>, !torch.vtensor<[3,4,5],si16>) -> !torch.vtensor<[3,4,5],si16>
return %0 : !torch.vtensor<[3,4,5],si16>
}
// CHECK-LABEL: @test_bitwise_and_i32_2d
func.func @test_bitwise_and_i32_2d(%arg0: !torch.vtensor<[3,4],si32>, %arg1: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],si32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.bitwise_and.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],si32>, !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],si32>
%0 = torch.operator "onnx.BitwiseAnd"(%arg0, %arg1) : (!torch.vtensor<[3,4],si32>, !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],si32>
return %0 : !torch.vtensor<[3,4],si32>
}
// CHECK-LABEL: @test_bitwise_and_ui8_bcast_4v3d
func.func @test_bitwise_and_ui8_bcast_4v3d(%arg0: !torch.vtensor<[3,4,5,6],ui8>, %arg1: !torch.vtensor<[4,5,6],ui8>) -> !torch.vtensor<[3,4,5,6],ui8> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.bitwise_and.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5,6],ui8>, !torch.vtensor<[4,5,6],ui8> -> !torch.vtensor<[3,4,5,6],ui8>
%0 = torch.operator "onnx.BitwiseAnd"(%arg0, %arg1) : (!torch.vtensor<[3,4,5,6],ui8>, !torch.vtensor<[4,5,6],ui8>) -> !torch.vtensor<[3,4,5,6],ui8>
return %0 : !torch.vtensor<[3,4,5,6],ui8>
}
// CHECK-LABEL: @test_bitwise_not_2d
func.func @test_bitwise_not_2d(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],si32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.bitwise_not %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],si32>
%0 = torch.operator "onnx.BitwiseNot"(%arg0) : (!torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],si32>
return %0 : !torch.vtensor<[3,4],si32>
}
// CHECK-LABEL: @test_bitwise_not_4d
func.func @test_bitwise_not_4d(%arg0: !torch.vtensor<[3,4,5,6],ui8>) -> !torch.vtensor<[3,4,5,6],ui8> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.bitwise_not %arg0 : !torch.vtensor<[3,4,5,6],ui8> -> !torch.vtensor<[3,4,5,6],ui8>
%0 = torch.operator "onnx.BitwiseNot"(%arg0) : (!torch.vtensor<[3,4,5,6],ui8>) -> !torch.vtensor<[3,4,5,6],ui8>
return %0 : !torch.vtensor<[3,4,5,6],ui8>
}