diff --git a/frontends/pytorch/csrc/builder/ivalue_importer.cpp b/frontends/pytorch/csrc/builder/ivalue_importer.cpp index 175ee1e72..f80f0017c 100644 --- a/frontends/pytorch/csrc/builder/ivalue_importer.cpp +++ b/frontends/pytorch/csrc/builder/ivalue_importer.cpp @@ -293,6 +293,16 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) { if (ivalue.isModule()) { return importModule(ivalue.toModule()); } + if (ivalue.isString()) { + MlirType type = npcompBytesTypeGet(context); + MlirOperation operation = createMlirOperationAtEnd( + importBlock, "basicpy.bytes_constant", loc, type, + toMlirNamedAttribute( + "value", + mlirStringAttrGet(context, + toMlirStringRef(ivalue.toString()->string())))); + return mlirOperationGetResult(operation, 0); + } if (ivalue.isNone()) { MlirOperation operation = createMlirOperationAtEnd( importBlock, "basicpy.singleton", loc, npcompNoneTypeGet(context)); diff --git a/frontends/pytorch/test/ivalue_import/strings.py b/frontends/pytorch/test/ivalue_import/strings.py new file mode 100644 index 000000000..d24408550 --- /dev/null +++ b/frontends/pytorch/test/ivalue_import/strings.py @@ -0,0 +1,32 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import typing + +import torch +import torch_mlir + +# RUN: %PYTHON %s | npcomp-opt | FileCheck %s + +mb = torch_mlir.ModuleBuilder() + +class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.s = "foo" +# CHECK: torch.class_type @[[CLASSTYPE:.*]] { +# TODO: Don't lose element type. +# CHECK: torch.attr "s" : !basicpy.BytesType +# CHECK: } +# CHECK: %[[BYTES:.*]] = basicpy.bytes_constant "foo" +# CHECK: torch.nn_module { +# CHECK: torch.slot "s", %[[BYTES]] : !basicpy.BytesType +# CHECK: } : !torch.nn.Module<"[[CLASSTYPE]]"> + + +test_module = TestModule() +recursivescriptmodule = torch.jit.script(test_module) +# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule. +mb.import_module(recursivescriptmodule._c) +mb.module.operation.print()