[ONNX][MLIR] Add support for onnx.gather op (#2726)

This commit adds support for gather op in the onnx pipeline.
https://github.com/nod-ai/SHARK-Turbine/issues/242

Signed-off-by: Gaurav Shukla <gaurav.shukla@amd.com>
pull/2775/head
Gaurav Shukla 2024-01-19 21:58:29 +05:30 committed by GitHub
parent 704cfdaf08
commit 3b85c70748
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 189 additions and 4 deletions

View File

@ -275,10 +275,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
for (uint64_t i = 1; i < operands.size(); i++) {
result = rewriter.create<Torch::AtenMaximumOp>(
binder.getLoc(), resultType, result, operands[i]);
}
rewriter.replaceOp(
binder.op, result.getDefiningOp());
return success();
}
rewriter.replaceOp(binder.op, result.getDefiningOp());
return success();
});
patterns.onOp("Min", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
@ -334,6 +333,155 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, resultType, lhs, rhs);
return success();
});
patterns.onOp(
"Gather", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value data, indices;
int64_t axis;
if (binder.tensorOperandAtIndex(data, 0) ||
binder.tensorOperandAtIndex(indices, 1) ||
binder.tensorResultType(resultType) ||
binder.s64IntegerAttr(axis, "axis", 0))
return failure();
Location loc = binder.getLoc();
// 1. Get data shape and rank.
auto dataTensorType = data.getType().cast<Torch::ValueTensorType>();
if (!dataTensorType || !dataTensorType.hasSizes()) {
return rewriter.notifyMatchFailure(binder.op,
"Expect non empty input data");
}
ArrayRef<int64_t> dataShape = dataTensorType.getSizes();
unsigned dataRank = dataShape.size();
// 2. Get indices shape and rank.
auto indexType = indices.getType().cast<Torch::ValueTensorType>();
if (!indexType || !indexType.hasSizes()) {
return rewriter.notifyMatchFailure(binder.op,
"Expect non empty index tensor");
}
ArrayRef<int64_t> indexShape = indexType.getSizes();
unsigned indexRank = indexShape.size();
// 3. Compute total elements in the indices tensor, as we will collapse
// the indices tensor to a unary tensor. Also compute index shape and
// data shape tensors as they will be used for creating output types.
int64_t indexElemCount = 1;
for (int64_t dim : indexShape) {
if (dim == Torch::kUnknownSize) {
indexElemCount = Torch::kUnknownSize;
break;
}
indexElemCount *= dim;
}
Value constOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
SmallVector<Value> indexShapeTensor;
Value indexElemCountVal = constOne;
for (unsigned i = 0; i < indexRank; ++i) {
Value indexDimVal = rewriter.create<Torch::AtenSizeIntOp>(
loc, indices,
rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i)));
indexShapeTensor.emplace_back(indexDimVal);
indexElemCountVal = rewriter.create<Torch::AtenMulIntOp>(
loc, indexElemCountVal, indexDimVal);
}
SmallVector<Value> dataShapeTensor;
for (unsigned i = 0; i < dataRank; ++i) {
dataShapeTensor.emplace_back(rewriter.create<Torch::AtenSizeIntOp>(
loc, data,
rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i))));
}
// 4. We can not directly perform torch.gather as the onnx.gather op
// collects the input data at different location of output compared to
// torch.gather op. The output of torch.gather and onnx.gather ops are
// indexed differently.
// check https://onnx.ai/onnx/operators/onnx__Gather.html for more
// details. So we will collapse indices tensor to a unary tensor and
// materialize to non-axis dimension of data tensor. For example,
// assuming indices is of shape (4, 5, 6), data is (8, 10, 11, 12) and
// axis=1. we will collapse indices into a (120,) unary tensor,
// materialize to non-axis dimension of data i.e. reshaping the unary
// indices tensor to (1, 120, 1, 1) and then perform the torch.gather
// operation. Now broadcast the output of gather operation to non-axis
// dimensions of data tensor. This would make the result of shape (8,
// 10, 120, 12). Post the broadcasting, expand the indices dimensions by
// reshaping (8, 10, 120, 12) to (8, 10, 4, 5, 6, 12) tensor, which is
// our expected final result.
SmallVector<int64_t> collapsedIndexShape(dataRank, 1);
collapsedIndexShape[axis] = indexElemCount;
Type collapsedIndexType = Torch::ValueTensorType::get(
indexType.getContext(), llvm::ArrayRef(collapsedIndexShape),
indexType.getOptionalDtype());
SmallVector<Value> collapsedIndexSize(dataRank, constOne);
collapsedIndexSize[axis] = indexElemCountVal;
auto collapsedIndexSizeList =
rewriter.create<Torch::PrimListConstructOp>(
loc,
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
collapsedIndexSize);
auto collapsedIndices = rewriter.create<Torch::AtenViewOp>(
loc, collapsedIndexType, indices, collapsedIndexSizeList);
// 5. Compute gather result type and perform gather operation.
Type gatherResultType = Torch::ValueTensorType::get(
dataTensorType.getContext(), llvm::ArrayRef(collapsedIndexShape),
dataTensorType.getOptionalDtype());
Value constAxis = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis));
Value constFalse = rewriter.create<Torch::ConstantBoolOp>(
binder.getLoc(), rewriter.getType<Torch::BoolType>(),
rewriter.getBoolAttr(false));
auto gatherOp = rewriter.create<Torch::AtenGatherOp>(
loc, gatherResultType, data, constAxis, collapsedIndices,
/*sparseGrad=*/constFalse);
// 6. Broadcast the gather output to non-axis dimensions of data tensor.
SmallVector<int64_t> dataShapeVector(dataShape);
dataShapeVector[axis] = indexElemCount;
Type expandResultType = Torch::ValueTensorType::get(
dataTensorType.getContext(), llvm::ArrayRef(dataShapeVector),
dataTensorType.getOptionalDtype());
dataShapeTensor[axis] = indexElemCountVal;
auto expandSizeList = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(data.getContext())),
dataShapeTensor);
auto expandedGather = rewriter.create<Torch::AtenExpandOp>(
loc, expandResultType, gatherOp, expandSizeList,
/*implicit=*/constFalse);
// 7. Compute the result type of reshape op which expands the collapsed
// indices shapes back to the original indices shapes and reshape the
// output produced at step 6. This will produce our expected result of
// onnx.gather op.
SmallVector<Value> resultShapeTensor;
for (unsigned i = 0; i < dataRank; ++i) {
if (i == axis) {
resultShapeTensor.insert(resultShapeTensor.end(),
indexShapeTensor.begin(),
indexShapeTensor.end());
continue;
}
resultShapeTensor.emplace_back(dataShapeTensor[i]);
}
auto resultSizeList = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(data.getContext())),
resultShapeTensor);
rewriter.replaceOpWithNewOp<Torch::AtenViewOp>(
binder.op, resultType, expandedGather, resultSizeList);
return success();
});
patterns.onOp(
"GatherElements", 13,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {

View File

@ -37,6 +37,43 @@ func.func @test_less(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[
// -----
// CHECK-LABEL: func.func @test_gather
func.func @test_gather(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[8,10,20,40], si64>) -> !torch.vtensor<[8,10,20,40,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} {
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
// CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[ARG1_SIZE0:.+]] = torch.aten.size.int %arg1, %[[INT0]]
// CHECK: %[[MUL1:.+]] = torch.aten.mul.int %[[INT1]], %[[ARG1_SIZE0]]
// CHECK: %[[INT1_0:.+]] = torch.constant.int 1
// CHECK: %[[ARG1_SIZE1:.+]] = torch.aten.size.int %arg1, %[[INT1_0]]
// CHECK: %[[MUL2:.+]] = torch.aten.mul.int %[[MUL1]], %[[ARG1_SIZE1]]
// CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2
// CHECK: %[[ARG1_SIZE2:.+]] = torch.aten.size.int %arg1, %[[INT2]] : !torch.vtensor<[8,10,20,40],si64>, !torch.int -> !torch.int
// CHECK: %[[MUL3:.+]] = torch.aten.mul.int %[[MUL2]], %[[ARG1_SIZE2]] : !torch.int, !torch.int -> !torch.int
// CHECK-DAG: %[[INT3:.+]] = torch.constant.int 3
// CHECK: %[[ARG1_SIZE3:.+]] = torch.aten.size.int %arg1, %[[INT3]] : !torch.vtensor<[8,10,20,40],si64>, !torch.int -> !torch.int
// CHECK: %[[MUL4:.+]] = torch.aten.mul.int %[[MUL3]], %[[ARG1_SIZE3]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[INT0_2:.+]] = torch.constant.int 0
// CHECK: %[[ARG0_SIZE0:.+]] = torch.aten.size.int %arg0, %[[INT0_2]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.int
// CHECK: %[[INT1_3:.+]] = torch.constant.int 1
// CHECK: %[[ARG0_SIZE1:.+]] = torch.aten.size.int %arg0, %[[INT1_3]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.int
// CHECK: %[[INT2_4:.+]] = torch.constant.int 2
// CHECK: %[[ARG0_SIZE2:.+]] = torch.aten.size.int %arg0, %[[INT2_4]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.int
// CHECK: %[[LIST1:.+]] = torch.prim.ListConstruct %[[MUL4]], %[[INT1]], %[[INT1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VIEW1:.+]] = torch.aten.view %arg1, %[[LIST1]] : !torch.vtensor<[8,10,20,40],si64>, !torch.list<int> -> !torch.vtensor<[64000,1,1],si64>
// CHECK: %[[INT0_1:.+]] = torch.constant.int 0
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[GATHER:.+]] = torch.aten.gather %arg0, %[[INT0_1]], %[[VIEW1]], %[[FALSE]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.vtensor<[64000,1,1],si64>, !torch.bool -> !torch.vtensor<[64000,1,1],f32>
// CHECK: %[[LIST2:.+]] = torch.prim.ListConstruct %[[MUL4]], %[[ARG0_SIZE1]], %[[ARG0_SIZE2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[GATHER]], %[[LIST2]], %[[FALSE]] : !torch.vtensor<[64000,1,1],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[64000,4,5],f32>
// CHECK: %[[LIST3:.+]] = torch.prim.ListConstruct %[[ARG1_SIZE0]], %[[ARG1_SIZE1]], %[[ARG1_SIZE2]], %[[ARG1_SIZE3]], %[[ARG0_SIZE1]], %[[ARG0_SIZE2]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[RES:.+]] = torch.aten.view %[[EXPAND]], %[[LIST3]] : !torch.vtensor<[64000,4,5],f32>, !torch.list<int> -> !torch.vtensor<[8,10,20,40,4,5],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[8,10,20,40,4,5],f32>
%0 = torch.operator "onnx.Gather"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[8,10,20,40], si64>) -> !torch.vtensor<[8,10,20,40,4,5],f32>
return %0 : !torch.vtensor<[8,10,20,40,4,5],f32>
}
// -----
// CHECK-LABEL: func.func @test_gather_elements
func.func @test_gather_elements(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5], si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} {
// CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0