torch-mlir/include/npcomp/Dialect/Torch/IR/TorchOps.h

107 lines
3.9 KiB
C
Raw Normal View History

2020-09-29 03:02:35 +08:00
//===------------------------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_DIALECT_TORCH_IR_TORCHOPS_H
#define NPCOMP_DIALECT_TORCH_IR_TORCHOPS_H
#include "mlir/IR/BuiltinTypes.h"
2020-09-29 03:02:35 +08:00
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
Introduce `!torch.tensor` / `!torch.vtensor` types. This removes our reliance on the numpy dialect and avoids our off-label use of the builtin tnesor type for modeling unknown dtypes. The `!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor. The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic tensor. The new types look as follows syntactically: ``` // Least-static-information, non-value-semantic tensor. !torch.tensor // Explicit form of least-static-information variant. !torch.tensor<*,unk> // Least-static-information, value-semantic tensor. !torch.vtensor // Explicit form of least-static-information variant. !torch.vtensor<*,unk> // Fixed-set of allowable element types, with first-class support for // Torch's frontend signedness semantics. !torch.tensor<*,si32> // First-class support for unknown dtypes. !torch.tensor<[?,?,?],unk> // Standard MLIR representation of `?` for unknown dimensions. !torch.tensor<[?,2,?,4],unk> // Statically shaped / dtyped example. !torch.vtensor<[1,2,3,4],f32> ``` This required fairly significant changes throughout the compiler, but overall it is a big cleanup. We now have a much clearer layering of "the Torch frontend lowering" vs "lowering to std + linalg + etc.". At the C++ level, there is `ValueTensorType`, `NonValueTensorType`. We also have a helper `BaseTensorType` (kind of like ShapedType) which interoperates with those two. Included changes: - New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for creating torch tensor literals in the frontend. - Consistently use signedness for the types (except i1 which I didn't touch -- we need to sort out the situation with !basicpy.BoolType there anyway so will be attending to that soon) - Frontend can annotate whether an argument to the function has value semantics. We currently require this, as our backend contract does not currently allow us to even model the non-value-semantic case. Before, the value-semantic assumption was randomly injected in the middle of the pass pipeline. - Move ArrayToTensor (now called MaximizeValueSemantics) and RefinePublicReturn passes to torch dialect. - The TorchToStd and TorchToLinalg passes are now type conversions from `!torch.vtensor` to `tensor` and use the dialect conversion infra. The overall conversion pipeline is set up following the best practices of the "Type Conversions the Not-So-Hard Way" talk. This required introducing `torch-func-builtin-tensorize` and `torch-finalizing-builtin-tensorize` passes analogous to the upstream bufferization passes with the corresponding names (mostly just copypasta from there). - Misc Torch-level canonicalizations -- we now cleanly layer the lowering to std later in the pipeline, so we are gradually lessening our reliance on random std constant folding before we get to that point. Recommended review order: - New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp - New ops in TorchOps.td / TorchOps.cpp - Less important / more mechanical stuff - Frontend changes. - Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
#include "mlir/Interfaces/InferTypeOpInterface.h"
Properly model "derefinement". In terms of IR structure, TorchScript allows types to vary in many circumstances where MLIR requires pointer-identical types. In particular, it is valid to pass any subtype in place of a type. For example, if an `Optional[int]` is required somewhere in the IR, it is legal to pass a value of just `int` (but not the other way around; see `torch.prim.unchecked_cast`). In effect, every *use* can have a different type. We introduce a new op `torch.derefine` that models that impedance mismatch. This op allows casting a value from one type to a type that it is a subtype of to model this behavior. Recommended review order: - TorchOps.td for new torch.derefine (and updated docs for `torch.prim.unchecked_cast`) - new test code in if.py, loop.py, function-derefine.py - new code in node_importer.cpp for handling derefinement insertion - function_importer.cpp and utils changes in torch_to_mlir_utils.cpp Properly handling derefinement on function boundaries required relayering the code so that graph_importer.cpp/.h is now function_importer.cpp/.h because only the `torch::jit::Function` (actually the `c10::FunctionSchema` it holds) knows the derefined types that are actually needed at the boundary (see `function-derefine.py` for a test). Annoyingly, this churns all the functions which are now prefixed with `__torch__.` but that is more correct anyway (that is their linkage name in the `torch::jit::CompilationUnit`; the previous `mb.import_function` was actually buggy in the case of functions calling each other as it would reference their unqualified name). With this change, we can import `resnet18` from `torchvision` :) IR: https://gist.github.com/silvasean/6426a5272d8a6c7caae533fce05ab704
2021-03-02 09:24:15 +08:00
#include "mlir/Interfaces/SideEffectInterfaces.h"
Significantly restructure torch/aten import design. This is a really major and invasive restructuring of the way we get torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into MLIR. Please forgive the challenging review, but due to the sheer invasiveness, it wasn't really practical do do it in sane smaller pieces. This fully replaces everything that was already working on the TorchScript path (actually, more -- we added tanh support to TorchToLinalg in order to delete the older code paths). Additionally, I've kept the lights on for the acap path too, including what little e2e stuff was working before (for expediency I made a few tiny compromises along the way that will be easy to undo when we give that path proper attention). Overview of the new design: - The torch operator `somens::someunqualname.someoverloadname` is imported as `torch.somens.someunqualname.someoverloadname` (skip the last dotted part if the overload name is empty), OR, if we don't have such an op registered, it is imported as `torch.operator "somens.someunqualname.someoverloadname" (...) : ...`. - The addition of the "overload name" is a critical element here, as the `(ns,unqual,overload)` triple is unique, which solves a lot of problems we were having. - This involves having separate MLIR ops for the `trailing_` and `.out` variants and all the different overloads. This seemed necessary, because the set of overloads is so wild and varied and unstructured. The previous design was leaning into some underlying structure that just isn't there -- the default situation is the "random overload that we want to manage on the MLIR side", rather than that being an exception. E.g. `aten::ne` (not-equal) has 21 overloads, only 4 of which are c10 dispatcher ops see [gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1), and the "out" variant is really called `.Tensor_out` instead of `.out` as it frequently is for other ops. - Rationale for all being in `torch` namespace: the set of operators are so varied and unstructured that "dialect per namespace" doesn't result in anything resembling the typical MLIR dialect boundary expectations. We could maybe draw the boundary at dispatcher ops vs non-dispatcher ops, but that doesn't seem to really result in very much useful structure at this point in time. - Note: within the torch operator registry, we effectively have a mini-basicpy subdialect (already type-resolved), which is reasonably structured. - The existing Torch op interfaces are also removed -- now that we track the overload name, we can losslessly find the original operator. - Instead of `ATenRecognizeKernelsPass`, we now have a `ReduceOpVariantsPass` that keys off certain traits (and perhaps eventually interfaces) to reduce variants of ops to a smaller set, ideally operating on immutable tensors and using surrounding ops to model the mutability/aliasing aspects. - Note: `torch.ns.unqual.overload` ops allow both immutable and mutable tensors (unlike the previous hard distinction in the common case). This is a premonition for a future change that will introduce a bona fide `!torch.tensor` type that will clean up a bunch of stuff. - `TorchToLinalg` / `TorchToStd` supercede the existing "ATen->TCF->TCP->Linalg" path. - The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`. It should look somewhat familiar, but the benefit of hindsight has allowed a lot of simplifications. The overall trend seems to be to make the `torch` dialect a nice layer independent of anything else. It feels like as a natural result of various future changes we will be removing the reliance on basicpy+numpy dialects and have a nice self-contained type system too that properly models the TorchScript type system (including proper subtyping, mutable/immutable tensors, optional dtype, etc.). Recommended review order: - Start at some of the new import IR, e.g. in `frontends/pytorch/test/node_import/prim.py`, `frontends/pytorch/test/acap_export/test_export_add3.py`, and other tests. - `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py` and associated generated files: - `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td` - `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td` - Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h` - Various code changes in the import path in `frontends/pytorch/csrc/builder`. Probably most interesting is the new code in `torch_to_mlir_utils.cpp` that has the logic to create the `torch.operator` ops or `torch.ns.unqual.overload` ops. This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe), just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
#include "npcomp/Dialect/Torch/IR/TorchTraits.h"
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
#include "npcomp/Interfaces/Traits.h"
2020-09-29 03:02:35 +08:00
#define GET_OP_CLASSES
#include "npcomp/Dialect/Torch/IR/TorchOps.h.inc"
Introduce `!torch.tensor` / `!torch.vtensor` types. This removes our reliance on the numpy dialect and avoids our off-label use of the builtin tnesor type for modeling unknown dtypes. The `!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor. The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic tensor. The new types look as follows syntactically: ``` // Least-static-information, non-value-semantic tensor. !torch.tensor // Explicit form of least-static-information variant. !torch.tensor<*,unk> // Least-static-information, value-semantic tensor. !torch.vtensor // Explicit form of least-static-information variant. !torch.vtensor<*,unk> // Fixed-set of allowable element types, with first-class support for // Torch's frontend signedness semantics. !torch.tensor<*,si32> // First-class support for unknown dtypes. !torch.tensor<[?,?,?],unk> // Standard MLIR representation of `?` for unknown dimensions. !torch.tensor<[?,2,?,4],unk> // Statically shaped / dtyped example. !torch.vtensor<[1,2,3,4],f32> ``` This required fairly significant changes throughout the compiler, but overall it is a big cleanup. We now have a much clearer layering of "the Torch frontend lowering" vs "lowering to std + linalg + etc.". At the C++ level, there is `ValueTensorType`, `NonValueTensorType`. We also have a helper `BaseTensorType` (kind of like ShapedType) which interoperates with those two. Included changes: - New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for creating torch tensor literals in the frontend. - Consistently use signedness for the types (except i1 which I didn't touch -- we need to sort out the situation with !basicpy.BoolType there anyway so will be attending to that soon) - Frontend can annotate whether an argument to the function has value semantics. We currently require this, as our backend contract does not currently allow us to even model the non-value-semantic case. Before, the value-semantic assumption was randomly injected in the middle of the pass pipeline. - Move ArrayToTensor (now called MaximizeValueSemantics) and RefinePublicReturn passes to torch dialect. - The TorchToStd and TorchToLinalg passes are now type conversions from `!torch.vtensor` to `tensor` and use the dialect conversion infra. The overall conversion pipeline is set up following the best practices of the "Type Conversions the Not-So-Hard Way" talk. This required introducing `torch-func-builtin-tensorize` and `torch-finalizing-builtin-tensorize` passes analogous to the upstream bufferization passes with the corresponding names (mostly just copypasta from there). - Misc Torch-level canonicalizations -- we now cleanly layer the lowering to std later in the pipeline, so we are gradually lessening our reliance on random std constant folding before we get to that point. Recommended review order: - New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp - New ops in TorchOps.td / TorchOps.cpp - Less important / more mechanical stuff - Frontend changes. - Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
namespace mlir {
namespace NPCOMP {
namespace Torch {
/// Create code to copy `tensor` to type `newType`.
///
/// This involves two independent steps, which we keep orthogonal in our
/// IR representation.
/// 1. Adding/removing static information about sizes/dtype.
/// 2. Performing the copy, which allows us to add/remove value semantics.
Value copyTensorToType(OpBuilder &builder, Location loc, BaseTensorType newType,
Value tensor);
} // namespace Torch
} // namespace NPCOMP
} // namespace mlir
Support multiple instances of a class in GlobalizeObjectGraph. This happens in practice with e.g. ResNet from torchvision (multiple instances of the same BatchNorm class). The key observation is that for this program, and the expected set of programs, we can convert the program to the same globalized form with a bit more static analysis and effort to suitably monomorphize the program. Though what we are doing here is fairly annoying to implement, it saves any nontrivial later pass from having to do similar analyses (or worse). E.g. shape inference would need to be object-graph aware, mutation/lifetime analyses would have to be aware, etc. Additionally, it would make us front-load what it means to have a !torch.nn.Module type on an ABI boundary, which we are just not ready to handle. I'm really, really hoping that in practice we can get away with this, otherwise it's going to be really rough designing a representation (and implementing everything to back it) that is convenient to transform and gracefully scales from full object graph (in the most dynamic case) down to a fixed set of global slots like we have here (in the most static case, which we presume a lot of practical programs fall into). This also involved introducing a `torch-prepare-for-globalize-object-graph` pass that does a minimal set of lowerings to simplify the IR into a more orthogonal and analyzable form, and a `torch-globalize-pipeline` helper. Recommended review order: - updated documentation in Passes.td - new tests in `globalize-object-graph-multiple-instances*.mlir` - implementation of GlobalizeObjectGraph.cpp - PrepareForGlobalizeObjectGraph.cpp + prepare-for-globalize-object-graph.mlir - misc stuff like torch-globalize-pipeline pipeline definition. With this, we can import, globalize, and inline resnet18 from torchvision: https://gist.github.com/silvasean/821586afc19b67d9fb72030b2e0adeb8
2021-03-10 12:33:21 +08:00
template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::SlotOp> {
using SlotOp = ::mlir::NPCOMP::Torch::SlotOp;
static SlotOp getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return SlotOp::getFromOpaquePointer(pointer);
}
static SlotOp getTombstoneKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return SlotOp::getFromOpaquePointer(pointer);
}
static unsigned getHashValue(SlotOp val) {
return hash_value(val.getAsOpaquePointer());
}
static bool isEqual(SlotOp lhs, SlotOp rhs) { return lhs == rhs; }
};
template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::NnModuleOp> {
using NnModuleOp = ::mlir::NPCOMP::Torch::NnModuleOp;
static NnModuleOp getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return NnModuleOp::getFromOpaquePointer(pointer);
}
static NnModuleOp getTombstoneKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return NnModuleOp::getFromOpaquePointer(pointer);
}
static unsigned getHashValue(NnModuleOp val) {
return hash_value(val.getAsOpaquePointer());
}
static bool isEqual(NnModuleOp lhs, NnModuleOp rhs) { return lhs == rhs; }
};
template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::ClassTypeOp> {
using ClassTypeOp = ::mlir::NPCOMP::Torch::ClassTypeOp;
static ClassTypeOp getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return ClassTypeOp::getFromOpaquePointer(pointer);
}
static ClassTypeOp getTombstoneKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return ClassTypeOp::getFromOpaquePointer(pointer);
}
static unsigned getHashValue(ClassTypeOp val) {
return hash_value(val.getAsOpaquePointer());
}
static bool isEqual(ClassTypeOp lhs, ClassTypeOp rhs) { return lhs == rhs; }
};
template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::GlobalSlotOp> {
using OpTy = ::mlir::NPCOMP::Torch::GlobalSlotOp;
static OpTy getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return OpTy::getFromOpaquePointer(pointer);
}
static OpTy getTombstoneKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return OpTy::getFromOpaquePointer(pointer);
}
static unsigned getHashValue(OpTy val) {
return hash_value(val.getAsOpaquePointer());
}
static bool isEqual(OpTy lhs, OpTy rhs) { return lhs == rhs; }
};
2020-09-29 03:02:35 +08:00
#endif // NPCOMP_DIALECT_TORCH_IR_TORCHOPS_H