mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] support aten.fake_quantize_per_tensor_affine (#3014)
parent
798bfd7dff
commit
4282eb9e76
|
@ -4366,6 +4366,33 @@ def Torch_AtenAddcdiv_Op : Torch_Op<"aten.addcdiv_", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenFakeQuantizePerTensorAffineOp : Torch_Op<"aten.fake_quantize_per_tensor_affine", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::fake_quantize_per_tensor_affine : (Tensor, float, int, int, int) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_FloatType:$scale,
|
||||
Torch_IntType:$zero_point,
|
||||
Torch_IntType:$quant_min,
|
||||
Torch_IntType:$quant_max
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenFakeQuantizePerTensorAffineOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 5, 1);
|
||||
}
|
||||
void AtenFakeQuantizePerTensorAffineOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 5, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenMaximumOp : Torch_Op<"aten.maximum", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -6306,6 +6306,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %18 = torch.aten.append.t %7, %17 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||
" return %7 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.fake_quantize_per_tensor_affine\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.sin\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -9432,6 +9436,40 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.fake_quantize_per_tensor_affine\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int) -> !torch.int {\n"
|
||||
" %int15 = torch.constant.int 15\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||
" torch.prim.If %1 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %2 = torch.aten.ne.int %0#1, %int15 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %2 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%arg0: !torch.int) -> !torch.bool {\n"
|
||||
" %0 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_float_dtypes() : () -> !torch.list<int>\n"
|
||||
" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list<int>, !torch.int -> !torch.bool\n"
|
||||
" return %1 : !torch.bool\n"
|
||||
" }\n"
|
||||
" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_float_dtypes() -> !torch.list<int> {\n"
|
||||
" %int7 = torch.constant.int 7\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
" %int15 = torch.constant.int 15\n"
|
||||
" %int5 = torch.constant.int 5\n"
|
||||
" %0 = torch.prim.ListConstruct %int5, %int15, %int6, %int7 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.cosh\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
|
||||
|
@ -9461,19 +9499,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %3 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%arg0: !torch.int) -> !torch.bool {\n"
|
||||
" %0 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_float_dtypes() : () -> !torch.list<int>\n"
|
||||
" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list<int>, !torch.int -> !torch.bool\n"
|
||||
" return %1 : !torch.bool\n"
|
||||
" }\n"
|
||||
" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_float_dtypes() -> !torch.list<int> {\n"
|
||||
" %int7 = torch.constant.int 7\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
" %int15 = torch.constant.int 15\n"
|
||||
" %int5 = torch.constant.int 5\n"
|
||||
" %0 = torch.prim.ListConstruct %int5, %int15, %int6, %int7 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.acosh\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
|
||||
|
|
|
@ -7196,6 +7196,57 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenFakeQuantizePerTensorAffineOp
|
||||
: public OpRewritePattern<AtenFakeQuantizePerTensorAffineOp> {
|
||||
public:
|
||||
using OpRewritePattern<AtenFakeQuantizePerTensorAffineOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenFakeQuantizePerTensorAffineOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
MLIRContext *context = getContext();
|
||||
|
||||
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||
Value falseVal = rewriter.create<ConstantBoolOp>(loc, false);
|
||||
Value one =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
|
||||
|
||||
// input/scale
|
||||
Value divScale = rewriter.create<AtenDivScalarOp>(
|
||||
loc, op.getType(), op.getSelf(), op.getScale());
|
||||
// std::nearby_int(input/scale)
|
||||
Value round = rewriter.create<AtenRoundOp>(loc, op.getType(), divScale);
|
||||
// std::nearby_int(input/scale) + zero_point
|
||||
Value addZeroPoint = rewriter.create<AtenAddScalarOp>(
|
||||
loc, op.getType(), round, op.getZeroPoint(), one);
|
||||
// max(quant_min, std::nearby_int(input/scale) + zero_point)
|
||||
Value max = rewriter.create<AtenMaximumOp>(
|
||||
loc, op.getType(), addZeroPoint,
|
||||
rewriter.create<AtenTensorIntOp>(loc, baseType, op.getQuantMin(),
|
||||
/*dtype=*/none,
|
||||
/*device=*/none,
|
||||
/*requires_grad=*/falseVal));
|
||||
// min(quant_max, max(quant_min, std::nearby_int(input/scale) + zero_point))
|
||||
Value min = rewriter.create<AtenMinimumOp>(
|
||||
loc, op.getType(), max,
|
||||
rewriter.create<AtenTensorIntOp>(loc, baseType, op.getQuantMax(),
|
||||
/*dtype=*/none, /*device=*/none,
|
||||
/*requires_grad=*/falseVal));
|
||||
// min(quant_max, max(quant_min, std::nearby_int(input/scale) + zero_point))
|
||||
// - zero_point
|
||||
Value subZeroPoint = rewriter.create<AtenSubScalarOp>(
|
||||
loc, op.getType(), min, op.getZeroPoint(), one);
|
||||
// (min(quant_max, max(quant_min, std::nearby_int(input/scale) +
|
||||
// zero_point)) - zero_point) * scale
|
||||
Value result = rewriter.create<AtenMulScalarOp>(
|
||||
loc, op.getType(), subZeroPoint, op.getScale());
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeComplexOpsPass
|
||||
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
||||
|
@ -7382,6 +7433,8 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenNormalFunctionalOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenEluOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenFakeQuantizePerTensorAffineOp>(
|
||||
patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSeluOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluBackwardOp>(patterns);
|
||||
|
|
|
@ -449,6 +449,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenHardsigmoidOp>();
|
||||
target.addIllegalOp<AtenRelu6Op>();
|
||||
target.addIllegalOp<AtenEluOp>();
|
||||
target.addIllegalOp<AtenFakeQuantizePerTensorAffineOp>();
|
||||
target.addIllegalOp<AtenGluOp>();
|
||||
target.addIllegalOp<AtenSeluOp>();
|
||||
target.addIllegalOp<AtenHardswishOp>();
|
||||
|
|
|
@ -308,6 +308,9 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
# Others
|
||||
"GridSamplerBasic1_basic",
|
||||
"GridSamplerBasic2_basic",
|
||||
"FakeQuantizePerTensorAffineModule_basic",
|
||||
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
|
||||
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
|
||||
}
|
||||
|
||||
TORCHDYNAMO_CRASHING_SET = {
|
||||
|
@ -846,6 +849,8 @@ STABLEHLO_PASS_SET = {
|
|||
"LinspaceModule_basic",
|
||||
"LinspaceOneSizeModule_basic",
|
||||
"LinspaceTwoSizeModule_basic",
|
||||
"FakeQuantizePerTensorAffineModule_basic",
|
||||
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
|
||||
}
|
||||
|
||||
STABLEHLO_CRASHING_SET = {
|
||||
|
@ -2120,5 +2125,8 @@ ONNX_XFAIL_SET = {
|
|||
"AtenLinalgCrossDynamic_basic"
|
||||
}
|
||||
|
||||
ONNX_CRASHING_SET = { }
|
||||
ONNX_CRASHING_SET = {
|
||||
"FakeQuantizePerTensorAffineModule_basic",
|
||||
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
|
||||
}
|
||||
|
||||
|
|
|
@ -89,6 +89,9 @@ def aten〇diagonal〡shape(self: List[int], offset: int = 0, dim1: int = 0, dim
|
|||
|
||||
return diagonal
|
||||
|
||||
def aten〇fake_quantize_per_tensor_affine〡shape(self: List[int], scale: float, zero_point: int, quant_min: int, quant_max: int) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇sin〡shape(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
|
@ -1892,6 +1895,13 @@ def prims〇split_dim〡dtype(a_rank_dtype: Tuple[int, int], dim: int, outer_len
|
|||
_, a_dtype = a_rank_dtype
|
||||
return a_dtype
|
||||
|
||||
# note: fake_quantize_per_tensor_affine doesn't support "meta" device, use "cpu" instead.
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, tensor_device="cpu", scale=0.1, zero_point=0, quant_min=0, quant_max=255, error_types={torch.complex128, torch.complex64, torch.bfloat16, torch.int64, torch.int32, torch.int16, torch.int8, torch.uint8, torch.bool}))
|
||||
def aten〇fake_quantize_per_tensor_affine〡dtype(self_rank_dtype: Tuple[int, int], scale: float, zero_point: int, quant_min: int, quant_max: int) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
assert is_float_dtype(self_dtype)
|
||||
assert self_dtype != torch.bfloat16
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||
def aten〇cosh〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||
|
|
|
@ -353,6 +353,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
|
||||
emit_with_mutating_variants("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
|
||||
emit_with_mutating_variants("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
|
||||
emit("aten::fake_quantize_per_tensor_affine : (Tensor, float, int, int, int) -> (Tensor)")
|
||||
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::mish : (Tensor) -> (Tensor)")
|
||||
|
|
|
@ -4739,5 +4739,52 @@ class GluStaticModule(torch.nn.Module):
|
|||
return torch.ops.aten.glu(x, dim=1)
|
||||
|
||||
@register_test_case(module_factory=lambda: GluStaticModule())
|
||||
def GluStaticModule_basic(module, tu: TestUtils):
|
||||
def GluStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 24, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class FakeQuantizePerTensorAffineModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([4, 50], torch.float32, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.fake_quantize_per_tensor_affine(x, 0.1, 1, 0, 255)
|
||||
|
||||
@register_test_case(module_factory=lambda: FakeQuantizePerTensorAffineModule())
|
||||
def FakeQuantizePerTensorAffineModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 50))
|
||||
|
||||
class FakeQuantizePerTensorAffineDynamicShapeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.fake_quantize_per_tensor_affine(x, 0.1, 1, 0, 255)
|
||||
|
||||
@register_test_case(module_factory=lambda: FakeQuantizePerTensorAffineDynamicShapeModule())
|
||||
def FakeQuantizePerTensorAffineDynamicShapeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 50))
|
||||
|
||||
class FakeQuantizePerTensorAffineRoundToEvenModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([4], torch.float32, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.fake_quantize_per_tensor_affine(x, 0.1, 0, -128, 127)
|
||||
|
||||
@register_test_case(module_factory=lambda: FakeQuantizePerTensorAffineRoundToEvenModule())
|
||||
def FakeQuantizePerTensorAffineRoundToEvenModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.FloatTensor([0.5, 1.5, -0.5, -1.5]))
|
||||
|
|
Loading…
Reference in New Issue