mirror of https://github.com/llvm/torch-mlir
Move single-tensor-tuple-return test to mlir unit test.
Also, add multiple return test.pull/1571/head
parent
4f173c6e0f
commit
a7ac0def45
|
@ -14,7 +14,6 @@ from torch_mlir_e2e_test.test_suite import COMMON_TORCH_MLIR_LOWERING_XFAILS
|
|||
|
||||
REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
|
||||
"UpSampleNearest2dDynamicFactor_basic",
|
||||
"SingleTensorTupleReturn_basic",
|
||||
}
|
||||
|
||||
EAGER_MODE_XFAIL_SET = {
|
||||
|
|
|
@ -3009,23 +3009,6 @@ def AtenToDeviceModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class SingleTensorTupleReturn(torch.nn.Module):
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1 , -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return (x,)
|
||||
|
||||
@register_test_case(module_factory=lambda: SingleTensorTupleReturn())
|
||||
def SingleTensorTupleReturn_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(2, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class UpSampleNearest2dBackwardVec(torch.nn.Module):
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: torch-mlir-opt -torch-adjust-calling-conventions -allow-unregistered-dialect -split-input-file %s | FileCheck %s
|
||||
// RUN: torch-mlir-opt -torch-adjust-calling-conventions -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @basic(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
||||
|
@ -97,3 +97,20 @@ func.func @call_tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vte
|
|||
%0 = call @tuple_return(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tuple<tensor, tensor>
|
||||
return %0 : !torch.tuple<tensor, tensor>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Single tensor tuple return
|
||||
// expected-error @+1 {{Functions must return}}
|
||||
func.func @single_tensor_tuple_return(%arg0: !torch.tensor) -> !torch.tuple<tensor> {
|
||||
%0 = torch.prim.TupleConstruct %arg0 : !torch.tensor -> !torch.tuple<tensor>
|
||||
return %0 : !torch.tuple<tensor>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Multiple, non-tuple return
|
||||
// expected-error @+1 {{should only ever return one item}}
|
||||
func.func @multiple_non_tuple_return(%arg0: !torch.tensor) -> (!torch.tensor, !torch.tensor) {
|
||||
return %arg0, %arg0 : !torch.tensor, !torch.tensor
|
||||
}
|
Loading…
Reference in New Issue