mirror of https://github.com/llvm/torch-mlir
[ONNX] simplify shapes fed to broadcast in Expand lowering (#3756)
Addresses ~200 onnx model compile failures in <https://github.com/nod-ai/SHARK-TestSuite> related to <https://github.com/iree-org/iree/issues/18631>. This change simplifies the result of the generated broadcast op substantially, but reduces the case coverage slightly. The case which will become unsupported: - trying to actually broadcast a dynamic dim that is secretly 1. When does this case appear in practical scenarios? - for a model where onnx shape inference cannot figure out that a dim should be 1. Why do I think we should not support this case for now? 1. For all models with dynamic dim expand ops, the previous path uniformly generates uglier linalg IR (making it harder for IREE to fuse properly with other ops). 2. For models failing shape inference castastrophically enough to fail to see a dim is statically 1, we can try to apply constant folding in the onnx model before importing. Leaving this as a draft PR, since it may be more appropriate to fix the compilation failure in IREE rather than torch-mlir. ### Example of broadcast required in previous path: ```mlir %300 = linalg.generic {indexing_maps = [#map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%299 : tensor<?x12x?x?xi1>) { ^bb0(%out: i1): %306 = linalg.index 0 : index %307 = linalg.index 3 : index %308 = arith.index_cast %285 : i64 to index %309 = arith.cmpi eq, %308, %c1 : index %310 = arith.select %309, %c0, %306 : index %311 = arith.index_cast %286 : i64 to index %312 = arith.cmpi eq, %311, %c1 : index %313 = arith.select %312, %c0, %307 : index %extracted_79 = tensor.extract %reshape_78[%310, %c0, %c0, %313] : tensor<?x1x1x?xi1> linalg.yield %extracted_79 : i1 } -> tensor<?x12x?x?xi1> ``` ### Example of broadcast with simplified shape list: ```mlir %409 = linalg.generic {indexing_maps = [#map15, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%reshape_135 : tensor<?x1x1x?xi1>) outs(%408 : tensor<?x12x?x?xi1>) { ^bb0(%in: i1, %out: i1): linalg.yield %in : i1 } -> tensor<?x12x?x?xi1> ```pull/3766/head
parent
9ab0db5789
commit
f08bfc4ff8
|
@ -2521,7 +2521,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
return failure();
|
||||
|
||||
auto shapeSizes = shapeType.getSizes();
|
||||
int64_t dataRank = dataType.getSizes().size();
|
||||
ArrayRef<int64_t> dataShape = dataType.getSizes();
|
||||
int64_t dataRank = dataShape.size();
|
||||
int64_t shapeRank = shapeSizes.size();
|
||||
if (shapeRank != 1 || shapeSizes[0] == Torch::kUnknownSize)
|
||||
return failure();
|
||||
|
@ -2543,22 +2544,43 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
// we are using torch implementation Torch::AtenBroadcastToOp which
|
||||
// takes list of int
|
||||
for (int i = 0; i < shapeSizes[0]; i++) {
|
||||
// extract dim from shape
|
||||
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
||||
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
||||
loc, selectResultType, shape, zero, selectIndex);
|
||||
Value dim = rewriter.create<Torch::AtenItemOp>(
|
||||
Value selectDim = rewriter.create<Torch::AtenItemOp>(
|
||||
loc, rewriter.getType<Torch::IntType>(), extract);
|
||||
|
||||
if (i + rankDifference >= 0) {
|
||||
// compute dim to pass to broadcast op. For non-broadcastable dims,
|
||||
// pass -1
|
||||
Value dim;
|
||||
if (i + rankDifference >= 0 && dataShape[i + rankDifference] != 1) {
|
||||
// 1. if dataShape[i + rankDiff] > 1, then this cannot be
|
||||
// broadcasted
|
||||
// 2. we will explicitly disallow broadcasting dynamic dims that are
|
||||
// secretly 1.
|
||||
dim = rewriter.create<Torch::ConstantIntOp>(loc, -1);
|
||||
// Assert dataShape[i + rankDiff] >= selectDim. If both are
|
||||
// constant, this should fold out.
|
||||
Value iv =
|
||||
rewriter.create<Torch::ConstantIntOp>(loc, i + rankDifference);
|
||||
auto sz = rewriter.create<Torch::AtenSizeIntOp>(
|
||||
loc, rewriter.getType<Torch::IntType>(), data, iv);
|
||||
dim = rewriter.create<Torch::PrimMaxIntOp>(loc, dim, sz);
|
||||
Value gtSelect =
|
||||
rewriter.create<Torch::AtenGeIntOp>(loc, sz, selectDim);
|
||||
rewriter.create<Torch::RuntimeAssertOp>(
|
||||
loc, gtSelect,
|
||||
rewriter.getStringAttr(
|
||||
"onnx.Expand input has a dim that is not statically 1; "
|
||||
"expected this dim >= dim provided shape."));
|
||||
} else {
|
||||
// 1. excess selectDims get included in broadcast (shapeSizes[0] >
|
||||
// dataRank)
|
||||
// 2. selectDims which correspond to dataShape == 1 get included in
|
||||
// broadcast
|
||||
dim = selectDim;
|
||||
}
|
||||
|
||||
dimList.push_back(dim);
|
||||
}
|
||||
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
|
|
|
@ -42,7 +42,7 @@ def import_onnx(contents):
|
|||
# Import the ONNX model proto from the file contents:
|
||||
raw_model = onnx.load_from_string(contents)
|
||||
# since it does not affect current e2e tests, data_prop is left false here
|
||||
model_proto = onnx.shape_inference.infer_shapes(raw_model)
|
||||
model_proto = onnx.shape_inference.infer_shapes(raw_model, data_prop=True)
|
||||
|
||||
# Import the ONNX module into an MLIR module:
|
||||
context = Context()
|
||||
|
|
|
@ -1608,16 +1608,13 @@ func.func @test_expand_dim2_shape2(%arg0: !torch.vtensor<[1,4],f32>, %arg1: !tor
|
|||
// CHECK-DAG: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[SEL0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32>
|
||||
// CHECK-DAG: %[[ITEM0:.+]] = torch.aten.item %[[SEL0]] : !torch.vtensor<[],si32> -> !torch.int
|
||||
// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[SZ0:.+]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int
|
||||
// CHECK-DAG: %[[MX0:.+]] = torch.prim.max.int %[[ITEM0]], %[[SZ0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32>
|
||||
// CHECK-DAG: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]] : !torch.vtensor<[],si32> -> !torch.int
|
||||
// CHECK-DAG: %[[I1:.+]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int
|
||||
// CHECK-DAG: %[[MX1:.+]] = torch.prim.max.int %[[ITEM1]], %[[SZ1]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[MX0]], %[[MX1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK-DAG: %[[Im1:.+]] = torch.constant.int -1
|
||||
// CHECK-DAG: %[[INT1_1:.+]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[INT1_1]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM0]], %[[Im1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: torch.aten.broadcast_to %arg0, %[[LIST]] : !torch.vtensor<[1,4],f32>, !torch.list<int> -> !torch.vtensor<[3,4],f32>
|
||||
%0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,4],f32>, !torch.vtensor<[2],si32>) -> !torch.vtensor<[3,4],f32>
|
||||
return %0 : !torch.vtensor<[3,4],f32>
|
||||
|
@ -1634,16 +1631,15 @@ func.func @test_expand_dim2_shape3(%arg0: !torch.vtensor<[3,1],f32>, %arg1: !tor
|
|||
// CHECK-NEXT: %[[I1:.+]] = torch.constant.int 1
|
||||
// CHECK-NEXT: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I1]]
|
||||
// CHECK-NEXT: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]]
|
||||
// CHECK-NEXT: %[[Im1:.+]] = torch.constant.int -1
|
||||
// CHECK-NEXT: %[[D1:.+]] = torch.constant.int 0
|
||||
// CHECK-NEXT: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[D1]]
|
||||
// CHECK-NEXT: %[[MX1:.+]] = torch.prim.max.int %[[ITEM1]], %[[SZ1]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK-NEXT: %[[GE:.+]] = torch.aten.ge.int
|
||||
// CHECK-NEXT: torch.runtime.assert %[[GE]]
|
||||
// CHECK-NEXT: %[[I2:.+]] = torch.constant.int 2
|
||||
// CHECK-NEXT: %[[SEL2:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I2]]
|
||||
// CHECK-NEXT: %[[ITEM2:.+]] = torch.aten.item %[[SEL2]]
|
||||
// CHECK-NEXT: %[[D2:.+]] = torch.constant.int 1
|
||||
// CHECK-NEXT: %[[SZ2:.+]] = torch.aten.size.int %arg0, %[[D2]]
|
||||
// CHECK-NEXT: %[[MX2:.+]] = torch.prim.max.int %[[ITEM2]], %[[SZ2]]
|
||||
// CHECK-NEXT: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM0]], %[[MX1]], %[[MX2]]
|
||||
// CHECK-NEXT: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM0]], %[[Im1]], %[[ITEM2]]
|
||||
// CHECK-NEXT: %[[EXPAND:.+]] = torch.aten.broadcast_to %arg0, %[[LIST]]
|
||||
// CHECK: return %[[EXPAND]]
|
||||
%0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[3,1],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,6],f32>
|
||||
|
|
Loading…
Reference in New Issue