torch-mlir/frontends/pytorch/examples/div_inplace_e2e.py

33 lines
810 B
Python

# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import torch
import torch_mlir
import npcomp
from npcomp.compiler.pytorch.backend import refjit, frontend_lowering
from npcomp.compiler.utils import logging
import test_utils
logging.enable()
torch.manual_seed(0)
arg0 = torch.ones(2, 2)
arg1 = torch.ones(2, 2)
def fun(a, b):
return a.div_(b)
mb = torch_mlir.ModuleBuilder()
with mb.capture_function("test", [arg0, arg1]) as f:
f.returns([fun(arg0, arg1)])
backend = refjit.CompilerBackend()
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))
test_utils.compare_outputs(torch.mm, jit_module.test, arg0, arg1)
test_utils.compare_outputs(torch.mm, jit_module.test, arg0 + 1, arg1 + 1)