mirror of https://github.com/llvm/torch-mlir
Added support for importing node prim::Constant with list type
Prior to this commit, importing a `prim::Constant` node with list type would result in an error since it was not supported. `ivalue_importer::importIValue` was modified to return the MlirValue corresponding to the root so its parent operation could be extracted.pull/588/head
parent
ce4d6d1f83
commit
73ac9a7e2e
|
@ -559,11 +559,11 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
|
|||
}
|
||||
}
|
||||
|
||||
void torch_mlir::importIValue(c10::IValue ivalue, MlirBlock block,
|
||||
MlirValue torch_mlir::importIValue(c10::IValue ivalue, MlirBlock block,
|
||||
MlirContext context, ClassAnnotator &annotator) {
|
||||
// When debugging module importing, it can be useful to dump as so:
|
||||
// if (ivalue.isModule())
|
||||
// ivalue.toModule().dump(true, false, false);
|
||||
IValueImporter importer(block, context, annotator);
|
||||
importer.importIValue(ivalue);
|
||||
return importer.importIValue(ivalue);
|
||||
}
|
||||
|
|
|
@ -25,8 +25,8 @@ namespace torch_mlir {
|
|||
|
||||
/// Main entry-point for importing torch IValue's .
|
||||
/// Recursively imports `ivalue`, inserting operations at the end of `block`.
|
||||
void importIValue(c10::IValue ivalue, MlirBlock block, MlirContext context,
|
||||
ClassAnnotator &annotator);
|
||||
MlirValue importIValue(c10::IValue ivalue, MlirBlock block, MlirContext context,
|
||||
ClassAnnotator &annotator);
|
||||
|
||||
} // namespace torch_mlir
|
||||
|
||||
|
|
|
@ -12,6 +12,8 @@
|
|||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "class_annotator.h"
|
||||
#include "ivalue_importer.h"
|
||||
#include "mlir_utils.h"
|
||||
|
||||
#include "mlir-c/BuiltinAttributes.h"
|
||||
|
@ -180,6 +182,14 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
toMlirNamedAttribute(
|
||||
"value",
|
||||
mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName))));
|
||||
} else if (output->type()->cast<c10::ListType>()) {
|
||||
ClassAnnotator dummyAnnotator;
|
||||
MlirValue listValue = importIValue(node->ival(c10::attr::value),
|
||||
appendToBlock,
|
||||
context,
|
||||
dummyAnnotator);
|
||||
mapResults(node, mlirOpResultGetOwner(listValue));
|
||||
return; // Early return, since `importIValue` already added op to block.
|
||||
} else {
|
||||
std::stringstream msg;
|
||||
msg << "unhandled prim::Constant node: ";
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
import typing
|
||||
|
||||
import torch
|
||||
from torch._C import CompilationUnit
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
|
||||
|
||||
import typing
|
||||
|
@ -14,12 +15,16 @@ import typing
|
|||
mb = ModuleBuilder()
|
||||
|
||||
|
||||
# Import TorchScript IR string as ScriptFunction.
|
||||
def import_ts_ir(func_name, ts_ir_str):
|
||||
cu = CompilationUnit()
|
||||
mb.import_function(cu.create_function(func_name, torch._C.parse_ir(ts_ir_str)))
|
||||
|
||||
# CHECK-LABEL: func @__torch__.prim_NumToTensor(
|
||||
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tensor {
|
||||
# CHECK: %[[RET:.*]] = torch.prim.NumToTensor.Scalar %[[ARG]] : !torch.int -> !torch.tensor
|
||||
# CHECK: return %[[RET]] : !torch.tensor
|
||||
# CHECK: }
|
||||
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def prim_NumToTensor(i: int):
|
||||
|
@ -147,5 +152,16 @@ def prim_min(x: int):
|
|||
def prim_max(x: int):
|
||||
return max(x), max(x,x), max(x, x, x)
|
||||
|
||||
# CHECK-LABEL: func @__torch__.prim_Constant_list() -> !torch.list<!torch.int> {
|
||||
# CHECK: %[[A:.*]] = torch.constant.int 1
|
||||
# CHECK: %[[B:.*]] = torch.constant.int 2
|
||||
# CHECK: %[[C:.*]] = torch.constant.int 3
|
||||
# CHECK: %[[RET:.*]] = torch.prim.ListConstruct %[[A]], %[[B]], %[[C]] :
|
||||
# CHECK-SAME: (!torch.int, !torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
# CHECK: return %[[RET]] : !torch.list<!torch.int>
|
||||
import_ts_ir('__torch__.prim_Constant_list', '''graph():
|
||||
%list : int[] = prim::Constant[value=[1, 2, 3]]()
|
||||
return (%list)''')
|
||||
|
||||
mb.module.operation.print()
|
||||
print()
|
||||
|
|
Loading…
Reference in New Issue