Commit Graph

928 Commits (9a6fe58a027d701eff6799e86a65535a8c2f3708)

Author SHA1 Message Date
Aart Bik f72770a725
[torch-mlir][sparse] replace ad-hoc mechanism with proper FX export (#3648)
Now that the PyDev feature request pytorch/pytorch#117188 has been
completed, we can remove all the ad-hoc code that propagates sparsity
metadata and replace it with the built-int PyDev metadata for sparse
tensors. This removes a lot of code and also ensures sparsity is
consistent with the torch.sparse package for all cases.
2024-08-20 09:56:21 -07:00
penguin_wwy 37e89828a1
[FxImporter] refactor canonicalize using table driven (#3402) 2024-08-16 22:57:18 +08:00
Vivek Khandelwal 78d0fa8998
build: manually update PyTorch version (#3568)
Set PyTorch and TorchVision version to nightly release 2024-08-04.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-08-06 21:36:39 +05:30
Jiawei Wu 6f7a5db801
[FxImporter] small fixes for fx importer compatibility issues between different pytorch versions (#3577) 2024-08-01 10:52:41 +08:00
Matthew Francis-Landau fe9db78120
Allow custom ops to return an array of tensors (#3531)
This PR adds support to `fx_importer.py` for handling custom ops that
return an array of tensors. As long as the length of the array is
consistent across runs (determined statically), then this patch will
work. This does not require that the number of tensors returned is
determined by the op's definition.

CC @sjain-stanford
2024-07-14 11:54:23 -07:00
Sambhav Jain cdbcf519f7
[NFC] Expose both raw Torch dialect and Torch dialect in backend form with Dynamo/FX (#3541)
This is a non-functional change. It merely allows intercepting the Torch
dialect during TorchDynamo export at two stages:
1. `OutputType.RAW`: This gets us the torch dialect as-imported from the
FX graph
2. `OutputType.TORCH`: This gets us the torch dialect after the raw
torch goes through DecomposeComplexOps and ReduceOpVariants.

Prior to this, there was no way of accessing the Torch dialect in
backend compliant form (right after running the
`torchdynamo-export-to-torch-backend-pipeline`) because both
[here](https://sourcegraph.com/github.com/llvm/torch-mlir@5e4f00acb13f3f849a05e5ac28ee39307a5fdbff/-/blob/python/torch_mlir/fx.py?L33)
and
[here](https://sourcegraph.com/github.com/llvm/torch-mlir@5e4f00acb13f3f849a05e5ac28ee39307a5fdbff/-/blob/python/torch_mlir/compiler_utils.py?L138)
the same `OutputType.TORCH` were used, meaning the 2nd condition would
never be reached.

Since the default behavior is unchanged, this is an NFC.
2024-07-14 10:33:47 -07:00
Yuanqiang Liu 4bb7ddf601
[Stablehlo] enable stablehlo's python extension binding (#3529) 2024-07-10 13:00:13 +08:00
Vivek Khandelwal 2f231f394e
Bump Onnx Version to 1.16.1 (#3515)
This commit adds the support for new data types: uint4, and int4 and
uint8 tensor protos. Also, it moves some tests from failing to crashing.

Fixes https://github.com/llvm/torch-mlir/issues/3507

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-07-01 22:15:45 +05:30
Aart Bik 6fece25ff3
[torch-mlir][sparse] add decomposition features to sparse compiler (#3505)
Fixes https://github.com/llvm/torch-mlir/issues/3499
2024-06-28 10:18:36 -07:00
Aart Bik 1f73895f93
[torch-mlir] bump to llvm/llvm-project@9b78ddf3b2 (#3491)
This bump triggered an upstream assert. Includes a WAR for #3506.

Also includes several things I needed to do to repro:

* When TORCH_MLIR_TEST_CONCURRENCY=1, test runs will be printed.
* Added TORCH_MLIR_TEST_VERBOSE=1 handling to enable verbose mode
(useful on CI).

---------

Co-authored-by: Stella Laurenzo <stellaraccident@gmail.com>
2024-06-27 19:28:02 -07:00
Yuanqiang Liu 61f37ae8a3
[fx importer] support fx importer with lower version torch (#3486) 2024-06-24 15:39:19 +08:00
Andrea 🦈 51902ec2dc
Create MLIR functions for ONNX operators that are functions (#3409)
Resolves #3384.

Many ONNX operators are defined by functions and therefore could be
expanded into simpler ONNX operations during importing, avoiding the
need for tools downstream to support these operators directly.

This commit adds this capability to onnx_importer.py. When importing a
node, the schema for the node's operator is retrieved. If the schema
provides a function for the operator, a specialized version for the
node's types and attributes will be created and imported as an MLIR
function with private visibility. An MLIR function call will then be
emitted, instead of a normal operator node. Caching is used to avoid
generating redundant functions within the same module.

In order to avoid a disruptive change to the importer output for a
large number of operators that already have TorchOnnxToTorch support,
an allowlist strategy is used by default. With this commit, only one
operator is allowlisted for expansion, MeanVarianceNormalization.
However, many other operators can be correctly expanded by the current
code, so hopefully the allowlist can be gradually extended. It is
possible to disable the allowlist in the configuration, in which case
all functions are expanded (useful for testing).

Tools downstream of the importer may now need to do inlining when
consuming the output of the importer, e.g.:

  cat imported.mlir | torch-mlir-opt --inline --convert-onnx-to-torch

Explanations for subtle code changes:

- Looking up the correct schema and function for an operator requires
  knowing the opset version. NodeImporter retrieves this from the
  opset imports on the ModelProto retained by the GraphInfo. Previously,
  the model_proto field on GraphInfo was None when importing a subgraph
  in import_regions, but this conflicts with the new need for opset
  version info. Since the apparent purpose of setting it to None was to
  control how GraphInfo generates its input map, a new flag is added to
  GraphInfo (is_subgraph) to control this behavior, so that the actual
  ModelProto can now be provided without breaking this. This also turned
  out to be useful for getting the Config via ModelInfo via GraphInfo.
- Some operators' functions are context-dependent, which means the
  function definition depends on the types of the inputs. Therefore node
  importing now needs to look up the types of a node's inputs, not just
  its outputs as was the case previously. Consequently the operand to
  find_type_proto_for_name() may now be a graph input or initializer in
  some cases, so it has to be updated.
2024-06-14 10:11:26 -07:00
Wu Yuan a02e14e971
[FxImporter] Add aten._scaled_dot_product_flash_attention_for_cpu to default decomposition table (#3456) 2024-06-14 10:52:09 +08:00
zjgarvey c0eb6d89c0
[ONNX] add some args to the onnx importer to assist shape_inference (#3445)
Adds the following arguments:
- "--clear-domain": enabling this flag (default False) will delete the
domain attribute from each node in the onnx model before importing.
Shape inference does not seem to work for onnx ops in custom domains. In
the rare case when these ops have a corresponding counterpart in base
onnx, enabling this flag might allow shape inference to work properly.
- "--opset-version": allows setting the opset version manually. This
will cause the importer to attempt to update the opset_version of the
onnx model before importing. Newer opset versions sometimes have more
robust shape inference patterns.
2024-06-12 10:55:14 -05:00
Sambhav Jain 7e0e23c668
Test custom op import with symbolic shapes (#3431)
Tests the basic constructs of registering a custom op and its abstract
implementations (with FakeTensors) in python, going through TorchDynamo
export, followed by importing the shape expressions in the Torch
dialect.

Also fixes the importer were previously the symbolic bind op insertion
was not gated in one place.
2024-06-09 00:32:49 -07:00
Rob Suderman 7f188eb824
Add f8 types to fx importer (#3434)
Missing types for tracing float8 types.
2024-06-07 13:58:18 -07:00
Sambhav Jain d0a818a03e
Representing Symbolic Shape Expressions in Torch Dialect (#3372)
Torch Dialect with symbolic shape expressions:
```ll
module {                                                                                                                                                                                                     
  func.func @main(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> {                                                                                   
    %0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int                                                                                                                                    
    %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 100} : !torch.int                                                                                                                                   
    %2 = torch.symbolic_int "s3" {min_val = 0, max_val = 50} : !torch.int                                                                                                                                    
    
    torch.bind_symbolic_shape %arg0, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                          
    torch.bind_symbolic_shape %arg1, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                          
    
    %3 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>                                                                                                                  
    torch.bind_symbolic_shape %3, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                             
    
    %4 = torch.aten.sigmoid %arg1 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>                                                                                                               
    torch.bind_symbolic_shape %4, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                             
    
    %5 = torch.prim.ListConstruct %3, %3, %4 : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor>                                               
    %int1 = torch.constant.int 1                                                                                                                                                                             
    %6 = torch.aten.cat %5, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32>                                                                                                          
    torch.bind_symbolic_shape %6, [%0, %1, %2], #affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32>                                                                            
    
    return %6 : !torch.vtensor<[?,?,3],f32>                                                                                                                                                                  
  }                                                                                                                                                                                                          
}              
```

For reference, this is the TorchDynamo exported program with symbolic
shape expressions that the above Torch dialect program is imported from:
```py
ExportedProgram:                                                                                                                                                                                             
    class GraphModule(torch.nn.Module):                                                                                                                                                                      
        def forward(self, x: "f32[s0, s1, 3]", y: "f32[s0, s3, 3]"):                                                                                                                                         
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:31 in forward, code: a = torch.tanh(x)                                        
            tanh: "f32[s0, s1, 3]" = torch.ops.aten.tanh.default(x);  x = None                                                                                                                               
                                                                                                                                                                                                             
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:32 in forward, code: b = torch.sigmoid(y)                                     
            sigmoid: "f32[s0, s3, 3]" = torch.ops.aten.sigmoid.default(y);  y = None                                                                                                                         
                                                                                                                                                                                                             
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:33 in forward, code: return torch.cat((a, a, b), dim=1)                       
            cat: "f32[s0, 2*s1 + s3, 3]" = torch.ops.aten.cat.default([tanh, tanh, sigmoid], 1);  tanh = sigmoid = None                                                                                      
            return (cat,)                                                                                                                                                                                    
                                                                                                                                                                                                             
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat'), target=None)])                                               
Range constraints: {s0: ValueRanges(lower=5, upper=10, is_bool=False), s1: ValueRanges(lower=0, upper=100, is_bool=False), s3: ValueRanges(lower=0, upper=50, is_bool=False)} 
```

Huge credit to @stellaraccident for the inputs that helped evaluate the
various design options and arrive at the representation of choice.


- [x] Op definitions for symbolic_int and bind_symbolic_shape ops
- [x] fx_importer updates to import range constraints + create
symbolic_int ops
- [x] fx_importer changes for AffineMapAttr building + adding
bind_symbolic_shape ops
- [x] custom printer/parser for inlined AffineMap expressions in mlir
assembly
- [x] Dialect lit test
- [x] fx_importer python lit tests
- [ ] Cleanup pass to remove these ops (can add in a follow-on)
2024-06-07 04:04:03 -07:00
Rob Suderman 0a6861b1e8
Add conversion operation for bool resolved_literal (#3410)
Resolving `bool` literals can result in a type change to uint8. This
needs to be converted back to the expected type before returning to the
wrapped `torch` operators.
2024-06-03 14:43:38 -07:00
penguin_wwy a5d3b546f8
[FxImporter] Fix embedding bag (#3387) 2024-05-29 14:46:21 +08:00
penguin_wwy d924d0047f
[FxImporter] Fix primitive type in return (#3379) 2024-05-23 09:55:33 +08:00
penguin_wwy 972d47b586
[FxImporter] Fix constant bool tensor (#3375) 2024-05-22 22:59:01 +08:00
Sambhav Jain 6e485574e5
[Pipeline] Use dedicated simplification pipeline for TorchDynamo frontend (#3376)
Discord Thread:
https://discord.com/channels/636084430946959380/1238330633328005243

## Context: 

[This](https://github.com/llvm/torch-mlir/blob/main/python/torch_mlir/fx.py#L61)
was updated to support e2e tests for the TorchDynamo frontend in
Torch-MLIR, where we run FX decompositions and import the FX IR to
generate Torch dialect, followed by
`torch-function-to-torch-backend-pipeline`, skipping only the shape/type
refinement for now. However, we should be able to skip many of the torch
simplification passes, as depicted in the [frontend
roadmap](https://github.com/llvm/torch-mlir/blob/main/docs/images/roadmap_frontend.png).

Based on IREE's TorchDynamo
[pipeline](https://github.com/iree-org/iree/blob/main/compiler/plugins/input/Torch/InputConversion/Passes.cpp#L29),
the only two passes we seem to require are: `ReduceOpVariantsPass` and
`DecomposeComplexOpsPass`. This is inline with our findings as well
based on initial exploration.

This PR creates a dedicated frontend simplification pipeline for
TorchDynamo / FX Importer which calls only `ReduceOpVariantsPass` and
`DecomposeComplexOpsPass`. We rely on the e2e fx_importer tests to
ensure we're not regressing by removing many of the passes that were
historically needed for TorchScript.

One notable change here is that we do not call the
`LowerToBackendContractPass` anymore, which used to call
`TorchSimplificationPipeline` iteratively until VerifyBackendContract
was clean. Some of this was required for the shape/type refinement to
converge, which seems a non-issue for Dynamo frontend. Do we anticipate
this (the iterative invocation of TorchSimplificationPipeline followed
by VerifyBackendContract) to be worth retaining in the Dynamo frontend
pipeline? If so, I can make those changes, PLMK.
2024-05-22 05:23:18 -07:00
penguin_wwy c2c1c2cfa4
[FxImporter] Fix failed e2e case (#3365) 2024-05-22 00:20:54 +08:00
Stella Laurenzo 00efec0b73
[linalg] Implement strict mode lowering for aten.view. (#3319)
* Enables assume_strict_symbolic_shapes on fx_importer imported
programs, indicating strict shape semantics.
* Reworks the view->reshape lowering to take advantage of strict mode
and do one of:
  * Collapse to 0D
  * Flatten/Unflatten when there is an inferred dim.
  * Fallback to tensor.reshape
* Splits some test cases up and adds an attribute to control the old
pattern (so new corners can be tested in strict mode in isolation).
* Dynamic inferred mode needs upstream work to generalize expand_shape
(so that case is suppressed here).
* Deletes the assert from the existing tensor.reshape lowering if strict
shape mode is enabled (since the condition it is dynamically asserting
cannot happen).
2024-05-10 13:45:50 -07:00
penguin_wwy 64b59c7fc3
[FxImporter] Eliminate the dependency on the refinement pass (#3309) 2024-05-10 02:44:36 +08:00
zjgarvey 0abc5868b5
[ONNX] Enables data propogation for onnx shape inference (#3280)
This small change seems to dramatically improve shape inference for
complex models, and consequently, improves onnx importer reliability.
2024-05-08 09:29:23 -07:00
penguin_wwy c3bd850951
[FxImporter] Add backend lowering to Fx API (#3288) 2024-05-07 20:58:50 +08:00
Xida Ren (Cedar) 33eef15e42
Support onnx.If (#2825)
This is probably a decent PR for learning about blocks and regions.

If you're here to learn about that, consider also looking at
lib/Conversion/TorchToSCF/TorchToSCF.cpp

While this doesn't include an e2e test, it is tested downstream in
https://github.com/nod-ai/SHARK-TestSuite/blob/main/e2eshark/onnx/operators/If/model.py

---------

Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>
2024-04-30 18:36:40 +00:00
penguin_wwy b2185195e8
[NFC] Update black version (#3256)
* Update black version to support 3.11/3.12
* Reformat code
2024-04-29 11:06:01 +08:00
penguin_wwy 9f64748f97
[FxImporter] Synchronize the collection of symbolic torch ops (#3236) 2024-04-29 10:09:00 +08:00
Stella Laurenzo 6877302504
[NFC reformat] Applies pre-commit formatting to Python files. (#3244)
This is a large change because prior to this point, Python files in the
project were not consistently formatted. This reformats them all with
black defaults.

Based on experience with prior projects, if you have a dev/long-term
branch with Python patches, you can minimize merge conflicts prior to
rebasing to include this commit by running `black` on your modified
Python files, squashing, and then rebasing/merging.
2024-04-27 14:16:31 -07:00
penguin_wwy 944a6df611
Extract the Python APIs in the pt1 dir back to the root (#3237) 2024-04-27 18:27:37 +08:00
penguin_wwy 3aa81f78d8
[FxImporter] Replace local_scalar_dense in fx_importer (#3180) 2024-04-17 22:45:47 +08:00
penguin_wwy e4b11a0ab4
[FxImporter] Fix fx importer test config and clean xfail set (#3176) 2024-04-16 22:36:07 -07:00
penguin_wwy 398aeeec87
[FxImporter] Fix kwarg operands in fx importer (#3166)
Remove the `kwarg_only` limitation, for example
```
torch.add(x, 3.0, alpha=2)
```
compiled to
```
%0 = torch.aten.add.Scalar %arg0, %float3.000000e00, %int1
```
fix to
```
%0 = torch.aten.add.Scalar %arg0, %float3.000000e00, %int2
```
2024-04-16 13:17:05 -07:00
penguin_wwy af5509c5d9
[FxImporter] Type conversion to resolve the mismatch between Py type and schema type (#3163) 2024-04-15 23:14:19 -07:00
Stella Laurenzo ffaaf08c31
[fx] Fix type inference for scalar/int types. (#3099)
This was discovered in a downstream test suite and was due to a control
flow nesting merge issue. In-tree test added and fixed.
2024-04-02 13:56:43 -07:00
penguin_wwy 5325d3e6e6
[fx] Fix type hint for fx importer (#3066)
Co-authored-by: Stella Laurenzo <stellaraccident@gmail.com>
2024-04-01 17:31:43 -07:00
Rob Suderman ec4cb8be44
Bump LLVM to llvm/llvm-project@0030fc4ac7 (#3079)
Co-authored-by: Peiming Liu <peiming@google.com>
2024-04-01 16:34:59 -07:00
Stella Laurenzo 826786bdd0
[fx] Support ExportedProgram buffer mutation. (#3080)
In the prior state when I supported mutation of user inputs by treating
them as mutable-tensor SSA values, I had left the case of buffer
mutation only vaguely implemented until a concrete use emerged.
    
This patch reworks this buffer mutation support by assuming that buffers
must be resolved via the hooks symbolically and treated with load/store
semantics. This is implied in the structure since we have no SSA value
that represents a buffer and we already assume that reading parameters
happens via such a mechanism.
2024-04-01 14:18:12 -07:00
Stella Laurenzo 282e9b0e64
[fx] Fix type determination for multi-return ops and static `None` returns. (#3081)
In practice, this was caught by the way that AOT autograd traces
`convolution_backward`. For the unit test, we just repro it with a
custom op.
2024-04-01 09:39:38 -07:00
Stella Laurenzo e2343cf4ce
[fx] Implement auto_functionalized higher order op. (#3063)
* Also adds the basic scaffolding for handling more of these, which will
be needed for cond, while, etc.
* Refactors some of the support in the generic OpOverload emitter so it
can be shared with these other special forms.

This has been on my list for a while, but it just so happens that as
part of upgrading to PyTorch 2.3 and a pure upstream flow in Turbine, we
were using a feature that required integration with auto_functionalized.
This is perhaps the "weirdest" of the higher-order ops and a poor place
to start, but needs must. We have testing for this in Turbine.

Full support in Turbine has an entire custom ops facility. I've reduced
this down to a unit test in torch-mlir.
2024-03-26 17:06:05 -07:00
Stella Laurenzo 17eeac880a
[fx] Accept `func_visibility=` and return created func op. (#3054)
This is a partial landing of #3046 while waiting for an upstream change
for the rest of it.
2024-03-25 16:48:06 -07:00
Stella Laurenzo 6ea857c644
[fx] Make the lift_fresh_copy -> clone special form use kwargs. (#3045)
At some point, this op became kwarg-only instead of arg/kwarg.
Discovered when upgrading to PyTorch 2.3.

Also adds a test as this was untested in-tree (was caught out of tree).
2024-03-21 15:34:40 -07:00
penguin_wwy 7616d637fd
Add stateless fx graph import (#3036) 2024-03-21 14:44:54 -07:00
Aart Bik fe59f1ee0d
[torch-mlir][sparse] higher dimension COO (#3042)
Lift this from 2-dim only to n-dim for n>=2
2024-03-19 15:59:07 -07:00
penguin_wwy f34c187ac4
Normalize type hints to be compatible with multiple Python versions (#3028)
Although we provide a wheel package for Python 3.8, it may actually
throw the following exception:
`TypeError: 'type' object is not subscriptable`
2024-03-15 08:29:48 -07:00
Sambhav Jain 0b2f9c89a2
Bring back `dynamic_shapes` constraints in fx importer API (#3026)
https://github.com/llvm/torch-mlir/pull/2992 dropped `constraints` from
the fx importer API,
[breaking](https://github.com/cruise-automation/mlir-tcp/actions/runs/8284385380/job/22669774071)
downstream AOT compile tests in `mlir-tcp` that use it. This knob has
been soft-deprecated for a while now, replaced by `dynamic_shapes` - a
more ergonomic interface. This PR brings back dynamic_shapes constraints
in the new supported form. Also added a python lit test with dynamic
shaped annotations.
2024-03-14 10:26:34 -07:00
Daniel Garvey 80c7bc3f7a
fximporter: support newer torch versions (#2999)
uses version checking since attributes exist in both versions, the only
thing that changes is what we're receiving as an fx graph
2024-03-08 14:58:50 -06:00
Vivek Khandelwal 6e84752c39
build: manually update PyTorch version (#2992)
Set PyTorch and TorchVision version to nightly release 2024-03-07.
This commit also removes the deprecated constraints API:
342e7929b8

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-03-07 21:42:38 +05:30