Commit Graph

568 Commits (18035021f067379e000f1b7fad18ad02be6d116c)

Author SHA1 Message Date
武家伟 99fb4c8637
Add folder for ToF64Op and FromF64Op (#1257) 2022-08-22 09:49:39 +08:00
Sean Silva 1a7fc3915c [docs] Add architecture doc.
This attempts to get out of my head most of the critical layering and
project structure decisions for Torch-MLIR.
2022-08-18 13:29:49 -07:00
Sean Silva 283e0f141a Add a concept of "backend legal ops".
This is a first step towards formalizing the set of ops in our backend
contract. The goal is to eventually formalize `torch` dialect ops into 3
categories:
1. Legal in backend contract
2. Illegal in backend contract
3. Conditionally legal in backend contract

The "conditionally legal" set are the ops that we can optionally
decompose for backends.

This patch adds relevant pass options for this throughout the compiler,
in preparation for a new set of traits which will formalize this
classification.
2022-08-18 11:46:50 -07:00
Sean Silva 57681f7947 Iteratively run the main simplification pipeline.
This introduces a new pass LowerToBackendContract (better name very
welcome) which performs the bulk of the simplifications that we do,
such as
- shape refinement
- dtype refinement
- maximizing value semantics
- inlining global slots
- decomposing complex ops

The key difference from before is that it iterates the set of
transformations, which can help to break a number of "catch-22" issues
where one simplification depends on another, the latest example being
here:
https://github.com/llvm/torch-mlir/issues/1131

This also exposed that RefineTypes was sometimes crashing/asserting for
certain inputs. This commit hardens it a bit.
2022-08-17 14:54:33 -07:00
Yan Xu 9be8997536
Revert "add native_dropout and related ops pattern (#1211)" (#1230)
This reverts commit c935795086.
2022-08-17 13:48:10 +08:00
武家伟 3b3cb99ef8
Generalize canonicalization pattern for more aten.sub/div/mul/add op (#1209)
Generalize canonicalization pattern for more sub/div/mul/add op, but for AtenDivTensorModeOp in 'trunc' rounding mode, we try to fold it.
2022-08-16 13:24:08 +08:00
Yan Xu c935795086
add native_dropout and related ops pattern (#1211) 2022-08-15 09:28:47 +08:00
Prashant Kumar b1a506624c Add decomposition of `aten.masked.tensor` op.
`aten.masked.tensor` op has been decomposed to `aten.masked.scalar` op.
2022-08-11 07:48:04 +05:30
Vidush Singhal dd2da5a038
E2E support for AtenRemainderScalarOp (#1200) 2022-08-10 20:02:06 -04:00
gpetters94 79b9cf9468
Add lowering for aten.to.device (#1107) 2022-08-10 19:24:02 -04:00
Ramana Radhakrishnan 738f4fe96a
Rename TorchToStd pass as TorchToArith (#1163)
All the converters in this pass appear to create ops from the
arith dialect. Hence the full rename.

Fix GH Issue #409.
2022-08-10 20:12:51 +01:00
powderluv e55fc4deb5
Revert "E2E support for AtenRemainderScalarOp (#1119)" (#1190)
This reverts commit 34e207eeb5.
2022-08-08 22:59:57 -07:00
Ashay Rane bb47c166a0
llvm: update tag to 061e0189 (#1180)
Summary of changes:
 - Switch to C++17 (similar to https://reviews.llvm.org/D131348)
 - Update MHLO to build with LLVM commit hash 061e0189
 - Replace deprecated `hasValue()` and `getValue()` with `has_value()`
   and `value()` respectively (https://reviews.llvm.org/D131349)
 - Use `TypedAttr` (https://reviews.llvm.org/D130092)
 - Use updated assembly format of `mhlo.compare` op (commit
   d03ef01e70fbf9afd0fa1976fbb7ed31838929b3 in MHLO repo)
2022-08-08 20:17:35 -07:00
Sean Silva 504de5e701 Rework how global slot initializers work.
Rather than a per-global-slot initializer region, we now have one for
the whole module. For example, it might look like this:

```
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
torch.global_slot.module_initializer {
  %0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
  %1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
  torch.initialize.global_slots [
    @tensor(%0 : !torch.tensor)
    @list(%1 : !torch.list<tensor>)
  ]
}
```

This new structure allows GlobalizeObjectGraph to create the initializer in a
much simpler way, avoiding the need to reason about whether different slots
alias each other. Reasoning about whether slots alias each other now is the
responsibility of InlineGlobalSlots, which has to do a much more complicated
analysis, implemented using MLIR's dataflow analysis framework.

Recommended review order:
- Check out the new IR constructs in the .mlir files of various passes
- Op definitions (*.td)
- Changes to GlobalizeObjectGraph pass.
- InlineGlobalSlots pass (~total rewrite)
- Misc changes:
  - Moving torchMlirAdjustStaticInformation for sharing with C++ code.
  - EraseModuleInitializer pass

To make this a bit nicer, it would be good to have a `torch.module` op
with an initializer region attached. That would be more invasive though.

This change has highlighted certain aspects of our project layering
which are worth calling out. None of our backends can handle global
slots, so we enforce that there are no global slots before backend
lowering. At an earlier stage in the project, we had aspirations of
transparently handling mutable global state and such, but for reasons
described below, that is no longer a goal. So really global slots should
be seen as a progressive lowering step as part of inlining all the
IValue's in the original program (GlobalizeObjectGraph is also one such
step).

Over time, with insights from work like IREE-JAX, it has become clear
that there isn't a reliable programming model we can compile for users
where we just transparently handle mutable global state (and some other
things, like lists and dictionaries). There is a need for an "outer
program" that orchestrates more restricted subroutines of the kind we
can handle in our compile flow here. The benefit of that is that it
decouples considerations like shapes, dtypes, etc. from the program
constructs used in the outer program. As long as the outer program can
efficiently invoke (pipelining/async/etc.) high-performance
data-parallel numerical subroutines of the kind we compile in our flow
here, then there is a complete programming model. This is also
consistent with the direction of upstream PyTorch which is becoming more
tracing-based (which inherently loses a lot of program structure, which
then has to be applied back with an "outer program" orchestrating the
traced subroutines).
2022-08-08 18:12:06 -07:00
Vidush Singhal 34e207eeb5
E2E support for AtenRemainderScalarOp (#1119)
* E2E support for AtenRemainderScalarOp
2022-08-08 20:02:52 -04:00
Vidush Singhal b70548edff
Add decomposition and E2E support for Aten_EmbeddingBag (#1137)
* Add decomposition and E2E support for Aten_EmbeddingBag
2022-08-08 18:56:49 -04:00
Vivek Khandelwal c129a6de93 [MLIR][TORCH] Add support for dim=None to Aten[Var|Std]DimOp
PyTorch recently added support for `dim=None` in the `torch.var`
(5ca9b2b6fa)
and `torch.std`op (eb0e30e0bc).
This commit adds the corresponding support in torch-mlir.

Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>
2022-08-05 20:28:56 +05:30
Ramiro Leal-Cavazos a7af1fd873
Add support for `dim=None` to `AtenMeanDimOp` (#1129)
PyTorch recently added support for `dim=None` in the `torch.mean`
op (2bfae07a79). This
commit adds the corresponding support in torch-mlir.
2022-08-02 16:08:06 +00:00
Quinn Dawkins 38d8498b21
add e2e support for aten.atan2 (#1117)
- Includes math-to-libm pass in refbackend for math::atan2 support
2022-08-02 11:39:41 -04:00
Vidush Singhal ed13ebfd8d
E2E support for AtenEmbeddingBagPaddingIdxOp SUM Mode (#1066) 2022-08-01 16:44:11 -04:00
Alec 554570f3ab Implemented a decomposition of aten::narrow 2022-08-01 18:32:14 +05:30
Henry Tu cec74b8d37 Blacklist _convolution op (#1048)
* Blacklist _convolution op in LTC

* Removed duplicate Torch_AtenSelectScatterOp instance from autogen .td

* Removed duplicate Torch_AtenSliceScatterOp instance from autogen .td
2022-07-30 09:40:02 -04:00
Henry Tu f5acad8512 Prune xfail e2e LTC tests & fix bugs from functionalization pass (#1044)
- Pruned number of xfailed e2e LTC tests from 305 to 134
  - Reviewed every failure to ensure the error genuinely warrants an xfail
- Fixed bug where non-tensor outputs of LTC computation had `.to('cpu')` called, which caused a failure and inflated the xfail count
- Fixed bug with `HBC_basic` test where a constant tensor was created in its constructor without being declared as a buffer, which prevented the device from being updated when the parent `torch.nn.Module` got moved to the `lazy` device
  - Note that this test is still xfail'd due to some unsupported ops. Left a comment about some potential issues that may arise if it gets reenabled in the future
- Updated autogen `GeneratedTorchOps.td` to reflect the latest set of supported ops
- Renamed `aten.zero.functionalization` to `aten.zero` to reflect upstream PyTorch changes
2022-07-30 09:40:02 -04:00
Jae Hoon (Antonio) Kim fb21c9e6cb Integrate Functionalization Pass (#998)
* Fix autogen build dir issue

* Got functionalization pass to compile

* Add slice/diagonal backwards functionalization

* Fix codegen invocation in CMakeLists.txt

* Add functionalization view ops

* Fix logsumexp out functionalization

* Fix ComputationPtr

* Blacklist new_empty op

* Add op comparison

* Remove unnecessary ops

Co-authored-by: Henry Tu <henry.tu@cerebras.net>
2022-07-30 09:40:02 -04:00
Henry Tu 0c35e607b3 Add static shape for scalar tensors (#833)
* Assume zero rank tensors are scalar

* Run RefineTypes pass on JIT Graph

* Rollback assumption that zero rank tensors are scalar

* Set numSizes to -1 for non-ranked tensors

* Rename RefineTypes to RefineTupleTypes
2022-07-30 09:40:02 -04:00
PhaneeshB 8b5631d4c5 [MLIR][TORCH] Add decomposition for aten.std.dim Op
Signed-Off By: Phaneesh Barwaria <phaneesh@nod-labs.com>
2022-07-29 23:52:54 +05:30
Vivek Khandelwal d386b8f9e5 [MLIR][TORCH] Add decomposition for aten.var.correction op
This commit adds the decomposition for `aten.var.correction` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com
2022-07-29 11:08:57 +05:30
Quinn Dawkins 11a8901078
[MLIR][TORCH] Add support for multiple indexing tensors for aten.index.Tensor (#1097)
- Includes a canonicalizer for `aten.add.t`needed for successfully lowering the shape function
 - Only offers support for statically sized index tensors when there is more than one
 - Dynamic shape support remains for single indexing tensors
2022-07-28 19:00:02 -04:00
Quinn Dawkins 3c9addf19c Add e2e support for aten.expm1 2022-07-27 12:31:35 +05:30
Kevin Kiningham e8f327cc00 Add lowering to linalg for softplus and log1p
Follows existing conventions for unary operators.
2022-07-25 21:25:57 +05:30
Ramiro Leal-Cavazos f271e6a88c
Add verifiers for ToBuiltinTensorOp and FromBuiltinTensorOp (#1089)
This commit adds verifiers to the ops `ToBuiltinTensorOp` and
`FromBuiltinTensorOp` that make sure that the input and output have
the same shape and data type.
2022-07-21 21:41:45 +00:00
Ashay Rane 72dd04cdb3
Revert "python: trim registration and loading of dialects and passes" (#1093)
This reverts commit ad283c1043, since it's
causing nightly build failures for all platforms.
2022-07-21 09:35:42 -07:00
Ashay Rane ad283c1043
python: trim registration and loading of dialects and passes (#1084)
In the interest of merging upstream LLVM quickly, a previous patch
(7f08169) updated the torch-mlir build to register all dialects and
passes through Python bindings.  This patch limits the dialects and
passes to only those that are used in torch-mlir.

Key to this change are the removal of
`MLIRPythonExtension.RegisterEverything` and the introduction of a new
Python module (`_mlir_libs/_site_initialize_0.py`), where we register
the dialects and passes used by torch-mlir.
2022-07-20 18:34:17 -07:00
Ziheng Jiang c61c99e887
[MHLO] Init MHLO integration. (#1083)
Co-authored-by: Bairen Yi <yibairen.byron@bytedance.com>
Co-authored-by: Jiawei Wu <xremold@gmail.com>
Co-authored-by: Tianyou Guo <tianyou.gty@alibaba-inc.com>
Co-authored-by: Xu Yan <yancey.yx@alibaba-inc.com>
Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com>
2022-07-20 16:18:16 -07:00
Kevin Kiningham 21f905afbe
Emit underscore version of aten.sqrt (#1072) 2022-07-18 23:57:47 -07:00
Ashay Rane 7f08169380
bump llvm tag to 3580daa (#1078)
This patch makes some rudimentary changes to torch-mlir's use of MLIR
Python bindings to work with the most recent LLVM code.  We can perhaps
do better by being more selective in what we link against, instead of
using `MLIRPythonExtension.RegisterEverything`.
2022-07-18 16:49:03 -07:00
Vivek Khandelwal 4c25878e64 [MLIR][TORCH] Add canonicalization pattern for prim.ListUnpack op
This commit adds the canonicalization pattern for the `prim.ListUnpack` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-07-18 13:51:25 +05:30
Sean Silva 795479a88d Remove HasValueSemantics from `is` ops. 2022-07-15 17:03:17 -07:00
Vivek Khandelwal 3589134d31 [MLIR][TORCH] Add decomposition for aten.var.dim op
This commit adds the decomposition for `aten.var.dim` op.
This commit also make changes in the decomposition for `aten.var` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-07-15 09:53:42 +05:30
Ashay Rane 29bc48aedb
torch: add pass to catch non-value tensors (#1052)
This patch adds a new pass `torch-verify-conversion-to-value-semantics`,
which looks for non-value semantics tensors to catch such tensors early
during compilation.

This pass requires `torch-refine-public-return` pass to ensure that
return operations are updated to use value tensors, followed by the
canonicalize pass to remove any dead ops that may use or produce
non-value tensors.
2022-07-13 17:11:15 -07:00
Ramiro Leal-Cavazos 11148e60d6
Undo shape lib changes + update function signature of sum + zero (#1035)
This commit does three things:
  1. Reverts some of the shape lib changes merged in
  https://github.com/llvm/torch-mlir/pull/844
  2. Updates the signature of `aten.sum_dim_IntList` that was recently
  updated in
  23bdb570cf
  3. Replaces `aten.zero.functional` with `aten.zero`, updated in 960758b0b7
2022-07-11 10:56:12 -07:00
Prateek Gupta 2d75654b2c [TORCH][MLIR] Add lowering of `aten.slice_scatter` and
`aten.select_scatter` op.

This commit adds:
1.  Lowering of `aten.slice_scatter` op into `tensor.insert_slice`
op.
2. Decomposes the `aten.select_scatter` op into `aten.slice_scater`
op.

Signed-Off-By: Prateek Gupta <gprateek93@gmail.com>
2022-07-11 14:07:21 +05:30
George Petterson a08ff0d7f2 Add lowering for _convolution 2022-07-11 11:03:03 +05:30
Quinn Dawkins f0c3b5a7ed
Add E2E support for aten.len.str (#969) 2022-07-07 10:41:55 -07:00
Ashay Rane f947443f98
python: lower `prim::{Load,Store,Enter,Exit}` nodes to torch dialect (#983)
TorchScript nodes like `prim::Load` and `prim::Store` aren't supported
in torch-mlir because they can't be lowered to backends, but such nodes
can occur in the TorchScript IR.

This patch adds a rudimentary translation from such nodes to
corresponding ops in the Torch dialect.  Since we expected such nodes to
go away during lowering because of the SymbolDCE pass, this patch does
not add code to lower these ops beyond the Torch dialect.
2022-06-30 13:17:35 -07:00
Sean Silva 227dea7b2e Add support for ScalarType::QUInt8
I ran into this while poking around at
https://github.com/llvm/torch-mlir/issues/959
2022-06-29 15:33:28 -07:00
JakopinA 5888c4f7dc Added E2E support for torch::aten.__contains__int_list 2022-06-27 19:30:00 +05:30
Tanyo Kwok 143a7bcb76
[MLIR][TORCH] Add folder for torch_c.from_i64 & torch_c.to_i64 (#933)
* [MLIR][TORCH] Add folder for torch_c.from_i64 & torch_c.to_i64

* add unit tests for each individual fold

* fix failure of NumelZeroRankModule & TestMultipleTensorAndPrimitiveTypesReturn
2022-06-24 09:34:39 +08:00
erman-gurses 5cff40c88a Add canonicalization for aten.add.tensor op 2022-06-23 17:24:59 -04:00
Vivek Khandelwal 77ab31641f [MLIR][TORCH] Add decomposition of aten.numpy_T op
This commit adds the decomposition of `aten.numpy_T` op into
`aten.t` or `aten.permute` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-06-16 00:01:22 +05:30
Albert Sandru 708a51ae2e Add E2E support for aten.is_floating_point 2022-06-15 11:54:00 -05:00
Bob Adolf b90837ee24
Temporarily revert support for custom op extensions. (#944)
The MacOS builders are having linking trouble with the extension library.
Until it's fixed, all support for op extensions is disabled. It should be
easy to restore once the issue is resolved.
2022-06-14 18:24:40 -07:00
Vivek Khandelwal 33fa8e7761 [MLIR][TORCH] Add decomposition of aten.floor_divide op
This commit adds the decomposition of `aten.floor_divide` op into
`aten.div.Tensor_mode` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-06-14 08:56:25 +05:30
Bob Adolf 0a7ba62438
Allow torch-mlir to support PyTorch extensions. (#895)
PyTorch allows new operators to be registered dynamically in modules.
Torch-mlir already makes it fairly straightforward to add support for
new operators, and this commit just extends that support to allow new
PyTorch ops to come from a external module.

This does *not* allow ops to be dynamically loaded into torch-mlir.
Torch-mlir must still be compiled with support built-in.

Add a `_torch_mlir_custom_op_example` subpackage to `torch_mlir` which
registers an demonstration op. It will not be imported by default when
importing torch_mlir. It's strictly for testing and documentation.

Adds an end-to-end test for the `torch_mlir_custom_op_example::identity` op.

With all these changes, we should now be actively testing PyTorch extension
support with all future patches.
2022-06-13 14:51:30 -07:00
Henry Tu c1da9edcf0
Generate underscore variant of functional ops (#915)
* Generate underscore variant of functional ops

* Do not apply `IsTrailingUnderscoreInplaceVariant` trait to underscore variant of functional op
2022-06-08 14:27:36 -04:00
Vivek Khandelwal b95b3d844d [MLIR][TORCH] Add E2E support for aten.div.Tensor_mode op
This commit adds lowering of `aten.div.Tensor_mode` op.
This commit also fixes formatting for the test file elementwise.py.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-06-07 22:26:44 +05:30
Vivek Khandelwal a11ef674a7 [MLIR][TORCH] Add E2E support for aten.baddbmm op
This commit decomposes `aten.baddbmm` op into `aten.bmm`,
`aten.mul.Scalar`, and `aten.add.Tensor` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-06-07 22:26:28 +05:30
Jae Hoon (Antonio) Kim fe784fd900
Add Support for aten::scatter_add (#906) 2022-06-06 15:02:45 -04:00
Jae Hoon (Antonio) Kim 8a1839a17e
Add support for aten::arange.start_out (#905) 2022-06-06 15:02:27 -04:00
Vivek Khandelwal 2718b4d838 [MLIR][TORCH] Add E2E support for aten.clamp_[min|max] op
This commit decomposes `aten.clamp_min` and `aten.clamp_max` op
into `aten.clamp` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-06-06 11:52:29 +05:30
Vidush Singhal fc419b1e7d
Add E2E support for AtenLogicalOrOp. (#883) 2022-06-03 16:21:03 -07:00
Henry Tu abf5c94a1b
Replace valsem.aten.zero with aten.zero.functional (#893) 2022-06-03 16:27:31 -04:00
Henry Tu 650f5a5008
Added support for native_dropout_backward (#892) 2022-06-03 14:08:51 -04:00
Henry Tu b7082a8d4e
Added support for native_dropout (#891) 2022-06-03 14:05:57 -04:00
Henry Tu a635fd2287
Added support for native_batch_norm_backward (#890) 2022-06-03 13:49:02 -04:00
Henry Tu bfe8ff4b42
Added support for embedding_dense_backward (#889) 2022-06-03 13:33:43 -04:00
Henry Tu a29903dfc8
Added support for native_layer_norm_backward (#888) 2022-06-03 13:15:23 -04:00
Vidush Singhal 0a913bc904
Add E2E support for AtenAllBoolOp (#874) 2022-06-01 18:20:25 -07:00
Vivek Khandelwal 6f548fc3ad [MLIR][TORCH] Add decomposition of aten.adaptive_avg_pool2d op
This commit adds the decomposition of `aten.adaptive_avg_pool2d` op into
`aten.avg_pool2d` op. The current decomposition only supports cases where
input size is equal to the output size.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-05-27 07:56:37 +05:30
Ramiro Leal-Cavazos b76c8c82dc
Emit `aten.unsqueeze` with mutating variants (#873)
The op `aten.unsqueeze` has a mutating variant. This commit adds
support for that variant.
2022-05-26 19:19:38 -05:00
Vivek Khandelwal 56e77d4213 [MLIR][TORCH] Add E2E support for aten.Bool.[float|int] op
This commit adds lowering of `aten.Bool.float` and `aten.Bool.int` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-05-24 21:18:34 +05:30
Vivek Khandelwal 014a6d16c7 [MLIR][TORCH] Add E2E support for aten.any.bool op
This commit adds lowering of `aten.any.bool` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-05-24 17:24:28 +05:30
Vivek Khandelwal bc9b2156e3 [MLIR][TORCH] Add E2E support for aten.sqrt.int op
This commit adds lowering of `aten.sqrt.int` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-05-24 16:50:39 +05:30
Ashay Rane f18b2be911
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function.  It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.

There exist several opportunities to make this lowering optimal and
robust.  For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf.  For L1 norms, we don't need to raise
each element to the power 1.0.  Similarly, L2 norms could benefit from
strength reduction.  Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-19 15:48:15 -07:00
Vivek Khandelwal c69a1e5688 [MLIR][TORCH] Add E2E support for ScalarImplicit, Int.Scalar op
This commit adds lowering of `aten.ScalarImplicit` and `aten.Int.Scalar` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-05-10 22:40:49 +05:30
Prashant Kumar 12b3af70d3 [TORCH] Add folding of aten.detach op.
`aten.detach` op is folded and returns the first operand since it's an
identity function(kind of identity just remove the has_grad attribute).
2022-05-10 21:54:45 +05:30
Yi Zhang 28be6511d2 Fix type promotion code for scalar only operations
Fix the type promotion code for scalar only operation to return
TorchType which is the type tracked in ValueKnowledge.scalarType.

- Fix `getPromotedResultScalarType` to return Torch type.
- Add `getBuiltInTypeForTorchScalar` helper to convert scalar type
to builtin type before passing to the next level type promotion
helper `updateResultTypeState`.
- Add `setScalarType` helper to make setting ValueKnowledge.scalarType
  easier.
2022-05-07 10:37:21 -04:00
Vivek Khandelwal b20679e1b8 [MLIR][TORCH] Modify aten::dropout op description
Signed-Off By: Vivek Khandelwal vivek@nod-labs.com
2022-05-06 11:15:52 +05:30
Vivek Khandelwal 96fabc0036 [MLIR][TORCH] E2E support for [ge|ceil].float, [ge|ne|gt].float_int op
This commit adds lowering of `aten.ge.float`, `aten.ge.float_int`,
`aten.ne.float_int`, `aten.gt.float_int` and `aten.ceil.float` op.
This commit also fixes formatting for the file scalar.py and scalar_comparison.py.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-05-05 21:48:35 +05:30
Kristof Denolf e682b1d0f3 changed name option to decompose-complex-ops 2022-05-05 00:38:51 -07:00
Kristof Denolf 5243638e33 add no decompose option 2022-05-05 00:38:51 -07:00
Yi Zhang 9f7264a7a4 Add support for scalar type propagation
The main changes are:
- Added `ValueKnowledge.scalarType` to track scalar type information.
- Added `ValueKnowledge.kind` to indicate the value kind.
- Modified the meet and join helper functions. The ValueKnowledge has
slightly more complicated state now so the meet and join function need
to look at the `kind` field in addition to just the type field.
2022-05-04 16:57:56 -04:00
Sean Silva ab5ad7af09 Add tracing suport to `torch_mlir.compile`.
This also has a fix for the adjustment of types of TupleConstruct
inputs, which I found when using this new functionality on a model.

Some scenarios in tracing create situations where the output of
TupleConstruct has a more refined type than the inputs.

This introduces a helper `adjustStaticInformationForValues` which
subsumes the `derefineValues` helper and the tensor static information
adjustment we were doing.
2022-05-03 09:08:40 -07:00
Vivek Khandelwal c0634bc996 [MLIR][TORCH] Add E2E support for aten.to.dtype_layout op
This commit decomposes `aten.to.dtype_layout` op into `aten.to.dtype` op.
This commit also fixes the formatting for the file type_conversion.py.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-05-03 12:48:58 +05:30
gpetters94 c4dcdd1e34
Add aten.flip (#817) 2022-05-02 16:01:15 -04:00
Vivek Khandelwal 4b11284440 [MLIR][TORCH] Add E2E support for aten.avg_pool2d op
This commit adds lowering of `aten.avg_pool2d` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-05-02 12:31:44 +05:30
Prateek Gupta 81ee5bb58c [TORCH][MLIR] Fix ConstantPad2dStaticModule test.
This commit fixes the `ConstantPad2dStaticModule` test case by adding
the lowering of `aten.pad` operation. Previously the test case
mapped to `aten.constant_pad_nd` operation.
The `aten.pad` now decomposes into `aten.constant_pad_nd` operation.

Signed-Off-By: Prateek Gupta <prateek@nod-labs.com>
2022-04-29 21:57:01 +05:30
Sean Silva 44c7b181d3 Revert "[MLIR][TORCH] Add E2E support for aten.ge.float op"
This reverts commit 564734b2d7.
2022-04-28 07:49:58 -07:00
Sean Silva eff144c0b7 Revert "[MLIR][TORCH] Add E2E support for aten.ge.float_int op"
This reverts commit 1f102cc400.
2022-04-28 07:49:58 -07:00
Sean Silva 7669ee4e4a Revert "[MLIR][TORCH] Add E2E support for aten.ne.float_int op"
This reverts commit 51dd462592.
2022-04-28 07:49:58 -07:00
Sean Silva 5ef9f501fa Revert "[MLIR][TORCH] Add E2E support for aten.ceil.float op"
This reverts commit 78f5747568.
2022-04-28 07:49:58 -07:00
Vivek Khandelwal e57e1968bc [MLIR][TORCH] Add E2E support for aten.index_put.hacked_twin op
This commit decomposes `aten.index_put.hacked_twin` op into
`valsem.aten.index_put_impl` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-04-28 13:41:47 +05:30
Vivek Khandelwal 78f5747568 [MLIR][TORCH] Add E2E support for aten.ceil.float op
This commit adds lowering of `aten.ceil.float` op.
This commit also fixes formatting for the file scalar.py.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-04-28 11:49:35 +05:30
Yi Zhang 86eb493a44 Change to AnyTorch* except for Torch_X ones 2022-04-27 14:18:52 -04:00
Vivek Khandelwal 51dd462592 [MLIR][TORCH] Add E2E support for aten.ne.float_int op
This commit adds lowering of `aten.ne.float_int` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-04-27 21:16:48 +05:30
Vivek Khandelwal 1f102cc400 [MLIR][TORCH] Add E2E support for aten.ge.float_int op
This commit adds lowering of `aten.ge.float_int` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-04-27 21:16:48 +05:30
Vivek Khandelwal 564734b2d7 [MLIR][TORCH] Add E2E support for aten.ge.float op
This commit adds lowering of `aten.ge.float` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-04-27 21:16:48 +05:30
Vivek Khandelwal f5b6c4b601 [MLIR][TORCH] Add E2E support for aten.div.float op
This commit adds lowering of `aten.div.float` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-04-27 21:16:48 +05:30
Ashay Rane 9208bf0eb6
llvm: bump tag to e1318078 (#781)
The updated LLVM code includes a patch to create bfloat16 array
attributes, thus enabling a different patch to torch-mlir to flesh out
support for the bfloat16 type.
2022-04-26 12:27:51 -07:00
Prashant Kumar 5cdef0213d [LINALG] Bug fix i64 vs i32 type comparison.
Comparing index type instead of integer types solves the problem.
2022-04-22 08:09:58 +05:30
Vivek Khandelwal 769f3a8870 [MLIR][TORCH] Add E2E support for max_pool2d_with_indices op
This commit adds lowering of `max_pool2d_with_indices` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-04-18 21:05:19 +05:30
Vivek Khandelwal 1bccb4fc8a [MLIR][TORCH] Add E2E support for aten::max_pool2d_with_indices_backward op
This commit adds lowering of `aten::max_pool2d_with_indices_backward` op.

This commit also fixes formatting issues in basic.py.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-04-14 21:46:47 +05:30
gpetters94 9ec0683e92
Add 2D case for convolution (#693) 2022-04-08 00:47:57 -04:00
gpetters94 fa0b24a73c
Rename optional list types (#643) 2022-04-07 18:15:51 -04:00
Prashant Kumar 1d5b5a89e8 [LINALG] Add torch.layout information
torch.layout information has been added.
2022-04-07 20:47:49 +05:30
Ramiro Leal-Cavazos 51d4d55f8a
Add support for multi-dim input to `index_put_impl` (#722)
This commit adds support for multi-dimensional tensors as input to the
`_index_put_impl_` op. The support was to some degree already there,
since `ScatterOp` already supports multi-dimensional tensors. This
commit also adds a bit more error checking to `index_put` and
refactors the code for creating `ScatterOp`s to mimic the way one
would make a `Linalg::GenericOp`.
2022-03-31 09:27:21 -07:00
Gaurav Shukla 969785d1b6 [LINALG] Add E2E support for `aten.where.[Scalar|ScalarSelf|ScalarOther]` ops
This commit decomposes different variants of `aten.where.*` op into
`aten.where.Self` op. It covers `aten.where.Scalar`,
`aten.where.ScalarSelf` and `aten.where.ScalarOther` ops.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-03-30 20:36:48 +05:30
Vivek Khandelwal 2597c481f6 [MLIR][TORCH] Add E2E support for aten.new_empty op
This commit decomposes `aten.new_empty` op into `aten.empty.memory_format` op.

This commit also made a dtype fix to the constant tensor allocation like ops.
Earlier the dtype for the result was inferred from the result type; now, it's
being evaluated as per the original definition of the op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-03-30 13:21:01 +05:30
Sean Silva 140babd952 Add minimal support for Union types.
A recent PyTorch commit made ConstantPad2d call a helper function with a
`Union[int, float]` type annotated. This commit adds minimal support for
representing and dealing with that.
https://github.com/pytorch/pytorch/pull/73287

Changes:
- Adding support for `!torch.union<T1, T2, T3>`/`Torch::UnionType`,
  along with the importer and CAPI code.
- Add support in isValidSubtype for union types.
- Adding a canonicalizer for `torch.derefine` to help simplify some code
  that derefines to a UnionType (this also fixes #664).

There is still more work to do for really supporting UnionType well,
such as canonicalizing UnionType's so that they can be compared with
pointer equality.
2022-03-29 17:45:48 -07:00
Liam Fitzpatrick f2269ced80
Improve list index normalization SimplifyShapeCalculations. (#710)
The reified code to compute the shape of torch.aten.constant_pad_nd
uses negative indices when setting list elements. This was not
converted to a positive offset in one place in SimplifyShapeCalculations
which prevented computation of the static shape.
2022-03-29 22:21:47 +02:00
Maksim Levental 25ba51b2af
This commit decomposes aten._reshape_alias op into aten.view op. (#690) 2022-03-28 23:54:28 -05:00
Gaurav Shukla 02b6d04eb4 [LINALG] Add E2E support for `aten.zero_` op
This commit adds decomposition of `aten.zero_` op.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-03-25 12:46:50 +05:30
Gaurav Shukla 7c3ba25238 [LINALG] Add decomposition of `aten.dropout` op
- This commit adds decomposition of `aten.dropout` op. It also covers the
  training mode of the same op.
- It also adds lowering of `aten.sub.float` op.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-03-22 13:14:49 +05:30
Sean Silva 729402c3f4 Reduce compilation time for TorchOps.cpp.inc
The `assemblyFormat` stuff (which generates unrolled, per-op C++ code)
was taking up a lot of compile time, and all the ops are essentially
printed with the same logic. So this PR makes them all call the same
helper function. This is done by using
`let hasCustomAssemblyFormat = 1` and then implementing `FooOp::parse`
and `FooOp::print`.

Additionally, the `Generated*Ops.td` files are all collapsed into just
`GeneratedTorchOps.td` (there is no reason to have the files separate,
since the files are very large anyway so one is always having to search
within them -- editors don't care that the file to search is now a bit
bigger :) ).

This reduces TorchOpsODSGenerated.cpp compile time (which is now
GeneratedTorchOps.cpp) from 39 to 31 seconds on my machine. This is
actually less than I expected, but this PR is an overall cleanup to the
code anyway. The next step will be to introduce (better) functionality
upstream for sharding the TorchOps.cpp.inc file, so that we can truly
parallelize the O(#ops) costs. This is also necessary, because after
this PR, TorchDialect.cpp is now the slowest file to compile, due to the
`addOperations<... all the ops ...>` call, which needs to be shareded
too.
2022-03-21 14:42:26 -07:00
Vivek Khandelwal 5b9bdfaf3f [MLIR][TORCH] Add E2E support for aten._to_copy op
This commit decomposes `aten._to_copy` op into
`valsem.aten.copy` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-03-21 19:12:37 +05:30
Vivek Khandelwal 13383b03b8 [MLIR][TORCH] Add value tensor variant to aten::copy_ op
This commit adds the op `ValsemVariantAtenCopyOp` that represents
`AtenCopy_Op` without the underscore. This is needed to make sure
that the `ReduceOpVariants` pass turns the in-place op into an op
that takes value tensors as inputs, otherwise the
`MaximizeValueSemantics` pass will not be able to add value
semantics correctly.

This commit also adds the lowering of `ValsemVariantAtenCopyOp`.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-03-21 19:12:37 +05:30
Vigilans 63fb1e5aad Bump LLVM at 8361c5da30588d3d4a48eae648f53be1feb5cfad 2022-03-18 13:16:14 -04:00
Sean Silva 3b66b4925a Make TorchOps.cpp faster to iterate on.
The ODS-generated code included via the `TorchOps.cpp.inc` file takes a
very long time to compile. This PR isolates it into its own file so that
the build system can cache it.

This PR creates a new file `TorchOpsODSGenerated.cpp` just to include
the `TorchOps.cpp.inc` file. Doing so required moving to the "new" way
to define verifiers, since the static `verify` free functions in
TorchOps.cpp weren't accessible from the .inc file after it was moved to
`TorchOpsODSGenerated.cpp`.

On my machine, this drops the build time of TorchOps.cpp (such as when
iterating on a canonicalizer) from >40 seconds to <10 seconds.
10 seconds still isn't great though, but at least it isn't "go get a
coffee" type of waiting.
2022-03-16 09:33:12 -07:00
Vivek Khandelwal 3d95c3d6c9 [MLIR][TORCH] Add value tensor variant to aten::_index_put_impl_
This commit adds the op `ValsemVariantAtenIndexPutImplOp` that represents
`Aten_IndexPutImpl_Op` without the underscore. This is needed to
make sure that the `ReduceOpVariants` pass turns the in-place op
into an op that takes value tensors as inputs, otherwise the
`MaximizeValueSemantics` pass will not be able to add value
semantics correctly.

This commit also adds the lowering of `ValsemVariantAtenIndexPutImplOp` op.

This commit also updates the `torch.bincount` op test cases.
2022-03-16 22:02:02 +05:30
Ramiro Leal-Cavazos 0bcc6d1075
Add maximize-value-semantics support for multiple non-value tensor inputs (#659)
This commit adds value semantics support for ops such as
`aten.view_as` and `aten.expand_as` that take two non-value 
tensors as input.
2022-03-15 18:13:45 -07:00
Sean Silva 92da4988f0 Improve "pseudo" op terminology.
The term "pseudo" is very vague and was getting confusing (I felt I had
to explain it in every comment referencing it). Instead, rework the
"pseudo" ops to instead be named:

- MLIR Syntax: `torch.valsem.*`
- C++ / ODS: `ValsemVariant*Op`

This makes it clear what the concept is, and avoids confusion with other
things that might be called "pseudo", since these are very specific and
should be 100% consistently named w.r.t. the non-valsem-variant ops that
they correspond to.
2022-03-15 17:57:52 -07:00
Sean Silva 7ea50a537a Avoid `using` the `torch_upstream` namespace.
This is code that we always want to treat as "foreign" and not get too
comfortable using in many functions. One way to accomplish that is to
make it a bit clunkier to use.

Also, fix Utils.cpp to match the LLVM/MLIR coding conventions (don't
define functions inside namespaces -- prefer `using` and explicit
qualification).
2022-03-15 17:24:17 -07:00
Sean Silva 84a9693006 Elide `!torch.` prefix in nested dialect types.
This leads to much more succinct types in many cases:

```
!torch.list<!torch.int>
!torch.list<int>

!torch.tuple<!torch.list<!torch.int>, !torch.list<!torch.int>>
!torch.tuple<list<int>, list<int>>

!torch.optional<!torch.list<!torch.int>>
!torch.optional<list<int>>

!torch.list<list<list<tensor>>>
!torch.list<!torch.list<!torch.list<!torch.tensor>>>
```

I would like to take this further and allow omitting the `!torch.`
prefix in all cases, but that's harder -- for example, we currently use
`FuncOp` for functions, and so I don't think we can customize the
printing there. It seems like it will be a longer road to getting that
level of customization.
2022-03-15 17:24:08 -07:00
Sean Silva a5fe0cf063 Introduce new shape library design.
See the documentation in `docs/shape_lib.md` and
`docs/adding_a_shape_function.md` for an overview of the system.

This completely overhauls how we represent shape functions. In
particular, RefineTypes does not infer shapes anymore (only dtypes).
Shape functions are now written in (TorchScript'able) Python.

Recommended review order:

1. Read `docs/shape_lib.md` and `docs/adding_a_shape_function.md`.
1. Code and tests for ReifyShapeCalculations, DropShapeCalculations.
1. Code and tests for SimplifyShapeCalculations.
1. shape_lib_gen.py
1. Code and tests for new RefineTypes pass.
1. Random folders/canonicalizers in TorchOps.cpp and associated test in
   `canonicalize.mlir`.
1. New ReadOnly trait inferred from the registry.
1. Any miscellaneous remaining stuff.

Example `-print-ir-after-all` for ElementwiseUnaryModule:
[IR lowering dump](https://gist.github.com/silvasean/e4dc8cbc8d00aac7819602e3cbd8e212).

Example `-print-ir-after-all` for ElementwiseBinaryModule:
[IR lowering dump](https://gist.github.com/silvasean/daf6860ecced732af3568af6b1899113).
2022-03-15 12:41:58 -07:00
Prateek Gupta 3d9ba5e525 [MLIR][TORCH] Add E2E support for aten.erf op.
Signed-Off-By: Prateek Gupta <prateek@nod-labs.com>
2022-03-09 22:22:03 +05:30
Vivek Khandelwal 1a2a9e066f [MLIR][TORCH] Add TorchToTMTensor pass
This pass is added to lower ops, which can not be lowered
via the TorchToLinalg pass, such as `torch.bincount` op.
This pass also uses torch-mlir's TMTensor Dialect to lower the
complex ops.

Also add torch.bincount op lowering with the help of TMTensor dialect

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-03-08 22:52:34 +05:30
Vivek Khandelwal b2952b12dd [MLIR][TORCH] Move common helper functions to Utils.cpp
This commit moves the helper function which are common across
different torch-mlir conversion passes into a common directory
Utils.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-03-08 22:52:34 +05:30
Gaurav Shukla e57d3f9774 [LINALG] Fix `aten.bernoulli` op lowering
- This commit adds E2E support for `aten.rand_like` and
  `aten.bernoulli_.Tensor` ops.
- The `aten.bernoulli(x)` was implemented as:
  `aten.bernoulli(x) = rand_like(x) < 0.5`, assuming 0.5 as default
  probability, whereas according to the pytorch documentation:
  https://pytorch.org/docs/stable/generated/torch.bernoulli.html#torch.bernoulli
  the input x in `aten.bernoulli(x)` is itself a tensor containing
  probabilities to be used for drawing the binary random number.
- So this commit fixes the `aten.bernoulli(x)` implementation as:
  `aten.bernoulli(x) = rand_like(x) < x`.
- It also fixes the case where the input to `aten.bernoulli_.float` is
  an integer tensor. In this case the input must be casted to float type
  before passing it as operand to `aten.rand_like` op.
  `aten.bernoulli_.float(x, p) = rand_like(float(x)) < p`.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-03-05 09:38:22 +05:30
Vivek Khandelwal af551bd9cd [MLIR][TORCH] Add E2E support for aten.full_like op
This commit decomposes `aten.full_like` op into `aten.empty_like`
and `aten.fill` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-03-04 21:58:23 +05:30
Vivek Khandelwal d61ae92eee [MLIR][TORCH] Add E2E support for aten.full op
This commit decomposes `aten.full` op into `aten.empty` and
`aten.fill` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-03-04 21:58:23 +05:30
Yi Zhang 1d285f0153 Add aten.hardtanh e2e support. 2022-03-02 12:28:06 -05:00
Prashant Kumar 819f29316f Decompose aten.silu op
Decomposition of aten.silu.op is added as silu(x) = x * sigmoid(x).
2022-03-01 23:24:19 +05:30
Vivek Khandelwal ddd45d6068 [MLIR][TORCH] Add E2E support for aten.new_zeros, aten.new_ones op
This commit adds lowering of `aten.new_zeros` and `aten.new_ones` op

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2022-03-01 22:09:47 +05:30
Ramiro Leal-Cavazos 1dba4fcbd7
[LINALG] Support for contiguous memory format in `clone` and `empty` (#628)
This commit adds support for the contiguous memory format for the ops
`AtenCloneOp` and `AtenEmptyMemoryFormatOp`.
2022-02-28 13:58:04 -08:00
Prashant Kumar 7c637eebc3 [LINALG] Decompose aten_hardswish op.
`aten.hardswish` op is decomposed into (x/6) * Relu6(x+3).
2022-02-25 21:59:27 +05:30
Gaurav Shukla 056cd2078d Revert "[LINALG] Decompose `aten.batch_norm` into `aten.native_batch_norm`"
This reverts commit 442ff4605c.
2022-02-25 15:46:55 +05:30
Ramiro Leal-Cavazos ba29d4f250
Add operand type invariant to `torch.overwrite.tensor.contents` (#606)
This commit adds the invariant to the op `torch.overwrite.tensor.contents` that
both of its operands have the same shape and size. In order to
maintain the invariant, special handling of this op is added to the
`RefineTypes` pass.
2022-02-22 11:41:46 -08:00
Prashant Kumar abbde7d439 [TORCH] The torch definition related to aten.gelu has changed.
New str argument approximation is added.
2022-02-18 21:57:46 +05:30
Nirvedh f8cb32faf0 LLVM bump
Major changes: opTrait changed to Trait, selectOp moved to arith dialect
assertOp moved to cf dialect
2022-02-16 15:28:13 -05:00
Gaurav Shukla 442ff4605c [LINALG] Decompose `aten.batch_norm` into `aten.native_batch_norm`
- This commit decomposes the `aten.batch_norm` op into the
  `aten.native_batch_norm` op, instead of lowering it to the
  `linalg.generic` op.
- It also adds run-time asserts in the `aten.native_batch_norm` lowering
  to make sure that the shape of the weight, bias, running_mean, and
  running_var must match the num of features.
- Since the `aten.native_batch_norm` op is not supported at TOSA backend,
  all the modules that are dependent on the `aten.native_batch_norm` op
  will fail and therefore they should be removed from the TOSA `passing`
  set.
- It also moves `checkNotNone` to utility.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-02-16 23:41:38 +05:30
Gaurav Shukla cd21dda867 [LINALG] Add E2E support for `aten.Hardsigmoid` op
This commit adds lowering of `aten.Hardsigmoid` op.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-02-16 02:35:18 +05:30
Ramiro Leal-Cavazos 00a6e9c1bb
[LINALG] Add value tensor variant to `fill_.Scalar` (#600)
This commit adds the op `PseudoAtenFillScalarOp` that represents
`AtenFill_ScalarOp` without the underscore. The approach is the same
as in commit dd998fa4d4.

Adding this op allows for a simpler and more consistent version of the
`empty` and `empty_like` op e2e tests.
2022-02-15 11:58:03 -08:00
Gaurav Shukla 41acde599b [LINALG] Add E2E support for `aten.[le|ge].Scalar` ops
- This commit adds lowering of `aten.le.Scalar` and `aten.ge.Scalar` ops
  as a part of `convert-torch-to-linalg` pass.
- It also creates a new test script `elementwise_comparison.py` for all
  element-wise comparison ops.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-02-15 12:21:09 +05:30
Ramiro Leal-Cavazos 413e6000d2
[LINALG] Add value tensor variant to `bernoulli_.float` (#597)
This commit adds the op `PseudoAtenBernoulliFloatOp` that represents
`AtenBernoulli_FloatOp` without the underscore. This is needed to make
sure that the `ReduceOpVariants` pass turns the in-place op into an op
that takes value tensors as inputs, otherwise the
`MaximizeValueSemantics` pass will not be able to add value semantics
correctly.
2022-02-14 18:58:48 -08:00
Gaurav Shukla f00d1686c8 [LINALG] Add E2E support for `aten.[Bool.Tensor|Float.Tensor]` op
- This commit adds lowering of `aten.Bool.Tensor` and
  `aten.Float.Tensor` op as a part of `convert-torch-to-linalg` pass.
- It also adds support for returning bool types.
- It also fixes lowering of the `aten.Int.Tensor` op for non-zero rank
  input tensors.
- If a scalar number is converted to a 0-d tensor and passed on to the
  `aten.Float.Tensor` op, it folds to the scalar number.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2022-02-14 23:09:20 +05:30
Yi Zhang 9e7b6cab08 Add folder for aten.gt/lt.float 2022-02-14 12:34:01 -05:00
Prashant Kumar 258660deb6 Add aten.bernoulli decomposition.
aten.bernoulli is decomposed to aten.gtTensor(aten.uniform(x), x).
2022-02-11 00:35:33 +05:30
Prashant Kumar 102c497c4c Add decomposition of _log_softmax op.
Decompose _log_softmax into log(softmax(x)).
2022-02-10 23:17:26 +05:30
Prateek Gupta 318946a650 [TORCH][MLIR] Add E2E support for `aten._unsafe_view` op.
This commit adds decomposition of `aten._unsafe_view` op into
`aten.view` op.

Signed-Off-By: Prateek Gupta<prateek@nod-labs.com>
2022-02-10 22:28:58 +05:30
Ramiro Leal-Cavazos 9b89f8eb3f
[TORCH][MLIR] Add E2E support for aten.clone (#571)
This commit adds support for the aten.clone op.
2022-02-09 19:31:03 -08:00