mirror of https://github.com/llvm/torch-mlir
[TorchToLinalg] Fix torch.aten.remainder for negative operands (#3581)
Closes #3575 The PyTorch remainder operator is meant to compute the Python modulus operator entrywise: https://pytorch.org/docs/stable/generated/torch.remainder.html#torch.remainder In python the modulus operator is meant to always return a result with the same sign as the divisor: https://docs.python.org/3/reference/expressions.html#binary-arithmetic-operations In other words, torch.aten.remainder should return a Python-style modulus instead of a C-style modulus. However the remainder operator was simply translated into arith.ModSI or arith.ModF, which both effectively compute the C-style modulus. Now the lowering has been modified so that the modulus operator works properly with negative numbers, both in the dividend, and the divisor.pull/3631/head
parent
c5b3cf299a
commit
d11d6f6fea
|
@ -7,6 +7,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||||
|
||||
#include "PopulatePatterns.h"
|
||||
|
@ -307,6 +308,70 @@ Value createDivModePayload(OpBuilder &b, Location loc,
|
|||
return quotient;
|
||||
}
|
||||
|
||||
template <typename OpT>
|
||||
Value createRemainderPayload(OpBuilder &b, Location loc,
|
||||
const TypeConverter *converter,
|
||||
ValueRange payloadArgs, OpT op,
|
||||
ArrayRef<Value> operands) {
|
||||
static_assert(
|
||||
llvm::is_one_of<OpT, AtenRemainderScalarOp, AtenRemainderTensorOp>(),
|
||||
"op must be a tensor/scalar remainder op");
|
||||
typename OpT::Adaptor adaptor(operands);
|
||||
Type dtype = cast<RankedTensorType>(converter->convertType(op.getType()))
|
||||
.getElementType();
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
Value rhs = convertScalarToDtype(
|
||||
b, loc,
|
||||
std::is_same_v<OpT, AtenRemainderScalarOp> ? operands[1] : payloadArgs[1],
|
||||
dtype);
|
||||
|
||||
// The remainder op we wish to create would look roughly like this:
|
||||
// rem = a % b
|
||||
// if rem != 0 AND (rem < 0 XOR b < 0) rem += b
|
||||
// This is how python calucates remainders for floats and longs:
|
||||
// https://github.com/python/cpython/blob/2afd1751dd9a35d4ec03b708e3e5cddd72c43f7e/Objects/floatobject.c#L645
|
||||
// https://github.com/python/cpython/blob/2afd1751dd9a35d4ec03b708e3e5cddd72c43f7e/Objects/longobject.c#L3662
|
||||
Value result;
|
||||
if (isa<mlir::FloatType>(dtype)) {
|
||||
Value remainder = b.create<arith::RemFOp>(loc, lhs, rhs);
|
||||
|
||||
Value zero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
|
||||
Value remainderNotEqualToZero = b.create<arith::CmpFOp>(
|
||||
loc, arith::CmpFPredicate::ONE, remainder, zero);
|
||||
Value otherLessThanZero =
|
||||
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, rhs, zero);
|
||||
Value remainderLessThanZero = b.create<arith::CmpFOp>(
|
||||
loc, arith::CmpFPredicate::OLT, remainder, zero);
|
||||
Value xorCondition =
|
||||
b.create<arith::XOrIOp>(loc, otherLessThanZero, remainderLessThanZero);
|
||||
Value condition =
|
||||
b.create<arith::AndIOp>(loc, remainderNotEqualToZero, xorCondition);
|
||||
Value fixedRemainder = b.create<arith::AddFOp>(loc, remainder, rhs);
|
||||
result =
|
||||
b.create<arith::SelectOp>(loc, condition, fixedRemainder, remainder);
|
||||
} else {
|
||||
assert(dtype.isInteger() &&
|
||||
"dtype should be a float or integer (signless or signed)");
|
||||
Value remainder = b.create<arith::RemSIOp>(loc, lhs, rhs);
|
||||
|
||||
Value zero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
|
||||
Value remainderNotEqualToZero =
|
||||
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, remainder, zero);
|
||||
Value otherLessThanZero =
|
||||
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, rhs, zero);
|
||||
Value remainderLessThanZero = b.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, remainder, zero);
|
||||
Value xorCondition =
|
||||
b.create<arith::XOrIOp>(loc, otherLessThanZero, remainderLessThanZero);
|
||||
Value condition =
|
||||
b.create<arith::AndIOp>(loc, remainderNotEqualToZero, xorCondition);
|
||||
Value fixedRemainder = b.create<arith::AddIOp>(loc, remainder, rhs);
|
||||
result =
|
||||
b.create<arith::SelectOp>(loc, condition, fixedRemainder, remainder);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||
OpBuilder &b, Location loc, const TypeConverter *converter,
|
||||
ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) {
|
||||
|
@ -1188,44 +1253,12 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return b.create<arith::DivFOp>(loc, self, other);
|
||||
}
|
||||
if (auto remScalar = dyn_cast<AtenRemainderScalarOp>(op)) {
|
||||
Type newResultType =
|
||||
cast<RankedTensorType>(converter->convertType(remScalar.getType()))
|
||||
.getElementType();
|
||||
|
||||
Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType);
|
||||
Value other = convertScalarToDtype(b, loc, operands[1], newResultType);
|
||||
Value result;
|
||||
|
||||
if (isa<mlir::FloatType>(newResultType)) {
|
||||
result = b.create<arith::RemFOp>(loc, self, other);
|
||||
} else if (isa<mlir::IntegerType>(newResultType)) {
|
||||
result = b.create<arith::RemSIOp>(loc, self, other);
|
||||
} else {
|
||||
remScalar.emitError(
|
||||
"Unsupported type encountered for AtenRemainderScalarOp.");
|
||||
}
|
||||
|
||||
return result;
|
||||
return createRemainderPayload(b, loc, converter, payloadArgs, remScalar,
|
||||
operands);
|
||||
}
|
||||
if (auto remTensor = dyn_cast<AtenRemainderTensorOp>(op)) {
|
||||
Type newResultType =
|
||||
cast<RankedTensorType>(converter->convertType(remTensor.getType()))
|
||||
.getElementType();
|
||||
|
||||
Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType);
|
||||
Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType);
|
||||
Value result;
|
||||
|
||||
if (isa<mlir::FloatType>(newResultType)) {
|
||||
result = b.create<arith::RemFOp>(loc, self, other);
|
||||
} else if (isa<mlir::IntegerType>(newResultType)) {
|
||||
result = b.create<arith::RemSIOp>(loc, self, other);
|
||||
} else {
|
||||
remTensor.emitError(
|
||||
"Unsupported type encountered for AtenRemainderTensorOp.");
|
||||
}
|
||||
|
||||
return result;
|
||||
return createRemainderPayload(b, loc, converter, payloadArgs, remTensor,
|
||||
operands);
|
||||
}
|
||||
if (auto fmod = dyn_cast<AtenFmodTensorOp>(op)) {
|
||||
Type newResultType =
|
||||
|
|
|
@ -1080,6 +1080,7 @@ STABLEHLO_PASS_SET = {
|
|||
"ElementwiseReciprocalModule_basic",
|
||||
"ElementwiseReluModule_basic",
|
||||
"ElementwiseRemainderTensorModule_Float_basic",
|
||||
"ElementwiseRemainderTensorModule_Float_NegativeDividend_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_Float_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_basic",
|
||||
"ElementwiseRreluEvalStaticModule_basic",
|
||||
|
@ -1801,6 +1802,10 @@ TOSA_PASS_SET = {
|
|||
"ElementwiseReciprocalModule_basic",
|
||||
"ElementwiseRelu6Module_basic",
|
||||
"ElementwiseReluModule_basic",
|
||||
"ElementwiseRemainderScalarModule_Float_NegativeDividend_basic",
|
||||
"ElementwiseRemainderScalarModule_Float_NegativeDivisor_basic",
|
||||
"ElementwiseRemainderScalarModule_Int_Float_NegativeDividend_basic",
|
||||
"ElementwiseRemainderScalarModule_Int_Float_NegativeDivisor_basic",
|
||||
"ElementwiseRemainderScalarModule_Float_basic",
|
||||
"ElementwiseRemainderScalarModule_Int_Float_basic",
|
||||
"ElementwiseRemainderScalarModule_Int_basic",
|
||||
|
@ -2491,6 +2496,8 @@ ONNX_XFAIL_SET = {
|
|||
"ElementwiseQuantizePerTensorModule_basic",
|
||||
"ElementwiseQuantizePerTensorUIntModule_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_NegativeDividend_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic",
|
||||
"ElementwiseSgnModule_basic",
|
||||
"EmptyStridedModule_basic",
|
||||
"EmptyStridedSizeIntStrideModule_basic",
|
||||
|
|
|
@ -3285,6 +3285,60 @@ def ElementwiseRemainderScalarModule_Int_Float_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseRemainderScalarModule_Int_Float_NegativeDividend(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.int32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.remainder(x, 5.0)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseRemainderScalarModule_Int_Float_NegativeDividend()
|
||||
)
|
||||
def ElementwiseRemainderScalarModule_Int_Float_NegativeDividend_basic(
|
||||
module, tu: TestUtils
|
||||
):
|
||||
module.forward(tu.randint(30, low=-10, high=10).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseRemainderScalarModule_Int_Float_NegativeDivisor(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.int32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.remainder(x, -5.0)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseRemainderScalarModule_Int_Float_NegativeDivisor()
|
||||
)
|
||||
def ElementwiseRemainderScalarModule_Int_Float_NegativeDivisor_basic(
|
||||
module, tu: TestUtils
|
||||
):
|
||||
module.forward(tu.randint(30, low=-10, high=-1).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseRemainderScalarModule_Float(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -3308,6 +3362,58 @@ def ElementwiseRemainderScalarModule_Float_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseRemainderScalarModule_Float_NegativeDividend(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.remainder(x, 5.0)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseRemainderScalarModule_Float_NegativeDividend()
|
||||
)
|
||||
def ElementwiseRemainderScalarModule_Float_NegativeDividend_basic(
|
||||
module, tu: TestUtils
|
||||
):
|
||||
module.forward(tu.rand(10, 3, low=-10.0, high=10.0))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseRemainderScalarModule_Float_NegativeDivisor(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.remainder(x, -5.0)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseRemainderScalarModule_Float_NegativeDivisor()
|
||||
)
|
||||
def ElementwiseRemainderScalarModule_Float_NegativeDivisor_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 3, low=-10.0, high=10.0))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseRemainderScalarModule_Int(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -3331,6 +3437,56 @@ def ElementwiseRemainderScalarModule_Int_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseRemainderScalarModule_Int_NegativeDividend(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.remainder(x, 5)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseRemainderScalarModule_Int_NegativeDividend()
|
||||
)
|
||||
def ElementwiseRemainderScalarModule_Int_NegativeDividend_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 2, low=-10, high=10).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseRemainderScalarModule_Int_NegativeDivisor(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.remainder(x, -5)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseRemainderScalarModule_Int_NegativeDivisor()
|
||||
)
|
||||
def ElementwiseRemainderScalarModule_Int_NegativeDivisor_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 2, low=-10, high=10).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseRemainderScalarModule_Bool(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -3354,6 +3510,31 @@ def ElementwiseRemainderScalarModule_Bool_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseRemainderScalarModule_Bool_NegativeDivisor(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.bool, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.remainder(x, -3)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseRemainderScalarModule_Bool_NegativeDivisor()
|
||||
)
|
||||
def ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor([True, False, True, True, True]))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseFmodTensor_Float(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -3415,6 +3596,8 @@ def ElementwiseFmodTensor_Int_basic(module, tu: TestUtils):
|
|||
tu.randint(100, low=0, high=1000).to(torch.int32),
|
||||
tu.randint(100, low=1, high=1000).to(torch.int32),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
@ -3442,6 +3625,67 @@ def ElementwiseRemainderTensorModule_Int_Float_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseRemainderTensorModule_Int_Float_NegativeDividend(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, a, b):
|
||||
return torch.remainder(a, b)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseRemainderTensorModule_Int_Float_NegativeDividend()
|
||||
)
|
||||
def ElementwiseRemainderTensorModule_Int_Float_NegativeDividend_basic(
|
||||
module, tu: TestUtils
|
||||
):
|
||||
module.forward(
|
||||
tu.randint(3, 4, low=-10, high=10).to(torch.int32), tu.rand(3, 4, high=10)
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, a, b):
|
||||
return torch.remainder(a, b)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor()
|
||||
)
|
||||
def ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic(
|
||||
module, tu: TestUtils
|
||||
):
|
||||
module.forward(
|
||||
tu.randint(3, 4, low=-10, high=10).to(torch.int32),
|
||||
tu.rand(3, 4, low=-10, high=-1),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseRemainderTensorModule_Float(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -3466,6 +3710,60 @@ def ElementwiseRemainderTensorModule_Float_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseRemainderTensorModule_Float_NegativeDividend(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, a, b):
|
||||
return torch.remainder(a, b)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseRemainderTensorModule_Float_NegativeDividend()
|
||||
)
|
||||
def ElementwiseRemainderTensorModule_Float_NegativeDividend_basic(
|
||||
module, tu: TestUtils
|
||||
):
|
||||
module.forward(tu.rand(3, 4, high=10), tu.rand(3, 4, high=10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseRemainderTensorModule_Float_NegativeDivisor(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, a, b):
|
||||
return torch.remainder(a, b)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseRemainderTensorModule_Float_NegativeDivisor()
|
||||
)
|
||||
def ElementwiseRemainderTensorModule_Float_NegativeDivisor_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, high=10), tu.rand(3, 4, low=-10, high=-1))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseRemainderTensorModule_Int(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -3493,6 +3791,64 @@ def ElementwiseRemainderTensorModule_Int_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseRemainderTensorModule_Int_NegativeDividend(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, a, b):
|
||||
return torch.remainder(a, b)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseRemainderTensorModule_Int_NegativeDividend()
|
||||
)
|
||||
def ElementwiseRemainderTensorModule_Int_NegativeDividend_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.randint(3, 4, low=-10, high=10, dtype=torch.int32),
|
||||
tu.randint(3, 4, high=10, dtype=torch.int32),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseRemainderTensorModule_Int_NegativeDivisor(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, a, b):
|
||||
return torch.remainder(a, b)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseRemainderTensorModule_Int_NegativeDivisor()
|
||||
)
|
||||
def ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.randint(3, 4, low=-10, high=10, dtype=torch.int32),
|
||||
tu.randint(3, 4, low=-10, high=-1, dtype=torch.int32),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseDivTensorFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue