mirror of https://github.com/llvm/torch-mlir
Fix modulus calculation in LCG algorithm of refbackend (#1658)
The current implementation sets the `nextSeed` value to `temp & 127`, which is wrong. The last step of the LCG algorithm for the multiplier and increment chosen should be `temp % 2^{64} = temp & (1 << 63)`. However, because we are dealing with i64 values, the modulus operation happens automatically, so it is not needed. See Donald Knuth's values for LCG here: https://en.wikipedia.org/wiki/Linear_congruential_generatorpull/1633/head
parent
44b185a46b
commit
0983a7f93a
|
@ -257,10 +257,7 @@ static Value lowerGetNextSeed(OpBuilder &b, Location loc) {
|
|||
loc, b.getI64IntegerAttr(1442695040888963407));
|
||||
// temp = multiplier * currentSeed + incrementStep
|
||||
Value mul = b.create<arith::MulIOp>(loc, currentSeed, multiplier);
|
||||
Value temp = b.create<arith::AddIOp>(loc, mul, incrementStep);
|
||||
// temp mod 64 = temp & 63
|
||||
Value cst127 = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(127));
|
||||
Value nextSeed = b.create<arith::AndIOp>(loc, temp, cst127);
|
||||
Value nextSeed = b.create<arith::AddIOp>(loc, mul, incrementStep);
|
||||
b.create<memref::StoreOp>(loc, nextSeed, globalVar);
|
||||
return nextSeed;
|
||||
}
|
||||
|
|
|
@ -1466,7 +1466,7 @@ class DropoutTrainModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: DropoutTrainModule())
|
||||
def DropoutTrainModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(256, 256))
|
||||
module.forward(tu.rand(1024, 1536))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
|
|
@ -144,8 +144,8 @@ class BernoulliTensorModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1], torch.float64, True),
|
||||
([-1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, x, px):
|
||||
a = torch.ops.aten.bernoulli_(x, px)
|
||||
|
@ -157,8 +157,8 @@ class BernoulliTensorModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: BernoulliTensorModule())
|
||||
def BernoulliTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(512, 512, 2).double(),
|
||||
tu.rand(512, 512, 2).double())
|
||||
tu.rand(512, 512).double(),
|
||||
tu.rand(512, 512).double())
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
|
|
@ -7,9 +7,7 @@
|
|||
// CHECK: %[[MULTIPLIER:.*]] = arith.constant 6364136223846793005 : i64
|
||||
// CHECK: %[[INC:.*]] = arith.constant 1442695040888963407 : i64
|
||||
// CHECK: %[[MUL:.*]] = arith.muli %[[SEED]], %[[MULTIPLIER]] : i64
|
||||
// CHECK: %[[TEMP:.*]] = arith.addi %[[MUL]], %[[INC]] : i64
|
||||
// CHECK: %[[CST127:.*]] = arith.constant 127 : i64
|
||||
// CHECK: %[[NEXT_SEED:.*]] = arith.andi %[[TEMP]], %[[CST127]] : i64
|
||||
// CHECK: %[[NEXT_SEED:.*]] = arith.addi %[[MUL]], %[[INC]] : i64
|
||||
// CHECK: memref.store %[[NEXT_SEED]], %[[MEMREF]][] : memref<i64>
|
||||
// CHECK: return %[[NEXT_SEED]] : i64
|
||||
module {
|
||||
|
|
Loading…
Reference in New Issue