Miscellaneous changes while trying to work on ResNet18

- Move frontend lowering pipelines to c++ (this helps with reproducing
  failures in npcomp-opt)
- Add debugging printouts when compilation fails on RefBackendTestConfig

The experience now when a test fails during MLIR lowering is now like this:
```
NPCOMP TorchScript Object Graph IR -> NPCOMP Backend IR lowering failed with the following diagnostics:
failed to legalize operation 'torch.global_slot'
Module does not conform to npcomp's backend contract. See dialect conversion legality information above.

Error can be reproduced with:
$ npcomp-opt -torchscript-to-npcomp-backend-pipeline /tmp/ResNet18Module.mlir
```

And when TorchScript->MLIR import fails it looks like this:
```
PyTorch TorchScript module -> NPCOMP Object Graph IR import failed with the following diagnostics:
unhandled prim operation: %18 : int = prim::min(%17) # /usr/local/google/home/silvasean/.local/lib/python3.9/site-packages/torch/nn/functional.py:4532:4
```

Also,
- Add `--filter=<regex>` to e2e test harness to filter tests.
- Add a few prim ops that were needed to import ResNet18
- Fix torch.prim.Loop.condition assemblyFormat (it previously would not
  round-trip in the case of no loop-carried variables)
pull/209/head
Sean Silva 2021-04-21 15:07:15 -07:00
parent 8f96901943
commit 3a890aa26c
13 changed files with 261 additions and 82 deletions

View File

