torch: add pass to catch non-value tensors (#1052)

This patch adds a new pass `torch-verify-conversion-to-value-semantics`,
which looks for non-value semantics tensors to catch such tensors early
during compilation.

This pass requires `torch-refine-public-return` pass to ensure that
return operations are updated to use value tensors, followed by the
canonicalize pass to remove any dead ops that may use or produce
non-value tensors.
pull/1053/head snapshot-20220714.533
Ashay Rane 2022-07-13 17:11:15 -07:00 committed by GitHub
parent 64c04bd5f6
commit 29bc48aedb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 96 additions and 0 deletions

View File

@ -77,6 +77,9 @@ createSimplifyShapeCalculationsPass();
std::unique_ptr<OperationPass<func::FuncOp>> createDropShapeCalculationsPass();
std::unique_ptr<OperationPass<ModuleOp>>
createVerifyConversionToValueSemanticsPass();
StringRef getShapeLibrary();
} // namespace Torch

View File

@ -253,4 +253,16 @@ def DropShapeCalculations : Pass<"torch-drop-shape-calculations", "func::FuncOp"
}];
}
def VerifyConversionToValueSemantics
: Pass<"torch-verify-conversion-to-value-semantics", "ModuleOp"> {
let summary = "Verify that all tensors have been converted to value semantics";
let constructor =
"mlir::torch::Torch::createVerifyConversionToValueSemanticsPass()";
let description = [{
Prior passes in the pipeline may have missed converting all tensors to value
semantics and we wish to catch such failures early instead of fixing
individual cases downstream.
}];
}
#endif // TORCHMLIR_TORCH_PASSES

View File

@ -13,6 +13,7 @@ add_mlir_library(TorchMLIRTorchPasses
ReifyShapeCalculations.cpp
ShapeLibrary.cpp
SimplifyShapeCalculations.cpp
VerifyConversionToValueSemantics.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/Torch/Transforms

View File

@ -126,6 +126,13 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
// Convert the bulk of non-ABI-visible !torch.tensor's to !torch.vtensor's.
pm.addNestedPass<func::FuncOp>(Torch::createMaximizeValueSemanticsPass());
// Update the return op to return value tensors and remove dead ops.
pm.addPass(Torch::createRefinePublicReturnPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// Ensure that all tensors have been converted to value semantics.
pm.addPass(Torch::createVerifyConversionToValueSemanticsPass());
// Do shape refinement.
// This must be run before RefineTypes (which primarily does dtype inference),
// because Torch type promotion rules actually depend on the shape of the

View File

@ -0,0 +1,64 @@
//===- VerifyConversionToValueSemantics.cpp ----------------------*- C++-*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/IR/BuiltinOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::torch::Torch;
static LogicalResult checkValueType(Operation *op, Value value) {
auto isNotValueTensorType = value.getType().isa<NonValueTensorType>();
return isNotValueTensorType
? op->emitError(
"found a non-value tensor type, this is likely due to a "
"missing case in the MaximizeValueSemantics pass")
: success();
}
namespace {
class VerifyConversionToValueSemanticsPass
: public VerifyConversionToValueSemanticsBase<
VerifyConversionToValueSemanticsPass> {
void runOnOperation() override {
bool didFail = false;
auto walkResult = getOperation().walk([&](Block *block) {
for (BlockArgument arg : block->getArguments()) {
if (failed(checkValueType(block->getParentOp(), arg))) {
didFail = true;
return WalkResult::interrupt();
}
}
for (Operation &op : *block) {
for (OpResult result : op.getResults()) {
if (failed(checkValueType(&op, result))) {
didFail = true;
return WalkResult::interrupt();
}
}
}
return WalkResult::advance();
});
if (didFail || walkResult.wasInterrupted())
signalPassFailure();
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::Torch::createVerifyConversionToValueSemanticsPass() {
return std::make_unique<VerifyConversionToValueSemanticsPass>();
}

View File

@ -0,0 +1,9 @@
// RUN: torch-mlir-opt -split-input-file -verify-diagnostics %s -torch-verify-conversion-to-value-semantics
// -----
func.func @result_is_non_value_tensor(%arg: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> {
// @expected-error@+1 {{found a non-value tensor type, this is likely due to a missing case in the MaximizeValueSemantics pass}}
%neg = torch.aten.neg %arg : !torch.vtensor<[2],f32> -> !torch.tensor
return %arg : !torch.vtensor<[2],f32>
}