mirror of https://github.com/llvm/torch-mlir
35 lines
821 B
Python
35 lines
821 B
Python
|
# RUN: %PYTHON -s %s 2>&1 | FileCheck %s
|
||
|
|
||
|
import gc
|
||
|
import sys
|
||
|
import torch
|
||
|
import torch_mlir
|
||
|
|
||
|
|
||
|
def run_test(f):
|
||
|
print("TEST:", f.__name__, file=sys.stderr)
|
||
|
f()
|
||
|
gc.collect()
|
||
|
|
||
|
|
||
|
class TinyModel(torch.nn.Module):
|
||
|
def __init__(self):
|
||
|
super(TinyModel, self).__init__()
|
||
|
|
||
|
self.linear = torch.nn.Linear(20, 30)
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = self.linear(x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
# CHECK-LABEL: TEST: test_enable_ir_printing
|
||
|
@run_test
|
||
|
def test_enable_ir_printing():
|
||
|
torch_mlir.compile(TinyModel(),
|
||
|
torch.ones(1, 3, 20, 20),
|
||
|
output_type="linalg-on-tensors",
|
||
|
enable_ir_printing=True)
|
||
|
# CHECK: // -----// IR Dump Before Canonicalizer (canonicalize)
|
||
|
# CHECK-NEXT: module attributes {torch.debug_module_name = "TinyModel"} {
|