Add npcomp-verify-backend-contract pass.

This pass verifies that a given module satisfies the contract that we
have for backends. This is phrased as an "allowlist", because we want to
keep this interface tight. Also, this gives much better diagnostics than
a backend randomly crashing or failing to compile would (though they
could still be improved).

This was especially painful because if we had
`tensor<?x!numpy.any_dtype>` slip through, at some point RefBackend
would convert it to a memref type and trip the "verify type invariants"
assertion which gives no location or anything and crashed the process,
which was very unpleasant.

We implement this with the dialect conversion framework, which works
reasonably well and was quick to put together and familiar, but is still
very "op oriented". We probably want to make this hand-rolled
eventually, especially the error reporting (the most useful kind of
error for a dialect conversion user is not necessarily the best for this
use case). Also, in production, these error will go to users, and need
to be surfaced carefully such as "the compiler needs a type annotation
on this function parameter" which in general requires some special
analysis, wordsmithing, and overall awareness of the e2e use case (such
as how much we can lean into certain source locations) to provide a
meaningful user-level diagnostic.

Also, add `inline` to the current frontend lowering pass pipeline to
allow slightly more complicated programs that otherwise would fail on
shape inference.
pull/208/head
Sean Silva 2021-04-12 18:39:53 -07:00
parent f5dfa02523
commit c4123d4d4d
17 changed files with 300 additions and 15 deletions

View File

@ -19,12 +19,19 @@ logging.enable()
mb = torch_mlir.ModuleBuilder()
class TestModule(torch.nn.Module):
class Submodule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, lhs, rhs):
return torch.mm(lhs, rhs)
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.s = Submodule()
def forward(self, lhs, rhs):
return self.s.forward(lhs, rhs)
test_module = TestModule()
class_annotator = torch_mlir.ClassAnnotator()
recursivescriptmodule = torch.jit.script(test_module)

View File

@ -0,0 +1,6 @@
add_subdirectory(Common)
# Currently this doesn't introduce any actual dependency on IREE, so add it
# unconditionally.
# TODO: Put this behind the NPCOMP_ENABLE_IREE flag.
add_subdirectory(IREE)

View File

@ -0,0 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls)
add_public_tablegen_target(NPCOMPCommonBackendPassIncGen)
add_mlir_doc(Passes -gen-pass-doc CommonBackendPasses ./)

View File

@ -0,0 +1,27 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_BACKEND_COMMON_PASSES_H
#define NPCOMP_BACKEND_COMMON_PASSES_H
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
namespace mlir {
namespace NPCOMP {
namespace CommonBackend {
/// Registers all CommonBackend passes.
void registerCommonBackendPasses();
std::unique_ptr<OperationPass<ModuleOp>> createVerifyBackendContractPass();
} // namespace CommonBackend
} // namespace NPCOMP
} // namespace mlir
#endif // NPCOMP_BACKEND_COMMON_PASSES_H

View File

@ -0,0 +1,19 @@
//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_BACKEND_COMMON_PASSES
#define NPCOMP_BACKEND_COMMON_PASSES
include "mlir/Pass/PassBase.td"
def VerifyBackendContract : Pass<"npcomp-verify-backend-contract", "ModuleOp"> {
let summary = "Verifies conformity to the backend contract that npcomp targets";
let constructor = "mlir::NPCOMP::CommonBackend::createVerifyBackendContractPass()";
}
#endif // NPCOMP_BACKEND_COMMON_PASSES

View File

@ -0,0 +1,5 @@
# Common backend utilities
This directory contains passes/transformations/analyses/etc. that are relevant
to the backend contract that npcomp targets, but otherwise independent of any
particular backend.

View File

@ -1,8 +1,4 @@
# Currently this doesn't introduce any actual dependency on IREE, so add it
# unconditionally.
# TODO: Put this behind the NPCOMP_ENABLE_IREE flag.
add_subdirectory(Backend/IREE)
add_subdirectory(Backend)
add_subdirectory(Conversion)
add_subdirectory(Dialect)
add_subdirectory(RefBackend)

View File

@ -0,0 +1,11 @@
add_subdirectory(Common)
if(NPCOMP_ENABLE_REFJIT)
add_subdirectory(RefJIT)
endif()
# Currently this doesn't introduce any actual dependency on IREE, so add it
# unconditionally.
# TODO: Put this behind the NPCOMP_ENABLE_IREE flag.
add_subdirectory(IREE)

