mirror of https://github.com/llvm/torch-mlir
[MLIR][ONNX] Fix onnx.gather_nd implementation (#3070)
The indices should be expanded before the torch.gather operation. Signed-off-by: Gaurav Shukla <gaurav@amd.com>pull/3098/head
parent
da88efad89
commit
129a79417a
|
@ -671,7 +671,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
if (!indicesTy || !indicesTy.hasSizes())
|
||||
return failure();
|
||||
|
||||
// step1. Get shapes and ranks of data and indices. The last dimension
|
||||
// step 1. Get shapes and ranks of data and indices. The last dimension
|
||||
// of indices is expected to be static.
|
||||
ArrayRef<int64_t> dataShape = dataTy.getSizes();
|
||||
int64_t dataRank = dataShape.size();
|
||||
|
@ -693,25 +693,38 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "expected last dimension of indices to be static");
|
||||
|
||||
// step2. Get dimension list of data and indices.
|
||||
// step 2. Get dimension list of data.
|
||||
SmallVector<int64_t> batchShape;
|
||||
SmallVector<Value> batchDims;
|
||||
SmallVector<Value> dataDims;
|
||||
for (int64_t i = 0; i < dataRank; ++i) {
|
||||
Value k = rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), i);
|
||||
Value dataDim = rewriter.create<Torch::AtenSizeIntOp>(loc, data, k);
|
||||
dataDims.push_back(dataDim);
|
||||
if (i < batchDimCount)
|
||||
if (i < batchDimCount) {
|
||||
batchShape.push_back(dataShape[i]);
|
||||
batchDims.push_back(dataDim);
|
||||
}
|
||||
}
|
||||
|
||||
// step 3. Get dimension list of indices.
|
||||
Value constZero = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(0));
|
||||
Value constOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
SmallVector<Value> indicesDimsMinusOne;
|
||||
SmallVector<Value> unflattenIndicesDims;
|
||||
Value indicesFlattenDim = constOne;
|
||||
for (int64_t i = 0; i < indicesRank - 1; ++i) {
|
||||
Value k = rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), i);
|
||||
Value indicesDim =
|
||||
rewriter.create<Torch::AtenSizeIntOp>(loc, indices, k);
|
||||
indicesDimsMinusOne.push_back(indicesDim);
|
||||
if (i >= batchDimCount)
|
||||
if (i >= batchDimCount) {
|
||||
unflattenIndicesDims.push_back(indicesDim);
|
||||
indicesFlattenDim = rewriter.create<Torch::AtenMulIntOp>(
|
||||
loc, indicesFlattenDim, indicesDim);
|
||||
}
|
||||
}
|
||||
ArrayRef<int64_t> indicesShapeMinusOne = indicesShape.drop_back();
|
||||
|
||||
|
@ -719,26 +732,23 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
// the ranks of data(`r`) and indices(`q`) to be same. So we will
|
||||
// perform collapse and reshape operations to match the ranks of data
|
||||
// and indices(making sure the semantics of the onnx.gather_nd are
|
||||
// preserved), perform torch.gather operation, later expand the gather
|
||||
// result to match onnx.gather_nd output. For example, assuming indices
|
||||
// is of shape (4, 5, 3, 2), data is (4, 10, 11, 7, 4) and
|
||||
// preserved), perform torch.gather operation, later unflatten the
|
||||
// gather result to match onnx.gather_nd output. For example, assuming
|
||||
// indices is of shape (4, 5, 3, 2), data is (4, 10, 11, 7, 4) and
|
||||
// batch_dims(`b`)=1. Firstly, modify indices to 1-D indexing as the
|
||||
// torch.gather op supports only single dimensional indexing. (this
|
||||
// algorithm would have been simpler if we can get a torch op that
|
||||
// supports indexing at multiple dimensions simultaneously). 1-D indexed
|
||||
// indices will be of shape (4, 5, 3, 1), now materialize it to
|
||||
// `r-b-indices_shape[-1]` dimension of data i.e. reshaping it to the
|
||||
// shape (4, 5, 3, 1, 1). Next step is to flatten the indices and data
|
||||
// to (4, 15, 1, 1) and (4, 110, 7, 4) shapes respectively and then
|
||||
// perform the torch.gather operation. Post the gather operation,
|
||||
// unflatten the indices dimensions of result to (4, 5, 3, 1, 1) and
|
||||
// then expand it to get the final output of shape (4, 5, 3, 7, 4).
|
||||
// shape (4, 5, 3, 1, 1). Next step is to flatten+expand the indices and
|
||||
// flatten the data to (4, 15, 7, 4) and (4, 110, 7, 4) shapes
|
||||
// respectively and then perform the torch.gather operation. Post the
|
||||
// gather operation, unflatten the indices dimensions of result to (4,
|
||||
// 5, 3, 7, 4) which is our required result.
|
||||
|
||||
// step3. Convert indices_shape[-1] dimensional indexing to 1D indexing.
|
||||
Value constZero = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(0));
|
||||
Value constOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
// step 4. Convert indices_shape[-1] dimensional indexing to 1D
|
||||
// indexing.
|
||||
Value sliceDim = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(indicesRank - 1));
|
||||
SmallVector<int64_t> indicesSliceShape(indicesShapeMinusOne);
|
||||
|
@ -774,7 +784,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
dataDims[batchDimCount + i]);
|
||||
}
|
||||
|
||||
// step4. Compute all the required result types here.
|
||||
// step 5. Compute all the required result types here.
|
||||
SmallVector<int64_t> reshapeIndicesShape(indicesShapeMinusOne);
|
||||
SmallVector<Value> reshapeIndicesDims(indicesDimsMinusOne);
|
||||
// Determine the collapsed dim size of indices(index_shape[-1] is not
|
||||
|
@ -801,19 +811,23 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
dataCt *= sz;
|
||||
}
|
||||
flattenDataShape.push_back(dataCt);
|
||||
// Compute result size of the final expand op.
|
||||
SmallVector<Value> expandResultDims(indicesDimsMinusOne);
|
||||
// Compute the shape of expand op.
|
||||
SmallVector<Value> expandIndicesDims(batchDims);
|
||||
expandIndicesDims.push_back(indicesFlattenDim);
|
||||
SmallVector<int64_t> expandIndicesShape(batchShape);
|
||||
expandIndicesShape.push_back(indicesCt);
|
||||
// Append `r-b-indices_shape[-1]` unit or data dims appropriately to all
|
||||
// result types.
|
||||
for (int64_t i = batchDimCount + indicesLastDim; i < dataRank; ++i) {
|
||||
reshapeIndicesShape.push_back(1);
|
||||
flattenIndicesShape.push_back(1);
|
||||
flattenDataShape.push_back(dataShape[i]);
|
||||
expandIndicesShape.push_back(dataShape[i]);
|
||||
reshapeIndicesDims.push_back(constOne);
|
||||
expandResultDims.push_back(dataDims[i]);
|
||||
expandIndicesDims.push_back(dataDims[i]);
|
||||
}
|
||||
|
||||
// step5. Reshape 1-D indexed indices to match the rank of flattened
|
||||
// step 6. Reshape 1-D indexed indices to match the rank of flattened
|
||||
// data by inserting unit dimensions.
|
||||
auto intListTy = rewriter.getType<Torch::ListType>(
|
||||
rewriter.getType<Torch::IntType>());
|
||||
|
@ -825,7 +839,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
Value reshapedIndices = rewriter.create<Torch::AtenViewOp>(
|
||||
loc, reshapeIndicesTy, updatedIndices, reshapeIndicesSizeList);
|
||||
|
||||
// step6. Flatten `q-b-1` dimensions of the indices.
|
||||
// step 7. Flatten `q-b-1` dimensions of the indices.
|
||||
auto flattenIndicesTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
flattenIndicesShape, indicesTy.getOptionalDtype());
|
||||
Value batchDimCountVal = rewriter.create<Torch::ConstantIntOp>(
|
||||
|
@ -836,12 +850,25 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
loc, flattenIndicesTy, reshapedIndices, constZero);
|
||||
} else if (indicesRank > 1) {
|
||||
Value endDim = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(indicesRank - 1));
|
||||
loc, rewriter.getI64IntegerAttr(indicesRank - 2));
|
||||
flattenedIndices = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
|
||||
loc, flattenIndicesTy, reshapedIndices, batchDimCountVal, endDim);
|
||||
}
|
||||
|
||||
// step7. Flatten indices_shape[-1] dimensions of data.
|
||||
// step 8. Expand `r-b-indices_shape[-1]` dims of flattened indices.
|
||||
auto expandIndicesTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
expandIndicesShape, indicesTy.getOptionalDtype());
|
||||
Value expandIndicesSizeList =
|
||||
rewriter.create<Torch::PrimListConstructOp>(loc, intListTy,
|
||||
expandIndicesDims);
|
||||
Value constFalse = rewriter.create<Torch::ConstantBoolOp>(
|
||||
loc, rewriter.getType<Torch::BoolType>(),
|
||||
rewriter.getBoolAttr(false));
|
||||
Value expandedIndices = rewriter.create<Torch::AtenExpandOp>(
|
||||
loc, expandIndicesTy, flattenedIndices, expandIndicesSizeList,
|
||||
/*implicit=*/constFalse);
|
||||
|
||||
// step 9. Flatten indices_shape[-1] dimensions of data.
|
||||
auto flattenDataTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
flattenDataShape, dataTy.getOptionalDtype());
|
||||
Value endDim = rewriter.create<Torch::ConstantIntOp>(
|
||||
|
@ -850,39 +877,24 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
Value flattenedData = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
|
||||
loc, flattenDataTy, data, batchDimCountVal, endDim);
|
||||
|
||||
// step8. Now we have flattenedData and flattenedIndices of same rank to
|
||||
// perform gather operation.
|
||||
// step 10. Now we have flattenedData and expandedIndices of same rank
|
||||
// to perform gather operation.
|
||||
auto gatherTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
flattenIndicesShape, dataTy.getOptionalDtype());
|
||||
Value constFalse = rewriter.create<Torch::ConstantBoolOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::BoolType>(),
|
||||
rewriter.getBoolAttr(false));
|
||||
expandIndicesShape, dataTy.getOptionalDtype());
|
||||
Value gather = rewriter.create<Torch::AtenGatherOp>(
|
||||
loc, gatherTy, flattenedData, batchDimCountVal, flattenedIndices,
|
||||
loc, gatherTy, flattenedData, batchDimCountVal, expandedIndices,
|
||||
/*sparseGrad=*/constFalse);
|
||||
|
||||
// step9. Unflatten the collapsed indices dims of gather result.
|
||||
auto unflattenTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
reshapeIndicesShape, dataTy.getOptionalDtype());
|
||||
Value unflattenedGather = gather;
|
||||
// step 11. Unflatten the collapsed indices dims of gather result.
|
||||
if (indicesRank == 1) {
|
||||
unflattenedGather = rewriter.create<Torch::AtenSqueezeDimOp>(
|
||||
loc, unflattenTy, gather, /*dim=*/constZero);
|
||||
} else if (indicesRank > 1) {
|
||||
Value unflattenSizeList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
loc, intListTy, unflattenIndicesDims);
|
||||
unflattenedGather = rewriter.create<Torch::AtenUnflattenIntOp>(
|
||||
loc, unflattenTy, gather, batchDimCountVal, unflattenSizeList);
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenSqueezeDimOp>(
|
||||
binder.op, resultType, gather, /*dim=*/constZero);
|
||||
return success();
|
||||
}
|
||||
|
||||
// step10. Expand `r-b-indices_shape[-1]` dims of unflattenedGather
|
||||
// result.
|
||||
Value expandResultSizeList =
|
||||
rewriter.create<Torch::PrimListConstructOp>(loc, intListTy,
|
||||
expandResultDims);
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenExpandOp>(
|
||||
binder.op, resultType, unflattenedGather, expandResultSizeList,
|
||||
/*implicit=*/constFalse);
|
||||
Value unflattenSizeList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
loc, intListTy, unflattenIndicesDims);
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenUnflattenIntOp>(
|
||||
binder.op, resultType, gather, batchDimCountVal, unflattenSizeList);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
|
|
|
@ -96,14 +96,16 @@ func.func @test_gather_nd(%arg0: !torch.vtensor<[2,6,8,5],f32>, %arg1: !torch.vt
|
|||
// CHECK: %[[SIZE2:.+]] = torch.aten.size.int %arg0, %[[INT2]] : !torch.vtensor<[2,6,8,5],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
||||
// CHECK: %[[SIZE3:.+]] = torch.aten.size.int %arg0, %[[INT3]] : !torch.vtensor<[2,6,8,5],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[INT0_3:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT1_4:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INDSIZE0:.+]] = torch.aten.size.int %arg1, %[[INT0_0]] : !torch.vtensor<[2,4,3,2],si64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[INT1_1:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[INDSIZE1:.+]] = torch.aten.size.int %arg1, %[[INT1_1]] : !torch.vtensor<[2,4,3,2],si64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[MUL1:.+]] = torch.aten.mul.int %[[INT1_4]], %[[INDSIZE1]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[INT2_2:.+]] = torch.constant.int 2
|
||||
// CHECK: %[[INDSIZE2:.+]] = torch.aten.size.int %arg1, %[[INT2_2]] : !torch.vtensor<[2,4,3,2],si64>, !torch.int -> !torch.int
|
||||
// CHECK: %[[INT0_3:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT1_4:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[MUL2:.+]] = torch.aten.mul.int %[[MUL1]], %[[INDSIZE2]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[INT3_5:.+]] = torch.constant.int 3
|
||||
// CHECK: %[[INT1_6:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[SLICE0:.+]] = torch.aten.slice.Tensor %arg1, %[[INT3_5]], %[[INT0_3]], %[[INT1_6]], %[[INT1_4]] : !torch.vtensor<[2,4,3,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,4,3,1],si64>
|
||||
|
@ -119,17 +121,17 @@ func.func @test_gather_nd(%arg0: !torch.vtensor<[2,6,8,5],f32>, %arg1: !torch.vt
|
|||
// CHECK: %[[LIST0:.+]] = torch.prim.ListConstruct %[[INDSIZE0]], %[[INDSIZE1]], %[[INDSIZE2]], %[[INT1_4]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VIEW:.+]] = torch.aten.view %[[NEWIND]], %[[LIST0]] : !torch.vtensor<[2,4,3,1],si64>, !torch.list<int> -> !torch.vtensor<[2,4,3,1],si64>
|
||||
// CHECK: %[[INT1_8:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[INT3_9:.+]] = torch.constant.int 3
|
||||
// CHECK: %[[FLATIND:.+]] = torch.aten.flatten.using_ints %[[VIEW]], %[[INT1_8]], %[[INT3_9]] : !torch.vtensor<[2,4,3,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,12,1],si64>
|
||||
// CHECK: %[[INT2_9:.+]] = torch.constant.int 2
|
||||
// CHECK: %[[FLATIND:.+]] = torch.aten.flatten.using_ints %[[VIEW]], %[[INT1_8]], %[[INT2_9]] : !torch.vtensor<[2,4,3,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,12,1],si64>
|
||||
// CHECK: %[[EXLIST:.+]] = torch.prim.ListConstruct %[[SIZE0]], %[[MUL2]], %[[SIZE3]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK: %[[EXPANDIND:.+]] = torch.aten.expand %[[FLATIND]], %[[EXLIST]], %[[FALSE]] : !torch.vtensor<[2,12,1],si64>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,12,5],si64>
|
||||
// CHECK: %[[INT2_10:.+]] = torch.constant.int 2
|
||||
// CHECK: %[[FLATDATA:.+]] = torch.aten.flatten.using_ints %arg0, %[[INT1_8]], %[[INT2_10]] : !torch.vtensor<[2,6,8,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,48,5],f32>
|
||||
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK: %[[GATHER:.+]] = torch.aten.gather %[[FLATDATA]], %[[INT1_8]], %[[FLATIND]], %[[FALSE]]
|
||||
// CHECK: %[[GATHER:.+]] = torch.aten.gather %[[FLATDATA]], %[[INT1_8]], %[[EXPANDIND]], %[[FALSE]]
|
||||
// CHECK: %[[LIST1:.+]] = torch.prim.ListConstruct %[[INDSIZE1]], %[[INDSIZE2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[UNFLAT:.+]] = torch.aten.unflatten.int %[[GATHER]], %[[INT1_8]], %[[LIST1]] : !torch.vtensor<[2,12,1],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[2,4,3,1],f32>
|
||||
// CHECK: %[[LIST2:.+]] = torch.prim.ListConstruct %[[INDSIZE0]], %[[INDSIZE1]], %[[INDSIZE2]], %[[SIZE3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[RES:.+]] = torch.aten.expand %[[UNFLAT]], %[[LIST2]], %[[FALSE]] : !torch.vtensor<[2,4,3,1],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,4,3,5],f32>
|
||||
// CHECK: return %[[RES]] : !torch.vtensor<[2,4,3,5],f32>
|
||||
// CHECK: %[[UNFLAT:.+]] = torch.aten.unflatten.int %[[GATHER]], %[[INT1_8]], %[[LIST1]] : !torch.vtensor<[2,12,5],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[2,4,3,5],f32>
|
||||
// CHECK: return %[[UNFLAT]] : !torch.vtensor<[2,4,3,5],f32>
|
||||
%0 = torch.operator "onnx.GatherND"(%arg0, %arg1) {torch.onnx.batch_dims = 1 : si64} : (!torch.vtensor<[2,6,8,5],f32>, !torch.vtensor<[2,4,3,2], si64>) -> !torch.vtensor<[2,4,3,5],f32>
|
||||
return %0 : !torch.vtensor<[2,4,3,5],f32>
|
||||
}
|
||||
|
@ -164,14 +166,14 @@ func.func @test_gather_nd_1D_indices(%arg0: !torch.vtensor<[2,6,8,5],f32>, %arg1
|
|||
// CHECK: %[[VIEW:.+]] = torch.aten.view %[[NEWIND]], %[[LIST0]] : !torch.vtensor<[1],si64>, !torch.list<int> -> !torch.vtensor<[1,1],si64>
|
||||
// CHECK: %[[INT0_5:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[FLATIND:.+]] = torch.aten.unsqueeze %[[VIEW]], %[[INT0_0]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.vtensor<[1,1,1],si64>
|
||||
// CHECK: %[[EXLIST:.+]] = torch.prim.ListConstruct %[[INT1_1]], %[[SIZE2]], %[[SIZE3]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK: %[[EXPANDIND:.+]] = torch.aten.expand %[[FLATIND]], %[[EXLIST]], %[[FALSE]] : !torch.vtensor<[1,1,1],si64>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,8,5],si64>
|
||||
// CHECK: %[[INT1_6:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[FLATDATA:.+]] = torch.aten.flatten.using_ints %arg0, %[[INT0_5]], %[[INT1_6]] : !torch.vtensor<[2,6,8,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[12,8,5],f32>
|
||||
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK: %[[GATHER:.+]] = torch.aten.gather %[[FLATDATA]], %[[INT0_5]], %[[FLATIND]], %[[FALSE]]
|
||||
// CHECK: %[[SQUEEZE:.+]] = torch.aten.squeeze.dim %[[GATHER]], %[[INT0_0]] : !torch.vtensor<[1,1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32>
|
||||
// CHECK: %[[LIST2:.+]] = torch.prim.ListConstruct %[[SIZE2]], %[[SIZE3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[RES:.+]] = torch.aten.expand %[[SQUEEZE]], %[[LIST2]], %[[FALSE]] : !torch.vtensor<[1,1],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[8,5],f32>
|
||||
// CHECK: return %[[RES]] : !torch.vtensor<[8,5],f32>
|
||||
// CHECK: %[[GATHER:.+]] = torch.aten.gather %[[FLATDATA]], %[[INT0_5]], %[[EXPANDIND]], %[[FALSE]]
|
||||
// CHECK: %[[SQUEEZE:.+]] = torch.aten.squeeze.dim %[[GATHER]], %[[INT0_0]] : !torch.vtensor<[1,8,5],f32>, !torch.int -> !torch.vtensor<[8,5],f32>
|
||||
// CHECK: return %[[SQUEEZE]] : !torch.vtensor<[8,5],f32>
|
||||
%0 = torch.operator "onnx.GatherND"(%arg0, %arg1) {torch.onnx.batch_dims = 0 : si64} : (!torch.vtensor<[2,6,8,5],f32>, !torch.vtensor<[2], si64>) -> !torch.vtensor<[8,5],f32>
|
||||
return %0 : !torch.vtensor<[8,5],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue