[ONNX] simplify shapes fed to broadcast in Expand lowering (#3756)

Addresses ~200 onnx model compile failures in
<https://github.com/nod-ai/SHARK-TestSuite> related to
<https://github.com/iree-org/iree/issues/18631>.

This change simplifies the result of the generated broadcast op
substantially, but reduces the case coverage slightly.

The case which will become unsupported: 
- trying to actually broadcast a dynamic dim that is secretly 1. 

When does this case appear in practical scenarios?
- for a model where onnx shape inference cannot figure out that a dim
should be 1.

Why do I think we should not support this case for now?
1. For all models with dynamic dim expand ops, the previous path
uniformly generates uglier linalg IR (making it harder for IREE to fuse
properly with other ops).
2. For models failing shape inference castastrophically enough to fail
to see a dim is statically 1, we can try to apply constant folding in
the onnx model before importing.

Leaving this as a draft PR, since it may be more appropriate to fix the
compilation failure in IREE rather than torch-mlir.

### Example of broadcast required in previous path:

```mlir
    %300 = linalg.generic {indexing_maps = [#map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%299 : tensor<?x12x?x?xi1>) {
    ^bb0(%out: i1):
      %306 = linalg.index 0 : index
      %307 = linalg.index 3 : index
      %308 = arith.index_cast %285 : i64 to index
      %309 = arith.cmpi eq, %308, %c1 : index
      %310 = arith.select %309, %c0, %306 : index
      %311 = arith.index_cast %286 : i64 to index
      %312 = arith.cmpi eq, %311, %c1 : index
      %313 = arith.select %312, %c0, %307 : index
      %extracted_79 = tensor.extract %reshape_78[%310, %c0, %c0, %313] : tensor<?x1x1x?xi1>
      linalg.yield %extracted_79 : i1
    } -> tensor<?x12x?x?xi1>
```

### Example of broadcast with simplified shape list:

```mlir
    %409 = linalg.generic {indexing_maps = [#map15, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%reshape_135 : tensor<?x1x1x?xi1>) outs(%408 : tensor<?x12x?x?xi1>) {
    ^bb0(%in: i1, %out: i1):
      linalg.yield %in : i1
    } -> tensor<?x12x?x?xi1>
```
pull/3764/head
zjgarvey 2024-10-03 18:11:51 -07:00 committed by GitHub
parent 9ab0db5789
commit f08bfc4ff8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 37 additions and 19 deletions

View File

@ -2521,7 +2521,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
return failure(); return failure();
auto shapeSizes = shapeType.getSizes(); auto shapeSizes = shapeType.getSizes();
int64_t dataRank = dataType.getSizes().size(); ArrayRef<int64_t> dataShape = dataType.getSizes();
int64_t dataRank = dataShape.size();
int64_t shapeRank = shapeSizes.size(); int64_t shapeRank = shapeSizes.size();
if (shapeRank != 1 || shapeSizes[0] == Torch::kUnknownSize) if (shapeRank != 1 || shapeSizes[0] == Torch::kUnknownSize)
return failure(); return failure();
@ -2543,22 +2544,43 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
// we are using torch implementation Torch::AtenBroadcastToOp which // we are using torch implementation Torch::AtenBroadcastToOp which
// takes list of int // takes list of int
for (int i = 0; i < shapeSizes[0]; i++) { for (int i = 0; i < shapeSizes[0]; i++) {
// extract dim from shape
Value selectIndex = rewriter.create<Torch::ConstantIntOp>( Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getType<Torch::IntType>(), loc, rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
Value extract = rewriter.create<Torch::AtenSelectIntOp>( Value extract = rewriter.create<Torch::AtenSelectIntOp>(
loc, selectResultType, shape, zero, selectIndex); loc, selectResultType, shape, zero, selectIndex);
Value dim = rewriter.create<Torch::AtenItemOp>( Value selectDim = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::IntType>(), extract); loc, rewriter.getType<Torch::IntType>(), extract);
// compute dim to pass to broadcast op. For non-broadcastable dims,
if (i + rankDifference >= 0) { // pass -1
Value dim;
if (i + rankDifference >= 0 && dataShape[i + rankDifference] != 1) {
// 1. if dataShape[i + rankDiff] > 1, then this cannot be
// broadcasted
// 2. we will explicitly disallow broadcasting dynamic dims that are
// secretly 1.
dim = rewriter.create<Torch::ConstantIntOp>(loc, -1);
// Assert dataShape[i + rankDiff] >= selectDim. If both are
// constant, this should fold out.
Value iv = Value iv =
rewriter.create<Torch::ConstantIntOp>(loc, i + rankDifference); rewriter.create<Torch::ConstantIntOp>(loc, i + rankDifference);
auto sz = rewriter.create<Torch::AtenSizeIntOp>( auto sz = rewriter.create<Torch::AtenSizeIntOp>(
loc, rewriter.getType<Torch::IntType>(), data, iv); loc, rewriter.getType<Torch::IntType>(), data, iv);
dim = rewriter.create<Torch::PrimMaxIntOp>(loc, dim, sz); Value gtSelect =
rewriter.create<Torch::AtenGeIntOp>(loc, sz, selectDim);
rewriter.create<Torch::RuntimeAssertOp>(
loc, gtSelect,
rewriter.getStringAttr(
"onnx.Expand input has a dim that is not statically 1; "
"expected this dim >= dim provided shape."));
} else {
// 1. excess selectDims get included in broadcast (shapeSizes[0] >
// dataRank)
// 2. selectDims which correspond to dataShape == 1 get included in
// broadcast
dim = selectDim;
} }
dimList.push_back(dim); dimList.push_back(dim);
} }
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>( Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(

View File

@ -42,7 +42,7 @@ def import_onnx(contents):
# Import the ONNX model proto from the file contents: # Import the ONNX model proto from the file contents:
raw_model = onnx.load_from_string(contents) raw_model = onnx.load_from_string(contents)
# since it does not affect current e2e tests, data_prop is left false here # since it does not affect current e2e tests, data_prop is left false here
model_proto = onnx.shape_inference.infer_shapes(raw_model) model_proto = onnx.shape_inference.infer_shapes(raw_model, data_prop=True)
# Import the ONNX module into an MLIR module: # Import the ONNX module into an MLIR module:
context = Context() context = Context()

View File

@ -1608,16 +1608,13 @@ func.func @test_expand_dim2_shape2(%arg0: !torch.vtensor<[1,4],f32>, %arg1: !tor
// CHECK-DAG: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK-DAG: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK-DAG: %[[SEL0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32> // CHECK-DAG: %[[SEL0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32>
// CHECK-DAG: %[[ITEM0:.+]] = torch.aten.item %[[SEL0]] : !torch.vtensor<[],si32> -> !torch.int // CHECK-DAG: %[[ITEM0:.+]] = torch.aten.item %[[SEL0]] : !torch.vtensor<[],si32> -> !torch.int
// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
// CHECK-DAG: %[[SZ0:.+]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int
// CHECK-DAG: %[[MX0:.+]] = torch.prim.max.int %[[ITEM0]], %[[SZ0]] : !torch.int, !torch.int -> !torch.int
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
// CHECK-DAG: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32> // CHECK-DAG: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32>
// CHECK-DAG: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]] : !torch.vtensor<[],si32> -> !torch.int // CHECK-DAG: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]] : !torch.vtensor<[],si32> -> !torch.int
// CHECK-DAG: %[[I1:.+]] = torch.constant.int 1 // CHECK-DAG: %[[Im1:.+]] = torch.constant.int -1
// CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int // CHECK-DAG: %[[INT1_1:.+]] = torch.constant.int 1
// CHECK-DAG: %[[MX1:.+]] = torch.prim.max.int %[[ITEM1]], %[[SZ1]] : !torch.int, !torch.int -> !torch.int // CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[INT1_1]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[MX0]], %[[MX1]] : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM0]], %[[Im1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.aten.broadcast_to %arg0, %[[LIST]] : !torch.vtensor<[1,4],f32>, !torch.list<int> -> !torch.vtensor<[3,4],f32> // CHECK: torch.aten.broadcast_to %arg0, %[[LIST]] : !torch.vtensor<[1,4],f32>, !torch.list<int> -> !torch.vtensor<[3,4],f32>
%0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,4],f32>, !torch.vtensor<[2],si32>) -> !torch.vtensor<[3,4],f32> %0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,4],f32>, !torch.vtensor<[2],si32>) -> !torch.vtensor<[3,4],f32>
return %0 : !torch.vtensor<[3,4],f32> return %0 : !torch.vtensor<[3,4],f32>
@ -1634,16 +1631,15 @@ func.func @test_expand_dim2_shape3(%arg0: !torch.vtensor<[3,1],f32>, %arg1: !tor
// CHECK-NEXT: %[[I1:.+]] = torch.constant.int 1 // CHECK-NEXT: %[[I1:.+]] = torch.constant.int 1
// CHECK-NEXT: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I1]] // CHECK-NEXT: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I1]]
// CHECK-NEXT: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]] // CHECK-NEXT: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]]
// CHECK-NEXT: %[[Im1:.+]] = torch.constant.int -1
// CHECK-NEXT: %[[D1:.+]] = torch.constant.int 0 // CHECK-NEXT: %[[D1:.+]] = torch.constant.int 0
// CHECK-NEXT: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[D1]] // CHECK-NEXT: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[D1]]
// CHECK-NEXT: %[[MX1:.+]] = torch.prim.max.int %[[ITEM1]], %[[SZ1]] : !torch.int, !torch.int -> !torch.int // CHECK-NEXT: %[[GE:.+]] = torch.aten.ge.int
// CHECK-NEXT: torch.runtime.assert %[[GE]]
// CHECK-NEXT: %[[I2:.+]] = torch.constant.int 2 // CHECK-NEXT: %[[I2:.+]] = torch.constant.int 2
// CHECK-NEXT: %[[SEL2:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I2]] // CHECK-NEXT: %[[SEL2:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I2]]
// CHECK-NEXT: %[[ITEM2:.+]] = torch.aten.item %[[SEL2]] // CHECK-NEXT: %[[ITEM2:.+]] = torch.aten.item %[[SEL2]]
// CHECK-NEXT: %[[D2:.+]] = torch.constant.int 1 // CHECK-NEXT: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM0]], %[[Im1]], %[[ITEM2]]
// CHECK-NEXT: %[[SZ2:.+]] = torch.aten.size.int %arg0, %[[D2]]
// CHECK-NEXT: %[[MX2:.+]] = torch.prim.max.int %[[ITEM2]], %[[SZ2]]
// CHECK-NEXT: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM0]], %[[MX1]], %[[MX2]]
// CHECK-NEXT: %[[EXPAND:.+]] = torch.aten.broadcast_to %arg0, %[[LIST]] // CHECK-NEXT: %[[EXPAND:.+]] = torch.aten.broadcast_to %arg0, %[[LIST]]
// CHECK: return %[[EXPAND]] // CHECK: return %[[EXPAND]]
%0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[3,1],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,6],f32> %0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[3,1],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,6],f32>