From adaf05f03ebafd1cc94aa4c8d2cadbf4235dbb32 Mon Sep 17 00:00:00 2001 From: Matthias Gehre <93204396+mgehre-amd@users.noreply.github.com> Date: Wed, 25 Jan 2023 08:51:29 +0100 Subject: [PATCH] [TorchToLinalg] Lower AtenRoundOp to math::RoundEvenOp (Fixes #1811) (#1823) [TorchToLinalg] Lower AtenRoundOp to math::RoundEvenOp (Fixes #1811) --- .../TorchToLinalg/Uncategorized.cpp | 2 +- .../test_suite/elementwise.py | 20 +++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index bc16c8c1e..d5ad4f50b 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -293,7 +293,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( round.emitError("unimplemented: non-floating point dtype"); return nullptr; } - return b.create(loc, payloadArgs[0]); + return b.create(loc, payloadArgs[0]); } if (auto prelu = dyn_cast(op)) { if (!prelu.getType() diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index 41f7c8364..46411df66 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -2647,6 +2647,26 @@ class AtenRoundFloatModule(torch.nn.Module): def AtenRoundFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(5, 5, low = -3.0, high = 3.0)) + +class AtenRoundFloatHalfToEvenModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.round(x) + + +@register_test_case(module_factory=lambda: AtenRoundFloatHalfToEvenModule()) +def AtenRoundFloatHalfToEvenModule_basic(module, tu: TestUtils): + module.forward(torch.FloatTensor([[0.5, 1.5], [-0.5, -1.5]])) + + class AtenRoundIntModule(torch.nn.Module): def __init__(self):