Clean up verification of calling conventions.

The implementation at this place was a remnent of the times the pipeline was
run only once.
Rely instead on the backend verification, after optimizations have had an
opportunity to resolve some uncertainties. (e.g. `!torch.optional`).
pull/2328/head
Alexandre Rames 2023-07-18 07:32:26 -07:00 committed by Alexandre Rames
parent 91a9baa3e7
commit 4847563bed
4 changed files with 30 additions and 63 deletions

View File

@ -187,53 +187,8 @@ public:
}; };
} // namespace } // namespace
static bool isValidNonContainerResultType(Type resultType) {
return resultType.isa<Torch::BaseTensorType>() ||
resultType.isa<Torch::FloatType>() ||
resultType.isa<Torch::IntType>() ||
resultType.isa<Torch::BoolType>() ||
resultType.isa<Torch::NoneType>();
}
static LogicalResult validateReturns(func::FuncOp func) {
if (func.getResultTypes().size() > 1) {
return func->emitError(
"Functions directly imported from Python should only ever return one "
"item. Multiple return values are returned as a tuple.");
}
// Allow returns of nothing. This shouldn't be possible from Python, but it
// can happen in IR that's been directly constructed.
if (func.getResultTypes().size() == 0)
return success();
const auto& resultType = func.getResultTypes().front();
// Allow single tensor, scalar, and bool returns
if (isValidNonContainerResultType(resultType)) {
return success();
}
// Allow multi-tensor/scalar/bool tuple returns
if (auto tuple = resultType.dyn_cast<Torch::TupleType>()) {
const auto& containedTypes = tuple.getContainedTypes();
bool containsValidTypes = llvm::all_of(
tuple.getContainedTypes(), isValidNonContainerResultType);
if (containedTypes.size() >= 2 && containsValidTypes) {
return success();
}
}
return func->emitError(
"Functions must return a single tensor-like value, multiple tensor-like "
"values, or a tuple of more than one tensor-like value. Tensor-like "
"values: tensors, scalars, bools, and Nones.");
}
static LogicalResult adjustCallingConventions(func::FuncOp func, static LogicalResult adjustCallingConventions(func::FuncOp func,
TypeBoundMap &typeBoundMap) { TypeBoundMap &typeBoundMap) {
if (failed(validateReturns(func)))
return failure();
MLIRContext *context = func.getContext(); MLIRContext *context = func.getContext();
RewritePatternSet patterns(context); RewritePatternSet patterns(context);
TypeConverter typeConverter; TypeConverter typeConverter;

View File

@ -17,8 +17,8 @@
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/Support/Debug.h"
#include "llvm/ADT/StringSet.h" #include "llvm/ADT/StringSet.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "torch-lower-to-backend-contract" #define DEBUG_TYPE "torch-lower-to-backend-contract"

View File

@ -97,20 +97,3 @@ func.func @call_tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vte
%0 = call @tuple_return(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tuple<tensor, tensor> %0 = call @tuple_return(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tuple<tensor, tensor>
return %0 : !torch.tuple<tensor, tensor> return %0 : !torch.tuple<tensor, tensor>
} }
// -----
// Single tensor tuple return
// expected-error @+1 {{Functions must return}}
func.func @single_tensor_tuple_return(%arg0: !torch.tensor) -> !torch.tuple<tensor> {
%0 = torch.prim.TupleConstruct %arg0 : !torch.tensor -> !torch.tuple<tensor>
return %0 : !torch.tuple<tensor>
}
// -----
// Multiple, non-tuple return
// expected-error @+1 {{should only ever return one item}}
func.func @multiple_non_tuple_return(%arg0: !torch.tensor) -> (!torch.tensor, !torch.tensor) {
return %arg0, %arg0 : !torch.tensor, !torch.tensor
}

View File

@ -1,7 +1,36 @@
// RUN: torch-mlir-opt -torch-verify-backend-contract-no-decompositions -split-input-file -verify-diagnostics %s // RUN: torch-mlir-opt -torch-verify-backend-contract-no-decompositions -split-input-file -verify-diagnostics %s
func.func @f(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor { func.func @f(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor {
// expected-error @below {{unsupported by backend contract: tensor with unknown rank}} // expected-error @below {{unsupported by backend contract: tensor with unknown rank}}
// expected-note @below {{this is likely due to a missing transfer function}} // expected-note @below {{this is likely due to a missing transfer function}}
%t = torch.aten.t %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor %t = torch.aten.t %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor
return %t : !torch.vtensor return %t : !torch.vtensor
} }
// -----
// expected-error @below {{invalid dtype 'i9'}}
func.func @bad_element_type(%arg: !torch.vtensor<[?],i9>) -> !torch.vtensor<[?],i9> {
return %arg : !torch.vtensor<[?],i9>
}
// -----
// expected-error @below {{unsupported by backend contract: non-value tensor type}}
// expected-note @below {{this is likely due to a missing case in the MaximizeValueSemantics pass}}
func.func @non_value_tensor(%arg0: !torch.tensor) -> !torch.tensor {
return %arg0 : !torch.tensor
}
// -----
func.func @valid_tuple(%arg0: !torch.vtensor<[?],f32>) -> !torch.tuple<vtensor<[?],f32>> {
%0 = torch.prim.TupleConstruct %arg0 : !torch.vtensor<[?],f32> -> !torch.tuple<vtensor<[?],f32>>
return %0 : !torch.tuple<vtensor<[?],f32>>
}
// -----
func.func @valid_multiple_ret_values(%arg0: !torch.vtensor<[?],f32>) -> (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) {
return %arg0, %arg0 : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>
}