mirror of https://github.com/llvm/torch-mlir
torch: add pass to catch non-value tensors (#1052)
This patch adds a new pass `torch-verify-conversion-to-value-semantics`, which looks for non-value semantics tensors to catch such tensors early during compilation. This pass requires `torch-refine-public-return` pass to ensure that return operations are updated to use value tensors, followed by the canonicalize pass to remove any dead ops that may use or produce non-value tensors.pull/1053/head snapshot-20220714.533
parent
64c04bd5f6
commit
29bc48aedb
|
@ -77,6 +77,9 @@ createSimplifyShapeCalculationsPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>> createDropShapeCalculationsPass();
|
std::unique_ptr<OperationPass<func::FuncOp>> createDropShapeCalculationsPass();
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
createVerifyConversionToValueSemanticsPass();
|
||||||
|
|
||||||
StringRef getShapeLibrary();
|
StringRef getShapeLibrary();
|
||||||
|
|
||||||
} // namespace Torch
|
} // namespace Torch
|
||||||
|
|
|
@ -253,4 +253,16 @@ def DropShapeCalculations : Pass<"torch-drop-shape-calculations", "func::FuncOp"
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def VerifyConversionToValueSemantics
|
||||||
|
: Pass<"torch-verify-conversion-to-value-semantics", "ModuleOp"> {
|
||||||
|
let summary = "Verify that all tensors have been converted to value semantics";
|
||||||
|
let constructor =
|
||||||
|
"mlir::torch::Torch::createVerifyConversionToValueSemanticsPass()";
|
||||||
|
let description = [{
|
||||||
|
Prior passes in the pipeline may have missed converting all tensors to value
|
||||||
|
semantics and we wish to catch such failures early instead of fixing
|
||||||
|
individual cases downstream.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
#endif // TORCHMLIR_TORCH_PASSES
|
#endif // TORCHMLIR_TORCH_PASSES
|
||||||
|
|
|
@ -13,6 +13,7 @@ add_mlir_library(TorchMLIRTorchPasses
|
||||||
ReifyShapeCalculations.cpp
|
ReifyShapeCalculations.cpp
|
||||||
ShapeLibrary.cpp
|
ShapeLibrary.cpp
|
||||||
SimplifyShapeCalculations.cpp
|
SimplifyShapeCalculations.cpp
|
||||||
|
VerifyConversionToValueSemantics.cpp
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
ADDITIONAL_HEADER_DIRS
|
||||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/Torch/Transforms
|
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/Torch/Transforms
|
||||||
|
|
|
@ -126,6 +126,13 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
|
||||||
// Convert the bulk of non-ABI-visible !torch.tensor's to !torch.vtensor's.
|
// Convert the bulk of non-ABI-visible !torch.tensor's to !torch.vtensor's.
|
||||||
pm.addNestedPass<func::FuncOp>(Torch::createMaximizeValueSemanticsPass());
|
pm.addNestedPass<func::FuncOp>(Torch::createMaximizeValueSemanticsPass());
|
||||||
|
|
||||||
|
// Update the return op to return value tensors and remove dead ops.
|
||||||
|
pm.addPass(Torch::createRefinePublicReturnPass());
|
||||||
|
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||||
|
|
||||||
|
// Ensure that all tensors have been converted to value semantics.
|
||||||
|
pm.addPass(Torch::createVerifyConversionToValueSemanticsPass());
|
||||||
|
|
||||||
// Do shape refinement.
|
// Do shape refinement.
|
||||||
// This must be run before RefineTypes (which primarily does dtype inference),
|
// This must be run before RefineTypes (which primarily does dtype inference),
|
||||||
// because Torch type promotion rules actually depend on the shape of the
|
// because Torch type promotion rules actually depend on the shape of the
|
||||||
|
|
|
@ -0,0 +1,64 @@
|
||||||
|
//===- VerifyConversionToValueSemantics.cpp ----------------------*- C++-*-===//
|
||||||
|
//
|
||||||
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "PassDetail.h"
|
||||||
|
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
|
static LogicalResult checkValueType(Operation *op, Value value) {
|
||||||
|
auto isNotValueTensorType = value.getType().isa<NonValueTensorType>();
|
||||||
|
return isNotValueTensorType
|
||||||
|
? op->emitError(
|
||||||
|
"found a non-value tensor type, this is likely due to a "
|
||||||
|
"missing case in the MaximizeValueSemantics pass")
|
||||||
|
: success();
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class VerifyConversionToValueSemanticsPass
|
||||||
|
: public VerifyConversionToValueSemanticsBase<
|
||||||
|
VerifyConversionToValueSemanticsPass> {
|
||||||
|
void runOnOperation() override {
|
||||||
|
bool didFail = false;
|
||||||
|
auto walkResult = getOperation().walk([&](Block *block) {
|
||||||
|
for (BlockArgument arg : block->getArguments()) {
|
||||||
|
if (failed(checkValueType(block->getParentOp(), arg))) {
|
||||||
|
didFail = true;
|
||||||
|
return WalkResult::interrupt();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (Operation &op : *block) {
|
||||||
|
for (OpResult result : op.getResults()) {
|
||||||
|
if (failed(checkValueType(&op, result))) {
|
||||||
|
didFail = true;
|
||||||
|
return WalkResult::interrupt();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return WalkResult::advance();
|
||||||
|
});
|
||||||
|
|
||||||
|
if (didFail || walkResult.wasInterrupted())
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
mlir::torch::Torch::createVerifyConversionToValueSemanticsPass() {
|
||||||
|
return std::make_unique<VerifyConversionToValueSemanticsPass>();
|
||||||
|
}
|
|
@ -0,0 +1,9 @@
|
||||||
|
// RUN: torch-mlir-opt -split-input-file -verify-diagnostics %s -torch-verify-conversion-to-value-semantics
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func.func @result_is_non_value_tensor(%arg: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> {
|
||||||
|
// @expected-error@+1 {{found a non-value tensor type, this is likely due to a missing case in the MaximizeValueSemantics pass}}
|
||||||
|
%neg = torch.aten.neg %arg : !torch.vtensor<[2],f32> -> !torch.tensor
|
||||||
|
return %arg : !torch.vtensor<[2],f32>
|
||||||
|
}
|
Loading…
Reference in New Issue