mirror of https://github.com/llvm/torch-mlir
[ONNX][TorchToLinalg] Add support for dynamic dims in Interpolate lowering (#3351)
Addresses [Shark-Turbine #196](https://github.com/nod-ai/SHARK-TestSuite/issues/196) Related tracker [Shark-Turbine #566](https://github.com/nod-ai/SHARK-Turbine/issues/566) Related onnx.Resize issues [Shark-Turbine #616](https://github.com/nod-ai/SHARK-Turbine/issues/616)pull/3268/merge
parent
513d89c16d
commit
6cba93b16e
|
@ -2912,11 +2912,13 @@ public:
|
|||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto inputRank = inputType.getRank();
|
||||
|
||||
if (inputType.isDynamicDim(2) || inputType.isDynamicDim(3)) {
|
||||
return rewriter.notifyMatchFailure(op, "error: Dynamic dim on resize op");
|
||||
}
|
||||
|
||||
SmallVector<Value, 2> outputSizeIntValues;
|
||||
Value inputSizeH = getDimOp(rewriter, loc, input, 2);
|
||||
inputSizeH = rewriter.create<arith::IndexCastOp>(
|
||||
loc, rewriter.getIntegerType(64), inputSizeH);
|
||||
Value inputSizeW = getDimOp(rewriter, loc, input, 3);
|
||||
inputSizeW = rewriter.create<arith::IndexCastOp>(
|
||||
loc, rewriter.getIntegerType(64), inputSizeW);
|
||||
|
||||
if (!op.getScaleFactor().getType().isa<Torch::NoneType>()) {
|
||||
SmallVector<Value, 2> ScaleFactorTorchFloat;
|
||||
|
@ -2927,8 +2929,6 @@ public:
|
|||
SmallVector<Value, 2> ScaleFactorFloatValues;
|
||||
ScaleFactorFloatValues = getTypeConvertedValues(
|
||||
rewriter, loc, getTypeConverter(), ScaleFactorTorchFloat);
|
||||
Value inputSizeH = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI64IntegerAttr(inputType.getShape()[2]));
|
||||
Value inputHFP = rewriter.create<arith::SIToFPOp>(
|
||||
loc, rewriter.getF32Type(), inputSizeH);
|
||||
Value scale = rewriter.create<arith::TruncFOp>(loc, inputHFP.getType(),
|
||||
|
@ -2938,8 +2938,6 @@ public:
|
|||
outputH =
|
||||
rewriter.create<arith::FPToSIOp>(loc, rewriter.getI64Type(), outputH);
|
||||
|
||||
Value inputSizeW = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI64IntegerAttr(inputType.getShape()[3]));
|
||||
Value inputWFP = rewriter.create<arith::SIToFPOp>(
|
||||
loc, rewriter.getF32Type(), inputSizeW);
|
||||
scale = rewriter.create<arith::TruncFOp>(loc, inputWFP.getType(),
|
||||
|
@ -2960,11 +2958,9 @@ public:
|
|||
outputSizeIntValues = getTypeConvertedValues(
|
||||
rewriter, loc, getTypeConverter(), outputSizeTorchInt);
|
||||
}
|
||||
int hDimOffset = 2;
|
||||
SmallVector<Value> dims = getTensorSizes(rewriter, loc, input);
|
||||
dims[hDimOffset] = castIntToIndex(rewriter, loc, outputSizeIntValues[0]);
|
||||
dims[hDimOffset + 1] =
|
||||
castIntToIndex(rewriter, loc, outputSizeIntValues[1]);
|
||||
SmallVector<Value> dims = getTensorSizesUntilDim(rewriter, loc, input, 1);
|
||||
dims.push_back(castIntToIndex(rewriter, loc, outputSizeIntValues[0]));
|
||||
dims.push_back(castIntToIndex(rewriter, loc, outputSizeIntValues[1]));
|
||||
|
||||
Value outTensor = rewriter.create<tensor::EmptyOp>(
|
||||
loc, getAsOpFoldResult(dims), inputType.getElementType());
|
||||
|
@ -2983,10 +2979,6 @@ public:
|
|||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value outputSizeH = outputSizeIntValues[0];
|
||||
Value outputSizeW = outputSizeIntValues[1];
|
||||
Value inputSizeH = b.create<arith::ConstantOp>(
|
||||
loc, b.getI64IntegerAttr(inputType.getShape()[2]));
|
||||
Value inputSizeW = b.create<arith::ConstantOp>(
|
||||
loc, b.getI64IntegerAttr(inputType.getShape()[3]));
|
||||
Value retVal;
|
||||
if (mode == "nearest") {
|
||||
retVal =
|
||||
|
|
|
@ -2607,9 +2607,6 @@ ONNX_XFAIL_SET = {
|
|||
"BernoulliTensorModule_basic",
|
||||
# Failure - onnx_lowering: onnx.ReduceProd
|
||||
"ReduceProdDimIntFloatModule_basic",
|
||||
# Failure - onnx_lowering: onnx.Resize
|
||||
"UpSampleNearest2dDynamicSize_basic",
|
||||
"UpSampleNearest2dStaticSize_basic",
|
||||
# Failure - onnx_lowering: onnx.ScatterElements
|
||||
"ScatterReduceFloatMaxModuleIncludeSelf",
|
||||
"ScatterReduceFloatMinModuleIncludeSelf",
|
||||
|
|
|
@ -4,15 +4,13 @@
|
|||
func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4]
|
||||
,si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[generic:.*]] = linalg.generic
|
||||
// CHECK: %[[c2_i64:.*]] = arith.constant 2 : i64
|
||||
// CHECK: %[[c4_i64:.*]] = arith.constant 4 : i64
|
||||
// CHECK: %[[cst:.*]] = arith.constant 1.001000e+00 : f32
|
||||
// CHECK: %[[cst_4:.*]] = arith.constant 1.000000e+00 : f32
|
||||
// CHECK: %[[cst_5:.*]] = arith.constant 5.000000e-01 : f32
|
||||
// CHECK: %[[cst_6:.*]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK: %[[x13:.*]] = linalg.index 2 : index
|
||||
// CHECK: %[[x14:.*]] = linalg.index 3 : index
|
||||
// CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64]] : i64 to f32
|
||||
// CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32
|
||||
// CHECK: %[[x16:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32
|
||||
// CHECK: %[[x17:.*]] = arith.divf %[[x16]], %[[x15]] : f32
|
||||
// CHECK: %[[x18:.*]] = arith.index_cast %[[x13]] : index to i64
|
||||
|
@ -23,7 +21,7 @@ func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1:
|
|||
// CHECK: %[[x23:.*]] = arith.maximumf %[[x22]], %[[cst_6]] : f32
|
||||
// CHECK: %[[x24:.*]] = arith.subf %[[x15]], %[[cst]] : f32
|
||||
// CHECK: %[[x25:.*]] = arith.minimumf %[[x23]], %[[x24]] : f32
|
||||
// CHECK: %[[x26:.*]] = arith.sitofp %[[c4_i64]] : i64 to f32
|
||||
// CHECK: %[[x26:.*]] = arith.sitofp %[[c4_i64:.*]] : i64 to f32
|
||||
// CHECK: %[[x27:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32
|
||||
// CHECK: %[[x28:.*]] = arith.divf %[[x27]], %[[x26]] : f32
|
||||
// CHECK: %[[x29:.*]] = arith.index_cast %[[x14]] : index to i64
|
||||
|
@ -96,12 +94,10 @@ func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1:
|
|||
|
||||
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[GENERIC:.*]] = linalg.generic
|
||||
// CHECK: %[[c2_i64:.*]] = arith.constant 2 : i64
|
||||
// CHECK: %[[c4_i64:.*]] = arith.constant 4 : i64
|
||||
// CHECK: %[[x13:.*]] = linalg.index 2 : index
|
||||
// CHECK: %[[x14:.*]] = linalg.index 3 : index
|
||||
// CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64]] : i64 to f32
|
||||
// CHECK: %[[x16:.*]] = arith.sitofp %[[c4_i64]] : i64 to f32
|
||||
// CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32
|
||||
// CHECK: %[[x16:.*]] = arith.sitofp %[[c4_i64:.*]] : i64 to f32
|
||||
// CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32
|
||||
// CHECK: %[[x20:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32
|
||||
// CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32
|
||||
|
|
Loading…
Reference in New Issue