mirror of https://github.com/llvm/torch-mlir
Iteratively run the main simplification pipeline.
This introduces a new pass LowerToBackendContract (better name very welcome) which performs the bulk of the simplifications that we do, such as - shape refinement - dtype refinement - maximizing value semantics - inlining global slots - decomposing complex ops The key difference from before is that it iterates the set of transformations, which can help to break a number of "catch-22" issues where one simplification depends on another, the latest example being here: https://github.com/llvm/torch-mlir/issues/1131 This also exposed that RefineTypes was sometimes crashing/asserting for certain inputs. This commit hardens it a bit.pull/1241/head
parent
9c8b962720
commit
57681f7947
|
@ -28,14 +28,20 @@ createPrepareForGlobalizeObjectGraphPass();
|
|||
|
||||
struct TorchLoweringPipelineOptions
|
||||
: public PassPipelineOptions<TorchLoweringPipelineOptions> {
|
||||
// If this option is true, then perform optimizations.
|
||||
// If this option is false, only do the bare minimum for correctness.
|
||||
Option<bool> optimize{*this, "optimize", llvm::cl::desc("Do optimizations."),
|
||||
llvm::cl::init(true)};
|
||||
|
||||
// The maximum number of invocations of the simplification pipeline in
|
||||
// LowerToBackendContract.
|
||||
Option<int> maxIterations{
|
||||
*this, "max-iterations",
|
||||
llvm::cl::desc(
|
||||
"Maximum number of invocations of the simplification pipeline."),
|
||||
llvm::cl::init(10)};
|
||||
// If this option is false, decompose complex operations.
|
||||
// If this option is true, skip decomposition of complex operations.
|
||||
Option<bool> decompose{*this, "decompose-complex-ops", llvm::cl::desc("Decompose complex operations."),
|
||||
// TODO: This should be replaced with a list of operations to decompose.
|
||||
// (or some other way to specify the set of allowed ops in the backend
|
||||
// contract)
|
||||
Option<bool> decompose{*this, "decompose-complex-ops",
|
||||
llvm::cl::desc("Decompose complex operations."),
|
||||
llvm::cl::init(true)};
|
||||
};
|
||||
|
||||
|
@ -50,10 +56,16 @@ void createTorchScriptModuleToTorchBackendPipeline(
|
|||
void createTorchFunctionToTorchBackendPipeline(
|
||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
|
||||
|
||||
/// Creates a pipeline that refines shapes of tensor operations in the program.
|
||||
void createTorchShapeRefinementPipeline(
|
||||
/// Creates a pipeline that simplifies the computations in the program.
|
||||
/// This pass does not do any global program restructuring -- it works entirely
|
||||
/// within a single semantic model of a `builtin.module` with
|
||||
/// `torch.global_slot` ops and `func.func` ops.
|
||||
void createTorchSimplificationPipeline(
|
||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
|
||||
|
||||
/// Creates a pipeline that refines shapes of tensor operations in the program.
|
||||
void createTorchShapeRefinementPipeline(OpPassManager &pm);
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createRefineTypesPass();
|
||||
|
@ -78,10 +90,10 @@ createSimplifyShapeCalculationsPass();
|
|||
std::unique_ptr<OperationPass<func::FuncOp>> createDropShapeCalculationsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createVerifyConversionToValueSemanticsPass();
|
||||
createEraseModuleInitializerPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createEraseModuleInitializerPass();
|
||||
createLowerToBackendContractPass(int maxIterations, bool decompose);
|
||||
|
||||
StringRef getShapeLibrary();
|
||||
|
||||
|
|
|
@ -253,18 +253,6 @@ def DropShapeCalculations : Pass<"torch-drop-shape-calculations", "func::FuncOp"
|
|||
}];
|
||||
}
|
||||
|
||||
def VerifyConversionToValueSemantics
|
||||
: Pass<"torch-verify-conversion-to-value-semantics", "ModuleOp"> {
|
||||
let summary = "Verify that all tensors have been converted to value semantics";
|
||||
let constructor =
|
||||
"mlir::torch::Torch::createVerifyConversionToValueSemanticsPass()";
|
||||
let description = [{
|
||||
Prior passes in the pipeline may have missed converting all tensors to value
|
||||
semantics and we wish to catch such failures early instead of fixing
|
||||
individual cases downstream.
|
||||
}];
|
||||
}
|
||||
|
||||
def EraseModuleInitializer
|
||||
: Pass<"torch-erase-module-initializer", "ModuleOp"> {
|
||||
let summary = "Erase the `torch.global_slot.module_initializer` op.";
|
||||
|
@ -273,9 +261,64 @@ def EraseModuleInitializer
|
|||
let description = [{
|
||||
Backends cannot currently handle module initializers, so we omit them from
|
||||
our backend contract. This pass removes the
|
||||
`torch.global_slot.module_initializer` op from the module if legal, or
|
||||
raises an error.
|
||||
`torch.global_slot.module_initializer` op from the module if legal.
|
||||
}];
|
||||
}
|
||||
|
||||
def LowerToBackendContract
|
||||
: Pass<"torch-lower-to-backend-contract", "ModuleOp"> {
|
||||
let summary = "Perform simplifications until the backend contract is satisfied.";
|
||||
let constructor = [{
|
||||
mlir::torch::Torch::createLowerToBackendContractPass(
|
||||
/*maxIterations=*/10, /*decompose=*/true)
|
||||
}];
|
||||
let description = [{
|
||||
This pass performs the bulk of the lowering of the program's computations
|
||||
to the backend contract. This pass does not do any global program
|
||||
restructuring -- it works entirely within a single semantic model
|
||||
of a `builtin.module` with `torch.global_slot` ops and `func.func` ops.
|
||||
|
||||
This pass runs a set of simplifications within that semantic model until
|
||||
the backend contract is satisfied, and fails if it cannot be satisfied.
|
||||
In particular, the backend contract consists of:
|
||||
- Tensors
|
||||
- Have been converted to value semantics.
|
||||
- Have at least a known rank, though ideally a maximally inferred shape.
|
||||
- Have a known dtype.
|
||||
- `torch.global_slot`'s have been eliminated from the program.
|
||||
- Ops have been decomposed.
|
||||
|
||||
This particular choice of backend contract was born out of a common set of
|
||||
requirements from backends, along with aligning with long-term PyTorch
|
||||
direction of being more tracing-based. The set of simplifications performed
|
||||
here can be thought of as simulating the kinds of simplifications that
|
||||
happen naturally as part of tracing, but in a way that is applicable
|
||||
to our TorchScript frontend. For the LazyTensorCore frontend, the backend
|
||||
contract trivially holds (except for certain decompositions).
|
||||
|
||||
Generally it is not desirable to have a compiler where successful
|
||||
compilation depends on "optimizing hard enough", but in this case, there
|
||||
seems to be enough alignment and recognition in the industry that the
|
||||
Python-based programming model in the source program is too dynamic
|
||||
to feasibly handle in totality without a tracing approach that has access
|
||||
to the source program to re-trace in the face of dynamism (e.g. the ability
|
||||
to do what TorchDynamo calls "graph break"). We are attempting to maintain
|
||||
a practical compiler that works well given the current set of constraints
|
||||
of the TorchScript frontend that PyTorch provides us, and are working to
|
||||
co-design PyTorch's direction so that we land in a place where most of this
|
||||
"optimizing hard enough" is not necessary.
|
||||
}];
|
||||
let options = [
|
||||
Option<"maxIterations", "max-iterations", "int", /*default=*/"10",
|
||||
"Maximum number of invocations of the simplification pipeline.">,
|
||||
// TODO: Make this a configurable set of ops.
|
||||
Option<"decompose", "decompose", "bool", /*default=*/"true",
|
||||
"Decompose ops.">
|
||||
|
||||
];
|
||||
// TODO: Debug why this is needed, even though the input program has func.func
|
||||
// ops in it.
|
||||
let dependentDialects = ["func::FuncDialect"];
|
||||
}
|
||||
|
||||
#endif // TORCHMLIR_TORCH_PASSES
|
||||
|
|
|
@ -6,6 +6,7 @@ add_mlir_library(TorchMLIRTorchPasses
|
|||
Passes.cpp
|
||||
GlobalizeObjectGraph.cpp
|
||||
InlineGlobalSlots.cpp
|
||||
LowerToBackendContract.cpp
|
||||
MaximizeValueSemantics.cpp
|
||||
PrepareForGlobalizeObjectGraph.cpp
|
||||
ReduceOpVariants.cpp
|
||||
|
@ -14,7 +15,6 @@ add_mlir_library(TorchMLIRTorchPasses
|
|||
ReifyShapeCalculations.cpp
|
||||
ShapeLibrary.cpp
|
||||
SimplifyShapeCalculations.cpp
|
||||
VerifyConversionToValueSemantics.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/Torch/Transforms
|
||||
|
|
|
@ -27,18 +27,15 @@ namespace {
|
|||
class EraseModuleInitializerPass
|
||||
: public EraseModuleInitializerBase<EraseModuleInitializerPass> {
|
||||
void runOnOperation() override {
|
||||
auto walkResult = getOperation().walk([](GlobalSlotModuleInitializerOp op) {
|
||||
for (auto initializer :
|
||||
getOperation().getOps<GlobalSlotModuleInitializerOp>()) {
|
||||
auto intialize =
|
||||
cast<InitializeGlobalSlotsOp>(op.getBody()->getTerminator());
|
||||
if (intialize.getNumOperands() != 0) {
|
||||
op.emitError("could not erase non-empty module initializer");
|
||||
return WalkResult::interrupt();
|
||||
cast<InitializeGlobalSlotsOp>(initializer.getBody()->getTerminator());
|
||||
if (intialize.getNumOperands() == 0) {
|
||||
initializer.erase();
|
||||
}
|
||||
op.erase();
|
||||
return WalkResult::advance();
|
||||
});
|
||||
if (walkResult.wasInterrupted()) {
|
||||
return signalPassFailure();
|
||||
// The verifier ensures there is only one GlobalSlotModuleInitializerOp.
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
@ -0,0 +1,247 @@
|
|||
//===- LowerToBackendContract.cpp --------------------------------*- C++-*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#define DEBUG_TYPE "torch-lower-to-backend-contract"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Checking the backend contract.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult checkType(Operation *op, Type type,
|
||||
bool actuallyEmitDiagnostics) {
|
||||
// Allow various scalar types that backends are expected to be able to handle.
|
||||
if (type.isa<Torch::IntType, Torch::FloatType, Torch::BoolType>())
|
||||
return success();
|
||||
|
||||
// Backends are not expected to support dynamic computations on these types,
|
||||
// but they frequently appear as parameters to ops which backends
|
||||
// can statically pattern match and eliminate from the program.
|
||||
// For example, a tensor operand might be optional, and the backend
|
||||
// will pattern-match statically whether it is passed as a tensor or None.
|
||||
if (type.isa<Torch::NoneType, Torch::StringType>())
|
||||
return success();
|
||||
|
||||
// We blanket prohibit non-value-semantic tensors.
|
||||
// All of our backends are currently based on value-semantic tensors, so
|
||||
// we consider it our responsibility to lower all non-value-semantic tensors
|
||||
// to value-semantic tensors.
|
||||
if (type.isa<NonValueTensorType>()) {
|
||||
if (actuallyEmitDiagnostics) {
|
||||
return op
|
||||
->emitError("unsupported by backend contract: non-value tensor type")
|
||||
.attachNote()
|
||||
.append("this is likely due to a missing case in the "
|
||||
"MaximizeValueSemantics pass");
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
// For value-semantic tensors, we require at least a known rank and dtype.
|
||||
// We are not aware of a situation where our backends can handle an unranked
|
||||
// tensor type or a tensor with a dynamic dtype.
|
||||
//
|
||||
// There are somewhat fundamental reasons for this. In particular, the problem
|
||||
// of unranked codegen is completely different from the problem of ranked
|
||||
// codegen (since ranked corresponds to a fixed loop nest structure). For all
|
||||
// codegen systems we are aware of, the program must be reduced to operate
|
||||
// on ranked tensors at some point in compilation, and we are not aware of
|
||||
// any backend with a general solution to this problem before it reaches
|
||||
// codegen. So we consider it our responsibility to eliminate unranked tensor
|
||||
// from the program.
|
||||
//
|
||||
// We aren't aware of any backend with any infrastructure to represent dynamic
|
||||
// dtypes, let alone transform and optimize them. Additionally, it is unlikely
|
||||
// that any backend, even if it supports dynamic dtypes in some form, will
|
||||
// have an sufficiently rich system for representing PyTorch type promotion
|
||||
// rules. So we consider it our responsibility to ensure that all dtypes are
|
||||
// statically known.
|
||||
if (auto tensorType = type.dyn_cast<ValueTensorType>()) {
|
||||
if (!tensorType.hasSizes()) {
|
||||
if (actuallyEmitDiagnostics) {
|
||||
return op
|
||||
->emitError(
|
||||
"unsupported by backend contract: tensor with unknown rank")
|
||||
.attachNote()
|
||||
.append("this is likely due to a missing shape transfer function "
|
||||
"in shape_lib_gen.py");
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
if (!tensorType.hasDtype()) {
|
||||
if (actuallyEmitDiagnostics) {
|
||||
return op
|
||||
->emitError(
|
||||
"unsupported by backend contract: tensor with unknown dtype")
|
||||
.attachNote()
|
||||
.append("this is likely due to a missing case in RefineTypes");
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
// Optional types are also in the category of types which we don't expect
|
||||
// backends to dynamically compute with, but they can be pattern matched
|
||||
// in many cases that are practically necessary.
|
||||
if (auto optionalType = type.dyn_cast<OptionalType>()) {
|
||||
// TODO: Be stricter about tensor types.
|
||||
// See comment below for ListType.
|
||||
if (optionalType.getContainedType().isa<ValueTensorType>())
|
||||
return success();
|
||||
return checkType(op, optionalType.getContainedType(),
|
||||
actuallyEmitDiagnostics);
|
||||
}
|
||||
// List types are also in the category of types which we don't expect
|
||||
// backends to dynamically compute with, but they can be pattern matched
|
||||
// in many cases that are practically necessary. For example, the
|
||||
// strides of a convolution op are represented as a list.
|
||||
if (auto listType = type.dyn_cast<ListType>()) {
|
||||
// TODO: Be stricter about tensor types.
|
||||
// For the moment, there are cases (such as for torch.cat) where we end
|
||||
// up with `!torch.list<vtensor>` which doesn't have shape or dtype in
|
||||
// the contained type information. Somehow this slips through and works.
|
||||
// We should be stricter about this and properly infer the contained type
|
||||
// and shape.
|
||||
if (listType.getContainedType().isa<ValueTensorType>())
|
||||
return success();
|
||||
return checkType(op, listType.getContainedType(), actuallyEmitDiagnostics);
|
||||
}
|
||||
// Tuple types are also in the category of types which we don't expect
|
||||
// backends to dynamically compute with, but they can be pattern matched
|
||||
// in many cases that are practically necessary.
|
||||
if (auto tupleType = type.dyn_cast<Torch::TupleType>()) {
|
||||
for (auto containedType : tupleType.getContainedTypes()) {
|
||||
if (failed(checkType(op, containedType, actuallyEmitDiagnostics)))
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
// Unsupported type.
|
||||
if (actuallyEmitDiagnostics) {
|
||||
return op->emitError("unsupported by backend contract: type ") << type;
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
static bool satisfiesBackendContract(ModuleOp module,
|
||||
bool actuallyEmitDiagnostics = false) {
|
||||
// We do not permit `torch.global_slot`'s in the backend contract, since
|
||||
// support for them is not widespread, and this does not align with PyTorch's
|
||||
// more tracing-based direction.
|
||||
//
|
||||
// We just check for the GlobalSlotModuleInitializerOp since its verifier
|
||||
// ensures that the set of global slots matches those initialized by the
|
||||
// module initializer.
|
||||
auto walkResult0 = module.walk([&](Torch::GlobalSlotModuleInitializerOp op) {
|
||||
if (actuallyEmitDiagnostics) {
|
||||
// Report the error on the terminator to avoid dumping the whole
|
||||
// initializer itself, which can have pages of ops in it.
|
||||
op.getBody()
|
||||
->getTerminator()
|
||||
->emitError("unsupported by backend contract: module initializers")
|
||||
.attachNote()
|
||||
.append("this is likely due to InlineGlobalSlots being unable to "
|
||||
"inline a global slot");
|
||||
}
|
||||
return WalkResult::interrupt();
|
||||
});
|
||||
if (walkResult0.wasInterrupted())
|
||||
return false;
|
||||
|
||||
// Check all the type of all Value's in the program.
|
||||
//
|
||||
// A pre-order walk gives a more intuitive "first error".
|
||||
// TODO: Should we report more than the first error?
|
||||
// How do we avoid making it too spammy?
|
||||
auto walkResult1 = module.walk<WalkOrder::PreOrder>([&](Block *block) {
|
||||
for (BlockArgument arg : block->getArguments())
|
||||
if (failed(checkType(block->getParentOp(), arg.getType(),
|
||||
actuallyEmitDiagnostics))) {
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
for (Operation &op : *block)
|
||||
for (OpResult result : op.getResults())
|
||||
if (failed(checkType(&op, result.getType(), actuallyEmitDiagnostics)))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
if (walkResult1.wasInterrupted())
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
class LowerToBackendContractPass
|
||||
: public LowerToBackendContractBase<LowerToBackendContractPass> {
|
||||
public:
|
||||
LowerToBackendContractPass() = default;
|
||||
LowerToBackendContractPass(int maxIterations, bool decompose) {
|
||||
this->maxIterations = maxIterations;
|
||||
this->decompose = decompose;
|
||||
}
|
||||
void runOnOperation() override {
|
||||
ModuleOp module = getOperation();
|
||||
|
||||
OpPassManager pm(module.getOperationName());
|
||||
TorchLoweringPipelineOptions options;
|
||||
options.decompose = decompose;
|
||||
createTorchSimplificationPipeline(pm, options);
|
||||
|
||||
int i = 0;
|
||||
do {
|
||||
if (i++ == maxIterations) {
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs() << "LowerToBackendContractPass: "
|
||||
<< "failed to satisfy backend contract after "
|
||||
<< maxIterations
|
||||
<< " iterations of the simplification pipeline\n";
|
||||
});
|
||||
// Show the diagnostics.
|
||||
(void)satisfiesBackendContract(module,
|
||||
/*actuallyEmitDiagnostics=*/true);
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
if (failed(runPipeline(pm, module)))
|
||||
return signalPassFailure();
|
||||
} while (!satisfiesBackendContract(module));
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs() << "LowerToBackendContractPass: "
|
||||
<< "succeeded after " << i
|
||||
<< " iterations of the simplification pipeline\n";
|
||||
});
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::torch::Torch::createLowerToBackendContractPass(int maxIterations,
|
||||
bool decompose) {
|
||||
return std::make_unique<LowerToBackendContractPass>(maxIterations, decompose);
|
||||
}
|
|
@ -31,6 +31,10 @@ void mlir::torch::registerTorchPasses() {
|
|||
"Pipeline lowering a Torch function to Torch backend form.",
|
||||
mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline);
|
||||
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||
"torch-simplification-pipeline",
|
||||
"Pipeline simplifying computations in the program.",
|
||||
mlir::torch::Torch::createTorchSimplificationPipeline);
|
||||
mlir::PassPipelineRegistration<>(
|
||||
"torch-shape-refinement-pipeline", "Pipeline refining shapes of tensors.",
|
||||
mlir::torch::Torch::createTorchShapeRefinementPipeline);
|
||||
}
|
||||
|
@ -66,131 +70,82 @@ void mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline(
|
|||
|
||||
void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
|
||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
|
||||
// General considerations: As a matter of bring-up, we are simultaneously
|
||||
// building out the frontend pipeline and also co-developing the backend
|
||||
// support story as well. This means that sometimes the most expedient way to
|
||||
// support a given program is to "optimize hard enough" that the parts of the
|
||||
// program that touch unimplemented backend support go away (constant folded,
|
||||
// dead-code-eliminated, etc.). In the fullness of time, most of that
|
||||
// optimization should not be necessary, and we should have an "O0" pipeline
|
||||
// that runs practically no optimizations.
|
||||
// However, as a matter of expediency, at the moment we do run those
|
||||
// optimizations. We guard those passes under the `options.optimize` option
|
||||
// (which default to true, currently). We leave notes with the `OPT-ONLY` tag
|
||||
// why we currently need that pass for correctness.
|
||||
// We should eventually remove those passes from the default pipeline once
|
||||
// backends have enough support.
|
||||
// In particular the following features are needed in some form from backends:
|
||||
// - Error handling (RaiseException + error string formatting)
|
||||
// - First-class list type
|
||||
// - torch.global_slot lowering
|
||||
// - ...
|
||||
// Please try to keep this list somewhat up to date when adding
|
||||
// "optimize hard enough that it works" transformations.
|
||||
|
||||
// Incorporate user annotations and remove signature Python-isms.
|
||||
pm.addPass(createAdjustCallingConventionsPass());
|
||||
// Perform the bulk of lowering to the backend contract.
|
||||
// See the pass documentation for more information.
|
||||
pm.addPass(createLowerToBackendContractPass(options.maxIterations,
|
||||
options.decompose));
|
||||
}
|
||||
|
||||
// TODO: Remove options.optimize and this OPT-ONLY stuff -- we are already way
|
||||
// past the point of no return for it being necessary for functional
|
||||
// correctness.
|
||||
if (options.optimize) {
|
||||
// Eliminate the PrimTupleIndexOp generated from the
|
||||
// adjustCallingConventions
|
||||
// A simplification pipeline to establish the invariants of the backend
|
||||
// contract (see `satisfiedBackendContract` in `LowerToBackendContract`).
|
||||
//
|
||||
// We structure this so that a single run of this pipeline is enough for
|
||||
// most models, but it is possible for it to take multiple runs to fully
|
||||
// clean things up when there are cyclic dependencies between certain
|
||||
// simplifications, such as a decomposition relying on shape refinement which
|
||||
// depends on another decomposition.
|
||||
//
|
||||
// Although technically this pipeline is an implementation detail of
|
||||
// LowerToBackendContract, we expose it here to help debugging.
|
||||
//
|
||||
// LowerToBackendContract will run this pipeline as many times as necessary, but
|
||||
// in general, it is costly to re-run this pipeline, since all the passes do
|
||||
// O(module size) work. We want the number of iterations of this pipeline
|
||||
// to be bounded by meaningful "always in practice small" program properties,
|
||||
// such as loop nesting depth, number of sequentially dependent steps of
|
||||
// constant global slots proving that other global slots are dead, etc.
|
||||
//
|
||||
// It is generally always possible to construct a pathological input that will
|
||||
// exceed the number of iterations. If we do find practical cases with
|
||||
// O(module size) number of iterations of this simplification pipeline, then
|
||||
// we may need to adjust the approach, such as to do some of the transformations
|
||||
// together at finer granularity.
|
||||
void mlir::torch::Torch::createTorchSimplificationPipeline(
|
||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
|
||||
// General cleanup.
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
// Inline global slots, which for most inference scenarios deletes them.
|
||||
// This also exposes more information to intraprocedural transformations
|
||||
// below like MaximizeValueSemantics and RefineTypes.
|
||||
// OPT-ONLY: Don't rely on this pass to "lower" global slots by deleting.
|
||||
// Also don't rely on this pass to expose constants into the program to
|
||||
// simplify handling of "optional".
|
||||
// Inline global slots to expose a bunch of simplification opportunities
|
||||
// from constant hyperparameters, weights, etc.
|
||||
pm.addPass(createInlineGlobalSlotsPass());
|
||||
// After doing a first round of inlining global slots, canonicalize again to
|
||||
// take advantage of optimization opportunities exposed by the inlined
|
||||
// global slots. In particular, this is functionally necessary now because
|
||||
// large amounts of control flow are guarded by an "is training" flag, so
|
||||
// inlining removes certain mutating operations done on the slots enabling
|
||||
// them to be deleted.
|
||||
// TODO: In full generality, we need to do a fixed-point iteration of
|
||||
// shape inference, maximizing value semantics, decomposition, inling global
|
||||
// slots, and canonicalization.
|
||||
// Erase the module initializer if we have proven that all the global slots
|
||||
// are gone.
|
||||
pm.addPass(createEraseModuleInitializerPass());
|
||||
// Clean up again to avoid needing to to back around the fixed-point
|
||||
// iteration.
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
// Inline again, cleaning up any remaining global slots that might be dead
|
||||
// now.
|
||||
pm.addPass(createInlineGlobalSlotsPass());
|
||||
// Erase the module initializers (or fail compilation), since they aren't
|
||||
// permitted in our backend contract at the moment.
|
||||
pm.addPass(Torch::createEraseModuleInitializerPass());
|
||||
}
|
||||
|
||||
// Reduce variants of ops to a smaller set of primitives.
|
||||
pm.addNestedPass<func::FuncOp>(createReduceOpVariantsPass());
|
||||
|
||||
if (options.optimize) {
|
||||
// OPT-ONLY: Right now we rely on this to eliminate certain branches that
|
||||
// guard unreachable code that backends can't handle yet, such as lists,
|
||||
// RaiseException, unimplemented tensor ops, and only-used-in-training
|
||||
// operations on `torch.global_slot`'s.
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
// OPT-ONLY: We may have deleted some `torch.global_slot.get` /
|
||||
// `torch.global_slot.get` ops, which may have left more
|
||||
// `torch.global_slot`'s unused.
|
||||
// Remove dead global slots.
|
||||
pm.addPass(createSymbolDCEPass());
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Lowering to ranked !torch.vtensors of known dtype.
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
// Convert the bulk of non-ABI-visible !torch.tensor's to !torch.vtensor's.
|
||||
pm.addNestedPass<func::FuncOp>(Torch::createMaximizeValueSemanticsPass());
|
||||
|
||||
// Update the return op to return value tensors and remove dead ops.
|
||||
// Update the return op to return value tensors.
|
||||
pm.addPass(Torch::createRefinePublicReturnPass());
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
|
||||
// Ensure that all tensors have been converted to value semantics.
|
||||
pm.addPass(Torch::createVerifyConversionToValueSemanticsPass());
|
||||
|
||||
// Do shape refinement.
|
||||
// This must be run before RefineTypes (which primarily does dtype inference),
|
||||
// because Torch type promotion rules actually depend on the shape of the
|
||||
// operand.
|
||||
createTorchShapeRefinementPipeline(pm, options);
|
||||
// This should be run before RefineTypes (which primarily does dtype
|
||||
// inference), because Torch type promotion rules actually depend on the shape
|
||||
// of the operand.
|
||||
createTorchShapeRefinementPipeline(pm);
|
||||
// Refine types in the program, which mainly means inferring dtypes of ops.
|
||||
pm.addNestedPass<func::FuncOp>(Torch::createRefineTypesPass());
|
||||
|
||||
// Propagate to ABI return types the shape/dtype information discovered by
|
||||
// the previous pass. Doing this is ABI-compatible for our backends.
|
||||
pm.addPass(Torch::createRefinePublicReturnPass());
|
||||
|
||||
if (options.optimize) {
|
||||
// This can fold away some branches given the information got from
|
||||
// RefineTypes before doing maximize value sematics which only works with
|
||||
// basic blocks.
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
}
|
||||
|
||||
if (options.optimize) {
|
||||
// All the type refinement we've done above has exposed new information
|
||||
// that allows folding away more stuff.
|
||||
// OPT-ONLY: Right now we rely on this to eliminate certain
|
||||
// branches that guard unreachable code that backends can't handle yet, such
|
||||
// as lists, RaiseException, unimplemented aten ops, and
|
||||
// only-used-in-training operations on `torch.global_slot`'s.
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
}
|
||||
|
||||
if (options.decompose) {
|
||||
pm.addNestedPass<func::FuncOp>(Torch::createDecomposeComplexOpsPass());
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
}
|
||||
|
||||
// TODO: VerifyTorchBackendContractPass.
|
||||
}
|
||||
|
||||
void mlir::torch::Torch::createTorchShapeRefinementPipeline(
|
||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
|
||||
void mlir::torch::Torch::createTorchShapeRefinementPipeline(OpPassManager &pm) {
|
||||
// Reify the shape functions for each op that is present in the shape library.
|
||||
pm.addPass(Torch::createReifyShapeCalculationsPass());
|
||||
|
||||
|
|
|
@ -1152,7 +1152,7 @@ void TypeAnalysis::visitAtenEmbeddingBagOp(Operation *op) {
|
|||
resultIntKnowledge.dtype =
|
||||
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
|
||||
|
||||
for (int64_t i = 1; i < 4; i++) {
|
||||
for (int64_t i = 1, e = op->getNumResults(); i < e; i++) {
|
||||
incorporateKnowledge(op->getResult(i), resultIntKnowledge);
|
||||
}
|
||||
return;
|
||||
|
@ -1259,6 +1259,12 @@ void TypeAnalysis::visitAtenTensorOp(AtenTensorOp op) {
|
|||
while (auto listType = type.dyn_cast<ListType>()) {
|
||||
type = listType.getContainedType();
|
||||
}
|
||||
// TODO: Support tensor as the contained type of the list.
|
||||
// These are the only types handled by fillInDTypeGivenDTypeAndDataType below.
|
||||
if (!type.isa<Torch::FloatType, Torch::IntType, Torch::BoolType>()) {
|
||||
incorporateKnowledge(op.getResult(), knowledge);
|
||||
return;
|
||||
}
|
||||
fillInDTypeGivenDTypeAndDataType(knowledge, dtype, type);
|
||||
incorporateKnowledge(op.getResult(), knowledge);
|
||||
}
|
||||
|
@ -1418,13 +1424,13 @@ static Type getMostRefinedStaticType(Value v, DataFlowSolver &solver) {
|
|||
};
|
||||
if (auto tensorType = v.getType().dyn_cast<BaseTensorType>()) {
|
||||
const ValueState *latticeElement = solver.lookupState<ValueState>(v);
|
||||
if (!latticeElement)
|
||||
if (!latticeElement || latticeElement->isUninitialized())
|
||||
return nullptr;
|
||||
const ValueKnowledge &knowledge = latticeElement->getValue();
|
||||
return getRefinedTensorType(tensorType, knowledge);
|
||||
} else if (auto optionalType = v.getType().dyn_cast<OptionalType>()) {
|
||||
const ValueState *latticeElement = solver.lookupState<ValueState>(v);
|
||||
if (!latticeElement)
|
||||
if (!latticeElement || latticeElement->isUninitialized())
|
||||
return nullptr;
|
||||
const ValueKnowledge &knowledge = latticeElement->getValue();
|
||||
if (knowledge.optional == OptionalKnowledge::isNone)
|
||||
|
@ -1438,7 +1444,7 @@ static Type getMostRefinedStaticType(Value v, DataFlowSolver &solver) {
|
|||
}
|
||||
} else if (auto scalarType = v.getType().dyn_cast<NumberType>()) {
|
||||
const ValueState *latticeElement = solver.lookupState<ValueState>(v);
|
||||
if (!latticeElement)
|
||||
if (!latticeElement || latticeElement->isUninitialized())
|
||||
return nullptr;
|
||||
const ValueKnowledge &knowledge = latticeElement->getValue();
|
||||
if (knowledge.kind == torch_upstream::TypeKind::IntType)
|
||||
|
|
|
@ -1,56 +0,0 @@
|
|||
//===- VerifyConversionToValueSemantics.cpp ----------------------*- C++-*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
static LogicalResult checkValueType(Operation *op, Value value) {
|
||||
auto isNotValueTensorType = value.getType().isa<NonValueTensorType>();
|
||||
return isNotValueTensorType
|
||||
? op->emitError(
|
||||
"found a non-value tensor type, this is likely due to a "
|
||||
"missing case in the MaximizeValueSemantics pass")
|
||||
: success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
class VerifyConversionToValueSemanticsPass
|
||||
: public VerifyConversionToValueSemanticsBase<
|
||||
VerifyConversionToValueSemanticsPass> {
|
||||
void runOnOperation() override {
|
||||
auto walkResult = getOperation().walk([&](Block *block) {
|
||||
for (BlockArgument arg : block->getArguments())
|
||||
if (failed(checkValueType(block->getParentOp(), arg)))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
for (Operation &op : *block)
|
||||
for (OpResult result : op.getResults())
|
||||
if (failed(checkValueType(&op, result)))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
if (walkResult.wasInterrupted())
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::torch::Torch::createVerifyConversionToValueSemanticsPass() {
|
||||
return std::make_unique<VerifyConversionToValueSemanticsPass>();
|
||||
}
|
|
@ -77,7 +77,6 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
|
|||
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());
|
||||
pm.addNestedPass<func::FuncOp>(memref::createExpandOpsPass());
|
||||
|
||||
if (options.optimize) {
|
||||
// Clean up any non-canonical code introduced above..
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
// Resolve `dim` ops on tensors (which currently live in the `memref`
|
||||
|
@ -86,7 +85,6 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
|
|||
memref::createResolveShapedTypeResultDimsPass());
|
||||
// The resolution of `dim` ops tends to create identical ops. CSE them.
|
||||
pm.addNestedPass<func::FuncOp>(createCSEPass());
|
||||
}
|
||||
|
||||
// Finish the type conversion from `torch` types to the types of the
|
||||
// linalg-on-tensors backend contract.
|
||||
|
@ -111,12 +109,10 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline(
|
|||
// Perform rank broadcasting so TosaToLinalg pass works
|
||||
pm.addNestedPass<func::FuncOp>(createTosaMakeBroadcastablePass());
|
||||
|
||||
if (options.optimize) {
|
||||
// Clean up any non-canonical code introduced above..
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
// The resolution of `dim` ops tends to create identical ops. CSE them.
|
||||
pm.addNestedPass<func::FuncOp>(createCSEPass());
|
||||
}
|
||||
|
||||
// Finish the type conversion from `torch` types to the types of the
|
||||
// TOSA backend contract.
|
||||
|
@ -140,21 +136,17 @@ void TorchConversion::createTorchBackendToMhloBackendPipeline(
|
|||
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToMhloPass());
|
||||
|
||||
if (options.optimize) {
|
||||
// Clean up any non-canonical code introduced above..
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
// The resolution of `dim` ops tends to create identical ops. CSE them.
|
||||
pm.addNestedPass<func::FuncOp>(createCSEPass());
|
||||
}
|
||||
|
||||
// Convert CHLO ops to MHLO ops
|
||||
pm.addNestedPass<func::FuncOp>(mhlo::createChloLegalizeToHloPass());
|
||||
if (options.optimize) {
|
||||
// Clean up any non-canonical code introduced above..
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
// The resolution of `dim` ops tends to create identical ops. CSE them.
|
||||
pm.addNestedPass<func::FuncOp>(createCSEPass());
|
||||
}
|
||||
|
||||
// Finish the type conversion from `torch` types to the types of the
|
||||
// MHLO backend contract.
|
||||
|
|
|
@ -6,15 +6,3 @@ torch.global_slot.module_initializer {
|
|||
torch.initialize.global_slots [
|
||||
]
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
torch.global_slot @slot0 : !torch.int
|
||||
|
||||
// expected-error@+1 {{could not erase non-empty module initializer}}
|
||||
torch.global_slot.module_initializer {
|
||||
%0 = torch.constant.int 0
|
||||
torch.initialize.global_slots [
|
||||
@slot0(%0: !torch.int)
|
||||
]
|
||||
}
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
// RUN: torch-mlir-opt -torch-lower-to-backend-contract -split-input-file -verify-diagnostics %s
|
||||
|
||||
torch.global_slot.module_initializer {
|
||||
%0 = torch.constant.int 1
|
||||
// expected-error @+2 {{unsupported by backend contract: module initializers}}
|
||||
// expected-note @+1 {{this is likely due to}}
|
||||
torch.initialize.global_slots [
|
||||
@slot0(%0 : !torch.int)
|
||||
]
|
||||
}
|
||||
torch.global_slot @slot0 : !torch.int
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+2 {{unsupported by backend contract: non-value tensor type}}
|
||||
// expected-note @+1 {{this is likely due to}}
|
||||
func.func @f(%arg0: !torch.tensor) {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+2 {{unsupported by backend contract: tensor with unknown rank}}
|
||||
// expected-note @+1 {{this is likely due to}}
|
||||
func.func @f(%arg0: !torch.vtensor<*,f32>) {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+2 {{unsupported by backend contract: tensor with unknown dtype}}
|
||||
// expected-note @+1 {{this is likely due to}}
|
||||
func.func @f(%arg0: !torch.vtensor<[],unk>) {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{unsupported by backend contract: type '!torch.any'}}
|
||||
func.func @f(%arg0: !torch.any) {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: checking of op results.
|
||||
// TODO: In theory we could diagnose every single value, but for now we bail out on the first one.
|
||||
|
||||
func.func @f(%arg0: !torch.bool, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[7],f32>) -> !torch.vtensor<*,f32> {
|
||||
// expected-error @+2 {{unsupported by backend contract: tensor with unknown rank}}
|
||||
// expected-note @+1 {{this is likely due to}}
|
||||
%0 = torch.prim.If %arg0 -> (!torch.vtensor<*,f32>) {
|
||||
%1 = torch.tensor_static_info_cast %arg1 : !torch.vtensor<[],f32> to !torch.vtensor<*,f32>
|
||||
torch.prim.If.yield %1 : !torch.vtensor<*,f32>
|
||||
} else {
|
||||
%2 = torch.tensor_static_info_cast %arg2 : !torch.vtensor<[7],f32> to !torch.vtensor<*,f32>
|
||||
torch.prim.If.yield %2 : !torch.vtensor<*,f32>
|
||||
}
|
||||
return %0 : !torch.vtensor<*,f32>
|
||||
}
|
|
@ -240,3 +240,35 @@ func.func @prim.dtype(%arg: !torch.vtensor<*,bf16>) -> !torch.vtensor<*,unk> {
|
|||
|
||||
return %result2 : !torch.vtensor<*,unk>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that we don't crash on this input.
|
||||
|
||||
// CHECK-LABEL: func.func @forward
|
||||
func.func @forward() -> !torch.vtensor {
|
||||
%false = torch.constant.bool false
|
||||
%none = torch.constant.none
|
||||
%0 = torch.prim.ListConstruct : () -> !torch.list<tensor>
|
||||
// CHECK: torch.aten.tensor
|
||||
%1 = torch.aten.tensor %0, %none, %none, %false : !torch.list<tensor>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that we don't crash on this input.
|
||||
// TODO: This appears to result in aten.mul.Tensor not being visited.
|
||||
// We should investigate why that happens.
|
||||
|
||||
// CHECK-LABEL: func.func @forward
|
||||
func.func @forward(%arg0: !torch.bool, %arg1: !torch.tensor) {
|
||||
%0 = torch.prim.If %arg0 -> (!torch.tensor) {
|
||||
torch.prim.If.yield %arg1 : !torch.tensor
|
||||
} else {
|
||||
torch.prim.If.yield %arg1 : !torch.tensor
|
||||
}
|
||||
%1 = torch.copy.to_vtensor %0 : !torch.vtensor
|
||||
%2 = torch.aten.mul.Tensor %1, %1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1,9 +0,0 @@
|
|||
// RUN: torch-mlir-opt -split-input-file -verify-diagnostics %s -torch-verify-conversion-to-value-semantics
|
||||
|
||||
// -----
|
||||
|
||||
func.func @result_is_non_value_tensor(%arg: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> {
|
||||
// @expected-error@+1 {{found a non-value tensor type, this is likely due to a missing case in the MaximizeValueSemantics pass}}
|
||||
%neg = torch.aten.neg %arg : !torch.vtensor<[2],f32> -> !torch.tensor
|
||||
return %arg : !torch.vtensor<[2],f32>
|
||||
}
|
|
@ -182,7 +182,6 @@ cc_library(
|
|||
"lib/Dialect/Torch/Transforms/ShapeLibrary.cpp",
|
||||
"lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp",
|
||||
"lib/Dialect/Torch/Transforms/PassDetail.h",
|
||||
"lib/Dialect/Torch/Transforms/VerifyConversionToValueSemantics.cpp",
|
||||
],
|
||||
hdrs = [
|
||||
"include/torch-mlir/Dialect/Torch/Transforms/Passes.h",
|
||||
|
|
Loading…
Reference in New Issue