mirror of https://github.com/llvm/torch-mlir
[torch] Add import support for IValue string Type(s) (#179)
* [torch] Add import support for IValue string Type(s) * [test] Add test for Strings importpull/183/head
parent
a36113e586
commit
b94a859e03
|
@ -293,6 +293,16 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
|||
if (ivalue.isModule()) {
|
||||
return importModule(ivalue.toModule());
|
||||
}
|
||||
if (ivalue.isString()) {
|
||||
MlirType type = npcompBytesTypeGet(context);
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
importBlock, "basicpy.bytes_constant", loc, type,
|
||||
toMlirNamedAttribute(
|
||||
"value",
|
||||
mlirStringAttrGet(context,
|
||||
toMlirStringRef(ivalue.toString()->string()))));
|
||||
return mlirOperationGetResult(operation, 0);
|
||||
}
|
||||
if (ivalue.isNone()) {
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
importBlock, "basicpy.singleton", loc, npcompNoneTypeGet(context));
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import typing
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.s = "foo"
|
||||
# CHECK: torch.class_type @[[CLASSTYPE:.*]] {
|
||||
# TODO: Don't lose element type.
|
||||
# CHECK: torch.attr "s" : !basicpy.BytesType
|
||||
# CHECK: }
|
||||
# CHECK: %[[BYTES:.*]] = basicpy.bytes_constant "foo"
|
||||
# CHECK: torch.nn_module {
|
||||
# CHECK: torch.slot "s", %[[BYTES]] : !basicpy.BytesType
|
||||
# CHECK: } : !torch.nn.Module<"[[CLASSTYPE]]">
|
||||
|
||||
|
||||
test_module = TestModule()
|
||||
recursivescriptmodule = torch.jit.script(test_module)
|
||||
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
||||
mb.import_module(recursivescriptmodule._c)
|
||||
mb.module.operation.print()
|
Loading…
Reference in New Issue