mirror of https://github.com/llvm/torch-mlir
Add TorchScript import tests missed in previous change.
parent
78a3c90758
commit
9ffd2556ab
|
@ -95,7 +95,8 @@ ModuleBuilder::startCaptureFunction(std::string &name,
|
||||||
return std::make_shared<AcapController>(typeMapper, std::move(funcBuilder));
|
return std::make_shared<AcapController>(typeMapper, std::move(funcBuilder));
|
||||||
}
|
}
|
||||||
|
|
||||||
void ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) {
|
torch::jit::StrongFunctionPtr
|
||||||
|
ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) {
|
||||||
auto inserter = createInserter();
|
auto inserter = createInserter();
|
||||||
GraphImporter::MlirMappingOptions mappingOptions{
|
GraphImporter::MlirMappingOptions mappingOptions{
|
||||||
context,
|
context,
|
||||||
|
@ -107,6 +108,7 @@ void ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) {
|
||||||
function.function_, std::move(mappingOptions));
|
function.function_, std::move(mappingOptions));
|
||||||
graphImporter->initialize();
|
graphImporter->initialize();
|
||||||
graphImporter->importGenericFunc();
|
graphImporter->importGenericFunc();
|
||||||
|
return function;
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncBuilder::Inserter ModuleBuilder::createInserter() {
|
FuncBuilder::Inserter ModuleBuilder::createInserter() {
|
||||||
|
|
|
@ -40,7 +40,9 @@ public:
|
||||||
// Imports a traced function. Note that the python type
|
// Imports a traced function. Note that the python type
|
||||||
// torch.jit.ScriptFunction is the C++ type torch::jit::StrongFunctionPtr.
|
// torch.jit.ScriptFunction is the C++ type torch::jit::StrongFunctionPtr.
|
||||||
// Just a bit of naming cruft.
|
// Just a bit of naming cruft.
|
||||||
void importFunction(torch::jit::StrongFunctionPtr function);
|
// Returns the same function, making it suitable as a nested decorator.
|
||||||
|
torch::jit::StrongFunctionPtr
|
||||||
|
importFunction(torch::jit::StrongFunctionPtr function);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
FuncBuilder::Inserter createInserter();
|
FuncBuilder::Inserter createInserter();
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
# -*- Python -*-
|
||||||
|
# This file is licensed under a pytorch-style license
|
||||||
|
# See frontends/pytorch/LICENSE for license information.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch_mlir
|
||||||
|
|
||||||
|
# RUN: %PYTHON %s
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
class ExampleClass:
|
||||||
|
def __init__(self, x):
|
||||||
|
self.x = x
|
||||||
|
|
||||||
|
|
||||||
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
|
||||||
|
# For now, TorchScript classes are wholly unsupported, so use it to test
|
||||||
|
# type conversion errors.
|
||||||
|
try:
|
||||||
|
@mb.import_function
|
||||||
|
@torch.jit.script
|
||||||
|
def import_class(c: ExampleClass):
|
||||||
|
return c.x
|
||||||
|
except RuntimeError as e:
|
||||||
|
# TODO: Once diagnostics are enabled, verify the actual error emitted.
|
||||||
|
assert str(e) == "could not convert function input type"
|
||||||
|
else:
|
||||||
|
assert False, "Expected exception"
|
|
@ -7,12 +7,7 @@ import torch_mlir
|
||||||
|
|
||||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def add3(t0, t1, t2):
|
|
||||||
return t0 + t1 + t2
|
|
||||||
|
|
||||||
mb = torch_mlir.ModuleBuilder()
|
mb = torch_mlir.ModuleBuilder()
|
||||||
mb.import_function(add3)
|
|
||||||
|
|
||||||
# Verify without debug info.
|
# Verify without debug info.
|
||||||
# CHECK-LABEL: func @add3$generic
|
# CHECK-LABEL: func @add3$generic
|
||||||
|
@ -21,5 +16,11 @@ mb.import_function(add3)
|
||||||
# CHECK: %[[A0:.*]] = torch.kernel_call "aten::add" %arg0, %arg1, %[[C1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, i64) -> !numpy.ndarray<*:!numpy.any_dtype> {sigArgTypes = ["Tensor", "Tensor", "Scalar"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
|
# CHECK: %[[A0:.*]] = torch.kernel_call "aten::add" %arg0, %arg1, %[[C1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, i64) -> !numpy.ndarray<*:!numpy.any_dtype> {sigArgTypes = ["Tensor", "Tensor", "Scalar"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
|
||||||
# CHECK: %[[A1:.*]] = torch.kernel_call "aten::add" %[[A0]], %arg2, %[[C1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, i64) -> !numpy.ndarray<*:!numpy.any_dtype> {sigArgTypes = ["Tensor", "Tensor", "Scalar"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
|
# CHECK: %[[A1:.*]] = torch.kernel_call "aten::add" %[[A0]], %arg2, %[[C1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, i64) -> !numpy.ndarray<*:!numpy.any_dtype> {sigArgTypes = ["Tensor", "Tensor", "Scalar"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
|
||||||
# CHECK: return %[[A1]] : !numpy.ndarray<*:!numpy.any_dtype>
|
# CHECK: return %[[A1]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
@mb.import_function
|
||||||
|
@torch.jit.script
|
||||||
|
def add3(t0, t1, t2):
|
||||||
|
return t0 + t1 + t2
|
||||||
|
|
||||||
|
assert isinstance(add3, torch.jit.ScriptFunction)
|
||||||
mb.module.operation.print()
|
mb.module.operation.print()
|
||||||
print()
|
print()
|
||||||
|
|
|
@ -7,20 +7,25 @@ import torch_mlir
|
||||||
|
|
||||||
# RUN: %PYTHON %s | FileCheck %s
|
# RUN: %PYTHON %s | FileCheck %s
|
||||||
|
|
||||||
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
|
||||||
|
# CHECK-LABEL: func @add3$generic
|
||||||
|
# Note that line-level debug information for parts unannotated in the Torch
|
||||||
|
# graph are ascribed to the first op that carries source information. Presently
|
||||||
|
# this includes naked constants, return and the function itself. This heuristic
|
||||||
|
# likely needs to be improved and this test should be reworked when it is.
|
||||||
|
@mb.import_function
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def add3(t0, t1, t2):
|
def add3(t0, t1, t2):
|
||||||
|
# CHECK: constant 1{{.*}}loc({{.*}}test_script_debug_info.py":[[# @LINE + 2]]
|
||||||
|
# CHECK: aten::add{{.*}}loc({{.*}}test_script_debug_info.py":[[# @LINE + 1]]
|
||||||
intermediate = t0 + t1
|
intermediate = t0 + t1
|
||||||
|
# CHECK: aten::add{{.*}}loc({{.*}}test_script_debug_info.py":[[# @LINE + 1]]
|
||||||
final = intermediate + t2
|
final = intermediate + t2
|
||||||
|
# CHECK: return{{.*}}loc({{.*}}test_script_debug_info.py":[[# @LINE - 3]]
|
||||||
return final
|
return final
|
||||||
|
# CHECK: }{{.*}}loc({{.*}}test_script_debug_info.py":[[# @LINE - 5]]
|
||||||
mb = torch_mlir.ModuleBuilder()
|
|
||||||
mb.import_function(add3)
|
|
||||||
|
|
||||||
# Verify again with debug info present. Just checking that it makes it in there.
|
# Verify again with debug info present. Just checking that it makes it in there.
|
||||||
# CHECK-LABEL: func @add3$generic
|
|
||||||
# CHECK: constant 1{{.*}}loc({{.*}}test_script_debug_info.py
|
|
||||||
# CHECK: aten::add{{.*}}loc({{.*}}test_script_debug_info.py
|
|
||||||
# CHECK: return{{.*}}loc({{.*}}test_script_debug_info.py
|
|
||||||
# CHECK: }{{.*}}loc({{.*}}test_script_debug_info.py
|
|
||||||
mb.module.operation.print(enable_debug_info=True)
|
mb.module.operation.print(enable_debug_info=True)
|
||||||
print()
|
print()
|
||||||
|
|
Loading…
Reference in New Issue