Commit Graph

1258 Commits (f85ae9c685d6c9f2d43fe60fdad2ad5d3183ec52)
 

Author SHA1 Message Date
Sean Silva 784156a998 Add `!torch.bool` type.
This finishes removing the dependence on the basicpy dialect!

Changes:
- Add `!torch.bool` type and replace use of `!basicpy.BoolType` in
  Torch-related code.
- Rename BuiltinTensorize to BackendTypeConversion since now it handles
  bool conversions (and, when we add !torch.int and !torch.float, it
  will handle those as well), and generalize the related utilities (I
  also moved them to Torch/Transforms since they aren't really part of
  Torch/IR).
  - Add `torch.to_i1` and `torch.from_i1` ops for materializations
- [cleanup] Reorganize `torch.constant.*` ops in TorchOps.td
- Remove dependency of `torch` dialect on `basicpy` dialect and also
  `std` dialect. For `std`, we use some call related ops, but the
  `torch` dialect itself never produces them (we have passes that do
  though).

This is fairly mechanical. Recommended review order:
- New stuff in Torch/IR
- New BuiltinTypeConversion files.
- Mechnical fixups elsewhere.
2021-06-16 13:22:00 -07:00
Yi Zhang 7b7c9c5d3d Add aten.relu Linalg lowering support 2021-06-16 08:18:14 -07:00
Sean Silva 3ccf6002af Add `torch.constant.int` and `torch.constant.float`.
- This removes reliance on basicpy.numeric_constant.
- Also, add OpAsmOpInterface to the `torch.constant.none` and
  `torch.constant.str` ops.
2021-06-15 15:29:42 -07:00
Sean Silva 2e850ecb72 Add !torch.str type.
- Remove dependence on `!basicpy.BytesType`.
- Add `torch.constant.str "s"` analogous to `torch.constant.none`.
2021-06-15 10:10:59 -07:00
Sean Silva 31c15cab2b Add 2021Q3 roadmap.
This also restructures the docs to a "roadmap" directory, to preserve
previous roadmaps / allow retrospective "grading" of how we did.
2021-06-15 10:05:25 -07:00
Sean Silva 92ee0fa98f Add `!torch.tuple<T1, T2>` type.
This further eliminates the need for the `basicpy` dependency.

This required adding `torch.prim.TupleConstruct` to replace
`basicpy.build_tuple`.
2021-06-15 08:15:22 -07:00
Sean Silva ea1dd1cd90 Remove a few more comments I missed in the last commit. 2021-06-14 18:18:43 -07:00
Sean Silva 6b2424512b Make C API files more consistent
- Make consistent with MLIR Core
  - Use `//` or `///` comments.
  - Use `bool` type for booleans
  - No duplicated comments in .cpp files
- Split types into separate files `{Basicpy,Numpy,Torch}Types.h`
- Add dialect prefix consistently to C API symbols. We have lots of
  similarly named types (e.g. "list" type in basicpy and torch).
2021-06-14 15:34:43 -07:00
Sean Silva db282fd1b4 Introduce native `!torch.none` type.
- Add `torch.constant.none` op to construct it (naming is chosen to be
  analogous to Torch's representation of a prim::Constant with
  NoneType, rather than using the "singleton" terminology of Basicpy).
2021-06-14 13:30:58 -07:00
Sean Silva 6b293b695d Use new "MLIR_ENABLE_BINDINGS_PYTHON" in the CI. 2021-06-10 18:06:46 -07:00
Sean Silva 81bcd7fb12 Move Torch type implementation code into TorchTypes.cpp 2021-06-10 16:46:47 -07:00
Sean Silva 0b6516c7cc Bump llvm-project to cbd0054b9eb17ec48f0702e3828209646c8f5ebd
Changes:
- MLIR_BINDINGS_PYTHON_ENABLED -> MLIR_ENABLE_BINDINGS_PYTHON
- canonicalizer constant insertion order
- EDSC is gone now
2021-06-10 16:26:45 -07:00
Yi Zhang e0ff5248fb Add TorchList type and prim::ListConstruct #218 2021-06-10 14:31:35 -07:00
Sean Silva 370e3270ab Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes.  The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:

```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```

This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".

