mirror of https://github.com/llvm/torch-mlir
Add alias analysis for cast-like ops to maximize-value-semantics (#2160)
When `use_tracing=True` is used to import a model into Torch-MLIR, several casts get inserted in the IR to bridge the untyped inputs and outputs with the typed body of the computation. These casts create extra aliases of tensors that cause the current analysis in `maximize-value-semantics` to fail. In particular, the `maximize-value-semantics` analysis assumes that the only valid alias right after an overwrite is the overwritten alias. So, if there is a use of a casted version of the overwritten alias after the overwrite, the analysis fails. This commit improves the analysis by identifying all cast-like aliases of the overwritten alias and allowing such aliases to be used after an overwrite. Because this issue only arises when using tracing, it cannot be currently tested e2e, so only lit test is added.pull/2171/head
parent
9f65a8a961
commit
dff3405d5a
|
@ -28,6 +28,28 @@ static Value assertNonValueTensor(Value tensor) {
|
||||||
return tensor;
|
return tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A cast-like op is an op that does not modify the contents, shape, and dtype
|
||||||
|
// of the input tensor. In other words, it is an op that only serves to encode
|
||||||
|
// compile time information, but at runtime the op behaves like a no-op.
|
||||||
|
static bool isCastLikeOp(Operation *op) {
|
||||||
|
return isa<TensorStaticInfoCastOp>(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Given a `value`, this function goes up the use-def chain and finds the
|
||||||
|
// largest sequence of consecutive cast-like ops. The returned set contains all
|
||||||
|
// the aliases that are identical to `value`, and have only been transformed by
|
||||||
|
// cast-like ops.
|
||||||
|
static DenseSet<Value> getCastLikeAliasesOf(Value value) {
|
||||||
|
Operation *currentOp = value.getDefiningOp();
|
||||||
|
DenseSet<Value> result;
|
||||||
|
while (isCastLikeOp(currentOp)) {
|
||||||
|
Value operand = assertNonValueTensor(currentOp->getOperand(0));
|
||||||
|
result.insert(operand);
|
||||||
|
currentOp = operand.getDefiningOp();
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
|
class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
|
||||||
: public OpRewritePattern<CopyToNonValueTensorOp> {
|
: public OpRewritePattern<CopyToNonValueTensorOp> {
|
||||||
|
@ -88,9 +110,13 @@ public:
|
||||||
} else if (auto overwrite = dyn_cast<OverwriteTensorContentsOp>(user)) {
|
} else if (auto overwrite = dyn_cast<OverwriteTensorContentsOp>(user)) {
|
||||||
// To simplify the analysis, we only support the case where the
|
// To simplify the analysis, we only support the case where the
|
||||||
// only aliases used after an overwrite are the aliases generated
|
// only aliases used after an overwrite are the aliases generated
|
||||||
// after plus the alias being overwritten.
|
// after plus the alias being overwritten and any aliases that are
|
||||||
|
// simply a cast of the overwritten alias.
|
||||||
availableAliases.clear();
|
availableAliases.clear();
|
||||||
availableAliases.insert(assertNonValueTensor(overwrite.getOverwritten()));
|
Value overwritten = overwrite.getOverwritten();
|
||||||
|
availableAliases.insert(assertNonValueTensor(overwritten));
|
||||||
|
DenseSet<Value> castLikeAliases = getCastLikeAliasesOf(overwritten);
|
||||||
|
availableAliases.insert(castLikeAliases.begin(), castLikeAliases.end());
|
||||||
result.overwriteTensorContentsOps.push_back(overwrite);
|
result.overwriteTensorContentsOps.push_back(overwrite);
|
||||||
} else if (auto returnOp = dyn_cast<mlir::func::ReturnOp>(user)) {
|
} else if (auto returnOp = dyn_cast<mlir::func::ReturnOp>(user)) {
|
||||||
result.returnOp = returnOp;
|
result.returnOp = returnOp;
|
||||||
|
@ -128,10 +154,19 @@ public:
|
||||||
for (OverwriteTensorContentsOp overwrite :
|
for (OverwriteTensorContentsOp overwrite :
|
||||||
llvm::reverse(ops.overwriteTensorContentsOps)) {
|
llvm::reverse(ops.overwriteTensorContentsOps)) {
|
||||||
Value overwritten = assertNonValueTensor(overwrite.getOverwritten());
|
Value overwritten = assertNonValueTensor(overwrite.getOverwritten());
|
||||||
overwritten.replaceUsesWithIf(
|
// Cast-like aliases represent the exact same tensor at runtime as the
|
||||||
|
// overwritten alias, since casts only encode compile time information.
|
||||||
|
// Therefore, here we replace the overwritten value and any cast-like
|
||||||
|
// aliases of it with the overwrite value.
|
||||||
|
DenseSet<Value> overwrittenAliases = getCastLikeAliasesOf(overwritten);
|
||||||
|
overwrittenAliases.insert(overwritten);
|
||||||
|
|
||||||
|
for (Value alias : overwrittenAliases) {
|
||||||
|
alias.replaceUsesWithIf(
|
||||||
overwrite.getValue(), [&](const OpOperand &operand) {
|
overwrite.getValue(), [&](const OpOperand &operand) {
|
||||||
return !operand.getOwner()->isBeforeInBlock(overwrite);
|
return !operand.getOwner()->isBeforeInBlock(overwrite);
|
||||||
});
|
});
|
||||||
|
}
|
||||||
rewriter.eraseOp(overwrite);
|
rewriter.eraseOp(overwrite);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -261,3 +261,19 @@ func.func @viewlike$two_inputs_two_copies(%arg0: !torch.vtensor, %arg1: !torch.v
|
||||||
%3 = torch.copy.to_vtensor %2 : !torch.vtensor
|
%3 = torch.copy.to_vtensor %2 : !torch.vtensor
|
||||||
return %3 : !torch.vtensor
|
return %3 : !torch.vtensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @castlike(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[5,4],f32>) -> !torch.tensor {
|
||||||
|
// CHECK: %[[CAST1:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[5,4],f32> to !torch.vtensor
|
||||||
|
// CHECK: %[[CAST2:.*]] = torch.tensor_static_info_cast %[[CAST1]] : !torch.vtensor to !torch.vtensor<[5,4],f32>
|
||||||
|
// CHECK: %[[CAST3:.*]] = torch.tensor_static_info_cast %[[CAST2]] : !torch.vtensor<[5,4],f32> to !torch.vtensor
|
||||||
|
// CHECK: %[[COPY:.*]] = torch.copy.to_tensor %[[CAST3]] : !torch.tensor
|
||||||
|
// CHECK: return %[[COPY]] : !torch.tensor
|
||||||
|
func.func @castlike(%arg0: !torch.vtensor<[5,4],f32>) -> !torch.tensor {
|
||||||
|
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[5,4],f32> to !torch.vtensor
|
||||||
|
%1 = torch.copy.to_tensor %0 : !torch.tensor
|
||||||
|
%2 = torch.tensor_static_info_cast %1 : !torch.tensor to !torch.tensor<[5,4],f32>
|
||||||
|
%3 = torch.copy.to_vtensor %2 : !torch.vtensor<[5,4],f32>
|
||||||
|
torch.overwrite.tensor.contents %3 overwrites %2 : !torch.vtensor<[5,4],f32>, !torch.tensor<[5,4],f32>
|
||||||
|
return %1 : !torch.tensor
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue