mirror of https://github.com/llvm/torch-mlir
43 lines
2.7 KiB
MLIR
43 lines
2.7 KiB
MLIR
|
// RUN: torch-mlir-opt --split-input-file --torch-match-quantized-custom-ops %s | FileCheck %s
|
||
|
|
||
|
// CHECK-LABEL: func.func @quantize_per_tensor
|
||
|
func.func @quantize_per_tensor(%arg0: !torch.vtensor<[1,3,8,8],f32>) -> !torch.vtensor<[1,3,8,8],si8> {
|
||
|
%float = torch.constant.float 0.5
|
||
|
%zp = torch.constant.int 17
|
||
|
%min = torch.constant.int -128
|
||
|
%max = torch.constant.int 127
|
||
|
%dtype = torch.constant.int 1
|
||
|
|
||
|
// CHECK-DAG: %[[SCALE:.+]] = torch.constant.float 5.000000e-01
|
||
|
// CHECK-DAG: %[[ZP:.+]] = torch.constant.int 17
|
||
|
// CHECK-DAG: %[[MIN:.+]] = torch.constant.int -128
|
||
|
// CHECK-DAG: %[[MAX:.+]] = torch.constant.int 127
|
||
|
// CHECK-DAG: %[[DTYPE:.+]] = torch.constant.int 1
|
||
|
// CHECK-DAG: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %arg0, %[[SCALE]], %[[ZP]], %[[DTYPE]] : !torch.vtensor<[1,3,8,8],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8>
|
||
|
// CHECK-DAG: %[[REPR:.+]] = torch.aten.int_repr %[[QUANT]] : !torch.vtensor<[1,3,8,8],!torch.qint8> -> !torch.vtensor<[1,3,8,8],si8>
|
||
|
// CHECK: torch.aten.clamp %[[REPR]], %[[MIN]], %[[MAX]]
|
||
|
%0 = torch.operator "torch.quantized_decomposed.quantize_per_tensor"(%arg0, %float, %zp, %min, %max, %dtype) : (!torch.vtensor<[1,3,8,8],f32>, !torch.float, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.vtensor<[1,3,8,8],si8>
|
||
|
return %0 : !torch.vtensor<[1,3,8,8],si8>
|
||
|
}
|
||
|
|
||
|
// -----
|
||
|
|
||
|
// CHECK-LABEL: func.func @dequantize_per_tensor
|
||
|
func.func @dequantize_per_tensor(%arg0: !torch.vtensor<[1,3,8,8],si8>) -> !torch.vtensor<[1,3,8,8],f32> {
|
||
|
%float = torch.constant.float 0.5
|
||
|
%zp = torch.constant.int 17
|
||
|
%min = torch.constant.int -128
|
||
|
%max = torch.constant.int 127
|
||
|
%dtype = torch.constant.int 1
|
||
|
|
||
|
// CHECK-DAG: %[[SCALE:.+]] = torch.constant.float 5.000000e-01
|
||
|
// CHECK-DAG: %[[ZP:.+]] = torch.constant.int 17
|
||
|
// CHECK-DAG: %[[MIN:.+]] = torch.constant.int -128
|
||
|
// CHECK-DAG: %[[MAX:.+]] = torch.constant.int 127
|
||
|
// CHECK-DAG: %[[CLAMP:.+]] = torch.aten.clamp %arg0, %[[MIN]], %[[MAX]] : !torch.vtensor<[1,3,8,8],si8>, !torch.int, !torch.int -> !torch.vtensor<[1,3,8,8],si8>
|
||
|
// CHECK-DAG: %[[QINT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CLAMP]], %[[SCALE]], %[[ZP]] : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8>
|
||
|
// CHECK: %[[DEQUANT:.+]] = torch.aten.dequantize.tensor %[[QINT]] : !torch.vtensor<[1,3,8,8],!torch.qint8> -> !torch.vtensor<[1,3,8,8],f32>
|
||
|
%13 = torch.operator "torch.quantized_decomposed.dequantize_per_tensor"(%arg0, %float, %zp, %min, %max, %dtype) : (!torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.vtensor<[1,3,8,8],f32>
|
||
|
return %13 : !torch.vtensor<[1,3,8,8],f32>
|
||
|
}
|