Fix modulus calculation in LCG algorithm of refbackend (#1658)

The current implementation sets the `nextSeed` value to `temp & 127`,
which is wrong. The last step of the LCG algorithm for the multiplier
and increment chosen should be `temp % 2^{64} = temp & (1 <<
63)`. However, because we are dealing with i64 values, the modulus
operation happens automatically, so it is not needed.

See Donald Knuth's values for LCG here:
https://en.wikipedia.org/wiki/Linear_congruential_generator
pull/1633/head
Ramiro Leal-Cavazos 2022-11-30 08:46:52 -08:00 committed by GitHub
parent 44b185a46b
commit 0983a7f93a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 7 additions and 12 deletions

View File

@ -257,10 +257,7 @@ static Value lowerGetNextSeed(OpBuilder &b, Location loc) {
loc, b.getI64IntegerAttr(1442695040888963407)); loc, b.getI64IntegerAttr(1442695040888963407));
// temp = multiplier * currentSeed + incrementStep // temp = multiplier * currentSeed + incrementStep
Value mul = b.create<arith::MulIOp>(loc, currentSeed, multiplier); Value mul = b.create<arith::MulIOp>(loc, currentSeed, multiplier);
Value temp = b.create<arith::AddIOp>(loc, mul, incrementStep); Value nextSeed = b.create<arith::AddIOp>(loc, mul, incrementStep);
// temp mod 64 = temp & 63
Value cst127 = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(127));
Value nextSeed = b.create<arith::AndIOp>(loc, temp, cst127);
b.create<memref::StoreOp>(loc, nextSeed, globalVar); b.create<memref::StoreOp>(loc, nextSeed, globalVar);
return nextSeed; return nextSeed;
} }

View File

@ -1466,7 +1466,7 @@ class DropoutTrainModule(torch.nn.Module):
@register_test_case(module_factory=lambda: DropoutTrainModule()) @register_test_case(module_factory=lambda: DropoutTrainModule())
def DropoutTrainModule_basic(module, tu: TestUtils): def DropoutTrainModule_basic(module, tu: TestUtils):
module.forward(tu.rand(256, 256)) module.forward(tu.rand(1024, 1536))
# ============================================================================== # ==============================================================================

View File

@ -144,8 +144,8 @@ class BernoulliTensorModule(torch.nn.Module):
@export @export
@annotate_args([ @annotate_args([
None, None,
([-1, -1, -1], torch.float64, True), ([-1, -1], torch.float64, True),
([-1, -1, -1], torch.float64, True), ([-1, -1], torch.float64, True),
]) ])
def forward(self, x, px): def forward(self, x, px):
a = torch.ops.aten.bernoulli_(x, px) a = torch.ops.aten.bernoulli_(x, px)
@ -157,8 +157,8 @@ class BernoulliTensorModule(torch.nn.Module):
@register_test_case(module_factory=lambda: BernoulliTensorModule()) @register_test_case(module_factory=lambda: BernoulliTensorModule())
def BernoulliTensorModule_basic(module, tu: TestUtils): def BernoulliTensorModule_basic(module, tu: TestUtils):
module.forward( module.forward(
tu.rand(512, 512, 2).double(), tu.rand(512, 512).double(),
tu.rand(512, 512, 2).double()) tu.rand(512, 512).double())
# ============================================================================== # ==============================================================================

View File

@ -7,9 +7,7 @@
// CHECK: %[[MULTIPLIER:.*]] = arith.constant 6364136223846793005 : i64 // CHECK: %[[MULTIPLIER:.*]] = arith.constant 6364136223846793005 : i64
// CHECK: %[[INC:.*]] = arith.constant 1442695040888963407 : i64 // CHECK: %[[INC:.*]] = arith.constant 1442695040888963407 : i64
// CHECK: %[[MUL:.*]] = arith.muli %[[SEED]], %[[MULTIPLIER]] : i64 // CHECK: %[[MUL:.*]] = arith.muli %[[SEED]], %[[MULTIPLIER]] : i64
// CHECK: %[[TEMP:.*]] = arith.addi %[[MUL]], %[[INC]] : i64 // CHECK: %[[NEXT_SEED:.*]] = arith.addi %[[MUL]], %[[INC]] : i64
// CHECK: %[[CST127:.*]] = arith.constant 127 : i64
// CHECK: %[[NEXT_SEED:.*]] = arith.andi %[[TEMP]], %[[CST127]] : i64
// CHECK: memref.store %[[NEXT_SEED]], %[[MEMREF]][] : memref<i64> // CHECK: memref.store %[[NEXT_SEED]], %[[MEMREF]][] : memref<i64>
// CHECK: return %[[NEXT_SEED]] : i64 // CHECK: return %[[NEXT_SEED]] : i64
module { module {