From 72b8070e57a8018e50e5c5182b9b11f498b82ca9 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Wed, 31 May 2023 14:14:14 +0800 Subject: [PATCH] [Importer] import constant tuple (#2132) * [Importer] import constant tuple * update * update * update --- e2e_testing/xfail_sets.py | 3 +++ .../importer/jit_ir/csrc/node_importer.cpp | 7 ++++--- python/torch_mlir_e2e_test/test_suite/basic.py | 18 ++++++++++++++++++ 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index c7def031d..eef7c4ceb 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -192,6 +192,7 @@ TORCHDYNAMO_XFAIL_SET = { 'IsFloatingPointInt_False', 'TorchPrimLoopForLikeModule_basic', 'TorchPrimLoopWhileLikeModule_basic', + "ScalarConstantTupleModule_basic", # END tests failing due to: empty graph in dynamo # ERROR due to: backend never runs because of empty frame @@ -268,6 +269,7 @@ TORCHDYNAMO_XFAIL_SET = { "ScatterValueFloatModule_basic", # ERROR: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {} "ScatterValueIntModule_basic", + } TORCHDYNAMO_CRASHING_SET = { @@ -543,6 +545,7 @@ STABLEHLO_PASS_SET = { "NormScalarOptDimKeepDimModule_basic", "NormScalarOptDimModule_basic", "NormalizeModule_basic", + "ScalarConstantTupleModule_basic", "SelectIntModule_basic", "SelectIntNegativeDimAndIndexStaticModule_basic", "SliceSingleIdxModule_basic", diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp index 8849bbf30..15cffedbe 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp @@ -226,12 +226,13 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock, toMlirNamedAttribute( "value", mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName)))); - } else if (output->type()->cast()) { + } else if (output->type()->cast() || + output->type()->cast()) { ClassAnnotator dummyAnnotator; - MlirValue listValue = + MlirValue listOrTupleValue = importIValue(node->ival(c10::attr::value), appendToBlock, context, dummyAnnotator, importOptions); - mapResults(node, mlirOpResultGetOwner(listValue)); + mapResults(node, mlirOpResultGetOwner(listOrTupleValue)); return; // Early return, since `importIValue` already added op to block. } else { std::stringstream msg; diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 2291ea64f..5043fcc61 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -12,6 +12,24 @@ from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== +class ScalarConstantTupleModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return (1, 2) + +@register_test_case(module_factory=lambda: ScalarConstantTupleModule()) +def ScalarConstantTupleModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 4)) + +# ============================================================================== class MmModule(torch.nn.Module):