Iteratively run the main simplification pipeline.

This introduces a new pass LowerToBackendContract (better name very
welcome) which performs the bulk of the simplifications that we do,
such as
- shape refinement
- dtype refinement
- maximizing value semantics
- inlining global slots
- decomposing complex ops

The key difference from before is that it iterates the set of
transformations, which can help to break a number of "catch-22" issues
where one simplification depends on another, the latest example being
here:
https://github.com/llvm/torch-mlir/issues/1131

This also exposed that RefineTypes was sometimes crashing/asserting for
certain inputs. This commit hardens it a bit.
pull/1241/head
Sean Silva 2022-08-04 18:39:21 +00:00
parent 9c8b962720
commit 57681f7947
14 changed files with 518 additions and 251 deletions

View File

@ -28,14 +28,20 @@ createPrepareForGlobalizeObjectGraphPass();
struct TorchLoweringPipelineOptions
: public PassPipelineOptions<TorchLoweringPipelineOptions> {
// If this option is true, then perform optimizations.
// If this option is false, only do the bare minimum for correctness.
Option<bool> optimize{*this, "optimize", llvm::cl::desc("Do optimizations."),
llvm::cl::init(true)};
// The maximum number of invocations of the simplification pipeline in
// LowerToBackendContract.
Option<int> maxIterations{
*this, "max-iterations",
llvm::cl::desc(
"Maximum number of invocations of the simplification pipeline."),
llvm::cl::init(10)};
// If this option is false, decompose complex operations.
// If this option is true, skip decomposition of complex operations.
Option<bool> decompose{*this, "decompose-complex-ops", llvm::cl::desc("Decompose complex operations."),
// TODO: This should be replaced with a list of operations to decompose.
// (or some other way to specify the set of allowed ops in the backend
// contract)
Option<bool> decompose{*this, "decompose-complex-ops",
llvm::cl::desc("Decompose complex operations."),
llvm::cl::init(true)};
};
@ -50,10 +56,16 @@ void createTorchScriptModuleToTorchBackendPipeline(
void createTorchFunctionToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
/// Creates a pipeline that refines shapes of tensor operations in the program.
void createTorchShapeRefinementPipeline(
/// Creates a pipeline that simplifies the computations in the program.
/// This pass does not do any global program restructuring -- it works entirely
/// within a single semantic model of a `builtin.module` with
/// `torch.global_slot` ops and `func.func` ops.
void createTorchSimplificationPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
/// Creates a pipeline that refines shapes of tensor operations in the program.
void createTorchShapeRefinementPipeline(OpPassManager &pm);
std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
std::unique_ptr<OperationPass<func::FuncOp>> createRefineTypesPass();
@ -78,10 +90,10 @@ createSimplifyShapeCalculationsPass();
std::unique_ptr<OperationPass<func::FuncOp>> createDropShapeCalculationsPass();
std::unique_ptr<OperationPass<ModuleOp>>
createVerifyConversionToValueSemanticsPass();
createEraseModuleInitializerPass();
std::unique_ptr<OperationPass<ModuleOp>>
createEraseModuleInitializerPass();
createLowerToBackendContractPass(int maxIterations, bool decompose);
StringRef getShapeLibrary();

View File

@ -253,18 +253,6 @@ def DropShapeCalculations : Pass<"torch-drop-shape-calculations", "func::FuncOp"
}];
}
def VerifyConversionToValueSemantics
: Pass<"torch-verify-conversion-to-value-semantics", "ModuleOp"> {
let summary = "Verify that all tensors have been converted to value semantics";
let constructor =
"mlir::torch::Torch::createVerifyConversionToValueSemanticsPass()";
let description = [{
Prior passes in the pipeline may have missed converting all tensors to value
semantics and we wish to catch such failures early instead of fixing
individual cases downstream.
}];
}
def EraseModuleInitializer
: Pass<"torch-erase-module-initializer", "ModuleOp"> {
let summary = "Erase the `torch.global_slot.module_initializer` op.";
@ -273,9 +261,64 @@ def EraseModuleInitializer
let description = [{
Backends cannot currently handle module initializers, so we omit them from
our backend contract. This pass removes the
`torch.global_slot.module_initializer` op from the module if legal, or
raises an error.
`torch.global_slot.module_initializer` op from the module if legal.
}];
}
def LowerToBackendContract
: Pass<"torch-lower-to-backend-contract", "ModuleOp"> {
let summary = "Perform simplifications until the backend contract is satisfied.";
let constructor = [{
mlir::torch::Torch::createLowerToBackendContractPass(
/*maxIterations=*/10, /*decompose=*/true)
}];
let description = [{
This pass performs the bulk of the lowering of the program's computations
to the backend contract. This pass does not do any global program
restructuring -- it works entirely within a single semantic model
of a `builtin.module` with `torch.global_slot` ops and `func.func` ops.
This pass runs a set of simplifications within that semantic model until
the backend contract is satisfied, and fails if it cannot be satisfied.
In particular, the backend contract consists of:
- Tensors
- Have been converted to value semantics.
- Have at least a known rank, though ideally a maximally inferred shape.
- Have a known dtype.
- `torch.global_slot`'s have been eliminated from the program.
- Ops have been decomposed.
This particular choice of backend contract was born out of a common set of
requirements from backends, along with aligning with long-term PyTorch
direction of being more tracing-based. The set of simplifications performed
here can be thought of as simulating the kinds of simplifications that
happen naturally as part of tracing, but in a way that is applicable
to our TorchScript frontend. For the LazyTensorCore frontend, the backend
contract trivially holds (except for certain decompositions).
Generally it is not desirable to have a compiler where successful
compilation depends on "optimizing hard enough", but in this case, there
seems to be enough alignment and recognition in the industry that the
Python-based programming model in the source program is too dynamic
to feasibly handle in totality without a tracing approach that has access
to the source program to re-trace in the face of dynamism (e.g. the ability
to do what TorchDynamo calls "graph break"). We are attempting to maintain
a practical compiler that works well given the current set of constraints
of the TorchScript frontend that PyTorch provides us, and are working to
co-design PyTorch's direction so that we land in a place where most of this
"optimizing hard enough" is not necessary.
}];
let options = [
Option<"maxIterations", "max-iterations", "int", /*default=*/"10",
"Maximum number of invocations of the simplification pipeline.">,
// TODO: Make this a configurable set of ops.
Option<"decompose", "decompose", "bool", /*default=*/"true",
"Decompose ops.">
];
// TODO: Debug why this is needed, even though the input program has func.func
// ops in it.
let dependentDialects = ["func::FuncDialect"];
}
#endif // TORCHMLIR_TORCH_PASSES

