From 3db2197ef57a379558fad8c9195e3440bcf82393 Mon Sep 17 00:00:00 2001 From: Siavash Nazari Date: Mon, 6 Jun 2022 21:00:27 -0400 Subject: [PATCH] Add sample model with block_quantize nodes from qtorch - This example script fails on torch_mlir.compile() API - Compiles fine with no block_quantize in SimpleModel.forward() - Having block_quantize only on the inputs tensors compiles fine --- examples/block_quantize_experiment.py | 32 +++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 examples/block_quantize_experiment.py diff --git a/examples/block_quantize_experiment.py b/examples/block_quantize_experiment.py new file mode 100644 index 000000000..151a5d8bd --- /dev/null +++ b/examples/block_quantize_experiment.py @@ -0,0 +1,32 @@ +import torch +import torch.nn as nn + +import qtorch +from qtorch.quant import block_quantize + +import torch_mlir + +class SimpleModel(nn.Module): + def __init__(self, input_dim, output_size): + super(SimpleModel, self).__init__() + self.matmul = nn.Linear(input_dim, output_size) + self.relu = nn.ReLU() + + def forward(self, x): + matmul_out = self.matmul(x.flatten(1)) + quantized_matmul_out = block_quantize(matmul_out, wl=8, dim=0, rounding="nearest") + relu_out = self.relu(quantized_matmul_out) + return relu_out + +batches = 5 +input_dim = 64 +output_size = 4 +inputs = torch.randn(batches, input_dim) +model = SimpleModel(input_dim, output_size) +print("forward propagate results on inputs is:\n", model.forward(inputs)) + +# quantized_inputs = block_quantize(inputs, wl=8, dim=0, rounding="nearest") +# print("forward propagate of quantized inputs result is ", model.forward(quantized_inputs)) + +module = torch_mlir.compile(model, inputs, output_type=torch_mlir.OutputType.TOSA) +print("Module compiled to TOSA is:\n", module)