2020-09-29 03:02:35 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
2021-09-30 00:03:40 +08:00
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
2020-09-29 03:02:35 +08:00
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
[torch-mlir earthmoving (1/N)] C/C++ code movement.
This creates the `external/torch-mlir` directory as an
LLVM_EXTERNAL_PROJECTS-compatible project (analogous to
`iree-dialects`) and completes movement/rename of all pure MLIR C/C++
compiler code into there. The next step will be to move all the Python
code / code that links/includes PyTorch C++ code (which currently lives
in `frontends/pytorch`) into a subdirectory here.
I call this "earthmoving" because it is mostly mechanical changes and
renames. As a quick summary (we can change this down the road easily)
- C++ `mlir::NPCOMP::Torch -> mlir::torch::Torch`
- CAPI `npcompTorchListTypeGet -> torchMlirTorchListTypeGet`
- preprocessor `#ifndef NPCOMP_ -> #ifndef TORCHMLIR_`
- CMake `NPCOMPFoo -> TorchMLIRFoo`
The goal of this is to create a standalone project creating a center of
mass for entry into the MLIR ecosystem from PyTorch, suitable in scope
for eventual inclusion/ownership in PyTorch. The idea is that
`external/torch-mlir` will some day be pulled out into its own
repository, and then npcomp will simply pull it in as a submodule.
Layering-wise, what lives in `torch-mlir` lowers code from PyTorch
(currently TorchScript, but TorchFX or pytorch/xla-style tracing are
possible extensions) down to what we have been calling the "Torch
backend contract" which is cleaned up IR (inlining, simplifcation,
conversion to value tensors, ...) entirely in the `torch` dialect. This
is the branching off point for further lowering, of which npcomp takes
one opinion (outside `torch-mlir` of course!), namely the
`TorchConversion` dialect/transforms which lower to IR suitable for IREE
and other linalg-on-tensors based lower-level compilers.
Summary of changes:
- move `{include,lib,test}/Dialect/Torch` into `torch-mlir`
- move relevant parts of CAPI into `torch-mlir`.
- leave a few things related to the `torch-mlir` Python build commented
out, which should be resolved in a subsequent change.
2021-09-10 03:24:10 +08:00
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
2022-12-09 01:49:54 +08:00
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
2020-09-29 03:02:35 +08:00
|
|
|
|
2022-04-27 03:27:51 +08:00
|
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
2020-09-30 05:17:34 +08:00
|
|
|
#include "mlir/IR/Builders.h"
|
2021-01-28 08:35:44 +08:00
|
|
|
#include "mlir/IR/BuiltinOps.h"
|
2021-04-27 02:42:41 +08:00
|
|
|
#include "mlir/IR/PatternMatch.h"
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
#include "mlir/IR/TypeUtilities.h"
|
2022-01-11 15:42:53 +08:00
|
|
|
#include "mlir/Support/LLVM.h"
|
2021-10-16 06:23:59 +08:00
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
2022-03-10 08:44:22 +08:00
|
|
|
#include "llvm/ADT/BitVector.h"
|
2021-02-18 03:28:51 +08:00
|
|
|
#include "llvm/ADT/StringMap.h"
|
2022-01-11 15:42:53 +08:00
|
|
|
#include "llvm/Support/Casting.h"
|
2020-09-30 05:17:34 +08:00
|
|
|
|
2020-09-29 03:02:35 +08:00
|
|
|
using namespace mlir;
|
[torch-mlir earthmoving (1/N)] C/C++ code movement.
This creates the `external/torch-mlir` directory as an
LLVM_EXTERNAL_PROJECTS-compatible project (analogous to
`iree-dialects`) and completes movement/rename of all pure MLIR C/C++
compiler code into there. The next step will be to move all the Python
code / code that links/includes PyTorch C++ code (which currently lives
in `frontends/pytorch`) into a subdirectory here.
I call this "earthmoving" because it is mostly mechanical changes and
renames. As a quick summary (we can change this down the road easily)
- C++ `mlir::NPCOMP::Torch -> mlir::torch::Torch`
- CAPI `npcompTorchListTypeGet -> torchMlirTorchListTypeGet`
- preprocessor `#ifndef NPCOMP_ -> #ifndef TORCHMLIR_`
- CMake `NPCOMPFoo -> TorchMLIRFoo`
The goal of this is to create a standalone project creating a center of
mass for entry into the MLIR ecosystem from PyTorch, suitable in scope
for eventual inclusion/ownership in PyTorch. The idea is that
`external/torch-mlir` will some day be pulled out into its own
repository, and then npcomp will simply pull it in as a submodule.
Layering-wise, what lives in `torch-mlir` lowers code from PyTorch
(currently TorchScript, but TorchFX or pytorch/xla-style tracing are
possible extensions) down to what we have been calling the "Torch
backend contract" which is cleaned up IR (inlining, simplifcation,
conversion to value tensors, ...) entirely in the `torch` dialect. This
is the branching off point for further lowering, of which npcomp takes
one opinion (outside `torch-mlir` of course!), namely the
`TorchConversion` dialect/transforms which lower to IR suitable for IREE
and other linalg-on-tensors based lower-level compilers.
Summary of changes:
- move `{include,lib,test}/Dialect/Torch` into `torch-mlir`
- move relevant parts of CAPI into `torch-mlir`.
- leave a few things related to the `torch-mlir` Python build commented
out, which should be resolved in a subsequent change.
2021-09-10 03:24:10 +08:00
|
|
|
using namespace mlir::torch;
|
|
|
|
using namespace mlir::torch::Torch;
|
2020-10-23 14:31:34 +08:00
|
|
|
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Utilities
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
Rework how global slot initializers work.
Rather than a per-global-slot initializer region, we now have one for
the whole module. For example, it might look like this:
```
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor)
@list(%1 : !torch.list<tensor>)
]
}
```
This new structure allows GlobalizeObjectGraph to create the initializer in a
much simpler way, avoiding the need to reason about whether different slots
alias each other. Reasoning about whether slots alias each other now is the
responsibility of InlineGlobalSlots, which has to do a much more complicated
analysis, implemented using MLIR's dataflow analysis framework.
Recommended review order:
- Check out the new IR constructs in the .mlir files of various passes
- Op definitions (*.td)
- Changes to GlobalizeObjectGraph pass.
- InlineGlobalSlots pass (~total rewrite)
- Misc changes:
- Moving torchMlirAdjustStaticInformation for sharing with C++ code.
- EraseModuleInitializer pass
To make this a bit nicer, it would be good to have a `torch.module` op
with an initializer region attached. That would be more invasive though.
This change has highlighted certain aspects of our project layering
which are worth calling out. None of our backends can handle global
slots, so we enforce that there are no global slots before backend
lowering. At an earlier stage in the project, we had aspirations of
transparently handling mutable global state and such, but for reasons
described below, that is no longer a goal. So really global slots should
be seen as a progressive lowering step as part of inlining all the
IValue's in the original program (GlobalizeObjectGraph is also one such
step).
Over time, with insights from work like IREE-JAX, it has become clear
that there isn't a reliable programming model we can compile for users
where we just transparently handle mutable global state (and some other
things, like lists and dictionaries). There is a need for an "outer
program" that orchestrates more restricted subroutines of the kind we
can handle in our compile flow here. The benefit of that is that it
decouples considerations like shapes, dtypes, etc. from the program
constructs used in the outer program. As long as the outer program can
efficiently invoke (pipelining/async/etc.) high-performance
data-parallel numerical subroutines of the kind we compile in our flow
here, then there is a complete programming model. This is also
consistent with the direction of upstream PyTorch which is becoming more
tracing-based (which inherently loses a lot of program structure, which
then has to be applied back with an "outer program" orchestrating the
traced subroutines).
2022-07-14 02:45:56 +08:00
|
|
|
Value mlir::torch::Torch::adjustStaticInformation(OpBuilder &builder,
|
|
|
|
Location loc, Value value,
|
|
|
|
Type desiredType,
|
|
|
|
bool userAllowsRefinement) {
|
|
|
|
Type type = value.getType();
|
|
|
|
|
|
|
|
// If the value is already of the desired type, we're done.
|
|
|
|
if (type == desiredType)
|
|
|
|
return value;
|
|
|
|
|
|
|
|
// If the type is a tensor, then adjust the static information.
|
|
|
|
if ((type.isa<ValueTensorType>() && desiredType.isa<ValueTensorType>()) ||
|
|
|
|
(type.isa<NonValueTensorType>() &&
|
|
|
|
desiredType.isa<NonValueTensorType>())) {
|
|
|
|
Value adjusted = builder.create<TensorStaticInfoCastOp>(value.getLoc(),
|
|
|
|
desiredType, value);
|
|
|
|
return adjusted;
|
|
|
|
}
|
|
|
|
|
|
|
|
// If the type is a subtype of desiredType, then we need to derefine it to
|
|
|
|
// desiredType, unless the user allows refinement.
|
|
|
|
if (isValidSubtype(type, desiredType)) {
|
|
|
|
if (!userAllowsRefinement) {
|
|
|
|
Value adjusted =
|
|
|
|
builder.create<DerefineOp>(value.getLoc(), desiredType, value);
|
|
|
|
return adjusted;
|
|
|
|
} else {
|
|
|
|
return value;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// If the desiredType is subtype of type, then we assume that the desiredType
|
|
|
|
// is dynamically valid, so we do an unchecked cast.
|
|
|
|
if (isValidSubtype(desiredType, type)) {
|
|
|
|
Value adjusted =
|
|
|
|
builder.create<PrimUncheckedCastOp>(value.getLoc(), desiredType, value);
|
|
|
|
return adjusted;
|
|
|
|
}
|
|
|
|
|
|
|
|
// No known adjustment.
|
|
|
|
return Value();
|
|
|
|
}
|
|
|
|
|
[torch-mlir earthmoving (1/N)] C/C++ code movement.
This creates the `external/torch-mlir` directory as an
LLVM_EXTERNAL_PROJECTS-compatible project (analogous to
`iree-dialects`) and completes movement/rename of all pure MLIR C/C++
compiler code into there. The next step will be to move all the Python
code / code that links/includes PyTorch C++ code (which currently lives
in `frontends/pytorch`) into a subdirectory here.
I call this "earthmoving" because it is mostly mechanical changes and
renames. As a quick summary (we can change this down the road easily)
- C++ `mlir::NPCOMP::Torch -> mlir::torch::Torch`
- CAPI `npcompTorchListTypeGet -> torchMlirTorchListTypeGet`
- preprocessor `#ifndef NPCOMP_ -> #ifndef TORCHMLIR_`
- CMake `NPCOMPFoo -> TorchMLIRFoo`
The goal of this is to create a standalone project creating a center of
mass for entry into the MLIR ecosystem from PyTorch, suitable in scope
for eventual inclusion/ownership in PyTorch. The idea is that
`external/torch-mlir` will some day be pulled out into its own
repository, and then npcomp will simply pull it in as a submodule.
Layering-wise, what lives in `torch-mlir` lowers code from PyTorch
(currently TorchScript, but TorchFX or pytorch/xla-style tracing are
possible extensions) down to what we have been calling the "Torch
backend contract" which is cleaned up IR (inlining, simplifcation,
conversion to value tensors, ...) entirely in the `torch` dialect. This
is the branching off point for further lowering, of which npcomp takes
one opinion (outside `torch-mlir` of course!), namely the
`TorchConversion` dialect/transforms which lower to IR suitable for IREE
and other linalg-on-tensors based lower-level compilers.
Summary of changes:
- move `{include,lib,test}/Dialect/Torch` into `torch-mlir`
- move relevant parts of CAPI into `torch-mlir`.
- leave a few things related to the `torch-mlir` Python build commented
out, which should be resolved in a subsequent change.
2021-09-10 03:24:10 +08:00
|
|
|
Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc,
|
|
|
|
BaseTensorType newType,
|
|
|
|
Value tensor) {
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
auto originalType = tensor.getType().cast<BaseTensorType>();
|
|
|
|
// Adjust the static information in the type to match between the original and
|
|
|
|
// new types.
|
|
|
|
if (!originalType.hasSameSizesAndDtype(newType)) {
|
|
|
|
tensor = builder.create<TensorStaticInfoCastOp>(
|
|
|
|
loc, originalType.getWithSizesAndDtypeFrom(newType), tensor);
|
|
|
|
}
|
2021-06-19 04:47:47 +08:00
|
|
|
|
|
|
|
// Unless both the original and new types are both value tensors, we end
|
|
|
|
// up creating one op that converts between the value and non-value tensor
|
|
|
|
// domains. If both the original and new types are both non-value tensors,
|
|
|
|
// then we do the copy by going to a value tensor and back.
|
|
|
|
if (tensor.getType().isa<NonValueTensorType>())
|
|
|
|
tensor = builder.create<CopyToValueTensorOp>(loc, tensor);
|
|
|
|
if (newType.isa<NonValueTensorType>())
|
|
|
|
tensor = builder.create<CopyToNonValueTensorOp>(loc, tensor);
|
|
|
|
|
|
|
|
return tensor;
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
}
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
bool mlir::torch::Torch::isListPotentiallyMutated(Value list) {
|
|
|
|
assert(list.getType().isa<Torch::ListType>());
|
|
|
|
return llvm::any_of(list.getUsers(), potentiallyMutatesListOperands);
|
|
|
|
}
|
|
|
|
|
|
|
|
bool mlir::torch::Torch::potentiallyMutatesListOperands(Operation *op) {
|
|
|
|
// TODO: Find a better place to put this assertion.
|
|
|
|
assert((!op->hasTrait<Torch::OpTrait::HasValueSemantics>() ||
|
|
|
|
op->hasTrait<OpTrait::ReadOnly>()) &&
|
|
|
|
"HasValueSemantics should imply ReadOnly!");
|
|
|
|
// ReadOnly ops trivially do not mutate any list operands.
|
|
|
|
if (op->hasTrait<Torch::OpTrait::ReadOnly>())
|
|
|
|
return false;
|
|
|
|
|
|
|
|
// Ops with no MemoryEffectOpInterface effects also do not mutate any list
|
|
|
|
// operands.
|
|
|
|
if (auto effects = dyn_cast<MemoryEffectOpInterface>(op)) {
|
|
|
|
if (effects.hasNoEffect())
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Conservatively assume that an op might mutate any list operands.
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2021-08-18 01:59:47 +08:00
|
|
|
static IntegerAttr getI64IntegerAttr(MLIRContext *context, int64_t value) {
|
|
|
|
return IntegerAttr::get(IntegerType::get(context, 64), value);
|
|
|
|
}
|
|
|
|
|
2022-04-25 20:06:41 +08:00
|
|
|
static FloatAttr getF64FloatAttr(MLIRContext *context, double value) {
|
|
|
|
return FloatAttr::get(Float64Type::get(context), value);
|
|
|
|
}
|
|
|
|
|
2023-03-07 02:12:58 +08:00
|
|
|
static Value getScalarIntValue(Value input, Location loc,
|
|
|
|
PatternRewriter &rewriter) {
|
2022-08-16 13:24:08 +08:00
|
|
|
auto inputType = input.getType();
|
|
|
|
if (inputType.isa<Torch::IntType>()) {
|
|
|
|
return input;
|
|
|
|
}
|
2023-03-07 02:12:58 +08:00
|
|
|
|
|
|
|
auto inputTensorType = inputType.dyn_cast<BaseTensorType>();
|
|
|
|
if (!inputTensorType)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
Type inputDtype = inputTensorType.getOptionalDtype();
|
|
|
|
if (!inputDtype || !inputDtype.isInteger(64))
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
std::optional<unsigned> inputRank = getTensorRank(input);
|
|
|
|
if (!inputRank || *inputRank != 0)
|
|
|
|
return nullptr;
|
|
|
|
|
2022-06-18 02:49:36 +08:00
|
|
|
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
|
2023-03-07 02:12:58 +08:00
|
|
|
auto val = valueTensorLiteralOp.getValue()
|
|
|
|
.cast<DenseElementsAttr>()
|
|
|
|
.getSplatValue<int64_t>();
|
|
|
|
return rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(val));
|
2022-06-18 02:49:36 +08:00
|
|
|
} else if (auto primNumToTensorScalarOp =
|
|
|
|
input.getDefiningOp<PrimNumToTensorScalarOp>()) {
|
2023-03-07 02:12:58 +08:00
|
|
|
return primNumToTensorScalarOp.getA();
|
2023-07-20 16:46:44 +08:00
|
|
|
} else if (auto tensorIntOp = input.getDefiningOp<AtenTensorIntOp>()) {
|
|
|
|
return tensorIntOp.getT();
|
2022-06-18 02:49:36 +08:00
|
|
|
}
|
2023-03-07 02:12:58 +08:00
|
|
|
return nullptr;
|
2022-06-18 02:49:36 +08:00
|
|
|
}
|
|
|
|
|
2023-11-21 13:26:17 +08:00
|
|
|
static Value getScalarFloatValue(Value input, Location loc,
|
|
|
|
PatternRewriter &rewriter) {
|
|
|
|
auto inputType = input.getType();
|
|
|
|
if (inputType.isa<Torch::FloatType>()) {
|
|
|
|
return input;
|
|
|
|
}
|
|
|
|
|
|
|
|
auto inputTensorType = inputType.dyn_cast<BaseTensorType>();
|
|
|
|
if (!inputTensorType)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
Type inputDtype = inputTensorType.getOptionalDtype();
|
|
|
|
if (!inputDtype ||
|
|
|
|
(!inputDtype.isF16() && !inputDtype.isF32() && !inputDtype.isF64()))
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
std::optional<unsigned> inputRank = getTensorRank(input);
|
|
|
|
if (!inputRank || *inputRank != 0)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
|
|
|
|
auto val = valueTensorLiteralOp.getValue()
|
|
|
|
.cast<DenseFPElementsAttr>()
|
|
|
|
.getSplatValue<FloatAttr>()
|
|
|
|
.getValueAsDouble();
|
|
|
|
return rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
loc, rewriter.getF64FloatAttr(val));
|
|
|
|
} else if (auto primNumToTensorScalarOp =
|
|
|
|
input.getDefiningOp<PrimNumToTensorScalarOp>()) {
|
|
|
|
return primNumToTensorScalarOp.getA();
|
|
|
|
} else if (auto tensorFloatOp = input.getDefiningOp<AtenTensorFloatOp>()) {
|
|
|
|
return tensorFloatOp.getT();
|
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2021-01-28 08:35:44 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// MethodOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
2024-01-30 01:59:33 +08:00
|
|
|
auto func = symbolTable.lookupNearestSymbolFrom<func::FuncOp>(
|
|
|
|
*this, getFunctionAttr());
|
2021-01-28 08:35:44 +08:00
|
|
|
if (!func)
|
2022-12-08 04:20:41 +08:00
|
|
|
return emitError() << "'@" << getFunction()
|
2021-01-28 08:35:44 +08:00
|
|
|
<< "' does not reference a valid function";
|
2021-02-18 03:28:51 +08:00
|
|
|
if (func.getVisibility() != SymbolTable::Visibility::Private)
|
2022-12-08 04:20:41 +08:00
|
|
|
return emitError() << "'@" << getFunction()
|
2021-02-18 03:28:51 +08:00
|
|
|
<< "' must reference a private function";
|
|
|
|
if (func.isDeclaration())
|
2022-12-08 04:20:41 +08:00
|
|
|
return emitError() << "'@" << getFunction()
|
2021-02-18 03:28:51 +08:00
|
|
|
<< "' must reference a function that is defined (not "
|
|
|
|
"merely declared)";
|
|
|
|
auto expectedReceiverArgType = NnModuleType::get(
|
|
|
|
getContext(), getOperation()->getParentOfType<ClassTypeOp>().getName());
|
2022-04-27 03:27:51 +08:00
|
|
|
if (func.getFunctionType().getNumInputs() == 0 ||
|
|
|
|
func.getFunctionType().getInput(0) != expectedReceiverArgType) {
|
2022-12-08 04:20:41 +08:00
|
|
|
return emitError() << "the referenced function '" << getFunction()
|
2021-02-18 03:28:51 +08:00
|
|
|
<< "' must have a first argument of type "
|
|
|
|
<< expectedReceiverArgType;
|
|
|
|
}
|
2021-01-28 08:35:44 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// NnModuleOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-03-16 08:54:57 +08:00
|
|
|
LogicalResult NnModuleOp::verify() {
|
|
|
|
for (Operation &child : *getBody())
|
2021-02-18 03:28:51 +08:00
|
|
|
if (!isa<SlotOp, NnModuleTerminatorOp>(&child))
|
|
|
|
return child.emitOpError() << "is not allowed inside 'torch.nn_module'";
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult NnModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
2021-09-04 02:38:00 +08:00
|
|
|
auto classType = symbolTable.lookupNearestSymbolFrom<ClassTypeOp>(
|
|
|
|
*this, SymbolRefAttr::get(getContext(), getClassName()));
|
2021-02-18 03:28:51 +08:00
|
|
|
if (!classType)
|
|
|
|
return emitError() << "'" << getClassName()
|
|
|
|
<< "' does not reference a valid class type";
|
|
|
|
|
|
|
|
auto attrs = llvm::to_vector<6>(getBody()->getOps<SlotOp>());
|
|
|
|
auto attrDefs = llvm::to_vector<6>(classType.getBody()->getOps<AttrOp>());
|
|
|
|
if (attrs.size() != attrDefs.size())
|
|
|
|
return emitError() << "number of 'torch.slot's in a 'torch.nn_module' must "
|
|
|
|
"match number of 'torch.attr's in "
|
|
|
|
"the corresponding 'torch.class_type'";
|
|
|
|
for (int i = 0, e = attrs.size(); i != e; i++) {
|
|
|
|
SlotOp attr = attrs[i];
|
|
|
|
AttrOp attrDef = attrDefs[i];
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!isValidSubtype(attr.getValue().getType(), attrDef.getType()) ||
|
|
|
|
attr.getName() != attrDef.getName()) {
|
2021-02-18 03:28:51 +08:00
|
|
|
return attr.emitOpError()
|
|
|
|
.append("is expected to match type and name of '",
|
|
|
|
attrDef.getOperation(), "'")
|
|
|
|
.attachNote(attrDef.getLoc())
|
|
|
|
.append("see torch.attr at corresponding index ", i, " here");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2021-06-05 06:57:21 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimListConstructOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-03-16 08:54:57 +08:00
|
|
|
LogicalResult PrimListConstructOp::verify() {
|
|
|
|
auto resultType = getResult().getType();
|
2021-06-05 06:57:21 +08:00
|
|
|
auto resultElementType = resultType.dyn_cast<ListType>().getContainedType();
|
|
|
|
auto matchResultElementType = [&](Type type) {
|
2021-08-08 10:33:39 +08:00
|
|
|
return isValidSubtype(type, resultElementType);
|
2021-06-05 06:57:21 +08:00
|
|
|
};
|
2022-03-16 08:54:57 +08:00
|
|
|
if (!llvm::all_of(getOperandTypes(), matchResultElementType)) {
|
|
|
|
return emitError() << "operand types should have the same type as the "
|
|
|
|
"list contained type";
|
2021-06-17 06:53:15 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
2021-06-05 06:57:21 +08:00
|
|
|
}
|
|
|
|
|
2021-08-08 10:33:39 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimDictConstructOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-03-16 08:54:57 +08:00
|
|
|
LogicalResult PrimDictConstructOp::verify() {
|
2021-08-08 10:33:39 +08:00
|
|
|
auto isValidSubTypeOf = [](Type expectedType) {
|
|
|
|
return [=](Type type) { return isValidSubtype(type, expectedType); };
|
|
|
|
};
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!llvm::all_of(getKeys().getTypes(), isValidSubTypeOf(getKeyType())))
|
2022-03-16 08:54:57 +08:00
|
|
|
return emitError() << "keys should be of Dict key type";
|
2021-08-08 10:33:39 +08:00
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!llvm::all_of(getValues().getTypes(), isValidSubTypeOf(getValueType())))
|
2022-03-16 08:54:57 +08:00
|
|
|
return emitError() << "values should be of Dict value type";
|
2021-08-11 09:28:50 +08:00
|
|
|
|
2021-08-08 10:33:39 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2021-02-18 03:28:51 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ClassTypeOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-03-16 08:54:57 +08:00
|
|
|
LogicalResult ClassTypeOp::verify() {
|
2021-02-18 03:28:51 +08:00
|
|
|
llvm::StringMap<Operation *> namesToOps;
|
2022-03-16 08:54:57 +08:00
|
|
|
for (Operation &child : getBody()->without_terminator()) {
|
2021-02-18 03:28:51 +08:00
|
|
|
if (!isa<AttrOp, MethodOp>(&child))
|
|
|
|
return child.emitOpError() << "is not allowed inside `torch.class_type`";
|
|
|
|
StringRef name;
|
|
|
|
if (auto attr = dyn_cast<AttrOp>(child))
|
2022-12-08 04:20:41 +08:00
|
|
|
name = attr.getName();
|
2021-02-18 03:28:51 +08:00
|
|
|
else
|
2022-12-08 04:20:41 +08:00
|
|
|
name = cast<MethodOp>(child).getName();
|
2021-02-18 03:28:51 +08:00
|
|
|
auto itAndWasInserted = namesToOps.insert({name, &child});
|
|
|
|
auto it = itAndWasInserted.first;
|
|
|
|
bool wasInserted = itAndWasInserted.second;
|
|
|
|
if (!wasInserted) {
|
2022-03-16 08:54:57 +08:00
|
|
|
auto diag = emitOpError().append("has duplicate attr/method with name '",
|
|
|
|
name, "'");
|
2021-02-18 03:28:51 +08:00
|
|
|
diag.attachNote(it->second->getLoc())
|
|
|
|
.append("see first conflicting attr/method here");
|
|
|
|
diag.attachNote(child.getLoc())
|
|
|
|
.append("see second conflicting attr/method here");
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2021-01-28 08:35:44 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2021-03-02 07:00:32 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimLoopOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-09-13 06:09:57 +08:00
|
|
|
OperandRange PrimLoopOp::getEntrySuccessorOperands(RegionBranchPoint point) {
|
|
|
|
assert(point == getRegion());
|
2022-12-08 04:20:41 +08:00
|
|
|
return getIterArgsInit();
|
2021-03-02 07:00:32 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
void PrimLoopOp::getSuccessorRegions(
|
2023-09-13 06:09:57 +08:00
|
|
|
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
|
|
|
|
Region ®ion = getRegion();
|
|
|
|
if (!point.getRegionOrNull()) {
|
|
|
|
regions.emplace_back(®ion, region.getArguments().slice(1));
|
2021-03-02 07:00:32 +08:00
|
|
|
return;
|
|
|
|
}
|
2023-09-13 06:09:57 +08:00
|
|
|
assert(point == region);
|
|
|
|
regions.emplace_back(®ion, region.getArguments().slice(1));
|
2021-03-02 07:00:32 +08:00
|
|
|
regions.emplace_back(getResults());
|
|
|
|
}
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
bool PrimLoopOp::isForLike() {
|
|
|
|
bool b;
|
2022-12-08 04:20:41 +08:00
|
|
|
return matchPattern(getInitialCondition(), m_TorchConstantBool(&b)) && b;
|
2022-03-10 08:44:22 +08:00
|
|
|
}
|
|
|
|
|
2021-08-19 23:13:55 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimLoopConditionOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-09-13 06:09:57 +08:00
|
|
|
MutableOperandRange
|
|
|
|
PrimLoopConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
|
2021-08-19 23:13:55 +08:00
|
|
|
// Pass all operands except the condition to the successor which is the
|
|
|
|
// parent loop op.
|
2022-12-08 04:20:41 +08:00
|
|
|
return getIterArgsMutable();
|
2021-08-19 23:13:55 +08:00
|
|
|
}
|
|
|
|
|
2021-06-17 01:23:26 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimIfOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-13 02:47:12 +08:00
|
|
|
ParseResult PrimIfOp::parse(OpAsmParser &parser, OperationState &result) {
|
2021-06-17 01:23:26 +08:00
|
|
|
// Create the regions.
|
|
|
|
result.regions.reserve(2);
|
|
|
|
Region *thenRegion = result.addRegion();
|
|
|
|
Region *elseRegion = result.addRegion();
|
|
|
|
|
|
|
|
auto &builder = parser.getBuilder();
|
2022-04-27 03:27:51 +08:00
|
|
|
OpAsmParser::UnresolvedOperand cond;
|
2021-06-17 01:23:26 +08:00
|
|
|
Type boolType = builder.getType<Torch::BoolType>();
|
|
|
|
if (parser.parseOperand(cond) ||
|
|
|
|
parser.resolveOperand(cond, boolType, result.operands))
|
|
|
|
return failure();
|
|
|
|
// Parse results type list.
|
|
|
|
if (parser.parseArrowTypeList(result.types))
|
|
|
|
return failure();
|
|
|
|
// Parse the 'then' region.
|
|
|
|
if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
|
|
|
|
return failure();
|
|
|
|
// Parse the 'else' region.
|
|
|
|
if (parser.parseKeyword("else"))
|
|
|
|
return failure();
|
|
|
|
if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
|
|
|
|
return failure();
|
|
|
|
// Parse the optional attribute list.
|
|
|
|
if (parser.parseOptionalAttrDict(result.attributes))
|
|
|
|
return failure();
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-02-13 02:47:12 +08:00
|
|
|
void PrimIfOp::print(OpAsmPrinter &p) {
|
2022-12-08 04:20:41 +08:00
|
|
|
p << " " << getCondition();
|
2022-02-13 02:47:12 +08:00
|
|
|
p << " -> (" << getResultTypes() << ") ";
|
2022-12-08 04:20:41 +08:00
|
|
|
p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false);
|
2022-01-26 14:16:30 +08:00
|
|
|
p << " else ";
|
2022-12-08 04:20:41 +08:00
|
|
|
p.printRegion(getElseRegion(), /*printEntryBlockArgs=*/false);
|
2021-06-17 01:23:26 +08:00
|
|
|
|
2022-02-13 02:47:12 +08:00
|
|
|
p.printOptionalAttrDict((*this)->getAttrs());
|
2021-06-17 01:23:26 +08:00
|
|
|
}
|
|
|
|
|
2023-09-13 06:09:57 +08:00
|
|
|
void PrimIfOp::getSuccessorRegions(RegionBranchPoint point,
|
2021-06-19 10:33:14 +08:00
|
|
|
SmallVectorImpl<RegionSuccessor> ®ions) {
|
2021-06-17 01:23:26 +08:00
|
|
|
// The `then` and the `else` region branch back to the parent operation.
|
2023-09-13 06:09:57 +08:00
|
|
|
if (point.getRegionOrNull()) {
|
2021-06-17 01:23:26 +08:00
|
|
|
regions.push_back(RegionSuccessor(getResults()));
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
// If the condition is constant, we can give a more precise answer.
|
2023-08-16 00:53:28 +08:00
|
|
|
bool condition;
|
|
|
|
if (matchPattern(getCondition(), m_TorchConstantBool(&condition))) {
|
|
|
|
Region *executedRegion = condition ? &getThenRegion() : &getElseRegion();
|
2021-06-17 01:23:26 +08:00
|
|
|
regions.push_back(RegionSuccessor(executedRegion));
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
// If the condition isn't constant, both regions may be executed.
|
2022-12-08 04:20:41 +08:00
|
|
|
regions.push_back(RegionSuccessor(&getThenRegion()));
|
|
|
|
regions.push_back(RegionSuccessor(&getElseRegion()));
|
2021-06-17 01:23:26 +08:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Replaces the given op with the contents of the given single-block region,
|
|
|
|
/// using the operands of the block terminator to replace operation results.
|
|
|
|
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
|
|
|
|
Region ®ion, ValueRange blockArgs = {}) {
|
|
|
|
assert(llvm::hasSingleElement(region) && "expected single-region block");
|
|
|
|
Block *block = ®ion.front();
|
|
|
|
Operation *terminator = block->getTerminator();
|
|
|
|
ValueRange results = terminator->getOperands();
|
2023-03-21 01:31:05 +08:00
|
|
|
rewriter.inlineBlockBefore(block, op, blockArgs);
|
2021-06-17 01:23:26 +08:00
|
|
|
rewriter.replaceOp(op, results);
|
|
|
|
rewriter.eraseOp(terminator);
|
|
|
|
}
|
|
|
|
|
|
|
|
void PrimIfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
// If the condition is constant, delete the dead branch and inline the live
|
|
|
|
// branch.
|
|
|
|
patterns.add(+[](PrimIfOp op, PatternRewriter &rewriter) {
|
2024-01-30 01:59:33 +08:00
|
|
|
auto constantBool =
|
|
|
|
op.getCondition().getDefiningOp<Torch::ConstantBoolOp>();
|
2021-06-17 01:23:26 +08:00
|
|
|
if (!constantBool)
|
|
|
|
return rewriter.notifyMatchFailure(op, "non-constant condition");
|
2024-01-30 01:59:33 +08:00
|
|
|
replaceOpWithRegion(rewriter, op,
|
|
|
|
constantBool.getValue() ? op.getThenRegion()
|
|
|
|
: op.getElseRegion());
|
2021-06-17 01:23:26 +08:00
|
|
|
return success();
|
|
|
|
});
|
2022-03-10 08:44:22 +08:00
|
|
|
// If the thenRegion and elseRegion yield the same Value's, then use those
|
|
|
|
// directly.
|
|
|
|
patterns.add(+[](PrimIfOp op, PatternRewriter &rewriter) {
|
2022-12-08 04:20:41 +08:00
|
|
|
auto trueTerminator = op.getThenRegion().front().getTerminator();
|
|
|
|
auto falseTerminator = op.getElseRegion().front().getTerminator();
|
2022-03-10 08:44:22 +08:00
|
|
|
bool madeChange = false;
|
|
|
|
SmallVector<int> resultsToErase;
|
|
|
|
for (auto t : llvm::zip(trueTerminator->getOperands(),
|
|
|
|
falseTerminator->getOperands(), op->getResults())) {
|
|
|
|
auto trueVal = std::get<0>(t);
|
|
|
|
auto falseVal = std::get<1>(t);
|
|
|
|
auto resultToBeReplaced = std::get<2>(t);
|
|
|
|
if (trueVal == falseVal) {
|
|
|
|
madeChange |= !resultToBeReplaced.use_empty();
|
|
|
|
resultToBeReplaced.replaceAllUsesWith(trueVal);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// We leave it up to a separate pattern (not yet implemented) to erase the
|
|
|
|
// results that are now dead. That transformation is independently useful,
|
|
|
|
// and also pretty tricky to implement because it changes the number of
|
|
|
|
// results.
|
|
|
|
return success(madeChange);
|
|
|
|
});
|
|
|
|
// Erase any dead results.
|
|
|
|
patterns.add(+[](PrimIfOp op, PatternRewriter &rewriter) {
|
|
|
|
llvm::BitVector resultsToErase(op.getNumResults());
|
|
|
|
for (auto result : llvm::enumerate(op->getResults())) {
|
|
|
|
if (result.value().use_empty())
|
|
|
|
resultsToErase.set(result.index());
|
|
|
|
}
|
|
|
|
|
|
|
|
// If no results have uses and there are no side effects, just erase the op.
|
|
|
|
// Approximate the body having no side effects by checking if it is just a
|
|
|
|
// terminator.
|
|
|
|
// Note: We don't want to make this logic too fancy, because in general,
|
|
|
|
// checking for recursive side effects can result in a quadratic amount of
|
|
|
|
// work (N nested If's each resulting in O(N) work). It should probably be
|
|
|
|
// split into its own pattern if we want to make it fancier.
|
|
|
|
if (resultsToErase.all() &&
|
2022-12-08 04:20:41 +08:00
|
|
|
llvm::hasSingleElement(op.getThenRegion().front()) &&
|
|
|
|
llvm::hasSingleElement(op.getElseRegion().front())) {
|
2022-03-10 08:44:22 +08:00
|
|
|
rewriter.eraseOp(op);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// If there are no results to erase, we're done.
|
|
|
|
if (!resultsToErase.any())
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
SmallVector<Type> newResultTypes;
|
|
|
|
for (int i = 0, e = op->getNumResults(); i < e; ++i) {
|
|
|
|
if (resultsToErase[i])
|
|
|
|
continue;
|
|
|
|
newResultTypes.push_back(op->getResult(i).getType());
|
|
|
|
}
|
2024-01-30 01:59:33 +08:00
|
|
|
auto newIf = rewriter.create<PrimIfOp>(op->getLoc(), newResultTypes,
|
|
|
|
op.getCondition());
|
2022-12-08 04:20:41 +08:00
|
|
|
rewriter.inlineRegionBefore(op.getThenRegion(), newIf.getThenRegion(),
|
|
|
|
newIf.getThenRegion().end());
|
|
|
|
rewriter.inlineRegionBefore(op.getElseRegion(), newIf.getElseRegion(),
|
|
|
|
newIf.getElseRegion().end());
|
2024-01-30 01:59:33 +08:00
|
|
|
newIf.getThenRegion().front().getTerminator()->eraseOperands(
|
|
|
|
resultsToErase);
|
|
|
|
newIf.getElseRegion().front().getTerminator()->eraseOperands(
|
|
|
|
resultsToErase);
|
2022-03-10 08:44:22 +08:00
|
|
|
SmallVector<Value> replacementValues;
|
|
|
|
for (int i = 0, e = op->getNumResults(), nextNewValue = 0; i < e; ++i) {
|
|
|
|
if (resultsToErase[i])
|
|
|
|
replacementValues.push_back(nullptr);
|
|
|
|
else
|
|
|
|
replacementValues.push_back(newIf->getResult(nextNewValue++));
|
|
|
|
}
|
|
|
|
rewriter.replaceOp(op, replacementValues);
|
|
|
|
|
|
|
|
return success();
|
|
|
|
});
|
2021-06-17 01:23:26 +08:00
|
|
|
}
|
|
|
|
|
2023-06-09 19:06:25 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// RuntimeAssertOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void RuntimeAssertOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](RuntimeAssertOp op, PatternRewriter &rewriter) {
|
|
|
|
bool value;
|
|
|
|
if (!matchPattern(op.getCondition(), m_TorchConstantBool(&value)))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
if (value) {
|
2024-01-30 01:59:33 +08:00
|
|
|
rewriter.eraseOp(op);
|
|
|
|
return success();
|
2023-06-09 19:06:25 +08:00
|
|
|
}
|
|
|
|
// Even if the condition is statically false, the assert might never be
|
|
|
|
// executed.
|
|
|
|
return failure();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2021-04-27 02:42:41 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// DerefineOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
bool DerefineOp::areCastCompatible(mlir::TypeRange inputs,
|
|
|
|
mlir::TypeRange outputs) {
|
|
|
|
return isValidSubtype(inputs[0], outputs[0]);
|
|
|
|
}
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult DerefineOp::fold(FoldAdaptor adaptor) {
|
2022-03-10 08:44:22 +08:00
|
|
|
auto uncheckedCast = getOperand().getDefiningOp<PrimUncheckedCastOp>();
|
|
|
|
if (!uncheckedCast)
|
|
|
|
return nullptr;
|
|
|
|
if (uncheckedCast.getOperand().getType() == getType())
|
|
|
|
return uncheckedCast.getOperand();
|
|
|
|
return nullptr;
|
|
|
|
}
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
|
2022-08-16 13:24:08 +08:00
|
|
|
void DerefineOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
2022-03-30 06:57:31 +08:00
|
|
|
patterns.add(+[](DerefineOp op, PatternRewriter &rewriter) {
|
|
|
|
bool madeChange = false;
|
|
|
|
for (OpOperand &use : llvm::make_early_inc_range(op->getUses())) {
|
|
|
|
if (use.getOwner()->hasTrait<OpTrait::AllowsTypeRefinement>()) {
|
|
|
|
use.set(op.getOperand());
|
|
|
|
madeChange = true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return success(madeChange);
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
static OpFoldResult atenIsOrIsNotFoldHelper(Operation *op, bool equalIsTrue) {
|
|
|
|
Value lhs = op->getOperand(0);
|
|
|
|
Value rhs = op->getOperand(1);
|
|
|
|
// Look through DerefineOp's to get more refined static information.
|
|
|
|
if (auto derefine = lhs.getDefiningOp<DerefineOp>())
|
|
|
|
lhs = derefine.getOperand();
|
|
|
|
if (auto derefine = rhs.getDefiningOp<DerefineOp>())
|
|
|
|
rhs = derefine.getOperand();
|
|
|
|
Type lhsType = lhs.getType();
|
|
|
|
Type rhsType = rhs.getType();
|
|
|
|
|
|
|
|
// If either type is a NoneType, make it be the lhsType.
|
|
|
|
if (rhsType.isa<Torch::NoneType>()) {
|
|
|
|
std::swap(lhsType, rhsType);
|
2022-01-28 21:35:40 +08:00
|
|
|
std::swap(lhs, rhs);
|
2022-03-10 08:44:22 +08:00
|
|
|
}
|
2022-01-28 21:35:40 +08:00
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
// For now, check a few specific cases.
|
2022-01-28 21:35:40 +08:00
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
// If both types are the singleton `!torch.none` type, then we don't even need
|
|
|
|
// to look at the values.
|
|
|
|
if (lhsType.isa<Torch::NoneType>() && rhsType.isa<Torch::NoneType>())
|
|
|
|
return IntegerAttr::get(IntegerType::get(op->getContext(), 1), equalIsTrue);
|
2021-08-11 09:28:50 +08:00
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
// If neither type is a subtype of the other, then the result is false.
|
|
|
|
// TODO: Implement and use subtype infra for this.
|
|
|
|
// For now, check a specific case.
|
|
|
|
// If the rhs is not OptionalType, then we know it cannot be None.
|
|
|
|
if (lhsType.isa<Torch::NoneType>() && !rhsType.isa<Torch::OptionalType>()) {
|
|
|
|
return IntegerAttr::get(IntegerType::get(op->getContext(), 1),
|
|
|
|
!equalIsTrue);
|
|
|
|
}
|
2021-08-11 09:28:50 +08:00
|
|
|
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten__RangeLengthOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult Aten__RangeLengthOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto lo = adaptor.getLo();
|
|
|
|
auto hi = adaptor.getHi();
|
|
|
|
auto step = adaptor.getStep();
|
2022-03-10 08:44:22 +08:00
|
|
|
if (!lo || !hi || !step)
|
|
|
|
return nullptr;
|
|
|
|
auto loInt = lo.dyn_cast_or_null<IntegerAttr>().getValue();
|
|
|
|
auto hiInt = hi.dyn_cast_or_null<IntegerAttr>().getValue();
|
|
|
|
auto stepInt = step.dyn_cast_or_null<IntegerAttr>().getValue();
|
|
|
|
// TODO: Implement folding for negative steps.
|
|
|
|
if (stepInt.isNegative())
|
|
|
|
return nullptr;
|
|
|
|
// From Python language spec:
|
|
|
|
// r[i] = lo + step*i such that i >= 0 and r[i] < hi
|
|
|
|
// So maximize `i` such that lo + step * i < hi
|
|
|
|
// ==> i == ceildiv(hi - lo, step)
|
2022-08-09 11:17:35 +08:00
|
|
|
return IntegerAttr::get(lo.cast<TypedAttr>().getType(),
|
2022-03-10 08:44:22 +08:00
|
|
|
llvm::APIntOps::RoundingSDiv(hiInt - loInt, stepInt,
|
|
|
|
APInt::Rounding::UP));
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten__DeriveIndexOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult Aten__DeriveIndexOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto index = adaptor.getIndex();
|
|
|
|
auto start = adaptor.getStart();
|
|
|
|
auto step = adaptor.getStep();
|
2022-03-10 08:44:22 +08:00
|
|
|
if (!index || !start || !step)
|
|
|
|
return nullptr;
|
|
|
|
auto indexInt = index.dyn_cast_or_null<IntegerAttr>().getValue();
|
|
|
|
auto startInt = start.dyn_cast_or_null<IntegerAttr>().getValue();
|
|
|
|
auto stepInt = step.dyn_cast_or_null<IntegerAttr>().getValue();
|
2022-08-09 11:17:35 +08:00
|
|
|
return IntegerAttr::get(index.cast<TypedAttr>().getType(),
|
|
|
|
startInt + stepInt * indexInt);
|
2022-03-10 08:44:22 +08:00
|
|
|
}
|
|
|
|
|
2021-08-11 09:28:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten__Is__Op
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult Aten__Is__Op::fold(FoldAdaptor adaptor) {
|
2021-08-11 09:28:50 +08:00
|
|
|
return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/true);
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten__Isnot__Op
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult Aten__Isnot__Op::fold(FoldAdaptor adaptor) {
|
2021-08-11 09:28:50 +08:00
|
|
|
return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/false);
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten__Not__Op
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult Aten__Not__Op::fold(FoldAdaptor adaptor) {
|
2021-08-11 09:28:50 +08:00
|
|
|
bool value;
|
|
|
|
if (!matchPattern(getOperand(), m_TorchConstantBool(&value)))
|
|
|
|
return nullptr;
|
|
|
|
return IntegerAttr::get(IntegerType::get(getContext(), 1), !value);
|
|
|
|
}
|
|
|
|
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2021-10-21 23:50:01 +08:00
|
|
|
// AtenNeBoolOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenNeBoolOp::fold(FoldAdaptor adaptor) {
|
2021-10-21 23:50:01 +08:00
|
|
|
if (getOperand(0) == getOperand(1))
|
|
|
|
return IntegerAttr::get(IntegerType::get(getContext(), 1), false);
|
|
|
|
|
|
|
|
bool a, b;
|
|
|
|
if (!matchPattern(getOperand(0), m_TorchConstantBool(&a)))
|
|
|
|
return nullptr;
|
|
|
|
if (!matchPattern(getOperand(1), m_TorchConstantBool(&b)))
|
|
|
|
return nullptr;
|
|
|
|
return IntegerAttr::get(IntegerType::get(getContext(), 1), a != b);
|
|
|
|
}
|
|
|
|
|
2021-11-25 04:19:13 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSqueezeOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) {
|
2024-01-05 06:33:41 +08:00
|
|
|
if (getOperand().getType() != getResult().getType())
|
|
|
|
return nullptr;
|
2021-11-25 04:19:13 +08:00
|
|
|
if (auto tensorType = getOperand().getType().dyn_cast<BaseTensorType>()) {
|
|
|
|
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
|
|
|
|
return getOperand();
|
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2021-11-30 22:50:55 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSqueezeDimOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) {
|
2024-01-05 06:33:41 +08:00
|
|
|
if (getOperand(0).getType() != getResult().getType())
|
|
|
|
return nullptr;
|
2021-11-30 22:50:55 +08:00
|
|
|
if (auto tensorType = getOperand(0).getType().dyn_cast<BaseTensorType>()) {
|
|
|
|
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
|
|
|
|
return getOperand(0);
|
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2022-10-11 17:52:01 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenRoundOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {
|
2024-01-05 06:33:41 +08:00
|
|
|
if (getSelf().getType() != getResult().getType())
|
|
|
|
return nullptr;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (auto selfType = getSelf().getType().dyn_cast<BaseTensorType>()) {
|
2022-10-11 17:52:01 +08:00
|
|
|
if (selfType.hasDtype() && selfType.getDtype().isa<mlir::IntegerType>())
|
2022-12-08 04:20:41 +08:00
|
|
|
return getSelf();
|
2022-10-11 17:52:01 +08:00
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2021-12-23 20:04:29 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenToDtypeOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenToDtypeOp::fold(FoldAdaptor adaptor) {
|
2021-12-23 20:04:29 +08:00
|
|
|
bool nonBlocking, copyArg;
|
|
|
|
// The non_blocking arg must be `False`.
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(getNonBlocking(), m_TorchConstantBool(&nonBlocking)) ||
|
2021-12-23 20:04:29 +08:00
|
|
|
nonBlocking)
|
|
|
|
return nullptr;
|
|
|
|
// The copy arg must be `False`.
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(getCopy(), m_TorchConstantBool(©Arg)) || copyArg)
|
2021-12-23 20:04:29 +08:00
|
|
|
return nullptr;
|
|
|
|
// The memory_format arg must be `none`.
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!getMemoryFormat().getType().isa<Torch::NoneType>())
|
2021-12-23 20:04:29 +08:00
|
|
|
return nullptr;
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
auto inputType = getSelf().getType().cast<BaseTensorType>();
|
2022-03-10 08:44:22 +08:00
|
|
|
auto resType = getType().cast<BaseTensorType>();
|
|
|
|
// If the types aren't equal, then we can't fold.
|
|
|
|
if (inputType != resType)
|
2021-12-23 20:04:29 +08:00
|
|
|
return nullptr;
|
2022-03-10 08:44:22 +08:00
|
|
|
// If the type does not have a statically known dtype, then we cannot fold.
|
|
|
|
// For example, folding `tensor<*,unk>` to `tensor<*,unk>` would be wrong,
|
|
|
|
// since the `unk` could be dynamically different for the operand and result.
|
|
|
|
if (!inputType.hasDtype())
|
2021-12-23 20:04:29 +08:00
|
|
|
return nullptr;
|
|
|
|
// Fold when both the input tensor and result are of the same type.
|
|
|
|
return getOperand(0);
|
|
|
|
}
|
|
|
|
|
2022-04-27 19:07:40 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenToDtypeLayoutOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) {
|
2022-04-27 19:07:40 +08:00
|
|
|
// The pin_memory arg should be either constant `False` or `none`.
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!getPinMemory().getType().isa<Torch::NoneType>()) {
|
2022-04-27 19:07:40 +08:00
|
|
|
bool pinMemory;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(getPinMemory(), m_TorchConstantBool(&pinMemory)))
|
2022-04-27 19:07:40 +08:00
|
|
|
return nullptr;
|
|
|
|
else if (pinMemory)
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
// The non_blocking arg should be constant `False`.
|
|
|
|
bool nonBlocking;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(getNonBlocking(), m_TorchConstantBool(&nonBlocking)))
|
2022-04-27 19:07:40 +08:00
|
|
|
return nullptr;
|
|
|
|
else if (nonBlocking)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
// The copy arg should be constant `False`.
|
|
|
|
bool copyArg;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(getCopy(), m_TorchConstantBool(©Arg)))
|
2022-04-27 19:07:40 +08:00
|
|
|
return nullptr;
|
|
|
|
else if (copyArg)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
// The device arg must be `none`.
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!getDevice().getType().isa<Torch::NoneType>())
|
2022-04-27 19:07:40 +08:00
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
// The memory_format arg must be `none`.
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!getMemoryFormat().getType().isa<Torch::NoneType>())
|
2022-04-27 19:07:40 +08:00
|
|
|
return nullptr;
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
auto inputType = getSelf().getType().cast<BaseTensorType>();
|
2022-04-27 19:07:40 +08:00
|
|
|
auto resType = getType().cast<BaseTensorType>();
|
|
|
|
// If the types aren't equal, then we can't fold.
|
|
|
|
if (inputType != resType)
|
|
|
|
return nullptr;
|
|
|
|
// If the type does not have a statically known dtype, then we cannot fold.
|
|
|
|
// For example, folding `tensor<*,unk>` to `tensor<*,unk>` would be wrong,
|
|
|
|
// since the `unk` could be dynamically different for the operand and result.
|
|
|
|
if (!inputType.hasDtype())
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
// The layout arg should be either `none` or `0` i.e. strided.
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!getLayout().getType().isa<Torch::NoneType>()) {
|
2022-04-27 19:07:40 +08:00
|
|
|
int64_t tensorLayout;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(getLayout(), m_TorchConstantInt(&tensorLayout)))
|
2022-04-27 19:07:40 +08:00
|
|
|
return nullptr;
|
|
|
|
else if (tensorLayout != torch_upstream::Layout::Strided)
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Fold when both the input tensor and result are of the same type and the
|
|
|
|
// layout arg is strided.
|
|
|
|
return getOperand(0);
|
|
|
|
}
|
|
|
|
|
2023-05-03 11:06:02 +08:00
|
|
|
void AtenToDtypeLayoutOp::getCanonicalizationPatterns(
|
|
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
|
|
// `to.dtype_layout` -> `to.device/to.dtype` if layout is none and pin memory
|
|
|
|
// is false
|
|
|
|
patterns.add(+[](AtenToDtypeLayoutOp op, PatternRewriter &rewriter) {
|
|
|
|
// The pin_memory arg should be either constant `False` or `none`.
|
|
|
|
if (!op.getPinMemory().getType().isa<Torch::NoneType>()) {
|
|
|
|
bool pinMemory;
|
|
|
|
if (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)))
|
|
|
|
return failure();
|
|
|
|
else if (pinMemory)
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
// The layout arg should be either `none` or `0` i.e. strided.
|
|
|
|
if (!op.getLayout().getType().isa<Torch::NoneType>()) {
|
|
|
|
int64_t tensorLayout;
|
|
|
|
if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout)))
|
|
|
|
return failure();
|
|
|
|
else if (tensorLayout != torch_upstream::Layout::Strided)
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
if (op.getDevice().getType().isa<Torch::NoneType>()) {
|
|
|
|
// The device arg is `none`. Rewrite to to.dtype.
|
|
|
|
AtenToDtypeOp toDtype = rewriter.create<AtenToDtypeOp>(
|
|
|
|
op.getLoc(), op.getType(), op.getSelf(), op.getDtype(),
|
|
|
|
op.getNonBlocking(), op.getCopy(), op.getMemoryFormat());
|
|
|
|
rewriter.replaceOp(op, toDtype->getResults());
|
|
|
|
} else {
|
|
|
|
// The device arg is not `none`. Rewrite to to.device.
|
|
|
|
AtenToDeviceOp toDevice = rewriter.create<AtenToDeviceOp>(
|
|
|
|
op.getLoc(), op.getType(), op.getSelf(), op.getDevice(),
|
|
|
|
op.getDtype(), op.getNonBlocking(), op.getCopy(),
|
|
|
|
op.getMemoryFormat());
|
|
|
|
rewriter.replaceOp(op, toDevice->getResults());
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2023-06-30 09:43:08 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenToOtherOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void AtenToOtherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
// Canonicalize `aten.to.other` to `aten.to.device`
|
|
|
|
patterns.add(+[](AtenToOtherOp op, PatternRewriter &rewriter) {
|
|
|
|
auto lhs = op.getSelf();
|
|
|
|
auto rhs = op.getOther();
|
|
|
|
auto getRhsDevice = rewriter.create<PrimDeviceOp>(op.getLoc(), rhs);
|
|
|
|
auto getRhsDtype = rewriter.create<PrimDtypeOp>(op.getLoc(), rhs);
|
2024-01-30 01:59:33 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenToDeviceOp>(
|
|
|
|
op, op.getType(), lhs, getRhsDevice.getResult(),
|
|
|
|
getRhsDtype.getResult(), op.getNonBlocking(), op.getCopy(),
|
|
|
|
op.getMemoryFormat());
|
2023-06-30 09:43:08 +08:00
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2021-12-23 20:04:29 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenViewOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) {
|
2021-12-23 20:04:29 +08:00
|
|
|
auto inputType = getOperand(0).getType().dyn_cast<BaseTensorType>();
|
|
|
|
if (!inputType || !inputType.hasSizes() || inputType.getSizes().size() != 1)
|
|
|
|
return nullptr;
|
|
|
|
auto resType = getType().dyn_cast<BaseTensorType>();
|
|
|
|
if (!resType || !resType.hasSizes() || resType.getSizes().size() != 1)
|
|
|
|
return nullptr;
|
2024-01-05 06:33:41 +08:00
|
|
|
if (inputType != resType)
|
|
|
|
return nullptr;
|
2021-12-23 20:04:29 +08:00
|
|
|
// Fold when both the input tensor and result are unity rank tensors.
|
|
|
|
return getOperand(0);
|
|
|
|
}
|
|
|
|
|
2023-04-10 11:50:26 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimsViewOfOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult PrimsViewOfOp::fold(FoldAdaptor adaptor) {
|
|
|
|
// Always fold the op with its only input operand.
|
|
|
|
return getOperand();
|
|
|
|
}
|
|
|
|
|
2021-10-21 23:50:01 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenDimOp
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenDimOp::fold(FoldAdaptor adaptor) {
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
if (auto tensorType = getOperand().getType().dyn_cast<BaseTensorType>()) {
|
|
|
|
if (tensorType.hasSizes())
|
|
|
|
return IntegerAttr::get(IntegerType::get(getContext(), 64),
|
|
|
|
tensorType.getSizes().size());
|
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenLenTOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenLenTOp::fold(FoldAdaptor adaptor) {
|
2022-03-10 08:44:22 +08:00
|
|
|
// `len([1,1,1])` -> `3`, if it is not mutated.
|
2021-06-19 10:33:14 +08:00
|
|
|
if (auto listConstruct =
|
|
|
|
getOperand().getDefiningOp<Torch::PrimListConstructOp>()) {
|
2022-03-10 08:44:22 +08:00
|
|
|
if (!isListPotentiallyMutated(listConstruct)) {
|
|
|
|
return IntegerAttr::get(IntegerType::get(getContext(), 64),
|
|
|
|
listConstruct.getNumOperands());
|
|
|
|
}
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
void AtenLenTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
// `len(t.size())` -> `t.ndim`
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
patterns.add(+[](AtenLenTOp op, PatternRewriter &rewriter) {
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
auto size = op.getOperand().getDefiningOp<AtenSizeOp>();
|
|
|
|
if (!size)
|
|
|
|
return rewriter.notifyMatchFailure(op, "operand not AtenSizeOp");
|
2021-06-17 06:53:15 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenDimOp>(op, size.getOperand());
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
2023-09-11 17:28:22 +08:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenMinOtherOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void AtenMinOtherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
// `aten.min.other` -> `aten.minimum`
|
|
|
|
patterns.add(+[](AtenMinOtherOp op, PatternRewriter &rewriter) {
|
|
|
|
rewriter.replaceOpWithNewOp<AtenMinimumOp>(op, op.getType(), op.getSelf(),
|
|
|
|
op.getOther());
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
|
2023-09-05 10:52:32 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenMaxOtherOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void AtenMaxOtherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
// `aten.max.other` -> `aten.maximum`
|
|
|
|
patterns.add(+[](AtenMaxOtherOp op, PatternRewriter &rewriter) {
|
|
|
|
rewriter.replaceOpWithNewOp<AtenMaximumOp>(op, op.getType(), op.getSelf(),
|
2024-01-30 01:59:33 +08:00
|
|
|
op.getOther());
|
2023-09-05 10:52:32 +08:00
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2022-07-08 01:41:55 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenLenStrOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenLenStrOp::fold(FoldAdaptor adaptor) {
|
2022-12-08 04:20:41 +08:00
|
|
|
if (auto stringConstruct = getS().getDefiningOp<ConstantStrOp>())
|
2022-08-16 13:24:08 +08:00
|
|
|
return getI64IntegerAttr(getContext(),
|
2022-12-08 04:20:41 +08:00
|
|
|
stringConstruct.getValueAttr().getValue().size());
|
2022-07-08 01:41:55 +08:00
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2022-08-16 13:24:08 +08:00
|
|
|
LogicalResult rewrite0DBinaryTensorOp(Operation *op,
|
|
|
|
PatternRewriter &rewriter) {
|
|
|
|
Location loc = op->getLoc();
|
|
|
|
// This canonicalization pattern also includes aten div/mul/add/sub ops
|
|
|
|
// between tensor and scalar, like aten.add.Scalar op
|
|
|
|
if (op->getNumOperands() < 2) {
|
|
|
|
return failure();
|
|
|
|
}
|
2023-03-07 02:12:58 +08:00
|
|
|
auto lhs = getScalarIntValue(op->getOperand(0), loc, rewriter);
|
|
|
|
auto rhs = getScalarIntValue(op->getOperand(1), loc, rewriter);
|
2022-08-16 13:24:08 +08:00
|
|
|
auto outType = op->getResult(0).getType();
|
|
|
|
|
|
|
|
if (!lhs || !rhs) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "only int scalar lhs or rhs is supported");
|
|
|
|
}
|
2023-03-07 09:38:27 +08:00
|
|
|
if (isa<AtenSubTensorOp, AtenSubScalarOp, AtenRsubScalarOp, AtenAddTensorOp,
|
|
|
|
AtenAddScalarOp>(op)) {
|
2023-03-07 02:12:58 +08:00
|
|
|
Value alpha = getScalarIntValue(op->getOperand(2), loc, rewriter);
|
2022-08-16 13:24:08 +08:00
|
|
|
if (!alpha) {
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"only int scalar alpha is supported");
|
|
|
|
}
|
2023-03-07 09:38:27 +08:00
|
|
|
if (isa<AtenRsubScalarOp>(op))
|
|
|
|
lhs = rewriter.create<AtenMulIntOp>(loc, lhs, alpha);
|
|
|
|
else
|
|
|
|
rhs = rewriter.create<AtenMulIntOp>(loc, rhs, alpha);
|
2022-08-16 13:24:08 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
if (isa<AtenDivTensorModeOp>(op)) {
|
|
|
|
// None rounding mode
|
|
|
|
if (op->getOperand(2).getType().isa<Torch::NoneType>()) {
|
|
|
|
Value quotient = rewriter.create<AtenDivOp>(loc, lhs, rhs);
|
|
|
|
rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(op, outType,
|
|
|
|
quotient);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
std::string roundingMode;
|
|
|
|
if (!matchPattern(op->getOperand(2), m_TorchConstantStr(roundingMode))) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "only None, 'floor' or 'trunc' rounding mode is supported");
|
|
|
|
}
|
|
|
|
if (roundingMode == "floor") {
|
|
|
|
Value quotient = rewriter.create<AtenFloordivIntOp>(loc, lhs, rhs);
|
|
|
|
rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(op, outType,
|
|
|
|
quotient);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
// For "trunc" rounding mode, insted of canonicalizing it into
|
|
|
|
// aten.abs, aten.floor, aten.sign and aten.mul.int ops, which adds
|
|
|
|
// complexity but helps little in optimization (such as constant folding),
|
|
|
|
// we are trying to fold it.
|
|
|
|
if (roundingMode == "trunc") {
|
|
|
|
int64_t lhsInt;
|
|
|
|
int64_t rhsInt;
|
|
|
|
if (!matchPattern(lhs, m_TorchConstantInt(&lhsInt))) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
if (!matchPattern(rhs, m_TorchConstantInt(&rhsInt))) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
int64_t result = (int64_t)std::trunc((double)lhsInt / rhsInt);
|
|
|
|
Value resultScalar = rewriter.create<ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(result));
|
|
|
|
rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(op, outType,
|
|
|
|
resultScalar);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
Value result;
|
|
|
|
// Other Add/Sub/Mul ops
|
|
|
|
if (isa<AtenAddTensorOp, AtenAddScalarOp>(op)) {
|
|
|
|
result = rewriter.create<AtenAddIntOp>(loc, lhs, rhs);
|
|
|
|
} else if (isa<AtenSubScalarOp, AtenSubTensorOp>(op)) {
|
|
|
|
result = rewriter.create<AtenSubIntOp>(loc, lhs, rhs);
|
2023-03-07 09:38:27 +08:00
|
|
|
} else if (isa<AtenRsubScalarOp>(op)) {
|
|
|
|
result = rewriter.create<AtenSubIntOp>(loc, rhs, lhs);
|
2022-08-16 13:24:08 +08:00
|
|
|
} else if (isa<AtenMulScalarOp, AtenMulTensorOp>(op)) {
|
|
|
|
result = rewriter.create<AtenMulIntOp>(loc, lhs, rhs);
|
|
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(op, outType, result);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-06-18 02:49:36 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenAddTensorOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void AtenAddTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenAddTensorOp op, PatternRewriter &rewriter) {
|
2022-08-16 13:24:08 +08:00
|
|
|
return rewrite0DBinaryTensorOp(op, rewriter);
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenAddScalarOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void AtenAddScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenAddScalarOp op, PatternRewriter &rewriter) {
|
|
|
|
return rewrite0DBinaryTensorOp(op, rewriter);
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSubTensorOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void AtenSubTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenSubTensorOp op, PatternRewriter &rewriter) {
|
|
|
|
return rewrite0DBinaryTensorOp(op, rewriter);
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSubScalarOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void AtenSubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenSubScalarOp op, PatternRewriter &rewriter) {
|
|
|
|
return rewrite0DBinaryTensorOp(op, rewriter);
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2023-03-07 09:38:27 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenRSubScalarOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void AtenRsubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenRsubScalarOp op, PatternRewriter &rewriter) {
|
|
|
|
return rewrite0DBinaryTensorOp(op, rewriter);
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2022-08-16 13:24:08 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenMulTensorOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void AtenMulTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenMulTensorOp op, PatternRewriter &rewriter) {
|
|
|
|
return rewrite0DBinaryTensorOp(op, rewriter);
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2023-11-02 09:51:31 +08:00
|
|
|
// AtenFloorOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void AtenFloorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenFloorOp op, PatternRewriter &rewriter) {
|
|
|
|
auto outputTy = op.getType().dyn_cast<ValueTensorType>();
|
|
|
|
if (outputTy && outputTy.hasDtype() &&
|
|
|
|
outputTy.getDtype().isa<mlir::IntegerType>()) {
|
|
|
|
rewriter.replaceOp(op, op.getSelf());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
return failure();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2022-08-16 13:24:08 +08:00
|
|
|
// AtenMulScalarOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void AtenMulScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenMulScalarOp op, PatternRewriter &rewriter) {
|
|
|
|
return rewrite0DBinaryTensorOp(op, rewriter);
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenDivTensorModeOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void AtenDivTensorModeOp::getCanonicalizationPatterns(
|
|
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenDivTensorModeOp op, PatternRewriter &rewriter) {
|
|
|
|
return rewrite0DBinaryTensorOp(op, rewriter);
|
2022-06-18 02:49:36 +08:00
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2023-11-11 12:16:53 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenNumelOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void AtenNumelOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenNumelOp op, PatternRewriter &rewriter) {
|
|
|
|
auto inputType = op.getSelf().getType().dyn_cast<BaseTensorType>();
|
|
|
|
if (!inputType || !inputType.areAllSizesKnown()) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
auto sizes = inputType.getSizes();
|
|
|
|
int64_t numel = 1;
|
|
|
|
for (int64_t d : sizes) {
|
|
|
|
numel *= d;
|
|
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(
|
|
|
|
op, rewriter.getI64IntegerAttr(numel));
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2023-09-06 14:21:51 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten__Or__TensorOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void Aten__Or__TensorOp::getCanonicalizationPatterns(
|
|
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
|
|
patterns.add(+[](Aten__Or__TensorOp op, PatternRewriter &rewriter) {
|
|
|
|
rewriter.replaceOpWithNewOp<AtenBitwiseOrTensorOp>(
|
|
|
|
op, op.getType(), op.getSelf(), op.getOther());
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2023-03-07 09:38:27 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenScalarImplicitOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void AtenScalarImplicitOp::getCanonicalizationPatterns(
|
|
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenScalarImplicitOp op, PatternRewriter &rewriter) {
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
Value a = op.getA();
|
|
|
|
auto outType = op.getResult().getType();
|
|
|
|
Value scalarValue = getScalarIntValue(a, loc, rewriter);
|
|
|
|
if (!scalarValue)
|
|
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::DerefineOp>(op, outType, scalarValue);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSizeOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-07-13 03:38:37 +08:00
|
|
|
// Traces at most 6 parents of `value` to determine the tensor type with known
|
|
|
|
// dimension size or returns failure if such a type was not found. If `dim` is
|
|
|
|
// `None`, then all dimension's sizes must be known.
|
|
|
|
static FailureOr<BaseTensorType>
|
2022-12-20 18:17:27 +08:00
|
|
|
traceKnownSizeTensorType(Value value, std::optional<int64_t> dim) {
|
2022-07-13 03:38:37 +08:00
|
|
|
// Function to check if we found a type that contains the queried information.
|
2022-12-20 18:17:27 +08:00
|
|
|
auto foundType = [](BaseTensorType tensorType, std::optional<int64_t>(dim)) {
|
2022-07-13 03:38:37 +08:00
|
|
|
if (!tensorType.hasSizes())
|
|
|
|
return false;
|
|
|
|
|
2022-12-14 16:06:39 +08:00
|
|
|
if (dim == std::nullopt)
|
2022-07-13 03:38:37 +08:00
|
|
|
return tensorType.areAllSizesKnown();
|
|
|
|
|
|
|
|
// If the dimension value is negative, then convert it to a positive value.
|
|
|
|
ArrayRef<int64_t> sizes = tensorType.getSizes();
|
|
|
|
*dim = toPositiveDim(*dim, sizes.size());
|
|
|
|
return isValidDim(*dim, sizes.size()) && sizes[*dim] != kUnknownSize;
|
|
|
|
};
|
|
|
|
|
|
|
|
// Limit the loop count to 6 to avoid indefinite compilation times from
|
|
|
|
// unbounded IR traversals.
|
|
|
|
for (auto idx = 0; idx < 6; ++idx) {
|
|
|
|
if (!value || !value.getType().isa<BaseTensorType>())
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
auto tensorType = value.getType().cast<BaseTensorType>();
|
|
|
|
if (foundType(tensorType, dim))
|
|
|
|
return tensorType;
|
|
|
|
|
|
|
|
auto op = value.getDefiningOp();
|
|
|
|
if (!op || !isa<CopyToValueTensorOp, CopyToNonValueTensorOp,
|
|
|
|
TensorStaticInfoCastOp>(op))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
// In all ops of interest to us, the source tensor is operand #0.
|
|
|
|
value = op->getOperand(0);
|
|
|
|
}
|
|
|
|
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenSizeOp op, PatternRewriter &rewriter) {
|
2022-12-14 16:06:39 +08:00
|
|
|
auto type = traceKnownSizeTensorType(op.getOperand(), std::nullopt);
|
2022-07-13 03:38:37 +08:00
|
|
|
if (failed(type))
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "all sizes not known");
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
SmallVector<Value> listElements;
|
2022-07-13 03:38:37 +08:00
|
|
|
for (int64_t size : type->getSizes()) {
|
2021-06-16 03:42:51 +08:00
|
|
|
listElements.push_back(rewriter.create<Torch::ConstantIntOp>(
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
op->getLoc(), rewriter.getI64IntegerAttr(size)));
|
|
|
|
}
|
2021-06-05 06:57:21 +08:00
|
|
|
rewriter.replaceOpWithNewOp<Torch::PrimListConstructOp>(
|
2021-06-17 06:53:15 +08:00
|
|
|
op, Torch::ListType::get(rewriter.getType<Torch::IntType>()),
|
|
|
|
listElements);
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
return success();
|
|
|
|
});
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
// One-off pattern to erase if dead.
|
|
|
|
// TODO: Use the effects infra to express the semantics of this op and enable
|
|
|
|
// a centralized "erase if dead" canonicalization.
|
|
|
|
// Specifically, we need to mark the op as only MemoryEffects::Allocate
|
|
|
|
// so that `mlir::wouldOpBeTriviallyDead` does the right thing.
|
|
|
|
patterns.add(+[](AtenSizeOp op, PatternRewriter &rewriter) {
|
|
|
|
if (!op.use_empty())
|
|
|
|
return failure();
|
|
|
|
rewriter.eraseOp(op);
|
|
|
|
return failure();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2021-10-16 06:23:59 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSizeIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenSizeIntOp::fold(FoldAdaptor adaptor) {
|
2022-07-13 03:38:37 +08:00
|
|
|
int64_t dim;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(this->getDim(), m_TorchConstantInt(&dim)))
|
2021-10-16 06:23:59 +08:00
|
|
|
return nullptr;
|
2022-12-08 04:20:41 +08:00
|
|
|
auto type = traceKnownSizeTensorType(this->getSelf(), dim);
|
2022-07-13 03:38:37 +08:00
|
|
|
if (failed(type))
|
2021-10-16 06:23:59 +08:00
|
|
|
return nullptr;
|
2022-07-13 03:38:37 +08:00
|
|
|
ArrayRef<int64_t> sizes = type->getSizes();
|
|
|
|
dim = toPositiveDim(dim, sizes.size());
|
2023-04-07 19:49:35 +08:00
|
|
|
if (!isValidDim(dim, sizes.size()))
|
|
|
|
return nullptr;
|
2022-07-13 03:38:37 +08:00
|
|
|
return IntegerAttr::get(IntegerType::get(getContext(), 64), sizes[dim]);
|
2021-10-16 06:23:59 +08:00
|
|
|
}
|
|
|
|
|
2021-06-17 02:05:08 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenGtIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static IntegerAttr getI1IntegerAttr(MLIRContext *context, bool value) {
|
|
|
|
return IntegerAttr::get(IntegerType::get(context, 1),
|
|
|
|
static_cast<int64_t>(value));
|
|
|
|
}
|
|
|
|
|
2022-02-11 05:25:25 +08:00
|
|
|
using ConstantFloatComparator = std::function<bool(double, double)>;
|
2021-08-11 09:28:50 +08:00
|
|
|
template <typename OpTy>
|
2022-02-11 05:25:25 +08:00
|
|
|
static OpFoldResult
|
|
|
|
floatComparatorFoldHelper(OpTy op, ConstantFloatComparator comparator) {
|
2021-08-11 09:28:50 +08:00
|
|
|
if (op.getOperand(0) == op.getOperand(1))
|
|
|
|
return getI1IntegerAttr(op.getContext(), comparator(0, 0));
|
|
|
|
|
2022-02-11 05:25:25 +08:00
|
|
|
double lhs, rhs;
|
|
|
|
if (!matchPattern(op.getOperand(0), m_TorchConstantFloat(&lhs)) ||
|
|
|
|
!matchPattern(op.getOperand(1), m_TorchConstantFloat(&rhs)))
|
2021-08-11 09:28:50 +08:00
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
return getI1IntegerAttr(op.getContext(), comparator(lhs, rhs));
|
2021-06-17 02:05:08 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2022-02-11 05:25:25 +08:00
|
|
|
// AtenLtFloatOp
|
2021-06-17 02:05:08 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenLtFloatOp::fold(FoldAdaptor adaptor) {
|
2022-02-11 05:25:25 +08:00
|
|
|
return floatComparatorFoldHelper(*this,
|
|
|
|
[](double a, double b) { return a < b; });
|
2021-08-11 09:28:50 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2022-02-11 05:25:25 +08:00
|
|
|
// AtenGtFloatOp
|
2021-08-11 09:28:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenGtFloatOp::fold(FoldAdaptor adaptor) {
|
2022-02-11 05:25:25 +08:00
|
|
|
return floatComparatorFoldHelper(*this,
|
|
|
|
[](double a, double b) { return a > b; });
|
2021-08-11 09:28:50 +08:00
|
|
|
}
|
|
|
|
|
2022-04-25 21:12:45 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenGeFloatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenGeFloatOp::fold(FoldAdaptor adaptor) {
|
2022-04-25 21:12:45 +08:00
|
|
|
return floatComparatorFoldHelper(*this,
|
|
|
|
[](double a, double b) { return a >= b; });
|
|
|
|
}
|
|
|
|
|
2022-01-11 15:42:53 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenEqFloatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenEqFloatOp::fold(FoldAdaptor adaptor) {
|
2022-02-11 05:25:25 +08:00
|
|
|
return floatComparatorFoldHelper(*this,
|
|
|
|
[](double a, double b) { return a == b; });
|
|
|
|
}
|
|
|
|
|
|
|
|
using ConstantIntComparator = std::function<bool(int64_t, int64_t)>;
|
|
|
|
template <typename OpTy>
|
|
|
|
static OpFoldResult intComparatorFoldHelper(OpTy op,
|
|
|
|
ConstantIntComparator comparator) {
|
2022-03-10 08:44:22 +08:00
|
|
|
|
|
|
|
Value lhsValue = op->getOperand(0);
|
|
|
|
Value rhsValue = op->getOperand(1);
|
|
|
|
if (lhsValue == rhsValue)
|
2022-02-11 05:25:25 +08:00
|
|
|
return getI1IntegerAttr(op.getContext(), comparator(0, 0));
|
2022-01-11 15:42:53 +08:00
|
|
|
|
2022-02-11 05:25:25 +08:00
|
|
|
int64_t lhs, rhs;
|
2022-03-10 08:44:22 +08:00
|
|
|
bool lhsIsConstant = matchPattern(lhsValue, m_TorchConstantInt(&lhs));
|
|
|
|
bool rhsIsConstant = matchPattern(rhsValue, m_TorchConstantInt(&rhs));
|
|
|
|
if (lhsIsConstant && rhsIsConstant)
|
|
|
|
return getI1IntegerAttr(op.getContext(), comparator(lhs, rhs));
|
2022-01-11 15:42:53 +08:00
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
// Ensure that if there is a constant, it is on the right.
|
|
|
|
if (lhsIsConstant && !rhsIsConstant) {
|
|
|
|
std::swap(lhs, rhs);
|
|
|
|
std::swap(lhsValue, rhsValue);
|
|
|
|
std::swap(lhsIsConstant, rhsIsConstant);
|
|
|
|
auto newComparator = [comparator](int64_t lhs, int64_t rhs) {
|
|
|
|
return comparator(rhs, lhs);
|
|
|
|
};
|
|
|
|
comparator = newComparator;
|
|
|
|
}
|
2022-03-31 05:10:51 +08:00
|
|
|
|
|
|
|
// Fold comparisons of AtenSizeIntOp against negative values.
|
|
|
|
// AtenSizeIntOp is known to always be non-negative.
|
2022-03-10 08:44:22 +08:00
|
|
|
if (rhsIsConstant && rhs < 0) {
|
|
|
|
// We can return `comparator(0, -1)` here because of the property:
|
|
|
|
// If x >= 0 && y < 0, then:
|
|
|
|
// - cmp(x, y) == cmp(x + 1, y)
|
|
|
|
// - cmp(x, y) == cmp(x, y - 1)
|
|
|
|
// By induction all cases here are covered.
|
|
|
|
if (auto size = lhsValue.getDefiningOp<AtenSizeIntOp>())
|
|
|
|
return getI1IntegerAttr(op->getContext(), comparator(0, -1));
|
|
|
|
}
|
2022-03-31 05:10:51 +08:00
|
|
|
|
|
|
|
// Fold comparisons of AtenSizeIntOp against 0:
|
|
|
|
// - torch.aten.size.int >= 0 ==> True.
|
|
|
|
// - torch.aten.size.int < 0 ==> False.
|
|
|
|
// (and the operand-swapped versions of the above)
|
|
|
|
if (rhsIsConstant && rhs == 0) {
|
|
|
|
if (auto size = lhsValue.getDefiningOp<AtenSizeIntOp>()) {
|
|
|
|
// >= 0 comparison.
|
|
|
|
if (comparator(0, 0) && comparator(1, 0))
|
|
|
|
return getI1IntegerAttr(op->getContext(), true);
|
|
|
|
// < 0 comparison.
|
|
|
|
if (!comparator(0, 0) && comparator(-1, 0) && !comparator(1, 0))
|
|
|
|
return getI1IntegerAttr(op->getContext(), false);
|
|
|
|
}
|
2022-03-10 08:44:22 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
return nullptr;
|
2022-02-11 05:25:25 +08:00
|
|
|
}
|
|
|
|
|
2023-04-18 23:59:14 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenDetachOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2024-01-31 01:45:51 +08:00
|
|
|
OpFoldResult AtenDetachOp::fold(FoldAdaptor adaptor) {
|
|
|
|
if (getSelf().getType() != getResult().getType())
|
|
|
|
return {};
|
|
|
|
return getSelf();
|
|
|
|
}
|
2023-04-18 23:59:14 +08:00
|
|
|
|
2022-02-11 05:25:25 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenNeIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenNeIntOp::fold(FoldAdaptor adaptor) {
|
2022-02-11 05:25:25 +08:00
|
|
|
return intComparatorFoldHelper(*this,
|
|
|
|
[](int64_t a, int64_t b) { return a != b; });
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenEqIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenEqIntOp::fold(FoldAdaptor adaptor) {
|
2022-02-11 05:25:25 +08:00
|
|
|
return intComparatorFoldHelper(*this,
|
|
|
|
[](int64_t a, int64_t b) { return a == b; });
|
2022-01-11 15:42:53 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenEqStrOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenEqStrOp::fold(FoldAdaptor adaptor) {
|
2022-01-11 15:42:53 +08:00
|
|
|
if (getOperand(0) == getOperand(1))
|
|
|
|
return getI1IntegerAttr(getContext(), true);
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
auto aStr = getA().getDefiningOp<ConstantStrOp>();
|
|
|
|
auto bStr = getB().getDefiningOp<ConstantStrOp>();
|
2022-01-11 15:42:53 +08:00
|
|
|
|
|
|
|
if (aStr && bStr)
|
|
|
|
return getI1IntegerAttr(getContext(), aStr == bStr);
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2021-08-11 09:28:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenLtIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenLtIntOp::fold(FoldAdaptor adaptor) {
|
2022-02-11 05:25:25 +08:00
|
|
|
return intComparatorFoldHelper(*this,
|
|
|
|
[](int64_t a, int64_t b) { return a < b; });
|
2021-08-11 09:28:50 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenLeIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenLeIntOp::fold(FoldAdaptor adaptor) {
|
2022-02-11 05:25:25 +08:00
|
|
|
return intComparatorFoldHelper(*this,
|
|
|
|
[](int64_t a, int64_t b) { return a <= b; });
|
2021-08-11 09:28:50 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenGtIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenGtIntOp::fold(FoldAdaptor adaptor) {
|
2022-02-11 05:25:25 +08:00
|
|
|
return intComparatorFoldHelper(*this,
|
|
|
|
[](int64_t a, int64_t b) { return a > b; });
|
2021-08-11 09:28:50 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenGeIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenGeIntOp::fold(FoldAdaptor adaptor) {
|
2022-02-11 05:25:25 +08:00
|
|
|
return intComparatorFoldHelper(*this,
|
|
|
|
[](int64_t a, int64_t b) { return a >= b; });
|
2021-06-17 02:05:08 +08:00
|
|
|
}
|
|
|
|
|
2022-05-20 16:26:52 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenBoolFloatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenBoolFloatOp::fold(FoldAdaptor adaptor) {
|
2022-05-20 16:26:52 +08:00
|
|
|
double c;
|
|
|
|
if (matchPattern(getOperand(), m_TorchConstantFloat(&c)))
|
|
|
|
return getI1IntegerAttr(getContext(), c != 0.0);
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenBoolIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenBoolIntOp::fold(FoldAdaptor adaptor) {
|
2022-05-20 16:26:52 +08:00
|
|
|
int64_t c;
|
|
|
|
if (matchPattern(getOperand(), m_TorchConstantInt(&c)))
|
|
|
|
return getI1IntegerAttr(getContext(), c != 0);
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2023-08-30 17:29:03 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenAnyBoolOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenAnyBoolOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto inputConstruct = getSelf().getDefiningOp<Torch::PrimListConstructOp>();
|
|
|
|
if (!inputConstruct || isListPotentiallyMutated(inputConstruct))
|
|
|
|
return nullptr;
|
|
|
|
// If any operand is a constant true, return true.
|
|
|
|
for (auto operand : inputConstruct.getOperands()) {
|
2023-09-04 09:59:26 +08:00
|
|
|
bool b = false;
|
2023-08-30 17:29:03 +08:00
|
|
|
if (matchPattern(operand, m_TorchConstantBool(&b)) && b) {
|
|
|
|
return getI1IntegerAttr(getContext(), true);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenFloatScalarOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenFloatScalarOp::fold(FoldAdaptor adaptor) {
|
2022-03-10 08:44:22 +08:00
|
|
|
// Constant fold int -> float conversion.
|
2023-01-25 09:29:42 +08:00
|
|
|
if (auto integerAttr = adaptor.getA().dyn_cast_or_null<IntegerAttr>()) {
|
2022-03-10 08:44:22 +08:00
|
|
|
return FloatAttr::get(
|
|
|
|
mlir::Float64Type::get(getContext()),
|
|
|
|
static_cast<double>(integerAttr.getValue().getSExtValue()));
|
|
|
|
}
|
|
|
|
// If the input is float type already, the op is an identity.
|
|
|
|
if (getType() == getOperand().getType())
|
|
|
|
return getOperand();
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2022-04-26 20:15:30 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2023-02-11 05:59:03 +08:00
|
|
|
// AtenIntFloatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenIntFloatOp::fold(FoldAdaptor adaptor) {
|
|
|
|
// Constant fold float -> int conversion.
|
|
|
|
if (auto floatAttr = adaptor.getA().dyn_cast_or_null<FloatAttr>()) {
|
|
|
|
return IntegerAttr::get(
|
2023-05-23 08:21:34 +08:00
|
|
|
mlir::IntegerType::get(getContext(), 64),
|
2023-02-11 05:59:03 +08:00
|
|
|
static_cast<int64_t>(floatAttr.getValue().convertToDouble()));
|
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2022-04-26 20:15:30 +08:00
|
|
|
// AtenIntScalarOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenIntScalarOp::fold(FoldAdaptor adaptor) {
|
2022-04-26 20:15:30 +08:00
|
|
|
// Constant fold float -> int conversion.
|
2023-01-25 09:29:42 +08:00
|
|
|
if (auto floatAttr = adaptor.getA().dyn_cast_or_null<FloatAttr>()) {
|
2022-04-26 20:15:30 +08:00
|
|
|
return IntegerAttr::get(
|
2023-05-23 08:21:34 +08:00
|
|
|
mlir::IntegerType::get(getContext(), 64),
|
2022-04-26 20:15:30 +08:00
|
|
|
static_cast<long>(floatAttr.getValue().convertToDouble()));
|
|
|
|
}
|
|
|
|
// If the input is int type already, the op is an identity.
|
|
|
|
if (getType() == getOperand().getType())
|
|
|
|
return getOperand();
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2023-01-18 02:14:14 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenIntBoolOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenIntBoolOp::fold(FoldAdaptor adaptor) {
|
2023-01-18 02:14:14 +08:00
|
|
|
bool b;
|
|
|
|
if (matchPattern(getOperand(), m_TorchConstantBool(&b))) {
|
|
|
|
return getI64IntegerAttr(getContext(), static_cast<long>(b));
|
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2023-11-21 13:26:17 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenMaskedFillTensorOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
// Fold 0d fill tensor to scalar
|
|
|
|
void AtenMaskedFillTensorOp::getCanonicalizationPatterns(
|
|
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenMaskedFillTensorOp op, PatternRewriter &rewriter) {
|
|
|
|
auto scalarIntVal =
|
|
|
|
getScalarIntValue(op.getValue(), op->getLoc(), rewriter);
|
|
|
|
auto scalarFloatVal =
|
|
|
|
getScalarFloatValue(op.getValue(), op->getLoc(), rewriter);
|
|
|
|
if (!scalarIntVal && !scalarFloatVal)
|
|
|
|
return failure();
|
|
|
|
Value scalarVal = scalarIntVal ? scalarIntVal : scalarFloatVal;
|
|
|
|
rewriter.replaceOpWithNewOp<AtenMaskedFillScalarOp>(
|
|
|
|
op, op.getType(), op.getSelf(), op.getMask(), scalarVal);
|
|
|
|
return failure();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2022-11-18 19:47:07 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSortIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void AtenSortIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenSortIntOp op, PatternRewriter &rewriter) {
|
|
|
|
SmallVector<int64_t> listElements;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getSelf(), m_TorchListOfConstantInts(listElements)))
|
2022-11-18 19:47:07 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "all input list elements must be constant ints");
|
|
|
|
bool reverse;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getReverse(), m_TorchConstantBool(&reverse)))
|
2022-11-18 19:47:07 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "Expected reverse arg to be constant bool.");
|
|
|
|
|
|
|
|
std::sort(listElements.begin(), listElements.end());
|
|
|
|
if (reverse)
|
|
|
|
std::reverse(listElements.begin(), listElements.end());
|
|
|
|
|
|
|
|
SmallVector<Value> sortedListElements;
|
|
|
|
for (int64_t elem : listElements)
|
|
|
|
sortedListElements.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
op->getLoc(), rewriter.getI64IntegerAttr(elem)));
|
|
|
|
Value result = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
op->getLoc(), Torch::ListType::get(rewriter.getType<Torch::IntType>()),
|
|
|
|
sortedListElements);
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
op.getSelf().replaceAllUsesWith(result);
|
2022-11-18 19:47:07 +08:00
|
|
|
rewriter.eraseOp(op);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2021-06-17 23:52:13 +08:00
|
|
|
// NonValueTensorLiteralOp
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2021-06-17 23:52:13 +08:00
|
|
|
LogicalResult NonValueTensorLiteralOp::inferReturnTypes(
|
2022-12-20 18:17:27 +08:00
|
|
|
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
2023-05-09 06:33:24 +08:00
|
|
|
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
2021-06-17 23:52:13 +08:00
|
|
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
2023-09-13 06:09:57 +08:00
|
|
|
auto attr = properties.as<Properties *>()
|
|
|
|
->getValue()
|
|
|
|
.dyn_cast_or_null<ElementsAttr>();
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
if (!attr)
|
|
|
|
return failure();
|
2021-09-14 08:57:59 +08:00
|
|
|
RankedTensorType tensorType = attr.getType().cast<RankedTensorType>();
|
|
|
|
NonValueTensorType returnType =
|
|
|
|
NonValueTensorType::get(tensorType.getContext(), tensorType.getShape(),
|
|
|
|
tensorType.getElementType());
|
|
|
|
inferredReturnTypes.push_back(returnType);
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
static bool areSizesAndDtypesCompatible(BaseTensorType a, BaseTensorType b) {
|
|
|
|
if (a.hasSizes() && b.hasSizes()) {
|
2022-11-29 20:33:31 +08:00
|
|
|
if (failed(verifyCompatibleShape(makeShapeLLVMCompatible(a.getSizes()),
|
|
|
|
makeShapeLLVMCompatible(b.getSizes()))))
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
return false;
|
|
|
|
}
|
|
|
|
if (a.hasDtype() && b.hasDtype()) {
|
|
|
|
if (a.getDtype() != b.getDtype())
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2021-06-17 23:52:13 +08:00
|
|
|
bool NonValueTensorLiteralOp::isCompatibleReturnTypes(TypeRange inferred,
|
|
|
|
TypeRange actual) {
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
if (!actual[0].isa<BaseTensorType>())
|
|
|
|
return false;
|
|
|
|
return areSizesAndDtypesCompatible(inferred[0].cast<BaseTensorType>(),
|
|
|
|
actual[0].cast<BaseTensorType>());
|
|
|
|
}
|
|
|
|
|
2021-06-17 23:52:13 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ValueTensorLiteralOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult ValueTensorLiteralOp::inferReturnTypes(
|
2022-12-20 18:17:27 +08:00
|
|
|
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
2023-05-09 06:33:24 +08:00
|
|
|
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
2021-06-17 23:52:13 +08:00
|
|
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
2023-09-13 06:09:57 +08:00
|
|
|
auto attr = properties.as<Properties *>()
|
|
|
|
->getValue()
|
|
|
|
.dyn_cast_or_null<ElementsAttr>();
|
2021-06-17 23:52:13 +08:00
|
|
|
if (!attr)
|
|
|
|
return failure();
|
2021-09-14 08:57:59 +08:00
|
|
|
RankedTensorType tensorType = attr.getType().cast<RankedTensorType>();
|
|
|
|
ValueTensorType returnType =
|
|
|
|
ValueTensorType::get(tensorType.getContext(), tensorType.getShape(),
|
|
|
|
tensorType.getElementType());
|
|
|
|
inferredReturnTypes.push_back(returnType);
|
2021-06-17 23:52:13 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult ValueTensorLiteralOp::fold(FoldAdaptor adaptor) {
|
2022-12-08 04:20:41 +08:00
|
|
|
return getValueAttr();
|
2021-06-17 23:52:13 +08:00
|
|
|
}
|
|
|
|
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
//----------------------------------------------------------------------------//
|
|
|
|
// TensorStaticInfoCast
|
|
|
|
//----------------------------------------------------------------------------//
|
|
|
|
|
|
|
|
bool TensorStaticInfoCastOp::areCastCompatible(mlir::TypeRange inputs,
|
|
|
|
mlir::TypeRange outputs) {
|
|
|
|
return areSizesAndDtypesCompatible(inputs[0].cast<BaseTensorType>(),
|
|
|
|
outputs[0].cast<BaseTensorType>());
|
|
|
|
}
|
|
|
|
|
2021-10-16 06:23:59 +08:00
|
|
|
void TensorStaticInfoCastOp::getCanonicalizationPatterns(
|
|
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
|
|
patterns.add(+[](TensorStaticInfoCastOp op, PatternRewriter &rewriter) {
|
|
|
|
auto reverseCast =
|
2022-12-08 04:20:41 +08:00
|
|
|
op.getOperand().getDefiningOp<Torch::TensorStaticInfoCastOp>();
|
|
|
|
if (!reverseCast || reverseCast.getOperand().getType() != op.getType())
|
2021-10-16 06:23:59 +08:00
|
|
|
return failure();
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
rewriter.replaceOp(op, reverseCast.getOperand());
|
2021-10-16 06:23:59 +08:00
|
|
|
return success();
|
|
|
|
});
|
2022-03-10 08:44:22 +08:00
|
|
|
patterns.add(+[](TensorStaticInfoCastOp op, PatternRewriter &rewriter) {
|
2022-03-19 00:10:12 +08:00
|
|
|
if (isValidSubtype(op.getOperand().getType(), op.getType())) {
|
|
|
|
SmallVector<std::reference_wrapper<OpOperand>> usesToChange(
|
|
|
|
llvm::make_filter_range(op->getUses(), [](OpOperand &operand) {
|
|
|
|
return operand.getOwner()
|
|
|
|
->hasTrait<mlir::torch::Torch::OpTrait::AllowsTypeRefinement>();
|
|
|
|
}));
|
|
|
|
|
|
|
|
if (usesToChange.empty())
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
for (OpOperand &use : usesToChange) {
|
|
|
|
Operation *user = use.getOwner();
|
2022-12-08 04:20:41 +08:00
|
|
|
user->setOperand(use.getOperandNumber(), op.getOperand());
|
2022-03-19 00:10:12 +08:00
|
|
|
}
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
return failure();
|
|
|
|
});
|
2021-10-16 06:23:59 +08:00
|
|
|
}
|
|
|
|
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2021-06-19 04:47:47 +08:00
|
|
|
// CopyToNonValueTensorOp
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-03-16 08:54:57 +08:00
|
|
|
LogicalResult CopyToNonValueTensorOp::verify() {
|
|
|
|
auto resultType = getResult().getType().cast<BaseTensorType>();
|
|
|
|
auto operandType = getOperand().getType().cast<BaseTensorType>();
|
|
|
|
if (!resultType.hasSameSizesAndDtype(operandType))
|
|
|
|
return emitError() << "operand and result must have same sizes and dtype";
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2021-06-19 04:47:47 +08:00
|
|
|
LogicalResult CopyToNonValueTensorOp::inferReturnTypes(
|
2022-12-20 18:17:27 +08:00
|
|
|
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
2023-05-09 06:33:24 +08:00
|
|
|
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
2021-06-19 04:47:47 +08:00
|
|
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
|
|
|
auto resultType = operands[0].getType().cast<ValueTensorType>();
|
|
|
|
inferredReturnTypes.push_back(resultType.getWithoutValueSemantics());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
void CopyToNonValueTensorOp::getEffects(
|
|
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
|
|
&effects) {
|
|
|
|
effects.emplace_back(MemoryEffects::Allocate::get(), getResult());
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// CopyToValueTensorOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-03-16 08:54:57 +08:00
|
|
|
LogicalResult CopyToValueTensorOp::verify() {
|
|
|
|
auto resultType = getResult().getType().cast<BaseTensorType>();
|
|
|
|
auto operandType = getOperand().getType().cast<BaseTensorType>();
|
|
|
|
if (!resultType.hasSameSizesAndDtype(operandType))
|
|
|
|
return emitError() << "operand and result must have same sizes and dtype";
|
2021-06-19 04:47:47 +08:00
|
|
|
return success();
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
}
|
|
|
|
|
2021-06-19 04:47:47 +08:00
|
|
|
LogicalResult CopyToValueTensorOp::inferReturnTypes(
|
2022-12-20 18:17:27 +08:00
|
|
|
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
2023-05-09 06:33:24 +08:00
|
|
|
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
2021-06-19 04:47:47 +08:00
|
|
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
|
|
|
auto resultType = operands[0].getType().cast<NonValueTensorType>();
|
|
|
|
inferredReturnTypes.push_back(resultType.getWithValueSemantics());
|
|
|
|
return success();
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
}
|
|
|
|
|
2021-06-19 04:47:47 +08:00
|
|
|
void CopyToValueTensorOp::getEffects(
|
|
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
2021-06-18 07:29:20 +08:00
|
|
|
&effects) {
|
2021-06-19 04:47:47 +08:00
|
|
|
effects.emplace_back(MemoryEffects::Read::get(), getOperand());
|
2021-06-18 07:29:20 +08:00
|
|
|
}
|
|
|
|
|
2021-06-15 02:36:10 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ConstantNoneOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult ConstantNoneOp::fold(FoldAdaptor adaptor) {
|
2021-06-15 02:36:10 +08:00
|
|
|
return TypeAttr::get(Torch::NoneType::get(getContext()));
|
|
|
|
}
|
|
|
|
|
2021-06-16 03:42:51 +08:00
|
|
|
void ConstantNoneOp::getAsmResultNames(
|
|
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
|
|
|
setNameFn(getResult(), "none");
|
|
|
|
}
|
|
|
|
|
2021-06-15 23:29:06 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ConstantStrOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult ConstantStrOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
|
2021-06-15 23:29:06 +08:00
|
|
|
|
2021-06-16 03:42:51 +08:00
|
|
|
void ConstantStrOp::getAsmResultNames(
|
|
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
|
|
|
setNameFn(getResult(), "str");
|
|
|
|
}
|
|
|
|
|
2021-09-28 22:56:08 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ConstantDeviceOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void ConstantDeviceOp::getAsmResultNames(
|
|
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
2022-12-08 04:20:41 +08:00
|
|
|
setNameFn(getResult(), getValue());
|
2021-09-28 22:56:08 +08:00
|
|
|
}
|
|
|
|
|
2021-06-16 03:42:51 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ConstantIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-13 02:47:12 +08:00
|
|
|
ParseResult ConstantIntOp::parse(OpAsmParser &parser, OperationState &result) {
|
2021-06-17 06:53:15 +08:00
|
|
|
Builder builder(result.getContext());
|
|
|
|
result.addTypes(builder.getType<Torch::IntType>());
|
|
|
|
if (parser.parseOptionalAttrDict(result.attributes))
|
|
|
|
return failure();
|
|
|
|
int64_t value;
|
|
|
|
if (parser.parseInteger(value))
|
|
|
|
return failure();
|
|
|
|
result.addAttribute("value", builder.getI64IntegerAttr(value));
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-02-13 02:47:12 +08:00
|
|
|
void ConstantIntOp::print(OpAsmPrinter &p) {
|
2021-09-04 02:38:00 +08:00
|
|
|
p << " ";
|
2023-05-23 08:21:34 +08:00
|
|
|
p << getValueAttr().getInt();
|
2022-02-13 02:47:12 +08:00
|
|
|
p.printOptionalAttrDict((*this)->getAttrs(), {"value"});
|
2021-06-17 06:53:15 +08:00
|
|
|
}
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult Torch::ConstantIntOp::fold(FoldAdaptor adaptor) {
|
2022-12-08 04:20:41 +08:00
|
|
|
return getValueAttr();
|
2021-06-16 03:42:51 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
void Torch::ConstantIntOp::getAsmResultNames(
|
|
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
|
|
|
SmallVector<char> buf;
|
|
|
|
llvm::raw_svector_ostream os(buf);
|
2023-05-23 08:21:34 +08:00
|
|
|
os << "int" << getValueAttr().getInt();
|
2021-06-16 03:42:51 +08:00
|
|
|
setNameFn(getResult(), os.str());
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ConstantFloatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult Torch::ConstantFloatOp::fold(FoldAdaptor adaptor) {
|
2022-12-08 04:20:41 +08:00
|
|
|
return getValueAttr();
|
2021-06-16 03:42:51 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
void Torch::ConstantFloatOp::getAsmResultNames(
|
|
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
2021-06-23 05:25:16 +08:00
|
|
|
// Calculate a stringified version of the number, compatible with MLIR
|
|
|
|
// identifier syntax. (in practice, this just removes the '+' from 'e+' in
|
|
|
|
// float string representation).
|
|
|
|
SmallVector<char> buf;
|
2022-12-08 04:20:41 +08:00
|
|
|
getValue().toString(buf, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0,
|
2024-01-30 01:59:33 +08:00
|
|
|
/*TruncateZero=*/false);
|
2021-06-23 05:25:16 +08:00
|
|
|
auto isValidMLIRIdentifierChar = [](char c) {
|
|
|
|
return isalpha(c) || isdigit(c) || c == '_' || c == '$' || c == '.' ||
|
|
|
|
c == '-';
|
|
|
|
};
|
|
|
|
auto numberStr = llvm::to_vector<16>(
|
|
|
|
llvm::make_filter_range(buf, isValidMLIRIdentifierChar));
|
|
|
|
|
|
|
|
// Construct the identifier string.
|
|
|
|
buf.clear();
|
|
|
|
llvm::append_range(buf, StringRef("float"));
|
|
|
|
llvm::append_range(buf, numberStr);
|
|
|
|
setNameFn(getResult(), StringRef(buf.data(), buf.size()));
|
2021-06-16 03:42:51 +08:00
|
|
|
}
|
|
|
|
|
2022-09-20 12:40:19 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ConstantNumberOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult Torch::ConstantNumberOp::fold(FoldAdaptor adaptor) {
|
2022-12-08 04:20:41 +08:00
|
|
|
return getValueAttr();
|
2022-09-20 12:40:19 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
void Torch::ConstantNumberOp::getCanonicalizationPatterns(
|
|
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
|
|
patterns.add(+[](Torch::ConstantNumberOp op, PatternRewriter &rewriter) {
|
|
|
|
Location loc = op->getLoc();
|
|
|
|
|
|
|
|
Value constValue;
|
2022-12-08 04:20:41 +08:00
|
|
|
Attribute value = op.getValueAttr();
|
2022-09-20 12:40:19 +08:00
|
|
|
if (auto floatValue = value.dyn_cast<mlir::FloatAttr>()) {
|
|
|
|
constValue = rewriter.create<Torch::ConstantFloatOp>(loc, floatValue);
|
|
|
|
} else if (auto intValue = value.dyn_cast<mlir::IntegerAttr>()) {
|
|
|
|
constValue = rewriter.create<Torch::ConstantIntOp>(loc, intValue);
|
|
|
|
} else {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::DerefineOp>(op, op.getType(),
|
|
|
|
constValue);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2021-06-16 07:47:53 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ConstantBoolOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult Torch::ConstantBoolOp::fold(FoldAdaptor adaptor) {
|
2022-12-08 04:20:41 +08:00
|
|
|
return getValueAttr();
|
2021-06-16 07:47:53 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
void Torch::ConstantBoolOp::getAsmResultNames(
|
|
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
2022-12-08 04:20:41 +08:00
|
|
|
setNameFn(getResult(), getValue() ? "true" : "false");
|
2021-06-16 07:47:53 +08:00
|
|
|
}
|
|
|
|
|
2021-06-05 06:57:21 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2021-06-23 04:56:12 +08:00
|
|
|
// PrimUncheckedCastOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
bool PrimUncheckedCastOp::areCastCompatible(mlir::TypeRange inputs,
|
|
|
|
mlir::TypeRange outputs) {
|
|
|
|
return isValidSubtype(outputs[0], inputs[0]);
|
|
|
|
}
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult PrimUncheckedCastOp::fold(FoldAdaptor adaptor) {
|
2022-12-08 04:20:41 +08:00
|
|
|
if (auto derefineOp = getX().getDefiningOp<Torch::DerefineOp>()) {
|
|
|
|
if (derefineOp.getOperand().getType() == getType())
|
|
|
|
return derefineOp.getOperand();
|
2022-01-28 21:35:40 +08:00
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2021-06-23 04:56:12 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2021-06-05 06:57:21 +08:00
|
|
|
// Aten__Getitem__TOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2021-06-19 10:33:14 +08:00
|
|
|
void Aten__Getitem__TOp::getCanonicalizationPatterns(
|
|
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
2021-06-05 06:57:21 +08:00
|
|
|
patterns.add(+[](Aten__Getitem__TOp op, PatternRewriter &rewriter) {
|
|
|
|
auto torchList = op.getOperand(0);
|
2022-03-10 08:44:22 +08:00
|
|
|
if (isListPotentiallyMutated(torchList))
|
2021-06-05 06:57:21 +08:00
|
|
|
return failure();
|
|
|
|
|
|
|
|
auto listConstruct = torchList.getDefiningOp<Torch::PrimListConstructOp>();
|
|
|
|
if (!listConstruct)
|
|
|
|
return failure();
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
// Get the index, but be careful because it might be statically invalid.
|
2022-12-20 18:17:27 +08:00
|
|
|
std::optional<int64_t> indexOpt = matchLegalConstantIndexIntoListOfSize(
|
2022-03-30 04:21:47 +08:00
|
|
|
op.getOperand(1), listConstruct.getNumOperands());
|
|
|
|
if (!indexOpt)
|
2022-03-10 08:44:22 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "statically invalid index");
|
2021-06-05 06:57:21 +08:00
|
|
|
|
2022-03-30 04:21:47 +08:00
|
|
|
rewriter.replaceOp(op, {listConstruct.getOperand(*indexOpt)});
|
2021-06-05 06:57:21 +08:00
|
|
|
return success();
|
|
|
|
});
|
2022-03-10 08:44:22 +08:00
|
|
|
patterns.add(+[](Aten__Getitem__TOp op, PatternRewriter &rewriter) {
|
2022-12-08 04:20:41 +08:00
|
|
|
auto sizeOp = op.getList().getDefiningOp<AtenSizeOp>();
|
2022-03-10 08:44:22 +08:00
|
|
|
if (!sizeOp)
|
|
|
|
return failure();
|
|
|
|
// This assumes tht the size doesn't change between the
|
|
|
|
// AtenSizeOp and the Aten__Getitem__TOp.
|
|
|
|
// `t_` is the only op I can find that changes the shape in-place. It seems
|
|
|
|
// like otherwise we can treat the size of a tensor as having value
|
|
|
|
// semantics. The other view-like ops don't have in-place variants --
|
|
|
|
// they always return a new SSA value that is aliased to the input.
|
|
|
|
// Can we have a pass to normalize the `t_` case and then elsewhere in the
|
|
|
|
// compiler treat the size as having value semantics?
|
|
|
|
// There's a small number of such ops, and they are marked as `inplace_view`
|
|
|
|
// in PyTorch's `native_functions.yaml` file.
|
2024-01-30 01:59:33 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenSizeIntOp>(op, sizeOp.getSelf(),
|
|
|
|
op.getIdx());
|
2022-03-10 08:44:22 +08:00
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2023-06-07 17:05:31 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenIsFloatingPointOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenIsFloatingPointOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto operandType = getSelf().getType().dyn_cast<BaseTensorType>();
|
|
|
|
if (!operandType)
|
|
|
|
return nullptr;
|
|
|
|
if (operandType.hasDtype()) {
|
|
|
|
bool isFloatType = operandType.getDtype().isa<mlir::FloatType>();
|
|
|
|
return IntegerAttr::get(IntegerType::get(getContext(), 1), isFloatType);
|
|
|
|
}
|
|
|
|
// doesn't has dtype
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2022-07-29 07:00:02 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenAddTOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void AtenAddTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenAddTOp op, PatternRewriter &rewriter) {
|
2024-01-30 01:59:33 +08:00
|
|
|
auto lhsListConstruct =
|
|
|
|
op.getA().getDefiningOp<Torch::PrimListConstructOp>();
|
2022-07-29 07:00:02 +08:00
|
|
|
if (!lhsListConstruct || isListPotentiallyMutated(lhsListConstruct))
|
|
|
|
return failure();
|
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
auto rhsListConstruct =
|
|
|
|
op.getB().getDefiningOp<Torch::PrimListConstructOp>();
|
2022-07-29 07:00:02 +08:00
|
|
|
if (!rhsListConstruct || isListPotentiallyMutated(rhsListConstruct))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
SmallVector<Value> concatenatedList;
|
|
|
|
for (auto a : lhsListConstruct.getOperands()) {
|
|
|
|
concatenatedList.push_back(a);
|
|
|
|
}
|
|
|
|
for (auto b : rhsListConstruct.getOperands()) {
|
|
|
|
concatenatedList.push_back(b);
|
|
|
|
}
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::PrimListConstructOp>(op, op.getType(),
|
|
|
|
concatenatedList);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2022-09-27 05:35:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSliceTOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void AtenSliceTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenSliceTOp op, PatternRewriter &rewriter) {
|
2022-12-08 04:20:41 +08:00
|
|
|
auto valueList = op.getL();
|
2022-09-27 05:35:50 +08:00
|
|
|
auto listConstructOp = valueList.getDefiningOp<PrimListConstructOp>();
|
|
|
|
if (!listConstructOp || isListPotentiallyMutated(listConstructOp)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<Value> listElements =
|
2022-12-08 04:20:41 +08:00
|
|
|
llvm::to_vector<4>(listConstructOp.getElements());
|
2022-09-27 05:35:50 +08:00
|
|
|
int64_t size = static_cast<int64_t>(listElements.size());
|
|
|
|
|
|
|
|
int64_t start;
|
|
|
|
int64_t end;
|
|
|
|
int64_t step;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (op.getStart().getType().isa<Torch::NoneType>()) {
|
2022-09-27 05:35:50 +08:00
|
|
|
start = 0;
|
2022-12-08 04:20:41 +08:00
|
|
|
} else if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) {
|
2022-09-27 05:35:50 +08:00
|
|
|
return failure();
|
|
|
|
}
|
2022-12-08 04:20:41 +08:00
|
|
|
if (op.getEnd().getType().isa<Torch::NoneType>()) {
|
2022-09-27 05:35:50 +08:00
|
|
|
end = listElements.size();
|
2022-12-08 04:20:41 +08:00
|
|
|
} else if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) {
|
2022-09-27 05:35:50 +08:00
|
|
|
return failure();
|
|
|
|
}
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) {
|
2022-09-27 05:35:50 +08:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
start = start >= 0 ? start : start + size;
|
|
|
|
start = start >= 0 ? start : 0;
|
|
|
|
end = end >= 0 ? end : end + size;
|
|
|
|
end = end < size ? end : size;
|
|
|
|
SmallVector<Value> newListElements;
|
|
|
|
|
|
|
|
for (int64_t i = start; i < end; i += step) {
|
|
|
|
newListElements.push_back(listElements[i]);
|
|
|
|
}
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<PrimListConstructOp>(
|
|
|
|
op, Torch::ListType::get(listElements[0].getType()), newListElements);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenEqIntListOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenEqIntListOp::fold(FoldAdaptor adaptor) {
|
2022-12-08 04:20:41 +08:00
|
|
|
auto lhsLiteral = getA().getDefiningOp<Torch::PrimListConstructOp>();
|
2022-03-10 08:44:22 +08:00
|
|
|
if (!lhsLiteral)
|
|
|
|
return nullptr;
|
2022-12-08 04:20:41 +08:00
|
|
|
auto rhsLiteral = getB().getDefiningOp<Torch::PrimListConstructOp>();
|
2022-03-10 08:44:22 +08:00
|
|
|
if (!rhsLiteral)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
// If the sizes don't match, then we know the lists aren't equal.
|
|
|
|
if (lhsLiteral.getNumOperands() != rhsLiteral.getNumOperands())
|
|
|
|
return getI1IntegerAttr(getContext(), false);
|
|
|
|
|
|
|
|
// If the sizes match and all corresponding list elements are the same Value,
|
|
|
|
// then we know the lists are equal.
|
|
|
|
// Note that we can't prove that the lists are not-equal with this method,
|
|
|
|
// since two different Value's might dynamically be equal.
|
|
|
|
if (llvm::all_of(
|
|
|
|
llvm::zip(lhsLiteral.getOperands(), rhsLiteral.getOperands()),
|
|
|
|
[](const auto &pair) {
|
|
|
|
return std::get<0>(pair) == std::get<1>(pair);
|
|
|
|
}))
|
|
|
|
return getI1IntegerAttr(getContext(), true);
|
|
|
|
return nullptr;
|
2021-06-05 06:57:21 +08:00
|
|
|
}
|
|
|
|
|
2023-01-24 08:34:22 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimTupleConstructOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult PrimTupleConstructOp::verify() {
|
|
|
|
if (!(isValidSubtype(
|
|
|
|
Torch::TupleType::get(getContext(),
|
|
|
|
llvm::to_vector<6>(getElements().getType())),
|
|
|
|
getResult().getType())))
|
|
|
|
return emitOpError(
|
|
|
|
"failed to verify that contained types correspond to operand types");
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2021-11-08 23:56:40 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimTupleIndexOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](PrimTupleIndexOp op, PatternRewriter &rewriter) {
|
2024-01-30 01:59:33 +08:00
|
|
|
auto tupleConstruct =
|
|
|
|
op.getTup().getDefiningOp<Torch::PrimTupleConstructOp>();
|
2021-11-08 23:56:40 +08:00
|
|
|
if (!tupleConstruct)
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
int64_t i;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getI(), m_TorchConstantInt(&i)))
|
2021-11-08 23:56:40 +08:00
|
|
|
return failure();
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
if (i >= (int64_t)tupleConstruct.getElements().size())
|
2021-11-08 23:56:40 +08:00
|
|
|
return failure();
|
|
|
|
|
2022-05-03 17:12:09 +08:00
|
|
|
// TODO: We should have a clear picture of whether we want to consistently
|
|
|
|
// allow refinement, and where. It seems desirable to require precise
|
|
|
|
// type equality for TupleConstruct / TupleIndex, but that might break
|
|
|
|
// things.
|
2022-12-08 04:20:41 +08:00
|
|
|
Value replacement = tupleConstruct.getElements()[i];
|
2022-05-19 21:12:58 +08:00
|
|
|
if (replacement.getType() != op.getType()) {
|
|
|
|
if (op.getType().isa<BaseTensorType>()) {
|
|
|
|
replacement = rewriter.create<Torch::TensorStaticInfoCastOp>(
|
|
|
|
op.getLoc(), op.getType(), replacement);
|
|
|
|
} else {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
rewriter.replaceOp(op, replacement);
|
2021-11-08 23:56:40 +08:00
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimUninitializedOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void PrimUninitializedOp::getCanonicalizationPatterns(
|
|
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
|
|
patterns.add(+[](PrimUninitializedOp op, PatternRewriter &rewriter) {
|
|
|
|
if (!op.use_empty())
|
|
|
|
return failure();
|
|
|
|
rewriter.eraseOp(op);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2021-08-18 01:59:47 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimTupleUnpackOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void PrimTupleUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](PrimTupleUnpackOp op, PatternRewriter &rewriter) {
|
2024-01-30 01:59:33 +08:00
|
|
|
auto tupleConstruct =
|
|
|
|
op.getTup().getDefiningOp<Torch::PrimTupleConstructOp>();
|
2021-08-18 01:59:47 +08:00
|
|
|
if (!tupleConstruct)
|
|
|
|
return failure();
|
|
|
|
|
2023-07-18 22:32:26 +08:00
|
|
|
llvm::SmallVector<Value> derefinedElements;
|
|
|
|
// The result types may be supertypes of the tuple element types.
|
|
|
|
// Ensure we maintain the exact type, with identity `derefine`s being
|
|
|
|
// folded.
|
|
|
|
for (auto [type, element] :
|
|
|
|
llvm::zip(op.getResultTypes(), tupleConstruct.getElements())) {
|
|
|
|
derefinedElements.push_back(
|
|
|
|
rewriter.createOrFold<DerefineOp>(op.getLoc(), type, element));
|
|
|
|
}
|
|
|
|
rewriter.replaceOp(op, derefinedElements);
|
2021-08-18 01:59:47 +08:00
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2022-07-15 20:10:23 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimListUnpackOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void PrimListUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](PrimListUnpackOp op, PatternRewriter &rewriter) {
|
2022-12-08 04:20:41 +08:00
|
|
|
auto torchList = op.getOperand();
|
2022-07-15 20:10:23 +08:00
|
|
|
if (isListPotentiallyMutated(torchList)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
auto listConstruct = torchList.getDefiningOp<Torch::PrimListConstructOp>();
|
|
|
|
if (!listConstruct)
|
|
|
|
return failure();
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
rewriter.replaceOp(op, listConstruct.getElements());
|
2022-07-15 20:10:23 +08:00
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2021-08-18 01:59:47 +08:00
|
|
|
static PrimDictConstructOp getDictConstructIfNotModified(Value torchDict) {
|
|
|
|
if (!llvm::all_of(torchDict.getUsers(), [](Operation *op) {
|
|
|
|
return isa<Aten__Getitem__DictStrOp, Aten__Contains__StrOp,
|
2021-08-28 05:18:29 +08:00
|
|
|
AtenKeysStrOp, AtenGetDefaultStrOp, PrimDictConstructOp>(op);
|
2021-08-18 01:59:47 +08:00
|
|
|
}))
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
return torchDict.getDefiningOp<Torch::PrimDictConstructOp>();
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten__Getitem__DictStrOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult Aten__Getitem__DictStrOp::fold(FoldAdaptor adaptor) {
|
2022-12-08 04:20:41 +08:00
|
|
|
auto dictConstruct = getDictConstructIfNotModified(getSelf());
|
2021-08-18 01:59:47 +08:00
|
|
|
if (!dictConstruct)
|
|
|
|
return nullptr;
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
auto targetKey = getKey();
|
|
|
|
for (auto i : llvm::zip(dictConstruct.getKeys(), dictConstruct.getValues())) {
|
2021-08-18 01:59:47 +08:00
|
|
|
auto k = std::get<0>(i);
|
|
|
|
if (k == targetKey)
|
|
|
|
return std::get<1>(i);
|
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten__Contains__StrOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult Aten__Contains__StrOp::fold(FoldAdaptor adaptor) {
|
2022-12-08 04:20:41 +08:00
|
|
|
auto dictConstruct = getDictConstructIfNotModified(getDict());
|
2021-08-18 01:59:47 +08:00
|
|
|
if (!dictConstruct)
|
|
|
|
return nullptr;
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
auto targetKey = getKey();
|
|
|
|
for (auto key : dictConstruct.getKeys()) {
|
2021-08-18 01:59:47 +08:00
|
|
|
if (key == targetKey)
|
|
|
|
return getI1IntegerAttr(getContext(), true);
|
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2022-06-23 23:16:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten__Contains__IntListOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static bool isListConstructNotModified(Value torchList) {
|
|
|
|
return llvm::all_of(torchList.getUsers(), [](Operation *op) {
|
2022-08-16 13:24:08 +08:00
|
|
|
return isa<Aten__Contains__IntListOp>(op);
|
|
|
|
});
|
2022-06-23 23:16:09 +08:00
|
|
|
}
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult Aten__Contains__IntListOp::fold(FoldAdaptor adaptor) {
|
2022-12-08 04:20:41 +08:00
|
|
|
auto itemConstruct = getItem();
|
|
|
|
if (!isListConstructNotModified(getL()))
|
2022-06-23 23:16:09 +08:00
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
int64_t item;
|
|
|
|
SmallVector<int64_t> list;
|
|
|
|
|
|
|
|
if (!matchPattern(itemConstruct, m_TorchConstantInt(&item)))
|
|
|
|
return nullptr;
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(getL(), m_TorchListOfConstantInts(list)))
|
2022-06-23 23:16:09 +08:00
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
for (auto elem : list) {
|
|
|
|
if (elem == item)
|
|
|
|
return getI1IntegerAttr(getContext(), true);
|
|
|
|
}
|
|
|
|
return getI1IntegerAttr(getContext(), false);
|
|
|
|
}
|
|
|
|
|
2021-08-18 01:59:47 +08:00
|
|
|
using BinaryIntOperatorFn = std::function<int64_t(int64_t, int64_t)>;
|
2022-09-20 12:40:19 +08:00
|
|
|
static OpFoldResult
|
|
|
|
atenBinaryIntOperatorFoldHelper(ArrayRef<Attribute> operands,
|
|
|
|
BinaryIntOperatorFn f) {
|
|
|
|
auto intLhs = operands[0].dyn_cast_or_null<IntegerAttr>();
|
|
|
|
auto intRhs = operands[1].dyn_cast_or_null<IntegerAttr>();
|
|
|
|
if (!intLhs || !intRhs) {
|
2021-08-18 01:59:47 +08:00
|
|
|
return nullptr;
|
2022-09-20 12:40:19 +08:00
|
|
|
}
|
|
|
|
return IntegerAttr::get(
|
|
|
|
intLhs.getType(),
|
|
|
|
f(intLhs.getValue().getSExtValue(), intRhs.getValue().getSExtValue()));
|
|
|
|
}
|
2021-08-18 01:59:47 +08:00
|
|
|
|
2022-09-20 12:40:19 +08:00
|
|
|
using BinaryFloatOperatorFn = std::function<double(double, double)>;
|
|
|
|
static OpFoldResult
|
|
|
|
atenBinaryFloatOperatorFoldHelper(ArrayRef<Attribute> operands,
|
|
|
|
BinaryFloatOperatorFn f) {
|
|
|
|
double lhs, rhs;
|
|
|
|
auto parseDoubleAttribute = [](Attribute attr, double &value) -> bool {
|
|
|
|
if (auto intLhs = attr.dyn_cast_or_null<IntegerAttr>()) {
|
|
|
|
value = static_cast<double>(intLhs.getValue().getSExtValue());
|
|
|
|
} else if (auto floatLhs = attr.dyn_cast_or_null<FloatAttr>()) {
|
|
|
|
value = floatLhs.getValue().convertToDouble();
|
|
|
|
} else {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
};
|
|
|
|
if (!parseDoubleAttribute(operands[0], lhs) ||
|
|
|
|
!parseDoubleAttribute(operands[1], rhs)) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
return getF64FloatAttr(operands[0].getContext(), f(lhs, rhs));
|
2021-08-18 01:59:47 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2023-06-21 01:14:09 +08:00
|
|
|
// AtenAliasOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
OpFoldResult AtenAliasOp::fold(FoldAdaptor adaptor) { return getOperand(); }
|
2023-06-21 01:14:09 +08:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2021-08-18 01:59:47 +08:00
|
|
|
// AtenFloordivIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenFloordivIntOp::fold(FoldAdaptor adaptor) {
|
2021-08-18 01:59:47 +08:00
|
|
|
return atenBinaryIntOperatorFoldHelper(
|
2023-01-25 09:29:42 +08:00
|
|
|
adaptor.getOperands(),
|
|
|
|
[](int64_t a, int64_t b) { return std::floor(a / (double)b); });
|
2021-08-18 01:59:47 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenRemainderIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenRemainderIntOp::fold(FoldAdaptor adaptor) {
|
2021-08-18 01:59:47 +08:00
|
|
|
return atenBinaryIntOperatorFoldHelper(
|
2023-01-25 09:29:42 +08:00
|
|
|
adaptor.getOperands(), [](int64_t a, int64_t b) { return a % b; });
|
2021-08-18 01:59:47 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenAddIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenAddIntOp::fold(FoldAdaptor adaptor) {
|
2021-08-18 01:59:47 +08:00
|
|
|
return atenBinaryIntOperatorFoldHelper(
|
2023-01-25 09:29:42 +08:00
|
|
|
adaptor.getOperands(), [](int64_t a, int64_t b) { return a + b; });
|
2021-08-18 01:59:47 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSubIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenSubIntOp::fold(FoldAdaptor adaptor) {
|
2021-08-18 01:59:47 +08:00
|
|
|
return atenBinaryIntOperatorFoldHelper(
|
2023-01-25 09:29:42 +08:00
|
|
|
adaptor.getOperands(), [](int64_t a, int64_t b) { return a - b; });
|
2021-08-18 01:59:47 +08:00
|
|
|
}
|
|
|
|
|
2022-12-14 05:02:47 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenCatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) {
|
2022-12-14 05:02:47 +08:00
|
|
|
auto list = getOperand(0).getDefiningOp<PrimListConstructOp>();
|
|
|
|
if (!list || !list->hasOneUse() || list.getElements().size() != 1)
|
|
|
|
return nullptr;
|
2024-01-05 06:33:41 +08:00
|
|
|
if (list.getElements()[0].getType() != getResult().getType())
|
|
|
|
return nullptr;
|
2022-12-14 05:02:47 +08:00
|
|
|
return list.getElements()[0];
|
|
|
|
}
|
|
|
|
|
2023-09-02 02:50:34 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenBroadcastToOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto inType = getOperand(0).getType().dyn_cast<BaseTensorType>();
|
|
|
|
auto outType = getResult().getType().dyn_cast<BaseTensorType>();
|
2024-01-05 06:33:41 +08:00
|
|
|
if (inType != outType)
|
|
|
|
return nullptr;
|
2023-09-02 02:50:34 +08:00
|
|
|
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes())
|
|
|
|
return nullptr;
|
|
|
|
if (inType.getSizes().size() != outType.getSizes().size() ||
|
2023-10-05 21:02:10 +08:00
|
|
|
(!isAssumingStrictSymbolicShapes((*this)->getBlock()) &&
|
|
|
|
(!inType.areAllSizesKnown() || !outType.areAllSizesKnown())))
|
2023-09-02 02:50:34 +08:00
|
|
|
return nullptr;
|
|
|
|
for (size_t i = 0; i < inType.getSizes().size(); ++i) {
|
|
|
|
if (inType.getSizes()[i] != outType.getSizes()[i])
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
return getOperand(0);
|
|
|
|
}
|
|
|
|
|
2022-12-14 05:02:47 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSliceTensorOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
|
2024-01-30 01:59:33 +08:00
|
|
|
int64_t start, end, step;
|
|
|
|
if (matchPattern(getStart(), m_TorchConstantInt(&start)) &&
|
|
|
|
matchPattern(getEnd(), m_TorchConstantInt(&end)) &&
|
|
|
|
matchPattern(getStep(), m_TorchConstantInt(&step)) && step == 1 &&
|
|
|
|
start == 0 && end == std::numeric_limits<int64_t>::max())
|
|
|
|
return getOperand(0);
|
2023-07-20 15:53:54 +08:00
|
|
|
|
|
|
|
auto inType = getOperand(0).getType().dyn_cast<BaseTensorType>();
|
|
|
|
auto outType = getResult().getType().dyn_cast<BaseTensorType>();
|
2024-01-05 06:33:41 +08:00
|
|
|
if (inType != outType)
|
|
|
|
return nullptr;
|
2022-12-14 05:02:47 +08:00
|
|
|
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes())
|
|
|
|
return nullptr;
|
|
|
|
if (inType.getSizes().size() != outType.getSizes().size() ||
|
|
|
|
!inType.areAllSizesKnown() || !outType.areAllSizesKnown())
|
|
|
|
return nullptr;
|
|
|
|
for (size_t i = 0; i < inType.getSizes().size(); ++i) {
|
|
|
|
if (inType.getSizes()[i] != outType.getSizes()[i])
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
return getOperand(0);
|
|
|
|
}
|
|
|
|
|
2021-08-18 01:59:47 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenMulIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) {
|
2021-08-18 01:59:47 +08:00
|
|
|
int64_t lhs, rhs;
|
|
|
|
bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs));
|
|
|
|
bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs));
|
|
|
|
if ((lConstant && lhs == 0) || (rConstant && rhs == 0))
|
|
|
|
return getI64IntegerAttr(getContext(), 0);
|
|
|
|
if (lConstant && rConstant)
|
|
|
|
return getI64IntegerAttr(getContext(), lhs * rhs);
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2023-06-29 10:37:13 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenMulFloatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenMulFloatOp::fold(FoldAdaptor adaptor) {
|
|
|
|
return atenBinaryFloatOperatorFoldHelper(
|
|
|
|
adaptor.getOperands(), [](double a, double b) { return a * b; });
|
|
|
|
}
|
|
|
|
|
2023-03-03 01:07:33 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSubFloatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenSubFloatOp::fold(FoldAdaptor adaptor) {
|
|
|
|
return atenBinaryFloatOperatorFoldHelper(
|
|
|
|
adaptor.getOperands(), [](double a, double b) { return a - b; });
|
|
|
|
}
|
|
|
|
|
2023-06-27 10:55:28 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenAddOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenAddOp::fold(FoldAdaptor adaptor) {
|
|
|
|
if (!adaptor.getA() || !adaptor.getB()) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (adaptor.getA().isa<IntegerAttr>() && adaptor.getB().isa<IntegerAttr>()) {
|
|
|
|
return atenBinaryIntOperatorFoldHelper(
|
|
|
|
adaptor.getOperands(),
|
|
|
|
[](int64_t a, int64_t b) -> int64_t { return a + b; });
|
|
|
|
}
|
|
|
|
return atenBinaryFloatOperatorFoldHelper(
|
|
|
|
adaptor.getOperands(),
|
|
|
|
[](double a, double b) -> double { return a + b; });
|
|
|
|
}
|
|
|
|
|
2022-09-20 12:40:19 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSubOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenSubOp::fold(FoldAdaptor adaptor) {
|
|
|
|
if (!adaptor.getA() || !adaptor.getB()) {
|
2022-09-20 12:40:19 +08:00
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
if (adaptor.getA().isa<IntegerAttr>() && adaptor.getB().isa<IntegerAttr>()) {
|
2022-09-20 12:40:19 +08:00
|
|
|
return atenBinaryIntOperatorFoldHelper(
|
2023-01-25 09:29:42 +08:00
|
|
|
adaptor.getOperands(),
|
|
|
|
[](int64_t a, int64_t b) -> int64_t { return a - b; });
|
2022-09-20 12:40:19 +08:00
|
|
|
}
|
|
|
|
return atenBinaryFloatOperatorFoldHelper(
|
2023-01-25 09:29:42 +08:00
|
|
|
adaptor.getOperands(),
|
|
|
|
[](double a, double b) -> double { return a - b; });
|
2022-09-20 12:40:19 +08:00
|
|
|
}
|
|
|
|
|
2022-09-20 22:31:24 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenDivOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenDivOp::fold(FoldAdaptor adaptor) {
|
|
|
|
if (!adaptor.getA() || !adaptor.getB()) {
|
2022-09-20 22:31:24 +08:00
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
// Since AtenDivOp always returns float value, we don't need to deal with the
|
|
|
|
// case where the operands are both integers separately.
|
|
|
|
return atenBinaryFloatOperatorFoldHelper(
|
2023-01-25 09:29:42 +08:00
|
|
|
adaptor.getOperands(),
|
|
|
|
[](double a, double b) -> double { return a / b; });
|
2022-09-20 22:31:24 +08:00
|
|
|
}
|
|
|
|
|
2023-06-29 10:37:13 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenAddFloatIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenAddFloatIntOp::fold(FoldAdaptor adaptor) {
|
|
|
|
if (!adaptor.getA() || !adaptor.getB()) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
return atenBinaryFloatOperatorFoldHelper(
|
|
|
|
adaptor.getOperands(), [](double a, double b) { return a + b; });
|
|
|
|
}
|
|
|
|
|
2023-03-01 01:36:05 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenPowIntFloatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenPowIntFloatOp::fold(FoldAdaptor adaptor) {
|
|
|
|
if (!adaptor.getA() || !adaptor.getB()) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
return atenBinaryFloatOperatorFoldHelper(
|
|
|
|
adaptor.getOperands(), [](double a, double b) { return std::pow(a, b); });
|
|
|
|
}
|
|
|
|
|
2022-09-20 12:40:19 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenCeilScalarOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenCeilScalarOp::fold(FoldAdaptor adaptor) {
|
|
|
|
if (!adaptor.getA()) {
|
2022-09-20 12:40:19 +08:00
|
|
|
return nullptr;
|
|
|
|
}
|
2023-01-25 09:29:42 +08:00
|
|
|
auto floatValue = adaptor.getA().dyn_cast_or_null<FloatAttr>();
|
2022-09-20 12:40:19 +08:00
|
|
|
if (!floatValue) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
return getI64IntegerAttr(
|
|
|
|
getContext(),
|
|
|
|
static_cast<int64_t>(std::ceil(floatValue.getValue().convertToDouble())));
|
|
|
|
}
|
|
|
|
|
2022-01-11 15:42:53 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenNegIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenNegIntOp::fold(FoldAdaptor adaptor) {
|
2022-01-11 15:42:53 +08:00
|
|
|
int64_t c;
|
|
|
|
if (matchPattern(getOperand(), m_TorchConstantInt(&c)))
|
|
|
|
return getI64IntegerAttr(getContext(), -c);
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2023-06-29 10:37:13 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenNegFloatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenNegFloatOp::fold(FoldAdaptor adaptor) {
|
|
|
|
if (!adaptor.getA()) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
auto value = adaptor.getA().dyn_cast_or_null<FloatAttr>();
|
|
|
|
if (!value) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
return getF64FloatAttr(getContext(), -value.getValue().convertToDouble());
|
|
|
|
}
|
|
|
|
|
2022-05-19 22:54:16 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSqrtIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenSqrtIntOp::fold(FoldAdaptor adaptor) {
|
2022-05-19 22:54:16 +08:00
|
|
|
int64_t c;
|
|
|
|
if (matchPattern(getOperand(), m_TorchConstantInt(&c)))
|
|
|
|
return getF64FloatAttr(getContext(), std::sqrt(c));
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2021-09-02 03:53:52 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimDtypeOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult PrimDtypeOp::fold(FoldAdaptor adaptor) {
|
2022-12-08 04:20:41 +08:00
|
|
|
BaseTensorType tensorType = getA().getType().cast<BaseTensorType>();
|
2021-09-02 03:53:52 +08:00
|
|
|
if (tensorType.hasDtype()) {
|
2022-07-08 05:21:05 +08:00
|
|
|
torch_upstream::ScalarType scalarType =
|
|
|
|
Torch::getScalarTypeForType(tensorType.getDtype());
|
|
|
|
return getI64IntegerAttr(getContext(), static_cast<int64_t>(scalarType));
|
2021-09-02 03:53:52 +08:00
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
2021-11-30 02:39:37 +08:00
|
|
|
|
2023-05-03 11:05:46 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimDeviceOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void PrimDeviceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](PrimDeviceOp op, PatternRewriter &rewriter) {
|
|
|
|
// Device information isn't relevant to torch-mlir, just replace it with
|
|
|
|
// "cpu".
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::ConstantDeviceOp>(op, "cpu");
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2023-06-14 09:56:39 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenCudaOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void AtenCudaOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenCudaOp op, PatternRewriter &rewriter) {
|
|
|
|
// Device information isn't relevant to torch-mlir
|
|
|
|
auto inputTensor = op.getSelf();
|
|
|
|
rewriter.replaceOp(op, inputTensor);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2023-06-23 01:07:14 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenDeviceWithIndexOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void AtenDeviceWithIndexOp::getCanonicalizationPatterns(
|
|
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenDeviceWithIndexOp op, PatternRewriter &rewriter) {
|
|
|
|
std::string type;
|
|
|
|
int64_t index;
|
|
|
|
if (!matchPattern(op.getType(), m_TorchConstantStr(type))) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: type must be a constant string");
|
|
|
|
}
|
|
|
|
if (!matchPattern(op.getIndex(), m_TorchConstantInt(&index))) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: index must be a constant integer");
|
|
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::ConstantDeviceOp>(
|
|
|
|
op, type + ":" + std::to_string(index));
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2022-02-09 19:55:14 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenIntTensorOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) {
|
2022-02-09 19:55:14 +08:00
|
|
|
// If a scalar number is converted to a 0-d tensor and passed on to
|
2021-11-30 02:39:37 +08:00
|
|
|
// aten.Int.Tensor, fold to the scalar number.
|
2022-12-08 04:20:41 +08:00
|
|
|
if (auto numToTensorScalar = getA().getDefiningOp<PrimNumToTensorScalarOp>())
|
|
|
|
return numToTensorScalar.getA();
|
2024-01-30 01:59:33 +08:00
|
|
|
if (auto tensorIntOp = getA().getDefiningOp<AtenTensorIntOp>())
|
2023-07-20 16:46:44 +08:00
|
|
|
return tensorIntOp.getT();
|
2021-11-30 02:39:37 +08:00
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2022-02-09 19:55:14 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenFloatTensorOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenFloatTensorOp::fold(FoldAdaptor adaptor) {
|
2022-02-09 19:55:14 +08:00
|
|
|
// If a scalar number is converted to a 0-d tensor and passed on to
|
|
|
|
// aten.Float.Tensor, fold to the scalar number.
|
2022-12-08 04:20:41 +08:00
|
|
|
if (auto numToTensorScalar = getA().getDefiningOp<PrimNumToTensorScalarOp>())
|
|
|
|
return numToTensorScalar.getA();
|
2022-02-09 19:55:14 +08:00
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2022-04-25 20:06:41 +08:00
|
|
|
// AtenDivFloatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenDivFloatOp::fold(FoldAdaptor adaptor) {
|
2022-04-25 20:06:41 +08:00
|
|
|
double lhs, rhs;
|
|
|
|
bool lConstant = matchPattern(getOperand(0), m_TorchConstantFloat(&lhs));
|
|
|
|
bool rConstant = matchPattern(getOperand(1), m_TorchConstantFloat(&rhs));
|
|
|
|
if (lConstant && lhs == 0.0)
|
|
|
|
return getF64FloatAttr(getContext(), 0.0);
|
|
|
|
if (lConstant && rConstant && rhs == 1.0)
|
|
|
|
return getF64FloatAttr(getContext(), lhs);
|
|
|
|
if (lConstant && rConstant)
|
|
|
|
return getF64FloatAttr(getContext(), lhs / rhs);
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2022-10-06 21:11:52 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenDivIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenDivIntOp::fold(FoldAdaptor adaptor) {
|
2022-10-06 21:11:52 +08:00
|
|
|
int64_t lhs, rhs;
|
|
|
|
bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs));
|
|
|
|
bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs));
|
|
|
|
if (lConstant && rConstant)
|
|
|
|
return getF64FloatAttr(getContext(), double(lhs) / rhs);
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2022-04-25 21:12:45 +08:00
|
|
|
// AtenCeilFloatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenCeilFloatOp::fold(FoldAdaptor adaptor) {
|
2022-04-25 21:12:45 +08:00
|
|
|
double c;
|
|
|
|
if (matchPattern(getOperand(), m_TorchConstantFloat(&c)))
|
|
|
|
return getI64IntegerAttr(getContext(), std::ceil(c));
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimMaxIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult PrimMaxIntOp::fold(FoldAdaptor adaptor) {
|
2022-03-10 08:44:22 +08:00
|
|
|
// If both operands are the same, then the operation is an identity.
|
2022-12-08 04:20:41 +08:00
|
|
|
if (getA() == getB())
|
|
|
|
return getA();
|
2022-03-10 08:44:22 +08:00
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
auto lhs = adaptor.getA().dyn_cast_or_null<IntegerAttr>();
|
|
|
|
auto rhs = adaptor.getB().dyn_cast_or_null<IntegerAttr>();
|
2022-03-10 08:44:22 +08:00
|
|
|
if (!lhs || !rhs)
|
|
|
|
return nullptr;
|
|
|
|
// Torch semantics are that !torch.int is 64-bit signed.
|
|
|
|
return IntegerAttr::get(
|
|
|
|
lhs.getType(),
|
|
|
|
std::max(lhs.getValue().getSExtValue(), rhs.getValue().getSExtValue()));
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimMinSelfIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult PrimMinSelfIntOp::fold(FoldAdaptor adaptor) {
|
2022-03-10 08:44:22 +08:00
|
|
|
auto list = getOperand().getDefiningOp<PrimListConstructOp>();
|
|
|
|
if (!list)
|
|
|
|
return nullptr;
|
|
|
|
// TODO: What does it return for an empty list?
|
|
|
|
if (list->getNumOperands() == 0)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
SmallVector<int64_t> values;
|
|
|
|
for (auto operand : list->getOperands()) {
|
|
|
|
int64_t value;
|
|
|
|
if (!matchPattern(operand, m_TorchConstantInt(&value)))
|
|
|
|
return nullptr;
|
|
|
|
values.push_back(value);
|
|
|
|
}
|
|
|
|
return getI64IntegerAttr(getContext(),
|
|
|
|
*std::min_element(values.begin(), values.end()));
|
|
|
|
}
|
|
|
|
|
2023-02-11 05:58:15 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimMinIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult PrimMinIntOp::fold(FoldAdaptor adaptor) {
|
|
|
|
// If both operands are the same, then the operation is an identity.
|
|
|
|
if (getA() == getB())
|
|
|
|
return getA();
|
|
|
|
|
|
|
|
auto lhs = adaptor.getA().dyn_cast_or_null<IntegerAttr>();
|
|
|
|
auto rhs = adaptor.getB().dyn_cast_or_null<IntegerAttr>();
|
|
|
|
if (!lhs || !rhs)
|
|
|
|
return nullptr;
|
|
|
|
// Torch semantics are that !torch.int is 64-bit signed.
|
|
|
|
return IntegerAttr::get(
|
|
|
|
lhs.getType(),
|
|
|
|
std::min(lhs.getValue().getSExtValue(), rhs.getValue().getSExtValue()));
|
|
|
|
}
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ShapeCalculateOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-12-14 00:25:41 +08:00
|
|
|
template <typename CalculateOp>
|
|
|
|
static void
|
2023-09-13 06:09:57 +08:00
|
|
|
getSuccessorRegionsForCalculateOp(CalculateOp op, RegionBranchPoint point,
|
2022-12-14 00:25:41 +08:00
|
|
|
SmallVectorImpl<RegionSuccessor> ®ions) {
|
2023-09-13 06:09:57 +08:00
|
|
|
if (!point.getRegionOrNull()) {
|
2022-12-14 00:25:41 +08:00
|
|
|
// First thing the op does is branch into the calculation.
|
|
|
|
regions.emplace_back(&op.getCalculation());
|
2022-03-10 08:44:22 +08:00
|
|
|
return;
|
|
|
|
}
|
2023-09-13 06:09:57 +08:00
|
|
|
if (point == op.getBody()) {
|
2022-03-10 08:44:22 +08:00
|
|
|
// Body returns control to the outer op, passing through results.
|
2022-12-14 00:25:41 +08:00
|
|
|
regions.emplace_back(op.getResults());
|
2022-03-10 08:44:22 +08:00
|
|
|
return;
|
|
|
|
}
|
2023-09-13 06:09:57 +08:00
|
|
|
assert(point == op.getCalculation());
|
2022-12-14 00:25:41 +08:00
|
|
|
// Calculation branches to the body.
|
|
|
|
regions.emplace_back(&op.getBody());
|
|
|
|
}
|
|
|
|
|
|
|
|
void ShapeCalculateOp::getSuccessorRegions(
|
2023-09-13 06:09:57 +08:00
|
|
|
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
|
|
|
|
getSuccessorRegionsForCalculateOp(*this, point, regions);
|
2022-12-14 00:25:41 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// DtypeCalculateOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void DtypeCalculateOp::getSuccessorRegions(
|
2023-09-13 06:09:57 +08:00
|
|
|
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
|
|
|
|
getSuccessorRegionsForCalculateOp(*this, point, regions);
|
2022-03-10 08:44:22 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ShapeCalculateYieldShapesOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
MutableOperandRange ShapeCalculateYieldShapesOp::getMutableSuccessorOperands(
|
2023-09-13 06:09:57 +08:00
|
|
|
RegionBranchPoint point) {
|
2022-03-10 08:44:22 +08:00
|
|
|
// The shape operands don't get forwarded to the body.
|
|
|
|
// MutableOperandRange always has an owning operation, even if empty, so
|
|
|
|
// create a 0-length range.
|
|
|
|
return MutableOperandRange(*this, /*start=*/0, /*length=*/0);
|
|
|
|
}
|
|
|
|
|
2022-03-16 08:54:57 +08:00
|
|
|
LogicalResult ShapeCalculateYieldShapesOp::verify() {
|
|
|
|
auto parent = cast<ShapeCalculateOp>(getOperation()->getParentOp());
|
|
|
|
if (parent.getNumResults() != getNumOperands())
|
|
|
|
return emitOpError("expected number of shapes to match number of results");
|
2022-03-10 08:44:22 +08:00
|
|
|
return success();
|
|
|
|
}
|
Rework how global slot initializers work.
Rather than a per-global-slot initializer region, we now have one for
the whole module. For example, it might look like this:
```
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor)
@list(%1 : !torch.list<tensor>)
]
}
```
This new structure allows GlobalizeObjectGraph to create the initializer in a
much simpler way, avoiding the need to reason about whether different slots
alias each other. Reasoning about whether slots alias each other now is the
responsibility of InlineGlobalSlots, which has to do a much more complicated
analysis, implemented using MLIR's dataflow analysis framework.
Recommended review order:
- Check out the new IR constructs in the .mlir files of various passes
- Op definitions (*.td)
- Changes to GlobalizeObjectGraph pass.
- InlineGlobalSlots pass (~total rewrite)
- Misc changes:
- Moving torchMlirAdjustStaticInformation for sharing with C++ code.
- EraseModuleInitializer pass
To make this a bit nicer, it would be good to have a `torch.module` op
with an initializer region attached. That would be more invasive though.
This change has highlighted certain aspects of our project layering
which are worth calling out. None of our backends can handle global
slots, so we enforce that there are no global slots before backend
lowering. At an earlier stage in the project, we had aspirations of
transparently handling mutable global state and such, but for reasons
described below, that is no longer a goal. So really global slots should
be seen as a progressive lowering step as part of inlining all the
IValue's in the original program (GlobalizeObjectGraph is also one such
step).
Over time, with insights from work like IREE-JAX, it has become clear
that there isn't a reliable programming model we can compile for users
where we just transparently handle mutable global state (and some other
things, like lists and dictionaries). There is a need for an "outer
program" that orchestrates more restricted subroutines of the kind we
can handle in our compile flow here. The benefit of that is that it
decouples considerations like shapes, dtypes, etc. from the program
constructs used in the outer program. As long as the outer program can
efficiently invoke (pipelining/async/etc.) high-performance
data-parallel numerical subroutines of the kind we compile in our flow
here, then there is a complete programming model. This is also
consistent with the direction of upstream PyTorch which is becoming more
tracing-based (which inherently loses a lot of program structure, which
then has to be applied back with an "outer program" orchestrating the
traced subroutines).
2022-07-14 02:45:56 +08:00
|
|
|
|
2023-11-16 03:47:54 +08:00
|
|
|
LogicalResult AtenPermuteOp::verify() {
|
|
|
|
|
|
|
|
// Verification of the permute op for input & output dimensions with
|
|
|
|
// statically known sizes.
|
|
|
|
|
|
|
|
SmallVector<Value> permutation;
|
|
|
|
auto permutationObtained = getListConstructElements(getDims(), permutation);
|
|
|
|
if (!permutationObtained) {
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
auto outType = getResult().getType().cast<BaseTensorType>();
|
|
|
|
auto inType = getSelf().getType().cast<BaseTensorType>();
|
|
|
|
|
|
|
|
if (!outType.hasSizes() || !inType.hasSizes()) {
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
auto outShape = outType.getSizes();
|
|
|
|
auto inShape = inType.getSizes();
|
|
|
|
|
|
|
|
auto outRank = outShape.size();
|
|
|
|
|
|
|
|
if (outRank != inShape.size()) {
|
|
|
|
return emitOpError(
|
|
|
|
"expected input and output tensors to have same rank, but ")
|
|
|
|
<< inShape.size() << " != " << outRank << '.';
|
|
|
|
}
|
|
|
|
|
|
|
|
if (outRank != permutation.size()) {
|
|
|
|
return emitOpError() << "expected permutation to have size equal result "
|
|
|
|
"tensor rank. The permutation has "
|
|
|
|
<< permutation.size()
|
|
|
|
<< " elements, the output has rank " << outRank << '.';
|
|
|
|
}
|
|
|
|
|
|
|
|
// Initialization of the reverse permutation. -1 denotes an unknown
|
|
|
|
// permutation index.
|
|
|
|
SmallVector<int64_t> reversePermutation(outRank, -1);
|
|
|
|
|
|
|
|
// In this loop:
|
|
|
|
// (1) check that the permutation indices are in bounds, and not duplicated.
|
|
|
|
// (2) populate reversePermutation (to check for duplicates).
|
|
|
|
// (3) check that the input and output shapes agree with the permutation. For
|
|
|
|
// example, if the permutation is (1,2,0) and the input shape is (2,3,5),
|
|
|
|
// then the output shape must be (3,5,2).
|
|
|
|
|
|
|
|
for (uint64_t to = 0; to < outRank; ++to) {
|
|
|
|
int64_t from;
|
|
|
|
|
|
|
|
auto fromIsSet = matchPattern(permutation[to], m_TorchConstantInt(&from));
|
|
|
|
|
|
|
|
if (!fromIsSet) {
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
|
|
|
// if 'from' is the unkwown index, continue.
|
|
|
|
if (from == -1) {
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (!isValidDim(from, outRank)) {
|
|
|
|
return emitError("observed invalid index in permutation (")
|
|
|
|
<< from << ") for input tensor of rank " << outRank << '.';
|
|
|
|
}
|
|
|
|
|
|
|
|
if (reversePermutation[from] != -1) {
|
|
|
|
return emitOpError("has a duplicate dimension (")
|
|
|
|
<< from << ") in its permutation " << getDims() << '.';
|
|
|
|
}
|
|
|
|
reversePermutation[from] = to;
|
|
|
|
|
|
|
|
auto dimSizesDefined =
|
|
|
|
inShape[from] != kUnknownSize && outShape[to] != kUnknownSize;
|
|
|
|
auto dimSizesDifferent = inShape[from] != outShape[to];
|
|
|
|
|
|
|
|
if (dimSizesDefined && dimSizesDifferent) {
|
|
|
|
return emitOpError("has a permutation which is not compatible with the "
|
|
|
|
"input and output shapes. ")
|
|
|
|
<< "The input shape in dimension " << from << " is "
|
|
|
|
<< inShape[from] << ", and the output shape in dimension " << to
|
|
|
|
<< " is " << outShape[to]
|
|
|
|
<< " : they should be the same with this permutation. ";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-12-14 00:25:41 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// DtypeCalculateYieldDtypesOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
MutableOperandRange DtypeCalculateYieldDtypesOp::getMutableSuccessorOperands(
|
2023-09-13 06:09:57 +08:00
|
|
|
RegionBranchPoint point) {
|
2022-12-14 00:25:41 +08:00
|
|
|
// The dtype operands don't get forwarded to the body.
|
|
|
|
// MutableOperandRange always has an owning operation, even if empty, so
|
|
|
|
// create a 0-length range.
|
|
|
|
return MutableOperandRange(*this, /*start=*/0, /*length=*/0);
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult DtypeCalculateYieldDtypesOp::verify() {
|
|
|
|
auto parent = cast<DtypeCalculateOp>(getOperation()->getParentOp());
|
|
|
|
if (parent.getNumResults() != getNumOperands())
|
|
|
|
return emitOpError("expected number of dtypes to match number of results");
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
Rework how global slot initializers work.
Rather than a per-global-slot initializer region, we now have one for
the whole module. For example, it might look like this:
```
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor)
@list(%1 : !torch.list<tensor>)
]
}
```
This new structure allows GlobalizeObjectGraph to create the initializer in a
much simpler way, avoiding the need to reason about whether different slots
alias each other. Reasoning about whether slots alias each other now is the
responsibility of InlineGlobalSlots, which has to do a much more complicated
analysis, implemented using MLIR's dataflow analysis framework.
Recommended review order:
- Check out the new IR constructs in the .mlir files of various passes
- Op definitions (*.td)
- Changes to GlobalizeObjectGraph pass.
- InlineGlobalSlots pass (~total rewrite)
- Misc changes:
- Moving torchMlirAdjustStaticInformation for sharing with C++ code.
- EraseModuleInitializer pass
To make this a bit nicer, it would be good to have a `torch.module` op
with an initializer region attached. That would be more invasive though.
This change has highlighted certain aspects of our project layering
which are worth calling out. None of our backends can handle global
slots, so we enforce that there are no global slots before backend
lowering. At an earlier stage in the project, we had aspirations of
transparently handling mutable global state and such, but for reasons
described below, that is no longer a goal. So really global slots should
be seen as a progressive lowering step as part of inlining all the
IValue's in the original program (GlobalizeObjectGraph is also one such
step).
Over time, with insights from work like IREE-JAX, it has become clear
that there isn't a reliable programming model we can compile for users
where we just transparently handle mutable global state (and some other
things, like lists and dictionaries). There is a need for an "outer
program" that orchestrates more restricted subroutines of the kind we
can handle in our compile flow here. The benefit of that is that it
decouples considerations like shapes, dtypes, etc. from the program
constructs used in the outer program. As long as the outer program can
efficiently invoke (pipelining/async/etc.) high-performance
data-parallel numerical subroutines of the kind we compile in our flow
here, then there is a complete programming model. This is also
consistent with the direction of upstream PyTorch which is becoming more
tracing-based (which inherently loses a lot of program structure, which
then has to be applied back with an "outer program" orchestrating the
traced subroutines).
2022-07-14 02:45:56 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// GlobalSlotModuleInitializerOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult GlobalSlotModuleInitializerOp::verify() {
|
|
|
|
// We centralize all verification of the global slots and the
|
|
|
|
// InitializeGlobalSlotsOp into here, since it requires processing the whole
|
|
|
|
// module.
|
|
|
|
|
|
|
|
// TODO: We should really have a `torch.module` and have this initializer be
|
|
|
|
// a region attached to it.
|
|
|
|
|
|
|
|
ModuleOp module = cast<ModuleOp>(getOperation()->getParentOp());
|
|
|
|
for (auto op : module.getOps<GlobalSlotModuleInitializerOp>()) {
|
|
|
|
if (op.getOperation() != getOperation())
|
|
|
|
return op.emitError("there must be only one global slot initializer");
|
|
|
|
}
|
|
|
|
|
|
|
|
// Collect the relevant symbol names we will verify.
|
|
|
|
DenseSet</*StringAttr*/ Attribute> knownGlobalSlots;
|
|
|
|
for (auto op : module.getOps<GlobalSlotOp>())
|
2022-12-08 04:20:41 +08:00
|
|
|
knownGlobalSlots.insert(op.getSymNameAttr());
|
Rework how global slot initializers work.
Rather than a per-global-slot initializer region, we now have one for
the whole module. For example, it might look like this:
```
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor)
@list(%1 : !torch.list<tensor>)
]
}
```
This new structure allows GlobalizeObjectGraph to create the initializer in a
much simpler way, avoiding the need to reason about whether different slots
alias each other. Reasoning about whether slots alias each other now is the
responsibility of InlineGlobalSlots, which has to do a much more complicated
analysis, implemented using MLIR's dataflow analysis framework.
Recommended review order:
- Check out the new IR constructs in the .mlir files of various passes
- Op definitions (*.td)
- Changes to GlobalizeObjectGraph pass.
- InlineGlobalSlots pass (~total rewrite)
- Misc changes:
- Moving torchMlirAdjustStaticInformation for sharing with C++ code.
- EraseModuleInitializer pass
To make this a bit nicer, it would be good to have a `torch.module` op
with an initializer region attached. That would be more invasive though.
This change has highlighted certain aspects of our project layering
which are worth calling out. None of our backends can handle global
slots, so we enforce that there are no global slots before backend
lowering. At an earlier stage in the project, we had aspirations of
transparently handling mutable global state and such, but for reasons
described below, that is no longer a goal. So really global slots should
be seen as a progressive lowering step as part of inlining all the
IValue's in the original program (GlobalizeObjectGraph is also one such
step).
Over time, with insights from work like IREE-JAX, it has become clear
that there isn't a reliable programming model we can compile for users
where we just transparently handle mutable global state (and some other
things, like lists and dictionaries). There is a need for an "outer
program" that orchestrates more restricted subroutines of the kind we
can handle in our compile flow here. The benefit of that is that it
decouples considerations like shapes, dtypes, etc. from the program
constructs used in the outer program. As long as the outer program can
efficiently invoke (pipelining/async/etc.) high-performance
data-parallel numerical subroutines of the kind we compile in our flow
here, then there is a complete programming model. This is also
consistent with the direction of upstream PyTorch which is becoming more
tracing-based (which inherently loses a lot of program structure, which
then has to be applied back with an "outer program" orchestrating the
traced subroutines).
2022-07-14 02:45:56 +08:00
|
|
|
DenseSet</*StringAttr*/ Attribute> initializedGlobalSlots;
|
|
|
|
auto initialize = cast<InitializeGlobalSlotsOp>(getBody()->getTerminator());
|
2022-12-08 04:20:41 +08:00
|
|
|
for (Attribute symName : initialize.getSlotSymNames()) {
|
Rework how global slot initializers work.
Rather than a per-global-slot initializer region, we now have one for
the whole module. For example, it might look like this:
```
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor)
@list(%1 : !torch.list<tensor>)
]
}
```
This new structure allows GlobalizeObjectGraph to create the initializer in a
much simpler way, avoiding the need to reason about whether different slots
alias each other. Reasoning about whether slots alias each other now is the
responsibility of InlineGlobalSlots, which has to do a much more complicated
analysis, implemented using MLIR's dataflow analysis framework.
Recommended review order:
- Check out the new IR constructs in the .mlir files of various passes
- Op definitions (*.td)
- Changes to GlobalizeObjectGraph pass.
- InlineGlobalSlots pass (~total rewrite)
- Misc changes:
- Moving torchMlirAdjustStaticInformation for sharing with C++ code.
- EraseModuleInitializer pass
To make this a bit nicer, it would be good to have a `torch.module` op
with an initializer region attached. That would be more invasive though.
This change has highlighted certain aspects of our project layering
which are worth calling out. None of our backends can handle global
slots, so we enforce that there are no global slots before backend
lowering. At an earlier stage in the project, we had aspirations of
transparently handling mutable global state and such, but for reasons
described below, that is no longer a goal. So really global slots should
be seen as a progressive lowering step as part of inlining all the
IValue's in the original program (GlobalizeObjectGraph is also one such
step).
Over time, with insights from work like IREE-JAX, it has become clear
that there isn't a reliable programming model we can compile for users
where we just transparently handle mutable global state (and some other
things, like lists and dictionaries). There is a need for an "outer
program" that orchestrates more restricted subroutines of the kind we
can handle in our compile flow here. The benefit of that is that it
decouples considerations like shapes, dtypes, etc. from the program
constructs used in the outer program. As long as the outer program can
efficiently invoke (pipelining/async/etc.) high-performance
data-parallel numerical subroutines of the kind we compile in our flow
here, then there is a complete programming model. This is also
consistent with the direction of upstream PyTorch which is becoming more
tracing-based (which inherently loses a lot of program structure, which
then has to be applied back with an "outer program" orchestrating the
traced subroutines).
2022-07-14 02:45:56 +08:00
|
|
|
auto wasInserted = initializedGlobalSlots
|
|
|
|
.insert(symName.cast<FlatSymbolRefAttr>().getAttr())
|
|
|
|
.second;
|
|
|
|
if (!wasInserted)
|
|
|
|
return initialize.emitError("duplicate initialization of global slot: ")
|
|
|
|
<< symName;
|
|
|
|
}
|
|
|
|
auto lessThanByStringValue = [](Attribute lhs, Attribute rhs) {
|
|
|
|
return lhs.cast<StringAttr>().getValue() <
|
|
|
|
rhs.cast<StringAttr>().getValue();
|
|
|
|
};
|
|
|
|
auto known = llvm::to_vector(knownGlobalSlots);
|
|
|
|
llvm::sort(known, lessThanByStringValue);
|
|
|
|
auto initialized = llvm::to_vector(initializedGlobalSlots);
|
|
|
|
llvm::sort(initialized, lessThanByStringValue);
|
|
|
|
|
|
|
|
// Check that the global slots in the module are all initialized.
|
|
|
|
SymbolTable symbolTable(module);
|
|
|
|
if (initializedGlobalSlots != knownGlobalSlots) {
|
|
|
|
InFlightDiagnostic diag = initialize.emitOpError(
|
|
|
|
"must have one initializer for each global slot in the module");
|
|
|
|
for (auto knownGlobalSlot : known) {
|
|
|
|
auto symName = FlatSymbolRefAttr::get(knownGlobalSlot.cast<StringAttr>());
|
|
|
|
if (!initializedGlobalSlots.count(knownGlobalSlot)) {
|
|
|
|
diag.attachNote(
|
|
|
|
symbolTable.lookup<GlobalSlotOp>(symName.getAttr()).getLoc())
|
|
|
|
.append("missing global slot initializer for ", symName);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for (auto initializedGlobalSlot : initialized) {
|
|
|
|
if (!knownGlobalSlots.count(initializedGlobalSlot)) {
|
|
|
|
diag.attachNote().append(
|
|
|
|
"unexpected global slot initializer for non-existent global slot ",
|
|
|
|
FlatSymbolRefAttr::get(initializedGlobalSlot.cast<StringAttr>()));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return diag;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Check that initial values satisfy type bounds.
|
|
|
|
for (int i = 0, e = initialize.getNumOperands(); i < e; ++i) {
|
2022-12-08 04:20:41 +08:00
|
|
|
auto symName = initialize.getSlotSymNames()[i].cast<FlatSymbolRefAttr>();
|
Rework how global slot initializers work.
Rather than a per-global-slot initializer region, we now have one for
the whole module. For example, it might look like this:
```
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor)
@list(%1 : !torch.list<tensor>)
]
}
```
This new structure allows GlobalizeObjectGraph to create the initializer in a
much simpler way, avoiding the need to reason about whether different slots
alias each other. Reasoning about whether slots alias each other now is the
responsibility of InlineGlobalSlots, which has to do a much more complicated
analysis, implemented using MLIR's dataflow analysis framework.
Recommended review order:
- Check out the new IR constructs in the .mlir files of various passes
- Op definitions (*.td)
- Changes to GlobalizeObjectGraph pass.
- InlineGlobalSlots pass (~total rewrite)
- Misc changes:
- Moving torchMlirAdjustStaticInformation for sharing with C++ code.
- EraseModuleInitializer pass
To make this a bit nicer, it would be good to have a `torch.module` op
with an initializer region attached. That would be more invasive though.
This change has highlighted certain aspects of our project layering
which are worth calling out. None of our backends can handle global
slots, so we enforce that there are no global slots before backend
lowering. At an earlier stage in the project, we had aspirations of
transparently handling mutable global state and such, but for reasons
described below, that is no longer a goal. So really global slots should
be seen as a progressive lowering step as part of inlining all the
IValue's in the original program (GlobalizeObjectGraph is also one such
step).
Over time, with insights from work like IREE-JAX, it has become clear
that there isn't a reliable programming model we can compile for users
where we just transparently handle mutable global state (and some other
things, like lists and dictionaries). There is a need for an "outer
program" that orchestrates more restricted subroutines of the kind we
can handle in our compile flow here. The benefit of that is that it
decouples considerations like shapes, dtypes, etc. from the program
constructs used in the outer program. As long as the outer program can
efficiently invoke (pipelining/async/etc.) high-performance
data-parallel numerical subroutines of the kind we compile in our flow
here, then there is a complete programming model. This is also
consistent with the direction of upstream PyTorch which is becoming more
tracing-based (which inherently loses a lot of program structure, which
then has to be applied back with an "outer program" orchestrating the
traced subroutines).
2022-07-14 02:45:56 +08:00
|
|
|
auto initialValue = initialize.getOperand(i);
|
|
|
|
auto globalSlotOp = symbolTable.lookup<GlobalSlotOp>(symName.getValue());
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!isValidSubtype(initialValue.getType(), globalSlotOp.getTypeBound())) {
|
Rework how global slot initializers work.
Rather than a per-global-slot initializer region, we now have one for
the whole module. For example, it might look like this:
```
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor)
@list(%1 : !torch.list<tensor>)
]
}
```
This new structure allows GlobalizeObjectGraph to create the initializer in a
much simpler way, avoiding the need to reason about whether different slots
alias each other. Reasoning about whether slots alias each other now is the
responsibility of InlineGlobalSlots, which has to do a much more complicated
analysis, implemented using MLIR's dataflow analysis framework.
Recommended review order:
- Check out the new IR constructs in the .mlir files of various passes
- Op definitions (*.td)
- Changes to GlobalizeObjectGraph pass.
- InlineGlobalSlots pass (~total rewrite)
- Misc changes:
- Moving torchMlirAdjustStaticInformation for sharing with C++ code.
- EraseModuleInitializer pass
To make this a bit nicer, it would be good to have a `torch.module` op
with an initializer region attached. That would be more invasive though.
This change has highlighted certain aspects of our project layering
which are worth calling out. None of our backends can handle global
slots, so we enforce that there are no global slots before backend
lowering. At an earlier stage in the project, we had aspirations of
transparently handling mutable global state and such, but for reasons
described below, that is no longer a goal. So really global slots should
be seen as a progressive lowering step as part of inlining all the
IValue's in the original program (GlobalizeObjectGraph is also one such
step).
Over time, with insights from work like IREE-JAX, it has become clear
that there isn't a reliable programming model we can compile for users
where we just transparently handle mutable global state (and some other
things, like lists and dictionaries). There is a need for an "outer
program" that orchestrates more restricted subroutines of the kind we
can handle in our compile flow here. The benefit of that is that it
decouples considerations like shapes, dtypes, etc. from the program
constructs used in the outer program. As long as the outer program can
efficiently invoke (pipelining/async/etc.) high-performance
data-parallel numerical subroutines of the kind we compile in our flow
here, then there is a complete programming model. This is also
consistent with the direction of upstream PyTorch which is becoming more
tracing-based (which inherently loses a lot of program structure, which
then has to be applied back with an "outer program" orchestrating the
traced subroutines).
2022-07-14 02:45:56 +08:00
|
|
|
return initialize.emitOpError().append(
|
|
|
|
"initial value for global slot ", symName, " has type ",
|
|
|
|
initialValue.getType(), " which is not within the bound ",
|
2022-12-08 04:20:41 +08:00
|
|
|
globalSlotOp.getTypeBound());
|
Rework how global slot initializers work.
Rather than a per-global-slot initializer region, we now have one for
the whole module. For example, it might look like this:
```
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor)
@list(%1 : !torch.list<tensor>)
]
}
```
This new structure allows GlobalizeObjectGraph to create the initializer in a
much simpler way, avoiding the need to reason about whether different slots
alias each other. Reasoning about whether slots alias each other now is the
responsibility of InlineGlobalSlots, which has to do a much more complicated
analysis, implemented using MLIR's dataflow analysis framework.
Recommended review order:
- Check out the new IR constructs in the .mlir files of various passes
- Op definitions (*.td)
- Changes to GlobalizeObjectGraph pass.
- InlineGlobalSlots pass (~total rewrite)
- Misc changes:
- Moving torchMlirAdjustStaticInformation for sharing with C++ code.
- EraseModuleInitializer pass
To make this a bit nicer, it would be good to have a `torch.module` op
with an initializer region attached. That would be more invasive though.
This change has highlighted certain aspects of our project layering
which are worth calling out. None of our backends can handle global
slots, so we enforce that there are no global slots before backend
lowering. At an earlier stage in the project, we had aspirations of
transparently handling mutable global state and such, but for reasons
described below, that is no longer a goal. So really global slots should
be seen as a progressive lowering step as part of inlining all the
IValue's in the original program (GlobalizeObjectGraph is also one such
step).
Over time, with insights from work like IREE-JAX, it has become clear
that there isn't a reliable programming model we can compile for users
where we just transparently handle mutable global state (and some other
things, like lists and dictionaries). There is a need for an "outer
program" that orchestrates more restricted subroutines of the kind we
can handle in our compile flow here. The benefit of that is that it
decouples considerations like shapes, dtypes, etc. from the program
constructs used in the outer program. As long as the outer program can
efficiently invoke (pipelining/async/etc.) high-performance
data-parallel numerical subroutines of the kind we compile in our flow
here, then there is a complete programming model. This is also
consistent with the direction of upstream PyTorch which is becoming more
tracing-based (which inherently loses a lot of program structure, which
then has to be applied back with an "outer program" orchestrating the
traced subroutines).
2022-07-14 02:45:56 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
auto walkResult = getOperation()->walk([](Operation *op) {
|
|
|
|
// We only permit a small set of ops in the module initializer.
|
|
|
|
// These ops are essentially those which can be produced by the IValue
|
|
|
|
// importer.
|
2022-09-20 05:56:35 +08:00
|
|
|
if (op->hasTrait<mlir::torch::Torch::OpTrait::AllowedInModuleInitializer>())
|
Rework how global slot initializers work.
Rather than a per-global-slot initializer region, we now have one for
the whole module. For example, it might look like this:
```
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor)
@list(%1 : !torch.list<tensor>)
]
}
```
This new structure allows GlobalizeObjectGraph to create the initializer in a
much simpler way, avoiding the need to reason about whether different slots
alias each other. Reasoning about whether slots alias each other now is the
responsibility of InlineGlobalSlots, which has to do a much more complicated
analysis, implemented using MLIR's dataflow analysis framework.
Recommended review order:
- Check out the new IR constructs in the .mlir files of various passes
- Op definitions (*.td)
- Changes to GlobalizeObjectGraph pass.
- InlineGlobalSlots pass (~total rewrite)
- Misc changes:
- Moving torchMlirAdjustStaticInformation for sharing with C++ code.
- EraseModuleInitializer pass
To make this a bit nicer, it would be good to have a `torch.module` op
with an initializer region attached. That would be more invasive though.
This change has highlighted certain aspects of our project layering
which are worth calling out. None of our backends can handle global
slots, so we enforce that there are no global slots before backend
lowering. At an earlier stage in the project, we had aspirations of
transparently handling mutable global state and such, but for reasons
described below, that is no longer a goal. So really global slots should
be seen as a progressive lowering step as part of inlining all the
IValue's in the original program (GlobalizeObjectGraph is also one such
step).
Over time, with insights from work like IREE-JAX, it has become clear
that there isn't a reliable programming model we can compile for users
where we just transparently handle mutable global state (and some other
things, like lists and dictionaries). There is a need for an "outer
program" that orchestrates more restricted subroutines of the kind we
can handle in our compile flow here. The benefit of that is that it
decouples considerations like shapes, dtypes, etc. from the program
constructs used in the outer program. As long as the outer program can
efficiently invoke (pipelining/async/etc.) high-performance
data-parallel numerical subroutines of the kind we compile in our flow
here, then there is a complete programming model. This is also
consistent with the direction of upstream PyTorch which is becoming more
tracing-based (which inherently loses a lot of program structure, which
then has to be applied back with an "outer program" orchestrating the
traced subroutines).
2022-07-14 02:45:56 +08:00
|
|
|
return WalkResult::advance();
|
|
|
|
op->emitOpError() << "is not allowed in a module initializer";
|
|
|
|
return WalkResult::interrupt();
|
|
|
|
});
|
|
|
|
if (walkResult.wasInterrupted())
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// InitializeGlobalSlotsOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
ParseResult InitializeGlobalSlotsOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
|
|
|
if (parser.parseOptionalAttrDict(result.attributes))
|
|
|
|
return failure();
|
|
|
|
if (parser.parseLSquare())
|
|
|
|
return failure();
|
|
|
|
SmallVector<Attribute> slotSymNames;
|
|
|
|
while (!succeeded(parser.parseOptionalRSquare())) {
|
|
|
|
NamedAttrList dummy;
|
|
|
|
StringAttr slotSymName;
|
|
|
|
if (parser.parseSymbolName(slotSymName, "dummy", dummy))
|
|
|
|
return failure();
|
|
|
|
slotSymNames.push_back(FlatSymbolRefAttr::get(slotSymName));
|
|
|
|
if (parser.parseLParen())
|
|
|
|
return failure();
|
|
|
|
OpAsmParser::UnresolvedOperand initialValue;
|
|
|
|
if (parser.parseOperand(initialValue))
|
|
|
|
return failure();
|
|
|
|
Type initialValueType;
|
|
|
|
if (parser.parseColonType(initialValueType))
|
|
|
|
return failure();
|
|
|
|
if (parser.parseRParen())
|
|
|
|
return failure();
|
|
|
|
if (parser.resolveOperand(initialValue, initialValueType, result.operands))
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
result.addAttribute("slotSymNames",
|
|
|
|
ArrayAttr::get(parser.getContext(), slotSymNames));
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
void InitializeGlobalSlotsOp::print(OpAsmPrinter &p) {
|
|
|
|
p.printOptionalAttrDict(getOperation()->getAttrs(),
|
|
|
|
/*elidedAttrs=*/{"slotSymNames"});
|
|
|
|
p << " [";
|
|
|
|
p.printNewline();
|
|
|
|
for (int i = 0, e = getNumOperands(); i < e; ++i) {
|
2022-12-08 04:20:41 +08:00
|
|
|
p << " " << getSlotSymNames()[i] << "(" << getInitialValues()[i] << " : "
|
|
|
|
<< getInitialValues()[i].getType() << ")";
|
Rework how global slot initializers work.
Rather than a per-global-slot initializer region, we now have one for
the whole module. For example, it might look like this:
```
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor)
@list(%1 : !torch.list<tensor>)
]
}
```
This new structure allows GlobalizeObjectGraph to create the initializer in a
much simpler way, avoiding the need to reason about whether different slots
alias each other. Reasoning about whether slots alias each other now is the
responsibility of InlineGlobalSlots, which has to do a much more complicated
analysis, implemented using MLIR's dataflow analysis framework.
Recommended review order:
- Check out the new IR constructs in the .mlir files of various passes
- Op definitions (*.td)
- Changes to GlobalizeObjectGraph pass.
- InlineGlobalSlots pass (~total rewrite)
- Misc changes:
- Moving torchMlirAdjustStaticInformation for sharing with C++ code.
- EraseModuleInitializer pass
To make this a bit nicer, it would be good to have a `torch.module` op
with an initializer region attached. That would be more invasive though.
This change has highlighted certain aspects of our project layering
which are worth calling out. None of our backends can handle global
slots, so we enforce that there are no global slots before backend
lowering. At an earlier stage in the project, we had aspirations of
transparently handling mutable global state and such, but for reasons
described below, that is no longer a goal. So really global slots should
be seen as a progressive lowering step as part of inlining all the
IValue's in the original program (GlobalizeObjectGraph is also one such
step).
Over time, with insights from work like IREE-JAX, it has become clear
that there isn't a reliable programming model we can compile for users
where we just transparently handle mutable global state (and some other
things, like lists and dictionaries). There is a need for an "outer
program" that orchestrates more restricted subroutines of the kind we
can handle in our compile flow here. The benefit of that is that it
decouples considerations like shapes, dtypes, etc. from the program
constructs used in the outer program. As long as the outer program can
efficiently invoke (pipelining/async/etc.) high-performance
data-parallel numerical subroutines of the kind we compile in our flow
here, then there is a complete programming model. This is also
consistent with the direction of upstream PyTorch which is becoming more
tracing-based (which inherently loses a lot of program structure, which
then has to be applied back with an "outer program" orchestrating the
traced subroutines).
2022-07-14 02:45:56 +08:00
|
|
|
p.printNewline();
|
|
|
|
}
|
|
|
|
p << "]";
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult InitializeGlobalSlotsOp::verify() {
|
2022-12-08 04:20:41 +08:00
|
|
|
if (getInitialValues().size() != getSlotSymNames().size())
|
Rework how global slot initializers work.
Rather than a per-global-slot initializer region, we now have one for
the whole module. For example, it might look like this:
```
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor)
@list(%1 : !torch.list<tensor>)
]
}
```
This new structure allows GlobalizeObjectGraph to create the initializer in a
much simpler way, avoiding the need to reason about whether different slots
alias each other. Reasoning about whether slots alias each other now is the
responsibility of InlineGlobalSlots, which has to do a much more complicated
analysis, implemented using MLIR's dataflow analysis framework.
Recommended review order:
- Check out the new IR constructs in the .mlir files of various passes
- Op definitions (*.td)
- Changes to GlobalizeObjectGraph pass.
- InlineGlobalSlots pass (~total rewrite)
- Misc changes:
- Moving torchMlirAdjustStaticInformation for sharing with C++ code.
- EraseModuleInitializer pass
To make this a bit nicer, it would be good to have a `torch.module` op
with an initializer region attached. That would be more invasive though.
This change has highlighted certain aspects of our project layering
which are worth calling out. None of our backends can handle global
slots, so we enforce that there are no global slots before backend
lowering. At an earlier stage in the project, we had aspirations of
transparently handling mutable global state and such, but for reasons
described below, that is no longer a goal. So really global slots should
be seen as a progressive lowering step as part of inlining all the
IValue's in the original program (GlobalizeObjectGraph is also one such
step).
Over time, with insights from work like IREE-JAX, it has become clear
that there isn't a reliable programming model we can compile for users
where we just transparently handle mutable global state (and some other
things, like lists and dictionaries). There is a need for an "outer
program" that orchestrates more restricted subroutines of the kind we
can handle in our compile flow here. The benefit of that is that it
decouples considerations like shapes, dtypes, etc. from the program
constructs used in the outer program. As long as the outer program can
efficiently invoke (pipelining/async/etc.) high-performance
data-parallel numerical subroutines of the kind we compile in our flow
here, then there is a complete programming model. This is also
consistent with the direction of upstream PyTorch which is becoming more
tracing-based (which inherently loses a lot of program structure, which
then has to be applied back with an "outer program" orchestrating the
traced subroutines).
2022-07-14 02:45:56 +08:00
|
|
|
return emitOpError("expected number of operands to match number of slots");
|
|
|
|
return success();
|
|
|
|
}
|