View File

@ -0,0 +1,22 @@
add_npcomp_library(NPCOMPCommonBackend
VerifyBackendContract.cpp
Passes.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SRC_DIR}/include/npcomp/Backend/Common
DEPENDS
NPCOMPCommonBackendPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRLinalg
MLIRMemRef
MLIRStandard
MLIRMath
)
mlir_check_all_link_libraries(NPCOMPCommonBackend)

View File

@ -0,0 +1,25 @@
//===- PassDetail.h - Pass class details ------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef BACKEND_COMMON_PASSDETAIL_H
#define BACKEND_COMMON_PASSDETAIL_H
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace NPCOMP {
namespace CommonBackend {
#define GEN_PASS_CLASSES
#include "npcomp/Backend/Common/Passes.h.inc"
} // namespace CommonBackend
} // namespace NPCOMP
} // end namespace mlir
#endif // BACKEND_COMMON_PASSDETAIL_H

View File

@ -0,0 +1,25 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "npcomp/Backend/Common/Passes.h"
#include "mlir/IR/BuiltinOps.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::CommonBackend;
namespace {
#define GEN_PASS_REGISTRATION
#include "npcomp/Backend/Common/Passes.h.inc"
} // end namespace
void mlir::NPCOMP::CommonBackend::registerCommonBackendPasses() {
::registerPasses();
}

View File