At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.

Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
  creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
  touch -- we need to sort out the situation with !basicpy.BoolType
  there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
  semantics. We currently require this, as our backend contract does not
  currently allow us to even model the non-value-semantic case. Before,
  the value-semantic assumption was randomly injected in the middle of
  the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
  RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
  `!torch.vtensor` to `tensor` and use the dialect conversion infra.
  The overall conversion pipeline is set up following the best practices
  of the "Type Conversions the Not-So-Hard Way" talk. This required
  introducing `torch-func-builtin-tensorize` and
  `torch-finalizing-builtin-tensorize` passes analogous to the upstream
  bufferization passes with the corresponding names (mostly just
  copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
  lowering to std later in the pipeline, so we are gradually lessening
  our reliance on random std constant folding before we get to that
  point.

Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
  - Frontend changes.
  - Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-06-10 10:56:48 -07:00
Sean Silva b7b7fd4959 Rewrite error reporting of e2e tests.
This now gives [much nicer output](https://gist.github.com/silvasean/f048e0f37b04542dae6469b86802bb3e).
Embarrassingly, we previously couldn't even report failures for two
different tests, and weren't able to report on compilation failures
(besides just crashing).
2021-05-20 11:28:20 -07:00
Sean Silva d66e8fe1f8 Get simple quantized model importing.
This is enough to import the program and get it through the compilation
pipeline. It of course fails at the VerifyBackendContract pass since
there is a lot missing, but the final IR for a simple quantized MLP is
looking pretty decent already:
[IR](https://gist.github.com/silvasean/f76bccd76e9b193d396cfb2f9a11f54d)

Main changes:
- Add support for importing torch quantized tensors, including
  `torch.per_tensor_affine.create` op and `!torch.qint8` element type.
- Add support for importing `LinearPackedParamsBase` (basically a weight
  + optional bias, but requires `torch.linear_params.create` op +
  `!torch.LinearParams` type to model it). This was less painful than I
  expected, as it has the necessary methods to opaquely unpack itself. I
  factored things so it should be easy to extend to other custom classes
  like `ConvPackedParamsBase`.
- Add minimal boilerplate for importing `quantized::*` ops, with
  `quantized::linear` being a motivating example.
- Add e2e test with simple quantized MLP (courtesy of @phoenix-meadowlark).

This is somewhat of an abuse of `!numpy.ndarray` / `tensor`, as
really the proper semantics of `!torch.qint8` dtype on a Torch tensor is
"check the quantizer object of the tensor for side data (scale/offset,
possibly per-channel) that defines the full semantics of the tensor". We
don't have any such notion of "side data" for `!numpy.ndarray` /
`tensor`, let alone anything that would have the associated behavior of
keying off the dtype to determine if the side data is present.
This will be fixed by a proper `!torch.tensor` type.
2021-05-20 11:28:20 -07:00
Sean Silva 0c89296075 Shore up error reporting for TorchScript import.
This code was not exception safe -- it would leave an operation
unattached to anything, which breaks MLIR's C++ data structure
invariants (e.g. it cannot safely erase ops).

Also, print out both the exception and any diagnostics, since they can
both contain useful information.
2021-05-20 11:28:20 -07:00
Sean Silva d50ea8d31e Improve diagnostic handler
It wasn't printing notes or putting the "error:" in front.
2021-05-20 11:28:20 -07:00
Sean Silva 2453805f7f Bump llvm-project to 35454268cf93f5561439980d6baeb27a874a380c 2021-05-19 14:00:38 -07:00
Sean Silva 2efda323ff Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.

This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).

Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
  imported as `torch.somens.someunqualname.someoverloadname` (skip the
  last dotted part if the overload name is empty), OR, if we don't have
  such an op registered, it is imported as
  `torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
  - The addition of the "overload name" is a critical element here, as
    the `(ns,unqual,overload)` triple is unique, which solves a lot of
    problems we were having.
  - This involves having separate MLIR ops for the `trailing_` and
    `.out` variants and all the different overloads. This seemed
    necessary, because the set of overloads is so wild and varied and
    unstructured. The previous design was leaning into some underlying
    structure that just isn't there -- the default situation is
    the "random overload that we want to manage on the MLIR side",
    rather than that being an exception. E.g.  `aten::ne` (not-equal)
    has 21 overloads, only 4 of which are c10 dispatcher ops see
    [gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
    and the "out" variant is really called `.Tensor_out` instead of
    `.out` as it frequently is for other ops.
  - Rationale for all being in `torch` namespace: the set of operators
    are so varied and unstructured that "dialect per namespace"
    doesn't result in anything resembling the typical MLIR dialect
    boundary expectations. We could maybe draw the boundary at
    dispatcher ops vs non-dispatcher ops, but that doesn't seem to
    really result in very much useful structure at this point in time.
  - Note: within the torch operator registry, we effectively have a
    mini-basicpy subdialect (already type-resolved), which is reasonably
    structured.
  - The existing Torch op interfaces are also removed -- now that we
    track the overload name, we can losslessly find the original
    operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
  `ReduceOpVariantsPass` that keys off certain traits (and perhaps
  eventually interfaces) to reduce variants of ops to a smaller set,
  ideally operating on immutable tensors and using surrounding ops to
  model the mutability/aliasing aspects.
  - Note: `torch.ns.unqual.overload` ops allow both immutable and
    mutable tensors (unlike the previous hard distinction in the common
    case). This is a premonition for a future change that will introduce a
    bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
  "ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
  It should look somewhat familiar, but the benefit of hindsight has
  allowed a lot of simplifications.

The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).

Recommended review order:
- Start at some of the new import IR, e.g. in
  `frontends/pytorch/test/node_import/prim.py`,
  `frontends/pytorch/test/acap_export/test_export_add3.py`, and other
  tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
  and associated generated files:
  - `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
  - `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
  traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
  `frontends/pytorch/csrc/builder`. Probably most interesting is the new
  code in `torch_to_mlir_utils.cpp` that has the logic to create the
  `torch.operator` ops or `torch.ns.unqual.overload` ops.

This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-19 13:37:39 -07:00
Sean Silva 01baa39781 Bump llvm-project to 12874e93a15219ccfaff42a0536b2b5368c6f304 2021-05-17 16:25:56 -07:00
Sean Silva 099bb7e29b Add -pass-pipeline-crash-reproducer to npcomp-opt alias. 2021-05-10 18:06:16 -07:00
Sean Silva 45ba5fac6c Bump llvm-project to 6d263b6f1c97fe6c45c75443e7daf6cd0c1c4222
Changes:
- representation of arg attributes on functions changed
2021-05-10 18:06:15 -07:00
Sean Silva 133bdf4b31 [cleanup] Add materializer for basicpy.singleton
This allows the canonicalizer to coalesce it like other constants.
2021-05-03 09:54:44 -07:00
Sean Silva 3d08c83580 Add flatten op recognition + shape refinement.
This op has complex aliasing semantics, so it is kept mutable for now.

With this, we reduce ResNet18 to a single BB with all aten operators
having rank + dtype:
https://gist.github.com/silvasean/2fcb1c6e4d4ae27461204a43ae9c5031
2021-05-03 09:54:44 -07:00
Sean Silva 122cae2ee3 Add aten::len.t, aten::size, and aten::gt.int primitive ops
Also add some canonicalizations that finally reduce ResNet down to a
single block.
2021-04-30 10:57:02 -07:00
Sean Silva ec6d06aa86 Add some more ResNet ops.
- aten::relu_, aten::max_pool2d, aten::adaptive_avg_pool2d, aten::batch_norm, aten::conv2d

No aten-to-linalg conversion for the latter ones, as they are fairly
substantial. At this point, I'm trying to get shape inference and stuff
working for them and the IR cleaned up.
2021-04-30 10:57:02 -07:00
Sean Silva 9257457d8a Add AllowsTypeRefinement trait and use it to improve RefineTypes
This trait lets us model the semantics of various aten/torch/numpy ops
that are insensitive to type refinements. This replaces
hardcoded/inconsistent checks for this property.

To show usage of this new trait, we fix up some old uses, and improve
RefineTypes to be smarter about rewriting with this trait.
2021-04-30 10:57:02 -07:00
Sean Silva 1c832604d2 Remove old aten-to-std / ATenLowering pass.
It was confusing now that we have `convert-aten-to-std`.
2021-04-30 10:57:02 -07:00
Sean Silva 55c3cc6624 Add recognition/folder/lowering for aten::__is__, aten::ne.int, and aten::dim
Interestingly, TorchScript has its own op (`torch::jit::Operator`)
registry separate from the dispatcher (it is a superset of the
dispatcher).

This is where the "prim" ops and some "aten" ops (that should probably
be renamed to "prim") live. In particular, `aten::__is__` is in that
latter category of "aten but really prim". This registry is also the
source of truth for what the TorchScript interpreter calls into when it
executes.

The bulk of the "not part of the dispatcher" ops live in
09feb5f579/torch/csrc/jit/runtime/register_prim_ops.cpp (L82)

And the registry itself lives in:
09feb5f579/torch/csrc/jit/runtime/operator.cpp (L196)

This fold further reduces the IR of ResNet by folding away some
more not-taken branches. These not-taken branches in ResNet require
first-class handling of the list type which we don't yet have on any
backend.
2021-04-30 10:57:02 -07:00
Sean Silva 7eb36b4ae7 Constant fold through basicpy.bool_cast.
This is the start of a push to getting ResNet running.

This involves throwing in the towel on an O0 pipelinie for now. See note
in the code. We keep an options struct with `optimize` flag, but it
default to true for now.
2021-04-30 10:57:02 -07:00
Sean Silva fb5f149e04 Reformat Passes.cpp and remove torch-globalize-pipeline.
The pipeline is subsumed by our lowering pipelines.
2021-04-30 10:57:02 -07:00
River Riddle 4678a7fedd Refactor RefineTypes to use the upstream ForwardDataFlowAnalysis engine
This removes the need for defining all of the custom propagation logic,
and also adds support for propagating value knowledge across branches,
through regions, and across calls.
2021-04-27 13:17:56 -07:00
Sean Silva 642482429c Bump llvm-project to 12011b5217929ef8a56c2099c6f3233934ea4fbc
- Rename FrozenRewritePatternList -> FrozenRewritePatternSet
2021-04-27 13:12:33 -07:00
Sean Silva 179105ca3e Add basic MLP's to the e2e curriculum.
These tests pass on the reference backend.

- Add aten.linear op + shape xfer function + ATen->Linalg lowering.
 - Note: this needs to be more automated, and needs to cover more cases.
 - Current not implemented caveats:
  - size-1 broadcasting for bias vector (either static-size-1 or ? case)
  - higher-rank aten.linear ops (not produced by torch.nn.Linear though)
  - type promotion (still don't even know the exact rules here)
- Add folder for torch.derefine op. Now the inliner can clean it up as
  it inlines. (call boundaries are a main place we need to insert
  torch.derefine) This is brittle -- the other important case is control
  flow which will need to be handled via an extension to
  RefineTypes.cpp (as will more robust call handling). River has an
  in-flight patch to update it to the new dataflow framework so I didn't
  want to do anything intrusive here.
    - Also adjust torch.derefine syntax to use the keyword `to` instead of
      `->`, as most type-only, cast-like ops do.
2021-04-27 12:18:54 -07:00
Sean Silva 9ba77c6e13 Add InlineGlobalSlots pass.
This inlines global slots if possible. This allows them to participate
in folding, canonicalization, shape inference, etc.

Example use cases:
- inlining weights and biases that are readonly during inference
- inlining the "training" bool to allow stuff to fold away

For training use cases (especially internal training loop), we will need
something smarter to get good performance. That would look like an "SSA
formation" which promotes the global slots to tensors in the program,
flushing them back to the slots at the minimal number of necessary
places. We might want to let backends do that transformation though.
This also interacts with shape inference (type bounds on the slots to
even lower them to backends in the first place).
2021-04-27 12:18:54 -07:00
Sean Silva b1c49ae648 Move GlobalizeObjectGraph tests to their own directory 2021-04-27 12:18:54 -07:00
Sean Silva 3a890aa26c Miscellaneous changes while trying to work on ResNet18
- Move frontend lowering pipelines to c++ (this helps with reproducing
  failures in npcomp-opt)
- Add debugging printouts when compilation fails on RefBackendTestConfig

The experience now when a test fails during MLIR lowering is now like this:
```
NPCOMP TorchScript Object Graph IR -> NPCOMP Backend IR lowering failed with the following diagnostics:
failed to legalize operation 'torch.global_slot'
Module does not conform to npcomp's backend contract. See dialect conversion legality information above.

Error can be reproduced with:
$ npcomp-opt -torchscript-to-npcomp-backend-pipeline /tmp/ResNet18Module.mlir
```

And when TorchScript->MLIR import fails it looks like this:
```
PyTorch TorchScript module -> NPCOMP Object Graph IR import failed with the following diagnostics:
unhandled prim operation: %18 : int = prim::min(%17) # /usr/local/google/home/silvasean/.local/lib/python3.9/site-packages/torch/nn/functional.py:4532:4
```

Also,
- Add `--filter=<regex>` to e2e test harness to filter tests.
- Add a few prim ops that were needed to import ResNet18
- Fix torch.prim.Loop.condition assemblyFormat (it previously would not
  round-trip in the case of no loop-carried variables)
2021-04-27 11:51:11 -07:00
Sean Silva 8f96901943 Add vision models (resnet18 to start).
Also,
- improve error reporting of e2e framework.
2021-04-27 11:51:11 -07:00
Sean Silva 390d39a96c
README: mention installing nightly pytorch. 2021-04-27 11:37:09 -07:00
Sean Silva 544cb4ef54 Bump llvm-project to 484b6648fdd4b104eaf7a2504dd07b60af2c9f8d
- add_mlir_doc arg order
- fix some dependent dialects on passes that were now causing errors
- "encoding" attribute on mlirRankedTensorTypeGetChecked
2021-04-22 18:12:55 -07:00
Sean Silva b8ad0189ac Add comment about relative import.
This file needs to adopt best practices for how to structure python
"main"'s, but I don't know how to do that yet.
2021-04-20 12:02:34 -07:00
Sean Silva 39d50ccf0d Add end-to-end testing framework for TorchScript.
The E2E tests can be run with
```
npcpy frontends/pytorch/e2e_testing/torchscript/main.py
```

This commit adds a couple items supporting that end, including new sugar
for annotations (no more raw use of ClassAnnotator!).

Recommended review order:

1. `frontends/pytorch/e2e_testing/torchscript/main.py` for
   the harness + `basic.py` in that directory for examples of tests.
2. Annotation sugar in `frontends/pytorch/python/torch_mlir/torchscript/annotations.py`
   and unittest in `frontends/pytorch/test/ivalue_import/annotations/sugar.py`
3. Global test registry / sugar in
   `frontends/pytorch/python/torch_mlir/torchscript/e2e_test/registry.py`
4. `frontends/pytorch/python/torch_mlir/torchscript/e2e_test/framework.py`
   for the meat of the testing framework (start at `run_tests`), and
   looking at the backend configs in
   `frontends/pytorch/python/torch_mlir/torchscript/e2e_test/configs`
   for examples of backends. This is likely the bulk of review time.
5. Unit tests of the framework logic in `frontends/pytorch/test/torchscript_e2e_test`

There's TODO's scattered throughout, but this seems functional enough to
start pulling stuff into and kicking the tires. A few missing pieces:

1. Marking test expected pass/fail per backend.
2. Figuring out how best to fit this into dev workflows.
3. IREE TestConfig.

Also, forgive this Python newbie... Any advice on Python code structure
/ library design would be much appreciated.
2021-04-20 12:00:35 -07:00
Sean Silva fef1733e12 Fix issue with unused functions in torch::jit::CompilationUnit
As described in the code comment:

```
When we import TorchScript IR, we import their entire "compilation unit",
which can contain numerous functions unrelated to the current program,
which breaks torch-globalization-pipeline; for example, there can be
random functions referencing types that haven't been imported
as part of the root `torch.nn.Module` we imported. Those will
be unreferenced private functions which symbol-dce will clean up nicely.
```

This situation is really easy to hit in jupyter notebooks, where the
same cell is evaluated multiple times. That results in the same
class name (at the Python level, e.g. class `Foo` in the top-level
main module). Internally to PyTorch, it handles this situation by
mangling in a unique number to the names of ClassType's and such. When
we import the new ClassType's, we see not just the new
torch::jit::Function's in the CompilationUnit, but, also all the old
ones, which reference ClassType's that are not reachable from the
`torch.nn.Module` that we imported.

Note: there is no way to avoid importing the whole CompilationUnit
(including these old remnants) without doing a fairly complicated call
graph reachability analysis of which functions are reachable from the
methods of the ClassType's we imported. It turns out that once we are
inside MLIR, we model visibility correctly so that `symbol-dce`
"Just Works" for this use case. That is to say, this is not a quick
hack, but rather seems like a totally palatable long-term solution.
2021-04-20 12:00:35 -07:00
Sean Silva c4123d4d4d Add npcomp-verify-backend-contract pass.
This pass verifies that a given module satisfies the contract that we
have for backends. This is phrased as an "allowlist", because we want to
keep this interface tight. Also, this gives much better diagnostics than
a backend randomly crashing or failing to compile would (though they
could still be improved).

This was especially painful because if we had
`tensor<?x!numpy.any_dtype>` slip through, at some point RefBackend
would convert it to a memref type and trip the "verify type invariants"
assertion which gives no location or anything and crashed the process,
which was very unpleasant.

We implement this with the dialect conversion framework, which works
reasonably well and was quick to put together and familiar, but is still
very "op oriented". We probably want to make this hand-rolled
eventually, especially the error reporting (the most useful kind of
error for a dialect conversion user is not necessarily the best for this
use case). Also, in production, these error will go to users, and need
to be surfaced carefully such as "the compiler needs a type annotation
on this function parameter" which in general requires some special
analysis, wordsmithing, and overall awareness of the e2e use case (such
as how much we can lean into certain source locations) to provide a
meaningful user-level diagnostic.

Also, add `inline` to the current frontend lowering pass pipeline to
allow slightly more complicated programs that otherwise would fail on
shape inference.
2021-04-20 12:00:35 -07:00
Sean Silva f5dfa02523 Add `aten.mm` to linalg lowering.
This is our first op with error semantics, and stresses the system.

There are a few design notes of special interest:
- RefineTypes.cpp's note about shape inference in the presence of code
  that dynamically produces and error, and it is provable statically.
- ATenToLinalg.cpp's notes about future automation of the ATen->linalg
  path.
- The notes in Passes.td about using low-tech `std.assert` ops instead
  of `shape.assuming`.

Note: Doesn't work on IREE yet due to the `std.assert` op (needs to be
lowered to `vm.fail` on the IREE side).
2021-04-16 12:03:31 -07:00
Sean Silva 28a0f02746 Add support for compiling through IREE.
Recommended review order:
- Changes in frontends/pytorch/examples/
- Changes in python/npcomp/compiler/pytorch/backend/
- Boilerplate for the `npcomp-iree-backend-lower-linkage` pass.

This change separates out a
`npcomp.compiler.pytorch.backend.frontend_lowering` module that does the
common lowering for all backends. The individual compiler backends
`npcomp.compiler.pytorch.backend.{refjit,iree}` now accept a loosely
defined "TCP + scalar code" IR mix that will be formalized in the
future as the interface to codegen backends.

This also required adding a small pass
`npcomp-iree-backend-lower-linkage` which adds `iree.module.export` onto
functions, and layering that into the frontend flow. The pass doesn't
require a C++-level dependency on IREE, which is nice for now. TBD how
we are going to handle lists (we hope we can get away with sneakerneting
some td files and relying on loose IR compatibility).

Running through IREE requires the ability to import `iree.compiler` and
`iree.runtime`, which can be obtained as follows:
```
python3 -m pip install iree-compiler-snapshot iree-runtime-snapshot -f https://github.com/google/iree/releases/tag/snapshot-20210406.200
PYTHONPATH="${PYTHONPATH}:${MY_IREE_BUILD}/bindings/python/"
```

This patch makes it painfully clear that we don't have any e2e testing
harness to really plug into, and also don't have a usable Python API to
our compiler stack (something usable in a jupyter notebook).
That will be addressed in subsequent commits. We've been flying by the
seat of our pants with this `examples` directory that isn't subject to
any kind of testing or real usability concerns.
2021-04-09 13:15:07 -07:00
Aaron J Arthurs f9d9518f6e Declare TCP dialect dependency in TCFToTCP conversion 2021-04-07 14:23:56 -07:00
Sean Silva 2ab62aec12 MILESTONE: TorchScript unary tanh runs on RefBackend
This revamps the TORCH_TO_TCF_PASSES to reflect the new layering that we
are doing in the compiler. See comments there for the layering.

Also adds `frontends/pytorch/examples/torchscript_tanh_e2e.py` as an
"example". E2E testing story TBD (want to get IREE working first).
2021-04-07 11:06:34 -07:00
Sean Silva 927546b3c5 Add RefinePublicReturn pass.
This pass allows shape information to be propagated to return types,
which is nontrivial and cannot be cleanly put anywhere else as it
changes the public ABI, which is a concern that we want to keep
concentrated in one place.
2021-04-07 11:06:34 -07:00