View File

@ -6,6 +6,7 @@ add_mlir_library(TorchMLIRTorchPasses
Passes.cpp
GlobalizeObjectGraph.cpp
InlineGlobalSlots.cpp
LowerToBackendContract.cpp
MaximizeValueSemantics.cpp
PrepareForGlobalizeObjectGraph.cpp
ReduceOpVariants.cpp
@ -14,7 +15,6 @@ add_mlir_library(TorchMLIRTorchPasses
ReifyShapeCalculations.cpp
ShapeLibrary.cpp
SimplifyShapeCalculations.cpp
VerifyConversionToValueSemantics.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/Torch/Transforms

View File

@ -27,18 +27,15 @@ namespace {
class EraseModuleInitializerPass
: public EraseModuleInitializerBase<EraseModuleInitializerPass> {
void runOnOperation() override {
auto walkResult = getOperation().walk([](GlobalSlotModuleInitializerOp op) {
for (auto initializer :
getOperation().getOps<GlobalSlotModuleInitializerOp>()) {
auto intialize =
cast<InitializeGlobalSlotsOp>(op.getBody()->getTerminator());
if (intialize.getNumOperands() != 0) {
op.emitError("could not erase non-empty module initializer");
return WalkResult::interrupt();
cast<InitializeGlobalSlotsOp>(initializer.getBody()->getTerminator());
if (intialize.getNumOperands() == 0) {
initializer.erase();
}
op.erase();
return WalkResult::advance();
});
if (walkResult.wasInterrupted()) {
return signalPassFailure();
// The verifier ensures there is only one GlobalSlotModuleInitializerOp.
break;
}
}
};

View File

@ -0,0 +1,247 @@
//===- LowerToBackendContract.cpp --------------------------------*- C++-*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "torch-lower-to-backend-contract"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
//===----------------------------------------------------------------------===//
// Checking the backend contract.
//===----------------------------------------------------------------------===//
static LogicalResult checkType(Operation *op, Type type,
bool actuallyEmitDiagnostics) {
// Allow various scalar types that backends are expected to be able to handle.
if (type.isa<Torch::IntType, Torch::FloatType, Torch::BoolType>())
return success();
// Backends are not expected to support dynamic computations on these types,
// but they frequently appear as parameters to ops which backends
// can statically pattern match and eliminate from the program.
// For example, a tensor operand might be optional, and the backend
// will pattern-match statically whether it is passed as a tensor or None.
if (type.isa<Torch::NoneType, Torch::StringType>())
return success();
// We blanket prohibit non-value-semantic tensors.
// All of our backends are currently based on value-semantic tensors, so
// we consider it our responsibility to lower all non-value-semantic tensors
// to value-semantic tensors.
if (type.isa<NonValueTensorType>()) {
if (actuallyEmitDiagnostics) {
return op
->emitError("unsupported by backend contract: non-value tensor type")
.attachNote()
.append("this is likely due to a missing case in the "
"MaximizeValueSemantics pass");
} else {
return failure();
}
}
// For value-semantic tensors, we require at least a known rank and dtype.
// We are not aware of a situation where our backends can handle an unranked
// tensor type or a tensor with a dynamic dtype.
//
// There are somewhat fundamental reasons for this. In particular, the problem
// of unranked codegen is completely different from the problem of ranked
// codegen (since ranked corresponds to a fixed loop nest structure). For all
// codegen systems we are aware of, the program must be reduced to operate
// on ranked tensors at some point in compilation, and we are not aware of
// any backend with a general solution to this problem before it reaches
// codegen. So we consider it our responsibility to eliminate unranked tensor
// from the program.
//
// We aren't aware of any backend with any infrastructure to represent dynamic
// dtypes, let alone transform and optimize them. Additionally, it is unlikely
// that any backend, even if it supports dynamic dtypes in some form, will
// have an sufficiently rich system for representing PyTorch type promotion
// rules. So we consider it our responsibility to ensure that all dtypes are
// statically known.
if (auto tensorType = type.dyn_cast<ValueTensorType>()) {
if (!tensorType.hasSizes()) {
if (actuallyEmitDiagnostics) {
return op
->emitError(
"unsupported by backend contract: tensor with unknown rank")
.attachNote()
.append("this is likely due to a missing shape transfer function "
"in shape_lib_gen.py");
} else {
return failure();
}
}
if (!tensorType.hasDtype()) {
if (actuallyEmitDiagnostics) {
return op
->emitError(
"unsupported by backend contract: tensor with unknown dtype")
.attachNote()
.append("this is likely due to a missing case in RefineTypes");
} else {
return failure();
}
}
return success();
}
// Optional types are also in the category of types which we don't expect
// backends to dynamically compute with, but they can be pattern matched
// in many cases that are practically necessary.
if (auto optionalType = type.dyn_cast<OptionalType>()) {
// TODO: Be stricter about tensor types.
// See comment below for ListType.
if (optionalType.getContainedType().isa<ValueTensorType>())
return success();
return checkType(op, optionalType.getContainedType(),
actuallyEmitDiagnostics);
}
// List types are also in the category of types which we don't expect
// backends to dynamically compute with, but they can be pattern matched
// in many cases that are practically necessary. For example, the
// strides of a convolution op are represented as a list.
if (auto listType = type.dyn_cast<ListType>()) {
// TODO: Be stricter about tensor types.
// For the moment, there are cases (such as for torch.cat) where we end
// up with `!torch.list<vtensor>` which doesn't have shape or dtype in
// the contained type information. Somehow this slips through and works.
// We should be stricter about this and properly infer the contained type
// and shape.
if (listType.getContainedType().isa<ValueTensorType>())
return success();
return checkType(op, listType.getContainedType(), actuallyEmitDiagnostics);
}
// Tuple types are also in the category of types which we don't expect
// backends to dynamically compute with, but they can be pattern matched
// in many cases that are practically necessary.
if (auto tupleType = type.dyn_cast<Torch::TupleType>()) {
for (auto containedType : tupleType.getContainedTypes()) {
if (failed(checkType(op, containedType, actuallyEmitDiagnostics)))
return failure();
}
return success();
}
// Unsupported type.
if (actuallyEmitDiagnostics) {
return op->emitError("unsupported by backend contract: type ") << type;
} else {
return failure();
}
}
static bool satisfiesBackendContract(ModuleOp module,
bool actuallyEmitDiagnostics = false) {
// We do not permit `torch.global_slot`'s in the backend contract, since
// support for them is not widespread, and this does not align with PyTorch's
// more tracing-based direction.
//
// We just check for the GlobalSlotModuleInitializerOp since its verifier
// ensures that the set of global slots matches those initialized by the
// module initializer.
auto walkResult0 = module.walk([&](Torch::GlobalSlotModuleInitializerOp op) {
if (actuallyEmitDiagnostics) {
// Report the error on the terminator to avoid dumping the whole
// initializer itself, which can have pages of ops in it.
op.getBody()
->getTerminator()
->emitError("unsupported by backend contract: module initializers")
.attachNote()
.append("this is likely due to InlineGlobalSlots being unable to "
"inline a global slot");
}
return WalkResult::interrupt();
});
if (walkResult0.wasInterrupted())
return false;
// Check all the type of all Value's in the program.
//
// A pre-order walk gives a more intuitive "first error".
// TODO: Should we report more than the first error?
// How do we avoid making it too spammy?
auto walkResult1 = module.walk<WalkOrder::PreOrder>([&](Block *block) {
for (BlockArgument arg : block->getArguments())
if (failed(checkType(block->getParentOp(), arg.getType(),
actuallyEmitDiagnostics))) {
return WalkResult::interrupt();
}
for (Operation &op : *block)
for (OpResult result : op.getResults())
if (failed(checkType(&op, result.getType(), actuallyEmitDiagnostics)))
return WalkResult::interrupt();
return WalkResult::advance();
});
if (walkResult1.wasInterrupted())
return false;
return true;
}
namespace {
class LowerToBackendContractPass
: public LowerToBackendContractBase<LowerToBackendContractPass> {
public:
LowerToBackendContractPass() = default;
LowerToBackendContractPass(int maxIterations, bool decompose) {
this->maxIterations = maxIterations;
this->decompose = decompose;
}
void runOnOperation() override {
ModuleOp module = getOperation();
OpPassManager pm(module.getOperationName());
TorchLoweringPipelineOptions options;
options.decompose = decompose;
createTorchSimplificationPipeline(pm, options);
int i = 0;
do {
if (i++ == maxIterations) {
LLVM_DEBUG({
llvm::dbgs() << "LowerToBackendContractPass: "
<< "failed to satisfy backend contract after "
<< maxIterations
<< " iterations of the simplification pipeline\n";
});
// Show the diagnostics.
(void)satisfiesBackendContract(module,
/*actuallyEmitDiagnostics=*/true);
return signalPassFailure();
}
if (failed(runPipeline(pm, module)))
return signalPassFailure();
} while (!satisfiesBackendContract(module));
LLVM_DEBUG({
llvm::dbgs() << "LowerToBackendContractPass: "
<< "succeeded after " << i
<< " iterations of the simplification pipeline\n";
});
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::Torch::createLowerToBackendContractPass(int maxIterations,
bool decompose) {
return std::make_unique<LowerToBackendContractPass>(maxIterations, decompose);
}

View File

@ -31,6 +31,10 @@ void mlir::torch::registerTorchPasses() {
"Pipeline lowering a Torch function to Torch backend form.",
mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline);
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torch-simplification-pipeline",
"Pipeline simplifying computations in the program.",
mlir::torch::Torch::createTorchSimplificationPipeline);
mlir::PassPipelineRegistration<>(
"torch-shape-refinement-pipeline", "Pipeline refining shapes of tensors.",
mlir::torch::Torch::createTorchShapeRefinementPipeline);
}
@ -66,131 +70,82 @@ void mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline(
void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
// General considerations: As a matter of bring-up, we are simultaneously
// building out the frontend pipeline and also co-developing the backend
// support story as well. This means that sometimes the most expedient way to
// support a given program is to "optimize hard enough" that the parts of the
// program that touch unimplemented backend support go away (constant folded,
// dead-code-eliminated, etc.). In the fullness of time, most of that
// optimization should not be necessary, and we should have an "O0" pipeline
// that runs practically no optimizations.
// However, as a matter of expediency, at the moment we do run those
// optimizations. We guard those passes under the `options.optimize` option
// (which default to true, currently). We leave notes with the `OPT-ONLY` tag
// why we currently need that pass for correctness.
// We should eventually remove those passes from the default pipeline once
// backends have enough support.
// In particular the following features are needed in some form from backends:
// - Error handling (RaiseException + error string formatting)
// - First-class list type
// - torch.global_slot lowering
// - ...
// Please try to keep this list somewhat up to date when adding
// "optimize hard enough that it works" transformations.
// Incorporate user annotations and remove signature Python-isms.
pm.addPass(createAdjustCallingConventionsPass());
// Perform the bulk of lowering to the backend contract.
// See the pass documentation for more information.
pm.addPass(createLowerToBackendContractPass(options.maxIterations,
options.decompose));
}
// TODO: Remove options.optimize and this OPT-ONLY stuff -- we are already way
// past the point of no return for it being necessary for functional
// correctness.
if (options.optimize) {
// Eliminate the PrimTupleIndexOp generated from the
// adjustCallingConventions
// A simplification pipeline to establish the invariants of the backend
// contract (see `satisfiedBackendContract` in `LowerToBackendContract`).
//
// We structure this so that a single run of this pipeline is enough for
// most models, but it is possible for it to take multiple runs to fully
// clean things up when there are cyclic dependencies between certain
// simplifications, such as a decomposition relying on shape refinement which
// depends on another decomposition.
//
// Although technically this pipeline is an implementation detail of
// LowerToBackendContract, we expose it here to help debugging.
//
// LowerToBackendContract will run this pipeline as many times as necessary, but
// in general, it is costly to re-run this pipeline, since all the passes do
// O(module size) work. We want the number of iterations of this pipeline
// to be bounded by meaningful "always in practice small" program properties,
// such as loop nesting depth, number of sequentially dependent steps of
// constant global slots proving that other global slots are dead, etc.
//
// It is generally always possible to construct a pathological input that will
// exceed the number of iterations. If we do find practical cases with
// O(module size) number of iterations of this simplification pipeline, then
// we may need to adjust the approach, such as to do some of the transformations
// together at finer granularity.
void mlir::torch::Torch::createTorchSimplificationPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
// General cleanup.
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// Inline global slots, which for most inference scenarios deletes them.
// This also exposes more information to intraprocedural transformations
// below like MaximizeValueSemantics and RefineTypes.
// OPT-ONLY: Don't rely on this pass to "lower" global slots by deleting.
// Also don't rely on this pass to expose constants into the program to
// simplify handling of "optional".
// Inline global slots to expose a bunch of simplification opportunities
// from constant hyperparameters, weights, etc.
pm.addPass(createInlineGlobalSlotsPass());
// After doing a first round of inlining global slots, canonicalize again to
// take advantage of optimization opportunities exposed by the inlined
// global slots. In particular, this is functionally necessary now because
// large amounts of control flow are guarded by an "is training" flag, so
// inlining removes certain mutating operations done on the slots enabling
// them to be deleted.
// TODO: In full generality, we need to do a fixed-point iteration of
// shape inference, maximizing value semantics, decomposition, inling global
// slots, and canonicalization.
// Erase the module initializer if we have proven that all the global slots
// are gone.
pm.addPass(createEraseModuleInitializerPass());
// Clean up again to avoid needing to to back around the fixed-point
// iteration.
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// Inline again, cleaning up any remaining global slots that might be dead
// now.
pm.addPass(createInlineGlobalSlotsPass());
// Erase the module initializers (or fail compilation), since they aren't
// permitted in our backend contract at the moment.
pm.addPass(Torch::createEraseModuleInitializerPass());
}
// Reduce variants of ops to a smaller set of primitives.
pm.addNestedPass<func::FuncOp>(createReduceOpVariantsPass());
if (options.optimize) {
// OPT-ONLY: Right now we rely on this to eliminate certain branches that
// guard unreachable code that backends can't handle yet, such as lists,
// RaiseException, unimplemented tensor ops, and only-used-in-training
// operations on `torch.global_slot`'s.
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// OPT-ONLY: We may have deleted some `torch.global_slot.get` /
// `torch.global_slot.get` ops, which may have left more
// `torch.global_slot`'s unused.
// Remove dead global slots.
pm.addPass(createSymbolDCEPass());
}
//===--------------------------------------------------------------------===//
// Lowering to ranked !torch.vtensors of known dtype.
//===--------------------------------------------------------------------===//
// Convert the bulk of non-ABI-visible !torch.tensor's to !torch.vtensor's.
pm.addNestedPass<func::FuncOp>(Torch::createMaximizeValueSemanticsPass());
// Update the return op to return value tensors and remove dead ops.
// Update the return op to return value tensors.
pm.addPass(Torch::createRefinePublicReturnPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// Ensure that all tensors have been converted to value semantics.
pm.addPass(Torch::createVerifyConversionToValueSemanticsPass());
// Do shape refinement.
// This must be run before RefineTypes (which primarily does dtype inference),
// because Torch type promotion rules actually depend on the shape of the
// operand.
createTorchShapeRefinementPipeline(pm, options);
// This should be run before RefineTypes (which primarily does dtype
// inference), because Torch type promotion rules actually depend on the shape
// of the operand.
createTorchShapeRefinementPipeline(pm);
// Refine types in the program, which mainly means inferring dtypes of ops.
pm.addNestedPass<func::FuncOp>(Torch::createRefineTypesPass());
// Propagate to ABI return types the shape/dtype information discovered by
// the previous pass. Doing this is ABI-compatible for our backends.
pm.addPass(Torch::createRefinePublicReturnPass());
if (options.optimize) {
// This can fold away some branches given the information got from
// RefineTypes before doing maximize value sematics which only works with
// basic blocks.
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
}
if (options.optimize) {
// All the type refinement we've done above has exposed new information
// that allows folding away more stuff.
// OPT-ONLY: Right now we rely on this to eliminate certain
// branches that guard unreachable code that backends can't handle yet, such
// as lists, RaiseException, unimplemented aten ops, and
// only-used-in-training operations on `torch.global_slot`'s.
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
}
if (options.decompose) {
pm.addNestedPass<func::FuncOp>(Torch::createDecomposeComplexOpsPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
}
// TODO: VerifyTorchBackendContractPass.
}
void mlir::torch::Torch::createTorchShapeRefinementPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
void mlir::torch::Torch::createTorchShapeRefinementPipeline(OpPassManager &pm) {
// Reify the shape functions for each op that is present in the shape library.
pm.addPass(Torch::createReifyShapeCalculationsPass());

View File

@ -1152,7 +1152,7 @@ void TypeAnalysis::visitAtenEmbeddingBagOp(Operation *op) {
resultIntKnowledge.dtype =
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
for (int64_t i = 1; i < 4; i++) {
for (int64_t i = 1, e = op->getNumResults(); i < e; i++) {
incorporateKnowledge(op->getResult(i), resultIntKnowledge);
}
return;
@ -1259,6 +1259,12 @@ void TypeAnalysis::visitAtenTensorOp(AtenTensorOp op) {
while (auto listType = type.dyn_cast<ListType>()) {
type = listType.getContainedType();
}
// TODO: Support tensor as the contained type of the list.
// These are the only types handled by fillInDTypeGivenDTypeAndDataType below.
if (!type.isa<Torch::FloatType, Torch::IntType, Torch::BoolType>()) {
incorporateKnowledge(op.getResult(), knowledge);
return;
}
fillInDTypeGivenDTypeAndDataType(knowledge, dtype, type);
incorporateKnowledge(op.getResult(), knowledge);
}
@ -1418,13 +1424,13 @@ static Type getMostRefinedStaticType(Value v, DataFlowSolver &solver) {
};
if (auto tensorType = v.getType().dyn_cast<BaseTensorType>()) {
const ValueState *latticeElement = solver.lookupState<ValueState>(v);
if (!latticeElement)
if (!latticeElement || latticeElement->isUninitialized())
return nullptr;
const ValueKnowledge &knowledge = latticeElement->getValue();
return getRefinedTensorType(tensorType, knowledge);
} else if (auto optionalType = v.getType().dyn_cast<OptionalType>()) {
const ValueState *latticeElement = solver.lookupState<ValueState>(v);
if (!latticeElement)
if (!latticeElement || latticeElement->isUninitialized())
return nullptr;
const ValueKnowledge &knowledge = latticeElement->getValue();
if (knowledge.optional == OptionalKnowledge::isNone)
@ -1438,7 +1444,7 @@ static Type getMostRefinedStaticType(Value v, DataFlowSolver &solver) {
}
} else if (auto scalarType = v.getType().dyn_cast<NumberType>()) {
const ValueState *latticeElement = solver.lookupState<ValueState>(v);
if (!latticeElement)
if (!latticeElement || latticeElement->isUninitialized())
return nullptr;
const ValueKnowledge &knowledge = latticeElement->getValue();
if (knowledge.kind == torch_upstream::TypeKind::IntType)

View File

@ -1,56 +0,0 @@
//===- VerifyConversionToValueSemantics.cpp ----------------------*- C++-*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/IR/BuiltinOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::torch::Torch;
static LogicalResult checkValueType(Operation *op, Value value) {
auto isNotValueTensorType = value.getType().isa<NonValueTensorType>();
return isNotValueTensorType
? op->emitError(
"found a non-value tensor type, this is likely due to a "
"missing case in the MaximizeValueSemantics pass")
: success();
}
namespace {
class VerifyConversionToValueSemanticsPass
: public VerifyConversionToValueSemanticsBase<
VerifyConversionToValueSemanticsPass> {
void runOnOperation() override {
auto walkResult = getOperation().walk([&](Block *block) {
for (BlockArgument arg : block->getArguments())
if (failed(checkValueType(block->getParentOp(), arg)))
return WalkResult::interrupt();
for (Operation &op : *block)
for (OpResult result : op.getResults())
if (failed(checkValueType(&op, result)))
return WalkResult::interrupt();
return WalkResult::advance();
});
if (walkResult.wasInterrupted())
signalPassFailure();
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::Torch::createVerifyConversionToValueSemanticsPass() {
return std::make_unique<VerifyConversionToValueSemanticsPass>();
}

View File

@ -77,7 +77,6 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());
pm.addNestedPass<func::FuncOp>(memref::createExpandOpsPass());
if (options.optimize) {
// Clean up any non-canonical code introduced above..
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// Resolve `dim` ops on tensors (which currently live in the `memref`
@ -86,7 +85,6 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
memref::createResolveShapedTypeResultDimsPass());
// The resolution of `dim` ops tends to create identical ops. CSE them.
pm.addNestedPass<func::FuncOp>(createCSEPass());
}
// Finish the type conversion from `torch` types to the types of the
// linalg-on-tensors backend contract.
@ -111,12 +109,10 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline(
// Perform rank broadcasting so TosaToLinalg pass works
pm.addNestedPass<func::FuncOp>(createTosaMakeBroadcastablePass());
if (options.optimize) {
// Clean up any non-canonical code introduced above..
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// The resolution of `dim` ops tends to create identical ops. CSE them.
pm.addNestedPass<func::FuncOp>(createCSEPass());
}
// Finish the type conversion from `torch` types to the types of the
// TOSA backend contract.
@ -140,21 +136,17 @@ void TorchConversion::createTorchBackendToMhloBackendPipeline(
pm.addNestedPass<func::FuncOp>(createConvertTorchToMhloPass());
if (options.optimize) {
// Clean up any non-canonical code introduced above..
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// The resolution of `dim` ops tends to create identical ops. CSE them.
pm.addNestedPass<func::FuncOp>(createCSEPass());
}
// Convert CHLO ops to MHLO ops
pm.addNestedPass<func::FuncOp>(mhlo::createChloLegalizeToHloPass());
if (options.optimize) {
// Clean up any non-canonical code introduced above..
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// The resolution of `dim` ops tends to create identical ops. CSE them.
pm.addNestedPass<func::FuncOp>(createCSEPass());
}
// Finish the type conversion from `torch` types to the types of the
// MHLO backend contract.

View File

@ -6,15 +6,3 @@ torch.global_slot.module_initializer {
torch.initialize.global_slots [
]
}
// -----
torch.global_slot @slot0 : !torch.int
// expected-error@+1 {{could not erase non-empty module initializer}}
torch.global_slot.module_initializer {
%0 = torch.constant.int 0
torch.initialize.global_slots [
@slot0(%0: !torch.int)
]
}

View File

@ -0,0 +1,61 @@
// RUN: torch-mlir-opt -torch-lower-to-backend-contract -split-input-file -verify-diagnostics %s
torch.global_slot.module_initializer {
%0 = torch.constant.int 1
// expected-error @+2 {{unsupported by backend contract: module initializers}}
// expected-note @+1 {{this is likely due to}}
torch.initialize.global_slots [
@slot0(%0 : !torch.int)
]
}
torch.global_slot @slot0 : !torch.int
// -----
// expected-error @+2 {{unsupported by backend contract: non-value tensor type}}
// expected-note @+1 {{this is likely due to}}
func.func @f(%arg0: !torch.tensor) {
return
}
// -----
// expected-error @+2 {{unsupported by backend contract: tensor with unknown rank}}
// expected-note @+1 {{this is likely due to}}
func.func @f(%arg0: !torch.vtensor<*,f32>) {
return
}
// -----
// expected-error @+2 {{unsupported by backend contract: tensor with unknown dtype}}
// expected-note @+1 {{this is likely due to}}
func.func @f(%arg0: !torch.vtensor<[],unk>) {
return
}
// -----
// expected-error @+1 {{unsupported by backend contract: type '!torch.any'}}
func.func @f(%arg0: !torch.any) {
return
}
// -----
// Test case: checking of op results.
// TODO: In theory we could diagnose every single value, but for now we bail out on the first one.
func.func @f(%arg0: !torch.bool, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[7],f32>) -> !torch.vtensor<*,f32> {
// expected-error @+2 {{unsupported by backend contract: tensor with unknown rank}}
// expected-note @+1 {{this is likely due to}}
%0 = torch.prim.If %arg0 -> (!torch.vtensor<*,f32>) {
%1 = torch.tensor_static_info_cast %arg1 : !torch.vtensor<[],f32> to !torch.vtensor<*,f32>
torch.prim.If.yield %1 : !torch.vtensor<*,f32>
} else {
%2 = torch.tensor_static_info_cast %arg2 : !torch.vtensor<[7],f32> to !torch.vtensor<*,f32>
torch.prim.If.yield %2 : !torch.vtensor<*,f32>
}
return %0 : !torch.vtensor<*,f32>
}

View File

@ -240,3 +240,35 @@ func.func @prim.dtype(%arg: !torch.vtensor<*,bf16>) -> !torch.vtensor<*,unk> {
return %result2 : !torch.vtensor<*,unk>
}
// -----
// Check that we don't crash on this input.
// CHECK-LABEL: func.func @forward
func.func @forward() -> !torch.vtensor {
%false = torch.constant.bool false
%none = torch.constant.none
%0 = torch.prim.ListConstruct : () -> !torch.list<tensor>
// CHECK: torch.aten.tensor
%1 = torch.aten.tensor %0, %none, %none, %false : !torch.list<tensor>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor
return %1 : !torch.vtensor
}
// -----
// Check that we don't crash on this input.
// TODO: This appears to result in aten.mul.Tensor not being visited.
// We should investigate why that happens.
// CHECK-LABEL: func.func @forward
func.func @forward(%arg0: !torch.bool, %arg1: !torch.tensor) {
%0 = torch.prim.If %arg0 -> (!torch.tensor) {
torch.prim.If.yield %arg1 : !torch.tensor
} else {
torch.prim.If.yield %arg1 : !torch.tensor
}
%1 = torch.copy.to_vtensor %0 : !torch.vtensor
%2 = torch.aten.mul.Tensor %1, %1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor
return
}

View File

@ -1,9 +0,0 @@
// RUN: torch-mlir-opt -split-input-file -verify-diagnostics %s -torch-verify-conversion-to-value-semantics
// -----
func.func @result_is_non_value_tensor(%arg: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> {
// @expected-error@+1 {{found a non-value tensor type, this is likely due to a missing case in the MaximizeValueSemantics pass}}
%neg = torch.aten.neg %arg : !torch.vtensor<[2],f32> -> !torch.tensor
return %arg : !torch.vtensor<[2],f32>
}

View File

@ -182,7 +182,6 @@ cc_library(
"lib/Dialect/Torch/Transforms/ShapeLibrary.cpp",
"lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp",
"lib/Dialect/Torch/Transforms/PassDetail.h",
"lib/Dialect/Torch/Transforms/VerifyConversionToValueSemantics.cpp",
],
hdrs = [
"include/torch-mlir/Dialect/Torch/Transforms/Passes.h",