@ -0,0 +1,83 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Transforms/DialectConversion.h"
#include "npcomp/Backend/Common/Passes.h"
#include "mlir/IR/BuiltinOps.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::CommonBackend;
namespace {
class VerifyBackendContractPass
: public VerifyBackendContractBase<VerifyBackendContractPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
auto module = getOperation();
TypeConverter converter;
converter.addConversion([](RankedTensorType type) -> Type {
if (BaseMemRefType::isValidElementType(type.getElementType()))
return type;
return nullptr;
});
TypeConverter scalarConverter;
for (TypeConverter *c : {&converter, &scalarConverter}) {
c->addConversion([](FloatType type) { return type; });
c->addConversion([](IntegerType type) { return type; });
c->addConversion([](IndexType type) { return type; });
}
auto opHasLegalTypes = [&](Operation *op) { return converter.isLegal(op); };
auto isLegalScalarOp = [&](Operation *op) {
// We recognize basic scalar ops by them having the trait "Elementwise",
// even though we don't expect them to operate on tensors.
return scalarConverter.isLegal(op) &&
op->hasTrait<OpTrait::Elementwise>();
};
ConversionTarget target(*context);
// Structural operations.
target.addDynamicallyLegalOp<ModuleOp, FuncOp, ReturnOp>(opHasLegalTypes);
// Basic scalar operations.
target.addDynamicallyLegalDialect<StandardOpsDialect>(isLegalScalarOp);
target.addDynamicallyLegalDialect<math::MathDialect>(isLegalScalarOp);
// Tensor operations should go through linalg.
target.addDynamicallyLegalDialect<linalg::LinalgDialect>(opHasLegalTypes);
// DimOp is used to query tensor sizes.
target.addDynamicallyLegalOp<memref::DimOp>(opHasLegalTypes);
// AssertOp is used to terminate the program for error guards.
target.addLegalOp<AssertOp>();
// ConstantOp is used for tensors and for scalars.
target.addDynamicallyLegalOp<ConstantOp>(opHasLegalTypes);
RewritePatternSet patterns(context);
if (failed(applyFullConversion(module, target, std::move(patterns)))) {
module.emitError()
<< "Module does not conform to npcomp's backend contract. See "
"dialect conversion legality information above.";
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::NPCOMP::CommonBackend::createVerifyBackendContractPass() {
return std::make_unique<VerifyBackendContractPass>();
}

View File

@ -27,6 +27,7 @@ void npcompRegisterAllPasses() {
// Upstream passes we depend on.
::mlir::registerSymbolDCEPass();
::mlir::registerInlinerPass();
::mlir::registerCanonicalizerPass();
::mlir::registerSCFToStandardPass();
::mlir::registerConvertElementwiseToLinalgPass();

View File

@ -1,18 +1,10 @@
add_subdirectory(Backend)
add_subdirectory(CAPI)
add_subdirectory(Conversion)
add_subdirectory(Dialect)
add_subdirectory(RefBackend)
add_subdirectory(Typing)
if(NPCOMP_ENABLE_REFJIT)
add_subdirectory(Backend/RefJIT)
endif()
# Currently this doesn't introduce any actual dependency on IREE, so add it
# unconditionally.
# TODO: Put this behind the NPCOMP_ENABLE_IREE flag.
add_subdirectory(Backend/IREE)
################################################################################
# Setup the initialization target.
# This includes conditional dependencies based on whether features are enabled.
@ -34,6 +26,7 @@ add_npcomp_library(NPCOMPInitAll
PUBLIC
# Local depends
NPCOMPCommonBackend
NPCOMPIREEBackend
NPCOMPRefBackend
NPCOMPRefbackDialect

View File

@ -9,6 +9,7 @@
#include "npcomp/InitAll.h"
#include "mlir/IR/Dialect.h"
#include "npcomp/Backend/Common/Passes.h"
#include "npcomp/Backend/IREE/Passes.h"
#include "npcomp/Conversion/Passes.h"
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
@ -52,4 +53,5 @@ void mlir::NPCOMP::registerAllPasses() {
mlir::NPCOMP::registerTorchPasses();
mlir::NPCOMP::registerTypingPasses();
mlir::NPCOMP::IREEBackend::registerIREEBackendPasses();
mlir::NPCOMP::CommonBackend::registerCommonBackendPasses();
}

View File

@ -29,6 +29,10 @@ OBJECT_GRAPH_LOWERING_PASSES = (
# bothersome because we don't currently have a lowering for them.
# TODO: Support global slots in backends.
"symbol-dce",
# Currently, our shape inference is not powerful enough to deal with
# calls, so inline everything.
# TODO: Improve shape inference.
"inline",
# Incorporate user annotations and remove signature Python-isms.
"torch-adjust-calling-conventions",
)
@ -65,6 +69,7 @@ TORCH_TO_TCP_PASSES = (
"func(convert-aten-to-tcf)",
"func(convert-tcf-to-std)",
"func(convert-elementwise-to-linalg)",
"npcomp-verify-backend-contract",
)
def lower_module(imported_module: Module):

View File

@ -0,0 +1,53 @@
// RUN: npcomp-opt -npcomp-verify-backend-contract -split-input-file -verify-diagnostics -allow-unregistered-dialect %s | FileCheck %s
// CHECK: func @mm
func @mm(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> attributes {iree.module.export} {
%c0 = constant 0 : index
%c1 = constant 1 : index
%cst = constant 0.000000e+00 : f32
%0 = memref.dim %arg0, %c0 : tensor<?x?xf32>
%1 = memref.dim %arg0, %c1 : tensor<?x?xf32>
%2 = memref.dim %arg1, %c0 : tensor<?x?xf32>
%3 = memref.dim %arg1, %c1 : tensor<?x?xf32>
%4 = cmpi eq, %1, %2 : index
assert %4, "mismatching contracting dimension for aten.mm"
%5 = linalg.init_tensor [%0, %3] : tensor<?x?xf32>
%6 = linalg.fill(%5, %cst) : tensor<?x?xf32>, f32 -> tensor<?x?xf32>
%7 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%6 : tensor<?x?xf32>) -> tensor<?x?xf32>
return %7 : tensor<?x?xf32>
}
// -----
// Basic check of error reporting.
// expected-error@+1 {{Module does not conform to npcomp's backend contract.}}
module {
func @disallowed() {
// expected-error@+1 {{failed to legalize operation 'unknown_dialect.unknown_op'}}
"unknown_dialect.unknown_op"() : () -> ()
return
}
}
// -----
// TODO: Improve these errors to give more exact reporting.
//
// The reporting we inherit from dialect conversion is not precise.
// For example, here we want it to explicitly call out that
// `tensor<?x!numpy.any_dtype>` is the problem here, which suggests
// that type inference didn't succeed, or insufficient type information
// was available.
//
// Ultimately, the output of this pass needs to be conveyed to the user
// in an understandable way, such as suggesting a particular place where
// a shape annotation is needed.
// expected-error@+1 {{Module does not conform to npcomp's backend contract.}}
module {
func @disallowed(%arg0: tensor<?x!numpy.any_dtype>) -> tensor<?x!numpy.any_dtype> {
// expected-error@+1 {{failed to legalize operation 'std.return'}}
return %arg0 : tensor<?x!numpy.any_dtype>
}
}