torch-mlir/e2e_testing/torchscript/basic.py

266 lines
7.0 KiB
Python
Raw Normal View History

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import torch
from npcomp_torchscript.e2e_test.framework import TestUtils
from npcomp_torchscript.e2e_test.registry import register_test_case
from npcomp_torchscript.annotations import annotate_args, export
# ==============================================================================
class MmModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
Introduce `!torch.tensor` / `!torch.vtensor` types. This removes our reliance on the numpy dialect and avoids our off-label use of the builtin tnesor type for modeling unknown dtypes. The `!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor. The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic tensor. The new types look as follows syntactically: ``` // Least-static-information, non-value-semantic tensor. !torch.tensor // Explicit form of least-static-information variant. !torch.tensor<*,unk> // Least-static-information, value-semantic tensor. !torch.vtensor // Explicit form of least-static-information variant. !torch.vtensor<*,unk> // Fixed-set of allowable element types, with first-class support for // Torch's frontend signedness semantics. !torch.tensor<*,si32> // First-class support for unknown dtypes. !torch.tensor<[?,?,?],unk> // Standard MLIR representation of `?` for unknown dimensions. !torch.tensor<[?,2,?,4],unk> // Statically shaped / dtyped example. !torch.vtensor<[1,2,3,4],f32> ``` This required fairly significant changes throughout the compiler, but overall it is a big cleanup. We now have a much clearer layering of "the Torch frontend lowering" vs "lowering to std + linalg + etc.". At the C++ level, there is `ValueTensorType`, `NonValueTensorType`. We also have a helper `BaseTensorType` (kind of like ShapedType) which interoperates with those two. Included changes: - New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for creating torch tensor literals in the frontend. - Consistently use signedness for the types (except i1 which I didn't touch -- we need to sort out the situation with !basicpy.BoolType there anyway so will be attending to that soon) - Frontend can annotate whether an argument to the function has value semantics. We currently require this, as our backend contract does not currently allow us to even model the non-value-semantic case. Before, the value-semantic assumption was randomly injected in the middle of the pass pipeline. - Move ArrayToTensor (now called MaximizeValueSemantics) and RefinePublicReturn passes to torch dialect. - The TorchToStd and TorchToLinalg passes are now type conversions from `!torch.vtensor` to `tensor` and use the dialect conversion infra. The overall conversion pipeline is set up following the best practices of the "Type Conversions the Not-So-Hard Way" talk. This required introducing `torch-func-builtin-tensorize` and `torch-finalizing-builtin-tensorize` passes analogous to the upstream bufferization passes with the corresponding names (mostly just copypasta from there). - Misc Torch-level canonicalizations -- we now cleanly layer the lowering to std later in the pipeline, so we are gradually lessening our reliance on random std constant folding before we get to that point. Recommended review order: - New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp - New ops in TorchOps.td / TorchOps.cpp - Less important / more mechanical stuff - Frontend changes. - Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, lhs, rhs):
return torch.mm(lhs, rhs)
@register_test_case(module_factory=lambda: MmModule())
def MmModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 4), tu.rand(4, 4))
# TODO: Investigate why RefBackend sometimes can't handle two calls in a row in
# the trace.
# It actually works, if MmModule_chained is run by itself, but if other tests
# are mixed with it, it fails with a mysterious-sounding low level ctypes error
# that exceeds my current ability to debug.
#
# @register_test_case(module_factory=lambda: MmModule())
# def MmModule_chained(module, tu: TestUtils):
# res = module.forward(tu.rand(4, 4), tu.rand(4, 4))
# module.forward(res, res)
# ==============================================================================
class BmmModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
])
def forward(self, lhs, rhs):
return torch.bmm(lhs, rhs)
@register_test_case(module_factory=lambda: BmmModule())
def BmmModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5), tu.rand(3, 5, 4))
# ==============================================================================
# A subgraph with multiple mm ops.
class MmDagModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([4, 4], torch.float32, True),
([4, 4], torch.float32, True),
])
def forward(self, lhs, rhs):
return torch.mm(lhs, torch.mm(lhs, rhs))
@register_test_case(module_factory=lambda: MmDagModule())
def MmDagModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 4), tu.rand(4, 4))
# ==============================================================================
class MmTanhModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
Introduce `!torch.tensor` / `!torch.vtensor` types. This removes our reliance on the numpy dialect and avoids our off-label use of the builtin tnesor type for modeling unknown dtypes. The `!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor. The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic tensor. The new types look as follows syntactically: ``` // Least-static-information, non-value-semantic tensor. !torch.tensor // Explicit form of least-static-information variant. !torch.tensor<*,unk> // Least-static-information, value-semantic tensor. !torch.vtensor // Explicit form of least-static-information variant. !torch.vtensor<*,unk> // Fixed-set of allowable element types, with first-class support for // Torch's frontend signedness semantics. !torch.tensor<*,si32> // First-class support for unknown dtypes. !torch.tensor<[?,?,?],unk> // Standard MLIR representation of `?` for unknown dimensions. !torch.tensor<[?,2,?,4],unk> // Statically shaped / dtyped example. !torch.vtensor<[1,2,3,4],f32> ``` This required fairly significant changes throughout the compiler, but overall it is a big cleanup. We now have a much clearer layering of "the Torch frontend lowering" vs "lowering to std + linalg + etc.". At the C++ level, there is `ValueTensorType`, `NonValueTensorType`. We also have a helper `BaseTensorType` (kind of like ShapedType) which interoperates with those two. Included changes: - New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for creating torch tensor literals in the frontend. - Consistently use signedness for the types (except i1 which I didn't touch -- we need to sort out the situation with !basicpy.BoolType there anyway so will be attending to that soon) - Frontend can annotate whether an argument to the function has value semantics. We currently require this, as our backend contract does not currently allow us to even model the non-value-semantic case. Before, the value-semantic assumption was randomly injected in the middle of the pass pipeline. - Move ArrayToTensor (now called MaximizeValueSemantics) and RefinePublicReturn passes to torch dialect. - The TorchToStd and TorchToLinalg passes are now type conversions from `!torch.vtensor` to `tensor` and use the dialect conversion infra. The overall conversion pipeline is set up following the best practices of the "Type Conversions the Not-So-Hard Way" talk. This required introducing `torch-func-builtin-tensorize` and `torch-finalizing-builtin-tensorize` passes analogous to the upstream bufferization passes with the corresponding names (mostly just copypasta from there). - Misc Torch-level canonicalizations -- we now cleanly layer the lowering to std later in the pipeline, so we are gradually lessening our reliance on random std constant folding before we get to that point. Recommended review order: - New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp - New ops in TorchOps.td / TorchOps.cpp - Less important / more mechanical stuff - Frontend changes. - Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, lhs, rhs):
return torch.tanh(self.matmul(lhs, rhs))
def matmul(self, lhs, rhs):
return torch.mm(lhs, rhs)
@register_test_case(module_factory=lambda: MmTanhModule())
def MmTanhModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 2), tu.rand(2, 4))
class AdaptiveAvgPool2dModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.aap2d = torch.nn.AdaptiveAvgPool2d((1, 1))
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, x):
return self.aap2d(x)
@register_test_case(module_factory=lambda: AdaptiveAvgPool2dModule())
def AdaptiveAvgPool2dModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 3, 8, 9))
class FlattenStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.flat = torch.nn.Flatten(2, 4)
@export
@annotate_args([
None,
([10, 3, 8, 9, 3, 4], torch.float32, True),
])
def forward(self, x):
return self.flat(x)
@register_test_case(module_factory=lambda: FlattenStaticModule())
def FlattenStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 3, 8, 9, 3, 4))
class FlattenRank0Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.flat = torch.nn.Flatten(-1, -1)
@export
@annotate_args([
None,
([], torch.float32, True),
])
def forward(self, x):
return self.flat(x)
@register_test_case(module_factory=lambda: FlattenRank0Module())
def FlattenRank0Module_basic(module, tu: TestUtils):
module.forward(torch.tensor(4.0))
class FlattenDynamicModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.flat = torch.nn.Flatten(2, 4)
@export
@annotate_args([
None,
([-1, -1, -1, 9, 3, -1], torch.float32, True),
])
def forward(self, x):
return self.flat(x)
@register_test_case(module_factory=lambda: FlattenDynamicModule())
def FlattenDynamicModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 3, 8, 9, 3, 4))
class MaxPool2dModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.mp2d = torch.nn.MaxPool2d(kernel_size=[6, 8],
stride=[2, 2],
padding=[3, 4],
dilation=2)
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, x):
return self.mp2d(x)
@register_test_case(module_factory=lambda: MaxPool2dModule())
def MaxPool2dModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 20, 20) - 0.5)
2021-09-17 14:49:04 +08:00
class TransposeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4, 2], torch.float32, True),
])
def forward(self, x):
return torch.transpose(x, 0, 1)
@register_test_case(module_factory=lambda: TransposeIntModule())
def TransposeIntModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 2))
class TensorsConcatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
])
def forward(self, x, y, z):
return torch.cat([x, y, z], 1)
@register_test_case(module_factory=lambda: TensorsConcatModule())
def TensorsConcatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 2, 4), tu.rand(2, 1, 4), tu.rand(2, 3, 4))
class GatherModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.int64, True),
])
def forward(self, tensor, indices):
return torch.gather(tensor, 2, indices)
#@register_test_case(module_factory=lambda: GatherModule())
#def GatherModule_basic(module, tu: TestUtils):
# module.forward(tu.rand(2, 3, 4), torch.tensor([[[1,2,3],[1,2,3]]]))