[TorchToLinalg] Fix torch.aten.remainder for negative operands (#3581)

Closes #3575

The PyTorch remainder operator is meant to compute the Python modulus
operator entrywise:

https://pytorch.org/docs/stable/generated/torch.remainder.html#torch.remainder

In python the modulus operator is meant to always return a result with
the same sign as the divisor:

https://docs.python.org/3/reference/expressions.html#binary-arithmetic-operations

In other words, torch.aten.remainder should return a Python-style
modulus instead of a C-style modulus. However the remainder operator was
simply translated into arith.ModSI or arith.ModF, which both effectively
compute the C-style modulus. Now the lowering has been modified so that
the modulus operator works properly with negative numbers, both in the
dividend, and the divisor.
pull/3631/head
pkapris-syrmia 2024-08-13 17:47:21 +02:00 committed by GitHub
parent c5b3cf299a
commit d11d6f6fea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 433 additions and 37 deletions

View File

@ -7,6 +7,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/BuiltinTypes.h"
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "PopulatePatterns.h"
@ -307,6 +308,70 @@ Value createDivModePayload(OpBuilder &b, Location loc,
return quotient;
}
template <typename OpT>
Value createRemainderPayload(OpBuilder &b, Location loc,
const TypeConverter *converter,
ValueRange payloadArgs, OpT op,
ArrayRef<Value> operands) {
static_assert(
llvm::is_one_of<OpT, AtenRemainderScalarOp, AtenRemainderTensorOp>(),
"op must be a tensor/scalar remainder op");
typename OpT::Adaptor adaptor(operands);
Type dtype = cast<RankedTensorType>(converter->convertType(op.getType()))
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(
b, loc,
std::is_same_v<OpT, AtenRemainderScalarOp> ? operands[1] : payloadArgs[1],
dtype);
// The remainder op we wish to create would look roughly like this:
// rem = a % b
// if rem != 0 AND (rem < 0 XOR b < 0) rem += b
// This is how python calucates remainders for floats and longs:
// https://github.com/python/cpython/blob/2afd1751dd9a35d4ec03b708e3e5cddd72c43f7e/Objects/floatobject.c#L645
// https://github.com/python/cpython/blob/2afd1751dd9a35d4ec03b708e3e5cddd72c43f7e/Objects/longobject.c#L3662
Value result;
if (isa<mlir::FloatType>(dtype)) {
Value remainder = b.create<arith::RemFOp>(loc, lhs, rhs);
Value zero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
Value remainderNotEqualToZero = b.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::ONE, remainder, zero);
Value otherLessThanZero =
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, rhs, zero);
Value remainderLessThanZero = b.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::OLT, remainder, zero);
Value xorCondition =
b.create<arith::XOrIOp>(loc, otherLessThanZero, remainderLessThanZero);
Value condition =
b.create<arith::AndIOp>(loc, remainderNotEqualToZero, xorCondition);
Value fixedRemainder = b.create<arith::AddFOp>(loc, remainder, rhs);
result =
b.create<arith::SelectOp>(loc, condition, fixedRemainder, remainder);
} else {
assert(dtype.isInteger() &&
"dtype should be a float or integer (signless or signed)");
Value remainder = b.create<arith::RemSIOp>(loc, lhs, rhs);
Value zero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
Value remainderNotEqualToZero =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, remainder, zero);
Value otherLessThanZero =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, rhs, zero);
Value remainderLessThanZero = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, remainder, zero);
Value xorCondition =
b.create<arith::XOrIOp>(loc, otherLessThanZero, remainderLessThanZero);
Value condition =
b.create<arith::AndIOp>(loc, remainderNotEqualToZero, xorCondition);
Value fixedRemainder = b.create<arith::AddIOp>(loc, remainder, rhs);
result =
b.create<arith::SelectOp>(loc, condition, fixedRemainder, remainder);
}
return result;
}
static Value createLinalgPayloadCalculationForElementwiseOp(
OpBuilder &b, Location loc, const TypeConverter *converter,
ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) {
@ -1188,44 +1253,12 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<arith::DivFOp>(loc, self, other);
}
if (auto remScalar = dyn_cast<AtenRemainderScalarOp>(op)) {
Type newResultType =
cast<RankedTensorType>(converter->convertType(remScalar.getType()))
.getElementType();
Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType);
Value other = convertScalarToDtype(b, loc, operands[1], newResultType);
Value result;
if (isa<mlir::FloatType>(newResultType)) {
result = b.create<arith::RemFOp>(loc, self, other);
} else if (isa<mlir::IntegerType>(newResultType)) {
result = b.create<arith::RemSIOp>(loc, self, other);
} else {
remScalar.emitError(
"Unsupported type encountered for AtenRemainderScalarOp.");
}
return result;
return createRemainderPayload(b, loc, converter, payloadArgs, remScalar,
operands);
}
if (auto remTensor = dyn_cast<AtenRemainderTensorOp>(op)) {
Type newResultType =
cast<RankedTensorType>(converter->convertType(remTensor.getType()))
.getElementType();
Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType);
Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType);
Value result;
if (isa<mlir::FloatType>(newResultType)) {
result = b.create<arith::RemFOp>(loc, self, other);
} else if (isa<mlir::IntegerType>(newResultType)) {
result = b.create<arith::RemSIOp>(loc, self, other);
} else {
remTensor.emitError(
"Unsupported type encountered for AtenRemainderTensorOp.");
}
return result;
return createRemainderPayload(b, loc, converter, payloadArgs, remTensor,
operands);
}
if (auto fmod = dyn_cast<AtenFmodTensorOp>(op)) {
Type newResultType =

View File

@ -1080,6 +1080,7 @@ STABLEHLO_PASS_SET = {
"ElementwiseReciprocalModule_basic",
"ElementwiseReluModule_basic",
"ElementwiseRemainderTensorModule_Float_basic",
"ElementwiseRemainderTensorModule_Float_NegativeDividend_basic",
"ElementwiseRemainderTensorModule_Int_Float_basic",
"ElementwiseRemainderTensorModule_Int_basic",
"ElementwiseRreluEvalStaticModule_basic",
@ -1801,6 +1802,10 @@ TOSA_PASS_SET = {
"ElementwiseReciprocalModule_basic",
"ElementwiseRelu6Module_basic",
"ElementwiseReluModule_basic",
"ElementwiseRemainderScalarModule_Float_NegativeDividend_basic",
"ElementwiseRemainderScalarModule_Float_NegativeDivisor_basic",
"ElementwiseRemainderScalarModule_Int_Float_NegativeDividend_basic",
"ElementwiseRemainderScalarModule_Int_Float_NegativeDivisor_basic",
"ElementwiseRemainderScalarModule_Float_basic",
"ElementwiseRemainderScalarModule_Int_Float_basic",
"ElementwiseRemainderScalarModule_Int_basic",
@ -2491,6 +2496,8 @@ ONNX_XFAIL_SET = {
"ElementwiseQuantizePerTensorModule_basic",
"ElementwiseQuantizePerTensorUIntModule_basic",
"ElementwiseRemainderTensorModule_Int_basic",
"ElementwiseRemainderTensorModule_Int_NegativeDividend_basic",
"ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic",
"ElementwiseSgnModule_basic",
"EmptyStridedModule_basic",
"EmptyStridedSizeIntStrideModule_basic",

View File

@ -3285,6 +3285,60 @@ def ElementwiseRemainderScalarModule_Int_Float_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseRemainderScalarModule_Int_Float_NegativeDividend(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1], torch.int32, True),
]
)
def forward(self, x):
return torch.remainder(x, 5.0)
@register_test_case(
module_factory=lambda: ElementwiseRemainderScalarModule_Int_Float_NegativeDividend()
)
def ElementwiseRemainderScalarModule_Int_Float_NegativeDividend_basic(
module, tu: TestUtils
):
module.forward(tu.randint(30, low=-10, high=10).to(torch.int32))
# ==============================================================================
class ElementwiseRemainderScalarModule_Int_Float_NegativeDivisor(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1], torch.int32, True),
]
)
def forward(self, x):
return torch.remainder(x, -5.0)
@register_test_case(
module_factory=lambda: ElementwiseRemainderScalarModule_Int_Float_NegativeDivisor()
)
def ElementwiseRemainderScalarModule_Int_Float_NegativeDivisor_basic(
module, tu: TestUtils
):
module.forward(tu.randint(30, low=-10, high=-1).to(torch.int32))
# ==============================================================================
class ElementwiseRemainderScalarModule_Float(torch.nn.Module):
def __init__(self):
super().__init__()
@ -3308,6 +3362,58 @@ def ElementwiseRemainderScalarModule_Float_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseRemainderScalarModule_Float_NegativeDividend(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
]
)
def forward(self, x):
return torch.remainder(x, 5.0)
@register_test_case(
module_factory=lambda: ElementwiseRemainderScalarModule_Float_NegativeDividend()
)
def ElementwiseRemainderScalarModule_Float_NegativeDividend_basic(
module, tu: TestUtils
):
module.forward(tu.rand(10, 3, low=-10.0, high=10.0))
# ==============================================================================
class ElementwiseRemainderScalarModule_Float_NegativeDivisor(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
]
)
def forward(self, x):
return torch.remainder(x, -5.0)
@register_test_case(
module_factory=lambda: ElementwiseRemainderScalarModule_Float_NegativeDivisor()
)
def ElementwiseRemainderScalarModule_Float_NegativeDivisor_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 3, low=-10.0, high=10.0))
# ==============================================================================
class ElementwiseRemainderScalarModule_Int(torch.nn.Module):
def __init__(self):
super().__init__()
@ -3331,6 +3437,56 @@ def ElementwiseRemainderScalarModule_Int_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseRemainderScalarModule_Int_NegativeDividend(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1], torch.int32, True),
]
)
def forward(self, x):
return torch.remainder(x, 5)
@register_test_case(
module_factory=lambda: ElementwiseRemainderScalarModule_Int_NegativeDividend()
)
def ElementwiseRemainderScalarModule_Int_NegativeDividend_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 2, low=-10, high=10).to(torch.int32))
# ==============================================================================
class ElementwiseRemainderScalarModule_Int_NegativeDivisor(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1], torch.int32, True),
]
)
def forward(self, x):
return torch.remainder(x, -5)
@register_test_case(
module_factory=lambda: ElementwiseRemainderScalarModule_Int_NegativeDivisor()
)
def ElementwiseRemainderScalarModule_Int_NegativeDivisor_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 2, low=-10, high=10).to(torch.int32))
# ==============================================================================
class ElementwiseRemainderScalarModule_Bool(torch.nn.Module):
def __init__(self):
super().__init__()
@ -3354,6 +3510,31 @@ def ElementwiseRemainderScalarModule_Bool_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseRemainderScalarModule_Bool_NegativeDivisor(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1], torch.bool, True),
]
)
def forward(self, x):
return torch.remainder(x, -3)
@register_test_case(
module_factory=lambda: ElementwiseRemainderScalarModule_Bool_NegativeDivisor()
)
def ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic(module, tu: TestUtils):
module.forward(torch.tensor([True, False, True, True, True]))
# ==============================================================================
class ElementwiseFmodTensor_Float(torch.nn.Module):
def __init__(self):
super().__init__()
@ -3415,6 +3596,8 @@ def ElementwiseFmodTensor_Int_basic(module, tu: TestUtils):
tu.randint(100, low=0, high=1000).to(torch.int32),
tu.randint(100, low=1, high=1000).to(torch.int32),
)
# ==============================================================================
@ -3442,6 +3625,67 @@ def ElementwiseRemainderTensorModule_Int_Float_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseRemainderTensorModule_Int_Float_NegativeDividend(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1], torch.int32, True),
([-1, -1], torch.float32, True),
]
)
def forward(self, a, b):
return torch.remainder(a, b)
@register_test_case(
module_factory=lambda: ElementwiseRemainderTensorModule_Int_Float_NegativeDividend()
)
def ElementwiseRemainderTensorModule_Int_Float_NegativeDividend_basic(
module, tu: TestUtils
):
module.forward(
tu.randint(3, 4, low=-10, high=10).to(torch.int32), tu.rand(3, 4, high=10)
)
# ==============================================================================
class ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1], torch.int32, True),
([-1, -1], torch.float32, True),
]
)
def forward(self, a, b):
return torch.remainder(a, b)
@register_test_case(
module_factory=lambda: ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor()
)
def ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic(
module, tu: TestUtils
):
module.forward(
tu.randint(3, 4, low=-10, high=10).to(torch.int32),
tu.rand(3, 4, low=-10, high=-1),
)
# ==============================================================================
class ElementwiseRemainderTensorModule_Float(torch.nn.Module):
def __init__(self):
super().__init__()
@ -3466,6 +3710,60 @@ def ElementwiseRemainderTensorModule_Float_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseRemainderTensorModule_Float_NegativeDividend(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
]
)
def forward(self, a, b):
return torch.remainder(a, b)
@register_test_case(
module_factory=lambda: ElementwiseRemainderTensorModule_Float_NegativeDividend()
)
def ElementwiseRemainderTensorModule_Float_NegativeDividend_basic(
module, tu: TestUtils
):
module.forward(tu.rand(3, 4, high=10), tu.rand(3, 4, high=10))
# ==============================================================================
class ElementwiseRemainderTensorModule_Float_NegativeDivisor(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
]
)
def forward(self, a, b):
return torch.remainder(a, b)
@register_test_case(
module_factory=lambda: ElementwiseRemainderTensorModule_Float_NegativeDivisor()
)
def ElementwiseRemainderTensorModule_Float_NegativeDivisor_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, high=10), tu.rand(3, 4, low=-10, high=-1))
# ==============================================================================
class ElementwiseRemainderTensorModule_Int(torch.nn.Module):
def __init__(self):
super().__init__()
@ -3493,6 +3791,64 @@ def ElementwiseRemainderTensorModule_Int_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseRemainderTensorModule_Int_NegativeDividend(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1], torch.int32, True),
([-1, -1], torch.int32, True),
]
)
def forward(self, a, b):
return torch.remainder(a, b)
@register_test_case(
module_factory=lambda: ElementwiseRemainderTensorModule_Int_NegativeDividend()
)
def ElementwiseRemainderTensorModule_Int_NegativeDividend_basic(module, tu: TestUtils):
module.forward(
tu.randint(3, 4, low=-10, high=10, dtype=torch.int32),
tu.randint(3, 4, high=10, dtype=torch.int32),
)
# ==============================================================================
class ElementwiseRemainderTensorModule_Int_NegativeDivisor(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1], torch.int32, True),
([-1, -1], torch.int32, True),
]
)
def forward(self, a, b):
return torch.remainder(a, b)
@register_test_case(
module_factory=lambda: ElementwiseRemainderTensorModule_Int_NegativeDivisor()
)
def ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic(module, tu: TestUtils):
module.forward(
tu.randint(3, 4, low=-10, high=10, dtype=torch.int32),
tu.randint(3, 4, low=-10, high=-1, dtype=torch.int32),
)
# ==============================================================================
class ElementwiseDivTensorFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()