[ONNX][TorchToLinalg] Add support for dynamic dims in Interpolate lowering (#3351)

Addresses [Shark-Turbine
#196](https://github.com/nod-ai/SHARK-TestSuite/issues/196)

Related tracker [Shark-Turbine
#566](https://github.com/nod-ai/SHARK-Turbine/issues/566)

Related onnx.Resize issues [Shark-Turbine
#616](https://github.com/nod-ai/SHARK-Turbine/issues/616)
pull/3268/merge
zjgarvey 2024-05-17 14:18:57 -05:00 committed by GitHub
parent 513d89c16d
commit 6cba93b16e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 28 deletions

View File

@ -2912,11 +2912,13 @@ public:
auto inputType = input.getType().cast<RankedTensorType>();
auto inputRank = inputType.getRank();
if (inputType.isDynamicDim(2) || inputType.isDynamicDim(3)) {
return rewriter.notifyMatchFailure(op, "error: Dynamic dim on resize op");
}
SmallVector<Value, 2> outputSizeIntValues;
Value inputSizeH = getDimOp(rewriter, loc, input, 2);
inputSizeH = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIntegerType(64), inputSizeH);
Value inputSizeW = getDimOp(rewriter, loc, input, 3);
inputSizeW = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIntegerType(64), inputSizeW);
if (!op.getScaleFactor().getType().isa<Torch::NoneType>()) {
SmallVector<Value, 2> ScaleFactorTorchFloat;
@ -2927,8 +2929,6 @@ public:
SmallVector<Value, 2> ScaleFactorFloatValues;
ScaleFactorFloatValues = getTypeConvertedValues(
rewriter, loc, getTypeConverter(), ScaleFactorTorchFloat);
Value inputSizeH = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(inputType.getShape()[2]));
Value inputHFP = rewriter.create<arith::SIToFPOp>(
loc, rewriter.getF32Type(), inputSizeH);
Value scale = rewriter.create<arith::TruncFOp>(loc, inputHFP.getType(),
@ -2938,8 +2938,6 @@ public:
outputH =
rewriter.create<arith::FPToSIOp>(loc, rewriter.getI64Type(), outputH);
Value inputSizeW = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(inputType.getShape()[3]));
Value inputWFP = rewriter.create<arith::SIToFPOp>(
loc, rewriter.getF32Type(), inputSizeW);
scale = rewriter.create<arith::TruncFOp>(loc, inputWFP.getType(),
@ -2960,11 +2958,9 @@ public:
outputSizeIntValues = getTypeConvertedValues(
rewriter, loc, getTypeConverter(), outputSizeTorchInt);
}
int hDimOffset = 2;
SmallVector<Value> dims = getTensorSizes(rewriter, loc, input);
dims[hDimOffset] = castIntToIndex(rewriter, loc, outputSizeIntValues[0]);
dims[hDimOffset + 1] =
castIntToIndex(rewriter, loc, outputSizeIntValues[1]);
SmallVector<Value> dims = getTensorSizesUntilDim(rewriter, loc, input, 1);
dims.push_back(castIntToIndex(rewriter, loc, outputSizeIntValues[0]));
dims.push_back(castIntToIndex(rewriter, loc, outputSizeIntValues[1]));
Value outTensor = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(dims), inputType.getElementType());
@ -2983,10 +2979,6 @@ public:
[&](OpBuilder &b, Location loc, ValueRange args) {
Value outputSizeH = outputSizeIntValues[0];
Value outputSizeW = outputSizeIntValues[1];
Value inputSizeH = b.create<arith::ConstantOp>(
loc, b.getI64IntegerAttr(inputType.getShape()[2]));
Value inputSizeW = b.create<arith::ConstantOp>(
loc, b.getI64IntegerAttr(inputType.getShape()[3]));
Value retVal;
if (mode == "nearest") {
retVal =

View File

@ -2607,9 +2607,6 @@ ONNX_XFAIL_SET = {
"BernoulliTensorModule_basic",
# Failure - onnx_lowering: onnx.ReduceProd
"ReduceProdDimIntFloatModule_basic",
# Failure - onnx_lowering: onnx.Resize
"UpSampleNearest2dDynamicSize_basic",
"UpSampleNearest2dStaticSize_basic",
# Failure - onnx_lowering: onnx.ScatterElements
"ScatterReduceFloatMaxModuleIncludeSelf",
"ScatterReduceFloatMinModuleIncludeSelf",

View File

@ -4,15 +4,13 @@
func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4]
,si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[generic:.*]] = linalg.generic
// CHECK: %[[c2_i64:.*]] = arith.constant 2 : i64
// CHECK: %[[c4_i64:.*]] = arith.constant 4 : i64
// CHECK: %[[cst:.*]] = arith.constant 1.001000e+00 : f32
// CHECK: %[[cst_4:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[cst_5:.*]] = arith.constant 5.000000e-01 : f32
// CHECK: %[[cst_6:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[x13:.*]] = linalg.index 2 : index
// CHECK: %[[x14:.*]] = linalg.index 3 : index
// CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64]] : i64 to f32
// CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32
// CHECK: %[[x16:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32
// CHECK: %[[x17:.*]] = arith.divf %[[x16]], %[[x15]] : f32
// CHECK: %[[x18:.*]] = arith.index_cast %[[x13]] : index to i64
@ -23,7 +21,7 @@ func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1:
// CHECK: %[[x23:.*]] = arith.maximumf %[[x22]], %[[cst_6]] : f32
// CHECK: %[[x24:.*]] = arith.subf %[[x15]], %[[cst]] : f32
// CHECK: %[[x25:.*]] = arith.minimumf %[[x23]], %[[x24]] : f32
// CHECK: %[[x26:.*]] = arith.sitofp %[[c4_i64]] : i64 to f32
// CHECK: %[[x26:.*]] = arith.sitofp %[[c4_i64:.*]] : i64 to f32
// CHECK: %[[x27:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32
// CHECK: %[[x28:.*]] = arith.divf %[[x27]], %[[x26]] : f32
// CHECK: %[[x29:.*]] = arith.index_cast %[[x14]] : index to i64
@ -96,12 +94,10 @@ func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1:
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[GENERIC:.*]] = linalg.generic
// CHECK: %[[c2_i64:.*]] = arith.constant 2 : i64
// CHECK: %[[c4_i64:.*]] = arith.constant 4 : i64
// CHECK: %[[x13:.*]] = linalg.index 2 : index
// CHECK: %[[x14:.*]] = linalg.index 3 : index
// CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64]] : i64 to f32
// CHECK: %[[x16:.*]] = arith.sitofp %[[c4_i64]] : i64 to f32
// CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32
// CHECK: %[[x16:.*]] = arith.sitofp %[[c4_i64:.*]] : i64 to f32
// CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32
// CHECK: %[[x20:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32
// CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32