From 7a0d0e954b145d28c6e495b5324d11cb03402f60 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 16 Feb 2024 13:05:44 -0800 Subject: [PATCH] [onnx] Fix onnx.gather lowering to use torch.aten.index_select (#2913) Onnx's gather maps directly to `torch.aten.index_select`. We should just use that path. --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 182 +++++------------- projects/pt1/e2e_testing/xfail_sets.py | 5 +- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 43 ++--- 3 files changed, 66 insertions(+), 164 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 9b2f3673c..50d4fae53 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -471,146 +471,66 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.s64IntegerAttr(axis, "axis", 0)) return failure(); Location loc = binder.getLoc(); + auto ctx = binder.op->getContext(); + auto indicesTy = cast(indices.getType()); + auto dataTy = cast(data.getType()); + if (!dataTy || !dataTy.hasSizes()) + return failure(); + if (axis < 0) + axis += dataTy.getSizes().size(); - // 1. Get data shape and rank. - auto dataTensorType = data.getType().cast(); - if (!dataTensorType || !dataTensorType.hasSizes()) { - return rewriter.notifyMatchFailure(binder.op, - "Expect non empty input data"); - } - ArrayRef dataShape = dataTensorType.getSizes(); - unsigned dataRank = dataShape.size(); + Value index = rewriter.create( + loc, Torch::IntType::get(ctx), rewriter.getI64IntegerAttr(axis)); - // 2. Get indices shape and rank. - auto indexType = indices.getType().cast(); - if (!indexType || !indexType.hasSizes()) { - return rewriter.notifyMatchFailure(binder.op, - "Expect non empty index tensor"); - } - ArrayRef indexShape = indexType.getSizes(); - unsigned indexRank = indexShape.size(); + // Apply bounds checking on the input: + auto intTy = rewriter.getType(); + auto boolTy = rewriter.getType( + indicesTy.getSizes(), rewriter.getI1Type()); + Value zero = rewriter.create( + loc, intTy, rewriter.getI64IntegerAttr(0)); + Value one = rewriter.create( + loc, intTy, rewriter.getI64IntegerAttr(1)); + Value lt = + rewriter.create(loc, boolTy, indices, zero); + Value dim = + rewriter.create(loc, intTy, data, index); + Value add = rewriter.create(loc, indicesTy, + indices, dim, one); + indices = rewriter.create(loc, indicesTy, lt, + add, indices); - // 3. Compute total elements in the indices tensor, as we will collapse - // the indices tensor to a unary tensor. Also compute index shape and - // data shape tensors as they will be used for creating output types. - int64_t indexElemCount = 1; - for (int64_t dim : indexShape) { - if (dim == Torch::kUnknownSize) { - indexElemCount = Torch::kUnknownSize; + auto intListTy = rewriter.getType( + rewriter.getType()); + auto indicesSize = + rewriter.create(loc, intListTy, indices); + + // Determine the collapsed dim size: + auto indicesCt = 1; + for (auto sz : indicesTy.getSizes()) { + if (sz == Torch::kUnknownSize) { + indicesCt = Torch::kUnknownSize; break; } - indexElemCount *= dim; + + indicesCt *= sz; } - Value constOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - SmallVector indexShapeTensor; - Value indexElemCountVal = constOne; - for (unsigned i = 0; i < indexRank; ++i) { - Value indexDimVal = rewriter.create( - loc, indices, - rewriter.create( - loc, rewriter.getI64IntegerAttr(i))); - indexShapeTensor.emplace_back(indexDimVal); - indexElemCountVal = rewriter.create( - loc, indexElemCountVal, indexDimVal); - } + auto flattenTy = rewriter.getType( + SmallVector{indicesCt}, indicesTy.getOptionalDtype()); + Value rank = rewriter.create(loc, intTy, indices); + Value end = rewriter.create(loc, rank, one); + indices = rewriter.create( + loc, flattenTy, indices, zero, end); - SmallVector dataShapeTensor; - for (unsigned i = 0; i < dataRank; ++i) { - dataShapeTensor.emplace_back(rewriter.create( - loc, data, - rewriter.create( - loc, rewriter.getI64IntegerAttr(i)))); - } + llvm::SmallVector gatherShape(dataTy.getSizes()); + gatherShape[axis] = indicesCt; - // Correct for negative axis: - if (axis < 0) - axis += dataRank; - - // 4. We can not directly perform torch.gather as the onnx.gather op - // collects the input data at different location of output compared to - // torch.gather op. The output of torch.gather and onnx.gather ops are - // indexed differently. - // check https://onnx.ai/onnx/operators/onnx__Gather.html for more - // details. So we will collapse indices tensor to a unary tensor and - // materialize to non-axis dimension of data tensor. For example, - // assuming indices is of shape (4, 5, 6), data is (8, 10, 11, 12) and - // axis=1. we will collapse indices into a (120,) unary tensor, - // materialize to non-axis dimension of data i.e. reshaping the unary - // indices tensor to (1, 120, 1, 1) and then perform the torch.gather - // operation. Now broadcast the output of gather operation to non-axis - // dimensions of data tensor. This would make the result of shape (8, - // 10, 120, 12). Post the broadcasting, expand the indices dimensions by - // reshaping (8, 10, 120, 12) to (8, 10, 4, 5, 6, 12) tensor, which is - // our expected final result. - SmallVector collapsedIndexShape(dataRank, 1); - collapsedIndexShape[axis] = indexElemCount; - Type collapsedIndexType = Torch::ValueTensorType::get( - indexType.getContext(), llvm::ArrayRef(collapsedIndexShape), - indexType.getOptionalDtype()); - - SmallVector collapsedIndexSize(dataRank, constOne); - collapsedIndexSize[axis] = indexElemCountVal; - auto collapsedIndexSizeList = - rewriter.create( - loc, - rewriter.getType( - rewriter.getType()), - collapsedIndexSize); - - auto collapsedIndices = rewriter.create( - loc, collapsedIndexType, indices, collapsedIndexSizeList); - - // 5. Compute gather result type and perform gather operation. - Type gatherResultType = Torch::ValueTensorType::get( - dataTensorType.getContext(), llvm::ArrayRef(collapsedIndexShape), - dataTensorType.getOptionalDtype()); - Value constAxis = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis)); - Value constFalse = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getBoolAttr(false)); - auto gatherOp = rewriter.create( - loc, gatherResultType, data, constAxis, collapsedIndices, - /*sparseGrad=*/constFalse); - - // 6. Broadcast the gather output to non-axis dimensions of data tensor. - SmallVector dataShapeVector(dataShape); - dataShapeVector[axis] = indexElemCount; - Type expandResultType = Torch::ValueTensorType::get( - dataTensorType.getContext(), llvm::ArrayRef(dataShapeVector), - dataTensorType.getOptionalDtype()); - - dataShapeTensor[axis] = indexElemCountVal; - auto expandSizeList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(data.getContext())), - dataShapeTensor); - auto expandedGather = rewriter.create( - loc, expandResultType, gatherOp, expandSizeList, - /*implicit=*/constFalse); - - // 7. Compute the result type of reshape op which expands the collapsed - // indices shapes back to the original indices shapes and reshape the - // output produced at step 6. This will produce our expected result of - // onnx.gather op. - SmallVector resultShapeTensor; - for (unsigned i = 0; i < dataRank; ++i) { - if (i == axis) { - resultShapeTensor.insert(resultShapeTensor.end(), - indexShapeTensor.begin(), - indexShapeTensor.end()); - continue; - } - resultShapeTensor.emplace_back(dataShapeTensor[i]); - } - auto resultSizeList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(data.getContext())), - resultShapeTensor); - - rewriter.replaceOpWithNewOp( - binder.op, resultType, expandedGather, resultSizeList); + auto gatherTy = rewriter.getType( + gatherShape, dataTy.getOptionalDtype()); + Value gather = rewriter.create( + loc, gatherTy, data, index, indices); + rewriter.replaceOpWithNewOp( + binder.op, resultType, gather, index, indicesSize); return success(); }); patterns.onOp( diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 440b7d730..66fbc4158 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2134,17 +2134,14 @@ ONNX_XFAIL_SET = { "AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic", "ElementwiseSeluModule_basic", "EmbeddingModule1DIndices_basic", - "EmbeddingModuleI32Static_basic", "FlipNegativeIndexModule_basic", "HardsigmoidModule_basic", "HardsigmoidRandomModule_basic", "IndexSelectDynamicIndexSizeModule_basic", "IndexSelectDynamicInputSizeModule_basic", "IndexSelectDynamicModulebasic", - "IndexSelectNegativeDimModule_basic", - "IndexSelectSingleIdxModule_basic", - "IndexSelectTwoIdxModule_basic", "IndexSelectWholeDimensionModule_basic", + "IndexSelectWholeTensorModule_basic", "IndexTensorStaticModule_basic", "IndexTensorStaticNonContiguousWithNoneModule_basic", "PixelShuffleModuleStaticRank4Float32_basic", diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 0a154db29..c5b28156a 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -39,35 +39,20 @@ func.func @test_less(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[ // CHECK-LABEL: func.func @test_gather func.func @test_gather(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[8,10,20,40], si64>) -> !torch.vtensor<[8,10,20,40,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} { - // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 - // CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[ARG1_SIZE0:.+]] = torch.aten.size.int %arg1, %[[INT0]] - // CHECK: %[[MUL1:.+]] = torch.aten.mul.int %[[INT1]], %[[ARG1_SIZE0]] - // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 - // CHECK: %[[ARG1_SIZE1:.+]] = torch.aten.size.int %arg1, %[[INT1_0]] - // CHECK: %[[MUL2:.+]] = torch.aten.mul.int %[[MUL1]], %[[ARG1_SIZE1]] - // CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: %[[ARG1_SIZE2:.+]] = torch.aten.size.int %arg1, %[[INT2]] : !torch.vtensor<[8,10,20,40],si64>, !torch.int -> !torch.int - // CHECK: %[[MUL3:.+]] = torch.aten.mul.int %[[MUL2]], %[[ARG1_SIZE2]] : !torch.int, !torch.int -> !torch.int - // CHECK-DAG: %[[INT3:.+]] = torch.constant.int 3 - // CHECK: %[[ARG1_SIZE3:.+]] = torch.aten.size.int %arg1, %[[INT3]] : !torch.vtensor<[8,10,20,40],si64>, !torch.int -> !torch.int - // CHECK: %[[MUL4:.+]] = torch.aten.mul.int %[[MUL3]], %[[ARG1_SIZE3]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[INT0_2:.+]] = torch.constant.int 0 - // CHECK: %[[ARG0_SIZE0:.+]] = torch.aten.size.int %arg0, %[[INT0_2]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.int - // CHECK: %[[INT1_3:.+]] = torch.constant.int 1 - // CHECK: %[[ARG0_SIZE1:.+]] = torch.aten.size.int %arg0, %[[INT1_3]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.int - // CHECK: %[[INT2_4:.+]] = torch.constant.int 2 - // CHECK: %[[ARG0_SIZE2:.+]] = torch.aten.size.int %arg0, %[[INT2_4]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.int - // CHECK: %[[LIST1:.+]] = torch.prim.ListConstruct %[[MUL4]], %[[INT1]], %[[INT1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: %[[VIEW1:.+]] = torch.aten.view %arg1, %[[LIST1]] : !torch.vtensor<[8,10,20,40],si64>, !torch.list -> !torch.vtensor<[64000,1,1],si64> - // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 - // CHECK: %[[FALSE:.+]] = torch.constant.bool false - // CHECK: %[[GATHER:.+]] = torch.aten.gather %arg0, %[[INT0_1]], %[[VIEW1]], %[[FALSE]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.vtensor<[64000,1,1],si64>, !torch.bool -> !torch.vtensor<[64000,1,1],f32> - // CHECK: %[[LIST2:.+]] = torch.prim.ListConstruct %[[MUL4]], %[[ARG0_SIZE1]], %[[ARG0_SIZE2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[GATHER]], %[[LIST2]], %[[FALSE]] : !torch.vtensor<[64000,1,1],f32>, !torch.list, !torch.bool -> !torch.vtensor<[64000,4,5],f32> - // CHECK: %[[LIST3:.+]] = torch.prim.ListConstruct %[[ARG1_SIZE0]], %[[ARG1_SIZE1]], %[[ARG1_SIZE2]], %[[ARG1_SIZE3]], %[[ARG0_SIZE1]], %[[ARG0_SIZE2]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: %[[RES:.+]] = torch.aten.view %[[EXPAND]], %[[LIST3]] : !torch.vtensor<[64000,4,5],f32>, !torch.list -> !torch.vtensor<[8,10,20,40,4,5],f32> - // CHECK: return %[[RES]] : !torch.vtensor<[8,10,20,40,4,5],f32> + // CHECK: %[[AXIS:.+]] = torch.constant.int 0 + // CHECK: %[[ZERO:.+]] = torch.constant.int 0 + // CHECK: %[[ONE:.+]] = torch.constant.int 1 + // CHECK: %[[LT:.+]] = torch.aten.le.Scalar %arg1, %[[ZERO]] + // CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]] + // CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]] + // CHECK: %[[SEL:.+]] = torch.aten.where.self %[[LT]], %[[ADD]], %arg1 + // CHECK: %[[SZ:.+]] = torch.aten.size %[[SEL]] + // CHECK: %[[DIM:.+]] = torch.aten.dim %[[SEL]] + // CHECK: %[[SUB:.+]] = torch.aten.sub.int %[[DIM]], %[[ONE]] + // CHECK: %[[FLAT:.+]] = torch.aten.flatten.using_ints %[[SEL]], %[[ZERO]], %[[SUB]] + // CHECK: %[[ISEL:.+]] = torch.aten.index_select %arg0, %[[AXIS]], %[[FLAT]] + // CHECK: %[[RES:.+]] = torch.aten.unflatten.int %[[ISEL]], %[[AXIS]], %[[SZ]] + // CHECK: return %[[RES]] %0 = torch.operator "onnx.Gather"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[8,10,20,40], si64>) -> !torch.vtensor<[8,10,20,40,4,5],f32> return %0 : !torch.vtensor<[8,10,20,40,4,5],f32> }