mirror of https://github.com/llvm/torch-mlir
[ONNXToTorch] Add conversion for Onnx range (#2752)
Implemented ONNX.Range. The spec says the data type for start, limit, delta are 0-D can be double, float, int16, int32, int64, All int types mapped to !torch.int and all float types mapped to !torch.float --------- Co-authored-by: Kumar Deepak <kumar@xilinx.com>pull/2592/merge
parent
09421b1cf3
commit
87389f0762
|
@ -27,6 +27,18 @@ using namespace mlir::torch::onnx_c;
|
|||
// to be more normal and a direct translation vs a special case. This
|
||||
// results in a lot of ONNX test cases that all reduce to the exact same
|
||||
// thing here, so we simplify.
|
||||
|
||||
// utilities
|
||||
// Templatized function to get an item op of a type
|
||||
namespace {
|
||||
template <typename T>
|
||||
Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter,
|
||||
Value &ofItem) {
|
||||
return rewriter.create<Torch::AtenItemOp>(binder.getLoc(),
|
||||
rewriter.getType<T>(), ofItem);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||
OnnxCustomOpConversionPattern &patterns) {
|
||||
patterns.onOp("Reciprocal", 1,
|
||||
|
@ -1336,4 +1348,50 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
data, dimValueList);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"Range", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
// ONNX.Range(start, limit, delta) -- limit is exclusive
|
||||
|
||||
Torch::ValueTensorType resultType;
|
||||
Value start, limit, delta;
|
||||
auto loc = binder.getLoc();
|
||||
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||
if (binder.tensorOperandAtIndex(start, 0) ||
|
||||
binder.tensorOperandAtIndex(limit, 1) ||
|
||||
binder.tensorOperandAtIndex(delta, 2) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
|
||||
// Convert a 0-dimensional/Scalar Tensor ([]) to Scalar Torch Numeric
|
||||
// Value torch.tensor(1.1) equivalent in ONNX to 1.1 as an example
|
||||
// type of start, limit, delta can be one of: double, float, int16,
|
||||
// int32, int64 Assuming start, limit and delta to be same type (could
|
||||
// they be different?)
|
||||
Torch::BaseTensorType startTensorType =
|
||||
start.getType().cast<Torch::BaseTensorType>();
|
||||
bool isFloatDType = startTensorType.getDtype().isF64() ||
|
||||
startTensorType.getDtype().isF32();
|
||||
bool isIntDType = startTensorType.getDtype().isInteger(16) ||
|
||||
startTensorType.getDtype().isInteger(32) ||
|
||||
startTensorType.getDtype().isInteger(64);
|
||||
if (!isFloatDType && !isIntDType) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Expected the start, limit, delta to be one of "
|
||||
"double, float, int16, int32, int64");
|
||||
}
|
||||
Value scalarStart, scalarLimit, scalarDelta;
|
||||
if (isFloatDType) {
|
||||
scalarStart = getItemOp<Torch::FloatType>(binder, rewriter, start);
|
||||
scalarLimit = getItemOp<Torch::FloatType>(binder, rewriter, limit);
|
||||
scalarDelta = getItemOp<Torch::FloatType>(binder, rewriter, delta);
|
||||
} else {
|
||||
scalarStart = getItemOp<Torch::IntType>(binder, rewriter, start);
|
||||
scalarLimit = getItemOp<Torch::IntType>(binder, rewriter, limit);
|
||||
scalarDelta = getItemOp<Torch::IntType>(binder, rewriter, delta);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenArangeStartStepOp>(
|
||||
binder.op, resultType, scalarStart, scalarLimit, scalarDelta, none,
|
||||
none, none, none);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
|
|
@ -1179,3 +1179,58 @@ func.func @test_reshape_zero_and_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32>
|
|||
%0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[2,3,1,4],f32>
|
||||
return %0 : !torch.vtensor<[2,3,1,4],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @test_range_float64_type
|
||||
func.func @test_range_float64_type(%arg0: !torch.vtensor<[],f64>, %arg1: !torch.vtensor<[],f64>, %arg2: !torch.vtensor<[],f64>) -> !torch.vtensor<[2],f64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[NONE:.*]] torch.constant.none
|
||||
// CHECK: torch.aten.item %arg0 : !torch.vtensor<[],f64> -> !torch.float
|
||||
// CHECK: torch.aten.item %arg1 : !torch.vtensor<[],f64> -> !torch.float
|
||||
// CHECK: torch.aten.item %arg2 : !torch.vtensor<[],f64> -> !torch.float
|
||||
// CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.float, !torch.float, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],f64>
|
||||
%0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],f64>, !torch.vtensor<[],f64>, !torch.vtensor<[],f64>) -> !torch.vtensor<[2],f64>
|
||||
return %0 : !torch.vtensor<[2],f64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @test_range_float32_type
|
||||
func.func @test_range_float32_type(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[2],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[NONE:.*]] torch.constant.none
|
||||
// CHECK: torch.aten.item %arg0 : !torch.vtensor<[],f32> -> !torch.float
|
||||
// CHECK: torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float
|
||||
// CHECK: torch.aten.item %arg2 : !torch.vtensor<[],f32> -> !torch.float
|
||||
// CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.float, !torch.float, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],f32>
|
||||
%0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[2],f32>
|
||||
return %0 : !torch.vtensor<[2],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @test_range_int64_type
|
||||
func.func @test_range_int64_type(%arg0: !torch.vtensor<[],si64>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[],si64>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[NONE:.*]] torch.constant.none
|
||||
// CHECK: torch.aten.item %arg0 : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: torch.aten.item %arg2 : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si64>
|
||||
%0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[2],si64>
|
||||
return %0 : !torch.vtensor<[2],si64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @test_range_int32_type
|
||||
func.func @test_range_int32_type(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>) -> !torch.vtensor<[2],si32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[NONE:.*]] torch.constant.none
|
||||
// CHECK: torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int
|
||||
// CHECK: torch.aten.item %arg1 : !torch.vtensor<[],si32> -> !torch.int
|
||||
// CHECK: torch.aten.item %arg2 : !torch.vtensor<[],si32> -> !torch.int
|
||||
// CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si32>
|
||||
%0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],si32>, !torch.vtensor<[],si32>, !torch.vtensor<[],si32>) -> !torch.vtensor<[2],si32>
|
||||
return %0 : !torch.vtensor<[2],si32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @test_range_int16_type
|
||||
func.func @test_range_int16_type(%arg0: !torch.vtensor<[],si16>, %arg1: !torch.vtensor<[],si16>, %arg2: !torch.vtensor<[],si16>) -> !torch.vtensor<[2],si16> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[NONE:.*]] torch.constant.none
|
||||
// CHECK: torch.aten.item %arg0 : !torch.vtensor<[],si16> -> !torch.int
|
||||
// CHECK: torch.aten.item %arg1 : !torch.vtensor<[],si16> -> !torch.int
|
||||
// CHECK: torch.aten.item %arg2 : !torch.vtensor<[],si16> -> !torch.int
|
||||
// CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si16>
|
||||
%0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],si16>, !torch.vtensor<[],si16>, !torch.vtensor<[],si16>) -> !torch.vtensor<[2],si16>
|
||||
return %0 : !torch.vtensor<[2],si16>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue