From 87389f0762c1626a56f3afaafcf51bd9f5e28518 Mon Sep 17 00:00:00 2001 From: kumardeepakamd <123522031+kumardeepakamd@users.noreply.github.com> Date: Mon, 15 Jan 2024 11:26:46 -0800 Subject: [PATCH] [ONNXToTorch] Add conversion for Onnx range (#2752) Implemented ONNX.Range. The spec says the data type for start, limit, delta are 0-D can be double, float, int16, int32, int64, All int types mapped to !torch.int and all float types mapped to !torch.float --------- Co-authored-by: Kumar Deepak --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 58 +++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 55 ++++++++++++++++++ 2 files changed, 113 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 11a05ea41..0833af54d 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -27,6 +27,18 @@ using namespace mlir::torch::onnx_c; // to be more normal and a direct translation vs a special case. This // results in a lot of ONNX test cases that all reduce to the exact same // thing here, so we simplify. + +// utilities +// Templatized function to get an item op of a type +namespace { +template +Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter, + Value &ofItem) { + return rewriter.create(binder.getLoc(), + rewriter.getType(), ofItem); +} +} // namespace + void mlir::torch::onnx_c::populateDefaultDomainQtoZ( OnnxCustomOpConversionPattern &patterns) { patterns.onOp("Reciprocal", 1, @@ -1336,4 +1348,50 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( data, dimValueList); return success(); }); + patterns.onOp( + "Range", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // ONNX.Range(start, limit, delta) -- limit is exclusive + + Torch::ValueTensorType resultType; + Value start, limit, delta; + auto loc = binder.getLoc(); + Value none = rewriter.create(loc); + if (binder.tensorOperandAtIndex(start, 0) || + binder.tensorOperandAtIndex(limit, 1) || + binder.tensorOperandAtIndex(delta, 2) || + binder.tensorResultType(resultType)) + return failure(); + + // Convert a 0-dimensional/Scalar Tensor ([]) to Scalar Torch Numeric + // Value torch.tensor(1.1) equivalent in ONNX to 1.1 as an example + // type of start, limit, delta can be one of: double, float, int16, + // int32, int64 Assuming start, limit and delta to be same type (could + // they be different?) + Torch::BaseTensorType startTensorType = + start.getType().cast(); + bool isFloatDType = startTensorType.getDtype().isF64() || + startTensorType.getDtype().isF32(); + bool isIntDType = startTensorType.getDtype().isInteger(16) || + startTensorType.getDtype().isInteger(32) || + startTensorType.getDtype().isInteger(64); + if (!isFloatDType && !isIntDType) { + return rewriter.notifyMatchFailure( + binder.op, "Expected the start, limit, delta to be one of " + "double, float, int16, int32, int64"); + } + Value scalarStart, scalarLimit, scalarDelta; + if (isFloatDType) { + scalarStart = getItemOp(binder, rewriter, start); + scalarLimit = getItemOp(binder, rewriter, limit); + scalarDelta = getItemOp(binder, rewriter, delta); + } else { + scalarStart = getItemOp(binder, rewriter, start); + scalarLimit = getItemOp(binder, rewriter, limit); + scalarDelta = getItemOp(binder, rewriter, delta); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, scalarStart, scalarLimit, scalarDelta, none, + none, none, none); + return success(); + }); } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 91421d944..593a993c8 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -1179,3 +1179,58 @@ func.func @test_reshape_zero_and_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32> %0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[2,3,1,4],f32> return %0 : !torch.vtensor<[2,3,1,4],f32> } + +// CHECK-LABEL: func.func @test_range_float64_type + func.func @test_range_float64_type(%arg0: !torch.vtensor<[],f64>, %arg1: !torch.vtensor<[],f64>, %arg2: !torch.vtensor<[],f64>) -> !torch.vtensor<[2],f64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] torch.constant.none + // CHECK: torch.aten.item %arg0 : !torch.vtensor<[],f64> -> !torch.float + // CHECK: torch.aten.item %arg1 : !torch.vtensor<[],f64> -> !torch.float + // CHECK: torch.aten.item %arg2 : !torch.vtensor<[],f64> -> !torch.float + // CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.float, !torch.float, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],f64> + %0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],f64>, !torch.vtensor<[],f64>, !torch.vtensor<[],f64>) -> !torch.vtensor<[2],f64> + return %0 : !torch.vtensor<[2],f64> + } + +// CHECK-LABEL: func.func @test_range_float32_type + func.func @test_range_float32_type(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[2],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] torch.constant.none + // CHECK: torch.aten.item %arg0 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: torch.aten.item %arg2 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.float, !torch.float, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],f32> + %0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[2],f32> + return %0 : !torch.vtensor<[2],f32> + } + +// CHECK-LABEL: func.func @test_range_int64_type + func.func @test_range_int64_type(%arg0: !torch.vtensor<[],si64>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[],si64>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] torch.constant.none + // CHECK: torch.aten.item %arg0 : !torch.vtensor<[],si64> -> !torch.int + // CHECK: torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int + // CHECK: torch.aten.item %arg2 : !torch.vtensor<[],si64> -> !torch.int + // CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si64> + %0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[2],si64> + return %0 : !torch.vtensor<[2],si64> + } + +// CHECK-LABEL: func.func @test_range_int32_type + func.func @test_range_int32_type(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>) -> !torch.vtensor<[2],si32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] torch.constant.none + // CHECK: torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int + // CHECK: torch.aten.item %arg1 : !torch.vtensor<[],si32> -> !torch.int + // CHECK: torch.aten.item %arg2 : !torch.vtensor<[],si32> -> !torch.int + // CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si32> + %0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],si32>, !torch.vtensor<[],si32>, !torch.vtensor<[],si32>) -> !torch.vtensor<[2],si32> + return %0 : !torch.vtensor<[2],si32> + } + + // CHECK-LABEL: func.func @test_range_int16_type + func.func @test_range_int16_type(%arg0: !torch.vtensor<[],si16>, %arg1: !torch.vtensor<[],si16>, %arg2: !torch.vtensor<[],si16>) -> !torch.vtensor<[2],si16> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] torch.constant.none + // CHECK: torch.aten.item %arg0 : !torch.vtensor<[],si16> -> !torch.int + // CHECK: torch.aten.item %arg1 : !torch.vtensor<[],si16> -> !torch.int + // CHECK: torch.aten.item %arg2 : !torch.vtensor<[],si16> -> !torch.int + // CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si16> + %0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],si16>, !torch.vtensor<[],si16>, !torch.vtensor<[],si16>) -> !torch.vtensor<[2],si16> + return %0 : !torch.vtensor<[2],si16> + }