torch-mlir/e2e_testing/torchscript/quantized_models.py

60 lines
1.7 KiB
Python
Raw Normal View History

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
Get simple quantized model importing. This is enough to import the program and get it through the compilation pipeline. It of course fails at the VerifyBackendContract pass since there is a lot missing, but the final IR for a simple quantized MLP is looking pretty decent already: [IR](https://gist.github.com/silvasean/f76bccd76e9b193d396cfb2f9a11f54d) Main changes: - Add support for importing torch quantized tensors, including `torch.per_tensor_affine.create` op and `!torch.qint8` element type. - Add support for importing `LinearPackedParamsBase` (basically a weight + optional bias, but requires `torch.linear_params.create` op + `!torch.LinearParams` type to model it). This was less painful than I expected, as it has the necessary methods to opaquely unpack itself. I factored things so it should be easy to extend to other custom classes like `ConvPackedParamsBase`. - Add minimal boilerplate for importing `quantized::*` ops, with `quantized::linear` being a motivating example. - Add e2e test with simple quantized MLP (courtesy of @phoenix-meadowlark). This is somewhat of an abuse of `!numpy.ndarray` / `tensor`, as really the proper semantics of `!torch.qint8` dtype on a Torch tensor is "check the quantizer object of the tensor for side data (scale/offset, possibly per-channel) that defines the full semantics of the tensor". We don't have any such notion of "side data" for `!numpy.ndarray` / `tensor`, let alone anything that would have the associated behavior of keying off the dtype to determine if the side data is present. This will be fixed by a proper `!torch.tensor` type.
2021-05-20 02:40:48 +08:00
import torch
from torch import nn
from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
Get simple quantized model importing. This is enough to import the program and get it through the compilation pipeline. It of course fails at the VerifyBackendContract pass since there is a lot missing, but the final IR for a simple quantized MLP is looking pretty decent already: [IR](https://gist.github.com/silvasean/f76bccd76e9b193d396cfb2f9a11f54d) Main changes: - Add support for importing torch quantized tensors, including `torch.per_tensor_affine.create` op and `!torch.qint8` element type. - Add support for importing `LinearPackedParamsBase` (basically a weight + optional bias, but requires `torch.linear_params.create` op + `!torch.LinearParams` type to model it). This was less painful than I expected, as it has the necessary methods to opaquely unpack itself. I factored things so it should be easy to extend to other custom classes like `ConvPackedParamsBase`. - Add minimal boilerplate for importing `quantized::*` ops, with `quantized::linear` being a motivating example. - Add e2e test with simple quantized MLP (courtesy of @phoenix-meadowlark). This is somewhat of an abuse of `!numpy.ndarray` / `tensor`, as really the proper semantics of `!torch.qint8` dtype on a Torch tensor is "check the quantizer object of the tensor for side data (scale/offset, possibly per-channel) that defines the full semantics of the tensor". We don't have any such notion of "side data" for `!numpy.ndarray` / `tensor`, let alone anything that would have the associated behavior of keying off the dtype to determine if the side data is present. This will be fixed by a proper `!torch.tensor` type.
2021-05-20 02:40:48 +08:00
# ==============================================================================
class QuantizedMLP(nn.Module):
def __init__(self):
super().__init__()
torch.random.manual_seed(0)
self.layers = nn.Sequential(
nn.Linear(16, 8),
nn.Tanh(),
nn.Linear(8, 4),
)
self.quantize = torch.quantization.QuantStub()
self.dequantize = torch.quantization.DeQuantStub()
@export
@export
@annotate_args([
None,
Introduce `!torch.tensor` / `!torch.vtensor` types. This removes our reliance on the numpy dialect and avoids our off-label use of the builtin tnesor type for modeling unknown dtypes. The `!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor. The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic tensor. The new types look as follows syntactically: ``` // Least-static-information, non-value-semantic tensor. !torch.tensor // Explicit form of least-static-information variant. !torch.tensor<*,unk> // Least-static-information, value-semantic tensor. !torch.vtensor // Explicit form of least-static-information variant. !torch.vtensor<*,unk> // Fixed-set of allowable element types, with first-class support for // Torch's frontend signedness semantics. !torch.tensor<*,si32> // First-class support for unknown dtypes. !torch.tensor<[?,?,?],unk> // Standard MLIR representation of `?` for unknown dimensions. !torch.tensor<[?,2,?,4],unk> // Statically shaped / dtyped example. !torch.vtensor<[1,2,3,4],f32> ``` This required fairly significant changes throughout the compiler, but overall it is a big cleanup. We now have a much clearer layering of "the Torch frontend lowering" vs "lowering to std + linalg + etc.". At the C++ level, there is `ValueTensorType`, `NonValueTensorType`. We also have a helper `BaseTensorType` (kind of like ShapedType) which interoperates with those two. Included changes: - New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for creating torch tensor literals in the frontend. - Consistently use signedness for the types (except i1 which I didn't touch -- we need to sort out the situation with !basicpy.BoolType there anyway so will be attending to that soon) - Frontend can annotate whether an argument to the function has value semantics. We currently require this, as our backend contract does not currently allow us to even model the non-value-semantic case. Before, the value-semantic assumption was randomly injected in the middle of the pass pipeline. - Move ArrayToTensor (now called MaximizeValueSemantics) and RefinePublicReturn passes to torch dialect. - The TorchToStd and TorchToLinalg passes are now type conversions from `!torch.vtensor` to `tensor` and use the dialect conversion infra. The overall conversion pipeline is set up following the best practices of the "Type Conversions the Not-So-Hard Way" talk. This required introducing `torch-func-builtin-tensorize` and `torch-finalizing-builtin-tensorize` passes analogous to the upstream bufferization passes with the corresponding names (mostly just copypasta from there). - Misc Torch-level canonicalizations -- we now cleanly layer the lowering to std later in the pipeline, so we are gradually lessening our reliance on random std constant folding before we get to that point. Recommended review order: - New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp - New ops in TorchOps.td / TorchOps.cpp - Less important / more mechanical stuff - Frontend changes. - Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
([1, 16], torch.float32, True),
Get simple quantized model importing. This is enough to import the program and get it through the compilation pipeline. It of course fails at the VerifyBackendContract pass since there is a lot missing, but the final IR for a simple quantized MLP is looking pretty decent already: [IR](https://gist.github.com/silvasean/f76bccd76e9b193d396cfb2f9a11f54d) Main changes: - Add support for importing torch quantized tensors, including `torch.per_tensor_affine.create` op and `!torch.qint8` element type. - Add support for importing `LinearPackedParamsBase` (basically a weight + optional bias, but requires `torch.linear_params.create` op + `!torch.LinearParams` type to model it). This was less painful than I expected, as it has the necessary methods to opaquely unpack itself. I factored things so it should be easy to extend to other custom classes like `ConvPackedParamsBase`. - Add minimal boilerplate for importing `quantized::*` ops, with `quantized::linear` being a motivating example. - Add e2e test with simple quantized MLP (courtesy of @phoenix-meadowlark). This is somewhat of an abuse of `!numpy.ndarray` / `tensor`, as really the proper semantics of `!torch.qint8` dtype on a Torch tensor is "check the quantizer object of the tensor for side data (scale/offset, possibly per-channel) that defines the full semantics of the tensor". We don't have any such notion of "side data" for `!numpy.ndarray` / `tensor`, let alone anything that would have the associated behavior of keying off the dtype to determine if the side data is present. This will be fixed by a proper `!torch.tensor` type.
2021-05-20 02:40:48 +08:00
])
def forward(self, x):
x = self.quantize(x)
x = self.layers(x)
x = self.dequantize(x)
return x
def get_mlp_input():
return 2 * torch.rand((1, 16)) - 1
def get_quantized_mlp():
model = QuantizedMLP()
model.eval()
model.qconfig = torch.quantization.default_qconfig
torch.quantization.prepare(model, inplace=True)
torch.manual_seed(0)
for _ in range(32):
model(get_mlp_input())
torch.quantization.convert(model, inplace=True)
return model
@register_test_case(module_factory=get_quantized_mlp)
def QuantizedMLP_basic(module, tu: TestUtils):
module.forward(get_mlp_input())