Added support for importing node prim::Constant with list type

Prior to this commit, importing a `prim::Constant` node with list type would result in an error since it was not supported. `ivalue_importer::importIValue` was modified to return the MlirValue corresponding to the root so its parent operation could be extracted.
pull/588/head
Henry Tu 2022-02-10 13:39:52 -05:00 committed by Yi Zhang
parent ce4d6d1f83
commit 73ac9a7e2e
4 changed files with 31 additions and 5 deletions

View File

@ -559,11 +559,11 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
}
}
void torch_mlir::importIValue(c10::IValue ivalue, MlirBlock block,
MlirValue torch_mlir::importIValue(c10::IValue ivalue, MlirBlock block,
MlirContext context, ClassAnnotator &annotator) {
// When debugging module importing, it can be useful to dump as so:
// if (ivalue.isModule())
// ivalue.toModule().dump(true, false, false);
IValueImporter importer(block, context, annotator);
importer.importIValue(ivalue);
return importer.importIValue(ivalue);
}

View File

@ -25,8 +25,8 @@ namespace torch_mlir {
/// Main entry-point for importing torch IValue's .
/// Recursively imports `ivalue`, inserting operations at the end of `block`.
void importIValue(c10::IValue ivalue, MlirBlock block, MlirContext context,
ClassAnnotator &annotator);
MlirValue importIValue(c10::IValue ivalue, MlirBlock block, MlirContext context,
ClassAnnotator &annotator);
} // namespace torch_mlir

View File

@ -12,6 +12,8 @@
#include <unordered_map>
#include "class_annotator.h"
#include "ivalue_importer.h"
#include "mlir_utils.h"
#include "mlir-c/BuiltinAttributes.h"
@ -180,6 +182,14 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
toMlirNamedAttribute(
"value",
mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName))));
} else if (output->type()->cast<c10::ListType>()) {
ClassAnnotator dummyAnnotator;
MlirValue listValue = importIValue(node->ival(c10::attr::value),
appendToBlock,
context,
dummyAnnotator);
mapResults(node, mlirOpResultGetOwner(listValue));
return; // Early return, since `importIValue` already added op to block.
} else {
std::stringstream msg;
msg << "unhandled prim::Constant node: ";

View File

@ -5,6 +5,7 @@
import typing
import torch
from torch._C import CompilationUnit
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
import typing
@ -14,12 +15,16 @@ import typing
mb = ModuleBuilder()
# Import TorchScript IR string as ScriptFunction.
def import_ts_ir(func_name, ts_ir_str):
cu = CompilationUnit()
mb.import_function(cu.create_function(func_name, torch._C.parse_ir(ts_ir_str)))
# CHECK-LABEL: func @__torch__.prim_NumToTensor(
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tensor {
# CHECK: %[[RET:.*]] = torch.prim.NumToTensor.Scalar %[[ARG]] : !torch.int -> !torch.tensor
# CHECK: return %[[RET]] : !torch.tensor
# CHECK: }
@mb.import_function
@torch.jit.script
def prim_NumToTensor(i: int):
@ -147,5 +152,16 @@ def prim_min(x: int):
def prim_max(x: int):
return max(x), max(x,x), max(x, x, x)
# CHECK-LABEL: func @__torch__.prim_Constant_list() -> !torch.list<!torch.int> {
# CHECK: %[[A:.*]] = torch.constant.int 1
# CHECK: %[[B:.*]] = torch.constant.int 2
# CHECK: %[[C:.*]] = torch.constant.int 3
# CHECK: %[[RET:.*]] = torch.prim.ListConstruct %[[A]], %[[B]], %[[C]] :
# CHECK-SAME: (!torch.int, !torch.int, !torch.int) -> !torch.list<!torch.int>
# CHECK: return %[[RET]] : !torch.list<!torch.int>
import_ts_ir('__torch__.prim_Constant_list', '''graph():
%list : int[] = prim::Constant[value=[1, 2, 3]]()
return (%list)''')
mb.module.operation.print()
print()