mirror of https://github.com/llvm/torch-mlir
Register fake_quantize related ops (#3522)
Register `aten.fake_quantize_per_channel_affine` and `aten.fake_quantize_per_tensor_affine.tensor_qparams` ops --------- Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>pull/3525/head
parent
0fe74845da
commit
d466d5b809
|
@ -4623,6 +4623,61 @@ def Torch_AtenFakeQuantizePerTensorAffineCachemaskOp : Torch_Op<"aten.fake_quant
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenFakeQuantizePerTensorAffineTensorQparamsOp : Torch_Op<"aten.fake_quantize_per_tensor_affine.tensor_qparams", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::fake_quantize_per_tensor_affine.tensor_qparams : (Tensor, Tensor, Tensor, int, int) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$scale,
|
||||
AnyTorchTensorType:$zero_point,
|
||||
Torch_IntType:$quant_min,
|
||||
Torch_IntType:$quant_max
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenFakeQuantizePerTensorAffineTensorQparamsOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 5, 1);
|
||||
}
|
||||
void AtenFakeQuantizePerTensorAffineTensorQparamsOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 5, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenFakeQuantizePerChannelAffineOp : Torch_Op<"aten.fake_quantize_per_channel_affine", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::fake_quantize_per_channel_affine : (Tensor, Tensor, Tensor, int, int, int) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$scale,
|
||||
AnyTorchTensorType:$zero_point,
|
||||
Torch_IntType:$axis,
|
||||
Torch_IntType:$quant_min,
|
||||
Torch_IntType:$quant_max
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenFakeQuantizePerChannelAffineOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 6, 1);
|
||||
}
|
||||
void AtenFakeQuantizePerChannelAffineOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 6, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenMaximumOp : Torch_Op<"aten.maximum", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -6363,6 +6363,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %2 = torch.prim.TupleConstruct %0, %1 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
||||
" return %2 : !torch.tuple<list<int>, list<int>>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.fake_quantize_per_tensor_affine.tensor_qparams\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.fake_quantize_per_channel_affine\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.sin\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -10551,6 +10559,48 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %4 = torch.prim.TupleConstruct %3, %int11 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
||||
" return %4 : !torch.tuple<int, int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.fake_quantize_per_tensor_affine.tensor_qparams\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.int {\n"
|
||||
" %int15 = torch.constant.int 15\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||
" torch.prim.If %1 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %2 = torch.aten.ne.int %0#1, %int15 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %2 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.fake_quantize_per_channel_affine\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int) -> !torch.int {\n"
|
||||
" %int15 = torch.constant.int 15\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||
" torch.prim.If %1 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %2 = torch.aten.ne.int %0#1, %int15 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %2 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.cosh\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
|
||||
|
|
|
@ -135,6 +135,12 @@ def aten〇fake_quantize_per_tensor_affine〡shape(self: List[int], scale: float
|
|||
def aten〇fake_quantize_per_tensor_affine_cachemask〡shape(self: List[int], scale: float, zero_point: int, quant_min: int, quant_max: int) -> Tuple[List[int], List[int]]:
|
||||
return (upstream_shape_functions.unary(self), upstream_shape_functions.unary(self))
|
||||
|
||||
def aten〇fake_quantize_per_tensor_affine〇tensor_qparams〡shape(self: List[int], scale: List[int], zero_point: List[int], quant_min: int, quant_max: int) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇fake_quantize_per_channel_affine〡shape(self: List[int], scale: List[int], zero_point: List[int], axis: int, quant_min: int, quant_max: int) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇sin〡shape(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
|
@ -2301,6 +2307,22 @@ def aten〇fake_quantize_per_tensor_affine_cachemask〡dtype(self_rank_dtype: Tu
|
|||
assert self_dtype != torch.bfloat16
|
||||
return (self_rank_dtype[1], torch.bool)
|
||||
|
||||
# note: fake_quantize_per_tensor_affine.tensor_qparams doesn't support "meta" device, use "cpu" instead.
|
||||
@check_dtype_function(Invocation(TensorOfShape(3, 3, dtype=dtype, device="cpu"), TensorOfShape(1, dtype=torch.float32, device="cpu"), TensorOfShape(1, dtype=torch.int32, device="cpu"), 0, 255) for dtype in [torch.float64, torch.float32, torch.float16])
|
||||
def aten〇fake_quantize_per_tensor_affine〇tensor_qparams〡dtype(self_rank_dtype: Tuple[int, int], scale_rank_dtype: Tuple[int, int], zero_point_rank_dtype: Tuple[int, int], quant_min: int, quant_max: int) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
assert is_float_dtype(self_dtype)
|
||||
assert self_dtype != torch.bfloat16
|
||||
return self_dtype
|
||||
|
||||
# note: fake_quantize_per_channel_affine doesn't support "meta" device, use "cpu" instead.
|
||||
@check_dtype_function(Invocation(TensorOfShape(3, 3, dtype=dtype, device="cpu"), TensorOfShape(3, dtype=torch.float32, device="cpu"), TensorOfShape(3, dtype=torch.int32, device="cpu"), 0, 0, 255) for dtype in [torch.float64, torch.float32, torch.float16])
|
||||
def aten〇fake_quantize_per_channel_affine〡dtype(self_rank_dtype: Tuple[int, int], scale_rank_dtype: Tuple[int, int], zero_point_rank_dtype: Tuple[int, int], axis: int, quant_min: int, quant_max: int) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
assert is_float_dtype(self_dtype)
|
||||
assert self_dtype != torch.bfloat16
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||
def aten〇cosh〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
|
|
|
@ -461,6 +461,12 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit(
|
||||
"aten::fake_quantize_per_tensor_affine_cachemask : (Tensor, float, int, int, int) -> (Tensor, Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::fake_quantize_per_tensor_affine.tensor_qparams : (Tensor, Tensor, Tensor, int, int) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::fake_quantize_per_channel_affine : (Tensor, Tensor, Tensor, int, int, int) -> (Tensor)"
|
||||
)
|
||||
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::fmax : (Tensor, Tensor) -> (Tensor)")
|
||||
|
|
|
@ -188,5 +188,20 @@ func.func @torch.permute$negative_index_valid (%arg0: !torch.vtensor<[1,2,3],f32
|
|||
%int1 = torch.constant.int 1
|
||||
%perm = torch.prim.ListConstruct %int0, %int1, %intm1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
%3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list<int> -> !torch.vtensor<[1,2,3],f32>
|
||||
return %3 : !torch.vtensor<[1,2,3],f32>
|
||||
return %3 : !torch.vtensor<[1,2,3],f32>
|
||||
}
|
||||
|
||||
// Check fake quantize ops
|
||||
func.func @torch.aten.fake_quantize_per_channel_affine (%arg0: !torch.vtensor<[3,3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],si32>) -> !torch.vtensor<[3,3],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
%int255 = torch.constant.int 255
|
||||
%1 = torch.aten.fake_quantize_per_channel_affine %arg0, %arg1, %arg2, %int0, %int0, %int255 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],si32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,3],f32>
|
||||
return %1 : !torch.vtensor<[3,3],f32>
|
||||
}
|
||||
|
||||
func.func @torch.aten.fake_quantize_per_tensor_affine.tensor_qparams (%arg0: !torch.vtensor<[3,3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],si32>) -> !torch.vtensor<[3,3],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
%int255 = torch.constant.int 255
|
||||
%1 = torch.aten.fake_quantize_per_tensor_affine.tensor_qparams %arg0, %arg1, %arg2, %int0, %int255 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],si32>, !torch.int, !torch.int -> !torch.vtensor<[3,3],f32>
|
||||
return %1 : !torch.vtensor<[3,3],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue