[TorchToLinalg] Fix possible OOB access in Interpolate lowering (#3570)

Following up from the discussion in
<https://github.com/llvm/torch-mlir/pull/3550>, I've edited the lowering
to prevent OOB extracts in a more direct fashion (i.e., just clamping
directly).

I don't think this affects the lit tests at all, but I've tested the
changes in our external test suite at
<https://github.com/nod-ai/SHARK-TestSuite/tree/main/>. I found the
issue when I was unexpectedly getting `nan`'s along the output image
border for a resize test there.
pull/3608/head
zjgarvey 2024-08-02 11:55:37 -07:00 committed by GitHub
parent 79ae0afc2f
commit d0933b0eb6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 16 deletions

View File

@ -2713,8 +2713,6 @@ static Value BilinearInterpolate(OpBuilder &b,
auto inputType = cast<RankedTensorType>(input.getType());
auto inputRank = inputType.getRank();
Value cstOneEps =
b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.000001));
Value cstOneFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.0));
Value cstHalf = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.5));
Value zero = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.0));
@ -2790,28 +2788,34 @@ static Value BilinearInterpolate(OpBuilder &b,
outputSizeFP, cstOneFloat);
preClip = b.create<arith::SelectOp>(loc, cmp, zero, preClip);
}
// clip to 0,inf
// preClip is the fp position inside the input image to extract from.
// clip to [0,inf)
Value max = b.create<arith::MaximumFOp>(loc, preClip, zero);
// length_original - 1.001
Value inputSubOneEps = b.create<arith::SubFOp>(loc, inputFP, cstOneEps);
Value inputSubOne = b.create<arith::SubFOp>(loc, inputFP, cstOneFloat);
// clip to [0,length_original - 1.001]
projEps.push_back(b.create<arith::MinimumFOp>(loc, max, inputSubOneEps));
// clip to [0,length_original - 1].
// proj is properly within the input image.
proj.push_back(b.create<arith::MinimumFOp>(loc, max, inputSubOne));
lowFP.push_back(b.create<math::FloorOp>(loc, projEps[i]));
Value projPlusOne = b.create<arith::AddFOp>(loc, cstOneFloat, projEps[i]);
// for bilinear interpolation, we look for the nearest indices below and
// above proj
lowFP.push_back(b.create<math::FloorOp>(loc, proj[i]));
Value projPlusOne = b.create<arith::AddFOp>(loc, cstOneFloat, proj[i]);
highFP.push_back(b.create<math::FloorOp>(loc, projPlusOne));
Value lowInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), lowFP[i]);
low.push_back(b.create<arith::IndexCastOp>(loc, b.getIndexType(), lowInt));
Value highInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), highFP[i]);
// highFP could be out-of-bounds, so make sure to clip it down before
// extracting. If highFP actually gets clipped here, then high[i] will
// extract at the last pixel, but will treat it as if it were extracted from
// one further position when computing the interpolation weights.
Value highExtract =
b.create<arith::MinimumFOp>(loc, projPlusOne, inputSubOne);
highExtract = b.create<arith::FPToSIOp>(loc, b.getI64Type(), highExtract);
high.push_back(
b.create<arith::IndexCastOp>(loc, b.getIndexType(), highInt));
b.create<arith::IndexCastOp>(loc, b.getIndexType(), highExtract));
}
SmallVector<Value> cornerValues;
indices[dimOffset] = low[0];
indices[dimOffset + 1] = low[1];
Value p00 = b.create<tensor::ExtractOp>(loc, input, indices);

View File

@ -3,11 +3,38 @@
// CHECK-LABEL: func.func @test_resize_sizes_linear
func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4]
,si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[x0:.*]] = torch_c.to_builtin_tensor %arg0
// CHECK: %[[generic:.*]] = linalg.generic
// CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x1:.*]], %[[x2:.*]], %[[x3:.*]], %[[x4:.*]]] : tensor<1x1x2x4xf32>
// CHECK: %[[extracted_7:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]]
// CHECK: %[[extracted_8:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]]
// CHECK: %[[extracted_9:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]]
// CHECK-DAG: %[[cst:.*]] = arith.constant 1.000000e+00 : f32
// CHECK-DAG: %[[cst_4:.*]] = arith.constant 5.000000e-01 : f32
// CHECK-DAG: %[[cst_5:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[x15:.*]] = linalg.index 0 : index
// CHECK-DAG: %[[x16:.*]] = linalg.index 1 : index
// CHECK-DAG: %[[x17:.*]] = linalg.index 2 : index
// CHECK-DAG: %[[x18:.*]] = linalg.index 3 : index
// CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32
// CHECK: %[[x20:.*]] = arith.sitofp %[[x8:.*]] : i64 to f32
// CHECK-DAG: %[[x21:.*]] = arith.divf %[[x20]], %[[x19]] : f32
// CHECK-DAG: %[[x22:.*]] = arith.index_cast %[[x17]] : index to i64
// CHECK-DAG: %[[x23:.*]] = arith.sitofp %[[x22]] : i64 to f32
// CHECK-DAG: %[[x24:.*]] = arith.addf %[[x23]], %[[cst_4]] : f32
// CHECK-DAG: %[[x25:.*]] = arith.divf %[[x24]], %[[x21]] : f32
// CHECK-DAG: %[[x26:.*]] = arith.subf %[[x25]], %[[cst_4]] : f32
// CHECK-DAG: %[[x27:.*]] = arith.maximumf %[[x26]], %[[cst_5]] : f32
// CHECK-DAG: %[[x28:.*]] = arith.subf %[[x19]], %[[cst]] : f32
// CHECK-DAG: %[[x29:.*]] = arith.minimumf %[[x27]], %[[x28]] : f32
// CHECK-DAG: %[[x30:.*]] = math.floor %[[x29]] : f32
// CHECK-DAG: %[[x31:.*]] = arith.addf %[[cst]], %[[x29]] : f32
// CHECK-DAG: %[[x32:.*]] = math.floor %[[x31]] : f32
// CHECK-DAG: %[[x33:.*]] = arith.fptosi %[[x30]] : f32 to i64
// CHECK-DAG: %[[x34:.*]] = arith.index_cast %[[x33]] : i64 to index
// CHECK-DAG: %[[x35:.*]] = arith.minimumf %[[x31]], %[[x28]] : f32
// CHECK-DAG: %[[x36:.*]] = arith.fptosi %[[x35]] : f32 to i64
// CHECK-DAG: %[[x37:.*]] = arith.index_cast %[[x36]] : i64 to index
// CHECK: %[[extracted:.*]] = tensor.extract %[[x0]][%[[x15]], %[[x16]], %[[x34]], %[[low:.*]]] : tensor<1x1x2x4xf32>
// CHECK: %[[extracted_7:.*]] = tensor.extract %[[x0]][%[[x15]], %[[x16]], %[[x34]], %[[high:.*]]] : tensor<1x1x2x4xf32>
// CHECK: %[[extracted_8:.*]] = tensor.extract %[[x0]][%[[x15]], %[[x16]], %[[x37]], %[[low]]] : tensor<1x1x2x4xf32>
// CHECK: %[[extracted_9:.*]] = tensor.extract %[[x0]][%[[x15]], %[[x16]], %[[x37]], %[[high]]] : tensor<1x1x2x4xf32>
// CHECK: %[[dx0p00:.*]] = arith.mulf %[[dx0:.*]], %[[extracted]]
// CHECK: %[[dx1p01:.*]] = arith.mulf %[[dx1:.*]], %[[extracted_7]]
// CHECK: %[[sum:.*]] = arith.addf %[[dx0p00]], %[[dx1p01]]