diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index 981263025..bcd8d6030 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -257,10 +257,7 @@ static Value lowerGetNextSeed(OpBuilder &b, Location loc) { loc, b.getI64IntegerAttr(1442695040888963407)); // temp = multiplier * currentSeed + incrementStep Value mul = b.create(loc, currentSeed, multiplier); - Value temp = b.create(loc, mul, incrementStep); - // temp mod 64 = temp & 63 - Value cst127 = b.create(loc, b.getI64IntegerAttr(127)); - Value nextSeed = b.create(loc, temp, cst127); + Value nextSeed = b.create(loc, mul, incrementStep); b.create(loc, nextSeed, globalVar); return nextSeed; } diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 9e04ab2d5..662d646a4 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -1466,7 +1466,7 @@ class DropoutTrainModule(torch.nn.Module): @register_test_case(module_factory=lambda: DropoutTrainModule()) def DropoutTrainModule_basic(module, tu: TestUtils): - module.forward(tu.rand(256, 256)) + module.forward(tu.rand(1024, 1536)) # ============================================================================== diff --git a/python/torch_mlir_e2e_test/test_suite/rng.py b/python/torch_mlir_e2e_test/test_suite/rng.py index 2fc1444ff..8e72723d1 100644 --- a/python/torch_mlir_e2e_test/test_suite/rng.py +++ b/python/torch_mlir_e2e_test/test_suite/rng.py @@ -144,8 +144,8 @@ class BernoulliTensorModule(torch.nn.Module): @export @annotate_args([ None, - ([-1, -1, -1], torch.float64, True), - ([-1, -1, -1], torch.float64, True), + ([-1, -1], torch.float64, True), + ([-1, -1], torch.float64, True), ]) def forward(self, x, px): a = torch.ops.aten.bernoulli_(x, px) @@ -157,8 +157,8 @@ class BernoulliTensorModule(torch.nn.Module): @register_test_case(module_factory=lambda: BernoulliTensorModule()) def BernoulliTensorModule_basic(module, tu: TestUtils): module.forward( - tu.rand(512, 512, 2).double(), - tu.rand(512, 512, 2).double()) + tu.rand(512, 512).double(), + tu.rand(512, 512).double()) # ============================================================================== diff --git a/test/RefBackend/insert-rng-globals.mlir b/test/RefBackend/insert-rng-globals.mlir index c44d3397a..51d836ee0 100644 --- a/test/RefBackend/insert-rng-globals.mlir +++ b/test/RefBackend/insert-rng-globals.mlir @@ -7,9 +7,7 @@ // CHECK: %[[MULTIPLIER:.*]] = arith.constant 6364136223846793005 : i64 // CHECK: %[[INC:.*]] = arith.constant 1442695040888963407 : i64 // CHECK: %[[MUL:.*]] = arith.muli %[[SEED]], %[[MULTIPLIER]] : i64 -// CHECK: %[[TEMP:.*]] = arith.addi %[[MUL]], %[[INC]] : i64 -// CHECK: %[[CST127:.*]] = arith.constant 127 : i64 -// CHECK: %[[NEXT_SEED:.*]] = arith.andi %[[TEMP]], %[[CST127]] : i64 +// CHECK: %[[NEXT_SEED:.*]] = arith.addi %[[MUL]], %[[INC]] : i64 // CHECK: memref.store %[[NEXT_SEED]], %[[MEMREF]][] : memref // CHECK: return %[[NEXT_SEED]] : i64 module {