mirror of https://github.com/llvm/torch-mlir
[ONNXToTorch] Add conversion for Onnx range (#2752)
Implemented ONNX.Range. The spec says the data type for start, limit, delta are 0-D can be double, float, int16, int32, int64, All int types mapped to !torch.int and all float types mapped to !torch.float --------- Co-authored-by: Kumar Deepak <kumar@xilinx.com>pull/2592/merge
parent
09421b1cf3
commit
87389f0762
|
@ -27,6 +27,18 @@ using namespace mlir::torch::onnx_c;
|
||||||
// to be more normal and a direct translation vs a special case. This
|
// to be more normal and a direct translation vs a special case. This
|
||||||
// results in a lot of ONNX test cases that all reduce to the exact same
|
// results in a lot of ONNX test cases that all reduce to the exact same
|
||||||
// thing here, so we simplify.
|
// thing here, so we simplify.
|
||||||
|
|
||||||
|
// utilities
|
||||||
|
// Templatized function to get an item op of a type
|
||||||
|
namespace {
|
||||||
|
template <typename T>
|
||||||
|
Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter,
|
||||||
|
Value &ofItem) {
|
||||||
|
return rewriter.create<Torch::AtenItemOp>(binder.getLoc(),
|
||||||
|
rewriter.getType<T>(), ofItem);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
OnnxCustomOpConversionPattern &patterns) {
|
OnnxCustomOpConversionPattern &patterns) {
|
||||||
patterns.onOp("Reciprocal", 1,
|
patterns.onOp("Reciprocal", 1,
|
||||||
|
@ -1336,4 +1348,50 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
data, dimValueList);
|
data, dimValueList);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
|
patterns.onOp(
|
||||||
|
"Range", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
// ONNX.Range(start, limit, delta) -- limit is exclusive
|
||||||
|
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
Value start, limit, delta;
|
||||||
|
auto loc = binder.getLoc();
|
||||||
|
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||||
|
if (binder.tensorOperandAtIndex(start, 0) ||
|
||||||
|
binder.tensorOperandAtIndex(limit, 1) ||
|
||||||
|
binder.tensorOperandAtIndex(delta, 2) ||
|
||||||
|
binder.tensorResultType(resultType))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Convert a 0-dimensional/Scalar Tensor ([]) to Scalar Torch Numeric
|
||||||
|
// Value torch.tensor(1.1) equivalent in ONNX to 1.1 as an example
|
||||||
|
// type of start, limit, delta can be one of: double, float, int16,
|
||||||
|
// int32, int64 Assuming start, limit and delta to be same type (could
|
||||||
|
// they be different?)
|
||||||
|
Torch::BaseTensorType startTensorType =
|
||||||
|
start.getType().cast<Torch::BaseTensorType>();
|
||||||
|
bool isFloatDType = startTensorType.getDtype().isF64() ||
|
||||||
|
startTensorType.getDtype().isF32();
|
||||||
|
bool isIntDType = startTensorType.getDtype().isInteger(16) ||
|
||||||
|
startTensorType.getDtype().isInteger(32) ||
|
||||||
|
startTensorType.getDtype().isInteger(64);
|
||||||
|
if (!isFloatDType && !isIntDType) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "Expected the start, limit, delta to be one of "
|
||||||
|
"double, float, int16, int32, int64");
|
||||||
|
}
|
||||||
|
Value scalarStart, scalarLimit, scalarDelta;
|
||||||
|
if (isFloatDType) {
|
||||||
|
scalarStart = getItemOp<Torch::FloatType>(binder, rewriter, start);
|
||||||
|
scalarLimit = getItemOp<Torch::FloatType>(binder, rewriter, limit);
|
||||||
|
scalarDelta = getItemOp<Torch::FloatType>(binder, rewriter, delta);
|
||||||
|
} else {
|
||||||
|
scalarStart = getItemOp<Torch::IntType>(binder, rewriter, start);
|
||||||
|
scalarLimit = getItemOp<Torch::IntType>(binder, rewriter, limit);
|
||||||
|
scalarDelta = getItemOp<Torch::IntType>(binder, rewriter, delta);
|
||||||
|
}
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenArangeStartStepOp>(
|
||||||
|
binder.op, resultType, scalarStart, scalarLimit, scalarDelta, none,
|
||||||
|
none, none, none);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -1179,3 +1179,58 @@ func.func @test_reshape_zero_and_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32>
|
||||||
%0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[2,3,1,4],f32>
|
%0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[2,3,1,4],f32>
|
||||||
return %0 : !torch.vtensor<[2,3,1,4],f32>
|
return %0 : !torch.vtensor<[2,3,1,4],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_range_float64_type
|
||||||
|
func.func @test_range_float64_type(%arg0: !torch.vtensor<[],f64>, %arg1: !torch.vtensor<[],f64>, %arg2: !torch.vtensor<[],f64>) -> !torch.vtensor<[2],f64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[NONE:.*]] torch.constant.none
|
||||||
|
// CHECK: torch.aten.item %arg0 : !torch.vtensor<[],f64> -> !torch.float
|
||||||
|
// CHECK: torch.aten.item %arg1 : !torch.vtensor<[],f64> -> !torch.float
|
||||||
|
// CHECK: torch.aten.item %arg2 : !torch.vtensor<[],f64> -> !torch.float
|
||||||
|
// CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.float, !torch.float, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],f64>
|
||||||
|
%0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],f64>, !torch.vtensor<[],f64>, !torch.vtensor<[],f64>) -> !torch.vtensor<[2],f64>
|
||||||
|
return %0 : !torch.vtensor<[2],f64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_range_float32_type
|
||||||
|
func.func @test_range_float32_type(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[2],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[NONE:.*]] torch.constant.none
|
||||||
|
// CHECK: torch.aten.item %arg0 : !torch.vtensor<[],f32> -> !torch.float
|
||||||
|
// CHECK: torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float
|
||||||
|
// CHECK: torch.aten.item %arg2 : !torch.vtensor<[],f32> -> !torch.float
|
||||||
|
// CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.float, !torch.float, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],f32>
|
||||||
|
%0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[2],f32>
|
||||||
|
return %0 : !torch.vtensor<[2],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_range_int64_type
|
||||||
|
func.func @test_range_int64_type(%arg0: !torch.vtensor<[],si64>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[],si64>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[NONE:.*]] torch.constant.none
|
||||||
|
// CHECK: torch.aten.item %arg0 : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: torch.aten.item %arg2 : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si64>
|
||||||
|
%0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[2],si64>
|
||||||
|
return %0 : !torch.vtensor<[2],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_range_int32_type
|
||||||
|
func.func @test_range_int32_type(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>) -> !torch.vtensor<[2],si32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[NONE:.*]] torch.constant.none
|
||||||
|
// CHECK: torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int
|
||||||
|
// CHECK: torch.aten.item %arg1 : !torch.vtensor<[],si32> -> !torch.int
|
||||||
|
// CHECK: torch.aten.item %arg2 : !torch.vtensor<[],si32> -> !torch.int
|
||||||
|
// CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si32>
|
||||||
|
%0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],si32>, !torch.vtensor<[],si32>, !torch.vtensor<[],si32>) -> !torch.vtensor<[2],si32>
|
||||||
|
return %0 : !torch.vtensor<[2],si32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_range_int16_type
|
||||||
|
func.func @test_range_int16_type(%arg0: !torch.vtensor<[],si16>, %arg1: !torch.vtensor<[],si16>, %arg2: !torch.vtensor<[],si16>) -> !torch.vtensor<[2],si16> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[NONE:.*]] torch.constant.none
|
||||||
|
// CHECK: torch.aten.item %arg0 : !torch.vtensor<[],si16> -> !torch.int
|
||||||
|
// CHECK: torch.aten.item %arg1 : !torch.vtensor<[],si16> -> !torch.int
|
||||||
|
// CHECK: torch.aten.item %arg2 : !torch.vtensor<[],si16> -> !torch.int
|
||||||
|
// CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si16>
|
||||||
|
%0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],si16>, !torch.vtensor<[],si16>, !torch.vtensor<[],si16>) -> !torch.vtensor<[2],si16>
|
||||||
|
return %0 : !torch.vtensor<[2],si16>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue