[MLIR][TORCH] Add support for bitwise_right_shit and bitwise_and.Scalar op

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/2497/head snapshot-20231002.979
Vivek Khandelwal 2023-09-28 12:53:02 +00:00
parent c434736ee9
commit 9293326e1e
7 changed files with 299 additions and 18 deletions

View File

@ -1421,4 +1421,6 @@ LTC_XFAIL_SET = {
"UniformStaticShapeModule_basic",
"AtenEmbeddingBagStaticModule_basic",
"EmptyStridedModule_basic",
"ElementwiseBitwiseAndScalarInt64Module_basic",
"ElementwiseBitwiseAndScalarInt32Module_basic",
}

View File

@ -2844,6 +2844,53 @@ def Torch_AtenBitwiseAnd_TensorOp : Torch_Op<"aten.bitwise_and_.Tensor", [
}];
}
def Torch_AtenBitwiseAndScalarOp : Torch_Op<"aten.bitwise_and.Scalar", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::bitwise_and.Scalar : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenBitwiseAndScalarOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenBitwiseAndScalarOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenBitwiseAnd_ScalarOp : Torch_Op<"aten.bitwise_and_.Scalar", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::bitwise_and_.Scalar : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenBitwiseAnd_ScalarOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenBitwiseAnd_ScalarOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenBitwiseOrTensorOp : Torch_Op<"aten.bitwise_or.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
@ -2938,6 +2985,53 @@ def Torch_AtenBitwiseXor_TensorOp : Torch_Op<"aten.bitwise_xor_.Tensor", [
}];
}
def Torch_AtenBitwiseRightShiftTensorOp : Torch_Op<"aten.bitwise_right_shift.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::bitwise_right_shift.Tensor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenBitwiseRightShiftTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenBitwiseRightShiftTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenBitwiseRightShift_TensorOp : Torch_Op<"aten.bitwise_right_shift_.Tensor", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::bitwise_right_shift_.Tensor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenBitwiseRightShift_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenBitwiseRightShift_TensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenThresholdOp : Torch_Op<"aten.threshold", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -300,6 +300,19 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::AndIOp>(loc, lhs, rhs);
}
if (auto bitwiseAndScalar = dyn_cast<AtenBitwiseAndScalarOp>(op)) {
Type dtype = converter->convertType(bitwiseAndScalar.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::IntegerType>()) {
bitwiseAndScalar.emitError(
"bitwise_and.Scalar does not support non-integer input dtype.");
return nullptr;
}
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
return b.create<arith::AndIOp>(loc, self, other);
}
if (auto bitwiseOrTensor = dyn_cast<AtenBitwiseOrTensorOp>(op)) {
if (bitwiseOrTensor.getType()
.cast<ValueTensorType>()
@ -332,6 +345,20 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::XOrIOp>(loc, lhs, rhs);
}
if (auto bitwiseRightShiftTensor =
dyn_cast<AtenBitwiseRightShiftTensorOp>(op)) {
Type dtype = converter->convertType(bitwiseRightShiftTensor.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::IntegerType>()) {
bitwiseRightShiftTensor.emitError(
"Bitwise_Right_Shift op does not support non-integer input dtype.");
return nullptr;
}
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::ShRSIOp>(loc, lhs, rhs);
}
if (isa<AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp>(op)) {
MLIRContext *context = op->getContext();
Type floatDtype = mlir::FloatType::getF64(context);
@ -571,7 +598,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
if (dtype.isa<mlir::FloatType>()) {
return b.create<arith::MulFOp>(loc, lhs, rhs);
} else if(dtype.isa<mlir::ComplexType>()) {
} else if (dtype.isa<mlir::ComplexType>()) {
return b.create<complex::MulOp>(loc, lhs, rhs);
} else {
return b.create<arith::MulIOp>(loc, lhs, rhs);
@ -1066,7 +1093,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
.getElementType();
Value self = payloadArgs[0];
Value threshold = convertScalarToDtype(b, loc, adaptor.getThreshold(), dtype);
Value threshold =
convertScalarToDtype(b, loc, adaptor.getThreshold(), dtype);
Value value = convertScalarToDtype(b, loc, adaptor.getValue(), dtype);
Value predicate;
@ -1088,7 +1116,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value grad = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value self = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
Value threshold = convertScalarToDtype(b, loc, adaptor.getThreshold(), dtype);
Value threshold =
convertScalarToDtype(b, loc, adaptor.getThreshold(), dtype);
Value constantZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
Value predicate;
@ -1197,10 +1226,11 @@ public:
AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenErfOp,
AtenSqrtOp, AtenFloorOp, AtenPowScalarOp, AtenPowTensorScalarOp,
AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp,
AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp,
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenGtScalarOp, AtenGeScalarOp,
AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp,
AtenRsqrtOp, AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp,
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp,
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp,
@ -1699,7 +1729,8 @@ public:
return failure();
Type resultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, adaptor.getSelf());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
adaptor.getSelf());
return success();
}
};
@ -1735,16 +1766,17 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPreluOp,
AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op,
AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp,
AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp,
AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp,
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp,
AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenLogicalXorOp,
AtenLogicalNotOp, AtenTriuOp, AtenTrilOp, AtenRemainderScalarOp,
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
AtenRealOp, AtenImagOp>();
AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenBitwiseRightShiftTensorOp, AtenGtScalarOp,
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp,
AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp,
AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, AtenLogicalOrOp,
AtenLogicalAndOp, AtenAtanOp, AtenLogicalXorOp, AtenLogicalNotOp,
AtenTriuOp, AtenTrilOp, AtenRemainderScalarOp, AtenBitwiseNotOp,
AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, AtenRealOp,
AtenImagOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenDetachOp>(typeConverter, context);

