mirror of https://github.com/llvm/torch-mlir
[Importer] import constant tuple (#2132)
* [Importer] import constant tuple * update * update * updatepull/2188/head snapshot-20230531.855
parent
479b2175ef
commit
72b8070e57
|
@ -192,6 +192,7 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
'IsFloatingPointInt_False',
|
||||
'TorchPrimLoopForLikeModule_basic',
|
||||
'TorchPrimLoopWhileLikeModule_basic',
|
||||
"ScalarConstantTupleModule_basic",
|
||||
# END tests failing due to: empty graph in dynamo
|
||||
|
||||
# ERROR due to: backend never runs because of empty frame
|
||||
|
@ -268,6 +269,7 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
"ScatterValueFloatModule_basic",
|
||||
# ERROR: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {}
|
||||
"ScatterValueIntModule_basic",
|
||||
|
||||
}
|
||||
|
||||
TORCHDYNAMO_CRASHING_SET = {
|
||||
|
@ -543,6 +545,7 @@ STABLEHLO_PASS_SET = {
|
|||
"NormScalarOptDimKeepDimModule_basic",
|
||||
"NormScalarOptDimModule_basic",
|
||||
"NormalizeModule_basic",
|
||||
"ScalarConstantTupleModule_basic",
|
||||
"SelectIntModule_basic",
|
||||
"SelectIntNegativeDimAndIndexStaticModule_basic",
|
||||
"SliceSingleIdxModule_basic",
|
||||
|
|
|
@ -226,12 +226,13 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
|
|||
toMlirNamedAttribute(
|
||||
"value",
|
||||
mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName))));
|
||||
} else if (output->type()->cast<c10::ListType>()) {
|
||||
} else if (output->type()->cast<c10::ListType>() ||
|
||||
output->type()->cast<c10::TupleType>()) {
|
||||
ClassAnnotator dummyAnnotator;
|
||||
MlirValue listValue =
|
||||
MlirValue listOrTupleValue =
|
||||
importIValue(node->ival(c10::attr::value), appendToBlock, context,
|
||||
dummyAnnotator, importOptions);
|
||||
mapResults(node, mlirOpResultGetOwner(listValue));
|
||||
mapResults(node, mlirOpResultGetOwner(listOrTupleValue));
|
||||
return; // Early return, since `importIValue` already added op to block.
|
||||
} else {
|
||||
std::stringstream msg;
|
||||
|
|
|
@ -12,6 +12,24 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ScalarConstantTupleModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return (1, 2)
|
||||
|
||||
@register_test_case(module_factory=lambda: ScalarConstantTupleModule())
|
||||
def ScalarConstantTupleModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class MmModule(torch.nn.Module):
|
||||
|
||||
|
|
Loading…
Reference in New Issue