Commit Graph

22 Commits (ead5c42a534b853b835aab837996d9ec7da7d957)

Author SHA1 Message Date
Gleb Kazantaev 708fa346a6
Fix Base Lazy Backend Type Conversion (#1412)
* Fix c10::prim::Constant conversion; Added CAPI for passes; Added passes to base lazy backend

* Update ivalue_importer to use ImportOptions; Added tests for non-value/value tensor types

* Added tests for scalar Constant import; Updated MB::importFunction to use ImportOptions

* Test updates

* Move back module variable name

* Remove RefineTypes from TorchMlirLoweringContext::Build()

* Rename pass; Remove passes from base lazy backend

* Rename pass to VerifyBackendContractPass

* Aligned cmd pass name; Fixed TorchConversion passes registration
2022-10-04 15:53:28 -07:00
AmosLewis 940959589b [MLIR][TORCH] Add Byte and Char Dtype support 2022-09-30 13:19:31 +05:30
Ashay Rane 0b46462528
Miscellaneous fixes for Windows builds (#1376)
* test: allow spaces in path to Python executable

On Windows, the path to the Python binary may contain spaces, so this
patch adds quotes around the path to the python executable.

Thanks to @sstamenova for suggesting the fix!

* python: remove header file that causes Windows build failures

Similar to https://reviews.llvm.org/D125284, we can safely remove this
header file without affecting the build on either Linux.  It is
necessary to remove this header file on Windows builds since otherwise
it causes build errors.

* python: drop `TORCH_API` from function defined in Torch-MLIR

`TORCH_API` should apply to functions that are either exported by
libtorch.so or ones that are imported from libtorch.so by its downstream
consumers (like Torch-MLIR).  Neither case applies to the
`importJitFunctionAsFuncOp()` function, since it is defined in
Torch-MLIR (and thus outside libtorch.so).  This patch fixes the problem
by dropping `TORCH_API` from that function's declaration.

* python: make output of class anotations deterministic

The `class-annotator-repr.py` test checks for class annotations in a
specific order, but prior to this patch, the order was
non-deterministic, since the code iterated on an _unordered_ map.

This patch makes the iteration order deterministic through two changes:
1. using a sorted map
2. using the class qualified name instead of the address of the class in
memory

* test: use Python3_EXECUTABLE as interpreter path for consistency

This ensures that tests use the Python3 version that was detected using
CMake, instead of whichever python version that happens to be in the
PATH variable when invoking the test.

* test: fix RUN string

The parenthesis syntax does not run on Windows (the shell interprets the
`(` character as part of the path).  Moreover, the ODR violation in the
comment no longer seems to apply.

* python: port parallel test framework to Windows

Since Windows does not support `fork` natively, Python's
`multiprocessing` module needs to use `spawn` on Windows.  However, to
use `spawn`, the multiprocessing module serializes (or pickles) the
worker function and its arguments.  Sadly, the multiprocessing module
(both the default one in Python and the one that is extended in PyTorch)
is unable to serialize lambda functions (see
https://stackoverflow.com/a/19985580) for detals.

Unfortunately, given how our tests are structured, we require that the
function under test is passed as an argument to another function, so we
cannot sidestep our use of lambda functions.

To resolve this problem, this patch makes use of the `multiprocess` and
`dill` Python modules, which together offers a multiprocessing mechanism
that can serialize lambda functions.  The multiprocess module also
offers a process pool, which simplifies the code for our parallel
testing framework.
2022-09-29 12:07:43 -05:00
Ashay Rane a9e1014fc7
python: add `CHECK-LABEL` statements to localize checks (#1363)
It seems as though an upstream change in PyTorch has caused the module
dump to include not just the module being tested, but also several
seemingly unrelated functions in the `torch._decom.decompositions`
namespace.  The presence of these new functions caused lit to match
variables against incorrect statements (i.e.  statements in the
unrelated functions instead of the module under test).

This patch inserts `CHECK-LABEL` statements in the failing tests so that
lit ignores these unrelated functions and only checks the statements at
or after the test module definition.
2022-09-13 15:44:13 -05:00
Tanyo Kwok 290d7755fb
importer: add initial support for loading Float16 tensors (#1169)
follow up #761:

    This patch updates the `torch_mlir::convertTensorToMlirElementsAttr()`
    method to enable the creation of tensors whose base type is Float16.
    This patch also adds a test to validate the IR generation, and it
    updates the test for importing tensors of various types.
2022-08-08 12:37:31 +08:00
Jae Hoon (Antonio) Kim 425362263b Clean up Autogen (#1112)
* Remove unnecessary sed in autogen

* Remove .pyc files frrom VCS
2022-07-30 09:40:02 -04:00
Jae Hoon (Antonio) Kim 1bde00c73d Fix LTC Decoupling (#815)
* Initial changes

* Fix up native functions

* Further fix decoupling

* Remove unnecessary ops

* Formatting and copyright banners:

* Add pytorch submodule
2022-07-30 09:40:02 -04:00
powderluv 31fd812acf
Add linux and macOS source builds in CI (#1070)
This enables building Pytorch from source in the CI.
The build should mostly hit the ccache.
Release builds will follow once we have some runtime on the CI.
2022-07-21 14:16:03 -07:00
Ashay Rane 72dd04cdb3
Revert "python: trim registration and loading of dialects and passes" (#1093)
This reverts commit ad283c1043, since it's
causing nightly build failures for all platforms.
2022-07-21 09:35:42 -07:00
Ashay Rane ad283c1043
python: trim registration and loading of dialects and passes (#1084)
In the interest of merging upstream LLVM quickly, a previous patch
(7f08169) updated the torch-mlir build to register all dialects and
passes through Python bindings.  This patch limits the dialects and
passes to only those that are used in torch-mlir.

Key to this change are the removal of
`MLIRPythonExtension.RegisterEverything` and the introduction of a new
Python module (`_mlir_libs/_site_initialize_0.py`), where we register
the dialects and passes used by torch-mlir.
2022-07-20 18:34:17 -07:00
Ashay Rane f947443f98
python: lower `prim::{Load,Store,Enter,Exit}` nodes to torch dialect (#983)
TorchScript nodes like `prim::Load` and `prim::Store` aren't supported
in torch-mlir because they can't be lowered to backends, but such nodes
can occur in the TorchScript IR.

This patch adds a rudimentary translation from such nodes to
corresponding ops in the Torch dialect.  Since we expected such nodes to
go away during lowering because of the SymbolDCE pass, this patch does
not add code to lower these ops beyond the Torch dialect.
2022-06-30 13:17:35 -07:00
Sean Silva 227dea7b2e Add support for ScalarType::QUInt8
I ran into this while poking around at
https://github.com/llvm/torch-mlir/issues/959
2022-06-29 15:33:28 -07:00
Ashay Rane bb52a460cb
mlir: bump llvm tag to 5380e3 (#856)
In addition to updating the llvm-project submodule, this patch also:

1. updates shape functions and tests so that `func` and `call`
   operations refer to the `func` dialect
2. avoid duplicate registration of dialects
2022-05-16 12:54:35 -07:00
Sean Silva ab5ad7af09 Add tracing suport to `torch_mlir.compile`.
This also has a fix for the adjustment of types of TupleConstruct
inputs, which I found when using this new functionality on a model.

Some scenarios in tracing create situations where the output of
TupleConstruct has a more refined type than the inputs.

This introduces a helper `adjustStaticInformationForValues` which
subsumes the `derefineValues` helper and the tensor static information
adjustment we were doing.
2022-05-03 09:08:40 -07:00
Ashay Rane 809f240f01
importer: add initial support for loading BFloat16 tensors (#761)
This patch updates the `torch_mlir::convertTensorToMlirElementsAttr()`
method to enable the creation of tensors whose base type is BFloat16.
This patch also adds a test to validate the IR generation, and it
updates the test for importing tensors of various types.
2022-04-29 09:01:49 -07:00
Sean Silva 73cc2ac152 Ensure that imported function input type and block arg types are consistent.
I wasn't able to find exactly what frontend situation created it, but
`torch.jit.trace` will sometimes create functions where the
`jit::Block`'s param node has refined tensor types. So we need to adjust
the function's formal param types to those refined types.
2022-04-27 08:01:23 -07:00
Sean Silva 140babd952 Add minimal support for Union types.
A recent PyTorch commit made ConstantPad2d call a helper function with a
`Union[int, float]` type annotated. This commit adds minimal support for
representing and dealing with that.
https://github.com/pytorch/pytorch/pull/73287

Changes:
- Adding support for `!torch.union<T1, T2, T3>`/`Torch::UnionType`,
  along with the importer and CAPI code.
- Add support in isValidSubtype for union types.
- Adding a canonicalizer for `torch.derefine` to help simplify some code
  that derefines to a UnionType (this also fixes #664).

There is still more work to do for really supporting UnionType well,
such as canonicalizing UnionType's so that they can be compared with
pointer equality.
2022-03-29 17:45:48 -07:00
Sean Silva 84a9693006 Elide `!torch.` prefix in nested dialect types.
This leads to much more succinct types in many cases:

```
!torch.list<!torch.int>
!torch.list<int>

!torch.tuple<!torch.list<!torch.int>, !torch.list<!torch.int>>
!torch.tuple<list<int>, list<int>>

!torch.optional<!torch.list<!torch.int>>
!torch.optional<list<int>>

!torch.list<list<list<tensor>>>
!torch.list<!torch.list<!torch.list<!torch.tensor>>>
```

I would like to take this further and allow omitting the `!torch.`
prefix in all cases, but that's harder -- for example, we currently use
`FuncOp` for functions, and so I don't think we can customize the
printing there. It seems like it will be a longer road to getting that
level of customization.
2022-03-15 17:24:08 -07:00
Sean Silva a5fe0cf063 Introduce new shape library design.
See the documentation in `docs/shape_lib.md` and
`docs/adding_a_shape_function.md` for an overview of the system.

This completely overhauls how we represent shape functions. In
particular, RefineTypes does not infer shapes anymore (only dtypes).
Shape functions are now written in (TorchScript'able) Python.

Recommended review order:

1. Read `docs/shape_lib.md` and `docs/adding_a_shape_function.md`.
1. Code and tests for ReifyShapeCalculations, DropShapeCalculations.
1. Code and tests for SimplifyShapeCalculations.
1. shape_lib_gen.py
1. Code and tests for new RefineTypes pass.
1. Random folders/canonicalizers in TorchOps.cpp and associated test in
   `canonicalize.mlir`.
1. New ReadOnly trait inferred from the registry.
1. Any miscellaneous remaining stuff.

Example `-print-ir-after-all` for ElementwiseUnaryModule:
[IR lowering dump](https://gist.github.com/silvasean/e4dc8cbc8d00aac7819602e3cbd8e212).

Example `-print-ir-after-all` for ElementwiseBinaryModule:
[IR lowering dump](https://gist.github.com/silvasean/daf6860ecced732af3568af6b1899113).
2022-03-15 12:41:58 -07:00
Henry Tu 73ac9a7e2e Added support for importing node prim::Constant with list type
Prior to this commit, importing a `prim::Constant` node with list type would result in an error since it was not supported. `ivalue_importer::importIValue` was modified to return the MlirValue corresponding to the root so its parent operation could be extracted.
2022-02-11 20:54:06 -05:00
Sean Silva 19e9fc4ef1 Bring some more order to the e2e error reporting situation.
- Move `run_pipeline_with_repro_report` to a more common place, and use it
  consistently
- Attach a `torch.debug_module_name` to the enclosing `builtin.module`
  op to allow for self-contained error reporting (not needing to pass
  the names around.
- Remove redundant error reporting in linalg_on_tensors_backend.py and
  tosa_backend.py (their respective backend abstract base classes now
  take care of the error reports themselves)
- Save off original value of sys.stderr, rather than always resetting to
  `sys.__stderr__`. This is just more hygienic, and allows nesting if
  desired.
2021-10-08 13:00:12 -07:00
Sean Silva 4fad753073 Move external/torch-mlir to the root of the repo. 2021-09-27 17:11:08 -07:00