diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 0dd6620a4..3507bafb1 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -2117,41 +2117,73 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.tensorResultType(resultType)) return failure(); + auto loc = binder.getLoc(); Value operand = operands[0]; Value scale = operands[1]; Value zeropoint = operands[2]; auto operandTy = cast(operand.getType()); + auto operandETy = operandTy.getDtype(); auto scaleTy = dyn_cast(scale.getType()); if (!scaleTy || !scaleTy.hasSizes()) return rewriter.notifyMatchFailure(binder.op, "requires known rank"); if (!resultType.hasDtype()) return rewriter.notifyMatchFailure(binder.op, "requires known result dtype"); - if (scaleTy.getSizes().size() == 0 || - (scaleTy.getSizes().size() == 1 && scaleTy.getSizes()[0] == 1)) { - auto qTensorTy = getQTorchTypeFromTorchIntType(operandTy); - if (!qTensorTy) { - return rewriter.notifyMatchFailure(binder.op, - "unsupported result dtype"); - } - scale = rewriter.create( - binder.getLoc(), rewriter.getType(), scale); - zeropoint = rewriter.create( - binder.getLoc(), rewriter.getType(), zeropoint); + bool rank0 = scaleTy.getSizes().size() == 0; + bool length1 = + scaleTy.getSizes().size() == 1 && scaleTy.getSizes()[0] == 1; - auto quantize = - rewriter.create( - binder.getLoc(), qTensorTy, operand, scale, zeropoint); - rewriter.replaceOpWithNewOp( - binder.op, resultType, quantize); + if (!rank0 && !length1) + return rewriter.notifyMatchFailure(binder.op, + "unimplemented: non-scalar scale"); + auto qTensorTy = getQTorchTypeFromTorchIntType(operandTy); + if (!qTensorTy) { + return rewriter.notifyMatchFailure(binder.op, + "unsupported result dtype"); + } + + scale = rewriter.create( + loc, rewriter.getType(), scale); + + bool fpOperand = isa(operandETy); + Type zeropointTy = rewriter.getType(); + if (fpOperand) + zeropointTy = rewriter.getType(); + + zeropoint = + rewriter.create(loc, zeropointTy, zeropoint); + + if (fpOperand) { + Value none = rewriter.create(loc); + Value cstFalse = rewriter.create(loc, false); + auto tyVal = Torch::getScalarTypeForType(resultType.getDtype()); + Value tyConst = rewriter.create( + loc, rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + static_cast(tyVal))); + Value toDtype = rewriter.create( + loc, resultType, operand, tyConst, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/none); + + Value one = rewriter.create( + loc, rewriter.getF64FloatAttr(1.0)); + Value sub = rewriter.create( + loc, resultType, toDtype, zeropoint, one); + rewriter.replaceOpWithNewOp( + binder.op, resultType, sub, scale); return success(); } - return rewriter.notifyMatchFailure(binder.op, - "unimplemented: non-scalar scale"); + auto quantize = + rewriter.create( + loc, qTensorTy, operand, scale, zeropoint); + rewriter.replaceOpWithNewOp( + binder.op, resultType, quantize); + return success(); }); patterns.onOp("Div", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index d143e1832..2c70d6730 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -800,6 +800,22 @@ func.func @test_dequantizelinear_i32(%arg0: !torch.vtensor<[6],si32>, %arg1: !to // ----- +// CHECK-LABEL: @test_dequantizelinear_fp8 +func.func @test_dequantizelinear_fp8(%arg0: !torch.vtensor<[6],f8E4M3FN>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],f8E4M3FN>) -> !torch.vtensor<[6],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} { + // CHECK: %[[SCALE:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[ZP:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],f8E4M3FN> -> !torch.float + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[DTY:.+]] = torch.constant.int 6 + // CHECK: %[[TO:.+]] = torch.aten.to.dtype %arg0, %[[DTY]], %[[FALSE]], %[[FALSE]], %[[NONE]] + // CHECK: %[[ONE:.+]] = torch.constant.float 1.000000e+00 + // CHECK: %[[SUB:.+]] = torch.aten.sub.Scalar %[[TO]], %[[ZP]], %[[ONE]] + // CHECK: %[[MUL:.+]] = torch.aten.mul.Scalar %[[SUB]], %[[SCALE]] + %0 = torch.operator "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],f8E4M3FN>, !torch.vtensor<[],f32>, !torch.vtensor<[],f8E4M3FN>) -> !torch.vtensor<[6],f32> + return %0 : !torch.vtensor<[6],f32> +} + +// ----- // CHECK-LABEL: @test_div_bcast func.func @test_div_bcast(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {