mirror of https://github.com/llvm/torch-mlir
268 lines
11 KiB
C++
268 lines
11 KiB
C++
//===- MaximizeValueSemantics.cpp --------------------------------*- C++-*-===//
|
|
//
|
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "PassDetail.h"
|
|
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::torch;
|
|
using namespace mlir::torch::Torch;
|
|
|
|
static bool isViewLikeOp(Operation *op) {
|
|
// AtenContiguousOp might return a view, so this is conservatively
|
|
// correct. We could potentially be more precise and identify the cases
|
|
// that it does not return a view and treat those as having value
|
|
// semantics.
|
|
return isa<AtenBroadcastToOp, AtenContiguousOp, AtenExpandOp,
|
|
AtenFlattenUsingIntsOp, AtenPermuteOp, AtenReshapeOp,
|
|
AtenSelectIntOp, AtenSliceTensorOp, AtenSqueezeDimOp,
|
|
AtenSqueezeOp, AtenTOp, AtenToDtypeOp, AtenTransposeIntOp,
|
|
AtenUnsqueezeOp, AtenViewOp, TensorStaticInfoCastOp>(op);
|
|
}
|
|
|
|
namespace {
|
|
class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
|
|
: public OpRewritePattern<CopyToNonValueTensorOp> {
|
|
public:
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
struct InterpretedOps {
|
|
SmallVector<Operation *> copyLikeOps;
|
|
SmallVector<Operation *> viewLikeOps;
|
|
SmallVector<OverwriteTensorContentsOp> overwriteTensorContentsOps;
|
|
};
|
|
|
|
// Check that graph rewriting is possible by doing an abstract
|
|
// interpretation within a single basic block. If rewriting is
|
|
// possible, the interpreted ops are returned split into their
|
|
// respective categories.
|
|
static FailureOr<InterpretedOps>
|
|
abstractlyInterpretSlice(CopyToNonValueTensorOp copyToNonValueTensor,
|
|
SmallVector<Operation *> nonValueTensorUsers,
|
|
PatternRewriter &rewriter) {
|
|
// Sort by order in the block, so we can abstractly interpret the ops.
|
|
llvm::sort(nonValueTensorUsers, [](Operation *lhs, Operation *rhs) {
|
|
return lhs->isBeforeInBlock(rhs);
|
|
});
|
|
|
|
// We track the available aliases at each point as well as split the
|
|
// users into view-like, copy-to-value, and overwrite ops as we walk
|
|
// forward.
|
|
InterpretedOps result;
|
|
result.copyLikeOps.push_back(copyToNonValueTensor);
|
|
DenseSet<Value> availableAliases{copyToNonValueTensor.result()};
|
|
for (Operation *user : nonValueTensorUsers) {
|
|
if (isViewLikeOp(user)) {
|
|
Value operand = user->getOperand(0);
|
|
if (!availableAliases.contains(operand)) {
|
|
return rewriter.notifyMatchFailure(
|
|
copyToNonValueTensor,
|
|
"operand of view-like op is not a valid tensor alias");
|
|
}
|
|
|
|
// View-like ops produce a new alias available to later ops.
|
|
availableAliases.insert(user->getResult(0));
|
|
result.viewLikeOps.push_back(user);
|
|
} else if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(user)) {
|
|
if (!availableAliases.contains(copyToValueTensor.operand())) {
|
|
return rewriter.notifyMatchFailure(
|
|
copyToNonValueTensor,
|
|
"operand of copyToValueTensorOp is not a valid tensor alias");
|
|
}
|
|
result.copyLikeOps.push_back(copyToValueTensor);
|
|
} else if (auto overwrite = dyn_cast<OverwriteTensorContentsOp>(user)) {
|
|
Value overwritten = overwrite.overwritten();
|
|
if (!availableAliases.contains(overwritten)) {
|
|
return rewriter.notifyMatchFailure(
|
|
copyToNonValueTensor, "overwritten tensor is not a valid alias");
|
|
}
|
|
|
|
// To simplify the analysis, we only support the case where the
|
|
// only aliases used after an overwrite are the aliases generated
|
|
// after plus the alias being overwritten.
|
|
availableAliases.clear();
|
|
availableAliases.insert(overwritten);
|
|
result.overwriteTensorContentsOps.push_back(overwrite);
|
|
} else {
|
|
return rewriter.notifyMatchFailure(
|
|
copyToNonValueTensor,
|
|
"unsupported op encountered during abstract analysis");
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
// Rewrite slice composed of the interpreted ops so that the slice uses
|
|
// value semantics everywhere.
|
|
static void rewriteSlice(const InterpretedOps &ops,
|
|
PatternRewriter &rewriter) {
|
|
// The rewriting for the overwrite op involves replacing all uses of its
|
|
// non-value tensor operand with its value tensor operand. Since the
|
|
// rewriting of other ops can potentially change the non-value tensor
|
|
// operand to a value tensor, this rewriting MUST happen first to avoid
|
|
// wrongly replacing operands that were previously not a view of the
|
|
// overwritten tensor.
|
|
for (OverwriteTensorContentsOp overwrite :
|
|
llvm::reverse(ops.overwriteTensorContentsOps)) {
|
|
Value overwritten = overwrite.overwritten();
|
|
assert(overwritten.getType().dyn_cast<NonValueTensorType>() &&
|
|
"the analysis assumes that overwritten remains a nonValueTensor "
|
|
"throughout the rewriting");
|
|
overwritten.replaceUsesWithIf(
|
|
overwrite.value(), [&](const OpOperand &operand) {
|
|
return !operand.getOwner()->isBeforeInBlock(overwrite);
|
|
});
|
|
rewriter.eraseOp(overwrite);
|
|
}
|
|
|
|
for (Operation *copyLikeOp : ops.copyLikeOps)
|
|
rewriter.replaceOp(copyLikeOp, copyLikeOp->getOperand(0));
|
|
|
|
// Replace return type of view-like ops with value-semantics type variant.
|
|
for (Operation *viewLikeOp : ops.viewLikeOps) {
|
|
rewriter.updateRootInPlace(viewLikeOp, [&] {
|
|
Value result = viewLikeOp->getResult(0);
|
|
auto resultType = result.getType().dyn_cast<NonValueTensorType>();
|
|
assert(resultType && "all view-like ops considered must have result of "
|
|
"type `NonValueTensorType` before rewriting");
|
|
result.setType(resultType.getWithValueSemantics());
|
|
});
|
|
}
|
|
}
|
|
|
|
LogicalResult matchAndRewrite(CopyToNonValueTensorOp copy,
|
|
PatternRewriter &rewriter) const override {
|
|
// Find a subgraph starting with this CopyToNonValueTensorOp, and
|
|
// terminating at CopyToValueTensorOp's, possibly with intervening view-like
|
|
// ops and overwrites. This also catches the special case of a
|
|
// CopyToNonValueTensorOp that trivially feeds into CopyToValueTensorOp's.
|
|
SmallVector<Operation *> nonValueTensorUsers;
|
|
auto workList = llvm::to_vector(copy.result().getUsers());
|
|
while (!workList.empty()) {
|
|
Operation *op = workList.pop_back_val();
|
|
if (op->getBlock() != copy->getBlock()) {
|
|
return rewriter.notifyMatchFailure(
|
|
copy, "can only analyze within a single basic block");
|
|
}
|
|
nonValueTensorUsers.push_back(op);
|
|
|
|
if (isViewLikeOp(op)) {
|
|
auto isTensor = [](const Value operand) {
|
|
return operand.getType().isa<BaseTensorType>();
|
|
};
|
|
|
|
// We currently only support view-like ops with one tensor input and one
|
|
// tensor output, meaning that the tensor use-def chains form a tree.
|
|
// This will not be the case for an op like `torch.aten.view_as`, so
|
|
// we will need to add a set to prune duplicate visitation.
|
|
if (llvm::count_if(op->getOperands(), isTensor) != 1 ||
|
|
llvm::count_if(op->getResults(), isTensor) != 1 ||
|
|
!isTensor(op->getOperand(0)) || !isTensor(op->getResult(0))) {
|
|
return rewriter.notifyMatchFailure(
|
|
copy, "unsupported: view-like ops must have one tensor input and "
|
|
"one tensor output, and the tensor input/output must be "
|
|
"the first operand/result");
|
|
}
|
|
|
|
llvm::append_range(workList, op->getResult(0).getUsers());
|
|
}
|
|
}
|
|
|
|
FailureOr<InterpretedOps> interpretedOps = abstractlyInterpretSlice(
|
|
copy, std::move(nonValueTensorUsers), rewriter);
|
|
if (failed(LogicalResult(interpretedOps)))
|
|
return failure();
|
|
rewriteSlice(*interpretedOps, rewriter);
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
namespace {
|
|
// Calculate a forward slice starting from a CopyToNonValueTensorOp
|
|
// and ending at CopyToValueTensorOp's. If all intervening ops
|
|
// are just view-like operations (i.e. no mutation), then we can trivially
|
|
// convert them all to value semantics.
|
|
// This pattern handles the case where views span multiple basic blocks,
|
|
// which is currently not supported by
|
|
// `AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock`.
|
|
class RewriteViewLikeSubgraph
|
|
: public OpRewritePattern<CopyToNonValueTensorOp> {
|
|
public:
|
|
using OpRewritePattern::OpRewritePattern;
|
|
LogicalResult matchAndRewrite(CopyToNonValueTensorOp copy,
|
|
PatternRewriter &rewriter) const override {
|
|
// Find a subgraph starting with this CopyToNonValueTensorOp, and
|
|
// terminating at CopyToValueTensorOp's, possibly with intervening view-like
|
|
// ops.
|
|
// This also catches the special case of a CopyToNonValueTensorOp that
|
|
// trivially feeds into CopyToValueTensorOp's.
|
|
SmallVector<Operation *> viewLikeOps;
|
|
SmallVector<CopyToValueTensorOp> copyToValueTensorOps;
|
|
auto workList = llvm::to_vector<6>(copy.getResult().getUsers());
|
|
// We currently only support view-like ops with one tensor input and one
|
|
// tensor output, meaning that the tensor use-def chains form a tree.
|
|
// This will not be the case for an op like `torch.aten.view_as`, so
|
|
// we will need to add a set to prune duplicate visitation.
|
|
while (!workList.empty()) {
|
|
Operation *op = workList.pop_back_val();
|
|
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(op)) {
|
|
copyToValueTensorOps.push_back(copyToValueTensor);
|
|
} else if (isViewLikeOp(op)) {
|
|
viewLikeOps.push_back(op);
|
|
llvm::append_range(workList, op->getResult(0).getUsers());
|
|
} else {
|
|
return rewriter.notifyMatchFailure(
|
|
copy, "can only handle these transitive user ops");
|
|
}
|
|
}
|
|
|
|
copy.replaceAllUsesWith(copy.getOperand());
|
|
for (CopyToValueTensorOp op : copyToValueTensorOps)
|
|
rewriter.replaceOp(op, op.getOperand());
|
|
for (Operation *op : viewLikeOps) {
|
|
rewriter.updateRootInPlace(op, [&]() {
|
|
if (auto nonValueTensorType =
|
|
op->getResult(0).getType().dyn_cast<NonValueTensorType>()) {
|
|
op->getResult(0).setType(nonValueTensorType.getWithValueSemantics());
|
|
}
|
|
});
|
|
}
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
namespace {
|
|
class MaximizeValueSemanticsPass
|
|
: public MaximizeValueSemanticsBase<MaximizeValueSemanticsPass> {
|
|
void runOnOperation() override {
|
|
MLIRContext *context = &getContext();
|
|
auto func = getOperation();
|
|
|
|
RewritePatternSet patterns(context);
|
|
patterns.insert<AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock,
|
|
RewriteViewLikeSubgraph>(context);
|
|
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
std::unique_ptr<OperationPass<FuncOp>>
|
|
mlir::torch::Torch::createMaximizeValueSemanticsPass() {
|
|
return std::make_unique<MaximizeValueSemanticsPass>();
|
|
}
|