mirror of https://github.com/llvm/torch-mlir
Add TorchScript import tests missed in previous change.
parent
78a3c90758
commit
9ffd2556ab
|
@ -95,7 +95,8 @@ ModuleBuilder::startCaptureFunction(std::string &name,
|
|||
return std::make_shared<AcapController>(typeMapper, std::move(funcBuilder));
|
||||
}
|
||||
|
||||
void ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) {
|
||||
torch::jit::StrongFunctionPtr
|
||||
ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) {
|
||||
auto inserter = createInserter();
|
||||
GraphImporter::MlirMappingOptions mappingOptions{
|
||||
context,
|
||||
|
@ -107,6 +108,7 @@ void ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) {
|
|||
function.function_, std::move(mappingOptions));
|
||||
graphImporter->initialize();
|
||||
graphImporter->importGenericFunc();
|
||||
return function;
|
||||
}
|
||||
|
||||
FuncBuilder::Inserter ModuleBuilder::createInserter() {
|
||||
|
|
|
@ -40,7 +40,9 @@ public:
|
|||
// Imports a traced function. Note that the python type
|
||||
// torch.jit.ScriptFunction is the C++ type torch::jit::StrongFunctionPtr.
|
||||
// Just a bit of naming cruft.
|
||||
void importFunction(torch::jit::StrongFunctionPtr function);
|
||||
// Returns the same function, making it suitable as a nested decorator.
|
||||
torch::jit::StrongFunctionPtr
|
||||
importFunction(torch::jit::StrongFunctionPtr function);
|
||||
|
||||
private:
|
||||
FuncBuilder::Inserter createInserter();
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
# RUN: %PYTHON %s
|
||||
|
||||
@torch.jit.script
|
||||
class ExampleClass:
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
# For now, TorchScript classes are wholly unsupported, so use it to test
|
||||
# type conversion errors.
|
||||
try:
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def import_class(c: ExampleClass):
|
||||
return c.x
|
||||
except RuntimeError as e:
|
||||
# TODO: Once diagnostics are enabled, verify the actual error emitted.
|
||||
assert str(e) == "could not convert function input type"
|
||||
else:
|
||||
assert False, "Expected exception"
|
|
@ -7,12 +7,7 @@ import torch_mlir
|
|||
|
||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||
|
||||
@torch.jit.script
|
||||
def add3(t0, t1, t2):
|
||||
return t0 + t1 + t2
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
mb.import_function(add3)
|
||||
|
||||
# Verify without debug info.
|
||||
# CHECK-LABEL: func @add3$generic
|
||||
|
@ -21,5 +16,11 @@ mb.import_function(add3)
|
|||
# CHECK: %[[A0:.*]] = torch.kernel_call "aten::add" %arg0, %arg1, %[[C1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, i64) -> !numpy.ndarray<*:!numpy.any_dtype> {sigArgTypes = ["Tensor", "Tensor", "Scalar"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
|
||||
# CHECK: %[[A1:.*]] = torch.kernel_call "aten::add" %[[A0]], %arg2, %[[C1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, i64) -> !numpy.ndarray<*:!numpy.any_dtype> {sigArgTypes = ["Tensor", "Tensor", "Scalar"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
|
||||
# CHECK: return %[[A1]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def add3(t0, t1, t2):
|
||||
return t0 + t1 + t2
|
||||
|
||||
assert isinstance(add3, torch.jit.ScriptFunction)
|
||||
mb.module.operation.print()
|
||||
print()
|
||||
|
|
|
@ -7,20 +7,25 @@ import torch_mlir
|
|||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
# CHECK-LABEL: func @add3$generic
|
||||
# Note that line-level debug information for parts unannotated in the Torch
|
||||
# graph are ascribed to the first op that carries source information. Presently
|
||||
# this includes naked constants, return and the function itself. This heuristic
|
||||
# likely needs to be improved and this test should be reworked when it is.
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def add3(t0, t1, t2):
|
||||
# CHECK: constant 1{{.*}}loc({{.*}}test_script_debug_info.py":[[# @LINE + 2]]
|
||||
# CHECK: aten::add{{.*}}loc({{.*}}test_script_debug_info.py":[[# @LINE + 1]]
|
||||
intermediate = t0 + t1
|
||||
# CHECK: aten::add{{.*}}loc({{.*}}test_script_debug_info.py":[[# @LINE + 1]]
|
||||
final = intermediate + t2
|
||||
# CHECK: return{{.*}}loc({{.*}}test_script_debug_info.py":[[# @LINE - 3]]
|
||||
return final
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
mb.import_function(add3)
|
||||
# CHECK: }{{.*}}loc({{.*}}test_script_debug_info.py":[[# @LINE - 5]]
|
||||
|
||||
# Verify again with debug info present. Just checking that it makes it in there.
|
||||
# CHECK-LABEL: func @add3$generic
|
||||
# CHECK: constant 1{{.*}}loc({{.*}}test_script_debug_info.py
|
||||
# CHECK: aten::add{{.*}}loc({{.*}}test_script_debug_info.py
|
||||
# CHECK: return{{.*}}loc({{.*}}test_script_debug_info.py
|
||||
# CHECK: }{{.*}}loc({{.*}}test_script_debug_info.py
|
||||
mb.module.operation.print(enable_debug_info=True)
|
||||
print()
|
||||
|
|
Loading…
Reference in New Issue