mirror of https://github.com/llvm/torch-mlir
[onnx] Fix onnx.gather lowering for rank-0 indices (#2973)
We assumed rank was atleast 1 however it can be rank-0, generating an illegal pair of flatten / unflatten operations. Corrected this.pull/2971/head
parent
916554f270
commit
d51e80b648
|
@ -572,10 +572,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
auto ctx = binder.op->getContext();
|
||||
auto indicesTy = cast<Torch::ValueTensorType>(indices.getType());
|
||||
auto dataTy = cast<Torch::ValueTensorType>(data.getType());
|
||||
if (!dataTy || !dataTy.hasSizes())
|
||||
if (!dataTy || !dataTy.hasSizes() || !indicesTy.hasSizes())
|
||||
return failure();
|
||||
if (axis < 0)
|
||||
axis += dataTy.getSizes().size();
|
||||
|
||||
int64_t dataRank = dataTy.getSizes().size();
|
||||
int64_t indicesRank = indicesTy.getSizes().size();
|
||||
axis = axis < 0 ? axis + dataRank : axis;
|
||||
|
||||
Value index = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, Torch::IntType::get(ctx), rewriter.getI64IntegerAttr(axis));
|
||||
|
@ -599,8 +601,16 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
|
||||
auto intListTy = rewriter.getType<Torch::ListType>(
|
||||
rewriter.getType<Torch::IntType>());
|
||||
auto indicesSize =
|
||||
rewriter.create<Torch::AtenSizeOp>(loc, intListTy, indices);
|
||||
|
||||
llvm::SmallVector<Value> indicesDims;
|
||||
for (int i = 0, s = indicesTy.getSizes().size(); i < s; ++i) {
|
||||
Value k = rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), i);
|
||||
indicesDims.push_back(rewriter.create<Torch::AtenSizeIntOp>(
|
||||
binder.getLoc(), indices, k));
|
||||
}
|
||||
|
||||
Value indicesSizeList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(), intListTy, indicesDims);
|
||||
|
||||
// Determine the collapsed dim size:
|
||||
auto indicesCt = 1;
|
||||
|
@ -615,20 +625,37 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
|
||||
auto flattenTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
SmallVector<int64_t>{indicesCt}, indicesTy.getOptionalDtype());
|
||||
|
||||
if (indicesRank == 0) {
|
||||
indices = rewriter.create<Torch::AtenUnsqueezeOp>(
|
||||
binder.getLoc(), flattenTy, indices, zero);
|
||||
} else if (indicesRank > 1) {
|
||||
Value rank = rewriter.create<Torch::AtenDimOp>(loc, intTy, indices);
|
||||
Value end = rewriter.create<Torch::AtenSubIntOp>(loc, rank, one);
|
||||
indices = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
|
||||
loc, flattenTy, indices, zero, end);
|
||||
}
|
||||
|
||||
llvm::SmallVector<int64_t> gatherShape(dataTy.getSizes());
|
||||
gatherShape[axis] = indicesCt;
|
||||
|
||||
auto gatherTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
gatherShape, dataTy.getOptionalDtype());
|
||||
Value gather = rewriter.create<Torch::AtenIndexSelectOp>(
|
||||
loc, gatherTy, data, index, indices);
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenUnflattenIntOp>(
|
||||
binder.op, resultType, gather, index, indicesSize);
|
||||
|
||||
if (indicesRank == 1) {
|
||||
rewriter.replaceOp(binder.op, gather);
|
||||
return success();
|
||||
}
|
||||
|
||||
if (indicesRank > 1) {
|
||||
gather = rewriter.replaceOpWithNewOp<Torch::AtenUnflattenIntOp>(
|
||||
binder.op, resultType, gather, index, indicesSizeList);
|
||||
return success();
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenSqueezeOp>(binder.op, resultType,
|
||||
gather);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
|
|
|
@ -2190,14 +2190,12 @@ ONNX_XFAIL_SET = {
|
|||
"ElementwiseUnaryIntModule_basic",
|
||||
"ElementwiseUnsqueezeNegDimsModule_basic",
|
||||
"ElementwiseWhereScalarModule_basic",
|
||||
"EmbeddingModule1DIndices_basic",
|
||||
"EmbeddingModuleF16_basic",
|
||||
"EmbeddingModuleI32_basic",
|
||||
"EmbeddingModuleI64_basic",
|
||||
"FlattenDynamicModule_basic",
|
||||
"GluStaticModule_basic",
|
||||
"GroupNormModule_basic",
|
||||
"IndexSelectDynamicIndexSizeModule_basic",
|
||||
"IndexSelectDynamicModulebasic",
|
||||
"IndexTensorHackedTwinModule3dInput_basic",
|
||||
"IndexTensorHackedTwinModule_basic",
|
||||
|
|
|
@ -37,8 +37,8 @@ func.func @test_less(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_gather
|
||||
func.func @test_gather(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[8,10,20,40], si64>) -> !torch.vtensor<[8,10,20,40,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} {
|
||||
// CHECK-LABEL: func.func @test_gather_nd
|
||||
func.func @test_gather_nd(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[8,10,20,40], si64>) -> !torch.vtensor<[8,10,20,40,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} {
|
||||
// CHECK: %[[AXIS:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[ZERO:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[ONE:.+]] = torch.constant.int 1
|
||||
|
@ -46,7 +46,15 @@ func.func @test_gather(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor
|
|||
// CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]]
|
||||
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]]
|
||||
// CHECK: %[[SEL:.+]] = torch.aten.where.self %[[LT]], %[[ADD]], %arg1
|
||||
// CHECK: %[[SZ:.+]] = torch.aten.size %[[SEL]]
|
||||
// CHECK: %[[D0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[SZ0:.+]] = torch.aten.size.int %[[SEL]], %[[D0]]
|
||||
// CHECK: %[[D1:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[SZ1:.+]] = torch.aten.size.int %[[SEL]], %[[D1]]
|
||||
// CHECK: %[[D2:.+]] = torch.constant.int 2
|
||||
// CHECK: %[[SZ2:.+]] = torch.aten.size.int %[[SEL]], %[[D2]]
|
||||
// CHECK: %[[D3:.+]] = torch.constant.int 3
|
||||
// CHECK: %[[SZ3:.+]] = torch.aten.size.int %[[SEL]], %[[D3]]
|
||||
// CHECK: %[[SZ:.+]] = torch.prim.ListConstruct %[[SZ0]], %[[SZ1]], %[[SZ2]], %[[SZ3]]
|
||||
// CHECK: %[[DIM:.+]] = torch.aten.dim %[[SEL]]
|
||||
// CHECK: %[[SUB:.+]] = torch.aten.sub.int %[[DIM]], %[[ONE]]
|
||||
// CHECK: %[[FLAT:.+]] = torch.aten.flatten.using_ints %[[SEL]], %[[ZERO]], %[[SUB]]
|
||||
|
@ -59,6 +67,25 @@ func.func @test_gather(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_gather_scalar
|
||||
func.func @test_gather_scalar(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[], si64>) -> !torch.vtensor<[4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} {
|
||||
// CHECK: %[[AXIS:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[ZERO:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[ONE:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[LT:.+]] = torch.aten.le.Scalar %arg1, %[[ZERO]]
|
||||
// CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]]
|
||||
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]]
|
||||
// CHECK: %[[SEL:.+]] = torch.aten.where.self %[[LT]], %[[ADD]], %arg1
|
||||
// CHECK: %[[FLAT:.+]] = torch.aten.unsqueeze %[[SEL]], %[[ZERO]] : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ISEL:.+]] = torch.aten.index_select %arg0, %[[AXIS]], %[[FLAT]]
|
||||
// CHECK: %[[RES:.+]] = torch.aten.squeeze %[[ISEL]] : !torch.vtensor<[1,4,5],f32> -> !torch.vtensor<[4,5],f32>
|
||||
// CHECK: return %[[RES]]
|
||||
%0 = torch.operator "onnx.Gather"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[], si64>) -> !torch.vtensor<[4,5],f32>
|
||||
return %0 : !torch.vtensor<[4,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_gather_elements
|
||||
func.func @test_gather_elements(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5], si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} {
|
||||
// CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0
|
||||
|
|
Loading…
Reference in New Issue