Generalize `aten.view` pattern in scalarize shapes (#3856)

Extends the existing pattern to allow finding matching dims from the
back as well as the front.
pull/3860/head
zjgarvey 2024-11-07 15:26:07 -06:00 committed by GitHub
parent 7058f456b8
commit 8519ecc4d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 109 additions and 75 deletions

View File

@ -1099,97 +1099,112 @@ public:
int64_t outRank = resultTy.getSizes().size();
SmallVector<int64_t> sizes(selfTy.getSizes());
int64_t endMatchingDim = -1;
// input sizes vs. provided view sizes comparison loop
for (int64_t i = 0; i < std::min(outRank, inRank); i++) {
int64_t leftMatchEnd = 0;
// compare input sizes with provided dims from left
for (; leftMatchEnd < std::min(outRank, inRank); leftMatchEnd++) {
int64_t providedSize;
bool providedStatic =
matchPattern(viewSizes[i], m_TorchConstantInt(&providedSize));
// if sizes[i] is static, it must match a constant in viewSizes[i]
if (sizes[i] != Torch::kUnknownSize) {
if (!providedStatic)
return rewriter.notifyMatchFailure(
op, "unsupported: found static input dim, but unable to match "
"provided view size on a constant. See position : " +
std::to_string(i));
if (providedSize != sizes[i]) {
endMatchingDim = i;
bool providedStatic = matchPattern(viewSizes[leftMatchEnd],
m_TorchConstantInt(&providedSize));
// static dim case
if (sizes[leftMatchEnd] != Torch::kUnknownSize) {
// if can't infer equality of dims, set end index and break
if (!providedStatic || providedSize != sizes[leftMatchEnd])
break;
}
continue;
}
// the remaining assumes sizes[i] is dynamic
// if provided dim is static, we can't verify it is a flatten/unflatten
// unless -1
if (i == outRank - 1 && providedStatic && providedSize == -1) {
endMatchingDim = i;
break;
}
// the remaining assumes sizes[leftMatchEnd] is dynamic
// if provided dim is static, we can't match.
if (providedStatic)
return rewriter.notifyMatchFailure(
op, "unexpected static view dim corresponding to dynamic input dim "
"at position : " +
std::to_string(i));
auto sizeIntOp = viewSizes[i].getDefiningOp<AtenSizeIntOp>();
// if we don't have a size int op on self, fail
break;
auto sizeIntOp = viewSizes[leftMatchEnd].getDefiningOp<AtenSizeIntOp>();
// if we don't have a size int op on self, break
if (!sizeIntOp || sizeIntOp.getSelf() != op.getSelf())
return rewriter.notifyMatchFailure(
op, "expected dynamic view dim to come from a corresponding "
"size.int op. See position : " +
std::to_string(i));
break;
int64_t dim;
// if the dim of the size int op doesn't match, fail
if (!matchPattern(sizeIntOp.getDim(), m_TorchConstantInt(&dim)) ||
dim != i)
dim != leftMatchEnd)
break;
}
int64_t rightMatchEnd = 0;
// compare input sizes with provided dims from right
for (; rightMatchEnd < std::min(outRank, inRank) - leftMatchEnd;
rightMatchEnd++) {
int64_t providedSize;
bool providedStatic = matchPattern(viewSizes[outRank - 1 - rightMatchEnd],
m_TorchConstantInt(&providedSize));
// static dim case
if (sizes[inRank - 1 - rightMatchEnd] != Torch::kUnknownSize) {
// if can't infer equality of dims, set end index and break
if (!providedStatic ||
providedSize != sizes[inRank - 1 - rightMatchEnd])
break;
continue;
}
// the remaining assumes sizes[inRank - 1 - rightMatchEnd] is dynamic
// if provided dim is static, we can't match.
if (providedStatic)
break;
auto sizeIntOp =
viewSizes[outRank - 1 - rightMatchEnd].getDefiningOp<AtenSizeIntOp>();
// if we don't have a size int op on self, break
if (!sizeIntOp || sizeIntOp.getSelf() != op.getSelf())
break;
int64_t dim;
// if the dim of the size int op doesn't match, break
if (!matchPattern(sizeIntOp.getDim(), m_TorchConstantInt(&dim)) ||
dim != inRank - 1 - rightMatchEnd)
break;
}
// the unmatched input dims start at leftMatchEnd, and end before inRank -
// rightMatchEnd
int64_t inputUnmatched = (inRank - rightMatchEnd) - leftMatchEnd;
int64_t outputUnmatched = (outRank - rightMatchEnd) - leftMatchEnd;
// if too many dims are unmatched in input/output, cannot canonicalize.
if (inputUnmatched > 1 && outputUnmatched > 1)
return rewriter.notifyMatchFailure(
op,
"size int op dim cannot be matched to current dim at position : " +
std::to_string(i));
// passing the previous checks means viewSizes[i] = aten.size.int(self,
// i), so continue
}
// if all dims match and the ranks are equal, fold
if (endMatchingDim == -1 && inRank == outRank) {
rewriter.replaceOp(op, op.getSelf());
"View op is not simple enough to canonicalize.\n# Unmatched Input "
"dims = " +
std::to_string(inputUnmatched) +
"\n# Unmatched Output Dims = " + std::to_string(outputUnmatched) +
"\nStarting unmatched index = " + std::to_string(leftMatchEnd));
// if all dims match, return self.
if (inputUnmatched == outputUnmatched &&
(inputUnmatched == 1 || inputUnmatched == 0)) {
rewriter.replaceOpWithNewOp<Torch::TensorStaticInfoCastOp>(
op, op.getType(), op.getSelf());
return success();
}
if (endMatchingDim > -1 && inRank > outRank) {
// only support flattening last dim
if (endMatchingDim != outRank - 1)
return rewriter.notifyMatchFailure(
op, "unimplemented: output has more than back dim mismatching");
// flatten
Value start =
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), endMatchingDim);
Value end =
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), inRank - 1);
rewriter.replaceOpWithNewOp<AtenFlattenUsingIntsOp>(
op, resultTy, op.getSelf(), start, end);
return success();
}
if (endMatchingDim > -1 && inRank < outRank) {
// only support unflattening last dim
if (endMatchingDim != inRank - 1)
return rewriter.notifyMatchFailure(
op, "unimplemented: input has more than back dim mismatching");
// unflatten
Value dim =
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), endMatchingDim);
Value primList = rewriter.create<Torch::PrimListConstructOp>(
op.getLoc(), op.getSize().getType(),
ArrayRef<Value>(viewSizes.begin() + endMatchingDim, viewSizes.end()));
// if input has 1 unmatched dim, and output has multiple, unflatten
if (inputUnmatched == 1 && outputUnmatched > 1) {
Value dimVal =
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), leftMatchEnd);
ArrayRef<Value> unflattenSizes(viewSizes.begin() + leftMatchEnd,
viewSizes.end() - rightMatchEnd);
Value unflattenList = rewriter.create<Torch::PrimListConstructOp>(
op.getLoc(), op.getSize().getType(), unflattenSizes);
rewriter.replaceOpWithNewOp<AtenUnflattenIntOp>(
op, resultTy, op.getSelf(), dim, primList);
op, op.getType(), op.getSelf(), dimVal, unflattenList);
return success();
}
// examples that might reach this:
// input shape = [10, 5]; view sizes = [5, 10] (or dynamic variants)
// input shape = [dim0, dim1]; view sizes = [dim0, dim1, 1, 1] (unsqueezes)
// input shape = [dim0, dim1, 1, 1] view sizes = [dim0, dim1] (squeezes)
// if multiple unmatched input dims map to one output dim, flatten
if (inputUnmatched > 1 && outputUnmatched == 1) {
Value startDim =
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), leftMatchEnd);
// note: flatten end is inclusive for some reason.
int64_t endInt = inRank - rightMatchEnd - 1;
Value endDim = rewriter.create<Torch::ConstantIntOp>(op.getLoc(), endInt);
rewriter.replaceOpWithNewOp<AtenFlattenUsingIntsOp>(
op, op.getType(), op.getSelf(), startDim, endDim);
return success();
}
// the remaining cases involve maximal matching dims, but mismatched ranks.
// This could only occur if squeezing or unsqueezing.
return rewriter.notifyMatchFailure(
op, "unhandled case: endMatchingDim=" + std::to_string(endMatchingDim) +
", inRank=" + std::to_string(inRank) +
", outRank=" + std::to_string(outRank));
op, "unhandled view op canonicalization to squeeze/unsqueeze.");
}
};
} // namespace

View File

@ -255,6 +255,25 @@ func.func @view_as_flatten_dynamic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !tor
return %3 : !torch.vtensor<[?,?,?],f32>
}
// -----
// CHECK-LABEL: @view_as_flatten_mid
func.func @view_as_flatten_mid(%arg0: !torch.vtensor<[?,?,?,?,2,4],f32>) -> !torch.vtensor<[?,?,?,4],f32> {
// CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2
// CHECK-DAG: %[[FOUR:.*]] = torch.constant.int 4
// CHECK-DAG: %[[FLAT:.*]] = torch.aten.flatten.using_ints %arg0, %[[TWO]], %[[FOUR]] : !torch.vtensor<[?,?,?,?,2,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,4],f32>
// CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,?,4],f32>
%int-1 = torch.constant.int -1
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%int4 = torch.constant.int 4
%0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,?,?,2,4],f32>, !torch.int -> !torch.int
%1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,?,?,2,4],f32>, !torch.int -> !torch.int
%2 = torch.prim.ListConstruct %0, %1, %int-1, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?,?,?,2,4],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?,4],f32>
return %3 : !torch.vtensor<[?,?,?,4],f32>
}
// -----