Support multiple instances of a class in GlobalizeObjectGraph.
This happens in practice with e.g. ResNet from torchvision (multiple
instances of the same BatchNorm class).
The key observation is that for this program, and the expected set of
programs, we can convert the program to the same globalized form with a
bit more static analysis and effort to suitably monomorphize the
program. Though what we are doing here is fairly annoying to implement,
it saves any nontrivial later pass from having to do similar analyses
(or worse). E.g. shape inference would need to be object-graph aware,
mutation/lifetime analyses would have to be aware, etc. Additionally, it
would make us front-load what it means to have a !torch.nn.Module type
on an ABI boundary, which we are just not ready to handle.
I'm really, really hoping that in practice we can get away with
this, otherwise it's going to be really rough designing a representation
(and implementing everything to back it) that is convenient to transform
and gracefully scales from full object graph (in the most dynamic case)
down to a fixed set of global slots like we have here (in the most
static case, which we presume a lot of practical programs fall into).
This also involved introducing a
`torch-prepare-for-globalize-object-graph` pass that does a minimal set of
lowerings to simplify the IR into a more orthogonal and analyzable form,
and a `torch-globalize-pipeline` helper.
Recommended review order:
- updated documentation in Passes.td
- new tests in `globalize-object-graph-multiple-instances*.mlir`
- implementation of GlobalizeObjectGraph.cpp
- PrepareForGlobalizeObjectGraph.cpp + prepare-for-globalize-object-graph.mlir
- misc stuff like torch-globalize-pipeline pipeline definition.
With this, we can import, globalize, and inline resnet18 from
torchvision:
https://gist.github.com/silvasean/821586afc19b67d9fb72030b2e0adeb8
2021-03-10 12:33:21 +08:00
|
|
|
//===- PrepareForGlobalizeObjectGraph.cpp ------------------------*- C++-*-===//
|
|
|
|
//
|
|
|
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
2021-09-30 00:03:40 +08:00
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
Support multiple instances of a class in GlobalizeObjectGraph.
This happens in practice with e.g. ResNet from torchvision (multiple
instances of the same BatchNorm class).
The key observation is that for this program, and the expected set of
programs, we can convert the program to the same globalized form with a
bit more static analysis and effort to suitably monomorphize the
program. Though what we are doing here is fairly annoying to implement,
it saves any nontrivial later pass from having to do similar analyses
(or worse). E.g. shape inference would need to be object-graph aware,
mutation/lifetime analyses would have to be aware, etc. Additionally, it
would make us front-load what it means to have a !torch.nn.Module type
on an ABI boundary, which we are just not ready to handle.
I'm really, really hoping that in practice we can get away with
this, otherwise it's going to be really rough designing a representation
(and implementing everything to back it) that is convenient to transform
and gracefully scales from full object graph (in the most dynamic case)
down to a fixed set of global slots like we have here (in the most
static case, which we presume a lot of practical programs fall into).
This also involved introducing a
`torch-prepare-for-globalize-object-graph` pass that does a minimal set of
lowerings to simplify the IR into a more orthogonal and analyzable form,
and a `torch-globalize-pipeline` helper.
Recommended review order:
- updated documentation in Passes.td
- new tests in `globalize-object-graph-multiple-instances*.mlir`
- implementation of GlobalizeObjectGraph.cpp
- PrepareForGlobalizeObjectGraph.cpp + prepare-for-globalize-object-graph.mlir
- misc stuff like torch-globalize-pipeline pipeline definition.
With this, we can import, globalize, and inline resnet18 from
torchvision:
https://gist.github.com/silvasean/821586afc19b67d9fb72030b2e0adeb8
2021-03-10 12:33:21 +08:00
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "PassDetail.h"
|
|
|
|
|
|
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
[torch-mlir earthmoving (1/N)] C/C++ code movement.
This creates the `external/torch-mlir` directory as an
LLVM_EXTERNAL_PROJECTS-compatible project (analogous to
`iree-dialects`) and completes movement/rename of all pure MLIR C/C++
compiler code into there. The next step will be to move all the Python
code / code that links/includes PyTorch C++ code (which currently lives
in `frontends/pytorch`) into a subdirectory here.
I call this "earthmoving" because it is mostly mechanical changes and
renames. As a quick summary (we can change this down the road easily)
- C++ `mlir::NPCOMP::Torch -> mlir::torch::Torch`
- CAPI `npcompTorchListTypeGet -> torchMlirTorchListTypeGet`
- preprocessor `#ifndef NPCOMP_ -> #ifndef TORCHMLIR_`
- CMake `NPCOMPFoo -> TorchMLIRFoo`
The goal of this is to create a standalone project creating a center of
mass for entry into the MLIR ecosystem from PyTorch, suitable in scope
for eventual inclusion/ownership in PyTorch. The idea is that
`external/torch-mlir` will some day be pulled out into its own
repository, and then npcomp will simply pull it in as a submodule.
Layering-wise, what lives in `torch-mlir` lowers code from PyTorch
(currently TorchScript, but TorchFX or pytorch/xla-style tracing are
possible extensions) down to what we have been calling the "Torch
backend contract" which is cleaned up IR (inlining, simplifcation,
conversion to value tensors, ...) entirely in the `torch` dialect. This
is the branching off point for further lowering, of which npcomp takes
one opinion (outside `torch-mlir` of course!), namely the
`TorchConversion` dialect/transforms which lower to IR suitable for IREE
and other linalg-on-tensors based lower-level compilers.
Summary of changes:
- move `{include,lib,test}/Dialect/Torch` into `torch-mlir`
- move relevant parts of CAPI into `torch-mlir`.
- leave a few things related to the `torch-mlir` Python build commented
out, which should be resolved in a subsequent change.
2021-09-10 03:24:10 +08:00
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
Support multiple instances of a class in GlobalizeObjectGraph.
This happens in practice with e.g. ResNet from torchvision (multiple
instances of the same BatchNorm class).
The key observation is that for this program, and the expected set of
programs, we can convert the program to the same globalized form with a
bit more static analysis and effort to suitably monomorphize the
program. Though what we are doing here is fairly annoying to implement,
it saves any nontrivial later pass from having to do similar analyses
(or worse). E.g. shape inference would need to be object-graph aware,
mutation/lifetime analyses would have to be aware, etc. Additionally, it
would make us front-load what it means to have a !torch.nn.Module type
on an ABI boundary, which we are just not ready to handle.
I'm really, really hoping that in practice we can get away with
this, otherwise it's going to be really rough designing a representation
(and implementing everything to back it) that is convenient to transform
and gracefully scales from full object graph (in the most dynamic case)
down to a fixed set of global slots like we have here (in the most
static case, which we presume a lot of practical programs fall into).
This also involved introducing a
`torch-prepare-for-globalize-object-graph` pass that does a minimal set of
lowerings to simplify the IR into a more orthogonal and analyzable form,
and a `torch-globalize-pipeline` helper.
Recommended review order:
- updated documentation in Passes.td
- new tests in `globalize-object-graph-multiple-instances*.mlir`
- implementation of GlobalizeObjectGraph.cpp
- PrepareForGlobalizeObjectGraph.cpp + prepare-for-globalize-object-graph.mlir
- misc stuff like torch-globalize-pipeline pipeline definition.
With this, we can import, globalize, and inline resnet18 from
torchvision:
https://gist.github.com/silvasean/821586afc19b67d9fb72030b2e0adeb8
2021-03-10 12:33:21 +08:00
|
|
|
|
|
|
|
using namespace mlir;
|
[torch-mlir earthmoving (1/N)] C/C++ code movement.
This creates the `external/torch-mlir` directory as an
LLVM_EXTERNAL_PROJECTS-compatible project (analogous to
`iree-dialects`) and completes movement/rename of all pure MLIR C/C++
compiler code into there. The next step will be to move all the Python
code / code that links/includes PyTorch C++ code (which currently lives
in `frontends/pytorch`) into a subdirectory here.
I call this "earthmoving" because it is mostly mechanical changes and
renames. As a quick summary (we can change this down the road easily)
- C++ `mlir::NPCOMP::Torch -> mlir::torch::Torch`
- CAPI `npcompTorchListTypeGet -> torchMlirTorchListTypeGet`
- preprocessor `#ifndef NPCOMP_ -> #ifndef TORCHMLIR_`
- CMake `NPCOMPFoo -> TorchMLIRFoo`
The goal of this is to create a standalone project creating a center of
mass for entry into the MLIR ecosystem from PyTorch, suitable in scope
for eventual inclusion/ownership in PyTorch. The idea is that
`external/torch-mlir` will some day be pulled out into its own
repository, and then npcomp will simply pull it in as a submodule.
Layering-wise, what lives in `torch-mlir` lowers code from PyTorch
(currently TorchScript, but TorchFX or pytorch/xla-style tracing are
possible extensions) down to what we have been calling the "Torch
backend contract" which is cleaned up IR (inlining, simplifcation,
conversion to value tensors, ...) entirely in the `torch` dialect. This
is the branching off point for further lowering, of which npcomp takes
one opinion (outside `torch-mlir` of course!), namely the
`TorchConversion` dialect/transforms which lower to IR suitable for IREE
and other linalg-on-tensors based lower-level compilers.
Summary of changes:
- move `{include,lib,test}/Dialect/Torch` into `torch-mlir`
- move relevant parts of CAPI into `torch-mlir`.
- leave a few things related to the `torch-mlir` Python build commented
out, which should be resolved in a subsequent change.
2021-09-10 03:24:10 +08:00
|
|
|
using namespace mlir::torch;
|
|
|
|
using namespace mlir::torch::Torch;
|
Support multiple instances of a class in GlobalizeObjectGraph.
This happens in practice with e.g. ResNet from torchvision (multiple
instances of the same BatchNorm class).
The key observation is that for this program, and the expected set of
programs, we can convert the program to the same globalized form with a
bit more static analysis and effort to suitably monomorphize the
program. Though what we are doing here is fairly annoying to implement,
it saves any nontrivial later pass from having to do similar analyses
(or worse). E.g. shape inference would need to be object-graph aware,
mutation/lifetime analyses would have to be aware, etc. Additionally, it
would make us front-load what it means to have a !torch.nn.Module type
on an ABI boundary, which we are just not ready to handle.
I'm really, really hoping that in practice we can get away with
this, otherwise it's going to be really rough designing a representation
(and implementing everything to back it) that is convenient to transform
and gracefully scales from full object graph (in the most dynamic case)
down to a fixed set of global slots like we have here (in the most
static case, which we presume a lot of practical programs fall into).
This also involved introducing a
`torch-prepare-for-globalize-object-graph` pass that does a minimal set of
lowerings to simplify the IR into a more orthogonal and analyzable form,
and a `torch-globalize-pipeline` helper.
Recommended review order:
- updated documentation in Passes.td
- new tests in `globalize-object-graph-multiple-instances*.mlir`
- implementation of GlobalizeObjectGraph.cpp
- PrepareForGlobalizeObjectGraph.cpp + prepare-for-globalize-object-graph.mlir
- misc stuff like torch-globalize-pipeline pipeline definition.
With this, we can import, globalize, and inline resnet18 from
torchvision:
https://gist.github.com/silvasean/821586afc19b67d9fb72030b2e0adeb8
2021-03-10 12:33:21 +08:00
|
|
|
|
|
|
|
namespace {
|
|
|
|
class ConvertPrimCallMethodToCall : public OpRewritePattern<PrimCallMethodOp> {
|
|
|
|
public:
|
|
|
|
ConvertPrimCallMethodToCall(MLIRContext *context, SymbolTable &symbolTable)
|
|
|
|
: OpRewritePattern(context), symbolTable(symbolTable) {}
|
|
|
|
LogicalResult matchAndRewrite(PrimCallMethodOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
auto classType = symbolTable.lookup<ClassTypeOp>(
|
2024-04-28 05:00:56 +08:00
|
|
|
cast<NnModuleType>(op.getReceiver().getType()).getClassName());
|
Fix issue with unused functions in torch::jit::CompilationUnit
As described in the code comment:
```
When we import TorchScript IR, we import their entire "compilation unit",
which can contain numerous functions unrelated to the current program,
which breaks torch-globalization-pipeline; for example, there can be
random functions referencing types that haven't been imported
as part of the root `torch.nn.Module` we imported. Those will
be unreferenced private functions which symbol-dce will clean up nicely.
```
This situation is really easy to hit in jupyter notebooks, where the
same cell is evaluated multiple times. That results in the same
class name (at the Python level, e.g. class `Foo` in the top-level
main module). Internally to PyTorch, it handles this situation by
mangling in a unique number to the names of ClassType's and such. When
we import the new ClassType's, we see not just the new
torch::jit::Function's in the CompilationUnit, but, also all the old
ones, which reference ClassType's that are not reachable from the
`torch.nn.Module` that we imported.
Note: there is no way to avoid importing the whole CompilationUnit
(including these old remnants) without doing a fairly complicated call
graph reachability analysis of which functions are reachable from the
methods of the ClassType's we imported. It turns out that once we are
inside MLIR, we model visibility correctly so that `symbol-dce`
"Just Works" for this use case. That is to say, this is not a quick
hack, but rather seems like a totally palatable long-term solution.
2021-04-14 05:50:58 +08:00
|
|
|
assert(classType && "malformed module -- missing ClassTypeOp");
|
2022-04-27 03:27:51 +08:00
|
|
|
func::FuncOp func;
|
Support multiple instances of a class in GlobalizeObjectGraph.
This happens in practice with e.g. ResNet from torchvision (multiple
instances of the same BatchNorm class).
The key observation is that for this program, and the expected set of
programs, we can convert the program to the same globalized form with a
bit more static analysis and effort to suitably monomorphize the
program. Though what we are doing here is fairly annoying to implement,
it saves any nontrivial later pass from having to do similar analyses
(or worse). E.g. shape inference would need to be object-graph aware,
mutation/lifetime analyses would have to be aware, etc. Additionally, it
would make us front-load what it means to have a !torch.nn.Module type
on an ABI boundary, which we are just not ready to handle.
I'm really, really hoping that in practice we can get away with
this, otherwise it's going to be really rough designing a representation
(and implementing everything to back it) that is convenient to transform
and gracefully scales from full object graph (in the most dynamic case)
down to a fixed set of global slots like we have here (in the most
static case, which we presume a lot of practical programs fall into).
This also involved introducing a
`torch-prepare-for-globalize-object-graph` pass that does a minimal set of
lowerings to simplify the IR into a more orthogonal and analyzable form,
and a `torch-globalize-pipeline` helper.
Recommended review order:
- updated documentation in Passes.td
- new tests in `globalize-object-graph-multiple-instances*.mlir`
- implementation of GlobalizeObjectGraph.cpp
- PrepareForGlobalizeObjectGraph.cpp + prepare-for-globalize-object-graph.mlir
- misc stuff like torch-globalize-pipeline pipeline definition.
With this, we can import, globalize, and inline resnet18 from
torchvision:
https://gist.github.com/silvasean/821586afc19b67d9fb72030b2e0adeb8
2021-03-10 12:33:21 +08:00
|
|
|
for (auto method : classType.getOps<MethodOp>()) {
|
2022-12-08 04:20:41 +08:00
|
|
|
if (method.getName() == op.getName()) {
|
|
|
|
func = symbolTable.lookup<func::FuncOp>(method.getFunction());
|
Support multiple instances of a class in GlobalizeObjectGraph.
This happens in practice with e.g. ResNet from torchvision (multiple
instances of the same BatchNorm class).
The key observation is that for this program, and the expected set of
programs, we can convert the program to the same globalized form with a
bit more static analysis and effort to suitably monomorphize the
program. Though what we are doing here is fairly annoying to implement,
it saves any nontrivial later pass from having to do similar analyses
(or worse). E.g. shape inference would need to be object-graph aware,
mutation/lifetime analyses would have to be aware, etc. Additionally, it
would make us front-load what it means to have a !torch.nn.Module type
on an ABI boundary, which we are just not ready to handle.
I'm really, really hoping that in practice we can get away with
this, otherwise it's going to be really rough designing a representation
(and implementing everything to back it) that is convenient to transform
and gracefully scales from full object graph (in the most dynamic case)
down to a fixed set of global slots like we have here (in the most
static case, which we presume a lot of practical programs fall into).
This also involved introducing a
`torch-prepare-for-globalize-object-graph` pass that does a minimal set of
lowerings to simplify the IR into a more orthogonal and analyzable form,
and a `torch-globalize-pipeline` helper.
Recommended review order:
- updated documentation in Passes.td
- new tests in `globalize-object-graph-multiple-instances*.mlir`
- implementation of GlobalizeObjectGraph.cpp
- PrepareForGlobalizeObjectGraph.cpp + prepare-for-globalize-object-graph.mlir
- misc stuff like torch-globalize-pipeline pipeline definition.
With this, we can import, globalize, and inline resnet18 from
torchvision:
https://gist.github.com/silvasean/821586afc19b67d9fb72030b2e0adeb8
2021-03-10 12:33:21 +08:00
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
assert(func);
|
2022-03-16 18:44:23 +08:00
|
|
|
rewriter.replaceOpWithNewOp<func::CallOp>(op, func, op->getOperands());
|
Support multiple instances of a class in GlobalizeObjectGraph.
This happens in practice with e.g. ResNet from torchvision (multiple
instances of the same BatchNorm class).
The key observation is that for this program, and the expected set of
programs, we can convert the program to the same globalized form with a
bit more static analysis and effort to suitably monomorphize the
program. Though what we are doing here is fairly annoying to implement,
it saves any nontrivial later pass from having to do similar analyses
(or worse). E.g. shape inference would need to be object-graph aware,
mutation/lifetime analyses would have to be aware, etc. Additionally, it
would make us front-load what it means to have a !torch.nn.Module type
on an ABI boundary, which we are just not ready to handle.
I'm really, really hoping that in practice we can get away with
this, otherwise it's going to be really rough designing a representation
(and implementing everything to back it) that is convenient to transform
and gracefully scales from full object graph (in the most dynamic case)
down to a fixed set of global slots like we have here (in the most
static case, which we presume a lot of practical programs fall into).
This also involved introducing a
`torch-prepare-for-globalize-object-graph` pass that does a minimal set of
lowerings to simplify the IR into a more orthogonal and analyzable form,
and a `torch-globalize-pipeline` helper.
Recommended review order:
- updated documentation in Passes.td
- new tests in `globalize-object-graph-multiple-instances*.mlir`
- implementation of GlobalizeObjectGraph.cpp
- PrepareForGlobalizeObjectGraph.cpp + prepare-for-globalize-object-graph.mlir
- misc stuff like torch-globalize-pipeline pipeline definition.
With this, we can import, globalize, and inline resnet18 from
torchvision:
https://gist.github.com/silvasean/821586afc19b67d9fb72030b2e0adeb8
2021-03-10 12:33:21 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
private:
|
|
|
|
SymbolTable &symbolTable;
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
namespace {
|
2022-03-16 18:44:23 +08:00
|
|
|
class EraseUnusedConstantOp : public OpRewritePattern<func::ConstantOp> {
|
Support multiple instances of a class in GlobalizeObjectGraph.
This happens in practice with e.g. ResNet from torchvision (multiple
instances of the same BatchNorm class).
The key observation is that for this program, and the expected set of
programs, we can convert the program to the same globalized form with a
bit more static analysis and effort to suitably monomorphize the
program. Though what we are doing here is fairly annoying to implement,
it saves any nontrivial later pass from having to do similar analyses
(or worse). E.g. shape inference would need to be object-graph aware,
mutation/lifetime analyses would have to be aware, etc. Additionally, it
would make us front-load what it means to have a !torch.nn.Module type
on an ABI boundary, which we are just not ready to handle.
I'm really, really hoping that in practice we can get away with
this, otherwise it's going to be really rough designing a representation
(and implementing everything to back it) that is convenient to transform
and gracefully scales from full object graph (in the most dynamic case)
down to a fixed set of global slots like we have here (in the most
static case, which we presume a lot of practical programs fall into).
This also involved introducing a
`torch-prepare-for-globalize-object-graph` pass that does a minimal set of
lowerings to simplify the IR into a more orthogonal and analyzable form,
and a `torch-globalize-pipeline` helper.
Recommended review order:
- updated documentation in Passes.td
- new tests in `globalize-object-graph-multiple-instances*.mlir`
- implementation of GlobalizeObjectGraph.cpp
- PrepareForGlobalizeObjectGraph.cpp + prepare-for-globalize-object-graph.mlir
- misc stuff like torch-globalize-pipeline pipeline definition.
With this, we can import, globalize, and inline resnet18 from
torchvision:
https://gist.github.com/silvasean/821586afc19b67d9fb72030b2e0adeb8
2021-03-10 12:33:21 +08:00
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
2022-03-16 18:44:23 +08:00
|
|
|
LogicalResult matchAndRewrite(func::ConstantOp op,
|
Support multiple instances of a class in GlobalizeObjectGraph.
This happens in practice with e.g. ResNet from torchvision (multiple
instances of the same BatchNorm class).
The key observation is that for this program, and the expected set of
programs, we can convert the program to the same globalized form with a
bit more static analysis and effort to suitably monomorphize the
program. Though what we are doing here is fairly annoying to implement,
it saves any nontrivial later pass from having to do similar analyses
(or worse). E.g. shape inference would need to be object-graph aware,
mutation/lifetime analyses would have to be aware, etc. Additionally, it
would make us front-load what it means to have a !torch.nn.Module type
on an ABI boundary, which we are just not ready to handle.
I'm really, really hoping that in practice we can get away with
this, otherwise it's going to be really rough designing a representation
(and implementing everything to back it) that is convenient to transform
and gracefully scales from full object graph (in the most dynamic case)
down to a fixed set of global slots like we have here (in the most
static case, which we presume a lot of practical programs fall into).
This also involved introducing a
`torch-prepare-for-globalize-object-graph` pass that does a minimal set of
lowerings to simplify the IR into a more orthogonal and analyzable form,
and a `torch-globalize-pipeline` helper.
Recommended review order:
- updated documentation in Passes.td
- new tests in `globalize-object-graph-multiple-instances*.mlir`
- implementation of GlobalizeObjectGraph.cpp
- PrepareForGlobalizeObjectGraph.cpp + prepare-for-globalize-object-graph.mlir
- misc stuff like torch-globalize-pipeline pipeline definition.
With this, we can import, globalize, and inline resnet18 from
torchvision:
https://gist.github.com/silvasean/821586afc19b67d9fb72030b2e0adeb8
2021-03-10 12:33:21 +08:00
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
if (op.use_empty()) {
|
|
|
|
rewriter.eraseOp(op);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
class PrepareForGlobalizeObjectGraphPass
|
|
|
|
: public PrepareForGlobalizeObjectGraphBase<
|
|
|
|
PrepareForGlobalizeObjectGraphPass> {
|
|
|
|
void runOnOperation() override {
|
|
|
|
|
|
|
|
SymbolTable symbolTable(getOperation());
|
|
|
|
|
|
|
|
MLIRContext *context = &getContext();
|
2021-03-24 05:16:23 +08:00
|
|
|
RewritePatternSet patterns(context);
|
|
|
|
patterns.add<ConvertPrimCallMethodToCall>(context, symbolTable);
|
2022-03-16 18:44:23 +08:00
|
|
|
func::CallIndirectOp::getCanonicalizationPatterns(patterns, context);
|
2021-03-24 05:16:23 +08:00
|
|
|
patterns.add<EraseUnusedConstantOp>(context);
|
Support multiple instances of a class in GlobalizeObjectGraph.
This happens in practice with e.g. ResNet from torchvision (multiple
instances of the same BatchNorm class).
The key observation is that for this program, and the expected set of
programs, we can convert the program to the same globalized form with a
bit more static analysis and effort to suitably monomorphize the
program. Though what we are doing here is fairly annoying to implement,
it saves any nontrivial later pass from having to do similar analyses
(or worse). E.g. shape inference would need to be object-graph aware,
mutation/lifetime analyses would have to be aware, etc. Additionally, it
would make us front-load what it means to have a !torch.nn.Module type
on an ABI boundary, which we are just not ready to handle.
I'm really, really hoping that in practice we can get away with
this, otherwise it's going to be really rough designing a representation
(and implementing everything to back it) that is convenient to transform
and gracefully scales from full object graph (in the most dynamic case)
down to a fixed set of global slots like we have here (in the most
static case, which we presume a lot of practical programs fall into).
This also involved introducing a
`torch-prepare-for-globalize-object-graph` pass that does a minimal set of
lowerings to simplify the IR into a more orthogonal and analyzable form,
and a `torch-globalize-pipeline` helper.
Recommended review order:
- updated documentation in Passes.td
- new tests in `globalize-object-graph-multiple-instances*.mlir`
- implementation of GlobalizeObjectGraph.cpp
- PrepareForGlobalizeObjectGraph.cpp + prepare-for-globalize-object-graph.mlir
- misc stuff like torch-globalize-pipeline pipeline definition.
With this, we can import, globalize, and inline resnet18 from
torchvision:
https://gist.github.com/silvasean/821586afc19b67d9fb72030b2e0adeb8
2021-03-10 12:33:21 +08:00
|
|
|
|
|
|
|
// Use applyPatternsAndFoldGreedily because the CallIndirectOp folding
|
|
|
|
// makes the ConstantOp unused, which does not work with the visitation
|
|
|
|
// order of the dialect conversion infrastructure.
|
|
|
|
// TODO: Do this with the dialect conversion infrastructure to avoid doing
|
|
|
|
// folding as part of this. Or avoid folding during greedy pattern
|
|
|
|
// application. See: https://llvm.org/PR49502
|
|
|
|
if (failed(applyPatternsAndFoldGreedily(getOperation(),
|
|
|
|
std::move(patterns)))) {
|
|
|
|
return signalPassFailure();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Do a dummy full conversion to ensure that the program has been converted
|
|
|
|
// to the form we want.
|
|
|
|
ConversionTarget target(*context);
|
|
|
|
target.addIllegalOp<PrimCallMethodOp>();
|
2022-03-16 18:44:23 +08:00
|
|
|
target.addDynamicallyLegalOp<func::ConstantOp>(
|
2024-04-28 05:00:56 +08:00
|
|
|
[](func::ConstantOp op) { return !isa<FunctionType>(op.getType()); });
|
2022-03-16 18:44:23 +08:00
|
|
|
target.addIllegalOp<func::CallIndirectOp>();
|
Support multiple instances of a class in GlobalizeObjectGraph.
This happens in practice with e.g. ResNet from torchvision (multiple
instances of the same BatchNorm class).
The key observation is that for this program, and the expected set of
programs, we can convert the program to the same globalized form with a
bit more static analysis and effort to suitably monomorphize the
program. Though what we are doing here is fairly annoying to implement,
it saves any nontrivial later pass from having to do similar analyses
(or worse). E.g. shape inference would need to be object-graph aware,
mutation/lifetime analyses would have to be aware, etc. Additionally, it
would make us front-load what it means to have a !torch.nn.Module type
on an ABI boundary, which we are just not ready to handle.
I'm really, really hoping that in practice we can get away with
this, otherwise it's going to be really rough designing a representation
(and implementing everything to back it) that is convenient to transform
and gracefully scales from full object graph (in the most dynamic case)
down to a fixed set of global slots like we have here (in the most
static case, which we presume a lot of practical programs fall into).
This also involved introducing a
`torch-prepare-for-globalize-object-graph` pass that does a minimal set of
lowerings to simplify the IR into a more orthogonal and analyzable form,
and a `torch-globalize-pipeline` helper.
Recommended review order:
- updated documentation in Passes.td
- new tests in `globalize-object-graph-multiple-instances*.mlir`
- implementation of GlobalizeObjectGraph.cpp
- PrepareForGlobalizeObjectGraph.cpp + prepare-for-globalize-object-graph.mlir
- misc stuff like torch-globalize-pipeline pipeline definition.
With this, we can import, globalize, and inline resnet18 from
torchvision:
https://gist.github.com/silvasean/821586afc19b67d9fb72030b2e0adeb8
2021-03-10 12:33:21 +08:00
|
|
|
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
|
|
|
|
|
2021-03-24 05:16:23 +08:00
|
|
|
RewritePatternSet dummyPatterns(context);
|
Support multiple instances of a class in GlobalizeObjectGraph.
This happens in practice with e.g. ResNet from torchvision (multiple
instances of the same BatchNorm class).
The key observation is that for this program, and the expected set of
programs, we can convert the program to the same globalized form with a
bit more static analysis and effort to suitably monomorphize the
program. Though what we are doing here is fairly annoying to implement,
it saves any nontrivial later pass from having to do similar analyses
(or worse). E.g. shape inference would need to be object-graph aware,
mutation/lifetime analyses would have to be aware, etc. Additionally, it
would make us front-load what it means to have a !torch.nn.Module type
on an ABI boundary, which we are just not ready to handle.
I'm really, really hoping that in practice we can get away with
this, otherwise it's going to be really rough designing a representation
(and implementing everything to back it) that is convenient to transform
and gracefully scales from full object graph (in the most dynamic case)
down to a fixed set of global slots like we have here (in the most
static case, which we presume a lot of practical programs fall into).
This also involved introducing a
`torch-prepare-for-globalize-object-graph` pass that does a minimal set of
lowerings to simplify the IR into a more orthogonal and analyzable form,
and a `torch-globalize-pipeline` helper.
Recommended review order:
- updated documentation in Passes.td
- new tests in `globalize-object-graph-multiple-instances*.mlir`
- implementation of GlobalizeObjectGraph.cpp
- PrepareForGlobalizeObjectGraph.cpp + prepare-for-globalize-object-graph.mlir
- misc stuff like torch-globalize-pipeline pipeline definition.
With this, we can import, globalize, and inline resnet18 from
torchvision:
https://gist.github.com/silvasean/821586afc19b67d9fb72030b2e0adeb8
2021-03-10 12:33:21 +08:00
|
|
|
|
|
|
|
if (failed(applyFullConversion(getOperation(), target,
|
[torch-mlir earthmoving (1/N)] C/C++ code movement.
This creates the `external/torch-mlir` directory as an
LLVM_EXTERNAL_PROJECTS-compatible project (analogous to
`iree-dialects`) and completes movement/rename of all pure MLIR C/C++
compiler code into there. The next step will be to move all the Python
code / code that links/includes PyTorch C++ code (which currently lives
in `frontends/pytorch`) into a subdirectory here.
I call this "earthmoving" because it is mostly mechanical changes and
renames. As a quick summary (we can change this down the road easily)
- C++ `mlir::NPCOMP::Torch -> mlir::torch::Torch`
- CAPI `npcompTorchListTypeGet -> torchMlirTorchListTypeGet`
- preprocessor `#ifndef NPCOMP_ -> #ifndef TORCHMLIR_`
- CMake `NPCOMPFoo -> TorchMLIRFoo`
The goal of this is to create a standalone project creating a center of
mass for entry into the MLIR ecosystem from PyTorch, suitable in scope
for eventual inclusion/ownership in PyTorch. The idea is that
`external/torch-mlir` will some day be pulled out into its own
repository, and then npcomp will simply pull it in as a submodule.
Layering-wise, what lives in `torch-mlir` lowers code from PyTorch
(currently TorchScript, but TorchFX or pytorch/xla-style tracing are
possible extensions) down to what we have been calling the "Torch
backend contract" which is cleaned up IR (inlining, simplifcation,
conversion to value tensors, ...) entirely in the `torch` dialect. This
is the branching off point for further lowering, of which npcomp takes
one opinion (outside `torch-mlir` of course!), namely the
`TorchConversion` dialect/transforms which lower to IR suitable for IREE
and other linalg-on-tensors based lower-level compilers.
Summary of changes:
- move `{include,lib,test}/Dialect/Torch` into `torch-mlir`
- move relevant parts of CAPI into `torch-mlir`.
- leave a few things related to the `torch-mlir` Python build commented
out, which should be resolved in a subsequent change.
2021-09-10 03:24:10 +08:00
|
|
|
std::move(dummyPatterns)))) {
|
Support multiple instances of a class in GlobalizeObjectGraph.
This happens in practice with e.g. ResNet from torchvision (multiple
instances of the same BatchNorm class).
The key observation is that for this program, and the expected set of
programs, we can convert the program to the same globalized form with a
bit more static analysis and effort to suitably monomorphize the
program. Though what we are doing here is fairly annoying to implement,
it saves any nontrivial later pass from having to do similar analyses
(or worse). E.g. shape inference would need to be object-graph aware,
mutation/lifetime analyses would have to be aware, etc. Additionally, it
would make us front-load what it means to have a !torch.nn.Module type
on an ABI boundary, which we are just not ready to handle.
I'm really, really hoping that in practice we can get away with
this, otherwise it's going to be really rough designing a representation
(and implementing everything to back it) that is convenient to transform
and gracefully scales from full object graph (in the most dynamic case)
down to a fixed set of global slots like we have here (in the most
static case, which we presume a lot of practical programs fall into).
This also involved introducing a
`torch-prepare-for-globalize-object-graph` pass that does a minimal set of
lowerings to simplify the IR into a more orthogonal and analyzable form,
and a `torch-globalize-pipeline` helper.
Recommended review order:
- updated documentation in Passes.td
- new tests in `globalize-object-graph-multiple-instances*.mlir`
- implementation of GlobalizeObjectGraph.cpp
- PrepareForGlobalizeObjectGraph.cpp + prepare-for-globalize-object-graph.mlir
- misc stuff like torch-globalize-pipeline pipeline definition.
With this, we can import, globalize, and inline resnet18 from
torchvision:
https://gist.github.com/silvasean/821586afc19b67d9fb72030b2e0adeb8
2021-03-10 12:33:21 +08:00
|
|
|
return signalPassFailure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
std::unique_ptr<OperationPass<ModuleOp>>
|
[torch-mlir earthmoving (1/N)] C/C++ code movement.
This creates the `external/torch-mlir` directory as an
LLVM_EXTERNAL_PROJECTS-compatible project (analogous to
`iree-dialects`) and completes movement/rename of all pure MLIR C/C++
compiler code into there. The next step will be to move all the Python
code / code that links/includes PyTorch C++ code (which currently lives
in `frontends/pytorch`) into a subdirectory here.
I call this "earthmoving" because it is mostly mechanical changes and
renames. As a quick summary (we can change this down the road easily)
- C++ `mlir::NPCOMP::Torch -> mlir::torch::Torch`
- CAPI `npcompTorchListTypeGet -> torchMlirTorchListTypeGet`
- preprocessor `#ifndef NPCOMP_ -> #ifndef TORCHMLIR_`
- CMake `NPCOMPFoo -> TorchMLIRFoo`
The goal of this is to create a standalone project creating a center of
mass for entry into the MLIR ecosystem from PyTorch, suitable in scope
for eventual inclusion/ownership in PyTorch. The idea is that
`external/torch-mlir` will some day be pulled out into its own
repository, and then npcomp will simply pull it in as a submodule.
Layering-wise, what lives in `torch-mlir` lowers code from PyTorch
(currently TorchScript, but TorchFX or pytorch/xla-style tracing are
possible extensions) down to what we have been calling the "Torch
backend contract" which is cleaned up IR (inlining, simplifcation,
conversion to value tensors, ...) entirely in the `torch` dialect. This
is the branching off point for further lowering, of which npcomp takes
one opinion (outside `torch-mlir` of course!), namely the
`TorchConversion` dialect/transforms which lower to IR suitable for IREE
and other linalg-on-tensors based lower-level compilers.
Summary of changes:
- move `{include,lib,test}/Dialect/Torch` into `torch-mlir`
- move relevant parts of CAPI into `torch-mlir`.
- leave a few things related to the `torch-mlir` Python build commented
out, which should be resolved in a subsequent change.
2021-09-10 03:24:10 +08:00
|
|
|
mlir::torch::Torch::createPrepareForGlobalizeObjectGraphPass() {
|
Support multiple instances of a class in GlobalizeObjectGraph.
This happens in practice with e.g. ResNet from torchvision (multiple
instances of the same BatchNorm class).
The key observation is that for this program, and the expected set of
programs, we can convert the program to the same globalized form with a
bit more static analysis and effort to suitably monomorphize the
program. Though what we are doing here is fairly annoying to implement,
it saves any nontrivial later pass from having to do similar analyses
(or worse). E.g. shape inference would need to be object-graph aware,
mutation/lifetime analyses would have to be aware, etc. Additionally, it
would make us front-load what it means to have a !torch.nn.Module type
on an ABI boundary, which we are just not ready to handle.
I'm really, really hoping that in practice we can get away with
this, otherwise it's going to be really rough designing a representation
(and implementing everything to back it) that is convenient to transform
and gracefully scales from full object graph (in the most dynamic case)
down to a fixed set of global slots like we have here (in the most
static case, which we presume a lot of practical programs fall into).
This also involved introducing a
`torch-prepare-for-globalize-object-graph` pass that does a minimal set of
lowerings to simplify the IR into a more orthogonal and analyzable form,
and a `torch-globalize-pipeline` helper.
Recommended review order:
- updated documentation in Passes.td
- new tests in `globalize-object-graph-multiple-instances*.mlir`
- implementation of GlobalizeObjectGraph.cpp
- PrepareForGlobalizeObjectGraph.cpp + prepare-for-globalize-object-graph.mlir
- misc stuff like torch-globalize-pipeline pipeline definition.
With this, we can import, globalize, and inline resnet18 from
torchvision:
https://gist.github.com/silvasean/821586afc19b67d9fb72030b2e0adeb8
2021-03-10 12:33:21 +08:00
|
|
|
return std::make_unique<PrepareForGlobalizeObjectGraphPass>();
|
|
|
|
}
|