[torch] Add import support for IValue string Type(s) (#179)

* [torch] Add import support for IValue string Type(s)

* [test] Add test for Strings import
pull/183/head
Bryce Arden 2021-03-04 15:08:50 -06:00 committed by GitHub
parent a36113e586
commit b94a859e03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 42 additions and 0 deletions

View File

@ -293,6 +293,16 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
if (ivalue.isModule()) {
return importModule(ivalue.toModule());
}
if (ivalue.isString()) {
MlirType type = npcompBytesTypeGet(context);
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "basicpy.bytes_constant", loc, type,
toMlirNamedAttribute(
"value",
mlirStringAttrGet(context,
toMlirStringRef(ivalue.toString()->string()))));
return mlirOperationGetResult(operation, 0);
}
if (ivalue.isNone()) {
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "basicpy.singleton", loc, npcompNoneTypeGet(context));

View File

@ -0,0 +1,32 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import typing
import torch
import torch_mlir
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder()
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.s = "foo"
# CHECK: torch.class_type @[[CLASSTYPE:.*]] {
# TODO: Don't lose element type.
# CHECK: torch.attr "s" : !basicpy.BytesType
# CHECK: }
# CHECK: %[[BYTES:.*]] = basicpy.bytes_constant "foo"
# CHECK: torch.nn_module {
# CHECK: torch.slot "s", %[[BYTES]] : !basicpy.BytesType
# CHECK: } : !torch.nn.Module<"[[CLASSTYPE]]">
test_module = TestModule()
recursivescriptmodule = torch.jit.script(test_module)
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
mb.import_module(recursivescriptmodule._c)
mb.module.operation.print()