mirror of https://github.com/llvm/torch-mlir
59 lines
1.6 KiB
Python
59 lines
1.6 KiB
Python
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
|
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
|
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
|
|
|
# ==============================================================================
|
|
|
|
|
|
class QuantizedMLP(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
torch.random.manual_seed(0)
|
|
self.layers = nn.Sequential(
|
|
nn.Linear(16, 8),
|
|
nn.Tanh(),
|
|
nn.Linear(8, 4),
|
|
)
|
|
self.quantize = torch.quantization.QuantStub()
|
|
self.dequantize = torch.quantization.DeQuantStub()
|
|
|
|
@export
|
|
@export
|
|
@annotate_args([
|
|
None,
|
|
([1, 16], torch.float32, True),
|
|
])
|
|
def forward(self, x):
|
|
x = self.quantize(x)
|
|
x = self.layers(x)
|
|
x = self.dequantize(x)
|
|
return x
|
|
|
|
|
|
def get_mlp_input():
|
|
return 2 * torch.rand((1, 16)) - 1
|
|
|
|
|
|
def get_quantized_mlp():
|
|
model = QuantizedMLP()
|
|
model.eval()
|
|
model.qconfig = torch.quantization.default_qconfig
|
|
torch.quantization.prepare(model, inplace=True)
|
|
torch.manual_seed(0)
|
|
for _ in range(32):
|
|
model(get_mlp_input())
|
|
torch.quantization.convert(model, inplace=True)
|
|
return model
|
|
|
|
|
|
@register_test_case(module_factory=get_quantized_mlp)
|
|
def QuantizedMLP_basic(module, tu: TestUtils):
|
|
module.forward(get_mlp_input())
|