mirror of https://github.com/llvm/torch-mlir
[onnx] Fix onnx.gather lowering to use torch.aten.index_select (#2913)
Onnx's gather maps directly to `torch.aten.index_select`. We should just use that path.pull/2922/head
parent
468c533942
commit
7a0d0e954b
|
@ -471,146 +471,66 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
binder.s64IntegerAttr(axis, "axis", 0))
|
binder.s64IntegerAttr(axis, "axis", 0))
|
||||||
return failure();
|
return failure();
|
||||||
Location loc = binder.getLoc();
|
Location loc = binder.getLoc();
|
||||||
|
auto ctx = binder.op->getContext();
|
||||||
|
auto indicesTy = cast<Torch::ValueTensorType>(indices.getType());
|
||||||
|
auto dataTy = cast<Torch::ValueTensorType>(data.getType());
|
||||||
|
if (!dataTy || !dataTy.hasSizes())
|
||||||
|
return failure();
|
||||||
|
if (axis < 0)
|
||||||
|
axis += dataTy.getSizes().size();
|
||||||
|
|
||||||
// 1. Get data shape and rank.
|
Value index = rewriter.create<Torch::ConstantIntOp>(
|
||||||
auto dataTensorType = data.getType().cast<Torch::ValueTensorType>();
|
loc, Torch::IntType::get(ctx), rewriter.getI64IntegerAttr(axis));
|
||||||
if (!dataTensorType || !dataTensorType.hasSizes()) {
|
|
||||||
return rewriter.notifyMatchFailure(binder.op,
|
|
||||||
"Expect non empty input data");
|
|
||||||
}
|
|
||||||
ArrayRef<int64_t> dataShape = dataTensorType.getSizes();
|
|
||||||
unsigned dataRank = dataShape.size();
|
|
||||||
|
|
||||||
// 2. Get indices shape and rank.
|
// Apply bounds checking on the input:
|
||||||
auto indexType = indices.getType().cast<Torch::ValueTensorType>();
|
auto intTy = rewriter.getType<Torch::IntType>();
|
||||||
if (!indexType || !indexType.hasSizes()) {
|
auto boolTy = rewriter.getType<Torch::ValueTensorType>(
|
||||||
return rewriter.notifyMatchFailure(binder.op,
|
indicesTy.getSizes(), rewriter.getI1Type());
|
||||||
"Expect non empty index tensor");
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||||
}
|
loc, intTy, rewriter.getI64IntegerAttr(0));
|
||||||
ArrayRef<int64_t> indexShape = indexType.getSizes();
|
Value one = rewriter.create<Torch::ConstantIntOp>(
|
||||||
unsigned indexRank = indexShape.size();
|
loc, intTy, rewriter.getI64IntegerAttr(1));
|
||||||
|
Value lt =
|
||||||
|
rewriter.create<Torch::AtenLeScalarOp>(loc, boolTy, indices, zero);
|
||||||
|
Value dim =
|
||||||
|
rewriter.create<Torch::AtenSizeIntOp>(loc, intTy, data, index);
|
||||||
|
Value add = rewriter.create<Torch::AtenAddScalarOp>(loc, indicesTy,
|
||||||
|
indices, dim, one);
|
||||||
|
indices = rewriter.create<Torch::AtenWhereSelfOp>(loc, indicesTy, lt,
|
||||||
|
add, indices);
|
||||||
|
|
||||||
// 3. Compute total elements in the indices tensor, as we will collapse
|
auto intListTy = rewriter.getType<Torch::ListType>(
|
||||||
// the indices tensor to a unary tensor. Also compute index shape and
|
rewriter.getType<Torch::IntType>());
|
||||||
// data shape tensors as they will be used for creating output types.
|
auto indicesSize =
|
||||||
int64_t indexElemCount = 1;
|
rewriter.create<Torch::AtenSizeOp>(loc, intListTy, indices);
|
||||||
for (int64_t dim : indexShape) {
|
|
||||||
if (dim == Torch::kUnknownSize) {
|
// Determine the collapsed dim size:
|
||||||
indexElemCount = Torch::kUnknownSize;
|
auto indicesCt = 1;
|
||||||
|
for (auto sz : indicesTy.getSizes()) {
|
||||||
|
if (sz == Torch::kUnknownSize) {
|
||||||
|
indicesCt = Torch::kUnknownSize;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
indexElemCount *= dim;
|
|
||||||
|
indicesCt *= sz;
|
||||||
}
|
}
|
||||||
|
|
||||||
Value constOne = rewriter.create<Torch::ConstantIntOp>(
|
auto flattenTy = rewriter.getType<Torch::ValueTensorType>(
|
||||||
loc, rewriter.getI64IntegerAttr(1));
|
SmallVector<int64_t>{indicesCt}, indicesTy.getOptionalDtype());
|
||||||
SmallVector<Value> indexShapeTensor;
|
Value rank = rewriter.create<Torch::AtenDimOp>(loc, intTy, indices);
|
||||||
Value indexElemCountVal = constOne;
|
Value end = rewriter.create<Torch::AtenSubIntOp>(loc, rank, one);
|
||||||
for (unsigned i = 0; i < indexRank; ++i) {
|
indices = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
|
||||||
Value indexDimVal = rewriter.create<Torch::AtenSizeIntOp>(
|
loc, flattenTy, indices, zero, end);
|
||||||
loc, indices,
|
|
||||||
rewriter.create<Torch::ConstantIntOp>(
|
|
||||||
loc, rewriter.getI64IntegerAttr(i)));
|
|
||||||
indexShapeTensor.emplace_back(indexDimVal);
|
|
||||||
indexElemCountVal = rewriter.create<Torch::AtenMulIntOp>(
|
|
||||||
loc, indexElemCountVal, indexDimVal);
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value> dataShapeTensor;
|
llvm::SmallVector<int64_t> gatherShape(dataTy.getSizes());
|
||||||
for (unsigned i = 0; i < dataRank; ++i) {
|
gatherShape[axis] = indicesCt;
|
||||||
dataShapeTensor.emplace_back(rewriter.create<Torch::AtenSizeIntOp>(
|
|
||||||
loc, data,
|
|
||||||
rewriter.create<Torch::ConstantIntOp>(
|
|
||||||
loc, rewriter.getI64IntegerAttr(i))));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Correct for negative axis:
|
auto gatherTy = rewriter.getType<Torch::ValueTensorType>(
|
||||||
if (axis < 0)
|
gatherShape, dataTy.getOptionalDtype());
|
||||||
axis += dataRank;
|
Value gather = rewriter.create<Torch::AtenIndexSelectOp>(
|
||||||
|
loc, gatherTy, data, index, indices);
|
||||||
// 4. We can not directly perform torch.gather as the onnx.gather op
|
rewriter.replaceOpWithNewOp<Torch::AtenUnflattenIntOp>(
|
||||||
// collects the input data at different location of output compared to
|
binder.op, resultType, gather, index, indicesSize);
|
||||||
// torch.gather op. The output of torch.gather and onnx.gather ops are
|
|
||||||
// indexed differently.
|
|
||||||
// check https://onnx.ai/onnx/operators/onnx__Gather.html for more
|
|
||||||
// details. So we will collapse indices tensor to a unary tensor and
|
|
||||||
// materialize to non-axis dimension of data tensor. For example,
|
|
||||||
// assuming indices is of shape (4, 5, 6), data is (8, 10, 11, 12) and
|
|
||||||
// axis=1. we will collapse indices into a (120,) unary tensor,
|
|
||||||
// materialize to non-axis dimension of data i.e. reshaping the unary
|
|
||||||
// indices tensor to (1, 120, 1, 1) and then perform the torch.gather
|
|
||||||
// operation. Now broadcast the output of gather operation to non-axis
|
|
||||||
// dimensions of data tensor. This would make the result of shape (8,
|
|
||||||
// 10, 120, 12). Post the broadcasting, expand the indices dimensions by
|
|
||||||
// reshaping (8, 10, 120, 12) to (8, 10, 4, 5, 6, 12) tensor, which is
|
|
||||||
// our expected final result.
|
|
||||||
SmallVector<int64_t> collapsedIndexShape(dataRank, 1);
|
|
||||||
collapsedIndexShape[axis] = indexElemCount;
|
|
||||||
Type collapsedIndexType = Torch::ValueTensorType::get(
|
|
||||||
indexType.getContext(), llvm::ArrayRef(collapsedIndexShape),
|
|
||||||
indexType.getOptionalDtype());
|
|
||||||
|
|
||||||
SmallVector<Value> collapsedIndexSize(dataRank, constOne);
|
|
||||||
collapsedIndexSize[axis] = indexElemCountVal;
|
|
||||||
auto collapsedIndexSizeList =
|
|
||||||
rewriter.create<Torch::PrimListConstructOp>(
|
|
||||||
loc,
|
|
||||||
rewriter.getType<Torch::ListType>(
|
|
||||||
rewriter.getType<Torch::IntType>()),
|
|
||||||
collapsedIndexSize);
|
|
||||||
|
|
||||||
auto collapsedIndices = rewriter.create<Torch::AtenViewOp>(
|
|
||||||
loc, collapsedIndexType, indices, collapsedIndexSizeList);
|
|
||||||
|
|
||||||
// 5. Compute gather result type and perform gather operation.
|
|
||||||
Type gatherResultType = Torch::ValueTensorType::get(
|
|
||||||
dataTensorType.getContext(), llvm::ArrayRef(collapsedIndexShape),
|
|
||||||
dataTensorType.getOptionalDtype());
|
|
||||||
Value constAxis = rewriter.create<Torch::ConstantIntOp>(
|
|
||||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
||||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis));
|
|
||||||
Value constFalse = rewriter.create<Torch::ConstantBoolOp>(
|
|
||||||
binder.getLoc(), rewriter.getType<Torch::BoolType>(),
|
|
||||||
rewriter.getBoolAttr(false));
|
|
||||||
auto gatherOp = rewriter.create<Torch::AtenGatherOp>(
|
|
||||||
loc, gatherResultType, data, constAxis, collapsedIndices,
|
|
||||||
/*sparseGrad=*/constFalse);
|
|
||||||
|
|
||||||
// 6. Broadcast the gather output to non-axis dimensions of data tensor.
|
|
||||||
SmallVector<int64_t> dataShapeVector(dataShape);
|
|
||||||
dataShapeVector[axis] = indexElemCount;
|
|
||||||
Type expandResultType = Torch::ValueTensorType::get(
|
|
||||||
dataTensorType.getContext(), llvm::ArrayRef(dataShapeVector),
|
|
||||||
dataTensorType.getOptionalDtype());
|
|
||||||
|
|
||||||
dataShapeTensor[axis] = indexElemCountVal;
|
|
||||||
auto expandSizeList = rewriter.create<Torch::PrimListConstructOp>(
|
|
||||||
loc, Torch::ListType::get(Torch::IntType::get(data.getContext())),
|
|
||||||
dataShapeTensor);
|
|
||||||
auto expandedGather = rewriter.create<Torch::AtenExpandOp>(
|
|
||||||
loc, expandResultType, gatherOp, expandSizeList,
|
|
||||||
/*implicit=*/constFalse);
|
|
||||||
|
|
||||||
// 7. Compute the result type of reshape op which expands the collapsed
|
|
||||||
// indices shapes back to the original indices shapes and reshape the
|
|
||||||
// output produced at step 6. This will produce our expected result of
|
|
||||||
// onnx.gather op.
|
|
||||||
SmallVector<Value> resultShapeTensor;
|
|
||||||
for (unsigned i = 0; i < dataRank; ++i) {
|
|
||||||
if (i == axis) {
|
|
||||||
resultShapeTensor.insert(resultShapeTensor.end(),
|
|
||||||
indexShapeTensor.begin(),
|
|
||||||
indexShapeTensor.end());
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
resultShapeTensor.emplace_back(dataShapeTensor[i]);
|
|
||||||
}
|
|
||||||
auto resultSizeList = rewriter.create<Torch::PrimListConstructOp>(
|
|
||||||
loc, Torch::ListType::get(Torch::IntType::get(data.getContext())),
|
|
||||||
resultShapeTensor);
|
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<Torch::AtenViewOp>(
|
|
||||||
binder.op, resultType, expandedGather, resultSizeList);
|
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
patterns.onOp(
|
patterns.onOp(
|
||||||
|
|
|
@ -2134,17 +2134,14 @@ ONNX_XFAIL_SET = {
|
||||||
"AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic",
|
"AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic",
|
||||||
"ElementwiseSeluModule_basic",
|
"ElementwiseSeluModule_basic",
|
||||||
"EmbeddingModule1DIndices_basic",
|
"EmbeddingModule1DIndices_basic",
|
||||||
"EmbeddingModuleI32Static_basic",
|
|
||||||
"FlipNegativeIndexModule_basic",
|
"FlipNegativeIndexModule_basic",
|
||||||
"HardsigmoidModule_basic",
|
"HardsigmoidModule_basic",
|
||||||
"HardsigmoidRandomModule_basic",
|
"HardsigmoidRandomModule_basic",
|
||||||
"IndexSelectDynamicIndexSizeModule_basic",
|
"IndexSelectDynamicIndexSizeModule_basic",
|
||||||
"IndexSelectDynamicInputSizeModule_basic",
|
"IndexSelectDynamicInputSizeModule_basic",
|
||||||
"IndexSelectDynamicModulebasic",
|
"IndexSelectDynamicModulebasic",
|
||||||
"IndexSelectNegativeDimModule_basic",
|
|
||||||
"IndexSelectSingleIdxModule_basic",
|
|
||||||
"IndexSelectTwoIdxModule_basic",
|
|
||||||
"IndexSelectWholeDimensionModule_basic",
|
"IndexSelectWholeDimensionModule_basic",
|
||||||
|
"IndexSelectWholeTensorModule_basic",
|
||||||
"IndexTensorStaticModule_basic",
|
"IndexTensorStaticModule_basic",
|
||||||
"IndexTensorStaticNonContiguousWithNoneModule_basic",
|
"IndexTensorStaticNonContiguousWithNoneModule_basic",
|
||||||
"PixelShuffleModuleStaticRank4Float32_basic",
|
"PixelShuffleModuleStaticRank4Float32_basic",
|
||||||
|
|
|
@ -39,35 +39,20 @@ func.func @test_less(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_gather
|
// CHECK-LABEL: func.func @test_gather
|
||||||
func.func @test_gather(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[8,10,20,40], si64>) -> !torch.vtensor<[8,10,20,40,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} {
|
func.func @test_gather(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[8,10,20,40], si64>) -> !torch.vtensor<[8,10,20,40,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} {
|
||||||
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
|
// CHECK: %[[AXIS:.+]] = torch.constant.int 0
|
||||||
// CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0
|
// CHECK: %[[ZERO:.+]] = torch.constant.int 0
|
||||||
// CHECK: %[[ARG1_SIZE0:.+]] = torch.aten.size.int %arg1, %[[INT0]]
|
// CHECK: %[[ONE:.+]] = torch.constant.int 1
|
||||||
// CHECK: %[[MUL1:.+]] = torch.aten.mul.int %[[INT1]], %[[ARG1_SIZE0]]
|
// CHECK: %[[LT:.+]] = torch.aten.le.Scalar %arg1, %[[ZERO]]
|
||||||
// CHECK: %[[INT1_0:.+]] = torch.constant.int 1
|
// CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]]
|
||||||
// CHECK: %[[ARG1_SIZE1:.+]] = torch.aten.size.int %arg1, %[[INT1_0]]
|
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]]
|
||||||
// CHECK: %[[MUL2:.+]] = torch.aten.mul.int %[[MUL1]], %[[ARG1_SIZE1]]
|
// CHECK: %[[SEL:.+]] = torch.aten.where.self %[[LT]], %[[ADD]], %arg1
|
||||||
// CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2
|
// CHECK: %[[SZ:.+]] = torch.aten.size %[[SEL]]
|
||||||
// CHECK: %[[ARG1_SIZE2:.+]] = torch.aten.size.int %arg1, %[[INT2]] : !torch.vtensor<[8,10,20,40],si64>, !torch.int -> !torch.int
|
// CHECK: %[[DIM:.+]] = torch.aten.dim %[[SEL]]
|
||||||
// CHECK: %[[MUL3:.+]] = torch.aten.mul.int %[[MUL2]], %[[ARG1_SIZE2]] : !torch.int, !torch.int -> !torch.int
|
// CHECK: %[[SUB:.+]] = torch.aten.sub.int %[[DIM]], %[[ONE]]
|
||||||
// CHECK-DAG: %[[INT3:.+]] = torch.constant.int 3
|
// CHECK: %[[FLAT:.+]] = torch.aten.flatten.using_ints %[[SEL]], %[[ZERO]], %[[SUB]]
|
||||||
// CHECK: %[[ARG1_SIZE3:.+]] = torch.aten.size.int %arg1, %[[INT3]] : !torch.vtensor<[8,10,20,40],si64>, !torch.int -> !torch.int
|
// CHECK: %[[ISEL:.+]] = torch.aten.index_select %arg0, %[[AXIS]], %[[FLAT]]
|
||||||
// CHECK: %[[MUL4:.+]] = torch.aten.mul.int %[[MUL3]], %[[ARG1_SIZE3]] : !torch.int, !torch.int -> !torch.int
|
// CHECK: %[[RES:.+]] = torch.aten.unflatten.int %[[ISEL]], %[[AXIS]], %[[SZ]]
|
||||||
// CHECK: %[[INT0_2:.+]] = torch.constant.int 0
|
// CHECK: return %[[RES]]
|
||||||
// CHECK: %[[ARG0_SIZE0:.+]] = torch.aten.size.int %arg0, %[[INT0_2]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.int
|
|
||||||
// CHECK: %[[INT1_3:.+]] = torch.constant.int 1
|
|
||||||
// CHECK: %[[ARG0_SIZE1:.+]] = torch.aten.size.int %arg0, %[[INT1_3]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.int
|
|
||||||
// CHECK: %[[INT2_4:.+]] = torch.constant.int 2
|
|
||||||
// CHECK: %[[ARG0_SIZE2:.+]] = torch.aten.size.int %arg0, %[[INT2_4]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.int
|
|
||||||
// CHECK: %[[LIST1:.+]] = torch.prim.ListConstruct %[[MUL4]], %[[INT1]], %[[INT1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
|
||||||
// CHECK: %[[VIEW1:.+]] = torch.aten.view %arg1, %[[LIST1]] : !torch.vtensor<[8,10,20,40],si64>, !torch.list<int> -> !torch.vtensor<[64000,1,1],si64>
|
|
||||||
// CHECK: %[[INT0_1:.+]] = torch.constant.int 0
|
|
||||||
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
|
||||||
// CHECK: %[[GATHER:.+]] = torch.aten.gather %arg0, %[[INT0_1]], %[[VIEW1]], %[[FALSE]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.vtensor<[64000,1,1],si64>, !torch.bool -> !torch.vtensor<[64000,1,1],f32>
|
|
||||||
// CHECK: %[[LIST2:.+]] = torch.prim.ListConstruct %[[MUL4]], %[[ARG0_SIZE1]], %[[ARG0_SIZE2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
|
||||||
// CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[GATHER]], %[[LIST2]], %[[FALSE]] : !torch.vtensor<[64000,1,1],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[64000,4,5],f32>
|
|
||||||
// CHECK: %[[LIST3:.+]] = torch.prim.ListConstruct %[[ARG1_SIZE0]], %[[ARG1_SIZE1]], %[[ARG1_SIZE2]], %[[ARG1_SIZE3]], %[[ARG0_SIZE1]], %[[ARG0_SIZE2]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
|
||||||
// CHECK: %[[RES:.+]] = torch.aten.view %[[EXPAND]], %[[LIST3]] : !torch.vtensor<[64000,4,5],f32>, !torch.list<int> -> !torch.vtensor<[8,10,20,40,4,5],f32>
|
|
||||||
// CHECK: return %[[RES]] : !torch.vtensor<[8,10,20,40,4,5],f32>
|
|
||||||
%0 = torch.operator "onnx.Gather"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[8,10,20,40], si64>) -> !torch.vtensor<[8,10,20,40,4,5],f32>
|
%0 = torch.operator "onnx.Gather"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[8,10,20,40], si64>) -> !torch.vtensor<[8,10,20,40,4,5],f32>
|
||||||
return %0 : !torch.vtensor<[8,10,20,40,4,5],f32>
|
return %0 : !torch.vtensor<[8,10,20,40,4,5],f32>
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue