[Importer] import constant tuple (#2132)

* [Importer] import constant tuple

* update

* update

* update
pull/2188/head snapshot-20230531.855
Yuanqiang Liu 2023-05-31 14:14:14 +08:00 committed by GitHub
parent 479b2175ef
commit 72b8070e57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 3 deletions

View File

@ -192,6 +192,7 @@ TORCHDYNAMO_XFAIL_SET = {
'IsFloatingPointInt_False',
'TorchPrimLoopForLikeModule_basic',
'TorchPrimLoopWhileLikeModule_basic',
"ScalarConstantTupleModule_basic",
# END tests failing due to: empty graph in dynamo
# ERROR due to: backend never runs because of empty frame
@ -268,6 +269,7 @@ TORCHDYNAMO_XFAIL_SET = {
"ScatterValueFloatModule_basic",
# ERROR: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {}
"ScatterValueIntModule_basic",
}
TORCHDYNAMO_CRASHING_SET = {
@ -543,6 +545,7 @@ STABLEHLO_PASS_SET = {
"NormScalarOptDimKeepDimModule_basic",
"NormScalarOptDimModule_basic",
"NormalizeModule_basic",
"ScalarConstantTupleModule_basic",
"SelectIntModule_basic",
"SelectIntNegativeDimAndIndexStaticModule_basic",
"SliceSingleIdxModule_basic",

View File

@ -226,12 +226,13 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
toMlirNamedAttribute(
"value",
mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName))));
} else if (output->type()->cast<c10::ListType>()) {
} else if (output->type()->cast<c10::ListType>() ||
output->type()->cast<c10::TupleType>()) {
ClassAnnotator dummyAnnotator;
MlirValue listValue =
MlirValue listOrTupleValue =
importIValue(node->ival(c10::attr::value), appendToBlock, context,
dummyAnnotator, importOptions);
mapResults(node, mlirOpResultGetOwner(listValue));
mapResults(node, mlirOpResultGetOwner(listOrTupleValue));
return; // Early return, since `importIValue` already added op to block.
} else {
std::stringstream msg;

View File

@ -12,6 +12,24 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
# ==============================================================================
class ScalarConstantTupleModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return (1, 2)
@register_test_case(module_factory=lambda: ScalarConstantTupleModule())
def ScalarConstantTupleModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 4))
# ==============================================================================
class MmModule(torch.nn.Module):