Add alias analysis for cast-like ops to maximize-value-semantics (#2160)

When `use_tracing=True` is used to import a model into Torch-MLIR,
several casts get inserted in the IR to bridge the untyped inputs and
outputs with the typed body of the computation. These casts create
extra aliases of tensors that cause the current analysis in
`maximize-value-semantics` to fail.

In particular, the `maximize-value-semantics` analysis assumes that the
only valid alias right after an overwrite is the overwritten
alias. So, if there is a use of a casted version of the overwritten
alias after the overwrite, the analysis fails.

This commit improves the analysis by identifying all cast-like aliases
of the overwritten alias and allowing such aliases to be used after an
overwrite.

Because this issue only arises when using tracing, it cannot be
currently tested e2e, so only lit test is added.
pull/2171/head
Ramiro Leal-Cavazos 2023-05-25 10:05:41 -07:00 committed by GitHub
parent 9f65a8a961
commit dff3405d5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 57 additions and 6 deletions

View File

@ -28,6 +28,28 @@ static Value assertNonValueTensor(Value tensor) {
return tensor; return tensor;
} }
// A cast-like op is an op that does not modify the contents, shape, and dtype
// of the input tensor. In other words, it is an op that only serves to encode
// compile time information, but at runtime the op behaves like a no-op.
static bool isCastLikeOp(Operation *op) {
return isa<TensorStaticInfoCastOp>(op);
}
// Given a `value`, this function goes up the use-def chain and finds the
// largest sequence of consecutive cast-like ops. The returned set contains all
// the aliases that are identical to `value`, and have only been transformed by
// cast-like ops.
static DenseSet<Value> getCastLikeAliasesOf(Value value) {
Operation *currentOp = value.getDefiningOp();
DenseSet<Value> result;
while (isCastLikeOp(currentOp)) {
Value operand = assertNonValueTensor(currentOp->getOperand(0));
result.insert(operand);
currentOp = operand.getDefiningOp();
}
return result;
}
namespace { namespace {
class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
: public OpRewritePattern<CopyToNonValueTensorOp> { : public OpRewritePattern<CopyToNonValueTensorOp> {
@ -88,9 +110,13 @@ public:
} else if (auto overwrite = dyn_cast<OverwriteTensorContentsOp>(user)) { } else if (auto overwrite = dyn_cast<OverwriteTensorContentsOp>(user)) {
// To simplify the analysis, we only support the case where the // To simplify the analysis, we only support the case where the
// only aliases used after an overwrite are the aliases generated // only aliases used after an overwrite are the aliases generated
// after plus the alias being overwritten. // after plus the alias being overwritten and any aliases that are
// simply a cast of the overwritten alias.
availableAliases.clear(); availableAliases.clear();
availableAliases.insert(assertNonValueTensor(overwrite.getOverwritten())); Value overwritten = overwrite.getOverwritten();
availableAliases.insert(assertNonValueTensor(overwritten));
DenseSet<Value> castLikeAliases = getCastLikeAliasesOf(overwritten);
availableAliases.insert(castLikeAliases.begin(), castLikeAliases.end());
result.overwriteTensorContentsOps.push_back(overwrite); result.overwriteTensorContentsOps.push_back(overwrite);
} else if (auto returnOp = dyn_cast<mlir::func::ReturnOp>(user)) { } else if (auto returnOp = dyn_cast<mlir::func::ReturnOp>(user)) {
result.returnOp = returnOp; result.returnOp = returnOp;
@ -128,10 +154,19 @@ public:
for (OverwriteTensorContentsOp overwrite : for (OverwriteTensorContentsOp overwrite :
llvm::reverse(ops.overwriteTensorContentsOps)) { llvm::reverse(ops.overwriteTensorContentsOps)) {
Value overwritten = assertNonValueTensor(overwrite.getOverwritten()); Value overwritten = assertNonValueTensor(overwrite.getOverwritten());
overwritten.replaceUsesWithIf( // Cast-like aliases represent the exact same tensor at runtime as the
// overwritten alias, since casts only encode compile time information.
// Therefore, here we replace the overwritten value and any cast-like
// aliases of it with the overwrite value.
DenseSet<Value> overwrittenAliases = getCastLikeAliasesOf(overwritten);
overwrittenAliases.insert(overwritten);
for (Value alias : overwrittenAliases) {
alias.replaceUsesWithIf(
overwrite.getValue(), [&](const OpOperand &operand) { overwrite.getValue(), [&](const OpOperand &operand) {
return !operand.getOwner()->isBeforeInBlock(overwrite); return !operand.getOwner()->isBeforeInBlock(overwrite);
}); });
}
rewriter.eraseOp(overwrite); rewriter.eraseOp(overwrite);
} }

View File

@ -261,3 +261,19 @@ func.func @viewlike$two_inputs_two_copies(%arg0: !torch.vtensor, %arg1: !torch.v
%3 = torch.copy.to_vtensor %2 : !torch.vtensor %3 = torch.copy.to_vtensor %2 : !torch.vtensor
return %3 : !torch.vtensor return %3 : !torch.vtensor
} }
// CHECK-LABEL: func.func @castlike(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[5,4],f32>) -> !torch.tensor {
// CHECK: %[[CAST1:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[5,4],f32> to !torch.vtensor
// CHECK: %[[CAST2:.*]] = torch.tensor_static_info_cast %[[CAST1]] : !torch.vtensor to !torch.vtensor<[5,4],f32>
// CHECK: %[[CAST3:.*]] = torch.tensor_static_info_cast %[[CAST2]] : !torch.vtensor<[5,4],f32> to !torch.vtensor
// CHECK: %[[COPY:.*]] = torch.copy.to_tensor %[[CAST3]] : !torch.tensor
// CHECK: return %[[COPY]] : !torch.tensor
func.func @castlike(%arg0: !torch.vtensor<[5,4],f32>) -> !torch.tensor {
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[5,4],f32> to !torch.vtensor
%1 = torch.copy.to_tensor %0 : !torch.tensor
%2 = torch.tensor_static_info_cast %1 : !torch.tensor to !torch.tensor<[5,4],f32>
%3 = torch.copy.to_vtensor %2 : !torch.vtensor<[5,4],f32>
torch.overwrite.tensor.contents %3 overwrites %2 : !torch.vtensor<[5,4],f32>, !torch.tensor<[5,4],f32>
return %1 : !torch.tensor
}