torch-mlir/test/Dialect/Torch/match-quantized-customs-ops...

43 lines
2.7 KiB
MLIR

// RUN: torch-mlir-opt --split-input-file --torch-match-quantized-custom-ops %s | FileCheck %s
// CHECK-LABEL: func.func @quantize_per_tensor
func.func @quantize_per_tensor(%arg0: !torch.vtensor<[1,3,8,8],f32>) -> !torch.vtensor<[1,3,8,8],si8> {
%float = torch.constant.float 0.5
%zp = torch.constant.int 17
%min = torch.constant.int -128
%max = torch.constant.int 127
%dtype = torch.constant.int 1
// CHECK-DAG: %[[SCALE:.+]] = torch.constant.float 5.000000e-01
// CHECK-DAG: %[[ZP:.+]] = torch.constant.int 17
// CHECK-DAG: %[[MIN:.+]] = torch.constant.int -128
// CHECK-DAG: %[[MAX:.+]] = torch.constant.int 127
// CHECK-DAG: %[[DTYPE:.+]] = torch.constant.int 1
// CHECK-DAG: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %arg0, %[[SCALE]], %[[ZP]], %[[DTYPE]] : !torch.vtensor<[1,3,8,8],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8>
// CHECK-DAG: %[[REPR:.+]] = torch.aten.int_repr %[[QUANT]] : !torch.vtensor<[1,3,8,8],!torch.qint8> -> !torch.vtensor<[1,3,8,8],si8>
// CHECK: torch.aten.clamp %[[REPR]], %[[MIN]], %[[MAX]]
%0 = torch.operator "torch.quantized_decomposed.quantize_per_tensor"(%arg0, %float, %zp, %min, %max, %dtype) : (!torch.vtensor<[1,3,8,8],f32>, !torch.float, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.vtensor<[1,3,8,8],si8>
return %0 : !torch.vtensor<[1,3,8,8],si8>
}
// -----
// CHECK-LABEL: func.func @dequantize_per_tensor
func.func @dequantize_per_tensor(%arg0: !torch.vtensor<[1,3,8,8],si8>) -> !torch.vtensor<[1,3,8,8],f32> {
%float = torch.constant.float 0.5
%zp = torch.constant.int 17
%min = torch.constant.int -128
%max = torch.constant.int 127
%dtype = torch.constant.int 1
// CHECK-DAG: %[[SCALE:.+]] = torch.constant.float 5.000000e-01
// CHECK-DAG: %[[ZP:.+]] = torch.constant.int 17
// CHECK-DAG: %[[MIN:.+]] = torch.constant.int -128
// CHECK-DAG: %[[MAX:.+]] = torch.constant.int 127
// CHECK-DAG: %[[CLAMP:.+]] = torch.aten.clamp %arg0, %[[MIN]], %[[MAX]] : !torch.vtensor<[1,3,8,8],si8>, !torch.int, !torch.int -> !torch.vtensor<[1,3,8,8],si8>
// CHECK-DAG: %[[QINT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CLAMP]], %[[SCALE]], %[[ZP]] : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8>
// CHECK: %[[DEQUANT:.+]] = torch.aten.dequantize.tensor %[[QINT]] : !torch.vtensor<[1,3,8,8],!torch.qint8> -> !torch.vtensor<[1,3,8,8],f32>
%13 = torch.operator "torch.quantized_decomposed.dequantize_per_tensor"(%arg0, %float, %zp, %min, %max, %dtype) : (!torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.vtensor<[1,3,8,8],f32>
return %13 : !torch.vtensor<[1,3,8,8],f32>
}