Move single-tensor-tuple-return test to mlir unit test.

Also, add multiple return test.
pull/1571/head
Daniel Ellis 2022-11-10 09:23:53 -05:00 committed by GitHub
parent 4f173c6e0f
commit a7ac0def45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 18 additions and 19 deletions

View File

@ -14,7 +14,6 @@ from torch_mlir_e2e_test.test_suite import COMMON_TORCH_MLIR_LOWERING_XFAILS
REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
"UpSampleNearest2dDynamicFactor_basic",
"SingleTensorTupleReturn_basic",
}
EAGER_MODE_XFAIL_SET = {

View File

@ -3009,23 +3009,6 @@ def AtenToDeviceModule_basic(module, tu: TestUtils):
# ==============================================================================
class SingleTensorTupleReturn(torch.nn.Module):
@export
@annotate_args([
None,
([-1 , -1], torch.float32, True),
])
def forward(self, x):
return (x,)
@register_test_case(module_factory=lambda: SingleTensorTupleReturn())
def SingleTensorTupleReturn_basic(module, tu: TestUtils):
module.forward(torch.randn(2, 4))
# ==============================================================================
class UpSampleNearest2dBackwardVec(torch.nn.Module):

View File

@ -1,4 +1,4 @@
// RUN: torch-mlir-opt -torch-adjust-calling-conventions -allow-unregistered-dialect -split-input-file %s | FileCheck %s
// RUN: torch-mlir-opt -torch-adjust-calling-conventions -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s
// CHECK-LABEL: func.func @basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
@ -97,3 +97,20 @@ func.func @call_tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vte
%0 = call @tuple_return(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tuple<tensor, tensor>
return %0 : !torch.tuple<tensor, tensor>
}
// -----
// Single tensor tuple return
// expected-error @+1 {{Functions must return}}
func.func @single_tensor_tuple_return(%arg0: !torch.tensor) -> !torch.tuple<tensor> {
%0 = torch.prim.TupleConstruct %arg0 : !torch.tensor -> !torch.tuple<tensor>
return %0 : !torch.tuple<tensor>
}
// -----
// Multiple, non-tuple return
// expected-error @+1 {{should only ever return one item}}
func.func @multiple_non_tuple_return(%arg0: !torch.tensor) -> (!torch.tensor, !torch.tensor) {
return %arg0, %arg0 : !torch.tensor, !torch.tensor
}