Add TorchScript import tests missed in previous change.

pull/129/head
Stella Laurenzo 2020-11-23 14:41:30 -08:00 committed by Stella Laurenzo
parent 78a3c90758
commit 9ffd2556ab
5 changed files with 54 additions and 15 deletions

View File

@ -95,7 +95,8 @@ ModuleBuilder::startCaptureFunction(std::string &name,
return std::make_shared<AcapController>(typeMapper, std::move(funcBuilder)); return std::make_shared<AcapController>(typeMapper, std::move(funcBuilder));
} }
void ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) { torch::jit::StrongFunctionPtr
ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) {
auto inserter = createInserter(); auto inserter = createInserter();
GraphImporter::MlirMappingOptions mappingOptions{ GraphImporter::MlirMappingOptions mappingOptions{
context, context,
@ -107,6 +108,7 @@ void ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) {
function.function_, std::move(mappingOptions)); function.function_, std::move(mappingOptions));
graphImporter->initialize(); graphImporter->initialize();
graphImporter->importGenericFunc(); graphImporter->importGenericFunc();
return function;
} }
FuncBuilder::Inserter ModuleBuilder::createInserter() { FuncBuilder::Inserter ModuleBuilder::createInserter() {

View File

@ -40,7 +40,9 @@ public:
// Imports a traced function. Note that the python type // Imports a traced function. Note that the python type
// torch.jit.ScriptFunction is the C++ type torch::jit::StrongFunctionPtr. // torch.jit.ScriptFunction is the C++ type torch::jit::StrongFunctionPtr.
// Just a bit of naming cruft. // Just a bit of naming cruft.
void importFunction(torch::jit::StrongFunctionPtr function); // Returns the same function, making it suitable as a nested decorator.
torch::jit::StrongFunctionPtr
importFunction(torch::jit::StrongFunctionPtr function);
private: private:
FuncBuilder::Inserter createInserter(); FuncBuilder::Inserter createInserter();

View File

@ -0,0 +1,29 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import torch
import torch_mlir
# RUN: %PYTHON %s
@torch.jit.script
class ExampleClass:
def __init__(self, x):
self.x = x
mb = torch_mlir.ModuleBuilder()
# For now, TorchScript classes are wholly unsupported, so use it to test
# type conversion errors.
try:
@mb.import_function
@torch.jit.script
def import_class(c: ExampleClass):
return c.x
except RuntimeError as e:
# TODO: Once diagnostics are enabled, verify the actual error emitted.
assert str(e) == "could not convert function input type"
else:
assert False, "Expected exception"

View File

@ -7,12 +7,7 @@ import torch_mlir
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s # RUN: %PYTHON %s | npcomp-opt | FileCheck %s
@torch.jit.script
def add3(t0, t1, t2):
return t0 + t1 + t2
mb = torch_mlir.ModuleBuilder() mb = torch_mlir.ModuleBuilder()
mb.import_function(add3)
# Verify without debug info. # Verify without debug info.
# CHECK-LABEL: func @add3$generic # CHECK-LABEL: func @add3$generic
@ -21,5 +16,11 @@ mb.import_function(add3)
# CHECK: %[[A0:.*]] = torch.kernel_call "aten::add" %arg0, %arg1, %[[C1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, i64) -> !numpy.ndarray<*:!numpy.any_dtype> {sigArgTypes = ["Tensor", "Tensor", "Scalar"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]} # CHECK: %[[A0:.*]] = torch.kernel_call "aten::add" %arg0, %arg1, %[[C1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, i64) -> !numpy.ndarray<*:!numpy.any_dtype> {sigArgTypes = ["Tensor", "Tensor", "Scalar"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
# CHECK: %[[A1:.*]] = torch.kernel_call "aten::add" %[[A0]], %arg2, %[[C1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, i64) -> !numpy.ndarray<*:!numpy.any_dtype> {sigArgTypes = ["Tensor", "Tensor", "Scalar"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]} # CHECK: %[[A1:.*]] = torch.kernel_call "aten::add" %[[A0]], %arg2, %[[C1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, i64) -> !numpy.ndarray<*:!numpy.any_dtype> {sigArgTypes = ["Tensor", "Tensor", "Scalar"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
# CHECK: return %[[A1]] : !numpy.ndarray<*:!numpy.any_dtype> # CHECK: return %[[A1]] : !numpy.ndarray<*:!numpy.any_dtype>
@mb.import_function
@torch.jit.script
def add3(t0, t1, t2):
return t0 + t1 + t2
assert isinstance(add3, torch.jit.ScriptFunction)
mb.module.operation.print() mb.module.operation.print()
print() print()

View File

@ -7,20 +7,25 @@ import torch_mlir
# RUN: %PYTHON %s | FileCheck %s # RUN: %PYTHON %s | FileCheck %s
mb = torch_mlir.ModuleBuilder()
# CHECK-LABEL: func @add3$generic
# Note that line-level debug information for parts unannotated in the Torch
# graph are ascribed to the first op that carries source information. Presently
# this includes naked constants, return and the function itself. This heuristic
# likely needs to be improved and this test should be reworked when it is.
@mb.import_function
@torch.jit.script @torch.jit.script
def add3(t0, t1, t2): def add3(t0, t1, t2):
# CHECK: constant 1{{.*}}loc({{.*}}test_script_debug_info.py":[[# @LINE + 2]]
# CHECK: aten::add{{.*}}loc({{.*}}test_script_debug_info.py":[[# @LINE + 1]]
intermediate = t0 + t1 intermediate = t0 + t1
# CHECK: aten::add{{.*}}loc({{.*}}test_script_debug_info.py":[[# @LINE + 1]]
final = intermediate + t2 final = intermediate + t2
# CHECK: return{{.*}}loc({{.*}}test_script_debug_info.py":[[# @LINE - 3]]
return final return final
# CHECK: }{{.*}}loc({{.*}}test_script_debug_info.py":[[# @LINE - 5]]
mb = torch_mlir.ModuleBuilder()
mb.import_function(add3)
# Verify again with debug info present. Just checking that it makes it in there. # Verify again with debug info present. Just checking that it makes it in there.
# CHECK-LABEL: func @add3$generic
# CHECK: constant 1{{.*}}loc({{.*}}test_script_debug_info.py
# CHECK: aten::add{{.*}}loc({{.*}}test_script_debug_info.py
# CHECK: return{{.*}}loc({{.*}}test_script_debug_info.py
# CHECK: }{{.*}}loc({{.*}}test_script_debug_info.py
mb.module.operation.print(enable_debug_info=True) mb.module.operation.print(enable_debug_info=True)
print() print()