mirror of https://github.com/llvm/torch-mlir
Fix modulus calculation in LCG algorithm of refbackend (#1658)
The current implementation sets the `nextSeed` value to `temp & 127`, which is wrong. The last step of the LCG algorithm for the multiplier and increment chosen should be `temp % 2^{64} = temp & (1 << 63)`. However, because we are dealing with i64 values, the modulus operation happens automatically, so it is not needed. See Donald Knuth's values for LCG here: https://en.wikipedia.org/wiki/Linear_congruential_generatorpull/1633/head
parent
44b185a46b
commit
0983a7f93a
|
@ -257,10 +257,7 @@ static Value lowerGetNextSeed(OpBuilder &b, Location loc) {
|
||||||
loc, b.getI64IntegerAttr(1442695040888963407));
|
loc, b.getI64IntegerAttr(1442695040888963407));
|
||||||
// temp = multiplier * currentSeed + incrementStep
|
// temp = multiplier * currentSeed + incrementStep
|
||||||
Value mul = b.create<arith::MulIOp>(loc, currentSeed, multiplier);
|
Value mul = b.create<arith::MulIOp>(loc, currentSeed, multiplier);
|
||||||
Value temp = b.create<arith::AddIOp>(loc, mul, incrementStep);
|
Value nextSeed = b.create<arith::AddIOp>(loc, mul, incrementStep);
|
||||||
// temp mod 64 = temp & 63
|
|
||||||
Value cst127 = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(127));
|
|
||||||
Value nextSeed = b.create<arith::AndIOp>(loc, temp, cst127);
|
|
||||||
b.create<memref::StoreOp>(loc, nextSeed, globalVar);
|
b.create<memref::StoreOp>(loc, nextSeed, globalVar);
|
||||||
return nextSeed;
|
return nextSeed;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1466,7 +1466,7 @@ class DropoutTrainModule(torch.nn.Module):
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: DropoutTrainModule())
|
@register_test_case(module_factory=lambda: DropoutTrainModule())
|
||||||
def DropoutTrainModule_basic(module, tu: TestUtils):
|
def DropoutTrainModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(256, 256))
|
module.forward(tu.rand(1024, 1536))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
|
@ -144,8 +144,8 @@ class BernoulliTensorModule(torch.nn.Module):
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args([
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1], torch.float64, True),
|
([-1, -1], torch.float64, True),
|
||||||
([-1, -1, -1], torch.float64, True),
|
([-1, -1], torch.float64, True),
|
||||||
])
|
])
|
||||||
def forward(self, x, px):
|
def forward(self, x, px):
|
||||||
a = torch.ops.aten.bernoulli_(x, px)
|
a = torch.ops.aten.bernoulli_(x, px)
|
||||||
|
@ -157,8 +157,8 @@ class BernoulliTensorModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: BernoulliTensorModule())
|
@register_test_case(module_factory=lambda: BernoulliTensorModule())
|
||||||
def BernoulliTensorModule_basic(module, tu: TestUtils):
|
def BernoulliTensorModule_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(
|
||||||
tu.rand(512, 512, 2).double(),
|
tu.rand(512, 512).double(),
|
||||||
tu.rand(512, 512, 2).double())
|
tu.rand(512, 512).double())
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
|
@ -7,9 +7,7 @@
|
||||||
// CHECK: %[[MULTIPLIER:.*]] = arith.constant 6364136223846793005 : i64
|
// CHECK: %[[MULTIPLIER:.*]] = arith.constant 6364136223846793005 : i64
|
||||||
// CHECK: %[[INC:.*]] = arith.constant 1442695040888963407 : i64
|
// CHECK: %[[INC:.*]] = arith.constant 1442695040888963407 : i64
|
||||||
// CHECK: %[[MUL:.*]] = arith.muli %[[SEED]], %[[MULTIPLIER]] : i64
|
// CHECK: %[[MUL:.*]] = arith.muli %[[SEED]], %[[MULTIPLIER]] : i64
|
||||||
// CHECK: %[[TEMP:.*]] = arith.addi %[[MUL]], %[[INC]] : i64
|
// CHECK: %[[NEXT_SEED:.*]] = arith.addi %[[MUL]], %[[INC]] : i64
|
||||||
// CHECK: %[[CST127:.*]] = arith.constant 127 : i64
|
|
||||||
// CHECK: %[[NEXT_SEED:.*]] = arith.andi %[[TEMP]], %[[CST127]] : i64
|
|
||||||
// CHECK: memref.store %[[NEXT_SEED]], %[[MEMREF]][] : memref<i64>
|
// CHECK: memref.store %[[NEXT_SEED]], %[[MEMREF]][] : memref<i64>
|
||||||
// CHECK: return %[[NEXT_SEED]] : i64
|
// CHECK: return %[[NEXT_SEED]] : i64
|
||||||
module {
|
module {
|
||||||
|
|
Loading…
Reference in New Issue