torch-mlir/frontends/pytorch/test/ivalue_import/quantization.py

45 lines
2.1 KiB
Python
Raw Normal View History

Get simple quantized model importing. This is enough to import the program and get it through the compilation pipeline. It of course fails at the VerifyBackendContract pass since there is a lot missing, but the final IR for a simple quantized MLP is looking pretty decent already: [IR](https://gist.github.com/silvasean/f76bccd76e9b193d396cfb2f9a11f54d) Main changes: - Add support for importing torch quantized tensors, including `torch.per_tensor_affine.create` op and `!torch.qint8` element type. - Add support for importing `LinearPackedParamsBase` (basically a weight + optional bias, but requires `torch.linear_params.create` op + `!torch.LinearParams` type to model it). This was less painful than I expected, as it has the necessary methods to opaquely unpack itself. I factored things so it should be easy to extend to other custom classes like `ConvPackedParamsBase`. - Add minimal boilerplate for importing `quantized::*` ops, with `quantized::linear` being a motivating example. - Add e2e test with simple quantized MLP (courtesy of @phoenix-meadowlark). This is somewhat of an abuse of `!numpy.ndarray` / `tensor`, as really the proper semantics of `!torch.qint8` dtype on a Torch tensor is "check the quantizer object of the tensor for side data (scale/offset, possibly per-channel) that defines the full semantics of the tensor". We don't have any such notion of "side data" for `!numpy.ndarray` / `tensor`, let alone anything that would have the associated behavior of keying off the dtype to determine if the side data is present. This will be fixed by a proper `!torch.tensor` type.
2021-05-20 02:40:48 +08:00
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import typing
import torch
import torch_mlir
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder()
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.quantized.Linear(5, 2, dtype=torch.qint8)
self.linear_no_bias = torch.nn.quantized.Linear(6,
2,
bias_=False,
dtype=torch.qint8)
# CHECK-DAG: %[[SCALE:.*]] = basicpy.numeric_constant {{.*}} : f64
# CHECK-DAG: %[[ZERO_POINT:.*]] = basicpy.numeric_constant 0 : i64
# CHECK-DAG: %[[INT_REPR:.*]] = constant dense<{{.*}}> : tensor<2x5xi8>
# CHECK-DAG: %[[WEIGHTS:.*]] = torch.per_tensor_affine.create %[[INT_REPR]], %[[SCALE]], %[[ZERO_POINT]] : tensor<2x5xi8>, f64, i64 -> tensor<2x5x!torch.qint8>
# CHECK-DAG: %[[WEIGHTS_ARRAY:.*]] = numpy.create_array_from_tensor %[[WEIGHTS]] : (tensor<2x5x!torch.qint8>) -> !numpy.ndarray<*:!numpy.any_dtype>
# CHECK-DAG: %[[BIAS:.*]] = constant dense<{{.*}}> : tensor<2xf32>
# CHECK-DAG: %[[BIAS_ARRAY:.*]] = numpy.create_array_from_tensor %[[BIAS]] : (tensor<2xf32>) -> !numpy.ndarray<*:!numpy.any_dtype>
# CHECK-DAG: %[[LINEAR_PARAMS:.*]] = torch.linear_params.create %[[WEIGHTS_ARRAY]], %[[BIAS_ARRAY]] : !numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>
@torch.jit.export
def test_linear(self, t):
return self.linear(t)
# CHECK: %[[LINEAR_PARAMS_NO_BIAS:.*]] = torch.linear_params.create %{{.*}} : !numpy.ndarray<*:!numpy.any_dtype>{{$}}
@torch.jit.export
def test_linear_no_bias(self, t):
return self.linear_no_bias(t)
test_module = TestModule()
recursivescriptmodule = torch.jit.script(test_module)
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
mb.import_module(recursivescriptmodule._c)
mb.module.operation.print()