From 8519ecc4d7e5da0db6e4dca6b02307ae46422feb Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Thu, 7 Nov 2024 15:26:07 -0600 Subject: [PATCH] Generalize `aten.view` pattern in scalarize shapes (#3856) Extends the existing pattern to allow finding matching dims from the back as well as the front. --- .../Torch/Transforms/ScalarizeShapes.cpp | 165 ++++++++++-------- test/Dialect/Torch/scalarize-shapes.mlir | 19 ++ 2 files changed, 109 insertions(+), 75 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp index 9a85fbaa8..3d1a54de2 100644 --- a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp +++ b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp @@ -1099,97 +1099,112 @@ public: int64_t outRank = resultTy.getSizes().size(); SmallVector sizes(selfTy.getSizes()); - int64_t endMatchingDim = -1; - // input sizes vs. provided view sizes comparison loop - for (int64_t i = 0; i < std::min(outRank, inRank); i++) { + int64_t leftMatchEnd = 0; + // compare input sizes with provided dims from left + for (; leftMatchEnd < std::min(outRank, inRank); leftMatchEnd++) { int64_t providedSize; - bool providedStatic = - matchPattern(viewSizes[i], m_TorchConstantInt(&providedSize)); - // if sizes[i] is static, it must match a constant in viewSizes[i] - if (sizes[i] != Torch::kUnknownSize) { - if (!providedStatic) - return rewriter.notifyMatchFailure( - op, "unsupported: found static input dim, but unable to match " - "provided view size on a constant. See position : " + - std::to_string(i)); - if (providedSize != sizes[i]) { - endMatchingDim = i; + bool providedStatic = matchPattern(viewSizes[leftMatchEnd], + m_TorchConstantInt(&providedSize)); + // static dim case + if (sizes[leftMatchEnd] != Torch::kUnknownSize) { + // if can't infer equality of dims, set end index and break + if (!providedStatic || providedSize != sizes[leftMatchEnd]) break; - } continue; } - // the remaining assumes sizes[i] is dynamic - // if provided dim is static, we can't verify it is a flatten/unflatten - // unless -1 - if (i == outRank - 1 && providedStatic && providedSize == -1) { - endMatchingDim = i; - break; - } + // the remaining assumes sizes[leftMatchEnd] is dynamic + // if provided dim is static, we can't match. if (providedStatic) - return rewriter.notifyMatchFailure( - op, "unexpected static view dim corresponding to dynamic input dim " - "at position : " + - std::to_string(i)); - auto sizeIntOp = viewSizes[i].getDefiningOp(); - // if we don't have a size int op on self, fail + break; + auto sizeIntOp = viewSizes[leftMatchEnd].getDefiningOp(); + // if we don't have a size int op on self, break if (!sizeIntOp || sizeIntOp.getSelf() != op.getSelf()) - return rewriter.notifyMatchFailure( - op, "expected dynamic view dim to come from a corresponding " - "size.int op. See position : " + - std::to_string(i)); + break; int64_t dim; // if the dim of the size int op doesn't match, fail if (!matchPattern(sizeIntOp.getDim(), m_TorchConstantInt(&dim)) || - dim != i) - return rewriter.notifyMatchFailure( - op, - "size int op dim cannot be matched to current dim at position : " + - std::to_string(i)); - // passing the previous checks means viewSizes[i] = aten.size.int(self, - // i), so continue + dim != leftMatchEnd) + break; } - // if all dims match and the ranks are equal, fold - if (endMatchingDim == -1 && inRank == outRank) { - rewriter.replaceOp(op, op.getSelf()); + + int64_t rightMatchEnd = 0; + // compare input sizes with provided dims from right + for (; rightMatchEnd < std::min(outRank, inRank) - leftMatchEnd; + rightMatchEnd++) { + int64_t providedSize; + bool providedStatic = matchPattern(viewSizes[outRank - 1 - rightMatchEnd], + m_TorchConstantInt(&providedSize)); + // static dim case + if (sizes[inRank - 1 - rightMatchEnd] != Torch::kUnknownSize) { + // if can't infer equality of dims, set end index and break + if (!providedStatic || + providedSize != sizes[inRank - 1 - rightMatchEnd]) + break; + continue; + } + // the remaining assumes sizes[inRank - 1 - rightMatchEnd] is dynamic + // if provided dim is static, we can't match. + if (providedStatic) + break; + auto sizeIntOp = + viewSizes[outRank - 1 - rightMatchEnd].getDefiningOp(); + // if we don't have a size int op on self, break + if (!sizeIntOp || sizeIntOp.getSelf() != op.getSelf()) + break; + int64_t dim; + // if the dim of the size int op doesn't match, break + if (!matchPattern(sizeIntOp.getDim(), m_TorchConstantInt(&dim)) || + dim != inRank - 1 - rightMatchEnd) + break; + } + // the unmatched input dims start at leftMatchEnd, and end before inRank - + // rightMatchEnd + int64_t inputUnmatched = (inRank - rightMatchEnd) - leftMatchEnd; + int64_t outputUnmatched = (outRank - rightMatchEnd) - leftMatchEnd; + // if too many dims are unmatched in input/output, cannot canonicalize. + if (inputUnmatched > 1 && outputUnmatched > 1) + return rewriter.notifyMatchFailure( + op, + "View op is not simple enough to canonicalize.\n# Unmatched Input " + "dims = " + + std::to_string(inputUnmatched) + + "\n# Unmatched Output Dims = " + std::to_string(outputUnmatched) + + "\nStarting unmatched index = " + std::to_string(leftMatchEnd)); + + // if all dims match, return self. + if (inputUnmatched == outputUnmatched && + (inputUnmatched == 1 || inputUnmatched == 0)) { + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf()); return success(); } - if (endMatchingDim > -1 && inRank > outRank) { - // only support flattening last dim - if (endMatchingDim != outRank - 1) - return rewriter.notifyMatchFailure( - op, "unimplemented: output has more than back dim mismatching"); - // flatten - Value start = - rewriter.create(op.getLoc(), endMatchingDim); - Value end = - rewriter.create(op.getLoc(), inRank - 1); - rewriter.replaceOpWithNewOp( - op, resultTy, op.getSelf(), start, end); - return success(); - } - if (endMatchingDim > -1 && inRank < outRank) { - // only support unflattening last dim - if (endMatchingDim != inRank - 1) - return rewriter.notifyMatchFailure( - op, "unimplemented: input has more than back dim mismatching"); - // unflatten - Value dim = - rewriter.create(op.getLoc(), endMatchingDim); - Value primList = rewriter.create( - op.getLoc(), op.getSize().getType(), - ArrayRef(viewSizes.begin() + endMatchingDim, viewSizes.end())); + // if input has 1 unmatched dim, and output has multiple, unflatten + if (inputUnmatched == 1 && outputUnmatched > 1) { + Value dimVal = + rewriter.create(op.getLoc(), leftMatchEnd); + ArrayRef unflattenSizes(viewSizes.begin() + leftMatchEnd, + viewSizes.end() - rightMatchEnd); + Value unflattenList = rewriter.create( + op.getLoc(), op.getSize().getType(), unflattenSizes); rewriter.replaceOpWithNewOp( - op, resultTy, op.getSelf(), dim, primList); + op, op.getType(), op.getSelf(), dimVal, unflattenList); return success(); } - // examples that might reach this: - // input shape = [10, 5]; view sizes = [5, 10] (or dynamic variants) - // input shape = [dim0, dim1]; view sizes = [dim0, dim1, 1, 1] (unsqueezes) - // input shape = [dim0, dim1, 1, 1] view sizes = [dim0, dim1] (squeezes) + // if multiple unmatched input dims map to one output dim, flatten + if (inputUnmatched > 1 && outputUnmatched == 1) { + Value startDim = + rewriter.create(op.getLoc(), leftMatchEnd); + // note: flatten end is inclusive for some reason. + int64_t endInt = inRank - rightMatchEnd - 1; + Value endDim = rewriter.create(op.getLoc(), endInt); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), startDim, endDim); + return success(); + } + // the remaining cases involve maximal matching dims, but mismatched ranks. + // This could only occur if squeezing or unsqueezing. return rewriter.notifyMatchFailure( - op, "unhandled case: endMatchingDim=" + std::to_string(endMatchingDim) + - ", inRank=" + std::to_string(inRank) + - ", outRank=" + std::to_string(outRank)); + op, "unhandled view op canonicalization to squeeze/unsqueeze."); } }; } // namespace diff --git a/test/Dialect/Torch/scalarize-shapes.mlir b/test/Dialect/Torch/scalarize-shapes.mlir index f5193b701..5ea715735 100644 --- a/test/Dialect/Torch/scalarize-shapes.mlir +++ b/test/Dialect/Torch/scalarize-shapes.mlir @@ -255,6 +255,25 @@ func.func @view_as_flatten_dynamic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !tor return %3 : !torch.vtensor<[?,?,?],f32> } +// ----- + +// CHECK-LABEL: @view_as_flatten_mid +func.func @view_as_flatten_mid(%arg0: !torch.vtensor<[?,?,?,?,2,4],f32>) -> !torch.vtensor<[?,?,?,4],f32> { + // CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2 + // CHECK-DAG: %[[FOUR:.*]] = torch.constant.int 4 + // CHECK-DAG: %[[FLAT:.*]] = torch.aten.flatten.using_ints %arg0, %[[TWO]], %[[FOUR]] : !torch.vtensor<[?,?,?,?,2,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,4],f32> + // CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,?,4],f32> + %int-1 = torch.constant.int -1 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %int4 = torch.constant.int 4 + %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,?,?,2,4],f32>, !torch.int -> !torch.int + %1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,?,?,2,4],f32>, !torch.int -> !torch.int + %2 = torch.prim.ListConstruct %0, %1, %int-1, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?,?,?,2,4],f32>, !torch.list -> !torch.vtensor<[?,?,?,4],f32> + return %3 : !torch.vtensor<[?,?,?,4],f32> +} + // -----