mirror of https://github.com/llvm/torch-mlir
support for onnx.expand operator (#2729)
maps onnx.expand to torch aten broadcast_to, three tests added --------- Co-authored-by: Kumar Deepak <kumar@xilinx.com>pull/2772/head
parent
469c055190
commit
29569713f3
|
@ -1213,6 +1213,51 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
binder.op, resultType, operand);
|
binder.op, resultType, operand);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
|
patterns.onOp(
|
||||||
|
"Expand", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
// uses ideas and code from onnx.Reshape
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
Value data, shape;
|
||||||
|
if (binder.tensorOperands(data, shape) ||
|
||||||
|
binder.tensorResultType(resultType))
|
||||||
|
return failure();
|
||||||
|
Torch::BaseTensorType shapeType =
|
||||||
|
shape.getType().cast<Torch::BaseTensorType>();
|
||||||
|
SmallVector<int64_t> selectSizes;
|
||||||
|
selectSizes.push_back(1);
|
||||||
|
Type selectResultType = shapeType.getWithSizesAndDtype(
|
||||||
|
llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype());
|
||||||
|
// Variable to store 1-D onnx shape tensor, shapeSizes[0] has the
|
||||||
|
// dimension size
|
||||||
|
auto shapeSizes =
|
||||||
|
dyn_cast<Torch::ValueTensorType>(shape.getType()).getSizes();
|
||||||
|
// A constant zero value
|
||||||
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
||||||
|
// Variable to store pytorch int list of shape (dimension)
|
||||||
|
SmallVector<Value> dimList;
|
||||||
|
|
||||||
|
// Convert the shape tensor from vector of int64_t to torch int list as
|
||||||
|
// we are using torch implementation Torch::AtenBroadcastToOp which
|
||||||
|
// takes list of int
|
||||||
|
for (int i = 0; i < shapeSizes[0]; i++) {
|
||||||
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
||||||
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
||||||
|
binder.getLoc(), selectResultType, shape, zero, selectIndex);
|
||||||
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
||||||
|
dimList.push_back(dim);
|
||||||
|
}
|
||||||
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
binder.getLoc(),
|
||||||
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||||
|
dimList);
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenBroadcastToOp>(
|
||||||
|
binder.op, resultType, data, dimValueList);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
patterns.onOp("Floor", 13,
|
patterns.onOp("Floor", 13,
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
Torch::ValueTensorType resultType;
|
Torch::ValueTensorType resultType;
|
||||||
|
|
|
@ -487,6 +487,7 @@ func.func @test_equal(%arg0: !torch.vtensor<[3,4,5],si32>, %arg1: !torch.vtensor
|
||||||
return %0 : !torch.vtensor<[3,4,5],i1>
|
return %0 : !torch.vtensor<[3,4,5],i1>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// CHECK-LABEL: @test_floor_example
|
// CHECK-LABEL: @test_floor_example
|
||||||
func.func @test_floor_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_floor_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: torch.aten.floor %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
|
// CHECK: torch.aten.floor %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
|
||||||
|
@ -800,6 +801,59 @@ func.func @test_concat_3d_axis_negative_3(%arg0: !torch.vtensor<[2,2,2],f32>, %a
|
||||||
return %0 : !torch.vtensor<[4,2,2],f32>
|
return %0 : !torch.vtensor<[4,2,2],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_expand_dim2_shape2
|
||||||
|
func.func @test_expand_dim2_shape2(%arg0: !torch.vtensor<[1,4],f32>, %arg1: !torch.vtensor<[2],si32>)
|
||||||
|
-> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[1],si32>
|
||||||
|
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si32> -> !torch.int
|
||||||
|
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||||
|
// CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[1],si32>
|
||||||
|
// CHECK: torch.aten.item %2 : !torch.vtensor<[1],si32> -> !torch.int
|
||||||
|
// CHECK: torch.prim.ListConstruct %1, %3 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: torch.aten.broadcast_to %arg0, %4 : !torch.vtensor<[1,4],f32>, !torch.list<int> -> !torch.vtensor<[3,4],f32>
|
||||||
|
%0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,4],f32>, !torch.vtensor<[2],si32>) -> !torch.vtensor<[3,4],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,4],f32>
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: @test_expand_dim2_shape3
|
||||||
|
func.func @test_expand_dim2_shape3(%arg0: !torch.vtensor<[3,1],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,6],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||||
|
// CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: torch.aten.item %2 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
||||||
|
// CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: torch.aten.item %4 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: torch.prim.ListConstruct %1, %3, %5 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: torch.aten.broadcast_to %arg0, %6 : !torch.vtensor<[3,1],f32>, !torch.list<int> -> !torch.vtensor<[2,3,6],f32>
|
||||||
|
%0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[3,1],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,6],f32>
|
||||||
|
return %0 : !torch.vtensor<[2,3,6],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_expand_dim3_shape4
|
||||||
|
func.func @test_expand_dim3_shape4(%arg0: !torch.vtensor<[1,3,1],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[3,3,3,3],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||||
|
// CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: torch.aten.item %2 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
||||||
|
// CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: torch.aten.item %4 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
||||||
|
// CHECK: torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: torch.prim.ListConstruct %1, %3, %5, %7 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %9 = torch.aten.broadcast_to %arg0, %8 : !torch.vtensor<[1,3,1],f32>, !torch.list<int> -> !torch.vtensor<[3,3,3,3],f32>
|
||||||
|
%0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,3,1],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[3,3,3,3],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,3,3,3],f32>
|
||||||
|
}
|
||||||
// CHECK-LABEL: @test_dropout
|
// CHECK-LABEL: @test_dropout
|
||||||
func.func @test_dropout(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_dropout(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: torch.aten.dropout %arg0, %float5.000000e-01, %false : !torch.vtensor<[3],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3],f32
|
// CHECK: torch.aten.dropout %arg0, %float5.000000e-01, %false : !torch.vtensor<[3],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3],f32
|
||||||
|
|
Loading…
Reference in New Issue