[torch] Rework `aten.repeat` to use flatten and unsqueeze (#2984)

Current implementation depends on using `aten.view` which has issues
inferring tensor collapse/expand operations during the lowering to
`linalg`. Using flatten and unsqueeze better infers what the later
reshape behavior.
pull/2992/head
Rob Suderman 2024-03-06 10:19:18 -08:00 committed by GitHub
parent aa7c9a9653
commit 06292d9429
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 85 additions and 92 deletions

View File

@ -2398,31 +2398,9 @@ public:
};
} // namespace
// Decompose aten.repeat into aten.expand and aten.view ops.
// Decompose aten.repeat into aten.squeeze, aten.unsqueeze, and aten.broadcast.
//
// Ref: https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html
//
// For shape [S1, S2, S3] and repeats [M0, M1, M2, M3]
// MS0 = M0; MS1 = M1 * S1; MS2 = M2 * S2; MS3 = M3 * S3
//
// def aten_repeat(self, repeats):
// sizes = self.size()
// unsqueezed_sizes = []
// expanded_sizes = []
// reshape_sizes = []
// leading_rank = repeats.size() - sizes.size()
// for r in range(leading_rank):
// unsqueezed_sizes.append(1)
// expanded_sizes.append(repeats[r])
// reshaped_sizes.append(repeats[r])
//
// for s, m in zip(sizes, repeats[leading_rank:]):
// unsqueezed_sizes += [1, s]
// expanded_sizes += [m, s]
// reshaped_sizes += [m * s]
// return
// self.view(unsqueezed_sizes).expand(expanded_sizes).view(reshaped_sizes)
//
namespace {
class DecomposeAtenRepeatOp : public OpRewritePattern<AtenRepeatOp> {
public:
@ -2431,94 +2409,110 @@ public:
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.getSelf();
MLIRContext *context = op.getContext();
std::optional<unsigned> maybeRank = getTensorRank(self);
if (!maybeRank)
return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor");
unsigned rank = *maybeRank;
auto selfTy = cast<BaseTensorType>(self.getType());
if (!selfTy.hasSizes())
return rewriter.notifyMatchFailure(
op, "Unimplemented: no implementation for rankless tensor");
SmallVector<Value> repeats;
if (!getListConstructElements(op.getRepeats(), repeats))
return rewriter.notifyMatchFailure(
op, "Unimplemented: repeats not list of Scalar");
if (rank > repeats.size()) {
int64_t rank = selfTy.getSizes().size();
if (rank > static_cast<int64_t>(repeats.size())) {
return rewriter.notifyMatchFailure(
op, "repeats are not matched with self's rank");
}
auto insertDimSizes = [](SmallVector<Value> &dimSizes,
SmallVector<int64_t> &shape,
const ArrayRef<Value> &vals) {
dimSizes.insert(dimSizes.end(), vals.begin(), vals.end());
std::transform(vals.begin(), vals.end(), std::back_inserter(shape),
[&](Value val) -> int64_t {
int64_t cst_val;
if (matchPattern(val, m_TorchConstantInt(&cst_val))) {
return cst_val;
} else {
return kUnknownSize;
}
});
};
int64_t repeatSz = repeats.size();
int64_t batch = repeatSz - rank;
Value one = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
if (!selfTy.hasSizes())
return rewriter.notifyMatchFailure(op, "input sizes unknown");
SmallVector<Value> unsqueezedSizes, expandedSizes, reshapedSizes;
SmallVector<int64_t> unsqueezedIntSizes, expandedIntSizes;
assert(repeats.size() >= rank && "leadingRank should greater than 0");
auto leadingRank = repeats.size() - rank;
for (size_t i = 0; i < leadingRank; ++i) {
insertDimSizes(unsqueezedSizes, unsqueezedIntSizes, ArrayRef<Value>{one});
insertDimSizes(expandedSizes, expandedIntSizes,
ArrayRef<Value>{repeats[i]});
reshapedSizes.push_back(repeats[i]);
// Materialize out 1 dimensions to broadcast along. This includes
// materializing out preceding batch dimensions:
for (int i = 0; i < repeatSz; ++i) {
auto oldSizes = selfTy.getSizes();
llvm::SmallVector<int64_t> sizes;
int64_t squeezeDim = i < batch ? i : i * 2 - batch;
for (int j = 0; j < squeezeDim; ++j)
sizes.push_back(oldSizes[j]);
sizes.push_back(1);
for (int j = squeezeDim, s = oldSizes.size(); j < s; j++)
sizes.push_back(oldSizes[j]);
Value dim = rewriter.create<Torch::ConstantIntOp>(loc, squeezeDim);
selfTy =
rewriter.getType<ValueTensorType>(sizes, selfTy.getOptionalDtype());
self = rewriter.create<AtenUnsqueezeOp>(loc, selfTy, self, dim);
}
auto selfType = self.getType().dyn_cast<BaseTensorType>();
auto selfShape = selfType.getSizes();
for (unsigned i = 0; i < rank; i++) {
auto scale = repeats[i + leadingRank];
Value dimSize;
if (selfShape[i] == kUnknownSize) {
Value dim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));
dimSize = rewriter.create<AtenSizeIntOp>(loc, self, dim);
} else {
dimSize = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(selfShape[i]));
llvm::SmallVector<Value> lengths;
for (int i = 0; i < repeatSz; ++i) {
if (i < batch) {
lengths.push_back(repeats[i]);
continue;
}
insertDimSizes(unsqueezedSizes, unsqueezedIntSizes,
ArrayRef<Value>{one, dimSize});
insertDimSizes(expandedSizes, expandedIntSizes,
ArrayRef<Value>{scale, dimSize});
Value scaledSize = rewriter.create<AtenMulIntOp>(loc, dimSize, scale);
reshapedSizes.push_back(scaledSize);
Value iv = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i * 2 + 1 - batch));
Value dim = rewriter.create<AtenSizeIntOp>(loc, self, /*dim=*/iv);
lengths.push_back(repeats[i]);
lengths.push_back(dim);
}
Type dtype = self.getType().cast<ValueTensorType>().getOptionalDtype();
Type unsqueezedType = ValueTensorType::get(
context, llvm::ArrayRef(unsqueezedIntSizes), dtype);
Type expandedType =
ValueTensorType::get(context, llvm::ArrayRef(expandedIntSizes), dtype);
Value lengthv = rewriter.create<PrimListConstructOp>(
loc, ListType::get(rewriter.getType<IntType>()), lengths);
auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext()));
Value unsqueezedDims =
rewriter.create<PrimListConstructOp>(loc, listType, unsqueezedSizes);
Value expandedDims =
rewriter.create<PrimListConstructOp>(loc, listType, expandedSizes);
Value reshapedDims =
rewriter.create<PrimListConstructOp>(loc, listType, reshapedSizes);
auto reshaped = rewriter.create<AtenViewOp>(loc, unsqueezedType,
op.getSelf(), unsqueezedDims);
auto expanded = rewriter.create<AtenBroadcastToOp>(loc, expandedType,
reshaped, expandedDims);
llvm::SmallVector<int64_t> expandShape(selfTy.getSizes());
for (int i = 0; i < repeatSz; ++i) {
int64_t repeatDim = i < batch ? i : i * 2 - batch;
int64_t repeat;
if (!matchPattern(repeats[i], m_TorchConstantInt(&repeat)))
repeat = Torch::kUnknownSize;
expandShape[repeatDim] = repeat;
}
rewriter.replaceOpWithNewOp<AtenViewOp>(op, op.getType(), expanded,
reshapedDims);
auto mulDim = [](int64_t lhs, int64_t rhs) {
if (lhs == Torch::kUnknownSize || rhs == Torch::kUnknownSize)
return Torch::kUnknownSize;
return lhs * rhs;
};
BaseTensorType expandTy = rewriter.getType<ValueTensorType>(
expandShape, selfTy.getOptionalDtype());
Value expand =
rewriter.create<AtenBroadcastToOp>(loc, expandTy, self, lengthv);
for (int i = 0; i < rank; ++i) {
auto oldShape = expandTy.getSizes();
llvm::SmallVector<int64_t> newShape;
int64_t flattenDim = i + batch;
for (int j = 0; j < flattenDim; ++j)
newShape.push_back(oldShape[j]);
newShape.push_back(
mulDim(oldShape[flattenDim], oldShape[flattenDim + 1]));
for (int j = flattenDim + 2, s = oldShape.size(); j < s; ++j)
newShape.push_back(oldShape[j]);
expandTy = rewriter.getType<ValueTensorType>(newShape,
expandTy.getOptionalDtype());
// Used to keep the return type the same on the last flatten:
expandTy = i < rank - 1 ? expandTy : cast<BaseTensorType>(op.getType());
Value start = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(flattenDim));
Value end = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(flattenDim + 1));
expand = rewriter.create<AtenFlattenUsingIntsOp>(loc, expandTy, expand,
start, end);
}
rewriter.replaceOp(op, expand);
return success();
}
};

View File

@ -2142,7 +2142,6 @@ ONNX_XFAIL_SET = {
"IndexTensorMultiInputThreeIndexers_basic",
"IndexTensorMultiInput_basic",
"IndexTensorStaticContiguousWithNoneModule_basic",
"RepeatModule_basic",
"SelectIntModule_basic",
"SliceSingleIdxModule_basic",
"ViewFlattenAndExpandModule_basic",