[TorchToLinalg] Lower AtenRoundOp to math::RoundEvenOp (Fixes #1811) (#1823)

[TorchToLinalg] Lower AtenRoundOp to math::RoundEvenOp (Fixes #1811)
pull/1841/head snapshot-20230125.729
Matthias Gehre 2023-01-25 08:51:29 +01:00 committed by GitHub
parent 3930588a7e
commit adaf05f03e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 1 deletions

View File

@ -293,7 +293,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
round.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
return b.create<math::RoundOp>(loc, payloadArgs[0]);
return b.create<math::RoundEvenOp>(loc, payloadArgs[0]);
}
if (auto prelu = dyn_cast<AtenPreluOp>(op)) {
if (!prelu.getType()

View File

@ -2647,6 +2647,26 @@ class AtenRoundFloatModule(torch.nn.Module):
def AtenRoundFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 5, low = -3.0, high = 3.0))
class AtenRoundFloatHalfToEvenModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.round(x)
@register_test_case(module_factory=lambda: AtenRoundFloatHalfToEvenModule())
def AtenRoundFloatHalfToEvenModule_basic(module, tu: TestUtils):
module.forward(torch.FloatTensor([[0.5, 1.5], [-0.5, -1.5]]))
class AtenRoundIntModule(torch.nn.Module):
def __init__(self):