mirror of https://github.com/llvm/torch-mlir
[TorchToLinalg] Fix torch.aten.remainder for negative operands (#3581)
Closes #3575 The PyTorch remainder operator is meant to compute the Python modulus operator entrywise: https://pytorch.org/docs/stable/generated/torch.remainder.html#torch.remainder In python the modulus operator is meant to always return a result with the same sign as the divisor: https://docs.python.org/3/reference/expressions.html#binary-arithmetic-operations In other words, torch.aten.remainder should return a Python-style modulus instead of a C-style modulus. However the remainder operator was simply translated into arith.ModSI or arith.ModF, which both effectively compute the C-style modulus. Now the lowering has been modified so that the modulus operator works properly with negative numbers, both in the dividend, and the divisor.pull/3631/head
parent
c5b3cf299a
commit
d11d6f6fea
|
@ -7,6 +7,7 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||||||
|
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
|
@ -307,6 +308,70 @@ Value createDivModePayload(OpBuilder &b, Location loc,
|
||||||
return quotient;
|
return quotient;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename OpT>
|
||||||
|
Value createRemainderPayload(OpBuilder &b, Location loc,
|
||||||
|
const TypeConverter *converter,
|
||||||
|
ValueRange payloadArgs, OpT op,
|
||||||
|
ArrayRef<Value> operands) {
|
||||||
|
static_assert(
|
||||||
|
llvm::is_one_of<OpT, AtenRemainderScalarOp, AtenRemainderTensorOp>(),
|
||||||
|
"op must be a tensor/scalar remainder op");
|
||||||
|
typename OpT::Adaptor adaptor(operands);
|
||||||
|
Type dtype = cast<RankedTensorType>(converter->convertType(op.getType()))
|
||||||
|
.getElementType();
|
||||||
|
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||||
|
Value rhs = convertScalarToDtype(
|
||||||
|
b, loc,
|
||||||
|
std::is_same_v<OpT, AtenRemainderScalarOp> ? operands[1] : payloadArgs[1],
|
||||||
|
dtype);
|
||||||
|
|
||||||
|
// The remainder op we wish to create would look roughly like this:
|
||||||
|
// rem = a % b
|
||||||
|
// if rem != 0 AND (rem < 0 XOR b < 0) rem += b
|
||||||
|
// This is how python calucates remainders for floats and longs:
|
||||||
|
// https://github.com/python/cpython/blob/2afd1751dd9a35d4ec03b708e3e5cddd72c43f7e/Objects/floatobject.c#L645
|
||||||
|
// https://github.com/python/cpython/blob/2afd1751dd9a35d4ec03b708e3e5cddd72c43f7e/Objects/longobject.c#L3662
|
||||||
|
Value result;
|
||||||
|
if (isa<mlir::FloatType>(dtype)) {
|
||||||
|
Value remainder = b.create<arith::RemFOp>(loc, lhs, rhs);
|
||||||
|
|
||||||
|
Value zero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
|
||||||
|
Value remainderNotEqualToZero = b.create<arith::CmpFOp>(
|
||||||
|
loc, arith::CmpFPredicate::ONE, remainder, zero);
|
||||||
|
Value otherLessThanZero =
|
||||||
|
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, rhs, zero);
|
||||||
|
Value remainderLessThanZero = b.create<arith::CmpFOp>(
|
||||||
|
loc, arith::CmpFPredicate::OLT, remainder, zero);
|
||||||
|
Value xorCondition =
|
||||||
|
b.create<arith::XOrIOp>(loc, otherLessThanZero, remainderLessThanZero);
|
||||||
|
Value condition =
|
||||||
|
b.create<arith::AndIOp>(loc, remainderNotEqualToZero, xorCondition);
|
||||||
|
Value fixedRemainder = b.create<arith::AddFOp>(loc, remainder, rhs);
|
||||||
|
result =
|
||||||
|
b.create<arith::SelectOp>(loc, condition, fixedRemainder, remainder);
|
||||||
|
} else {
|
||||||
|
assert(dtype.isInteger() &&
|
||||||
|
"dtype should be a float or integer (signless or signed)");
|
||||||
|
Value remainder = b.create<arith::RemSIOp>(loc, lhs, rhs);
|
||||||
|
|
||||||
|
Value zero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
|
||||||
|
Value remainderNotEqualToZero =
|
||||||
|
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, remainder, zero);
|
||||||
|
Value otherLessThanZero =
|
||||||
|
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, rhs, zero);
|
||||||
|
Value remainderLessThanZero = b.create<arith::CmpIOp>(
|
||||||
|
loc, arith::CmpIPredicate::slt, remainder, zero);
|
||||||
|
Value xorCondition =
|
||||||
|
b.create<arith::XOrIOp>(loc, otherLessThanZero, remainderLessThanZero);
|
||||||
|
Value condition =
|
||||||
|
b.create<arith::AndIOp>(loc, remainderNotEqualToZero, xorCondition);
|
||||||
|
Value fixedRemainder = b.create<arith::AddIOp>(loc, remainder, rhs);
|
||||||
|
result =
|
||||||
|
b.create<arith::SelectOp>(loc, condition, fixedRemainder, remainder);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
static Value createLinalgPayloadCalculationForElementwiseOp(
|
static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
OpBuilder &b, Location loc, const TypeConverter *converter,
|
OpBuilder &b, Location loc, const TypeConverter *converter,
|
||||||
ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) {
|
ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) {
|
||||||
|
@ -1188,44 +1253,12 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return b.create<arith::DivFOp>(loc, self, other);
|
return b.create<arith::DivFOp>(loc, self, other);
|
||||||
}
|
}
|
||||||
if (auto remScalar = dyn_cast<AtenRemainderScalarOp>(op)) {
|
if (auto remScalar = dyn_cast<AtenRemainderScalarOp>(op)) {
|
||||||
Type newResultType =
|
return createRemainderPayload(b, loc, converter, payloadArgs, remScalar,
|
||||||
cast<RankedTensorType>(converter->convertType(remScalar.getType()))
|
operands);
|
||||||
.getElementType();
|
|
||||||
|
|
||||||
Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType);
|
|
||||||
Value other = convertScalarToDtype(b, loc, operands[1], newResultType);
|
|
||||||
Value result;
|
|
||||||
|
|
||||||
if (isa<mlir::FloatType>(newResultType)) {
|
|
||||||
result = b.create<arith::RemFOp>(loc, self, other);
|
|
||||||
} else if (isa<mlir::IntegerType>(newResultType)) {
|
|
||||||
result = b.create<arith::RemSIOp>(loc, self, other);
|
|
||||||
} else {
|
|
||||||
remScalar.emitError(
|
|
||||||
"Unsupported type encountered for AtenRemainderScalarOp.");
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
if (auto remTensor = dyn_cast<AtenRemainderTensorOp>(op)) {
|
if (auto remTensor = dyn_cast<AtenRemainderTensorOp>(op)) {
|
||||||
Type newResultType =
|
return createRemainderPayload(b, loc, converter, payloadArgs, remTensor,
|
||||||
cast<RankedTensorType>(converter->convertType(remTensor.getType()))
|
operands);
|
||||||
.getElementType();
|
|
||||||
|
|
||||||
Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType);
|
|
||||||
Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType);
|
|
||||||
Value result;
|
|
||||||
|
|
||||||
if (isa<mlir::FloatType>(newResultType)) {
|
|
||||||
result = b.create<arith::RemFOp>(loc, self, other);
|
|
||||||
} else if (isa<mlir::IntegerType>(newResultType)) {
|
|
||||||
result = b.create<arith::RemSIOp>(loc, self, other);
|
|
||||||
} else {
|
|
||||||
remTensor.emitError(
|
|
||||||
"Unsupported type encountered for AtenRemainderTensorOp.");
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
if (auto fmod = dyn_cast<AtenFmodTensorOp>(op)) {
|
if (auto fmod = dyn_cast<AtenFmodTensorOp>(op)) {
|
||||||
Type newResultType =
|
Type newResultType =
|
||||||
|
|
|
@ -1080,6 +1080,7 @@ STABLEHLO_PASS_SET = {
|
||||||
"ElementwiseReciprocalModule_basic",
|
"ElementwiseReciprocalModule_basic",
|
||||||
"ElementwiseReluModule_basic",
|
"ElementwiseReluModule_basic",
|
||||||
"ElementwiseRemainderTensorModule_Float_basic",
|
"ElementwiseRemainderTensorModule_Float_basic",
|
||||||
|
"ElementwiseRemainderTensorModule_Float_NegativeDividend_basic",
|
||||||
"ElementwiseRemainderTensorModule_Int_Float_basic",
|
"ElementwiseRemainderTensorModule_Int_Float_basic",
|
||||||
"ElementwiseRemainderTensorModule_Int_basic",
|
"ElementwiseRemainderTensorModule_Int_basic",
|
||||||
"ElementwiseRreluEvalStaticModule_basic",
|
"ElementwiseRreluEvalStaticModule_basic",
|
||||||
|
@ -1801,6 +1802,10 @@ TOSA_PASS_SET = {
|
||||||
"ElementwiseReciprocalModule_basic",
|
"ElementwiseReciprocalModule_basic",
|
||||||
"ElementwiseRelu6Module_basic",
|
"ElementwiseRelu6Module_basic",
|
||||||
"ElementwiseReluModule_basic",
|
"ElementwiseReluModule_basic",
|
||||||
|
"ElementwiseRemainderScalarModule_Float_NegativeDividend_basic",
|
||||||
|
"ElementwiseRemainderScalarModule_Float_NegativeDivisor_basic",
|
||||||
|
"ElementwiseRemainderScalarModule_Int_Float_NegativeDividend_basic",
|
||||||
|
"ElementwiseRemainderScalarModule_Int_Float_NegativeDivisor_basic",
|
||||||
"ElementwiseRemainderScalarModule_Float_basic",
|
"ElementwiseRemainderScalarModule_Float_basic",
|
||||||
"ElementwiseRemainderScalarModule_Int_Float_basic",
|
"ElementwiseRemainderScalarModule_Int_Float_basic",
|
||||||
"ElementwiseRemainderScalarModule_Int_basic",
|
"ElementwiseRemainderScalarModule_Int_basic",
|
||||||
|
@ -2491,6 +2496,8 @@ ONNX_XFAIL_SET = {
|
||||||
"ElementwiseQuantizePerTensorModule_basic",
|
"ElementwiseQuantizePerTensorModule_basic",
|
||||||
"ElementwiseQuantizePerTensorUIntModule_basic",
|
"ElementwiseQuantizePerTensorUIntModule_basic",
|
||||||
"ElementwiseRemainderTensorModule_Int_basic",
|
"ElementwiseRemainderTensorModule_Int_basic",
|
||||||
|
"ElementwiseRemainderTensorModule_Int_NegativeDividend_basic",
|
||||||
|
"ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic",
|
||||||
"ElementwiseSgnModule_basic",
|
"ElementwiseSgnModule_basic",
|
||||||
"EmptyStridedModule_basic",
|
"EmptyStridedModule_basic",
|
||||||
"EmptyStridedSizeIntStrideModule_basic",
|
"EmptyStridedSizeIntStrideModule_basic",
|
||||||
|
|
|
@ -3285,6 +3285,60 @@ def ElementwiseRemainderScalarModule_Int_Float_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseRemainderScalarModule_Int_Float_NegativeDividend(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1], torch.int32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.remainder(x, 5.0)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ElementwiseRemainderScalarModule_Int_Float_NegativeDividend()
|
||||||
|
)
|
||||||
|
def ElementwiseRemainderScalarModule_Int_Float_NegativeDividend_basic(
|
||||||
|
module, tu: TestUtils
|
||||||
|
):
|
||||||
|
module.forward(tu.randint(30, low=-10, high=10).to(torch.int32))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseRemainderScalarModule_Int_Float_NegativeDivisor(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1], torch.int32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.remainder(x, -5.0)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ElementwiseRemainderScalarModule_Int_Float_NegativeDivisor()
|
||||||
|
)
|
||||||
|
def ElementwiseRemainderScalarModule_Int_Float_NegativeDivisor_basic(
|
||||||
|
module, tu: TestUtils
|
||||||
|
):
|
||||||
|
module.forward(tu.randint(30, low=-10, high=-1).to(torch.int32))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseRemainderScalarModule_Float(torch.nn.Module):
|
class ElementwiseRemainderScalarModule_Float(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -3308,6 +3362,58 @@ def ElementwiseRemainderScalarModule_Float_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseRemainderScalarModule_Float_NegativeDividend(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.remainder(x, 5.0)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ElementwiseRemainderScalarModule_Float_NegativeDividend()
|
||||||
|
)
|
||||||
|
def ElementwiseRemainderScalarModule_Float_NegativeDividend_basic(
|
||||||
|
module, tu: TestUtils
|
||||||
|
):
|
||||||
|
module.forward(tu.rand(10, 3, low=-10.0, high=10.0))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseRemainderScalarModule_Float_NegativeDivisor(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.remainder(x, -5.0)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ElementwiseRemainderScalarModule_Float_NegativeDivisor()
|
||||||
|
)
|
||||||
|
def ElementwiseRemainderScalarModule_Float_NegativeDivisor_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(10, 3, low=-10.0, high=10.0))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseRemainderScalarModule_Int(torch.nn.Module):
|
class ElementwiseRemainderScalarModule_Int(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -3331,6 +3437,56 @@ def ElementwiseRemainderScalarModule_Int_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseRemainderScalarModule_Int_NegativeDividend(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.remainder(x, 5)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ElementwiseRemainderScalarModule_Int_NegativeDividend()
|
||||||
|
)
|
||||||
|
def ElementwiseRemainderScalarModule_Int_NegativeDividend_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(3, 2, low=-10, high=10).to(torch.int32))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseRemainderScalarModule_Int_NegativeDivisor(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.remainder(x, -5)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ElementwiseRemainderScalarModule_Int_NegativeDivisor()
|
||||||
|
)
|
||||||
|
def ElementwiseRemainderScalarModule_Int_NegativeDivisor_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(3, 2, low=-10, high=10).to(torch.int32))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseRemainderScalarModule_Bool(torch.nn.Module):
|
class ElementwiseRemainderScalarModule_Bool(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -3354,6 +3510,31 @@ def ElementwiseRemainderScalarModule_Bool_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseRemainderScalarModule_Bool_NegativeDivisor(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1], torch.bool, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.remainder(x, -3)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ElementwiseRemainderScalarModule_Bool_NegativeDivisor()
|
||||||
|
)
|
||||||
|
def ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.tensor([True, False, True, True, True]))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseFmodTensor_Float(torch.nn.Module):
|
class ElementwiseFmodTensor_Float(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -3415,7 +3596,9 @@ def ElementwiseFmodTensor_Int_basic(module, tu: TestUtils):
|
||||||
tu.randint(100, low=0, high=1000).to(torch.int32),
|
tu.randint(100, low=0, high=1000).to(torch.int32),
|
||||||
tu.randint(100, low=1, high=1000).to(torch.int32),
|
tu.randint(100, low=1, high=1000).to(torch.int32),
|
||||||
)
|
)
|
||||||
# ==============================================================================
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseRemainderTensorModule_Int_Float(torch.nn.Module):
|
class ElementwiseRemainderTensorModule_Int_Float(torch.nn.Module):
|
||||||
|
@ -3442,6 +3625,67 @@ def ElementwiseRemainderTensorModule_Int_Float_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseRemainderTensorModule_Int_Float_NegativeDividend(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int32, True),
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.remainder(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ElementwiseRemainderTensorModule_Int_Float_NegativeDividend()
|
||||||
|
)
|
||||||
|
def ElementwiseRemainderTensorModule_Int_Float_NegativeDividend_basic(
|
||||||
|
module, tu: TestUtils
|
||||||
|
):
|
||||||
|
module.forward(
|
||||||
|
tu.randint(3, 4, low=-10, high=10).to(torch.int32), tu.rand(3, 4, high=10)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int32, True),
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.remainder(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor()
|
||||||
|
)
|
||||||
|
def ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic(
|
||||||
|
module, tu: TestUtils
|
||||||
|
):
|
||||||
|
module.forward(
|
||||||
|
tu.randint(3, 4, low=-10, high=10).to(torch.int32),
|
||||||
|
tu.rand(3, 4, low=-10, high=-1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseRemainderTensorModule_Float(torch.nn.Module):
|
class ElementwiseRemainderTensorModule_Float(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -3466,6 +3710,60 @@ def ElementwiseRemainderTensorModule_Float_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseRemainderTensorModule_Float_NegativeDividend(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.remainder(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ElementwiseRemainderTensorModule_Float_NegativeDividend()
|
||||||
|
)
|
||||||
|
def ElementwiseRemainderTensorModule_Float_NegativeDividend_basic(
|
||||||
|
module, tu: TestUtils
|
||||||
|
):
|
||||||
|
module.forward(tu.rand(3, 4, high=10), tu.rand(3, 4, high=10))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseRemainderTensorModule_Float_NegativeDivisor(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.remainder(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ElementwiseRemainderTensorModule_Float_NegativeDivisor()
|
||||||
|
)
|
||||||
|
def ElementwiseRemainderTensorModule_Float_NegativeDivisor_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4, high=10), tu.rand(3, 4, low=-10, high=-1))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseRemainderTensorModule_Int(torch.nn.Module):
|
class ElementwiseRemainderTensorModule_Int(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -3493,6 +3791,64 @@ def ElementwiseRemainderTensorModule_Int_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseRemainderTensorModule_Int_NegativeDividend(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int32, True),
|
||||||
|
([-1, -1], torch.int32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.remainder(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ElementwiseRemainderTensorModule_Int_NegativeDividend()
|
||||||
|
)
|
||||||
|
def ElementwiseRemainderTensorModule_Int_NegativeDividend_basic(module, tu: TestUtils):
|
||||||
|
module.forward(
|
||||||
|
tu.randint(3, 4, low=-10, high=10, dtype=torch.int32),
|
||||||
|
tu.randint(3, 4, high=10, dtype=torch.int32),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseRemainderTensorModule_Int_NegativeDivisor(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int32, True),
|
||||||
|
([-1, -1], torch.int32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.remainder(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ElementwiseRemainderTensorModule_Int_NegativeDivisor()
|
||||||
|
)
|
||||||
|
def ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic(module, tu: TestUtils):
|
||||||
|
module.forward(
|
||||||
|
tu.randint(3, 4, low=-10, high=10, dtype=torch.int32),
|
||||||
|
tu.randint(3, 4, low=-10, high=-1, dtype=torch.int32),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseDivTensorFloatModule(torch.nn.Module):
|
class ElementwiseDivTensorFloatModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue