[ONNXToTorch] Add conversion for Onnx range (#2752)

Implemented ONNX.Range. The spec says the data type for start, limit,
delta are 0-D can be double, float, int16, int32, int64, All int types
mapped to !torch.int and all float types mapped to !torch.float

---------

Co-authored-by: Kumar Deepak <kumar@xilinx.com>
pull/2592/merge
kumardeepakamd 2024-01-15 11:26:46 -08:00 committed by GitHub
parent 09421b1cf3
commit 87389f0762
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 113 additions and 0 deletions

View File

@ -27,6 +27,18 @@ using namespace mlir::torch::onnx_c;
// to be more normal and a direct translation vs a special case. This // to be more normal and a direct translation vs a special case. This
// results in a lot of ONNX test cases that all reduce to the exact same // results in a lot of ONNX test cases that all reduce to the exact same
// thing here, so we simplify. // thing here, so we simplify.
// utilities
// Templatized function to get an item op of a type
namespace {
template <typename T>
Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter,
Value &ofItem) {
return rewriter.create<Torch::AtenItemOp>(binder.getLoc(),
rewriter.getType<T>(), ofItem);
}
} // namespace
void mlir::torch::onnx_c::populateDefaultDomainQtoZ( void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
OnnxCustomOpConversionPattern &patterns) { OnnxCustomOpConversionPattern &patterns) {
patterns.onOp("Reciprocal", 1, patterns.onOp("Reciprocal", 1,
@ -1336,4 +1348,50 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
data, dimValueList); data, dimValueList);
return success(); return success();
}); });
patterns.onOp(
"Range", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
// ONNX.Range(start, limit, delta) -- limit is exclusive
Torch::ValueTensorType resultType;
Value start, limit, delta;
auto loc = binder.getLoc();
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
if (binder.tensorOperandAtIndex(start, 0) ||
binder.tensorOperandAtIndex(limit, 1) ||
binder.tensorOperandAtIndex(delta, 2) ||
binder.tensorResultType(resultType))
return failure();
// Convert a 0-dimensional/Scalar Tensor ([]) to Scalar Torch Numeric
// Value torch.tensor(1.1) equivalent in ONNX to 1.1 as an example
// type of start, limit, delta can be one of: double, float, int16,
// int32, int64 Assuming start, limit and delta to be same type (could
// they be different?)
Torch::BaseTensorType startTensorType =
start.getType().cast<Torch::BaseTensorType>();
bool isFloatDType = startTensorType.getDtype().isF64() ||
startTensorType.getDtype().isF32();
bool isIntDType = startTensorType.getDtype().isInteger(16) ||
startTensorType.getDtype().isInteger(32) ||
startTensorType.getDtype().isInteger(64);
if (!isFloatDType && !isIntDType) {
return rewriter.notifyMatchFailure(
binder.op, "Expected the start, limit, delta to be one of "
"double, float, int16, int32, int64");
}
Value scalarStart, scalarLimit, scalarDelta;
if (isFloatDType) {
scalarStart = getItemOp<Torch::FloatType>(binder, rewriter, start);
scalarLimit = getItemOp<Torch::FloatType>(binder, rewriter, limit);
scalarDelta = getItemOp<Torch::FloatType>(binder, rewriter, delta);
} else {
scalarStart = getItemOp<Torch::IntType>(binder, rewriter, start);
scalarLimit = getItemOp<Torch::IntType>(binder, rewriter, limit);
scalarDelta = getItemOp<Torch::IntType>(binder, rewriter, delta);
}
rewriter.replaceOpWithNewOp<Torch::AtenArangeStartStepOp>(
binder.op, resultType, scalarStart, scalarLimit, scalarDelta, none,
none, none, none);
return success();
});
} }

View File

@ -1179,3 +1179,58 @@ func.func @test_reshape_zero_and_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32>
%0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[2,3,1,4],f32> %0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[2,3,1,4],f32>
return %0 : !torch.vtensor<[2,3,1,4],f32> return %0 : !torch.vtensor<[2,3,1,4],f32>
} }
// CHECK-LABEL: func.func @test_range_float64_type
func.func @test_range_float64_type(%arg0: !torch.vtensor<[],f64>, %arg1: !torch.vtensor<[],f64>, %arg2: !torch.vtensor<[],f64>) -> !torch.vtensor<[2],f64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[NONE:.*]] torch.constant.none
// CHECK: torch.aten.item %arg0 : !torch.vtensor<[],f64> -> !torch.float
// CHECK: torch.aten.item %arg1 : !torch.vtensor<[],f64> -> !torch.float
// CHECK: torch.aten.item %arg2 : !torch.vtensor<[],f64> -> !torch.float
// CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.float, !torch.float, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],f64>
%0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],f64>, !torch.vtensor<[],f64>, !torch.vtensor<[],f64>) -> !torch.vtensor<[2],f64>
return %0 : !torch.vtensor<[2],f64>
}
// CHECK-LABEL: func.func @test_range_float32_type
func.func @test_range_float32_type(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[2],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[NONE:.*]] torch.constant.none
// CHECK: torch.aten.item %arg0 : !torch.vtensor<[],f32> -> !torch.float
// CHECK: torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float
// CHECK: torch.aten.item %arg2 : !torch.vtensor<[],f32> -> !torch.float
// CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.float, !torch.float, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],f32>
%0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[2],f32>
return %0 : !torch.vtensor<[2],f32>
}
// CHECK-LABEL: func.func @test_range_int64_type
func.func @test_range_int64_type(%arg0: !torch.vtensor<[],si64>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[],si64>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[NONE:.*]] torch.constant.none
// CHECK: torch.aten.item %arg0 : !torch.vtensor<[],si64> -> !torch.int
// CHECK: torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int
// CHECK: torch.aten.item %arg2 : !torch.vtensor<[],si64> -> !torch.int
// CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si64>
%0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[2],si64>
return %0 : !torch.vtensor<[2],si64>
}
// CHECK-LABEL: func.func @test_range_int32_type
func.func @test_range_int32_type(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>) -> !torch.vtensor<[2],si32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[NONE:.*]] torch.constant.none
// CHECK: torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int
// CHECK: torch.aten.item %arg1 : !torch.vtensor<[],si32> -> !torch.int
// CHECK: torch.aten.item %arg2 : !torch.vtensor<[],si32> -> !torch.int
// CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si32>
%0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],si32>, !torch.vtensor<[],si32>, !torch.vtensor<[],si32>) -> !torch.vtensor<[2],si32>
return %0 : !torch.vtensor<[2],si32>
}
// CHECK-LABEL: func.func @test_range_int16_type
func.func @test_range_int16_type(%arg0: !torch.vtensor<[],si16>, %arg1: !torch.vtensor<[],si16>, %arg2: !torch.vtensor<[],si16>) -> !torch.vtensor<[2],si16> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[NONE:.*]] torch.constant.none
// CHECK: torch.aten.item %arg0 : !torch.vtensor<[],si16> -> !torch.int
// CHECK: torch.aten.item %arg1 : !torch.vtensor<[],si16> -> !torch.int
// CHECK: torch.aten.item %arg2 : !torch.vtensor<[],si16> -> !torch.int
// CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si16>
%0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],si16>, !torch.vtensor<[],si16>, !torch.vtensor<[],si16>) -> !torch.vtensor<[2],si16>
return %0 : !torch.vtensor<[2],si16>
}