@ -403,6 +403,12 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
}
for (torch::jit::Function *function : cu->get_functions()) {
// Useful for debugging errors in free functions that end up being
// unused. These can be missing when round-tripping through the on-disk
// format, even though they still cause import issues when importing
// through the larger Python session where they originate.
// std::cerr << "NAME: " << function->qualname().qualifiedName() << "\n";
// std::cerr << *function->graph();
MethodAnnotation *annotation =
annotator.getMethodAnnotationForFunction(function);
MlirOperation func = importJitFunctionAsFuncOp(

View File

@ -80,6 +80,9 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
case c10::prim::Uninitialized:
case c10::prim::RaiseException:
case c10::prim::Print:
case c10::prim::min:
case c10::prim::max:
case c10::prim::layout:
case c10::prim::NumToTensor: {
createAndMapTrivialNode(node,
"torch.prim." + std::string(kind.toUnqualString()));

View File

@ -3,6 +3,8 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import argparse
import re
import sys
from torch_mlir.torchscript.e2e_test.framework import run_tests
from torch_mlir.torchscript.e2e_test.reporting import report_results
@ -33,6 +35,9 @@ Meaning of options:
"refbackend": run through npcomp's RefBackend.
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
''')
parser.add_argument('--filter', default='.*', help='''
Regular expression specifying which tests to include in this run.
''')
args = parser.parse_args()
if args.config == 'refbackend':
@ -41,7 +46,20 @@ Meaning of options:
config = NativeTorchTestConfig()
elif args.config == 'torchscript':
config = TorchScriptTestConfig()
results = run_tests(GLOBAL_TEST_REGISTRY, config)
tests = [
test for test in GLOBAL_TEST_REGISTRY
if re.match(args.filter, test.unique_name)
]
if len(tests) == 0:
print(
f'ERROR: the provided filter {args.filter!r} does not match any tests'
)
print('The available tests are:')
for test in GLOBAL_TEST_REGISTRY:
print(test.unique_name)
sys.exit(1)
results = run_tests(tests, config)
report_results(results)
if __name__ == '__main__':

View File

@ -11,7 +11,7 @@ from torch_mlir.torchscript.annotations import annotate_args, export
# ==============================================================================
class Resnet18Module(torch.nn.Module):
class ResNet18Module(torch.nn.Module):
def __init__(self):
super().__init__()
# Reset seed to make model deterministic.
@ -25,6 +25,6 @@ class Resnet18Module(torch.nn.Module):
def forward(self, img):
return self.resnet.forward(img)
@register_test_case(module_factory=lambda: Resnet18Module())
def Resnet18Module_basic(module, tu: TestUtils):
@register_test_case(module_factory=lambda: ResNet18Module())
def ResNet18Module_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 3, 224, 224))

View File

@ -2,13 +2,18 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import sys
from typing import Any
from io import StringIO
import os
import tempfile
import numpy as np
import torch
from mlir.passmanager import PassManager
import torch_mlir
from npcomp.compiler.pytorch.backend import refjit, frontend_lowering
from npcomp.compiler.pytorch.backend import refjit
from torch_mlir.torchscript.e2e_test.framework import TestConfig, Trace, TraceItem
from torch_mlir.torchscript.annotations import extract_annotations
@ -26,9 +31,51 @@ class RefBackendTestConfig(TestConfig):
extract_annotations(program, scripted, class_annotator)
mb.import_module(scripted._c, class_annotator)
# Lower module in place.
frontend_lowering.lower_object_graph(mb.module)
# TODO: Find a way to make each of these calls own its own
# "debuggable error report" situation.
try:
sys.stderr = StringIO()
# Import the TorchScript module to MLIR
mb.import_module(scripted._c, class_annotator)
except Exception as e:
raise Exception(f"""
PyTorch TorchScript module -> NPCOMP Object Graph IR import failed with the following diagnostics:
{sys.stderr.getvalue()}
""") from None
finally:
sys.stderr = sys.__stderr__
try:
sys.stderr = StringIO()
asm_for_error_report = mb.module.operation.get_asm(
large_elements_limit=10, enable_debug_info=True)
pipeline_str = "torchscript-to-npcomp-backend-pipeline"
# Lower module in place to make it ready for compiler backends.
with mb.module.context:
pm = PassManager.parse(pipeline_str)
pm.run(mb.module)
except Exception as e:
# TODO: More robust.
# - don't arbitrarily clutter up /tmp. When a test suite has many
# tests, this can be a big disk cost (also, /tmp/ is frequently a
# RAM fs, which increases worries about capacity).
# - don't have colliding filenames (hard to do without cluttering
# up /tmp)
# - if we do have have colliding filenames, writes should at least
# avoid being racy.
filename = os.path.join(tempfile.gettempdir(),
scripted.original_name + '.mlir')
with open(filename, 'w') as f:
f.write(asm_for_error_report)
raise Exception(f"""
NPCOMP TorchScript Object Graph IR -> NPCOMP Backend IR lowering failed with the following diagnostics:
{sys.stderr.getvalue()}
Error can be reproduced with:
$ npcomp-opt -{pipeline_str} {filename}
""") from None
finally:
sys.stderr = sys.__stderr__
return self.backend.compile(mb.module)
def run(self, artifact: Any, trace: Trace) -> Trace:

View File

@ -18,7 +18,7 @@ mb = torch_mlir.ModuleBuilder()
# CHECK: %[[RESULTS:.*]] = torch.prim.Loop %[[MAX_ITERATIONS]], %[[BOOL_TRUE]], init(%[[F_INIT]]) {
# CHECK: ^bb0(%[[IV:.*]]: i64, %[[F_ITER:.*]]: f64):
# CHECK: %[[F_NEXT:.*]] = torch.kernel_call "aten::add" %[[F_ITER]], %[[IV]] : (f64, i64) -> f64 {sigArgTypes = ["float", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["float"]}
# CHECK: torch.prim.Loop.condition %[[BOOL_TRUE]] iter(%[[F_NEXT]]) : !basicpy.BoolType, (f64)
# CHECK: torch.prim.Loop.condition %[[BOOL_TRUE]], iter(%[[F_NEXT]] : f64)
# CHECK: } : (i64, !basicpy.BoolType, f64) -> f64
# CHECK: return %[[RESULTS:.*]] : f64
@mb.import_function
@ -38,7 +38,7 @@ def prim_Loop_forlike(n: int):
# CHECK: ^bb0(%[[F_ITER:.*]]: i64, %[[F_ITER:.*]]: f64):
# CHECK: %[[F_NEXT:.*]] = torch.kernel_call "aten::mul" %[[F_ITER]], %[[F_ITER]] : (f64, f64) -> f64 {sigArgTypes = ["float", "float"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["float"]}
# CHECK: %[[COND_ITER:.*]] = torch.kernel_call "aten::lt" %[[F_NEXT]], %[[VAL_0]] : (f64, i64) -> !basicpy.BoolType {sigArgTypes = ["float", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["bool"]}
# CHECK: torch.prim.Loop.condition %[[COND_ITER]] iter(%[[F_NEXT]]) : !basicpy.BoolType, (f64)
# CHECK: torch.prim.Loop.condition %[[COND_ITER]], iter(%[[F_NEXT]] : f64)
# CHECK: } : (i64, !basicpy.BoolType, f64) -> f64
# CHECK: return %[[RET:.*]] : f64
@mb.import_function
@ -57,7 +57,7 @@ def prim_Loop_whilelike(n: int):
# CHECK: %[[RET:.*]] = torch.prim.Loop %[[ARG]], %[[TRUE]], init(%[[NONE_DEREFINED]]) {
# CHECK: ^bb0(%[[IV:.*]]: i64, %[[X_ITER:.*]]: !torch.optional<i64>):
# CHECK: %[[X_NEXT:.*]] = torch.derefine %[[ARG]] : i64 -> !torch.optional<i64>
# CHECK: torch.prim.Loop.condition %[[TRUE]] iter(%[[X_NEXT]]) : !basicpy.BoolType, (!torch.optional<i64>)
# CHECK: torch.prim.Loop.condition %[[TRUE]], iter(%[[X_NEXT]] : !torch.optional<i64>)
# CHECK: } : (i64, !basicpy.BoolType, !torch.optional<i64>) -> !torch.optional<i64>
# CHECK: return %[[RET:.*]] : !torch.optional<i64>
@mb.import_function

View File

@ -102,6 +102,15 @@ def prim_ListUnpack(l: typing.List[int]):
def prim_dtype(x):
return x.dtype
# CHECK-LABEL: func @__torch__.prim_layout(
# CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> i64 {
# CHECK: %[[RET:.*]] = torch.prim.layout %[[ARG]] : !numpy.ndarray<*:!numpy.any_dtype> -> i64
# CHECK: return %[[RET]] : i64
@mb.import_function
@torch.jit.script
def prim_layout(x):
return x.layout
# CHECK-LABEL: func @__torch__.prim_device(
# CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !torch.Device {
# CHECK: %[[RET:.*]] = torch.prim.device %[[ARG]] : !numpy.ndarray<*:!numpy.any_dtype> -> !torch.Device
@ -111,5 +120,33 @@ def prim_dtype(x):
def prim_device(x):
return x.device
# CHECK-LABEL: func @__torch__.prim_min(
# CHECK-SAME: %[[ARG:.*]]: i64) -> !basicpy.TupleType {
# CHECK: %[[SINGLETON:.*]] = basicpy.build_list %[[ARG]] : (i64) -> !basicpy.ListType
# CHECK: %[[MIN1:.*]] = torch.prim.min %[[SINGLETON]] : !basicpy.ListType -> i64
# CHECK: %[[MIN2:.*]] = torch.prim.min %[[ARG]], %[[ARG]] : i64, i64 -> i64
# CHECK: %[[ARG_3_TIMES:.*]] = basicpy.build_list %[[ARG]], %[[ARG]], %[[ARG]] : (i64, i64, i64) -> !basicpy.ListType
# CHECK: %[[MIN3:.*]] = torch.prim.min %[[ARG_3_TIMES]] : !basicpy.ListType -> i64
# CHECK: %[[RET:.*]] = basicpy.build_tuple %[[MIN1]], %[[MIN2]], %[[MIN3]] : (i64, i64, i64) -> !basicpy.TupleType
# CHECK: return %[[RET]] : !basicpy.TupleType
@mb.import_function
@torch.jit.script
def prim_min(x: int):
return min(x), min(x,x), min(x, x, x)
# CHECK-LABEL: func @__torch__.prim_max(
# CHECK-SAME: %[[ARG:.*]]: i64) -> !basicpy.TupleType {
# CHECK: %[[SINGLETON:.*]] = basicpy.build_list %[[ARG]] : (i64) -> !basicpy.ListType
# CHECK: %[[MAX1:.*]] = torch.prim.max %[[SINGLETON]] : !basicpy.ListType -> i64
# CHECK: %[[MAX2:.*]] = torch.prim.max %[[ARG]], %[[ARG]] : i64, i64 -> i64
# CHECK: %[[ARG_3_TIMES:.*]] = basicpy.build_list %[[ARG]], %[[ARG]], %[[ARG]] : (i64, i64, i64) -> !basicpy.ListType
# CHECK: %[[MAX3:.*]] = torch.prim.max %[[ARG_3_TIMES]] : !basicpy.ListType -> i64
# CHECK: %[[RET:.*]] = basicpy.build_tuple %[[MAX1]], %[[MAX2]], %[[MAX3]] : (i64, i64, i64) -> !basicpy.TupleType
# CHECK: return %[[RET]] : !basicpy.TupleType
@mb.import_function
@torch.jit.script
def prim_max(x: int):
return max(x), max(x,x), max(x, x, x)
mb.module.operation.print()
print()

View File

@ -420,8 +420,8 @@ def Torch_PrimLoopConditionOp : Torch_Op<"prim.Loop.condition", [
let results = (outs);
let assemblyFormat = [{
$shouldContinue `iter` `(` $iterArgs `)`
attr-dict `:` type($shouldContinue) `,` `(` type($iterArgs) `)`
$shouldContinue `,`
`iter` `(` ($iterArgs^ `:` type($iterArgs))? `)` attr-dict
}];
}
@ -564,6 +564,15 @@ def Torch_PrimdtypeOp : Torch_Op<"prim.dtype", []> {
}];
}
def Torch_PrimlayoutOp : Torch_Op<"prim.layout", []> {
let summary = "TorchScript prim::layout op";
let arguments = (ins AnyTorchTensorType:$tensor);
let results = (outs AnyTorchNumberType:$result);
let assemblyFormat = [{
$tensor attr-dict `:` type($tensor) `->` type($result)
}];
}
def Torch_PrimdeviceOp : Torch_Op<"prim.device", []> {
let summary = "TorchScript prim::device op";
let arguments = (ins AnyTorchTensorType:$tensor);
@ -573,4 +582,26 @@ def Torch_PrimdeviceOp : Torch_Op<"prim.device", []> {
}];
}
def Torch_PrimminOp : Torch_Op<"prim.min", []> {
let summary = "TorchScript prim::min op";
// TODO: Separate this along all the different possible signatures.
// At the time of this writing, there are 11.
// In particular, there are binary scalar versions and various list versions.
let arguments = (ins Variadic<AnyType>:$operands);
let results = (outs AnyTorchScalarType:$result);
let assemblyFormat = [{
$operands attr-dict `:` type($operands) `->` type($result)
}];
}
def Torch_PrimmaxOp : Torch_Op<"prim.max", []> {
let summary = "TorchScript prim::max op";
// TODO: Separate this along all the different possible signatures.
let arguments = (ins Variadic<AnyType>:$operands);
let results = (outs AnyTorchScalarType:$result);
let assemblyFormat = [{
$operands attr-dict `:` type($operands) `->` type($result)
}];
}
#endif // TORCH_OPS

View File

@ -26,6 +26,17 @@ createPrepareForGlobalizeObjectGraphPass();
/// See the documentation on torch-globalize-object-graph for more details.
void createGlobalizePipeline(OpPassManager &pm);
/// Creates a pipeline that lowers the object graph IR that is produced by
/// TorchScript import into the form expected by npcomp-verify-backend-contract.
void createLowerObjectGraphPipeline(OpPassManager &pm);
/// Creates a pipeline that lowers a flat list of funcs and global slots
/// with the torch and aten dialects and mutable arrays and converts it to
/// the form required by npcomp-verify-backend-contract, in particular
/// lowering most arrays to ranked tensors of known dtype, lowering aten ops to
/// linalg, converting torch.prim.* ops to elementary math operations.
void createLowerToNpcompBackendPipeline(OpPassManager &pm);
std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
std::unique_ptr<OperationPass<FuncOp>> createRefineTypesPass();

View File

@ -56,8 +56,9 @@ class VerifyBackendContractPass
target.addDynamicallyLegalDialect<StandardOpsDialect>(isLegalScalarOp);
target.addDynamicallyLegalDialect<math::MathDialect>(isLegalScalarOp);
// Tensor operations should go through linalg.
// Tensor operations should go through linalg and the tensor dialect.
target.addDynamicallyLegalDialect<linalg::LinalgDialect>(opHasLegalTypes);
target.addDynamicallyLegalDialect<tensor::TensorDialect>(opHasLegalTypes);
// DimOp is used to query tensor sizes.
target.addDynamicallyLegalOp<memref::DimOp>(opHasLegalTypes);

View File

@ -19,4 +19,9 @@ add_npcomp_conversion_library(NPCOMPTorchPasses
MLIRPass
NPCOMPTorchDialect
NPCOMPBasicpyDialect
NPCOMPATenPasses
NPCOMPNumpyPasses
NPCOMPATenToLinalg
NPCOMPATenToTCF
NPCOMPTCFToStd
)

View File

@ -7,7 +7,15 @@
//===----------------------------------------------------------------------===//
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "npcomp/Backend/Common/Passes.h"
#include "npcomp/Conversion/ATenToLinalg/ATenToLinalg.h"
#include "npcomp/Conversion/ATenToTCF/Passes.h"
#include "npcomp/Conversion/TCFToStd/TCFToStd.h"
#include "npcomp/Dialect/ATen/Transforms/Passes.h"
#include "npcomp/Dialect/Numpy/Transforms/Passes.h"
//===----------------------------------------------------------------------===//
// Pass registration
@ -23,9 +31,86 @@ void mlir::NPCOMP::registerTorchPasses() {
mlir::PassPipelineRegistration<>(
"torch-globalize-pipeline", "Globalization pipeline.",
mlir::NPCOMP::Torch::createGlobalizePipeline);
mlir::PassPipelineRegistration<>(
"torchscript-to-npcomp-backend-pipeline",
"Pipeline lowering torch object graph to npcomp backend format.",
mlir::NPCOMP::Torch::createLowerObjectGraphPipeline);
mlir::PassPipelineRegistration<>(
"torch-globalized-module-to-npcomp-backend-pipeline",
"Pipeline lowering to npcomp backend form.",
mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline);
}
void mlir::NPCOMP::Torch::createGlobalizePipeline(OpPassManager &pm) {
pm.addPass(createPrepareForGlobalizeObjectGraphPass());
pm.addPass(createGlobalizeObjectGraphPass());
}
void mlir::NPCOMP::Torch::createLowerObjectGraphPipeline(OpPassManager &pm) {
// When we import TorchScript IR, we import their entire "compilation unit",
// which can contain numerous functions unrelated to the current program,
// which breaks torch-globalization-pipeline; for example, there can be
// random functions referencing types that haven't been imported
// as part of the root `torch.nn.Module` we imported. Those will
// be unreferenced private functions which symbol-dce will clean up nicely.
pm.addPass(createSymbolDCEPass());
// Globalize the program. The rest of the compiler assumes a globalized
// program, which makes all analyses and transforms significantly easier
// to write.
pm.addPass(createPrepareForGlobalizeObjectGraphPass());
pm.addPass(createGlobalizeObjectGraphPass());
// "lower" `torch.global_slot` ops by deleting them if unused, which we
// currently require because we don't have a lowering path for backends to
// handle them.
// Torch usually inserts a few unused global slots so this ends up hitting
// every single module even if it doesn't have any explicit slots.
// TODO: Support global slots in backends.
pm.addPass(createSymbolDCEPass());
// Currently, our shape inference is not powerful enough to deal with
// calls, so inline everything.
// TODO: Improve shape inference.
pm.addPass(createInlinerPass());
// Incorporate user annotations and remove signature Python-isms.
pm.addPass(createAdjustCallingConventionsPass());
createLowerToNpcompBackendPipeline(pm);
}
void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline(OpPassManager &pm) {
// Recognize ATen kernels.
pm.addNestedPass<FuncOp>(aten::createRecognizeKernelsPass());
// Convert the bulk of the program to ranked tensors with known dtype.
// This is the input to the backend layer that we are aiming for.
// First, unilaterally convert public functions to tensor.
// The way this pass is currently written, this implies that
// as pipeline authors, we are restricting our users to not be able to see
// updates to "out params" on their public functions.
// This is deemed ok for now.
pm.addPass(Numpy::createPublicFunctionsToTensorPass());
// Convert the bulk of non-ABI-visible arrays to tensors.
pm.addNestedPass<FuncOp>(Numpy::createArrayToTensorPass());
// Do shape and dtype refinement.
// We could do it sooner, but the pass currently doesn't have transfer
// functions for array ops.
pm.addNestedPass<FuncOp>(Torch::createRefineTypesPass());
// Propagate to ABI return types the shape/dtype information discovered by
// the previous pass. Doing this is ABI-compatible for our backends.
pm.addPass(Numpy::createRefinePublicReturnPass());
// Clean up a few stray array/tensor conversion remnants.
pm.addNestedPass<FuncOp>(Numpy::createArrayToTensorPass());
// Lower to TCP (+ guards) which is the input to codegen backends.
// Most of this should be subsumed by aten->linalg+guards conversions.
// (the guard generation will be automated from the linalg Op DSL).
pm.addNestedPass<FuncOp>(createConvertATenToLinalgPass());
pm.addNestedPass<FuncOp>(createConvertATenToTCFPass());
pm.addNestedPass<FuncOp>(createConvertTCFToStdPass());
pm.addNestedPass<FuncOp>(createConvertElementwiseToLinalgPass());
// Verify that we have lowered to the form that backends expect.
// This fails compilation (signalPassFailure) if the IR is not in the
// correct form.
pm.addPass(CommonBackend::createVerifyBackendContractPass());
}

View File

@ -15,71 +15,6 @@ __all__ = [
"lower_module",
]
# The set of passes that lowers from a TorchScript object graph representation
# to a module semantics where symbols correspond to dotted paths into the
# module.
OBJECT_GRAPH_LOWERING_PASSES = (
# When we import TorchScript IR, we import their entire "compilation unit",
# which can contain numerous functions unrelated to the current program,
# which breaks torch-globalization-pipeline; for example, there can be
# random functions referencing types that haven't been imported
# as part of the root `torch.nn.Module` we imported. Those will
# be unreferenced private functions which symbol-dce will clean up nicely.
"symbol-dce",
# Globalize the program. The rest of the compiler assumes a globalized
# program, which makes all analyses and transforms significantly easier
# to write.
"torch-globalize-pipeline",
# "lower" `torch.global_slot` ops by deleting them if unused, which we
# currently require because we don't have a lowering path for backends to
# handle them.
# Torch usually inserts a few unused global slots so this ends up hitting
# every single module even if it doesn't have any explicit slots.
# TODO: Support global slots in backends.
"symbol-dce",
# Currently, our shape inference is not powerful enough to deal with
# calls, so inline everything.
# TODO: Improve shape inference.
"inline",
# Incorporate user annotations and remove signature Python-isms.
"torch-adjust-calling-conventions",
)
TORCH_TO_TCP_PASSES = (
# Recognize ATen kernels.
"func(aten-recognize-kernels)",
# Convert the bulk of the program to ranked tensors with known dtype.
# This is the input to the backend layer that we are aiming for.
# First, unilaterally convert public functions to tensor.
# The way this pass is currently written, this implies that
# as pipeline authors, we are restricting our users to not be able to see
# updates to "out params" on their public functions.
# This is deemed ok for now.
"numpy-public-functions-to-tensor",
# Convert the bulk of non-ABI-visible arrays to tensors.
"func(numpy-array-to-tensor)",
# Do shape and dtype refinement.
# We could do it sooner, but the pass currently doesn't have transfer
# functions for array ops.
"func(torch-refine-types)",
# Propagate to ABI return types the shape/dtype information discovered by
# the previous pass. Doing this is ABI-compatible for our backends.
"numpy-refine-public-return",
# Clean up a few stray array/tensor conversion remnants.
"func(numpy-array-to-tensor)",
# Lower to TCP (+ guards) which is the input to codegen backends.
# Most of this should be subsumed by aten->linalg+guards conversions.
# (the guard generation will be automated from the linalg Op DSL)
"func(convert-aten-to-linalg)",
"func(convert-aten-to-tcf)",
"func(convert-tcf-to-std)",
"func(convert-elementwise-to-linalg)",
"npcomp-verify-backend-contract",
)
def lower_module(imported_module: Module):
"""Compiles an imported module, with a flat list of functions.
@ -93,7 +28,7 @@ def lower_module(imported_module: Module):
if logging.debug_enabled():
logging.debug("Initial PyTorch IR:\n{}", imported_module)
# Frontend.
pipeline_str = ",".join(TORCH_TO_TCP_PASSES)
pipeline_str = "torch-globalized-module-to-npcomp-backend-pipeline"
if logging.debug_enabled():
logging.debug("Running Torch->TCP pipeline '{}'", pipeline_str)
pm = PassManager.parse(pipeline_str)
@ -116,10 +51,10 @@ def lower_object_graph(imported_module: Module):
logging.debug("Initial PyTorch object graph IR:\n{}", imported_module)
# Object graph lowering.
pipeline_str = ",".join(OBJECT_GRAPH_LOWERING_PASSES)
pipeline_str = "torchscript-to-npcomp-backend-pipeline"
if logging.debug_enabled():
logging.debug(
"Running Torch object graph lowering pipeline '{}'", pipeline_str)
pm = PassManager.parse(pipeline_str)
pm.run(imported_module)
return lower_module(imported_module)
return imported_module