mirror of https://github.com/llvm/torch-mlir
torch: add pass to catch non-value tensors (#1052)
This patch adds a new pass `torch-verify-conversion-to-value-semantics`, which looks for non-value semantics tensors to catch such tensors early during compilation. This pass requires `torch-refine-public-return` pass to ensure that return operations are updated to use value tensors, followed by the canonicalize pass to remove any dead ops that may use or produce non-value tensors.pull/1053/head snapshot-20220714.533
parent
64c04bd5f6
commit
29bc48aedb
|
@ -77,6 +77,9 @@ createSimplifyShapeCalculationsPass();
|
|||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createDropShapeCalculationsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createVerifyConversionToValueSemanticsPass();
|
||||
|
||||
StringRef getShapeLibrary();
|
||||
|
||||
} // namespace Torch
|
||||
|
|
|
@ -253,4 +253,16 @@ def DropShapeCalculations : Pass<"torch-drop-shape-calculations", "func::FuncOp"
|
|||
}];
|
||||
}
|
||||
|
||||
def VerifyConversionToValueSemantics
|
||||
: Pass<"torch-verify-conversion-to-value-semantics", "ModuleOp"> {
|
||||
let summary = "Verify that all tensors have been converted to value semantics";
|
||||
let constructor =
|
||||
"mlir::torch::Torch::createVerifyConversionToValueSemanticsPass()";
|
||||
let description = [{
|
||||
Prior passes in the pipeline may have missed converting all tensors to value
|
||||
semantics and we wish to catch such failures early instead of fixing
|
||||
individual cases downstream.
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // TORCHMLIR_TORCH_PASSES
|
||||
|
|
|
@ -13,6 +13,7 @@ add_mlir_library(TorchMLIRTorchPasses
|
|||
ReifyShapeCalculations.cpp
|
||||
ShapeLibrary.cpp
|
||||
SimplifyShapeCalculations.cpp
|
||||
VerifyConversionToValueSemantics.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/Torch/Transforms
|
||||
|
|
|
@ -126,6 +126,13 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
|
|||
// Convert the bulk of non-ABI-visible !torch.tensor's to !torch.vtensor's.
|
||||
pm.addNestedPass<func::FuncOp>(Torch::createMaximizeValueSemanticsPass());
|
||||
|
||||
// Update the return op to return value tensors and remove dead ops.
|
||||
pm.addPass(Torch::createRefinePublicReturnPass());
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
|
||||
// Ensure that all tensors have been converted to value semantics.
|
||||
pm.addPass(Torch::createVerifyConversionToValueSemanticsPass());
|
||||
|
||||
// Do shape refinement.
|
||||
// This must be run before RefineTypes (which primarily does dtype inference),
|
||||
// because Torch type promotion rules actually depend on the shape of the
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
//===- VerifyConversionToValueSemantics.cpp ----------------------*- C++-*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
static LogicalResult checkValueType(Operation *op, Value value) {
|
||||
auto isNotValueTensorType = value.getType().isa<NonValueTensorType>();
|
||||
return isNotValueTensorType
|
||||
? op->emitError(
|
||||
"found a non-value tensor type, this is likely due to a "
|
||||
"missing case in the MaximizeValueSemantics pass")
|
||||
: success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
class VerifyConversionToValueSemanticsPass
|
||||
: public VerifyConversionToValueSemanticsBase<
|
||||
VerifyConversionToValueSemanticsPass> {
|
||||
void runOnOperation() override {
|
||||
bool didFail = false;
|
||||
auto walkResult = getOperation().walk([&](Block *block) {
|
||||
for (BlockArgument arg : block->getArguments()) {
|
||||
if (failed(checkValueType(block->getParentOp(), arg))) {
|
||||
didFail = true;
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
|
||||
for (Operation &op : *block) {
|
||||
for (OpResult result : op.getResults()) {
|
||||
if (failed(checkValueType(&op, result))) {
|
||||
didFail = true;
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
if (didFail || walkResult.wasInterrupted())
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::torch::Torch::createVerifyConversionToValueSemanticsPass() {
|
||||
return std::make_unique<VerifyConversionToValueSemanticsPass>();
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
// RUN: torch-mlir-opt -split-input-file -verify-diagnostics %s -torch-verify-conversion-to-value-semantics
|
||||
|
||||
// -----
|
||||
|
||||
func.func @result_is_non_value_tensor(%arg: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> {
|
||||
// @expected-error@+1 {{found a non-value tensor type, this is likely due to a missing case in the MaximizeValueSemantics pass}}
|
||||
%neg = torch.aten.neg %arg : !torch.vtensor<[2],f32> -> !torch.tensor
|
||||
return %arg : !torch.vtensor<[2],f32>
|
||||
}
|
Loading…
Reference in New Issue