Commit Graph

4 Commits (89d4931324589bf75ed088e98888ff21fe7cd41e)

Author SHA1 Message Date
Sean Silva f49ebf1690 Add `!torch.int` type.
This replaces the ad-hoc use of `i64` throughout the Torch layer, and
helps to keep it crystal clear the distinction between `!torch.int`
(which is modeling the Python `int` type) and the various types that
serve as dtypes of tensors, which are a totally different type universe.

Changes:
- `!torch.int` type and C bindings.
- Change `torch.constant.int` parser to not need the `: i64` at the end.
- `m_TorchConstantInt` matcher to aid with matching constants.
- BackendTypeConversion changes for `!torch.int` -> `i64` type
  conversion.
- Refactor finalizing patterns in FinalizingBackendTypeConversionPass
  (they were getting very repetitive).
- Mechanical rewriting of `!torch.int` to `i64` in all the tests, and
  `AnyTorchIntType` to `Torch_IntType` in the `.td` files.
2021-06-17 07:28:23 -07:00
Sean Silva db282fd1b4 Introduce native `!torch.none` type.
- Add `torch.constant.none` op to construct it (naming is chosen to be
  analogous to Torch's representation of a prim::Constant with
  NoneType, rather than using the "singleton" terminology of Basicpy).
2021-06-14 13:30:58 -07:00
Sean Silva 179105ca3e Add basic MLP's to the e2e curriculum.
These tests pass on the reference backend.

- Add aten.linear op + shape xfer function + ATen->Linalg lowering.
 - Note: this needs to be more automated, and needs to cover more cases.
 - Current not implemented caveats:
  - size-1 broadcasting for bias vector (either static-size-1 or ? case)
  - higher-rank aten.linear ops (not produced by torch.nn.Linear though)
  - type promotion (still don't even know the exact rules here)
- Add folder for torch.derefine op. Now the inliner can clean it up as
  it inlines. (call boundaries are a main place we need to insert
  torch.derefine) This is brittle -- the other important case is control
  flow which will need to be handled via an extension to
  RefineTypes.cpp (as will more robust call handling). River has an
  in-flight patch to update it to the new dataflow framework so I didn't
  want to do anything intrusive here.
    - Also adjust torch.derefine syntax to use the keyword `to` instead of
      `->`, as most type-only, cast-like ops do.
2021-04-27 12:18:54 -07:00
Sean Silva 43dba03afd Properly model "derefinement".
In terms of IR structure, TorchScript allows types to vary in many
circumstances where MLIR requires pointer-identical types. In particular,
it is valid to pass any subtype in place of a type. For example, if an
`Optional[int]` is required somewhere in the IR, it is legal to pass a
value of just `int` (but not the other way around; see
`torch.prim.unchecked_cast`). In effect, every *use* can have a different
type.

We introduce a new op `torch.derefine` that models that impedance
mismatch. This op allows casting a value from one type to a type that it
is a subtype of to model this behavior.

Recommended review order:
- TorchOps.td for new torch.derefine (and updated docs for
  `torch.prim.unchecked_cast`)
- new test code in if.py, loop.py, function-derefine.py
- new code in node_importer.cpp for handling derefinement insertion
- function_importer.cpp and utils changes in torch_to_mlir_utils.cpp

Properly handling derefinement on function boundaries required
relayering the code so that graph_importer.cpp/.h is now
function_importer.cpp/.h because only the `torch::jit::Function`
(actually the `c10::FunctionSchema` it holds) knows the derefined types that are
actually needed at the boundary (see `function-derefine.py` for a test).

Annoyingly, this churns all the functions which are now prefixed with
`__torch__.` but that is more correct anyway (that is their linkage name
in the `torch::jit::CompilationUnit`; the previous `mb.import_function`
was actually buggy in the case of functions calling each other as it
would reference their unqualified name).

With this change, we can import `resnet18` from `torchvision` :)
IR: https://gist.github.com/silvasean/6426a5272d8a6c7caae533fce05ab704
2021-03-03 15:09:44 -08:00