mirror of https://github.com/llvm/torch-mlir
[cleanup] Make diagnostics better
Also remove some unused imports.pull/1607/head snapshot-20221117.660
parent
5f7177da35
commit
39de4d6265
|
@ -142,9 +142,12 @@ static LogicalResult mungeFunction(
|
||||||
SmallVector<Type> newArgTypes;
|
SmallVector<Type> newArgTypes;
|
||||||
for (auto arg : func.getArguments()) {
|
for (auto arg : func.getArguments()) {
|
||||||
auto type = arg.getType();
|
auto type = arg.getType();
|
||||||
if (!isArgMemRefTypeValid(type))
|
if (!isArgMemRefTypeValid(type)) {
|
||||||
return emitError(arg.getLoc(),
|
return emitError(arg.getLoc())
|
||||||
"argument must be a memref of f32, f64, i32, i64, i8, i1");
|
.append("argument must be a memref of f32, f64, i32, i64, i8, i1 but "
|
||||||
|
"got ",
|
||||||
|
type);
|
||||||
|
}
|
||||||
auto cast = b.create<memref::CastOp>(arg.getLoc(), type, arg);
|
auto cast = b.create<memref::CastOp>(arg.getLoc(), type, arg);
|
||||||
arg.replaceAllUsesExcept(cast, cast);
|
arg.replaceAllUsesExcept(cast, cast);
|
||||||
arg.setType(getAbiTypeForMemRef(type));
|
arg.setType(getAbiTypeForMemRef(type));
|
||||||
|
|
|
@ -3,9 +3,6 @@
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
# Also available under a BSD-style license. See LICENSE.
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
import copy
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
|
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
|
||||||
|
|
|
@ -22,7 +22,7 @@ __all__ = [
|
||||||
|
|
||||||
def assert_arg_type_is_supported(ty):
|
def assert_arg_type_is_supported(ty):
|
||||||
SUPPORTED = [np.float32, np.float64, np.uint8, np.int8, np.int32, np.int64, np.bool_]
|
SUPPORTED = [np.float32, np.float64, np.uint8, np.int8, np.int32, np.int64, np.bool_]
|
||||||
assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported"
|
assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported, but got {ty}"
|
||||||
|
|
||||||
|
|
||||||
memref_type_to_np_dtype = {
|
memref_type_to_np_dtype = {
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: torch-mlir-opt %s -refback-munge-calling-conventions -split-input-file | FileCheck %s
|
// RUN: torch-mlir-opt %s -refback-munge-calling-conventions -split-input-file -verify-diagnostics | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @f(
|
// CHECK-LABEL: func.func @f(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} {
|
// CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} {
|
||||||
|
@ -70,3 +70,10 @@ func.func @multiple_return_values(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %a
|
||||||
func.func @two_return_values(%arg0: memref<?xf32>, %arg1: memref<?xi64>) -> (memref<?xf32>, memref<?xi64>) {
|
func.func @two_return_values(%arg0: memref<?xf32>, %arg1: memref<?xi64>) -> (memref<?xf32>, memref<?xi64>) {
|
||||||
return %arg0 ,%arg1 : memref<?xf32>, memref<?xi64>
|
return %arg0 ,%arg1 : memref<?xf32>, memref<?xi64>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// expected-error-re @+1 {{argument must be a memref of {{.*}} but got 'tensor<?xf32>'}}
|
||||||
|
func.func @f(%arg0: tensor<?xf32>) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue