mirror of https://github.com/llvm/torch-mlir
Port prior acap export tests to new dispatcher based versions.
* Sadly, non-trivial ones fail. * Bugs filed and marked XFAIL.pull/81/head
parent
30cfc6499f
commit
abb6fe8aa2
|
@ -0,0 +1,57 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import unittest
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import torch_mlir
|
||||
|
||||
# TODO: Fix https://github.com/llvm/mlir-npcomp/issues/79
|
||||
# XFAIL: *
|
||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||
|
||||
class ResA(nn.Module):
|
||||
def __init__(self, channels):
|
||||
C = int(channels)
|
||||
C2 = int(channels/2)
|
||||
super(ResA, self).__init__()
|
||||
self.model = nn.Sequential(# A1
|
||||
nn.BatchNorm2d(C),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(C,C2,1,stride=1,padding=0,dilation=1,groups=1,bias=True),
|
||||
# B1
|
||||
nn.BatchNorm2d(C2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(C2,C2,3,stride=1,padding=1,dilation=1,groups=1,bias=True),
|
||||
# C1
|
||||
nn.BatchNorm2d(C2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(C2,C,1,stride=1,padding=0,dilation=1,groups=1,bias=True))
|
||||
def forward(self, x):
|
||||
res = self.model.forward(x)
|
||||
return x + res
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
model = ResA(16)
|
||||
inputs = torch.ones((1,16,128,128))
|
||||
with mb.capture_function("resa", [inputs]) as f:
|
||||
f.returns([model(inputs)])
|
||||
|
||||
# CHECK-LABEL: func @resa
|
||||
# TODO: Update checks when test passes to this point.
|
||||
# CHECK: [[V0:%[a-zA-Z0-9]+]], %{{.*}}, %{{.*}} = "aten.native_batch_norm"({{.*}}) {layer_name = "L0-native_batch_norm-0"}
|
||||
# CHECK: [[V1:%[a-zA-Z0-9]+]] = "aten.relu"([[V0]]) {layer_name = "L1-relu-0"}
|
||||
# CHECK: [[V2:%[a-zA-Z0-9]+]] = "aten.convolution_overrideable"([[V1]], {{.*}}) {layer_name = "L2-convolution_overrideable-0"}
|
||||
# CHECK: [[V3:%[a-zA-Z0-9_]+]], %{{.*}}, %{{.*}} = "aten.native_batch_norm"([[V2]]{{.*}}) {layer_name = "L3-native_batch_norm-1"}
|
||||
# CHECK: [[V4:%[a-zA-Z0-9]+]] = "aten.relu"([[V3]]) {layer_name = "L4-relu-1"}
|
||||
# CHECK: [[V5:%[a-zA-Z0-9]+]] = "aten.convolution_overrideable"([[V4]],{{.*}}) {layer_name = "L5-convolution_overrideable-1"}
|
||||
# CHECK: [[V6:%[a-zA-Z0-9_]+]], %{{.*}}, %{{.*}} = "aten.native_batch_norm"([[V5]],{{.*}}) {layer_name = "L6-native_batch_norm-2"}
|
||||
# CHECK: [[V7:%[a-zA-Z0-9]+]] = "aten.relu"([[V6]]) {layer_name = "L7-relu-2"}
|
||||
# CHECK: [[V8:%[a-zA-Z0-9]+]] = "aten.convolution_overrideable"([[V7]],{{.*}}) {layer_name = "L8-convolution_overrideable-2"}
|
||||
# CHECK: {{.*}} = "aten.add"(%arg0, [[V8]], {{.*}}) {layer_name = "L9-add-0"}
|
||||
print(mb.module)
|
|
@ -0,0 +1,25 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||
|
||||
t0 = torch.randn((1,2,3,4))
|
||||
t1 = torch.randn((1,2,3,4))
|
||||
t2 = torch.randn((1,2,3,4))
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
with mb.capture_function("add3", [t0, t1, t2]) as f:
|
||||
t3 = t0 + t1 + t2
|
||||
f.returns([t3])
|
||||
|
||||
# CHECK-LABEL: func @add3
|
||||
# CHECK: %[[CST_1A:.*]] = constant 1 : i64
|
||||
# CHECK: %[[CST_1B:.*]] = constant 1 : i64
|
||||
# CHECK: %[[ADD0:.*]] = torch.kernel_call "aten::add" %arg0, %arg1, %[[CST_1A]]
|
||||
# CHECK: %[[ADD1:.*]] = torch.kernel_call "aten::add" %[[ADD0]], %arg2, %[[CST_1B]]
|
||||
# CHECK: return %[[ADD1]]
|
||||
print(mb.module)
|
|
@ -0,0 +1,27 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
# See bug references below and remove XFAIL when resolved.
|
||||
# XFAIL: *
|
||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
# TODO: Both of these fail with the "unsupported from an unboxed API yet" error.
|
||||
# The corresponding ops need to be manually coded. Then these can be moved into
|
||||
# the capture. https://github.com/llvm/mlir-npcomp/issues/78
|
||||
# TODO: These also create constant tensors (needs implementation of import of
|
||||
# DenseElements constants). https://github.com/llvm/mlir-npcomp/issues/79
|
||||
model = torch.nn.BatchNorm2d(123)
|
||||
ones = torch.ones(42,123,4,5)
|
||||
|
||||
with mb.capture_function("bn2d", []) as f:
|
||||
result = model(ones)
|
||||
f.returns([result])
|
||||
|
||||
# CHECK-LABEL: @bn2d
|
||||
print(mb.module)
|
|
@ -0,0 +1,50 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
# TODO: Fix https://github.com/llvm/mlir-npcomp/issues/80
|
||||
# XFAIL: *
|
||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
N = 3
|
||||
Cin = 16
|
||||
Cout = 4
|
||||
w = 10
|
||||
h = 10
|
||||
|
||||
model = torch.nn.Conv2d(Cin, Cout, (3,3))
|
||||
ref_model = torch.nn.Conv2d(Cin, Cout, (3,3))
|
||||
|
||||
ref_model.weight.data = model.weight.clone()
|
||||
ref_model.bias.data = model.bias.clone()
|
||||
|
||||
softmax = torch.nn.LogSoftmax(dim=1)
|
||||
loss = torch.nn.NLLLoss()
|
||||
|
||||
tensor = torch.randn(N, Cin, h, w)
|
||||
|
||||
with mb.capture_function("@conv2d_fwd", [tensor]) as f:
|
||||
result = model(tensor)
|
||||
f.returns([result])
|
||||
|
||||
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, Cout)
|
||||
ref_target = target.clone()
|
||||
|
||||
with mb.capture_function("@conv2d_backward", [result, target]) as f:
|
||||
test_loss = loss(softmax(result), target)
|
||||
f.returns([test_loss.backward()])
|
||||
|
||||
# CHECK-LABEL: func @conv2d_fwd
|
||||
# TODO: Add checks when passing
|
||||
|
||||
# CHECK-LABEL: func @conv2d_backward
|
||||
# TODO: Update checks when passing
|
||||
# NO-CHECK: aten.convolution_overrideable
|
||||
# NO-CHECK: aten._log_softmax
|
||||
# NO-CHECK: aten.nll_loss2d_forward
|
||||
print(mb.module)
|
|
@ -0,0 +1,29 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
t0 = torch.randn(4)
|
||||
t1 = torch.randn(4)
|
||||
t2 = torch.randn(4)
|
||||
|
||||
with mb.capture_function("multi_output", [t0, t1, t2]) as f:
|
||||
t4 = t0 + t1 + t2
|
||||
t5 = t4 + t1
|
||||
t6 = t5 + t4
|
||||
f.returns([t4, t5, t6])
|
||||
|
||||
# CHECK-LABEL: func @multi_output
|
||||
# CHECK: %[[ADD0:.*]] = torch.kernel_call "aten::add" %arg0
|
||||
# CHECK: %[[ADD1:.*]] = torch.kernel_call "aten::add" %[[ADD0]]
|
||||
# CHECK: %[[ADD2:.*]] = torch.kernel_call "aten::add" %[[ADD1]]
|
||||
# CHECK: %[[ADD3:.*]] = torch.kernel_call "aten::add" %[[ADD2]]
|
||||
# CHECK: return %[[ADD1]], %[[ADD2]], %[[ADD3]]
|
||||
print(mb.module)
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
import torchvision.models as models
|
||||
|
||||
# TODO: Fix https://github.com/llvm/mlir-npcomp/issues/80
|
||||
# XFAIL: *
|
||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||
|
||||
model = models.resnet18()
|
||||
model.training = False
|
||||
|
||||
tensor = torch.randn(32,3,32,32)
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
with mb.capture_function("res18", [tensor]) as f:
|
||||
result = model(tensor)
|
||||
f.returns([result])
|
||||
|
||||
print(mb.module)
|
||||
|
||||
# for now we just check the output shape
|
||||
# CHECK-LABEL: @res18
|
||||
# TODO: Add checks once running to this point.
|
|
@ -0,0 +1,26 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
import torchvision.models as models
|
||||
|
||||
# TODO: Fix https://github.com/llvm/mlir-npcomp/issues/80
|
||||
# XFAIL: *
|
||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||
|
||||
model = models.vgg11_bn()
|
||||
model.training = False
|
||||
|
||||
inputs = torch.ones(32,3,32,32)
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
with mb.capture_function("vgg11", [inputs]) as f:
|
||||
result = model(inputs)
|
||||
f.returns([result])
|
||||
|
||||
# CHECK-LABEL: func @vgg11
|
||||
# TODO: Add checks once passing this far.
|
||||
print(mb.module)
|
|
@ -67,6 +67,7 @@ llvm_config.with_environment('PYTHONPATH', [
|
|||
|
||||
tool_dirs = [config.npcomp_tools_dir, config.llvm_tools_dir]
|
||||
tools = [
|
||||
'npcomp-opt',
|
||||
]
|
||||
|
||||
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
||||
|
|
Loading…
Reference in New Issue