2020-09-29 03:02:35 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
2021-09-30 00:03:40 +08:00
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
2020-09-29 03:02:35 +08:00
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
[torch-mlir earthmoving (1/N)] C/C++ code movement.
This creates the `external/torch-mlir` directory as an
LLVM_EXTERNAL_PROJECTS-compatible project (analogous to
`iree-dialects`) and completes movement/rename of all pure MLIR C/C++
compiler code into there. The next step will be to move all the Python
code / code that links/includes PyTorch C++ code (which currently lives
in `frontends/pytorch`) into a subdirectory here.
I call this "earthmoving" because it is mostly mechanical changes and
renames. As a quick summary (we can change this down the road easily)
- C++ `mlir::NPCOMP::Torch -> mlir::torch::Torch`
- CAPI `npcompTorchListTypeGet -> torchMlirTorchListTypeGet`
- preprocessor `#ifndef NPCOMP_ -> #ifndef TORCHMLIR_`
- CMake `NPCOMPFoo -> TorchMLIRFoo`
The goal of this is to create a standalone project creating a center of
mass for entry into the MLIR ecosystem from PyTorch, suitable in scope
for eventual inclusion/ownership in PyTorch. The idea is that
`external/torch-mlir` will some day be pulled out into its own
repository, and then npcomp will simply pull it in as a submodule.
Layering-wise, what lives in `torch-mlir` lowers code from PyTorch
(currently TorchScript, but TorchFX or pytorch/xla-style tracing are
possible extensions) down to what we have been calling the "Torch
backend contract" which is cleaned up IR (inlining, simplifcation,
conversion to value tensors, ...) entirely in the `torch` dialect. This
is the branching off point for further lowering, of which npcomp takes
one opinion (outside `torch-mlir` of course!), namely the
`TorchConversion` dialect/transforms which lower to IR suitable for IREE
and other linalg-on-tensors based lower-level compilers.
Summary of changes:
- move `{include,lib,test}/Dialect/Torch` into `torch-mlir`
- move relevant parts of CAPI into `torch-mlir`.
- leave a few things related to the `torch-mlir` Python build commented
out, which should be resolved in a subsequent change.
2021-09-10 03:24:10 +08:00
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
2020-09-29 03:02:35 +08:00
|
|
|
|
2022-04-27 03:27:51 +08:00
|
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
2020-09-30 05:17:34 +08:00
|
|
|
#include "mlir/IR/Builders.h"
|
2021-01-28 08:35:44 +08:00
|
|
|
#include "mlir/IR/BuiltinOps.h"
|
2021-04-27 02:42:41 +08:00
|
|
|
#include "mlir/IR/PatternMatch.h"
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
#include "mlir/IR/TypeUtilities.h"
|
2022-01-11 15:42:53 +08:00
|
|
|
#include "mlir/Support/LLVM.h"
|
2021-10-16 06:23:59 +08:00
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
2022-03-10 08:44:22 +08:00
|
|
|
#include "llvm/ADT/BitVector.h"
|
2021-02-18 03:28:51 +08:00
|
|
|
#include "llvm/ADT/StringMap.h"
|
2022-01-11 15:42:53 +08:00
|
|
|
#include "llvm/Support/Casting.h"
|
2020-09-30 05:17:34 +08:00
|
|
|
|
2020-09-29 03:02:35 +08:00
|
|
|
using namespace mlir;
|
[torch-mlir earthmoving (1/N)] C/C++ code movement.
This creates the `external/torch-mlir` directory as an
LLVM_EXTERNAL_PROJECTS-compatible project (analogous to
`iree-dialects`) and completes movement/rename of all pure MLIR C/C++
compiler code into there. The next step will be to move all the Python
code / code that links/includes PyTorch C++ code (which currently lives
in `frontends/pytorch`) into a subdirectory here.
I call this "earthmoving" because it is mostly mechanical changes and
renames. As a quick summary (we can change this down the road easily)
- C++ `mlir::NPCOMP::Torch -> mlir::torch::Torch`
- CAPI `npcompTorchListTypeGet -> torchMlirTorchListTypeGet`
- preprocessor `#ifndef NPCOMP_ -> #ifndef TORCHMLIR_`
- CMake `NPCOMPFoo -> TorchMLIRFoo`
The goal of this is to create a standalone project creating a center of
mass for entry into the MLIR ecosystem from PyTorch, suitable in scope
for eventual inclusion/ownership in PyTorch. The idea is that
`external/torch-mlir` will some day be pulled out into its own
repository, and then npcomp will simply pull it in as a submodule.
Layering-wise, what lives in `torch-mlir` lowers code from PyTorch
(currently TorchScript, but TorchFX or pytorch/xla-style tracing are
possible extensions) down to what we have been calling the "Torch
backend contract" which is cleaned up IR (inlining, simplifcation,
conversion to value tensors, ...) entirely in the `torch` dialect. This
is the branching off point for further lowering, of which npcomp takes
one opinion (outside `torch-mlir` of course!), namely the
`TorchConversion` dialect/transforms which lower to IR suitable for IREE
and other linalg-on-tensors based lower-level compilers.
Summary of changes:
- move `{include,lib,test}/Dialect/Torch` into `torch-mlir`
- move relevant parts of CAPI into `torch-mlir`.
- leave a few things related to the `torch-mlir` Python build commented
out, which should be resolved in a subsequent change.
2021-09-10 03:24:10 +08:00
|
|
|
using namespace mlir::torch;
|
|
|
|
using namespace mlir::torch::Torch;
|
2020-10-23 14:31:34 +08:00
|
|
|
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Utilities
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
[torch-mlir earthmoving (1/N)] C/C++ code movement.
This creates the `external/torch-mlir` directory as an
LLVM_EXTERNAL_PROJECTS-compatible project (analogous to
`iree-dialects`) and completes movement/rename of all pure MLIR C/C++
compiler code into there. The next step will be to move all the Python
code / code that links/includes PyTorch C++ code (which currently lives
in `frontends/pytorch`) into a subdirectory here.
I call this "earthmoving" because it is mostly mechanical changes and
renames. As a quick summary (we can change this down the road easily)
- C++ `mlir::NPCOMP::Torch -> mlir::torch::Torch`
- CAPI `npcompTorchListTypeGet -> torchMlirTorchListTypeGet`
- preprocessor `#ifndef NPCOMP_ -> #ifndef TORCHMLIR_`
- CMake `NPCOMPFoo -> TorchMLIRFoo`
The goal of this is to create a standalone project creating a center of
mass for entry into the MLIR ecosystem from PyTorch, suitable in scope
for eventual inclusion/ownership in PyTorch. The idea is that
`external/torch-mlir` will some day be pulled out into its own
repository, and then npcomp will simply pull it in as a submodule.
Layering-wise, what lives in `torch-mlir` lowers code from PyTorch
(currently TorchScript, but TorchFX or pytorch/xla-style tracing are
possible extensions) down to what we have been calling the "Torch
backend contract" which is cleaned up IR (inlining, simplifcation,
conversion to value tensors, ...) entirely in the `torch` dialect. This
is the branching off point for further lowering, of which npcomp takes
one opinion (outside `torch-mlir` of course!), namely the
`TorchConversion` dialect/transforms which lower to IR suitable for IREE
and other linalg-on-tensors based lower-level compilers.
Summary of changes:
- move `{include,lib,test}/Dialect/Torch` into `torch-mlir`
- move relevant parts of CAPI into `torch-mlir`.
- leave a few things related to the `torch-mlir` Python build commented
out, which should be resolved in a subsequent change.
2021-09-10 03:24:10 +08:00
|
|
|
Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc,
|
|
|
|
BaseTensorType newType,
|
|
|
|
Value tensor) {
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
auto originalType = tensor.getType().cast<BaseTensorType>();
|
|
|
|
// Adjust the static information in the type to match between the original and
|
|
|
|
// new types.
|
|
|
|
if (!originalType.hasSameSizesAndDtype(newType)) {
|
|
|
|
tensor = builder.create<TensorStaticInfoCastOp>(
|
|
|
|
loc, originalType.getWithSizesAndDtypeFrom(newType), tensor);
|
|
|
|
}
|
2021-06-19 04:47:47 +08:00
|
|
|
|
|
|
|
// Unless both the original and new types are both value tensors, we end
|
|
|
|
// up creating one op that converts between the value and non-value tensor
|
|
|
|
// domains. If both the original and new types are both non-value tensors,
|
|
|
|
// then we do the copy by going to a value tensor and back.
|
|
|
|
if (tensor.getType().isa<NonValueTensorType>())
|
|
|
|
tensor = builder.create<CopyToValueTensorOp>(loc, tensor);
|
|
|
|
if (newType.isa<NonValueTensorType>())
|
|
|
|
tensor = builder.create<CopyToNonValueTensorOp>(loc, tensor);
|
|
|
|
|
|
|
|
return tensor;
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
}
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
bool mlir::torch::Torch::isListPotentiallyMutated(Value list) {
|
|
|
|
assert(list.getType().isa<Torch::ListType>());
|
|
|
|
return llvm::any_of(list.getUsers(), potentiallyMutatesListOperands);
|
|
|
|
}
|
|
|
|
|
|
|
|
bool mlir::torch::Torch::potentiallyMutatesListOperands(Operation *op) {
|
|
|
|
// TODO: Find a better place to put this assertion.
|
|
|
|
assert((!op->hasTrait<Torch::OpTrait::HasValueSemantics>() ||
|
|
|
|
op->hasTrait<OpTrait::ReadOnly>()) &&
|
|
|
|
"HasValueSemantics should imply ReadOnly!");
|
|
|
|
// ReadOnly ops trivially do not mutate any list operands.
|
|
|
|
if (op->hasTrait<Torch::OpTrait::ReadOnly>())
|
|
|
|
return false;
|
|
|
|
|
|
|
|
// Ops with no MemoryEffectOpInterface effects also do not mutate any list
|
|
|
|
// operands.
|
|
|
|
if (auto effects = dyn_cast<MemoryEffectOpInterface>(op)) {
|
|
|
|
if (effects.hasNoEffect())
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Conservatively assume that an op might mutate any list operands.
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2021-08-18 01:59:47 +08:00
|
|
|
static IntegerAttr getI64IntegerAttr(MLIRContext *context, int64_t value) {
|
|
|
|
return IntegerAttr::get(IntegerType::get(context, 64), value);
|
|
|
|
}
|
|
|
|
|
2022-04-25 20:06:41 +08:00
|
|
|
static FloatAttr getF64FloatAttr(MLIRContext *context, double value) {
|
|
|
|
return FloatAttr::get(Float64Type::get(context), value);
|
|
|
|
}
|
|
|
|
|
2022-06-18 02:49:36 +08:00
|
|
|
static Value getScalarValue(Value input, Location loc,
|
|
|
|
PatternRewriter &rewriter) {
|
|
|
|
Value scalar = nullptr;
|
|
|
|
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
|
|
|
|
if (valueTensorLiteralOp &&
|
|
|
|
getTensorRank(valueTensorLiteralOp.getResult()) == 0) {
|
|
|
|
auto tensorType =
|
|
|
|
valueTensorLiteralOp.value().getType().cast<RankedTensorType>();
|
|
|
|
if (tensorType.getElementType().isa<mlir::IntegerType>()) {
|
|
|
|
auto val = valueTensorLiteralOp.value()
|
|
|
|
.cast<DenseElementsAttr>()
|
|
|
|
.getSplatValue<int64_t>();
|
|
|
|
scalar = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(val));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} else if (auto primNumToTensorScalarOp =
|
|
|
|
input.getDefiningOp<PrimNumToTensorScalarOp>()) {
|
|
|
|
scalar = primNumToTensorScalarOp.a();
|
|
|
|
}
|
|
|
|
return scalar;
|
|
|
|
}
|
|
|
|
|
2021-01-28 08:35:44 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// MethodOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
2021-09-04 02:38:00 +08:00
|
|
|
auto func =
|
2022-04-27 03:27:51 +08:00
|
|
|
symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, functionAttr());
|
2021-01-28 08:35:44 +08:00
|
|
|
if (!func)
|
2021-02-18 03:28:51 +08:00
|
|
|
return emitError() << "'@" << function()
|
2021-01-28 08:35:44 +08:00
|
|
|
<< "' does not reference a valid function";
|
2021-02-18 03:28:51 +08:00
|
|
|
if (func.getVisibility() != SymbolTable::Visibility::Private)
|
|
|
|
return emitError() << "'@" << function()
|
|
|
|
<< "' must reference a private function";
|
|
|
|
if (func.isDeclaration())
|
|
|
|
return emitError() << "'@" << function()
|
|
|
|
<< "' must reference a function that is defined (not "
|
|
|
|
"merely declared)";
|
|
|
|
auto expectedReceiverArgType = NnModuleType::get(
|
|
|
|
getContext(), getOperation()->getParentOfType<ClassTypeOp>().getName());
|
2022-04-27 03:27:51 +08:00
|
|
|
if (func.getFunctionType().getNumInputs() == 0 ||
|
|
|
|
func.getFunctionType().getInput(0) != expectedReceiverArgType) {
|
2021-02-18 03:28:51 +08:00
|
|
|
return emitError() << "the referenced function '" << function()
|
|
|
|
<< "' must have a first argument of type "
|
|
|
|
<< expectedReceiverArgType;
|
|
|
|
}
|
2021-01-28 08:35:44 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// NnModuleOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-03-16 08:54:57 +08:00
|
|
|
LogicalResult NnModuleOp::verify() {
|
|
|
|
for (Operation &child : *getBody())
|
2021-02-18 03:28:51 +08:00
|
|
|
if (!isa<SlotOp, NnModuleTerminatorOp>(&child))
|
|
|
|
return child.emitOpError() << "is not allowed inside 'torch.nn_module'";
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult NnModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
2021-09-04 02:38:00 +08:00
|
|
|
auto classType = symbolTable.lookupNearestSymbolFrom<ClassTypeOp>(
|
|
|
|
*this, SymbolRefAttr::get(getContext(), getClassName()));
|
2021-02-18 03:28:51 +08:00
|
|
|
if (!classType)
|
|
|
|
return emitError() << "'" << getClassName()
|
|
|
|
<< "' does not reference a valid class type";
|
|
|
|
|
|
|
|
auto attrs = llvm::to_vector<6>(getBody()->getOps<SlotOp>());
|
|
|
|
auto attrDefs = llvm::to_vector<6>(classType.getBody()->getOps<AttrOp>());
|
|
|
|
if (attrs.size() != attrDefs.size())
|
|
|
|
return emitError() << "number of 'torch.slot's in a 'torch.nn_module' must "
|
|
|
|
"match number of 'torch.attr's in "
|
|
|
|
"the corresponding 'torch.class_type'";
|
|
|
|
for (int i = 0, e = attrs.size(); i != e; i++) {
|
|
|
|
SlotOp attr = attrs[i];
|
|
|
|
AttrOp attrDef = attrDefs[i];
|
|
|
|
if (!isValidSubtype(attr.value().getType(), attrDef.type()) ||
|
|
|
|
attr.name() != attrDef.name()) {
|
|
|
|
return attr.emitOpError()
|
|
|
|
.append("is expected to match type and name of '",
|
|
|
|
attrDef.getOperation(), "'")
|
|
|
|
.attachNote(attrDef.getLoc())
|
|
|
|
.append("see torch.attr at corresponding index ", i, " here");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2021-06-05 06:57:21 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimListConstructOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-03-16 08:54:57 +08:00
|
|
|
LogicalResult PrimListConstructOp::verify() {
|
|
|
|
auto resultType = getResult().getType();
|
2021-06-05 06:57:21 +08:00
|
|
|
auto resultElementType = resultType.dyn_cast<ListType>().getContainedType();
|
|
|
|
auto matchResultElementType = [&](Type type) {
|
2021-08-08 10:33:39 +08:00
|
|
|
return isValidSubtype(type, resultElementType);
|
2021-06-05 06:57:21 +08:00
|
|
|
};
|
2022-03-16 08:54:57 +08:00
|
|
|
if (!llvm::all_of(getOperandTypes(), matchResultElementType)) {
|
|
|
|
return emitError() << "operand types should have the same type as the "
|
|
|
|
"list contained type";
|
2021-06-17 06:53:15 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
2021-06-05 06:57:21 +08:00
|
|
|
}
|
|
|
|
|
2021-08-08 10:33:39 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimDictConstructOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-03-16 08:54:57 +08:00
|
|
|
LogicalResult PrimDictConstructOp::verify() {
|
2021-08-08 10:33:39 +08:00
|
|
|
auto isValidSubTypeOf = [](Type expectedType) {
|
|
|
|
return [=](Type type) { return isValidSubtype(type, expectedType); };
|
|
|
|
};
|
|
|
|
|
2022-03-16 08:54:57 +08:00
|
|
|
if (!llvm::all_of(keys().getTypes(), isValidSubTypeOf(getKeyType())))
|
|
|
|
return emitError() << "keys should be of Dict key type";
|
2021-08-08 10:33:39 +08:00
|
|
|
|
2022-03-16 08:54:57 +08:00
|
|
|
if (!llvm::all_of(values().getTypes(), isValidSubTypeOf(getValueType())))
|
|
|
|
return emitError() << "values should be of Dict value type";
|
2021-08-11 09:28:50 +08:00
|
|
|
|
2021-08-08 10:33:39 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2021-02-18 03:28:51 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ClassTypeOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-03-16 08:54:57 +08:00
|
|
|
LogicalResult ClassTypeOp::verify() {
|
2021-02-18 03:28:51 +08:00
|
|
|
llvm::StringMap<Operation *> namesToOps;
|
2022-03-16 08:54:57 +08:00
|
|
|
for (Operation &child : getBody()->without_terminator()) {
|
2021-02-18 03:28:51 +08:00
|
|
|
if (!isa<AttrOp, MethodOp>(&child))
|
|
|
|
return child.emitOpError() << "is not allowed inside `torch.class_type`";
|
|
|
|
StringRef name;
|
|
|
|
if (auto attr = dyn_cast<AttrOp>(child))
|
|
|
|
name = attr.name();
|
|
|
|
else
|
|
|
|
name = cast<MethodOp>(child).name();
|
|
|
|
auto itAndWasInserted = namesToOps.insert({name, &child});
|
|
|
|
auto it = itAndWasInserted.first;
|
|
|
|
bool wasInserted = itAndWasInserted.second;
|
|
|
|
if (!wasInserted) {
|
2022-03-16 08:54:57 +08:00
|
|
|
auto diag = emitOpError().append("has duplicate attr/method with name '",
|
|
|
|
name, "'");
|
2021-02-18 03:28:51 +08:00
|
|
|
diag.attachNote(it->second->getLoc())
|
|
|
|
.append("see first conflicting attr/method here");
|
|
|
|
diag.attachNote(child.getLoc())
|
|
|
|
.append("see second conflicting attr/method here");
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2021-01-28 08:35:44 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2021-03-02 07:00:32 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimLoopOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-06-23 11:23:46 +08:00
|
|
|
OperandRange PrimLoopOp::getSuccessorEntryOperands(Optional<unsigned int> index) {
|
|
|
|
assert(index.hasValue() && index.value() == 0);
|
2021-03-02 07:00:32 +08:00
|
|
|
return iterArgsInit();
|
|
|
|
}
|
|
|
|
|
|
|
|
void PrimLoopOp::getSuccessorRegions(
|
|
|
|
Optional<unsigned> index, ArrayRef<Attribute> operands,
|
|
|
|
SmallVectorImpl<RegionSuccessor> ®ions) {
|
|
|
|
(void)operands;
|
|
|
|
|
|
|
|
if (!index.hasValue()) {
|
|
|
|
regions.emplace_back(®ion(), region().getArguments().slice(1));
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
assert(*index == 0);
|
|
|
|
regions.emplace_back(®ion(), region().getArguments().slice(1));
|
|
|
|
regions.emplace_back(getResults());
|
|
|
|
}
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
bool PrimLoopOp::isForLike() {
|
|
|
|
bool b;
|
|
|
|
return matchPattern(initialCondition(), m_TorchConstantBool(&b)) && b;
|
|
|
|
}
|
|
|
|
|
2021-08-19 23:13:55 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimLoopConditionOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
MutableOperandRange
|
|
|
|
PrimLoopConditionOp::getMutableSuccessorOperands(Optional<unsigned> index) {
|
|
|
|
// Pass all operands except the condition to the successor which is the
|
|
|
|
// parent loop op.
|
|
|
|
return iterArgsMutable();
|
|
|
|
}
|
|
|
|
|
2021-06-17 01:23:26 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimIfOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-13 02:47:12 +08:00
|
|
|
ParseResult PrimIfOp::parse(OpAsmParser &parser, OperationState &result) {
|
2021-06-17 01:23:26 +08:00
|
|
|
// Create the regions.
|
|
|
|
result.regions.reserve(2);
|
|
|
|
Region *thenRegion = result.addRegion();
|
|
|
|
Region *elseRegion = result.addRegion();
|
|
|
|
|
|
|
|
auto &builder = parser.getBuilder();
|
2022-04-27 03:27:51 +08:00
|
|
|
OpAsmParser::UnresolvedOperand cond;
|
2021-06-17 01:23:26 +08:00
|
|
|
Type boolType = builder.getType<Torch::BoolType>();
|
|
|
|
if (parser.parseOperand(cond) ||
|
|
|
|
parser.resolveOperand(cond, boolType, result.operands))
|
|
|
|
return failure();
|
|
|
|
// Parse results type list.
|
|
|
|
if (parser.parseArrowTypeList(result.types))
|
|
|
|
return failure();
|
|
|
|
// Parse the 'then' region.
|
|
|
|
if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
|
|
|
|
return failure();
|
|
|
|
// Parse the 'else' region.
|
|
|
|
if (parser.parseKeyword("else"))
|
|
|
|
return failure();
|
|
|
|
if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
|
|
|
|
return failure();
|
|
|
|
// Parse the optional attribute list.
|
|
|
|
if (parser.parseOptionalAttrDict(result.attributes))
|
|
|
|
return failure();
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-02-13 02:47:12 +08:00
|
|
|
void PrimIfOp::print(OpAsmPrinter &p) {
|
|
|
|
p << " " << condition();
|
|
|
|
p << " -> (" << getResultTypes() << ") ";
|
|
|
|
p.printRegion(thenRegion(), /*printEntryBlockArgs=*/false);
|
2022-01-26 14:16:30 +08:00
|
|
|
p << " else ";
|
2022-02-13 02:47:12 +08:00
|
|
|
p.printRegion(elseRegion(), /*printEntryBlockArgs=*/false);
|
2021-06-17 01:23:26 +08:00
|
|
|
|
2022-02-13 02:47:12 +08:00
|
|
|
p.printOptionalAttrDict((*this)->getAttrs());
|
2021-06-17 01:23:26 +08:00
|
|
|
}
|
|
|
|
|
2021-06-19 10:33:14 +08:00
|
|
|
void PrimIfOp::getSuccessorRegions(Optional<unsigned> index,
|
|
|
|
ArrayRef<Attribute> operands,
|
|
|
|
SmallVectorImpl<RegionSuccessor> ®ions) {
|
2021-06-17 01:23:26 +08:00
|
|
|
// The `then` and the `else` region branch back to the parent operation.
|
|
|
|
if (index.hasValue()) {
|
|
|
|
regions.push_back(RegionSuccessor(getResults()));
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
// If the condition is constant, we can give a more precise answer.
|
|
|
|
if (auto condAttr = operands.front().dyn_cast_or_null<IntegerAttr>()) {
|
|
|
|
Region *executedRegion =
|
|
|
|
condAttr.getValue().isOneValue() ? &thenRegion() : &elseRegion();
|
|
|
|
regions.push_back(RegionSuccessor(executedRegion));
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
// If the condition isn't constant, both regions may be executed.
|
|
|
|
regions.push_back(RegionSuccessor(&thenRegion()));
|
|
|
|
regions.push_back(RegionSuccessor(&elseRegion()));
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Replaces the given op with the contents of the given single-block region,
|
|
|
|
/// using the operands of the block terminator to replace operation results.
|
|
|
|
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
|
|
|
|
Region ®ion, ValueRange blockArgs = {}) {
|
|
|
|
assert(llvm::hasSingleElement(region) && "expected single-region block");
|
|
|
|
Block *block = ®ion.front();
|
|
|
|
Operation *terminator = block->getTerminator();
|
|
|
|
ValueRange results = terminator->getOperands();
|
|
|
|
rewriter.mergeBlockBefore(block, op, blockArgs);
|
|
|
|
rewriter.replaceOp(op, results);
|
|
|
|
rewriter.eraseOp(terminator);
|
|
|
|
}
|
|
|
|
|
|
|
|
void PrimIfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
// If the condition is constant, delete the dead branch and inline the live
|
|
|
|
// branch.
|
|
|
|
patterns.add(+[](PrimIfOp op, PatternRewriter &rewriter) {
|
|
|
|
auto constantBool = op.condition().getDefiningOp<Torch::ConstantBoolOp>();
|
|
|
|
if (!constantBool)
|
|
|
|
return rewriter.notifyMatchFailure(op, "non-constant condition");
|
|
|
|
replaceOpWithRegion(
|
|
|
|
rewriter, op, constantBool.value() ? op.thenRegion() : op.elseRegion());
|
|
|
|
return success();
|
|
|
|
});
|
2022-03-10 08:44:22 +08:00
|
|
|
// If the thenRegion and elseRegion yield the same Value's, then use those
|
|
|
|
// directly.
|
|
|
|
patterns.add(+[](PrimIfOp op, PatternRewriter &rewriter) {
|
|
|
|
auto trueTerminator = op.thenRegion().front().getTerminator();
|
|
|
|
auto falseTerminator = op.elseRegion().front().getTerminator();
|
|
|
|
bool madeChange = false;
|
|
|
|
SmallVector<int> resultsToErase;
|
|
|
|
for (auto t : llvm::zip(trueTerminator->getOperands(),
|
|
|
|
falseTerminator->getOperands(), op->getResults())) {
|
|
|
|
auto trueVal = std::get<0>(t);
|
|
|
|
auto falseVal = std::get<1>(t);
|
|
|
|
auto resultToBeReplaced = std::get<2>(t);
|
|
|
|
if (trueVal == falseVal) {
|
|
|
|
madeChange |= !resultToBeReplaced.use_empty();
|
|
|
|
resultToBeReplaced.replaceAllUsesWith(trueVal);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// We leave it up to a separate pattern (not yet implemented) to erase the
|
|
|
|
// results that are now dead. That transformation is independently useful,
|
|
|
|
// and also pretty tricky to implement because it changes the number of
|
|
|
|
// results.
|
|
|
|
return success(madeChange);
|
|
|
|
});
|
|
|
|
// Erase any dead results.
|
|
|
|
patterns.add(+[](PrimIfOp op, PatternRewriter &rewriter) {
|
|
|
|
llvm::BitVector resultsToErase(op.getNumResults());
|
|
|
|
for (auto result : llvm::enumerate(op->getResults())) {
|
|
|
|
if (result.value().use_empty())
|
|
|
|
resultsToErase.set(result.index());
|
|
|
|
}
|
|
|
|
|
|
|
|
// If no results have uses and there are no side effects, just erase the op.
|
|
|
|
// Approximate the body having no side effects by checking if it is just a
|
|
|
|
// terminator.
|
|
|
|
// Note: We don't want to make this logic too fancy, because in general,
|
|
|
|
// checking for recursive side effects can result in a quadratic amount of
|
|
|
|
// work (N nested If's each resulting in O(N) work). It should probably be
|
|
|
|
// split into its own pattern if we want to make it fancier.
|
|
|
|
if (resultsToErase.all() &&
|
|
|
|
llvm::hasSingleElement(op.thenRegion().front()) &&
|
|
|
|
llvm::hasSingleElement(op.elseRegion().front())) {
|
|
|
|
rewriter.eraseOp(op);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// If there are no results to erase, we're done.
|
|
|
|
if (!resultsToErase.any())
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
SmallVector<Type> newResultTypes;
|
|
|
|
for (int i = 0, e = op->getNumResults(); i < e; ++i) {
|
|
|
|
if (resultsToErase[i])
|
|
|
|
continue;
|
|
|
|
newResultTypes.push_back(op->getResult(i).getType());
|
|
|
|
}
|
|
|
|
auto newIf =
|
|
|
|
rewriter.create<PrimIfOp>(op->getLoc(), newResultTypes, op.condition());
|
|
|
|
rewriter.inlineRegionBefore(op.thenRegion(), newIf.thenRegion(),
|
|
|
|
newIf.thenRegion().end());
|
|
|
|
rewriter.inlineRegionBefore(op.elseRegion(), newIf.elseRegion(),
|
|
|
|
newIf.elseRegion().end());
|
|
|
|
newIf.thenRegion().front().getTerminator()->eraseOperands(resultsToErase);
|
|
|
|
newIf.elseRegion().front().getTerminator()->eraseOperands(resultsToErase);
|
|
|
|
SmallVector<Value> replacementValues;
|
|
|
|
for (int i = 0, e = op->getNumResults(), nextNewValue = 0; i < e; ++i) {
|
|
|
|
if (resultsToErase[i])
|
|
|
|
replacementValues.push_back(nullptr);
|
|
|
|
else
|
|
|
|
replacementValues.push_back(newIf->getResult(nextNewValue++));
|
|
|
|
}
|
|
|
|
rewriter.replaceOp(op, replacementValues);
|
|
|
|
|
|
|
|
return success();
|
|
|
|
});
|
2021-06-17 01:23:26 +08:00
|
|
|
}
|
|
|
|
|
2021-04-27 02:42:41 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// DerefineOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
bool DerefineOp::areCastCompatible(mlir::TypeRange inputs,
|
|
|
|
mlir::TypeRange outputs) {
|
|
|
|
return isValidSubtype(inputs[0], outputs[0]);
|
|
|
|
}
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
OpFoldResult DerefineOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
auto uncheckedCast = getOperand().getDefiningOp<PrimUncheckedCastOp>();
|
|
|
|
if (!uncheckedCast)
|
|
|
|
return nullptr;
|
|
|
|
if (uncheckedCast.getOperand().getType() == getType())
|
|
|
|
return uncheckedCast.getOperand();
|
|
|
|
return nullptr;
|
|
|
|
}
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
|
2022-03-30 06:57:31 +08:00
|
|
|
void DerefineOp::getCanonicalizationPatterns(
|
|
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
|
|
patterns.add(+[](DerefineOp op, PatternRewriter &rewriter) {
|
|
|
|
bool madeChange = false;
|
|
|
|
for (OpOperand &use : llvm::make_early_inc_range(op->getUses())) {
|
|
|
|
if (use.getOwner()->hasTrait<OpTrait::AllowsTypeRefinement>()) {
|
|
|
|
use.set(op.getOperand());
|
|
|
|
madeChange = true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return success(madeChange);
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
static OpFoldResult atenIsOrIsNotFoldHelper(Operation *op, bool equalIsTrue) {
|
|
|
|
Value lhs = op->getOperand(0);
|
|
|
|
Value rhs = op->getOperand(1);
|
|
|
|
// Look through DerefineOp's to get more refined static information.
|
|
|
|
if (auto derefine = lhs.getDefiningOp<DerefineOp>())
|
|
|
|
lhs = derefine.getOperand();
|
|
|
|
if (auto derefine = rhs.getDefiningOp<DerefineOp>())
|
|
|
|
rhs = derefine.getOperand();
|
|
|
|
Type lhsType = lhs.getType();
|
|
|
|
Type rhsType = rhs.getType();
|
|
|
|
|
|
|
|
// If either type is a NoneType, make it be the lhsType.
|
|
|
|
if (rhsType.isa<Torch::NoneType>()) {
|
|
|
|
std::swap(lhsType, rhsType);
|
2022-01-28 21:35:40 +08:00
|
|
|
std::swap(lhs, rhs);
|
2022-03-10 08:44:22 +08:00
|
|
|
}
|
2022-01-28 21:35:40 +08:00
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
// For now, check a few specific cases.
|
2022-01-28 21:35:40 +08:00
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
// If both types are the singleton `!torch.none` type, then we don't even need
|
|
|
|
// to look at the values.
|
|
|
|
if (lhsType.isa<Torch::NoneType>() && rhsType.isa<Torch::NoneType>())
|
|
|
|
return IntegerAttr::get(IntegerType::get(op->getContext(), 1), equalIsTrue);
|
2021-08-11 09:28:50 +08:00
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
// If neither type is a subtype of the other, then the result is false.
|
|
|
|
// TODO: Implement and use subtype infra for this.
|
|
|
|
// For now, check a specific case.
|
|
|
|
// If the rhs is not OptionalType, then we know it cannot be None.
|
|
|
|
if (lhsType.isa<Torch::NoneType>() && !rhsType.isa<Torch::OptionalType>()) {
|
|
|
|
return IntegerAttr::get(IntegerType::get(op->getContext(), 1),
|
|
|
|
!equalIsTrue);
|
|
|
|
}
|
2021-08-11 09:28:50 +08:00
|
|
|
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten__RangeLengthOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult Aten__RangeLengthOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
auto lo = operands[0];
|
|
|
|
auto hi = operands[1];
|
|
|
|
auto step = operands[2];
|
|
|
|
if (!lo || !hi || !step)
|
|
|
|
return nullptr;
|
|
|
|
auto loInt = lo.dyn_cast_or_null<IntegerAttr>().getValue();
|
|
|
|
auto hiInt = hi.dyn_cast_or_null<IntegerAttr>().getValue();
|
|
|
|
auto stepInt = step.dyn_cast_or_null<IntegerAttr>().getValue();
|
|
|
|
// TODO: Implement folding for negative steps.
|
|
|
|
if (stepInt.isNegative())
|
|
|
|
return nullptr;
|
|
|
|
// From Python language spec:
|
|
|
|
// r[i] = lo + step*i such that i >= 0 and r[i] < hi
|
|
|
|
// So maximize `i` such that lo + step * i < hi
|
|
|
|
// ==> i == ceildiv(hi - lo, step)
|
|
|
|
return IntegerAttr::get(lo.getType(),
|
|
|
|
llvm::APIntOps::RoundingSDiv(hiInt - loInt, stepInt,
|
|
|
|
APInt::Rounding::UP));
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten__DeriveIndexOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult Aten__DeriveIndexOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
auto index = operands[0];
|
|
|
|
auto start = operands[1];
|
|
|
|
auto step = operands[2];
|
|
|
|
if (!index || !start || !step)
|
|
|
|
return nullptr;
|
|
|
|
auto indexInt = index.dyn_cast_or_null<IntegerAttr>().getValue();
|
|
|
|
auto startInt = start.dyn_cast_or_null<IntegerAttr>().getValue();
|
|
|
|
auto stepInt = step.dyn_cast_or_null<IntegerAttr>().getValue();
|
|
|
|
return IntegerAttr::get(index.getType(), startInt + stepInt * indexInt);
|
|
|
|
}
|
|
|
|
|
2021-08-11 09:28:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten__Is__Op
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult Aten__Is__Op::fold(ArrayRef<Attribute> operands) {
|
|
|
|
return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/true);
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten__Isnot__Op
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult Aten__Isnot__Op::fold(ArrayRef<Attribute> operands) {
|
|
|
|
return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/false);
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten__Not__Op
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult Aten__Not__Op::fold(ArrayRef<Attribute> operands) {
|
|
|
|
bool value;
|
|
|
|
if (!matchPattern(getOperand(), m_TorchConstantBool(&value)))
|
|
|
|
return nullptr;
|
|
|
|
return IntegerAttr::get(IntegerType::get(getContext(), 1), !value);
|
|
|
|
}
|
|
|
|
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2021-10-21 23:50:01 +08:00
|
|
|
// AtenNeBoolOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenNeBoolOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
if (getOperand(0) == getOperand(1))
|
|
|
|
return IntegerAttr::get(IntegerType::get(getContext(), 1), false);
|
|
|
|
|
|
|
|
bool a, b;
|
|
|
|
if (!matchPattern(getOperand(0), m_TorchConstantBool(&a)))
|
|
|
|
return nullptr;
|
|
|
|
if (!matchPattern(getOperand(1), m_TorchConstantBool(&b)))
|
|
|
|
return nullptr;
|
|
|
|
return IntegerAttr::get(IntegerType::get(getContext(), 1), a != b);
|
|
|
|
}
|
|
|
|
|
2021-11-25 04:19:13 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSqueezeOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenSqueezeOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
if (auto tensorType = getOperand().getType().dyn_cast<BaseTensorType>()) {
|
|
|
|
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
|
|
|
|
return getOperand();
|
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2021-11-30 22:50:55 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSqueezeDimOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenSqueezeDimOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
if (auto tensorType = getOperand(0).getType().dyn_cast<BaseTensorType>()) {
|
|
|
|
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
|
|
|
|
return getOperand(0);
|
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2021-12-23 20:04:29 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenToDtypeOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenToDtypeOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
bool nonBlocking, copyArg;
|
|
|
|
// The non_blocking arg must be `False`.
|
|
|
|
if (!matchPattern(non_blocking(), m_TorchConstantBool(&nonBlocking)) ||
|
|
|
|
nonBlocking)
|
|
|
|
return nullptr;
|
|
|
|
// The copy arg must be `False`.
|
|
|
|
if (!matchPattern(copy(), m_TorchConstantBool(©Arg)) || copyArg)
|
|
|
|
return nullptr;
|
|
|
|
// The memory_format arg must be `none`.
|
|
|
|
if (!memory_format().getType().isa<Torch::NoneType>())
|
|
|
|
return nullptr;
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
auto inputType = self().getType().cast<BaseTensorType>();
|
|
|
|
auto resType = getType().cast<BaseTensorType>();
|
|
|
|
// If the types aren't equal, then we can't fold.
|
|
|
|
if (inputType != resType)
|
2021-12-23 20:04:29 +08:00
|
|
|
return nullptr;
|
2022-03-10 08:44:22 +08:00
|
|
|
// If the type does not have a statically known dtype, then we cannot fold.
|
|
|
|
// For example, folding `tensor<*,unk>` to `tensor<*,unk>` would be wrong,
|
|
|
|
// since the `unk` could be dynamically different for the operand and result.
|
|
|
|
if (!inputType.hasDtype())
|
2021-12-23 20:04:29 +08:00
|
|
|
return nullptr;
|
|
|
|
// Fold when both the input tensor and result are of the same type.
|
|
|
|
return getOperand(0);
|
|
|
|
}
|
|
|
|
|
2022-04-27 19:07:40 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenToDtypeLayoutOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenToDtypeLayoutOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
// The pin_memory arg should be either constant `False` or `none`.
|
|
|
|
if (!pin_memory().getType().isa<Torch::NoneType>()) {
|
|
|
|
bool pinMemory;
|
|
|
|
if (!matchPattern(pin_memory(), m_TorchConstantBool(&pinMemory)))
|
|
|
|
return nullptr;
|
|
|
|
else if (pinMemory)
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
// The non_blocking arg should be constant `False`.
|
|
|
|
bool nonBlocking;
|
|
|
|
if (!matchPattern(non_blocking(), m_TorchConstantBool(&nonBlocking)))
|
|
|
|
return nullptr;
|
|
|
|
else if (nonBlocking)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
// The copy arg should be constant `False`.
|
|
|
|
bool copyArg;
|
|
|
|
if (!matchPattern(copy(), m_TorchConstantBool(©Arg)))
|
|
|
|
return nullptr;
|
|
|
|
else if (copyArg)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
// The device arg must be `none`.
|
|
|
|
if (!device().getType().isa<Torch::NoneType>())
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
// The memory_format arg must be `none`.
|
|
|
|
if (!memory_format().getType().isa<Torch::NoneType>())
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
auto inputType = self().getType().cast<BaseTensorType>();
|
|
|
|
auto resType = getType().cast<BaseTensorType>();
|
|
|
|
// If the types aren't equal, then we can't fold.
|
|
|
|
if (inputType != resType)
|
|
|
|
return nullptr;
|
|
|
|
// If the type does not have a statically known dtype, then we cannot fold.
|
|
|
|
// For example, folding `tensor<*,unk>` to `tensor<*,unk>` would be wrong,
|
|
|
|
// since the `unk` could be dynamically different for the operand and result.
|
|
|
|
if (!inputType.hasDtype())
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
// The layout arg should be either `none` or `0` i.e. strided.
|
|
|
|
if (!layout().getType().isa<Torch::NoneType>()) {
|
|
|
|
int64_t tensorLayout;
|
|
|
|
if (!matchPattern(layout(), m_TorchConstantInt(&tensorLayout)))
|
|
|
|
return nullptr;
|
|
|
|
else if (tensorLayout != torch_upstream::Layout::Strided)
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Fold when both the input tensor and result are of the same type and the
|
|
|
|
// layout arg is strided.
|
|
|
|
return getOperand(0);
|
|
|
|
}
|
|
|
|
|
2021-12-23 20:04:29 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenViewOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenViewOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
auto inputType = getOperand(0).getType().dyn_cast<BaseTensorType>();
|
|
|
|
if (!inputType || !inputType.hasSizes() || inputType.getSizes().size() != 1)
|
|
|
|
return nullptr;
|
|
|
|
auto resType = getType().dyn_cast<BaseTensorType>();
|
|
|
|
if (!resType || !resType.hasSizes() || resType.getSizes().size() != 1)
|
|
|
|
return nullptr;
|
|
|
|
// Fold when both the input tensor and result are unity rank tensors.
|
|
|
|
return getOperand(0);
|
|
|
|
}
|
|
|
|
|
2021-10-21 23:50:01 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenDimOp
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
OpFoldResult AtenDimOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
if (auto tensorType = getOperand().getType().dyn_cast<BaseTensorType>()) {
|
|
|
|
if (tensorType.hasSizes())
|
|
|
|
return IntegerAttr::get(IntegerType::get(getContext(), 64),
|
|
|
|
tensorType.getSizes().size());
|
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenLenTOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenLenTOp::fold(ArrayRef<Attribute> operands) {
|
2022-03-10 08:44:22 +08:00
|
|
|
// `len([1,1,1])` -> `3`, if it is not mutated.
|
2021-06-19 10:33:14 +08:00
|
|
|
if (auto listConstruct =
|
|
|
|
getOperand().getDefiningOp<Torch::PrimListConstructOp>()) {
|
2022-03-10 08:44:22 +08:00
|
|
|
if (!isListPotentiallyMutated(listConstruct)) {
|
|
|
|
return IntegerAttr::get(IntegerType::get(getContext(), 64),
|
|
|
|
listConstruct.getNumOperands());
|
|
|
|
}
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
void AtenLenTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
// `len(t.size())` -> `t.ndim`
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
patterns.add(+[](AtenLenTOp op, PatternRewriter &rewriter) {
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
auto size = op.getOperand().getDefiningOp<AtenSizeOp>();
|
|
|
|
if (!size)
|
|
|
|
return rewriter.notifyMatchFailure(op, "operand not AtenSizeOp");
|
2021-06-17 06:53:15 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenDimOp>(op, size.getOperand());
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2022-07-08 01:41:55 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenLenStrOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenLenStrOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
if(auto stringConstruct = s().getDefiningOp<ConstantStrOp>())
|
|
|
|
return getI64IntegerAttr(getContext(), stringConstruct.valueAttr().getValue().size());
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2022-06-18 02:49:36 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenAddTensorOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void AtenAddTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenAddTensorOp op, PatternRewriter &rewriter) {
|
|
|
|
// The lhs and rhs of the add.tensor op should be 0d tensors for the
|
|
|
|
// canonicalization to be carried out.
|
|
|
|
// `aten.add.tensor(self, other, alpha)` is canonicalized to
|
|
|
|
// `aten.add.int(self, aten.mul.int(other, alpha))`.
|
|
|
|
|
|
|
|
Value lhs = getScalarValue(op.self(), op.getLoc(), rewriter);
|
|
|
|
if (!lhs)
|
|
|
|
return rewriter.notifyMatchFailure(op, "lhs scalar is empyty");
|
|
|
|
if (!lhs.getType().isa<Torch::IntType>())
|
|
|
|
return rewriter.notifyMatchFailure(op, "lhs scalar is not IntType");
|
|
|
|
|
|
|
|
Value rhs = getScalarValue(op.other(), op.getLoc(), rewriter);
|
|
|
|
if (!rhs)
|
|
|
|
return rewriter.notifyMatchFailure(op, "rhs scalar is empyty");
|
|
|
|
if (!rhs.getType().isa<Torch::IntType>())
|
|
|
|
return rewriter.notifyMatchFailure(op, "rhs scalar is not IntType");
|
|
|
|
|
|
|
|
Value mul = rewriter.create<AtenMulIntOp>(op->getLoc(), rhs, op.alpha());
|
|
|
|
Value add = rewriter.create<AtenAddIntOp>(op->getLoc(), lhs, mul);
|
|
|
|
rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(
|
|
|
|
op, op.self().getType(), add);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSizeOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](AtenSizeOp op, PatternRewriter &rewriter) {
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
auto type = op.getOperand().getType().dyn_cast<BaseTensorType>();
|
|
|
|
if (!type || !type.areAllSizesKnown())
|
|
|
|
return rewriter.notifyMatchFailure(op, "all sizes not known");
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
SmallVector<Value> listElements;
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
for (int64_t size : type.getSizes()) {
|
2021-06-16 03:42:51 +08:00
|
|
|
listElements.push_back(rewriter.create<Torch::ConstantIntOp>(
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
op->getLoc(), rewriter.getI64IntegerAttr(size)));
|
|
|
|
}
|
2021-06-05 06:57:21 +08:00
|
|
|
rewriter.replaceOpWithNewOp<Torch::PrimListConstructOp>(
|
2021-06-17 06:53:15 +08:00
|
|
|
op, Torch::ListType::get(rewriter.getType<Torch::IntType>()),
|
|
|
|
listElements);
|
Significantly restructure torch/aten import design.
This is a really major and invasive restructuring of the way we get
torch operators (`torch::jit::Operator` / `c10::OperatorHandle`) into
MLIR. Please forgive the challenging review, but due to the sheer
invasiveness, it wasn't really practical do do it in sane smaller
pieces.
This fully replaces everything that was already working on the
TorchScript path (actually, more -- we added tanh support to
TorchToLinalg in order to delete the older code paths). Additionally,
I've kept the lights on for the acap path too, including what little e2e
stuff was working before (for expediency I made a few tiny compromises
along the way that will be easy to undo when we give that path proper
attention).
Overview of the new design:
- The torch operator `somens::someunqualname.someoverloadname` is
imported as `torch.somens.someunqualname.someoverloadname` (skip the
last dotted part if the overload name is empty), OR, if we don't have
such an op registered, it is imported as
`torch.operator "somens.someunqualname.someoverloadname" (...) : ...`.
- The addition of the "overload name" is a critical element here, as
the `(ns,unqual,overload)` triple is unique, which solves a lot of
problems we were having.
- This involves having separate MLIR ops for the `trailing_` and
`.out` variants and all the different overloads. This seemed
necessary, because the set of overloads is so wild and varied and
unstructured. The previous design was leaning into some underlying
structure that just isn't there -- the default situation is
the "random overload that we want to manage on the MLIR side",
rather than that being an exception. E.g. `aten::ne` (not-equal)
has 21 overloads, only 4 of which are c10 dispatcher ops see
[gist](https://gist.github.com/silvasean/190ba918c550c956260e21254e1b8aa1),
and the "out" variant is really called `.Tensor_out` instead of
`.out` as it frequently is for other ops.
- Rationale for all being in `torch` namespace: the set of operators
are so varied and unstructured that "dialect per namespace"
doesn't result in anything resembling the typical MLIR dialect
boundary expectations. We could maybe draw the boundary at
dispatcher ops vs non-dispatcher ops, but that doesn't seem to
really result in very much useful structure at this point in time.
- Note: within the torch operator registry, we effectively have a
mini-basicpy subdialect (already type-resolved), which is reasonably
structured.
- The existing Torch op interfaces are also removed -- now that we
track the overload name, we can losslessly find the original
operator.
- Instead of `ATenRecognizeKernelsPass`, we now have a
`ReduceOpVariantsPass` that keys off certain traits (and perhaps
eventually interfaces) to reduce variants of ops to a smaller set,
ideally operating on immutable tensors and using surrounding ops to
model the mutability/aliasing aspects.
- Note: `torch.ns.unqual.overload` ops allow both immutable and
mutable tensors (unlike the previous hard distinction in the common
case). This is a premonition for a future change that will introduce a
bona fide `!torch.tensor` type that will clean up a bunch of stuff.
- `TorchToLinalg` / `TorchToStd` supercede the existing
"ATen->TCF->TCP->Linalg" path.
- The new `torch_ods_gen.py` supercedes `torch_signature_ods_gen.py`.
It should look somewhat familiar, but the benefit of hindsight has
allowed a lot of simplifications.
The overall trend seems to be to make the `torch` dialect a nice layer
independent of anything else. It feels like as a natural result of
various future changes we will be removing the reliance on basicpy+numpy
dialects and have a nice self-contained type system too that properly
models the TorchScript type system (including proper subtyping,
mutable/immutable tensors, optional dtype, etc.).
Recommended review order:
- Start at some of the new import IR, e.g. in
`frontends/pytorch/test/node_import/prim.py`,
`frontends/pytorch/test/acap_export/test_export_add3.py`, and other
tests.
- `frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py`
and associated generated files:
- `include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td`
- `include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td`
- Inspect `ReduceOpVariants.cpp` / `reduce-op-variants.mlir` and the new
traits in `include/npcomp/Dialect/Torch/IR/TorchTraits.h`
- Various code changes in the import path in
`frontends/pytorch/csrc/builder`. Probably most interesting is the new
code in `torch_to_mlir_utils.cpp` that has the logic to create the
`torch.operator` ops or `torch.ns.unqual.overload` ops.
This is the [new ResNet IR](https://gist.github.com/silvasean/5407aafb710d07612b7b5b92eabecebe),
just to be able to look at a substantial sample of IR in the new style.
2021-05-05 05:42:50 +08:00
|
|
|
return success();
|
|
|
|
});
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
// One-off pattern to erase if dead.
|
|
|
|
// TODO: Use the effects infra to express the semantics of this op and enable
|
|
|
|
// a centralized "erase if dead" canonicalization.
|
|
|
|
// Specifically, we need to mark the op as only MemoryEffects::Allocate
|
|
|
|
// so that `mlir::wouldOpBeTriviallyDead` does the right thing.
|
|
|
|
patterns.add(+[](AtenSizeOp op, PatternRewriter &rewriter) {
|
|
|
|
if (!op.use_empty())
|
|
|
|
return failure();
|
|
|
|
rewriter.eraseOp(op);
|
|
|
|
return failure();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2021-10-16 06:23:59 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSizeIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenSizeIntOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
auto type = getOperand(0).getType().dyn_cast<BaseTensorType>();
|
|
|
|
if (!type || !type.hasSizes())
|
|
|
|
return nullptr;
|
|
|
|
|
2022-03-30 04:21:47 +08:00
|
|
|
llvm::Optional<int64_t> dimOpt = matchLegalConstantIndexIntoListOfSize(
|
|
|
|
this->dim(), type.getSizes().size());
|
|
|
|
if (!dimOpt)
|
2021-10-16 06:23:59 +08:00
|
|
|
return nullptr;
|
2022-03-30 04:21:47 +08:00
|
|
|
if (type.getSizes()[*dimOpt] == kUnknownSize)
|
2021-10-16 06:23:59 +08:00
|
|
|
return nullptr;
|
|
|
|
return IntegerAttr::get(IntegerType::get(getContext(), 64),
|
2022-03-30 04:21:47 +08:00
|
|
|
type.getSizes()[*dimOpt]);
|
2021-10-16 06:23:59 +08:00
|
|
|
}
|
|
|
|
|
2021-06-17 02:05:08 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenGtIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static IntegerAttr getI1IntegerAttr(MLIRContext *context, bool value) {
|
|
|
|
return IntegerAttr::get(IntegerType::get(context, 1),
|
|
|
|
static_cast<int64_t>(value));
|
|
|
|
}
|
|
|
|
|
2022-02-11 05:25:25 +08:00
|
|
|
using ConstantFloatComparator = std::function<bool(double, double)>;
|
2021-08-11 09:28:50 +08:00
|
|
|
template <typename OpTy>
|
2022-02-11 05:25:25 +08:00
|
|
|
static OpFoldResult
|
|
|
|
floatComparatorFoldHelper(OpTy op, ConstantFloatComparator comparator) {
|
2021-08-11 09:28:50 +08:00
|
|
|
if (op.getOperand(0) == op.getOperand(1))
|
|
|
|
return getI1IntegerAttr(op.getContext(), comparator(0, 0));
|
|
|
|
|
2022-02-11 05:25:25 +08:00
|
|
|
double lhs, rhs;
|
|
|
|
if (!matchPattern(op.getOperand(0), m_TorchConstantFloat(&lhs)) ||
|
|
|
|
!matchPattern(op.getOperand(1), m_TorchConstantFloat(&rhs)))
|
2021-08-11 09:28:50 +08:00
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
return getI1IntegerAttr(op.getContext(), comparator(lhs, rhs));
|
2021-06-17 02:05:08 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2022-02-11 05:25:25 +08:00
|
|
|
// AtenLtFloatOp
|
2021-06-17 02:05:08 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-11 05:25:25 +08:00
|
|
|
OpFoldResult AtenLtFloatOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
return floatComparatorFoldHelper(*this,
|
|
|
|
[](double a, double b) { return a < b; });
|
2021-08-11 09:28:50 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2022-02-11 05:25:25 +08:00
|
|
|
// AtenGtFloatOp
|
2021-08-11 09:28:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-11 05:25:25 +08:00
|
|
|
OpFoldResult AtenGtFloatOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
return floatComparatorFoldHelper(*this,
|
|
|
|
[](double a, double b) { return a > b; });
|
2021-08-11 09:28:50 +08:00
|
|
|
}
|
|
|
|
|
2022-04-25 21:12:45 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenGeFloatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenGeFloatOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
return floatComparatorFoldHelper(*this,
|
|
|
|
[](double a, double b) { return a >= b; });
|
|
|
|
}
|
|
|
|
|
2022-01-11 15:42:53 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenEqFloatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenEqFloatOp::fold(ArrayRef<Attribute> operands) {
|
2022-02-11 05:25:25 +08:00
|
|
|
return floatComparatorFoldHelper(*this,
|
|
|
|
[](double a, double b) { return a == b; });
|
|
|
|
}
|
|
|
|
|
|
|
|
using ConstantIntComparator = std::function<bool(int64_t, int64_t)>;
|
|
|
|
template <typename OpTy>
|
|
|
|
static OpFoldResult intComparatorFoldHelper(OpTy op,
|
|
|
|
ConstantIntComparator comparator) {
|
2022-03-10 08:44:22 +08:00
|
|
|
|
|
|
|
Value lhsValue = op->getOperand(0);
|
|
|
|
Value rhsValue = op->getOperand(1);
|
|
|
|
if (lhsValue == rhsValue)
|
2022-02-11 05:25:25 +08:00
|
|
|
return getI1IntegerAttr(op.getContext(), comparator(0, 0));
|
2022-01-11 15:42:53 +08:00
|
|
|
|
2022-02-11 05:25:25 +08:00
|
|
|
int64_t lhs, rhs;
|
2022-03-10 08:44:22 +08:00
|
|
|
bool lhsIsConstant = matchPattern(lhsValue, m_TorchConstantInt(&lhs));
|
|
|
|
bool rhsIsConstant = matchPattern(rhsValue, m_TorchConstantInt(&rhs));
|
|
|
|
if (lhsIsConstant && rhsIsConstant)
|
|
|
|
return getI1IntegerAttr(op.getContext(), comparator(lhs, rhs));
|
2022-01-11 15:42:53 +08:00
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
// Ensure that if there is a constant, it is on the right.
|
|
|
|
if (lhsIsConstant && !rhsIsConstant) {
|
|
|
|
std::swap(lhs, rhs);
|
|
|
|
std::swap(lhsValue, rhsValue);
|
|
|
|
std::swap(lhsIsConstant, rhsIsConstant);
|
|
|
|
auto newComparator = [comparator](int64_t lhs, int64_t rhs) {
|
|
|
|
return comparator(rhs, lhs);
|
|
|
|
};
|
|
|
|
comparator = newComparator;
|
|
|
|
}
|
2022-03-31 05:10:51 +08:00
|
|
|
|
|
|
|
// Fold comparisons of AtenSizeIntOp against negative values.
|
|
|
|
// AtenSizeIntOp is known to always be non-negative.
|
2022-03-10 08:44:22 +08:00
|
|
|
if (rhsIsConstant && rhs < 0) {
|
|
|
|
// We can return `comparator(0, -1)` here because of the property:
|
|
|
|
// If x >= 0 && y < 0, then:
|
|
|
|
// - cmp(x, y) == cmp(x + 1, y)
|
|
|
|
// - cmp(x, y) == cmp(x, y - 1)
|
|
|
|
// By induction all cases here are covered.
|
|
|
|
if (auto size = lhsValue.getDefiningOp<AtenSizeIntOp>())
|
|
|
|
return getI1IntegerAttr(op->getContext(), comparator(0, -1));
|
|
|
|
}
|
2022-03-31 05:10:51 +08:00
|
|
|
|
|
|
|
// Fold comparisons of AtenSizeIntOp against 0:
|
|
|
|
// - torch.aten.size.int >= 0 ==> True.
|
|
|
|
// - torch.aten.size.int < 0 ==> False.
|
|
|
|
// (and the operand-swapped versions of the above)
|
|
|
|
if (rhsIsConstant && rhs == 0) {
|
|
|
|
if (auto size = lhsValue.getDefiningOp<AtenSizeIntOp>()) {
|
|
|
|
// >= 0 comparison.
|
|
|
|
if (comparator(0, 0) && comparator(1, 0))
|
|
|
|
return getI1IntegerAttr(op->getContext(), true);
|
|
|
|
// < 0 comparison.
|
|
|
|
if (!comparator(0, 0) && comparator(-1, 0) && !comparator(1, 0))
|
|
|
|
return getI1IntegerAttr(op->getContext(), false);
|
|
|
|
}
|
2022-03-10 08:44:22 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
return nullptr;
|
2022-02-11 05:25:25 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenNeIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenNeIntOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
return intComparatorFoldHelper(*this,
|
|
|
|
[](int64_t a, int64_t b) { return a != b; });
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenEqIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenEqIntOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
return intComparatorFoldHelper(*this,
|
|
|
|
[](int64_t a, int64_t b) { return a == b; });
|
2022-01-11 15:42:53 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenEqStrOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenEqStrOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
if (getOperand(0) == getOperand(1))
|
|
|
|
return getI1IntegerAttr(getContext(), true);
|
|
|
|
|
|
|
|
auto aStr = a().getDefiningOp<ConstantStrOp>();
|
|
|
|
auto bStr = b().getDefiningOp<ConstantStrOp>();
|
|
|
|
|
|
|
|
if (aStr && bStr)
|
|
|
|
return getI1IntegerAttr(getContext(), aStr == bStr);
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2021-08-11 09:28:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenLtIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenLtIntOp::fold(ArrayRef<Attribute> operands) {
|
2022-02-11 05:25:25 +08:00
|
|
|
return intComparatorFoldHelper(*this,
|
|
|
|
[](int64_t a, int64_t b) { return a < b; });
|
2021-08-11 09:28:50 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenLeIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenLeIntOp::fold(ArrayRef<Attribute> operands) {
|
2022-02-11 05:25:25 +08:00
|
|
|
return intComparatorFoldHelper(*this,
|
|
|
|
[](int64_t a, int64_t b) { return a <= b; });
|
2021-08-11 09:28:50 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenGtIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenGtIntOp::fold(ArrayRef<Attribute> operands) {
|
2022-02-11 05:25:25 +08:00
|
|
|
return intComparatorFoldHelper(*this,
|
|
|
|
[](int64_t a, int64_t b) { return a > b; });
|
2021-08-11 09:28:50 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenGeIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenGeIntOp::fold(ArrayRef<Attribute> operands) {
|
2022-02-11 05:25:25 +08:00
|
|
|
return intComparatorFoldHelper(*this,
|
|
|
|
[](int64_t a, int64_t b) { return a >= b; });
|
2021-06-17 02:05:08 +08:00
|
|
|
}
|
|
|
|
|
2022-05-20 16:26:52 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenBoolFloatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenBoolFloatOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
double c;
|
|
|
|
if (matchPattern(getOperand(), m_TorchConstantFloat(&c)))
|
|
|
|
return getI1IntegerAttr(getContext(), c != 0.0);
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenBoolIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenBoolIntOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
int64_t c;
|
|
|
|
if (matchPattern(getOperand(), m_TorchConstantInt(&c)))
|
|
|
|
return getI1IntegerAttr(getContext(), c != 0);
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenFloatScalarOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenFloatScalarOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
// Constant fold int -> float conversion.
|
|
|
|
if (auto integerAttr = operands[0].dyn_cast_or_null<IntegerAttr>()) {
|
|
|
|
return FloatAttr::get(
|
|
|
|
mlir::Float64Type::get(getContext()),
|
|
|
|
static_cast<double>(integerAttr.getValue().getSExtValue()));
|
|
|
|
}
|
|
|
|
// If the input is float type already, the op is an identity.
|
|
|
|
if (getType() == getOperand().getType())
|
|
|
|
return getOperand();
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2022-04-26 20:15:30 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenIntScalarOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenIntScalarOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
// Constant fold float -> int conversion.
|
|
|
|
if (auto floatAttr = operands[0].dyn_cast_or_null<FloatAttr>()) {
|
|
|
|
return IntegerAttr::get(
|
|
|
|
mlir::IntegerType::get(getContext(), 64, IntegerType::Signed),
|
|
|
|
static_cast<long>(floatAttr.getValue().convertToDouble()));
|
|
|
|
}
|
|
|
|
// If the input is int type already, the op is an identity.
|
|
|
|
if (getType() == getOperand().getType())
|
|
|
|
return getOperand();
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2021-06-17 23:52:13 +08:00
|
|
|
// NonValueTensorLiteralOp
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2021-06-17 23:52:13 +08:00
|
|
|
LogicalResult NonValueTensorLiteralOp::inferReturnTypes(
|
|
|
|
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
|
|
|
DictionaryAttr attributes, RegionRange regions,
|
|
|
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
auto attr = attributes.get("value").dyn_cast_or_null<ElementsAttr>();
|
|
|
|
if (!attr)
|
|
|
|
return failure();
|
2021-09-14 08:57:59 +08:00
|
|
|
RankedTensorType tensorType = attr.getType().cast<RankedTensorType>();
|
|
|
|
NonValueTensorType returnType =
|
|
|
|
NonValueTensorType::get(tensorType.getContext(), tensorType.getShape(),
|
|
|
|
tensorType.getElementType());
|
|
|
|
inferredReturnTypes.push_back(returnType);
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
static bool areSizesAndDtypesCompatible(BaseTensorType a, BaseTensorType b) {
|
|
|
|
if (a.hasSizes() && b.hasSizes()) {
|
|
|
|
if (failed(verifyCompatibleShape(a.getSizes(), b.getSizes())))
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
if (a.hasDtype() && b.hasDtype()) {
|
|
|
|
if (a.getDtype() != b.getDtype())
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2021-06-17 23:52:13 +08:00
|
|
|
bool NonValueTensorLiteralOp::isCompatibleReturnTypes(TypeRange inferred,
|
|
|
|
TypeRange actual) {
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
if (!actual[0].isa<BaseTensorType>())
|
|
|
|
return false;
|
|
|
|
return areSizesAndDtypesCompatible(inferred[0].cast<BaseTensorType>(),
|
|
|
|
actual[0].cast<BaseTensorType>());
|
|
|
|
}
|
|
|
|
|
2021-06-17 23:52:13 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ValueTensorLiteralOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult ValueTensorLiteralOp::inferReturnTypes(
|
|
|
|
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
|
|
|
DictionaryAttr attributes, RegionRange regions,
|
|
|
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
|
|
|
auto attr = attributes.get("value").dyn_cast_or_null<ElementsAttr>();
|
|
|
|
if (!attr)
|
|
|
|
return failure();
|
2021-09-14 08:57:59 +08:00
|
|
|
RankedTensorType tensorType = attr.getType().cast<RankedTensorType>();
|
|
|
|
ValueTensorType returnType =
|
|
|
|
ValueTensorType::get(tensorType.getContext(), tensorType.getShape(),
|
|
|
|
tensorType.getElementType());
|
|
|
|
inferredReturnTypes.push_back(returnType);
|
2021-06-17 23:52:13 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
OpFoldResult ValueTensorLiteralOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
return valueAttr();
|
|
|
|
}
|
|
|
|
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
//----------------------------------------------------------------------------//
|
|
|
|
// TensorStaticInfoCast
|
|
|
|
//----------------------------------------------------------------------------//
|
|
|
|
|
|
|
|
bool TensorStaticInfoCastOp::areCastCompatible(mlir::TypeRange inputs,
|
|
|
|
mlir::TypeRange outputs) {
|
|
|
|
return areSizesAndDtypesCompatible(inputs[0].cast<BaseTensorType>(),
|
|
|
|
outputs[0].cast<BaseTensorType>());
|
|
|
|
}
|
|
|
|
|
2021-10-16 06:23:59 +08:00
|
|
|
void TensorStaticInfoCastOp::getCanonicalizationPatterns(
|
|
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
|
|
patterns.add(+[](TensorStaticInfoCastOp op, PatternRewriter &rewriter) {
|
|
|
|
auto reverseCast =
|
|
|
|
op.operand().getDefiningOp<Torch::TensorStaticInfoCastOp>();
|
|
|
|
if (!reverseCast || reverseCast.operand().getType() != op.getType())
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, reverseCast.operand());
|
|
|
|
return success();
|
|
|
|
});
|
2022-03-10 08:44:22 +08:00
|
|
|
patterns.add(+[](TensorStaticInfoCastOp op, PatternRewriter &rewriter) {
|
2022-03-19 00:10:12 +08:00
|
|
|
if (isValidSubtype(op.getOperand().getType(), op.getType())) {
|
|
|
|
SmallVector<std::reference_wrapper<OpOperand>> usesToChange(
|
|
|
|
llvm::make_filter_range(op->getUses(), [](OpOperand &operand) {
|
|
|
|
return operand.getOwner()
|
|
|
|
->hasTrait<mlir::torch::Torch::OpTrait::AllowsTypeRefinement>();
|
|
|
|
}));
|
|
|
|
|
|
|
|
if (usesToChange.empty())
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
for (OpOperand &use : usesToChange) {
|
|
|
|
Operation *user = use.getOwner();
|
|
|
|
user->setOperand(use.getOperandNumber(), op.operand());
|
|
|
|
}
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
return failure();
|
|
|
|
});
|
2021-10-16 06:23:59 +08:00
|
|
|
}
|
|
|
|
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2021-06-19 04:47:47 +08:00
|
|
|
// CopyToNonValueTensorOp
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-03-16 08:54:57 +08:00
|
|
|
LogicalResult CopyToNonValueTensorOp::verify() {
|
|
|
|
auto resultType = getResult().getType().cast<BaseTensorType>();
|
|
|
|
auto operandType = getOperand().getType().cast<BaseTensorType>();
|
|
|
|
if (!resultType.hasSameSizesAndDtype(operandType))
|
|
|
|
return emitError() << "operand and result must have same sizes and dtype";
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2021-06-19 04:47:47 +08:00
|
|
|
LogicalResult CopyToNonValueTensorOp::inferReturnTypes(
|
|
|
|
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
|
|
|
DictionaryAttr attributes, RegionRange regions,
|
|
|
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
|
|
|
auto resultType = operands[0].getType().cast<ValueTensorType>();
|
|
|
|
inferredReturnTypes.push_back(resultType.getWithoutValueSemantics());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
void CopyToNonValueTensorOp::getEffects(
|
|
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
|
|
&effects) {
|
|
|
|
effects.emplace_back(MemoryEffects::Allocate::get(), getResult());
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// CopyToValueTensorOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-03-16 08:54:57 +08:00
|
|
|
LogicalResult CopyToValueTensorOp::verify() {
|
|
|
|
auto resultType = getResult().getType().cast<BaseTensorType>();
|
|
|
|
auto operandType = getOperand().getType().cast<BaseTensorType>();
|
|
|
|
if (!resultType.hasSameSizesAndDtype(operandType))
|
|
|
|
return emitError() << "operand and result must have same sizes and dtype";
|
2021-06-19 04:47:47 +08:00
|
|
|
return success();
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
}
|
|
|
|
|
2021-06-19 04:47:47 +08:00
|
|
|
LogicalResult CopyToValueTensorOp::inferReturnTypes(
|
|
|
|
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
|
|
|
DictionaryAttr attributes, RegionRange regions,
|
|
|
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
|
|
|
auto resultType = operands[0].getType().cast<NonValueTensorType>();
|
|
|
|
inferredReturnTypes.push_back(resultType.getWithValueSemantics());
|
|
|
|
return success();
|
Introduce `!torch.tensor` / `!torch.vtensor` types.
This removes our reliance on the numpy dialect and avoids our off-label
use of the builtin tnesor type for modeling unknown dtypes. The
`!torch.vtensor` (`ValueTensorType`) type is a value-semantic tensor.
The `!torch.tensor` (`NonValueTensorType`) type is a non-value-semantic
tensor. The new types look as follows syntactically:
```
// Least-static-information, non-value-semantic tensor.
!torch.tensor
// Explicit form of least-static-information variant.
!torch.tensor<*,unk>
// Least-static-information, value-semantic tensor.
!torch.vtensor
// Explicit form of least-static-information variant.
!torch.vtensor<*,unk>
// Fixed-set of allowable element types, with first-class support for
// Torch's frontend signedness semantics.
!torch.tensor<*,si32>
// First-class support for unknown dtypes.
!torch.tensor<[?,?,?],unk>
// Standard MLIR representation of `?` for unknown dimensions.
!torch.tensor<[?,2,?,4],unk>
// Statically shaped / dtyped example.
!torch.vtensor<[1,2,3,4],f32>
```
This required fairly significant changes throughout the compiler, but
overall it is a big cleanup. We now have a much clearer layering of "the
Torch frontend lowering" vs "lowering to std + linalg + etc.".
At the C++ level, there is `ValueTensorType`, `NonValueTensorType`.
We also have a helper `BaseTensorType` (kind of like ShapedType) which
interoperates with those two.
Included changes:
- New `torch.tensor(dense<0.0> : tensor<5xf32>) : !torch.tensor` op for
creating torch tensor literals in the frontend.
- Consistently use signedness for the types (except i1 which I didn't
touch -- we need to sort out the situation with !basicpy.BoolType
there anyway so will be attending to that soon)
- Frontend can annotate whether an argument to the function has value
semantics. We currently require this, as our backend contract does not
currently allow us to even model the non-value-semantic case. Before,
the value-semantic assumption was randomly injected in the middle of
the pass pipeline.
- Move ArrayToTensor (now called MaximizeValueSemantics) and
RefinePublicReturn passes to torch dialect.
- The TorchToStd and TorchToLinalg passes are now type conversions from
`!torch.vtensor` to `tensor` and use the dialect conversion infra.
The overall conversion pipeline is set up following the best practices
of the "Type Conversions the Not-So-Hard Way" talk. This required
introducing `torch-func-builtin-tensorize` and
`torch-finalizing-builtin-tensorize` passes analogous to the upstream
bufferization passes with the corresponding names (mostly just
copypasta from there).
- Misc Torch-level canonicalizations -- we now cleanly layer the
lowering to std later in the pipeline, so we are gradually lessening
our reliance on random std constant folding before we get to that
point.
Recommended review order:
- New types in TorchTypes.td/TorchTypes.h/TorchDialect.cpp
- New ops in TorchOps.td / TorchOps.cpp
- Less important / more mechanical stuff
- Frontend changes.
- Pass changes/additions in `Torch/Transforms` and `Conversion/`
2021-05-21 08:07:18 +08:00
|
|
|
}
|
|
|
|
|
2021-06-19 04:47:47 +08:00
|
|
|
void CopyToValueTensorOp::getEffects(
|
|
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
2021-06-18 07:29:20 +08:00
|
|
|
&effects) {
|
2021-06-19 04:47:47 +08:00
|
|
|
effects.emplace_back(MemoryEffects::Read::get(), getOperand());
|
2021-06-18 07:29:20 +08:00
|
|
|
}
|
|
|
|
|
2021-06-15 02:36:10 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ConstantNoneOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult ConstantNoneOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
return TypeAttr::get(Torch::NoneType::get(getContext()));
|
|
|
|
}
|
|
|
|
|
2021-06-16 03:42:51 +08:00
|
|
|
void ConstantNoneOp::getAsmResultNames(
|
|
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
|
|
|
setNameFn(getResult(), "none");
|
|
|
|
}
|
|
|
|
|
2021-06-15 23:29:06 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ConstantStrOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult ConstantStrOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
return valueAttr();
|
|
|
|
}
|
|
|
|
|
2021-06-16 03:42:51 +08:00
|
|
|
void ConstantStrOp::getAsmResultNames(
|
|
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
|
|
|
setNameFn(getResult(), "str");
|
|
|
|
}
|
|
|
|
|
2021-09-28 22:56:08 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ConstantDeviceOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void ConstantDeviceOp::getAsmResultNames(
|
|
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
|
|
|
setNameFn(getResult(), value());
|
|
|
|
}
|
|
|
|
|
2021-06-16 03:42:51 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ConstantIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-02-13 02:47:12 +08:00
|
|
|
ParseResult ConstantIntOp::parse(OpAsmParser &parser, OperationState &result) {
|
2021-06-17 06:53:15 +08:00
|
|
|
Builder builder(result.getContext());
|
|
|
|
result.addTypes(builder.getType<Torch::IntType>());
|
|
|
|
if (parser.parseOptionalAttrDict(result.attributes))
|
|
|
|
return failure();
|
|
|
|
int64_t value;
|
|
|
|
if (parser.parseInteger(value))
|
|
|
|
return failure();
|
|
|
|
result.addAttribute("value", builder.getI64IntegerAttr(value));
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-02-13 02:47:12 +08:00
|
|
|
void ConstantIntOp::print(OpAsmPrinter &p) {
|
2021-09-04 02:38:00 +08:00
|
|
|
p << " ";
|
2022-02-13 02:47:12 +08:00
|
|
|
p << value().getSExtValue();
|
|
|
|
p.printOptionalAttrDict((*this)->getAttrs(), {"value"});
|
2021-06-17 06:53:15 +08:00
|
|
|
}
|
|
|
|
|
2021-06-16 03:42:51 +08:00
|
|
|
OpFoldResult Torch::ConstantIntOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
return valueAttr();
|
|
|
|
}
|
|
|
|
|
|
|
|
void Torch::ConstantIntOp::getAsmResultNames(
|
|
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
|
|
|
SmallVector<char> buf;
|
|
|
|
llvm::raw_svector_ostream os(buf);
|
|
|
|
os << "int" << value();
|
|
|
|
setNameFn(getResult(), os.str());
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ConstantFloatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult Torch::ConstantFloatOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
return valueAttr();
|
|
|
|
}
|
|
|
|
|
|
|
|
void Torch::ConstantFloatOp::getAsmResultNames(
|
|
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
2021-06-23 05:25:16 +08:00
|
|
|
// Calculate a stringified version of the number, compatible with MLIR
|
|
|
|
// identifier syntax. (in practice, this just removes the '+' from 'e+' in
|
|
|
|
// float string representation).
|
|
|
|
SmallVector<char> buf;
|
|
|
|
value().toString(buf, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0,
|
|
|
|
/*TruncateZero=*/false);
|
|
|
|
auto isValidMLIRIdentifierChar = [](char c) {
|
|
|
|
return isalpha(c) || isdigit(c) || c == '_' || c == '$' || c == '.' ||
|
|
|
|
c == '-';
|
|
|
|
};
|
|
|
|
auto numberStr = llvm::to_vector<16>(
|
|
|
|
llvm::make_filter_range(buf, isValidMLIRIdentifierChar));
|
|
|
|
|
|
|
|
// Construct the identifier string.
|
|
|
|
buf.clear();
|
|
|
|
llvm::append_range(buf, StringRef("float"));
|
|
|
|
llvm::append_range(buf, numberStr);
|
|
|
|
setNameFn(getResult(), StringRef(buf.data(), buf.size()));
|
2021-06-16 03:42:51 +08:00
|
|
|
}
|
|
|
|
|
2021-06-16 07:47:53 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ConstantBoolOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult Torch::ConstantBoolOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
return valueAttr();
|
|
|
|
}
|
|
|
|
|
|
|
|
void Torch::ConstantBoolOp::getAsmResultNames(
|
|
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
|
|
|
setNameFn(getResult(), value() ? "true" : "false");
|
|
|
|
}
|
|
|
|
|
2021-06-05 06:57:21 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2021-06-23 04:56:12 +08:00
|
|
|
// PrimUncheckedCastOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
bool PrimUncheckedCastOp::areCastCompatible(mlir::TypeRange inputs,
|
|
|
|
mlir::TypeRange outputs) {
|
|
|
|
return isValidSubtype(outputs[0], inputs[0]);
|
|
|
|
}
|
|
|
|
|
2022-01-28 21:35:40 +08:00
|
|
|
OpFoldResult PrimUncheckedCastOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
if (auto derefineOp = x().getDefiningOp<Torch::DerefineOp>()) {
|
|
|
|
if (derefineOp.operand().getType() == getType())
|
|
|
|
return derefineOp.operand();
|
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2021-06-23 04:56:12 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2021-06-05 06:57:21 +08:00
|
|
|
// Aten__Getitem__TOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2021-06-19 10:33:14 +08:00
|
|
|
void Aten__Getitem__TOp::getCanonicalizationPatterns(
|
|
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
2021-06-05 06:57:21 +08:00
|
|
|
patterns.add(+[](Aten__Getitem__TOp op, PatternRewriter &rewriter) {
|
|
|
|
auto torchList = op.getOperand(0);
|
2022-03-10 08:44:22 +08:00
|
|
|
if (isListPotentiallyMutated(torchList))
|
2021-06-05 06:57:21 +08:00
|
|
|
return failure();
|
|
|
|
|
|
|
|
auto listConstruct = torchList.getDefiningOp<Torch::PrimListConstructOp>();
|
|
|
|
if (!listConstruct)
|
|
|
|
return failure();
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
// Get the index, but be careful because it might be statically invalid.
|
2022-03-30 04:21:47 +08:00
|
|
|
llvm::Optional<int64_t> indexOpt = matchLegalConstantIndexIntoListOfSize(
|
|
|
|
op.getOperand(1), listConstruct.getNumOperands());
|
|
|
|
if (!indexOpt)
|
2022-03-10 08:44:22 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "statically invalid index");
|
2021-06-05 06:57:21 +08:00
|
|
|
|
2022-03-30 04:21:47 +08:00
|
|
|
rewriter.replaceOp(op, {listConstruct.getOperand(*indexOpt)});
|
2021-06-05 06:57:21 +08:00
|
|
|
return success();
|
|
|
|
});
|
2022-03-10 08:44:22 +08:00
|
|
|
patterns.add(+[](Aten__Getitem__TOp op, PatternRewriter &rewriter) {
|
|
|
|
auto sizeOp = op.list().getDefiningOp<AtenSizeOp>();
|
|
|
|
if (!sizeOp)
|
|
|
|
return failure();
|
|
|
|
// This assumes tht the size doesn't change between the
|
|
|
|
// AtenSizeOp and the Aten__Getitem__TOp.
|
|
|
|
// `t_` is the only op I can find that changes the shape in-place. It seems
|
|
|
|
// like otherwise we can treat the size of a tensor as having value
|
|
|
|
// semantics. The other view-like ops don't have in-place variants --
|
|
|
|
// they always return a new SSA value that is aliased to the input.
|
|
|
|
// Can we have a pass to normalize the `t_` case and then elsewhere in the
|
|
|
|
// compiler treat the size as having value semantics?
|
|
|
|
// There's a small number of such ops, and they are marked as `inplace_view`
|
|
|
|
// in PyTorch's `native_functions.yaml` file.
|
|
|
|
rewriter.replaceOpWithNewOp<AtenSizeIntOp>(op, sizeOp.self(), op.idx());
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenEqIntListOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenEqIntListOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
auto lhsLiteral = a().getDefiningOp<Torch::PrimListConstructOp>();
|
|
|
|
if (!lhsLiteral)
|
|
|
|
return nullptr;
|
|
|
|
auto rhsLiteral = b().getDefiningOp<Torch::PrimListConstructOp>();
|
|
|
|
if (!rhsLiteral)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
// If the sizes don't match, then we know the lists aren't equal.
|
|
|
|
if (lhsLiteral.getNumOperands() != rhsLiteral.getNumOperands())
|
|
|
|
return getI1IntegerAttr(getContext(), false);
|
|
|
|
|
|
|
|
// If the sizes match and all corresponding list elements are the same Value,
|
|
|
|
// then we know the lists are equal.
|
|
|
|
// Note that we can't prove that the lists are not-equal with this method,
|
|
|
|
// since two different Value's might dynamically be equal.
|
|
|
|
if (llvm::all_of(
|
|
|
|
llvm::zip(lhsLiteral.getOperands(), rhsLiteral.getOperands()),
|
|
|
|
[](const auto &pair) {
|
|
|
|
return std::get<0>(pair) == std::get<1>(pair);
|
|
|
|
}))
|
|
|
|
return getI1IntegerAttr(getContext(), true);
|
|
|
|
return nullptr;
|
2021-06-05 06:57:21 +08:00
|
|
|
}
|
|
|
|
|
2021-11-08 23:56:40 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimTupleIndexOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](PrimTupleIndexOp op, PatternRewriter &rewriter) {
|
|
|
|
auto tupleConstruct = op.tup().getDefiningOp<Torch::PrimTupleConstructOp>();
|
|
|
|
if (!tupleConstruct)
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
int64_t i;
|
|
|
|
if (!matchPattern(op.i(), m_TorchConstantInt(&i)))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
if (i >= (int64_t)tupleConstruct.elements().size())
|
|
|
|
return failure();
|
|
|
|
|
2022-05-03 17:12:09 +08:00
|
|
|
// TODO: We should have a clear picture of whether we want to consistently
|
|
|
|
// allow refinement, and where. It seems desirable to require precise
|
|
|
|
// type equality for TupleConstruct / TupleIndex, but that might break
|
|
|
|
// things.
|
2022-05-19 21:12:58 +08:00
|
|
|
Value replacement = tupleConstruct.elements()[i];
|
|
|
|
if (replacement.getType() != op.getType()) {
|
|
|
|
if (op.getType().isa<BaseTensorType>()) {
|
|
|
|
replacement = rewriter.create<Torch::TensorStaticInfoCastOp>(
|
|
|
|
op.getLoc(), op.getType(), replacement);
|
|
|
|
} else {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
rewriter.replaceOp(op, replacement);
|
2021-11-08 23:56:40 +08:00
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimUninitializedOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void PrimUninitializedOp::getCanonicalizationPatterns(
|
|
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
|
|
patterns.add(+[](PrimUninitializedOp op, PatternRewriter &rewriter) {
|
|
|
|
if (!op.use_empty())
|
|
|
|
return failure();
|
|
|
|
rewriter.eraseOp(op);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2021-08-18 01:59:47 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimTupleUnpackOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void PrimTupleUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
|
|
MLIRContext *context) {
|
|
|
|
patterns.add(+[](PrimTupleUnpackOp op, PatternRewriter &rewriter) {
|
2021-11-08 23:56:40 +08:00
|
|
|
auto tupleConstruct = op.tup().getDefiningOp<Torch::PrimTupleConstructOp>();
|
2021-08-18 01:59:47 +08:00
|
|
|
if (!tupleConstruct)
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, tupleConstruct.elements());
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
|
|
|
static PrimDictConstructOp getDictConstructIfNotModified(Value torchDict) {
|
|
|
|
if (!llvm::all_of(torchDict.getUsers(), [](Operation *op) {
|
|
|
|
return isa<Aten__Getitem__DictStrOp, Aten__Contains__StrOp,
|
2021-08-28 05:18:29 +08:00
|
|
|
AtenKeysStrOp, AtenGetDefaultStrOp, PrimDictConstructOp>(op);
|
2021-08-18 01:59:47 +08:00
|
|
|
}))
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
return torchDict.getDefiningOp<Torch::PrimDictConstructOp>();
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten__Getitem__DictStrOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult Aten__Getitem__DictStrOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
auto dictConstruct = getDictConstructIfNotModified(self());
|
|
|
|
if (!dictConstruct)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
auto targetKey = key();
|
|
|
|
for (auto i : llvm::zip(dictConstruct.keys(), dictConstruct.values())) {
|
|
|
|
auto k = std::get<0>(i);
|
|
|
|
if (k == targetKey)
|
|
|
|
return std::get<1>(i);
|
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten__Contains__StrOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult Aten__Contains__StrOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
auto dictConstruct = getDictConstructIfNotModified(dict());
|
|
|
|
if (!dictConstruct)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
auto targetKey = key();
|
|
|
|
for (auto key : dictConstruct.keys()) {
|
|
|
|
if (key == targetKey)
|
|
|
|
return getI1IntegerAttr(getContext(), true);
|
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2022-06-23 23:16:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Aten__Contains__IntListOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static bool isListConstructNotModified(Value torchList) {
|
|
|
|
return llvm::all_of(torchList.getUsers(), [](Operation *op) {
|
|
|
|
return isa<Aten__Contains__IntListOp>(op);
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
|
|
|
OpFoldResult Aten__Contains__IntListOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
auto itemConstruct = item();
|
|
|
|
if (!isListConstructNotModified(l()))
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
int64_t item;
|
|
|
|
SmallVector<int64_t> list;
|
|
|
|
|
|
|
|
if (!matchPattern(itemConstruct, m_TorchConstantInt(&item)))
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
if (!matchPattern(l(), m_TorchConstantIntList(list)))
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
for (auto elem : list) {
|
|
|
|
if (elem == item)
|
|
|
|
return getI1IntegerAttr(getContext(), true);
|
|
|
|
}
|
|
|
|
return getI1IntegerAttr(getContext(), false);
|
|
|
|
}
|
|
|
|
|
2021-08-18 01:59:47 +08:00
|
|
|
using BinaryIntOperatorFn = std::function<int64_t(int64_t, int64_t)>;
|
|
|
|
template <typename OpTy>
|
|
|
|
static OpFoldResult atenBinaryIntOperatorFoldHelper(OpTy op,
|
|
|
|
BinaryIntOperatorFn f) {
|
|
|
|
int64_t lhs, rhs;
|
|
|
|
if (!matchPattern(op.getOperand(0), m_TorchConstantInt(&lhs)) ||
|
|
|
|
!matchPattern(op.getOperand(1), m_TorchConstantInt(&rhs)))
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
return getI64IntegerAttr(op.getContext(), f(lhs, rhs));
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenFloordivIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenFloordivIntOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
return atenBinaryIntOperatorFoldHelper(
|
|
|
|
*this, [](int64_t a, int64_t b) { return std::floor(a / (double)b); });
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenRemainderIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenRemainderIntOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
return atenBinaryIntOperatorFoldHelper(
|
|
|
|
*this, [](int64_t a, int64_t b) { return a % b; });
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenAddIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenAddIntOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
return atenBinaryIntOperatorFoldHelper(
|
|
|
|
*this, [](int64_t a, int64_t b) { return a + b; });
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSubIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenSubIntOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
return atenBinaryIntOperatorFoldHelper(
|
|
|
|
*this, [](int64_t a, int64_t b) { return a - b; });
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenMulIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenMulIntOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
int64_t lhs, rhs;
|
|
|
|
bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs));
|
|
|
|
bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs));
|
|
|
|
if ((lConstant && lhs == 0) || (rConstant && rhs == 0))
|
|
|
|
return getI64IntegerAttr(getContext(), 0);
|
|
|
|
if (lConstant && rConstant)
|
|
|
|
return getI64IntegerAttr(getContext(), lhs * rhs);
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2022-01-11 15:42:53 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenNegIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenNegIntOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
int64_t c;
|
|
|
|
if (matchPattern(getOperand(), m_TorchConstantInt(&c)))
|
|
|
|
return getI64IntegerAttr(getContext(), -c);
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2022-05-19 22:54:16 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenSqrtIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenSqrtIntOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
int64_t c;
|
|
|
|
if (matchPattern(getOperand(), m_TorchConstantInt(&c)))
|
|
|
|
return getF64FloatAttr(getContext(), std::sqrt(c));
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2021-09-02 03:53:52 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimDtypeOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult PrimDtypeOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
BaseTensorType tensorType = a().getType().cast<BaseTensorType>();
|
|
|
|
if (tensorType.hasDtype()) {
|
2022-07-08 05:21:05 +08:00
|
|
|
torch_upstream::ScalarType scalarType =
|
|
|
|
Torch::getScalarTypeForType(tensorType.getDtype());
|
|
|
|
return getI64IntegerAttr(getContext(), static_cast<int64_t>(scalarType));
|
2021-09-02 03:53:52 +08:00
|
|
|
}
|
|
|
|
return nullptr;
|
|
|
|
}
|
2021-11-30 02:39:37 +08:00
|
|
|
|
2022-02-09 19:55:14 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenIntTensorOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2021-11-30 02:39:37 +08:00
|
|
|
OpFoldResult AtenIntTensorOp::fold(ArrayRef<Attribute> operands) {
|
2022-02-09 19:55:14 +08:00
|
|
|
// If a scalar number is converted to a 0-d tensor and passed on to
|
2021-11-30 02:39:37 +08:00
|
|
|
// aten.Int.Tensor, fold to the scalar number.
|
|
|
|
if (auto numToTensorScalar = a().getDefiningOp<PrimNumToTensorScalarOp>())
|
|
|
|
return numToTensorScalar.a();
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2022-02-09 19:55:14 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AtenFloatTensorOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenFloatTensorOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
// If a scalar number is converted to a 0-d tensor and passed on to
|
|
|
|
// aten.Float.Tensor, fold to the scalar number.
|
|
|
|
if (auto numToTensorScalar = a().getDefiningOp<PrimNumToTensorScalarOp>())
|
|
|
|
return numToTensorScalar.a();
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2022-04-25 20:06:41 +08:00
|
|
|
// AtenDivFloatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenDivFloatOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
double lhs, rhs;
|
|
|
|
bool lConstant = matchPattern(getOperand(0), m_TorchConstantFloat(&lhs));
|
|
|
|
bool rConstant = matchPattern(getOperand(1), m_TorchConstantFloat(&rhs));
|
|
|
|
if (lConstant && lhs == 0.0)
|
|
|
|
return getF64FloatAttr(getContext(), 0.0);
|
|
|
|
if (lConstant && rConstant && rhs == 1.0)
|
|
|
|
return getF64FloatAttr(getContext(), lhs);
|
|
|
|
if (lConstant && rConstant)
|
|
|
|
return getF64FloatAttr(getContext(), lhs / rhs);
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2022-04-25 21:12:45 +08:00
|
|
|
// AtenCeilFloatOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult AtenCeilFloatOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
double c;
|
|
|
|
if (matchPattern(getOperand(), m_TorchConstantFloat(&c)))
|
|
|
|
return getI64IntegerAttr(getContext(), std::ceil(c));
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2022-04-25 20:06:41 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2022-02-09 19:55:14 +08:00
|
|
|
|
2022-03-10 08:44:22 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimMaxIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult PrimMaxIntOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
// If both operands are the same, then the operation is an identity.
|
|
|
|
if (a() == b())
|
|
|
|
return a();
|
|
|
|
|
|
|
|
auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
|
|
|
|
auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
|
|
|
|
if (!lhs || !rhs)
|
|
|
|
return nullptr;
|
|
|
|
// Torch semantics are that !torch.int is 64-bit signed.
|
|
|
|
return IntegerAttr::get(
|
|
|
|
lhs.getType(),
|
|
|
|
std::max(lhs.getValue().getSExtValue(), rhs.getValue().getSExtValue()));
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PrimMinSelfIntOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
OpFoldResult PrimMinSelfIntOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
auto list = getOperand().getDefiningOp<PrimListConstructOp>();
|
|
|
|
if (!list)
|
|
|
|
return nullptr;
|
|
|
|
// TODO: What does it return for an empty list?
|
|
|
|
if (list->getNumOperands() == 0)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
SmallVector<int64_t> values;
|
|
|
|
for (auto operand : list->getOperands()) {
|
|
|
|
int64_t value;
|
|
|
|
if (!matchPattern(operand, m_TorchConstantInt(&value)))
|
|
|
|
return nullptr;
|
|
|
|
values.push_back(value);
|
|
|
|
}
|
|
|
|
return getI64IntegerAttr(getContext(),
|
|
|
|
*std::min_element(values.begin(), values.end()));
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ShapeCalculateOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void ShapeCalculateOp::getSuccessorRegions(
|
|
|
|
Optional<unsigned> index, ArrayRef<Attribute> operands,
|
|
|
|
SmallVectorImpl<RegionSuccessor> ®ions) {
|
|
|
|
(void)operands;
|
|
|
|
|
|
|
|
if (!index.hasValue()) {
|
|
|
|
// First thing the op does is branch into the shape calculation.
|
|
|
|
regions.emplace_back(&shapeCalculation());
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
if (*index == 0) {
|
|
|
|
// Body returns control to the outer op, passing through results.
|
|
|
|
regions.emplace_back(getResults());
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
assert(*index == 1);
|
|
|
|
// Shape calculation branches to the body.
|
|
|
|
regions.emplace_back(&body());
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ShapeCalculateYieldShapesOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
MutableOperandRange ShapeCalculateYieldShapesOp::getMutableSuccessorOperands(
|
|
|
|
Optional<unsigned> index) {
|
|
|
|
// The shape operands don't get forwarded to the body.
|
|
|
|
// MutableOperandRange always has an owning operation, even if empty, so
|
|
|
|
// create a 0-length range.
|
|
|
|
return MutableOperandRange(*this, /*start=*/0, /*length=*/0);
|
|
|
|
}
|
|
|
|
|
2022-03-16 08:54:57 +08:00
|
|
|
LogicalResult ShapeCalculateYieldShapesOp::verify() {
|
|
|
|
auto parent = cast<ShapeCalculateOp>(getOperation()->getParentOp());
|
|
|
|
if (parent.getNumResults() != getNumOperands())
|
|
|
|
return emitOpError("expected number of shapes to match number of results");
|
2022-03-10 08:44:22 +08:00
|
|
|
return success();
|
|
|
|
}
|