mirror of https://github.com/llvm/torch-mlir
[onnx] Fix expand operation for dynamic shape max (#3001)
If the broadcast shape is length-1 at a dim while `?` in the input dim then we need to broadcast to the dynamic dim. This is equivalent to taking a max of two dimensions.pull/3003/head
parent
0723584936
commit
bd7f1baa42
|
@ -1495,23 +1495,34 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
patterns.onOp(
|
||||
"Expand", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
// uses ideas and code from onnx.Reshape
|
||||
auto loc = binder.getLoc();
|
||||
Torch::ValueTensorType resultType;
|
||||
Value data, shape;
|
||||
if (binder.tensorOperands(data, shape) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
Torch::BaseTensorType shapeType =
|
||||
shape.getType().cast<Torch::BaseTensorType>();
|
||||
|
||||
auto dataType = cast<Torch::BaseTensorType>(data.getType());
|
||||
auto shapeType = cast<Torch::BaseTensorType>(shape.getType());
|
||||
if (!dataType.hasSizes() || !shapeType.hasSizes())
|
||||
return failure();
|
||||
|
||||
auto shapeSizes = shapeType.getSizes();
|
||||
int64_t dataRank = dataType.getSizes().size();
|
||||
int64_t shapeRank = shapeSizes.size();
|
||||
if (shapeRank != 1 || shapeSizes[0] == Torch::kUnknownSize)
|
||||
return failure();
|
||||
|
||||
auto rankDifference = dataRank - shapeSizes[0];
|
||||
|
||||
SmallVector<int64_t> selectSizes;
|
||||
Type selectResultType = shapeType.getWithSizesAndDtype(
|
||||
llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype());
|
||||
// Variable to store 1-D onnx shape tensor, shapeSizes[0] has the
|
||||
// dimension size
|
||||
auto shapeSizes =
|
||||
dyn_cast<Torch::ValueTensorType>(shape.getType()).getSizes();
|
||||
// A constant zero value
|
||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
||||
loc, rewriter.getI64IntegerAttr(0));
|
||||
// Variable to store pytorch int list of shape (dimension)
|
||||
SmallVector<Value> dimList;
|
||||
|
||||
|
@ -1520,12 +1531,21 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
// takes list of int
|
||||
for (int i = 0; i < shapeSizes[0]; i++) {
|
||||
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
loc, rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
||||
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
||||
binder.getLoc(), selectResultType, shape, zero, selectIndex);
|
||||
loc, selectResultType, shape, zero, selectIndex);
|
||||
Value dim = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
||||
loc, rewriter.getType<Torch::IntType>(), extract);
|
||||
|
||||
if (i + rankDifference >= 0) {
|
||||
Value iv =
|
||||
rewriter.create<Torch::ConstantIntOp>(loc, i + rankDifference);
|
||||
auto sz = rewriter.create<Torch::AtenSizeIntOp>(
|
||||
loc, rewriter.getType<Torch::IntType>(), data, iv);
|
||||
dim = rewriter.create<Torch::PrimMaxIntOp>(loc, dim, sz);
|
||||
}
|
||||
|
||||
dimList.push_back(dim);
|
||||
}
|
||||
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
|
|
|
@ -1507,8 +1507,6 @@ ONNX_XFAIL_SET = {
|
|||
"ArangeStartOutDtypeModule_basic",
|
||||
"ArangeStartOutViewModule_basic",
|
||||
"BroadcastDynamicDimModule_basic",
|
||||
"BroadcastToModule_basic",
|
||||
"ExpandModule_basic",
|
||||
"MoveDimIntNegativeIndexModule_basic",
|
||||
"ViewSizeFromOtherTensor_basic",
|
||||
|
||||
|
|
|
@ -1164,15 +1164,21 @@ func.func @test_exp(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f3
|
|||
// CHECK-LABEL: @test_expand_dim2_shape2
|
||||
func.func @test_expand_dim2_shape2(%arg0: !torch.vtensor<[1,4],f32>, %arg1: !torch.vtensor<[2],si32>)
|
||||
-> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32>
|
||||
// CHECK: torch.aten.item %0 : !torch.vtensor<[],si32> -> !torch.int
|
||||
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32>
|
||||
// CHECK: torch.aten.item %2 : !torch.vtensor<[],si32> -> !torch.int
|
||||
// CHECK: torch.prim.ListConstruct %1, %3 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: torch.aten.broadcast_to %arg0, %4 : !torch.vtensor<[1,4],f32>, !torch.list<int> -> !torch.vtensor<[3,4],f32>
|
||||
// CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[SEL0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32>
|
||||
// CHECK-DAG: %[[ITEM0:.+]] = torch.aten.item %[[SEL0]] : !torch.vtensor<[],si32> -> !torch.int
|
||||
// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[SZ0:.+]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int
|
||||
// CHECK-DAG: %[[MX0:.+]] = torch.prim.max.int %[[ITEM0]], %[[SZ0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32>
|
||||
// CHECK-DAG: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]] : !torch.vtensor<[],si32> -> !torch.int
|
||||
// CHECK-DAG: %[[I1:.+]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int
|
||||
// CHECK-DAG: %[[MX1:.+]] = torch.prim.max.int %[[ITEM1]], %[[SZ1]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[MX0]], %[[MX1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: torch.aten.broadcast_to %arg0, %[[LIST]] : !torch.vtensor<[1,4],f32>, !torch.list<int> -> !torch.vtensor<[3,4],f32>
|
||||
%0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,4],f32>, !torch.vtensor<[2],si32>) -> !torch.vtensor<[3,4],f32>
|
||||
return %0 : !torch.vtensor<[3,4],f32>
|
||||
}
|
||||
|
@ -1181,47 +1187,31 @@ func.func @test_expand_dim2_shape2(%arg0: !torch.vtensor<[1,4],f32>, %arg1: !tor
|
|||
|
||||
// CHECK-LABEL: @test_expand_dim2_shape3
|
||||
func.func @test_expand_dim2_shape3(%arg0: !torch.vtensor<[3,1],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,6],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: torch.aten.item %0 : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: torch.aten.item %2 : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: torch.aten.item %4 : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: torch.prim.ListConstruct %1, %3, %5 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: torch.aten.broadcast_to %arg0, %6 : !torch.vtensor<[3,1],f32>, !torch.list<int> -> !torch.vtensor<[2,3,6],f32>
|
||||
// CHECK: %[[I0:.+]] = torch.constant.int 0
|
||||
// CHECK-NEXT: %[[I0_0:.+]] = torch.constant.int 0
|
||||
// CHECK-NEXT: %[[SEL0:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I0_0]]
|
||||
// CHECK-NEXT: %[[ITEM0:.+]] = torch.aten.item %[[SEL0]]
|
||||
// CHECK-NEXT: %[[I1:.+]] = torch.constant.int 1
|
||||
// CHECK-NEXT: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I1]]
|
||||
// CHECK-NEXT: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]]
|
||||
// CHECK-NEXT: %[[D1:.+]] = torch.constant.int 0
|
||||
// CHECK-NEXT: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[D1]]
|
||||
// CHECK-NEXT: %[[MX1:.+]] = torch.prim.max.int %[[ITEM1]], %[[SZ1]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK-NEXT: %[[I2:.+]] = torch.constant.int 2
|
||||
// CHECK-NEXT: %[[SEL2:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I2]]
|
||||
// CHECK-NEXT: %[[ITEM2:.+]] = torch.aten.item %[[SEL2]]
|
||||
// CHECK-NEXT: %[[D2:.+]] = torch.constant.int 1
|
||||
// CHECK-NEXT: %[[SZ2:.+]] = torch.aten.size.int %arg0, %[[D2]]
|
||||
// CHECK-NEXT: %[[MX2:.+]] = torch.prim.max.int %[[ITEM2]], %[[SZ2]]
|
||||
// CHECK-NEXT: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM0]], %[[MX1]], %[[MX2]]
|
||||
// CHECK-NEXT: %[[EXPAND:.+]] = torch.aten.broadcast_to %arg0, %[[LIST]]
|
||||
// CHECK: return %[[EXPAND]]
|
||||
%0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[3,1],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,6],f32>
|
||||
return %0 : !torch.vtensor<[2,3,6],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_expand_dim3_shape4
|
||||
func.func @test_expand_dim3_shape4(%arg0: !torch.vtensor<[1,3,1],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[3,3,3,3],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: torch.aten.item %0 : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: torch.aten.item %2 : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: torch.aten.item %4 : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: torch.aten.item %6 : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: torch.prim.ListConstruct %1, %3, %5, %7 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %9 = torch.aten.broadcast_to %arg0, %8 : !torch.vtensor<[1,3,1],f32>, !torch.list<int> -> !torch.vtensor<[3,3,3,3],f32>
|
||||
%0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,3,1],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[3,3,3,3],f32>
|
||||
return %0 : !torch.vtensor<[3,3,3,3],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_dropout
|
||||
func.func @test_dropout(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.dropout %arg0, %float5.000000e-01, %false : !torch.vtensor<[3],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3],f32
|
||||
|
|
Loading…
Reference in New Issue