support for onnx.expand operator (#2729)

maps onnx.expand to torch aten broadcast_to, three tests added

---------

Co-authored-by: Kumar Deepak <kumar@xilinx.com>
pull/2772/head
kumardeepakamd 2024-01-10 13:05:37 -08:00 committed by GitHub
parent 469c055190
commit 29569713f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 99 additions and 0 deletions

View File

@ -1213,6 +1213,51 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.op, resultType, operand);
return success();
});
patterns.onOp(
"Expand", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
// uses ideas and code from onnx.Reshape
Torch::ValueTensorType resultType;
Value data, shape;
if (binder.tensorOperands(data, shape) ||
binder.tensorResultType(resultType))
return failure();
Torch::BaseTensorType shapeType =
shape.getType().cast<Torch::BaseTensorType>();
SmallVector<int64_t> selectSizes;
selectSizes.push_back(1);
Type selectResultType = shapeType.getWithSizesAndDtype(
llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype());
// Variable to store 1-D onnx shape tensor, shapeSizes[0] has the
// dimension size
auto shapeSizes =
dyn_cast<Torch::ValueTensorType>(shape.getType()).getSizes();
// A constant zero value
Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
// Variable to store pytorch int list of shape (dimension)
SmallVector<Value> dimList;
// Convert the shape tensor from vector of int64_t to torch int list as
// we are using torch implementation Torch::AtenBroadcastToOp which
// takes list of int
for (int i = 0; i < shapeSizes[0]; i++) {
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
binder.getLoc(), selectResultType, shape, zero, selectIndex);
Value dim = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
dimList.push_back(dim);
}
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
dimList);
rewriter.replaceOpWithNewOp<Torch::AtenBroadcastToOp>(
binder.op, resultType, data, dimValueList);
return success();
});
patterns.onOp("Floor", 13,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;

View File

@ -487,6 +487,7 @@ func.func @test_equal(%arg0: !torch.vtensor<[3,4,5],si32>, %arg1: !torch.vtensor
return %0 : !torch.vtensor<[3,4,5],i1>
}
// CHECK-LABEL: @test_floor_example
func.func @test_floor_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.floor %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
@ -800,6 +801,59 @@ func.func @test_concat_3d_axis_negative_3(%arg0: !torch.vtensor<[2,2,2],f32>, %a
return %0 : !torch.vtensor<[4,2,2],f32>
}
// CHECK-LABEL: @test_expand_dim2_shape2
func.func @test_expand_dim2_shape2(%arg0: !torch.vtensor<[1,4],f32>, %arg1: !torch.vtensor<[2],si32>)
-> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[1],si32>
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si32> -> !torch.int
// CHECK: %[[INT1:.+]] = torch.constant.int 1
// CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[1],si32>
// CHECK: torch.aten.item %2 : !torch.vtensor<[1],si32> -> !torch.int
// CHECK: torch.prim.ListConstruct %1, %3 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.aten.broadcast_to %arg0, %4 : !torch.vtensor<[1,4],f32>, !torch.list<int> -> !torch.vtensor<[3,4],f32>
%0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,4],f32>, !torch.vtensor<[2],si32>) -> !torch.vtensor<[3,4],f32>
return %0 : !torch.vtensor<[3,4],f32>
}
// CHECK-LABEL: @test_expand_dim2_shape3
func.func @test_expand_dim2_shape3(%arg0: !torch.vtensor<[3,1],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,6],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[INT1:.+]] = torch.constant.int 1
// CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %2 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[INT2:.+]] = torch.constant.int 2
// CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %4 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.prim.ListConstruct %1, %3, %5 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.aten.broadcast_to %arg0, %6 : !torch.vtensor<[3,1],f32>, !torch.list<int> -> !torch.vtensor<[2,3,6],f32>
%0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[3,1],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,6],f32>
return %0 : !torch.vtensor<[2,3,6],f32>
}
// CHECK-LABEL: @test_expand_dim3_shape4
func.func @test_expand_dim3_shape4(%arg0: !torch.vtensor<[1,3,1],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[3,3,3,3],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[INT1:.+]] = torch.constant.int 1
// CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %2 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[INT2:.+]] = torch.constant.int 2
// CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %4 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[INT3:.+]] = torch.constant.int 3
// CHECK: torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.prim.ListConstruct %1, %3, %5, %7 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %9 = torch.aten.broadcast_to %arg0, %8 : !torch.vtensor<[1,3,1],f32>, !torch.list<int> -> !torch.vtensor<[3,3,3,3],f32>
%0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,3,1],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[3,3,3,3],f32>
return %0 : !torch.vtensor<[3,3,3,3],f32>
}
// CHECK-LABEL: @test_dropout
func.func @test_dropout(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.dropout %arg0, %float5.000000e-01, %false : !torch.vtensor<[3],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3],f32