mirror of https://github.com/llvm/torch-mlir
[cleanup] Make diagnostics better
Also remove some unused imports.pull/1607/head snapshot-20221117.660
parent
5f7177da35
commit
39de4d6265
|
@ -142,9 +142,12 @@ static LogicalResult mungeFunction(
|
|||
SmallVector<Type> newArgTypes;
|
||||
for (auto arg : func.getArguments()) {
|
||||
auto type = arg.getType();
|
||||
if (!isArgMemRefTypeValid(type))
|
||||
return emitError(arg.getLoc(),
|
||||
"argument must be a memref of f32, f64, i32, i64, i8, i1");
|
||||
if (!isArgMemRefTypeValid(type)) {
|
||||
return emitError(arg.getLoc())
|
||||
.append("argument must be a memref of f32, f64, i32, i64, i8, i1 but "
|
||||
"got ",
|
||||
type);
|
||||
}
|
||||
auto cast = b.create<memref::CastOp>(arg.getLoc(), type, arg);
|
||||
arg.replaceAllUsesExcept(cast, cast);
|
||||
arg.setType(getAbiTypeForMemRef(type));
|
||||
|
|
|
@ -3,9 +3,6 @@
|
|||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import copy
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
|
||||
|
|
|
@ -22,7 +22,7 @@ __all__ = [
|
|||
|
||||
def assert_arg_type_is_supported(ty):
|
||||
SUPPORTED = [np.float32, np.float64, np.uint8, np.int8, np.int32, np.int64, np.bool_]
|
||||
assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported"
|
||||
assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported, but got {ty}"
|
||||
|
||||
|
||||
memref_type_to_np_dtype = {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: torch-mlir-opt %s -refback-munge-calling-conventions -split-input-file | FileCheck %s
|
||||
// RUN: torch-mlir-opt %s -refback-munge-calling-conventions -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @f(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} {
|
||||
|
@ -70,3 +70,10 @@ func.func @multiple_return_values(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %a
|
|||
func.func @two_return_values(%arg0: memref<?xf32>, %arg1: memref<?xi64>) -> (memref<?xf32>, memref<?xi64>) {
|
||||
return %arg0 ,%arg1 : memref<?xf32>, memref<?xi64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error-re @+1 {{argument must be a memref of {{.*}} but got 'tensor<?xf32>'}}
|
||||
func.func @f(%arg0: tensor<?xf32>) {
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue