mirror of https://github.com/llvm/torch-mlir
Miscellaneous changes while trying to work on ResNet18
- Move frontend lowering pipelines to c++ (this helps with reproducing failures in npcomp-opt) - Add debugging printouts when compilation fails on RefBackendTestConfig The experience now when a test fails during MLIR lowering is now like this: ``` NPCOMP TorchScript Object Graph IR -> NPCOMP Backend IR lowering failed with the following diagnostics: failed to legalize operation 'torch.global_slot' Module does not conform to npcomp's backend contract. See dialect conversion legality information above. Error can be reproduced with: $ npcomp-opt -torchscript-to-npcomp-backend-pipeline /tmp/ResNet18Module.mlir ``` And when TorchScript->MLIR import fails it looks like this: ``` PyTorch TorchScript module -> NPCOMP Object Graph IR import failed with the following diagnostics: unhandled prim operation: %18 : int = prim::min(%17) # /usr/local/google/home/silvasean/.local/lib/python3.9/site-packages/torch/nn/functional.py:4532:4 ``` Also, - Add `--filter=<regex>` to e2e test harness to filter tests. - Add a few prim ops that were needed to import ResNet18 - Fix torch.prim.Loop.condition assemblyFormat (it previously would not round-trip in the case of no loop-carried variables)pull/209/head
parent
8f96901943
commit
3a890aa26c
|
@ -403,6 +403,12 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
|
|||
}
|
||||
|
||||
for (torch::jit::Function *function : cu->get_functions()) {
|
||||
// Useful for debugging errors in free functions that end up being
|
||||
// unused. These can be missing when round-tripping through the on-disk
|
||||
// format, even though they still cause import issues when importing
|
||||
// through the larger Python session where they originate.
|
||||
// std::cerr << "NAME: " << function->qualname().qualifiedName() << "\n";
|
||||
// std::cerr << *function->graph();
|
||||
MethodAnnotation *annotation =
|
||||
annotator.getMethodAnnotationForFunction(function);
|
||||
MlirOperation func = importJitFunctionAsFuncOp(
|
||||
|
|
|
@ -80,6 +80,9 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
|
|||
case c10::prim::Uninitialized:
|
||||
case c10::prim::RaiseException:
|
||||
case c10::prim::Print:
|
||||
case c10::prim::min:
|
||||
case c10::prim::max:
|
||||
case c10::prim::layout:
|
||||
case c10::prim::NumToTensor: {
|
||||
createAndMapTrivialNode(node,
|
||||
"torch.prim." + std::string(kind.toUnqualString()));
|
||||
|
|
|
@ -3,6 +3,8 @@
|
|||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import sys
|
||||
|
||||
from torch_mlir.torchscript.e2e_test.framework import run_tests
|
||||
from torch_mlir.torchscript.e2e_test.reporting import report_results
|
||||
|
@ -33,6 +35,9 @@ Meaning of options:
|
|||
"refbackend": run through npcomp's RefBackend.
|
||||
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
|
||||
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
||||
''')
|
||||
parser.add_argument('--filter', default='.*', help='''
|
||||
Regular expression specifying which tests to include in this run.
|
||||
''')
|
||||
args = parser.parse_args()
|
||||
if args.config == 'refbackend':
|
||||
|
@ -41,7 +46,20 @@ Meaning of options:
|
|||
config = NativeTorchTestConfig()
|
||||
elif args.config == 'torchscript':
|
||||
config = TorchScriptTestConfig()
|
||||
results = run_tests(GLOBAL_TEST_REGISTRY, config)
|
||||
|
||||
tests = [
|
||||
test for test in GLOBAL_TEST_REGISTRY
|
||||
if re.match(args.filter, test.unique_name)
|
||||
]
|
||||
if len(tests) == 0:
|
||||
print(
|
||||
f'ERROR: the provided filter {args.filter!r} does not match any tests'
|
||||
)
|
||||
print('The available tests are:')
|
||||
for test in GLOBAL_TEST_REGISTRY:
|
||||
print(test.unique_name)
|
||||
sys.exit(1)
|
||||
results = run_tests(tests, config)
|
||||
report_results(results)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -11,7 +11,7 @@ from torch_mlir.torchscript.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class Resnet18Module(torch.nn.Module):
|
||||
class ResNet18Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Reset seed to make model deterministic.
|
||||
|
@ -25,6 +25,6 @@ class Resnet18Module(torch.nn.Module):
|
|||
def forward(self, img):
|
||||
return self.resnet.forward(img)
|
||||
|
||||
@register_test_case(module_factory=lambda: Resnet18Module())
|
||||
def Resnet18Module_basic(module, tu: TestUtils):
|
||||
@register_test_case(module_factory=lambda: ResNet18Module())
|
||||
def ResNet18Module_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 3, 224, 224))
|
||||
|
|
|
@ -2,13 +2,18 @@
|
|||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import sys
|
||||
from typing import Any
|
||||
from io import StringIO
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mlir.passmanager import PassManager
|
||||
|
||||
import torch_mlir
|
||||
from npcomp.compiler.pytorch.backend import refjit, frontend_lowering
|
||||
from npcomp.compiler.pytorch.backend import refjit
|
||||
from torch_mlir.torchscript.e2e_test.framework import TestConfig, Trace, TraceItem
|
||||
from torch_mlir.torchscript.annotations import extract_annotations
|
||||
|
||||
|
@ -26,9 +31,51 @@ class RefBackendTestConfig(TestConfig):
|
|||
|
||||
extract_annotations(program, scripted, class_annotator)
|
||||
|
||||
mb.import_module(scripted._c, class_annotator)
|
||||
# Lower module in place.
|
||||
frontend_lowering.lower_object_graph(mb.module)
|
||||
# TODO: Find a way to make each of these calls own its own
|
||||
# "debuggable error report" situation.
|
||||
try:
|
||||
sys.stderr = StringIO()
|
||||
# Import the TorchScript module to MLIR
|
||||
mb.import_module(scripted._c, class_annotator)
|
||||
except Exception as e:
|
||||
raise Exception(f"""
|
||||
PyTorch TorchScript module -> NPCOMP Object Graph IR import failed with the following diagnostics:
|
||||
{sys.stderr.getvalue()}
|
||||
""") from None
|
||||
finally:
|
||||
sys.stderr = sys.__stderr__
|
||||
|
||||
try:
|
||||
sys.stderr = StringIO()
|
||||
asm_for_error_report = mb.module.operation.get_asm(
|
||||
large_elements_limit=10, enable_debug_info=True)
|
||||
pipeline_str = "torchscript-to-npcomp-backend-pipeline"
|
||||
# Lower module in place to make it ready for compiler backends.
|
||||
with mb.module.context:
|
||||
pm = PassManager.parse(pipeline_str)
|
||||
pm.run(mb.module)
|
||||
except Exception as e:
|
||||
# TODO: More robust.
|
||||
# - don't arbitrarily clutter up /tmp. When a test suite has many
|
||||
# tests, this can be a big disk cost (also, /tmp/ is frequently a
|
||||
# RAM fs, which increases worries about capacity).
|
||||
# - don't have colliding filenames (hard to do without cluttering
|
||||
# up /tmp)
|
||||
# - if we do have have colliding filenames, writes should at least
|
||||
# avoid being racy.
|
||||
filename = os.path.join(tempfile.gettempdir(),
|
||||
scripted.original_name + '.mlir')
|
||||
with open(filename, 'w') as f:
|
||||
f.write(asm_for_error_report)
|
||||
raise Exception(f"""
|
||||
NPCOMP TorchScript Object Graph IR -> NPCOMP Backend IR lowering failed with the following diagnostics:
|
||||
{sys.stderr.getvalue()}
|
||||
|
||||
Error can be reproduced with:
|
||||
$ npcomp-opt -{pipeline_str} {filename}
|
||||
""") from None
|
||||
finally:
|
||||
sys.stderr = sys.__stderr__
|
||||
return self.backend.compile(mb.module)
|
||||
|
||||
def run(self, artifact: Any, trace: Trace) -> Trace:
|
||||
|
|
|
@ -18,7 +18,7 @@ mb = torch_mlir.ModuleBuilder()
|
|||
# CHECK: %[[RESULTS:.*]] = torch.prim.Loop %[[MAX_ITERATIONS]], %[[BOOL_TRUE]], init(%[[F_INIT]]) {
|
||||
# CHECK: ^bb0(%[[IV:.*]]: i64, %[[F_ITER:.*]]: f64):
|
||||
# CHECK: %[[F_NEXT:.*]] = torch.kernel_call "aten::add" %[[F_ITER]], %[[IV]] : (f64, i64) -> f64 {sigArgTypes = ["float", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["float"]}
|
||||
# CHECK: torch.prim.Loop.condition %[[BOOL_TRUE]] iter(%[[F_NEXT]]) : !basicpy.BoolType, (f64)
|
||||
# CHECK: torch.prim.Loop.condition %[[BOOL_TRUE]], iter(%[[F_NEXT]] : f64)
|
||||
# CHECK: } : (i64, !basicpy.BoolType, f64) -> f64
|
||||
# CHECK: return %[[RESULTS:.*]] : f64
|
||||
@mb.import_function
|
||||
|
@ -38,7 +38,7 @@ def prim_Loop_forlike(n: int):
|
|||
# CHECK: ^bb0(%[[F_ITER:.*]]: i64, %[[F_ITER:.*]]: f64):
|
||||
# CHECK: %[[F_NEXT:.*]] = torch.kernel_call "aten::mul" %[[F_ITER]], %[[F_ITER]] : (f64, f64) -> f64 {sigArgTypes = ["float", "float"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["float"]}
|
||||
# CHECK: %[[COND_ITER:.*]] = torch.kernel_call "aten::lt" %[[F_NEXT]], %[[VAL_0]] : (f64, i64) -> !basicpy.BoolType {sigArgTypes = ["float", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["bool"]}
|
||||
# CHECK: torch.prim.Loop.condition %[[COND_ITER]] iter(%[[F_NEXT]]) : !basicpy.BoolType, (f64)
|
||||
# CHECK: torch.prim.Loop.condition %[[COND_ITER]], iter(%[[F_NEXT]] : f64)
|
||||
# CHECK: } : (i64, !basicpy.BoolType, f64) -> f64
|
||||
# CHECK: return %[[RET:.*]] : f64
|
||||
@mb.import_function
|
||||
|
@ -57,7 +57,7 @@ def prim_Loop_whilelike(n: int):
|
|||
# CHECK: %[[RET:.*]] = torch.prim.Loop %[[ARG]], %[[TRUE]], init(%[[NONE_DEREFINED]]) {
|
||||
# CHECK: ^bb0(%[[IV:.*]]: i64, %[[X_ITER:.*]]: !torch.optional<i64>):
|
||||
# CHECK: %[[X_NEXT:.*]] = torch.derefine %[[ARG]] : i64 -> !torch.optional<i64>
|
||||
# CHECK: torch.prim.Loop.condition %[[TRUE]] iter(%[[X_NEXT]]) : !basicpy.BoolType, (!torch.optional<i64>)
|
||||
# CHECK: torch.prim.Loop.condition %[[TRUE]], iter(%[[X_NEXT]] : !torch.optional<i64>)
|
||||
# CHECK: } : (i64, !basicpy.BoolType, !torch.optional<i64>) -> !torch.optional<i64>
|
||||
# CHECK: return %[[RET:.*]] : !torch.optional<i64>
|
||||
@mb.import_function
|
||||
|
|
|
@ -102,6 +102,15 @@ def prim_ListUnpack(l: typing.List[int]):
|
|||
def prim_dtype(x):
|
||||
return x.dtype
|
||||
|
||||
# CHECK-LABEL: func @__torch__.prim_layout(
|
||||
# CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> i64 {
|
||||
# CHECK: %[[RET:.*]] = torch.prim.layout %[[ARG]] : !numpy.ndarray<*:!numpy.any_dtype> -> i64
|
||||
# CHECK: return %[[RET]] : i64
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def prim_layout(x):
|
||||
return x.layout
|
||||
|
||||
# CHECK-LABEL: func @__torch__.prim_device(
|
||||
# CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !torch.Device {
|
||||
# CHECK: %[[RET:.*]] = torch.prim.device %[[ARG]] : !numpy.ndarray<*:!numpy.any_dtype> -> !torch.Device
|
||||
|
@ -111,5 +120,33 @@ def prim_dtype(x):
|
|||
def prim_device(x):
|
||||
return x.device
|
||||
|
||||
# CHECK-LABEL: func @__torch__.prim_min(
|
||||
# CHECK-SAME: %[[ARG:.*]]: i64) -> !basicpy.TupleType {
|
||||
# CHECK: %[[SINGLETON:.*]] = basicpy.build_list %[[ARG]] : (i64) -> !basicpy.ListType
|
||||
# CHECK: %[[MIN1:.*]] = torch.prim.min %[[SINGLETON]] : !basicpy.ListType -> i64
|
||||
# CHECK: %[[MIN2:.*]] = torch.prim.min %[[ARG]], %[[ARG]] : i64, i64 -> i64
|
||||
# CHECK: %[[ARG_3_TIMES:.*]] = basicpy.build_list %[[ARG]], %[[ARG]], %[[ARG]] : (i64, i64, i64) -> !basicpy.ListType
|
||||
# CHECK: %[[MIN3:.*]] = torch.prim.min %[[ARG_3_TIMES]] : !basicpy.ListType -> i64
|
||||
# CHECK: %[[RET:.*]] = basicpy.build_tuple %[[MIN1]], %[[MIN2]], %[[MIN3]] : (i64, i64, i64) -> !basicpy.TupleType
|
||||
# CHECK: return %[[RET]] : !basicpy.TupleType
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def prim_min(x: int):
|
||||
return min(x), min(x,x), min(x, x, x)
|
||||
|
||||
# CHECK-LABEL: func @__torch__.prim_max(
|
||||
# CHECK-SAME: %[[ARG:.*]]: i64) -> !basicpy.TupleType {
|
||||
# CHECK: %[[SINGLETON:.*]] = basicpy.build_list %[[ARG]] : (i64) -> !basicpy.ListType
|
||||
# CHECK: %[[MAX1:.*]] = torch.prim.max %[[SINGLETON]] : !basicpy.ListType -> i64
|
||||
# CHECK: %[[MAX2:.*]] = torch.prim.max %[[ARG]], %[[ARG]] : i64, i64 -> i64
|
||||
# CHECK: %[[ARG_3_TIMES:.*]] = basicpy.build_list %[[ARG]], %[[ARG]], %[[ARG]] : (i64, i64, i64) -> !basicpy.ListType
|
||||
# CHECK: %[[MAX3:.*]] = torch.prim.max %[[ARG_3_TIMES]] : !basicpy.ListType -> i64
|
||||
# CHECK: %[[RET:.*]] = basicpy.build_tuple %[[MAX1]], %[[MAX2]], %[[MAX3]] : (i64, i64, i64) -> !basicpy.TupleType
|
||||
# CHECK: return %[[RET]] : !basicpy.TupleType
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def prim_max(x: int):
|
||||
return max(x), max(x,x), max(x, x, x)
|
||||
|
||||
mb.module.operation.print()
|
||||
print()
|
||||
|
|
|
@ -420,8 +420,8 @@ def Torch_PrimLoopConditionOp : Torch_Op<"prim.Loop.condition", [
|
|||
let results = (outs);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$shouldContinue `iter` `(` $iterArgs `)`
|
||||
attr-dict `:` type($shouldContinue) `,` `(` type($iterArgs) `)`
|
||||
$shouldContinue `,`
|
||||
`iter` `(` ($iterArgs^ `:` type($iterArgs))? `)` attr-dict
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -564,6 +564,15 @@ def Torch_PrimdtypeOp : Torch_Op<"prim.dtype", []> {
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_PrimlayoutOp : Torch_Op<"prim.layout", []> {
|
||||
let summary = "TorchScript prim::layout op";
|
||||
let arguments = (ins AnyTorchTensorType:$tensor);
|
||||
let results = (outs AnyTorchNumberType:$result);
|
||||
let assemblyFormat = [{
|
||||
$tensor attr-dict `:` type($tensor) `->` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_PrimdeviceOp : Torch_Op<"prim.device", []> {
|
||||
let summary = "TorchScript prim::device op";
|
||||
let arguments = (ins AnyTorchTensorType:$tensor);
|
||||
|
@ -573,4 +582,26 @@ def Torch_PrimdeviceOp : Torch_Op<"prim.device", []> {
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_PrimminOp : Torch_Op<"prim.min", []> {
|
||||
let summary = "TorchScript prim::min op";
|
||||
// TODO: Separate this along all the different possible signatures.
|
||||
// At the time of this writing, there are 11.
|
||||
// In particular, there are binary scalar versions and various list versions.
|
||||
let arguments = (ins Variadic<AnyType>:$operands);
|
||||
let results = (outs AnyTorchScalarType:$result);
|
||||
let assemblyFormat = [{
|
||||
$operands attr-dict `:` type($operands) `->` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_PrimmaxOp : Torch_Op<"prim.max", []> {
|
||||
let summary = "TorchScript prim::max op";
|
||||
// TODO: Separate this along all the different possible signatures.
|
||||
let arguments = (ins Variadic<AnyType>:$operands);
|
||||
let results = (outs AnyTorchScalarType:$result);
|
||||
let assemblyFormat = [{
|
||||
$operands attr-dict `:` type($operands) `->` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // TORCH_OPS
|
||||
|
|
|
@ -26,6 +26,17 @@ createPrepareForGlobalizeObjectGraphPass();
|
|||
/// See the documentation on torch-globalize-object-graph for more details.
|
||||
void createGlobalizePipeline(OpPassManager &pm);
|
||||
|
||||
/// Creates a pipeline that lowers the object graph IR that is produced by
|
||||
/// TorchScript import into the form expected by npcomp-verify-backend-contract.
|
||||
void createLowerObjectGraphPipeline(OpPassManager &pm);
|
||||
|
||||
/// Creates a pipeline that lowers a flat list of funcs and global slots
|
||||
/// with the torch and aten dialects and mutable arrays and converts it to
|
||||
/// the form required by npcomp-verify-backend-contract, in particular
|
||||
/// lowering most arrays to ranked tensors of known dtype, lowering aten ops to
|
||||
/// linalg, converting torch.prim.* ops to elementary math operations.
|
||||
void createLowerToNpcompBackendPipeline(OpPassManager &pm);
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createRefineTypesPass();
|
||||
|
|
|
@ -56,8 +56,9 @@ class VerifyBackendContractPass
|
|||
target.addDynamicallyLegalDialect<StandardOpsDialect>(isLegalScalarOp);
|
||||
target.addDynamicallyLegalDialect<math::MathDialect>(isLegalScalarOp);
|
||||
|
||||
// Tensor operations should go through linalg.
|
||||
// Tensor operations should go through linalg and the tensor dialect.
|
||||
target.addDynamicallyLegalDialect<linalg::LinalgDialect>(opHasLegalTypes);
|
||||
target.addDynamicallyLegalDialect<tensor::TensorDialect>(opHasLegalTypes);
|
||||
// DimOp is used to query tensor sizes.
|
||||
target.addDynamicallyLegalOp<memref::DimOp>(opHasLegalTypes);
|
||||
|
||||
|
|
|
@ -19,4 +19,9 @@ add_npcomp_conversion_library(NPCOMPTorchPasses
|
|||
MLIRPass
|
||||
NPCOMPTorchDialect
|
||||
NPCOMPBasicpyDialect
|
||||
NPCOMPATenPasses
|
||||
NPCOMPNumpyPasses
|
||||
NPCOMPATenToLinalg
|
||||
NPCOMPATenToTCF
|
||||
NPCOMPTCFToStd
|
||||
)
|
||||
|
|
|
@ -7,7 +7,15 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "npcomp/Backend/Common/Passes.h"
|
||||
#include "npcomp/Conversion/ATenToLinalg/ATenToLinalg.h"
|
||||
#include "npcomp/Conversion/ATenToTCF/Passes.h"
|
||||
#include "npcomp/Conversion/TCFToStd/TCFToStd.h"
|
||||
#include "npcomp/Dialect/ATen/Transforms/Passes.h"
|
||||
#include "npcomp/Dialect/Numpy/Transforms/Passes.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pass registration
|
||||
|
@ -23,9 +31,86 @@ void mlir::NPCOMP::registerTorchPasses() {
|
|||
mlir::PassPipelineRegistration<>(
|
||||
"torch-globalize-pipeline", "Globalization pipeline.",
|
||||
mlir::NPCOMP::Torch::createGlobalizePipeline);
|
||||
mlir::PassPipelineRegistration<>(
|
||||
"torchscript-to-npcomp-backend-pipeline",
|
||||
"Pipeline lowering torch object graph to npcomp backend format.",
|
||||
mlir::NPCOMP::Torch::createLowerObjectGraphPipeline);
|
||||
mlir::PassPipelineRegistration<>(
|
||||
"torch-globalized-module-to-npcomp-backend-pipeline",
|
||||
"Pipeline lowering to npcomp backend form.",
|
||||
mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline);
|
||||
}
|
||||
|
||||
void mlir::NPCOMP::Torch::createGlobalizePipeline(OpPassManager &pm) {
|
||||
pm.addPass(createPrepareForGlobalizeObjectGraphPass());
|
||||
pm.addPass(createGlobalizeObjectGraphPass());
|
||||
}
|
||||
|
||||
void mlir::NPCOMP::Torch::createLowerObjectGraphPipeline(OpPassManager &pm) {
|
||||
// When we import TorchScript IR, we import their entire "compilation unit",
|
||||
// which can contain numerous functions unrelated to the current program,
|
||||
// which breaks torch-globalization-pipeline; for example, there can be
|
||||
// random functions referencing types that haven't been imported
|
||||
// as part of the root `torch.nn.Module` we imported. Those will
|
||||
// be unreferenced private functions which symbol-dce will clean up nicely.
|
||||
pm.addPass(createSymbolDCEPass());
|
||||
// Globalize the program. The rest of the compiler assumes a globalized
|
||||
// program, which makes all analyses and transforms significantly easier
|
||||
// to write.
|
||||
pm.addPass(createPrepareForGlobalizeObjectGraphPass());
|
||||
pm.addPass(createGlobalizeObjectGraphPass());
|
||||
// "lower" `torch.global_slot` ops by deleting them if unused, which we
|
||||
// currently require because we don't have a lowering path for backends to
|
||||
// handle them.
|
||||
// Torch usually inserts a few unused global slots so this ends up hitting
|
||||
// every single module even if it doesn't have any explicit slots.
|
||||
// TODO: Support global slots in backends.
|
||||
pm.addPass(createSymbolDCEPass());
|
||||
// Currently, our shape inference is not powerful enough to deal with
|
||||
// calls, so inline everything.
|
||||
// TODO: Improve shape inference.
|
||||
pm.addPass(createInlinerPass());
|
||||
// Incorporate user annotations and remove signature Python-isms.
|
||||
pm.addPass(createAdjustCallingConventionsPass());
|
||||
|
||||
createLowerToNpcompBackendPipeline(pm);
|
||||
}
|
||||
|
||||
void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline(OpPassManager &pm) {
|
||||
// Recognize ATen kernels.
|
||||
pm.addNestedPass<FuncOp>(aten::createRecognizeKernelsPass());
|
||||
|
||||
// Convert the bulk of the program to ranked tensors with known dtype.
|
||||
// This is the input to the backend layer that we are aiming for.
|
||||
|
||||
// First, unilaterally convert public functions to tensor.
|
||||
// The way this pass is currently written, this implies that
|
||||
// as pipeline authors, we are restricting our users to not be able to see
|
||||
// updates to "out params" on their public functions.
|
||||
// This is deemed ok for now.
|
||||
pm.addPass(Numpy::createPublicFunctionsToTensorPass());
|
||||
// Convert the bulk of non-ABI-visible arrays to tensors.
|
||||
pm.addNestedPass<FuncOp>(Numpy::createArrayToTensorPass());
|
||||
// Do shape and dtype refinement.
|
||||
// We could do it sooner, but the pass currently doesn't have transfer
|
||||
// functions for array ops.
|
||||
pm.addNestedPass<FuncOp>(Torch::createRefineTypesPass());
|
||||
// Propagate to ABI return types the shape/dtype information discovered by
|
||||
// the previous pass. Doing this is ABI-compatible for our backends.
|
||||
pm.addPass(Numpy::createRefinePublicReturnPass());
|
||||
// Clean up a few stray array/tensor conversion remnants.
|
||||
pm.addNestedPass<FuncOp>(Numpy::createArrayToTensorPass());
|
||||
|
||||
// Lower to TCP (+ guards) which is the input to codegen backends.
|
||||
// Most of this should be subsumed by aten->linalg+guards conversions.
|
||||
// (the guard generation will be automated from the linalg Op DSL).
|
||||
pm.addNestedPass<FuncOp>(createConvertATenToLinalgPass());
|
||||
pm.addNestedPass<FuncOp>(createConvertATenToTCFPass());
|
||||
pm.addNestedPass<FuncOp>(createConvertTCFToStdPass());
|
||||
pm.addNestedPass<FuncOp>(createConvertElementwiseToLinalgPass());
|
||||
|
||||
// Verify that we have lowered to the form that backends expect.
|
||||
// This fails compilation (signalPassFailure) if the IR is not in the
|
||||
// correct form.
|
||||
pm.addPass(CommonBackend::createVerifyBackendContractPass());
|
||||
}
|
||||
|
|
|
@ -15,71 +15,6 @@ __all__ = [
|
|||
"lower_module",
|
||||
]
|
||||
|
||||
# The set of passes that lowers from a TorchScript object graph representation
|
||||
# to a module semantics where symbols correspond to dotted paths into the
|
||||
# module.
|
||||
OBJECT_GRAPH_LOWERING_PASSES = (
|
||||
# When we import TorchScript IR, we import their entire "compilation unit",
|
||||
# which can contain numerous functions unrelated to the current program,
|
||||
# which breaks torch-globalization-pipeline; for example, there can be
|
||||
# random functions referencing types that haven't been imported
|
||||
# as part of the root `torch.nn.Module` we imported. Those will
|
||||
# be unreferenced private functions which symbol-dce will clean up nicely.
|
||||
"symbol-dce",
|
||||
# Globalize the program. The rest of the compiler assumes a globalized
|
||||
# program, which makes all analyses and transforms significantly easier
|
||||
# to write.
|
||||
"torch-globalize-pipeline",
|
||||
# "lower" `torch.global_slot` ops by deleting them if unused, which we
|
||||
# currently require because we don't have a lowering path for backends to
|
||||
# handle them.
|
||||
# Torch usually inserts a few unused global slots so this ends up hitting
|
||||
# every single module even if it doesn't have any explicit slots.
|
||||
# TODO: Support global slots in backends.
|
||||
"symbol-dce",
|
||||
# Currently, our shape inference is not powerful enough to deal with
|
||||
# calls, so inline everything.
|
||||
# TODO: Improve shape inference.
|
||||
"inline",
|
||||
# Incorporate user annotations and remove signature Python-isms.
|
||||
"torch-adjust-calling-conventions",
|
||||
)
|
||||
|
||||
TORCH_TO_TCP_PASSES = (
|
||||
# Recognize ATen kernels.
|
||||
"func(aten-recognize-kernels)",
|
||||
|
||||
# Convert the bulk of the program to ranked tensors with known dtype.
|
||||
# This is the input to the backend layer that we are aiming for.
|
||||
|
||||
# First, unilaterally convert public functions to tensor.
|
||||
# The way this pass is currently written, this implies that
|
||||
# as pipeline authors, we are restricting our users to not be able to see
|
||||
# updates to "out params" on their public functions.
|
||||
# This is deemed ok for now.
|
||||
"numpy-public-functions-to-tensor",
|
||||
# Convert the bulk of non-ABI-visible arrays to tensors.
|
||||
"func(numpy-array-to-tensor)",
|
||||
# Do shape and dtype refinement.
|
||||
# We could do it sooner, but the pass currently doesn't have transfer
|
||||
# functions for array ops.
|
||||
"func(torch-refine-types)",
|
||||
# Propagate to ABI return types the shape/dtype information discovered by
|
||||
# the previous pass. Doing this is ABI-compatible for our backends.
|
||||
"numpy-refine-public-return",
|
||||
# Clean up a few stray array/tensor conversion remnants.
|
||||
"func(numpy-array-to-tensor)",
|
||||
|
||||
# Lower to TCP (+ guards) which is the input to codegen backends.
|
||||
# Most of this should be subsumed by aten->linalg+guards conversions.
|
||||
# (the guard generation will be automated from the linalg Op DSL)
|
||||
"func(convert-aten-to-linalg)",
|
||||
"func(convert-aten-to-tcf)",
|
||||
"func(convert-tcf-to-std)",
|
||||
"func(convert-elementwise-to-linalg)",
|
||||
"npcomp-verify-backend-contract",
|
||||
)
|
||||
|
||||
def lower_module(imported_module: Module):
|
||||
"""Compiles an imported module, with a flat list of functions.
|
||||
|
||||
|
@ -93,7 +28,7 @@ def lower_module(imported_module: Module):
|
|||
if logging.debug_enabled():
|
||||
logging.debug("Initial PyTorch IR:\n{}", imported_module)
|
||||
# Frontend.
|
||||
pipeline_str = ",".join(TORCH_TO_TCP_PASSES)
|
||||
pipeline_str = "torch-globalized-module-to-npcomp-backend-pipeline"
|
||||
if logging.debug_enabled():
|
||||
logging.debug("Running Torch->TCP pipeline '{}'", pipeline_str)
|
||||
pm = PassManager.parse(pipeline_str)
|
||||
|
@ -116,10 +51,10 @@ def lower_object_graph(imported_module: Module):
|
|||
logging.debug("Initial PyTorch object graph IR:\n{}", imported_module)
|
||||
|
||||
# Object graph lowering.
|
||||
pipeline_str = ",".join(OBJECT_GRAPH_LOWERING_PASSES)
|
||||
pipeline_str = "torchscript-to-npcomp-backend-pipeline"
|
||||
if logging.debug_enabled():
|
||||
logging.debug(
|
||||
"Running Torch object graph lowering pipeline '{}'", pipeline_str)
|
||||
pm = PassManager.parse(pipeline_str)
|
||||
pm.run(imported_module)
|
||||
return lower_module(imported_module)
|
||||
return imported_module
|
||||
|
|
Loading…
Reference in New Issue