mirror of https://github.com/llvm/torch-mlir
parent
f74164161a
commit
7714bebfe3
|
@ -53,7 +53,7 @@ createOneDimTfIndices(PatternRewriter &rewriter, Operation *op,
|
||||||
// dim0: indicesMetaElementRepeatTimes = 1 x 4*2 = 8
|
// dim0: indicesMetaElementRepeatTimes = 1 x 4*2 = 8
|
||||||
// dim1: indicesMetaElementRepeatTimes = 1 *1 x 2 = 2
|
// dim1: indicesMetaElementRepeatTimes = 1 *1 x 2 = 2
|
||||||
// dim2: indicesMetaElementRepeatTimes = 1 *1*4 = 4
|
// dim2: indicesMetaElementRepeatTimes = 1 *1*4 = 4
|
||||||
for (int i = 0; i < indexRank; i++) {
|
for (int i = 0; i < static_cast<int>(indexRank); i++) {
|
||||||
if (i == dim) {
|
if (i == dim) {
|
||||||
continue;
|
continue;
|
||||||
} else {
|
} else {
|
||||||
|
@ -61,7 +61,7 @@ createOneDimTfIndices(PatternRewriter &rewriter, Operation *op,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (dim != indexShape.size() - 1) {
|
if (dim != static_cast<int>(indexShape.size()) - 1) {
|
||||||
// Create one dim indices for index except for last dim
|
// Create one dim indices for index except for last dim
|
||||||
// Create indices raw vector.
|
// Create indices raw vector.
|
||||||
// torch.stack(torch.meshgrid)
|
// torch.stack(torch.meshgrid)
|
||||||
|
|
Loading…
Reference in New Issue