[onnx] Add support for `fp8` `onnx.DequantizeLinear` (#3617)

Fp8 needs a slightly different path for dequantization as the `torch`
dequantize operation does not support `fp8` types.
pull/3622/head
Rob Suderman 2024-08-08 16:20:53 -07:00 committed by GitHub
parent 880e64bbbb
commit 8358e8c255
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 66 additions and 18 deletions

View File

@ -2117,41 +2117,73 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.tensorResultType(resultType))
return failure();
auto loc = binder.getLoc();
Value operand = operands[0];
Value scale = operands[1];
Value zeropoint = operands[2];
auto operandTy = cast<Torch::ValueTensorType>(operand.getType());
auto operandETy = operandTy.getDtype();
auto scaleTy = dyn_cast<Torch::ValueTensorType>(scale.getType());
if (!scaleTy || !scaleTy.hasSizes())
return rewriter.notifyMatchFailure(binder.op, "requires known rank");
if (!resultType.hasDtype())
return rewriter.notifyMatchFailure(binder.op,
"requires known result dtype");
if (scaleTy.getSizes().size() == 0 ||
(scaleTy.getSizes().size() == 1 && scaleTy.getSizes()[0] == 1)) {
auto qTensorTy = getQTorchTypeFromTorchIntType(operandTy);
if (!qTensorTy) {
return rewriter.notifyMatchFailure(binder.op,
"unsupported result dtype");
}
scale = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), scale);
zeropoint = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), zeropoint);
bool rank0 = scaleTy.getSizes().size() == 0;
bool length1 =
scaleTy.getSizes().size() == 1 && scaleTy.getSizes()[0] == 1;
auto quantize =
rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
binder.getLoc(), qTensorTy, operand, scale, zeropoint);
rewriter.replaceOpWithNewOp<Torch::AtenDequantizeSelfOp>(
binder.op, resultType, quantize);
if (!rank0 && !length1)
return rewriter.notifyMatchFailure(binder.op,
"unimplemented: non-scalar scale");
auto qTensorTy = getQTorchTypeFromTorchIntType(operandTy);
if (!qTensorTy) {
return rewriter.notifyMatchFailure(binder.op,
"unsupported result dtype");
}
scale = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::FloatType>(), scale);
bool fpOperand = isa<mlir::FloatType>(operandETy);
Type zeropointTy = rewriter.getType<Torch::IntType>();
if (fpOperand)
zeropointTy = rewriter.getType<Torch::FloatType>();
zeropoint =
rewriter.create<Torch::AtenItemOp>(loc, zeropointTy, zeropoint);
if (fpOperand) {
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
auto tyVal = Torch::getScalarTypeForType(resultType.getDtype());
Value tyConst = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
static_cast<int64_t>(tyVal)));
Value toDtype = rewriter.create<Torch::AtenToDtypeOp>(
loc, resultType, operand, tyConst,
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/none);
Value one = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(1.0));
Value sub = rewriter.create<Torch::AtenSubScalarOp>(
loc, resultType, toDtype, zeropoint, one);
rewriter.replaceOpWithNewOp<Torch::AtenMulScalarOp>(
binder.op, resultType, sub, scale);
return success();
}
return rewriter.notifyMatchFailure(binder.op,
"unimplemented: non-scalar scale");
auto quantize =
rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
loc, qTensorTy, operand, scale, zeropoint);
rewriter.replaceOpWithNewOp<Torch::AtenDequantizeSelfOp>(
binder.op, resultType, quantize);
return success();
});
patterns.onOp("Div", 7,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {

View File

@ -800,6 +800,22 @@ func.func @test_dequantizelinear_i32(%arg0: !torch.vtensor<[6],si32>, %arg1: !to
// -----
// CHECK-LABEL: @test_dequantizelinear_fp8
func.func @test_dequantizelinear_fp8(%arg0: !torch.vtensor<[6],f8E4M3FN>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],f8E4M3FN>) -> !torch.vtensor<[6],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} {
// CHECK: %[[SCALE:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float
// CHECK: %[[ZP:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],f8E4M3FN> -> !torch.float
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[DTY:.+]] = torch.constant.int 6
// CHECK: %[[TO:.+]] = torch.aten.to.dtype %arg0, %[[DTY]], %[[FALSE]], %[[FALSE]], %[[NONE]]
// CHECK: %[[ONE:.+]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB:.+]] = torch.aten.sub.Scalar %[[TO]], %[[ZP]], %[[ONE]]
// CHECK: %[[MUL:.+]] = torch.aten.mul.Scalar %[[SUB]], %[[SCALE]]
%0 = torch.operator "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],f8E4M3FN>, !torch.vtensor<[],f32>, !torch.vtensor<[],f8E4M3FN>) -> !torch.vtensor<[6],f32>
return %0 : !torch.vtensor<[6],f32>
}
// -----
// CHECK-LABEL: @test_div_bcast
func.func @test_div_bcast(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {