mirror of https://github.com/llvm/torch-mlir
[ONNX][TorchToLinalg] Add support for dynamic dims in Interpolate lowering (#3351)
Addresses [Shark-Turbine #196](https://github.com/nod-ai/SHARK-TestSuite/issues/196) Related tracker [Shark-Turbine #566](https://github.com/nod-ai/SHARK-Turbine/issues/566) Related onnx.Resize issues [Shark-Turbine #616](https://github.com/nod-ai/SHARK-Turbine/issues/616)pull/3268/merge
parent
513d89c16d
commit
6cba93b16e
|
@ -2912,11 +2912,13 @@ public:
|
||||||
auto inputType = input.getType().cast<RankedTensorType>();
|
auto inputType = input.getType().cast<RankedTensorType>();
|
||||||
auto inputRank = inputType.getRank();
|
auto inputRank = inputType.getRank();
|
||||||
|
|
||||||
if (inputType.isDynamicDim(2) || inputType.isDynamicDim(3)) {
|
|
||||||
return rewriter.notifyMatchFailure(op, "error: Dynamic dim on resize op");
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value, 2> outputSizeIntValues;
|
SmallVector<Value, 2> outputSizeIntValues;
|
||||||
|
Value inputSizeH = getDimOp(rewriter, loc, input, 2);
|
||||||
|
inputSizeH = rewriter.create<arith::IndexCastOp>(
|
||||||
|
loc, rewriter.getIntegerType(64), inputSizeH);
|
||||||
|
Value inputSizeW = getDimOp(rewriter, loc, input, 3);
|
||||||
|
inputSizeW = rewriter.create<arith::IndexCastOp>(
|
||||||
|
loc, rewriter.getIntegerType(64), inputSizeW);
|
||||||
|
|
||||||
if (!op.getScaleFactor().getType().isa<Torch::NoneType>()) {
|
if (!op.getScaleFactor().getType().isa<Torch::NoneType>()) {
|
||||||
SmallVector<Value, 2> ScaleFactorTorchFloat;
|
SmallVector<Value, 2> ScaleFactorTorchFloat;
|
||||||
|
@ -2927,8 +2929,6 @@ public:
|
||||||
SmallVector<Value, 2> ScaleFactorFloatValues;
|
SmallVector<Value, 2> ScaleFactorFloatValues;
|
||||||
ScaleFactorFloatValues = getTypeConvertedValues(
|
ScaleFactorFloatValues = getTypeConvertedValues(
|
||||||
rewriter, loc, getTypeConverter(), ScaleFactorTorchFloat);
|
rewriter, loc, getTypeConverter(), ScaleFactorTorchFloat);
|
||||||
Value inputSizeH = rewriter.create<arith::ConstantOp>(
|
|
||||||
loc, rewriter.getI64IntegerAttr(inputType.getShape()[2]));
|
|
||||||
Value inputHFP = rewriter.create<arith::SIToFPOp>(
|
Value inputHFP = rewriter.create<arith::SIToFPOp>(
|
||||||
loc, rewriter.getF32Type(), inputSizeH);
|
loc, rewriter.getF32Type(), inputSizeH);
|
||||||
Value scale = rewriter.create<arith::TruncFOp>(loc, inputHFP.getType(),
|
Value scale = rewriter.create<arith::TruncFOp>(loc, inputHFP.getType(),
|
||||||
|
@ -2938,8 +2938,6 @@ public:
|
||||||
outputH =
|
outputH =
|
||||||
rewriter.create<arith::FPToSIOp>(loc, rewriter.getI64Type(), outputH);
|
rewriter.create<arith::FPToSIOp>(loc, rewriter.getI64Type(), outputH);
|
||||||
|
|
||||||
Value inputSizeW = rewriter.create<arith::ConstantOp>(
|
|
||||||
loc, rewriter.getI64IntegerAttr(inputType.getShape()[3]));
|
|
||||||
Value inputWFP = rewriter.create<arith::SIToFPOp>(
|
Value inputWFP = rewriter.create<arith::SIToFPOp>(
|
||||||
loc, rewriter.getF32Type(), inputSizeW);
|
loc, rewriter.getF32Type(), inputSizeW);
|
||||||
scale = rewriter.create<arith::TruncFOp>(loc, inputWFP.getType(),
|
scale = rewriter.create<arith::TruncFOp>(loc, inputWFP.getType(),
|
||||||
|
@ -2960,11 +2958,9 @@ public:
|
||||||
outputSizeIntValues = getTypeConvertedValues(
|
outputSizeIntValues = getTypeConvertedValues(
|
||||||
rewriter, loc, getTypeConverter(), outputSizeTorchInt);
|
rewriter, loc, getTypeConverter(), outputSizeTorchInt);
|
||||||
}
|
}
|
||||||
int hDimOffset = 2;
|
SmallVector<Value> dims = getTensorSizesUntilDim(rewriter, loc, input, 1);
|
||||||
SmallVector<Value> dims = getTensorSizes(rewriter, loc, input);
|
dims.push_back(castIntToIndex(rewriter, loc, outputSizeIntValues[0]));
|
||||||
dims[hDimOffset] = castIntToIndex(rewriter, loc, outputSizeIntValues[0]);
|
dims.push_back(castIntToIndex(rewriter, loc, outputSizeIntValues[1]));
|
||||||
dims[hDimOffset + 1] =
|
|
||||||
castIntToIndex(rewriter, loc, outputSizeIntValues[1]);
|
|
||||||
|
|
||||||
Value outTensor = rewriter.create<tensor::EmptyOp>(
|
Value outTensor = rewriter.create<tensor::EmptyOp>(
|
||||||
loc, getAsOpFoldResult(dims), inputType.getElementType());
|
loc, getAsOpFoldResult(dims), inputType.getElementType());
|
||||||
|
@ -2983,10 +2979,6 @@ public:
|
||||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
Value outputSizeH = outputSizeIntValues[0];
|
Value outputSizeH = outputSizeIntValues[0];
|
||||||
Value outputSizeW = outputSizeIntValues[1];
|
Value outputSizeW = outputSizeIntValues[1];
|
||||||
Value inputSizeH = b.create<arith::ConstantOp>(
|
|
||||||
loc, b.getI64IntegerAttr(inputType.getShape()[2]));
|
|
||||||
Value inputSizeW = b.create<arith::ConstantOp>(
|
|
||||||
loc, b.getI64IntegerAttr(inputType.getShape()[3]));
|
|
||||||
Value retVal;
|
Value retVal;
|
||||||
if (mode == "nearest") {
|
if (mode == "nearest") {
|
||||||
retVal =
|
retVal =
|
||||||
|
|
|
@ -2607,9 +2607,6 @@ ONNX_XFAIL_SET = {
|
||||||
"BernoulliTensorModule_basic",
|
"BernoulliTensorModule_basic",
|
||||||
# Failure - onnx_lowering: onnx.ReduceProd
|
# Failure - onnx_lowering: onnx.ReduceProd
|
||||||
"ReduceProdDimIntFloatModule_basic",
|
"ReduceProdDimIntFloatModule_basic",
|
||||||
# Failure - onnx_lowering: onnx.Resize
|
|
||||||
"UpSampleNearest2dDynamicSize_basic",
|
|
||||||
"UpSampleNearest2dStaticSize_basic",
|
|
||||||
# Failure - onnx_lowering: onnx.ScatterElements
|
# Failure - onnx_lowering: onnx.ScatterElements
|
||||||
"ScatterReduceFloatMaxModuleIncludeSelf",
|
"ScatterReduceFloatMaxModuleIncludeSelf",
|
||||||
"ScatterReduceFloatMinModuleIncludeSelf",
|
"ScatterReduceFloatMinModuleIncludeSelf",
|
||||||
|
|
|
@ -4,15 +4,13 @@
|
||||||
func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4]
|
func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4]
|
||||||
,si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
,si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[generic:.*]] = linalg.generic
|
// CHECK: %[[generic:.*]] = linalg.generic
|
||||||
// CHECK: %[[c2_i64:.*]] = arith.constant 2 : i64
|
|
||||||
// CHECK: %[[c4_i64:.*]] = arith.constant 4 : i64
|
|
||||||
// CHECK: %[[cst:.*]] = arith.constant 1.001000e+00 : f32
|
// CHECK: %[[cst:.*]] = arith.constant 1.001000e+00 : f32
|
||||||
// CHECK: %[[cst_4:.*]] = arith.constant 1.000000e+00 : f32
|
// CHECK: %[[cst_4:.*]] = arith.constant 1.000000e+00 : f32
|
||||||
// CHECK: %[[cst_5:.*]] = arith.constant 5.000000e-01 : f32
|
// CHECK: %[[cst_5:.*]] = arith.constant 5.000000e-01 : f32
|
||||||
// CHECK: %[[cst_6:.*]] = arith.constant 0.000000e+00 : f32
|
// CHECK: %[[cst_6:.*]] = arith.constant 0.000000e+00 : f32
|
||||||
// CHECK: %[[x13:.*]] = linalg.index 2 : index
|
// CHECK: %[[x13:.*]] = linalg.index 2 : index
|
||||||
// CHECK: %[[x14:.*]] = linalg.index 3 : index
|
// CHECK: %[[x14:.*]] = linalg.index 3 : index
|
||||||
// CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64]] : i64 to f32
|
// CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32
|
||||||
// CHECK: %[[x16:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32
|
// CHECK: %[[x16:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32
|
||||||
// CHECK: %[[x17:.*]] = arith.divf %[[x16]], %[[x15]] : f32
|
// CHECK: %[[x17:.*]] = arith.divf %[[x16]], %[[x15]] : f32
|
||||||
// CHECK: %[[x18:.*]] = arith.index_cast %[[x13]] : index to i64
|
// CHECK: %[[x18:.*]] = arith.index_cast %[[x13]] : index to i64
|
||||||
|
@ -23,7 +21,7 @@ func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1:
|
||||||
// CHECK: %[[x23:.*]] = arith.maximumf %[[x22]], %[[cst_6]] : f32
|
// CHECK: %[[x23:.*]] = arith.maximumf %[[x22]], %[[cst_6]] : f32
|
||||||
// CHECK: %[[x24:.*]] = arith.subf %[[x15]], %[[cst]] : f32
|
// CHECK: %[[x24:.*]] = arith.subf %[[x15]], %[[cst]] : f32
|
||||||
// CHECK: %[[x25:.*]] = arith.minimumf %[[x23]], %[[x24]] : f32
|
// CHECK: %[[x25:.*]] = arith.minimumf %[[x23]], %[[x24]] : f32
|
||||||
// CHECK: %[[x26:.*]] = arith.sitofp %[[c4_i64]] : i64 to f32
|
// CHECK: %[[x26:.*]] = arith.sitofp %[[c4_i64:.*]] : i64 to f32
|
||||||
// CHECK: %[[x27:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32
|
// CHECK: %[[x27:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32
|
||||||
// CHECK: %[[x28:.*]] = arith.divf %[[x27]], %[[x26]] : f32
|
// CHECK: %[[x28:.*]] = arith.divf %[[x27]], %[[x26]] : f32
|
||||||
// CHECK: %[[x29:.*]] = arith.index_cast %[[x14]] : index to i64
|
// CHECK: %[[x29:.*]] = arith.index_cast %[[x14]] : index to i64
|
||||||
|
@ -96,12 +94,10 @@ func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1:
|
||||||
|
|
||||||
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[GENERIC:.*]] = linalg.generic
|
// CHECK: %[[GENERIC:.*]] = linalg.generic
|
||||||
// CHECK: %[[c2_i64:.*]] = arith.constant 2 : i64
|
|
||||||
// CHECK: %[[c4_i64:.*]] = arith.constant 4 : i64
|
|
||||||
// CHECK: %[[x13:.*]] = linalg.index 2 : index
|
// CHECK: %[[x13:.*]] = linalg.index 2 : index
|
||||||
// CHECK: %[[x14:.*]] = linalg.index 3 : index
|
// CHECK: %[[x14:.*]] = linalg.index 3 : index
|
||||||
// CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64]] : i64 to f32
|
// CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32
|
||||||
// CHECK: %[[x16:.*]] = arith.sitofp %[[c4_i64]] : i64 to f32
|
// CHECK: %[[x16:.*]] = arith.sitofp %[[c4_i64:.*]] : i64 to f32
|
||||||
// CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32
|
// CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32
|
||||||
// CHECK: %[[x20:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32
|
// CHECK: %[[x20:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32
|
||||||
// CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32
|
// CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32
|
||||||
|
|
Loading…
Reference in New Issue