Commit Graph

46 Commits (8cad02f87e3cc6a80b9b058d86d4d66ec2cb9c25)

Author SHA1 Message Date
武家伟 99fb4c8637
Add folder for ToF64Op and FromF64Op (#1257) 2022-08-22 09:49:39 +08:00
Sean Silva 57681f7947 Iteratively run the main simplification pipeline.
This introduces a new pass LowerToBackendContract (better name very
welcome) which performs the bulk of the simplifications that we do,
such as
- shape refinement
- dtype refinement
- maximizing value semantics
- inlining global slots
- decomposing complex ops

The key difference from before is that it iterates the set of
transformations, which can help to break a number of "catch-22" issues
where one simplification depends on another, the latest example being
here:
https://github.com/llvm/torch-mlir/issues/1131

This also exposed that RefineTypes was sometimes crashing/asserting for
certain inputs. This commit hardens it a bit.
2022-08-17 14:54:33 -07:00
Yan Xu 9be8997536
Revert "add native_dropout and related ops pattern (#1211)" (#1230)
This reverts commit c935795086.
2022-08-17 13:48:10 +08:00
Yan Xu c935795086
add native_dropout and related ops pattern (#1211) 2022-08-15 09:28:47 +08:00
Ramana Radhakrishnan 738f4fe96a
Rename TorchToStd pass as TorchToArith (#1163)
All the converters in this pass appear to create ops from the
arith dialect. Hence the full rename.

Fix GH Issue #409.
2022-08-10 20:12:51 +01:00
Sean Silva 504de5e701 Rework how global slot initializers work.
Rather than a per-global-slot initializer region, we now have one for
the whole module. For example, it might look like this:

```
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
torch.global_slot.module_initializer {
  %0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
  %1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
  torch.initialize.global_slots [
    @tensor(%0 : !torch.tensor)
    @list(%1 : !torch.list<tensor>)
  ]
}
```

This new structure allows GlobalizeObjectGraph to create the initializer in a
much simpler way, avoiding the need to reason about whether different slots
alias each other. Reasoning about whether slots alias each other now is the
responsibility of InlineGlobalSlots, which has to do a much more complicated
analysis, implemented using MLIR's dataflow analysis framework.

Recommended review order:
- Check out the new IR constructs in the .mlir files of various passes
- Op definitions (*.td)
- Changes to GlobalizeObjectGraph pass.
- InlineGlobalSlots pass (~total rewrite)
- Misc changes:
  - Moving torchMlirAdjustStaticInformation for sharing with C++ code.
  - EraseModuleInitializer pass

To make this a bit nicer, it would be good to have a `torch.module` op
with an initializer region attached. That would be more invasive though.

This change has highlighted certain aspects of our project layering
which are worth calling out. None of our backends can handle global
slots, so we enforce that there are no global slots before backend
lowering. At an earlier stage in the project, we had aspirations of
transparently handling mutable global state and such, but for reasons
described below, that is no longer a goal. So really global slots should
be seen as a progressive lowering step as part of inlining all the
IValue's in the original program (GlobalizeObjectGraph is also one such
step).

Over time, with insights from work like IREE-JAX, it has become clear
that there isn't a reliable programming model we can compile for users
where we just transparently handle mutable global state (and some other
things, like lists and dictionaries). There is a need for an "outer
program" that orchestrates more restricted subroutines of the kind we
can handle in our compile flow here. The benefit of that is that it
decouples considerations like shapes, dtypes, etc. from the program
constructs used in the outer program. As long as the outer program can
efficiently invoke (pipelining/async/etc.) high-performance
data-parallel numerical subroutines of the kind we compile in our flow
here, then there is a complete programming model. This is also
consistent with the direction of upstream PyTorch which is becoming more
tracing-based (which inherently loses a lot of program structure, which
then has to be applied back with an "outer program" orchestrating the
traced subroutines).
2022-08-08 18:12:06 -07:00
武家伟 76c976682c
[MHLO] Support for dynamic shape in basic op conversion by introducing CHLO dialect (#1123)
* [MHLO] Support for dynamic shape in basic op conversion by introducing CHLO dialect
Co-authored-by: Bairen Yi <yibairen.byron@bytedance.com>
Co-authored-by: Jiawei Wu <xremold@gmail.com>
Co-authored-by: Tianyou Guo <tianyou.gty@alibaba-inc.com>
Co-authored-by: Xu Yan <yancey.yx@alibaba-inc.com>
Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com>

* [MHLO] Support I32 as shape tensor dtype

* [NFC] Add a 'TODO' annotation
2022-08-02 12:53:24 +08:00
Ramiro Leal-Cavazos f271e6a88c
Add verifiers for ToBuiltinTensorOp and FromBuiltinTensorOp (#1089)
This commit adds verifiers to the ops `ToBuiltinTensorOp` and
`FromBuiltinTensorOp` that make sure that the input and output have
the same shape and data type.
2022-07-21 21:41:45 +00:00
Sean Silva c0ef192865
Improve error message
The unknown dtype case can come from RefineTypes.
2022-07-21 13:52:24 -07:00
Ziheng Jiang c61c99e887
[MHLO] Init MHLO integration. (#1083)
Co-authored-by: Bairen Yi <yibairen.byron@bytedance.com>
Co-authored-by: Jiawei Wu <xremold@gmail.com>
Co-authored-by: Tianyou Guo <tianyou.gty@alibaba-inc.com>
Co-authored-by: Xu Yan <yancey.yx@alibaba-inc.com>
Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com>
2022-07-20 16:18:16 -07:00
Ashay Rane e06ee08506
torch: [nfc] use `WalkResult::isInterrupted()` instead of booleans (#1081)
An upstream MLIR bug (that was recently fixed) caused the result to be
ignored for Region- and Block-visitor functions.  Now that the bug is
fixed, we don't need an auxiliary boolean to track whether the visitor
function has succeeded.
2022-07-19 10:17:57 -07:00
Tanyo Kwok 143a7bcb76
[MLIR][TORCH] Add folder for torch_c.from_i64 & torch_c.to_i64 (#933)
* [MLIR][TORCH] Add folder for torch_c.from_i64 & torch_c.to_i64

* add unit tests for each individual fold

* fix failure of NumelZeroRankModule & TestMultipleTensorAndPrimitiveTypesReturn
2022-06-24 09:34:39 +08:00
Maksim Levental 829717c96e
Bump LLVM (#958) 2022-06-22 22:23:46 -05:00
Prateek Gupta e1db318a3c [TORCH][MLIR]Add lowering for control flow operations.
1. This commit adds lowering of "while-like" prim loop to scf.while
operation.
2. Adds lowering of "for-like" prim loops to scf.for operation.

Signed-Off-By: Prateek Gupta <prateek@nod-labs.com>
2022-04-29 16:25:58 +05:30
Ashay Rane 9208bf0eb6
llvm: bump tag to e1318078 (#781)
The updated LLVM code includes a patch to create bfloat16 array
attributes, thus enabling a different patch to torch-mlir to flesh out
support for the bfloat16 type.
2022-04-26 12:27:51 -07:00
Sean Silva e7721fb784 Fix error message.
RefineTypes doesn't handle shape refinement anymore.
2022-04-07 14:46:44 -07:00
Ahmed Taei f9d34596e8 [NFC] Split BackendTypeConversion -> (BackendTypeConversion, BackendTypeConversionPasses) 2022-03-22 13:56:18 -07:00
Vigilans 63fb1e5aad Bump LLVM at 8361c5da30588d3d4a48eae648f53be1feb5cfad 2022-03-18 13:16:14 -04:00
Sean Silva a5fe0cf063 Introduce new shape library design.
See the documentation in `docs/shape_lib.md` and
`docs/adding_a_shape_function.md` for an overview of the system.

This completely overhauls how we represent shape functions. In
particular, RefineTypes does not infer shapes anymore (only dtypes).
Shape functions are now written in (TorchScript'able) Python.

Recommended review order:

1. Read `docs/shape_lib.md` and `docs/adding_a_shape_function.md`.
1. Code and tests for ReifyShapeCalculations, DropShapeCalculations.
1. Code and tests for SimplifyShapeCalculations.
1. shape_lib_gen.py
1. Code and tests for new RefineTypes pass.
1. Random folders/canonicalizers in TorchOps.cpp and associated test in
   `canonicalize.mlir`.
1. New ReadOnly trait inferred from the registry.
1. Any miscellaneous remaining stuff.

Example `-print-ir-after-all` for ElementwiseUnaryModule:
[IR lowering dump](https://gist.github.com/silvasean/e4dc8cbc8d00aac7819602e3cbd8e212).

Example `-print-ir-after-all` for ElementwiseBinaryModule:
[IR lowering dump](https://gist.github.com/silvasean/daf6860ecced732af3568af6b1899113).
2022-03-15 12:41:58 -07:00
Vivek Khandelwal 1a2a9e066f [MLIR][TORCH] Add TorchToTMTensor pass
This pass is added to lower ops, which can not be lowered
via the TorchToLinalg pass, such as `torch.bincount` op.
This pass also uses torch-mlir's TMTensor Dialect to lower the
complex ops.

Also add torch.bincount op lowering with the help of TMTensor dialect

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-03-08 22:52:34 +05:30
Nirvedh f8cb32faf0 LLVM bump
Major changes: opTrait changed to Trait, selectOp moved to arith dialect
assertOp moved to cf dialect
2022-02-16 15:28:13 -05:00
Yi Zhang 0cb216a1ad [Torch][Linalg] Add basic support for RNG
This PR include the following pieces:
- Add torch `Generator` type. `Generator` type is converted to i64 in
refbackend type converter.
- Add seed managment support for the default global generator.
`torch_c.getNextSeed` op is used to get the seed. On refbackend, the
`torch_c.getNextSeed` is lowered to load/store from [0] of global
variable `default_generator` memref<i64> in `InsertRngGlobals` pass.
- Add `aten.uniform_` and testing as an example op for RNG ops. Add
`torch.pseudo.aten.uniform` op. It has the same operands and return as
the `aten.uniform_` from the op registry except for value semantics.
2022-01-31 18:56:42 -05:00
stephenneuendorffer 52ed3313b4
Bump LLVM to 84fe34a0b7fdd7bbf179981d1583693d5d5ec68b (#544)
* external/llvm-project 881ff4e4ebe8...84fe34a0b7fd (466):
  > [MLIR] Workaround for python detection problems.
2022-01-27 17:21:09 -08:00
stephenneuendorffer 3fd9b7789e
Bump LLVM to 881ff4e4ebe8cc0cc045c7c167cffb01f94f27f8 (#539) 2022-01-25 22:16:30 -08:00
dan 3745f54489 Update external/llvm-project
- Add `qualified` to ods because of
https://reviews.llvm.org/D113873 and https://reviews.llvm.org/D116905
- Needed to revert https://github.com/llvm/torch-mlir/pull/520 as it
was based on an old torch version.
https://github.com/llvm/torch-mlir/pull/527 will bring this back with
a better design.
- Change ConvertAtenCatOp to use more accurate tensor shape info and
as much static info as possible to pass `tensor.insert_slice`
verification code added by https://reviews.llvm.org/D114715
- Other minor fixes
2022-01-18 13:25:42 -05:00
Anup Gangwar abd61b4974 * Workaround for Issue 521, remove createTosaToStandard from Passes.cpp and
disable ElementwisePowModule_basic
* Update nll_loss_forward to align to the change in PyTorch

Signed-off-by: Anup Gangwar <anup.gangwar@arm.com>
2022-01-12 14:30:58 -06:00
Anup Gangwar d69d29b7a6 * [tosa] Support for AtenPowTensorScalarOp with constant Scalar as input
Signed-off-by: Anup Gangwar <anup.gangwar@arm.com>
2022-01-11 22:55:54 -05:00
Yi Zhang 7cf7b91664 [MLIR][TORCH] Fix tensor literal int elem type to be signless
The element type of tensor literal should be signless when converted to
builtin tensor types.
2022-01-07 16:34:24 -05:00
Suraj Sudhir 829cf8afc3
[tosa] Implement Argmax support (#485)
Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>
2021-12-15 11:01:01 -08:00
Suraj Sudhir 1251c186b5 [tosa] Add TosaMakeBroadcastable pass to torch-to-tosa pipeline.
Fixes broken e2e test ElementwiseAddModule_basic

Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>
2021-11-30 13:26:57 -08:00
Yi Zhang 53733933a4 Update llvm upstream to 0b17336f793108a7b10c3fa913039144ef1d0f61
Update AsmPrinter/Parser and MatchAndRewrite
2021-11-16 13:04:51 -05:00
Yi Zhang 0902438882 Update llvm-project to a54f4eae0e1d0ef5adccdcf9f6c2b518dc1101aa
This brings in https://reviews.llvm.org/D110797. PRs that are in
progress will need to use scripts provided by
https://llvm.discourse.group/t/psa-removed-arithmetic-ops-from-standard/4455.
2021-10-18 13:36:42 -04:00
Sean Silva 0c5c84d63d Add a basic TOSA E2E backend.
We lower through linalg-on-tensors and use RefBackend to run it.
This adds enough support for a "tanh" op. Adding more ops should be
fairly mechanical now that things are wired up. Run with:
```
./tools/torchscript_e2e_test.sh -c tosa
```

The backend structure is very similar to linalg-on-tensors based E2E
backends and is a nice parallel (see `tosa_backend.py`). Actually, this
forced a nice refactoring to the layering here. We removed
`torchscript-module-to-linalg-on-tensors-backend-pipeline` and instead
require separately running
```
torchscript-function-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline
```
This highlights the step that lowers to the "torch backend contract"
of cleaned up `torch` dialect ops is a critical step in the lowering.
Going forward, that is the key load-bearing contract of the torch-mlir
project, not the linalg-on-tensors backend contract.

Recommended review order:
- `TorchToTosa.cpp` / `TorchToTosa/basic.mlir`
- `python/torch_mlir_e2e_test/torchscript/configs/tosa_backend.py` and
  the new `utils.py` file there.
- `python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py` and
  `abc.py` in that directory for the TOSA backend e2e interface.
- other misc mechanical changes
2021-10-08 09:59:45 -07:00
Sean Silva 5b6902e31c Dual license the torch-mlir project.
This commit (with approval from all contributors) dual licenses
the torch-mlir project under both the standard LLVM license and the
standard PyTorch license. This will facilitate moving code between
torch-mlir and the two upstream projects.

The standard file comment is now:

```
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
```

See `LICENSE` in the project root for the terms of both licenses.
2021-10-01 10:46:08 -07:00
Ramiro Leal-Cavazos b59f2cb673
Implement the lazytensor package (#331)
Implement the `lazytensor` python package for converting
lazy computations captured by the Lazy Tensor Core into MLIR.
This PR also fixes a few things with `torchfx` and its example
2021-09-28 17:25:06 -07:00
Sean Silva 4fad753073 Move external/torch-mlir to the root of the repo. 2021-09-27 17:11:08 -07:00
Sean Silva a99cbeeb7e Move TorchConversion dialect and TorchTo* into torch-mlir 2021-09-23 21:39:31 -07:00
Sean Silva 2213584c4f VerifyBackendContract -> VerifyLinalgOnTensorsBackendContract
This moves it into TorchConversion since it is only needed there.

This removes the Backend/ directory.
2021-09-23 21:39:31 -07:00
Yi Zhang 603e068e45 E2e implementation for `aten.cat`,`aten.gather`, `aten.bmm`
Also contains the following changes:
- Remove derefineOp canonicalizer because it's not safe.
- Support for optional tensor and list tensors in reduceOpVariant. This
only works for some special detected and easy to handle cases. For list,
it covers the case list is got from a `ListConstruct`. For optional, it
covers the case optional is constructed from a `DerefineOp`.
- Remove the `inferReturnTypes` for `FromBuiltinTensorOp` because it's
not safe to deduce types from the input. For example, a built-in tensor
of i8 could be converted to si8 or ui8. It's better to let the user
specify the return type explicitly.
2021-09-22 19:15:01 -04:00
Sean Silva 1a0b953ea7 Eliminate almost all mentions of IREE.
A few remain in examples/docs that will be naturally be updated in due
time.

This regresses the list support and the general direction of more widely
supported control flow, lists/dicts/globals that we were going for with
the TorchScript path. The idea is that we are deferring that work to
make torch-mlir a very clean standalone thing. We will reboot it,
probably using some of the tools of iree_pydm to make it simpler, and in
a more natural place (such as an iree-torch repo that depends on IREE and
torch-mlir to build a working PyTorch frontend solution for IREE -- it
was really weird that npcomp depended on IREE).
2021-09-22 16:06:38 -07:00
Sean Silva f9c48d0b89 Bring up new RefBackend.
`tools/torchscript_e2e_test.sh` is all green.

This needs a few passes I put into torch-mlir/lib/RefBackend (not to be
confused with `npcomp/lib/RefBackend`, which will soon be deleted).

For the sake of review, since this brings together a lot of things, I
split this into its own commit. I temporarily commented out some "list"
stuff that we are going to remove as part of the torch-mlir refocus.
2021-09-22 14:20:22 -07:00
Sean Silva 28a7738189 [torch-mlir earthmoving (1/N)] C/C++ code movement.
This creates the `external/torch-mlir` directory as an
LLVM_EXTERNAL_PROJECTS-compatible project (analogous to
`iree-dialects`) and completes movement/rename of all pure MLIR C/C++
compiler code into there. The next step will be to move all the Python
code / code that links/includes PyTorch C++ code (which currently lives
in `frontends/pytorch`) into a subdirectory here.

I call this "earthmoving" because it is mostly mechanical changes and
renames. As a quick summary (we can change this down the road easily)
- C++ `mlir::NPCOMP::Torch -> mlir::torch::Torch`
- CAPI `npcompTorchListTypeGet -> torchMlirTorchListTypeGet`
- preprocessor `#ifndef NPCOMP_ -> #ifndef TORCHMLIR_`
- CMake `NPCOMPFoo -> TorchMLIRFoo`

The goal of this is to create a standalone project creating a center of
mass for entry into the MLIR ecosystem from PyTorch, suitable in scope
for eventual inclusion/ownership in PyTorch. The idea is that
`external/torch-mlir` will some day be pulled out into its own
repository, and then npcomp will simply pull it in as a submodule.

Layering-wise, what lives in `torch-mlir` lowers code from PyTorch
(currently TorchScript, but TorchFX or pytorch/xla-style tracing are
possible extensions) down to what we have been calling the "Torch
backend contract" which is cleaned up IR (inlining, simplifcation,
conversion to value tensors, ...) entirely in the `torch` dialect. This
is the branching off point for further lowering, of which npcomp takes
one opinion (outside `torch-mlir` of course!), namely the
`TorchConversion` dialect/transforms which lower to IR suitable for IREE
and other linalg-on-tensors based lower-level compilers.

Summary of changes:
- move `{include,lib,test}/Dialect/Torch` into `torch-mlir`
- move relevant parts of CAPI into `torch-mlir`.
- leave a few things related to the `torch-mlir` Python build commented
  out, which should be resolved in a subsequent change.
2021-09-10 21:44:37 -07:00
Sean Silva a7252f9a06 Add basic support for lists.
This plumbs through a vertical slice of support for lists.

The main chunk of new code here is AnnotateABIPass which captures the
program signature at the Torch backend contract layer, right before we
start `TorchConversion`. The `TorchConversion` lowering process is lossy
w.r.t. types, so it's necessary to do this for all targets in general.
Like using `!iree.list` directly, we use IREE's ABI annotation
representation for this, although there is nothing very IREE-specific
about it (see
https://github.com/google/iree/blob/main/docs/developers/design_docs/function_abi.md)

We change `ListLiteralModule_basic` to use `!torch.int` because IREE
doesn't support f64 yet (and we don't yet have a way for users to say
that they want `!torch.float` to lower as f32).

Recommended review order:
- AnnotateABIPass and tests
- Arg marshaling in npcomp_backend.py and `iree.py`
- Updates to `list_programs.py` / `xfail_sets.py`
- Moving DeleteDeadIREEListsPass to Backend/Common, so that backends
  that don't support lists can use it. RefBackend uses that pass, for
  example.
2021-09-09 20:48:55 -07:00
Sean Silva 1dec561cfd Update llvm-project to 830c0b9023cd0cf91955900e0d96283e7a8c3711
- builder.getSymbolRefAttr is gone.
- OpAsmOpInterface's getAsmResultNames method needs explicit override
- a bunch of churn for builtin.func needing to be made explicit (and
  sometimes implicit?)
- operation printers no longer need to print the operation name
  themselves.
- snuck in beneficial trivial addition to TmpDeleteDeadIREEListsPass to
  test a particular upstream change e2e with my local patchset.
2021-09-03 14:16:38 -07:00
Stella Laurenzo 80ff744c56 Add a few missing deps exposed by stricter linking with BFD. 2021-08-22 11:56:48 -07:00
Sean Silva cab8d922ec Add TorchToIREE and factor out TorchConversion dialect.
This converts a basic list op (torch.prim.ListConstruct) to the IREE
dialect.

```
    def forward(self, x: float):
            return [x, x]
```

turns into:

```
builtin.func @forward(%arg0: !torch.float) -> !torch.list<!torch.float> {
  %0 = torch.prim.ListConstruct %arg0, %arg0 : (!torch.float, !torch.float) -> !torch.list<!torch.float>
  return %0 : !torch.list<!torch.float>
}
```

which turns into:

```
builtin.func @forward(%arg0: f64) -> !iree.list<f64> {
  %c1 = constant 1 : index
  %c0 = constant 0 : index
  %c2 = constant 2 : index
  %0 = iree.list.create %c2 : !iree.list<f64>
  iree.list.set %0[%c0], %arg0 : !iree.list<f64>, f64
  iree.list.set %0[%c1], %arg0 : !iree.list<f64>, f64
  return %0 : !iree.list<f64>
}
```

As part of doing this, I realized that it was time to formalize the IR
form that we reach right before running TorchTo{Linalg,Std,...}. We now
call it the "Torch backend contract". We then lower the "Torch backend
contract" to the "npcomp backend contract", which involves the new
TorchConversion (`torch_c`) dialect, which holds ops that need to
operate on both the npcomp backend types (e.g. builtin tensors, i1, IREE
list, etc.) and the `!torch` types.

This made more sense, as I realized that if I didn't factor out
`torch_c` then the Torch dialect would have a dependency on IREE
dialect (we previously didn't notice this was an issue because we only
depended on `builtin` types), which seemed wrong to me.

Recommended review order:
- TorchToIREE.cpp / `TorchToIREE/basic.mlir`
- Look at the new structure of createTorchScriptToNpcompBackendPipeline.
  It now lives in TorchConversion/Transforms/Passes.cpp and cleanly
  calls into `Torch::createTorchScriptToTorchBackendPipeline` for the
  frontend lowering to the Torch backend contract.
- Mechanical change extracting
  `torch_c.{to,from}_{i1,i64,f64,builtin_tensor,iree_list}` into a new
  TorchConversion dialect, and a few passes specific to the lowering
  from the Torch backend contract to the npcomp backend contract.
- Minor fixes to TorchToLinalg.cpp to use unconverted operands (now that
  we convert lists as part of operand materialization, we need to use
  the original operands). Also added test for AtenMaxPool2dOp and fixed
  m_TorchConstantIntList.
- TmpDeleteDeadIREELists pass. Temporary pass for deleting dead IREE lists that
  are created as part of operand materialization for conv/max pool/avg pool ops
  in TorchToLinalg.
2021-08-16 15:01:58 -07:00