While playing with TorchDynamo on ResNet18. I notice following issues:
- `prims.convert_element_type` can’t be canonicalized even if the input
and the output share the same type
- `aten.max_pool2d_with_indices` is always used instead of
`aten.max_pool2d`, even if the second returned output (indices) has no
user
This PR fixes above issues by adding a folder to the
PrimsConvertElementTypeOp and a canonicalizer to the
AtenMaxPool2dWithIndicesOp
Lit test:
`cmake --build build --target check-torch-mlir-all`
---------
Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
This is part 1 of ~3, formatting all miscellaneous text files and CPP files matched by a first run of pre-commit. These tend to be low change-traffic and are likely not disruptive.
Subsequent patches will format Python files and remaining CPP files.
This folds small version of the tensor-scalar comparison operators as
they are commonly used for shape computations. This includes le, lt, ge,
gt, eq, and ne.
A handful of operations are commonly used in shape calculations (slice,
concat, broadcast). Added these additional folders to better propagate
simple shape computations.
A bunch of small fixes are interlinked and trigger crashes if not
addressed as a group. This includes:
- aten view when expand from a rank-0 tensor
- slice folder with negative indices
- `aten._shape_as_tensor` folder on a rank-0 tensor
- `aten.cat` of a tensor with a length-0 tensor
Simple folder for limited size aten tensor operations. This is primarily
useful for shape computation folding as they unfortunately can use
`aten` operators. Add, sub, mul are common examples of these folders.
Folds aten::index_select ops under the following conditions:
1. If the input and output are the same shape, the indexing operation is
a NOP, so just return the input.
2. If the input has shape <1x1x...xNx...x1> (all 1's except for one
dim), and the output shape is <1x1x...x1> (all 1's), then there is a
single index, so extract the single element value and return a tensor
with that value.
---------
Co-authored-by: Dave Liddell <dliddell@xilinx.com>
If a tensor is initialized by a list with a single constant integer,
this folder turns it into a torch.vtensor.literal
---------
Co-authored-by: Dave Liddell <dliddell@xilinx.com>
So that the CumSum Op in OPT can get the constant that it requires to be lowered to TMTensor
---------
Co-authored-by: Rob Suderman <rob.suderman@gmail.com>
Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>
We were seeing some assertion failures after some checks around folders
were tightened up in LLVM:
https://github.com/llvm/llvm-project/pull/75887 . This PR essentially
moves the logic that used to be applied at the LLVM level into the
folder, which seems to be the suggested fix.
I'm not sure if the IR that caused issues for us _should_ be valid?
```
%1 = torch.aten.detach %arg0 : !torch.tensor<[1],f32> -> !torch.tensor
```
A better fix might be to create a verifier ensuring the result of
`aten.detach` has the same type as its operand.
---------
Co-authored-by: aaron-stgeorge <aaron.stgeorge@getcruise.com>
Strict symbolic shapes allow us to assume numpy-style dynamic broadcasts
never occur. This allows us to strengthen the folder for broadcasts to
cases where the rank is the same and all shapes match (including dynamic
sentinel values).
* RecomposeComplexOps: Remove dead slice op
* lib/Dialect/Torch/IR/TorchOps.cpp: Fold slice ops even when they are on non-value tensors
* lib/Conversion/TorchToTosa/TorchToTosa.cpp: Fix slice start/end out of range/none
* lib/Dialect/Torch/IR/TorchOps.cpp: AtenSliceTensorOp::fold: Fold slices that go from 0:int_max
* More tests for aten.split.Tensor
Currently, the op `torch.tensor_static_info_cast` will not get
canonicalized away if the result type has any shape or dtype
information. This is because `isValidSubtype` only returns true when
the tensor types being compared are exactly the same or the supertype
has no shape and dtype information. Being unable to canonicalize away
the `torch.tensor_static_info_cast` gets in the way of further
optimizations, such as shape propagation.
This commit improves `isValidSubtype` by adding logic that compares
the shapes and dtypes of the two tensor types to determine of one type
is indeed a valid subtype of the other.
Fixes https://github.com/llvm/torch-mlir/issues/1926