From dff3405d5a1b7912aad702cb7f4d581de2d4a695 Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Thu, 25 May 2023 10:05:41 -0700 Subject: [PATCH] Add alias analysis for cast-like ops to maximize-value-semantics (#2160) When `use_tracing=True` is used to import a model into Torch-MLIR, several casts get inserted in the IR to bridge the untyped inputs and outputs with the typed body of the computation. These casts create extra aliases of tensors that cause the current analysis in `maximize-value-semantics` to fail. In particular, the `maximize-value-semantics` analysis assumes that the only valid alias right after an overwrite is the overwritten alias. So, if there is a use of a casted version of the overwritten alias after the overwrite, the analysis fails. This commit improves the analysis by identifying all cast-like aliases of the overwritten alias and allowing such aliases to be used after an overwrite. Because this issue only arises when using tracing, it cannot be currently tested e2e, so only lit test is added. --- .../Transforms/MaximizeValueSemantics.cpp | 47 ++++++++++++++++--- .../Torch/maximize-value-semantics.mlir | 16 +++++++ 2 files changed, 57 insertions(+), 6 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp index 121c759a6..cd76275a7 100644 --- a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp +++ b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp @@ -28,6 +28,28 @@ static Value assertNonValueTensor(Value tensor) { return tensor; } +// A cast-like op is an op that does not modify the contents, shape, and dtype +// of the input tensor. In other words, it is an op that only serves to encode +// compile time information, but at runtime the op behaves like a no-op. +static bool isCastLikeOp(Operation *op) { + return isa(op); +} + +// Given a `value`, this function goes up the use-def chain and finds the +// largest sequence of consecutive cast-like ops. The returned set contains all +// the aliases that are identical to `value`, and have only been transformed by +// cast-like ops. +static DenseSet getCastLikeAliasesOf(Value value) { + Operation *currentOp = value.getDefiningOp(); + DenseSet result; + while (isCastLikeOp(currentOp)) { + Value operand = assertNonValueTensor(currentOp->getOperand(0)); + result.insert(operand); + currentOp = operand.getDefiningOp(); + } + return result; +} + namespace { class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock : public OpRewritePattern { @@ -88,9 +110,13 @@ public: } else if (auto overwrite = dyn_cast(user)) { // To simplify the analysis, we only support the case where the // only aliases used after an overwrite are the aliases generated - // after plus the alias being overwritten. + // after plus the alias being overwritten and any aliases that are + // simply a cast of the overwritten alias. availableAliases.clear(); - availableAliases.insert(assertNonValueTensor(overwrite.getOverwritten())); + Value overwritten = overwrite.getOverwritten(); + availableAliases.insert(assertNonValueTensor(overwritten)); + DenseSet castLikeAliases = getCastLikeAliasesOf(overwritten); + availableAliases.insert(castLikeAliases.begin(), castLikeAliases.end()); result.overwriteTensorContentsOps.push_back(overwrite); } else if (auto returnOp = dyn_cast(user)) { result.returnOp = returnOp; @@ -128,10 +154,19 @@ public: for (OverwriteTensorContentsOp overwrite : llvm::reverse(ops.overwriteTensorContentsOps)) { Value overwritten = assertNonValueTensor(overwrite.getOverwritten()); - overwritten.replaceUsesWithIf( - overwrite.getValue(), [&](const OpOperand &operand) { - return !operand.getOwner()->isBeforeInBlock(overwrite); - }); + // Cast-like aliases represent the exact same tensor at runtime as the + // overwritten alias, since casts only encode compile time information. + // Therefore, here we replace the overwritten value and any cast-like + // aliases of it with the overwrite value. + DenseSet overwrittenAliases = getCastLikeAliasesOf(overwritten); + overwrittenAliases.insert(overwritten); + + for (Value alias : overwrittenAliases) { + alias.replaceUsesWithIf( + overwrite.getValue(), [&](const OpOperand &operand) { + return !operand.getOwner()->isBeforeInBlock(overwrite); + }); + } rewriter.eraseOp(overwrite); } diff --git a/test/Dialect/Torch/maximize-value-semantics.mlir b/test/Dialect/Torch/maximize-value-semantics.mlir index 795d90045..6643e1e6f 100644 --- a/test/Dialect/Torch/maximize-value-semantics.mlir +++ b/test/Dialect/Torch/maximize-value-semantics.mlir @@ -261,3 +261,19 @@ func.func @viewlike$two_inputs_two_copies(%arg0: !torch.vtensor, %arg1: !torch.v %3 = torch.copy.to_vtensor %2 : !torch.vtensor return %3 : !torch.vtensor } + +// CHECK-LABEL: func.func @castlike( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[5,4],f32>) -> !torch.tensor { +// CHECK: %[[CAST1:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[5,4],f32> to !torch.vtensor +// CHECK: %[[CAST2:.*]] = torch.tensor_static_info_cast %[[CAST1]] : !torch.vtensor to !torch.vtensor<[5,4],f32> +// CHECK: %[[CAST3:.*]] = torch.tensor_static_info_cast %[[CAST2]] : !torch.vtensor<[5,4],f32> to !torch.vtensor +// CHECK: %[[COPY:.*]] = torch.copy.to_tensor %[[CAST3]] : !torch.tensor +// CHECK: return %[[COPY]] : !torch.tensor +func.func @castlike(%arg0: !torch.vtensor<[5,4],f32>) -> !torch.tensor { + %0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[5,4],f32> to !torch.vtensor + %1 = torch.copy.to_tensor %0 : !torch.tensor + %2 = torch.tensor_static_info_cast %1 : !torch.tensor to !torch.tensor<[5,4],f32> + %3 = torch.copy.to_vtensor %2 : !torch.vtensor<[5,4],f32> + torch.overwrite.tensor.contents %3 overwrites %2 : !torch.vtensor<[5,4],f32>, !torch.tensor<[5,4],f32> + return %1 : !torch.tensor +}