mirror of https://github.com/llvm/torch-mlir
Register fake_quantize_cachemask ops and add their decompose patterns (#3556)
Test: `cmake --build build --target check-torch-mlir-all`pull/3559/head
parent
21ad890009
commit
d1e172f418
|
@ -4650,6 +4650,35 @@ def Torch_AtenFakeQuantizePerTensorAffineTensorQparamsOp : Torch_Op<"aten.fake_q
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_Aten_FakeQuantizePerTensorAffineCachemaskTensorQparamsOp : Torch_Op<"aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams : (Tensor, Tensor, Tensor, Tensor, int, int) -> (Tensor, Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$scale,
|
||||
AnyTorchTensorType:$zero_point,
|
||||
AnyTorchTensorType:$fake_quant_enabled,
|
||||
Torch_IntType:$quant_min,
|
||||
Torch_IntType:$quant_max
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$output,
|
||||
AnyTorchOptionalTensorType:$mask
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult Aten_FakeQuantizePerTensorAffineCachemaskTensorQparamsOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 6, 2);
|
||||
}
|
||||
void Aten_FakeQuantizePerTensorAffineCachemaskTensorQparamsOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 6, 2);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenFakeQuantizePerChannelAffineOp : Torch_Op<"aten.fake_quantize_per_channel_affine", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -4678,6 +4707,35 @@ def Torch_AtenFakeQuantizePerChannelAffineOp : Torch_Op<"aten.fake_quantize_per_
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenFakeQuantizePerChannelAffineCachemaskOp : Torch_Op<"aten.fake_quantize_per_channel_affine_cachemask", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::fake_quantize_per_channel_affine_cachemask : (Tensor, Tensor, Tensor, int, int, int) -> (Tensor, Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$scale,
|
||||
AnyTorchTensorType:$zero_point,
|
||||
Torch_IntType:$axis,
|
||||
Torch_IntType:$quant_min,
|
||||
Torch_IntType:$quant_max
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$output,
|
||||
AnyTorchOptionalTensorType:$mask
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenFakeQuantizePerChannelAffineCachemaskOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 6, 2);
|
||||
}
|
||||
void AtenFakeQuantizePerChannelAffineCachemaskOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 6, 2);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenMaximumOp : Torch_Op<"aten.maximum", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -6367,10 +6367,22 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.int, %arg5: !torch.int) -> !torch.tuple<list<int>, list<int>> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" %1 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" %2 = torch.prim.TupleConstruct %0, %1 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
||||
" return %2 : !torch.tuple<list<int>, list<int>>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.fake_quantize_per_channel_affine\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.fake_quantize_per_channel_affine_cachemask\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int) -> !torch.tuple<list<int>, list<int>> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" %1 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" %2 = torch.prim.TupleConstruct %0, %1 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
||||
" return %2 : !torch.tuple<list<int>, list<int>>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.sin\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -10735,6 +10747,31 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.tuple<int, int>, %arg4: !torch.int, %arg5: !torch.int) -> !torch.tuple<int, int> {\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
" %int15 = torch.constant.int 15\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||
" torch.prim.If %1 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %2 = torch.aten.ne.int %0#1, %int15 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %2 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %3 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
|
||||
" %4 = torch.prim.TupleConstruct %3, %int11 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
||||
" return %4 : !torch.tuple<int, int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.fake_quantize_per_channel_affine\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int) -> !torch.int {\n"
|
||||
" %int15 = torch.constant.int 15\n"
|
||||
" %none = torch.constant.none\n"
|
||||
|
@ -10756,6 +10793,31 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.fake_quantize_per_channel_affine_cachemask\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int) -> !torch.tuple<int, int> {\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
" %int15 = torch.constant.int 15\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||
" torch.prim.If %1 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %2 = torch.aten.ne.int %0#1, %int15 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %2 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %3 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
|
||||
" %4 = torch.prim.TupleConstruct %3, %int11 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
||||
" return %4 : !torch.tuple<int, int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.cosh\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
|
||||
|
|
|
@ -9029,6 +9029,61 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams
|
||||
// into aten.fake_quantize_per_tensor_affine.tensor_qparams
|
||||
// when the second result is unused.
|
||||
class DecomposeAten_FakeQuantizePerTensorAffineCachemaskTensorQparamsOp
|
||||
: public OpRewritePattern<
|
||||
Aten_FakeQuantizePerTensorAffineCachemaskTensorQparamsOp> {
|
||||
public:
|
||||
using OpRewritePattern<
|
||||
Aten_FakeQuantizePerTensorAffineCachemaskTensorQparamsOp>::
|
||||
OpRewritePattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(Aten_FakeQuantizePerTensorAffineCachemaskTensorQparamsOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!op->getResult(1).use_empty())
|
||||
return failure();
|
||||
|
||||
auto newOp =
|
||||
rewriter.create<AtenFakeQuantizePerTensorAffineTensorQparamsOp>(
|
||||
op.getLoc(), op->getResult(0).getType(), op.getSelf(),
|
||||
op.getScale(), op.getZeroPoint(), op.getQuantMin(),
|
||||
op.getQuantMax());
|
||||
|
||||
rewriter.replaceAllUsesWith(op->getResult(0), newOp);
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose aten.fake_quantize_per_channel_affine_cachemask
|
||||
// into aten.fake_quantize_per_channel_affine
|
||||
// when the second result is unused.
|
||||
class DecomposeAtenFakeQuantizePerChannelAffineCachemaskOp
|
||||
: public OpRewritePattern<AtenFakeQuantizePerChannelAffineCachemaskOp> {
|
||||
public:
|
||||
using OpRewritePattern<
|
||||
AtenFakeQuantizePerChannelAffineCachemaskOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenFakeQuantizePerChannelAffineCachemaskOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!op->getResult(1).use_empty())
|
||||
return failure();
|
||||
|
||||
auto newOp = rewriter.create<AtenFakeQuantizePerChannelAffineOp>(
|
||||
op.getLoc(), op->getResult(0).getType(), op.getSelf(), op.getScale(),
|
||||
op.getZeroPoint(), op.getAxis(), op.getQuantMin(), op.getQuantMax());
|
||||
|
||||
rewriter.replaceAllUsesWith(op->getResult(0), newOp);
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose aten.fmax/fmin to aten.maximum/minimum + aten.where(nanMask)
|
||||
template <typename AtenFOpT, typename AtenOpT>
|
||||
|
@ -9306,6 +9361,11 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgSlogdetOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
DecomposeAtenFakeQuantizePerTensorAffineCachemaskOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
DecomposeAten_FakeQuantizePerTensorAffineCachemaskTensorQparamsOp>(
|
||||
patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
DecomposeAtenFakeQuantizePerChannelAffineCachemaskOp>(patterns);
|
||||
// More specific conv ops
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTbcOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenConv1dOp>(patterns);
|
||||
|
|
|
@ -138,9 +138,15 @@ def aten〇fake_quantize_per_tensor_affine_cachemask〡shape(self: List[int], sc
|
|||
def aten〇fake_quantize_per_tensor_affine〇tensor_qparams〡shape(self: List[int], scale: List[int], zero_point: List[int], quant_min: int, quant_max: int) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇_fake_quantize_per_tensor_affine_cachemask_tensor_qparams〡shape(self: List[int], scale: List[int], zero_point: List[int], fake_quant_enabled: List[int], quant_min: int, quant_max: int) -> Tuple[List[int], List[int]]:
|
||||
return (upstream_shape_functions.unary(self), upstream_shape_functions.unary(self))
|
||||
|
||||
def aten〇fake_quantize_per_channel_affine〡shape(self: List[int], scale: List[int], zero_point: List[int], axis: int, quant_min: int, quant_max: int) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇fake_quantize_per_channel_affine_cachemask〡shape(self: List[int], scale: List[int], zero_point: List[int], axis: int, quant_min: int, quant_max: int) -> Tuple[List[int], List[int]]:
|
||||
return (upstream_shape_functions.unary(self), upstream_shape_functions.unary(self))
|
||||
|
||||
def aten〇sin〡shape(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
|
@ -2372,6 +2378,14 @@ def aten〇fake_quantize_per_tensor_affine〇tensor_qparams〡dtype(self_rank_dt
|
|||
assert self_dtype != torch.bfloat16
|
||||
return self_dtype
|
||||
|
||||
# note: _fake_quantize_per_tensor_affine_cachemask_tensor_qparams doesn't support "meta" device, use "cpu" instead.
|
||||
@check_dtype_function(Invocation(TensorOfShape(3, 3, dtype=dtype, device="cpu"), TensorOfShape(1, dtype=torch.float32, device="cpu"), TensorOfShape(1, dtype=torch.int32, device="cpu"), TensorOfShape(1, dtype=torch.int32, device="cpu"), 0, 255) for dtype in [torch.float64, torch.float32, torch.float16])
|
||||
def aten〇_fake_quantize_per_tensor_affine_cachemask_tensor_qparams〡dtype(self_rank_dtype: Tuple[int, int], scale_rank_dtype: Tuple[int, int], zero_point_rank_dtype: Tuple[int, int], fake_quant_enabled_rank_dtype: Tuple[int, int], quant_min: int, quant_max: int) -> Tuple[int, int]:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
assert is_float_dtype(self_dtype)
|
||||
assert self_dtype != torch.bfloat16
|
||||
return (self_rank_dtype[1], torch.bool)
|
||||
|
||||
# note: fake_quantize_per_channel_affine doesn't support "meta" device, use "cpu" instead.
|
||||
@check_dtype_function(Invocation(TensorOfShape(3, 3, dtype=dtype, device="cpu"), TensorOfShape(3, dtype=torch.float32, device="cpu"), TensorOfShape(3, dtype=torch.int32, device="cpu"), 0, 0, 255) for dtype in [torch.float64, torch.float32, torch.float16])
|
||||
def aten〇fake_quantize_per_channel_affine〡dtype(self_rank_dtype: Tuple[int, int], scale_rank_dtype: Tuple[int, int], zero_point_rank_dtype: Tuple[int, int], axis: int, quant_min: int, quant_max: int) -> int:
|
||||
|
@ -2380,6 +2394,14 @@ def aten〇fake_quantize_per_channel_affine〡dtype(self_rank_dtype: Tuple[int,
|
|||
assert self_dtype != torch.bfloat16
|
||||
return self_dtype
|
||||
|
||||
# note: fake_quantize_per_channel_affine_cachemask doesn't support "meta" device, use "cpu" instead.
|
||||
@check_dtype_function(Invocation(TensorOfShape(3, 3, dtype=dtype, device="cpu"), TensorOfShape(3, dtype=torch.float32, device="cpu"), TensorOfShape(3, dtype=torch.int32, device="cpu"), 0, 0, 255) for dtype in [torch.float64, torch.float32, torch.float16])
|
||||
def aten〇fake_quantize_per_channel_affine_cachemask〡dtype(self_rank_dtype: Tuple[int, int], scale_rank_dtype: Tuple[int, int], zero_point_rank_dtype: Tuple[int, int], axis: int, quant_min: int, quant_max: int) -> Tuple[int, int]:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
assert is_float_dtype(self_dtype)
|
||||
assert self_dtype != torch.bfloat16
|
||||
return (self_rank_dtype[1], torch.bool)
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||
def aten〇cosh〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
|
|
|
@ -464,9 +464,15 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit(
|
||||
"aten::fake_quantize_per_tensor_affine.tensor_qparams : (Tensor, Tensor, Tensor, int, int) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams : (Tensor, Tensor, Tensor, Tensor, int, int) -> (Tensor, Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::fake_quantize_per_channel_affine : (Tensor, Tensor, Tensor, int, int, int) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::fake_quantize_per_channel_affine_cachemask : (Tensor, Tensor, Tensor, int, int, int) -> (Tensor, Tensor)"
|
||||
)
|
||||
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::fmax : (Tensor, Tensor) -> (Tensor)")
|
||||
|
|
|
@ -97,3 +97,37 @@ func.func @torch.aten.one_hot$fold(%arg0: !torch.vtensor<[3],si64>, %arg1: !torc
|
|||
%0 = torch.aten.one_hot %arg0, %arg1 : !torch.vtensor<[3],si64>, !torch.int -> !torch.vtensor<[3,?],si64>
|
||||
return %0 : !torch.vtensor<[3,?],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG_1:.*]]: !torch.vtensor<[1],f32>,
|
||||
// CHECK-SAME: %[[ARG_2:.*]]: !torch.vtensor<[1],si32>, %[[ARG_3:.*]]: !torch.vtensor<[1],si64>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
// CHECK: %[[CONST1:.*]] = torch.constant.int 127
|
||||
// CHECK: %[[CONST2:.*]] = torch.constant.int -128
|
||||
// CHECK: %[[RESULT:.*]] = torch.aten.fake_quantize_per_tensor_affine.tensor_qparams %[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[CONST2]], %[[CONST1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],si32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?,?],f32>
|
||||
func.func @torch.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],si32>, %arg3: !torch.vtensor<[1],si64>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
%int127 = torch.constant.int 127
|
||||
%int-128 = torch.constant.int -128
|
||||
%0:2 = torch.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams %arg0, %arg1, %arg2, %arg3, %int-128, %int127 : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],si32>, !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],i1>
|
||||
return %0#0 : !torch.vtensor<[?,?,?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.fake_quantize_per_channel_affine_cachemask(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG_1:.*]]: !torch.vtensor<[?],f32>,
|
||||
// CHECK-SAME: %[[ARG_2:.*]]: !torch.vtensor<[?],si32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
// CHECK: %[[CONST0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[CONST1:.*]] = torch.constant.int 127
|
||||
// CHECK: %[[CONST2:.*]] = torch.constant.int -128
|
||||
// CHECK: %[[RESULT:.*]] = torch.aten.fake_quantize_per_channel_affine %[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[CONST0]], %[[CONST2]], %[[CONST1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?],f32>, !torch.vtensor<[?],si32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?,?],f32>
|
||||
func.func @torch.aten.fake_quantize_per_channel_affine_cachemask(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?],f32>, %arg2: !torch.vtensor<[?],si32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
%int127 = torch.constant.int 127
|
||||
%int-128 = torch.constant.int -128
|
||||
%0:2 = torch.aten.fake_quantize_per_channel_affine_cachemask %arg0, %arg1, %arg2, %int0, %int-128, %int127 : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?],f32>, !torch.vtensor<[?],si32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],i1>
|
||||
return %0#0 : !torch.vtensor<[?,?,?,?],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue