torch-mlir/lib/Conversion/TorchToLinalg
Ramiro Leal-Cavazos 58abec5c0a
Add `reduction` support to `torch.nll_loss_forward` (#624)
This commit does a couple of things. First, it fixes a bug in the
`linalg.generic` body of the `nll_loss_forward` lowering where the
`ignoreIndex` was being compared with the loop index rather than the
current element of the `target` tensor. This was not being caught by
the tests because they were not testing the case where `ingnoreIndex`
actually corresponds to a value in `target`. This has been fixed.

Second, this commit adds support for the `reduction` argument in
`torch.nll_loss_forward` as well as support for 1-D inputs. In order
to simplify the lowering code, I've refactored the code that creates
the `linalg.generic` ops for elementwise and reduction ops into static
functions, to avoid having boilerplate code for indexing maps, etc
that can be very error prone.

Note: The function `convertScalarToDtype` was moved to before all the
conversion patterns, but nothing in it was modified.
2022-02-28 11:01:23 -08:00
..
CMakeLists.txt Move external/torch-mlir to the root of the repo. 2021-09-27 17:11:08 -07:00
TorchToLinalg.cpp Add `reduction` support to `torch.nll_loss_forward` (#624) 2022-02-28 11:01:23 -08:00