View File

@ -7410,10 +7410,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.bitwise_and.Scalar\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.bitwise_xor.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.bitwise_right_shift.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.bitwise_not\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -9201,6 +9209,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_and.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_or.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
@ -9217,6 +9234,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_right_shift.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.bmm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"

View File

@ -796,9 +796,15 @@ def atenbitwise_orTensor〡shape(self: List[int], other: List[int]) -> Lis
def atenbitwise_andTensor〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)
def atenbitwise_andScalar〡shape(self: List[int], other: float) -> List[int]:
return upstream_shape_functions.unary(self)
def atenbitwise_xorTensor〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)
def atenbitwise_right_shiftTensor〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)
def atenbitwise_not〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
@ -2265,6 +2271,14 @@ def atenbitwise_andTensor〡dtype(self_rank_dtype: Tuple[int, int], other_
dtypes = [self_dtype, other_dtype]
return promote_dtypes(ranks, dtypes)
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0))
def atenbitwise_andScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
self_rank, self_dtype = self_rank_dtype
ranks: List[Optional[int]] = [self_rank, None]
dtypes = [self_dtype, get_dtype_of_scalar(other)]
return promote_dtypes(ranks, dtypes)
@check_dtype_function(_check_two_tensor_op())
def atenbitwise_orTensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
other_rank, other_dtype = other_rank_dtype
@ -2281,6 +2295,14 @@ def atenbitwise_xorTensor〡dtype(self_rank_dtype: Tuple[int, int], other_
dtypes = [self_dtype, other_dtype]
return promote_dtypes(ranks, dtypes)
@check_dtype_function(_check_two_tensor_op())
def atenbitwise_right_shiftTensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
other_rank, other_dtype = other_rank_dtype
self_rank, self_dtype = self_rank_dtype
ranks: List[Optional[int]] = [self_rank, other_rank]
dtypes = [self_dtype, other_dtype]
return promote_dtypes(ranks, dtypes)
@check_dtype_function(
_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 4), (2, 4, 3)]) +
# Different width

View File

@ -301,8 +301,10 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::abs : (Tensor) -> (Tensor)",
"aten::reciprocal : (Tensor) -> (Tensor)",
"aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::bitwise_and.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::bitwise_or.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::bitwise_xor.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::bitwise_right_shift.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)",
"aten::square : (Tensor) -> (Tensor)",
"aten::unsqueeze : (Tensor, int) -> (Tensor)",

View File

@ -3515,3 +3515,107 @@ class TupleModule(torch.nn.Module):
@register_test_case(module_factory=lambda: TupleModule())
def TupleModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 2), tu.rand(2, 2))
# ==============================================================================
class ElementwiseBitwiseRightShiftInt64Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
([-1, -1], torch.int64, True),
])
def forward(self, lhs, rhs):
return torch.bitwise_right_shift(lhs, rhs)
@register_test_case(module_factory=lambda: ElementwiseBitwiseRightShiftInt64Module())
def ElementwiseBitwiseRightShiftInt64Module_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-1000, high=1000), tu.randint(3, 4, low=0, high=64))
class ElementwiseBitwiseRightShiftInt32Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, 4], torch.int32, True),
([-1, 1], torch.int32, True),
])
def forward(self, lhs, rhs):
return torch.bitwise_right_shift(lhs, rhs)
@register_test_case(module_factory=lambda: ElementwiseBitwiseRightShiftInt32Module())
def ElementwiseBitwiseRightShiftInt32Module_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-1000, high=1000).to(torch.int32), tu.randint(3, 1, low=0, high=32).to(torch.int32))
class ElementwiseBitwiseRightShiftInt8Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int8, True),
([-1, -1], torch.int8, True),
])
def forward(self, lhs, rhs):
return torch.bitwise_right_shift(lhs, rhs)
@register_test_case(module_factory=lambda: ElementwiseBitwiseRightShiftInt8Module())
def ElementwiseBitwiseRightShiftInt8Module_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-100, high=100).to(torch.int8), tu.randint(3, 4, low=0, high=8).to(torch.int8))
# ==============================================================================
class ElementwiseBitwiseAndScalarInt64Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, x):
return torch.bitwise_and(x, 15)
@register_test_case(module_factory=lambda: ElementwiseBitwiseAndScalarInt64Module())
def ElementwiseBitwiseAndScalarInt64Module_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-1000, high=1000))
class ElementwiseBitwiseAndScalarInt32Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, x):
return torch.bitwise_and(x, 100)
@register_test_case(module_factory=lambda: ElementwiseBitwiseAndScalarInt32Module())
def ElementwiseBitwiseAndScalarInt32Module_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-1000, high=1000).to(torch.int32))