[cleanup] Make diagnostics better

Also remove some unused imports.
pull/1607/head snapshot-20221117.660
Sean Silva 2022-11-15 15:25:39 +00:00
parent 5f7177da35
commit 39de4d6265
4 changed files with 15 additions and 8 deletions

View File

@ -142,9 +142,12 @@ static LogicalResult mungeFunction(
SmallVector<Type> newArgTypes; SmallVector<Type> newArgTypes;
for (auto arg : func.getArguments()) { for (auto arg : func.getArguments()) {
auto type = arg.getType(); auto type = arg.getType();
if (!isArgMemRefTypeValid(type)) if (!isArgMemRefTypeValid(type)) {
return emitError(arg.getLoc(), return emitError(arg.getLoc())
"argument must be a memref of f32, f64, i32, i64, i8, i1"); .append("argument must be a memref of f32, f64, i32, i64, i8, i1 but "
"got ",
type);
}
auto cast = b.create<memref::CastOp>(arg.getLoc(), type, arg); auto cast = b.create<memref::CastOp>(arg.getLoc(), type, arg);
arg.replaceAllUsesExcept(cast, cast); arg.replaceAllUsesExcept(cast, cast);
arg.setType(getAbiTypeForMemRef(type)); arg.setType(getAbiTypeForMemRef(type));

View File

@ -3,9 +3,6 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE. # Also available under a BSD-style license. See LICENSE.
import copy
from typing import Any
import torch import torch
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem

View File

@ -22,7 +22,7 @@ __all__ = [
def assert_arg_type_is_supported(ty): def assert_arg_type_is_supported(ty):
SUPPORTED = [np.float32, np.float64, np.uint8, np.int8, np.int32, np.int64, np.bool_] SUPPORTED = [np.float32, np.float64, np.uint8, np.int8, np.int32, np.int64, np.bool_]
assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported" assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported, but got {ty}"
memref_type_to_np_dtype = { memref_type_to_np_dtype = {

View File

@ -1,4 +1,4 @@
// RUN: torch-mlir-opt %s -refback-munge-calling-conventions -split-input-file | FileCheck %s // RUN: torch-mlir-opt %s -refback-munge-calling-conventions -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: func.func @f( // CHECK-LABEL: func.func @f(
// CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} { // CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} {
@ -70,3 +70,10 @@ func.func @multiple_return_values(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %a
func.func @two_return_values(%arg0: memref<?xf32>, %arg1: memref<?xi64>) -> (memref<?xf32>, memref<?xi64>) { func.func @two_return_values(%arg0: memref<?xf32>, %arg1: memref<?xi64>) -> (memref<?xf32>, memref<?xi64>) {
return %arg0 ,%arg1 : memref<?xf32>, memref<?xi64> return %arg0 ,%arg1 : memref<?xf32>, memref<?xi64>
} }
// -----
// expected-error-re @+1 {{argument must be a memref of {{.*}} but got 'tensor<?xf32>'}}
func.func @f(%arg0: tensor<?xf32>) {
return
}