From b94a859e03b4f42be85e4a8901e384023facb7cf Mon Sep 17 00:00:00 2001 From: Bryce Arden Date: Thu, 4 Mar 2021 15:08:50 -0600 Subject: [PATCH] [torch] Add import support for IValue string Type(s) (#179) * [torch] Add import support for IValue string Type(s) * [test] Add test for Strings import --- .../pytorch/csrc/builder/ivalue_importer.cpp | 10 ++++++ .../pytorch/test/ivalue_import/strings.py | 32 +++++++++++++++++++ 2 files changed, 42 insertions(+) create mode 100644 frontends/pytorch/test/ivalue_import/strings.py diff --git a/frontends/pytorch/csrc/builder/ivalue_importer.cpp b/frontends/pytorch/csrc/builder/ivalue_importer.cpp index 175ee1e72..f80f0017c 100644 --- a/frontends/pytorch/csrc/builder/ivalue_importer.cpp +++ b/frontends/pytorch/csrc/builder/ivalue_importer.cpp @@ -293,6 +293,16 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) { if (ivalue.isModule()) { return importModule(ivalue.toModule()); } + if (ivalue.isString()) { + MlirType type = npcompBytesTypeGet(context); + MlirOperation operation = createMlirOperationAtEnd( + importBlock, "basicpy.bytes_constant", loc, type, + toMlirNamedAttribute( + "value", + mlirStringAttrGet(context, + toMlirStringRef(ivalue.toString()->string())))); + return mlirOperationGetResult(operation, 0); + } if (ivalue.isNone()) { MlirOperation operation = createMlirOperationAtEnd( importBlock, "basicpy.singleton", loc, npcompNoneTypeGet(context)); diff --git a/frontends/pytorch/test/ivalue_import/strings.py b/frontends/pytorch/test/ivalue_import/strings.py new file mode 100644 index 000000000..d24408550 --- /dev/null +++ b/frontends/pytorch/test/ivalue_import/strings.py @@ -0,0 +1,32 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import typing + +import torch +import torch_mlir + +# RUN: %PYTHON %s | npcomp-opt | FileCheck %s + +mb = torch_mlir.ModuleBuilder() + +class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.s = "foo" +# CHECK: torch.class_type @[[CLASSTYPE:.*]] { +# TODO: Don't lose element type. +# CHECK: torch.attr "s" : !basicpy.BytesType +# CHECK: } +# CHECK: %[[BYTES:.*]] = basicpy.bytes_constant "foo" +# CHECK: torch.nn_module { +# CHECK: torch.slot "s", %[[BYTES]] : !basicpy.BytesType +# CHECK: } : !torch.nn.Module<"[[CLASSTYPE]]"> + + +test_module = TestModule() +recursivescriptmodule = torch.jit.script(test_module) +# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule. +mb.import_module(recursivescriptmodule._c) +mb.module.operation.print()