mirror of https://github.com/llvm/torch-mlir
Add npcomp-verify-backend-contract pass.
This pass verifies that a given module satisfies the contract that we have for backends. This is phrased as an "allowlist", because we want to keep this interface tight. Also, this gives much better diagnostics than a backend randomly crashing or failing to compile would (though they could still be improved). This was especially painful because if we had `tensor<?x!numpy.any_dtype>` slip through, at some point RefBackend would convert it to a memref type and trip the "verify type invariants" assertion which gives no location or anything and crashed the process, which was very unpleasant. We implement this with the dialect conversion framework, which works reasonably well and was quick to put together and familiar, but is still very "op oriented". We probably want to make this hand-rolled eventually, especially the error reporting (the most useful kind of error for a dialect conversion user is not necessarily the best for this use case). Also, in production, these error will go to users, and need to be surfaced carefully such as "the compiler needs a type annotation on this function parameter" which in general requires some special analysis, wordsmithing, and overall awareness of the e2e use case (such as how much we can lean into certain source locations) to provide a meaningful user-level diagnostic. Also, add `inline` to the current frontend lowering pass pipeline to allow slightly more complicated programs that otherwise would fail on shape inference.pull/208/head
parent
f5dfa02523
commit
c4123d4d4d
|
@ -19,12 +19,19 @@ logging.enable()
|
|||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
class Submodule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.mm(lhs, rhs)
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.s = Submodule()
|
||||
def forward(self, lhs, rhs):
|
||||
return self.s.forward(lhs, rhs)
|
||||
|
||||
test_module = TestModule()
|
||||
class_annotator = torch_mlir.ClassAnnotator()
|
||||
recursivescriptmodule = torch.jit.script(test_module)
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
add_subdirectory(Common)
|
||||
|
||||
# Currently this doesn't introduce any actual dependency on IREE, so add it
|
||||
# unconditionally.
|
||||
# TODO: Put this behind the NPCOMP_ENABLE_IREE flag.
|
||||
add_subdirectory(IREE)
|
|
@ -0,0 +1,5 @@
|
|||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||
add_public_tablegen_target(NPCOMPCommonBackendPassIncGen)
|
||||
|
||||
add_mlir_doc(Passes -gen-pass-doc CommonBackendPasses ./)
|
|
@ -0,0 +1,27 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_BACKEND_COMMON_PASSES_H
|
||||
#define NPCOMP_BACKEND_COMMON_PASSES_H
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace CommonBackend {
|
||||
/// Registers all CommonBackend passes.
|
||||
void registerCommonBackendPasses();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createVerifyBackendContractPass();
|
||||
|
||||
} // namespace CommonBackend
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
||||
#endif // NPCOMP_BACKEND_COMMON_PASSES_H
|
|
@ -0,0 +1,19 @@
|
|||
//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_BACKEND_COMMON_PASSES
|
||||
#define NPCOMP_BACKEND_COMMON_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def VerifyBackendContract : Pass<"npcomp-verify-backend-contract", "ModuleOp"> {
|
||||
let summary = "Verifies conformity to the backend contract that npcomp targets";
|
||||
let constructor = "mlir::NPCOMP::CommonBackend::createVerifyBackendContractPass()";
|
||||
}
|
||||
|
||||
#endif // NPCOMP_BACKEND_COMMON_PASSES
|
|
@ -0,0 +1,5 @@
|
|||
# Common backend utilities
|
||||
|
||||
This directory contains passes/transformations/analyses/etc. that are relevant
|
||||
to the backend contract that npcomp targets, but otherwise independent of any
|
||||
particular backend.
|
|
@ -1,8 +1,4 @@
|
|||
# Currently this doesn't introduce any actual dependency on IREE, so add it
|
||||
# unconditionally.
|
||||
# TODO: Put this behind the NPCOMP_ENABLE_IREE flag.
|
||||
add_subdirectory(Backend/IREE)
|
||||
|
||||
add_subdirectory(Backend)
|
||||
add_subdirectory(Conversion)
|
||||
add_subdirectory(Dialect)
|
||||
add_subdirectory(RefBackend)
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
add_subdirectory(Common)
|
||||
|
||||
if(NPCOMP_ENABLE_REFJIT)
|
||||
add_subdirectory(RefJIT)
|
||||
endif()
|
||||
|
||||
# Currently this doesn't introduce any actual dependency on IREE, so add it
|
||||
# unconditionally.
|
||||
# TODO: Put this behind the NPCOMP_ENABLE_IREE flag.
|
||||
add_subdirectory(IREE)
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
add_npcomp_library(NPCOMPCommonBackend
|
||||
VerifyBackendContract.cpp
|
||||
Passes.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SRC_DIR}/include/npcomp/Backend/Common
|
||||
|
||||
DEPENDS
|
||||
NPCOMPCommonBackendPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRLinalg
|
||||
MLIRMemRef
|
||||
MLIRStandard
|
||||
MLIRMath
|
||||
)
|
||||
|
||||
mlir_check_all_link_libraries(NPCOMPCommonBackend)
|
|
@ -0,0 +1,25 @@
|
|||
//===- PassDetail.h - Pass class details ------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef BACKEND_COMMON_PASSDETAIL_H
|
||||
#define BACKEND_COMMON_PASSDETAIL_H
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace CommonBackend {
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "npcomp/Backend/Common/Passes.h.inc"
|
||||
|
||||
} // namespace CommonBackend
|
||||
} // namespace NPCOMP
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // BACKEND_COMMON_PASSDETAIL_H
|
|
@ -0,0 +1,25 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
#include "npcomp/Backend/Common/Passes.h"
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::CommonBackend;
|
||||
|
||||
namespace {
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "npcomp/Backend/Common/Passes.h.inc"
|
||||
} // end namespace
|
||||
|
||||
void mlir::NPCOMP::CommonBackend::registerCommonBackendPasses() {
|
||||
::registerPasses();
|
||||
}
|
|
@ -0,0 +1,83 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "npcomp/Backend/Common/Passes.h"
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::CommonBackend;
|
||||
|
||||
namespace {
|
||||
class VerifyBackendContractPass
|
||||
: public VerifyBackendContractBase<VerifyBackendContractPass> {
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
auto module = getOperation();
|
||||
TypeConverter converter;
|
||||
converter.addConversion([](RankedTensorType type) -> Type {
|
||||
if (BaseMemRefType::isValidElementType(type.getElementType()))
|
||||
return type;
|
||||
return nullptr;
|
||||
});
|
||||
TypeConverter scalarConverter;
|
||||
for (TypeConverter *c : {&converter, &scalarConverter}) {
|
||||
c->addConversion([](FloatType type) { return type; });
|
||||
c->addConversion([](IntegerType type) { return type; });
|
||||
c->addConversion([](IndexType type) { return type; });
|
||||
}
|
||||
|
||||
auto opHasLegalTypes = [&](Operation *op) { return converter.isLegal(op); };
|
||||
auto isLegalScalarOp = [&](Operation *op) {
|
||||
// We recognize basic scalar ops by them having the trait "Elementwise",
|
||||
// even though we don't expect them to operate on tensors.
|
||||
return scalarConverter.isLegal(op) &&
|
||||
op->hasTrait<OpTrait::Elementwise>();
|
||||
};
|
||||
|
||||
ConversionTarget target(*context);
|
||||
|
||||
// Structural operations.
|
||||
target.addDynamicallyLegalOp<ModuleOp, FuncOp, ReturnOp>(opHasLegalTypes);
|
||||
|
||||
// Basic scalar operations.
|
||||
target.addDynamicallyLegalDialect<StandardOpsDialect>(isLegalScalarOp);
|
||||
target.addDynamicallyLegalDialect<math::MathDialect>(isLegalScalarOp);
|
||||
|
||||
// Tensor operations should go through linalg.
|
||||
target.addDynamicallyLegalDialect<linalg::LinalgDialect>(opHasLegalTypes);
|
||||
// DimOp is used to query tensor sizes.
|
||||
target.addDynamicallyLegalOp<memref::DimOp>(opHasLegalTypes);
|
||||
|
||||
// AssertOp is used to terminate the program for error guards.
|
||||
target.addLegalOp<AssertOp>();
|
||||
// ConstantOp is used for tensors and for scalars.
|
||||
target.addDynamicallyLegalOp<ConstantOp>(opHasLegalTypes);
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
if (failed(applyFullConversion(module, target, std::move(patterns)))) {
|
||||
module.emitError()
|
||||
<< "Module does not conform to npcomp's backend contract. See "
|
||||
"dialect conversion legality information above.";
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::NPCOMP::CommonBackend::createVerifyBackendContractPass() {
|
||||
return std::make_unique<VerifyBackendContractPass>();
|
||||
}
|
|
@ -27,6 +27,7 @@ void npcompRegisterAllPasses() {
|
|||
|
||||
// Upstream passes we depend on.
|
||||
::mlir::registerSymbolDCEPass();
|
||||
::mlir::registerInlinerPass();
|
||||
::mlir::registerCanonicalizerPass();
|
||||
::mlir::registerSCFToStandardPass();
|
||||
::mlir::registerConvertElementwiseToLinalgPass();
|
||||
|
|
|
@ -1,18 +1,10 @@
|
|||
add_subdirectory(Backend)
|
||||
add_subdirectory(CAPI)
|
||||
add_subdirectory(Conversion)
|
||||
add_subdirectory(Dialect)
|
||||
add_subdirectory(RefBackend)
|
||||
add_subdirectory(Typing)
|
||||
|
||||
if(NPCOMP_ENABLE_REFJIT)
|
||||
add_subdirectory(Backend/RefJIT)
|
||||
endif()
|
||||
|
||||
# Currently this doesn't introduce any actual dependency on IREE, so add it
|
||||
# unconditionally.
|
||||
# TODO: Put this behind the NPCOMP_ENABLE_IREE flag.
|
||||
add_subdirectory(Backend/IREE)
|
||||
|
||||
################################################################################
|
||||
# Setup the initialization target.
|
||||
# This includes conditional dependencies based on whether features are enabled.
|
||||
|
@ -34,6 +26,7 @@ add_npcomp_library(NPCOMPInitAll
|
|||
|
||||
PUBLIC
|
||||
# Local depends
|
||||
NPCOMPCommonBackend
|
||||
NPCOMPIREEBackend
|
||||
NPCOMPRefBackend
|
||||
NPCOMPRefbackDialect
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include "npcomp/InitAll.h"
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "npcomp/Backend/Common/Passes.h"
|
||||
#include "npcomp/Backend/IREE/Passes.h"
|
||||
#include "npcomp/Conversion/Passes.h"
|
||||
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
|
||||
|
@ -52,4 +53,5 @@ void mlir::NPCOMP::registerAllPasses() {
|
|||
mlir::NPCOMP::registerTorchPasses();
|
||||
mlir::NPCOMP::registerTypingPasses();
|
||||
mlir::NPCOMP::IREEBackend::registerIREEBackendPasses();
|
||||
mlir::NPCOMP::CommonBackend::registerCommonBackendPasses();
|
||||
}
|
||||
|
|
|
@ -29,6 +29,10 @@ OBJECT_GRAPH_LOWERING_PASSES = (
|
|||
# bothersome because we don't currently have a lowering for them.
|
||||
# TODO: Support global slots in backends.
|
||||
"symbol-dce",
|
||||
# Currently, our shape inference is not powerful enough to deal with
|
||||
# calls, so inline everything.
|
||||
# TODO: Improve shape inference.
|
||||
"inline",
|
||||
# Incorporate user annotations and remove signature Python-isms.
|
||||
"torch-adjust-calling-conventions",
|
||||
)
|
||||
|
@ -65,6 +69,7 @@ TORCH_TO_TCP_PASSES = (
|
|||
"func(convert-aten-to-tcf)",
|
||||
"func(convert-tcf-to-std)",
|
||||
"func(convert-elementwise-to-linalg)",
|
||||
"npcomp-verify-backend-contract",
|
||||
)
|
||||
|
||||
def lower_module(imported_module: Module):
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
// RUN: npcomp-opt -npcomp-verify-backend-contract -split-input-file -verify-diagnostics -allow-unregistered-dialect %s | FileCheck %s
|
||||
|
||||
// CHECK: func @mm
|
||||
func @mm(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> attributes {iree.module.export} {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%cst = constant 0.000000e+00 : f32
|
||||
%0 = memref.dim %arg0, %c0 : tensor<?x?xf32>
|
||||
%1 = memref.dim %arg0, %c1 : tensor<?x?xf32>
|
||||
%2 = memref.dim %arg1, %c0 : tensor<?x?xf32>
|
||||
%3 = memref.dim %arg1, %c1 : tensor<?x?xf32>
|
||||
%4 = cmpi eq, %1, %2 : index
|
||||
assert %4, "mismatching contracting dimension for aten.mm"
|
||||
%5 = linalg.init_tensor [%0, %3] : tensor<?x?xf32>
|
||||
%6 = linalg.fill(%5, %cst) : tensor<?x?xf32>, f32 -> tensor<?x?xf32>
|
||||
%7 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%6 : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
return %7 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Basic check of error reporting.
|
||||
|
||||
// expected-error@+1 {{Module does not conform to npcomp's backend contract.}}
|
||||
module {
|
||||
func @disallowed() {
|
||||
// expected-error@+1 {{failed to legalize operation 'unknown_dialect.unknown_op'}}
|
||||
"unknown_dialect.unknown_op"() : () -> ()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// TODO: Improve these errors to give more exact reporting.
|
||||
//
|
||||
// The reporting we inherit from dialect conversion is not precise.
|
||||
// For example, here we want it to explicitly call out that
|
||||
// `tensor<?x!numpy.any_dtype>` is the problem here, which suggests
|
||||
// that type inference didn't succeed, or insufficient type information
|
||||
// was available.
|
||||
//
|
||||
// Ultimately, the output of this pass needs to be conveyed to the user
|
||||
// in an understandable way, such as suggesting a particular place where
|
||||
// a shape annotation is needed.
|
||||
|
||||
// expected-error@+1 {{Module does not conform to npcomp's backend contract.}}
|
||||
module {
|
||||
func @disallowed(%arg0: tensor<?x!numpy.any_dtype>) -> tensor<?x!numpy.any_dtype> {
|
||||
// expected-error@+1 {{failed to legalize operation 'std.return'}}
|
||||
return %arg0 : tensor<?x!numpy.any_dtype>
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue