mirror of https://github.com/llvm/torch-mlir
[onnx] Support `fp8` for `onnx.QuantizeLinear` (#3619)
We need to directly decompose quantize linear for `fp8` types as the equivalent torch operations do not support the operation.pull/3292/merge
parent
8358e8c255
commit
44266ab0c4
|
@ -214,6 +214,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
binder.tensorResultType(resultType))
|
binder.tensorResultType(resultType))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
auto loc = binder.getLoc();
|
||||||
Value operand = operands[0];
|
Value operand = operands[0];
|
||||||
Value scale = operands[1];
|
Value scale = operands[1];
|
||||||
Value zeropoint = operands[2];
|
Value zeropoint = operands[2];
|
||||||
|
@ -225,33 +226,61 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
return rewriter.notifyMatchFailure(binder.op,
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
"requires known result dtype");
|
"requires known result dtype");
|
||||||
|
|
||||||
if (scaleTy.getSizes().size() == 0) {
|
auto resultETy = resultType.getDtype();
|
||||||
auto qTensorTy = getQTorchTypeFromTorchIntType(resultType);
|
|
||||||
if (!qTensorTy) {
|
|
||||||
return rewriter.notifyMatchFailure(binder.op,
|
|
||||||
"unsupported result dtype");
|
|
||||||
}
|
|
||||||
|
|
||||||
auto torchqTy = Torch::getScalarTypeForType(qTensorTy.getDtype());
|
bool rank0 = scaleTy.getSizes().size() == 0;
|
||||||
|
bool length1 =
|
||||||
|
scaleTy.getSizes().size() == 1 && scaleTy.getSizes()[0] == 1;
|
||||||
|
|
||||||
Value tyConst = rewriter.create<Torch::ConstantIntOp>(
|
if (!rank0 && !length1)
|
||||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
"unimplemented: non-scalar scale");
|
||||||
static_cast<int64_t>(torchqTy)));
|
|
||||||
|
|
||||||
scale = rewriter.create<Torch::AtenItemOp>(
|
auto qTensorTy = getQTorchTypeFromTorchIntType(resultType);
|
||||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(), scale);
|
if (!qTensorTy) {
|
||||||
zeropoint = rewriter.create<Torch::AtenItemOp>(
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
binder.getLoc(), rewriter.getType<Torch::IntType>(), zeropoint);
|
"unsupported result dtype");
|
||||||
|
}
|
||||||
|
|
||||||
auto quantize = rewriter.create<Torch::AtenQuantizePerTensorOp>(
|
auto torchqTy = Torch::getScalarTypeForType(qTensorTy.getDtype());
|
||||||
binder.getLoc(), qTensorTy, operand, scale, zeropoint, tyConst);
|
|
||||||
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(
|
Value tyConst = rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.op, resultType, quantize);
|
loc, rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||||
|
static_cast<int64_t>(torchqTy)));
|
||||||
|
|
||||||
|
scale = rewriter.create<Torch::AtenItemOp>(
|
||||||
|
loc, rewriter.getType<Torch::FloatType>(), scale);
|
||||||
|
|
||||||
|
bool fpResult = isa<mlir::FloatType>(resultETy);
|
||||||
|
Type zeropointTy = rewriter.getType<Torch::IntType>();
|
||||||
|
if (fpResult)
|
||||||
|
zeropointTy = rewriter.getType<Torch::FloatType>();
|
||||||
|
zeropoint =
|
||||||
|
rewriter.create<Torch::AtenItemOp>(loc, zeropointTy, zeropoint);
|
||||||
|
|
||||||
|
if (fpResult) {
|
||||||
|
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||||
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
||||||
|
Value one = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
|
loc, rewriter.getF64FloatAttr(1.0));
|
||||||
|
Value div = rewriter.create<Torch::AtenDivScalarOp>(
|
||||||
|
loc, operand.getType(), operand, scale);
|
||||||
|
Value add = rewriter.create<Torch::AtenAddScalarOp>(
|
||||||
|
loc, operand.getType(), div, zeropoint, one);
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
|
||||||
|
binder.op, resultType, add, tyConst,
|
||||||
|
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
|
||||||
|
/*memory_format=*/none);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
return failure();
|
auto quantize = rewriter.create<Torch::AtenQuantizePerTensorOp>(
|
||||||
|
loc, qTensorTy, operand, scale, zeropoint, tyConst);
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(binder.op, resultType,
|
||||||
|
quantize);
|
||||||
|
return success();
|
||||||
});
|
});
|
||||||
patterns.onOp(
|
patterns.onOp(
|
||||||
"QLinearConv", 1,
|
"QLinearConv", 1,
|
||||||
|
|
|
@ -47,6 +47,23 @@ func.func @test_quantizelinear_i32(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_quantizelinear_f8
|
||||||
|
func.func @test_quantizelinear_f8(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[6],f8E4M3FN> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} {
|
||||||
|
// CHECK: %[[DTYPE:.+]] = torch.constant.int 24
|
||||||
|
// CHECK: %[[SCALE:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float
|
||||||
|
// CHECK: %[[ZP:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],f32> -> !torch.float
|
||||||
|
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||||
|
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[ONE:.+]] = torch.constant.float 1.000000e+00
|
||||||
|
// CHECK: %[[DIV:.+]] = torch.aten.div.Scalar %arg0, %[[SCALE]]
|
||||||
|
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %[[DIV]], %[[ZP]], %[[ONE]]
|
||||||
|
// CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[ADD]], %[[DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]]
|
||||||
|
%0 = torch.operator "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[6],f8E4M3FN>
|
||||||
|
return %0 : !torch.vtensor<[6],f8E4M3FN>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @test_qlinearconv_nobias
|
// CHECK-LABEL: @test_qlinearconv_nobias
|
||||||
func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,1,1,1],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,1,1,1],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
%0 = torch.operator "onnx.QLinearConv"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[1,1,7,7],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[1,1,1,1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8>
|
%0 = torch.operator "onnx.QLinearConv"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[1,1,7,7],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[1,1,1,1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8>
|
||||||
|
|
Loading…
Reference in New Issue