Remove VerifyInvariantsBeforeBackendLowering

LowerToBackendContract now checks all this consistently.
pull/1292/head
Sean Silva 2022-08-25 22:18:19 +00:00
parent b1fa7a2b9d
commit 0e3ddbac91
7 changed files with 1 additions and 161 deletions

View File

@ -41,9 +41,6 @@ void createTorchBackendToMhloBackendPipeline(
const torch::Torch::TorchLoweringPipelineOptions &options);
#endif
std::unique_ptr<OperationPass<ModuleOp>>
createVerifyInvariantsBeforeBackendLoweringPass();
std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();
std::unique_ptr<OperationPass<func::FuncOp>>

View File

@ -12,27 +12,6 @@
include "mlir/Pass/PassBase.td"
def VerifyInvariantsBeforeBackendLowering
: Pass<"torch-verify-invariants-before-backend-lowering", "ModuleOp"> {
let summary = "Verify invariants required by backend lowering";
let constructor =
"mlir::torch::TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass()";
let description = [{
This pass checks any invariants needed by the process of lowering the
`torch` dialect to the linalg-on-tensors backend contract.
The most important invariant is that all tensors should be ranked and have
a known dtype. It is useful to catch this early because it usually
represents a simple bug in RefineTypes, but can manifest as many different
kinds of obscure symptoms during lowering.
TODO: This pass should probably be phrased as checking the
"torch backend contract" and moved to that dialect once we have more
substantial definition definition around what that layer is from an
"allowlist" perspective.
}];
}
def FuncBackendTypeConversion : Pass<"torch-func-backend-type-conversion", "ModuleOp"> {
let summary = "Convert functions to operate on builtin tensors";
let constructor = "mlir::torch::TorchConversion::createFuncBackendTypeConversionPass()";

View File

@ -18,7 +18,6 @@ add_mlir_library(TorchMLIRTorchConversionPasses
BackendTypeConversion.cpp
BackendTypeConversionPasses.cpp
Passes.cpp
VerifyInvariantsBeforeBackendLowering.cpp
VerifyLinalgOnTensorsBackendContract.cpp
VerifyTosaBackendContract.cpp

View File

@ -63,10 +63,6 @@ void mlir::torch::registerTorchConversionPasses() {
void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
// Check some invariants to catch errors in a clear way.
pm.addPass(
TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass());
// Lower to linalg + guards which is the input to codegen backends.
// We do this first as it tends to involve pattern-matching against constants,
// (e.g. dimensions which must be constant in a ranked programming model)
@ -101,10 +97,6 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
void TorchConversion::createTorchBackendToTosaBackendPipeline(
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
// Check some invariants to catch errors in a clear way.
pm.addPass(
TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToTosaPass());
// Perform rank broadcasting so TosaToLinalg pass works
pm.addNestedPass<func::FuncOp>(createTosaMakeBroadcastablePass());
@ -130,10 +122,6 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline(
#ifdef TORCH_MLIR_ENABLE_MHLO
void TorchConversion::createTorchBackendToMhloBackendPipeline(
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
// Check some invariants to catch errors in a clear way.
pm.addPass(
TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToMhloPass());
// Clean up any non-canonical code introduced above..

View File

@ -1,86 +0,0 @@
//===- VerifyInvariantsBeforeBackendLowering.cpp -----------------*- C++-*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::TorchConversion;
using namespace mlir::torch;
static LogicalResult checkValueInvariants(Operation *errorReportOp, Value v) {
// TODO: Make this an allowlist instead of a denylist.
// TODO: Make this stricter.
auto type = v.getType();
if (auto valueTensorType = type.dyn_cast<Torch::ValueTensorType>()) {
if (!valueTensorType.hasDtype() || !valueTensorType.hasSizes())
return errorReportOp->emitError()
.append("unsupported by backend lowering: tensor with unknown rank "
"or dtype")
.attachNote()
.append("this is likely due to a missing shape transfer function in "
"shape_lib_gen.py or missing case in RefineTypes");
}
return success();
}
namespace {
class VerifyInvariantsBeforeBackendLoweringPass
: public VerifyInvariantsBeforeBackendLoweringBase<
VerifyInvariantsBeforeBackendLoweringPass> {
void runOnOperation() override {
if (getOperation()
.walk([](Torch::GlobalSlotModuleInitializerOp op) {
op.emitError()
<< "unsupported by backend lowering: module initializers";
return WalkResult::interrupt();
})
.wasInterrupted())
return signalPassFailure();
auto walkResult = getOperation().walk([&](Block *block) {
// Check invariants on all the Value's in the program.
// That is, check all BlockArgument's and OpResult's.
for (BlockArgument arg : block->getArguments())
if (failed(checkValueInvariants(block->getParentOp(), arg)))
return WalkResult::interrupt();
for (Operation &op : *block) {
if (isa<Torch::OperatorOp>(op)) {
op.emitError()
.append("unsupported by backend lowering: `torch.operator` op")
.attachNote()
.append("this is likely due to a missing op that needs to be "
"generated by torch_ods_gen.py");
return WalkResult::interrupt();
}
for (OpResult result : op.getResults())
if (failed(checkValueInvariants(&op, result)))
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (walkResult.wasInterrupted())
return signalPassFailure();
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>> mlir::torch::TorchConversion::
createVerifyInvariantsBeforeBackendLoweringPass() {
return std::make_unique<VerifyInvariantsBeforeBackendLoweringPass>();
}

View File

@ -1,36 +0,0 @@
// RUN: torch-mlir-opt -split-input-file -verify-diagnostics %s -torch-verify-invariants-before-backend-lowering
// -----
func.func @unknown_rank(%arg0: !torch.vtensor<[],f32>) {
// expected-error@+2 {{unsupported by backend lowering: tensor with unknown rank or dtype}}
// expected-note@+1 {{this is likely due to a missing shape transfer function in shape_lib_gen.py}}
%0 = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<*,f32>
return
}
// -----
func.func @unknown_dtype(%arg0: !torch.vtensor<[],f32>) {
// expected-error@+2 {{unsupported by backend lowering: tensor with unknown rank or dtype}}
// expected-note@+1 {{this is likely due to a missing shape transfer function in shape_lib_gen.py}}
%0 = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[],unk>
return
}
// -----
func.func @unresolved_operator(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.int) {
// expected-error@+2 {{unsupported by backend lowering: `torch.operator` op}}
// expected-note@+1 {{this is likely due to a missing op that needs to be generated by torch_ods_gen.py}}
torch.operator "aten.mul.Scalar"(%arg0, %arg1) : (!torch.vtensor<[],f32>, !torch.int) -> !torch.vtensor<[],f32>
return
}
// -----
// expected-error@+1 {{unsupported by backend lowering: module initializers}}
torch.global_slot.module_initializer {
torch.initialize.global_slots [
]
}

View File

@ -500,7 +500,6 @@ cc_library(
"lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp",
"lib/Dialect/TorchConversion/Transforms/PassDetail.h",
"lib/Dialect/TorchConversion/Transforms/Passes.cpp",
"lib/Dialect/TorchConversion/Transforms/VerifyInvariantsBeforeBackendLowering.cpp",
"lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp",
"lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp",
],