mirror of https://github.com/llvm/torch-mlir
Clean up verification of calling conventions.
The implementation at this place was a remnent of the times the pipeline was run only once. Rely instead on the backend verification, after optimizations have had an opportunity to resolve some uncertainties. (e.g. `!torch.optional`).pull/2328/head
parent
91a9baa3e7
commit
4847563bed
|
@ -187,53 +187,8 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
static bool isValidNonContainerResultType(Type resultType) {
|
|
||||||
return resultType.isa<Torch::BaseTensorType>() ||
|
|
||||||
resultType.isa<Torch::FloatType>() ||
|
|
||||||
resultType.isa<Torch::IntType>() ||
|
|
||||||
resultType.isa<Torch::BoolType>() ||
|
|
||||||
resultType.isa<Torch::NoneType>();
|
|
||||||
}
|
|
||||||
|
|
||||||
static LogicalResult validateReturns(func::FuncOp func) {
|
|
||||||
if (func.getResultTypes().size() > 1) {
|
|
||||||
return func->emitError(
|
|
||||||
"Functions directly imported from Python should only ever return one "
|
|
||||||
"item. Multiple return values are returned as a tuple.");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Allow returns of nothing. This shouldn't be possible from Python, but it
|
|
||||||
// can happen in IR that's been directly constructed.
|
|
||||||
if (func.getResultTypes().size() == 0)
|
|
||||||
return success();
|
|
||||||
|
|
||||||
const auto& resultType = func.getResultTypes().front();
|
|
||||||
|
|
||||||
// Allow single tensor, scalar, and bool returns
|
|
||||||
if (isValidNonContainerResultType(resultType)) {
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Allow multi-tensor/scalar/bool tuple returns
|
|
||||||
if (auto tuple = resultType.dyn_cast<Torch::TupleType>()) {
|
|
||||||
const auto& containedTypes = tuple.getContainedTypes();
|
|
||||||
bool containsValidTypes = llvm::all_of(
|
|
||||||
tuple.getContainedTypes(), isValidNonContainerResultType);
|
|
||||||
if (containedTypes.size() >= 2 && containsValidTypes) {
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return func->emitError(
|
|
||||||
"Functions must return a single tensor-like value, multiple tensor-like "
|
|
||||||
"values, or a tuple of more than one tensor-like value. Tensor-like "
|
|
||||||
"values: tensors, scalars, bools, and Nones.");
|
|
||||||
}
|
|
||||||
|
|
||||||
static LogicalResult adjustCallingConventions(func::FuncOp func,
|
static LogicalResult adjustCallingConventions(func::FuncOp func,
|
||||||
TypeBoundMap &typeBoundMap) {
|
TypeBoundMap &typeBoundMap) {
|
||||||
if (failed(validateReturns(func)))
|
|
||||||
return failure();
|
|
||||||
MLIRContext *context = func.getContext();
|
MLIRContext *context = func.getContext();
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
TypeConverter typeConverter;
|
TypeConverter typeConverter;
|
||||||
|
|
|
@ -17,8 +17,8 @@
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||||
#include "llvm/Support/Debug.h"
|
|
||||||
#include "llvm/ADT/StringSet.h"
|
#include "llvm/ADT/StringSet.h"
|
||||||
|
#include "llvm/Support/Debug.h"
|
||||||
|
|
||||||
#define DEBUG_TYPE "torch-lower-to-backend-contract"
|
#define DEBUG_TYPE "torch-lower-to-backend-contract"
|
||||||
|
|
||||||
|
|
|
@ -97,20 +97,3 @@ func.func @call_tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vte
|
||||||
%0 = call @tuple_return(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tuple<tensor, tensor>
|
%0 = call @tuple_return(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tuple<tensor, tensor>
|
||||||
return %0 : !torch.tuple<tensor, tensor>
|
return %0 : !torch.tuple<tensor, tensor>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// Single tensor tuple return
|
|
||||||
// expected-error @+1 {{Functions must return}}
|
|
||||||
func.func @single_tensor_tuple_return(%arg0: !torch.tensor) -> !torch.tuple<tensor> {
|
|
||||||
%0 = torch.prim.TupleConstruct %arg0 : !torch.tensor -> !torch.tuple<tensor>
|
|
||||||
return %0 : !torch.tuple<tensor>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// Multiple, non-tuple return
|
|
||||||
// expected-error @+1 {{should only ever return one item}}
|
|
||||||
func.func @multiple_non_tuple_return(%arg0: !torch.tensor) -> (!torch.tensor, !torch.tensor) {
|
|
||||||
return %arg0, %arg0 : !torch.tensor, !torch.tensor
|
|
||||||
}
|
|
|
@ -1,7 +1,36 @@
|
||||||
// RUN: torch-mlir-opt -torch-verify-backend-contract-no-decompositions -split-input-file -verify-diagnostics %s
|
// RUN: torch-mlir-opt -torch-verify-backend-contract-no-decompositions -split-input-file -verify-diagnostics %s
|
||||||
|
|
||||||
func.func @f(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor {
|
func.func @f(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor {
|
||||||
// expected-error @below {{unsupported by backend contract: tensor with unknown rank}}
|
// expected-error @below {{unsupported by backend contract: tensor with unknown rank}}
|
||||||
// expected-note @below {{this is likely due to a missing transfer function}}
|
// expected-note @below {{this is likely due to a missing transfer function}}
|
||||||
%t = torch.aten.t %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor
|
%t = torch.aten.t %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor
|
||||||
return %t : !torch.vtensor
|
return %t : !torch.vtensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// expected-error @below {{invalid dtype 'i9'}}
|
||||||
|
func.func @bad_element_type(%arg: !torch.vtensor<[?],i9>) -> !torch.vtensor<[?],i9> {
|
||||||
|
return %arg : !torch.vtensor<[?],i9>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// expected-error @below {{unsupported by backend contract: non-value tensor type}}
|
||||||
|
// expected-note @below {{this is likely due to a missing case in the MaximizeValueSemantics pass}}
|
||||||
|
func.func @non_value_tensor(%arg0: !torch.tensor) -> !torch.tensor {
|
||||||
|
return %arg0 : !torch.tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func.func @valid_tuple(%arg0: !torch.vtensor<[?],f32>) -> !torch.tuple<vtensor<[?],f32>> {
|
||||||
|
%0 = torch.prim.TupleConstruct %arg0 : !torch.vtensor<[?],f32> -> !torch.tuple<vtensor<[?],f32>>
|
||||||
|
return %0 : !torch.tuple<vtensor<[?],f32>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func.func @valid_multiple_ret_values(%arg0: !torch.vtensor<[?],f32>) -> (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) {
|
||||||
|
return %arg0, %arg0 : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue