mirror of https://github.com/llvm/torch-mlir
Fix Slice Folder OOB Crash and onnx.Shape lowering (#3843)
1. Clamps OOB start index to 0 in slice folder 2. Adds a more descriptive `emitError` in slice folder if the creation of the `DenseElementsAttr` would fail due to a bad result shape. 3. Fixes the `onnx.Shape` lowering to default to `inputRank` for `end` instead of `-1`. When `end==-1` the last element was missing when slicing.pull/3845/head
parent
738d45d3bb
commit
3104b66560
|
@ -1654,14 +1654,19 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
Value operand;
|
Value operand;
|
||||||
int64_t start, end;
|
int64_t start, end;
|
||||||
if (binder.tensorOperand(operand) ||
|
if (binder.tensorOperand(operand) ||
|
||||||
binder.tensorResultType(resultType) ||
|
binder.tensorResultType(resultType))
|
||||||
binder.s64IntegerAttr(start, "start", 0) ||
|
|
||||||
binder.s64IntegerAttr(end, "end", -1))
|
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto inputType = dyn_cast<Torch::ValueTensorType>(operand.getType());
|
auto inputType = dyn_cast<Torch::ValueTensorType>(operand.getType());
|
||||||
|
if (!inputType || !inputType.hasSizes())
|
||||||
|
return failure();
|
||||||
|
|
||||||
int64_t inputRank = inputType.getSizes().size();
|
int64_t inputRank = inputType.getSizes().size();
|
||||||
|
|
||||||
|
if (binder.s64IntegerAttr(start, "start", 0) ||
|
||||||
|
binder.s64IntegerAttr(end, "end", inputRank))
|
||||||
|
return failure();
|
||||||
|
|
||||||
auto shapeType = Torch::ValueTensorType::get(
|
auto shapeType = Torch::ValueTensorType::get(
|
||||||
binder.op->getContext(), SmallVector<int64_t>{inputRank},
|
binder.op->getContext(), SmallVector<int64_t>{inputRank},
|
||||||
resultType.getOptionalDtype());
|
resultType.getOptionalDtype());
|
||||||
|
@ -1674,7 +1679,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (start == 0 && end == -1) {
|
if (start == 0 && end == inputRank) {
|
||||||
rewriter.replaceOp(binder.op, shape);
|
rewriter.replaceOp(binder.op, shape);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -3998,6 +3998,7 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
|
||||||
int64_t limit = end.getValue().getSExtValue();
|
int64_t limit = end.getValue().getSExtValue();
|
||||||
int64_t stride = step.getValue().getSExtValue();
|
int64_t stride = step.getValue().getSExtValue();
|
||||||
begin = begin < 0 ? begin + inType.getSizes()[dimInt] : begin;
|
begin = begin < 0 ? begin + inType.getSizes()[dimInt] : begin;
|
||||||
|
begin = std::max<int64_t>(begin, 0);
|
||||||
limit = limit < 0 ? limit + inType.getSizes()[dimInt] : limit;
|
limit = limit < 0 ? limit + inType.getSizes()[dimInt] : limit;
|
||||||
limit = limit < 0 ? -1 : limit;
|
limit = limit < 0 ? -1 : limit;
|
||||||
limit = std::min(limit, inType.getSizes()[dimInt]);
|
limit = std::min(limit, inType.getSizes()[dimInt]);
|
||||||
|
@ -4038,6 +4039,14 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
recursiveIter(recursiveIter, 0, 0);
|
recursiveIter(recursiveIter, 0, 0);
|
||||||
|
if (static_cast<int64_t>(values.size()) != count) {
|
||||||
|
emitError(
|
||||||
|
"Op has incorrect result shape for provided arguments.\nNum elements "
|
||||||
|
"present in slice: " +
|
||||||
|
std::to_string(values.size()) +
|
||||||
|
"\nNum elements implied by result type: " + std::to_string(count));
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
return DenseElementsAttr::get(outType.toBuiltinTensor(), values);
|
return DenseElementsAttr::get(outType.toBuiltinTensor(), values);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2330,6 +2330,32 @@ func.func @torch.aten.slice.tensor$fold_full_slice(%arg0: !torch.vtensor<[?],f32
|
||||||
return %0 : !torch.vtensor<[?],f32>
|
return %0 : !torch.vtensor<[?],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @torch.aten.slice.tensor$fold_oob_start
|
||||||
|
// CHECK: %[[LIT:.*]] = torch.vtensor.literal(dense<[0, 1, 2]> : tensor<3xsi64>) : !torch.vtensor<[3],si64>
|
||||||
|
// CHECK: return %[[LIT]] : !torch.vtensor<[3],si64>
|
||||||
|
func.func @torch.aten.slice.tensor$fold_oob_start() -> !torch.vtensor<[3],si64> {
|
||||||
|
%0 = torch.vtensor.literal(dense<[0,1,2,3]> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int-1 = torch.constant.int -1
|
||||||
|
%int-10 = torch.constant.int -10
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%1 = torch.aten.slice.Tensor %0, %int0, %int-10, %int-1, %int1 : !torch.vtensor<[4], si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3], si64>
|
||||||
|
return %1 : !torch.vtensor<[3],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @torch.aten.slice.tensor$nofold_invalid_shape
|
||||||
|
// CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor
|
||||||
|
// CHECK: return %[[SLICE]]
|
||||||
|
func.func @torch.aten.slice.tensor$nofold_invalid_shape() -> !torch.vtensor<[4],si64> {
|
||||||
|
%0 = torch.vtensor.literal(dense<[0,1,2,3]> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int-1 = torch.constant.int -1
|
||||||
|
%int-10 = torch.constant.int -10
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%1 = torch.aten.slice.Tensor %0, %int0, %int-10, %int-1, %int1 : !torch.vtensor<[4], si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4], si64>
|
||||||
|
return %1 : !torch.vtensor<[4],si64>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @torch.aten.slice.tensor$no_fold_step
|
// CHECK-LABEL: @torch.aten.slice.tensor$no_fold_step
|
||||||
// CHECK: torch.aten.slice.Tensor
|
// CHECK: torch.aten.slice.Tensor
|
||||||
func.func @torch.aten.slice.tensor$no_fold_step(%arg0: !torch.vtensor<[?],f32>, %dim: !torch.int) -> !torch.vtensor<[?],f32> {
|
func.func @torch.aten.slice.tensor$no_fold_step(%arg0: !torch.vtensor<[?],f32>, %dim: !torch.int) -> !torch.vtensor<[?],f32> {
|
||||||
|
|
Loading…
Reference in New Issue