mirror of https://github.com/llvm/torch-mlir
[torch] Rework `aten.repeat` to use flatten and unsqueeze (#2984)
Current implementation depends on using `aten.view` which has issues inferring tensor collapse/expand operations during the lowering to `linalg`. Using flatten and unsqueeze better infers what the later reshape behavior.pull/2992/head
parent
aa7c9a9653
commit
06292d9429
|
@ -2398,31 +2398,9 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Decompose aten.repeat into aten.expand and aten.view ops.
|
// Decompose aten.repeat into aten.squeeze, aten.unsqueeze, and aten.broadcast.
|
||||||
//
|
//
|
||||||
// Ref: https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html
|
// Ref: https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html
|
||||||
//
|
|
||||||
// For shape [S1, S2, S3] and repeats [M0, M1, M2, M3]
|
|
||||||
// MS0 = M0; MS1 = M1 * S1; MS2 = M2 * S2; MS3 = M3 * S3
|
|
||||||
//
|
|
||||||
// def aten_repeat(self, repeats):
|
|
||||||
// sizes = self.size()
|
|
||||||
// unsqueezed_sizes = []
|
|
||||||
// expanded_sizes = []
|
|
||||||
// reshape_sizes = []
|
|
||||||
// leading_rank = repeats.size() - sizes.size()
|
|
||||||
// for r in range(leading_rank):
|
|
||||||
// unsqueezed_sizes.append(1)
|
|
||||||
// expanded_sizes.append(repeats[r])
|
|
||||||
// reshaped_sizes.append(repeats[r])
|
|
||||||
//
|
|
||||||
// for s, m in zip(sizes, repeats[leading_rank:]):
|
|
||||||
// unsqueezed_sizes += [1, s]
|
|
||||||
// expanded_sizes += [m, s]
|
|
||||||
// reshaped_sizes += [m * s]
|
|
||||||
// return
|
|
||||||
// self.view(unsqueezed_sizes).expand(expanded_sizes).view(reshaped_sizes)
|
|
||||||
//
|
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeAtenRepeatOp : public OpRewritePattern<AtenRepeatOp> {
|
class DecomposeAtenRepeatOp : public OpRewritePattern<AtenRepeatOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -2431,94 +2409,110 @@ public:
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value self = op.getSelf();
|
Value self = op.getSelf();
|
||||||
MLIRContext *context = op.getContext();
|
auto selfTy = cast<BaseTensorType>(self.getType());
|
||||||
std::optional<unsigned> maybeRank = getTensorRank(self);
|
if (!selfTy.hasSizes())
|
||||||
if (!maybeRank)
|
return rewriter.notifyMatchFailure(
|
||||||
return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor");
|
op, "Unimplemented: no implementation for rankless tensor");
|
||||||
unsigned rank = *maybeRank;
|
|
||||||
|
|
||||||
SmallVector<Value> repeats;
|
SmallVector<Value> repeats;
|
||||||
if (!getListConstructElements(op.getRepeats(), repeats))
|
if (!getListConstructElements(op.getRepeats(), repeats))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Unimplemented: repeats not list of Scalar");
|
op, "Unimplemented: repeats not list of Scalar");
|
||||||
|
|
||||||
if (rank > repeats.size()) {
|
int64_t rank = selfTy.getSizes().size();
|
||||||
|
if (rank > static_cast<int64_t>(repeats.size())) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "repeats are not matched with self's rank");
|
op, "repeats are not matched with self's rank");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto insertDimSizes = [](SmallVector<Value> &dimSizes,
|
int64_t repeatSz = repeats.size();
|
||||||
SmallVector<int64_t> &shape,
|
int64_t batch = repeatSz - rank;
|
||||||
const ArrayRef<Value> &vals) {
|
|
||||||
dimSizes.insert(dimSizes.end(), vals.begin(), vals.end());
|
if (!selfTy.hasSizes())
|
||||||
std::transform(vals.begin(), vals.end(), std::back_inserter(shape),
|
return rewriter.notifyMatchFailure(op, "input sizes unknown");
|
||||||
[&](Value val) -> int64_t {
|
|
||||||
int64_t cst_val;
|
// Materialize out 1 dimensions to broadcast along. This includes
|
||||||
if (matchPattern(val, m_TorchConstantInt(&cst_val))) {
|
// materializing out preceding batch dimensions:
|
||||||
return cst_val;
|
for (int i = 0; i < repeatSz; ++i) {
|
||||||
} else {
|
auto oldSizes = selfTy.getSizes();
|
||||||
return kUnknownSize;
|
llvm::SmallVector<int64_t> sizes;
|
||||||
|
int64_t squeezeDim = i < batch ? i : i * 2 - batch;
|
||||||
|
|
||||||
|
for (int j = 0; j < squeezeDim; ++j)
|
||||||
|
sizes.push_back(oldSizes[j]);
|
||||||
|
sizes.push_back(1);
|
||||||
|
for (int j = squeezeDim, s = oldSizes.size(); j < s; j++)
|
||||||
|
sizes.push_back(oldSizes[j]);
|
||||||
|
|
||||||
|
Value dim = rewriter.create<Torch::ConstantIntOp>(loc, squeezeDim);
|
||||||
|
selfTy =
|
||||||
|
rewriter.getType<ValueTensorType>(sizes, selfTy.getOptionalDtype());
|
||||||
|
self = rewriter.create<AtenUnsqueezeOp>(loc, selfTy, self, dim);
|
||||||
}
|
}
|
||||||
});
|
|
||||||
|
llvm::SmallVector<Value> lengths;
|
||||||
|
for (int i = 0; i < repeatSz; ++i) {
|
||||||
|
if (i < batch) {
|
||||||
|
lengths.push_back(repeats[i]);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value iv = rewriter.create<ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(i * 2 + 1 - batch));
|
||||||
|
Value dim = rewriter.create<AtenSizeIntOp>(loc, self, /*dim=*/iv);
|
||||||
|
lengths.push_back(repeats[i]);
|
||||||
|
lengths.push_back(dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value lengthv = rewriter.create<PrimListConstructOp>(
|
||||||
|
loc, ListType::get(rewriter.getType<IntType>()), lengths);
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t> expandShape(selfTy.getSizes());
|
||||||
|
for (int i = 0; i < repeatSz; ++i) {
|
||||||
|
int64_t repeatDim = i < batch ? i : i * 2 - batch;
|
||||||
|
int64_t repeat;
|
||||||
|
if (!matchPattern(repeats[i], m_TorchConstantInt(&repeat)))
|
||||||
|
repeat = Torch::kUnknownSize;
|
||||||
|
expandShape[repeatDim] = repeat;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto mulDim = [](int64_t lhs, int64_t rhs) {
|
||||||
|
if (lhs == Torch::kUnknownSize || rhs == Torch::kUnknownSize)
|
||||||
|
return Torch::kUnknownSize;
|
||||||
|
return lhs * rhs;
|
||||||
};
|
};
|
||||||
|
|
||||||
Value one = rewriter.create<Torch::ConstantIntOp>(
|
BaseTensorType expandTy = rewriter.getType<ValueTensorType>(
|
||||||
loc, rewriter.getI64IntegerAttr(1));
|
expandShape, selfTy.getOptionalDtype());
|
||||||
|
Value expand =
|
||||||
|
rewriter.create<AtenBroadcastToOp>(loc, expandTy, self, lengthv);
|
||||||
|
|
||||||
SmallVector<Value> unsqueezedSizes, expandedSizes, reshapedSizes;
|
for (int i = 0; i < rank; ++i) {
|
||||||
SmallVector<int64_t> unsqueezedIntSizes, expandedIntSizes;
|
auto oldShape = expandTy.getSizes();
|
||||||
assert(repeats.size() >= rank && "leadingRank should greater than 0");
|
llvm::SmallVector<int64_t> newShape;
|
||||||
auto leadingRank = repeats.size() - rank;
|
int64_t flattenDim = i + batch;
|
||||||
for (size_t i = 0; i < leadingRank; ++i) {
|
for (int j = 0; j < flattenDim; ++j)
|
||||||
insertDimSizes(unsqueezedSizes, unsqueezedIntSizes, ArrayRef<Value>{one});
|
newShape.push_back(oldShape[j]);
|
||||||
insertDimSizes(expandedSizes, expandedIntSizes,
|
newShape.push_back(
|
||||||
ArrayRef<Value>{repeats[i]});
|
mulDim(oldShape[flattenDim], oldShape[flattenDim + 1]));
|
||||||
reshapedSizes.push_back(repeats[i]);
|
for (int j = flattenDim + 2, s = oldShape.size(); j < s; ++j)
|
||||||
|
newShape.push_back(oldShape[j]);
|
||||||
|
|
||||||
|
expandTy = rewriter.getType<ValueTensorType>(newShape,
|
||||||
|
expandTy.getOptionalDtype());
|
||||||
|
|
||||||
|
// Used to keep the return type the same on the last flatten:
|
||||||
|
expandTy = i < rank - 1 ? expandTy : cast<BaseTensorType>(op.getType());
|
||||||
|
|
||||||
|
Value start = rewriter.create<ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(flattenDim));
|
||||||
|
Value end = rewriter.create<ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(flattenDim + 1));
|
||||||
|
expand = rewriter.create<AtenFlattenUsingIntsOp>(loc, expandTy, expand,
|
||||||
|
start, end);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto selfType = self.getType().dyn_cast<BaseTensorType>();
|
rewriter.replaceOp(op, expand);
|
||||||
auto selfShape = selfType.getSizes();
|
|
||||||
for (unsigned i = 0; i < rank; i++) {
|
|
||||||
auto scale = repeats[i + leadingRank];
|
|
||||||
Value dimSize;
|
|
||||||
if (selfShape[i] == kUnknownSize) {
|
|
||||||
Value dim = rewriter.create<Torch::ConstantIntOp>(
|
|
||||||
loc, rewriter.getI64IntegerAttr(i));
|
|
||||||
dimSize = rewriter.create<AtenSizeIntOp>(loc, self, dim);
|
|
||||||
} else {
|
|
||||||
dimSize = rewriter.create<Torch::ConstantIntOp>(
|
|
||||||
loc, rewriter.getI64IntegerAttr(selfShape[i]));
|
|
||||||
}
|
|
||||||
|
|
||||||
insertDimSizes(unsqueezedSizes, unsqueezedIntSizes,
|
|
||||||
ArrayRef<Value>{one, dimSize});
|
|
||||||
insertDimSizes(expandedSizes, expandedIntSizes,
|
|
||||||
ArrayRef<Value>{scale, dimSize});
|
|
||||||
|
|
||||||
Value scaledSize = rewriter.create<AtenMulIntOp>(loc, dimSize, scale);
|
|
||||||
reshapedSizes.push_back(scaledSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
Type dtype = self.getType().cast<ValueTensorType>().getOptionalDtype();
|
|
||||||
Type unsqueezedType = ValueTensorType::get(
|
|
||||||
context, llvm::ArrayRef(unsqueezedIntSizes), dtype);
|
|
||||||
Type expandedType =
|
|
||||||
ValueTensorType::get(context, llvm::ArrayRef(expandedIntSizes), dtype);
|
|
||||||
|
|
||||||
auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext()));
|
|
||||||
Value unsqueezedDims =
|
|
||||||
rewriter.create<PrimListConstructOp>(loc, listType, unsqueezedSizes);
|
|
||||||
Value expandedDims =
|
|
||||||
rewriter.create<PrimListConstructOp>(loc, listType, expandedSizes);
|
|
||||||
Value reshapedDims =
|
|
||||||
rewriter.create<PrimListConstructOp>(loc, listType, reshapedSizes);
|
|
||||||
auto reshaped = rewriter.create<AtenViewOp>(loc, unsqueezedType,
|
|
||||||
op.getSelf(), unsqueezedDims);
|
|
||||||
auto expanded = rewriter.create<AtenBroadcastToOp>(loc, expandedType,
|
|
||||||
reshaped, expandedDims);
|
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<AtenViewOp>(op, op.getType(), expanded,
|
|
||||||
reshapedDims);
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -2142,7 +2142,6 @@ ONNX_XFAIL_SET = {
|
||||||
"IndexTensorMultiInputThreeIndexers_basic",
|
"IndexTensorMultiInputThreeIndexers_basic",
|
||||||
"IndexTensorMultiInput_basic",
|
"IndexTensorMultiInput_basic",
|
||||||
"IndexTensorStaticContiguousWithNoneModule_basic",
|
"IndexTensorStaticContiguousWithNoneModule_basic",
|
||||||
"RepeatModule_basic",
|
|
||||||
"SelectIntModule_basic",
|
"SelectIntModule_basic",
|
||||||
"SliceSingleIdxModule_basic",
|
"SliceSingleIdxModule_basic",
|
||||||
"ViewFlattenAndExpandModule_basic",
|
"ViewFlattenAndExpandModule_basic",
|
||||||
|
|
Loading…
Reference in New Issue