Constant fold through basicpy.bool_cast.

This is the start of a push to getting ResNet running.

This involves throwing in the towel on an O0 pipelinie for now. See note
in the code. We keep an options struct with `optimize` flag, but it
default to true for now.
pull/213/head
Sean Silva 2021-04-26 14:22:50 -07:00
parent fb5f149e04
commit 7eb36b4ae7
8 changed files with 101 additions and 14 deletions

View File

@ -17,6 +17,7 @@ class ResNet18Module(torch.nn.Module):
# Reset seed to make model deterministic.
torch.manual_seed(0)
self.resnet = models.resnet18()
self.train(False)
@export
@annotate_args([
None,

View File

@ -295,6 +295,7 @@ def Basicpy_BoolCastOp : Basicpy_Op<"bool_cast", [NoSideEffect]> {
let arguments = (ins BoolOrI1Type:$operand);
let results = (outs BoolOrI1Type:$result);
let assemblyFormat = "$operand attr-dict `:` type(operands) `->` type(results)";
let hasFolder = 1;
}
def Basicpy_UnknownCastOp : Basicpy_Op<"unknown_cast", [NoSideEffect]> {

View File

@ -22,16 +22,26 @@ std::unique_ptr<OperationPass<ModuleOp>> createGlobalizeObjectGraphPass();
std::unique_ptr<OperationPass<ModuleOp>>
createPrepareForGlobalizeObjectGraphPass();
struct TorchLoweringPipelineOptions
: public PassPipelineOptions<TorchLoweringPipelineOptions> {
// If this option is true, then perform optimizations.
// If this option is false, only do the bare minimum for correctness.
Option<bool> optimize{*this, "optimize", llvm::cl::desc("Do optimizations."),
llvm::cl::init(true)};
};
/// Creates a pipeline that lowers the object graph IR that is produced by
/// TorchScript import into the form expected by npcomp-verify-backend-contract.
void createLowerObjectGraphPipeline(OpPassManager &pm);
void createLowerObjectGraphPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
/// Creates a pipeline that lowers a flat list of funcs and global slots
/// with the torch and aten dialects and mutable arrays and converts it to
/// the form required by npcomp-verify-backend-contract, in particular
/// lowering most arrays to ranked tensors of known dtype, lowering aten ops to
/// linalg, converting torch.prim.* ops to elementary math operations.
void createLowerToNpcompBackendPipeline(OpPassManager &pm);
void createLowerToNpcompBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();

View File

@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Transforms/InliningUtils.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
@ -46,6 +47,7 @@ void BasicpyDialect::initialize() {
addTypes<BoolType, BytesType, DictType, EllipsisType, ListType, NoneType,
SlotObjectType, StrType, TupleType, UnknownType>();
addInterfaces<BasicpyInlinerInterface>();
getContext()->getOrLoadDialect<StandardOpsDialect>();
// TODO: Make real ops for everything we need.
allowUnknownOperations();
@ -54,6 +56,11 @@ void BasicpyDialect::initialize() {
Operation *BasicpyDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
// std.constant is used for literal i1 types (not !basicpy.BoolType).
if (auto integerType = type.dyn_cast<IntegerType>()) {
if (integerType.getWidth() == 1)
return builder.create<ConstantOp>(loc, value);
}
// NumericConstantOp.
// Supports IntegerType (any signedness), FloatType and ComplexType.
if (type.isa<IntegerType>() || type.isa<FloatType>() ||

View File

@ -21,6 +21,14 @@ using namespace mlir::NPCOMP::Basicpy;
// Fallback verifier for ops that don't have a dedicated one.
template <typename T> static LogicalResult verify(T op) { return success(); }
//===----------------------------------------------------------------------===//
// BoolCastOp
//===----------------------------------------------------------------------===//
OpFoldResult BoolCastOp::fold(ArrayRef<Attribute> operands) {
return operands[0];
}
//===----------------------------------------------------------------------===//
// BoolConstantOp
//===----------------------------------------------------------------------===//

View File

@ -11,4 +11,5 @@ add_npcomp_dialect_library(NPCOMPBasicpyDialect
LINK_LIBS PUBLIC
NPCOMPTypingCPA
MLIRIR
MLIRStandard
)

View File

@ -28,17 +28,18 @@ namespace {
void mlir::NPCOMP::registerTorchPasses() {
::registerPasses();
mlir::PassPipelineRegistration<>(
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torchscript-to-npcomp-backend-pipeline",
"Pipeline lowering torch object graph to npcomp backend format.",
mlir::NPCOMP::Torch::createLowerObjectGraphPipeline);
mlir::PassPipelineRegistration<>(
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torch-globalized-module-to-npcomp-backend-pipeline",
"Pipeline lowering to npcomp backend form.",
mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline);
}
void mlir::NPCOMP::Torch::createLowerObjectGraphPipeline(OpPassManager &pm) {
void mlir::NPCOMP::Torch::createLowerObjectGraphPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
// When we import TorchScript IR, we import their entire "compilation unit",
// which can contain numerous functions unrelated to the current program,
// which breaks torch-globalization-pipeline; for example, there can be
@ -65,14 +66,59 @@ void mlir::NPCOMP::Torch::createLowerObjectGraphPipeline(OpPassManager &pm) {
// Incorporate user annotations and remove signature Python-isms.
pm.addPass(createAdjustCallingConventionsPass());
createLowerToNpcompBackendPipeline(pm);
createLowerToNpcompBackendPipeline(pm, options);
}
void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline(
OpPassManager &pm) {
// Recognize ATen kernels.
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
// General considerations: As a matter of bring-up, we are simultaneously
// building out the frontend pipeline and also co-developing the backend
// support story as well. This means that sometimes the most expedient way to
// support a given program is to "optimize hard enough" that the parts of the
// program that touch unimplemented backend support go away (constant folded,
// dead-code-eliminated, etc.). In the fullness of time, most of that
// optimization should not be necessary, and we should have an "O0" pipeline
// that runs practically no optimizations.
// However, as a matter of expediency, at the moment we do run those
// optimizations. We guard those passes under the `options.optimize` option
// (which default to true, currently). We leave notes with the `OPT-ONLY` tag
// why we currently need that pass for correctness.
// We should eventually remove those passes from the default pipeline once
// backends have enough support.
// In particular the following features are needed in some form from backends:
// - Error handling (RaiseException + error string formatting)
// - First-class list type
// - torch.global_slot lowering
// - ...
// Please try to keep this list somewhat up to date when adding
// "optimize hard enough that it works" transformations.
if (options.optimize) {
// Inline global slots, which for most inference scenarios deletes them.
// This also exposes more information to intraprocedural transformations
// below like ArrayToTensor and RefineTypes.
// OPT-ONLY: Don't rely on this pass to "lower" global slots by deleting.
// Also don't rely on this pass to expose constants into the program to
// simplify handling of "optional".
pm.addPass(createInlineGlobalSlotsPass());
}
// Recognize ATen kernels. This is a totally local transformation that
// we want to run as soon as possible.
pm.addNestedPass<FuncOp>(aten::createRecognizeKernelsPass());
if (options.optimize) {
// OPT-ONLY: Right now we rely on this to eliminate certain branches that
// guard unreachable code that backends can't handle yet, such as lists,
// RaiseException, unimplemented aten ops, and only-used-in-training
// operations on `torch.global_slot`'s.
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
// OPT-ONLY: We may have deleted some `torch.global_slot.get` /
// `torch.global_slot.get` ops, which may have left more
// `torch.global_slot`'s unused.
pm.addPass(createSymbolDCEPass());
}
// Convert the bulk of the program to ranked tensors with known dtype.
// This is the input to the backend layer that we are aiming for.
@ -82,12 +128,6 @@ void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline(
// updates to "out params" on their public functions.
// This is deemed ok for now.
pm.addPass(Numpy::createPublicFunctionsToTensorPass());
// Inline global slots, which for most inference scenarios deletes them.
// This also exposes more information to intraprocedural transformations
// below like ArrayToTensor and RefineTypes.
// TODO: Don't rely on this pass to "lower" global slots by deleting.
// This pass should eventually be "just an optimization".
pm.addPass(createInlineGlobalSlotsPass());
// Convert the bulk of non-ABI-visible arrays to tensors.
pm.addNestedPass<FuncOp>(Numpy::createArrayToTensorPass());
// Do shape and dtype refinement.
@ -100,6 +140,15 @@ void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline(
// Clean up a few stray array/tensor conversion remnants.
pm.addNestedPass<FuncOp>(Numpy::createArrayToTensorPass());
if (options.optimize) {
// RefineTypes has exposed new type information that allows folding away
// more stuff. OPT-ONLY: Right now we rely on this to eliminate certain
// branches that guard unreachable code that backends can't handle yet, such
// as lists, RaiseException, unimplemented aten ops, and
// only-used-in-training operations on `torch.global_slot`'s.
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
}
// Lower to TCP (+ guards) which is the input to codegen backends.
// Most of this should be subsumed by aten->linalg+guards conversions.
// (the guard generation will be automated from the linalg Op DSL).

View File

@ -70,3 +70,13 @@ func @str_constant() -> !basicpy.StrType {
%0 = basicpy.str_constant "foobar"
return %0 : !basicpy.StrType
}
// -----
// CHECK-LABEL: @bool_cast
func @bool_cast() -> i1 {
// CHECK: %[[CTRUE:.*]] = constant true
%0 = basicpy.bool_constant true
%1 = basicpy.bool_cast %0 : !basicpy.BoolType -> i1
// CHECK: return %[[CTRUE]] : i1
return %1 : i1
}