mirror of https://github.com/llvm/torch-mlir
[MLIR][ONNX] Add OnnxToTorch support for atan and bitwise ops
This commit adds the OnnxToTorch support for Atan, Bitshift, BitwiseAnd, and BitwiseNot op. This commit also adds the TorchToLinalg support for AtenBitwiseLeftShiftTensorOp. Signed-Off By: vivekkhandelwal@nod-labs.compull/2595/head
parent
53fc995639
commit
dc9ea08db5
|
@ -17,6 +17,7 @@
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm/ADT/SmallString.h"
|
#include "llvm/ADT/SmallString.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include <string>
|
||||||
|
|
||||||
namespace mlir::torch::onnx_c {
|
namespace mlir::torch::onnx_c {
|
||||||
|
|
||||||
|
@ -103,6 +104,22 @@ struct OpBinder {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ParseResult customOpNameStringAttr(std::string &value, StringRef nameSuffix,
|
||||||
|
std::string defaultValue = "") {
|
||||||
|
SmallString<64> name("torch.onnx.");
|
||||||
|
name.append(nameSuffix);
|
||||||
|
auto attr = op->getAttr(name);
|
||||||
|
if (!attr) {
|
||||||
|
value = defaultValue;
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
|
||||||
|
value = stringAttr.str();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
Torch::ValueTensorType toValidTensorType(Type t) {
|
Torch::ValueTensorType toValidTensorType(Type t) {
|
||||||
auto tt = dyn_cast<Torch::ValueTensorType>(t);
|
auto tt = dyn_cast<Torch::ValueTensorType>(t);
|
||||||
if (tt && tt.hasSizes())
|
if (tt && tt.hasSizes())
|
||||||
|
|
|
@ -2891,6 +2891,53 @@ def Torch_AtenBitwiseXor_TensorOp : Torch_Op<"aten.bitwise_xor_.Tensor", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenBitwiseLeftShiftTensorOp : Torch_Op<"aten.bitwise_left_shift.Tensor", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::bitwise_left_shift.Tensor : (Tensor, Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchTensorType:$other
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenBitwiseLeftShiftTensorOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||||
|
}
|
||||||
|
void AtenBitwiseLeftShiftTensorOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_AtenBitwiseLeftShift_TensorOp : Torch_Op<"aten.bitwise_left_shift_.Tensor", [
|
||||||
|
IsTrailingUnderscoreInplaceVariant,
|
||||||
|
AllowsTypeRefinement
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::bitwise_left_shift_.Tensor : (Tensor, Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
Torch_NonValueTensorType:$self,
|
||||||
|
Torch_NonValueTensorType:$other
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
Torch_NonValueTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenBitwiseLeftShift_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||||
|
}
|
||||||
|
void AtenBitwiseLeftShift_TensorOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenBitwiseRightShiftTensorOp : Torch_Op<"aten.bitwise_right_shift.Tensor", [
|
def Torch_AtenBitwiseRightShiftTensorOp : Torch_Op<"aten.bitwise_right_shift.Tensor", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -8,6 +8,7 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
|
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
|
@ -141,6 +142,57 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
});
|
});
|
||||||
// TODO: Asin unimplemented in torch-mlir
|
// TODO: Asin unimplemented in torch-mlir
|
||||||
// TODO: Asinh unimplemented in torch-mlir
|
// TODO: Asinh unimplemented in torch-mlir
|
||||||
// TODO: Atan unimplemented in torch-mlir
|
|
||||||
// TODO: Atanh unimplemented in torch-mlir
|
// TODO: Atanh unimplemented in torch-mlir
|
||||||
|
patterns.onOp("Atan", 7,
|
||||||
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
Value operand;
|
||||||
|
if (binder.tensorOperand(operand) ||
|
||||||
|
binder.tensorResultType(resultType))
|
||||||
|
return failure();
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenAtanOp>(
|
||||||
|
binder.op, resultType, operand);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
patterns.onOp(
|
||||||
|
"BitShift", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
Value lhs, rhs;
|
||||||
|
std::string direction;
|
||||||
|
if (binder.tensorOperands(lhs, rhs) ||
|
||||||
|
binder.tensorResultType(resultType) ||
|
||||||
|
binder.customOpNameStringAttr(direction, "direction", ""))
|
||||||
|
return failure();
|
||||||
|
if (direction == "LEFT") {
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseLeftShiftTensorOp>(
|
||||||
|
binder.op, resultType, lhs, rhs);
|
||||||
|
} else {
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseRightShiftTensorOp>(
|
||||||
|
binder.op, resultType, lhs, rhs);
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
patterns.onOp(
|
||||||
|
"BitwiseAnd", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
Value lhs, rhs;
|
||||||
|
std::string direction;
|
||||||
|
if (binder.tensorOperands(lhs, rhs) ||
|
||||||
|
binder.tensorResultType(resultType))
|
||||||
|
return failure();
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseAndTensorOp>(
|
||||||
|
binder.op, resultType, lhs, rhs);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
patterns.onOp("BitwiseNot", 18,
|
||||||
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
Value operand;
|
||||||
|
if (binder.tensorOperand(operand) ||
|
||||||
|
binder.tensorResultType(resultType))
|
||||||
|
return failure();
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseNotOp>(
|
||||||
|
binder.op, resultType, operand);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -366,6 +366,20 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||||
return b.create<arith::ShRSIOp>(loc, lhs, rhs);
|
return b.create<arith::ShRSIOp>(loc, lhs, rhs);
|
||||||
}
|
}
|
||||||
|
if (auto bitwiseLeftShiftTensor =
|
||||||
|
dyn_cast<AtenBitwiseLeftShiftTensorOp>(op)) {
|
||||||
|
Type dtype = converter->convertType(bitwiseLeftShiftTensor.getType())
|
||||||
|
.cast<RankedTensorType>()
|
||||||
|
.getElementType();
|
||||||
|
if (!dtype.isa<mlir::IntegerType>()) {
|
||||||
|
bitwiseLeftShiftTensor.emitError(
|
||||||
|
"Bitwise_Left_Shift op does not support non-integer input dtype.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||||
|
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||||
|
return b.create<arith::ShLIOp>(loc, lhs, rhs);
|
||||||
|
}
|
||||||
if (isa<AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp>(op)) {
|
if (isa<AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp>(op)) {
|
||||||
MLIRContext *context = op->getContext();
|
MLIRContext *context = op->getContext();
|
||||||
Type floatDtype = mlir::FloatType::getF64(context);
|
Type floatDtype = mlir::FloatType::getF64(context);
|
||||||
|
@ -1252,16 +1266,17 @@ public:
|
||||||
AtenRsqrtOp, AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp,
|
AtenRsqrtOp, AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp,
|
||||||
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp,
|
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp,
|
||||||
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
|
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
|
||||||
AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp,
|
AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp,
|
||||||
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
|
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
|
||||||
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
|
AtenLeScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp,
|
||||||
AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp,
|
AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp,
|
||||||
AtenAddScalarOp, AtenThresholdOp, AtenThresholdBackwardOp,
|
AtenLeTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
|
||||||
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
|
AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp,
|
||||||
AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp,
|
AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp,
|
||||||
AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp,
|
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
|
||||||
AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
|
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp,
|
||||||
AtenFillTensorOp, AtenAtanOp, AtenRealOp, AtenImagOp>(op))
|
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
|
||||||
|
AtenAtanOp, AtenRealOp, AtenImagOp>(op))
|
||||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||||
|
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
|
@ -1900,16 +1915,16 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
||||||
AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op,
|
AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op,
|
||||||
AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
|
AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
|
||||||
AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
|
AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
|
||||||
AtenBitwiseXorTensorOp, AtenBitwiseRightShiftTensorOp, AtenGtScalarOp,
|
AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
|
||||||
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
|
AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp,
|
||||||
AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
|
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
|
||||||
AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp,
|
AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp,
|
||||||
AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp,
|
AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
|
||||||
AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, AtenLogicalOrOp,
|
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp,
|
||||||
AtenLogicalAndOp, AtenAtanOp, AtenLogicalXorOp, AtenLogicalNotOp,
|
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp,
|
||||||
AtenTriuOp, AtenTrilOp, AtenRemainderScalarOp, AtenBitwiseNotOp,
|
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp,
|
||||||
AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, AtenRealOp,
|
AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
|
||||||
AtenImagOp>();
|
AtenFillTensorOp, AtenRealOp, AtenImagOp>();
|
||||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenNllLossForwardOp>();
|
target.addIllegalOp<AtenNllLossForwardOp>();
|
||||||
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
||||||
|
|
|
@ -7680,6 +7680,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.bitwise_left_shift.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||||
|
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||||
|
" return %0 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.bitwise_not\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.bitwise_not\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
|
@ -9560,6 +9564,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||||
" return %4 : !torch.int\n"
|
" return %4 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_left_shift.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||||
|
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
|
||||||
|
" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||||
|
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||||
|
" return %4 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.bmm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten.bmm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||||
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
|
|
@ -908,6 +908,9 @@ def aten〇bitwise_xor〇Tensor〡shape(self: List[int], other: List[int]) -> Li
|
||||||
def aten〇bitwise_right_shift〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]:
|
def aten〇bitwise_right_shift〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]:
|
||||||
return upstream_shape_functions.broadcast(self, other)
|
return upstream_shape_functions.broadcast(self, other)
|
||||||
|
|
||||||
|
def aten〇bitwise_left_shift〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]:
|
||||||
|
return upstream_shape_functions.broadcast(self, other)
|
||||||
|
|
||||||
def aten〇bitwise_not〡shape(self: List[int]) -> List[int]:
|
def aten〇bitwise_not〡shape(self: List[int]) -> List[int]:
|
||||||
return upstream_shape_functions.unary(self)
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
|
@ -2454,6 +2457,14 @@ def aten〇bitwise_right_shift〇Tensor〡dtype(self_rank_dtype: Tuple[int, int]
|
||||||
dtypes = [self_dtype, other_dtype]
|
dtypes = [self_dtype, other_dtype]
|
||||||
return promote_dtypes(ranks, dtypes)
|
return promote_dtypes(ranks, dtypes)
|
||||||
|
|
||||||
|
@check_dtype_function(_check_two_tensor_op())
|
||||||
|
def aten〇bitwise_left_shift〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
|
||||||
|
other_rank, other_dtype = other_rank_dtype
|
||||||
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
ranks: List[Optional[int]] = [self_rank, other_rank]
|
||||||
|
dtypes = [self_dtype, other_dtype]
|
||||||
|
return promote_dtypes(ranks, dtypes)
|
||||||
|
|
||||||
@check_dtype_function(
|
@check_dtype_function(
|
||||||
_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 4), (2, 4, 3)]) +
|
_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 4), (2, 4, 3)]) +
|
||||||
# Different width
|
# Different width
|
||||||
|
|
|
@ -317,6 +317,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
"aten::bitwise_and.Scalar : (Tensor, Scalar) -> (Tensor)",
|
"aten::bitwise_and.Scalar : (Tensor, Scalar) -> (Tensor)",
|
||||||
"aten::bitwise_or.Tensor : (Tensor, Tensor) -> (Tensor)",
|
"aten::bitwise_or.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||||
"aten::bitwise_xor.Tensor : (Tensor, Tensor) -> (Tensor)",
|
"aten::bitwise_xor.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||||
|
"aten::bitwise_left_shift.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||||
"aten::bitwise_right_shift.Tensor : (Tensor, Tensor) -> (Tensor)",
|
"aten::bitwise_right_shift.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||||
"aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)",
|
"aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)",
|
||||||
"aten::square : (Tensor) -> (Tensor)",
|
"aten::square : (Tensor) -> (Tensor)",
|
||||||
|
|
|
@ -3649,6 +3649,69 @@ def ElementwiseBitwiseRightShiftInt8Module_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseBitwiseLeftShiftInt64Module(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int64, True),
|
||||||
|
([-1, -1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, lhs, rhs):
|
||||||
|
return torch.bitwise_left_shift(lhs, rhs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseBitwiseLeftShiftInt64Module())
|
||||||
|
def ElementwiseBitwiseLeftShiftInt64Module_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(3, 4, low=-1000, high=1000), tu.randint(3, 4, low=0, high=64))
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseBitwiseLeftShiftInt32Module(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, 4], torch.int32, True),
|
||||||
|
([-1, 1], torch.int32, True),
|
||||||
|
])
|
||||||
|
def forward(self, lhs, rhs):
|
||||||
|
return torch.bitwise_left_shift(lhs, rhs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseBitwiseLeftShiftInt32Module())
|
||||||
|
def ElementwiseBitwiseLeftShiftInt32Module_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(3, 4, low=-1000, high=1000).to(torch.int32), tu.randint(3, 1, low=0, high=32).to(torch.int32))
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseBitwiseLeftShiftInt8Module(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int8, True),
|
||||||
|
([-1, -1], torch.int8, True),
|
||||||
|
])
|
||||||
|
def forward(self, lhs, rhs):
|
||||||
|
return torch.bitwise_left_shift(lhs, rhs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseBitwiseLeftShiftInt8Module())
|
||||||
|
def ElementwiseBitwiseLeftShiftInt8Module_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(3, 4, low=-100, high=100).to(torch.int8), tu.randint(3, 4, low=0, high=8).to(torch.int8))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseBitwiseAndScalarInt64Module(torch.nn.Module):
|
class ElementwiseBitwiseAndScalarInt64Module(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
@ -95,3 +95,101 @@ func.func @test_argmin_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) ->
|
||||||
%0 = torch.operator "onnx.ArgMin"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64>
|
%0 = torch.operator "onnx.ArgMin"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64>
|
||||||
return %0 : !torch.vtensor<[2],si64>
|
return %0 : !torch.vtensor<[2],si64>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_atan
|
||||||
|
func.func @test_atan(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: torch.aten.atan %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
|
||||||
|
%0 = torch.operator "onnx.Atan"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_bitshift_left_uint8
|
||||||
|
func.func @test_bitshift_left_uint8(%arg0: !torch.vtensor<[3],ui8>, %arg1: !torch.vtensor<[3],ui8>) -> !torch.vtensor<[3],ui8> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: torch.aten.bitwise_left_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui8>, !torch.vtensor<[3],ui8> -> !torch.vtensor<[3],ui8>
|
||||||
|
%0 = torch.operator "onnx.BitShift"(%arg0, %arg1) {torch.onnx.direction = "LEFT"} : (!torch.vtensor<[3],ui8>, !torch.vtensor<[3],ui8>) -> !torch.vtensor<[3],ui8>
|
||||||
|
return %0 : !torch.vtensor<[3],ui8>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_bitshift_left_uint16
|
||||||
|
func.func @test_bitshift_left_uint16(%arg0: !torch.vtensor<[3],ui16>, %arg1: !torch.vtensor<[3],ui16>) -> !torch.vtensor<[3],ui16> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: torch.aten.bitwise_left_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui16>, !torch.vtensor<[3],ui16> -> !torch.vtensor<[3],ui16>
|
||||||
|
%0 = torch.operator "onnx.BitShift"(%arg0, %arg1) {torch.onnx.direction = "LEFT"} : (!torch.vtensor<[3],ui16>, !torch.vtensor<[3],ui16>) -> !torch.vtensor<[3],ui16>
|
||||||
|
return %0 : !torch.vtensor<[3],ui16>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_bitshift_left_uint32
|
||||||
|
func.func @test_bitshift_left_uint32(%arg0: !torch.vtensor<[3],ui32>, %arg1: !torch.vtensor<[3],ui32>) -> !torch.vtensor<[3],ui32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: torch.aten.bitwise_left_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui32>, !torch.vtensor<[3],ui32> -> !torch.vtensor<[3],ui32>
|
||||||
|
%0 = torch.operator "onnx.BitShift"(%arg0, %arg1) {torch.onnx.direction = "LEFT"} : (!torch.vtensor<[3],ui32>, !torch.vtensor<[3],ui32>) -> !torch.vtensor<[3],ui32>
|
||||||
|
return %0 : !torch.vtensor<[3],ui32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_bitshift_left_uint64
|
||||||
|
func.func @test_bitshift_left_uint64(%arg0: !torch.vtensor<[3],ui64>, %arg1: !torch.vtensor<[3],ui64>) -> !torch.vtensor<[3],ui64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: torch.aten.bitwise_left_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui64>, !torch.vtensor<[3],ui64> -> !torch.vtensor<[3],ui64>
|
||||||
|
%0 = torch.operator "onnx.BitShift"(%arg0, %arg1) {torch.onnx.direction = "LEFT"} : (!torch.vtensor<[3],ui64>, !torch.vtensor<[3],ui64>) -> !torch.vtensor<[3],ui64>
|
||||||
|
return %0 : !torch.vtensor<[3],ui64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_bitshift_right_uint8
|
||||||
|
func.func @test_bitshift_right_uint8(%arg0: !torch.vtensor<[3],ui8>, %arg1: !torch.vtensor<[3],ui8>) -> !torch.vtensor<[3],ui8> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui8>, !torch.vtensor<[3],ui8> -> !torch.vtensor<[3],ui8>
|
||||||
|
%0 = torch.operator "onnx.BitShift"(%arg0, %arg1) {torch.onnx.direction = "RIGHT"} : (!torch.vtensor<[3],ui8>, !torch.vtensor<[3],ui8>) -> !torch.vtensor<[3],ui8>
|
||||||
|
return %0 : !torch.vtensor<[3],ui8>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_bitshift_right_uint16
|
||||||
|
func.func @test_bitshift_right_uint16(%arg0: !torch.vtensor<[3],ui16>, %arg1: !torch.vtensor<[3],ui16>) -> !torch.vtensor<[3],ui16> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui16>, !torch.vtensor<[3],ui16> -> !torch.vtensor<[3],ui16>
|
||||||
|
%0 = torch.operator "onnx.BitShift"(%arg0, %arg1) {torch.onnx.direction = "RIGHT"} : (!torch.vtensor<[3],ui16>, !torch.vtensor<[3],ui16>) -> !torch.vtensor<[3],ui16>
|
||||||
|
return %0 : !torch.vtensor<[3],ui16>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_bitshift_right_uint32
|
||||||
|
func.func @test_bitshift_right_uint32(%arg0: !torch.vtensor<[3],ui32>, %arg1: !torch.vtensor<[3],ui32>) -> !torch.vtensor<[3],ui32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui32>, !torch.vtensor<[3],ui32> -> !torch.vtensor<[3],ui32>
|
||||||
|
%0 = torch.operator "onnx.BitShift"(%arg0, %arg1) {torch.onnx.direction = "RIGHT"} : (!torch.vtensor<[3],ui32>, !torch.vtensor<[3],ui32>) -> !torch.vtensor<[3],ui32>
|
||||||
|
return %0 : !torch.vtensor<[3],ui32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_bitshift_right_uint64
|
||||||
|
func.func @test_bitshift_right_uint64(%arg0: !torch.vtensor<[3],ui64>, %arg1: !torch.vtensor<[3],ui64>) -> !torch.vtensor<[3],ui64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui64>, !torch.vtensor<[3],ui64> -> !torch.vtensor<[3],ui64>
|
||||||
|
%0 = torch.operator "onnx.BitShift"(%arg0, %arg1) {torch.onnx.direction = "RIGHT"} : (!torch.vtensor<[3],ui64>, !torch.vtensor<[3],ui64>) -> !torch.vtensor<[3],ui64>
|
||||||
|
return %0 : !torch.vtensor<[3],ui64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_bitwise_and_i16_3d
|
||||||
|
func.func @test_bitwise_and_i16_3d(%arg0: !torch.vtensor<[3,4,5],si16>, %arg1: !torch.vtensor<[3,4,5],si16>) -> !torch.vtensor<[3,4,5],si16> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: torch.aten.bitwise_and.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si16>, !torch.vtensor<[3,4,5],si16> -> !torch.vtensor<[3,4,5],si16>
|
||||||
|
%0 = torch.operator "onnx.BitwiseAnd"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],si16>, !torch.vtensor<[3,4,5],si16>) -> !torch.vtensor<[3,4,5],si16>
|
||||||
|
return %0 : !torch.vtensor<[3,4,5],si16>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_bitwise_and_i32_2d
|
||||||
|
func.func @test_bitwise_and_i32_2d(%arg0: !torch.vtensor<[3,4],si32>, %arg1: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],si32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: torch.aten.bitwise_and.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],si32>, !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],si32>
|
||||||
|
%0 = torch.operator "onnx.BitwiseAnd"(%arg0, %arg1) : (!torch.vtensor<[3,4],si32>, !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],si32>
|
||||||
|
return %0 : !torch.vtensor<[3,4],si32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_bitwise_and_ui8_bcast_4v3d
|
||||||
|
func.func @test_bitwise_and_ui8_bcast_4v3d(%arg0: !torch.vtensor<[3,4,5,6],ui8>, %arg1: !torch.vtensor<[4,5,6],ui8>) -> !torch.vtensor<[3,4,5,6],ui8> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: torch.aten.bitwise_and.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5,6],ui8>, !torch.vtensor<[4,5,6],ui8> -> !torch.vtensor<[3,4,5,6],ui8>
|
||||||
|
%0 = torch.operator "onnx.BitwiseAnd"(%arg0, %arg1) : (!torch.vtensor<[3,4,5,6],ui8>, !torch.vtensor<[4,5,6],ui8>) -> !torch.vtensor<[3,4,5,6],ui8>
|
||||||
|
return %0 : !torch.vtensor<[3,4,5,6],ui8>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_bitwise_not_2d
|
||||||
|
func.func @test_bitwise_not_2d(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],si32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: torch.aten.bitwise_not %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],si32>
|
||||||
|
%0 = torch.operator "onnx.BitwiseNot"(%arg0) : (!torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],si32>
|
||||||
|
return %0 : !torch.vtensor<[3,4],si32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_bitwise_not_4d
|
||||||
|
func.func @test_bitwise_not_4d(%arg0: !torch.vtensor<[3,4,5,6],ui8>) -> !torch.vtensor<[3,4,5,6],ui8> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: torch.aten.bitwise_not %arg0 : !torch.vtensor<[3,4,5,6],ui8> -> !torch.vtensor<[3,4,5,6],ui8>
|
||||||
|
%0 = torch.operator "onnx.BitwiseNot"(%arg0) : (!torch.vtensor<[3,4,5,6],ui8>) -> !torch.vtensor<[3,4,5,6],ui8>
|
||||||
|
return %0 : !torch.vtensor<[3,4,5,6],ui8>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue