mirror of https://github.com/llvm/torch-mlir
79 lines
3.2 KiB
Python
79 lines
3.2 KiB
Python
# -*- Python -*-
|
|
# This file is licensed under a pytorch-style license
|
|
# See frontends/pytorch/LICENSE for license information.
|
|
|
|
import unittest
|
|
from unittest import TestCase
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
import npcomp.frontends.pytorch as torch_mlir
|
|
|
|
import inspect
|
|
|
|
# RUN: python %s | FileCheck %s
|
|
|
|
class ResA(nn.Module):
|
|
def __init__(self, channels):
|
|
C = int(channels)
|
|
C2 = int(channels/2)
|
|
super(ResA, self).__init__()
|
|
self.model = nn.Sequential(# A1
|
|
nn.BatchNorm2d(C),
|
|
nn.ReLU(),
|
|
nn.Conv2d(C,C2,1,stride=1,padding=0,dilation=1,groups=1,bias=True),
|
|
# B1
|
|
nn.BatchNorm2d(C2),
|
|
nn.ReLU(),
|
|
nn.Conv2d(C2,C2,3,stride=1,padding=1,dilation=1,groups=1,bias=True),
|
|
# C1
|
|
nn.BatchNorm2d(C2),
|
|
nn.ReLU(),
|
|
nn.Conv2d(C2,C,1,stride=1,padding=0,dilation=1,groups=1,bias=True))
|
|
def forward(self, x):
|
|
res = self.model.forward(x)
|
|
return x + res
|
|
|
|
# Prints `str` prefixed by the current test function name so we can use it in
|
|
# Filecheck label directives.
|
|
# This is achieved by inspecting the stack and getting the parent name.
|
|
def printWithCurrentFunctionName(s):
|
|
# stack[1] is the caller, i.e. "_test_model"
|
|
# stack[2] is the caller's caller, e.g. "test_conv_1"
|
|
print(inspect.stack()[2][3], s)
|
|
|
|
class TestMLIRExport(unittest.TestCase):
|
|
def setUp(self):
|
|
pass
|
|
|
|
def _test_model(self, model, model_args):
|
|
result = model(model_args)
|
|
|
|
mlir = torch_mlir.get_mlir(result)
|
|
printWithCurrentFunctionName (mlir)
|
|
return True
|
|
|
|
def test_ResA_16(self):
|
|
dev = torch_mlir.mlir_device()
|
|
model = ResA(16).to(dev)
|
|
passed = self._test_model(model, torch.ones((1,16,128,128), device=dev))
|
|
# CHECK-LABEL: test_ResA_16
|
|
# CHECK: [[V0:%[a-zA-Z0-9]+]], %{{.*}}, %{{.*}} = "aten.native_batch_norm"({{.*}}) {layer_name = "L0-native_batch_norm-0"}
|
|
# CHECK: [[V1:%[a-zA-Z0-9]+]] = "aten.relu"([[V0]]) {layer_name = "L1-relu-0"}
|
|
# CHECK: [[V2:%[a-zA-Z0-9]+]] = "aten.convolution_overrideable"([[V1]], {{.*}}) {layer_name = "L2-convolution_overrideable-0"}
|
|
# CHECK: [[V3:%[a-zA-Z0-9_]+]], %{{.*}}, %{{.*}} = "aten.native_batch_norm"([[V2]]{{.*}}) {layer_name = "L3-native_batch_norm-1"}
|
|
# CHECK: [[V4:%[a-zA-Z0-9]+]] = "aten.relu"([[V3]]) {layer_name = "L4-relu-1"}
|
|
# CHECK: [[V5:%[a-zA-Z0-9]+]] = "aten.convolution_overrideable"([[V4]],{{.*}}) {layer_name = "L5-convolution_overrideable-1"}
|
|
# CHECK: [[V6:%[a-zA-Z0-9_]+]], %{{.*}}, %{{.*}} = "aten.native_batch_norm"([[V5]],{{.*}}) {layer_name = "L6-native_batch_norm-2"}
|
|
# CHECK: [[V7:%[a-zA-Z0-9]+]] = "aten.relu"([[V6]]) {layer_name = "L7-relu-2"}
|
|
# CHECK: [[V8:%[a-zA-Z0-9]+]] = "aten.convolution_overrideable"([[V7]],{{.*}}) {layer_name = "L8-convolution_overrideable-2"}
|
|
# CHECK: {{.*}} = "aten.add"(%arg0, [[V8]], {{.*}}) {layer_name = "L9-add-0"}
|
|
self.assertTrue(passed)
|
|
|
|
verbose = False
|
|
if __name__ == '__main__':
|
|
verbose = True
|
|
unittest.main()
|