mirror of https://github.com/llvm/torch-mlir
Generalize `aten.view` pattern in scalarize shapes (#3856)
Extends the existing pattern to allow finding matching dims from the back as well as the front.pull/3860/head
parent
7058f456b8
commit
8519ecc4d7
|
@ -1099,97 +1099,112 @@ public:
|
|||
int64_t outRank = resultTy.getSizes().size();
|
||||
|
||||
SmallVector<int64_t> sizes(selfTy.getSizes());
|
||||
int64_t endMatchingDim = -1;
|
||||
// input sizes vs. provided view sizes comparison loop
|
||||
for (int64_t i = 0; i < std::min(outRank, inRank); i++) {
|
||||
int64_t leftMatchEnd = 0;
|
||||
// compare input sizes with provided dims from left
|
||||
for (; leftMatchEnd < std::min(outRank, inRank); leftMatchEnd++) {
|
||||
int64_t providedSize;
|
||||
bool providedStatic =
|
||||
matchPattern(viewSizes[i], m_TorchConstantInt(&providedSize));
|
||||
// if sizes[i] is static, it must match a constant in viewSizes[i]
|
||||
if (sizes[i] != Torch::kUnknownSize) {
|
||||
if (!providedStatic)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unsupported: found static input dim, but unable to match "
|
||||
"provided view size on a constant. See position : " +
|
||||
std::to_string(i));
|
||||
if (providedSize != sizes[i]) {
|
||||
endMatchingDim = i;
|
||||
bool providedStatic = matchPattern(viewSizes[leftMatchEnd],
|
||||
m_TorchConstantInt(&providedSize));
|
||||
// static dim case
|
||||
if (sizes[leftMatchEnd] != Torch::kUnknownSize) {
|
||||
// if can't infer equality of dims, set end index and break
|
||||
if (!providedStatic || providedSize != sizes[leftMatchEnd])
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
// the remaining assumes sizes[i] is dynamic
|
||||
// if provided dim is static, we can't verify it is a flatten/unflatten
|
||||
// unless -1
|
||||
if (i == outRank - 1 && providedStatic && providedSize == -1) {
|
||||
endMatchingDim = i;
|
||||
break;
|
||||
}
|
||||
// the remaining assumes sizes[leftMatchEnd] is dynamic
|
||||
// if provided dim is static, we can't match.
|
||||
if (providedStatic)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unexpected static view dim corresponding to dynamic input dim "
|
||||
"at position : " +
|
||||
std::to_string(i));
|
||||
auto sizeIntOp = viewSizes[i].getDefiningOp<AtenSizeIntOp>();
|
||||
// if we don't have a size int op on self, fail
|
||||
break;
|
||||
auto sizeIntOp = viewSizes[leftMatchEnd].getDefiningOp<AtenSizeIntOp>();
|
||||
// if we don't have a size int op on self, break
|
||||
if (!sizeIntOp || sizeIntOp.getSelf() != op.getSelf())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected dynamic view dim to come from a corresponding "
|
||||
"size.int op. See position : " +
|
||||
std::to_string(i));
|
||||
break;
|
||||
int64_t dim;
|
||||
// if the dim of the size int op doesn't match, fail
|
||||
if (!matchPattern(sizeIntOp.getDim(), m_TorchConstantInt(&dim)) ||
|
||||
dim != i)
|
||||
dim != leftMatchEnd)
|
||||
break;
|
||||
}
|
||||
|
||||
int64_t rightMatchEnd = 0;
|
||||
// compare input sizes with provided dims from right
|
||||
for (; rightMatchEnd < std::min(outRank, inRank) - leftMatchEnd;
|
||||
rightMatchEnd++) {
|
||||
int64_t providedSize;
|
||||
bool providedStatic = matchPattern(viewSizes[outRank - 1 - rightMatchEnd],
|
||||
m_TorchConstantInt(&providedSize));
|
||||
// static dim case
|
||||
if (sizes[inRank - 1 - rightMatchEnd] != Torch::kUnknownSize) {
|
||||
// if can't infer equality of dims, set end index and break
|
||||
if (!providedStatic ||
|
||||
providedSize != sizes[inRank - 1 - rightMatchEnd])
|
||||
break;
|
||||
continue;
|
||||
}
|
||||
// the remaining assumes sizes[inRank - 1 - rightMatchEnd] is dynamic
|
||||
// if provided dim is static, we can't match.
|
||||
if (providedStatic)
|
||||
break;
|
||||
auto sizeIntOp =
|
||||
viewSizes[outRank - 1 - rightMatchEnd].getDefiningOp<AtenSizeIntOp>();
|
||||
// if we don't have a size int op on self, break
|
||||
if (!sizeIntOp || sizeIntOp.getSelf() != op.getSelf())
|
||||
break;
|
||||
int64_t dim;
|
||||
// if the dim of the size int op doesn't match, break
|
||||
if (!matchPattern(sizeIntOp.getDim(), m_TorchConstantInt(&dim)) ||
|
||||
dim != inRank - 1 - rightMatchEnd)
|
||||
break;
|
||||
}
|
||||
// the unmatched input dims start at leftMatchEnd, and end before inRank -
|
||||
// rightMatchEnd
|
||||
int64_t inputUnmatched = (inRank - rightMatchEnd) - leftMatchEnd;
|
||||
int64_t outputUnmatched = (outRank - rightMatchEnd) - leftMatchEnd;
|
||||
// if too many dims are unmatched in input/output, cannot canonicalize.
|
||||
if (inputUnmatched > 1 && outputUnmatched > 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"size int op dim cannot be matched to current dim at position : " +
|
||||
std::to_string(i));
|
||||
// passing the previous checks means viewSizes[i] = aten.size.int(self,
|
||||
// i), so continue
|
||||
}
|
||||
// if all dims match and the ranks are equal, fold
|
||||
if (endMatchingDim == -1 && inRank == outRank) {
|
||||
rewriter.replaceOp(op, op.getSelf());
|
||||
"View op is not simple enough to canonicalize.\n# Unmatched Input "
|
||||
"dims = " +
|
||||
std::to_string(inputUnmatched) +
|
||||
"\n# Unmatched Output Dims = " + std::to_string(outputUnmatched) +
|
||||
"\nStarting unmatched index = " + std::to_string(leftMatchEnd));
|
||||
|
||||
// if all dims match, return self.
|
||||
if (inputUnmatched == outputUnmatched &&
|
||||
(inputUnmatched == 1 || inputUnmatched == 0)) {
|
||||
rewriter.replaceOpWithNewOp<Torch::TensorStaticInfoCastOp>(
|
||||
op, op.getType(), op.getSelf());
|
||||
return success();
|
||||
}
|
||||
if (endMatchingDim > -1 && inRank > outRank) {
|
||||
// only support flattening last dim
|
||||
if (endMatchingDim != outRank - 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: output has more than back dim mismatching");
|
||||
// flatten
|
||||
Value start =
|
||||
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), endMatchingDim);
|
||||
Value end =
|
||||
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), inRank - 1);
|
||||
rewriter.replaceOpWithNewOp<AtenFlattenUsingIntsOp>(
|
||||
op, resultTy, op.getSelf(), start, end);
|
||||
return success();
|
||||
}
|
||||
if (endMatchingDim > -1 && inRank < outRank) {
|
||||
// only support unflattening last dim
|
||||
if (endMatchingDim != inRank - 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: input has more than back dim mismatching");
|
||||
// unflatten
|
||||
Value dim =
|
||||
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), endMatchingDim);
|
||||
Value primList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
op.getLoc(), op.getSize().getType(),
|
||||
ArrayRef<Value>(viewSizes.begin() + endMatchingDim, viewSizes.end()));
|
||||
// if input has 1 unmatched dim, and output has multiple, unflatten
|
||||
if (inputUnmatched == 1 && outputUnmatched > 1) {
|
||||
Value dimVal =
|
||||
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), leftMatchEnd);
|
||||
ArrayRef<Value> unflattenSizes(viewSizes.begin() + leftMatchEnd,
|
||||
viewSizes.end() - rightMatchEnd);
|
||||
Value unflattenList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
op.getLoc(), op.getSize().getType(), unflattenSizes);
|
||||
rewriter.replaceOpWithNewOp<AtenUnflattenIntOp>(
|
||||
op, resultTy, op.getSelf(), dim, primList);
|
||||
op, op.getType(), op.getSelf(), dimVal, unflattenList);
|
||||
return success();
|
||||
}
|
||||
// examples that might reach this:
|
||||
// input shape = [10, 5]; view sizes = [5, 10] (or dynamic variants)
|
||||
// input shape = [dim0, dim1]; view sizes = [dim0, dim1, 1, 1] (unsqueezes)
|
||||
// input shape = [dim0, dim1, 1, 1] view sizes = [dim0, dim1] (squeezes)
|
||||
// if multiple unmatched input dims map to one output dim, flatten
|
||||
if (inputUnmatched > 1 && outputUnmatched == 1) {
|
||||
Value startDim =
|
||||
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), leftMatchEnd);
|
||||
// note: flatten end is inclusive for some reason.
|
||||
int64_t endInt = inRank - rightMatchEnd - 1;
|
||||
Value endDim = rewriter.create<Torch::ConstantIntOp>(op.getLoc(), endInt);
|
||||
rewriter.replaceOpWithNewOp<AtenFlattenUsingIntsOp>(
|
||||
op, op.getType(), op.getSelf(), startDim, endDim);
|
||||
return success();
|
||||
}
|
||||
// the remaining cases involve maximal matching dims, but mismatched ranks.
|
||||
// This could only occur if squeezing or unsqueezing.
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unhandled case: endMatchingDim=" + std::to_string(endMatchingDim) +
|
||||
", inRank=" + std::to_string(inRank) +
|
||||
", outRank=" + std::to_string(outRank));
|
||||
op, "unhandled view op canonicalization to squeeze/unsqueeze.");
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
|
|
@ -255,6 +255,25 @@ func.func @view_as_flatten_dynamic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !tor
|
|||
return %3 : !torch.vtensor<[?,?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @view_as_flatten_mid
|
||||
func.func @view_as_flatten_mid(%arg0: !torch.vtensor<[?,?,?,?,2,4],f32>) -> !torch.vtensor<[?,?,?,4],f32> {
|
||||
// CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2
|
||||
// CHECK-DAG: %[[FOUR:.*]] = torch.constant.int 4
|
||||
// CHECK-DAG: %[[FLAT:.*]] = torch.aten.flatten.using_ints %arg0, %[[TWO]], %[[FOUR]] : !torch.vtensor<[?,?,?,?,2,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,4],f32>
|
||||
// CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,?,4],f32>
|
||||
%int-1 = torch.constant.int -1
|
||||
%int1 = torch.constant.int 1
|
||||
%int0 = torch.constant.int 0
|
||||
%int4 = torch.constant.int 4
|
||||
%0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,?,?,2,4],f32>, !torch.int -> !torch.int
|
||||
%1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,?,?,2,4],f32>, !torch.int -> !torch.int
|
||||
%2 = torch.prim.ListConstruct %0, %1, %int-1, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
%3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?,?,?,2,4],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?,4],f32>
|
||||
return %3 : !torch.vtensor<[?,?,?,4],f32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
|
Loading…
Reference in New Issue