diff --git a/frontends/pytorch/csrc/builder/module_builder.cpp b/frontends/pytorch/csrc/builder/module_builder.cpp index 2114b379b..d941c844a 100644 --- a/frontends/pytorch/csrc/builder/module_builder.cpp +++ b/frontends/pytorch/csrc/builder/module_builder.cpp @@ -95,7 +95,8 @@ ModuleBuilder::startCaptureFunction(std::string &name, return std::make_shared(typeMapper, std::move(funcBuilder)); } -void ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) { +torch::jit::StrongFunctionPtr +ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) { auto inserter = createInserter(); GraphImporter::MlirMappingOptions mappingOptions{ context, @@ -107,6 +108,7 @@ void ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) { function.function_, std::move(mappingOptions)); graphImporter->initialize(); graphImporter->importGenericFunc(); + return function; } FuncBuilder::Inserter ModuleBuilder::createInserter() { diff --git a/frontends/pytorch/csrc/builder/module_builder.h b/frontends/pytorch/csrc/builder/module_builder.h index 56bbcf935..bff96a2ab 100644 --- a/frontends/pytorch/csrc/builder/module_builder.h +++ b/frontends/pytorch/csrc/builder/module_builder.h @@ -40,7 +40,9 @@ public: // Imports a traced function. Note that the python type // torch.jit.ScriptFunction is the C++ type torch::jit::StrongFunctionPtr. // Just a bit of naming cruft. - void importFunction(torch::jit::StrongFunctionPtr function); + // Returns the same function, making it suitable as a nested decorator. + torch::jit::StrongFunctionPtr + importFunction(torch::jit::StrongFunctionPtr function); private: FuncBuilder::Inserter createInserter(); diff --git a/frontends/pytorch/test/graph_export/test_errors.py b/frontends/pytorch/test/graph_export/test_errors.py new file mode 100644 index 000000000..2f817f74e --- /dev/null +++ b/frontends/pytorch/test/graph_export/test_errors.py @@ -0,0 +1,29 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import torch +import torch_mlir + +# RUN: %PYTHON %s + +@torch.jit.script +class ExampleClass: + def __init__(self, x): + self.x = x + + +mb = torch_mlir.ModuleBuilder() + +# For now, TorchScript classes are wholly unsupported, so use it to test +# type conversion errors. +try: + @mb.import_function + @torch.jit.script + def import_class(c: ExampleClass): + return c.x +except RuntimeError as e: + # TODO: Once diagnostics are enabled, verify the actual error emitted. + assert str(e) == "could not convert function input type" +else: + assert False, "Expected exception" diff --git a/frontends/pytorch/test/graph_export/test_script_add3.py b/frontends/pytorch/test/graph_export/test_script_add3.py index c490df9d3..7d448d4f8 100644 --- a/frontends/pytorch/test/graph_export/test_script_add3.py +++ b/frontends/pytorch/test/graph_export/test_script_add3.py @@ -7,12 +7,7 @@ import torch_mlir # RUN: %PYTHON %s | npcomp-opt | FileCheck %s -@torch.jit.script -def add3(t0, t1, t2): - return t0 + t1 + t2 - mb = torch_mlir.ModuleBuilder() -mb.import_function(add3) # Verify without debug info. # CHECK-LABEL: func @add3$generic @@ -21,5 +16,11 @@ mb.import_function(add3) # CHECK: %[[A0:.*]] = torch.kernel_call "aten::add" %arg0, %arg1, %[[C1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, i64) -> !numpy.ndarray<*:!numpy.any_dtype> {sigArgTypes = ["Tensor", "Tensor", "Scalar"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]} # CHECK: %[[A1:.*]] = torch.kernel_call "aten::add" %[[A0]], %arg2, %[[C1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, i64) -> !numpy.ndarray<*:!numpy.any_dtype> {sigArgTypes = ["Tensor", "Tensor", "Scalar"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]} # CHECK: return %[[A1]] : !numpy.ndarray<*:!numpy.any_dtype> +@mb.import_function +@torch.jit.script +def add3(t0, t1, t2): + return t0 + t1 + t2 + +assert isinstance(add3, torch.jit.ScriptFunction) mb.module.operation.print() print() diff --git a/frontends/pytorch/test/graph_export/test_script_debug_info.py b/frontends/pytorch/test/graph_export/test_script_debug_info.py index f9a4e2a4e..e44830009 100644 --- a/frontends/pytorch/test/graph_export/test_script_debug_info.py +++ b/frontends/pytorch/test/graph_export/test_script_debug_info.py @@ -7,20 +7,25 @@ import torch_mlir # RUN: %PYTHON %s | FileCheck %s +mb = torch_mlir.ModuleBuilder() + +# CHECK-LABEL: func @add3$generic +# Note that line-level debug information for parts unannotated in the Torch +# graph are ascribed to the first op that carries source information. Presently +# this includes naked constants, return and the function itself. This heuristic +# likely needs to be improved and this test should be reworked when it is. +@mb.import_function @torch.jit.script def add3(t0, t1, t2): + # CHECK: constant 1{{.*}}loc({{.*}}test_script_debug_info.py":[[# @LINE + 2]] + # CHECK: aten::add{{.*}}loc({{.*}}test_script_debug_info.py":[[# @LINE + 1]] intermediate = t0 + t1 + # CHECK: aten::add{{.*}}loc({{.*}}test_script_debug_info.py":[[# @LINE + 1]] final = intermediate + t2 + # CHECK: return{{.*}}loc({{.*}}test_script_debug_info.py":[[# @LINE - 3]] return final - -mb = torch_mlir.ModuleBuilder() -mb.import_function(add3) + # CHECK: }{{.*}}loc({{.*}}test_script_debug_info.py":[[# @LINE - 5]] # Verify again with debug info present. Just checking that it makes it in there. -# CHECK-LABEL: func @add3$generic -# CHECK: constant 1{{.*}}loc({{.*}}test_script_debug_info.py -# CHECK: aten::add{{.*}}loc({{.*}}test_script_debug_info.py -# CHECK: return{{.*}}loc({{.*}}test_script_debug_info.py -# CHECK: }{{.*}}loc({{.*}}test_script_debug_info.py mb.module.operation.print(enable_debug_info=True) print()