From d0933b0eb6c94c38132d5d80fd82323cda80a159 Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Fri, 2 Aug 2024 11:55:37 -0700 Subject: [PATCH] [TorchToLinalg] Fix possible OOB access in Interpolate lowering (#3570) Following up from the discussion in , I've edited the lowering to prevent OOB extracts in a more direct fashion (i.e., just clamping directly). I don't think this affects the lit tests at all, but I've tested the changes in our external test suite at . I found the issue when I was unexpectedly getting `nan`'s along the output image border for a resize test there. --- .../TorchToLinalg/Uncategorized.cpp | 28 ++++++++------- test/Conversion/TorchToLinalg/resize.mlir | 35 ++++++++++++++++--- 2 files changed, 47 insertions(+), 16 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index b4b245f6d..958047ee3 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -2713,8 +2713,6 @@ static Value BilinearInterpolate(OpBuilder &b, auto inputType = cast(input.getType()); auto inputRank = inputType.getRank(); - Value cstOneEps = - b.create(loc, b.getF32FloatAttr(1.000001)); Value cstOneFloat = b.create(loc, b.getF32FloatAttr(1.0)); Value cstHalf = b.create(loc, b.getF32FloatAttr(0.5)); Value zero = b.create(loc, b.getF32FloatAttr(0.0)); @@ -2790,28 +2788,34 @@ static Value BilinearInterpolate(OpBuilder &b, outputSizeFP, cstOneFloat); preClip = b.create(loc, cmp, zero, preClip); } - // clip to 0,inf + // preClip is the fp position inside the input image to extract from. + // clip to [0,inf) Value max = b.create(loc, preClip, zero); - // length_original - 1.001 - Value inputSubOneEps = b.create(loc, inputFP, cstOneEps); Value inputSubOne = b.create(loc, inputFP, cstOneFloat); - // clip to [0,length_original - 1.001] - projEps.push_back(b.create(loc, max, inputSubOneEps)); + // clip to [0,length_original - 1]. + // proj is properly within the input image. proj.push_back(b.create(loc, max, inputSubOne)); - lowFP.push_back(b.create(loc, projEps[i])); - Value projPlusOne = b.create(loc, cstOneFloat, projEps[i]); + // for bilinear interpolation, we look for the nearest indices below and + // above proj + lowFP.push_back(b.create(loc, proj[i])); + Value projPlusOne = b.create(loc, cstOneFloat, proj[i]); highFP.push_back(b.create(loc, projPlusOne)); Value lowInt = b.create(loc, b.getI64Type(), lowFP[i]); low.push_back(b.create(loc, b.getIndexType(), lowInt)); - Value highInt = b.create(loc, b.getI64Type(), highFP[i]); + // highFP could be out-of-bounds, so make sure to clip it down before + // extracting. If highFP actually gets clipped here, then high[i] will + // extract at the last pixel, but will treat it as if it were extracted from + // one further position when computing the interpolation weights. + Value highExtract = + b.create(loc, projPlusOne, inputSubOne); + highExtract = b.create(loc, b.getI64Type(), highExtract); high.push_back( - b.create(loc, b.getIndexType(), highInt)); + b.create(loc, b.getIndexType(), highExtract)); } - SmallVector cornerValues; indices[dimOffset] = low[0]; indices[dimOffset + 1] = low[1]; Value p00 = b.create(loc, input, indices); diff --git a/test/Conversion/TorchToLinalg/resize.mlir b/test/Conversion/TorchToLinalg/resize.mlir index 64198d03f..7976b1ad8 100644 --- a/test/Conversion/TorchToLinalg/resize.mlir +++ b/test/Conversion/TorchToLinalg/resize.mlir @@ -3,11 +3,38 @@ // CHECK-LABEL: func.func @test_resize_sizes_linear func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4] ,si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[x0:.*]] = torch_c.to_builtin_tensor %arg0 // CHECK: %[[generic:.*]] = linalg.generic - // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x1:.*]], %[[x2:.*]], %[[x3:.*]], %[[x4:.*]]] : tensor<1x1x2x4xf32> - // CHECK: %[[extracted_7:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]] - // CHECK: %[[extracted_8:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]] - // CHECK: %[[extracted_9:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]] + // CHECK-DAG: %[[cst:.*]] = arith.constant 1.000000e+00 : f32 + // CHECK-DAG: %[[cst_4:.*]] = arith.constant 5.000000e-01 : f32 + // CHECK-DAG: %[[cst_5:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK-DAG: %[[x15:.*]] = linalg.index 0 : index + // CHECK-DAG: %[[x16:.*]] = linalg.index 1 : index + // CHECK-DAG: %[[x17:.*]] = linalg.index 2 : index + // CHECK-DAG: %[[x18:.*]] = linalg.index 3 : index + // CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32 + // CHECK: %[[x20:.*]] = arith.sitofp %[[x8:.*]] : i64 to f32 + // CHECK-DAG: %[[x21:.*]] = arith.divf %[[x20]], %[[x19]] : f32 + // CHECK-DAG: %[[x22:.*]] = arith.index_cast %[[x17]] : index to i64 + // CHECK-DAG: %[[x23:.*]] = arith.sitofp %[[x22]] : i64 to f32 + // CHECK-DAG: %[[x24:.*]] = arith.addf %[[x23]], %[[cst_4]] : f32 + // CHECK-DAG: %[[x25:.*]] = arith.divf %[[x24]], %[[x21]] : f32 + // CHECK-DAG: %[[x26:.*]] = arith.subf %[[x25]], %[[cst_4]] : f32 + // CHECK-DAG: %[[x27:.*]] = arith.maximumf %[[x26]], %[[cst_5]] : f32 + // CHECK-DAG: %[[x28:.*]] = arith.subf %[[x19]], %[[cst]] : f32 + // CHECK-DAG: %[[x29:.*]] = arith.minimumf %[[x27]], %[[x28]] : f32 + // CHECK-DAG: %[[x30:.*]] = math.floor %[[x29]] : f32 + // CHECK-DAG: %[[x31:.*]] = arith.addf %[[cst]], %[[x29]] : f32 + // CHECK-DAG: %[[x32:.*]] = math.floor %[[x31]] : f32 + // CHECK-DAG: %[[x33:.*]] = arith.fptosi %[[x30]] : f32 to i64 + // CHECK-DAG: %[[x34:.*]] = arith.index_cast %[[x33]] : i64 to index + // CHECK-DAG: %[[x35:.*]] = arith.minimumf %[[x31]], %[[x28]] : f32 + // CHECK-DAG: %[[x36:.*]] = arith.fptosi %[[x35]] : f32 to i64 + // CHECK-DAG: %[[x37:.*]] = arith.index_cast %[[x36]] : i64 to index + // CHECK: %[[extracted:.*]] = tensor.extract %[[x0]][%[[x15]], %[[x16]], %[[x34]], %[[low:.*]]] : tensor<1x1x2x4xf32> + // CHECK: %[[extracted_7:.*]] = tensor.extract %[[x0]][%[[x15]], %[[x16]], %[[x34]], %[[high:.*]]] : tensor<1x1x2x4xf32> + // CHECK: %[[extracted_8:.*]] = tensor.extract %[[x0]][%[[x15]], %[[x16]], %[[x37]], %[[low]]] : tensor<1x1x2x4xf32> + // CHECK: %[[extracted_9:.*]] = tensor.extract %[[x0]][%[[x15]], %[[x16]], %[[x37]], %[[high]]] : tensor<1x1x2x4xf32> // CHECK: %[[dx0p00:.*]] = arith.mulf %[[dx0:.*]], %[[extracted]] // CHECK: %[[dx1p01:.*]] = arith.mulf %[[dx1:.*]], %[[extracted_7]] // CHECK: %[[sum:.*]] = arith.addf %[[dx0p00]], %[[dx1p01]]