2020-09-29 03:02:35 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
2021-09-30 00:03:40 +08:00
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
2020-09-29 03:02:35 +08:00
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
2024-02-03 02:46:33 +08:00
|
|
|
#define DEBUG_TYPE "torch-mlir-torch-dialect"
|
[torch-mlir earthmoving (1/N)] C/C++ code movement.
This creates the `external/torch-mlir` directory as an
LLVM_EXTERNAL_PROJECTS-compatible project (analogous to
`iree-dialects`) and completes movement/rename of all pure MLIR C/C++
compiler code into there. The next step will be to move all the Python
code / code that links/includes PyTorch C++ code (which currently lives
in `frontends/pytorch`) into a subdirectory here.
I call this "earthmoving" because it is mostly mechanical changes and
renames. As a quick summary (we can change this down the road easily)
- C++ `mlir::NPCOMP::Torch -> mlir::torch::Torch`
- CAPI `npcompTorchListTypeGet -> torchMlirTorchListTypeGet`
- preprocessor `#ifndef NPCOMP_ -> #ifndef TORCHMLIR_`
- CMake `NPCOMPFoo -> TorchMLIRFoo`
The goal of this is to create a standalone project creating a center of
mass for entry into the MLIR ecosystem from PyTorch, suitable in scope
for eventual inclusion/ownership in PyTorch. The idea is that
`external/torch-mlir` will some day be pulled out into its own
repository, and then npcomp will simply pull it in as a submodule.
Layering-wise, what lives in `torch-mlir` lowers code from PyTorch
(currently TorchScript, but TorchFX or pytorch/xla-style tracing are
possible extensions) down to what we have been calling the "Torch
backend contract" which is cleaned up IR (inlining, simplifcation,
conversion to value tensors, ...) entirely in the `torch` dialect. This
is the branching off point for further lowering, of which npcomp takes
one opinion (outside `torch-mlir` of course!), namely the
`TorchConversion` dialect/transforms which lower to IR suitable for IREE
and other linalg-on-tensors based lower-level compilers.
Summary of changes:
- move `{include,lib,test}/Dialect/Torch` into `torch-mlir`
- move relevant parts of CAPI into `torch-mlir`.
- leave a few things related to the `torch-mlir` Python build commented
out, which should be resolved in a subsequent change.
2021-09-10 03:24:10 +08:00
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
2022-12-09 01:49:54 +08:00
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
2024-02-03 02:46:33 +08:00
|
|
|
#include "llvm/Support/Debug.h"
|
2020-09-29 03:02:35 +08:00
|
|
|
|
2022-04-27 03:27:51 +08:00
|
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
2020-09-30 05:17:34 +08:00
|
|
|
#include "mlir/IR/Builders.h"
|
2021-01-28 08:35:44 +08:00
|
|
|
#include "mlir/IR/BuiltinOps.h"
|
2021-04-27 02:42:41 +08:00
|
|
|
#include "mlir/IR/PatternMatch.h"
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
#include "mlir/IR/TypeUtilities.h"
|
2022-01-11 15:42:53 +08:00
|
|
|
#include "mlir/Support/LLVM.h"
|
2021-10-16 06:23:59 +08:00
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
2022-03-10 08:44:22 +08:00
|
|
|
#include "llvm/ADT/BitVector.h"
|
2021-02-18 03:28:51 +08:00
|
|
|
#include "llvm/ADT/StringMap.h"
|
2022-01-11 15:42:53 +08:00
|
|
|
#include "llvm/Support/Casting.h"
|
2020-09-30 05:17:34 +08:00
|
|
|
|
2020-09-29 03:02:35 +08:00
|
|
|
using namespace mlir;
|
[torch-mlir earthmoving (1/N)] C/C++ code movement.
This creates the `external/torch-mlir` directory as an
LLVM_EXTERNAL_PROJECTS-compatible project (analogous to
`iree-dialects`) and completes movement/rename of all pure MLIR C/C++
compiler code into there. The next step will be to move all the Python
code / code that links/includes PyTorch C++ code (which currently lives
in `frontends/pytorch`) into a subdirectory here.
I call this "earthmoving" because it is mostly mechanical changes and
renames. As a quick summary (we can change this down the road easily)
- C++ `mlir::NPCOMP::Torch -> mlir::torch::Torch`
- CAPI `npcompTorchListTypeGet -> torchMlirTorchListTypeGet`
- preprocessor `#ifndef NPCOMP_ -> #ifndef TORCHMLIR_`
- CMake `NPCOMPFoo -> TorchMLIRFoo`
The goal of this is to create a standalone project creating a center of
mass for entry into the MLIR ecosystem from PyTorch, suitable in scope
for eventual inclusion/ownership in PyTorch. The idea is that
`external/torch-mlir` will some day be pulled out into its own
repository, and then npcomp will simply pull it in as a submodule.
Layering-wise, what lives in `torch-mlir` lowers code from PyTorch
(currently TorchScript, but TorchFX or pytorch/xla-style tracing are
possible extensions) down to what we have been calling the "Torch
backend contract" which is cleaned up IR (inlining, simplifcation,
conversion to value tensors, ...) entirely in the `torch` dialect. This
is the branching off point for further lowering, of which npcomp takes
one opinion (outside `torch-mlir` of course!), namely the
`TorchConversion` dialect/transforms which lower to IR suitable for IREE
and other linalg-on-tensors based lower-level compilers.
Summary of changes:
- move `{include,lib,test}/Dialect/Torch` into `torch-mlir`
- move relevant parts of CAPI into `torch-mlir`.
- leave a few things related to the `torch-mlir` Python build commented
out, which should be resolved in a subsequent change.
2021-09-10 03:24:10 +08:00
|
|
|
using namespace mlir::torch;
|
|
|
|
using namespace mlir::torch::Torch;
|
2020-10-23 14:31:34 +08:00
|
|
|
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Utilities
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
Rework how global slot initializers work.
Rather than a per-global-slot initializer region, we now have one for
the whole module. For example, it might look like this:
```
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor)
@list(%1 : !torch.list<tensor>)
]
}
```
This new structure allows GlobalizeObjectGraph to create the initializer in a
much simpler way, avoiding the need to reason about whether different slots
alias each other. Reasoning about whether slots alias each other now is the
responsibility of InlineGlobalSlots, which has to do a much more complicated
analysis, implemented using MLIR's dataflow analysis framework.
Recommended review order:
- Check out the new IR constructs in the .mlir files of various passes
- Op definitions (*.td)
- Changes to GlobalizeObjectGraph pass.
- InlineGlobalSlots pass (~total rewrite)
- Misc changes:
- Moving torchMlirAdjustStaticInformation for sharing with C++ code.
- EraseModuleInitializer pass
To make this a bit nicer, it would be good to have a `torch.module` op
with an initializer region attached. That would be more invasive though.
This change has highlighted certain aspects of our project layering
which are worth calling out. None of our backends can handle global
slots, so we enforce that there are no global slots before backend
lowering. At an earlier stage in the project, we had aspirations of
transparently handling mutable global state and such, but for reasons
described below, that is no longer a goal. So really global slots should
be seen as a progressive lowering step as part of inlining all the
IValue's in the original program (GlobalizeObjectGraph is also one such
step).
Over time, with insights from work like IREE-JAX, it has become clear
that there isn't a reliable programming model we can compile for users
where we just transparently handle mutable global state (and some other
things, like lists and dictionaries). There is a need for an "outer
program" that orchestrates more restricted subroutines of the kind we
can handle in our compile flow here. The benefit of that is that it
decouples considerations like shapes, dtypes, etc. from the program
constructs used in the outer program. As long as the outer program can
efficiently invoke (pipelining/async/etc.) high-performance
data-parallel numerical subroutines of the kind we compile in our flow
here, then there is a complete programming model. This is also
consistent with the direction of upstream PyTorch which is becoming more
tracing-based (which inherently loses a lot of program structure, which
then has to be applied back with an "outer program" orchestrating the
traced subroutines).
2022-07-14 02:45:56 +08:00
|
|
|
Value mlir::torch::Torch::adjustStaticInformation(OpBuilder &builder,
|
|
|
|
Location loc, Value value,
|
|
|
|
Type desiredType,
|
|
|
|
bool userAllowsRefinement) {
|
|
|
|
Type type = value.getType();
|
|
|
|
|
|
|
|
// If the value is already of the desired type, we're done.
|
|
|
|
if (type == desiredType)
|
|
|
|
return value;
|
|
|
|
|
|
|
|
// If the type is a tensor, then adjust the static information.
|
2024-04-11 21:47:35 +08:00
|
|
|
if ((isa<ValueTensorType>(type) && isa<ValueTensorType>(desiredType)) ||
|
|
|
|
(isa<NonValueTensorType>(type) && isa<NonValueTensorType>(desiredType))) {
|
Rework how global slot initializers work.
Rather than a per-global-slot initializer region, we now have one for
the whole module. For example, it might look like this:
```
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor)
@list(%1 : !torch.list<tensor>)
]
}
```
This new structure allows GlobalizeObjectGraph to create the initializer in a
much simpler way, avoiding the need to reason about whether different slots
alias each other. Reasoning about whether slots alias each other now is the
responsibility of InlineGlobalSlots, which has to do a much more complicated
analysis, implemented using MLIR's dataflow analysis framework.
Recommended review order:
- Check out the new IR constructs in the .mlir files of various passes
- Op definitions (*.td)
- Changes to GlobalizeObjectGraph pass.
- InlineGlobalSlots pass (~total rewrite)
- Misc changes:
- Moving torchMlirAdjustStaticInformation for sharing with C++ code.
- EraseModuleInitializer pass
To make this a bit nicer, it would be good to have a `torch.module` op
with an initializer region attached. That would be more invasive though.
This change has highlighted certain aspects of our project layering
which are worth calling out. None of our backends can handle global
slots, so we enforce that there are no global slots before backend
lowering. At an earlier stage in the project, we had aspirations of
transparently handling mutable global state and such, but for reasons
described below, that is no longer a goal. So really global slots should
be seen as a progressive lowering step as part of inlining all the
IValue's in the original program (GlobalizeObjectGraph is also one such
step).
Over time, with insights from work like IREE-JAX, it has become clear
that there isn't a reliable programming model we can compile for users
where we just transparently handle mutable global state (and some other
things, like lists and dictionaries). There is a need for an "outer
program" that orchestrates more restricted subroutines of the kind we
can handle in our compile flow here. The benefit of that is that it
decouples considerations like shapes, dtypes, etc. from the program
constructs used in the outer program. As long as the outer program can
efficiently invoke (pipelining/async/etc.) high-performance
data-parallel numerical subroutines of the kind we compile in our flow
here, then there is a complete programming model. This is also
consistent with the direction of upstream PyTorch which is becoming more
tracing-based (which inherently loses a lot of program structure, which
then has to be applied back with an "outer program" orchestrating the
traced subroutines).
2022-07-14 02:45:56 +08:00
|
|
|
Value adjusted = builder.create<TensorStaticInfoCastOp>(value.getLoc(),
|
|
|
|
desiredType, value);
|
|
|
|
return adjusted;
|
|
|
|
}
|
|
|
|
|
|
|
|
// If the type is a subtype of desiredType, then we need to derefine it to
|
|
|
|
// desiredType, unless the user allows refinement.
|
|
|
|
if (isValidSubtype(type, desiredType)) {
|
|
|
|
if (!userAllowsRefinement) {
|
|
|
|
Value adjusted =
|
|
|
|
builder.create<DerefineOp>(value.getLoc(), desiredType, value);
|
|
|
|
return adjusted;
|
|
|
|
} else {
|
|
|
|
return value;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// If the desiredType is subtype of type, then we assume that the desiredType
|
|
|
|
// is dynamically valid, so we do an unchecked cast.
|
|
|
|
if (isValidSubtype(desiredType, type)) {
|
|
|
|
Value adjusted =
|
|
|
|
builder.create<PrimUncheckedCastOp>(value.getLoc(), desiredType, value);
|
|
|
|
return adjusted;
|
|
|
|
}
|
|
|
|
|
|
|
|
// No known adjustment.
|
|
|
|
return Value();
|
|
|
|
}
|
|
|
|
|
[torch-mlir earthmoving (1/N)] C/C++ code movement.
This creates the `external/torch-mlir` directory as an
LLVM_EXTERNAL_PROJECTS-compatible project (analogous to
`iree-dialects`) and completes movement/rename of all pure MLIR C/C++
compiler code into there. The next step will be to move all the Python
code / code that links/includes PyTorch C++ code (which currently lives
in `frontends/pytorch`) into a subdirectory here.
I call this "earthmoving" because it is mostly mechanical changes and
renames. As a quick summary (we can change this down the road easily)
- C++ `mlir::NPCOMP::Torch -> mlir::torch::Torch`
- CAPI `npcompTorchListTypeGet -> torchMlirTorchListTypeGet`
- preprocessor `#ifndef NPCOMP_ -> #ifndef TORCHMLIR_`
- CMake `NPCOMPFoo -> TorchMLIRFoo`
The goal of this is to create a standalone project creating a center of
mass for entry into the MLIR ecosystem from PyTorch, suitable in scope
for eventual inclusion/ownership in PyTorch. The idea is that
`external/torch-mlir` will some day be pulled out into its own
repository, and then npcomp will simply pull it in as a submodule.
Layering-wise, what lives in `torch-mlir` lowers code from PyTorch
(currently TorchScript, but TorchFX or pytorch/xla-style tracing are
possible extensions) down to what we have been calling the "Torch
backend contract" which is cleaned up IR (inlining, simplifcation,
conversion to value tensors, ...) entirely in the `torch` dialect. This
is the branching off point for further lowering, of which npcomp takes
one opinion (outside `torch-mlir` of course!), namely the
`TorchConversion` dialect/transforms which lower to IR suitable for IREE
and other linalg-on-tensors based lower-level compilers.
Summary of changes:
- move `{include,lib,test}/Dialect/Torch` into `torch-mlir`
- move relevant parts of CAPI into `torch-mlir`.
- leave a few things related to the `torch-mlir` Python build commented
out, which should be resolved in a subsequent change.
2021-09-10 03:24:10 +08:00
|
|
|
Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc,
|
|
|
|
BaseTensorType newType,
|
|
|
|
Value tensor) {
|
2024-04-28 05:00:56 +08:00
|
|
|
auto originalType = cast<BaseTensorType>(tensor.getType());
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
// Adjust the static information in the type to match between the original and
|
|
|
|
// new types.
|
|
|
|
if (!originalType.hasSameSizesAndDtype(newType)) {
|
|
|
|
tensor = builder.create<TensorStaticInfoCastOp>(
|
|
|
|
loc, originalType.getWithSizesAndDtypeFrom(newType), tensor);
|
|
|
|
}
|
2021-06-19 04:47:47 +08:00
|
|
|
|
|
|
|
// Unless both the original and new types are both value tensors, we end
|
|
|
|
// up creating one op that converts between the value and non-value tensor
|
|
|
|
// domains. If both the original and new types are both non-value tensors,
|
|
|
|
// then we do the copy by going to a value tensor and back.
|
2024-04-28 05:00:56 +08:00
|
|
|
if (isa<NonValueTensorType>(tensor.getType()))
|
2021-06-19 04:47:47 +08:00
|
|
|
tensor = builder.create<CopyToValueTensorOp>(loc, tensor);
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<NonValueTensorType>(newType))
|
2021-06-19 04:47:47 +08:00
|
|
|
tensor = builder.create<CopyToNonValueTensorOp>(loc, tensor);
|
|
|
|
|
|
|
|
return tensor;
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
}
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
bool mlir::torch::Torch::isListPotentiallyMutated(Value list) {
|
2024-04-28 05:00:56 +08:00
|
|
|
assert(isa<Torch::ListType>(list.getType()));
|
2022-03-10 08:44:22 +08:00
|
|
|
return llvm::any_of(list.getUsers(), potentiallyMutatesListOperands);
|
|
|
|
}
|
|
|
|
|
|
|
|
bool mlir::torch::Torch::potentiallyMutatesListOperands(Operation *op) {
|
|
|
|
// TODO: Find a better place to put this assertion.
|
|
|
|
assert((!op->hasTrait<Torch::OpTrait::HasValueSemantics>() ||
|
|
|
|
op->hasTrait<OpTrait::ReadOnly>()) &&
|
|
|
|
"HasValueSemantics should imply ReadOnly!");
|
|
|
|
// ReadOnly ops trivially do not mutate any list operands.
|
|
|
|
if (op->hasTrait<Torch::OpTrait::ReadOnly>())
|
|
|
|
return false;
|
|
|
|
|
|
|
|
// Ops with no MemoryEffectOpInterface effects also do not mutate any list
|
|
|
|
// operands.
|
|
|
|
if (auto effects = dyn_cast<MemoryEffectOpInterface>(op)) {
|
|
|
|
if (effects.hasNoEffect())
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Conservatively assume that an op might mutate any list operands.
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2021-08-18 01:59:47 +08:00
|
|
|
static IntegerAttr getI64IntegerAttr(MLIRContext *context, int64_t value) {
|
|
|
|
return IntegerAttr::get(IntegerType::get(context, 64), value);
|
|
|
|
}
|
|
|
|
|
2022-04-25 20:06:41 +08:00
|
|
|
static FloatAttr getF64FloatAttr(MLIRContext *context, double value) {
|
|
|
|
return FloatAttr::get(Float64Type::get(context), value);
|
|
|
|
}
|
|
|
|
|
2024-07-24 14:13:48 +08:00
|
|
|
static DenseElementsAttr reshapeDenseElementsAttr(DenseElementsAttr attr,
|
|
|
|
ShapedType newType) {
|
|
|
|
// TODO: DenseElementsAttr::reshape is broken for bool splats.
|
|
|
|
// Once that ticket is fixed, we can remove this conditional.
|
|
|
|
if (attr.isSplat() && newType.getElementType().isInteger(/*width=*/1)) {
|
|
|
|
auto splatValue = attr.getValues<bool>()[0];
|
|
|
|
return DenseElementsAttr::get(newType, {splatValue});
|
|
|
|
}
|
|
|
|
return attr.reshape(newType);
|
|
|
|
}
|
|
|
|
|
2023-03-07 02:12:58 +08:00
|
|
|
static Value getScalarIntValue(Value input, Location loc,
|
|
|
|
PatternRewriter &rewriter) {
|
2022-08-16 13:24:08 +08:00
|
|
|
auto inputType = input.getType();
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<Torch::IntType>(inputType)) {
|
2022-08-16 13:24:08 +08:00
|
|
|
return input;
|
|
|
|
}
|
2023-03-07 02:12:58 +08:00
|
|
|
|
2024-04-11 21:47:35 +08:00
|
|
|
auto inputTensorType = dyn_cast<BaseTensorType>(inputType);
|
2023-03-07 02:12:58 +08:00
|
|
|
if (!inputTensorType)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
Type inputDtype = inputTensorType.getOptionalDtype();
|
2024-04-29 12:00:01 +08:00
|
|
|
if (!inputDtype || !(inputDtype.isInteger(64) || inputDtype.isInteger(1)))
|
2023-03-07 02:12:58 +08:00
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
std::optional<unsigned> inputRank = getTensorRank(input);
|
|
|
|
if (!inputRank || *inputRank != 0)
|
|
|
|
return nullptr;
|
|
|
|
|
2022-06-18 02:49:36 +08:00
|
|
|
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
|
2024-04-29 12:00:01 +08:00
|
|
|
if (inputDtype.isInteger(64)) {
|
2024-05-31 14:45:13 +08:00
|
|
|
auto val = cast<DenseIntElementsAttr>(valueTensorLiteralOp.getValue())
|
2024-04-29 12:00:01 +08:00
|
|
|
.getSplatValue<int64_t>();
|
|
|
|
return rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(val));
|
|
|
|
} else {
|
2024-05-31 14:45:13 +08:00
|
|
|
auto val = cast<DenseIntElementsAttr>(valueTensorLiteralOp.getValue())
|
2024-04-29 12:00:01 +08:00
|
|
|
.getSplatValue<bool>();
|
|
|
|
return rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(val));
|
|
|
|
}
|
2022-06-18 02:49:36 +08:00
|
|
|
} else if (auto primNumToTensorScalarOp =
|
|
|
|
input.getDefiningOp<PrimNumToTensorScalarOp>()) {
|
2023-03-07 02:12:58 +08:00
|
|
|
return primNumToTensorScalarOp.getA();
|
2023-07-20 16:46:44 +08:00
|
|
|
} else if (auto tensorIntOp = input.getDefiningOp<AtenTensorIntOp>()) {
|
|
|
|
return tensorIntOp.getT();
|
2022-06-18 02:49:36 +08:00
|
|
|
}
|
2023-03-07 02:12:58 +08:00
|
|
|
return nullptr;
|
2022-06-18 02:49:36 +08:00
|
|
|
}
|
|
|
|
|
2023-11-21 13:26:17 +08:00
|
|
|
static Value getScalarFloatValue(Value input, Location loc,
|
|
|
|
PatternRewriter &rewriter) {
|
|
|
|
auto inputType = input.getType();
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<Torch::FloatType>(inputType)) {
|
2023-11-21 13:26:17 +08:00
|
|
|
return input;
|
|
|
|
}
|
|
|
|
|
2024-04-11 21:47:35 +08:00
|
|
|
auto inputTensorType = dyn_cast<BaseTensorType>(inputType);
|
2023-11-21 13:26:17 +08:00
|
|
|
if (!inputTensorType)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
Type inputDtype = inputTensorType.getOptionalDtype();
|
|
|
|
if (!inputDtype ||
|
|
|
|
(!inputDtype.isF16() && !inputDtype.isF32() && !inputDtype.isF64()))
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
std::optional<unsigned> inputRank = getTensorRank(input);
|
|
|
|
if (!inputRank || *inputRank != 0)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
|
2024-05-31 14:45:13 +08:00
|
|
|
auto val = cast<DenseFPElementsAttr>(valueTensorLiteralOp.getValue())
|
2023-11-21 13:26:17 +08:00
|
|
|
.getSplatValue<FloatAttr>()
|
|
|
|
.getValueAsDouble();
|
|
|
|
return rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
loc, rewriter.getF64FloatAttr(val));
|
|
|
|
} else if (auto primNumToTensorScalarOp =
|
|
|
|
input.getDefiningOp<PrimNumToTensorScalarOp>()) {
|
|
|
|
return primNumToTensorScalarOp.getA();
|
|
|
|
} else if (auto tensorFloatOp = input.getDefiningOp<AtenTensorFloatOp>()) {
|
|
|
|
return tensorFloatOp.getT();
|
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2021-01-28 08:35:44 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// MethodOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
2024-01-30 01:59:33 +08:00
|
|
|
auto func = symbolTable.lookupNearestSymbolFrom<func::FuncOp>(
|
|
|
|
*this, getFunctionAttr());
|
2021-01-28 08:35:44 +08:00
|
|
|
if (!func)
|
2022-12-08 04:20:41 +08:00
|
|
|
return emitError() << "'@" << getFunction()
|
2021-01-28 08:35:44 +08:00
|
|
|
<< "' does not reference a valid function";
|
2021-02-18 03:28:51 +08:00
|
|
|
if (func.getVisibility() != SymbolTable::Visibility::Private)
|
2022-12-08 04:20:41 +08:00
|
|
|
return emitError() << "'@" << getFunction()
|
2021-02-18 03:28:51 +08:00
|
|
|
<< "' must reference a private function";
|
|
|
|
if (func.isDeclaration())
|
2022-12-08 04:20:41 +08:00
|
|
|
return emitError() << "'@" << getFunction()
|
2021-02-18 03:28:51 +08:00
|
|
|
<< "' must reference a function that is defined (not "
|
|
|
|
"merely declared)";
|
|
|
|
auto expectedReceiverArgType = NnModuleType::get(
|
|
|
|
getContext(), getOperation()->getParentOfType<ClassTypeOp>().getName());
|
2022-04-27 03:27:51 +08:00
|
|
|
if (func.getFunctionType().getNumInputs() == 0 ||
|
|
|
|
func.getFunctionType().getInput(0) != expectedReceiverArgType) {
|
2022-12-08 04:20:41 +08:00
|
|
|
return emitError() << "the referenced function '" << getFunction()
|
2021-02-18 03:28:51 +08:00
|
|
|
<< "' must have a first argument of type "
|
|
|
|
<< expectedReceiverArgType;
|
|
|
|
}
|
2021-01-28 08:35:44 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// NnModuleOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-03-16 08:54:57 +08:00
|
|
|
LogicalResult NnModuleOp::verify() {
|
|
|
|
for (Operation &child : *getBody())
|
2021-02-18 03:28:51 +08:00
|
|
|
if (!isa<SlotOp, NnModuleTerminatorOp>(&child))
|
|
|
|
return child.emitOpError() << "is not allowed inside 'torch.nn_module'";
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult NnModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
2021-09-04 02:38:00 +08:00
|
|
|
auto classType = symbolTable.lookupNearestSymbolFrom<ClassTypeOp>(
|
|
|
|
*this, SymbolRefAttr::get(getContext(), getClassName()));
|
2021-02-18 03:28:51 +08:00
|
|
|
if (!classType)
|
|
|
|
return emitError() << "'" << getClassName()
|
|
|
|
<< "' does not reference a valid class type";
|
|
|
|
|
|
|
|
auto attrs = llvm::to_vector<6>(getBody()->getOps<SlotOp>());
|
|
|
|
auto attrDefs = llvm::to_vector<6>(classType.getBody()->getOps<AttrOp>());
|
|
|
|
if (attrs.size() != attrDefs.size())
|
|
|
|
return emitError() << "number of 'torch.slot's in a 'torch.nn_module' must "
|
|
|
|
"match number of 'torch.attr's in "
|
|
|
|
"the corresponding 'torch.class_type'";
|
|
|
|
for (int i = 0, e = attrs.size(); i != e; i++) {
|
|
|
|
SlotOp attr = attrs[i];
|
|
|
|
AttrOp attrDef = attrDefs[i];
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!isValidSubtype(attr.getValue().getType(), attrDef.getType()) ||
|
|
|
|
attr.getName() != attrDef.getName()) {
|
2021-02-18 03:28:51 +08:00
|
|
|
return attr.emitOpError()
|
|
|
|
.append("is expected to match type and name of '",
|
|
|
|
attrDef.getOperation(), "'")
|
|
|
|
.attachNote(attrDef.getLoc())
|
|
|
|
.append("see torch.attr at corresponding index ", i, " here");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2021-06-05 06:57:21 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimListConstructOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-03-16 08:54:57 +08:00
|
|
|
LogicalResult PrimListConstructOp::verify() {
|
|
|
|
auto resultType = getResult().getType();
|
2024-04-11 21:47:35 +08:00
|
|
|
auto resultElementType = dyn_cast<ListType>(resultType).getContainedType();
|
2021-06-05 06:57:21 +08:00
|
|
|
auto matchResultElementType = [&](Type type) {
|
2021-08-08 10:33:39 +08:00
|
|
|
return isValidSubtype(type, resultElementType);
|
2021-06-05 06:57:21 +08:00
|
|
|
};
|
2022-03-16 08:54:57 +08:00
|
|
|
if (!llvm::all_of(getOperandTypes(), matchResultElementType)) {
|
|
|
|
return emitError() << "operand types should have the same type as the "
|
|
|
|
"list contained type";
|
2021-06-17 06:53:15 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
2021-06-05 06:57:21 +08:00
|
|
|
}
|
|
|
|
|
2021-08-08 10:33:39 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimDictConstructOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-03-16 08:54:57 +08:00
|
|
|
LogicalResult PrimDictConstructOp::verify() {
|
2021-08-08 10:33:39 +08:00
|
|
|
auto isValidSubTypeOf = [](Type expectedType) {
|
|
|
|
return [=](Type type) { return isValidSubtype(type, expectedType); };
|
|
|
|
};
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!llvm::all_of(getKeys().getTypes(), isValidSubTypeOf(getKeyType())))
|
2022-03-16 08:54:57 +08:00
|
|
|
return emitError() << "keys should be of Dict key type";
|
2021-08-08 10:33:39 +08:00
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!llvm::all_of(getValues().getTypes(), isValidSubTypeOf(getValueType())))
|
2022-03-16 08:54:57 +08:00
|
|
|
return emitError() << "values should be of Dict value type";
|
2021-08-11 09:28:50 +08:00
|
|
|
|
2021-08-08 10:33:39 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2021-02-18 03:28:51 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ClassTypeOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-03-16 08:54:57 +08:00
|
|
|
LogicalResult ClassTypeOp::verify() {
|
2021-02-18 03:28:51 +08:00
|
|
|
llvm::StringMap<Operation *> namesToOps;
|
2022-03-16 08:54:57 +08:00
|
|
|
for (Operation &child : getBody()->without_terminator()) {
|
2021-02-18 03:28:51 +08:00
|
|
|
if (!isa<AttrOp, MethodOp>(&child))
|
|
|
|
return child.emitOpError() << "is not allowed inside `torch.class_type`";
|
|
|
|
StringRef name;
|
|
|
|
if (auto attr = dyn_cast<AttrOp>(child))
|
2022-12-08 04:20:41 +08:00
|
|
|
name = attr.getName();
|
2021-02-18 03:28:51 +08:00
|
|
|
else
|
2022-12-08 04:20:41 +08:00
|
|
|
name = cast<MethodOp>(child).getName();
|
2021-02-18 03:28:51 +08:00
|
|
|
auto itAndWasInserted = namesToOps.insert({name, &child});
|
|
|
|
auto it = itAndWasInserted.first;
|
|
|
|
bool wasInserted = itAndWasInserted.second;
|
|
|
|
if (!wasInserted) {
|
2022-03-16 08:54:57 +08:00
|
|
|
auto diag = emitOpError().append("has duplicate attr/method with name '",
|
|
|
|
name, "'");
|
2021-02-18 03:28:51 +08:00
|
|
|
diag.attachNote(it->second->getLoc())
|
|
|
|
.append("see first conflicting attr/method here");
|
|
|
|
diag.attachNote(child.getLoc())
|
|
|
|
.append("see second conflicting attr/method here");
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2021-01-28 08:35:44 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2021-03-02 07:00:32 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimLoopOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-09-13 06:09:57 +08:00
|
|
|
OperandRange PrimLoopOp::getEntrySuccessorOperands(RegionBranchPoint point) {
|
|
|
|
assert(point == getRegion());
|
2022-12-08 04:20:41 +08:00
|
|
|
return getIterArgsInit();
|
2021-03-02 07:00:32 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
void PrimLoopOp::getSuccessorRegions(
|
2023-09-13 06:09:57 +08:00
|
|
|
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
|
|
|
|
Region ®ion = getRegion();
|
|
|
|
if (!point.getRegionOrNull()) {
|
|
|
|
regions.emplace_back(®ion, region.getArguments().slice(1));
|
2021-03-02 07:00:32 +08:00
|
|
|
return;
|
|
|
|
}
|
2023-09-13 06:09:57 +08:00
|
|
|
assert(point == region);
|
|
|
|
regions.emplace_back(®ion, region.getArguments().slice(1));
|
2021-03-02 07:00:32 +08:00
|
|
|
regions.emplace_back(getResults());
|
|
|
|
}
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
bool PrimLoopOp::isForLike() {
|
|
|
|
bool b;
|
2022-12-08 04:20:41 +08:00
|
|
|
return matchPattern(getInitialCondition(), m_TorchConstantBool(&b)) && b;
|
2022-03-10 08:44:22 +08:00
|
|
|
}
|
|
|
|
|
2021-08-19 23:13:55 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimLoopConditionOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-09-13 06:09:57 +08:00
|
|
|
MutableOperandRange
|
|
|
|
PrimLoopConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
|
2021-08-19 23:13:55 +08:00
|
|
|
// Pass all operands except the condition to the successor which is the
|
|
|
|
// parent loop op.
|
2022-12-08 04:20:41 +08:00
|
|
|
return getIterArgsMutable();
|
2021-08-19 23:13:55 +08:00
|
|
|
}
|
|
|
|
|
2021-06-17 01:23:26 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimIfOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-13 02:47:12 +08:00
|
|
|
ParseResult PrimIfOp::parse(OpAsmParser &parser, OperationState &result) {
|
2021-06-17 01:23:26 +08:00
|
|
|
// Create the regions.
|
|
|
|
result.regions.reserve(2);
|
|
|
|
Region *thenRegion = result.addRegion();
|
|
|
|
Region *elseRegion = result.addRegion();
|
|
|
|
|
|
|
|
auto &builder = parser.getBuilder();
|
2022-04-27 03:27:51 +08:00
|
|
|
OpAsmParser::UnresolvedOperand cond;
|
2021-06-17 01:23:26 +08:00
|
|
|
Type boolType = builder.getType<Torch::BoolType>();
|
|
|
|
if (parser.parseOperand(cond) ||
|
|
|
|
parser.resolveOperand(cond, boolType, result.operands))
|
|
|
|
return failure();
|
|
|
|
// Parse results type list.
|
|
|
|
if (parser.parseArrowTypeList(result.types))
|
|
|
|
return failure();
|
|
|
|
// Parse the 'then' region.
|
|
|
|
if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
|
|
|
|
return failure();
|
|
|
|
// Parse the 'else' region.
|
|
|
|
if (parser.parseKeyword("else"))
|
|
|
|
return failure();
|
|
|
|
if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
|
|
|
|
return failure();
|
|
|
|
// Parse the optional attribute list.
|
|
|
|
if (parser.parseOptionalAttrDict(result.attributes))
|
|
|
|
return failure();
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-02-13 02:47:12 +08:00
|
|
|
void PrimIfOp::print(OpAsmPrinter &p) {
|
2022-12-08 04:20:41 +08:00
|
|
|
p << " " << getCondition();
|
2022-02-13 02:47:12 +08:00
|
|
|
p << " -> (" << getResultTypes() << ") ";
|
2022-12-08 04:20:41 +08:00
|
|
|
p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false);
|
2022-01-26 14:16:30 +08:00
|
|
|
p << " else ";
|
2022-12-08 04:20:41 +08:00
|
|
|
p.printRegion(getElseRegion(), /*printEntryBlockArgs=*/false);
|
2021-06-17 01:23:26 +08:00
|
|
|
|
2022-02-13 02:47:12 +08:00
|
|
|
p.printOptionalAttrDict((*this)->getAttrs());
|
2021-06-17 01:23:26 +08:00
|
|
|
}
|
|
|
|
|
2023-09-13 06:09:57 +08:00
|
|
|
void PrimIfOp::getSuccessorRegions(RegionBranchPoint point,
|
2021-06-19 10:33:14 +08:00
|
|
|
SmallVectorImpl<RegionSuccessor> ®ions) {
|
2021-06-17 01:23:26 +08:00
|
|
|
// The `then` and the `else` region branch back to the parent operation.
|
2023-09-13 06:09:57 +08:00
|
|
|
if (point.getRegionOrNull()) {
|
2021-06-17 01:23:26 +08:00
|
|
|
regions.push_back(RegionSuccessor(getResults()));
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
// If the condition is constant, we can give a more precise answer.
|
2023-08-16 00:53:28 +08:00
|
|
|
bool condition;
|
|
|
|
if (matchPattern(getCondition(), m_TorchConstantBool(&condition))) {
|
|
|
|
Region *executedRegion = condition ? &getThenRegion() : &getElseRegion();
|
2021-06-17 01:23:26 +08:00
|
|
|
regions.push_back(RegionSuccessor(executedRegion));
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
// If the condition isn't constant, both regions may be executed.
|
2022-12-08 04:20:41 +08:00
|
|
|
regions.push_back(RegionSuccessor(&getThenRegion()));
|
|
|
|
regions.push_back(RegionSuccessor(&getElseRegion()));
|
2021-06-17 01:23:26 +08:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Replaces the given op with the contents of the given single-block region,
|
|
|
|
/// using the operands of the block terminator to replace operation results.
|
|
|
|
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
|
|
|
|
Region ®ion, ValueRange blockArgs = {}) {
|
|
|
|
assert(llvm::hasSingleElement(region) && "expected single-region block");
|
|
|
|
Block *block = ®ion.front();
|
|
|
|
Operation *terminator = block->getTerminator();
|
|
|
|
ValueRange results = terminator->getOperands();
|
2023-03-21 01:31:05 +08:00
|
|
|
rewriter.inlineBlockBefore(block, op, blockArgs);
|
2021-06-17 01:23:26 +08:00
|
|
|
rewriter.replaceOp(op, results);
|
|
|
|
rewriter.eraseOp(terminator);
|
|
|
|
}
|
|
|
|
|
|
|
|
void PrimIfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
// If the condition is constant, delete the dead branch and inline the live
|
|
|
|
// branch.
|
|
|
|
patterns.add(+[](PrimIfOp op, PatternRewriter &rewriter) {
|
2024-01-30 01:59:33 +08:00
|
|
|
auto constantBool =
|
|
|
|
op.getCondition().getDefiningOp<Torch::ConstantBoolOp>();
|
2021-06-17 01:23:26 +08:00
|
|
|
if (!constantBool)
|
|
|
|
return rewriter.notifyMatchFailure(op, "non-constant condition");
|
2024-01-30 01:59:33 +08:00
|
|
|
replaceOpWithRegion(rewriter, op,
|
|
|
|
constantBool.getValue() ? op.getThenRegion()
|
|
|
|
: op.getElseRegion());
|
2021-06-17 01:23:26 +08:00
|
|
|
return success();
|
|
|
|
});
|
2022-03-10 08:44:22 +08:00
|
|
|
// If the thenRegion and elseRegion yield the same Value's, then use those
|
|
|
|
// directly.
|
|
|
|
patterns.add(+[](PrimIfOp op, PatternRewriter &rewriter) {
|
2022-12-08 04:20:41 +08:00
|
|
|
auto trueTerminator = op.getThenRegion().front().getTerminator();
|
|
|
|
auto falseTerminator = op.getElseRegion().front().getTerminator();
|
2022-03-10 08:44:22 +08:00
|
|
|
bool madeChange = false;
|
|
|
|
SmallVector<int> resultsToErase;
|
|
|
|
for (auto t : llvm::zip(trueTerminator->getOperands(),
|
|
|
|
falseTerminator->getOperands(), op->getResults())) {
|
|
|
|
auto trueVal = std::get<0>(t);
|
|
|
|
auto falseVal = std::get<1>(t);
|
|
|
|
auto resultToBeReplaced = std::get<2>(t);
|
|
|
|
if (trueVal == falseVal) {
|
|
|
|
madeChange |= !resultToBeReplaced.use_empty();
|
|
|
|
resultToBeReplaced.replaceAllUsesWith(trueVal);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// We leave it up to a separate pattern (not yet implemented) to erase the
|
|
|
|
// results that are now dead. That transformation is independently useful,
|
|
|
|
// and also pretty tricky to implement because it changes the number of
|
|
|
|
// results.
|
|
|
|
return success(madeChange);
|
|
|
|
});
|
|
|
|
// Erase any dead results.
|
|
|
|
patterns.add(+[](PrimIfOp op, PatternRewriter &rewriter) {
|
|
|
|
llvm::BitVector resultsToErase(op.getNumResults());
|
|
|
|
for (auto result : llvm::enumerate(op->getResults())) {
|
|
|
|
if (result.value().use_empty())
|
|
|
|
resultsToErase.set(result.index());
|
|
|
|
}
|
|
|
|
|
|
|
|
// If no results have uses and there are no side effects, just erase the op.
|
|
|
|
// Approximate the body having no side effects by checking if it is just a
|
|
|
|
// terminator.
|
|
|
|
// Note: We don't want to make this logic too fancy, because in general,
|
|
|
|
// checking for recursive side effects can result in a quadratic amount of
|
|
|
|
// work (N nested If's each resulting in O(N) work). It should probably be
|
|
|
|
// split into its own pattern if we want to make it fancier.
|
|
|
|
if (resultsToErase.all() &&
|
2022-12-08 04:20:41 +08:00
|
|
|
llvm::hasSingleElement(op.getThenRegion().front()) &&
|
|
|
|
llvm::hasSingleElement(op.getElseRegion().front())) {
|
2022-03-10 08:44:22 +08:00
|
|
|
rewriter.eraseOp(op);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// If there are no results to erase, we're done.
|
|
|
|
if (!resultsToErase.any())
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
SmallVector<Type> newResultTypes;
|
|
|
|
for (int i = 0, e = op->getNumResults(); i < e; ++i) {
|
|
|
|
if (resultsToErase[i])
|
|
|
|
continue;
|
|
|
|
newResultTypes.push_back(op->getResult(i).getType());
|
|
|
|
}
|
2024-01-30 01:59:33 +08:00
|
|
|
auto newIf = rewriter.create<PrimIfOp>(op->getLoc(), newResultTypes,
|
|
|
|
op.getCondition());
|
2022-12-08 04:20:41 +08:00
|
|
|
rewriter.inlineRegionBefore(op.getThenRegion(), newIf.getThenRegion(),
|
|
|
|
newIf.getThenRegion().end());
|
|
|
|
rewriter.inlineRegionBefore(op.getElseRegion(), newIf.getElseRegion(),
|
|
|
|
newIf.getElseRegion().end());
|
2024-01-30 01:59:33 +08:00
|
|
|
newIf.getThenRegion().front().getTerminator()->eraseOperands(
|
|
|
|
resultsToErase);
|
|
|
|
newIf.getElseRegion().front().getTerminator()->eraseOperands(
|
|
|
|
resultsToErase);
|
2022-03-10 08:44:22 +08:00
|
|
|
SmallVector<Value> replacementValues;
|
|
|
|
for (int i = 0, e = op->getNumResults(), nextNewValue = 0; i < e; ++i) {
|
|
|
|
if (resultsToErase[i])
|
|
|
|
replacementValues.push_back(nullptr);
|
|
|
|
else
|
|
|
|
replacementValues.push_back(newIf->getResult(nextNewValue++));
|
|
|
|
}
|
|
|
|
rewriter.replaceOp(op, replacementValues);
|
|
|
|
|
|
|
|
return success();
|
|
|
|
});
|
2021-06-17 01:23:26 +08:00
|
|
|
}
|
|
|
|
|
2024-05-18 22:45:14 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenDotOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void AtenDotOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenDotOp op, PatternRewriter &rewriter) {
|
|
|
|
auto ty = dyn_cast<ValueTensorType>(op.getResult().getType());
|
|
|
|
if (!ty || !ty.hasSizes() || !ty.hasDtype()) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenMatmulOp>(op, op.getResult().getType(),
|
|
|
|
op.getSelf(), op.getTensor());
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2023-06-09 19:06:25 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// RuntimeAssertOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void RuntimeAssertOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](RuntimeAssertOp op, PatternRewriter &rewriter) {
|
|
|
|
bool value;
|
|
|
|
if (!matchPattern(op.getCondition(), m_TorchConstantBool(&value)))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
if (value) {
|
2024-01-30 01:59:33 +08:00
|
|
|
rewriter.eraseOp(op);
|
|
|
|
return success();
|
2023-06-09 19:06:25 +08:00
|
|
|
}
|
|
|
|
// Even if the condition is statically false, the assert might never be
|
|
|
|
// executed.
|
|
|
|
return failure();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2021-04-27 02:42:41 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// DerefineOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
bool DerefineOp::areCastCompatible(mlir::TypeRange inputs,
|
|
|
|
mlir::TypeRange outputs) {
|
|
|
|
return isValidSubtype(inputs[0], outputs[0]);
|
|
|
|
}
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult DerefineOp::fold(FoldAdaptor adaptor) {
|
2022-03-10 08:44:22 +08:00
|
|
|
auto uncheckedCast = getOperand().getDefiningOp<PrimUncheckedCastOp>();
|
|
|
|
if (!uncheckedCast)
|
|
|
|
return nullptr;
|
|
|
|
if (uncheckedCast.getOperand().getType() == getType())
|
|
|
|
return uncheckedCast.getOperand();
|
|
|
|
return nullptr;
|
|
|
|
}
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
|
2022-08-16 13:24:08 +08:00
|
|
|
void DerefineOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
2022-03-30 06:57:31 +08:00
|
|
|
patterns.add(+[](DerefineOp op, PatternRewriter &rewriter) {
|
|
|
|
bool madeChange = false;
|
|
|
|
for (OpOperand &use : llvm::make_early_inc_range(op->getUses())) {
|
|
|
|
if (use.getOwner()->hasTrait<OpTrait::AllowsTypeRefinement>()) {
|
|
|
|
use.set(op.getOperand());
|
|
|
|
madeChange = true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return success(madeChange);
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
static OpFoldResult atenIsOrIsNotFoldHelper(Operation *op, bool equalIsTrue) {
|
|
|
|
Value lhs = op->getOperand(0);
|
|
|
|
Value rhs = op->getOperand(1);
|
|
|
|
// Look through DerefineOp's to get more refined static information.
|
|
|
|
if (auto derefine = lhs.getDefiningOp<DerefineOp>())
|
|
|
|
lhs = derefine.getOperand();
|
|
|
|
if (auto derefine = rhs.getDefiningOp<DerefineOp>())
|
|
|
|
rhs = derefine.getOperand();
|
|
|
|
Type lhsType = lhs.getType();
|
|
|
|
Type rhsType = rhs.getType();
|
|
|
|
|
|
|
|
// If either type is a NoneType, make it be the lhsType.
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<Torch::NoneType>(rhsType)) {
|
2022-03-10 08:44:22 +08:00
|
|
|
std::swap(lhsType, rhsType);
|
2022-01-28 21:35:40 +08:00
|
|
|
std::swap(lhs, rhs);
|
2022-03-10 08:44:22 +08:00
|
|
|
}
|
2022-01-28 21:35:40 +08:00
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
// For now, check a few specific cases.
|
2022-01-28 21:35:40 +08:00
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
// If both types are the singleton `!torch.none` type, then we don't even need
|
|
|
|
// to look at the values.
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<Torch::NoneType>(lhsType) && isa<Torch::NoneType>(rhsType))
|
2022-03-10 08:44:22 +08:00
|
|
|
return IntegerAttr::get(IntegerType::get(op->getContext(), 1), equalIsTrue);
|
2021-08-11 09:28:50 +08:00
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
// If neither type is a subtype of the other, then the result is false.
|
|
|
|
// TODO: Implement and use subtype infra for this.
|
|
|
|
// For now, check a specific case.
|
|
|
|
// If the rhs is not OptionalType, then we know it cannot be None.
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<Torch::NoneType>(lhsType) && !isa<Torch::OptionalType>(rhsType)) {
|
2022-03-10 08:44:22 +08:00
|
|
|
return IntegerAttr::get(IntegerType::get(op->getContext(), 1),
|
|
|
|
!equalIsTrue);
|
|
|
|
}
|
2021-08-11 09:28:50 +08:00
|
|
|
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten__RangeLengthOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult Aten__RangeLengthOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto lo = adaptor.getLo();
|
|
|
|
auto hi = adaptor.getHi();
|
|
|
|
auto step = adaptor.getStep();
|
2022-03-10 08:44:22 +08:00
|
|
|
if (!lo || !hi || !step)
|
|
|
|
return nullptr;
|
2024-04-11 21:47:35 +08:00
|
|
|
auto loInt = dyn_cast_or_null<IntegerAttr>(lo).getValue();
|
|
|
|
auto hiInt = dyn_cast_or_null<IntegerAttr>(hi).getValue();
|
|
|
|
auto stepInt = dyn_cast_or_null<IntegerAttr>(step).getValue();
|
2022-03-10 08:44:22 +08:00
|
|
|
// TODO: Implement folding for negative steps.
|
|
|
|
if (stepInt.isNegative())
|
|
|
|
return nullptr;
|
|
|
|
// From Python language spec:
|
|
|
|
// r[i] = lo + step*i such that i >= 0 and r[i] < hi
|
|
|
|
// So maximize `i` such that lo + step * i < hi
|
|
|
|
// ==> i == ceildiv(hi - lo, step)
|
2024-04-11 21:47:35 +08:00
|
|
|
return IntegerAttr::get(cast<TypedAttr>(lo).getType(),
|
2022-03-10 08:44:22 +08:00
|
|
|
llvm::APIntOps::RoundingSDiv(hiInt - loInt, stepInt,
|
|
|
|
APInt::Rounding::UP));
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten__DeriveIndexOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult Aten__DeriveIndexOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto index = adaptor.getIndex();
|
|
|
|
auto start = adaptor.getStart();
|
|
|
|
auto step = adaptor.getStep();
|
2022-03-10 08:44:22 +08:00
|
|
|
if (!index || !start || !step)
|
|
|
|
return nullptr;
|
2024-04-11 21:47:35 +08:00
|
|
|
auto indexInt = dyn_cast_or_null<IntegerAttr>(index).getValue();
|
|
|
|
auto startInt = dyn_cast_or_null<IntegerAttr>(start).getValue();
|
|
|
|
auto stepInt = dyn_cast_or_null<IntegerAttr>(step).getValue();
|
|
|
|
return IntegerAttr::get(cast<TypedAttr>(index).getType(),
|
2022-08-09 11:17:35 +08:00
|
|
|
startInt + stepInt * indexInt);
|
2022-03-10 08:44:22 +08:00
|
|
|
}
|
|
|
|
|
2021-08-11 09:28:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten__Is__Op
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult Aten__Is__Op::fold(FoldAdaptor adaptor) {
|
2021-08-11 09:28:50 +08:00
|
|
|
return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/true);
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten__Isnot__Op
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult Aten__Isnot__Op::fold(FoldAdaptor adaptor) {
|
2021-08-11 09:28:50 +08:00
|
|
|
return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/false);
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten__Not__Op
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult Aten__Not__Op::fold(FoldAdaptor adaptor) {
|
2021-08-11 09:28:50 +08:00
|
|
|
bool value;
|
|
|
|
if (!matchPattern(getOperand(), m_TorchConstantBool(&value)))
|
|
|
|
return nullptr;
|
|
|
|
return IntegerAttr::get(IntegerType::get(getContext(), 1), !value);
|
|
|
|
}
|
|
|
|
|
2024-07-31 17:23:53 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten__Or__Op
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult Aten__Or__BoolOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto valueA = dyn_cast_or_null<IntegerAttr>(adaptor.getA());
|
|
|
|
auto valueB = dyn_cast_or_null<IntegerAttr>(adaptor.getB());
|
2024-09-04 07:38:20 +08:00
|
|
|
if (!valueA && !valueB)
|
2024-07-31 17:23:53 +08:00
|
|
|
return nullptr;
|
2024-09-04 07:38:20 +08:00
|
|
|
if ((valueA && valueA.getValue() == 1) || (valueB && valueB.getValue() == 1))
|
|
|
|
return IntegerAttr::get(IntegerType::get(getContext(), 1), 1);
|
|
|
|
if (valueA && valueA.getValue() == 0)
|
|
|
|
return getB();
|
|
|
|
if (valueB && valueB.getValue() == 0)
|
|
|
|
return getA();
|
|
|
|
// unreachable
|
|
|
|
return nullptr;
|
2024-07-31 17:23:53 +08:00
|
|
|
}
|
|
|
|
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2021-10-21 23:50:01 +08:00
|
|
|
// AtenNeBoolOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenNeBoolOp::fold(FoldAdaptor adaptor) {
|
2021-10-21 23:50:01 +08:00
|
|
|
if (getOperand(0) == getOperand(1))
|
|
|
|
return IntegerAttr::get(IntegerType::get(getContext(), 1), false);
|
|
|
|
|
|
|
|
bool a, b;
|
|
|
|
if (!matchPattern(getOperand(0), m_TorchConstantBool(&a)))
|
|
|
|
return nullptr;
|
|
|
|
if (!matchPattern(getOperand(1), m_TorchConstantBool(&b)))
|
|
|
|
return nullptr;
|
|
|
|
return IntegerAttr::get(IntegerType::get(getContext(), 1), a != b);
|
|
|
|
}
|
|
|
|
|
2024-03-27 03:41:40 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenUnsqueezeOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenUnsqueezeOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto selfTy = dyn_cast<BaseTensorType>(getSelf().getType());
|
|
|
|
auto rty = dyn_cast<BaseTensorType>(getType());
|
|
|
|
if (!rty.hasDtype())
|
|
|
|
return {};
|
|
|
|
|
|
|
|
if (auto attr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf())) {
|
|
|
|
auto aty = dyn_cast<RankedTensorType>(attr.getType());
|
|
|
|
if (rty.hasSizes() && rty.areAllSizesKnown() && attr.isSplat()) {
|
|
|
|
auto naty = RankedTensorType::get(rty.getSizes(), aty.getElementType());
|
|
|
|
return DenseElementsAttr::get(naty, attr.getSplatValue<Attribute>());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (getSelf().getType() != getResult().getType())
|
|
|
|
return nullptr;
|
|
|
|
if (selfTy && rty) {
|
|
|
|
if (selfTy.hasSizes() && rty.hasSizes() &&
|
|
|
|
selfTy.getSizes().size() == rty.getSizes().size())
|
|
|
|
return getSelf();
|
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2021-11-25 04:19:13 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSqueezeOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) {
|
2024-03-27 03:41:40 +08:00
|
|
|
auto selfTy = dyn_cast<BaseTensorType>(getSelf().getType());
|
|
|
|
auto rty = dyn_cast<BaseTensorType>(getType());
|
|
|
|
if (!rty.hasDtype())
|
|
|
|
return {};
|
|
|
|
|
|
|
|
if (auto attr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf())) {
|
|
|
|
auto aty = dyn_cast<RankedTensorType>(attr.getType());
|
|
|
|
if (rty.hasSizes() && rty.areAllSizesKnown() && attr.isSplat()) {
|
|
|
|
auto naty = RankedTensorType::get(rty.getSizes(), aty.getElementType());
|
|
|
|
return DenseElementsAttr::get(naty, attr.getSplatValue<Attribute>());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (getSelf().getType() != getResult().getType())
|
2024-01-05 06:33:41 +08:00
|
|
|
return nullptr;
|
2024-03-27 03:41:40 +08:00
|
|
|
if (selfTy && rty) {
|
|
|
|
if (selfTy.hasSizes() && rty.hasSizes() &&
|
|
|
|
selfTy.getSizes().size() == rty.getSizes().size())
|
|
|
|
return getSelf();
|
2021-11-25 04:19:13 +08:00
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2021-11-30 22:50:55 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSqueezeDimOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) {
|
2024-07-24 14:13:48 +08:00
|
|
|
auto inType = dyn_cast<ValueTensorType>(getOperand(0).getType());
|
|
|
|
auto outType = dyn_cast<ValueTensorType>(getResult().getType());
|
|
|
|
if (!inType || !outType || !inType.areAllSizesKnown() ||
|
|
|
|
!outType.areAllSizesKnown() || !inType.hasDtype() ||
|
|
|
|
!outType.hasDtype()) {
|
2024-01-05 06:33:41 +08:00
|
|
|
return nullptr;
|
2024-07-24 14:13:48 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
if (inType == outType) {
|
|
|
|
return getOperand(0);
|
|
|
|
}
|
|
|
|
|
|
|
|
DenseElementsAttr input =
|
|
|
|
dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
|
|
|
|
if (input) {
|
|
|
|
return reshapeDenseElementsAttr(input, outType.toBuiltinTensor());
|
2021-11-30 22:50:55 +08:00
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2021-12-23 20:04:29 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenToDtypeOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenToDtypeOp::fold(FoldAdaptor adaptor) {
|
2021-12-23 20:04:29 +08:00
|
|
|
bool nonBlocking, copyArg;
|
|
|
|
// The non_blocking arg must be `False`.
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(getNonBlocking(), m_TorchConstantBool(&nonBlocking)) ||
|
2021-12-23 20:04:29 +08:00
|
|
|
nonBlocking)
|
|
|
|
return nullptr;
|
|
|
|
// The copy arg must be `False`.
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(getCopy(), m_TorchConstantBool(©Arg)) || copyArg)
|
2021-12-23 20:04:29 +08:00
|
|
|
return nullptr;
|
|
|
|
// The memory_format arg must be `none`.
|
2024-04-28 05:00:56 +08:00
|
|
|
if (!isa<Torch::NoneType>(getMemoryFormat().getType()))
|
2021-12-23 20:04:29 +08:00
|
|
|
return nullptr;
|
|
|
|
|
2024-04-28 05:00:56 +08:00
|
|
|
auto inputType = cast<BaseTensorType>(getSelf().getType());
|
|
|
|
auto resType = cast<BaseTensorType>(getType());
|
2022-03-10 08:44:22 +08:00
|
|
|
// If the types aren't equal, then we can't fold.
|
|
|
|
if (inputType != resType)
|
2021-12-23 20:04:29 +08:00
|
|
|
return nullptr;
|
2022-03-10 08:44:22 +08:00
|
|
|
// If the type does not have a statically known dtype, then we cannot fold.
|
|
|
|
// For example, folding `tensor<*,unk>` to `tensor<*,unk>` would be wrong,
|
|
|
|
// since the `unk` could be dynamically different for the operand and result.
|
|
|
|
if (!inputType.hasDtype())
|
2021-12-23 20:04:29 +08:00
|
|
|
return nullptr;
|
|
|
|
// Fold when both the input tensor and result are of the same type.
|
|
|
|
return getOperand(0);
|
|
|
|
}
|
|
|
|
|
2022-04-27 19:07:40 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenToDtypeLayoutOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) {
|
2022-04-27 19:07:40 +08:00
|
|
|
// The pin_memory arg should be either constant `False` or `none`.
|
2024-04-28 05:00:56 +08:00
|
|
|
if (!isa<Torch::NoneType>(getPinMemory().getType())) {
|
2022-04-27 19:07:40 +08:00
|
|
|
bool pinMemory;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(getPinMemory(), m_TorchConstantBool(&pinMemory)))
|
2022-04-27 19:07:40 +08:00
|
|
|
return nullptr;
|
|
|
|
else if (pinMemory)
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
// The non_blocking arg should be constant `False`.
|
|
|
|
bool nonBlocking;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(getNonBlocking(), m_TorchConstantBool(&nonBlocking)))
|
2022-04-27 19:07:40 +08:00
|
|
|
return nullptr;
|
|
|
|
else if (nonBlocking)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
// The copy arg should be constant `False`.
|
|
|
|
bool copyArg;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(getCopy(), m_TorchConstantBool(©Arg)))
|
2022-04-27 19:07:40 +08:00
|
|
|
return nullptr;
|
|
|
|
else if (copyArg)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
// The device arg must be `none`.
|
2024-04-28 05:00:56 +08:00
|
|
|
if (!isa<Torch::NoneType>(getDevice().getType()))
|
2022-04-27 19:07:40 +08:00
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
// The memory_format arg must be `none`.
|
2024-04-28 05:00:56 +08:00
|
|
|
if (!isa<Torch::NoneType>(getMemoryFormat().getType()))
|
2022-04-27 19:07:40 +08:00
|
|
|
return nullptr;
|
|
|
|
|
2024-04-28 05:00:56 +08:00
|
|
|
auto inputType = cast<BaseTensorType>(getSelf().getType());
|
|
|
|
auto resType = cast<BaseTensorType>(getType());
|
2022-04-27 19:07:40 +08:00
|
|
|
// If the types aren't equal, then we can't fold.
|
|
|
|
if (inputType != resType)
|
|
|
|
return nullptr;
|
|
|
|
// If the type does not have a statically known dtype, then we cannot fold.
|
|
|
|
// For example, folding `tensor<*,unk>` to `tensor<*,unk>` would be wrong,
|
|
|
|
// since the `unk` could be dynamically different for the operand and result.
|
|
|
|
if (!inputType.hasDtype())
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
// The layout arg should be either `none` or `0` i.e. strided.
|
2024-04-28 05:00:56 +08:00
|
|
|
if (!isa<Torch::NoneType>(getLayout().getType())) {
|
2022-04-27 19:07:40 +08:00
|
|
|
int64_t tensorLayout;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(getLayout(), m_TorchConstantInt(&tensorLayout)))
|
2022-04-27 19:07:40 +08:00
|
|
|
return nullptr;
|
|
|
|
else if (tensorLayout != torch_upstream::Layout::Strided)
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Fold when both the input tensor and result are of the same type and the
|
|
|
|
// layout arg is strided.
|
|
|
|
return getOperand(0);
|
|
|
|
}
|
|
|
|
|
2023-05-03 11:06:02 +08:00
|
|
|
void AtenToDtypeLayoutOp::getCanonicalizationPatterns(
|
|
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
|
|
// `to.dtype_layout` -> `to.device/to.dtype` if layout is none and pin memory
|
|
|
|
// is false
|
|
|
|
patterns.add(+[](AtenToDtypeLayoutOp op, PatternRewriter &rewriter) {
|
|
|
|
// The pin_memory arg should be either constant `False` or `none`.
|
2024-04-28 05:00:56 +08:00
|
|
|
if (!isa<Torch::NoneType>(op.getPinMemory().getType())) {
|
2023-05-03 11:06:02 +08:00
|
|
|
bool pinMemory;
|
|
|
|
if (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)))
|
|
|
|
return failure();
|
|
|
|
else if (pinMemory)
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
// The layout arg should be either `none` or `0` i.e. strided.
|
2024-04-28 05:00:56 +08:00
|
|
|
if (!isa<Torch::NoneType>(op.getLayout().getType())) {
|
2023-05-03 11:06:02 +08:00
|
|
|
int64_t tensorLayout;
|
|
|
|
if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout)))
|
|
|
|
return failure();
|
|
|
|
else if (tensorLayout != torch_upstream::Layout::Strided)
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
2024-04-28 05:00:56 +08:00
|
|
|
if (isa<Torch::NoneType>(op.getDevice().getType())) {
|
2023-05-03 11:06:02 +08:00
|
|
|
// The device arg is `none`. Rewrite to to.dtype.
|
|
|
|
AtenToDtypeOp toDtype = rewriter.create<AtenToDtypeOp>(
|
|
|
|
op.getLoc(), op.getType(), op.getSelf(), op.getDtype(),
|
|
|
|
op.getNonBlocking(), op.getCopy(), op.getMemoryFormat());
|
|
|
|
rewriter.replaceOp(op, toDtype->getResults());
|
|
|
|
} else {
|
|
|
|
// The device arg is not `none`. Rewrite to to.device.
|
|
|
|
AtenToDeviceOp toDevice = rewriter.create<AtenToDeviceOp>(
|
|
|
|
op.getLoc(), op.getType(), op.getSelf(), op.getDevice(),
|
|
|
|
op.getDtype(), op.getNonBlocking(), op.getCopy(),
|
|
|
|
op.getMemoryFormat());
|
|
|
|
rewriter.replaceOp(op, toDevice->getResults());
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2023-06-30 09:43:08 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenToOtherOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void AtenToOtherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
// Canonicalize `aten.to.other` to `aten.to.device`
|
|
|
|
patterns.add(+[](AtenToOtherOp op, PatternRewriter &rewriter) {
|
|
|
|
auto lhs = op.getSelf();
|
|
|
|
auto rhs = op.getOther();
|
|
|
|
auto getRhsDevice = rewriter.create<PrimDeviceOp>(op.getLoc(), rhs);
|
|
|
|
auto getRhsDtype = rewriter.create<PrimDtypeOp>(op.getLoc(), rhs);
|
2024-01-30 01:59:33 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenToDeviceOp>(
|
|
|
|
op, op.getType(), lhs, getRhsDevice.getResult(),
|
|
|
|
getRhsDtype.getResult(), op.getNonBlocking(), op.getCopy(),
|
|
|
|
op.getMemoryFormat());
|
2023-06-30 09:43:08 +08:00
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2024-04-09 11:06:53 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten_CastFloatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void Aten_CastFloatOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
// `aten.cast_float` -> `aten.to.dtype`
|
|
|
|
patterns.add(+[](Aten_CastFloatOp op, PatternRewriter &rewriter) {
|
|
|
|
auto self = op.getSelf();
|
|
|
|
auto loc = op.getLoc();
|
|
|
|
Value constNone = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
Value f32Type = rewriter.create<ConstantIntOp>(
|
|
|
|
loc, (int)torch_upstream::ScalarType::Float);
|
|
|
|
Value constFalse = rewriter.create<ConstantBoolOp>(loc, false);
|
|
|
|
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(op, op.getType(), self, f32Type,
|
|
|
|
op.getNonBlocking(), constFalse,
|
|
|
|
constNone);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2024-04-17 21:58:32 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten_CastLongOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void Aten_CastLongOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
// `aten.cast_long` -> `aten.to.dtype`
|
|
|
|
patterns.add(+[](Aten_CastLongOp op, PatternRewriter &rewriter) {
|
|
|
|
auto self = op.getSelf();
|
|
|
|
auto loc = op.getLoc();
|
|
|
|
Value constNone = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
Value longType = rewriter.create<ConstantIntOp>(
|
|
|
|
loc, (int)torch_upstream::ScalarType::Long);
|
|
|
|
Value constFalse = rewriter.create<ConstantBoolOp>(loc, false);
|
|
|
|
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(op, op.getType(), self, longType,
|
|
|
|
op.getNonBlocking(), constFalse,
|
|
|
|
constNone);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2021-12-23 20:04:29 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenViewOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) {
|
2024-04-28 05:00:56 +08:00
|
|
|
auto inputType = dyn_cast<BaseTensorType>(getOperand(0).getType());
|
2021-12-23 20:04:29 +08:00
|
|
|
if (!inputType || !inputType.hasSizes() || inputType.getSizes().size() != 1)
|
|
|
|
return nullptr;
|
2024-04-28 05:00:56 +08:00
|
|
|
auto resType = dyn_cast<BaseTensorType>(getType());
|
2021-12-23 20:04:29 +08:00
|
|
|
if (!resType || !resType.hasSizes() || resType.getSizes().size() != 1)
|
|
|
|
return nullptr;
|
2024-01-05 06:33:41 +08:00
|
|
|
if (inputType != resType)
|
|
|
|
return nullptr;
|
2021-12-23 20:04:29 +08:00
|
|
|
// Fold when both the input tensor and result are unity rank tensors.
|
|
|
|
return getOperand(0);
|
|
|
|
}
|
|
|
|
|
2023-04-10 11:50:26 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimsViewOfOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult PrimsViewOfOp::fold(FoldAdaptor adaptor) {
|
|
|
|
// Always fold the op with its only input operand.
|
|
|
|
return getOperand();
|
|
|
|
}
|
|
|
|
|
2021-10-21 23:50:01 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenDimOp
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenDimOp::fold(FoldAdaptor adaptor) {
|
2024-04-28 05:00:56 +08:00
|
|
|
if (auto tensorType = dyn_cast<BaseTensorType>(getOperand().getType())) {
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
if (tensorType.hasSizes())
|
|
|
|
return IntegerAttr::get(IntegerType::get(getContext(), 64),
|
|
|
|
tensorType.getSizes().size());
|
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenLenTOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenLenTOp::fold(FoldAdaptor adaptor) {
|
2022-03-10 08:44:22 +08:00
|
|
|
// `len([1,1,1])` -> `3`, if it is not mutated.
|
2021-06-19 10:33:14 +08:00
|
|
|
if (auto listConstruct =
|
|
|
|
getOperand().getDefiningOp<Torch::PrimListConstructOp>()) {
|
2022-03-10 08:44:22 +08:00
|
|
|
if (!isListPotentiallyMutated(listConstruct)) {
|
|
|
|
return IntegerAttr::get(IntegerType::get(getContext(), 64),
|
|
|
|
listConstruct.getNumOperands());
|
|
|
|
}
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
void AtenLenTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
// `len(t.size())` -> `t.ndim`
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
patterns.add(+[](AtenLenTOp op, PatternRewriter &rewriter) {
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
auto size = op.getOperand().getDefiningOp<AtenSizeOp>();
|
|
|
|
if (!size)
|
|
|
|
return rewriter.notifyMatchFailure(op, "operand not AtenSizeOp");
|
2021-06-17 06:53:15 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenDimOp>(op, size.getOperand());
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
2023-09-11 17:28:22 +08:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenMinOtherOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void AtenMinOtherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
// `aten.min.other` -> `aten.minimum`
|
|
|
|
patterns.add(+[](AtenMinOtherOp op, PatternRewriter &rewriter) {
|
|
|
|
rewriter.replaceOpWithNewOp<AtenMinimumOp>(op, op.getType(), op.getSelf(),
|
|
|
|
op.getOther());
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
|
2023-09-05 10:52:32 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenMaxOtherOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void AtenMaxOtherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
// `aten.max.other` -> `aten.maximum`
|
|
|
|
patterns.add(+[](AtenMaxOtherOp op, PatternRewriter &rewriter) {
|
|
|
|
rewriter.replaceOpWithNewOp<AtenMaximumOp>(op, op.getType(), op.getSelf(),
|
2024-01-30 01:59:33 +08:00
|
|
|
op.getOther());
|
2023-09-05 10:52:32 +08:00
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2022-07-08 01:41:55 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenLenStrOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenLenStrOp::fold(FoldAdaptor adaptor) {
|
2022-12-08 04:20:41 +08:00
|
|
|
if (auto stringConstruct = getS().getDefiningOp<ConstantStrOp>())
|
2022-08-16 13:24:08 +08:00
|
|
|
return getI64IntegerAttr(getContext(),
|
2022-12-08 04:20:41 +08:00
|
|
|
stringConstruct.getValueAttr().getValue().size());
|
2022-07-08 01:41:55 +08:00
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2022-08-16 13:24:08 +08:00
|
|
|
LogicalResult rewrite0DBinaryTensorOp(Operation *op,
|
|
|
|
PatternRewriter &rewriter) {
|
|
|
|
Location loc = op->getLoc();
|
|
|
|
// This canonicalization pattern also includes aten div/mul/add/sub ops
|
|
|
|
// between tensor and scalar, like aten.add.Scalar op
|
|
|
|
if (op->getNumOperands() < 2) {
|
|
|
|
return failure();
|
|
|
|
}
|
2023-03-07 02:12:58 +08:00
|
|
|
auto lhs = getScalarIntValue(op->getOperand(0), loc, rewriter);
|
|
|
|
auto rhs = getScalarIntValue(op->getOperand(1), loc, rewriter);
|
2022-08-16 13:24:08 +08:00
|
|
|
auto outType = op->getResult(0).getType();
|
|
|
|
|
|
|
|
if (!lhs || !rhs) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "only int scalar lhs or rhs is supported");
|
|
|
|
}
|
2023-03-07 09:38:27 +08:00
|
|
|
if (isa<AtenSubTensorOp, AtenSubScalarOp, AtenRsubScalarOp, AtenAddTensorOp,
|
|
|
|
AtenAddScalarOp>(op)) {
|
2023-03-07 02:12:58 +08:00
|
|
|
Value alpha = getScalarIntValue(op->getOperand(2), loc, rewriter);
|
2022-08-16 13:24:08 +08:00
|
|
|
if (!alpha) {
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"only int scalar alpha is supported");
|
|
|
|
}
|
2023-03-07 09:38:27 +08:00
|
|
|
if (isa<AtenRsubScalarOp>(op))
|
|
|
|
lhs = rewriter.create<AtenMulIntOp>(loc, lhs, alpha);
|
|
|
|
else
|
|
|
|
rhs = rewriter.create<AtenMulIntOp>(loc, rhs, alpha);
|
2022-08-16 13:24:08 +08:00
|
|
|
}
|
|
|
|
|
2024-04-16 04:45:10 +08:00
|
|
|
if (isa<AtenDivTensorModeOp, AtenDivScalarModeOp>(op)) {
|
2024-04-28 05:00:56 +08:00
|
|
|
if (isa<Torch::NoneType>(op->getOperand(2).getType())) {
|
2024-04-16 04:45:10 +08:00
|
|
|
// None rounding mode
|
2022-08-16 13:24:08 +08:00
|
|
|
Value quotient = rewriter.create<AtenDivOp>(loc, lhs, rhs);
|
|
|
|
rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(op, outType,
|
|
|
|
quotient);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
std::string roundingMode;
|
|
|
|
if (!matchPattern(op->getOperand(2), m_TorchConstantStr(roundingMode))) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "only None, 'floor' or 'trunc' rounding mode is supported");
|
|
|
|
}
|
|
|
|
if (roundingMode == "floor") {
|
|
|
|
Value quotient = rewriter.create<AtenFloordivIntOp>(loc, lhs, rhs);
|
|
|
|
rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(op, outType,
|
|
|
|
quotient);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
// For "trunc" rounding mode, insted of canonicalizing it into
|
|
|
|
// aten.abs, aten.floor, aten.sign and aten.mul.int ops, which adds
|
|
|
|
// complexity but helps little in optimization (such as constant folding),
|
|
|
|
// we are trying to fold it.
|
|
|
|
if (roundingMode == "trunc") {
|
|
|
|
int64_t lhsInt;
|
|
|
|
int64_t rhsInt;
|
|
|
|
if (!matchPattern(lhs, m_TorchConstantInt(&lhsInt))) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
if (!matchPattern(rhs, m_TorchConstantInt(&rhsInt))) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
int64_t result = (int64_t)std::trunc((double)lhsInt / rhsInt);
|
|
|
|
Value resultScalar = rewriter.create<ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(result));
|
|
|
|
rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(op, outType,
|
|
|
|
resultScalar);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
Value result;
|
|
|
|
// Other Add/Sub/Mul ops
|
|
|
|
if (isa<AtenAddTensorOp, AtenAddScalarOp>(op)) {
|
|
|
|
result = rewriter.create<AtenAddIntOp>(loc, lhs, rhs);
|
|
|
|
} else if (isa<AtenSubScalarOp, AtenSubTensorOp>(op)) {
|
|
|
|
result = rewriter.create<AtenSubIntOp>(loc, lhs, rhs);
|
2023-03-07 09:38:27 +08:00
|
|
|
} else if (isa<AtenRsubScalarOp>(op)) {
|
|
|
|
result = rewriter.create<AtenSubIntOp>(loc, rhs, lhs);
|
2022-08-16 13:24:08 +08:00
|
|
|
} else if (isa<AtenMulScalarOp, AtenMulTensorOp>(op)) {
|
|
|
|
result = rewriter.create<AtenMulIntOp>(loc, lhs, rhs);
|
|
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(op, outType, result);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2024-02-20 02:28:23 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// NAry folder helpers
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static bool checkAllSplats(llvm::ArrayRef<Attribute> attrs) {
|
|
|
|
for (auto attr : attrs) {
|
|
|
|
if (auto dense = dyn_cast_or_null<ElementsAttr>(attr)) {
|
|
|
|
if (!dense.isSplat())
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
llvm::SmallVector<double> getFoldValueAtIndexFp(llvm::ArrayRef<Attribute> attrs,
|
|
|
|
int64_t idx = 0) {
|
|
|
|
llvm::SmallVector<double> splattrs;
|
|
|
|
|
2024-07-24 17:54:59 +08:00
|
|
|
// Note that i1 is neither signed nor unsigned.
|
|
|
|
// But we should trait i1 as unsigned, otherwise that
|
|
|
|
// APInt(1,1).getSExtValue() return allOnes 64-bit integer.
|
|
|
|
// So here only distinguish signed integer.
|
|
|
|
auto convertAPIntToDouble = [](APInt value, bool isSigned) -> double {
|
|
|
|
if (isSigned)
|
|
|
|
return static_cast<double>(value.getSExtValue());
|
|
|
|
else
|
|
|
|
return static_cast<double>(value.getZExtValue());
|
|
|
|
};
|
|
|
|
|
2024-02-20 02:28:23 +08:00
|
|
|
for (auto attr : attrs) {
|
2024-07-24 17:54:59 +08:00
|
|
|
if (auto dense = dyn_cast<DenseFPElementsAttr>(attr)) {
|
2024-02-20 02:28:23 +08:00
|
|
|
if (dense.isSplat()) {
|
|
|
|
splattrs.push_back(dense.getSplatValue<APFloat>().convertToDouble());
|
|
|
|
} else {
|
|
|
|
splattrs.push_back(dense.getValues<APFloat>()[idx].convertToDouble());
|
|
|
|
}
|
2024-07-24 17:54:59 +08:00
|
|
|
} else if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) {
|
|
|
|
bool isSigned = cast<IntegerType>(dense.getElementType()).isSigned();
|
|
|
|
if (dense.isSplat()) {
|
|
|
|
splattrs.push_back(
|
|
|
|
convertAPIntToDouble(dense.getSplatValue<APInt>(), isSigned));
|
|
|
|
} else {
|
|
|
|
splattrs.push_back(
|
|
|
|
convertAPIntToDouble(dense.getValues<APInt>()[idx], isSigned));
|
|
|
|
}
|
|
|
|
} else if (auto fpattr = dyn_cast<FloatAttr>(attr)) {
|
|
|
|
splattrs.push_back(fpattr.getValueAsDouble());
|
|
|
|
} else if (auto intattr = dyn_cast<IntegerAttr>(attr)) {
|
|
|
|
bool isSigned = cast<IntegerType>(intattr.getType()).isSigned();
|
|
|
|
splattrs.push_back(convertAPIntToDouble(intattr.getValue(), isSigned));
|
2024-02-20 02:28:23 +08:00
|
|
|
} else {
|
|
|
|
return {};
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return splattrs;
|
|
|
|
}
|
|
|
|
|
|
|
|
llvm::SmallVector<APInt> getFoldValueAtIndexInt(llvm::ArrayRef<Attribute> attrs,
|
|
|
|
int64_t bitwidth,
|
|
|
|
int64_t idx = 0) {
|
|
|
|
llvm::SmallVector<APInt> splattrs;
|
|
|
|
|
|
|
|
for (auto attr : attrs) {
|
2024-04-26 10:10:02 +08:00
|
|
|
bool isSigned = false;
|
2024-07-24 17:54:59 +08:00
|
|
|
if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) {
|
|
|
|
isSigned = cast<IntegerType>(dense.getElementType()).isSigned();
|
2024-02-20 02:28:23 +08:00
|
|
|
if (dense.isSplat()) {
|
|
|
|
splattrs.push_back(dense.getSplatValue<APInt>());
|
|
|
|
} else {
|
|
|
|
splattrs.push_back(dense.getValues<APInt>()[idx]);
|
|
|
|
}
|
|
|
|
} else if (auto intattr = dyn_cast<IntegerAttr>(attr)) {
|
2024-04-26 10:10:02 +08:00
|
|
|
isSigned = cast<IntegerType>(intattr.getType()).isSigned();
|
2024-02-20 02:28:23 +08:00
|
|
|
splattrs.push_back(intattr.getValue());
|
|
|
|
} else {
|
|
|
|
return {};
|
|
|
|
}
|
|
|
|
|
2024-07-24 17:54:59 +08:00
|
|
|
// Note that i1 is neither signed nor unsigned.
|
|
|
|
// But we should trait i1 as unsigned, otherwise that
|
|
|
|
// APInt(1,1).getSExtValue() return allOnes 64-bit integer.
|
|
|
|
// So here only distinguish signed integer.
|
2024-02-20 02:28:23 +08:00
|
|
|
auto &apint = splattrs.back();
|
|
|
|
if (apint.getBitWidth() < bitwidth) {
|
2024-04-26 10:10:02 +08:00
|
|
|
if (isSigned) {
|
2024-02-20 02:28:23 +08:00
|
|
|
apint = apint.sextOrTrunc(bitwidth);
|
2024-04-26 10:10:02 +08:00
|
|
|
} else {
|
|
|
|
apint = apint.zextOrTrunc(bitwidth);
|
2024-02-20 02:28:23 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return splattrs;
|
|
|
|
}
|
|
|
|
|
|
|
|
using NAryFoldFpOperator = std::function<double(ArrayRef<double>)>;
|
|
|
|
using NAryFoldIntOperator = std::function<APInt(ArrayRef<APInt>)>;
|
|
|
|
|
2024-08-03 03:27:31 +08:00
|
|
|
static OpFoldResult
|
|
|
|
naryFolderHelper(ArrayRef<Attribute> operands, Type ty,
|
|
|
|
std::optional<NAryFoldFpOperator> fpFolder,
|
|
|
|
std::optional<NAryFoldIntOperator> intFolder) {
|
2024-07-24 17:54:59 +08:00
|
|
|
constexpr int64_t kMaxFold = 16;
|
|
|
|
for (auto attr : operands) {
|
|
|
|
if (!attr)
|
|
|
|
return nullptr;
|
|
|
|
}
|
2024-02-20 02:28:23 +08:00
|
|
|
|
|
|
|
auto resultTy = dyn_cast<ValueTensorType>(ty);
|
2024-07-24 17:54:59 +08:00
|
|
|
if (!resultTy || !resultTy.hasDtype() || !resultTy.areAllSizesKnown())
|
2024-02-20 02:28:23 +08:00
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
auto dty = resultTy.getDtype();
|
2024-06-08 09:36:32 +08:00
|
|
|
auto resultBTy = resultTy.toBuiltinTensor();
|
2024-02-20 02:28:23 +08:00
|
|
|
|
|
|
|
auto fpTy = dyn_cast<mlir::FloatType>(dty);
|
|
|
|
auto intTy = dyn_cast<mlir::IntegerType>(dty);
|
|
|
|
if (!fpTy && !intTy)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
bool allSplats = checkAllSplats(operands);
|
2024-07-24 17:54:59 +08:00
|
|
|
if (!(allSplats || resultBTy.getNumElements() <= kMaxFold))
|
2024-02-20 02:28:23 +08:00
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
// We do not support broadcasting in the non-splat case so validate same
|
|
|
|
// shaped inputs / outputs:
|
|
|
|
if (!allSplats) {
|
|
|
|
auto resultShape = resultBTy.getShape();
|
|
|
|
for (int i = 0, s = operands.size(); i < s; ++i) {
|
|
|
|
if (auto dense = dyn_cast<DenseElementsAttr>(operands[i])) {
|
|
|
|
if (dense.isSplat())
|
|
|
|
continue;
|
|
|
|
auto operandShape = cast<ShapedType>(dense.getType()).getShape();
|
|
|
|
if (operandShape.size() != resultShape.size())
|
|
|
|
return nullptr;
|
|
|
|
for (int i = 0, s = operandShape.size(); i < s; ++i)
|
|
|
|
if (operandShape[i] != resultShape[i])
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
const int64_t numValues = allSplats ? 1 : resultBTy.getNumElements();
|
|
|
|
|
|
|
|
if (fpTy) {
|
2024-08-03 03:27:31 +08:00
|
|
|
if (!fpFolder.has_value())
|
|
|
|
return nullptr;
|
|
|
|
auto folder = fpFolder.value();
|
2024-02-20 02:28:23 +08:00
|
|
|
llvm::SmallVector<APFloat> folded;
|
|
|
|
for (int i = 0, s = numValues; i < s; ++i) {
|
|
|
|
auto inputs = getFoldValueAtIndexFp(operands, i);
|
2024-07-24 17:54:59 +08:00
|
|
|
if (inputs.size() != operands.size())
|
|
|
|
return nullptr;
|
2024-08-03 03:27:31 +08:00
|
|
|
double fold = folder(inputs);
|
2024-02-20 02:28:23 +08:00
|
|
|
|
|
|
|
APFloat val(fold);
|
|
|
|
bool unused;
|
|
|
|
val.convert(fpTy.getFloatSemantics(), APFloat::rmNearestTiesToEven,
|
|
|
|
&unused);
|
|
|
|
folded.push_back(val);
|
|
|
|
}
|
|
|
|
return DenseElementsAttr::get(resultBTy, folded);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (intTy) {
|
2024-08-03 03:27:31 +08:00
|
|
|
if (!intFolder.has_value())
|
|
|
|
return nullptr;
|
|
|
|
auto folder = intFolder.value();
|
2024-02-20 02:28:23 +08:00
|
|
|
llvm::SmallVector<APInt> folded;
|
|
|
|
for (int i = 0, s = numValues; i < s; ++i) {
|
|
|
|
auto inputs =
|
|
|
|
getFoldValueAtIndexInt(operands, dty.getIntOrFloatBitWidth(), i);
|
2024-07-24 17:54:59 +08:00
|
|
|
if (inputs.size() != operands.size())
|
|
|
|
return nullptr;
|
2024-08-03 03:27:31 +08:00
|
|
|
folded.push_back(folder(inputs));
|
2024-02-20 02:28:23 +08:00
|
|
|
}
|
|
|
|
return DenseElementsAttr::get(resultBTy, folded);
|
|
|
|
}
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2022-06-18 02:49:36 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenAddTensorOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void AtenAddTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenAddTensorOp op, PatternRewriter &rewriter) {
|
2022-08-16 13:24:08 +08:00
|
|
|
return rewrite0DBinaryTensorOp(op, rewriter);
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2024-02-20 02:28:23 +08:00
|
|
|
OpFoldResult AtenAddTensorOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto fpFold = [](llvm::ArrayRef<double> inputs) {
|
|
|
|
assert(inputs.size() == 3);
|
|
|
|
return inputs[0] + (inputs[1] * inputs[2]);
|
|
|
|
};
|
|
|
|
|
|
|
|
auto intFold = [](llvm::ArrayRef<APInt> inputs) {
|
|
|
|
assert(inputs.size() == 3);
|
|
|
|
return inputs[0] + (inputs[1] * inputs[2]);
|
|
|
|
};
|
|
|
|
|
|
|
|
return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
|
|
|
|
}
|
|
|
|
|
2022-08-16 13:24:08 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenAddScalarOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void AtenAddScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenAddScalarOp op, PatternRewriter &rewriter) {
|
|
|
|
return rewrite0DBinaryTensorOp(op, rewriter);
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2024-04-03 07:19:57 +08:00
|
|
|
OpFoldResult AtenAddScalarOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto fpFold = [](llvm::ArrayRef<double> inputs) {
|
|
|
|
assert(inputs.size() == 3);
|
|
|
|
return inputs[0] + (inputs[1] * inputs[2]);
|
|
|
|
};
|
|
|
|
|
|
|
|
auto intFold = [](llvm::ArrayRef<APInt> inputs) {
|
|
|
|
assert(inputs.size() == 3);
|
|
|
|
int64_t bits = inputs[0].getBitWidth();
|
|
|
|
APInt other(bits, inputs[1].getLimitedValue());
|
|
|
|
APInt alpha(bits, inputs[2].getLimitedValue());
|
|
|
|
return inputs[0] + (other * alpha);
|
|
|
|
};
|
|
|
|
|
|
|
|
return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
|
|
|
|
}
|
|
|
|
|
2022-08-16 13:24:08 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSubTensorOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void AtenSubTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenSubTensorOp op, PatternRewriter &rewriter) {
|
|
|
|
return rewrite0DBinaryTensorOp(op, rewriter);
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2024-02-20 02:28:23 +08:00
|
|
|
OpFoldResult AtenSubTensorOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto fpFold = [](llvm::ArrayRef<double> inputs) {
|
|
|
|
assert(inputs.size() == 3);
|
|
|
|
return inputs[0] - (inputs[1] * inputs[2]);
|
|
|
|
};
|
|
|
|
|
|
|
|
auto intFold = [](llvm::ArrayRef<APInt> inputs) {
|
|
|
|
assert(inputs.size() == 3);
|
|
|
|
return inputs[0] - (inputs[1] * inputs[2]);
|
|
|
|
};
|
|
|
|
|
|
|
|
return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
|
|
|
|
}
|
|
|
|
|
2022-08-16 13:24:08 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSubScalarOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void AtenSubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenSubScalarOp op, PatternRewriter &rewriter) {
|
|
|
|
return rewrite0DBinaryTensorOp(op, rewriter);
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2024-04-03 07:19:57 +08:00
|
|
|
OpFoldResult AtenSubScalarOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto fpFold = [](llvm::ArrayRef<double> inputs) {
|
|
|
|
assert(inputs.size() == 3);
|
|
|
|
return inputs[0] - (inputs[1] * inputs[2]);
|
|
|
|
};
|
|
|
|
|
|
|
|
auto intFold = [](llvm::ArrayRef<APInt> inputs) {
|
|
|
|
assert(inputs.size() == 3);
|
|
|
|
int64_t bits = inputs[0].getBitWidth();
|
|
|
|
APInt other(bits, inputs[1].getLimitedValue());
|
|
|
|
APInt alpha(bits, inputs[2].getLimitedValue());
|
|
|
|
return inputs[0] - (other * alpha);
|
|
|
|
};
|
|
|
|
|
|
|
|
return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
|
|
|
|
}
|
|
|
|
|
2023-03-07 09:38:27 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenRSubScalarOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void AtenRsubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenRsubScalarOp op, PatternRewriter &rewriter) {
|
|
|
|
return rewrite0DBinaryTensorOp(op, rewriter);
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2022-08-16 13:24:08 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenMulTensorOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void AtenMulTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenMulTensorOp op, PatternRewriter &rewriter) {
|
|
|
|
return rewrite0DBinaryTensorOp(op, rewriter);
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2024-02-20 02:28:23 +08:00
|
|
|
OpFoldResult AtenMulTensorOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto fpFold = [](llvm::ArrayRef<double> inputs) {
|
|
|
|
assert(inputs.size() == 2);
|
|
|
|
return inputs[0] * inputs[1];
|
|
|
|
};
|
|
|
|
|
|
|
|
auto intFold = [](llvm::ArrayRef<APInt> inputs) {
|
|
|
|
assert(inputs.size() == 2);
|
2024-04-03 07:19:57 +08:00
|
|
|
int64_t bits = inputs[0].getBitWidth();
|
|
|
|
APInt other(bits, inputs[1].getLimitedValue());
|
|
|
|
return inputs[0] * other;
|
2024-02-20 02:28:23 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
|
|
|
|
}
|
|
|
|
|
2024-02-10 07:02:20 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenEqTensorOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenEqTensorOp::fold(FoldAdaptor adaptor) {
|
|
|
|
constexpr int64_t kMaxFold = 16;
|
|
|
|
auto ty = dyn_cast<ValueTensorType>(getType());
|
|
|
|
if (!ty || !ty.hasDtype() || !ty.hasSizes())
|
|
|
|
return nullptr;
|
|
|
|
|
2024-06-08 09:36:32 +08:00
|
|
|
auto bty = ty.toBuiltinTensor();
|
2024-02-10 07:02:20 +08:00
|
|
|
if (!bty.hasStaticShape())
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
if (getSelf() == getOther())
|
|
|
|
return DenseElementsAttr::get(bty,
|
|
|
|
IntegerAttr::get(bty.getElementType(), 1));
|
|
|
|
|
|
|
|
auto self = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
|
|
|
|
auto other = dyn_cast_or_null<DenseElementsAttr>(adaptor.getOther());
|
|
|
|
if (!self || !other)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
auto selfTy = dyn_cast<ShapedType>(self.getType());
|
|
|
|
auto otherTy = dyn_cast<ShapedType>(other.getType());
|
|
|
|
if (!selfTy || !otherTy ||
|
|
|
|
selfTy.getElementType() != otherTy.getElementType())
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
// If both values are splats we can just compute the output value as a splat.
|
|
|
|
if (self.isSplat() && other.isSplat()) {
|
|
|
|
if (isa<mlir::FloatType>(selfTy.getElementType())) {
|
|
|
|
APFloat lhsFp = self.getSplatValue<APFloat>();
|
|
|
|
APFloat rhsFp = other.getSplatValue<APFloat>();
|
|
|
|
bool eq = lhsFp.compare(rhsFp) == APFloat::cmpEqual;
|
|
|
|
return DenseElementsAttr::get(bty, eq);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (isa<mlir::IntegerType>(selfTy.getElementType())) {
|
|
|
|
APInt lhsInt = self.getSplatValue<APInt>();
|
|
|
|
APInt rhsInt = other.getSplatValue<APInt>();
|
|
|
|
bool eq = lhsInt == rhsInt;
|
|
|
|
return DenseElementsAttr::get(bty, eq);
|
|
|
|
}
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (selfTy != otherTy || bty.getNumElements() > kMaxFold)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
if (isa<mlir::FloatType>(selfTy.getElementType())) {
|
|
|
|
auto extract = [bty](DenseElementsAttr attr) {
|
|
|
|
llvm::SmallVector<APFloat> vals;
|
|
|
|
if (attr.isSplat()) {
|
|
|
|
vals.resize(bty.getNumElements(), attr.getSplatValue<APFloat>());
|
|
|
|
return vals;
|
|
|
|
}
|
|
|
|
|
|
|
|
for (auto fp : attr.getValues<APFloat>()) {
|
|
|
|
vals.push_back(fp);
|
|
|
|
}
|
|
|
|
return vals;
|
|
|
|
};
|
|
|
|
|
|
|
|
llvm::SmallVector<APFloat> lhsFp = extract(self);
|
|
|
|
llvm::SmallVector<APFloat> rhsFp = extract(other);
|
|
|
|
llvm::SmallVector<bool> vals(bty.getNumElements());
|
|
|
|
for (int i = 0, s = bty.getNumElements(); i < s; ++i) {
|
|
|
|
vals[i] = lhsFp[i].compare(rhsFp[i]) == APFloat::cmpEqual;
|
|
|
|
}
|
|
|
|
|
|
|
|
return DenseElementsAttr::get(bty, vals);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (isa<mlir::IntegerType>(selfTy.getElementType())) {
|
|
|
|
auto extract = [bty](DenseElementsAttr attr) {
|
|
|
|
llvm::SmallVector<APInt> vals;
|
|
|
|
if (attr.isSplat()) {
|
|
|
|
vals.resize(bty.getNumElements(), attr.getSplatValue<APInt>());
|
|
|
|
return vals;
|
|
|
|
}
|
|
|
|
|
|
|
|
for (auto fp : attr.getValues<APInt>()) {
|
|
|
|
vals.push_back(fp);
|
|
|
|
}
|
|
|
|
return vals;
|
|
|
|
};
|
|
|
|
|
|
|
|
llvm::SmallVector<APInt> lhsInt = extract(self);
|
|
|
|
llvm::SmallVector<APInt> rhsInt = extract(other);
|
|
|
|
llvm::SmallVector<bool> vals(bty.getNumElements());
|
|
|
|
for (int i = 0, s = bty.getNumElements(); i < s; ++i) {
|
|
|
|
vals[i] = lhsInt[i] == rhsInt[i];
|
|
|
|
}
|
|
|
|
|
|
|
|
return DenseElementsAttr::get(bty, vals);
|
|
|
|
}
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2024-03-09 05:44:00 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenLeScalarOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
using ComparisonFoldFpOperator = std::function<bool(double, double)>;
|
|
|
|
using ComparisonFoldIntOperator = std::function<bool(APInt, APInt, bool)>;
|
|
|
|
|
|
|
|
static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs,
|
|
|
|
ValueTensorType resultTy,
|
|
|
|
ComparisonFoldFpOperator fpFolder,
|
|
|
|
ComparisonFoldIntOperator intFolder) {
|
|
|
|
constexpr int64_t kMaxFold = 16;
|
|
|
|
if (!lhs || !rhs || !resultTy)
|
|
|
|
return nullptr;
|
2024-07-24 17:54:59 +08:00
|
|
|
if (!resultTy.areAllSizesKnown() || !resultTy.hasDtype())
|
2024-03-09 05:44:00 +08:00
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
auto ctx = lhs.getContext();
|
|
|
|
auto tensorETy = cast<RankedTensorType>(lhs.getType()).getElementType();
|
|
|
|
if (lhs.isSplat()) {
|
|
|
|
if (auto intAttr = dyn_cast<IntegerAttr>(rhs)) {
|
|
|
|
auto unsign = cast<IntegerType>(tensorETy).isUnsigned();
|
|
|
|
auto scalarAP = intAttr.getValue();
|
|
|
|
auto tensorAP = lhs.getSplatValue<IntegerAttr>().getValue();
|
|
|
|
tensorAP = APInt(
|
|
|
|
scalarAP.getBitWidth(),
|
|
|
|
unsign ? tensorAP.getZExtValue() : tensorAP.getSExtValue(), !unsign);
|
|
|
|
auto resultBool = intFolder(tensorAP, scalarAP, unsign);
|
|
|
|
auto resultAP = IntegerAttr::get(IntegerType::get(ctx, 1), resultBool);
|
2024-06-08 09:36:32 +08:00
|
|
|
return DenseElementsAttr::get(resultTy.toBuiltinTensor(), resultAP);
|
2024-03-09 05:44:00 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
if (auto floatAttr = dyn_cast<FloatAttr>(rhs)) {
|
|
|
|
APFloat scalarAP = floatAttr.getValue();
|
|
|
|
APFloat tensorAP = lhs.getSplatValue<FloatAttr>().getValue();
|
|
|
|
auto resultBool =
|
|
|
|
fpFolder(tensorAP.convertToDouble(), scalarAP.convertToDouble());
|
|
|
|
auto resultAP = IntegerAttr::get(IntegerType::get(ctx, 1), resultBool);
|
2024-06-08 09:36:32 +08:00
|
|
|
return DenseElementsAttr::get(resultTy.toBuiltinTensor(), resultAP);
|
2024-03-09 05:44:00 +08:00
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
int64_t count = 1;
|
|
|
|
for (auto size : resultTy.getSizes())
|
|
|
|
count *= size;
|
|
|
|
|
|
|
|
if (count > kMaxFold)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
if (auto intAttr = dyn_cast<IntegerAttr>(rhs)) {
|
|
|
|
auto unsign = cast<IntegerType>(tensorETy).isUnsigned();
|
|
|
|
llvm::SmallVector<bool> values;
|
|
|
|
for (auto tensorAP : lhs.getValues<APInt>()) {
|
|
|
|
auto scalarAP = intAttr.getValue();
|
|
|
|
tensorAP = APInt(
|
|
|
|
scalarAP.getBitWidth(),
|
|
|
|
unsign ? tensorAP.getZExtValue() : tensorAP.getSExtValue(), !unsign);
|
|
|
|
auto resultBool = intFolder(tensorAP, scalarAP, unsign);
|
|
|
|
values.push_back(resultBool);
|
|
|
|
}
|
2024-06-08 09:36:32 +08:00
|
|
|
return DenseElementsAttr::get(resultTy.toBuiltinTensor(), values);
|
2024-03-09 05:44:00 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
if (auto floatAttr = dyn_cast<FloatAttr>(rhs)) {
|
|
|
|
llvm::SmallVector<bool> values;
|
|
|
|
for (auto tensorAP : lhs.getValues<APFloat>()) {
|
|
|
|
APFloat scalarAP = floatAttr.getValue();
|
|
|
|
auto resultBool =
|
|
|
|
fpFolder(tensorAP.convertToDouble(), scalarAP.convertToDouble());
|
|
|
|
values.push_back(resultBool);
|
|
|
|
}
|
2024-06-08 09:36:32 +08:00
|
|
|
return DenseElementsAttr::get(resultTy.toBuiltinTensor(), values);
|
2024-03-09 05:44:00 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
OpFoldResult AtenLeScalarOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto self = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
|
|
|
|
auto other = adaptor.getOther();
|
|
|
|
auto resultTy = dyn_cast<ValueTensorType>(getType());
|
|
|
|
|
|
|
|
auto fpFold = [](double lhs, double rhs) -> bool { return lhs <= rhs; };
|
|
|
|
|
|
|
|
auto intFold = [](APInt lhs, APInt rhs, bool unsign) -> bool {
|
2024-04-03 07:19:57 +08:00
|
|
|
int64_t bits = std::max(lhs.getBitWidth(), rhs.getBitWidth());
|
|
|
|
APInt lhsWiden(bits, lhs.getLimitedValue());
|
|
|
|
APInt rhsWiden(bits, rhs.getLimitedValue());
|
|
|
|
return unsign ? lhsWiden.ule(rhsWiden) : lhsWiden.sle(rhsWiden);
|
2024-03-09 05:44:00 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
return comparisonScaleFolder(self, other, resultTy, fpFold, intFold);
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenLtScalarOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenLtScalarOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto self = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
|
|
|
|
auto other = adaptor.getOther();
|
|
|
|
auto resultTy = dyn_cast<ValueTensorType>(getType());
|
|
|
|
|
|
|
|
auto fpFold = [](double lhs, double rhs) -> bool { return lhs < rhs; };
|
|
|
|
|
|
|
|
auto intFold = [](APInt lhs, APInt rhs, bool unsign) -> bool {
|
2024-04-03 07:19:57 +08:00
|
|
|
int64_t bits = std::max(lhs.getBitWidth(), rhs.getBitWidth());
|
|
|
|
APInt lhsWiden(bits, lhs.getLimitedValue());
|
|
|
|
APInt rhsWiden(bits, rhs.getLimitedValue());
|
|
|
|
return unsign ? lhsWiden.ult(rhsWiden) : lhsWiden.slt(rhsWiden);
|
2024-03-09 05:44:00 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
return comparisonScaleFolder(self, other, resultTy, fpFold, intFold);
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenGtScalarOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenGtScalarOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto self = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
|
|
|
|
auto other = adaptor.getOther();
|
|
|
|
auto resultTy = dyn_cast<ValueTensorType>(getType());
|
|
|
|
|
|
|
|
auto fpFold = [](double lhs, double rhs) -> bool { return lhs > rhs; };
|
|
|
|
|
|
|
|
auto intFold = [](APInt lhs, APInt rhs, bool unsign) -> bool {
|
2024-04-03 07:19:57 +08:00
|
|
|
int64_t bits = std::max(lhs.getBitWidth(), rhs.getBitWidth());
|
|
|
|
APInt lhsWiden(bits, lhs.getLimitedValue());
|
|
|
|
APInt rhsWiden(bits, rhs.getLimitedValue());
|
|
|
|
return unsign ? lhsWiden.ugt(rhsWiden) : lhsWiden.sgt(rhsWiden);
|
2024-03-09 05:44:00 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
return comparisonScaleFolder(self, other, resultTy, fpFold, intFold);
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenGeScalarOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenGeScalarOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto self = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
|
|
|
|
auto other = adaptor.getOther();
|
|
|
|
auto resultTy = dyn_cast<ValueTensorType>(getType());
|
|
|
|
|
|
|
|
auto fpFold = [](double lhs, double rhs) -> bool { return lhs >= rhs; };
|
|
|
|
|
|
|
|
auto intFold = [](APInt lhs, APInt rhs, bool unsign) -> bool {
|
2024-04-03 07:19:57 +08:00
|
|
|
int64_t bits = std::max(lhs.getBitWidth(), rhs.getBitWidth());
|
|
|
|
APInt lhsWiden(bits, lhs.getLimitedValue());
|
|
|
|
APInt rhsWiden(bits, rhs.getLimitedValue());
|
|
|
|
return unsign ? lhsWiden.uge(rhsWiden) : lhsWiden.sge(rhsWiden);
|
2024-03-09 05:44:00 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
return comparisonScaleFolder(self, other, resultTy, fpFold, intFold);
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenEqScalarOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenEqScalarOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto self = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
|
|
|
|
auto other = adaptor.getOther();
|
|
|
|
auto resultTy = dyn_cast<ValueTensorType>(getType());
|
|
|
|
|
|
|
|
auto fpFold = [](double lhs, double rhs) -> bool { return lhs == rhs; };
|
|
|
|
|
|
|
|
auto intFold = [](APInt lhs, APInt rhs, bool unsign) -> bool {
|
2024-04-03 07:19:57 +08:00
|
|
|
int64_t bits = std::max(lhs.getBitWidth(), rhs.getBitWidth());
|
|
|
|
APInt lhsWiden(bits, lhs.getLimitedValue());
|
|
|
|
APInt rhsWiden(bits, rhs.getLimitedValue());
|
|
|
|
return lhsWiden.eq(rhsWiden);
|
2024-03-09 05:44:00 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
return comparisonScaleFolder(self, other, resultTy, fpFold, intFold);
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenNeScalarOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenNeScalarOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto self = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
|
|
|
|
auto other = adaptor.getOther();
|
|
|
|
auto resultTy = dyn_cast<ValueTensorType>(getType());
|
|
|
|
|
|
|
|
auto fpFold = [](double lhs, double rhs) -> bool { return lhs != rhs; };
|
|
|
|
|
|
|
|
auto intFold = [](APInt lhs, APInt rhs, bool unsign) -> bool {
|
2024-04-03 07:19:57 +08:00
|
|
|
int64_t bits = std::max(lhs.getBitWidth(), rhs.getBitWidth());
|
|
|
|
APInt lhsWiden(bits, lhs.getLimitedValue());
|
|
|
|
APInt rhsWiden(bits, rhs.getLimitedValue());
|
|
|
|
return lhsWiden.ne(rhsWiden);
|
2024-03-09 05:44:00 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
return comparisonScaleFolder(self, other, resultTy, fpFold, intFold);
|
|
|
|
}
|
|
|
|
|
2024-04-26 10:10:02 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenLogOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenLogOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto self = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
|
|
|
|
auto resultType = dyn_cast<ValueTensorType>(getType());
|
|
|
|
if (!self || !resultType)
|
|
|
|
return nullptr;
|
|
|
|
|
2024-07-24 17:54:59 +08:00
|
|
|
auto fpFold = [](llvm::ArrayRef<double> inputs) -> double {
|
|
|
|
assert(inputs.size() == 1);
|
|
|
|
return std::log(inputs[0]);
|
|
|
|
};
|
2024-04-26 10:10:02 +08:00
|
|
|
|
2024-08-03 03:27:31 +08:00
|
|
|
return naryFolderHelper(adaptor.getOperands(), resultType, fpFold,
|
|
|
|
std::nullopt);
|
2024-04-26 10:10:02 +08:00
|
|
|
}
|
|
|
|
|
2022-08-16 13:24:08 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2023-11-02 09:51:31 +08:00
|
|
|
// AtenFloorOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
2024-03-14 07:41:58 +08:00
|
|
|
|
|
|
|
OpFoldResult AtenFloorOp::fold(FoldAdaptor adaptor) {
|
2024-04-28 05:00:56 +08:00
|
|
|
auto resultType = dyn_cast<ValueTensorType>(getType());
|
2024-03-14 07:41:58 +08:00
|
|
|
if (resultType && resultType.hasDtype() &&
|
2024-04-28 05:00:56 +08:00
|
|
|
isa<mlir::IntegerType>(resultType.getDtype())) {
|
2024-03-14 07:41:58 +08:00
|
|
|
return getSelf();
|
|
|
|
}
|
|
|
|
return {};
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenCeilOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenCeilOp::fold(FoldAdaptor adaptor) {
|
2024-04-28 05:00:56 +08:00
|
|
|
auto resultType = dyn_cast<ValueTensorType>(getType());
|
2024-03-14 07:41:58 +08:00
|
|
|
if (resultType && resultType.hasDtype() &&
|
2024-04-28 05:00:56 +08:00
|
|
|
isa<mlir::IntegerType>(resultType.getDtype())) {
|
2024-03-14 07:41:58 +08:00
|
|
|
return getSelf();
|
|
|
|
}
|
|
|
|
return {};
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenRoundOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {
|
2024-04-28 05:00:56 +08:00
|
|
|
auto resultType = dyn_cast<ValueTensorType>(getType());
|
2024-03-14 07:41:58 +08:00
|
|
|
if (resultType && resultType.hasDtype() &&
|
2024-04-28 05:00:56 +08:00
|
|
|
isa<mlir::IntegerType>(resultType.getDtype())) {
|
2024-03-14 07:41:58 +08:00
|
|
|
return getSelf();
|
|
|
|
}
|
|
|
|
return {};
|
2023-11-02 09:51:31 +08:00
|
|
|
}
|
|
|
|
|
2024-04-24 14:32:33 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenTruncOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenTruncOp::fold(FoldAdaptor adaptor) {
|
2024-04-28 05:00:56 +08:00
|
|
|
auto resultType = dyn_cast<ValueTensorType>(getType());
|
2024-04-24 14:32:33 +08:00
|
|
|
if (resultType && resultType.hasDtype() &&
|
2024-05-31 14:45:13 +08:00
|
|
|
isa<mlir::IntegerType>(resultType.getDtype())) {
|
2024-04-24 14:32:33 +08:00
|
|
|
return getSelf();
|
|
|
|
}
|
|
|
|
return {};
|
|
|
|
}
|
|
|
|
|
2024-04-08 20:05:42 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSignOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void AtenSignOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenSignOp op, PatternRewriter &rewriter) {
|
|
|
|
rewriter.replaceOpWithNewOp<AtenSgnOp>(op, op.getType(), op.getSelf());
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2023-11-02 09:51:31 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2022-08-16 13:24:08 +08:00
|
|
|
// AtenMulScalarOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void AtenMulScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenMulScalarOp op, PatternRewriter &rewriter) {
|
|
|
|
return rewrite0DBinaryTensorOp(op, rewriter);
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2024-04-03 07:19:57 +08:00
|
|
|
OpFoldResult AtenMulScalarOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto fpFold = [](llvm::ArrayRef<double> inputs) {
|
|
|
|
assert(inputs.size() == 2);
|
|
|
|
return inputs[0] * inputs[1];
|
|
|
|
};
|
|
|
|
|
|
|
|
auto intFold = [](llvm::ArrayRef<APInt> inputs) {
|
|
|
|
assert(inputs.size() == 2);
|
|
|
|
return inputs[0] * inputs[1];
|
|
|
|
};
|
|
|
|
|
|
|
|
return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
|
|
|
|
}
|
|
|
|
|
2022-08-16 13:24:08 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenDivTensorModeOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void AtenDivTensorModeOp::getCanonicalizationPatterns(
|
|
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenDivTensorModeOp op, PatternRewriter &rewriter) {
|
|
|
|
return rewrite0DBinaryTensorOp(op, rewriter);
|
2022-06-18 02:49:36 +08:00
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2024-04-16 04:45:10 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenDivScalarModeOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void AtenDivScalarModeOp::getCanonicalizationPatterns(
|
|
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenDivScalarModeOp op, PatternRewriter &rewriter) {
|
|
|
|
return rewrite0DBinaryTensorOp(op, rewriter);
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2023-11-11 12:16:53 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenNumelOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void AtenNumelOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenNumelOp op, PatternRewriter &rewriter) {
|
2024-04-28 05:00:56 +08:00
|
|
|
auto inputType = dyn_cast<BaseTensorType>(op.getSelf().getType());
|
2023-11-11 12:16:53 +08:00
|
|
|
if (!inputType || !inputType.areAllSizesKnown()) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
auto sizes = inputType.getSizes();
|
|
|
|
int64_t numel = 1;
|
|
|
|
for (int64_t d : sizes) {
|
|
|
|
numel *= d;
|
|
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(
|
|
|
|
op, rewriter.getI64IntegerAttr(numel));
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2023-09-06 14:21:51 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten__Or__TensorOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void Aten__Or__TensorOp::getCanonicalizationPatterns(
|
|
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
|
|
patterns.add(+[](Aten__Or__TensorOp op, PatternRewriter &rewriter) {
|
|
|
|
rewriter.replaceOpWithNewOp<AtenBitwiseOrTensorOp>(
|
|
|
|
op, op.getType(), op.getSelf(), op.getOther());
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2024-04-08 20:24:17 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten__And__ScalarOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void Aten__And__ScalarOp::getCanonicalizationPatterns(
|
|
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
|
|
patterns.add(+[](Aten__And__ScalarOp op, PatternRewriter &rewriter) {
|
|
|
|
rewriter.replaceOpWithNewOp<AtenBitwiseAndScalarOp>(
|
|
|
|
op, op.getType(), op.getSelf(), op.getOther());
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2023-03-07 09:38:27 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenScalarImplicitOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void AtenScalarImplicitOp::getCanonicalizationPatterns(
|
|
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenScalarImplicitOp op, PatternRewriter &rewriter) {
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
Value a = op.getA();
|
|
|
|
auto outType = op.getResult().getType();
|
2024-04-26 02:36:13 +08:00
|
|
|
Value scalarIntValue = getScalarIntValue(a, loc, rewriter);
|
|
|
|
if (scalarIntValue) {
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::DerefineOp>(op, outType,
|
|
|
|
scalarIntValue);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
Value scalarFloatValue = getScalarFloatValue(a, loc, rewriter);
|
|
|
|
if (scalarFloatValue) {
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::DerefineOp>(op, outType,
|
|
|
|
scalarFloatValue);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
return failure();
|
2023-03-07 09:38:27 +08:00
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2024-02-27 13:32:05 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenFloatImplicitOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void AtenFloatImplicitOp::getCanonicalizationPatterns(
|
|
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenFloatImplicitOp op, PatternRewriter &rewriter) {
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
Value a = op.getA();
|
|
|
|
Value scalarValue = getScalarFloatValue(a, loc, rewriter);
|
|
|
|
if (!scalarValue)
|
|
|
|
return failure();
|
|
|
|
rewriter.replaceOp(op, scalarValue);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenIntImplicitOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void AtenIntImplicitOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenIntImplicitOp op, PatternRewriter &rewriter) {
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
Value a = op.getA();
|
|
|
|
Value scalarValue = getScalarIntValue(a, loc, rewriter);
|
|
|
|
if (!scalarValue)
|
|
|
|
return failure();
|
|
|
|
rewriter.replaceOp(op, scalarValue);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSizeOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-07-13 03:38:37 +08:00
|
|
|
// Traces at most 6 parents of `value` to determine the tensor type with known
|
|
|
|
// dimension size or returns failure if such a type was not found. If `dim` is
|
|
|
|
// `None`, then all dimension's sizes must be known.
|
|
|
|
static FailureOr<BaseTensorType>
|
2022-12-20 18:17:27 +08:00
|
|
|
traceKnownSizeTensorType(Value value, std::optional<int64_t> dim) {
|
2022-07-13 03:38:37 +08:00
|
|
|
// Function to check if we found a type that contains the queried information.
|
2022-12-20 18:17:27 +08:00
|
|
|
auto foundType = [](BaseTensorType tensorType, std::optional<int64_t>(dim)) {
|
2022-07-13 03:38:37 +08:00
|
|
|
if (!tensorType.hasSizes())
|
|
|
|
return false;
|
|
|
|
|
2022-12-14 16:06:39 +08:00
|
|
|
if (dim == std::nullopt)
|
2022-07-13 03:38:37 +08:00
|
|
|
return tensorType.areAllSizesKnown();
|
|
|
|
|
|
|
|
// If the dimension value is negative, then convert it to a positive value.
|
|
|
|
ArrayRef<int64_t> sizes = tensorType.getSizes();
|
|
|
|
*dim = toPositiveDim(*dim, sizes.size());
|
|
|
|
return isValidDim(*dim, sizes.size()) && sizes[*dim] != kUnknownSize;
|
|
|
|
};
|
|
|
|
|
|
|
|
// Limit the loop count to 6 to avoid indefinite compilation times from
|
|
|
|
// unbounded IR traversals.
|
|
|
|
for (auto idx = 0; idx < 6; ++idx) {
|
2024-05-31 14:45:13 +08:00
|
|
|
if (!value || !isa<BaseTensorType>(value.getType()))
|
2022-07-13 03:38:37 +08:00
|
|
|
return failure();
|
|
|
|
|
2024-04-28 05:00:56 +08:00
|
|
|
auto tensorType = cast<BaseTensorType>(value.getType());
|
2022-07-13 03:38:37 +08:00
|
|
|
if (foundType(tensorType, dim))
|
|
|
|
return tensorType;
|
|
|
|
|
|
|
|
auto op = value.getDefiningOp();
|
|
|
|
if (!op || !isa<CopyToValueTensorOp, CopyToNonValueTensorOp,
|
|
|
|
TensorStaticInfoCastOp>(op))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
// In all ops of interest to us, the source tensor is operand #0.
|
|
|
|
value = op->getOperand(0);
|
|
|
|
}
|
|
|
|
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenSizeOp op, PatternRewriter &rewriter) {
|
2022-12-14 16:06:39 +08:00
|
|
|
auto type = traceKnownSizeTensorType(op.getOperand(), std::nullopt);
|
2022-07-13 03:38:37 +08:00
|
|
|
if (failed(type))
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "all sizes not known");
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
SmallVector<Value> listElements;
|
2022-07-13 03:38:37 +08:00
|
|
|
for (int64_t size : type->getSizes()) {
|
2021-06-16 03:42:51 +08:00
|
|
|
listElements.push_back(rewriter.create<Torch::ConstantIntOp>(
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
op->getLoc(), rewriter.getI64IntegerAttr(size)));
|
|
|
|
}
|
2021-06-05 06:57:21 +08:00
|
|
|
rewriter.replaceOpWithNewOp<Torch::PrimListConstructOp>(
|
2021-06-17 06:53:15 +08:00
|
|
|
op, Torch::ListType::get(rewriter.getType<Torch::IntType>()),
|
|
|
|
listElements);
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
return success();
|
|
|
|
});
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
// One-off pattern to erase if dead.
|
|
|
|
// TODO: Use the effects infra to express the semantics of this op and enable
|
|
|
|
// a centralized "erase if dead" canonicalization.
|
|
|
|
// Specifically, we need to mark the op as only MemoryEffects::Allocate
|
|
|
|
// so that `mlir::wouldOpBeTriviallyDead` does the right thing.
|
|
|
|
patterns.add(+[](AtenSizeOp op, PatternRewriter &rewriter) {
|
|
|
|
if (!op.use_empty())
|
|
|
|
return failure();
|
|
|
|
rewriter.eraseOp(op);
|
|
|
|
return failure();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2024-09-04 07:38:20 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenUnflattenIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void AtenUnflattenIntOp::getCanonicalizationPatterns(
|
|
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
|
|
// if there are only two sizes and one of them is statically 1, then convert
|
|
|
|
// to an unqueeze.
|
|
|
|
patterns.add(+[](AtenUnflattenIntOp op, PatternRewriter &rewriter) {
|
|
|
|
SmallVector<Value> sizeValues;
|
|
|
|
if (!getListConstructElements(op.getSizes(), sizeValues))
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"sizes must come from list construct");
|
|
|
|
if (sizeValues.size() != 2)
|
|
|
|
return failure();
|
|
|
|
int64_t dim0, dim1;
|
|
|
|
bool dim0Constant = matchPattern(sizeValues[0], m_TorchConstantInt(&dim0));
|
|
|
|
bool dim1Constant = matchPattern(sizeValues[1], m_TorchConstantInt(&dim1));
|
|
|
|
if (!dim0Constant && !dim1Constant)
|
|
|
|
return failure();
|
|
|
|
if (dim0 != 1 && dim1 != 1)
|
|
|
|
return failure();
|
|
|
|
Value unflattenDim = op.getDim();
|
|
|
|
Value self = op.getSelf();
|
|
|
|
Value cstMOne = rewriter.create<Torch::ConstantIntOp>(op.getLoc(), -1);
|
|
|
|
// the runtime asserts below are introduced to catch malformed unflatten ops
|
|
|
|
// possibly generated from onnx IR.
|
|
|
|
Value unsqueeze;
|
|
|
|
if (dim0 == 1) {
|
|
|
|
// unsqueeze at dim
|
|
|
|
FailureOr<Value> maybeUnsqueeze =
|
|
|
|
Torch::unsqueezeTensor(rewriter, op, self, unflattenDim);
|
|
|
|
if (failed(maybeUnsqueeze))
|
|
|
|
return rewriter.notifyMatchFailure(op, "failed to create unsqueeze op");
|
|
|
|
unsqueeze = maybeUnsqueeze.value();
|
|
|
|
// check if the remaining size value is either -1 or equal to original
|
|
|
|
// size at dim
|
|
|
|
Value selfSizeAtDim =
|
|
|
|
rewriter.create<AtenSizeIntOp>(op.getLoc(), self, unflattenDim);
|
|
|
|
Value isSameSize = rewriter.create<AtenEqIntOp>(
|
|
|
|
op.getLoc(), selfSizeAtDim, sizeValues[1]);
|
|
|
|
Value isMinusOne =
|
|
|
|
rewriter.create<AtenEqIntOp>(op.getLoc(), cstMOne, sizeValues[1]);
|
|
|
|
Value isMOneOrSameSize = rewriter.create<Aten__Or__BoolOp>(
|
|
|
|
op.getLoc(), isMinusOne, isSameSize);
|
|
|
|
rewriter.create<Torch::RuntimeAssertOp>(
|
|
|
|
op.getLoc(), isMOneOrSameSize,
|
|
|
|
rewriter.getStringAttr("unflatten sizes must be compatible"));
|
|
|
|
}
|
|
|
|
if (dim1 == 1) {
|
|
|
|
// unsqueeze at dim + 1
|
|
|
|
Value cstOne = rewriter.create<Torch::ConstantIntOp>(op.getLoc(), 1);
|
|
|
|
Value dimPlusOne =
|
|
|
|
rewriter.create<AtenAddIntOp>(op.getLoc(), unflattenDim, cstOne);
|
|
|
|
FailureOr<Value> maybeUnsqueeze =
|
|
|
|
Torch::unsqueezeTensor(rewriter, op, self, dimPlusOne);
|
|
|
|
if (failed(maybeUnsqueeze))
|
|
|
|
return rewriter.notifyMatchFailure(op, "failed to create unsqueeze op");
|
|
|
|
unsqueeze = maybeUnsqueeze.value();
|
|
|
|
// check if the remaining size value is either -1 or equal to original
|
|
|
|
// size at dim
|
|
|
|
Value selfSizeAtDim =
|
|
|
|
rewriter.create<AtenSizeIntOp>(op.getLoc(), self, unflattenDim);
|
|
|
|
Value isSameSize = rewriter.create<AtenEqIntOp>(
|
|
|
|
op.getLoc(), selfSizeAtDim, sizeValues[0]);
|
|
|
|
Value isMinusOne =
|
|
|
|
rewriter.create<AtenEqIntOp>(op.getLoc(), cstMOne, sizeValues[0]);
|
|
|
|
Value isMOneOrSameSize = rewriter.create<Aten__Or__BoolOp>(
|
|
|
|
op.getLoc(), isMinusOne, isSameSize);
|
|
|
|
rewriter.create<Torch::RuntimeAssertOp>(
|
|
|
|
op.getLoc(), isMOneOrSameSize,
|
|
|
|
rewriter.getStringAttr("unflatten sizes must be compatible"));
|
|
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::TensorStaticInfoCastOp>(op, op.getType(),
|
|
|
|
unsqueeze);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2024-02-10 06:02:54 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSelectIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenSelectIntOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto self = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
|
|
|
|
auto ty = dyn_cast<ValueTensorType>(getType());
|
|
|
|
if (!self || !ty || !ty.hasDtype() || !ty.hasSizes())
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
auto selfTy = cast<ShapedType>(self.getType());
|
2024-06-08 09:36:32 +08:00
|
|
|
auto bty = ty.toBuiltinTensor();
|
2024-02-10 06:02:54 +08:00
|
|
|
if (!bty.hasStaticShape())
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
if (self.isSplat())
|
|
|
|
return DenseElementsAttr::get(bty, self.getSplatValue<Attribute>());
|
|
|
|
|
|
|
|
auto dimAttr = dyn_cast_or_null<IntegerAttr>(adaptor.getDim());
|
|
|
|
auto indexAttr = dyn_cast_or_null<IntegerAttr>(adaptor.getIndex());
|
|
|
|
if (!dimAttr || !indexAttr || bty.getNumElements() != 1)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
auto dim = dimAttr.getInt();
|
|
|
|
auto index = indexAttr.getInt();
|
|
|
|
|
|
|
|
for (int i = 0, s = selfTy.getRank(); i < s; ++i) {
|
|
|
|
if (i != dim && selfTy.getDimSize(i) != 1)
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
auto splattr = self.getValues<Attribute>()[index];
|
|
|
|
return DenseElementsAttr::get(bty, splattr);
|
|
|
|
}
|
|
|
|
|
2021-10-16 06:23:59 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSizeIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenSizeIntOp::fold(FoldAdaptor adaptor) {
|
2022-07-13 03:38:37 +08:00
|
|
|
int64_t dim;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(this->getDim(), m_TorchConstantInt(&dim)))
|
2021-10-16 06:23:59 +08:00
|
|
|
return nullptr;
|
2022-12-08 04:20:41 +08:00
|
|
|
auto type = traceKnownSizeTensorType(this->getSelf(), dim);
|
2022-07-13 03:38:37 +08:00
|
|
|
if (failed(type))
|
2021-10-16 06:23:59 +08:00
|
|
|
return nullptr;
|
2022-07-13 03:38:37 +08:00
|
|
|
ArrayRef<int64_t> sizes = type->getSizes();
|
|
|
|
dim = toPositiveDim(dim, sizes.size());
|
2023-04-07 19:49:35 +08:00
|
|
|
if (!isValidDim(dim, sizes.size()))
|
|
|
|
return nullptr;
|
2022-07-13 03:38:37 +08:00
|
|
|
return IntegerAttr::get(IntegerType::get(getContext(), 64), sizes[dim]);
|
2021-10-16 06:23:59 +08:00
|
|
|
}
|
|
|
|
|
2021-06-17 02:05:08 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenGtIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static IntegerAttr getI1IntegerAttr(MLIRContext *context, bool value) {
|
|
|
|
return IntegerAttr::get(IntegerType::get(context, 1),
|
|
|
|
static_cast<int64_t>(value));
|
|
|
|
}
|
|
|
|
|
2022-02-11 05:25:25 +08:00
|
|
|
using ConstantFloatComparator = std::function<bool(double, double)>;
|
2021-08-11 09:28:50 +08:00
|
|
|
template <typename OpTy>
|
2022-02-11 05:25:25 +08:00
|
|
|
static OpFoldResult
|
|
|
|
floatComparatorFoldHelper(OpTy op, ConstantFloatComparator comparator) {
|
2021-08-11 09:28:50 +08:00
|
|
|
if (op.getOperand(0) == op.getOperand(1))
|
|
|
|
return getI1IntegerAttr(op.getContext(), comparator(0, 0));
|
|
|
|
|
2022-02-11 05:25:25 +08:00
|
|
|
double lhs, rhs;
|
|
|
|
if (!matchPattern(op.getOperand(0), m_TorchConstantFloat(&lhs)) ||
|
|
|
|
!matchPattern(op.getOperand(1), m_TorchConstantFloat(&rhs)))
|
2021-08-11 09:28:50 +08:00
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
return getI1IntegerAttr(op.getContext(), comparator(lhs, rhs));
|
2021-06-17 02:05:08 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2022-02-11 05:25:25 +08:00
|
|
|
// AtenLtFloatOp
|
2021-06-17 02:05:08 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenLtFloatOp::fold(FoldAdaptor adaptor) {
|
2022-02-11 05:25:25 +08:00
|
|
|
return floatComparatorFoldHelper(*this,
|
|
|
|
[](double a, double b) { return a < b; });
|
2021-08-11 09:28:50 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2022-02-11 05:25:25 +08:00
|
|
|
// AtenGtFloatOp
|
2021-08-11 09:28:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenGtFloatOp::fold(FoldAdaptor adaptor) {
|
2022-02-11 05:25:25 +08:00
|
|
|
return floatComparatorFoldHelper(*this,
|
|
|
|
[](double a, double b) { return a > b; });
|
2021-08-11 09:28:50 +08:00
|
|
|
}
|
|
|
|
|
2022-04-25 21:12:45 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenGeFloatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenGeFloatOp::fold(FoldAdaptor adaptor) {
|
2022-04-25 21:12:45 +08:00
|
|
|
return floatComparatorFoldHelper(*this,
|
|
|
|
[](double a, double b) { return a >= b; });
|
|
|
|
}
|
|
|
|
|
2022-01-11 15:42:53 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenEqFloatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenEqFloatOp::fold(FoldAdaptor adaptor) {
|
2022-02-11 05:25:25 +08:00
|
|
|
return floatComparatorFoldHelper(*this,
|
|
|
|
[](double a, double b) { return a == b; });
|
|
|
|
}
|
|
|
|
|
|
|
|
using ConstantIntComparator = std::function<bool(int64_t, int64_t)>;
|
|
|
|
template <typename OpTy>
|
|
|
|
static OpFoldResult intComparatorFoldHelper(OpTy op,
|
|
|
|
ConstantIntComparator comparator) {
|
2022-03-10 08:44:22 +08:00
|
|
|
|
|
|
|
Value lhsValue = op->getOperand(0);
|
|
|
|
Value rhsValue = op->getOperand(1);
|
|
|
|
if (lhsValue == rhsValue)
|
2022-02-11 05:25:25 +08:00
|
|
|
return getI1IntegerAttr(op.getContext(), comparator(0, 0));
|
2022-01-11 15:42:53 +08:00
|
|
|
|
2022-02-11 05:25:25 +08:00
|
|
|
int64_t lhs, rhs;
|
2022-03-10 08:44:22 +08:00
|
|
|
bool lhsIsConstant = matchPattern(lhsValue, m_TorchConstantInt(&lhs));
|
|
|
|
bool rhsIsConstant = matchPattern(rhsValue, m_TorchConstantInt(&rhs));
|
|
|
|
if (lhsIsConstant && rhsIsConstant)
|
|
|
|
return getI1IntegerAttr(op.getContext(), comparator(lhs, rhs));
|
2022-01-11 15:42:53 +08:00
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
// Ensure that if there is a constant, it is on the right.
|
|
|
|
if (lhsIsConstant && !rhsIsConstant) {
|
|
|
|
std::swap(lhs, rhs);
|
|
|
|
std::swap(lhsValue, rhsValue);
|
|
|
|
std::swap(lhsIsConstant, rhsIsConstant);
|
|
|
|
auto newComparator = [comparator](int64_t lhs, int64_t rhs) {
|
|
|
|
return comparator(rhs, lhs);
|
|
|
|
};
|
|
|
|
comparator = newComparator;
|
|
|
|
}
|
2022-03-31 05:10:51 +08:00
|
|
|
|
|
|
|
// Fold comparisons of AtenSizeIntOp against negative values.
|
|
|
|
// AtenSizeIntOp is known to always be non-negative.
|
2022-03-10 08:44:22 +08:00
|
|
|
if (rhsIsConstant && rhs < 0) {
|
|
|
|
// We can return `comparator(0, -1)` here because of the property:
|
|
|
|
// If x >= 0 && y < 0, then:
|
|
|
|
// - cmp(x, y) == cmp(x + 1, y)
|
|
|
|
// - cmp(x, y) == cmp(x, y - 1)
|
|
|
|
// By induction all cases here are covered.
|
|
|
|
if (auto size = lhsValue.getDefiningOp<AtenSizeIntOp>())
|
|
|
|
return getI1IntegerAttr(op->getContext(), comparator(0, -1));
|
|
|
|
}
|
2022-03-31 05:10:51 +08:00
|
|
|
|
|
|
|
// Fold comparisons of AtenSizeIntOp against 0:
|
|
|
|
// - torch.aten.size.int >= 0 ==> True.
|
|
|
|
// - torch.aten.size.int < 0 ==> False.
|
|
|
|
// (and the operand-swapped versions of the above)
|
|
|
|
if (rhsIsConstant && rhs == 0) {
|
|
|
|
if (auto size = lhsValue.getDefiningOp<AtenSizeIntOp>()) {
|
|
|
|
// >= 0 comparison.
|
|
|
|
if (comparator(0, 0) && comparator(1, 0))
|
|
|
|
return getI1IntegerAttr(op->getContext(), true);
|
|
|
|
// < 0 comparison.
|
|
|
|
if (!comparator(0, 0) && comparator(-1, 0) && !comparator(1, 0))
|
|
|
|
return getI1IntegerAttr(op->getContext(), false);
|
|
|
|
}
|
2022-03-10 08:44:22 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
return nullptr;
|
2022-02-11 05:25:25 +08:00
|
|
|
}
|
|
|
|
|
2023-04-18 23:59:14 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenDetachOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2024-01-31 01:45:51 +08:00
|
|
|
OpFoldResult AtenDetachOp::fold(FoldAdaptor adaptor) {
|
|
|
|
if (getSelf().getType() != getResult().getType())
|
|
|
|
return {};
|
|
|
|
return getSelf();
|
|
|
|
}
|
2023-04-18 23:59:14 +08:00
|
|
|
|
2022-02-11 05:25:25 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenNeIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenNeIntOp::fold(FoldAdaptor adaptor) {
|
2022-02-11 05:25:25 +08:00
|
|
|
return intComparatorFoldHelper(*this,
|
|
|
|
[](int64_t a, int64_t b) { return a != b; });
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenEqIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenEqIntOp::fold(FoldAdaptor adaptor) {
|
2022-02-11 05:25:25 +08:00
|
|
|
return intComparatorFoldHelper(*this,
|
|
|
|
[](int64_t a, int64_t b) { return a == b; });
|
2022-01-11 15:42:53 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenEqStrOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenEqStrOp::fold(FoldAdaptor adaptor) {
|
2022-01-11 15:42:53 +08:00
|
|
|
if (getOperand(0) == getOperand(1))
|
|
|
|
return getI1IntegerAttr(getContext(), true);
|
|
|
|
|
2024-04-28 00:58:50 +08:00
|
|
|
auto aStr = adaptor.getA();
|
|
|
|
auto bStr = adaptor.getB();
|
2022-01-11 15:42:53 +08:00
|
|
|
if (aStr && bStr)
|
|
|
|
return getI1IntegerAttr(getContext(), aStr == bStr);
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2024-04-28 00:58:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenNeStrOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenNeStrOp::fold(FoldAdaptor adaptor) {
|
|
|
|
if (getOperand(0) == getOperand(1))
|
|
|
|
return getI1IntegerAttr(getContext(), false);
|
|
|
|
|
|
|
|
auto aStr = adaptor.getA();
|
|
|
|
auto bStr = adaptor.getB();
|
|
|
|
if (aStr && bStr)
|
|
|
|
return getI1IntegerAttr(getContext(), aStr != bStr);
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2024-04-29 10:51:17 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten__Contains__StrListOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult Aten__Contains__StrListOp::fold(FoldAdaptor adaptor) {
|
2024-05-31 08:34:37 +08:00
|
|
|
StringAttr item = dyn_cast_or_null<StringAttr>(adaptor.getItem());
|
2024-04-29 10:51:17 +08:00
|
|
|
if (!item)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
if (auto listConstruct = getL().getDefiningOp<Torch::PrimListConstructOp>()) {
|
|
|
|
if (isListPotentiallyMutated(listConstruct))
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
llvm::SmallVector<std::string> strs;
|
|
|
|
if (matchPattern(getL(), m_TorchListOfConstantStrs(strs))) {
|
|
|
|
for (const auto &str : strs) {
|
|
|
|
if (item.getValue().str() == str)
|
|
|
|
return getI1IntegerAttr(getContext(), true);
|
|
|
|
}
|
|
|
|
return getI1IntegerAttr(getContext(), false);
|
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2021-08-11 09:28:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenLtIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenLtIntOp::fold(FoldAdaptor adaptor) {
|
2022-02-11 05:25:25 +08:00
|
|
|
return intComparatorFoldHelper(*this,
|
|
|
|
[](int64_t a, int64_t b) { return a < b; });
|
2021-08-11 09:28:50 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenLeIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenLeIntOp::fold(FoldAdaptor adaptor) {
|
2022-02-11 05:25:25 +08:00
|
|
|
return intComparatorFoldHelper(*this,
|
|
|
|
[](int64_t a, int64_t b) { return a <= b; });
|
2021-08-11 09:28:50 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenGtIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenGtIntOp::fold(FoldAdaptor adaptor) {
|
2022-02-11 05:25:25 +08:00
|
|
|
return intComparatorFoldHelper(*this,
|
|
|
|
[](int64_t a, int64_t b) { return a > b; });
|
2021-08-11 09:28:50 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenGeIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenGeIntOp::fold(FoldAdaptor adaptor) {
|
2022-02-11 05:25:25 +08:00
|
|
|
return intComparatorFoldHelper(*this,
|
|
|
|
[](int64_t a, int64_t b) { return a >= b; });
|
2021-06-17 02:05:08 +08:00
|
|
|
}
|
|
|
|
|
2022-05-20 16:26:52 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenBoolFloatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenBoolFloatOp::fold(FoldAdaptor adaptor) {
|
2022-05-20 16:26:52 +08:00
|
|
|
double c;
|
|
|
|
if (matchPattern(getOperand(), m_TorchConstantFloat(&c)))
|
|
|
|
return getI1IntegerAttr(getContext(), c != 0.0);
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenBoolIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenBoolIntOp::fold(FoldAdaptor adaptor) {
|
2022-05-20 16:26:52 +08:00
|
|
|
int64_t c;
|
|
|
|
if (matchPattern(getOperand(), m_TorchConstantInt(&c)))
|
|
|
|
return getI1IntegerAttr(getContext(), c != 0);
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2023-08-30 17:29:03 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenAnyBoolOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenAnyBoolOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto inputConstruct = getSelf().getDefiningOp<Torch::PrimListConstructOp>();
|
|
|
|
if (!inputConstruct || isListPotentiallyMutated(inputConstruct))
|
|
|
|
return nullptr;
|
|
|
|
// If any operand is a constant true, return true.
|
|
|
|
for (auto operand : inputConstruct.getOperands()) {
|
2023-09-04 09:59:26 +08:00
|
|
|
bool b = false;
|
2023-08-30 17:29:03 +08:00
|
|
|
if (matchPattern(operand, m_TorchConstantBool(&b)) && b) {
|
|
|
|
return getI1IntegerAttr(getContext(), true);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenFloatScalarOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenFloatScalarOp::fold(FoldAdaptor adaptor) {
|
2022-03-10 08:44:22 +08:00
|
|
|
// Constant fold int -> float conversion.
|
2024-05-31 14:45:13 +08:00
|
|
|
if (auto integerAttr = dyn_cast_or_null<IntegerAttr>(adaptor.getA())) {
|
2022-03-10 08:44:22 +08:00
|
|
|
return FloatAttr::get(
|
|
|
|
mlir::Float64Type::get(getContext()),
|
|
|
|
static_cast<double>(integerAttr.getValue().getSExtValue()));
|
|
|
|
}
|
|
|
|
// If the input is float type already, the op is an identity.
|
|
|
|
if (getType() == getOperand().getType())
|
|
|
|
return getOperand();
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2022-04-26 20:15:30 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2023-02-11 05:59:03 +08:00
|
|
|
// AtenIntFloatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenIntFloatOp::fold(FoldAdaptor adaptor) {
|
|
|
|
// Constant fold float -> int conversion.
|
2024-05-31 14:45:13 +08:00
|
|
|
if (auto floatAttr = dyn_cast_or_null<FloatAttr>(adaptor.getA())) {
|
2023-02-11 05:59:03 +08:00
|
|
|
return IntegerAttr::get(
|
2023-05-23 08:21:34 +08:00
|
|
|
mlir::IntegerType::get(getContext(), 64),
|
2023-02-11 05:59:03 +08:00
|
|
|
static_cast<int64_t>(floatAttr.getValue().convertToDouble()));
|
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2022-04-26 20:15:30 +08:00
|
|
|
// AtenIntScalarOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenIntScalarOp::fold(FoldAdaptor adaptor) {
|
2022-04-26 20:15:30 +08:00
|
|
|
// Constant fold float -> int conversion.
|
2024-05-31 14:45:13 +08:00
|
|
|
if (auto floatAttr = dyn_cast_or_null<FloatAttr>(adaptor.getA())) {
|
2022-04-26 20:15:30 +08:00
|
|
|
return IntegerAttr::get(
|
2023-05-23 08:21:34 +08:00
|
|
|
mlir::IntegerType::get(getContext(), 64),
|
2022-04-26 20:15:30 +08:00
|
|
|
static_cast<long>(floatAttr.getValue().convertToDouble()));
|
|
|
|
}
|
|
|
|
// If the input is int type already, the op is an identity.
|
|
|
|
if (getType() == getOperand().getType())
|
|
|
|
return getOperand();
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2023-01-18 02:14:14 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenIntBoolOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenIntBoolOp::fold(FoldAdaptor adaptor) {
|
2023-01-18 02:14:14 +08:00
|
|
|
bool b;
|
|
|
|
if (matchPattern(getOperand(), m_TorchConstantBool(&b))) {
|
|
|
|
return getI64IntegerAttr(getContext(), static_cast<long>(b));
|
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2023-11-21 13:26:17 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenMaskedFillTensorOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
// Fold 0d fill tensor to scalar
|
|
|
|
void AtenMaskedFillTensorOp::getCanonicalizationPatterns(
|
|
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenMaskedFillTensorOp op, PatternRewriter &rewriter) {
|
|
|
|
auto scalarIntVal =
|
|
|
|
getScalarIntValue(op.getValue(), op->getLoc(), rewriter);
|
|
|
|
auto scalarFloatVal =
|
|
|
|
getScalarFloatValue(op.getValue(), op->getLoc(), rewriter);
|
|
|
|
if (!scalarIntVal && !scalarFloatVal)
|
|
|
|
return failure();
|
|
|
|
Value scalarVal = scalarIntVal ? scalarIntVal : scalarFloatVal;
|
|
|
|
rewriter.replaceOpWithNewOp<AtenMaskedFillScalarOp>(
|
|
|
|
op, op.getType(), op.getSelf(), op.getMask(), scalarVal);
|
|
|
|
return failure();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2024-01-31 09:43:21 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenCloneOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenCloneOp::fold(FoldAdaptor adaptor) {
|
|
|
|
// note: memory_format would be ignored
|
2024-05-16 00:07:45 +08:00
|
|
|
if (getSelf().getType() == getResult().getType() &&
|
|
|
|
llvm::dyn_cast<ValueTensorType>(getSelf().getType())) {
|
2024-01-31 09:43:21 +08:00
|
|
|
// self should have value semantics
|
|
|
|
return getSelf();
|
|
|
|
}
|
|
|
|
return {};
|
|
|
|
}
|
|
|
|
|
2022-11-18 19:47:07 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSortIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void AtenSortIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenSortIntOp op, PatternRewriter &rewriter) {
|
|
|
|
SmallVector<int64_t> listElements;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getSelf(), m_TorchListOfConstantInts(listElements)))
|
2022-11-18 19:47:07 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "all input list elements must be constant ints");
|
|
|
|
bool reverse;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getReverse(), m_TorchConstantBool(&reverse)))
|
2022-11-18 19:47:07 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "Expected reverse arg to be constant bool.");
|
|
|
|
|
|
|
|
std::sort(listElements.begin(), listElements.end());
|
|
|
|
if (reverse)
|
|
|
|
std::reverse(listElements.begin(), listElements.end());
|
|
|
|
|
|
|
|
SmallVector<Value> sortedListElements;
|
|
|
|
for (int64_t elem : listElements)
|
|
|
|
sortedListElements.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
op->getLoc(), rewriter.getI64IntegerAttr(elem)));
|
|
|
|
Value result = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
op->getLoc(), Torch::ListType::get(rewriter.getType<Torch::IntType>()),
|
|
|
|
sortedListElements);
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
op.getSelf().replaceAllUsesWith(result);
|
2022-11-18 19:47:07 +08:00
|
|
|
rewriter.eraseOp(op);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2024-02-07 05:12:12 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSortOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult AtenSortOp::fold(FoldAdaptor adaptor,
|
|
|
|
SmallVectorImpl<OpFoldResult> &results) {
|
|
|
|
auto operand = getSelf();
|
|
|
|
auto operandType = dyn_cast<BaseTensorType>(operand.getType());
|
|
|
|
if (!operandType || !operandType.hasSizes())
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
// only ValueTensorType has toBuiltinTensor
|
|
|
|
auto indicesTensorType = dyn_cast<ValueTensorType>(getResult(1).getType());
|
|
|
|
if (!indicesTensorType)
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
if (!indicesTensorType.hasDtype())
|
|
|
|
return failure();
|
2024-06-08 09:36:32 +08:00
|
|
|
auto indicesType = indicesTensorType.toBuiltinTensor();
|
2024-02-07 05:12:12 +08:00
|
|
|
if (!indicesType || !indicesType.hasStaticShape())
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
bool unaryDim = false;
|
|
|
|
IntegerAttr dimAttribute = dyn_cast_if_present<IntegerAttr>(adaptor.getDim());
|
|
|
|
if (!dimAttribute)
|
|
|
|
return failure();
|
|
|
|
int64_t dimInt = dimAttribute.getValue().getSExtValue();
|
|
|
|
if (dimInt < 0)
|
|
|
|
dimInt += operandType.getSizes().size();
|
|
|
|
if (dimAttribute) {
|
|
|
|
unaryDim = operandType.getSizes()[dimInt] == 1;
|
|
|
|
}
|
|
|
|
|
|
|
|
OpBuilder builder(getContext());
|
|
|
|
if (unaryDim || llvm::all_of(operandType.getSizes(),
|
|
|
|
[](int64_t dim) { return dim == 1; })) {
|
|
|
|
results.push_back(operand);
|
|
|
|
results.push_back(DenseElementsAttr::get(
|
|
|
|
indicesType, builder.getZeroAttr(indicesType.getElementType())));
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2021-06-17 23:52:13 +08:00
|
|
|
// NonValueTensorLiteralOp
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2021-06-17 23:52:13 +08:00
|
|
|
LogicalResult NonValueTensorLiteralOp::inferReturnTypes(
|
2022-12-20 18:17:27 +08:00
|
|
|
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
2023-05-09 06:33:24 +08:00
|
|
|
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
2021-06-17 23:52:13 +08:00
|
|
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
2024-05-31 14:45:13 +08:00
|
|
|
auto attr =
|
|
|
|
dyn_cast_or_null<ElementsAttr>(properties.as<Properties *>()->getValue());
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
if (!attr)
|
|
|
|
return failure();
|
2024-04-28 05:00:56 +08:00
|
|
|
RankedTensorType tensorType = cast<RankedTensorType>(attr.getType());
|
2021-09-14 08:57:59 +08:00
|
|
|
NonValueTensorType returnType =
|
|
|
|
NonValueTensorType::get(tensorType.getContext(), tensorType.getShape(),
|
|
|
|
tensorType.getElementType());
|
|
|
|
inferredReturnTypes.push_back(returnType);
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
static bool areSizesAndDtypesCompatible(BaseTensorType a, BaseTensorType b) {
|
|
|
|
if (a.hasSizes() && b.hasSizes()) {
|
2022-11-29 20:33:31 +08:00
|
|
|
if (failed(verifyCompatibleShape(makeShapeLLVMCompatible(a.getSizes()),
|
|
|
|
makeShapeLLVMCompatible(b.getSizes()))))
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
return false;
|
|
|
|
}
|
|
|
|
if (a.hasDtype() && b.hasDtype()) {
|
|
|
|
if (a.getDtype() != b.getDtype())
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2021-06-17 23:52:13 +08:00
|
|
|
bool NonValueTensorLiteralOp::isCompatibleReturnTypes(TypeRange inferred,
|
|
|
|
TypeRange actual) {
|
2024-05-31 14:45:13 +08:00
|
|
|
if (!isa<BaseTensorType>(actual[0]))
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
return false;
|
2024-05-31 14:45:13 +08:00
|
|
|
return areSizesAndDtypesCompatible(cast<BaseTensorType>(inferred[0]),
|
|
|
|
cast<BaseTensorType>(actual[0]));
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
}
|
|
|
|
|
2021-06-17 23:52:13 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ValueTensorLiteralOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult ValueTensorLiteralOp::inferReturnTypes(
|
2022-12-20 18:17:27 +08:00
|
|
|
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
2023-05-09 06:33:24 +08:00
|
|
|
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
2021-06-17 23:52:13 +08:00
|
|
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
2024-05-31 14:45:13 +08:00
|
|
|
auto attr =
|
|
|
|
dyn_cast_or_null<ElementsAttr>(properties.as<Properties *>()->getValue());
|
2021-06-17 23:52:13 +08:00
|
|
|
if (!attr)
|
|
|
|
return failure();
|
2024-04-28 05:00:56 +08:00
|
|
|
RankedTensorType tensorType = cast<RankedTensorType>(attr.getType());
|
2021-09-14 08:57:59 +08:00
|
|
|
ValueTensorType returnType =
|
|
|
|
ValueTensorType::get(tensorType.getContext(), tensorType.getShape(),
|
|
|
|
tensorType.getElementType());
|
|
|
|
inferredReturnTypes.push_back(returnType);
|
2021-06-17 23:52:13 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult ValueTensorLiteralOp::fold(FoldAdaptor adaptor) {
|
2022-12-08 04:20:41 +08:00
|
|
|
return getValueAttr();
|
2021-06-17 23:52:13 +08:00
|
|
|
}
|
|
|
|
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
//----------------------------------------------------------------------------//
|
|
|
|
// TensorStaticInfoCast
|
|
|
|
//----------------------------------------------------------------------------//
|
|
|
|
|
|
|
|
bool TensorStaticInfoCastOp::areCastCompatible(mlir::TypeRange inputs,
|
|
|
|
mlir::TypeRange outputs) {
|
2024-05-31 14:45:13 +08:00
|
|
|
return areSizesAndDtypesCompatible(cast<BaseTensorType>(inputs[0]),
|
|
|
|
cast<BaseTensorType>(outputs[0]));
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
}
|
|
|
|
|
2021-10-16 06:23:59 +08:00
|
|
|
void TensorStaticInfoCastOp::getCanonicalizationPatterns(
|
|
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
|
|
patterns.add(+[](TensorStaticInfoCastOp op, PatternRewriter &rewriter) {
|
|
|
|
auto reverseCast =
|
2022-12-08 04:20:41 +08:00
|
|
|
op.getOperand().getDefiningOp<Torch::TensorStaticInfoCastOp>();
|
|
|
|
if (!reverseCast || reverseCast.getOperand().getType() != op.getType())
|
2021-10-16 06:23:59 +08:00
|
|
|
return failure();
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
rewriter.replaceOp(op, reverseCast.getOperand());
|
2021-10-16 06:23:59 +08:00
|
|
|
return success();
|
|
|
|
});
|
2022-03-10 08:44:22 +08:00
|
|
|
patterns.add(+[](TensorStaticInfoCastOp op, PatternRewriter &rewriter) {
|
2022-03-19 00:10:12 +08:00
|
|
|
if (isValidSubtype(op.getOperand().getType(), op.getType())) {
|
|
|
|
SmallVector<std::reference_wrapper<OpOperand>> usesToChange(
|
|
|
|
llvm::make_filter_range(op->getUses(), [](OpOperand &operand) {
|
|
|
|
return operand.getOwner()
|
|
|
|
->hasTrait<mlir::torch::Torch::OpTrait::AllowsTypeRefinement>();
|
|
|
|
}));
|
|
|
|
|
|
|
|
if (usesToChange.empty())
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
for (OpOperand &use : usesToChange) {
|
|
|
|
Operation *user = use.getOwner();
|
2022-12-08 04:20:41 +08:00
|
|
|
user->setOperand(use.getOperandNumber(), op.getOperand());
|
2022-03-19 00:10:12 +08:00
|
|
|
}
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
return failure();
|
|
|
|
});
|
2021-10-16 06:23:59 +08:00
|
|
|
}
|
|
|
|
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2021-06-19 04:47:47 +08:00
|
|
|
// CopyToNonValueTensorOp
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-03-16 08:54:57 +08:00
|
|
|
LogicalResult CopyToNonValueTensorOp::verify() {
|
2024-04-28 05:00:56 +08:00
|
|
|
auto resultType = cast<BaseTensorType>(getResult().getType());
|
|
|
|
auto operandType = cast<BaseTensorType>(getOperand().getType());
|
2022-03-16 08:54:57 +08:00
|
|
|
if (!resultType.hasSameSizesAndDtype(operandType))
|
|
|
|
return emitError() << "operand and result must have same sizes and dtype";
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2021-06-19 04:47:47 +08:00
|
|
|
LogicalResult CopyToNonValueTensorOp::inferReturnTypes(
|
2022-12-20 18:17:27 +08:00
|
|
|
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
2023-05-09 06:33:24 +08:00
|
|
|
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
2021-06-19 04:47:47 +08:00
|
|
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
2024-04-28 05:00:56 +08:00
|
|
|
auto resultType = cast<ValueTensorType>(operands[0].getType());
|
2021-06-19 04:47:47 +08:00
|
|
|
inferredReturnTypes.push_back(resultType.getWithoutValueSemantics());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
void CopyToNonValueTensorOp::getEffects(
|
|
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
|
|
&effects) {
|
2024-06-28 10:28:02 +08:00
|
|
|
effects.emplace_back(MemoryEffects::Allocate::get(),
|
|
|
|
getOperation()->getOpResult(0));
|
2021-06-19 04:47:47 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// CopyToValueTensorOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-03-16 08:54:57 +08:00
|
|
|
LogicalResult CopyToValueTensorOp::verify() {
|
2024-04-28 05:00:56 +08:00
|
|
|
auto resultType = cast<BaseTensorType>(getResult().getType());
|
|
|
|
auto operandType = cast<BaseTensorType>(getOperand().getType());
|
2022-03-16 08:54:57 +08:00
|
|
|
if (!resultType.hasSameSizesAndDtype(operandType))
|
|
|
|
return emitError() << "operand and result must have same sizes and dtype";
|
2021-06-19 04:47:47 +08:00
|
|
|
return success();
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
}
|
|
|
|
|
2021-06-19 04:47:47 +08:00
|
|
|
LogicalResult CopyToValueTensorOp::inferReturnTypes(
|
2022-12-20 18:17:27 +08:00
|
|
|
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
2023-05-09 06:33:24 +08:00
|
|
|
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
2021-06-19 04:47:47 +08:00
|
|
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
2024-04-28 05:00:56 +08:00
|
|
|
auto resultType = cast<NonValueTensorType>(operands[0].getType());
|
2021-06-19 04:47:47 +08:00
|
|
|
inferredReturnTypes.push_back(resultType.getWithValueSemantics());
|
|
|
|
return success();
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
}
|
|
|
|
|
2021-06-19 04:47:47 +08:00
|
|
|
void CopyToValueTensorOp::getEffects(
|
|
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
2021-06-18 07:29:20 +08:00
|
|
|
&effects) {
|
2024-06-28 10:28:02 +08:00
|
|
|
effects.emplace_back(MemoryEffects::Read::get(),
|
|
|
|
&getOperation()->getOpOperand(0));
|
2021-06-18 07:29:20 +08:00
|
|
|
}
|
|
|
|
|
2021-06-15 02:36:10 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ConstantNoneOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult ConstantNoneOp::fold(FoldAdaptor adaptor) {
|
2021-06-15 02:36:10 +08:00
|
|
|
return TypeAttr::get(Torch::NoneType::get(getContext()));
|
|
|
|
}
|
|
|
|
|
2021-06-16 03:42:51 +08:00
|
|
|
void ConstantNoneOp::getAsmResultNames(
|
|
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
|
|
|
setNameFn(getResult(), "none");
|
|
|
|
}
|
|
|
|
|
2021-06-15 23:29:06 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ConstantStrOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult ConstantStrOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
|
2021-06-15 23:29:06 +08:00
|
|
|
|
2021-06-16 03:42:51 +08:00
|
|
|
void ConstantStrOp::getAsmResultNames(
|
|
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
|
|
|
setNameFn(getResult(), "str");
|
|
|
|
}
|
|
|
|
|
2021-09-28 22:56:08 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ConstantDeviceOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void ConstantDeviceOp::getAsmResultNames(
|
|
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
2022-12-08 04:20:41 +08:00
|
|
|
setNameFn(getResult(), getValue());
|
2021-09-28 22:56:08 +08:00
|
|
|
}
|
|
|
|
|
2021-06-16 03:42:51 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ConstantIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-13 02:47:12 +08:00
|
|
|
ParseResult ConstantIntOp::parse(OpAsmParser &parser, OperationState &result) {
|
2021-06-17 06:53:15 +08:00
|
|
|
Builder builder(result.getContext());
|
|
|
|
result.addTypes(builder.getType<Torch::IntType>());
|
|
|
|
int64_t value;
|
|
|
|
if (parser.parseInteger(value))
|
|
|
|
return failure();
|
2024-06-28 22:06:52 +08:00
|
|
|
if (parser.parseOptionalAttrDict(result.attributes))
|
|
|
|
return failure();
|
2021-06-17 06:53:15 +08:00
|
|
|
result.addAttribute("value", builder.getI64IntegerAttr(value));
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-02-13 02:47:12 +08:00
|
|
|
void ConstantIntOp::print(OpAsmPrinter &p) {
|
2021-09-04 02:38:00 +08:00
|
|
|
p << " ";
|
2023-05-23 08:21:34 +08:00
|
|
|
p << getValueAttr().getInt();
|
2022-02-13 02:47:12 +08:00
|
|
|
p.printOptionalAttrDict((*this)->getAttrs(), {"value"});
|
2021-06-17 06:53:15 +08:00
|
|
|
}
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult Torch::ConstantIntOp::fold(FoldAdaptor adaptor) {
|
2022-12-08 04:20:41 +08:00
|
|
|
return getValueAttr();
|
2021-06-16 03:42:51 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
void Torch::ConstantIntOp::getAsmResultNames(
|
|
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
|
|
|
SmallVector<char> buf;
|
|
|
|
llvm::raw_svector_ostream os(buf);
|
2023-05-23 08:21:34 +08:00
|
|
|
os << "int" << getValueAttr().getInt();
|
2021-06-16 03:42:51 +08:00
|
|
|
setNameFn(getResult(), os.str());
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ConstantFloatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult Torch::ConstantFloatOp::fold(FoldAdaptor adaptor) {
|
2022-12-08 04:20:41 +08:00
|
|
|
return getValueAttr();
|
2021-06-16 03:42:51 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
void Torch::ConstantFloatOp::getAsmResultNames(
|
|
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
2021-06-23 05:25:16 +08:00
|
|
|
// Calculate a stringified version of the number, compatible with MLIR
|
|
|
|
// identifier syntax. (in practice, this just removes the '+' from 'e+' in
|
|
|
|
// float string representation).
|
|
|
|
SmallVector<char> buf;
|
2022-12-08 04:20:41 +08:00
|
|
|
getValue().toString(buf, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0,
|
2024-01-30 01:59:33 +08:00
|
|
|
/*TruncateZero=*/false);
|
2021-06-23 05:25:16 +08:00
|
|
|
auto isValidMLIRIdentifierChar = [](char c) {
|
|
|
|
return isalpha(c) || isdigit(c) || c == '_' || c == '$' || c == '.' ||
|
|
|
|
c == '-';
|
|
|
|
};
|
|
|
|
auto numberStr = llvm::to_vector<16>(
|
|
|
|
llvm::make_filter_range(buf, isValidMLIRIdentifierChar));
|
|
|
|
|
|
|
|
// Construct the identifier string.
|
|
|
|
buf.clear();
|
|
|
|
llvm::append_range(buf, StringRef("float"));
|
|
|
|
llvm::append_range(buf, numberStr);
|
|
|
|
setNameFn(getResult(), StringRef(buf.data(), buf.size()));
|
2021-06-16 03:42:51 +08:00
|
|
|
}
|
|
|
|
|
2022-09-20 12:40:19 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ConstantNumberOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult Torch::ConstantNumberOp::fold(FoldAdaptor adaptor) {
|
2022-12-08 04:20:41 +08:00
|
|
|
return getValueAttr();
|
2022-09-20 12:40:19 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
void Torch::ConstantNumberOp::getCanonicalizationPatterns(
|
|
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
|
|
patterns.add(+[](Torch::ConstantNumberOp op, PatternRewriter &rewriter) {
|
|
|
|
Location loc = op->getLoc();
|
|
|
|
|
|
|
|
Value constValue;
|
2022-12-08 04:20:41 +08:00
|
|
|
Attribute value = op.getValueAttr();
|
2024-04-11 21:47:35 +08:00
|
|
|
if (auto floatValue = dyn_cast<mlir::FloatAttr>(value)) {
|
2022-09-20 12:40:19 +08:00
|
|
|
constValue = rewriter.create<Torch::ConstantFloatOp>(loc, floatValue);
|
2024-04-11 21:47:35 +08:00
|
|
|
} else if (auto intValue = dyn_cast<mlir::IntegerAttr>(value)) {
|
2022-09-20 12:40:19 +08:00
|
|
|
constValue = rewriter.create<Torch::ConstantIntOp>(loc, intValue);
|
|
|
|
} else {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::DerefineOp>(op, op.getType(),
|
|
|
|
constValue);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2021-06-16 07:47:53 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ConstantBoolOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult Torch::ConstantBoolOp::fold(FoldAdaptor adaptor) {
|
2022-12-08 04:20:41 +08:00
|
|
|
return getValueAttr();
|
2021-06-16 07:47:53 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
void Torch::ConstantBoolOp::getAsmResultNames(
|
|
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
2022-12-08 04:20:41 +08:00
|
|
|
setNameFn(getResult(), getValue() ? "true" : "false");
|
2021-06-16 07:47:53 +08:00
|
|
|
}
|
|
|
|
|
2021-06-05 06:57:21 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2021-06-23 04:56:12 +08:00
|
|
|
// PrimUncheckedCastOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
bool PrimUncheckedCastOp::areCastCompatible(mlir::TypeRange inputs,
|
|
|
|
mlir::TypeRange outputs) {
|
|
|
|
return isValidSubtype(outputs[0], inputs[0]);
|
|
|
|
}
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult PrimUncheckedCastOp::fold(FoldAdaptor adaptor) {
|
2022-12-08 04:20:41 +08:00
|
|
|
if (auto derefineOp = getX().getDefiningOp<Torch::DerefineOp>()) {
|
|
|
|
if (derefineOp.getOperand().getType() == getType())
|
|
|
|
return derefineOp.getOperand();
|
2022-01-28 21:35:40 +08:00
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2021-06-23 04:56:12 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2021-06-05 06:57:21 +08:00
|
|
|
// Aten__Getitem__TOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2021-06-19 10:33:14 +08:00
|
|
|
void Aten__Getitem__TOp::getCanonicalizationPatterns(
|
|
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
2021-06-05 06:57:21 +08:00
|
|
|
patterns.add(+[](Aten__Getitem__TOp op, PatternRewriter &rewriter) {
|
|
|
|
auto torchList = op.getOperand(0);
|
2022-03-10 08:44:22 +08:00
|
|
|
if (isListPotentiallyMutated(torchList))
|
2021-06-05 06:57:21 +08:00
|
|
|
return failure();
|
|
|
|
|
|
|
|
auto listConstruct = torchList.getDefiningOp<Torch::PrimListConstructOp>();
|
|
|
|
if (!listConstruct)
|
|
|
|
return failure();
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
// Get the index, but be careful because it might be statically invalid.
|
2022-12-20 18:17:27 +08:00
|
|
|
std::optional<int64_t> indexOpt = matchLegalConstantIndexIntoListOfSize(
|
2022-03-30 04:21:47 +08:00
|
|
|
op.getOperand(1), listConstruct.getNumOperands());
|
|
|
|
if (!indexOpt)
|
2022-03-10 08:44:22 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "statically invalid index");
|
2021-06-05 06:57:21 +08:00
|
|
|
|
2022-03-30 04:21:47 +08:00
|
|
|
rewriter.replaceOp(op, {listConstruct.getOperand(*indexOpt)});
|
2021-06-05 06:57:21 +08:00
|
|
|
return success();
|
|
|
|
});
|
2022-03-10 08:44:22 +08:00
|
|
|
patterns.add(+[](Aten__Getitem__TOp op, PatternRewriter &rewriter) {
|
2022-12-08 04:20:41 +08:00
|
|
|
auto sizeOp = op.getList().getDefiningOp<AtenSizeOp>();
|
2022-03-10 08:44:22 +08:00
|
|
|
if (!sizeOp)
|
|
|
|
return failure();
|
|
|
|
// This assumes tht the size doesn't change between the
|
|
|
|
// AtenSizeOp and the Aten__Getitem__TOp.
|
|
|
|
// `t_` is the only op I can find that changes the shape in-place. It seems
|
|
|
|
// like otherwise we can treat the size of a tensor as having value
|
|
|
|
// semantics. The other view-like ops don't have in-place variants --
|
|
|
|
// they always return a new SSA value that is aliased to the input.
|
|
|
|
// Can we have a pass to normalize the `t_` case and then elsewhere in the
|
|
|
|
// compiler treat the size as having value semantics?
|
|
|
|
// There's a small number of such ops, and they are marked as `inplace_view`
|
|
|
|
// in PyTorch's `native_functions.yaml` file.
|
2024-01-30 01:59:33 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenSizeIntOp>(op, sizeOp.getSelf(),
|
|
|
|
op.getIdx());
|
2022-03-10 08:44:22 +08:00
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2024-06-14 23:59:08 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenMeshgridOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void AtenMeshgridOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenMeshgridOp op, PatternRewriter &rewriter) {
|
|
|
|
Value constIndexing = rewriter.create<Torch::ConstantStrOp>(
|
|
|
|
op->getLoc(), rewriter.getStringAttr("ij"));
|
|
|
|
rewriter.replaceOpWithNewOp<AtenMeshgridIndexingOp>(
|
|
|
|
op, op->getResultTypes(), op.getTensors(), constIndexing);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2024-05-31 09:56:47 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSplitSizesOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void AtenSplitSizesOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenSplitSizesOp op, PatternRewriter &rewriter) {
|
|
|
|
rewriter.replaceOpWithNewOp<AtenSplitWithSizesOp>(
|
|
|
|
op, op->getResultTypes(), op.getSelf(), op.getSplitSize(), op.getDim());
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2023-06-07 17:05:31 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenIsFloatingPointOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenIsFloatingPointOp::fold(FoldAdaptor adaptor) {
|
2024-04-28 05:00:56 +08:00
|
|
|
auto operandType = dyn_cast<BaseTensorType>(getSelf().getType());
|
2023-06-07 17:05:31 +08:00
|
|
|
if (!operandType)
|
|
|
|
return nullptr;
|
|
|
|
if (operandType.hasDtype()) {
|
2024-05-31 14:45:13 +08:00
|
|
|
bool isFloatType = isa<mlir::FloatType>(operandType.getDtype());
|
2023-06-07 17:05:31 +08:00
|
|
|
return IntegerAttr::get(IntegerType::get(getContext(), 1), isFloatType);
|
|
|
|
}
|
|
|
|
// doesn't has dtype
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2022-07-29 07:00:02 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenAddTOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void AtenAddTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenAddTOp op, PatternRewriter &rewriter) {
|
2024-01-30 01:59:33 +08:00
|
|
|
auto lhsListConstruct =
|
|
|
|
op.getA().getDefiningOp<Torch::PrimListConstructOp>();
|
2022-07-29 07:00:02 +08:00
|
|
|
if (!lhsListConstruct || isListPotentiallyMutated(lhsListConstruct))
|
|
|
|
return failure();
|
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
auto rhsListConstruct =
|
|
|
|
op.getB().getDefiningOp<Torch::PrimListConstructOp>();
|
2022-07-29 07:00:02 +08:00
|
|
|
if (!rhsListConstruct || isListPotentiallyMutated(rhsListConstruct))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
SmallVector<Value> concatenatedList;
|
|
|
|
for (auto a : lhsListConstruct.getOperands()) {
|
|
|
|
concatenatedList.push_back(a);
|
|
|
|
}
|
|
|
|
for (auto b : rhsListConstruct.getOperands()) {
|
|
|
|
concatenatedList.push_back(b);
|
|
|
|
}
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::PrimListConstructOp>(op, op.getType(),
|
|
|
|
concatenatedList);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2022-09-27 05:35:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSliceTOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void AtenSliceTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenSliceTOp op, PatternRewriter &rewriter) {
|
2022-12-08 04:20:41 +08:00
|
|
|
auto valueList = op.getL();
|
2022-09-27 05:35:50 +08:00
|
|
|
auto listConstructOp = valueList.getDefiningOp<PrimListConstructOp>();
|
|
|
|
if (!listConstructOp || isListPotentiallyMutated(listConstructOp)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<Value> listElements =
|
2022-12-08 04:20:41 +08:00
|
|
|
llvm::to_vector<4>(listConstructOp.getElements());
|
2022-09-27 05:35:50 +08:00
|
|
|
int64_t size = static_cast<int64_t>(listElements.size());
|
|
|
|
|
|
|
|
int64_t start;
|
|
|
|
int64_t end;
|
|
|
|
int64_t step;
|
2024-05-31 14:45:13 +08:00
|
|
|
if (isa<Torch::NoneType>(op.getStart().getType())) {
|
2022-09-27 05:35:50 +08:00
|
|
|
start = 0;
|
2022-12-08 04:20:41 +08:00
|
|
|
} else if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) {
|
2022-09-27 05:35:50 +08:00
|
|
|
return failure();
|
|
|
|
}
|
2024-05-31 14:45:13 +08:00
|
|
|
if (isa<Torch::NoneType>(op.getEnd().getType())) {
|
2022-09-27 05:35:50 +08:00
|
|
|
end = listElements.size();
|
2022-12-08 04:20:41 +08:00
|
|
|
} else if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) {
|
2022-09-27 05:35:50 +08:00
|
|
|
return failure();
|
|
|
|
}
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) {
|
2022-09-27 05:35:50 +08:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
start = start >= 0 ? start : start + size;
|
|
|
|
start = start >= 0 ? start : 0;
|
|
|
|
end = end >= 0 ? end : end + size;
|
|
|
|
end = end < size ? end : size;
|
|
|
|
SmallVector<Value> newListElements;
|
|
|
|
|
|
|
|
for (int64_t i = start; i < end; i += step) {
|
|
|
|
newListElements.push_back(listElements[i]);
|
|
|
|
}
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<PrimListConstructOp>(
|
|
|
|
op, Torch::ListType::get(listElements[0].getType()), newListElements);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenEqIntListOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenEqIntListOp::fold(FoldAdaptor adaptor) {
|
2022-12-08 04:20:41 +08:00
|
|
|
auto lhsLiteral = getA().getDefiningOp<Torch::PrimListConstructOp>();
|
2022-03-10 08:44:22 +08:00
|
|
|
if (!lhsLiteral)
|
|
|
|
return nullptr;
|
2022-12-08 04:20:41 +08:00
|
|
|
auto rhsLiteral = getB().getDefiningOp<Torch::PrimListConstructOp>();
|
2022-03-10 08:44:22 +08:00
|
|
|
if (!rhsLiteral)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
// If the sizes don't match, then we know the lists aren't equal.
|
|
|
|
if (lhsLiteral.getNumOperands() != rhsLiteral.getNumOperands())
|
|
|
|
return getI1IntegerAttr(getContext(), false);
|
|
|
|
|
|
|
|
// If the sizes match and all corresponding list elements are the same Value,
|
|
|
|
// then we know the lists are equal.
|
|
|
|
// Note that we can't prove that the lists are not-equal with this method,
|
|
|
|
// since two different Value's might dynamically be equal.
|
|
|
|
if (llvm::all_of(
|
|
|
|
llvm::zip(lhsLiteral.getOperands(), rhsLiteral.getOperands()),
|
|
|
|
[](const auto &pair) {
|
|
|
|
return std::get<0>(pair) == std::get<1>(pair);
|
|
|
|
}))
|
|
|
|
return getI1IntegerAttr(getContext(), true);
|
|
|
|
return nullptr;
|
2021-06-05 06:57:21 +08:00
|
|
|
}
|
|
|
|
|
2023-01-24 08:34:22 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimTupleConstructOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult PrimTupleConstructOp::verify() {
|
|
|
|
if (!(isValidSubtype(
|
|
|
|
Torch::TupleType::get(getContext(),
|
|
|
|
llvm::to_vector<6>(getElements().getType())),
|
|
|
|
getResult().getType())))
|
|
|
|
return emitOpError(
|
|
|
|
"failed to verify that contained types correspond to operand types");
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2021-11-08 23:56:40 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimTupleIndexOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](PrimTupleIndexOp op, PatternRewriter &rewriter) {
|
2024-01-30 01:59:33 +08:00
|
|
|
auto tupleConstruct =
|
|
|
|
op.getTup().getDefiningOp<Torch::PrimTupleConstructOp>();
|
2021-11-08 23:56:40 +08:00
|
|
|
if (!tupleConstruct)
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
int64_t i;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getI(), m_TorchConstantInt(&i)))
|
2021-11-08 23:56:40 +08:00
|
|
|
return failure();
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
if (i >= (int64_t)tupleConstruct.getElements().size())
|
2021-11-08 23:56:40 +08:00
|
|
|
return failure();
|
|
|
|
|
2022-05-03 17:12:09 +08:00
|
|
|
// TODO: We should have a clear picture of whether we want to consistently
|
|
|
|
// allow refinement, and where. It seems desirable to require precise
|
|
|
|
// type equality for TupleConstruct / TupleIndex, but that might break
|
|
|
|
// things.
|
2022-12-08 04:20:41 +08:00
|
|
|
Value replacement = tupleConstruct.getElements()[i];
|
2022-05-19 21:12:58 +08:00
|
|
|
if (replacement.getType() != op.getType()) {
|
2024-05-31 14:45:13 +08:00
|
|
|
if (isa<BaseTensorType>(op.getType())) {
|
2022-05-19 21:12:58 +08:00
|
|
|
replacement = rewriter.create<Torch::TensorStaticInfoCastOp>(
|
|
|
|
op.getLoc(), op.getType(), replacement);
|
|
|
|
} else {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
rewriter.replaceOp(op, replacement);
|
2021-11-08 23:56:40 +08:00
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimUninitializedOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void PrimUninitializedOp::getCanonicalizationPatterns(
|
|
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
|
|
patterns.add(+[](PrimUninitializedOp op, PatternRewriter &rewriter) {
|
|
|
|
if (!op.use_empty())
|
|
|
|
return failure();
|
|
|
|
rewriter.eraseOp(op);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2021-08-18 01:59:47 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimTupleUnpackOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void PrimTupleUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](PrimTupleUnpackOp op, PatternRewriter &rewriter) {
|
2024-01-30 01:59:33 +08:00
|
|
|
auto tupleConstruct =
|
|
|
|
op.getTup().getDefiningOp<Torch::PrimTupleConstructOp>();
|
2021-08-18 01:59:47 +08:00
|
|
|
if (!tupleConstruct)
|
|
|
|
return failure();
|
|
|
|
|
2023-07-18 22:32:26 +08:00
|
|
|
llvm::SmallVector<Value> derefinedElements;
|
|
|
|
// The result types may be supertypes of the tuple element types.
|
|
|
|
// Ensure we maintain the exact type, with identity `derefine`s being
|
|
|
|
// folded.
|
|
|
|
for (auto [type, element] :
|
|
|
|
llvm::zip(op.getResultTypes(), tupleConstruct.getElements())) {
|
|
|
|
derefinedElements.push_back(
|
|
|
|
rewriter.createOrFold<DerefineOp>(op.getLoc(), type, element));
|
|
|
|
}
|
|
|
|
rewriter.replaceOp(op, derefinedElements);
|
2021-08-18 01:59:47 +08:00
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2022-07-15 20:10:23 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimListUnpackOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void PrimListUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](PrimListUnpackOp op, PatternRewriter &rewriter) {
|
2022-12-08 04:20:41 +08:00
|
|
|
auto torchList = op.getOperand();
|
2022-07-15 20:10:23 +08:00
|
|
|
if (isListPotentiallyMutated(torchList)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
auto listConstruct = torchList.getDefiningOp<Torch::PrimListConstructOp>();
|
|
|
|
if (!listConstruct)
|
|
|
|
return failure();
|
|
|
|
|
[Torch] Fix PrimListUnpackOp::getCanonicalizationPatterns (#3140)
Fix the case PrimListUnpackOp's result num is not equal to PrimList
length.
See the following example:
```python
def forward(self, x):
if len(x.shape) == 5:
b0, t, c0, h0, w0 = x.shape
b, c, h, w = torch.mul(b0, t), c0, h0, w0
else:
b1, c1, h1, w1 = x.shape
b, c, h, w = b1, c1, h1, w1
res = torch.reshape(x, [b, c, h, w])
return res
```
Without this fix, the following error message will occur:
```
/root/torch-mlir/externals/llvm-project/mlir/lib/IR/PatternMatch.cpp:118: virtual void mlir::RewriterBase::replaceOp(mlir::Operation *, mlir::ValueRange): Assertion `op->getNumResults() == newValues.size() && "incorrect # of replacement values"' failed.
```
2024-04-11 19:48:49 +08:00
|
|
|
if (op->getNumResults() != listConstruct.getElements().size())
|
|
|
|
return failure();
|
|
|
|
|
2024-08-09 07:17:31 +08:00
|
|
|
SmallVector<Value> unpacked;
|
|
|
|
for (int i = 0, s = op->getNumResults(); i < s; ++i) {
|
|
|
|
auto element = listConstruct.getElements()[i];
|
|
|
|
if (element.getType() != op->getResult(i).getType()) {
|
|
|
|
element = rewriter.create<TensorStaticInfoCastOp>(
|
|
|
|
op.getLoc(), op->getResult(i).getType(), element);
|
|
|
|
}
|
|
|
|
|
|
|
|
unpacked.push_back(element);
|
|
|
|
}
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, unpacked);
|
2022-07-15 20:10:23 +08:00
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2021-08-18 01:59:47 +08:00
|
|
|
static PrimDictConstructOp getDictConstructIfNotModified(Value torchDict) {
|
|
|
|
if (!llvm::all_of(torchDict.getUsers(), [](Operation *op) {
|
|
|
|
return isa<Aten__Getitem__DictStrOp, Aten__Contains__StrOp,
|
2021-08-28 05:18:29 +08:00
|
|
|
AtenKeysStrOp, AtenGetDefaultStrOp, PrimDictConstructOp>(op);
|
2021-08-18 01:59:47 +08:00
|
|
|
}))
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
return torchDict.getDefiningOp<Torch::PrimDictConstructOp>();
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten__Getitem__DictStrOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult Aten__Getitem__DictStrOp::fold(FoldAdaptor adaptor) {
|
2022-12-08 04:20:41 +08:00
|
|
|
auto dictConstruct = getDictConstructIfNotModified(getSelf());
|
2021-08-18 01:59:47 +08:00
|
|
|
if (!dictConstruct)
|
|
|
|
return nullptr;
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
auto targetKey = getKey();
|
|
|
|
for (auto i : llvm::zip(dictConstruct.getKeys(), dictConstruct.getValues())) {
|
2021-08-18 01:59:47 +08:00
|
|
|
auto k = std::get<0>(i);
|
|
|
|
if (k == targetKey)
|
|
|
|
return std::get<1>(i);
|
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten__Contains__StrOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult Aten__Contains__StrOp::fold(FoldAdaptor adaptor) {
|
2022-12-08 04:20:41 +08:00
|
|
|
auto dictConstruct = getDictConstructIfNotModified(getDict());
|
2021-08-18 01:59:47 +08:00
|
|
|
if (!dictConstruct)
|
|
|
|
return nullptr;
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
auto targetKey = getKey();
|
|
|
|
for (auto key : dictConstruct.getKeys()) {
|
2021-08-18 01:59:47 +08:00
|
|
|
if (key == targetKey)
|
|
|
|
return getI1IntegerAttr(getContext(), true);
|
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2022-06-23 23:16:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten__Contains__IntListOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static bool isListConstructNotModified(Value torchList) {
|
|
|
|
return llvm::all_of(torchList.getUsers(), [](Operation *op) {
|
2022-08-16 13:24:08 +08:00
|
|
|
return isa<Aten__Contains__IntListOp>(op);
|
|
|
|
});
|
2022-06-23 23:16:09 +08:00
|
|
|
}
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult Aten__Contains__IntListOp::fold(FoldAdaptor adaptor) {
|
2022-12-08 04:20:41 +08:00
|
|
|
auto itemConstruct = getItem();
|
|
|
|
if (!isListConstructNotModified(getL()))
|
2022-06-23 23:16:09 +08:00
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
int64_t item;
|
|
|
|
SmallVector<int64_t> list;
|
|
|
|
|
|
|
|
if (!matchPattern(itemConstruct, m_TorchConstantInt(&item)))
|
|
|
|
return nullptr;
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(getL(), m_TorchListOfConstantInts(list)))
|
2022-06-23 23:16:09 +08:00
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
for (auto elem : list) {
|
|
|
|
if (elem == item)
|
|
|
|
return getI1IntegerAttr(getContext(), true);
|
|
|
|
}
|
|
|
|
return getI1IntegerAttr(getContext(), false);
|
|
|
|
}
|
|
|
|
|
2021-08-18 01:59:47 +08:00
|
|
|
using BinaryIntOperatorFn = std::function<int64_t(int64_t, int64_t)>;
|
2022-09-20 12:40:19 +08:00
|
|
|
static OpFoldResult
|
|
|
|
atenBinaryIntOperatorFoldHelper(ArrayRef<Attribute> operands,
|
|
|
|
BinaryIntOperatorFn f) {
|
2024-05-31 14:45:13 +08:00
|
|
|
auto intLhs = dyn_cast_or_null<IntegerAttr>(operands[0]);
|
|
|
|
auto intRhs = dyn_cast_or_null<IntegerAttr>(operands[1]);
|
2022-09-20 12:40:19 +08:00
|
|
|
if (!intLhs || !intRhs) {
|
2021-08-18 01:59:47 +08:00
|
|
|
return nullptr;
|
2022-09-20 12:40:19 +08:00
|
|
|
}
|
|
|
|
return IntegerAttr::get(
|
|
|
|
intLhs.getType(),
|
|
|
|
f(intLhs.getValue().getSExtValue(), intRhs.getValue().getSExtValue()));
|
|
|
|
}
|
2021-08-18 01:59:47 +08:00
|
|
|
|
2022-09-20 12:40:19 +08:00
|
|
|
using BinaryFloatOperatorFn = std::function<double(double, double)>;
|
|
|
|
static OpFoldResult
|
|
|
|
atenBinaryFloatOperatorFoldHelper(ArrayRef<Attribute> operands,
|
|
|
|
BinaryFloatOperatorFn f) {
|
|
|
|
double lhs, rhs;
|
|
|
|
auto parseDoubleAttribute = [](Attribute attr, double &value) -> bool {
|
2024-04-11 21:47:35 +08:00
|
|
|
if (auto intLhs = dyn_cast_or_null<IntegerAttr>(attr)) {
|
2022-09-20 12:40:19 +08:00
|
|
|
value = static_cast<double>(intLhs.getValue().getSExtValue());
|
2024-04-11 21:47:35 +08:00
|
|
|
} else if (auto floatLhs = dyn_cast_or_null<FloatAttr>(attr)) {
|
2022-09-20 12:40:19 +08:00
|
|
|
value = floatLhs.getValue().convertToDouble();
|
|
|
|
} else {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
};
|
|
|
|
if (!parseDoubleAttribute(operands[0], lhs) ||
|
|
|
|
!parseDoubleAttribute(operands[1], rhs)) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
return getF64FloatAttr(operands[0].getContext(), f(lhs, rhs));
|
2021-08-18 01:59:47 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2023-06-21 01:14:09 +08:00
|
|
|
// AtenAliasOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
OpFoldResult AtenAliasOp::fold(FoldAdaptor adaptor) { return getOperand(); }
|
2023-06-21 01:14:09 +08:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2021-08-18 01:59:47 +08:00
|
|
|
// AtenFloordivIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenFloordivIntOp::fold(FoldAdaptor adaptor) {
|
2021-08-18 01:59:47 +08:00
|
|
|
return atenBinaryIntOperatorFoldHelper(
|
2023-01-25 09:29:42 +08:00
|
|
|
adaptor.getOperands(),
|
|
|
|
[](int64_t a, int64_t b) { return std::floor(a / (double)b); });
|
2021-08-18 01:59:47 +08:00
|
|
|
}
|
|
|
|
|
2024-09-04 00:13:59 +08:00
|
|
|
void AtenFloordivIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenFloordivIntOp op, PatternRewriter &rewriter) {
|
|
|
|
int64_t lhs, rhs;
|
|
|
|
bool lConstant = matchPattern(op.getA(), m_TorchConstantInt(&lhs));
|
|
|
|
bool rConstant = matchPattern(op.getB(), m_TorchConstantInt(&rhs));
|
|
|
|
if (lConstant && rConstant)
|
|
|
|
return failure();
|
|
|
|
if (lConstant || rConstant) {
|
|
|
|
int64_t firstConstant = lConstant ? lhs : rhs;
|
|
|
|
Value firstOperand = lConstant ? op.getB() : op.getA();
|
|
|
|
if (firstOperand.getDefiningOp() &&
|
|
|
|
firstOperand.getDefiningOp<AtenMulIntOp>()) {
|
|
|
|
auto prevMulIntOp = firstOperand.getDefiningOp<AtenMulIntOp>();
|
|
|
|
int64_t prevLhs, prevRhs;
|
|
|
|
bool prevLConstant =
|
|
|
|
matchPattern(prevMulIntOp.getA(), m_TorchConstantInt(&prevLhs));
|
|
|
|
bool prevRConstant =
|
|
|
|
matchPattern(prevMulIntOp.getB(), m_TorchConstantInt(&prevRhs));
|
|
|
|
if (prevLConstant && prevRConstant)
|
|
|
|
return failure();
|
|
|
|
if ((prevLConstant || prevRConstant) &&
|
|
|
|
prevMulIntOp->hasOneUse() == 1) {
|
|
|
|
int64_t secondConstant = prevLConstant ? prevLhs : prevRhs;
|
|
|
|
if (secondConstant == firstConstant) {
|
|
|
|
rewriter.replaceAllUsesWith(
|
|
|
|
op.getResult(), prevMulIntOp.getOperand(prevLConstant ? 1 : 0));
|
|
|
|
rewriter.eraseOp(op);
|
|
|
|
rewriter.eraseOp(prevMulIntOp);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return failure();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2021-08-18 01:59:47 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenRemainderIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenRemainderIntOp::fold(FoldAdaptor adaptor) {
|
2021-08-18 01:59:47 +08:00
|
|
|
return atenBinaryIntOperatorFoldHelper(
|
2023-01-25 09:29:42 +08:00
|
|
|
adaptor.getOperands(), [](int64_t a, int64_t b) { return a % b; });
|
2021-08-18 01:59:47 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenAddIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenAddIntOp::fold(FoldAdaptor adaptor) {
|
2021-08-18 01:59:47 +08:00
|
|
|
return atenBinaryIntOperatorFoldHelper(
|
2023-01-25 09:29:42 +08:00
|
|
|
adaptor.getOperands(), [](int64_t a, int64_t b) { return a + b; });
|
2021-08-18 01:59:47 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSubIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenSubIntOp::fold(FoldAdaptor adaptor) {
|
2021-08-18 01:59:47 +08:00
|
|
|
return atenBinaryIntOperatorFoldHelper(
|
2023-01-25 09:29:42 +08:00
|
|
|
adaptor.getOperands(), [](int64_t a, int64_t b) { return a - b; });
|
2021-08-18 01:59:47 +08:00
|
|
|
}
|
|
|
|
|
2022-12-14 05:02:47 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenCatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) {
|
2024-03-05 03:46:49 +08:00
|
|
|
// We set a maximum folding size of 16. This is a reasonable upper limit
|
|
|
|
// for shape computations.
|
|
|
|
constexpr int64_t kMaxFoldSize = 16;
|
2022-12-14 05:02:47 +08:00
|
|
|
auto list = getOperand(0).getDefiningOp<PrimListConstructOp>();
|
2024-03-05 03:46:49 +08:00
|
|
|
if (!list)
|
2022-12-14 05:02:47 +08:00
|
|
|
return nullptr;
|
2024-03-05 03:46:49 +08:00
|
|
|
|
|
|
|
auto elements = list.getElements();
|
|
|
|
if (elements.size() == 1 && elements[0].getType() == getResult().getType())
|
|
|
|
return list.getElements()[0];
|
|
|
|
|
|
|
|
auto resultTy = dyn_cast<ValueTensorType>(getType());
|
|
|
|
if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype())
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
auto bResultTy = resultTy.toBuiltinTensor();
|
|
|
|
if (!bResultTy.hasStaticShape() || bResultTy.getNumElements() > kMaxFoldSize)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
auto dimAttr = dyn_cast_or_null<IntegerAttr>(adaptor.getDim());
|
|
|
|
if (!dimAttr)
|
2024-01-05 06:33:41 +08:00
|
|
|
return nullptr;
|
2024-03-05 03:46:49 +08:00
|
|
|
auto dim = dimAttr.getValue().getSExtValue();
|
|
|
|
dim += dim < 0 ? bResultTy.getRank() : 0;
|
|
|
|
|
|
|
|
for (int i = 0, s = bResultTy.getRank(); i < s; ++i) {
|
|
|
|
if (i == dim)
|
|
|
|
continue;
|
|
|
|
if (bResultTy.getDimSize(i) != 1)
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
llvm::SmallVector<Attribute> values;
|
|
|
|
for (auto operand : list.getOperands()) {
|
|
|
|
DenseElementsAttr dattr;
|
|
|
|
if (!matchPattern(operand, m_Constant(&dattr)))
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
auto oty = dyn_cast<RankedTensorType>(dattr.getType());
|
|
|
|
if (!oty)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
if (dattr.isSplat()) {
|
|
|
|
for (int i = 0, s = oty.getDimSize(dim); i < s; ++i)
|
|
|
|
values.push_back(dattr.getSplatValue<Attribute>());
|
|
|
|
} else {
|
|
|
|
auto evals = dattr.getValues<Attribute>();
|
|
|
|
for (int i = 0, s = oty.getDimSize(dim); i < s; ++i)
|
|
|
|
values.push_back(evals[i]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return DenseElementsAttr::get(bResultTy.clone(resultTy.getDtype()), values);
|
2022-12-14 05:02:47 +08:00
|
|
|
}
|
|
|
|
|
2024-03-02 13:41:12 +08:00
|
|
|
void AtenCatOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenCatOp op, PatternRewriter &rewriter) {
|
|
|
|
auto list = op.getTensors().getDefiningOp<PrimListConstructOp>();
|
|
|
|
auto resultTy = dyn_cast<BaseTensorType>(op.getType());
|
|
|
|
if (!list || !resultTy)
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
int64_t dim;
|
|
|
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
llvm::SmallVector<Value> filtered;
|
|
|
|
for (auto operand : list.getOperands()) {
|
|
|
|
auto operandTy = dyn_cast<BaseTensorType>(operand.getType());
|
|
|
|
if (!operandTy || !operandTy.hasSizes())
|
|
|
|
return failure();
|
|
|
|
int64_t adim = dim < 0 ? dim + operandTy.getSizes().size() : dim;
|
|
|
|
if (operandTy.getSizes()[adim] != 0)
|
|
|
|
filtered.push_back(operand);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (filtered.size() == list.getNumOperands())
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
auto newlist = rewriter.create<PrimListConstructOp>(
|
|
|
|
op.getLoc(), list.getType(), filtered);
|
|
|
|
rewriter.replaceOpWithNewOp<AtenCatOp>(op, op.getType(), newlist,
|
|
|
|
op.getDim());
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2023-09-02 02:50:34 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenBroadcastToOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) {
|
2024-04-28 05:00:56 +08:00
|
|
|
auto inType = dyn_cast<BaseTensorType>(getOperand(0).getType());
|
|
|
|
auto outType = dyn_cast<BaseTensorType>(getResult().getType());
|
2024-03-05 03:46:49 +08:00
|
|
|
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() ||
|
|
|
|
!outType.hasDtype())
|
2023-09-02 02:50:34 +08:00
|
|
|
return nullptr;
|
2024-03-05 03:46:49 +08:00
|
|
|
|
|
|
|
if (!inType.areAllSizesKnown() || !outType.areAllSizesKnown())
|
2023-09-02 02:50:34 +08:00
|
|
|
return nullptr;
|
2024-03-05 03:46:49 +08:00
|
|
|
|
|
|
|
auto inSizes = inType.getSizes();
|
|
|
|
auto outSizes = outType.getSizes();
|
|
|
|
if (inSizes.size() == outSizes.size()) {
|
|
|
|
bool sameSizes = true;
|
|
|
|
for (int i = 0, s = inSizes.size(); i < s; ++i)
|
|
|
|
sameSizes &= inSizes[i] == outSizes[i];
|
|
|
|
|
|
|
|
if (sameSizes)
|
|
|
|
return getOperand(0);
|
2023-09-02 02:50:34 +08:00
|
|
|
}
|
2024-03-05 03:46:49 +08:00
|
|
|
|
|
|
|
auto selfAttr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
|
|
|
|
if (!selfAttr)
|
|
|
|
return nullptr;
|
|
|
|
if (!selfAttr.isSplat())
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
auto attrty = RankedTensorType::get(outType.getSizes(), outType.getDtype());
|
|
|
|
return DenseElementsAttr::get(attrty, selfAttr.getSplatValue<Attribute>());
|
2023-09-02 02:50:34 +08:00
|
|
|
}
|
|
|
|
|
2022-12-14 05:02:47 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSliceTensorOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
|
2024-02-08 03:00:46 +08:00
|
|
|
DenseElementsAttr input =
|
|
|
|
dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
|
|
|
|
IntegerAttr start = dyn_cast_or_null<IntegerAttr>(adaptor.getStart());
|
|
|
|
IntegerAttr end = dyn_cast_or_null<IntegerAttr>(adaptor.getEnd());
|
|
|
|
IntegerAttr step = dyn_cast_or_null<IntegerAttr>(adaptor.getStep());
|
|
|
|
IntegerAttr dim = dyn_cast_or_null<IntegerAttr>(adaptor.getDim());
|
2024-04-28 05:00:56 +08:00
|
|
|
auto inType = dyn_cast<ValueTensorType>(getOperand(0).getType());
|
|
|
|
auto outType = dyn_cast<ValueTensorType>(getResult().getType());
|
2024-02-08 03:00:46 +08:00
|
|
|
|
2024-05-16 11:42:43 +08:00
|
|
|
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() ||
|
|
|
|
!inType.hasDtype() || !outType.hasDtype() ||
|
|
|
|
inType.getDtype() != outType.getDtype())
|
|
|
|
return nullptr;
|
|
|
|
|
2024-02-08 03:00:46 +08:00
|
|
|
if (start && end && step && step.getValue().getSExtValue() == 1 &&
|
|
|
|
start.getValue().getSExtValue() == 0 &&
|
2024-04-13 02:43:45 +08:00
|
|
|
end.getValue().getSExtValue() == std::numeric_limits<int64_t>::max() &&
|
|
|
|
inType == outType)
|
2024-01-30 01:59:33 +08:00
|
|
|
return getOperand(0);
|
2023-07-20 15:53:54 +08:00
|
|
|
|
2022-12-14 05:02:47 +08:00
|
|
|
if (inType.getSizes().size() != outType.getSizes().size() ||
|
|
|
|
!inType.areAllSizesKnown() || !outType.areAllSizesKnown())
|
|
|
|
return nullptr;
|
2024-02-08 03:00:46 +08:00
|
|
|
|
|
|
|
if (input && input.isSplat())
|
2024-06-08 09:36:32 +08:00
|
|
|
return DenseElementsAttr::get(outType.toBuiltinTensor(),
|
|
|
|
input.getSplatValue<Attribute>());
|
2024-02-08 03:00:46 +08:00
|
|
|
|
2024-07-23 22:53:03 +08:00
|
|
|
int64_t count = 1;
|
2024-03-05 03:46:49 +08:00
|
|
|
for (auto dim : outType.getSizes())
|
|
|
|
count = count * dim;
|
|
|
|
if (count == 0)
|
2024-07-23 22:53:03 +08:00
|
|
|
return nullptr;
|
2024-03-05 03:46:49 +08:00
|
|
|
|
|
|
|
if (!dim)
|
|
|
|
return nullptr;
|
|
|
|
int64_t dimInt = dim.getValue().getSExtValue();
|
|
|
|
if (dimInt < 0)
|
|
|
|
dimInt += inType.getSizes().size();
|
|
|
|
|
|
|
|
// Fold the slice if the output tensor is relatively small, currently
|
|
|
|
// coded to 16:
|
2024-07-23 22:53:03 +08:00
|
|
|
constexpr int64_t kMaxFold = 16;
|
|
|
|
if (input && start && step && dim && count <= kMaxFold) {
|
2024-03-05 03:46:49 +08:00
|
|
|
int64_t begin = start.getValue().getSExtValue();
|
2024-07-23 22:53:03 +08:00
|
|
|
int64_t limit = end.getValue().getSExtValue();
|
2024-03-05 03:46:49 +08:00
|
|
|
int64_t stride = step.getValue().getSExtValue();
|
|
|
|
if (stride < 1)
|
2024-07-23 22:53:03 +08:00
|
|
|
return nullptr;
|
|
|
|
begin = begin < 0 ? begin + inType.getSizes()[dimInt] : begin;
|
|
|
|
limit = limit < 0 ? limit + inType.getSizes()[dimInt] : limit;
|
2024-03-05 03:46:49 +08:00
|
|
|
limit = std::min(limit, inType.getSizes()[dimInt]);
|
|
|
|
|
2024-07-23 22:53:03 +08:00
|
|
|
int64_t inputRank = inType.getSizes().size();
|
|
|
|
llvm::SmallVector<int64_t> inputStrides(inputRank, 1);
|
|
|
|
for (int64_t i = inputRank - 2; i >= 0; i--) {
|
|
|
|
inputStrides[i] = inputStrides[i + 1] * inType.getSizes()[i + 1];
|
|
|
|
}
|
2024-03-05 03:46:49 +08:00
|
|
|
|
2024-07-23 22:53:03 +08:00
|
|
|
llvm::SmallVector<Attribute> values;
|
|
|
|
values.reserve(count);
|
|
|
|
auto recursiveIter = [&](auto &self, int64_t currDim, int64_t currOffset) {
|
|
|
|
if (currDim >= inputRank)
|
|
|
|
return;
|
|
|
|
size_t _begin = (currDim == dimInt) ? begin : 0;
|
|
|
|
size_t _limit = (currDim == dimInt) ? limit : inType.getSizes()[currDim];
|
|
|
|
size_t _stride = (currDim == dimInt) ? stride : 1;
|
|
|
|
for (size_t i = _begin; i < _limit; i += _stride) {
|
|
|
|
if (currDim == inputRank - 1) {
|
|
|
|
values.push_back(input.getValues<Attribute>()[currOffset + i]);
|
|
|
|
}
|
|
|
|
self(self, currDim + 1, currOffset + inputStrides[currDim] * i);
|
|
|
|
}
|
|
|
|
};
|
|
|
|
recursiveIter(recursiveIter, 0, 0);
|
2024-06-08 09:36:32 +08:00
|
|
|
return DenseElementsAttr::get(outType.toBuiltinTensor(), values);
|
2024-02-08 03:00:46 +08:00
|
|
|
}
|
|
|
|
|
[ONNX] Fix bug in ONNXToTorch PadOp's pads tensor rearrangement (#3485)
Fix the pad tensor rearrangement such that we change the representation
from [x1_begin, x2_begin, ..., x1_end, x2_end,...] to [xn_begin, xn_end,
...., x2_begin, x2_end, x1_begin, x1_end] where x1, x2 .. xn are the
dimensions of the pads tensor argument.
---------
Co-authored-by: zjgarvey <zjgarvey@gmail.com>
Co-authored-by: zjgarvey <47986913+zjgarvey@users.noreply.github.com>
2024-07-04 04:02:49 +08:00
|
|
|
// If the input and output shapes are the same & step == 1 we can fold:
|
|
|
|
if (!step || step.getValue().getSExtValue() != 1)
|
|
|
|
return nullptr;
|
2022-12-14 05:02:47 +08:00
|
|
|
for (size_t i = 0; i < inType.getSizes().size(); ++i) {
|
|
|
|
if (inType.getSizes()[i] != outType.getSizes()[i])
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
return getOperand(0);
|
|
|
|
}
|
|
|
|
|
2021-08-18 01:59:47 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenMulIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) {
|
2021-08-18 01:59:47 +08:00
|
|
|
int64_t lhs, rhs;
|
|
|
|
bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs));
|
|
|
|
bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs));
|
|
|
|
if ((lConstant && lhs == 0) || (rConstant && rhs == 0))
|
|
|
|
return getI64IntegerAttr(getContext(), 0);
|
|
|
|
if (lConstant && rConstant)
|
|
|
|
return getI64IntegerAttr(getContext(), lhs * rhs);
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2024-09-04 00:13:59 +08:00
|
|
|
void AtenMulIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenMulIntOp op, PatternRewriter &rewriter) {
|
|
|
|
int64_t lhs, rhs;
|
|
|
|
bool lConstant = matchPattern(op.getA(), m_TorchConstantInt(&lhs));
|
|
|
|
bool rConstant = matchPattern(op.getB(), m_TorchConstantInt(&rhs));
|
|
|
|
if (lConstant && rConstant)
|
|
|
|
return failure();
|
|
|
|
if (lConstant || rConstant) {
|
|
|
|
int64_t firstConstant = lConstant ? lhs : rhs;
|
|
|
|
Value firstOperand = lConstant ? op.getB() : op.getA();
|
|
|
|
if (firstOperand.getDefiningOp() &&
|
|
|
|
firstOperand.getDefiningOp<AtenMulIntOp>()) {
|
|
|
|
auto prevMulIntOp = firstOperand.getDefiningOp<AtenMulIntOp>();
|
|
|
|
int64_t prevLhs, prevRhs;
|
|
|
|
bool prevLConstant =
|
|
|
|
matchPattern(prevMulIntOp.getA(), m_TorchConstantInt(&prevLhs));
|
|
|
|
bool prevRConstant =
|
|
|
|
matchPattern(prevMulIntOp.getB(), m_TorchConstantInt(&prevRhs));
|
|
|
|
if (prevLConstant && prevRConstant)
|
|
|
|
return failure();
|
|
|
|
if ((prevLConstant || prevRConstant) &&
|
|
|
|
prevMulIntOp->hasOneUse() == 1) {
|
|
|
|
auto newConstant = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
op.getLoc(), rewriter.getI64IntegerAttr(
|
|
|
|
prevLConstant ? prevLhs * firstConstant
|
|
|
|
: prevRhs * firstConstant));
|
|
|
|
rewriter.replaceOpWithNewOp<AtenMulIntOp>(
|
|
|
|
op, op.getType(), prevMulIntOp.getOperand(prevLConstant ? 1 : 0),
|
|
|
|
newConstant);
|
|
|
|
rewriter.eraseOp(prevMulIntOp);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return failure();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2023-06-29 10:37:13 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenMulFloatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenMulFloatOp::fold(FoldAdaptor adaptor) {
|
|
|
|
return atenBinaryFloatOperatorFoldHelper(
|
|
|
|
adaptor.getOperands(), [](double a, double b) { return a * b; });
|
|
|
|
}
|
|
|
|
|
2023-03-03 01:07:33 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSubFloatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenSubFloatOp::fold(FoldAdaptor adaptor) {
|
|
|
|
return atenBinaryFloatOperatorFoldHelper(
|
|
|
|
adaptor.getOperands(), [](double a, double b) { return a - b; });
|
|
|
|
}
|
|
|
|
|
2023-06-27 10:55:28 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenAddOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenAddOp::fold(FoldAdaptor adaptor) {
|
|
|
|
if (!adaptor.getA() || !adaptor.getB()) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2024-05-31 14:45:13 +08:00
|
|
|
if (isa<IntegerAttr>(adaptor.getA()) && isa<IntegerAttr>(adaptor.getB())) {
|
2023-06-27 10:55:28 +08:00
|
|
|
return atenBinaryIntOperatorFoldHelper(
|
|
|
|
adaptor.getOperands(),
|
|
|
|
[](int64_t a, int64_t b) -> int64_t { return a + b; });
|
|
|
|
}
|
|
|
|
return atenBinaryFloatOperatorFoldHelper(
|
|
|
|
adaptor.getOperands(),
|
|
|
|
[](double a, double b) -> double { return a + b; });
|
|
|
|
}
|
|
|
|
|
2024-03-11 19:59:34 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenMulOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenMulOp::fold(FoldAdaptor adaptor) {
|
|
|
|
if (!adaptor.getA() || !adaptor.getB()) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2024-05-31 14:45:13 +08:00
|
|
|
if (isa<IntegerAttr>(adaptor.getA()) && isa<IntegerAttr>(adaptor.getB())) {
|
2024-03-11 19:59:34 +08:00
|
|
|
return atenBinaryIntOperatorFoldHelper(
|
|
|
|
adaptor.getOperands(),
|
|
|
|
[](int64_t a, int64_t b) -> int64_t { return a * b; });
|
|
|
|
}
|
|
|
|
return atenBinaryFloatOperatorFoldHelper(
|
|
|
|
adaptor.getOperands(),
|
|
|
|
[](double a, double b) -> double { return a * b; });
|
|
|
|
}
|
|
|
|
|
2022-09-20 12:40:19 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSubOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenSubOp::fold(FoldAdaptor adaptor) {
|
|
|
|
if (!adaptor.getA() || !adaptor.getB()) {
|
2022-09-20 12:40:19 +08:00
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2024-05-31 14:45:13 +08:00
|
|
|
if (isa<IntegerAttr>(adaptor.getA()) && isa<IntegerAttr>(adaptor.getB())) {
|
2022-09-20 12:40:19 +08:00
|
|
|
return atenBinaryIntOperatorFoldHelper(
|
2023-01-25 09:29:42 +08:00
|
|
|
adaptor.getOperands(),
|
|
|
|
[](int64_t a, int64_t b) -> int64_t { return a - b; });
|
2022-09-20 12:40:19 +08:00
|
|
|
}
|
|
|
|
return atenBinaryFloatOperatorFoldHelper(
|
2023-01-25 09:29:42 +08:00
|
|
|
adaptor.getOperands(),
|
|
|
|
[](double a, double b) -> double { return a - b; });
|
2022-09-20 12:40:19 +08:00
|
|
|
}
|
|
|
|
|
2022-09-20 22:31:24 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenDivOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenDivOp::fold(FoldAdaptor adaptor) {
|
|
|
|
if (!adaptor.getA() || !adaptor.getB()) {
|
2022-09-20 22:31:24 +08:00
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
// Since AtenDivOp always returns float value, we don't need to deal with the
|
|
|
|
// case where the operands are both integers separately.
|
|
|
|
return atenBinaryFloatOperatorFoldHelper(
|
2023-01-25 09:29:42 +08:00
|
|
|
adaptor.getOperands(),
|
|
|
|
[](double a, double b) -> double { return a / b; });
|
2022-09-20 22:31:24 +08:00
|
|
|
}
|
|
|
|
|
2023-06-29 10:37:13 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenAddFloatIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenAddFloatIntOp::fold(FoldAdaptor adaptor) {
|
|
|
|
if (!adaptor.getA() || !adaptor.getB()) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
return atenBinaryFloatOperatorFoldHelper(
|
|
|
|
adaptor.getOperands(), [](double a, double b) { return a + b; });
|
|
|
|
}
|
|
|
|
|
2023-03-01 01:36:05 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenPowIntFloatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenPowIntFloatOp::fold(FoldAdaptor adaptor) {
|
|
|
|
if (!adaptor.getA() || !adaptor.getB()) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
return atenBinaryFloatOperatorFoldHelper(
|
|
|
|
adaptor.getOperands(), [](double a, double b) { return std::pow(a, b); });
|
|
|
|
}
|
|
|
|
|
2022-09-20 12:40:19 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenCeilScalarOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenCeilScalarOp::fold(FoldAdaptor adaptor) {
|
|
|
|
if (!adaptor.getA()) {
|
2022-09-20 12:40:19 +08:00
|
|
|
return nullptr;
|
|
|
|
}
|
2024-05-31 14:45:13 +08:00
|
|
|
auto floatValue = dyn_cast_or_null<FloatAttr>(adaptor.getA());
|
2022-09-20 12:40:19 +08:00
|
|
|
if (!floatValue) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
return getI64IntegerAttr(
|
|
|
|
getContext(),
|
|
|
|
static_cast<int64_t>(std::ceil(floatValue.getValue().convertToDouble())));
|
|
|
|
}
|
|
|
|
|
2022-01-11 15:42:53 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenNegIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenNegIntOp::fold(FoldAdaptor adaptor) {
|
2022-01-11 15:42:53 +08:00
|
|
|
int64_t c;
|
|
|
|
if (matchPattern(getOperand(), m_TorchConstantInt(&c)))
|
|
|
|
return getI64IntegerAttr(getContext(), -c);
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2023-06-29 10:37:13 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenNegFloatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenNegFloatOp::fold(FoldAdaptor adaptor) {
|
|
|
|
if (!adaptor.getA()) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
2024-05-31 14:45:13 +08:00
|
|
|
auto value = dyn_cast_or_null<FloatAttr>(adaptor.getA());
|
2023-06-29 10:37:13 +08:00
|
|
|
if (!value) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
return getF64FloatAttr(getContext(), -value.getValue().convertToDouble());
|
|
|
|
}
|
|
|
|
|
2022-05-19 22:54:16 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSqrtIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenSqrtIntOp::fold(FoldAdaptor adaptor) {
|
2022-05-19 22:54:16 +08:00
|
|
|
int64_t c;
|
|
|
|
if (matchPattern(getOperand(), m_TorchConstantInt(&c)))
|
|
|
|
return getF64FloatAttr(getContext(), std::sqrt(c));
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2021-09-02 03:53:52 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimDtypeOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult PrimDtypeOp::fold(FoldAdaptor adaptor) {
|
2024-04-28 05:00:56 +08:00
|
|
|
BaseTensorType tensorType = cast<BaseTensorType>(getA().getType());
|
2021-09-02 03:53:52 +08:00
|
|
|
if (tensorType.hasDtype()) {
|
2022-07-08 05:21:05 +08:00
|
|
|
torch_upstream::ScalarType scalarType =
|
|
|
|
Torch::getScalarTypeForType(tensorType.getDtype());
|
|
|
|
return getI64IntegerAttr(getContext(), static_cast<int64_t>(scalarType));
|
2021-09-02 03:53:52 +08:00
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
2021-11-30 02:39:37 +08:00
|
|
|
|
2023-05-03 11:05:46 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimDeviceOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void PrimDeviceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](PrimDeviceOp op, PatternRewriter &rewriter) {
|
|
|
|
// Device information isn't relevant to torch-mlir, just replace it with
|
|
|
|
// "cpu".
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::ConstantDeviceOp>(op, "cpu");
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2023-06-14 09:56:39 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenCudaOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void AtenCudaOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenCudaOp op, PatternRewriter &rewriter) {
|
|
|
|
// Device information isn't relevant to torch-mlir
|
|
|
|
auto inputTensor = op.getSelf();
|
|
|
|
rewriter.replaceOp(op, inputTensor);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2023-06-23 01:07:14 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenDeviceWithIndexOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void AtenDeviceWithIndexOp::getCanonicalizationPatterns(
|
|
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenDeviceWithIndexOp op, PatternRewriter &rewriter) {
|
|
|
|
std::string type;
|
|
|
|
int64_t index;
|
|
|
|
if (!matchPattern(op.getType(), m_TorchConstantStr(type))) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: type must be a constant string");
|
|
|
|
}
|
|
|
|
if (!matchPattern(op.getIndex(), m_TorchConstantInt(&index))) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: index must be a constant integer");
|
|
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::ConstantDeviceOp>(
|
|
|
|
op, type + ":" + std::to_string(index));
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2024-02-06 09:10:42 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenTensorOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) {
|
|
|
|
// If a torch.aten.tensor op is initialized by a list with a constant, single
|
|
|
|
// element, fold it into a torch.vtensor.literal
|
|
|
|
auto resultTy = dyn_cast<ValueTensorType>(getType());
|
[Torch] Add folder for AtenIntOp, AtenFloatOp (#3189)
See unit test below:
```
// CHECK-LABEL: func.func @torch.aten.tensor.float(
// CHECK-NEXT: torch.vtensor.literal(dense<1.000000e+01> : tensor<f32>) : !torch.vtensor<[],f32>
func.func @torch.aten.tensor.float() -> !torch.vtensor<[],f32> {
%none = torch.constant.none
%false = torch.constant.bool false
%float1.000000e01 = torch.constant.float 1.000000e+01
%67 = torch.aten.tensor.float %float1.000000e01, %none, %none, %false : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],f32>
return %67 : !torch.vtensor<[],f32>
}
// CHECK-LABEL: func.func @torch.aten.tensor.int(
// CHECK-NEXT: torch.vtensor.literal(dense<45> : tensor<si32>) : !torch.vtensor<[],si32>
func.func @torch.aten.tensor.int() -> !torch.vtensor<[],si32> {
%none = torch.constant.none
%false = torch.constant.bool false
%int45 = torch.constant.int 45
%67 = torch.aten.tensor.int %int45, %none, %none, %false : !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],si32>
return %67 : !torch.vtensor<[],si32>
}
```
2024-04-19 22:17:06 +08:00
|
|
|
if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype())
|
|
|
|
return nullptr;
|
2024-02-06 09:10:42 +08:00
|
|
|
Type eTy = resultTy.getDtype();
|
2024-06-08 09:36:32 +08:00
|
|
|
ShapedType shapedTy = resultTy.toBuiltinTensor();
|
2024-02-06 09:10:42 +08:00
|
|
|
|
|
|
|
SmallVector<int64_t> data;
|
|
|
|
if (matchPattern(getData(), m_TorchListOfConstantInts(data)) &&
|
|
|
|
data.size() == 1) {
|
|
|
|
Attribute attribute = IntegerAttr::get(eTy, data[0]);
|
|
|
|
return DenseElementsAttr::get(shapedTy, attribute);
|
|
|
|
}
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2024-02-29 04:04:52 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
[Torch] Add folder for AtenIntOp, AtenFloatOp (#3189)
See unit test below:
```
// CHECK-LABEL: func.func @torch.aten.tensor.float(
// CHECK-NEXT: torch.vtensor.literal(dense<1.000000e+01> : tensor<f32>) : !torch.vtensor<[],f32>
func.func @torch.aten.tensor.float() -> !torch.vtensor<[],f32> {
%none = torch.constant.none
%false = torch.constant.bool false
%float1.000000e01 = torch.constant.float 1.000000e+01
%67 = torch.aten.tensor.float %float1.000000e01, %none, %none, %false : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],f32>
return %67 : !torch.vtensor<[],f32>
}
// CHECK-LABEL: func.func @torch.aten.tensor.int(
// CHECK-NEXT: torch.vtensor.literal(dense<45> : tensor<si32>) : !torch.vtensor<[],si32>
func.func @torch.aten.tensor.int() -> !torch.vtensor<[],si32> {
%none = torch.constant.none
%false = torch.constant.bool false
%int45 = torch.constant.int 45
%67 = torch.aten.tensor.int %int45, %none, %none, %false : !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],si32>
return %67 : !torch.vtensor<[],si32>
}
```
2024-04-19 22:17:06 +08:00
|
|
|
// AtenTensorIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenTensorIntOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto resultTy = dyn_cast<ValueTensorType>(getType());
|
|
|
|
if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype())
|
|
|
|
return nullptr;
|
|
|
|
Type eTy = resultTy.getDtype();
|
2024-06-08 09:36:32 +08:00
|
|
|
ShapedType shapedTy = resultTy.toBuiltinTensor();
|
[Torch] Add folder for AtenIntOp, AtenFloatOp (#3189)
See unit test below:
```
// CHECK-LABEL: func.func @torch.aten.tensor.float(
// CHECK-NEXT: torch.vtensor.literal(dense<1.000000e+01> : tensor<f32>) : !torch.vtensor<[],f32>
func.func @torch.aten.tensor.float() -> !torch.vtensor<[],f32> {
%none = torch.constant.none
%false = torch.constant.bool false
%float1.000000e01 = torch.constant.float 1.000000e+01
%67 = torch.aten.tensor.float %float1.000000e01, %none, %none, %false : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],f32>
return %67 : !torch.vtensor<[],f32>
}
// CHECK-LABEL: func.func @torch.aten.tensor.int(
// CHECK-NEXT: torch.vtensor.literal(dense<45> : tensor<si32>) : !torch.vtensor<[],si32>
func.func @torch.aten.tensor.int() -> !torch.vtensor<[],si32> {
%none = torch.constant.none
%false = torch.constant.bool false
%int45 = torch.constant.int 45
%67 = torch.aten.tensor.int %int45, %none, %none, %false : !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],si32>
return %67 : !torch.vtensor<[],si32>
}
```
2024-04-19 22:17:06 +08:00
|
|
|
|
|
|
|
int64_t data;
|
|
|
|
if (matchPattern(getT(), m_TorchConstantInt(&data))) {
|
|
|
|
Attribute attribute = IntegerAttr::get(eTy, data);
|
|
|
|
return DenseElementsAttr::get(shapedTy, attribute);
|
|
|
|
}
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenTensorFloatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenTensorFloatOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto resultTy = dyn_cast<ValueTensorType>(getType());
|
|
|
|
if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype())
|
|
|
|
return nullptr;
|
|
|
|
Type eTy = resultTy.getDtype();
|
2024-06-08 09:36:32 +08:00
|
|
|
ShapedType shapedTy = resultTy.toBuiltinTensor();
|
[Torch] Add folder for AtenIntOp, AtenFloatOp (#3189)
See unit test below:
```
// CHECK-LABEL: func.func @torch.aten.tensor.float(
// CHECK-NEXT: torch.vtensor.literal(dense<1.000000e+01> : tensor<f32>) : !torch.vtensor<[],f32>
func.func @torch.aten.tensor.float() -> !torch.vtensor<[],f32> {
%none = torch.constant.none
%false = torch.constant.bool false
%float1.000000e01 = torch.constant.float 1.000000e+01
%67 = torch.aten.tensor.float %float1.000000e01, %none, %none, %false : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],f32>
return %67 : !torch.vtensor<[],f32>
}
// CHECK-LABEL: func.func @torch.aten.tensor.int(
// CHECK-NEXT: torch.vtensor.literal(dense<45> : tensor<si32>) : !torch.vtensor<[],si32>
func.func @torch.aten.tensor.int() -> !torch.vtensor<[],si32> {
%none = torch.constant.none
%false = torch.constant.bool false
%int45 = torch.constant.int 45
%67 = torch.aten.tensor.int %int45, %none, %none, %false : !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],si32>
return %67 : !torch.vtensor<[],si32>
}
```
2024-04-19 22:17:06 +08:00
|
|
|
|
|
|
|
double data;
|
|
|
|
if (matchPattern(getT(), m_TorchConstantFloat(&data))) {
|
|
|
|
Attribute attribute = FloatAttr::get(eTy, data);
|
|
|
|
return DenseElementsAttr::get(shapedTy, attribute);
|
|
|
|
}
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten_ShapeAsTensorOp
|
2024-02-29 04:04:52 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult Aten_ShapeAsTensorOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto selfTy = dyn_cast<BaseTensorType>(getSelf().getType());
|
|
|
|
auto resultTy = dyn_cast<BaseTensorType>(getType());
|
|
|
|
if (!selfTy || !resultTy || !selfTy.hasSizes() || !resultTy.hasDtype() ||
|
|
|
|
!resultTy.hasSizes())
|
|
|
|
return {};
|
|
|
|
|
|
|
|
llvm::SmallVector<int64_t> values(selfTy.getSizes());
|
|
|
|
if (llvm::any_of(values, [](int64_t d) { return d == Torch::kUnknownSize; }))
|
|
|
|
return {};
|
|
|
|
|
|
|
|
auto dty = dyn_cast<IntegerType>(resultTy.getDtype());
|
|
|
|
if (!dty)
|
|
|
|
return {};
|
|
|
|
|
|
|
|
llvm::SmallVector<Attribute> attrs;
|
|
|
|
for (auto val : values) {
|
|
|
|
attrs.push_back(IntegerAttr::get(dty, val));
|
|
|
|
}
|
|
|
|
|
|
|
|
auto attrty = RankedTensorType::get(resultTy.getSizes(), dty);
|
|
|
|
return DenseElementsAttr::get(attrty, attrs);
|
|
|
|
}
|
|
|
|
|
2022-02-09 19:55:14 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenIntTensorOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2024-03-27 09:51:58 +08:00
|
|
|
void AtenIntTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenIntTensorOp op, PatternRewriter &rewriter) {
|
|
|
|
Value scalarInt = getScalarIntValue(op.getA(), op.getLoc(), rewriter);
|
|
|
|
if (!scalarInt)
|
|
|
|
return failure();
|
|
|
|
rewriter.replaceOp(op, scalarInt);
|
|
|
|
return success();
|
|
|
|
});
|
2021-11-30 02:39:37 +08:00
|
|
|
}
|
|
|
|
|
2022-02-09 19:55:14 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenFloatTensorOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenFloatTensorOp::fold(FoldAdaptor adaptor) {
|
2022-02-09 19:55:14 +08:00
|
|
|
// If a scalar number is converted to a 0-d tensor and passed on to
|
|
|
|
// aten.Float.Tensor, fold to the scalar number.
|
2022-12-08 04:20:41 +08:00
|
|
|
if (auto numToTensorScalar = getA().getDefiningOp<PrimNumToTensorScalarOp>())
|
|
|
|
return numToTensorScalar.getA();
|
2022-02-09 19:55:14 +08:00
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2022-04-25 20:06:41 +08:00
|
|
|
// AtenDivFloatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenDivFloatOp::fold(FoldAdaptor adaptor) {
|
2022-04-25 20:06:41 +08:00
|
|
|
double lhs, rhs;
|
|
|
|
bool lConstant = matchPattern(getOperand(0), m_TorchConstantFloat(&lhs));
|
|
|
|
bool rConstant = matchPattern(getOperand(1), m_TorchConstantFloat(&rhs));
|
|
|
|
if (lConstant && lhs == 0.0)
|
|
|
|
return getF64FloatAttr(getContext(), 0.0);
|
|
|
|
if (lConstant && rConstant && rhs == 1.0)
|
|
|
|
return getF64FloatAttr(getContext(), lhs);
|
|
|
|
if (lConstant && rConstant)
|
|
|
|
return getF64FloatAttr(getContext(), lhs / rhs);
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2022-10-06 21:11:52 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenDivIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenDivIntOp::fold(FoldAdaptor adaptor) {
|
2022-10-06 21:11:52 +08:00
|
|
|
int64_t lhs, rhs;
|
|
|
|
bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs));
|
|
|
|
bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs));
|
|
|
|
if (lConstant && rConstant)
|
|
|
|
return getF64FloatAttr(getContext(), double(lhs) / rhs);
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2024-02-08 08:17:15 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenIndexSelectOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenIndexSelectOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto self = getSelf();
|
|
|
|
auto index = getIndex();
|
|
|
|
auto selfTy = dyn_cast<ValueTensorType>(self.getType());
|
|
|
|
auto indexTy = dyn_cast<ValueTensorType>(index.getType());
|
|
|
|
auto resultTy = dyn_cast<ValueTensorType>(getType());
|
|
|
|
if (!selfTy || !indexTy || !resultTy || !selfTy.hasSizes() ||
|
|
|
|
!indexTy.hasSizes() || !resultTy.hasSizes() || !selfTy.hasDtype() ||
|
|
|
|
!indexTy.hasDtype() || !resultTy.hasDtype())
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
auto selfSizes = selfTy.getSizes();
|
|
|
|
auto indexSizes = indexTy.getSizes();
|
|
|
|
auto resultSizes = resultTy.getSizes();
|
|
|
|
|
|
|
|
if (selfTy.getDtype() != resultTy.getDtype() ||
|
|
|
|
selfSizes.size() != resultSizes.size() || indexSizes.size() != 1)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
// If the selection results in a tensor of the same dimensions as the
|
|
|
|
// input, the selection must have specified every index of the input,
|
|
|
|
// so the result is exactly the same as the input.
|
|
|
|
|
|
|
|
bool fullTensor = true;
|
|
|
|
for (int i = 0, s = selfSizes.size(); i < s; ++i) {
|
|
|
|
fullTensor &= selfSizes[i] == resultSizes[i];
|
|
|
|
fullTensor &= selfSizes[i] != Torch::kUnknownSize;
|
|
|
|
fullTensor &= resultSizes[i] != Torch::kUnknownSize;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (fullTensor && indexSizes[0] == 1)
|
|
|
|
return self;
|
|
|
|
|
|
|
|
// If the input tensor, index dimension, or indexes are non-constant,
|
|
|
|
// can't fold.
|
|
|
|
|
|
|
|
auto selfAttr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
|
|
|
|
auto dimAttr = dyn_cast_or_null<IntegerAttr>(adaptor.getDim());
|
|
|
|
auto indexAttr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getIndex());
|
|
|
|
|
|
|
|
if (!selfAttr || !dimAttr || !indexAttr)
|
|
|
|
return {};
|
|
|
|
|
|
|
|
// If the input's dimensions are all 1 except for one dimension, and if
|
|
|
|
// there is a single index in the index list (as detected by the result
|
|
|
|
// dimension being 1), then fold to a <1x1x...x1> tensor literal containing
|
|
|
|
// a single element. Handles float and int types.
|
|
|
|
|
|
|
|
int64_t dimInt = dimAttr.getInt();
|
|
|
|
// If the selected dim is negative, count backwards from the last dim
|
|
|
|
if (dimInt < 0)
|
|
|
|
dimInt = selfSizes.size() + dimInt;
|
|
|
|
assert(uint64_t(dimInt) < selfSizes.size() &&
|
|
|
|
"Selected dim > number of dims");
|
|
|
|
|
|
|
|
for (int i = 0, s = selfSizes.size(); i < s; ++i) {
|
|
|
|
if ((selfSizes[i] != 1 && i != dimInt) || resultSizes[i] != 1)
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Get the single index value for the selected dimension
|
|
|
|
auto splatValue = indexAttr.getSplatValue<IntegerAttr>();
|
2024-02-22 13:28:44 +08:00
|
|
|
int64_t indexInt = getIntAttrAsSigned(splatValue);
|
|
|
|
indexInt = indexInt < 0 && selfSizes[dimInt] ? indexInt + selfSizes[dimInt]
|
|
|
|
: indexInt;
|
2024-02-08 08:17:15 +08:00
|
|
|
|
|
|
|
// Extract the single constant value from the input tensor and turn the
|
|
|
|
// extracted value into a single-element tensor of the output shape and dtype
|
2024-02-22 13:28:44 +08:00
|
|
|
Attribute splattr = selfAttr.isSplat()
|
|
|
|
? selfAttr.getSplatValue<Attribute>()
|
|
|
|
: selfAttr.getValues<Attribute>()[indexInt];
|
2024-02-08 08:17:15 +08:00
|
|
|
|
|
|
|
auto dty = resultTy.getDtype();
|
2024-06-08 09:36:32 +08:00
|
|
|
auto attrTy = resultTy.toBuiltinTensor();
|
2024-02-08 08:17:15 +08:00
|
|
|
if (auto floatAttr = dyn_cast<FloatAttr>(splattr))
|
|
|
|
return DenseElementsAttr::get(
|
|
|
|
attrTy, FloatAttr::get(dty, floatAttr.getValueAsDouble()));
|
|
|
|
|
|
|
|
if (auto intAttr = dyn_cast<IntegerAttr>(splattr)) {
|
|
|
|
return DenseElementsAttr::get(attrTy,
|
|
|
|
IntegerAttr::get(dty, intAttr.getValue()));
|
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2024-02-03 02:46:33 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenItemOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenItemOp::fold(FoldAdaptor adaptor) {
|
|
|
|
// see if we have a constant tensor
|
|
|
|
DenseElementsAttr attr;
|
|
|
|
if (matchPattern(getOperand(), m_Constant(&attr))) {
|
|
|
|
auto splat = attr.getSplatValue<Attribute>();
|
|
|
|
if (auto intAttr = dyn_cast<IntegerAttr>(splat)) {
|
2024-04-02 07:21:05 +08:00
|
|
|
return intAttr.getType().isUnsignedInteger()
|
|
|
|
? getI64IntegerAttr(getContext(), intAttr.getUInt())
|
|
|
|
: getI64IntegerAttr(getContext(), intAttr.getSInt());
|
2024-02-03 02:46:33 +08:00
|
|
|
}
|
|
|
|
if (auto floatAttr = dyn_cast<FloatAttr>(splat)) {
|
|
|
|
return getF64FloatAttr(getContext(), floatAttr.getValueAsDouble());
|
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2024-04-03 07:19:57 +08:00
|
|
|
if (auto full = getOperand().getDefiningOp<Torch::AtenFullOp>()) {
|
|
|
|
return full.getFillValue();
|
|
|
|
}
|
|
|
|
|
|
|
|
if (auto numToTensor =
|
|
|
|
getOperand().getDefiningOp<Torch::PrimNumToTensorScalarOp>()) {
|
|
|
|
return numToTensor.getA();
|
|
|
|
}
|
|
|
|
|
2024-02-03 02:46:33 +08:00
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenOnesOp, AtenZerosOp, AtenFullOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult AtenOnesOp::fold(FoldAdaptor adaptor) {
|
|
|
|
SmallVector<int64_t> sizes;
|
|
|
|
if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
Type resultType = getResult().getType();
|
2024-04-11 21:47:35 +08:00
|
|
|
BaseTensorType resultTensorType = dyn_cast<BaseTensorType>(resultType);
|
2024-03-12 04:45:49 +08:00
|
|
|
if (!resultTensorType || !resultTensorType.hasDtype() ||
|
|
|
|
!resultTensorType.hasSizes()) {
|
2024-02-03 02:46:33 +08:00
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2024-03-12 04:45:49 +08:00
|
|
|
for (auto sz : sizes)
|
|
|
|
if (sz == Torch::kUnknownSize || sz < 0)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
for (auto sz : resultTensorType.getSizes())
|
|
|
|
if (sz == Torch::kUnknownSize || sz < 0)
|
|
|
|
return nullptr;
|
2024-02-29 04:04:52 +08:00
|
|
|
|
2024-02-03 02:46:33 +08:00
|
|
|
ShapedType shapedty =
|
|
|
|
mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType
|
|
|
|
sizes, resultTensorType.getDtype());
|
|
|
|
if (!shapedty) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
auto elementType = shapedty.getElementType();
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<IntegerType>(elementType)) {
|
2024-02-03 02:46:33 +08:00
|
|
|
Attribute attribute = IntegerAttr::get(elementType, 1);
|
|
|
|
return DenseElementsAttr::get(shapedty, attribute);
|
|
|
|
}
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<FloatType>(elementType)) {
|
2024-02-03 02:46:33 +08:00
|
|
|
Attribute attribute = FloatAttr::get(elementType, 1.0);
|
|
|
|
return DenseElementsAttr::get(shapedty, attribute);
|
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
OpFoldResult AtenZerosOp::fold(FoldAdaptor adaptor) {
|
|
|
|
SmallVector<int64_t> sizes;
|
|
|
|
if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
Type resultType = getResult().getType();
|
2024-04-11 21:47:35 +08:00
|
|
|
BaseTensorType resultTensorType = dyn_cast<BaseTensorType>(resultType);
|
2024-03-12 04:45:49 +08:00
|
|
|
if (!resultTensorType || !resultTensorType.hasDtype() ||
|
|
|
|
!resultTensorType.hasSizes()) {
|
2024-02-03 02:46:33 +08:00
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2024-03-12 04:45:49 +08:00
|
|
|
for (auto sz : sizes)
|
|
|
|
if (sz == Torch::kUnknownSize || sz < 0)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
for (auto sz : resultTensorType.getSizes())
|
|
|
|
if (sz == Torch::kUnknownSize || sz < 0)
|
|
|
|
return nullptr;
|
2024-02-29 04:04:52 +08:00
|
|
|
|
2024-02-03 02:46:33 +08:00
|
|
|
ShapedType shapedty =
|
|
|
|
mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType
|
|
|
|
sizes, resultTensorType.getDtype());
|
|
|
|
if (!shapedty) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
auto elementType = shapedty.getElementType();
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<IntegerType>(elementType)) {
|
2024-02-03 02:46:33 +08:00
|
|
|
Attribute attribute = IntegerAttr::get(elementType, 0);
|
|
|
|
return DenseElementsAttr::get(shapedty, attribute);
|
|
|
|
}
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<FloatType>(elementType)) {
|
2024-02-03 02:46:33 +08:00
|
|
|
Attribute attribute = FloatAttr::get(elementType, 0.0);
|
|
|
|
return DenseElementsAttr::get(shapedty, attribute);
|
|
|
|
}
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) {
|
|
|
|
SmallVector<int64_t> sizes;
|
|
|
|
if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
Type resultType = getResult().getType();
|
2024-04-11 21:47:35 +08:00
|
|
|
BaseTensorType resultTensorType = dyn_cast<BaseTensorType>(resultType);
|
2024-03-12 04:45:49 +08:00
|
|
|
if (!resultTensorType || !resultTensorType.hasDtype() ||
|
|
|
|
!resultTensorType.hasSizes()) {
|
2024-02-03 02:46:33 +08:00
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2024-03-12 04:45:49 +08:00
|
|
|
for (auto sz : sizes)
|
|
|
|
if (sz == Torch::kUnknownSize || sz < 0)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
for (auto sz : resultTensorType.getSizes())
|
|
|
|
if (sz == Torch::kUnknownSize || sz < 0)
|
|
|
|
return nullptr;
|
2024-02-29 04:04:52 +08:00
|
|
|
|
2024-04-16 07:06:47 +08:00
|
|
|
ShapedType shapedty = mlir::RankedTensorType::get(
|
|
|
|
resultTensorType.getSizes(), resultTensorType.getDtype());
|
2024-03-12 04:45:49 +08:00
|
|
|
|
2024-02-03 02:46:33 +08:00
|
|
|
auto elementType = shapedty.getElementType();
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<IntegerType>(elementType)) {
|
2024-02-03 02:46:33 +08:00
|
|
|
int64_t value = 0;
|
|
|
|
if (matchPattern(getFillValue(), m_TorchConstantInt(&value))) {
|
|
|
|
Attribute attribute = IntegerAttr::get(elementType, value);
|
|
|
|
return DenseElementsAttr::get(shapedty, attribute);
|
|
|
|
}
|
|
|
|
}
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<FloatType>(elementType)) {
|
2024-02-03 02:46:33 +08:00
|
|
|
double value = 0.0;
|
|
|
|
if (matchPattern(getFillValue(), m_TorchConstantFloat(&value))) {
|
|
|
|
Attribute attribute = FloatAttr::get(elementType, value);
|
|
|
|
return DenseElementsAttr::get(shapedty, attribute);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
2022-10-06 21:11:52 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2022-04-25 21:12:45 +08:00
|
|
|
// AtenCeilFloatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult AtenCeilFloatOp::fold(FoldAdaptor adaptor) {
|
2022-04-25 21:12:45 +08:00
|
|
|
double c;
|
|
|
|
if (matchPattern(getOperand(), m_TorchConstantFloat(&c)))
|
|
|
|
return getI64IntegerAttr(getContext(), std::ceil(c));
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2024-02-08 08:43:31 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenWhereSelfOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) {
|
|
|
|
if (!attr || !ty.hasDtype() || !ty.hasSizes())
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
auto dty = ty.getDtype();
|
|
|
|
|
|
|
|
if (auto valueDense = dyn_cast<DenseElementsAttr>(attr)) {
|
|
|
|
if (!valueDense.isSplat())
|
|
|
|
return nullptr;
|
|
|
|
auto splattr = valueDense.getSplatValue<Attribute>();
|
2024-06-08 09:36:32 +08:00
|
|
|
auto attrty = ty.toBuiltinTensor();
|
2024-02-08 08:43:31 +08:00
|
|
|
return DenseElementsAttr::get(attrty, splattr);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr)) {
|
|
|
|
if (!isa<mlir::IntegerType>(dty))
|
|
|
|
return nullptr;
|
|
|
|
int64_t intval = intAttr.getInt();
|
2024-06-08 09:36:32 +08:00
|
|
|
auto attrty = ty.toBuiltinTensor();
|
2024-02-08 08:43:31 +08:00
|
|
|
return DenseElementsAttr::get(attrty, IntegerAttr::get(dty, intval));
|
|
|
|
}
|
|
|
|
|
|
|
|
if (auto fpAttr = dyn_cast_or_null<FloatAttr>(attr)) {
|
|
|
|
if (!isa<mlir::FloatType>(dty))
|
|
|
|
return nullptr;
|
|
|
|
double dblval = fpAttr.getValueAsDouble();
|
2024-06-08 09:36:32 +08:00
|
|
|
auto attrty = ty.toBuiltinTensor();
|
2024-02-08 08:43:31 +08:00
|
|
|
return DenseElementsAttr::get(attrty, FloatAttr::get(dty, dblval));
|
|
|
|
}
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
OpFoldResult AtenWhereSelfOp::fold(FoldAdaptor adaptor) {
|
2024-04-03 07:19:57 +08:00
|
|
|
if (getSelf() == getOther())
|
|
|
|
return getSelf();
|
|
|
|
|
2024-02-08 08:43:31 +08:00
|
|
|
auto dense = dyn_cast_or_null<DenseElementsAttr>(adaptor.getCondition());
|
|
|
|
auto resultTy = dyn_cast<ValueTensorType>(getType());
|
|
|
|
if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes() || !dense ||
|
|
|
|
!dense.isSplat())
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
auto condattr = dense.getSplatValue<APInt>();
|
|
|
|
auto value = getSelf();
|
|
|
|
auto valueAttr = adaptor.getSelf();
|
|
|
|
if (condattr.isZero()) {
|
|
|
|
value = getOther();
|
|
|
|
valueAttr = adaptor.getOther();
|
|
|
|
}
|
|
|
|
|
|
|
|
auto valueTy = dyn_cast<ValueTensorType>(value.getType());
|
|
|
|
if (valueTy && valueTy.hasSizes() && valueTy.hasDtype() &&
|
|
|
|
valueTy == resultTy)
|
|
|
|
return value;
|
|
|
|
|
|
|
|
return getBroadcastedAttr(valueAttr, resultTy);
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenWhereScalarOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenWhereScalarOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto dense = dyn_cast_or_null<DenseElementsAttr>(adaptor.getCondition());
|
|
|
|
auto resultTy = dyn_cast<ValueTensorType>(getType());
|
|
|
|
if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes() || !dense ||
|
|
|
|
!dense.isSplat())
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
auto condattr = dense.getSplatValue<APInt>();
|
|
|
|
auto valueAttr = adaptor.getSelf();
|
|
|
|
if (condattr.isZero()) {
|
|
|
|
valueAttr = adaptor.getOther();
|
|
|
|
}
|
|
|
|
|
|
|
|
return getBroadcastedAttr(valueAttr, resultTy);
|
|
|
|
}
|
|
|
|
|
2024-04-03 07:19:57 +08:00
|
|
|
void AtenWhereScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
|
|
|
|
patterns.add(+[](AtenWhereScalarOp op, PatternRewriter &rewriter) {
|
|
|
|
auto cond = op.getCondition();
|
|
|
|
auto self = op.getSelf();
|
|
|
|
auto other = op.getOther();
|
|
|
|
|
|
|
|
if (self != other)
|
|
|
|
return rewriter.notifyMatchFailure(op, "differing output");
|
|
|
|
|
|
|
|
auto condTy = dyn_cast<BaseTensorType>(cond.getType());
|
|
|
|
if (!condTy || !condTy.hasSizes())
|
|
|
|
return rewriter.notifyMatchFailure(op, "output size unknown");
|
|
|
|
|
|
|
|
SmallVector<Value> dims;
|
|
|
|
auto torchIntTy = rewriter.getType<Torch::IntType>();
|
|
|
|
for (int i = 0, s = condTy.getSizes().size(); i < s; ++i) {
|
|
|
|
Value iv = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
op.getLoc(), torchIntTy, rewriter.getI64IntegerAttr(i));
|
|
|
|
dims.push_back(rewriter.create<Torch::AtenSizeIntOp>(
|
|
|
|
op.getLoc(), torchIntTy, cond, iv));
|
|
|
|
}
|
|
|
|
|
|
|
|
Value dimsList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
op.getLoc(), Torch::ListType::get(torchIntTy), dims);
|
|
|
|
|
|
|
|
Value none = rewriter.create<Torch::ConstantNoneOp>(op.getLoc());
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenFullOp>(
|
|
|
|
op, op.getType(), dimsList, self, none, none, none, none);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2024-02-08 08:43:31 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenWhereScalarOtherOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenWhereScalarOtherOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto dense = dyn_cast_or_null<DenseElementsAttr>(adaptor.getCondition());
|
|
|
|
auto resultTy = dyn_cast<ValueTensorType>(getType());
|
|
|
|
if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes() || !dense ||
|
|
|
|
!dense.isSplat())
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
auto condattr = dense.getSplatValue<APInt>();
|
|
|
|
auto valueAttr = adaptor.getSelf();
|
|
|
|
if (condattr.isZero()) {
|
|
|
|
valueAttr = adaptor.getOther();
|
|
|
|
}
|
|
|
|
|
|
|
|
return getBroadcastedAttr(valueAttr, resultTy);
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenWhereScalarSelfOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenWhereScalarSelfOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto dense = dyn_cast_or_null<DenseElementsAttr>(adaptor.getCondition());
|
|
|
|
auto resultTy = dyn_cast<ValueTensorType>(getType());
|
|
|
|
if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes() || !dense ||
|
|
|
|
!dense.isSplat())
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
auto condattr = dense.getSplatValue<APInt>();
|
|
|
|
auto valueAttr = adaptor.getSelf();
|
|
|
|
if (condattr.isZero()) {
|
|
|
|
valueAttr = adaptor.getOther();
|
|
|
|
}
|
|
|
|
|
|
|
|
return getBroadcastedAttr(valueAttr, resultTy);
|
|
|
|
}
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimMaxIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult PrimMaxIntOp::fold(FoldAdaptor adaptor) {
|
2022-03-10 08:44:22 +08:00
|
|
|
// If both operands are the same, then the operation is an identity.
|
2022-12-08 04:20:41 +08:00
|
|
|
if (getA() == getB())
|
|
|
|
return getA();
|
2022-03-10 08:44:22 +08:00
|
|
|
|
2024-05-31 14:45:13 +08:00
|
|
|
auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getA());
|
|
|
|
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getB());
|
2022-03-10 08:44:22 +08:00
|
|
|
if (!lhs || !rhs)
|
|
|
|
return nullptr;
|
|
|
|
// Torch semantics are that !torch.int is 64-bit signed.
|
|
|
|
return IntegerAttr::get(
|
|
|
|
lhs.getType(),
|
|
|
|
std::max(lhs.getValue().getSExtValue(), rhs.getValue().getSExtValue()));
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2024-02-20 03:55:54 +08:00
|
|
|
// PrimNumToTensorScalarOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult PrimNumToTensorScalarOp::fold(FoldAdaptor adaptor) {
|
|
|
|
Attribute a = adaptor.getA();
|
[Torch] Fix PrimNumToTensorScalarOp::fold (#3339)
In constant folding progress, a new constant op will be created
according to the origin op's result type.
See the code in TorchDialect.cpp.
```cpp
Operation *TorchDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
if (auto integerType = dyn_cast<Torch::IntType>(type))
return builder.create<Torch::ConstantIntOp>(loc, cast<IntegerAttr>(value));
if (auto floatType = dyn_cast<Torch::FloatType>(type))
return builder.create<Torch::ConstantFloatOp>(loc, cast<FloatAttr>(value));
if (auto numberType = dyn_cast<Torch::NumberType>(type)) {
if (auto floatValue = dyn_cast<mlir::FloatAttr>(value)) {
return builder.create<Torch::ConstantNumberOp>(loc, floatValue);
} else if (auto intValue = dyn_cast<mlir::IntegerAttr>(value)) {
return builder.create<Torch::ConstantNumberOp>(loc, intValue);
}
}
if (isa<Torch::BoolType>(type)) {
return builder.create<Torch::ConstantBoolOp>(loc, cast<IntegerAttr>(value));
}
if (isa<Torch::NoneType>(type))
return builder.create<ConstantNoneOp>(loc);
if (auto stringAttr = dyn_cast<StringAttr>(value))
return builder.create<ConstantStrOp>(loc, stringAttr);
if (auto elementsAttr = dyn_cast<ElementsAttr>(value)) {
// Only !torch.vtensor can be constant folded. !torch.tensor has
// non-trivial aliasing semantics which prevent deduplicating it.
assert(isa<ValueTensorType>(type) && "should be a vtensor type!");
return builder.create<ValueTensorLiteralOp>(loc, elementsAttr);
}
return nullptr;
}
```
So when the op has a tensor result type, it must be "ValueTensorType"
due to the **assert** statement. However, many fold methods in
TorchOps.cpp only have a judgment of "BaseTensorType".
2024-05-15 20:54:19 +08:00
|
|
|
auto resultTy = dyn_cast<ValueTensorType>(getType());
|
2024-02-20 03:55:54 +08:00
|
|
|
if (!a)
|
|
|
|
return {};
|
[Torch] Fix PrimNumToTensorScalarOp::fold (#3339)
In constant folding progress, a new constant op will be created
according to the origin op's result type.
See the code in TorchDialect.cpp.
```cpp
Operation *TorchDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
if (auto integerType = dyn_cast<Torch::IntType>(type))
return builder.create<Torch::ConstantIntOp>(loc, cast<IntegerAttr>(value));
if (auto floatType = dyn_cast<Torch::FloatType>(type))
return builder.create<Torch::ConstantFloatOp>(loc, cast<FloatAttr>(value));
if (auto numberType = dyn_cast<Torch::NumberType>(type)) {
if (auto floatValue = dyn_cast<mlir::FloatAttr>(value)) {
return builder.create<Torch::ConstantNumberOp>(loc, floatValue);
} else if (auto intValue = dyn_cast<mlir::IntegerAttr>(value)) {
return builder.create<Torch::ConstantNumberOp>(loc, intValue);
}
}
if (isa<Torch::BoolType>(type)) {
return builder.create<Torch::ConstantBoolOp>(loc, cast<IntegerAttr>(value));
}
if (isa<Torch::NoneType>(type))
return builder.create<ConstantNoneOp>(loc);
if (auto stringAttr = dyn_cast<StringAttr>(value))
return builder.create<ConstantStrOp>(loc, stringAttr);
if (auto elementsAttr = dyn_cast<ElementsAttr>(value)) {
// Only !torch.vtensor can be constant folded. !torch.tensor has
// non-trivial aliasing semantics which prevent deduplicating it.
assert(isa<ValueTensorType>(type) && "should be a vtensor type!");
return builder.create<ValueTensorLiteralOp>(loc, elementsAttr);
}
return nullptr;
}
```
So when the op has a tensor result type, it must be "ValueTensorType"
due to the **assert** statement. However, many fold methods in
TorchOps.cpp only have a judgment of "BaseTensorType".
2024-05-15 20:54:19 +08:00
|
|
|
if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes())
|
2024-02-20 03:55:54 +08:00
|
|
|
return {};
|
|
|
|
|
|
|
|
auto dty = resultTy.getDtype();
|
|
|
|
if (auto iattr = dyn_cast<IntegerAttr>(a)) {
|
|
|
|
a = IntegerAttr::get(dty, iattr.getInt());
|
|
|
|
} else if (auto fattr = dyn_cast<FloatAttr>(a)) {
|
|
|
|
a = FloatAttr::get(dty, fattr.getValueAsDouble());
|
2024-03-27 09:51:58 +08:00
|
|
|
} else {
|
|
|
|
// doesn't handle other types, like complex type
|
|
|
|
return {};
|
2024-02-20 03:55:54 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
auto mlirTensorType =
|
|
|
|
RankedTensorType::get(resultTy.getSizes(), resultTy.getDtype());
|
|
|
|
return SplatElementsAttr::get(mlirTensorType, a);
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2022-03-10 08:44:22 +08:00
|
|
|
// PrimMinSelfIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
OpFoldResult PrimMinSelfIntOp::fold(FoldAdaptor adaptor) {
|
2022-03-10 08:44:22 +08:00
|
|
|
auto list = getOperand().getDefiningOp<PrimListConstructOp>();
|
|
|
|
if (!list)
|
|
|
|
return nullptr;
|
|
|
|
// TODO: What does it return for an empty list?
|
|
|
|
if (list->getNumOperands() == 0)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
SmallVector<int64_t> values;
|
|
|
|
for (auto operand : list->getOperands()) {
|
|
|
|
int64_t value;
|
|
|
|
if (!matchPattern(operand, m_TorchConstantInt(&value)))
|
|
|
|
return nullptr;
|
|
|
|
values.push_back(value);
|
|
|
|
}
|
|
|
|
return getI64IntegerAttr(getContext(),
|
|
|
|
*std::min_element(values.begin(), values.end()));
|
|
|
|
}
|
|
|
|
|
2023-02-11 05:58:15 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimMinIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult PrimMinIntOp::fold(FoldAdaptor adaptor) {
|
|
|
|
// If both operands are the same, then the operation is an identity.
|
|
|
|
if (getA() == getB())
|
|
|
|
return getA();
|
|
|
|
|
2024-05-31 14:45:13 +08:00
|
|
|
auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getA());
|
|
|
|
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getB());
|
2023-02-11 05:58:15 +08:00
|
|
|
if (!lhs || !rhs)
|
|
|
|
return nullptr;
|
|
|
|
// Torch semantics are that !torch.int is 64-bit signed.
|
|
|
|
return IntegerAttr::get(
|
|
|
|
lhs.getType(),
|
|
|
|
std::min(lhs.getValue().getSExtValue(), rhs.getValue().getSExtValue()));
|
|
|
|
}
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ShapeCalculateOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-12-14 00:25:41 +08:00
|
|
|
template <typename CalculateOp>
|
|
|
|
static void
|
2023-09-13 06:09:57 +08:00
|
|
|
getSuccessorRegionsForCalculateOp(CalculateOp op, RegionBranchPoint point,
|
2022-12-14 00:25:41 +08:00
|
|
|
SmallVectorImpl<RegionSuccessor> ®ions) {
|
2023-09-13 06:09:57 +08:00
|
|
|
if (!point.getRegionOrNull()) {
|
2022-12-14 00:25:41 +08:00
|
|
|
// First thing the op does is branch into the calculation.
|
|
|
|
regions.emplace_back(&op.getCalculation());
|
2022-03-10 08:44:22 +08:00
|
|
|
return;
|
|
|
|
}
|
2023-09-13 06:09:57 +08:00
|
|
|
if (point == op.getBody()) {
|
2022-03-10 08:44:22 +08:00
|
|
|
// Body returns control to the outer op, passing through results.
|
2022-12-14 00:25:41 +08:00
|
|
|
regions.emplace_back(op.getResults());
|
2022-03-10 08:44:22 +08:00
|
|
|
return;
|
|
|
|
}
|
2023-09-13 06:09:57 +08:00
|
|
|
assert(point == op.getCalculation());
|
2022-12-14 00:25:41 +08:00
|
|
|
// Calculation branches to the body.
|
|
|
|
regions.emplace_back(&op.getBody());
|
|
|
|
}
|
|
|
|
|
|
|
|
void ShapeCalculateOp::getSuccessorRegions(
|
2023-09-13 06:09:57 +08:00
|
|
|
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
|
|
|
|
getSuccessorRegionsForCalculateOp(*this, point, regions);
|
2022-12-14 00:25:41 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// DtypeCalculateOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void DtypeCalculateOp::getSuccessorRegions(
|
2023-09-13 06:09:57 +08:00
|
|
|
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
|
|
|
|
getSuccessorRegionsForCalculateOp(*this, point, regions);
|
2022-03-10 08:44:22 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ShapeCalculateYieldShapesOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
MutableOperandRange ShapeCalculateYieldShapesOp::getMutableSuccessorOperands(
|
2023-09-13 06:09:57 +08:00
|
|
|
RegionBranchPoint point) {
|
2022-03-10 08:44:22 +08:00
|
|
|
// The shape operands don't get forwarded to the body.
|
|
|
|
// MutableOperandRange always has an owning operation, even if empty, so
|
|
|
|
// create a 0-length range.
|
|
|
|
return MutableOperandRange(*this, /*start=*/0, /*length=*/0);
|
|
|
|
}
|
|
|
|
|
2022-03-16 08:54:57 +08:00
|
|
|
LogicalResult ShapeCalculateYieldShapesOp::verify() {
|
|
|
|
auto parent = cast<ShapeCalculateOp>(getOperation()->getParentOp());
|
|
|
|
if (parent.getNumResults() != getNumOperands())
|
|
|
|
return emitOpError("expected number of shapes to match number of results");
|
2022-03-10 08:44:22 +08:00
|
|
|
return success();
|
|
|
|
}
|
Rework how global slot initializers work.
Rather than a per-global-slot initializer region, we now have one for
the whole module. For example, it might look like this:
```
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor)
@list(%1 : !torch.list<tensor>)
]
}
```
This new structure allows GlobalizeObjectGraph to create the initializer in a
much simpler way, avoiding the need to reason about whether different slots
alias each other. Reasoning about whether slots alias each other now is the
responsibility of InlineGlobalSlots, which has to do a much more complicated
analysis, implemented using MLIR's dataflow analysis framework.
Recommended review order:
- Check out the new IR constructs in the .mlir files of various passes
- Op definitions (*.td)
- Changes to GlobalizeObjectGraph pass.
- InlineGlobalSlots pass (~total rewrite)
- Misc changes:
- Moving torchMlirAdjustStaticInformation for sharing with C++ code.
- EraseModuleInitializer pass
To make this a bit nicer, it would be good to have a `torch.module` op
with an initializer region attached. That would be more invasive though.
This change has highlighted certain aspects of our project layering
which are worth calling out. None of our backends can handle global
slots, so we enforce that there are no global slots before backend
lowering. At an earlier stage in the project, we had aspirations of
transparently handling mutable global state and such, but for reasons
described below, that is no longer a goal. So really global slots should
be seen as a progressive lowering step as part of inlining all the
IValue's in the original program (GlobalizeObjectGraph is also one such
step).
Over time, with insights from work like IREE-JAX, it has become clear
that there isn't a reliable programming model we can compile for users
where we just transparently handle mutable global state (and some other
things, like lists and dictionaries). There is a need for an "outer
program" that orchestrates more restricted subroutines of the kind we
can handle in our compile flow here. The benefit of that is that it
decouples considerations like shapes, dtypes, etc. from the program
constructs used in the outer program. As long as the outer program can
efficiently invoke (pipelining/async/etc.) high-performance
data-parallel numerical subroutines of the kind we compile in our flow
here, then there is a complete programming model. This is also
consistent with the direction of upstream PyTorch which is becoming more
tracing-based (which inherently loses a lot of program structure, which
then has to be applied back with an "outer program" orchestrating the
traced subroutines).
2022-07-14 02:45:56 +08:00
|
|
|
|
2024-02-27 00:46:56 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenNormScalarOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult AtenNormScalarOp::verify() {
|
|
|
|
|
|
|
|
// Verificaion of input type for torch.aten.norm.Scalar.
|
|
|
|
// Per PyTorch docs, only float and complex types are valid for norm
|
|
|
|
// operation.
|
|
|
|
|
2024-04-28 05:00:56 +08:00
|
|
|
auto inTensor = cast<BaseTensorType>(getSelf().getType());
|
2024-02-27 00:46:56 +08:00
|
|
|
|
|
|
|
// If no dtype is specified, it will default to a float one.
|
|
|
|
if (!inTensor.hasDtype()) {
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
auto inTensorDtype = inTensor.getDtype();
|
|
|
|
|
|
|
|
// Check if dtype is one of those supported by norm operation.
|
|
|
|
// ComplexType will match any torch complex types, but each float must be
|
|
|
|
// checked individually.
|
2024-05-31 14:45:13 +08:00
|
|
|
if (!isa<mlir::ComplexType, mlir::Float16Type, mlir::Float32Type,
|
|
|
|
mlir::Float64Type>(inTensorDtype)) {
|
2024-02-27 00:46:56 +08:00
|
|
|
return emitOpError(
|
|
|
|
"expected a float or complex type for input tensor, but got ")
|
|
|
|
<< inTensorDtype;
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2024-06-18 01:40:57 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenRenormOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult AtenRenormOp::verify() {
|
|
|
|
|
|
|
|
auto selfType = cast<BaseTensorType>(getSelf().getType());
|
|
|
|
|
|
|
|
if (!selfType.hasDtype() || !selfType.hasSizes())
|
|
|
|
return success();
|
|
|
|
|
|
|
|
auto inShape = selfType.getSizes();
|
|
|
|
int64_t selfRank = inShape.size();
|
|
|
|
auto selfDtype = selfType.getDtype();
|
|
|
|
|
|
|
|
if (!isa<mlir::Float16Type, mlir::BFloat16Type, mlir::Float32Type,
|
|
|
|
mlir::Float64Type, mlir::ComplexType>(selfDtype))
|
|
|
|
return emitOpError(
|
|
|
|
"expected a float or complex type for input tensor, but got ")
|
|
|
|
<< selfDtype;
|
|
|
|
|
|
|
|
// According to the Pytoch documentation tensor need to be at least rank 2
|
|
|
|
if (selfRank <= 1)
|
|
|
|
return emitOpError("renorm: input needs at least 2 dimensions, got ")
|
|
|
|
<< selfRank << " dimensions";
|
|
|
|
|
|
|
|
// Check if argument p is valid
|
|
|
|
auto pType = getP().getType();
|
|
|
|
|
|
|
|
if (isa<mlir::ComplexType>(pType))
|
|
|
|
return emitOpError("renorm: p must be real-valued");
|
|
|
|
|
|
|
|
// The argument 'p' can be either an integer or a floating-point number,
|
|
|
|
// so we need to consider both options and check if 'p' is within the correct
|
|
|
|
// range
|
|
|
|
int64_t pInt = 1;
|
|
|
|
double_t pDouble = 1;
|
|
|
|
if (!matchPattern(getP(), m_TorchConstantInt(&pInt)) &&
|
|
|
|
!matchPattern(getP(), m_TorchConstantFloat(&pDouble)))
|
|
|
|
return success();
|
|
|
|
|
|
|
|
if (pInt <= 0 || pDouble <= 0)
|
|
|
|
return emitOpError("renorm: non-positive norm not supported");
|
|
|
|
|
|
|
|
// Check if argument maxnorm is valid
|
|
|
|
auto maxnormType = getMaxnorm().getType();
|
|
|
|
if (isa<mlir::ComplexType>(maxnormType))
|
|
|
|
return emitOpError("renorm: maxnorm must be real-valued");
|
|
|
|
|
|
|
|
// The argument 'maxnorm' can be either an integer or a floating-point number,
|
|
|
|
// so we need to consider both options and check if 'maxnorm' is within the
|
|
|
|
// correct range
|
|
|
|
int64_t maxnormInt = 0;
|
|
|
|
double_t maxnormDouble = 0;
|
|
|
|
if (!matchPattern(getMaxnorm(), m_TorchConstantInt(&maxnormInt)) &&
|
|
|
|
!matchPattern(getMaxnorm(), m_TorchConstantFloat(&maxnormDouble)))
|
|
|
|
return success();
|
|
|
|
|
|
|
|
if (maxnormInt < 0 || maxnormDouble < 0)
|
|
|
|
return emitOpError("renorm: expected maxnorm to be >= 0");
|
|
|
|
|
|
|
|
// Get the dimension
|
|
|
|
int64_t dim;
|
|
|
|
if (!matchPattern(getDim(), m_TorchConstantInt(&dim)))
|
|
|
|
return success();
|
|
|
|
|
|
|
|
// check if is dim is in the correct range
|
|
|
|
if (dim >= selfRank || dim < -selfRank)
|
|
|
|
return emitOpError("Dimension out of range (expected to be in range of [")
|
|
|
|
<< -selfRank << ", " << selfRank - 1 << "], but got " << dim;
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2024-02-27 00:46:56 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenPermuteOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-11-16 03:47:54 +08:00
|
|
|
LogicalResult AtenPermuteOp::verify() {
|
|
|
|
|
|
|
|
// Verification of the permute op for input & output dimensions with
|
|
|
|
// statically known sizes.
|
|
|
|
|
|
|
|
SmallVector<Value> permutation;
|
|
|
|
auto permutationObtained = getListConstructElements(getDims(), permutation);
|
|
|
|
if (!permutationObtained) {
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2024-04-28 05:00:56 +08:00
|
|
|
auto outType = cast<BaseTensorType>(getResult().getType());
|
|
|
|
auto inType = cast<BaseTensorType>(getSelf().getType());
|
2023-11-16 03:47:54 +08:00
|
|
|
|
|
|
|
if (!outType.hasSizes() || !inType.hasSizes()) {
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
auto outShape = outType.getSizes();
|
|
|
|
auto inShape = inType.getSizes();
|
|
|
|
|
|
|
|
auto outRank = outShape.size();
|
|
|
|
|
|
|
|
if (outRank != inShape.size()) {
|
|
|
|
return emitOpError(
|
|
|
|
"expected input and output tensors to have same rank, but ")
|
|
|
|
<< inShape.size() << " != " << outRank << '.';
|
|
|
|
}
|
|
|
|
|
|
|
|
if (outRank != permutation.size()) {
|
|
|
|
return emitOpError() << "expected permutation to have size equal result "
|
|
|
|
"tensor rank. The permutation has "
|
|
|
|
<< permutation.size()
|
|
|
|
<< " elements, the output has rank " << outRank << '.';
|
|
|
|
}
|
|
|
|
|
|
|
|
// Initialization of the reverse permutation. -1 denotes an unknown
|
|
|
|
// permutation index.
|
|
|
|
SmallVector<int64_t> reversePermutation(outRank, -1);
|
|
|
|
|
|
|
|
// In this loop:
|
|
|
|
// (1) check that the permutation indices are in bounds, and not duplicated.
|
|
|
|
// (2) populate reversePermutation (to check for duplicates).
|
|
|
|
// (3) check that the input and output shapes agree with the permutation. For
|
|
|
|
// example, if the permutation is (1,2,0) and the input shape is (2,3,5),
|
|
|
|
// then the output shape must be (3,5,2).
|
|
|
|
|
|
|
|
for (uint64_t to = 0; to < outRank; ++to) {
|
|
|
|
int64_t from;
|
|
|
|
|
|
|
|
auto fromIsSet = matchPattern(permutation[to], m_TorchConstantInt(&from));
|
|
|
|
|
|
|
|
if (!fromIsSet) {
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
|
|
|
// if 'from' is the unkwown index, continue.
|
|
|
|
if (from == -1) {
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (!isValidDim(from, outRank)) {
|
|
|
|
return emitError("observed invalid index in permutation (")
|
|
|
|
<< from << ") for input tensor of rank " << outRank << '.';
|
|
|
|
}
|
|
|
|
|
|
|
|
if (reversePermutation[from] != -1) {
|
|
|
|
return emitOpError("has a duplicate dimension (")
|
|
|
|
<< from << ") in its permutation " << getDims() << '.';
|
|
|
|
}
|
|
|
|
reversePermutation[from] = to;
|
|
|
|
|
|
|
|
auto dimSizesDefined =
|
|
|
|
inShape[from] != kUnknownSize && outShape[to] != kUnknownSize;
|
|
|
|
auto dimSizesDifferent = inShape[from] != outShape[to];
|
|
|
|
|
|
|
|
if (dimSizesDefined && dimSizesDifferent) {
|
|
|
|
return emitOpError("has a permutation which is not compatible with the "
|
|
|
|
"input and output shapes. ")
|
|
|
|
<< "The input shape in dimension " << from << " is "
|
|
|
|
<< inShape[from] << ", and the output shape in dimension " << to
|
|
|
|
<< " is " << outShape[to]
|
|
|
|
<< " : they should be the same with this permutation. ";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2024-05-02 15:03:41 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimsConvertElementTypeOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult PrimsConvertElementTypeOp::fold(FoldAdaptor adaptor) {
|
|
|
|
auto inputType = cast<BaseTensorType>(getA().getType());
|
|
|
|
auto outputType = cast<BaseTensorType>(getResult().getType());
|
|
|
|
if (inputType != outputType)
|
|
|
|
return nullptr;
|
|
|
|
if (!inputType.hasDtype() || !outputType.hasDtype())
|
|
|
|
return nullptr;
|
|
|
|
if (inputType.getDtype() != outputType.getDtype())
|
|
|
|
return nullptr;
|
|
|
|
return getA();
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenMaxPool2dWithIndicesOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns(
|
|
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenMaxPool2dWithIndicesOp op, PatternRewriter &rewriter) {
|
|
|
|
if (!op.getResult1().use_empty()) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "result1 of MaxPool2dWithIndices should be unused");
|
|
|
|
}
|
|
|
|
|
|
|
|
Value result = rewriter.create<Torch::AtenMaxPool2dOp>(
|
|
|
|
op->getLoc(), op.getResult0().getType(), op.getSelf(),
|
|
|
|
op.getKernelSize(), op.getStride(), op.getPadding(), op.getDilation(),
|
|
|
|
op.getCeilMode());
|
|
|
|
|
|
|
|
op.getResult0().replaceAllUsesWith(result);
|
|
|
|
rewriter.eraseOp(op);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
[Torch] Add support for static uneven divisible AdaptiveAvgPool2d (#3566)
The static uneven divisible AdaptiveAvgPool2d means that although the
input size is not an integer multiple of ouput size, but the kernel and
stride size can also be fixed (not dynamic). The derivation logic of
kernel and stride size is consistent with
torch/_decomp/decomposations.py:adaptive_avg_pool2d as described in the
following:
1. Stride Size
Firstly , derive the start index in each reduce operation according to
the output size (`n`), `start_index = ([0, 1, ..., n - 1] * input_size)
// output_size`. For each index `k`, if `k * (input_size % output_size)
< output_size`, then the current and previous stride keeps the same as
`input_size // output_size`. So suppose `(n-1) * (input_size %
output_size) < output_size`, the stride in the whole AdaptiveAvgPool2d
process keeps static, as `input_size // output_size`.
2. Kernel Size
torch/_decomp/decomposations.py:adaptive_avg_pool2d calculates a static
kernel size when the input/output sizes satisfy either of the two
conditions, `input_size % output_size == 0` or `output_size %
(input_size % output_size) == 0`. Here if `input_size % output_size ==
0`, then the kernel size equals `input_size // output_size`, otherwise
`input_size // output_size + 1.`
2024-08-01 11:37:53 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten_AdaptiveAvgPool2dOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void Aten_AdaptiveAvgPool2dOp::getCanonicalizationPatterns(
|
|
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
|
|
patterns.add(+[](Aten_AdaptiveAvgPool2dOp op, PatternRewriter &rewriter) {
|
|
|
|
rewriter.replaceOpWithNewOp<AtenAdaptiveAvgPool2dOp>(
|
|
|
|
op, op.getType(), op.getSelf(), op.getOutputSize());
|
|
|
|
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2024-03-14 03:17:22 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenLinalgCrossOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult AtenLinalgCrossOp::verify() {
|
|
|
|
|
2024-04-28 05:00:56 +08:00
|
|
|
auto selfType = cast<BaseTensorType>(getSelf().getType());
|
|
|
|
auto otherType = cast<BaseTensorType>(getOther().getType());
|
2024-03-14 03:17:22 +08:00
|
|
|
|
|
|
|
if (!selfType.hasDtype() || !otherType.hasDtype() || !selfType.hasSizes() ||
|
|
|
|
!otherType.hasSizes()) {
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
Type selfDtype = selfType.getDtype();
|
|
|
|
Type otherDtype = otherType.getDtype();
|
|
|
|
|
|
|
|
// the operation succeeds only if both inputs have the same dtype
|
|
|
|
if (selfDtype != otherDtype) {
|
|
|
|
return emitOpError("input tensors must have the same dtype, but got ")
|
|
|
|
<< selfDtype << " and " << otherDtype;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Check if any of the input tensors has torch.bool dtype.
|
|
|
|
// The operation does not support this type.
|
|
|
|
// The docs state that only float, double, cfloat and cdouble dtypes are
|
|
|
|
// supported, but, when testing, it fails only for boolean dtype. Update to
|
|
|
|
// fit the docs if necessary.
|
|
|
|
// https://pytorch.org/docs/stable/generated/torch.linalg.cross.html
|
|
|
|
if (selfDtype.isSignlessInteger(1) || otherDtype.isSignlessInteger(1)) {
|
|
|
|
return emitOpError("input tensors must not have bool dtype");
|
|
|
|
}
|
|
|
|
|
|
|
|
ArrayRef<int64_t> selfShape = selfType.getSizes();
|
|
|
|
ArrayRef<int64_t> otherShape = otherType.getSizes();
|
|
|
|
|
|
|
|
int64_t selfRank = selfShape.size();
|
|
|
|
int64_t otherRank = otherShape.size();
|
|
|
|
|
|
|
|
// check if both input tensors have the same number of dims
|
|
|
|
if (selfRank != otherRank) {
|
|
|
|
return emitOpError("input tensors must have the same number of dimensions, "
|
|
|
|
"but got ")
|
|
|
|
<< selfRank << " and " << otherRank;
|
|
|
|
}
|
|
|
|
|
|
|
|
// convert dim to an integer type
|
|
|
|
int64_t dim;
|
|
|
|
if (!matchPattern(getDim(), m_TorchConstantInt(&dim))) {
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// check if dim is in the correct range
|
|
|
|
if (dim >= selfRank || dim < -selfRank) {
|
|
|
|
return emitOpError("dim expected to be in rank of [")
|
|
|
|
<< -selfRank << ", " << selfRank - 1 << "], but got " << dim;
|
|
|
|
}
|
|
|
|
|
|
|
|
// compensate for possible negative dim value
|
|
|
|
if (dim < 0) {
|
|
|
|
dim += selfRank;
|
|
|
|
}
|
|
|
|
|
|
|
|
// check if the size of the dimensions specified by 'dim' is equal to 3
|
|
|
|
// (required by the operation)
|
|
|
|
if ((selfShape[dim] != 3 && selfShape[dim] != kUnknownSize) ||
|
|
|
|
(otherShape[dim] != 3 && otherShape[dim] != kUnknownSize)) {
|
|
|
|
return emitOpError("inputs dimension ")
|
|
|
|
<< dim << " must have length 3, but got " << selfShape[dim]
|
|
|
|
<< " and " << otherShape[dim];
|
|
|
|
}
|
|
|
|
|
|
|
|
// Check if there is a disparity between dimension sizes.
|
|
|
|
// Dimensions at the same index must either have the same size,
|
|
|
|
// or one of them must be equal to 1.
|
|
|
|
int32_t i = 0;
|
|
|
|
for (auto [selfCurrent, otherCurrent] :
|
|
|
|
llvm::zip_equal(selfShape, otherShape)) {
|
|
|
|
if (selfCurrent != otherCurrent && selfCurrent != 1 && otherCurrent != 1) {
|
|
|
|
return emitOpError("the size of first tensor (")
|
|
|
|
<< selfCurrent << ") must match the size of second tensor ("
|
|
|
|
<< otherCurrent << ") at dimension " << i
|
|
|
|
<< " or one of them must be 1";
|
|
|
|
}
|
|
|
|
++i;
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2024-06-15 13:48:39 +08:00
|
|
|
LogicalResult AtenKthvalueOp::verify() {
|
|
|
|
|
|
|
|
auto selfType = cast<BaseTensorType>(getSelf().getType());
|
|
|
|
|
|
|
|
if (!selfType.hasDtype() || !selfType.hasSizes())
|
|
|
|
return success();
|
|
|
|
|
|
|
|
Type selfDtype = selfType.getDtype();
|
|
|
|
if (selfDtype.isSignlessInteger(1))
|
|
|
|
return emitOpError("input tensors must not have bool dtype");
|
|
|
|
|
|
|
|
int64_t dim;
|
|
|
|
if (!matchPattern(getDim(), m_TorchConstantInt(&dim)))
|
|
|
|
return success();
|
|
|
|
|
|
|
|
ArrayRef<int64_t> selfShape = selfType.getSizes();
|
|
|
|
int64_t selfRank = selfShape.size();
|
|
|
|
|
|
|
|
dim = toPositiveDim(dim, selfRank);
|
|
|
|
if (!isValidDim(dim, selfRank))
|
|
|
|
return emitOpError("dim expected to be in range of [")
|
|
|
|
<< -selfRank << ", " << selfRank - 1 << "], but got " << dim;
|
|
|
|
|
|
|
|
// convert k to an integer type
|
|
|
|
int64_t k;
|
|
|
|
if (!matchPattern(getK(), m_TorchConstantInt(&k)))
|
|
|
|
return success();
|
|
|
|
|
|
|
|
// check if k is in the correct range
|
|
|
|
if (selfShape[dim] != kUnknownSize && (k < 1 || k > selfShape[dim]))
|
|
|
|
return emitOpError("k expected to be in range of [")
|
|
|
|
<< 1 << ", " << selfShape[dim] << "], but got " << k;
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-12-14 00:25:41 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// DtypeCalculateYieldDtypesOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
MutableOperandRange DtypeCalculateYieldDtypesOp::getMutableSuccessorOperands(
|
2023-09-13 06:09:57 +08:00
|
|
|
RegionBranchPoint point) {
|
2022-12-14 00:25:41 +08:00
|
|
|
// The dtype operands don't get forwarded to the body.
|
|
|
|
// MutableOperandRange always has an owning operation, even if empty, so
|
|
|
|
// create a 0-length range.
|
|
|
|
return MutableOperandRange(*this, /*start=*/0, /*length=*/0);
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult DtypeCalculateYieldDtypesOp::verify() {
|
|
|
|
auto parent = cast<DtypeCalculateOp>(getOperation()->getParentOp());
|
|
|
|
if (parent.getNumResults() != getNumOperands())
|
|
|
|
return emitOpError("expected number of dtypes to match number of results");
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
Rework how global slot initializers work.
Rather than a per-global-slot initializer region, we now have one for
the whole module. For example, it might look like this:
```
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor)
@list(%1 : !torch.list<tensor>)
]
}
```
This new structure allows GlobalizeObjectGraph to create the initializer in a
much simpler way, avoiding the need to reason about whether different slots
alias each other. Reasoning about whether slots alias each other now is the
responsibility of InlineGlobalSlots, which has to do a much more complicated
analysis, implemented using MLIR's dataflow analysis framework.
Recommended review order:
- Check out the new IR constructs in the .mlir files of various passes
- Op definitions (*.td)
- Changes to GlobalizeObjectGraph pass.
- InlineGlobalSlots pass (~total rewrite)
- Misc changes:
- Moving torchMlirAdjustStaticInformation for sharing with C++ code.
- EraseModuleInitializer pass
To make this a bit nicer, it would be good to have a `torch.module` op
with an initializer region attached. That would be more invasive though.
This change has highlighted certain aspects of our project layering
which are worth calling out. None of our backends can handle global
slots, so we enforce that there are no global slots before backend
lowering. At an earlier stage in the project, we had aspirations of
transparently handling mutable global state and such, but for reasons
described below, that is no longer a goal. So really global slots should
be seen as a progressive lowering step as part of inlining all the
IValue's in the original program (GlobalizeObjectGraph is also one such
step).
Over time, with insights from work like IREE-JAX, it has become clear
that there isn't a reliable programming model we can compile for users
where we just transparently handle mutable global state (and some other
things, like lists and dictionaries). There is a need for an "outer
program" that orchestrates more restricted subroutines of the kind we
can handle in our compile flow here. The benefit of that is that it
decouples considerations like shapes, dtypes, etc. from the program
constructs used in the outer program. As long as the outer program can
efficiently invoke (pipelining/async/etc.) high-performance
data-parallel numerical subroutines of the kind we compile in our flow
here, then there is a complete programming model. This is also
consistent with the direction of upstream PyTorch which is becoming more
tracing-based (which inherently loses a lot of program structure, which
then has to be applied back with an "outer program" orchestrating the
traced subroutines).
2022-07-14 02:45:56 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// GlobalSlotModuleInitializerOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult GlobalSlotModuleInitializerOp::verify() {
|
|
|
|
// We centralize all verification of the global slots and the
|
|
|
|
// InitializeGlobalSlotsOp into here, since it requires processing the whole
|
|
|
|
// module.
|
|
|
|
|
|
|
|
// TODO: We should really have a `torch.module` and have this initializer be
|
|
|
|
// a region attached to it.
|
|
|
|
|
|
|
|
ModuleOp module = cast<ModuleOp>(getOperation()->getParentOp());
|
|
|
|
for (auto op : module.getOps<GlobalSlotModuleInitializerOp>()) {
|
|
|
|
if (op.getOperation() != getOperation())
|
|
|
|
return op.emitError("there must be only one global slot initializer");
|
|
|
|
}
|
|
|
|
|
|
|
|
// Collect the relevant symbol names we will verify.
|
|
|
|
DenseSet</*StringAttr*/ Attribute> knownGlobalSlots;
|
|
|
|
for (auto op : module.getOps<GlobalSlotOp>())
|
2022-12-08 04:20:41 +08:00
|
|
|
knownGlobalSlots.insert(op.getSymNameAttr());
|
Rework how global slot initializers work.
Rather than a per-global-slot initializer region, we now have one for
the whole module. For example, it might look like this:
```
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor)
@list(%1 : !torch.list<tensor>)
]
}
```
This new structure allows GlobalizeObjectGraph to create the initializer in a
much simpler way, avoiding the need to reason about whether different slots
alias each other. Reasoning about whether slots alias each other now is the
responsibility of InlineGlobalSlots, which has to do a much more complicated
analysis, implemented using MLIR's dataflow analysis framework.
Recommended review order:
- Check out the new IR constructs in the .mlir files of various passes
- Op definitions (*.td)
- Changes to GlobalizeObjectGraph pass.
- InlineGlobalSlots pass (~total rewrite)
- Misc changes:
- Moving torchMlirAdjustStaticInformation for sharing with C++ code.
- EraseModuleInitializer pass
To make this a bit nicer, it would be good to have a `torch.module` op
with an initializer region attached. That would be more invasive though.
This change has highlighted certain aspects of our project layering
which are worth calling out. None of our backends can handle global
slots, so we enforce that there are no global slots before backend
lowering. At an earlier stage in the project, we had aspirations of
transparently handling mutable global state and such, but for reasons
described below, that is no longer a goal. So really global slots should
be seen as a progressive lowering step as part of inlining all the
IValue's in the original program (GlobalizeObjectGraph is also one such
step).
Over time, with insights from work like IREE-JAX, it has become clear
that there isn't a reliable programming model we can compile for users
where we just transparently handle mutable global state (and some other
things, like lists and dictionaries). There is a need for an "outer
program" that orchestrates more restricted subroutines of the kind we
can handle in our compile flow here. The benefit of that is that it
decouples considerations like shapes, dtypes, etc. from the program
constructs used in the outer program. As long as the outer program can
efficiently invoke (pipelining/async/etc.) high-performance
data-parallel numerical subroutines of the kind we compile in our flow
here, then there is a complete programming model. This is also
consistent with the direction of upstream PyTorch which is becoming more
tracing-based (which inherently loses a lot of program structure, which
then has to be applied back with an "outer program" orchestrating the
traced subroutines).
2022-07-14 02:45:56 +08:00
|
|
|
DenseSet</*StringAttr*/ Attribute> initializedGlobalSlots;
|
|
|
|
auto initialize = cast<InitializeGlobalSlotsOp>(getBody()->getTerminator());
|
2022-12-08 04:20:41 +08:00
|
|
|
for (Attribute symName : initialize.getSlotSymNames()) {
|
Rework how global slot initializers work.
Rather than a per-global-slot initializer region, we now have one for
the whole module. For example, it might look like this:
```
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor)
@list(%1 : !torch.list<tensor>)
]
}
```
This new structure allows GlobalizeObjectGraph to create the initializer in a
much simpler way, avoiding the need to reason about whether different slots
alias each other. Reasoning about whether slots alias each other now is the
responsibility of InlineGlobalSlots, which has to do a much more complicated
analysis, implemented using MLIR's dataflow analysis framework.
Recommended review order:
- Check out the new IR constructs in the .mlir files of various passes
- Op definitions (*.td)
- Changes to GlobalizeObjectGraph pass.
- InlineGlobalSlots pass (~total rewrite)
- Misc changes:
- Moving torchMlirAdjustStaticInformation for sharing with C++ code.
- EraseModuleInitializer pass
To make this a bit nicer, it would be good to have a `torch.module` op
with an initializer region attached. That would be more invasive though.
This change has highlighted certain aspects of our project layering
which are worth calling out. None of our backends can handle global
slots, so we enforce that there are no global slots before backend
lowering. At an earlier stage in the project, we had aspirations of
transparently handling mutable global state and such, but for reasons
described below, that is no longer a goal. So really global slots should
be seen as a progressive lowering step as part of inlining all the
IValue's in the original program (GlobalizeObjectGraph is also one such
step).
Over time, with insights from work like IREE-JAX, it has become clear
that there isn't a reliable programming model we can compile for users
where we just transparently handle mutable global state (and some other
things, like lists and dictionaries). There is a need for an "outer
program" that orchestrates more restricted subroutines of the kind we
can handle in our compile flow here. The benefit of that is that it
decouples considerations like shapes, dtypes, etc. from the program
constructs used in the outer program. As long as the outer program can
efficiently invoke (pipelining/async/etc.) high-performance
data-parallel numerical subroutines of the kind we compile in our flow
here, then there is a complete programming model. This is also
consistent with the direction of upstream PyTorch which is becoming more
tracing-based (which inherently loses a lot of program structure, which
then has to be applied back with an "outer program" orchestrating the
traced subroutines).
2022-07-14 02:45:56 +08:00
|
|
|
auto wasInserted = initializedGlobalSlots
|
2024-04-11 21:47:35 +08:00
|
|
|
.insert(cast<FlatSymbolRefAttr>(symName).getAttr())
|
Rework how global slot initializers work.
Rather than a per-global-slot initializer region, we now have one for
the whole module. For example, it might look like this:
```
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor)
@list(%1 : !torch.list<tensor>)
]
}
```
This new structure allows GlobalizeObjectGraph to create the initializer in a
much simpler way, avoiding the need to reason about whether different slots
alias each other. Reasoning about whether slots alias each other now is the
responsibility of InlineGlobalSlots, which has to do a much more complicated
analysis, implemented using MLIR's dataflow analysis framework.
Recommended review order:
- Check out the new IR constructs in the .mlir files of various passes
- Op definitions (*.td)
- Changes to GlobalizeObjectGraph pass.
- InlineGlobalSlots pass (~total rewrite)
- Misc changes:
- Moving torchMlirAdjustStaticInformation for sharing with C++ code.
- EraseModuleInitializer pass
To make this a bit nicer, it would be good to have a `torch.module` op
with an initializer region attached. That would be more invasive though.
This change has highlighted certain aspects of our project layering
which are worth calling out. None of our backends can handle global
slots, so we enforce that there are no global slots before backend
lowering. At an earlier stage in the project, we had aspirations of
transparently handling mutable global state and such, but for reasons
described below, that is no longer a goal. So really global slots should
be seen as a progressive lowering step as part of inlining all the
IValue's in the original program (GlobalizeObjectGraph is also one such
step).
Over time, with insights from work like IREE-JAX, it has become clear
that there isn't a reliable programming model we can compile for users
where we just transparently handle mutable global state (and some other
things, like lists and dictionaries). There is a need for an "outer
program" that orchestrates more restricted subroutines of the kind we
can handle in our compile flow here. The benefit of that is that it
decouples considerations like shapes, dtypes, etc. from the program
constructs used in the outer program. As long as the outer program can
efficiently invoke (pipelining/async/etc.) high-performance
data-parallel numerical subroutines of the kind we compile in our flow
here, then there is a complete programming model. This is also
consistent with the direction of upstream PyTorch which is becoming more
tracing-based (which inherently loses a lot of program structure, which
then has to be applied back with an "outer program" orchestrating the
traced subroutines).
2022-07-14 02:45:56 +08:00
|
|
|
.second;
|
|
|
|
if (!wasInserted)
|
|
|
|
return initialize.emitError("duplicate initialization of global slot: ")
|
|
|
|
<< symName;
|
|
|
|
}
|
|
|
|
auto lessThanByStringValue = [](Attribute lhs, Attribute rhs) {
|
2024-04-11 21:47:35 +08:00
|
|
|
return cast<StringAttr>(lhs).getValue() < cast<StringAttr>(rhs).getValue();
|
Rework how global slot initializers work.
Rather than a per-global-slot initializer region, we now have one for
the whole module. For example, it might look like this:
```
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor)
@list(%1 : !torch.list<tensor>)
]
}
```
This new structure allows GlobalizeObjectGraph to create the initializer in a
much simpler way, avoiding the need to reason about whether different slots
alias each other. Reasoning about whether slots alias each other now is the
responsibility of InlineGlobalSlots, which has to do a much more complicated
analysis, implemented using MLIR's dataflow analysis framework.
Recommended review order:
- Check out the new IR constructs in the .mlir files of various passes
- Op definitions (*.td)
- Changes to GlobalizeObjectGraph pass.
- InlineGlobalSlots pass (~total rewrite)
- Misc changes:
- Moving torchMlirAdjustStaticInformation for sharing with C++ code.
- EraseModuleInitializer pass
To make this a bit nicer, it would be good to have a `torch.module` op
with an initializer region attached. That would be more invasive though.
This change has highlighted certain aspects of our project layering
which are worth calling out. None of our backends can handle global
slots, so we enforce that there are no global slots before backend
lowering. At an earlier stage in the project, we had aspirations of
transparently handling mutable global state and such, but for reasons
described below, that is no longer a goal. So really global slots should
be seen as a progressive lowering step as part of inlining all the
IValue's in the original program (GlobalizeObjectGraph is also one such
step).
Over time, with insights from work like IREE-JAX, it has become clear
that there isn't a reliable programming model we can compile for users
where we just transparently handle mutable global state (and some other
things, like lists and dictionaries). There is a need for an "outer
program" that orchestrates more restricted subroutines of the kind we
can handle in our compile flow here. The benefit of that is that it
decouples considerations like shapes, dtypes, etc. from the program
constructs used in the outer program. As long as the outer program can
efficiently invoke (pipelining/async/etc.) high-performance
data-parallel numerical subroutines of the kind we compile in our flow
here, then there is a complete programming model. This is also
consistent with the direction of upstream PyTorch which is becoming more
tracing-based (which inherently loses a lot of program structure, which
then has to be applied back with an "outer program" orchestrating the
traced subroutines).
2022-07-14 02:45:56 +08:00
|
|
|
};
|
|
|
|
auto known = llvm::to_vector(knownGlobalSlots);
|
|
|
|
llvm::sort(known, lessThanByStringValue);
|
|
|
|
auto initialized = llvm::to_vector(initializedGlobalSlots);
|
|
|
|
llvm::sort(initialized, lessThanByStringValue);
|
|
|
|
|
|
|
|
// Check that the global slots in the module are all initialized.
|
|
|
|
SymbolTable symbolTable(module);
|
|
|
|
if (initializedGlobalSlots != knownGlobalSlots) {
|
|
|
|
InFlightDiagnostic diag = initialize.emitOpError(
|
|
|
|
"must have one initializer for each global slot in the module");
|
|
|
|
for (auto knownGlobalSlot : known) {
|
2024-04-11 21:47:35 +08:00
|
|
|
auto symName = FlatSymbolRefAttr::get(cast<StringAttr>(knownGlobalSlot));
|
Rework how global slot initializers work.
Rather than a per-global-slot initializer region, we now have one for
the whole module. For example, it might look like this:
```
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor)
@list(%1 : !torch.list<tensor>)
]
}
```
This new structure allows GlobalizeObjectGraph to create the initializer in a
much simpler way, avoiding the need to reason about whether different slots
alias each other. Reasoning about whether slots alias each other now is the
responsibility of InlineGlobalSlots, which has to do a much more complicated
analysis, implemented using MLIR's dataflow analysis framework.
Recommended review order:
- Check out the new IR constructs in the .mlir files of various passes
- Op definitions (*.td)
- Changes to GlobalizeObjectGraph pass.
- InlineGlobalSlots pass (~total rewrite)
- Misc changes:
- Moving torchMlirAdjustStaticInformation for sharing with C++ code.
- EraseModuleInitializer pass
To make this a bit nicer, it would be good to have a `torch.module` op
with an initializer region attached. That would be more invasive though.
This change has highlighted certain aspects of our project layering
which are worth calling out. None of our backends can handle global
slots, so we enforce that there are no global slots before backend
lowering. At an earlier stage in the project, we had aspirations of
transparently handling mutable global state and such, but for reasons
described below, that is no longer a goal. So really global slots should
be seen as a progressive lowering step as part of inlining all the
IValue's in the original program (GlobalizeObjectGraph is also one such
step).
Over time, with insights from work like IREE-JAX, it has become clear
that there isn't a reliable programming model we can compile for users
where we just transparently handle mutable global state (and some other
things, like lists and dictionaries). There is a need for an "outer
program" that orchestrates more restricted subroutines of the kind we
can handle in our compile flow here. The benefit of that is that it
decouples considerations like shapes, dtypes, etc. from the program
constructs used in the outer program. As long as the outer program can
efficiently invoke (pipelining/async/etc.) high-performance
data-parallel numerical subroutines of the kind we compile in our flow
here, then there is a complete programming model. This is also
consistent with the direction of upstream PyTorch which is becoming more
tracing-based (which inherently loses a lot of program structure, which
then has to be applied back with an "outer program" orchestrating the
traced subroutines).
2022-07-14 02:45:56 +08:00
|
|
|
if (!initializedGlobalSlots.count(knownGlobalSlot)) {
|
|
|
|
diag.attachNote(
|
|
|
|
symbolTable.lookup<GlobalSlotOp>(symName.getAttr()).getLoc())
|
|
|
|
.append("missing global slot initializer for ", symName);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for (auto initializedGlobalSlot : initialized) {
|
|
|
|
if (!knownGlobalSlots.count(initializedGlobalSlot)) {
|
|
|
|
diag.attachNote().append(
|
|
|
|
"unexpected global slot initializer for non-existent global slot ",
|
2024-04-11 21:47:35 +08:00
|
|
|
FlatSymbolRefAttr::get(cast<StringAttr>(initializedGlobalSlot)));
|
Rework how global slot initializers work.
Rather than a per-global-slot initializer region, we now have one for
the whole module. For example, it might look like this:
```
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor)
@list(%1 : !torch.list<tensor>)
]
}
```
This new structure allows GlobalizeObjectGraph to create the initializer in a
much simpler way, avoiding the need to reason about whether different slots
alias each other. Reasoning about whether slots alias each other now is the
responsibility of InlineGlobalSlots, which has to do a much more complicated
analysis, implemented using MLIR's dataflow analysis framework.
Recommended review order:
- Check out the new IR constructs in the .mlir files of various passes
- Op definitions (*.td)
- Changes to GlobalizeObjectGraph pass.
- InlineGlobalSlots pass (~total rewrite)
- Misc changes:
- Moving torchMlirAdjustStaticInformation for sharing with C++ code.
- EraseModuleInitializer pass
To make this a bit nicer, it would be good to have a `torch.module` op
with an initializer region attached. That would be more invasive though.
This change has highlighted certain aspects of our project layering
which are worth calling out. None of our backends can handle global
slots, so we enforce that there are no global slots before backend
lowering. At an earlier stage in the project, we had aspirations of
transparently handling mutable global state and such, but for reasons
described below, that is no longer a goal. So really global slots should
be seen as a progressive lowering step as part of inlining all the
IValue's in the original program (GlobalizeObjectGraph is also one such
step).
Over time, with insights from work like IREE-JAX, it has become clear
that there isn't a reliable programming model we can compile for users
where we just transparently handle mutable global state (and some other
things, like lists and dictionaries). There is a need for an "outer
program" that orchestrates more restricted subroutines of the kind we
can handle in our compile flow here. The benefit of that is that it
decouples considerations like shapes, dtypes, etc. from the program
constructs used in the outer program. As long as the outer program can
efficiently invoke (pipelining/async/etc.) high-performance
data-parallel numerical subroutines of the kind we compile in our flow
here, then there is a complete programming model. This is also
consistent with the direction of upstream PyTorch which is becoming more
tracing-based (which inherently loses a lot of program structure, which
then has to be applied back with an "outer program" orchestrating the
traced subroutines).
2022-07-14 02:45:56 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
return diag;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Check that initial values satisfy type bounds.
|
|
|
|
for (int i = 0, e = initialize.getNumOperands(); i < e; ++i) {
|
2024-04-28 05:00:56 +08:00
|
|
|
auto symName = cast<FlatSymbolRefAttr>(initialize.getSlotSymNames()[i]);
|
Rework how global slot initializers work.
Rather than a per-global-slot initializer region, we now have one for
the whole module. For example, it might look like this:
```
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor)
@list(%1 : !torch.list<tensor>)
]
}
```
This new structure allows GlobalizeObjectGraph to create the initializer in a
much simpler way, avoiding the need to reason about whether different slots
alias each other. Reasoning about whether slots alias each other now is the
responsibility of InlineGlobalSlots, which has to do a much more complicated
analysis, implemented using MLIR's dataflow analysis framework.
Recommended review order:
- Check out the new IR constructs in the .mlir files of various passes
- Op definitions (*.td)
- Changes to GlobalizeObjectGraph pass.
- InlineGlobalSlots pass (~total rewrite)
- Misc changes:
- Moving torchMlirAdjustStaticInformation for sharing with C++ code.
- EraseModuleInitializer pass
To make this a bit nicer, it would be good to have a `torch.module` op
with an initializer region attached. That would be more invasive though.
This change has highlighted certain aspects of our project layering
which are worth calling out. None of our backends can handle global
slots, so we enforce that there are no global slots before backend
lowering. At an earlier stage in the project, we had aspirations of
transparently handling mutable global state and such, but for reasons
described below, that is no longer a goal. So really global slots should
be seen as a progressive lowering step as part of inlining all the
IValue's in the original program (GlobalizeObjectGraph is also one such
step).
Over time, with insights from work like IREE-JAX, it has become clear
that there isn't a reliable programming model we can compile for users
where we just transparently handle mutable global state (and some other
things, like lists and dictionaries). There is a need for an "outer
program" that orchestrates more restricted subroutines of the kind we
can handle in our compile flow here. The benefit of that is that it
decouples considerations like shapes, dtypes, etc. from the program
constructs used in the outer program. As long as the outer program can
efficiently invoke (pipelining/async/etc.) high-performance
data-parallel numerical subroutines of the kind we compile in our flow
here, then there is a complete programming model. This is also
consistent with the direction of upstream PyTorch which is becoming more
tracing-based (which inherently loses a lot of program structure, which
then has to be applied back with an "outer program" orchestrating the
traced subroutines).
2022-07-14 02:45:56 +08:00
|
|
|
auto initialValue = initialize.getOperand(i);
|
|
|
|
auto globalSlotOp = symbolTable.lookup<GlobalSlotOp>(symName.getValue());
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!isValidSubtype(initialValue.getType(), globalSlotOp.getTypeBound())) {
|
Rework how global slot initializers work.
Rather than a per-global-slot initializer region, we now have one for
the whole module. For example, it might look like this:
```
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor)
@list(%1 : !torch.list<tensor>)
]
}
```
This new structure allows GlobalizeObjectGraph to create the initializer in a
much simpler way, avoiding the need to reason about whether different slots
alias each other. Reasoning about whether slots alias each other now is the
responsibility of InlineGlobalSlots, which has to do a much more complicated
analysis, implemented using MLIR's dataflow analysis framework.
Recommended review order:
- Check out the new IR constructs in the .mlir files of various passes
- Op definitions (*.td)
- Changes to GlobalizeObjectGraph pass.
- InlineGlobalSlots pass (~total rewrite)
- Misc changes:
- Moving torchMlirAdjustStaticInformation for sharing with C++ code.
- EraseModuleInitializer pass
To make this a bit nicer, it would be good to have a `torch.module` op
with an initializer region attached. That would be more invasive though.
This change has highlighted certain aspects of our project layering
which are worth calling out. None of our backends can handle global
slots, so we enforce that there are no global slots before backend
lowering. At an earlier stage in the project, we had aspirations of
transparently handling mutable global state and such, but for reasons
described below, that is no longer a goal. So really global slots should
be seen as a progressive lowering step as part of inlining all the
IValue's in the original program (GlobalizeObjectGraph is also one such
step).
Over time, with insights from work like IREE-JAX, it has become clear
that there isn't a reliable programming model we can compile for users
where we just transparently handle mutable global state (and some other
things, like lists and dictionaries). There is a need for an "outer
program" that orchestrates more restricted subroutines of the kind we
can handle in our compile flow here. The benefit of that is that it
decouples considerations like shapes, dtypes, etc. from the program
constructs used in the outer program. As long as the outer program can
efficiently invoke (pipelining/async/etc.) high-performance
data-parallel numerical subroutines of the kind we compile in our flow
here, then there is a complete programming model. This is also
consistent with the direction of upstream PyTorch which is becoming more
tracing-based (which inherently loses a lot of program structure, which
then has to be applied back with an "outer program" orchestrating the
traced subroutines).
2022-07-14 02:45:56 +08:00
|
|
|
return initialize.emitOpError().append(
|
|
|
|
"initial value for global slot ", symName, " has type ",
|
|
|
|
initialValue.getType(), " which is not within the bound ",
|
2022-12-08 04:20:41 +08:00
|
|
|
globalSlotOp.getTypeBound());
|
Rework how global slot initializers work.
Rather than a per-global-slot initializer region, we now have one for
the whole module. For example, it might look like this:
```
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor)
@list(%1 : !torch.list<tensor>)
]
}
```
This new structure allows GlobalizeObjectGraph to create the initializer in a
much simpler way, avoiding the need to reason about whether different slots
alias each other. Reasoning about whether slots alias each other now is the
responsibility of InlineGlobalSlots, which has to do a much more complicated
analysis, implemented using MLIR's dataflow analysis framework.
Recommended review order:
- Check out the new IR constructs in the .mlir files of various passes
- Op definitions (*.td)
- Changes to GlobalizeObjectGraph pass.
- InlineGlobalSlots pass (~total rewrite)
- Misc changes:
- Moving torchMlirAdjustStaticInformation for sharing with C++ code.
- EraseModuleInitializer pass
To make this a bit nicer, it would be good to have a `torch.module` op
with an initializer region attached. That would be more invasive though.
This change has highlighted certain aspects of our project layering
which are worth calling out. None of our backends can handle global
slots, so we enforce that there are no global slots before backend
lowering. At an earlier stage in the project, we had aspirations of
transparently handling mutable global state and such, but for reasons
described below, that is no longer a goal. So really global slots should
be seen as a progressive lowering step as part of inlining all the
IValue's in the original program (GlobalizeObjectGraph is also one such
step).
Over time, with insights from work like IREE-JAX, it has become clear
that there isn't a reliable programming model we can compile for users
where we just transparently handle mutable global state (and some other
things, like lists and dictionaries). There is a need for an "outer
program" that orchestrates more restricted subroutines of the kind we
can handle in our compile flow here. The benefit of that is that it
decouples considerations like shapes, dtypes, etc. from the program
constructs used in the outer program. As long as the outer program can
efficiently invoke (pipelining/async/etc.) high-performance
data-parallel numerical subroutines of the kind we compile in our flow
here, then there is a complete programming model. This is also
consistent with the direction of upstream PyTorch which is becoming more
tracing-based (which inherently loses a lot of program structure, which
then has to be applied back with an "outer program" orchestrating the
traced subroutines).
2022-07-14 02:45:56 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
auto walkResult = getOperation()->walk([](Operation *op) {
|
|
|
|
// We only permit a small set of ops in the module initializer.
|
|
|
|
// These ops are essentially those which can be produced by the IValue
|
|
|
|
// importer.
|
2022-09-20 05:56:35 +08:00
|
|
|
if (op->hasTrait<mlir::torch::Torch::OpTrait::AllowedInModuleInitializer>())
|
Rework how global slot initializers work.
Rather than a per-global-slot initializer region, we now have one for
the whole module. For example, it might look like this:
```
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor)
@list(%1 : !torch.list<tensor>)
]
}
```
This new structure allows GlobalizeObjectGraph to create the initializer in a
much simpler way, avoiding the need to reason about whether different slots
alias each other. Reasoning about whether slots alias each other now is the
responsibility of InlineGlobalSlots, which has to do a much more complicated
analysis, implemented using MLIR's dataflow analysis framework.
Recommended review order:
- Check out the new IR constructs in the .mlir files of various passes
- Op definitions (*.td)
- Changes to GlobalizeObjectGraph pass.
- InlineGlobalSlots pass (~total rewrite)
- Misc changes:
- Moving torchMlirAdjustStaticInformation for sharing with C++ code.
- EraseModuleInitializer pass
To make this a bit nicer, it would be good to have a `torch.module` op
with an initializer region attached. That would be more invasive though.
This change has highlighted certain aspects of our project layering
which are worth calling out. None of our backends can handle global
slots, so we enforce that there are no global slots before backend
lowering. At an earlier stage in the project, we had aspirations of
transparently handling mutable global state and such, but for reasons
described below, that is no longer a goal. So really global slots should
be seen as a progressive lowering step as part of inlining all the
IValue's in the original program (GlobalizeObjectGraph is also one such
step).
Over time, with insights from work like IREE-JAX, it has become clear
that there isn't a reliable programming model we can compile for users
where we just transparently handle mutable global state (and some other
things, like lists and dictionaries). There is a need for an "outer
program" that orchestrates more restricted subroutines of the kind we
can handle in our compile flow here. The benefit of that is that it
decouples considerations like shapes, dtypes, etc. from the program
constructs used in the outer program. As long as the outer program can
efficiently invoke (pipelining/async/etc.) high-performance
data-parallel numerical subroutines of the kind we compile in our flow
here, then there is a complete programming model. This is also
consistent with the direction of upstream PyTorch which is becoming more
tracing-based (which inherently loses a lot of program structure, which
then has to be applied back with an "outer program" orchestrating the
traced subroutines).
2022-07-14 02:45:56 +08:00
|
|
|
return WalkResult::advance();
|
|
|
|
op->emitOpError() << "is not allowed in a module initializer";
|
|
|
|
return WalkResult::interrupt();
|
|
|
|
});
|
|
|
|
if (walkResult.wasInterrupted())
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// InitializeGlobalSlotsOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
ParseResult InitializeGlobalSlotsOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
|
|
|
if (parser.parseOptionalAttrDict(result.attributes))
|
|
|
|
return failure();
|
|
|
|
if (parser.parseLSquare())
|
|
|
|
return failure();
|
|
|
|
SmallVector<Attribute> slotSymNames;
|
|
|
|
while (!succeeded(parser.parseOptionalRSquare())) {
|
|
|
|
NamedAttrList dummy;
|
|
|
|
StringAttr slotSymName;
|
|
|
|
if (parser.parseSymbolName(slotSymName, "dummy", dummy))
|
|
|
|
return failure();
|
|
|
|
slotSymNames.push_back(FlatSymbolRefAttr::get(slotSymName));
|
|
|
|
if (parser.parseLParen())
|
|
|
|
return failure();
|
|
|
|
OpAsmParser::UnresolvedOperand initialValue;
|
|
|
|
if (parser.parseOperand(initialValue))
|
|
|
|
return failure();
|
|
|
|
Type initialValueType;
|
|
|
|
if (parser.parseColonType(initialValueType))
|
|
|
|
return failure();
|
|
|
|
if (parser.parseRParen())
|
|
|
|
return failure();
|
|
|
|
if (parser.resolveOperand(initialValue, initialValueType, result.operands))
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
result.addAttribute("slotSymNames",
|
|
|
|
ArrayAttr::get(parser.getContext(), slotSymNames));
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
void InitializeGlobalSlotsOp::print(OpAsmPrinter &p) {
|
|
|
|
p.printOptionalAttrDict(getOperation()->getAttrs(),
|
|
|
|
/*elidedAttrs=*/{"slotSymNames"});
|
|
|
|
p << " [";
|
|
|
|
p.printNewline();
|
|
|
|
for (int i = 0, e = getNumOperands(); i < e; ++i) {
|
2022-12-08 04:20:41 +08:00
|
|
|
p << " " << getSlotSymNames()[i] << "(" << getInitialValues()[i] << " : "
|
|
|
|
<< getInitialValues()[i].getType() << ")";
|
Rework how global slot initializers work.
Rather than a per-global-slot initializer region, we now have one for
the whole module. For example, it might look like this:
```
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor)
@list(%1 : !torch.list<tensor>)
]
}
```
This new structure allows GlobalizeObjectGraph to create the initializer in a
much simpler way, avoiding the need to reason about whether different slots
alias each other. Reasoning about whether slots alias each other now is the
responsibility of InlineGlobalSlots, which has to do a much more complicated
analysis, implemented using MLIR's dataflow analysis framework.
Recommended review order:
- Check out the new IR constructs in the .mlir files of various passes
- Op definitions (*.td)
- Changes to GlobalizeObjectGraph pass.
- InlineGlobalSlots pass (~total rewrite)
- Misc changes:
- Moving torchMlirAdjustStaticInformation for sharing with C++ code.
- EraseModuleInitializer pass
To make this a bit nicer, it would be good to have a `torch.module` op
with an initializer region attached. That would be more invasive though.
This change has highlighted certain aspects of our project layering
which are worth calling out. None of our backends can handle global
slots, so we enforce that there are no global slots before backend
lowering. At an earlier stage in the project, we had aspirations of
transparently handling mutable global state and such, but for reasons
described below, that is no longer a goal. So really global slots should
be seen as a progressive lowering step as part of inlining all the
IValue's in the original program (GlobalizeObjectGraph is also one such
step).
Over time, with insights from work like IREE-JAX, it has become clear
that there isn't a reliable programming model we can compile for users
where we just transparently handle mutable global state (and some other
things, like lists and dictionaries). There is a need for an "outer
program" that orchestrates more restricted subroutines of the kind we
can handle in our compile flow here. The benefit of that is that it
decouples considerations like shapes, dtypes, etc. from the program
constructs used in the outer program. As long as the outer program can
efficiently invoke (pipelining/async/etc.) high-performance
data-parallel numerical subroutines of the kind we compile in our flow
here, then there is a complete programming model. This is also
consistent with the direction of upstream PyTorch which is becoming more
tracing-based (which inherently loses a lot of program structure, which
then has to be applied back with an "outer program" orchestrating the
traced subroutines).
2022-07-14 02:45:56 +08:00
|
|
|
p.printNewline();
|
|
|
|
}
|
|
|
|
p << "]";
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult InitializeGlobalSlotsOp::verify() {
|
2022-12-08 04:20:41 +08:00
|
|
|
if (getInitialValues().size() != getSlotSymNames().size())
|
Rework how global slot initializers work.
Rather than a per-global-slot initializer region, we now have one for
the whole module. For example, it might look like this:
```
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor)
@list(%1 : !torch.list<tensor>)
]
}
```
This new structure allows GlobalizeObjectGraph to create the initializer in a
much simpler way, avoiding the need to reason about whether different slots
alias each other. Reasoning about whether slots alias each other now is the
responsibility of InlineGlobalSlots, which has to do a much more complicated
analysis, implemented using MLIR's dataflow analysis framework.
Recommended review order:
- Check out the new IR constructs in the .mlir files of various passes
- Op definitions (*.td)
- Changes to GlobalizeObjectGraph pass.
- InlineGlobalSlots pass (~total rewrite)
- Misc changes:
- Moving torchMlirAdjustStaticInformation for sharing with C++ code.
- EraseModuleInitializer pass
To make this a bit nicer, it would be good to have a `torch.module` op
with an initializer region attached. That would be more invasive though.
This change has highlighted certain aspects of our project layering
which are worth calling out. None of our backends can handle global
slots, so we enforce that there are no global slots before backend
lowering. At an earlier stage in the project, we had aspirations of
transparently handling mutable global state and such, but for reasons
described below, that is no longer a goal. So really global slots should
be seen as a progressive lowering step as part of inlining all the
IValue's in the original program (GlobalizeObjectGraph is also one such
step).
Over time, with insights from work like IREE-JAX, it has become clear
that there isn't a reliable programming model we can compile for users
where we just transparently handle mutable global state (and some other
things, like lists and dictionaries). There is a need for an "outer
program" that orchestrates more restricted subroutines of the kind we
can handle in our compile flow here. The benefit of that is that it
decouples considerations like shapes, dtypes, etc. from the program
constructs used in the outer program. As long as the outer program can
efficiently invoke (pipelining/async/etc.) high-performance
data-parallel numerical subroutines of the kind we compile in our flow
here, then there is a complete programming model. This is also
consistent with the direction of upstream PyTorch which is becoming more
tracing-based (which inherently loses a lot of program structure, which
then has to be applied back with an "outer program" orchestrating the
traced subroutines).
2022-07-14 02:45:56 +08:00
|
|
|
return emitOpError("expected number of operands to match number of slots");
|
|
|
|
return success();
|
|
|
|
}
|
Representing Symbolic Shape Expressions in Torch Dialect (#3372)
Torch Dialect with symbolic shape expressions:
```ll
module {
func.func @main(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> {
%0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int
%1 = torch.symbolic_int "s1" {min_val = 0, max_val = 100} : !torch.int
%2 = torch.symbolic_int "s3" {min_val = 0, max_val = 50} : !torch.int
torch.bind_symbolic_shape %arg0, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %arg1, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
%3 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %3, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
%4 = torch.aten.sigmoid %arg1 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %4, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
%5 = torch.prim.ListConstruct %3, %3, %4 : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor>
%int1 = torch.constant.int 1
%6 = torch.aten.cat %5, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %6, [%0, %1, %2], #affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32>
return %6 : !torch.vtensor<[?,?,3],f32>
}
}
```
For reference, this is the TorchDynamo exported program with symbolic
shape expressions that the above Torch dialect program is imported from:
```py
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s0, s1, 3]", y: "f32[s0, s3, 3]"):
# File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:31 in forward, code: a = torch.tanh(x)
tanh: "f32[s0, s1, 3]" = torch.ops.aten.tanh.default(x); x = None
# File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:32 in forward, code: b = torch.sigmoid(y)
sigmoid: "f32[s0, s3, 3]" = torch.ops.aten.sigmoid.default(y); y = None
# File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:33 in forward, code: return torch.cat((a, a, b), dim=1)
cat: "f32[s0, 2*s1 + s3, 3]" = torch.ops.aten.cat.default([tanh, tanh, sigmoid], 1); tanh = sigmoid = None
return (cat,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat'), target=None)])
Range constraints: {s0: ValueRanges(lower=5, upper=10, is_bool=False), s1: ValueRanges(lower=0, upper=100, is_bool=False), s3: ValueRanges(lower=0, upper=50, is_bool=False)}
```
Huge credit to @stellaraccident for the inputs that helped evaluate the
various design options and arrive at the representation of choice.
- [x] Op definitions for symbolic_int and bind_symbolic_shape ops
- [x] fx_importer updates to import range constraints + create
symbolic_int ops
- [x] fx_importer changes for AffineMapAttr building + adding
bind_symbolic_shape ops
- [x] custom printer/parser for inlined AffineMap expressions in mlir
assembly
- [x] Dialect lit test
- [x] fx_importer python lit tests
- [ ] Cleanup pass to remove these ops (can add in a follow-on)
2024-06-07 19:04:03 +08:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// BindSymbolicShapeOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
//
|
|
|
|
// torch.bind_symbolic_shape %6, [%0, %1, %2], affine_map<()[s0, s1, s2] ->
|
|
|
|
// (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32>
|
|
|
|
//
|
|
|
|
|
|
|
|
ParseResult BindSymbolicShapeOp::parse(OpAsmParser &parser,
|
|
|
|
OperationState &result) {
|
|
|
|
OpAsmParser::UnresolvedOperand operand;
|
|
|
|
SmallVector<OpAsmParser::UnresolvedOperand> shapeSymbols;
|
|
|
|
AffineMapAttr shapeExpressions;
|
|
|
|
Type operandType;
|
|
|
|
|
|
|
|
if (parser.parseOperand(operand) || parser.parseComma() ||
|
|
|
|
parser.parseLSquare() || parser.parseOperandList(shapeSymbols) ||
|
|
|
|
parser.parseRSquare() || parser.parseComma() ||
|
|
|
|
parser.parseAttribute(shapeExpressions, "shape_expressions",
|
|
|
|
result.attributes) ||
|
|
|
|
parser.parseOptionalAttrDict(result.attributes) ||
|
|
|
|
parser.parseColonType(operandType)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
if (parser.resolveOperand(operand, operandType, result.operands) ||
|
|
|
|
parser.resolveOperands(shapeSymbols,
|
|
|
|
parser.getBuilder().getType<Torch::IntType>(),
|
|
|
|
result.operands)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Use a custom printer here to avoid the AffineMap from getting hoisted
|
|
|
|
// when printed. This makes it so the AffineMap is printed inline with the op.
|
|
|
|
void BindSymbolicShapeOp::print(OpAsmPrinter &p) {
|
|
|
|
p << " " << getOperand() << ", [";
|
|
|
|
llvm::interleaveComma(getShapeSymbols(), p);
|
|
|
|
p << "], " << "affine_map<" << getShapeExpressions().getValue() << ">";
|
|
|
|
p.printOptionalAttrDict((*this)->getAttrs(),
|
|
|
|
/*elidedAttrs=*/{"shape_expressions"});
|
|
|
|
p << " : " << getOperand().getType();
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult BindSymbolicShapeOp::verify() {
|
|
|
|
if (getShapeSymbols().empty())
|
|
|
|
return emitOpError() << "requires non-empty shapeSymbols";
|
|
|
|
|
|
|
|
for (auto symbol : getShapeSymbols()) {
|
|
|
|
Operation *definingOp = symbol.getDefiningOp();
|
|
|
|
if (!isa<SymbolicIntOp>(definingOp)) {
|
|
|
|
return emitOpError()
|
|
|
|
<< "shape symbol must be produced by a SymbolicIntOp";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
2024-06-22 07:16:38 +08:00
|
|
|
// AtenTriuIndicesOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult AtenTriuIndicesOp::verify() {
|
|
|
|
|
|
|
|
// Check if row, col and offset are constant ints
|
|
|
|
int64_t row;
|
|
|
|
if (!matchPattern(getRow(), m_TorchConstantInt(&row)))
|
|
|
|
return success();
|
|
|
|
|
|
|
|
int64_t col;
|
|
|
|
if (!matchPattern(getCol(), m_TorchConstantInt(&col)))
|
|
|
|
return success();
|
|
|
|
|
|
|
|
int64_t offset;
|
|
|
|
if (!matchPattern(getOffset(), m_TorchConstantInt(&offset)))
|
|
|
|
return success();
|
|
|
|
|
|
|
|
// Check if values of row, and col are valid
|
|
|
|
if (row < 0)
|
|
|
|
return emitOpError("row must be non-negative, got ") << row;
|
|
|
|
|
|
|
|
if (col < 0)
|
|
|
|
return emitOpError("col must be non-negative, got ") << col;
|
|
|
|
|
|
|
|
// Check if dtype is valid
|
|
|
|
int64_t dtype;
|
|
|
|
if (!matchPattern(getDtype(), m_TorchConstantInt(&dtype)))
|
|
|
|
return success();
|
|
|
|
if (dtype != (int)torch_upstream::ScalarType::Int &&
|
|
|
|
dtype != (int)torch_upstream::ScalarType::Long)
|
|
|
|
return emitOpError(
|
|
|
|
"'triu_indices' implemented only for torch.int32 and torch.int64");
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
2024-07-18 21:08:12 +08:00
|
|
|
|
|
|
|
// AtenTrilIndicesOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult AtenTrilIndicesOp::verify() {
|
|
|
|
|
|
|
|
// Check if row, col and offset are constant ints
|
|
|
|
int64_t row;
|
|
|
|
if (!matchPattern(getRow(), m_TorchConstantInt(&row)))
|
|
|
|
return success();
|
|
|
|
|
|
|
|
int64_t col;
|
|
|
|
if (!matchPattern(getCol(), m_TorchConstantInt(&col)))
|
|
|
|
return success();
|
|
|
|
|
|
|
|
int64_t offset;
|
|
|
|
if (!matchPattern(getOffset(), m_TorchConstantInt(&offset)))
|
|
|
|
return success();
|
|
|
|
|
|
|
|
// Check if values of row, and col are valid
|
|
|
|
if (row < 0)
|
|
|
|
return emitOpError("row must be non-negative, got ") << row;
|
|
|
|
|
|
|
|
if (col < 0)
|
|
|
|
return emitOpError("col must be non-negative, got ") << col;
|
|
|
|
|
|
|
|
// Check if dtype is valid
|
|
|
|
int64_t dtype;
|
|
|
|
if (!matchPattern(getDtype(), m_TorchConstantInt(&dtype)))
|
|
|
|
return success();
|
|
|
|
if (dtype != (int)torch_upstream::ScalarType::Int &&
|
|
|
|
dtype != (int)torch_upstream::ScalarType::Long)
|
|
|
|
return emitOpError(
|
|
|
|
"'triu_indices' implemented only for torch.int32 and torch.int64");
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
2024-09-06 13:06:17 +08:00
|
|
|
|
|
|
|
// AtenRot90Op
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult AtenRot90Op::verify() {
|
|
|
|
// Check rotation dimensions.
|
|
|
|
SmallVector<Value> dims;
|
|
|
|
if (!getListConstructElements(getDims(), dims))
|
|
|
|
return success();
|
|
|
|
|
|
|
|
if (dims.size() != 2)
|
|
|
|
return emitOpError("expected total rotation dims == 2, but got dims = ")
|
|
|
|
<< dims.size();
|
|
|
|
|
|
|
|
// Check a rank of the input tensor.
|
|
|
|
auto selfType = cast<BaseTensorType>(getSelf().getType());
|
|
|
|
if (!selfType.hasSizes())
|
|
|
|
return success();
|
|
|
|
|
|
|
|
auto selfShape = selfType.getSizes();
|
|
|
|
int64_t selfRank = selfShape.size();
|
|
|
|
|
|
|
|
if (selfRank < 2)
|
|
|
|
return emitOpError("expected total dims >= 2, but got total dims = ")
|
|
|
|
<< selfRank;
|
|
|
|
|
|
|
|
if (dims[0] == dims[1])
|
|
|
|
return emitOpError(
|
|
|
|
"expected rotation dims to be different, but got dim0 = ")
|
|
|
|
<< dims[0] << " and dim1 = " << dims[1];
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|