2021-10-16 06:23:59 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "PassDetail.h"
|
|
|
|
|
2022-02-09 04:57:23 +08:00
|
|
|
#include "mlir/IR/BuiltinDialect.h"
|
2021-10-16 06:23:59 +08:00
|
|
|
#include "mlir/Transforms/DialectConversion.h"
|
2022-12-09 01:26:38 +08:00
|
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
2021-10-16 06:23:59 +08:00
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
2022-04-26 20:18:09 +08:00
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
2021-10-16 06:23:59 +08:00
|
|
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
2022-05-10 21:15:59 +08:00
|
|
|
#include "llvm/ADT/ArrayRef.h"
|
2021-10-16 06:23:59 +08:00
|
|
|
#include "llvm/ADT/StringExtras.h"
|
2022-12-09 01:26:38 +08:00
|
|
|
#include "llvm/ADT/StringSet.h"
|
2022-05-10 21:15:59 +08:00
|
|
|
#include <cstdint>
|
2021-10-16 06:23:59 +08:00
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::torch;
|
|
|
|
using namespace mlir::torch::Torch;
|
|
|
|
|
2022-03-11 01:25:21 +08:00
|
|
|
// Helper function to check whether the `dtype` is None or Float type.
|
|
|
|
static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) {
|
|
|
|
if (dtype.getType().isa<Torch::NoneType>())
|
|
|
|
return true;
|
|
|
|
int64_t dtypeInt;
|
|
|
|
if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
|
|
|
|
return false;
|
2023-01-21 02:40:13 +08:00
|
|
|
FailureOr<Type> resDtype =
|
2022-03-11 01:25:21 +08:00
|
|
|
getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
|
2023-01-21 02:40:13 +08:00
|
|
|
if (failed(resDtype))
|
|
|
|
return false;
|
|
|
|
return resDtype->isa<mlir::FloatType>();
|
2022-03-11 01:25:21 +08:00
|
|
|
}
|
|
|
|
|
2022-02-01 03:56:32 +08:00
|
|
|
// Helper function to compute the return type of the reduction function.
|
|
|
|
// `dim` specifies the dimension to reduce and `keepDim` preserves the rank of
|
|
|
|
// the input tensor.
|
|
|
|
static Type computeReductionType(PatternRewriter &rewriter, Operation *op,
|
2022-06-29 15:23:57 +08:00
|
|
|
BaseTensorType tensorType, Value dim,
|
|
|
|
bool keepDim) {
|
2021-11-08 23:56:40 +08:00
|
|
|
SmallVector<int64_t> sizes;
|
|
|
|
int64_t dimInt;
|
|
|
|
if (tensorType.hasSizes()) {
|
|
|
|
ArrayRef<int64_t> inputShape = tensorType.getSizes();
|
|
|
|
int64_t inputRank = inputShape.size();
|
|
|
|
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
|
|
|
|
dimInt = toPositiveDim(dimInt, inputRank);
|
|
|
|
if (!isValidDim(dimInt, inputRank)) {
|
|
|
|
(void)rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
sizes.append(inputShape.begin(), inputShape.end());
|
2022-02-01 03:56:32 +08:00
|
|
|
// The dimension to be reduced is set to 1 when `keepDim` is true else it
|
|
|
|
// is removed.
|
|
|
|
if (keepDim)
|
|
|
|
sizes[dimInt] = 1;
|
|
|
|
else
|
2022-11-23 17:36:44 +08:00
|
|
|
sizes.erase(sizes.begin() + dimInt);
|
2021-11-08 23:56:40 +08:00
|
|
|
} else {
|
2022-02-01 03:56:32 +08:00
|
|
|
unsigned reducedRank = keepDim ? inputRank : inputRank - 1;
|
|
|
|
sizes.resize(reducedRank, kUnknownSize);
|
2021-11-08 23:56:40 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
Type resultType = tensorType.getWithSizesAndDtype(
|
2022-12-20 18:17:27 +08:00
|
|
|
sizes.size() == 0 ? std::optional<ArrayRef<int64_t>>()
|
2021-11-08 23:56:40 +08:00
|
|
|
: llvm::makeArrayRef(sizes),
|
2023-01-04 06:19:18 +08:00
|
|
|
tensorType.getOptionalDtype());
|
2022-02-01 03:56:32 +08:00
|
|
|
return resultType;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Reduction function to calculate sum along given `dim`.
|
|
|
|
static Value createSumAlongDimension(PatternRewriter &rewriter, Location loc,
|
|
|
|
Operation *op, Value input, Value dim,
|
|
|
|
bool keepDim) {
|
|
|
|
Value dimList = rewriter.create<PrimListConstructOp>(
|
|
|
|
loc, Torch::ListType::get(dim.getType()), dim);
|
|
|
|
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
|
|
|
|
Value dtype = rewriter.create<ConstantNoneOp>(loc);
|
2022-06-29 15:23:57 +08:00
|
|
|
Type resultType = computeReductionType(
|
|
|
|
rewriter, op, input.getType().cast<BaseTensorType>(), dim, keepDim);
|
2022-02-01 03:56:32 +08:00
|
|
|
if (!resultType)
|
|
|
|
return nullptr;
|
|
|
|
return rewriter.create<AtenSumDimIntListOp>(loc, resultType, input, dimList,
|
|
|
|
keepDimCst, dtype);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Redunction function to calculate max along given `dim`.
|
|
|
|
static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc,
|
|
|
|
Operation *op, Value input, Value dim,
|
|
|
|
bool keepDim) {
|
|
|
|
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
|
|
|
|
BaseTensorType valueType =
|
2022-06-29 15:23:57 +08:00
|
|
|
computeReductionType(rewriter, op, input.getType().cast<BaseTensorType>(),
|
|
|
|
dim, keepDim)
|
2022-02-01 03:56:32 +08:00
|
|
|
.cast<BaseTensorType>();
|
|
|
|
if (!valueType)
|
|
|
|
return nullptr;
|
|
|
|
BaseTensorType indexType =
|
|
|
|
valueType
|
|
|
|
.getWithSizesAndDtype(
|
2022-12-20 18:17:27 +08:00
|
|
|
!valueType.hasSizes() ? std::optional<ArrayRef<int64_t>>()
|
2022-02-01 03:56:32 +08:00
|
|
|
: llvm::makeArrayRef(valueType.getSizes()),
|
|
|
|
IntegerType::get(op->getContext(), 64, IntegerType::Signed))
|
|
|
|
.cast<BaseTensorType>();
|
|
|
|
return rewriter
|
|
|
|
.create<AtenMaxDimOp>(loc, valueType, indexType, input, dim, keepDimCst)
|
2022-12-08 04:20:41 +08:00
|
|
|
.getValues();
|
2021-11-08 23:56:40 +08:00
|
|
|
}
|
|
|
|
|
2021-11-19 02:02:20 +08:00
|
|
|
// Helper for creating `aten::sub_tensor_op`.
|
2021-11-19 20:18:41 +08:00
|
|
|
static Value createTensorSub(PatternRewriter &rewriter, Location loc,
|
2022-02-15 21:14:32 +08:00
|
|
|
Type tensorType, Value lhs, Value rhs) {
|
2021-11-19 02:02:20 +08:00
|
|
|
Value alpha =
|
|
|
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
|
|
|
|
Value sub =
|
|
|
|
rewriter.create<AtenSubTensorOp>(loc, tensorType, lhs, rhs, alpha);
|
|
|
|
return sub;
|
|
|
|
}
|
|
|
|
|
2022-03-03 00:48:15 +08:00
|
|
|
// Helper to create a tensor filled with the given scalar. Scalar would be
|
|
|
|
// converted the to the element type of the given tensor type.
|
2022-02-09 04:57:23 +08:00
|
|
|
static Value createInitTensor(PatternRewriter &rewriter, Location loc,
|
|
|
|
Type resultType, Value scalar, Value sizeList) {
|
|
|
|
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
2022-09-23 10:24:36 +08:00
|
|
|
return rewriter.create<AtenFullOp>(
|
|
|
|
loc, resultType, sizeList, scalar, /*dtype=*/noneVal, /*layout=*/noneVal,
|
|
|
|
/*device=*/noneVal, /*memory_format=*/noneVal);
|
2022-02-09 04:57:23 +08:00
|
|
|
}
|
|
|
|
|
2022-02-26 00:35:04 +08:00
|
|
|
// Helper to create a rank 0 tensor filled with the given `scalar`. `scalar`
|
|
|
|
// would be converted to the element type of the given `inputType`.
|
2022-02-09 04:57:23 +08:00
|
|
|
static Value createRank0Tensor(PatternRewriter &rewriter, Location loc,
|
|
|
|
BaseTensorType inputType, Value scalar) {
|
|
|
|
SmallVector<int64_t> sizes;
|
|
|
|
Type rank0TensorTy = inputType.getWithSizesAndDtype(
|
|
|
|
makeArrayRef(sizes), inputType.getOptionalDtype());
|
|
|
|
Value dimList = rewriter.create<PrimListConstructOp>(
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())),
|
|
|
|
ValueRange{});
|
|
|
|
return createInitTensor(rewriter, loc, rank0TensorTy, scalar, dimList);
|
|
|
|
}
|
|
|
|
|
2021-11-19 20:18:41 +08:00
|
|
|
// Share code between `softmax_backward` and `log_softmax_backward` ops.
|
|
|
|
// Returns x - y * sum(z, dim).
|
|
|
|
static Value createSoftmaxBackwardCommonKernel(PatternRewriter &rewriter,
|
|
|
|
Location loc, Operation *op,
|
|
|
|
Type tensorType, Value x,
|
|
|
|
Value y, Value z, Value dim) {
|
2022-02-15 21:14:32 +08:00
|
|
|
Value sum =
|
|
|
|
createSumAlongDimension(rewriter, loc, op, z, dim, /*keepDim=*/true);
|
2021-11-19 20:18:41 +08:00
|
|
|
if (!sum)
|
|
|
|
return nullptr;
|
|
|
|
auto broadcastSizeType =
|
|
|
|
Torch::ListType::get(Torch::IntType::get(op->getContext()));
|
|
|
|
Value broadcastSize = rewriter.create<AtenSizeOp>(loc, broadcastSizeType, z);
|
|
|
|
Value sumBroadcast =
|
|
|
|
rewriter.create<AtenBroadcastToOp>(loc, tensorType, sum, broadcastSize);
|
|
|
|
Value temp =
|
|
|
|
rewriter.create<AtenMulTensorOp>(loc, tensorType, y, sumBroadcast);
|
|
|
|
|
|
|
|
Value sub = createTensorSub(rewriter, loc, tensorType, x, temp);
|
|
|
|
return sub;
|
|
|
|
}
|
|
|
|
|
2022-11-23 02:37:28 +08:00
|
|
|
namespace {
|
|
|
|
/// We decompose aten.amax into a set of aten.max.dim op(s) depending on the
|
|
|
|
/// number of dimensions across which the max needs to be computed.
|
|
|
|
/// Eg:
|
|
|
|
/// INPUT:
|
|
|
|
/// final_output = aten.amax(initial_input, dim=(0, 2, 1), keepdim=False)
|
|
|
|
///
|
|
|
|
/// OUTPUT:
|
|
|
|
/// input_1 = aten.max.dim(initial_input, 2, keepdim) #1
|
|
|
|
/// input_2 = aten.max.dim(input_1, 1, keepdim) #2
|
|
|
|
/// final_output = aten.max.dim(input_2, 0, keepdim) #3
|
|
|
|
///
|
|
|
|
/// NOTE: We iterate over, in reverse order, every dimension included in `dim`
|
|
|
|
/// of the `aten.amax` op and create an `aten.amax.dim` op.
|
|
|
|
/// Input tensor to the next `aten.amax.dim` op is thus the output of the
|
|
|
|
/// previous `aten.amax.dim` op.
|
|
|
|
class DecomposeAtenAmaxOp : public OpRewritePattern<AtenAmaxOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenAmaxOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
SmallVector<int64_t, 4> dims;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dims)))
|
2022-11-23 02:37:28 +08:00
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"non-const dim parameter unsupported");
|
|
|
|
|
|
|
|
bool keepDim;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim)))
|
2022-11-23 02:37:28 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "Expected a constant boolean value for keepDim");
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = op.getSelf();
|
2022-11-23 02:37:28 +08:00
|
|
|
// For every dimension included in `dim` of the op, iterated over in
|
|
|
|
// reverse order, we create a call to aten.max.dim.
|
2022-12-13 00:56:28 +08:00
|
|
|
std::sort(dims.begin(), dims.end());
|
|
|
|
std::reverse(dims.begin(), dims.end());
|
|
|
|
for (int64_t dimInt : dims) {
|
2022-11-23 02:37:28 +08:00
|
|
|
Value dim = rewriter.create<Torch::ConstantIntOp>(
|
2022-12-13 00:56:28 +08:00
|
|
|
loc, rewriter.getI64IntegerAttr(dimInt));
|
2022-11-23 02:37:28 +08:00
|
|
|
// The input to the next invocation of aten.max.dim is the output of the
|
|
|
|
// previous aten.max.dim op.
|
|
|
|
input = createMaxAlongDimension(rewriter, loc, op, input, dim, keepDim);
|
|
|
|
}
|
|
|
|
rewriter.replaceOp(op, input);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // end namespace
|
|
|
|
|
2021-11-08 23:56:40 +08:00
|
|
|
namespace {
|
|
|
|
class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenSizeOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value self = op.getSelf();
|
2021-11-08 23:56:40 +08:00
|
|
|
MLIRContext *context = op.getContext();
|
2022-12-20 18:17:27 +08:00
|
|
|
std::optional<unsigned> maybeRank = getTensorRank(self);
|
2022-12-13 00:56:28 +08:00
|
|
|
if (!maybeRank)
|
2021-11-08 23:56:40 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor");
|
2022-12-13 00:56:28 +08:00
|
|
|
unsigned rank = *maybeRank;
|
2021-11-08 23:56:40 +08:00
|
|
|
SmallVector<Value> sizes;
|
2022-12-13 00:56:28 +08:00
|
|
|
for (unsigned i = 0; i < rank; i++) {
|
2021-11-08 23:56:40 +08:00
|
|
|
Value dim = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(i));
|
|
|
|
sizes.push_back(rewriter.create<AtenSizeIntOp>(loc, self, dim));
|
|
|
|
}
|
|
|
|
|
|
|
|
Value sizeList = rewriter.create<PrimListConstructOp>(
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(context)), sizes);
|
|
|
|
rewriter.replaceOp(op, sizeList);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2021-12-03 12:09:21 +08:00
|
|
|
namespace {
|
|
|
|
class DecomposeAtenSelectIntOp : public OpRewritePattern<AtenSelectIntOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenSelectIntOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value start = op.getIndex();
|
|
|
|
Value dim = op.getDim();
|
|
|
|
Value self = op.getSelf();
|
2022-02-12 03:34:05 +08:00
|
|
|
|
2023-01-18 02:14:14 +08:00
|
|
|
// convert `start` to non-negative: start += int(start < 0) * dimSize
|
|
|
|
Value zero =
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
Value isNegative = rewriter.create<AtenLtIntOp>(loc, start, zero);
|
|
|
|
isNegative = rewriter.create<AtenIntBoolOp>(loc, isNegative);
|
|
|
|
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, self, dim);
|
|
|
|
Value indexOffset = rewriter.create<AtenMulIntOp>(loc, isNegative, dimSize);
|
|
|
|
start = rewriter.create<AtenAddIntOp>(loc, start, indexOffset);
|
|
|
|
|
2021-12-03 12:09:21 +08:00
|
|
|
Value one =
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
2022-02-12 03:34:05 +08:00
|
|
|
Value startPlusOne =
|
|
|
|
rewriter.create<AtenAddIntOp>(loc, one.getType(), start, one);
|
|
|
|
Value slice = rewriter.create<AtenSliceTensorOp>(
|
2022-06-29 15:23:57 +08:00
|
|
|
loc,
|
|
|
|
computeReductionType(rewriter, op,
|
|
|
|
self.getType().cast<BaseTensorType>(), dim,
|
|
|
|
/*keepDim=*/true),
|
2022-12-08 04:20:41 +08:00
|
|
|
op.getSelf(), dim, start, startPlusOne, /*step=*/one);
|
2022-02-12 03:34:05 +08:00
|
|
|
|
|
|
|
// `aten.slice.tensor` doesn't squeeze the dim even when it's size 1 after
|
|
|
|
// slicing, while `aten.select.int` does.
|
|
|
|
rewriter.replaceOpWithNewOp<AtenSqueezeDimOp>(op, op.getResult().getType(),
|
2022-12-08 04:20:41 +08:00
|
|
|
slice, op.getDim());
|
2021-12-03 12:09:21 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-03-25 03:26:37 +08:00
|
|
|
namespace {
|
2022-08-01 20:32:35 +08:00
|
|
|
class DecomposeAtenNarrowOp : public OpRewritePattern<AtenNarrowOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenNarrowOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value start = op.getStart();
|
|
|
|
Value dim = op.getDim();
|
|
|
|
Value length = op.getLength();
|
2022-08-01 20:32:35 +08:00
|
|
|
|
|
|
|
Value one =
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
Value startPlusLength =
|
|
|
|
rewriter.create<AtenAddIntOp>(loc, one.getType(), start, length);
|
2022-11-06 20:44:05 +08:00
|
|
|
|
2022-08-01 20:32:35 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenSliceTensorOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
op, op.getResult().getType(), op.getSelf(), /*dim=*/dim, /*start=*/start,
|
2022-08-01 20:32:35 +08:00
|
|
|
/*end=*/startPlusLength, /*step=*/one);
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
class DecomposeAtenZeroOp
|
|
|
|
: public OpRewritePattern<AtenZeroOp> {
|
2022-03-25 03:26:37 +08:00
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
2022-07-12 01:56:12 +08:00
|
|
|
LogicalResult matchAndRewrite(AtenZeroOp op,
|
2022-03-25 03:26:37 +08:00
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Value zero = rewriter.create<ConstantIntOp>(op.getLoc(),
|
|
|
|
rewriter.getI64IntegerAttr(0));
|
2022-12-08 04:20:41 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenFillScalarOp>(op, op.getType(), op.getSelf(),
|
2022-10-28 23:06:11 +08:00
|
|
|
zero);
|
2022-03-25 03:26:37 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2021-12-17 23:54:03 +08:00
|
|
|
namespace {
|
|
|
|
class DecomposeAtenReshapeOp : public OpRewritePattern<AtenReshapeOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenReshapeOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = op.getSelf();
|
2021-12-17 23:54:03 +08:00
|
|
|
// TODO: Handle non value tensor type operands.
|
|
|
|
if (!input.getType().isa<ValueTensorType>()) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: only value tensor type operands are supported");
|
|
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<AtenViewOp>(op, op.getType(), input,
|
2022-12-08 04:20:41 +08:00
|
|
|
op.getShape());
|
2021-12-17 23:54:03 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2021-11-25 13:49:02 +08:00
|
|
|
// Calculates the softmax function on the given `input` tensor. Softmax(x) =
|
|
|
|
// exp(x)/sum(exp(x)).
|
2022-02-01 03:56:32 +08:00
|
|
|
// To avoid overflow we use the following decomposition rule:
|
|
|
|
// x_max = max(input, dim, keepdim = True)
|
|
|
|
// unnorm = aten.exp(input - x_max)
|
|
|
|
// softmax = unnorm / sum(unnorm, dim, keepdim = True)
|
2021-11-25 13:49:02 +08:00
|
|
|
template <typename OpTy>
|
2022-11-25 13:56:37 +08:00
|
|
|
static Value getSoftmaxResult(OpTy op, Value self, Type resultType,
|
2021-11-25 13:49:02 +08:00
|
|
|
PatternRewriter &rewriter) {
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value dim = op.getDim();
|
2022-02-01 03:56:32 +08:00
|
|
|
Value xMax =
|
|
|
|
createMaxAlongDimension(rewriter, loc, op, self, dim, /*keepDim=*/true);
|
|
|
|
if (!xMax)
|
|
|
|
return nullptr;
|
|
|
|
Value unNormalized = createTensorSub(rewriter, loc, resultType, self, xMax);
|
|
|
|
Value unNormalizedExp =
|
|
|
|
rewriter.create<AtenExpOp>(loc, resultType, unNormalized);
|
|
|
|
Value sum = createSumAlongDimension(rewriter, loc, op, unNormalizedExp, dim,
|
|
|
|
/*keepDim=*/true);
|
2021-11-25 13:49:02 +08:00
|
|
|
if (!sum)
|
|
|
|
return nullptr;
|
2022-02-01 03:56:32 +08:00
|
|
|
return rewriter.create<AtenDivTensorOp>(loc, resultType, unNormalizedExp,
|
|
|
|
sum);
|
2021-11-25 13:49:02 +08:00
|
|
|
}
|
|
|
|
|
2021-10-16 06:23:59 +08:00
|
|
|
// Decompose softmax into: exp(x) / sum(exp(x))
|
|
|
|
namespace {
|
|
|
|
class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenSoftmaxIntOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
Value self = op.getSelf();
|
|
|
|
if (!op.getDtype().getType().isa<Torch::NoneType>())
|
2021-10-16 06:23:59 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "Unimplemented non-None dtype for softmax");
|
|
|
|
|
|
|
|
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
|
|
|
|
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
|
|
|
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
2021-11-08 23:56:40 +08:00
|
|
|
|
2022-11-25 13:56:37 +08:00
|
|
|
Value result = getSoftmaxResult(op, self, tensorType, rewriter);
|
2021-11-25 13:49:02 +08:00
|
|
|
if (!result)
|
2021-11-08 23:56:40 +08:00
|
|
|
return failure();
|
2021-11-25 13:49:02 +08:00
|
|
|
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
|
|
|
|
result);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
class DecomposeAten_SoftmaxOp : public OpRewritePattern<Aten_SoftmaxOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(Aten_SoftmaxOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
Value self = op.getSelf();
|
2021-11-25 13:49:02 +08:00
|
|
|
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
|
|
|
|
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
|
|
|
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
|
|
|
bool halfToFloat;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getHalfToFloat(), m_TorchConstantBool(&halfToFloat)))
|
2021-11-25 13:49:02 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "Expected a boolean value for half_to_float");
|
|
|
|
|
2022-11-25 13:56:37 +08:00
|
|
|
BaseTensorType resultTensorType = op.getType().cast<BaseTensorType>();
|
2023-01-04 06:19:18 +08:00
|
|
|
if (!resultTensorType.hasDtype()) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "expected result type to have a dtype");
|
|
|
|
}
|
|
|
|
Type resultTensorDtype = resultTensorType.getDtype();
|
2022-11-25 13:56:37 +08:00
|
|
|
// `torch.ops.aten._softmax`'s softmax with half to float conversion is not
|
|
|
|
// supported on CPU, but we go ahead with the decomposing.
|
|
|
|
// TODO: Add an e2e test once upstream support is added.
|
|
|
|
// If `half_to_float` is set, we convert the input's elemental type to match
|
|
|
|
// that of output's.
|
|
|
|
if (halfToFloat) {
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
Value cstFalse = rewriter.create<ConstantBoolOp>(loc, false);
|
|
|
|
self = rewriter.create<AtenToDtypeOp>(
|
|
|
|
loc, resultTensorType, self,
|
2023-01-04 06:19:18 +08:00
|
|
|
getDtypeIntValueForType(rewriter, loc, resultTensorDtype),
|
2022-11-25 13:56:37 +08:00
|
|
|
/*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none);
|
|
|
|
}
|
|
|
|
Value result = getSoftmaxResult(op, self, resultTensorType, rewriter);
|
2021-11-25 13:49:02 +08:00
|
|
|
if (!result)
|
|
|
|
return op.emitError("failed to get softmax result");
|
2022-11-25 13:56:37 +08:00
|
|
|
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, resultTensorType,
|
2021-10-16 06:23:59 +08:00
|
|
|
result);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2021-11-08 23:56:40 +08:00
|
|
|
// Aten_SoftmaxBackwardDataOp(gradOutput, output, dim) =>
|
|
|
|
// newGrad = gradOutput * output
|
|
|
|
// result = newGrad - output * sum(newGrad, dim))
|
|
|
|
//
|
|
|
|
// Refer to
|
|
|
|
// https://github.com/pytorch/pytorch/blob/15fecc4c830a3907fde4b44c9962dc4144da50a4/torch/csrc/jit/codegen/cuda/ops/normalization.cpp#L31
|
|
|
|
namespace {
|
|
|
|
class DecomposeAten_SoftmaxBackwardDataOp
|
|
|
|
: public OpRewritePattern<Aten_SoftmaxBackwardDataOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(Aten_SoftmaxBackwardDataOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value gradOutput = op.getGradOutput();
|
|
|
|
Value output = op.getOutput();
|
|
|
|
Value dim = op.getDim();
|
2021-11-08 23:56:40 +08:00
|
|
|
|
|
|
|
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>();
|
|
|
|
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
|
|
|
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
|
|
|
|
|
|
|
Value newGrad =
|
|
|
|
rewriter.create<AtenMulTensorOp>(loc, tensorType, gradOutput, output);
|
2021-11-19 20:18:41 +08:00
|
|
|
Value result = createSoftmaxBackwardCommonKernel(
|
|
|
|
rewriter, loc, op, tensorType, newGrad, output, newGrad, dim);
|
|
|
|
if (!result)
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op,
|
|
|
|
"nullptr returned by createSoftmaxBackwardCommonKernel function.");
|
|
|
|
rewriter.replaceOp(op, result);
|
2021-11-08 23:56:40 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2021-11-09 20:25:04 +08:00
|
|
|
// AtenTanhBackwardOp(gradOutput, output) =>
|
|
|
|
// result = gradOutput * (1 - output^2)
|
|
|
|
// To get away from broadcasts the above formula is expanded i.e.,
|
|
|
|
// result = gradOutput - (gradOutput * output^2)
|
|
|
|
namespace {
|
|
|
|
class DecomposeAtenTanhBackwardOp
|
|
|
|
: public OpRewritePattern<AtenTanhBackwardOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenTanhBackwardOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value gradOutput = op.getGradOutput();
|
2021-11-09 20:25:04 +08:00
|
|
|
|
|
|
|
// `output` is the value flowing out from tanh. Hence, tanh(x) = output.
|
2021-11-19 20:18:41 +08:00
|
|
|
// Since, dTanh(x) = (1 - tanh(x)^2) hence, dOutput = (1 - output^2).
|
2022-12-08 04:20:41 +08:00
|
|
|
Value output = op.getOutput();
|
2021-11-09 20:25:04 +08:00
|
|
|
|
|
|
|
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>();
|
|
|
|
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
|
|
|
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
|
|
|
|
|
|
|
Value tanhSquare =
|
|
|
|
rewriter.create<AtenMulTensorOp>(loc, tensorType, output, output);
|
|
|
|
Value gradMulTanhSquare = rewriter.create<AtenMulTensorOp>(
|
|
|
|
loc, tensorType, tanhSquare, gradOutput);
|
|
|
|
|
2021-11-19 20:18:41 +08:00
|
|
|
Value newGrad = createTensorSub(rewriter, loc, tensorType, gradOutput,
|
2022-02-15 21:14:32 +08:00
|
|
|
gradMulTanhSquare);
|
2021-11-09 20:25:04 +08:00
|
|
|
rewriter.replaceOp(op, newGrad);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2021-11-19 02:02:20 +08:00
|
|
|
// Aten_LogSoftmaxBackwardDataOp(gradOutput, output, dim) =>
|
|
|
|
// result = gradOutput - (exp(output) * sum(gradOutput, dim))
|
|
|
|
namespace {
|
|
|
|
class DecomposeAten_LogSoftmaxBackwardDataOp
|
|
|
|
: public OpRewritePattern<Aten_LogSoftmaxBackwardDataOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(Aten_LogSoftmaxBackwardDataOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value gradOutput = op.getGradOutput();
|
|
|
|
Value output = op.getOutput();
|
|
|
|
Value dim = op.getDim();
|
2021-11-19 02:02:20 +08:00
|
|
|
|
|
|
|
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>();
|
|
|
|
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
|
|
|
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
|
|
|
|
|
|
|
Value expOut = rewriter.create<AtenExpOp>(loc, tensorType, output);
|
2021-11-19 20:18:41 +08:00
|
|
|
Value result = createSoftmaxBackwardCommonKernel(
|
|
|
|
rewriter, loc, op, tensorType, gradOutput, expOut, gradOutput, dim);
|
|
|
|
if (!result)
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op,
|
|
|
|
"nullptr returned by createSoftmaxBackwardCommonKernel function.");
|
|
|
|
rewriter.replaceOp(op, result);
|
2021-11-19 02:02:20 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-01-25 16:53:55 +08:00
|
|
|
// Decompose `AtenArgMaxOp` into `AtenMaxDimOp`.
|
|
|
|
namespace {
|
|
|
|
class DecomposeAtenArgMaxOp : public OpRewritePattern<AtenArgmaxOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenArgmaxOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = op.getSelf();
|
|
|
|
Value dim = op.getDim();
|
|
|
|
Value keepDim = op.getKeepdim();
|
|
|
|
Value result = op.getResult();
|
2022-01-25 16:53:55 +08:00
|
|
|
|
|
|
|
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
BaseTensorType indicesTensorType = result.getType().cast<BaseTensorType>();
|
2022-12-20 18:17:27 +08:00
|
|
|
std::optional<unsigned> maybeInputRank = getTensorRank(input);
|
2022-12-13 00:56:28 +08:00
|
|
|
if (!maybeInputRank) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "expected input tensor to have a rank");
|
|
|
|
}
|
|
|
|
unsigned inputRank = *maybeInputRank;
|
2022-01-25 16:53:55 +08:00
|
|
|
if (!indicesTensorType.hasSizes())
|
|
|
|
return failure();
|
|
|
|
BaseTensorType valueTensorType =
|
|
|
|
inputType
|
2023-01-04 06:19:18 +08:00
|
|
|
.getWithSizesAndDtype(indicesTensorType.getOptionalSizes(),
|
|
|
|
inputType.getOptionalDtype())
|
2022-01-25 16:53:55 +08:00
|
|
|
.cast<BaseTensorType>();
|
|
|
|
|
|
|
|
// If the dim type is `NoneType` i.e. reduce along all the dimensions.
|
|
|
|
// `AtenMaxDimOp` doesn't support dim as `NoneType` so first the input
|
|
|
|
// tensor is flattened to 1d tensor and then the reduction happens on the
|
|
|
|
// 0th dimension.
|
|
|
|
if (dim.getType().isa<Torch::NoneType>()) {
|
|
|
|
BaseTensorType flattenType =
|
2023-01-04 06:19:18 +08:00
|
|
|
inputType
|
|
|
|
.getWithSizesAndDtype({kUnknownSize},
|
|
|
|
inputType.getOptionalDtype())
|
2022-01-25 16:53:55 +08:00
|
|
|
.cast<BaseTensorType>();
|
|
|
|
dim = rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
Value end = rewriter.create<ConstantIntOp>(
|
2022-12-13 00:56:28 +08:00
|
|
|
loc, rewriter.getI64IntegerAttr(inputRank - 1));
|
2022-01-25 16:53:55 +08:00
|
|
|
input = rewriter.create<AtenFlattenUsingIntsOp>(loc, flattenType, input,
|
|
|
|
dim, end);
|
|
|
|
}
|
|
|
|
Value maxResult =
|
|
|
|
rewriter
|
|
|
|
.create<AtenMaxDimOp>(loc, valueTensorType, indicesTensorType,
|
|
|
|
input, dim, keepDim)
|
2022-12-08 04:20:41 +08:00
|
|
|
.getIndices();
|
2022-01-25 16:53:55 +08:00
|
|
|
|
|
|
|
rewriter.replaceOp(op, maxResult);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-02-11 16:39:34 +08:00
|
|
|
// To avoid overflow we use the following decomposition rule:
|
|
|
|
// x_max = aten.max(x, dim, keepdim=True)[0]
|
|
|
|
// shifted = x - x_max
|
|
|
|
// shifted_logsumexp = aten.log(aten.sum(aten.exp(shifted), dim, keepdim=True))
|
|
|
|
// log_softmax = shifted - shifted_logsumexp
|
|
|
|
template <typename OpTy>
|
|
|
|
static Value getLogSoftmaxResult(OpTy op, PatternRewriter &rewriter) {
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value dim = op.getDim();
|
|
|
|
Value self = op.getSelf();
|
2022-02-11 16:39:34 +08:00
|
|
|
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
|
|
|
|
Value xMax =
|
|
|
|
createMaxAlongDimension(rewriter, loc, op, self, dim, /*keepDim=*/true);
|
|
|
|
if (!xMax)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
Value shifted = createTensorSub(rewriter, loc, tensorType, self, xMax);
|
|
|
|
Value shiftedExp = rewriter.create<AtenExpOp>(loc, tensorType, shifted);
|
|
|
|
Value shiftedSumExp =
|
|
|
|
createSumAlongDimension(rewriter, loc, op, shiftedExp, dim,
|
|
|
|
/*keepDim=*/true);
|
|
|
|
if (!shiftedSumExp)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
Value shiftedLogSumExp =
|
|
|
|
rewriter.create<AtenLogOp>(loc, shiftedSumExp.getType(), shiftedSumExp);
|
|
|
|
Value result =
|
|
|
|
createTensorSub(rewriter, loc, op.getType(), shifted, shiftedLogSumExp);
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
2021-11-03 01:06:04 +08:00
|
|
|
namespace {
|
|
|
|
class DecomposeAtenLogSoftmaxIntOp
|
|
|
|
: public OpRewritePattern<AtenLogSoftmaxIntOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenLogSoftmaxIntOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
Value self = op.getSelf();
|
|
|
|
if (!op.getDtype().getType().isa<Torch::NoneType>())
|
2021-11-03 01:06:04 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "Unimplemented non-None dtype for log_softmax");
|
|
|
|
|
|
|
|
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
|
|
|
|
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
|
|
|
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
|
|
|
|
2022-02-11 16:39:34 +08:00
|
|
|
Value logSoftmax = getLogSoftmaxResult(op, rewriter);
|
|
|
|
if (!logSoftmax)
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "getLogSoftmaxResult function returned nullptr");
|
|
|
|
rewriter.replaceOp(op, logSoftmax);
|
2021-11-03 01:06:04 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-02-10 15:05:23 +08:00
|
|
|
namespace {
|
|
|
|
class DecomposeAten_LogSoftmaxOp : public OpRewritePattern<Aten_LogSoftmaxOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(Aten_LogSoftmaxOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-02-11 16:39:34 +08:00
|
|
|
bool halfToFloat;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getHalfToFloat(), m_TorchConstantBool(&halfToFloat)))
|
2022-02-11 16:39:34 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "Expected a boolean value for half_to_float");
|
|
|
|
|
|
|
|
// Currently, setting `halfToFloat` is not supported as the E2E testing for
|
|
|
|
// the same is not present on CPU.
|
|
|
|
if (halfToFloat)
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "halfToFloat is currently not supported.");
|
|
|
|
Value _logSoftmax = getLogSoftmaxResult(op, rewriter);
|
|
|
|
if (!_logSoftmax)
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "getLogSoftmaxResult function returned nullptr");
|
|
|
|
rewriter.replaceOp(op, _logSoftmax);
|
2022-02-10 15:05:23 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-01-30 01:10:50 +08:00
|
|
|
// Decompose aten.matmul into: aten.mm and aten.bmm according to ranks.
|
2021-10-21 13:15:10 +08:00
|
|
|
namespace {
|
|
|
|
class DecomposeAtenMatmulOp : public OpRewritePattern<AtenMatmulOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenMatmulOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
Value lhs = op.getSelf();
|
|
|
|
Value rhs = op.getOther();
|
2021-10-21 13:15:10 +08:00
|
|
|
|
2022-12-20 18:17:27 +08:00
|
|
|
std::optional<unsigned> maybeLhsRank = getTensorRank(lhs);
|
|
|
|
std::optional<unsigned> maybeRhsRank = getTensorRank(rhs);
|
2022-12-13 00:56:28 +08:00
|
|
|
if (!maybeLhsRank || !maybeRhsRank) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "expected input tensors to have a rank");
|
|
|
|
}
|
|
|
|
unsigned lhsRank = *maybeLhsRank;
|
|
|
|
unsigned rhsRank = *maybeRhsRank;
|
2021-10-21 13:15:10 +08:00
|
|
|
|
2022-12-09 01:26:38 +08:00
|
|
|
if (lhsRank == 2 && rhsRank == 2) {
|
|
|
|
// If both lhs and rhs ranks are 2 then map it to `aten.mm` op.
|
2021-10-21 13:15:10 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenMmOp>(op, op.getType(), lhs, rhs);
|
2022-12-09 01:26:38 +08:00
|
|
|
} else if (lhsRank == 3 && rhsRank == 3) {
|
|
|
|
// If both lhs and rhs ranks are 3 then map it to `aten.bmm` op.
|
2021-10-21 13:15:10 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenBmmOp>(op, op.getType(), lhs, rhs);
|
2022-12-09 01:26:38 +08:00
|
|
|
} else {
|
|
|
|
return failure();
|
|
|
|
}
|
2021-10-21 13:15:10 +08:00
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-10-04 02:32:17 +08:00
|
|
|
// Decompose aten.mv into: aten.matmul.
|
|
|
|
namespace {
|
|
|
|
class DecomposeAtenMvOp : public OpRewritePattern<AtenMvOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenMvOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
Value lhs = op.getSelf();
|
|
|
|
Value rhs = op.getVec();
|
2022-10-04 02:32:17 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenMatmulOp>(op, op.getType(), lhs, rhs);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-02-15 21:14:32 +08:00
|
|
|
// ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6)
|
|
|
|
static Value getRelu6Results(PatternRewriter &rewriter, Location loc,
|
|
|
|
Value input) {
|
|
|
|
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
2022-02-09 04:57:23 +08:00
|
|
|
|
2022-02-15 21:14:32 +08:00
|
|
|
Value relu = rewriter.create<AtenReluOp>(loc, inputType, input);
|
2022-02-09 04:57:23 +08:00
|
|
|
Value cst6 =
|
2022-02-15 21:14:32 +08:00
|
|
|
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(6));
|
2022-02-09 04:57:23 +08:00
|
|
|
Value sixTensor = createRank0Tensor(rewriter, loc, inputType, cst6);
|
2022-02-15 21:14:32 +08:00
|
|
|
Value relu6Out =
|
|
|
|
rewriter.create<AtenMinimumOp>(loc, inputType, relu, sixTensor);
|
|
|
|
return relu6Out;
|
|
|
|
}
|
|
|
|
|
2022-09-23 20:39:15 +08:00
|
|
|
namespace {
|
|
|
|
class DecomposeAtenRelu6Op : public OpRewritePattern<AtenRelu6Op> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenRelu6Op op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value relu6 = getRelu6Results(rewriter, loc, op.getSelf());
|
2022-09-23 20:39:15 +08:00
|
|
|
rewriter.replaceOp(op, relu6);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-02-15 21:14:32 +08:00
|
|
|
// Hardswish(x) = x * Relu6(x+3)/6
|
|
|
|
namespace {
|
|
|
|
class DecomposeAtenHardswishOp : public OpRewritePattern<AtenHardswishOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenHardswishOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = op.getSelf();
|
2022-02-15 21:14:32 +08:00
|
|
|
Type inputType = input.getType();
|
|
|
|
|
|
|
|
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
Value constantThree = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(3));
|
|
|
|
Value constantSix = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(6));
|
|
|
|
Value inputPlusThree = rewriter.create<AtenAddScalarOp>(
|
|
|
|
loc, inputType, input, constantThree, /*alpha=*/constantOne);
|
|
|
|
Value relu6 = getRelu6Results(rewriter, loc, inputPlusThree);
|
|
|
|
Value divTensor =
|
|
|
|
rewriter.create<AtenDivScalarOp>(loc, inputType, relu6, constantSix);
|
|
|
|
Value mulTensor =
|
|
|
|
rewriter.create<AtenMulTensorOp>(loc, inputType, divTensor, input);
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, mulTensor);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2023-01-04 00:30:16 +08:00
|
|
|
// LeakyRelu = max(0,x) + negative_slope * min(0,x)
|
|
|
|
namespace {
|
|
|
|
class DecomposeAtenLeakyReluOp : public OpRewritePattern<AtenLeakyReluOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenLeakyReluOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
Value input = op.getSelf();
|
|
|
|
Value negativeSlope = op.getNegativeSlope();
|
|
|
|
auto resType = op.getType().cast<BaseTensorType>();
|
|
|
|
|
|
|
|
Value constantZero =
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
Value constantOne =
|
|
|
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
|
|
|
Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero);
|
|
|
|
Value positiveOutput =
|
|
|
|
rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, input);
|
|
|
|
Value negativeOutput =
|
|
|
|
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, input);
|
|
|
|
Value scaledNegativeOutput = rewriter.create<AtenMulScalarOp>(
|
|
|
|
loc, resType, negativeOutput, negativeSlope);
|
|
|
|
Value leakyReluOutput = rewriter.create<AtenAddTensorOp>(
|
|
|
|
loc, resType, positiveOutput, scaledNegativeOutput, constantOne);
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, leakyReluOutput);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
// LeakyReluBackward = max(0,grad) + negative_slope * min(0,x)
|
|
|
|
namespace {
|
|
|
|
class DecomposeAtenLeakyReluBackwardOp
|
|
|
|
: public OpRewritePattern<AtenLeakyReluBackwardOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenLeakyReluBackwardOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
Value gradOutput = op.getGradOutput();
|
|
|
|
Value input = op.getSelf();
|
|
|
|
Value negativeSlope = op.getNegativeSlope();
|
|
|
|
auto resType = op.getType().cast<BaseTensorType>();
|
|
|
|
|
|
|
|
bool selfIsResult = false;
|
|
|
|
if (!matchPattern(op.getSelfIsResult(),
|
|
|
|
m_TorchConstantBool(&selfIsResult)) ||
|
|
|
|
selfIsResult)
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: self_is_result should be false");
|
|
|
|
|
|
|
|
Value constantZero =
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
Value constantOne =
|
|
|
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
|
|
|
Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero);
|
|
|
|
Value positiveOutput =
|
|
|
|
rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, gradOutput);
|
|
|
|
Value negativeOutput =
|
|
|
|
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, input);
|
|
|
|
Value scaledNegativeOutput = rewriter.create<AtenMulScalarOp>(
|
|
|
|
loc, resType, negativeOutput, negativeSlope);
|
|
|
|
Value leakyReluBackwardOutput = rewriter.create<AtenAddTensorOp>(
|
|
|
|
loc, resType, positiveOutput, scaledNegativeOutput, constantOne);
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, leakyReluBackwardOutput);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2021-12-17 12:08:07 +08:00
|
|
|
namespace {
|
|
|
|
class DecomposeAtenTOp : public OpRewritePattern<AtenTOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenTOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
Value lhs = op.getSelf();
|
2022-12-20 18:17:27 +08:00
|
|
|
std::optional<unsigned> lhsRank = getTensorRank(lhs);
|
2021-12-17 12:08:07 +08:00
|
|
|
auto loc = op.getLoc();
|
|
|
|
|
2022-12-13 00:56:28 +08:00
|
|
|
if (!lhsRank) {
|
|
|
|
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
|
|
|
|
} else if (*lhsRank > 2) {
|
2021-12-17 12:08:07 +08:00
|
|
|
std::string errorMessage =
|
|
|
|
"t() expects a tensor with <=2 dimensions, but self is " +
|
2022-12-13 00:56:28 +08:00
|
|
|
std::to_string(*lhsRank) + "D";
|
2021-12-17 12:08:07 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, errorMessage.c_str());
|
2022-12-13 00:56:28 +08:00
|
|
|
} else if (*lhsRank < 2)
|
2021-12-17 12:08:07 +08:00
|
|
|
rewriter.replaceOp(op, lhs);
|
|
|
|
else {
|
|
|
|
Value zero =
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
Value one =
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
rewriter.replaceOpWithNewOp<AtenTransposeIntOp>(op, op.getType(), lhs,
|
|
|
|
zero, one);
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-08-24 08:36:05 +08:00
|
|
|
// Decompose aten.roll into aten.slice and aten.cat ops.
|
|
|
|
// https://pytorch.org/docs/stable/generated/torch.roll.html
|
|
|
|
namespace {
|
|
|
|
class DecomposeAtenRollOp : public OpRewritePattern<AtenRollOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenRollOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
SmallVector<Value> shifts;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!getListConstructElements(op.getShifts(), shifts))
|
2022-08-24 08:36:05 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: shifts not list of Scalar");
|
|
|
|
SmallVector<Value> dims;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!getListConstructElements(op.getDims(), dims))
|
2022-08-24 08:36:05 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: dims not list of Scalar");
|
|
|
|
|
|
|
|
if (shifts.size() != dims.size())
|
|
|
|
return op.emitError("list sizes of shifts and dims are not the same");
|
|
|
|
|
|
|
|
auto loc = op.getLoc();
|
|
|
|
Value constNone = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
Value constZero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
Value constOne = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(1));
|
2022-12-08 04:20:41 +08:00
|
|
|
auto self = op.getSelf();
|
2022-08-24 08:36:05 +08:00
|
|
|
auto selfTy = self.getType().cast<BaseTensorType>();
|
|
|
|
// roll(input, shift, dim) = cat({
|
|
|
|
// slice(input, dim, -shift, none),
|
|
|
|
// slice(input, dim, 0, -shift)}, dim)
|
|
|
|
auto imitateRoll = [&](Value input, Value shift, Value dim,
|
|
|
|
int64_t cstDim) {
|
|
|
|
Value negShift = rewriter.create<AtenNegIntOp>(loc, shift);
|
|
|
|
ArrayRef<int64_t> inputShape = selfTy.getSizes();
|
|
|
|
SmallVector<int64_t> sizes;
|
|
|
|
sizes.append(inputShape.begin(), inputShape.end());
|
2022-11-29 20:33:31 +08:00
|
|
|
sizes[cstDim] = kUnknownSize;
|
2022-08-24 08:36:05 +08:00
|
|
|
Type sliceTy = selfTy.getWithSizesAndDtype(llvm::makeArrayRef(sizes),
|
2023-01-04 06:19:18 +08:00
|
|
|
selfTy.getOptionalDtype());
|
2022-08-24 08:36:05 +08:00
|
|
|
Value slice0 = rewriter.create<AtenSliceTensorOp>(
|
|
|
|
loc, sliceTy, input, dim, negShift, constNone, constOne);
|
|
|
|
Value slice1 = rewriter.create<AtenSliceTensorOp>(
|
|
|
|
loc, sliceTy, input, dim, constZero, negShift, constOne);
|
|
|
|
|
|
|
|
Type listType = Torch::ListType::get(sliceTy);
|
|
|
|
Value slices = rewriter.create<PrimListConstructOp>(
|
|
|
|
loc, listType, llvm::ArrayRef<Value>{slice0, slice1});
|
|
|
|
return rewriter.create<AtenCatOp>(loc, self.getType(), slices, dim);
|
|
|
|
};
|
2022-12-20 18:17:27 +08:00
|
|
|
std::optional<unsigned> maybeRank = getTensorRank(self);
|
2022-12-13 00:56:28 +08:00
|
|
|
if (!maybeRank)
|
2022-08-24 08:36:05 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor");
|
2022-12-13 00:56:28 +08:00
|
|
|
unsigned rank = *maybeRank;
|
2022-08-24 08:36:05 +08:00
|
|
|
Value output = self;
|
|
|
|
auto nShifts = shifts.size();
|
|
|
|
for (size_t k = 0; k < nShifts; ++k) {
|
|
|
|
auto dim = dims[k];
|
|
|
|
int64_t cstDim = -1;
|
|
|
|
if (!matchPattern(dim, m_TorchConstantInt(&cstDim)))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: dim must be constant");
|
|
|
|
|
|
|
|
cstDim = toPositiveDim(cstDim, rank);
|
|
|
|
output = imitateRoll(output, shifts[k], dim, cstDim);
|
|
|
|
}
|
|
|
|
rewriter.replaceOp(op, output);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-07-01 13:02:31 +08:00
|
|
|
// Decompose aten.repeat into aten.expand and aten.view ops.
|
|
|
|
//
|
|
|
|
// Ref: https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html
|
|
|
|
//
|
|
|
|
// For shape [S1, S2, S3] and repeats [M0, M1, M2, M3]
|
|
|
|
// MS0 = M0; MS1 = M1 * S1; MS2 = M2 * S2; MS3 = M3 * S3
|
|
|
|
//
|
|
|
|
// def aten_repeat(self, repeats):
|
|
|
|
// sizes = self.size()
|
|
|
|
// unsqueezed_sizes = []
|
|
|
|
// expanded_sizes = []
|
|
|
|
// reshape_sizes = []
|
|
|
|
// leading_rank = repeats.size() - sizes.size()
|
|
|
|
// for r in range(leading_rank):
|
|
|
|
// unsqueezed_sizes.append(1)
|
|
|
|
// expanded_sizes.append(repeats[r])
|
|
|
|
// reshaped_sizes.append(repeats[r])
|
|
|
|
//
|
|
|
|
// for s, m in zip(sizes, repeats[leading_rank:]):
|
|
|
|
// unsqueezed_sizes += [1, s]
|
|
|
|
// expanded_sizes += [m, s]
|
|
|
|
// reshaped_sizes += [m * s]
|
2022-07-22 20:42:14 +08:00
|
|
|
// return
|
|
|
|
// self.view(unsqueezed_sizes).expand(expanded_sizes).view(reshaped_sizes)
|
2022-07-01 13:02:31 +08:00
|
|
|
//
|
|
|
|
namespace {
|
|
|
|
class DecomposeAtenRepeatOp : public OpRewritePattern<AtenRepeatOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenRepeatOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value self = op.getSelf();
|
2022-07-01 13:02:31 +08:00
|
|
|
MLIRContext *context = op.getContext();
|
2022-12-20 18:17:27 +08:00
|
|
|
std::optional<unsigned> maybeRank = getTensorRank(self);
|
2022-12-13 00:56:28 +08:00
|
|
|
if (!maybeRank)
|
2022-07-01 13:02:31 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor");
|
2022-12-13 00:56:28 +08:00
|
|
|
unsigned rank = *maybeRank;
|
2022-07-01 13:02:31 +08:00
|
|
|
|
|
|
|
SmallVector<Value> repeats;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!getListConstructElements(op.getRepeats(), repeats))
|
2022-07-01 13:02:31 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "Unimplemented: repeats not list of Scalar");
|
|
|
|
|
2022-12-13 00:56:28 +08:00
|
|
|
if (rank > repeats.size()) {
|
2022-07-01 13:02:31 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "repeats are not matched with self's rank");
|
|
|
|
}
|
|
|
|
|
|
|
|
auto insertDimSizes = [](SmallVector<Value> &dimSizes,
|
|
|
|
SmallVector<int64_t> &shape,
|
|
|
|
const ArrayRef<Value> &vals) {
|
|
|
|
dimSizes.insert(dimSizes.end(), vals.begin(), vals.end());
|
|
|
|
std::transform(vals.begin(), vals.end(), std::back_inserter(shape),
|
|
|
|
[&](Value val) -> int64_t {
|
|
|
|
int64_t cst_val;
|
|
|
|
if (matchPattern(val, m_TorchConstantInt(&cst_val))) {
|
|
|
|
return cst_val;
|
|
|
|
} else {
|
2022-11-29 20:33:31 +08:00
|
|
|
return kUnknownSize;
|
2022-07-01 13:02:31 +08:00
|
|
|
}
|
|
|
|
});
|
|
|
|
};
|
|
|
|
|
|
|
|
Value one = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
|
|
|
SmallVector<Value> unsqueezedSizes, expandedSizes, reshapedSizes;
|
|
|
|
SmallVector<int64_t> unsqueezedIntSizes, expandedIntSizes;
|
|
|
|
auto leadingRank = repeats.size() - rank;
|
|
|
|
assert(leadingRank >= 0 && "leadingRank should greater than 0");
|
|
|
|
for (size_t i = 0; i < leadingRank; ++i) {
|
|
|
|
insertDimSizes(unsqueezedSizes, unsqueezedIntSizes, ArrayRef<Value>{one});
|
2022-07-22 20:42:14 +08:00
|
|
|
insertDimSizes(expandedSizes, expandedIntSizes,
|
|
|
|
ArrayRef<Value>{repeats[i]});
|
2022-07-01 13:02:31 +08:00
|
|
|
reshapedSizes.push_back(repeats[i]);
|
|
|
|
}
|
|
|
|
|
|
|
|
auto selfType = self.getType().dyn_cast<BaseTensorType>();
|
|
|
|
auto selfShape = selfType.getSizes();
|
2022-12-13 00:56:28 +08:00
|
|
|
for (unsigned i = 0; i < rank; i++) {
|
2022-07-01 13:02:31 +08:00
|
|
|
auto scale = repeats[i + leadingRank];
|
|
|
|
Value dimSize;
|
2022-11-29 20:33:31 +08:00
|
|
|
if (selfShape[i] == kUnknownSize) {
|
2022-07-01 13:02:31 +08:00
|
|
|
Value dim = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(i));
|
|
|
|
dimSize = rewriter.create<AtenSizeIntOp>(loc, self, dim);
|
|
|
|
} else {
|
|
|
|
dimSize = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(selfShape[i]));
|
|
|
|
}
|
|
|
|
|
2022-07-22 20:42:14 +08:00
|
|
|
insertDimSizes(unsqueezedSizes, unsqueezedIntSizes,
|
|
|
|
ArrayRef<Value>{one, dimSize});
|
|
|
|
insertDimSizes(expandedSizes, expandedIntSizes,
|
|
|
|
ArrayRef<Value>{scale, dimSize});
|
2022-07-01 13:02:31 +08:00
|
|
|
|
|
|
|
Value scaledSize = rewriter.create<AtenMulIntOp>(loc, dimSize, scale);
|
|
|
|
reshapedSizes.push_back(scaledSize);
|
|
|
|
}
|
|
|
|
|
2023-01-04 06:19:18 +08:00
|
|
|
Type dtype = self.getType().cast<ValueTensorType>().getOptionalDtype();
|
2022-07-22 20:42:14 +08:00
|
|
|
Type unsqueezedType = ValueTensorType::get(
|
|
|
|
context, llvm::makeArrayRef(unsqueezedIntSizes), dtype);
|
|
|
|
Type expandedType = ValueTensorType::get(
|
|
|
|
context, llvm::makeArrayRef(expandedIntSizes), dtype);
|
2022-07-01 13:02:31 +08:00
|
|
|
|
|
|
|
auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext()));
|
|
|
|
Value unsqueezedDims =
|
|
|
|
rewriter.create<PrimListConstructOp>(loc, listType, unsqueezedSizes);
|
|
|
|
Value expandedDims =
|
|
|
|
rewriter.create<PrimListConstructOp>(loc, listType, expandedSizes);
|
|
|
|
Value reshapedDims =
|
|
|
|
rewriter.create<PrimListConstructOp>(loc, listType, reshapedSizes);
|
2022-12-08 04:20:41 +08:00
|
|
|
auto reshaped = rewriter.create<AtenViewOp>(loc, unsqueezedType, op.getSelf(),
|
2022-07-22 20:42:14 +08:00
|
|
|
unsqueezedDims);
|
2022-07-01 13:02:31 +08:00
|
|
|
auto expanded = rewriter.create<AtenBroadcastToOp>(loc, expandedType,
|
|
|
|
reshaped, expandedDims);
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenViewOp>(op, op.getType(), expanded,
|
|
|
|
reshapedDims);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-08-23 11:52:54 +08:00
|
|
|
// Decompose aten.flatten.using_ints into aten.view op.
|
|
|
|
namespace {
|
|
|
|
class DecomposeAtenFlattenUsingIntsOp
|
|
|
|
: public OpRewritePattern<AtenFlattenUsingIntsOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenFlattenUsingIntsOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value self = op.getSelf();
|
2022-08-23 11:52:54 +08:00
|
|
|
MLIRContext *context = op.getContext();
|
2022-12-20 18:17:27 +08:00
|
|
|
std::optional<unsigned> maybeRank = getTensorRank(self);
|
2022-12-13 00:56:28 +08:00
|
|
|
if (!maybeRank)
|
2022-08-23 11:52:54 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "unimplemented: unranked tensor");
|
2022-12-13 00:56:28 +08:00
|
|
|
unsigned rank = *maybeRank;
|
2022-08-23 11:52:54 +08:00
|
|
|
|
|
|
|
int64_t start, end;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getStartDim(), m_TorchConstantInt(&start)) ||
|
|
|
|
!matchPattern(op.getEndDim(), m_TorchConstantInt(&end))) {
|
2022-08-23 11:52:54 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: requires start and end dims to be constants");
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<Value, 4> newSizes;
|
|
|
|
if (rank == 0) {
|
|
|
|
Value one =
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
newSizes.push_back(one);
|
|
|
|
} else {
|
|
|
|
start = toPositiveDim(start, rank);
|
|
|
|
end = toPositiveDim(end, rank);
|
|
|
|
|
|
|
|
if (start > end) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "expected end dim larger than start dim");
|
|
|
|
}
|
|
|
|
|
|
|
|
newSizes.reserve(rank - end + start);
|
2022-08-26 06:00:01 +08:00
|
|
|
for (int64_t k = 0; k < start; ++k) {
|
2022-08-23 11:52:54 +08:00
|
|
|
Value dim =
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(k));
|
|
|
|
newSizes.push_back(
|
|
|
|
rewriter.create<AtenSizeIntOp>(loc, self, /*dim=*/dim));
|
|
|
|
}
|
2022-11-24 21:02:59 +08:00
|
|
|
Value flattenDimSize =
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
|
2022-08-23 11:52:54 +08:00
|
|
|
newSizes.push_back(flattenDimSize);
|
2022-08-26 06:00:01 +08:00
|
|
|
for (int64_t k = end + 1; k < rank; ++k) {
|
2022-08-23 11:52:54 +08:00
|
|
|
Value dim =
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(k));
|
|
|
|
newSizes.push_back(
|
|
|
|
rewriter.create<AtenSizeIntOp>(loc, self, /*dim=*/dim));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
Value newSizeList = rewriter.create<PrimListConstructOp>(
|
|
|
|
loc, ListType::get(IntType::get(context)), newSizes);
|
2022-12-08 04:20:41 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenViewOp>(op, op.getType(), op.getSelf(),
|
2022-08-23 11:52:54 +08:00
|
|
|
newSizeList);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-01-30 01:10:50 +08:00
|
|
|
// Decompose aten.expand into aten.broadcast_to op.
|
2021-11-03 00:48:29 +08:00
|
|
|
namespace {
|
|
|
|
class DecomposeAtenExpandOp : public OpRewritePattern<AtenExpandOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenExpandOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
bool implicit = false;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getImplicit(), m_TorchConstantBool(&implicit)) ||
|
2021-11-03 00:48:29 +08:00
|
|
|
implicit) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: requires implicit to be false");
|
|
|
|
}
|
2022-12-08 04:20:41 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenBroadcastToOp>(op, op.getType(), op.getSelf(),
|
|
|
|
op.getSize());
|
2021-11-03 00:48:29 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-03-12 01:21:36 +08:00
|
|
|
// Decompose aten.where.Scalar into aten.where.self op.
|
|
|
|
namespace {
|
|
|
|
class DecomposeAtenWhereScalarOp : public OpRewritePattern<AtenWhereScalarOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenWhereScalarOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
auto resType = op.getType().cast<BaseTensorType>();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value selfTensor = createRank0Tensor(rewriter, loc, resType, op.getSelf());
|
|
|
|
Value otherTensor = createRank0Tensor(rewriter, loc, resType, op.getOther());
|
|
|
|
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.getCondition(),
|
2022-03-12 01:21:36 +08:00
|
|
|
selfTensor, otherTensor);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
// Decompose aten.where.ScalarOther into aten.where.self op.
|
|
|
|
namespace {
|
|
|
|
class DecomposeAtenWhereScalarOtherOp
|
|
|
|
: public OpRewritePattern<AtenWhereScalarOtherOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenWhereScalarOtherOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
auto resType = op.getType().cast<BaseTensorType>();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value otherTensor = createRank0Tensor(rewriter, loc, resType, op.getOther());
|
|
|
|
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.getCondition(),
|
|
|
|
op.getSelf(), otherTensor);
|
2022-03-12 01:21:36 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
// Decompose aten.where.ScalarSelf into aten.where.self op.
|
|
|
|
namespace {
|
|
|
|
class DecomposeAtenWhereScalarSelfOp
|
|
|
|
: public OpRewritePattern<AtenWhereScalarSelfOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenWhereScalarSelfOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
auto resType = op.getType().cast<BaseTensorType>();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value selfTensor = createRank0Tensor(rewriter, loc, resType, op.getSelf());
|
|
|
|
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.getCondition(),
|
|
|
|
selfTensor, op.getOther());
|
2022-03-12 01:21:36 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-11-04 15:57:29 +08:00
|
|
|
// Decompose aten.convolution_overrideable to aten.convolution op.
|
2022-04-08 12:47:57 +08:00
|
|
|
namespace {
|
|
|
|
class DecomposeAtenConvolutionOverrideableOp
|
|
|
|
: public OpRewritePattern<AtenConvolutionOverrideableOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenConvolutionOverrideableOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(),
|
|
|
|
op.getStride(), op.getPadding(), op.getDilation(), op.getTransposed(),
|
|
|
|
op.getOutputPadding(), op.getGroups());
|
2022-04-08 12:47:57 +08:00
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-08-22 11:17:36 +08:00
|
|
|
// Decompose aten._convolution-like to aten.convolution
|
2022-07-08 14:44:03 +08:00
|
|
|
namespace {
|
2022-11-04 15:57:29 +08:00
|
|
|
template <typename ConvolutionLikeOp>
|
2022-08-22 11:17:36 +08:00
|
|
|
class DecomposeAten_ConvolutionLikeOp
|
|
|
|
: public OpRewritePattern<ConvolutionLikeOp> {
|
2022-07-08 14:44:03 +08:00
|
|
|
public:
|
2022-08-22 11:17:36 +08:00
|
|
|
using OpRewritePattern<ConvolutionLikeOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(ConvolutionLikeOp op,
|
2022-07-08 14:44:03 +08:00
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(),
|
|
|
|
op.getStride(), op.getPadding(), op.getDilation(), op.getTransposed(),
|
|
|
|
op.getOutputPadding(), op.getGroups());
|
2022-07-08 14:44:03 +08:00
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-04-08 12:47:57 +08:00
|
|
|
// Decompose aten.conv2d to aten.convolution
|
|
|
|
namespace {
|
|
|
|
class DecomposeAtenConv2dOp : public OpRewritePattern<AtenConv2dOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenConv2dOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
Value emptyList = rewriter.create<PrimListConstructOp>(
|
|
|
|
op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
|
|
|
SmallVector<Value>());
|
|
|
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
|
|
|
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(),
|
|
|
|
op.getStride(), op.getPadding(), op.getDilation(), cstFalse, emptyList,
|
|
|
|
op.getGroups());
|
2022-04-08 12:47:57 +08:00
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-08-25 00:19:35 +08:00
|
|
|
// Decompose aten.conv_transpose2d to aten.convolution
|
|
|
|
namespace {
|
|
|
|
class DecomposeAtenConvTranspose2dOp
|
|
|
|
: public OpRewritePattern<AtenConvTranspose2dInputOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenConvTranspose2dInputOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), true);
|
|
|
|
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(),
|
|
|
|
op.getStride(), op.getPadding(), op.getDilation(), /*transposed=*/cstTrue,
|
|
|
|
op.getOutputPadding(), op.getGroups());
|
2022-11-04 15:57:29 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
// Decompose aten.convolution_backward_overrideable to aten.convolution_backward
|
|
|
|
// op.
|
|
|
|
namespace {
|
|
|
|
class DecomposeAtenConvolutionBackwardOverrideableOp
|
|
|
|
: public OpRewritePattern<AtenConvolutionBackwardOverrideableOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenConvolutionBackwardOverrideableOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
Value none = rewriter.create<ConstantNoneOp>(op->getLoc());
|
|
|
|
rewriter.replaceOpWithNewOp<AtenConvolutionBackwardOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
op, op.getResultTypes(), op.getGradOutput(), op.getInput(), op.getWeight(),
|
|
|
|
none, op.getStride(), op.getPadding(), op.getDilation(), op.getTransposed(),
|
|
|
|
op.getOutputPadding(), op.getGroups(), op.getOutputMask());
|
2022-11-04 15:57:29 +08:00
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
class DecomposeAtenConvolutionBackwardOp
|
|
|
|
: public OpRewritePattern<AtenConvolutionBackwardOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenConvolutionBackwardOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
MLIRContext *context = op.getContext();
|
2022-12-13 00:56:28 +08:00
|
|
|
Value gradOutput = op.getGradOutput();
|
2022-12-20 18:17:27 +08:00
|
|
|
std::optional<unsigned> maybeGradRank = getTensorRank(gradOutput);
|
2022-12-13 00:56:28 +08:00
|
|
|
if (!maybeGradRank) {
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"expected grad output to have a rank");
|
|
|
|
}
|
|
|
|
unsigned gradRank = *maybeGradRank;
|
2022-11-04 15:57:29 +08:00
|
|
|
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
Value cstTwo = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(2));
|
|
|
|
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
|
|
|
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
|
|
|
|
loc, rewriter.getBoolAttr(false));
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = op.getInput();
|
|
|
|
Value weight = op.getWeight();
|
2022-11-04 15:57:29 +08:00
|
|
|
|
|
|
|
if (gradRank != 4)
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: only 2D convolutions supported.");
|
|
|
|
|
|
|
|
SmallVector<Value> padding;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!getListConstructElements(op.getPadding(), padding))
|
2022-11-04 15:57:29 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "padding must be a list.");
|
|
|
|
|
|
|
|
SmallVector<Value> strides;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!getListConstructElements(op.getStride(), strides))
|
2022-11-04 15:57:29 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "stride must be a list.");
|
|
|
|
for (Value stride : strides) {
|
|
|
|
Value cmp = rewriter.create<Torch::AtenEqIntOp>(loc, stride, cstOne);
|
|
|
|
rewriter.create<Torch::RuntimeAssertOp>(
|
|
|
|
loc, cmp, "unimplemented: only strides of 1 supported.");
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<Value> dilations;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!getListConstructElements(op.getDilation(), dilations))
|
2022-11-04 15:57:29 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "dilation must be a list.");
|
|
|
|
for (Value dilation : dilations) {
|
|
|
|
Value cmp = rewriter.create<Torch::AtenEqIntOp>(loc, dilation, cstOne);
|
|
|
|
rewriter.create<Torch::RuntimeAssertOp>(
|
|
|
|
loc, cmp, "unimplemented: only dilations of 1 supported.");
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<bool> outMask;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getOutputMask(), m_TorchListOfConstantBools(outMask)))
|
2022-11-04 15:57:29 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "only constant bool output_mask is supported.");
|
|
|
|
// Support for `False` values for output mask unimplemented.
|
|
|
|
if (!llvm::all_of(outMask, [](bool mask) { return mask; }))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: only true values for output_mask supported.");
|
|
|
|
|
|
|
|
bool transposed;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed)))
|
2022-11-04 15:57:29 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "only constant transposed is supported.");
|
|
|
|
if (transposed)
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: transposed convolutions are not supported.");
|
|
|
|
|
|
|
|
// Rotate weight.
|
|
|
|
SmallVector<Value> axes;
|
2022-12-13 00:56:28 +08:00
|
|
|
for (unsigned i = 2; i < gradRank; i++) {
|
2022-11-04 15:57:29 +08:00
|
|
|
axes.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(i)));
|
|
|
|
}
|
|
|
|
Value axesList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), axes);
|
|
|
|
weight = rewriter.create<Torch::AtenFlipOp>(loc, weight.getType(), weight,
|
|
|
|
axesList);
|
|
|
|
// Calculate padding for first convolution.
|
|
|
|
SmallVector<Value> gradInputPaddingValues;
|
2022-12-13 00:56:28 +08:00
|
|
|
for (unsigned i = 2; i < gradRank; i++) {
|
2022-11-04 15:57:29 +08:00
|
|
|
Value dim = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(i));
|
|
|
|
Value outDim = rewriter.create<Torch::AtenSizeIntOp>(loc, input, dim);
|
|
|
|
|
|
|
|
// Calculate 1 + (weightDim // 2) * 2, which fixes issues with
|
|
|
|
// even-sized weight.
|
|
|
|
Value weightDim = rewriter.create<Torch::AtenSizeIntOp>(loc, weight, dim);
|
|
|
|
weightDim =
|
|
|
|
rewriter.create<Torch::AtenFloordivIntOp>(loc, weightDim, cstTwo);
|
|
|
|
weightDim = rewriter.create<Torch::AtenMulIntOp>(loc, weightDim, cstTwo);
|
|
|
|
weightDim = rewriter.create<Torch::AtenAddIntOp>(loc, weightDim, cstOne);
|
|
|
|
Value gradOutDim =
|
|
|
|
rewriter.create<Torch::AtenSizeIntOp>(loc, gradOutput, dim);
|
|
|
|
|
|
|
|
// Calculate (((outDim - 1) * stride) + weightDim - gradOutDim) // 2,
|
|
|
|
// the padding value for this dimension. Derived from the formula at
|
|
|
|
// https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
|
|
|
|
Value padVal = rewriter.create<Torch::AtenSubIntOp>(loc, outDim, cstOne);
|
|
|
|
padVal =
|
|
|
|
rewriter.create<Torch::AtenMulIntOp>(loc, padVal, strides[i - 2]);
|
|
|
|
padVal = rewriter.create<Torch::AtenAddIntOp>(loc, padVal, weightDim);
|
|
|
|
padVal = rewriter.create<Torch::AtenSubIntOp>(loc, padVal, gradOutDim);
|
|
|
|
padVal = rewriter.create<Torch::AtenFloordivIntOp>(loc, padVal, cstTwo);
|
|
|
|
|
|
|
|
gradInputPaddingValues.push_back(padVal);
|
|
|
|
}
|
|
|
|
Value gradInputPadding = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
loc, ListType::get(IntType::get(context)), gradInputPaddingValues);
|
|
|
|
Value weightTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
|
|
|
|
loc, weight.getType(), weight, cstZero, cstOne);
|
|
|
|
// Convolve grad_output with weight.
|
|
|
|
Value gradInput = rewriter.create<Torch::AtenConvolutionOp>(
|
|
|
|
loc, op.getResultTypes()[0], gradOutput, weightTransposed, cstNone,
|
2022-12-08 04:20:41 +08:00
|
|
|
op.getStride(), gradInputPadding, op.getDilation(), op.getTransposed(),
|
|
|
|
op.getOutputPadding(), op.getGroups());
|
2022-11-04 15:57:29 +08:00
|
|
|
|
|
|
|
Value gradOutputTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
|
|
|
|
loc, gradOutput.getType(), gradOutput, cstZero, cstOne);
|
|
|
|
Value inputTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
|
|
|
|
loc, input.getType(), input, cstZero, cstOne);
|
|
|
|
// Convolve input with grad_output.
|
|
|
|
Value gradWeight = rewriter.create<Torch::AtenConvolutionOp>(
|
|
|
|
loc, op.getResultTypes()[1], inputTransposed, gradOutputTransposed,
|
2022-12-08 04:20:41 +08:00
|
|
|
cstNone, op.getStride(), op.getPadding(), op.getDilation(), op.getTransposed(),
|
|
|
|
op.getOutputPadding(), op.getGroups());
|
2022-11-04 15:57:29 +08:00
|
|
|
gradWeight = rewriter.create<Torch::AtenTransposeIntOp>(
|
|
|
|
loc, gradWeight.getType(), gradWeight, cstZero, cstOne);
|
|
|
|
|
|
|
|
SmallVector<Value> dimIntList{cstZero};
|
2022-12-13 00:56:28 +08:00
|
|
|
for (unsigned i = 2; i < gradRank; i++)
|
2022-11-04 15:57:29 +08:00
|
|
|
dimIntList.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(i)));
|
|
|
|
Value gradIntList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
|
|
|
dimIntList);
|
|
|
|
// Sum grad_output along dim 1.
|
|
|
|
Value gradBias = rewriter.create<Torch::AtenSumDimIntListOp>(
|
|
|
|
loc, op.getResultTypes()[2], gradOutput, gradIntList, cstFalse,
|
|
|
|
cstNone);
|
2022-08-25 00:19:35 +08:00
|
|
|
|
2022-11-04 15:57:29 +08:00
|
|
|
rewriter.replaceOp(op, {gradInput, gradWeight, gradBias});
|
2022-08-25 00:19:35 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-01-30 01:10:50 +08:00
|
|
|
// Decompose aten.addmm into aten.mm and aten.add.Tensor op.
|
2021-11-11 17:02:13 +08:00
|
|
|
namespace {
|
|
|
|
class DecomposeAtenAddmmOp : public OpRewritePattern<AtenAddmmOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenAddmmOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = op.getSelf();
|
|
|
|
Value mat1 = op.getMat1();
|
|
|
|
Value mat2 = op.getMat2();
|
2022-12-20 18:17:27 +08:00
|
|
|
std::optional<unsigned> mat1Rank = getTensorRank(mat1);
|
|
|
|
std::optional<unsigned> mat2Rank = getTensorRank(mat2);
|
2021-11-11 17:02:13 +08:00
|
|
|
|
|
|
|
// The operands `mat1`, `mat2` to aten.addmm must be of rank 2.
|
2022-12-13 00:56:28 +08:00
|
|
|
if (!mat1Rank || !mat2Rank || *mat1Rank != 2 || *mat2Rank != 2) {
|
2021-11-11 17:02:13 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "expected mat1, mat2 operands to aten.addmm to be rank 2");
|
|
|
|
}
|
|
|
|
|
|
|
|
// TODO: Handle integer type operands.
|
2023-01-04 06:19:18 +08:00
|
|
|
auto inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
|
2021-11-11 17:02:13 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: non-floating point dtype");
|
|
|
|
}
|
|
|
|
|
|
|
|
// matrix multiplication: matmul = mat1 @ mat2
|
|
|
|
Value matmul = rewriter.create<AtenMmOp>(loc, op.getType(), mat1, mat2);
|
|
|
|
// scaledInput = self * beta
|
|
|
|
Value scaledInput = rewriter.create<AtenMulScalarOp>(loc, input.getType(),
|
2022-12-08 04:20:41 +08:00
|
|
|
input, op.getBeta());
|
2021-11-11 17:02:13 +08:00
|
|
|
// result = scaledInput + alpha * matmul
|
|
|
|
rewriter.replaceOpWithNewOp<AtenAddTensorOp>(op, op.getType(), scaledInput,
|
2022-12-08 04:20:41 +08:00
|
|
|
matmul, op.getAlpha());
|
2021-11-11 17:02:13 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-01-30 01:10:50 +08:00
|
|
|
// Decompose aten.mean into: sum(x)/div(numTensorElements).
|
2021-11-19 23:59:29 +08:00
|
|
|
namespace {
|
|
|
|
class DecomposeAtenMeanOp : public OpRewritePattern<AtenMeanOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenMeanOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = op.getSelf();
|
|
|
|
Value output = op.getResult();
|
2021-11-19 23:59:29 +08:00
|
|
|
BaseTensorType outputTensorType = output.getType().cast<BaseTensorType>();
|
2022-12-08 01:51:37 +08:00
|
|
|
Value sum =
|
|
|
|
rewriter.create<AtenSumOp>(loc, outputTensorType, input, op.getDtype());
|
2021-11-19 23:59:29 +08:00
|
|
|
Value numTensorElements = rewriter.create<AtenNumelOp>(loc, input);
|
2022-12-08 01:51:37 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, outputTensorType, sum,
|
|
|
|
numTensorElements);
|
2021-11-19 23:59:29 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-03-11 01:25:21 +08:00
|
|
|
// productDimSize = product(size(dim) for dim in dims)
|
|
|
|
// aten.mean(x, dims) = aten.sum(x, dims) / productDimSize.
|
|
|
|
namespace {
|
|
|
|
class DecomposeAtenMeanDimOp : public OpRewritePattern<AtenMeanDimOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenMeanDimOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = op.getSelf();
|
2022-12-20 18:17:27 +08:00
|
|
|
std::optional<unsigned> maybeInputRank = getTensorRank(input);
|
2022-12-13 00:56:28 +08:00
|
|
|
if (!maybeInputRank) {
|
|
|
|
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
|
|
|
|
}
|
|
|
|
unsigned inputRank = *maybeInputRank;
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
Value dimList = op.getDim();
|
|
|
|
Value keepDim = op.getKeepdim();
|
|
|
|
Value dtype = op.getDtype();
|
2022-12-08 01:51:37 +08:00
|
|
|
Type outputType = op.getType();
|
2022-03-11 01:25:21 +08:00
|
|
|
MLIRContext *context = op.getContext();
|
|
|
|
|
|
|
|
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>() ||
|
|
|
|
!isNoneOrFloatDtype(context, dtype)) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "only floating-point type is supported");
|
|
|
|
}
|
|
|
|
|
2022-08-03 00:08:06 +08:00
|
|
|
SmallVector<Value> dimListElements;
|
|
|
|
if (!getListConstructElements(dimList, dimListElements) &&
|
|
|
|
!dimList.getType().isa<Torch::NoneType>()) {
|
2022-03-11 01:25:21 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
2022-08-03 00:08:06 +08:00
|
|
|
op, "expected `dim` to be `None` or constructed from list construct");
|
2022-03-11 01:25:21 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
// Compute sum along dimensions specified in `dimList`.
|
|
|
|
Value sumAlongDims = rewriter.create<AtenSumDimIntListOp>(
|
2022-12-08 01:51:37 +08:00
|
|
|
loc, outputType, input, dimList, keepDim, dtype);
|
2022-03-11 01:25:21 +08:00
|
|
|
|
|
|
|
// `productDimSize` is product of sizes of dimensions to be reduced.
|
2022-07-28 22:24:24 +08:00
|
|
|
Value productDimSize;
|
|
|
|
// Case: Reduce along all dims.
|
2022-08-03 00:08:06 +08:00
|
|
|
if (dimListElements.empty() && inputRank != 0) {
|
2022-07-28 22:24:24 +08:00
|
|
|
productDimSize = rewriter.create<AtenNumelOp>(loc, input);
|
|
|
|
} else {
|
|
|
|
productDimSize = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(1));
|
2022-08-03 00:08:06 +08:00
|
|
|
for (Value dim : dimListElements) {
|
2022-07-28 22:24:24 +08:00
|
|
|
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, input, dim);
|
|
|
|
productDimSize =
|
|
|
|
rewriter.create<AtenMulIntOp>(loc, productDimSize, dimSize);
|
|
|
|
}
|
2022-03-11 01:25:21 +08:00
|
|
|
}
|
2022-12-08 01:51:37 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, outputType, sumAlongDims,
|
|
|
|
productDimSize);
|
2022-03-11 01:25:21 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-01-30 01:10:50 +08:00
|
|
|
namespace {
|
|
|
|
class DecomposeAtenSquareOp : public OpRewritePattern<AtenSquareOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenSquareOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
Value self = op.getSelf();
|
2022-01-30 01:10:50 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenMulTensorOp>(op, op.getType(), self, self);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-03-02 01:30:58 +08:00
|
|
|
// Silu(x) = sigmoid(x) * x
|
|
|
|
namespace {
|
|
|
|
class DecomposeAtenSiluOp : public OpRewritePattern<AtenSiluOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenSiluOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
Value self = op.getSelf();
|
2022-03-02 01:30:58 +08:00
|
|
|
Value sigmoid =
|
|
|
|
rewriter.create<AtenSigmoidOp>(op.getLoc(), op.getType(), self);
|
|
|
|
rewriter.replaceOpWithNewOp<AtenMulTensorOp>(op, op.getType(), sigmoid,
|
|
|
|
self);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-02-17 00:34:03 +08:00
|
|
|
// pDash = 1.0 - p
|
|
|
|
// boolMask = aten.rand_like(input) < pDash
|
|
|
|
// dropout(input, p, train=True) = (boolMask * input) / pDash
|
|
|
|
// dropout(input, p, train=False) = input
|
|
|
|
namespace {
|
|
|
|
class DecomposeAtenDropoutOp : public OpRewritePattern<AtenDropoutOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenDropoutOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = op.getInput();
|
|
|
|
Value prob = op.getP();
|
2022-02-17 00:34:03 +08:00
|
|
|
bool train = false;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getTrain(), m_TorchConstantBool(&train)))
|
2022-02-17 00:34:03 +08:00
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"train must be a boolean constant");
|
|
|
|
if (!train) {
|
|
|
|
rewriter.replaceOp(op, input);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>())
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "only support floating type input for training mode");
|
|
|
|
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
Value floatOne =
|
|
|
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
|
|
|
Value oneMinusP = rewriter.create<AtenSubFloatOp>(loc, floatOne, prob);
|
|
|
|
Value boolMask = rewriter.create<ValsemVariantAtenBernoulliFloatOp>(
|
|
|
|
loc, inputType, input, oneMinusP, /*generator=*/noneVal);
|
|
|
|
Value maskedInput =
|
|
|
|
rewriter.create<AtenMulTensorOp>(loc, inputType, boolMask, input);
|
|
|
|
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, op.getType(), maskedInput,
|
|
|
|
oneMinusP);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-06-29 15:23:57 +08:00
|
|
|
// Decompose aten.var into: aten.var.dim op.
|
2022-01-30 01:10:50 +08:00
|
|
|
namespace {
|
|
|
|
class DecomposeAtenVarOp : public OpRewritePattern<AtenVarOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenVarOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value self = op.getSelf();
|
2022-12-20 18:17:27 +08:00
|
|
|
std::optional<unsigned> maybeInputRank = getTensorRank(self);
|
2022-12-13 00:56:28 +08:00
|
|
|
if (!maybeInputRank) {
|
|
|
|
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
|
|
|
|
}
|
|
|
|
unsigned inputRank = *maybeInputRank;
|
2022-01-30 01:10:50 +08:00
|
|
|
BaseTensorType rank0FloatTensorTy = op.getType().cast<BaseTensorType>();
|
2022-03-10 08:44:22 +08:00
|
|
|
if (!rank0FloatTensorTy.hasSizes() ||
|
|
|
|
rank0FloatTensorTy.getSizes().size() != 0) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "expected aten.var to have a rank 0 tensor type");
|
|
|
|
}
|
2022-01-30 01:10:50 +08:00
|
|
|
|
2022-06-29 15:23:57 +08:00
|
|
|
SmallVector<Value> dims;
|
|
|
|
for (unsigned i = 0; i < inputRank; i++)
|
|
|
|
dims.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(i)));
|
|
|
|
Value dimList = rewriter.create<PrimListConstructOp>(
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), dims);
|
2022-01-30 01:10:50 +08:00
|
|
|
|
2022-06-29 15:23:57 +08:00
|
|
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
|
|
|
rewriter.replaceOpWithNewOp<AtenVarDimOp>(op, rank0FloatTensorTy, self,
|
2022-12-08 04:20:41 +08:00
|
|
|
dimList, op.getUnbiased(),
|
2022-06-29 15:23:57 +08:00
|
|
|
/*keepdim=*/cstFalse);
|
2022-01-30 01:10:50 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
// Decompose aten.std to sqrt(var(x))
|
|
|
|
namespace {
|
|
|
|
class DecomposeAtenStdOp : public OpRewritePattern<AtenStdOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenStdOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
Value self = op.getSelf();
|
2022-01-30 01:10:50 +08:00
|
|
|
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
|
|
|
|
if (!inputTensorTy.hasDtype() ||
|
|
|
|
!inputTensorTy.getDtype().isa<mlir::FloatType>()) {
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"Only aten.std support floating type");
|
|
|
|
}
|
|
|
|
Value var = rewriter.create<AtenVarOp>(op->getLoc(), op.getType(),
|
2022-12-08 04:20:41 +08:00
|
|
|
op.getSelf(), op.getUnbiased());
|
2022-01-30 01:10:50 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenSqrtOp>(op, op.getType(), var);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-07-18 03:00:29 +08:00
|
|
|
// Softplus(x, beta, threshold) =
|
|
|
|
// x * beta > threshold ? x : log(1 + exp(x * beta)) / beta
|
|
|
|
namespace {
|
|
|
|
class DecomposeAtenSoftplusOp : public OpRewritePattern<AtenSoftplusOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenSoftplusOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = op.getSelf();
|
2022-07-18 03:00:29 +08:00
|
|
|
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
2022-07-22 20:42:14 +08:00
|
|
|
Value inputTimesBeta =
|
2022-12-08 04:20:41 +08:00
|
|
|
rewriter.create<AtenMulScalarOp>(loc, inputType, input, op.getBeta());
|
2022-07-18 03:00:29 +08:00
|
|
|
|
|
|
|
// out = log1p(exp(input * beta)) / beta
|
|
|
|
Value exp = rewriter.create<AtenExpOp>(loc, inputType, inputTimesBeta);
|
|
|
|
Value log1p = rewriter.create<AtenLog1pOp>(loc, inputType, exp);
|
2022-07-22 20:42:14 +08:00
|
|
|
Value out =
|
2022-12-08 04:20:41 +08:00
|
|
|
rewriter.create<AtenDivScalarOp>(loc, inputType, log1p, op.getBeta());
|
2022-07-18 03:00:29 +08:00
|
|
|
|
|
|
|
// Select where x * beta > threshold
|
|
|
|
auto boolResType = inputType.getWithSizesAndDtype(inputType.getSizes(),
|
|
|
|
rewriter.getI1Type());
|
|
|
|
Value condition = rewriter.create<AtenGtScalarOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
loc, boolResType, inputTimesBeta, op.getThreshold());
|
2022-07-18 03:00:29 +08:00
|
|
|
|
2022-07-22 20:42:14 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, op.getType(), condition,
|
|
|
|
input, out);
|
2022-07-18 03:00:29 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-07-27 04:02:01 +08:00
|
|
|
// Decompose aten.std.dim to sqrt(var.dim(x))
|
|
|
|
namespace {
|
|
|
|
class DecomposeAtenStdDimOp : public OpRewritePattern<AtenStdDimOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenStdDimOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
Value self = op.getSelf();
|
2022-07-27 04:02:01 +08:00
|
|
|
BaseTensorType inputTensorType = self.getType().cast<BaseTensorType>();
|
|
|
|
if (!inputTensorType.hasDtype() ||
|
|
|
|
!inputTensorType.getDtype().isa<mlir::FloatType>()) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "aten.std.dim expects input tensor of floating-point type");
|
|
|
|
}
|
|
|
|
|
|
|
|
Value varDim =
|
|
|
|
rewriter.create<AtenVarDimOp>(op->getLoc(), op.getType(), self,
|
2022-12-08 04:20:41 +08:00
|
|
|
op.getDim(), op.getUnbiased(), op.getKeepdim());
|
2022-07-27 04:02:01 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenSqrtOp>(op, op.getType(), varDim);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-12-22 13:02:40 +08:00
|
|
|
// Decompose aten.std.correction to sqrt(var.correction(x))
|
|
|
|
namespace {
|
|
|
|
class DecomposeAtenStdCorrectionOp
|
|
|
|
: public OpRewritePattern<AtenStdCorrectionOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenStdCorrectionOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Value self = op.getSelf();
|
|
|
|
BaseTensorType inputTensorType = self.getType().cast<BaseTensorType>();
|
|
|
|
if (!inputTensorType.hasDtype() ||
|
|
|
|
!inputTensorType.getDtype().isa<mlir::FloatType>()) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op,
|
|
|
|
"aten.std.correction expects input tensor of floating-point type");
|
|
|
|
}
|
|
|
|
|
|
|
|
Value varCorrection = rewriter.create<AtenVarCorrectionOp>(
|
|
|
|
op->getLoc(), op.getType(), self, op.getDim(), op.getCorrection(),
|
|
|
|
op.getKeepdim());
|
|
|
|
rewriter.replaceOpWithNewOp<AtenSqrtOp>(op, op.getType(), varCorrection);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-02-14 22:46:44 +08:00
|
|
|
// Hardsigmoid(x) = max(0, min(1, (x+3)/6))
|
|
|
|
namespace {
|
|
|
|
class DecomposeAtenHardsigmoidOp : public OpRewritePattern<AtenHardsigmoidOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenHardsigmoidOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = op.getSelf();
|
2022-02-09 04:57:23 +08:00
|
|
|
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
2022-02-14 22:46:44 +08:00
|
|
|
|
|
|
|
// outputTensor = (input + 3) / 6.
|
|
|
|
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
Value constantThree = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(3));
|
|
|
|
Value constantSix = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(6));
|
|
|
|
Value inputPlusThree = rewriter.create<AtenAddScalarOp>(
|
|
|
|
loc, inputType, input, constantThree, /*alpha=*/constantOne);
|
|
|
|
Value outputTensor = rewriter.create<AtenDivScalarOp>(
|
|
|
|
loc, inputType, inputPlusThree, constantSix);
|
|
|
|
|
|
|
|
// result = max(0, min(1, (input+3)/6))
|
2022-02-09 04:57:23 +08:00
|
|
|
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
Value oneTensor = createRank0Tensor(rewriter, loc, inputType, constantOne);
|
2022-02-14 22:46:44 +08:00
|
|
|
Value minResult =
|
|
|
|
rewriter.create<AtenMinimumOp>(loc, inputType, oneTensor, outputTensor);
|
2022-02-09 04:57:23 +08:00
|
|
|
Value zeroTensor =
|
|
|
|
createRank0Tensor(rewriter, loc, inputType, constantZero);
|
2022-02-14 22:46:44 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenMaximumOp>(op, op.getType(), zeroTensor,
|
|
|
|
minResult);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-02-09 04:57:23 +08:00
|
|
|
namespace {
|
|
|
|
class DecomposeAtenHardtanhOp : public OpRewritePattern<AtenHardtanhOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenHardtanhOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = op.getSelf();
|
2022-02-09 04:57:23 +08:00
|
|
|
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
|
|
|
// result = min(maxVal, max(minVal, x))
|
2022-12-08 04:20:41 +08:00
|
|
|
Value minVal = createRank0Tensor(rewriter, loc, inputType, op.getMinVal());
|
2022-02-09 04:57:23 +08:00
|
|
|
Value maxResult =
|
|
|
|
rewriter.create<AtenMaximumOp>(loc, inputType, input, minVal);
|
2022-12-08 04:20:41 +08:00
|
|
|
Value maxVal = createRank0Tensor(rewriter, loc, inputType, op.getMaxVal());
|
2022-02-09 04:57:23 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenMinimumOp>(op, op.getType(), maxVal,
|
|
|
|
maxResult);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-02-26 00:35:04 +08:00
|
|
|
namespace {
|
|
|
|
class DecomposeAtenRandLikeOp : public OpRewritePattern<AtenRandLikeOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenRandLikeOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = op.getSelf();
|
2022-05-13 07:00:59 +08:00
|
|
|
Type resultType = op.getType();
|
2022-02-26 00:35:04 +08:00
|
|
|
auto inputType = input.getType().cast<BaseTensorType>();
|
2022-05-13 07:00:59 +08:00
|
|
|
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
|
2022-02-26 00:35:04 +08:00
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"only support floating-point type");
|
|
|
|
}
|
|
|
|
|
2022-05-13 07:00:59 +08:00
|
|
|
// Create a uniform random op with low and high set to 0.0 and 1.0,
|
2022-02-26 00:35:04 +08:00
|
|
|
// respectively.
|
|
|
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
2022-05-13 07:00:59 +08:00
|
|
|
Value zero =
|
2022-02-26 00:35:04 +08:00
|
|
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
|
2022-05-13 07:00:59 +08:00
|
|
|
Value one =
|
2022-02-26 00:35:04 +08:00
|
|
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
2022-09-23 10:24:36 +08:00
|
|
|
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
loc, resultType, input, zero, op.getDtype(), op.getLayout(), op.getDevice(),
|
|
|
|
op.getPinMemory(), op.getMemoryFormat());
|
2022-10-28 23:06:11 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenUniformOp>(op, resultType, emptyTensor,
|
|
|
|
/*from=*/zero, /*to=*/one,
|
|
|
|
/*generator=*/none);
|
2022-02-26 00:35:04 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
// Bernoulli(x, p) = (randLike(float(x)) < p).cast(type(x)). Here,
|
2022-02-26 00:35:04 +08:00
|
|
|
// 1. p must be a float tensor.
|
|
|
|
// 2. The shape of p should be broadcastable to the shape of x.
|
|
|
|
// 3. Bernoulli(x, p) returns a tensor of the same type as that of x.
|
2022-02-09 04:57:23 +08:00
|
|
|
static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter,
|
|
|
|
Operation *op, Location loc,
|
2022-02-26 00:35:04 +08:00
|
|
|
Value input, Value prob,
|
|
|
|
Value &output) {
|
|
|
|
auto inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
auto probType = prob.getType().cast<BaseTensorType>();
|
|
|
|
// Both the `input` and `prob` must be ranked tensors.
|
|
|
|
if (!inputType.hasSizes() || !inputType.hasDtype() || !probType.hasSizes() ||
|
|
|
|
!probType.hasDtype()) {
|
2022-02-09 04:57:23 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
2022-02-26 00:35:04 +08:00
|
|
|
op, "can't decompose bernoulli like ops without sizes or dtype");
|
|
|
|
}
|
|
|
|
// The `prob` is expected to be a float type tensor.
|
|
|
|
if (!probType.getDtype().isa<mlir::FloatType>()) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "probabilities must be a float type tensor");
|
2022-02-09 04:57:23 +08:00
|
|
|
}
|
2022-02-04 19:43:25 +08:00
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
// Since the `aten.randLike` op expects float-type operand, create a
|
2022-02-26 00:35:04 +08:00
|
|
|
// float-type tensor with the same shape as that of the `input`.
|
|
|
|
Value floatTensor =
|
|
|
|
convertTensorToDtype(rewriter, loc, input, rewriter.getF64Type());
|
|
|
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
Value randomVal = rewriter.create<AtenRandLikeOp>(
|
|
|
|
loc, floatTensor.getType(), floatTensor, /*dtype=*/none, /*layout=*/none,
|
2022-12-08 04:20:41 +08:00
|
|
|
/*device=*/none, /*pinMemory=*/none, /*memoryFormat=*/none);
|
2022-02-26 00:35:04 +08:00
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
// Bernoulli(x, p) = randLike(float(x)) < p.
|
2022-02-26 00:35:04 +08:00
|
|
|
auto boolResType = inputType.getWithSizesAndDtype(inputType.getSizes(),
|
|
|
|
rewriter.getI1Type());
|
|
|
|
Value lessThanP =
|
|
|
|
rewriter.create<AtenLtTensorOp>(loc, boolResType, randomVal, prob);
|
|
|
|
|
|
|
|
// As the `output` is expected to be of the `input` type, convert the boolean
|
|
|
|
// tensor `lessThanP` to a `input` type tensor.
|
|
|
|
output = convertTensorToDtype(rewriter, loc, lessThanP, inputType.getDtype());
|
2022-02-09 04:57:23 +08:00
|
|
|
return success();
|
2022-02-04 19:43:25 +08:00
|
|
|
}
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
// aten.bernoulli(x) = randLike(x) < x. Here, the input x is a tensor
|
2022-02-26 00:35:04 +08:00
|
|
|
// containing probabilities to be used for drawing the binary random number.
|
2022-02-04 19:43:25 +08:00
|
|
|
class DecomposeAtenBernoulliOp : public OpRewritePattern<AtenBernoulliOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenBernoulliOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = op.getSelf();
|
|
|
|
if (!op.getGenerator().getType().isa<Torch::NoneType>())
|
2022-02-04 19:43:25 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "The generator has to ben None because only global default "
|
|
|
|
"generator is supported");
|
2022-02-26 00:35:04 +08:00
|
|
|
Value output;
|
|
|
|
if (failed(
|
|
|
|
decomposeBernoulliLikeOp(rewriter, op, loc, input, input, output)))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "decomposeBernoulliLikeOp failed to decompose the op");
|
|
|
|
rewriter.replaceOp(op, output);
|
2022-02-04 19:43:25 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
// aten.bernoulli.float(x, p) = (randLike(float(x)) < tensor(p)).cast(type(x)).
|
2022-02-26 00:35:04 +08:00
|
|
|
// Since the input x can be an integer tensor, it's important to cast it to
|
2022-12-08 04:20:41 +08:00
|
|
|
// float type before passing it to the `aten.randLike` op.
|
2022-03-16 07:57:33 +08:00
|
|
|
class DecomposeValsemVariantAtenBernoulliFloatOp
|
|
|
|
: public OpRewritePattern<ValsemVariantAtenBernoulliFloatOp> {
|
2022-02-04 19:43:25 +08:00
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
2022-03-16 07:57:33 +08:00
|
|
|
LogicalResult matchAndRewrite(ValsemVariantAtenBernoulliFloatOp op,
|
2022-02-04 19:43:25 +08:00
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = op.getSelf();
|
|
|
|
Value p = op.getP();
|
|
|
|
if (!op.getGenerator().getType().isa<Torch::NoneType>())
|
2022-02-26 00:35:04 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "The generator has to ben None because only global default "
|
|
|
|
"generator is supported");
|
2022-02-04 19:43:25 +08:00
|
|
|
|
2022-02-26 00:35:04 +08:00
|
|
|
auto inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
SmallVector<int64_t> empty;
|
|
|
|
Type tensorType = inputType.getWithSizesAndDtype(llvm::makeArrayRef(empty),
|
|
|
|
rewriter.getF64Type());
|
|
|
|
Value prob = rewriter.create<PrimNumToTensorScalarOp>(loc, tensorType, p);
|
|
|
|
Value output;
|
|
|
|
if (failed(
|
|
|
|
decomposeBernoulliLikeOp(rewriter, op, loc, input, prob, output)))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "decomposeBernoulliLikeOp failed to decompose the op");
|
|
|
|
rewriter.replaceOp(op, output);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
// aten.bernoulli.Tensor(x, p) = (randLike(float(x)) < p).cast(type(x)).
|
2022-02-26 00:35:04 +08:00
|
|
|
// Since the input x can be an integer tensor, it's important to cast it to
|
2022-12-08 04:20:41 +08:00
|
|
|
// float type before passing it to the `aten.randLike` op.
|
2022-10-28 23:06:11 +08:00
|
|
|
class DecomposeAtenBernoulliTensorOp
|
|
|
|
: public OpRewritePattern<AtenBernoulliTensorOp> {
|
2022-02-26 00:35:04 +08:00
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
2022-10-28 23:06:11 +08:00
|
|
|
LogicalResult matchAndRewrite(AtenBernoulliTensorOp op,
|
2022-02-26 00:35:04 +08:00
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = op.getSelf();
|
|
|
|
Value prob = op.getP();
|
|
|
|
if (!op.getGenerator().getType().isa<Torch::NoneType>())
|
2022-02-04 19:43:25 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "The generator has to ben None because only global default "
|
|
|
|
"generator is supported");
|
2022-02-26 00:35:04 +08:00
|
|
|
Value output;
|
|
|
|
if (failed(
|
|
|
|
decomposeBernoulliLikeOp(rewriter, op, loc, input, prob, output)))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "decomposeBernoulliLikeOp failed to decompose the op");
|
|
|
|
rewriter.replaceOp(op, output);
|
2022-02-04 19:43:25 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2021-11-25 06:01:48 +08:00
|
|
|
namespace {
|
2022-02-15 21:14:32 +08:00
|
|
|
template <typename OpTy, typename T1T2Op>
|
2021-11-25 06:01:48 +08:00
|
|
|
class DecomposeAtenAddCLikeOp : public OpRewritePattern<OpTy> {
|
|
|
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(OpTy op,
|
2022-02-15 21:14:32 +08:00
|
|
|
PatternRewriter &rewriter) const override {
|
2021-11-25 06:01:48 +08:00
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = op.getSelf();
|
|
|
|
Value tensor1 = op.getTensor1();
|
|
|
|
Value tensor2 = op.getTensor2();
|
|
|
|
Value value = op.getValue();
|
2021-11-25 06:01:48 +08:00
|
|
|
|
2022-02-15 21:14:32 +08:00
|
|
|
Value product =
|
|
|
|
rewriter.create<T1T2Op>(loc, op.getType(), tensor1, tensor2);
|
|
|
|
rewriter.replaceOpWithNewOp<AtenAddTensorOp>(op, op.getType(), input,
|
|
|
|
product, value);
|
2021-11-25 06:01:48 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
2021-12-10 21:36:19 +08:00
|
|
|
|
|
|
|
class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
|
|
|
|
using OpRewritePattern<AtenLayerNormOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenLayerNormOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
auto input = op.getInput().getType().cast<BaseTensorType>();
|
2021-12-10 21:36:19 +08:00
|
|
|
if (!input.hasSizes())
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "input tensor should have known sizes.");
|
|
|
|
int64_t inputRank = input.getSizes().size();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value normalizedShape = op.getNormalizedShape();
|
2021-12-10 21:36:19 +08:00
|
|
|
SmallVector<Value> normalizedShapeSizesTorchInt;
|
|
|
|
getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt);
|
2022-03-16 20:51:57 +08:00
|
|
|
int64_t axis = inputRank - normalizedShapeSizesTorchInt.size();
|
|
|
|
std::vector<int64_t> meanVarSizes(inputRank, 1);
|
|
|
|
for (int i = 0; i < axis; i++)
|
|
|
|
meanVarSizes[i] = input.getSizes()[i];
|
2021-12-10 21:36:19 +08:00
|
|
|
auto meanVarType = input.getWithSizesAndDtype(
|
2023-01-04 06:19:18 +08:00
|
|
|
llvm::makeArrayRef(meanVarSizes), input.getOptionalDtype());
|
2021-12-10 21:36:19 +08:00
|
|
|
auto nativeLayerNorm = rewriter.create<AtenNativeLayerNormOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
loc, op.getType(), meanVarType, meanVarType, op.getInput(),
|
|
|
|
op.getNormalizedShape(), op.getWeight(), op.getBias(), op.getEps());
|
2021-12-10 21:36:19 +08:00
|
|
|
rewriter.replaceOp(op, nativeLayerNorm.getResult(0));
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
2021-11-25 06:01:48 +08:00
|
|
|
} // namespace
|
|
|
|
|
2022-09-02 09:29:22 +08:00
|
|
|
namespace {
|
|
|
|
class DecomposeAtenNativeLayerNormOp
|
|
|
|
: public OpRewritePattern<AtenNativeLayerNormOp> {
|
|
|
|
using OpRewritePattern<AtenNativeLayerNormOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenNativeLayerNormOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
auto context = op.getContext();
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
auto inputTy = op.getInput().getType().cast<BaseTensorType>();
|
2022-09-02 09:29:22 +08:00
|
|
|
if (!inputTy.hasSizes())
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "input tensor should have known sizes.");
|
|
|
|
int64_t inputRank = inputTy.getSizes().size();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value normalizedShape = op.getNormalizedShape();
|
2022-09-02 09:29:22 +08:00
|
|
|
SmallVector<Value> normalizedShapeSizesTorchInt;
|
|
|
|
getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt);
|
|
|
|
int64_t axis = inputRank - normalizedShapeSizesTorchInt.size();
|
|
|
|
auto reduceDimInts = llvm::to_vector<4>(llvm::seq<int64_t>(axis, inputRank));
|
|
|
|
auto reducedTy = op.getResult(1).getType();
|
|
|
|
auto sizeListType = ListType::get(IntType::get(context));
|
|
|
|
|
|
|
|
// build reduce dims
|
|
|
|
SmallVector<Value> reduceDimVals;
|
|
|
|
reduceDimVals.reserve(reduceDimInts.size());
|
|
|
|
std::transform(reduceDimInts.begin(), reduceDimInts.end(),
|
|
|
|
std::back_inserter(reduceDimVals), [&](int64_t d) {
|
|
|
|
return rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(d));
|
|
|
|
});
|
|
|
|
Value reduceDimList =
|
|
|
|
rewriter.create<PrimListConstructOp>(loc, sizeListType, reduceDimVals);
|
|
|
|
Value one = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
|
|
|
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
|
|
|
|
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
|
|
|
// mean(x)
|
|
|
|
Value inputMean = rewriter.create<AtenMeanDimOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
loc, reducedTy, op.getInput(), reduceDimList, cstTrue, none);
|
2022-09-02 09:29:22 +08:00
|
|
|
|
|
|
|
// x - mean(x)
|
|
|
|
Value inputMeanExpanded =
|
2022-12-08 04:20:41 +08:00
|
|
|
rewriter.create<AtenExpandAsOp>(loc, inputTy, inputMean, op.getInput());
|
2022-09-02 09:29:22 +08:00
|
|
|
Value inputZeroMean = rewriter.create<AtenSubTensorOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
loc, inputTy, op.getInput(), inputMeanExpanded, one);
|
2022-09-02 09:29:22 +08:00
|
|
|
// var(x) = mean((x - mean(x))^2)
|
|
|
|
Value inputZeroMeanSquare = rewriter.create<AtenMulTensorOp>(
|
|
|
|
loc, inputTy, inputZeroMean, inputZeroMean);
|
|
|
|
Value inputVar = rewriter.create<AtenMeanDimOp>(
|
|
|
|
loc, reducedTy, inputZeroMeanSquare, reduceDimList, cstTrue, none);
|
|
|
|
|
|
|
|
// rsqrt(var(x) + eps)
|
|
|
|
Value inputVarPlusEps = rewriter.create<AtenAddScalarOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
loc, reducedTy, inputVar, op.getEps(), one);
|
2022-09-02 09:29:22 +08:00
|
|
|
Value inputRsqrtVar =
|
|
|
|
rewriter.create<AtenRsqrtOp>(loc, reducedTy, inputVarPlusEps);
|
|
|
|
|
|
|
|
// (x - mean(x)) * rsqrt(var(x) + eps)
|
|
|
|
Value inputRsqrtVarExpanded = rewriter.create<AtenExpandAsOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
loc, inputTy, inputRsqrtVar, op.getInput());
|
2022-09-02 09:29:22 +08:00
|
|
|
Value inputNormalized = rewriter.create<AtenMulTensorOp>(
|
|
|
|
loc, inputTy, inputZeroMean, inputRsqrtVarExpanded);
|
|
|
|
Value out = rewriter.create<TensorStaticInfoCastOp>(
|
|
|
|
loc, op.getResult(0).getType(), inputNormalized);
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
Value weight = op.getWeight();
|
|
|
|
Value bias = op.getBias();
|
2022-09-02 09:29:22 +08:00
|
|
|
if (!weight.getType().isa<Torch::NoneType>()) {
|
|
|
|
out = rewriter.create<AtenMulTensorOp>(loc, out.getType(), out, weight);
|
|
|
|
}
|
|
|
|
if (!bias.getType().isa<Torch::NoneType>()) {
|
|
|
|
out =
|
|
|
|
rewriter.create<AtenAddTensorOp>(loc, out.getType(), out, bias, one);
|
|
|
|
}
|
|
|
|
rewriter.replaceOp(op, {out, inputMean, inputRsqrtVar});
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2021-12-14 03:01:10 +08:00
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
// Decompose `aten.emptyLike` op into `aten.size` and `aten.empty` ops.
|
2021-12-14 03:01:10 +08:00
|
|
|
class DecomposeAtenEmptyLikeOp : public OpRewritePattern<AtenEmptyLikeOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenEmptyLikeOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
auto sizeListType =
|
|
|
|
Torch::ListType::get(Torch::IntType::get(op.getContext()));
|
|
|
|
Value sizeList =
|
2022-12-08 04:20:41 +08:00
|
|
|
rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.getSelf());
|
2021-12-14 03:01:10 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
op, op.getType(), sizeList, op.getDtype(), op.getLayout(), op.getDevice(),
|
|
|
|
op.getPinMemory(), op.getMemoryFormat());
|
2021-12-14 03:01:10 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2021-12-23 21:22:45 +08:00
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
// The `aten.arange` op is converted to `aten.arange.startStep` op.
|
2021-12-23 21:22:45 +08:00
|
|
|
class DecomposeAtenArangeOp : public OpRewritePattern<AtenArangeOp> {
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenArangeOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
// The AtenArangeOp doesn't have a start and step value. Therefore we set
|
|
|
|
// them as default values 0 and 1, respectively.
|
|
|
|
Value start, step;
|
|
|
|
start = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
step = rewriter.create<Torch::ConstantIntOp>(loc,
|
|
|
|
rewriter.getI64IntegerAttr(1));
|
|
|
|
rewriter.replaceOpWithNewOp<AtenArangeStartStepOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
op, op.getType(), start, op.getEnd(), step, op.getDtype(), op.getLayout(),
|
|
|
|
op.getDevice(), op.getPinMemory());
|
2021-12-23 21:22:45 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
// The `aten.arange.start` op is converted to `aten.arange.startStep` op.
|
2021-12-23 21:22:45 +08:00
|
|
|
class DecomposeAtenArangeStartOp : public OpRewritePattern<AtenArangeStartOp> {
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenArangeStartOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
// The AtenArangeStartOp doesn't have a step value. Therefore we set it as
|
|
|
|
// default value 1.
|
|
|
|
Value step;
|
|
|
|
step = rewriter.create<Torch::ConstantIntOp>(loc,
|
|
|
|
rewriter.getI64IntegerAttr(1));
|
|
|
|
rewriter.replaceOpWithNewOp<AtenArangeStartStepOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
op, op.getType(), op.getStart(), op.getEnd(), step, op.getDtype(), op.getLayout(),
|
|
|
|
op.getDevice(), op.getPinMemory());
|
2021-12-23 21:22:45 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2021-12-21 19:51:19 +08:00
|
|
|
namespace {
|
2022-09-23 10:24:36 +08:00
|
|
|
// Decompose constant tensor full like ops.
|
2021-12-21 19:51:19 +08:00
|
|
|
template <typename OpTy, int fillVal>
|
|
|
|
class DecomposeConstantTensorAllocLikeOp : public OpRewritePattern<OpTy> {
|
|
|
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(OpTy op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
Value constVal = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(fillVal));
|
2022-09-23 10:24:36 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenFullLikeOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
op, op.getType(), op.getSelf(), constVal, op.getDtype(), op.getLayout(),
|
|
|
|
op.getDevice(), op.getPinMemory(), op.getMemoryFormat());
|
2021-12-21 19:51:19 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-02-08 00:08:10 +08:00
|
|
|
namespace {
|
|
|
|
class DecomposeAtenNativeBatchNormOp
|
|
|
|
: public OpRewritePattern<AtenNativeBatchNormOp> {
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenNativeBatchNormOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
MLIRContext *context = op.getContext();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = op.getInput();
|
|
|
|
Value weight = op.getWeight();
|
|
|
|
Value bias = op.getBias();
|
|
|
|
Value runningMean = op.getRunningMean();
|
|
|
|
Value runningVar = op.getRunningVar();
|
|
|
|
Value eps = op.getEps();
|
2022-02-08 00:08:10 +08:00
|
|
|
|
|
|
|
// TODO: Add support for `training` mode.
|
|
|
|
bool training = false;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training)) ||
|
2022-02-08 00:08:10 +08:00
|
|
|
training)
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: training mode is not supported");
|
|
|
|
|
|
|
|
// Rank of the input tensor must be greater than or equal to 2. The shape of
|
|
|
|
// the `input` is supposed to be (N, C, D?, H?, W?).
|
2022-12-20 18:17:27 +08:00
|
|
|
std::optional<unsigned> maybeInputRank = getTensorRank(input);
|
2022-12-13 00:56:28 +08:00
|
|
|
if (!maybeInputRank || *maybeInputRank < 2)
|
2022-02-08 00:08:10 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "input must have rank greater than or equal to 2");
|
2022-12-13 00:56:28 +08:00
|
|
|
unsigned inputRank = *maybeInputRank;
|
2022-02-08 00:08:10 +08:00
|
|
|
|
|
|
|
// In the inference mode, the `runningMean` and `runningVar` must not be
|
|
|
|
// None.
|
|
|
|
if (runningMean.getType().isa<Torch::NoneType>() ||
|
|
|
|
runningVar.getType().isa<Torch::NoneType>())
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "running stats must not be None in inference mode");
|
|
|
|
|
|
|
|
// Rank of `runningMean` and `runningVar` must be exactly 1.
|
2022-12-20 18:17:27 +08:00
|
|
|
std::optional<unsigned> runningMeanRank = getTensorRank(runningMean);
|
|
|
|
std::optional<unsigned> runningVarRank = getTensorRank(runningVar);
|
2022-12-13 00:56:28 +08:00
|
|
|
if (!runningMeanRank || !runningVarRank || *runningMeanRank != 1 ||
|
|
|
|
*runningVarRank != 1)
|
2022-02-08 00:08:10 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
2022-12-08 04:20:41 +08:00
|
|
|
op, "expected runningMean and runningVar to be rank 1");
|
2022-02-08 00:08:10 +08:00
|
|
|
|
|
|
|
Value zero =
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
Value one =
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
Value numFeatures = rewriter.create<AtenSizeIntOp>(loc, input, /*dim=*/one);
|
2022-02-25 03:41:55 +08:00
|
|
|
// TODO: Add Runtime Asserts to check the shape of weight, bias,
|
2022-12-08 04:20:41 +08:00
|
|
|
// runningMean and runningVar to be (numFeatures).
|
2022-02-08 00:08:10 +08:00
|
|
|
|
|
|
|
// The `runningMean` and `runningVar` must be reshaped to (1, C, 1?, 1?, 1?)
|
|
|
|
// to make it broadcast-compatible with (N, C, D?, H?, W?).
|
|
|
|
// 1. runningMean = runningMean.view(1, C, 1?, 1?, 1?)
|
|
|
|
// 2. runningVar = runningVar.view(1, C, 1?, 1?, 1?)
|
|
|
|
SmallVector<Value> runningStatsShape(inputRank, one);
|
|
|
|
runningStatsShape[1] = numFeatures;
|
|
|
|
Value runningStatsSizeList = rewriter.create<PrimListConstructOp>(
|
|
|
|
loc, ListType::get(IntType::get(context)), runningStatsShape);
|
|
|
|
|
|
|
|
SmallVector<int64_t> runningStatsShapeInt(inputRank, 1);
|
2022-11-29 20:33:31 +08:00
|
|
|
runningStatsShapeInt[1] = kUnknownSize;
|
2023-01-04 06:19:18 +08:00
|
|
|
Type dtype = input.getType().cast<ValueTensorType>().getOptionalDtype();
|
2022-02-08 00:08:10 +08:00
|
|
|
Type reshapeType = ValueTensorType::get(
|
|
|
|
context, llvm::makeArrayRef(runningStatsShapeInt), dtype);
|
|
|
|
|
|
|
|
runningMean = rewriter.create<AtenViewOp>(loc, reshapeType, runningMean,
|
|
|
|
runningStatsSizeList);
|
|
|
|
runningVar = rewriter.create<AtenViewOp>(loc, reshapeType, runningVar,
|
|
|
|
runningStatsSizeList);
|
|
|
|
|
|
|
|
// normalizedInput = (input - runningMean) / (sqrt(runningVar + eps)).
|
|
|
|
Value inputSubMean = rewriter.create<AtenSubTensorOp>(
|
|
|
|
loc, input.getType(), input, runningMean, /*alpha=*/one);
|
|
|
|
Value varEps = rewriter.create<AtenAddScalarOp>(
|
|
|
|
loc, runningVar.getType(), runningVar, eps, /*alpha=*/one);
|
|
|
|
Value invStd = rewriter.create<AtenRsqrtOp>(loc, varEps.getType(), varEps);
|
|
|
|
Value normalizedInput = rewriter.create<AtenMulTensorOp>(
|
|
|
|
loc, inputSubMean.getType(), inputSubMean, invStd);
|
|
|
|
|
|
|
|
// The `weight` and `bias` must be reshaped to (1, C, 1?, 1?, 1?) to make it
|
|
|
|
// broadcast-compatible with (N, C, D?, H?, W?).
|
|
|
|
// 1. weight = weight.view(1, C, 1?, 1?, 1?)
|
|
|
|
// 2. bias = bias.view(1, C, 1?, 1?, 1?)
|
|
|
|
// 3. output = normalizedInput * weight + bias
|
|
|
|
Value batchNormOutput = normalizedInput;
|
|
|
|
if (!weight.getType().isa<Torch::NoneType>()) {
|
2022-02-25 03:41:55 +08:00
|
|
|
// Rank of `weight` must be exactly 1.
|
2022-12-20 18:17:27 +08:00
|
|
|
std::optional<unsigned> weightRank = getTensorRank(weight);
|
2022-12-13 00:56:28 +08:00
|
|
|
if (!weightRank || *weightRank != 1)
|
2022-02-08 00:08:10 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "expected weight to be rank 1");
|
|
|
|
weight = rewriter.create<AtenViewOp>(loc, reshapeType, weight,
|
|
|
|
runningStatsSizeList);
|
|
|
|
batchNormOutput = rewriter.create<AtenMulTensorOp>(
|
|
|
|
loc, batchNormOutput.getType(), batchNormOutput, weight);
|
|
|
|
}
|
|
|
|
if (!bias.getType().isa<Torch::NoneType>()) {
|
2022-02-25 03:41:55 +08:00
|
|
|
// Rank of `bias` must be exactly 1.
|
2022-12-20 18:17:27 +08:00
|
|
|
std::optional<unsigned> biasRank = getTensorRank(bias);
|
2022-12-13 00:56:28 +08:00
|
|
|
if (!biasRank || *biasRank != 1)
|
2022-02-08 00:08:10 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "expected bias to be rank 1");
|
|
|
|
bias = rewriter.create<AtenViewOp>(loc, reshapeType, bias,
|
|
|
|
runningStatsSizeList);
|
|
|
|
batchNormOutput = rewriter.create<AtenAddTensorOp>(
|
|
|
|
loc, batchNormOutput.getType(), batchNormOutput, bias, /*alpha=*/one);
|
|
|
|
}
|
|
|
|
|
|
|
|
// The `mean` and `invstd` outputs are empty tensors in inference mode.
|
|
|
|
Value zeroList = rewriter.create<PrimListConstructOp>(
|
|
|
|
loc, Torch::ListType::get(zero.getType()), zero);
|
|
|
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
Value emptyMeanTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
|
|
|
|
loc, op.getType(1), zeroList, /*dtype=*/none, /*layout=*/none,
|
2022-12-08 04:20:41 +08:00
|
|
|
/*device=*/none, /*pinMemory=*/none, /*memoryFormat=*/none);
|
2022-02-08 00:08:10 +08:00
|
|
|
Value emptyInvStdTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
|
|
|
|
loc, op.getType(2), zeroList, /*dtype=*/none, /*layout=*/none,
|
2022-12-08 04:20:41 +08:00
|
|
|
/*device=*/none, /*pinMemory=*/none, /*memoryFormat=*/none);
|
2022-02-08 00:08:10 +08:00
|
|
|
|
|
|
|
rewriter.replaceOp(op,
|
|
|
|
{batchNormOutput, emptyMeanTensor, emptyInvStdTensor});
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
// Decompse `Aten_UnsafeViewOp` into `AtenViewOp`. UnsafeView() differs from
|
2022-02-10 16:11:05 +08:00
|
|
|
// view() in that the returned tensor isn't treated as a view for the purposes
|
|
|
|
// of automatic differentiation. It's only safe to use if the `self` tensor is
|
|
|
|
// temporary. For example, the viewed tensor here (a + b) is discarded
|
|
|
|
// immediately after viewing:
|
|
|
|
//
|
2022-12-08 04:20:41 +08:00
|
|
|
// res = UnsafeView(a + b, size);
|
2022-02-10 16:11:05 +08:00
|
|
|
//
|
|
|
|
// This is a hack because in-place operations on tensors treated like views
|
|
|
|
// can be much more expensive than the same operations on non-view tensors.
|
|
|
|
|
|
|
|
// Refer to
|
|
|
|
// https://github.com/pytorch/pytorch/blob/364055b2771ecf9b54f1d67a8bf44bb5496476d4/aten/src/ATen/native/TensorShape.cpp#L2072
|
|
|
|
namespace {
|
|
|
|
class DecomposeAten_UnsafeViewOp : public OpRewritePattern<Aten_UnsafeViewOp> {
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(Aten_UnsafeViewOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenViewOp>(op, op.getType(), op.getSelf(),
|
|
|
|
op.getSize());
|
2022-02-10 16:11:05 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
// In PyTorch, ReshapeAlias just uses an already computed stride.
|
2022-03-29 12:54:28 +08:00
|
|
|
// See
|
|
|
|
// https://github.com/pytorch/pytorch/blob/d8c31a819d4a65e732b5901e3b994e1869851f1a/aten/src/ATen/native/TensorShape.cpp#L1153
|
|
|
|
// Note that this is the same decomposition as in AOTAutograd
|
2022-12-08 04:20:41 +08:00
|
|
|
// https://github.com/pytorch/functorch/blob/a3042d94e616d4143813668b1372d9d4545be14e/functorch/Src/aotAutograd.py#L104
|
2022-03-29 12:54:28 +08:00
|
|
|
namespace {
|
2022-05-13 20:06:24 +08:00
|
|
|
class DecomposeAten_ReshapeAliasOp
|
|
|
|
: public OpRewritePattern<Aten_ReshapeAliasOp> {
|
2022-03-29 12:54:28 +08:00
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(Aten_ReshapeAliasOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenViewOp>(op, op.getType(), op.getSelf(),
|
|
|
|
op.getSize());
|
2022-03-29 12:54:28 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-02-28 14:14:40 +08:00
|
|
|
namespace {
|
|
|
|
// Decompose constant tensor like ops.
|
|
|
|
template <typename OpTy, typename NewOpTy>
|
|
|
|
class DecomposeConstantTensorNewLikeOp : public OpRewritePattern<OpTy> {
|
|
|
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(OpTy op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
Value dtype = op.getDtype();
|
2022-03-25 00:40:21 +08:00
|
|
|
if (dtype.getType().isa<Torch::NoneType>()) {
|
|
|
|
BaseTensorType tensorType =
|
2022-12-08 04:20:41 +08:00
|
|
|
op.getSelf().getType().template cast<BaseTensorType>();
|
2023-01-04 06:19:18 +08:00
|
|
|
if (!tensorType.hasDtype()) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "expected input tensor to have a dtype");
|
|
|
|
}
|
2022-03-25 00:40:21 +08:00
|
|
|
dtype =
|
|
|
|
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
|
|
|
|
}
|
2022-12-08 04:20:41 +08:00
|
|
|
rewriter.replaceOpWithNewOp<NewOpTy>(op, op.getType(), op.getSize(), dtype,
|
|
|
|
op.getLayout(), op.getDevice(),
|
|
|
|
op.getPinMemory());
|
2022-02-28 14:14:40 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-03-03 21:41:14 +08:00
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
// Decompose `aten.full` op into `aten.broadcastTo`
|
2022-03-03 21:41:14 +08:00
|
|
|
class DecomposeAtenFullOp : public OpRewritePattern<AtenFullOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenFullOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
2022-09-23 10:24:36 +08:00
|
|
|
BaseTensorType outTy = op.getType().template cast<BaseTensorType>();
|
2023-01-04 06:19:18 +08:00
|
|
|
if (!outTy.hasDtype()) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "expected result type to have a dtype");
|
|
|
|
}
|
2022-09-23 10:24:36 +08:00
|
|
|
SmallVector<int64_t> empty;
|
|
|
|
auto dtype =
|
2022-12-08 04:20:41 +08:00
|
|
|
getTypeForTorchType(op.getContext(), op.getFillValue().getType());
|
2022-09-23 10:24:36 +08:00
|
|
|
Type tensorType =
|
|
|
|
outTy.getWithSizesAndDtype(llvm::makeArrayRef(empty), dtype);
|
|
|
|
Value fillVal = rewriter.create<PrimNumToTensorScalarOp>(loc, tensorType,
|
2022-12-08 04:20:41 +08:00
|
|
|
op.getFillValue());
|
2022-09-23 10:24:36 +08:00
|
|
|
fillVal = convertTensorToDtype(rewriter, loc, fillVal, outTy.getDtype());
|
|
|
|
rewriter.replaceOpWithNewOp<AtenBroadcastToOp>(op, op.getType(), fillVal,
|
2022-12-08 04:20:41 +08:00
|
|
|
op.getSize());
|
2022-03-03 21:41:14 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-05-19 00:29:04 +08:00
|
|
|
namespace {
|
|
|
|
// Decompose `aten.linear` op into `aten.matmul` and `aten.add` ops.
|
|
|
|
class DecomposeAtenLinearOp : public OpRewritePattern<AtenLinearOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenLinearOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = op.getInput();
|
|
|
|
Value weight = op.getWeight();
|
|
|
|
Value bias = op.getBias();
|
2022-05-19 00:29:04 +08:00
|
|
|
|
|
|
|
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
if (!inputType.hasSizes() || inputType.getSizes().size() < 2)
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "expected input to be rank 2 or greater");
|
|
|
|
|
|
|
|
BaseTensorType weightType = weight.getType().cast<BaseTensorType>();
|
|
|
|
// `weight` must be a rank 2 matrix.
|
|
|
|
if (!weightType.hasSizes() || weightType.getSizes().size() != 2)
|
|
|
|
return rewriter.notifyMatchFailure(op, "expected weight to be a rank 2");
|
|
|
|
|
|
|
|
SmallVector<int64_t> transposeShape =
|
|
|
|
llvm::to_vector(llvm::reverse(weightType.getSizes()));
|
|
|
|
Type transposeType = weightType.getWithSizesAndDtype(
|
2023-01-04 06:19:18 +08:00
|
|
|
llvm::makeArrayRef(transposeShape), weightType.getOptionalDtype());
|
2022-05-19 00:29:04 +08:00
|
|
|
Value transposeWeight =
|
|
|
|
rewriter.create<AtenTOp>(loc, transposeType, weight);
|
|
|
|
|
|
|
|
Value matmul = rewriter.create<AtenMatmulOp>(loc, op.getType(), input,
|
|
|
|
transposeWeight);
|
|
|
|
if (bias.getType().isa<Torch::NoneType>()) {
|
|
|
|
rewriter.replaceOp(op, matmul);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
BaseTensorType biasType = bias.getType().cast<BaseTensorType>();
|
|
|
|
if (!biasType.hasSizes() || biasType.getSizes().size() != 1)
|
|
|
|
return rewriter.notifyMatchFailure(op, "expected bias to be rank 1");
|
|
|
|
|
|
|
|
Value alpha =
|
|
|
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
|
|
|
|
rewriter.replaceOpWithNewOp<AtenAddTensorOp>(op, op.getType(), matmul,
|
2022-12-08 04:20:41 +08:00
|
|
|
op.getBias(), alpha);
|
2022-05-19 00:29:04 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-10-12 05:03:10 +08:00
|
|
|
namespace {
|
|
|
|
// Decompose `aten.mish` op into `aten.tanh` and `aten.softplus` ops.
|
|
|
|
// Mish(x) = x * Tanh(Softplus(x))
|
|
|
|
class DecomposeAtenMishOp : public OpRewritePattern<AtenMishOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenMishOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = op.getSelf();
|
2022-10-12 05:03:10 +08:00
|
|
|
Type type = op.getType();
|
|
|
|
|
|
|
|
auto inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
if (!inputType.hasDtype())
|
|
|
|
return rewriter.notifyMatchFailure(op, "Dtype not present");
|
|
|
|
|
|
|
|
Type dType = inputType.getDtype();
|
|
|
|
// Form default Value tensors for `beta` and `threshold` operands
|
|
|
|
// of `aten.softplus` op.
|
|
|
|
Value beta = getConstantWithGivenDtypeAndValue(rewriter, loc, 1.0, dType);
|
|
|
|
Value threshold =
|
|
|
|
getConstantWithGivenDtypeAndValue(rewriter, loc, 20.0, dType);
|
|
|
|
Value softplusOp =
|
|
|
|
rewriter.create<AtenSoftplusOp>(loc, type, input, beta, threshold);
|
|
|
|
Value tanhOp = rewriter.create<AtenTanhOp>(loc, type, softplusOp);
|
|
|
|
rewriter.replaceOpWithNewOp<AtenMulTensorOp>(op, type, input, tanhOp);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-03-03 22:25:22 +08:00
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
// Decompose `aten.fullLike` op into `aten.emptyLike` and `aten.fill` ops.
|
2022-03-03 22:25:22 +08:00
|
|
|
class DecomposeAtenFullLikeOp : public OpRewritePattern<AtenFullLikeOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenFullLikeOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-09-23 10:24:36 +08:00
|
|
|
BaseTensorType outTy = op.getType().template cast<BaseTensorType>();
|
2023-01-04 06:19:18 +08:00
|
|
|
if (!outTy.hasDtype()) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "expected result type to have a dtype");
|
|
|
|
}
|
2022-09-23 10:24:36 +08:00
|
|
|
SmallVector<int64_t> empty;
|
|
|
|
auto dtype =
|
2022-12-08 04:20:41 +08:00
|
|
|
getTypeForTorchType(op.getContext(), op.getFillValue().getType());
|
2022-09-23 10:24:36 +08:00
|
|
|
Type tensorType =
|
|
|
|
outTy.getWithSizesAndDtype(llvm::makeArrayRef(empty), dtype);
|
|
|
|
Value fillVal = rewriter.create<PrimNumToTensorScalarOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
op.getLoc(), tensorType, op.getFillValue());
|
2022-09-23 10:24:36 +08:00
|
|
|
fillVal =
|
|
|
|
convertTensorToDtype(rewriter, op.getLoc(), fillVal, outTy.getDtype());
|
|
|
|
rewriter.replaceOpWithNewOp<AtenExpandAsOp>(op, op.getType(), fillVal,
|
2022-12-08 04:20:41 +08:00
|
|
|
op.getSelf());
|
2022-03-03 22:25:22 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-03-10 23:18:08 +08:00
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
// Decompose `aten.indexPut` op into `valsem.aten.indexPutImpl` op.
|
2022-03-10 23:18:08 +08:00
|
|
|
class DecomposeAtenIndexPutOp : public OpRewritePattern<AtenIndexPutOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenIndexPutOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
2022-10-28 23:06:11 +08:00
|
|
|
rewriter.replaceOpWithNewOp<Aten_IndexPutImplOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(), op.getAccumulate(),
|
2022-03-10 23:18:08 +08:00
|
|
|
/*unsafe=*/cstFalse);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-03-14 16:12:37 +08:00
|
|
|
namespace {
|
|
|
|
class DecomposeAtenExpandAsOp : public OpRewritePattern<AtenExpandAsOp> {
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenExpandAsOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
auto sizeListType =
|
|
|
|
Torch::ListType::get(Torch::IntType::get(op.getContext()));
|
|
|
|
Value sizeList =
|
2022-12-08 04:20:41 +08:00
|
|
|
rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.getOther());
|
|
|
|
rewriter.replaceOpWithNewOp<AtenBroadcastToOp>(op, op.getType(), op.getSelf(),
|
2022-03-14 16:12:37 +08:00
|
|
|
sizeList);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-03-17 21:35:17 +08:00
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
// Decompose `aten.ToCopy` op into `valsem.aten.copy` op.
|
2022-03-17 21:35:17 +08:00
|
|
|
class DecomposeAten_ToCopyOp : public OpRewritePattern<Aten_ToCopyOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(Aten_ToCopyOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
2023-01-04 06:19:18 +08:00
|
|
|
auto resultType = op.getType().cast<BaseTensorType>();
|
|
|
|
if (!resultType.hasDtype()) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "expected result type to have a dtype");
|
|
|
|
}
|
|
|
|
Type resultDtype = resultType.getDtype();
|
2022-10-04 21:05:59 +08:00
|
|
|
Value zero = getConstantWithGivenDtypeAndValue(rewriter, op.getLoc(), 0.0,
|
|
|
|
resultDtype);
|
2022-09-23 10:24:36 +08:00
|
|
|
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
op.getLoc(), op.getType(), op.getSelf(), zero, op.getDtype(), op.getLayout(),
|
|
|
|
op.getDevice(), op.getPinMemory(), op.getMemoryFormat());
|
2022-10-28 23:06:11 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenCopyOp>(op, op.getType(), emptyTensor,
|
2022-12-08 04:20:41 +08:00
|
|
|
op.getSelf(), op.getNonBlocking());
|
2022-03-17 21:35:17 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-12-22 10:13:59 +08:00
|
|
|
namespace {
|
|
|
|
// Decompose `aten.copy` op into `aten.to.dtype` and `aten.expand_as`.
|
|
|
|
class DecomposeAtenCopyOp : public OpRewritePattern<AtenCopyOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenCopyOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
2023-01-04 06:19:18 +08:00
|
|
|
auto resultType = op.getType().cast<BaseTensorType>();
|
|
|
|
if (!resultType.hasDtype()) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "expected result type to have a dtype");
|
|
|
|
}
|
|
|
|
Type resultDtype = resultType.getDtype();
|
2022-12-22 10:13:59 +08:00
|
|
|
Value srcToDtype =
|
|
|
|
convertTensorToDtype(rewriter, op.getLoc(), op.getSrc(), resultDtype);
|
|
|
|
rewriter.replaceOpWithNewOp<AtenExpandAsOp>(op, op.getType(), srcToDtype,
|
|
|
|
op.getSelf());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-03-25 00:40:21 +08:00
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
// Decompose `aten.newEmpty` op into `aten.empty.memoryFormat` op.
|
2022-03-25 00:40:21 +08:00
|
|
|
class DecomposeAtenNewEmptyOp : public OpRewritePattern<AtenNewEmptyOp> {
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenNewEmptyOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
|
2022-12-08 04:20:41 +08:00
|
|
|
Value dtype = op.getDtype();
|
2022-03-25 00:40:21 +08:00
|
|
|
if (dtype.getType().isa<Torch::NoneType>()) {
|
2022-12-08 04:20:41 +08:00
|
|
|
BaseTensorType tensorType = op.getSelf().getType().cast<BaseTensorType>();
|
2023-01-04 06:19:18 +08:00
|
|
|
if (!tensorType.hasDtype()) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "expected input tensor to have a dtype");
|
|
|
|
}
|
2022-03-25 00:40:21 +08:00
|
|
|
dtype =
|
|
|
|
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
|
|
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
op, op.getType(), op.getSize(), dtype, op.getLayout(), op.getDevice(),
|
|
|
|
op.getPinMemory(), /*memoryFormat=*/noneVal);
|
2022-03-25 00:40:21 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-03-24 15:12:59 +08:00
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
// Decompose `aten.indexPut.hackedTwin` op into `valsem.aten.indexPutImpl`
|
2022-03-24 15:12:59 +08:00
|
|
|
// op.
|
|
|
|
class DecomposeAtenIndexPutHackedTwinOp
|
|
|
|
: public OpRewritePattern<AtenIndexPutHackedTwinOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenIndexPutHackedTwinOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
2022-10-28 23:06:11 +08:00
|
|
|
rewriter.replaceOpWithNewOp<Aten_IndexPutImplOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(), op.getAccumulate(),
|
2022-03-24 15:12:59 +08:00
|
|
|
/*unsafe=*/cstFalse);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-04-26 20:18:09 +08:00
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
// Decompose `aten.pad` op into `aten.constantPadNd` op.
|
2022-04-26 20:18:09 +08:00
|
|
|
class DecomposeAtenPadOp : public OpRewritePattern<AtenPadOp> {
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenPadOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
Value value = op.getValue();
|
2022-04-26 20:18:09 +08:00
|
|
|
if (value.getType().isa<Torch::OptionalType>())
|
|
|
|
return rewriter.notifyMatchFailure(op, "optional type not supported");
|
|
|
|
if (value.getType().isa<Torch::NoneType>())
|
|
|
|
value = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
op.getLoc(), rewriter.getF64FloatAttr(0));
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenConstantPadNdOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
op, op.getType(), op.getSelf(), op.getPad(), value);
|
2022-04-26 20:18:09 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-04-27 19:07:40 +08:00
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
// Decompose `aten.to.dtypeLayout` op into `aten.to.dtype` op.
|
2022-04-27 19:07:40 +08:00
|
|
|
class DecomposeAtenToDtypeLayoutOp
|
|
|
|
: public OpRewritePattern<AtenToDtypeLayoutOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenToDtypeLayoutOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
// TODO: Add support for pinMemory arg equal to `True`.
|
|
|
|
if (!op.getPinMemory().getType().isa<Torch::NoneType>()) {
|
2022-04-27 19:07:40 +08:00
|
|
|
bool pinMemory;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)))
|
2022-04-27 19:07:40 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
2022-12-08 04:20:41 +08:00
|
|
|
op, "unimplemented: pinMemory must be a constant");
|
2022-04-27 19:07:40 +08:00
|
|
|
else if (pinMemory)
|
|
|
|
return rewriter.notifyMatchFailure(
|
2022-12-08 04:20:41 +08:00
|
|
|
op, "unimplemented: pinMemory is expected to be false");
|
2022-04-27 19:07:40 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
// TODO: Add support for non-None device arg.
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!op.getDevice().getType().isa<Torch::NoneType>()) {
|
2022-04-27 19:07:40 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: device arg must be None");
|
|
|
|
}
|
|
|
|
|
|
|
|
// TODO: Add support for non-strided layout.
|
|
|
|
// torch.layout is by default strided i.e. 0.
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!op.getLayout().getType().isa<Torch::NoneType>()) {
|
2022-04-27 19:07:40 +08:00
|
|
|
int64_t tensorLayout;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout)))
|
2022-04-27 19:07:40 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: layout must be a constant");
|
|
|
|
else if (tensorLayout != torch_upstream::Layout::Strided)
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: layout is expected to be strided");
|
|
|
|
}
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(op, op.getType(), op.getSelf(),
|
|
|
|
op.getDtype(), op.getNonBlocking(),
|
|
|
|
op.getCopy(), op.getMemoryFormat());
|
2022-04-27 19:07:40 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-08-11 07:24:02 +08:00
|
|
|
namespace {
|
|
|
|
// Decompose `aten.to.device` op into `aten.to.dtype` op.
|
|
|
|
class DecomposeAtenToDeviceOp : public OpRewritePattern<AtenToDeviceOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenToDeviceOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
// Device information isn't relevant to torch-mlir, so we can drop that info
|
|
|
|
// here.
|
2022-12-08 04:20:41 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(op, op.getType(), op.getSelf(),
|
|
|
|
op.getDtype(), op.getNonBlocking(),
|
|
|
|
op.getCopy(), op.getMemoryFormat());
|
2022-08-11 07:24:02 +08:00
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-05-13 20:06:24 +08:00
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
// Decompose `aten.adaptiveAvgPool2d` op into `aten.avgPool2d` op.
|
2022-05-13 20:06:24 +08:00
|
|
|
//
|
|
|
|
// For AdaptiveAvgPool2d op, when the input size is an integer multiple of
|
2022-12-08 04:20:41 +08:00
|
|
|
// output size the kernelSize, stride and padding is calculated as follows:
|
2022-05-13 20:06:24 +08:00
|
|
|
// strideH = inH // outH
|
|
|
|
// strideW = inH // outH
|
|
|
|
// kernelH = inH - [(outH - 1) * strideH]
|
|
|
|
// kernelW = inW - [(outW - 1) * strideW]
|
|
|
|
// paddingH = 0, paddingW = 0
|
|
|
|
//
|
|
|
|
// For the special case, when the output size is one for all dimensions,
|
|
|
|
// the kernel size is same as the input size.
|
|
|
|
class DecomposeAtenAdaptiveAvgPool2dOp
|
|
|
|
: public OpRewritePattern<AtenAdaptiveAvgPool2dOp> {
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenAdaptiveAvgPool2dOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
MLIRContext *context = op.getContext();
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = op.getSelf();
|
2022-12-20 18:17:27 +08:00
|
|
|
std::optional<unsigned> maybeRank = getTensorRank(input);
|
2022-12-13 00:56:28 +08:00
|
|
|
if (!maybeRank) {
|
|
|
|
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
|
|
|
|
}
|
|
|
|
unsigned rank = *maybeRank;
|
2022-05-13 20:06:24 +08:00
|
|
|
SmallVector<Value, 2> inputHW;
|
|
|
|
Value dimH = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(rank - 2));
|
|
|
|
inputHW.push_back(
|
|
|
|
/*inH=*/rewriter.create<AtenSizeIntOp>(loc, input, dimH));
|
|
|
|
Value dimW = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(rank - 1));
|
|
|
|
inputHW.push_back(
|
|
|
|
/*inW=*/rewriter.create<AtenSizeIntOp>(loc, input, dimW));
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
Value outputShape = op.getOutputSize();
|
2022-05-13 20:06:24 +08:00
|
|
|
SmallVector<Value> outputShapeSizesTorchInt;
|
|
|
|
getListConstructElements(outputShape, outputShapeSizesTorchInt);
|
|
|
|
|
|
|
|
// TODO: Add support for cases other than:
|
|
|
|
// 1.) inH == outH and inW == outW.
|
|
|
|
// 2.) outH == outW == 1
|
|
|
|
bool unitOutputSize = true;
|
|
|
|
for (Value outShape : outputShapeSizesTorchInt) {
|
|
|
|
int64_t outShapeInt;
|
|
|
|
if (!matchPattern(outShape, m_TorchConstantInt(&outShapeInt))) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "output size is expected to be a constant");
|
|
|
|
}
|
|
|
|
if (outShapeInt != 1) {
|
|
|
|
unitOutputSize = false;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
Value constantFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
|
|
|
Value constantTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
|
|
|
|
Value constantNone = rewriter.create<Torch::ConstantNoneOp>(loc);
|
|
|
|
SmallVector<Value, 2> kernelSize;
|
|
|
|
|
|
|
|
for (unsigned i = 0; i < inputHW.size(); i++) {
|
|
|
|
if (unitOutputSize) {
|
|
|
|
BaseTensorType inputTensorType = input.getType().cast<BaseTensorType>();
|
|
|
|
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
|
|
|
|
kernelSize.push_back(inputShape[rank - 2 + i] == kUnknownSize
|
|
|
|
? inputHW[i]
|
|
|
|
: rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(
|
|
|
|
inputShape[rank - 2 + i])));
|
|
|
|
} else {
|
|
|
|
Value cond = rewriter.create<AtenEqIntOp>(loc, inputHW[i],
|
|
|
|
outputShapeSizesTorchInt[i]);
|
|
|
|
rewriter.create<RuntimeAssertOp>(
|
|
|
|
loc, cond,
|
|
|
|
"unimplemented: only support cases where input and output size are "
|
|
|
|
"equal for non-unit output size");
|
|
|
|
|
|
|
|
Value outMinusOne = rewriter.create<AtenSubIntOp>(
|
|
|
|
loc, outputShapeSizesTorchInt[i], constantOne);
|
|
|
|
kernelSize.push_back(
|
|
|
|
rewriter.create<AtenSubIntOp>(loc, inputHW[i], outMinusOne));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
Value kernelSizeList = rewriter.create<PrimListConstructOp>(
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize);
|
|
|
|
// Currently we only support cases where input size is equal to the output
|
|
|
|
// size or unit output size. For the former case, stride is always equal to
|
|
|
|
// one and for the latter the stride value doesn't matter, since the kernel
|
|
|
|
// size is same as the input size. Therfore, keeping the stride as one for
|
|
|
|
// the latter case as well for the ease of implementation.
|
|
|
|
Value strideList = rewriter.create<PrimListConstructOp>(
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(context)),
|
|
|
|
ValueRange{constantOne, constantOne});
|
|
|
|
Value paddingSizeList = rewriter.create<PrimListConstructOp>(
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(context)),
|
|
|
|
ValueRange{constantZero, constantZero});
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenAvgPool2dOp>(
|
|
|
|
op, op.getType(), input, kernelSizeList, strideList, paddingSizeList,
|
2022-12-08 04:20:41 +08:00
|
|
|
/*ceilMode=*/constantFalse, /*countIncludePad=*/constantTrue,
|
|
|
|
/*divisorOverride=*/constantNone);
|
2022-05-13 20:06:24 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-06-03 15:41:13 +08:00
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
// Decompose `aten.clampMin` op into `aten.clamp` op.
|
2022-06-03 15:41:13 +08:00
|
|
|
class DecomposeAtenClampMinOp : public OpRewritePattern<AtenClampMinOp> {
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenClampMinOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Value constantNone = rewriter.create<Torch::ConstantNoneOp>(op.getLoc());
|
2022-12-08 04:20:41 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenClampOp>(op, op.getType(), op.getSelf(),
|
|
|
|
op.getMin(), /*max=*/constantNone);
|
2022-06-03 15:41:13 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
// Decompose `aten.clampMax` op into `aten.clamp` op.
|
2022-06-03 15:41:13 +08:00
|
|
|
class DecomposeAtenClampMaxOp : public OpRewritePattern<AtenClampMaxOp> {
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenClampMaxOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Value constantNone = rewriter.create<Torch::ConstantNoneOp>(op.getLoc());
|
2022-12-08 04:20:41 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenClampOp>(op, op.getType(), op.getSelf(),
|
|
|
|
/*min=*/constantNone, op.getMax());
|
2022-06-03 15:41:13 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-05-30 16:08:54 +08:00
|
|
|
namespace {
|
|
|
|
// Decompose `aten.baddbmm` op into `aten.bmm`, `aten.mul.Scalar`, and
|
|
|
|
// `aten.add.Tensor` op.
|
|
|
|
class DecomposeAtenBaddbmmOp : public OpRewritePattern<AtenBaddbmmOp> {
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenBaddbmmOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
Value bmm =
|
2022-12-08 04:20:41 +08:00
|
|
|
rewriter.create<AtenBmmOp>(loc, op.getType(), op.getBatch1(), op.getBatch2());
|
2022-05-30 16:08:54 +08:00
|
|
|
Value alphaTimesBmm =
|
2022-12-08 04:20:41 +08:00
|
|
|
rewriter.create<AtenMulScalarOp>(loc, op.getType(), bmm, op.getAlpha());
|
|
|
|
Value input = op.getSelf();
|
2022-05-30 16:08:54 +08:00
|
|
|
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
BaseTensorType resultType =
|
|
|
|
op->getResult(0).getType().cast<BaseTensorType>();
|
|
|
|
if (inputType.hasDtype() && resultType.hasDtype() &&
|
|
|
|
inputType.getDtype() != resultType.getDtype()) {
|
|
|
|
input = convertTensorToDtype(rewriter, loc, input, resultType.getDtype());
|
|
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<AtenAddTensorOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
op, op.getType(), alphaTimesBmm, op.getSelf(), op.getBeta());
|
2022-05-30 16:08:54 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-06-09 14:09:28 +08:00
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
// Decompose `aten.floorDivide` op into `aten.div.TensorMode` op.
|
2022-06-09 14:09:28 +08:00
|
|
|
class DecomposeAtenFloorDivideOp : public OpRewritePattern<AtenFloorDivideOp> {
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenFloorDivideOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
// https://pytorch.org/docs/stable/generated/torch.floorDivide.html
|
|
|
|
// PyTorch aten.floorDivide is a misnomer because it actually rounds
|
2022-08-06 23:38:06 +08:00
|
|
|
// the quotient towards zero instead of taking its floor.
|
2022-06-09 14:09:28 +08:00
|
|
|
Value cstStrFloor =
|
2022-08-06 23:38:06 +08:00
|
|
|
rewriter.create<Torch::ConstantStrOp>(op.getLoc(), "trunc");
|
2022-06-09 14:09:28 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenDivTensorModeOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
op, op.getType(), op.getSelf(), op.getOther(),
|
|
|
|
/*roundingMode=*/cstStrFloor);
|
2022-06-09 14:09:28 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-06-03 20:38:59 +08:00
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
// Decompose `aten.numpyT` op into `aten.permute` op.
|
2022-06-03 20:38:59 +08:00
|
|
|
class DecomposeAtenNumpyTOp : public OpRewritePattern<AtenNumpyTOp> {
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenNumpyTOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value self = op.getSelf();
|
2022-12-20 18:17:27 +08:00
|
|
|
std::optional<unsigned> maybeInputRank = getTensorRank(self);
|
2022-12-13 00:56:28 +08:00
|
|
|
if (!maybeInputRank) {
|
|
|
|
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
|
|
|
|
}
|
|
|
|
unsigned inputRank = *maybeInputRank;
|
2022-06-03 20:38:59 +08:00
|
|
|
|
|
|
|
SmallVector<Value> dimListElements;
|
2022-12-13 00:56:28 +08:00
|
|
|
SmallVector<int> dimListInts(llvm::reverse(
|
|
|
|
llvm::iota_range<int>(0, inputRank, /*inclusive=*/false)));
|
|
|
|
for (int dimListInt : dimListInts) {
|
2022-06-03 20:38:59 +08:00
|
|
|
dimListElements.push_back(rewriter.create<Torch::ConstantIntOp>(
|
2022-12-13 00:56:28 +08:00
|
|
|
loc, rewriter.getI64IntegerAttr(dimListInt)));
|
|
|
|
}
|
2022-06-03 20:38:59 +08:00
|
|
|
Value dimList = rewriter.create<PrimListConstructOp>(
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(op->getContext())),
|
|
|
|
dimListElements);
|
|
|
|
rewriter.replaceOpWithNewOp<AtenPermuteOp>(op, op.getType(), self, dimList);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-07-22 20:42:14 +08:00
|
|
|
template <typename OpTy>
|
|
|
|
static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
|
|
|
|
bool unbiased, int64_t correction) {
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value self = op.getSelf();
|
|
|
|
Value dimList = op.getDim();
|
|
|
|
Value keepDim = op.getKeepdim();
|
2022-07-22 20:42:14 +08:00
|
|
|
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
|
|
|
|
Type outputType = op.getType();
|
|
|
|
BaseTensorType outputTensorType = outputType.cast<BaseTensorType>();
|
2023-01-04 06:19:18 +08:00
|
|
|
if (!outputTensorType.hasDtype()) {
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"expected result type to have a dtype");
|
|
|
|
}
|
2022-12-08 01:51:37 +08:00
|
|
|
Type newOutputType = outputTensorType.getWithSizesAndDtype(
|
2022-07-22 20:42:14 +08:00
|
|
|
outputTensorType.getSizes(), rewriter.getF64Type());
|
|
|
|
if (!inputTensorTy.hasDtype() ||
|
|
|
|
!inputTensorTy.getDtype().isa<mlir::FloatType>()) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "support floating-point type input only");
|
|
|
|
}
|
|
|
|
|
|
|
|
// Upcasting the input tensor to `F64` dtype for higher precision during the
|
|
|
|
// computation of the result.
|
|
|
|
if (inputTensorTy.getDtype().getIntOrFloatBitWidth() != 64) {
|
|
|
|
self = convertTensorToDtype(rewriter, loc, self, rewriter.getF64Type());
|
|
|
|
inputTensorTy = self.getType().cast<BaseTensorType>();
|
|
|
|
}
|
|
|
|
|
2022-12-20 18:17:27 +08:00
|
|
|
std::optional<unsigned> maybeInputRank = getTensorRank(self);
|
2022-12-13 00:56:28 +08:00
|
|
|
if (!maybeInputRank) {
|
|
|
|
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
|
|
|
|
}
|
|
|
|
unsigned inputRank = *maybeInputRank;
|
2022-07-22 20:42:14 +08:00
|
|
|
SmallVector<Value> dimListElements;
|
|
|
|
bool isNoneOrEmpty = true;
|
|
|
|
if (!dimList.getType().template isa<Torch::NoneType>()) {
|
|
|
|
if (!getListConstructElements(dimList, dimListElements))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "expect dimList to be constructed from list construct");
|
|
|
|
if (!dimListElements.empty() || inputRank == 0)
|
|
|
|
isNoneOrEmpty = false;
|
|
|
|
}
|
|
|
|
if (isNoneOrEmpty) {
|
|
|
|
for (unsigned i = 0; i < inputRank; i++)
|
|
|
|
dimListElements.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(i)));
|
|
|
|
dimList = rewriter.create<PrimListConstructOp>(
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
|
|
|
dimListElements);
|
|
|
|
}
|
|
|
|
Type meanDimResultType = inputTensorTy;
|
|
|
|
for (unsigned i = 0; i < dimListElements.size(); i++)
|
|
|
|
meanDimResultType = computeReductionType(
|
|
|
|
rewriter, op, meanDimResultType.cast<BaseTensorType>(),
|
|
|
|
dimListElements[i],
|
|
|
|
/*keepDim=*/true);
|
|
|
|
|
|
|
|
Value constantNone = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
Value constantTrue = rewriter.create<ConstantBoolOp>(loc, true);
|
|
|
|
Value meanAlongDims = rewriter.create<AtenMeanDimOp>(
|
|
|
|
loc, meanDimResultType, self, dimList, /*keepDim=*/constantTrue,
|
|
|
|
/*dtype=*/constantNone);
|
|
|
|
Value subMean =
|
|
|
|
createTensorSub(rewriter, loc, inputTensorTy, self, meanAlongDims);
|
|
|
|
Value square = rewriter.create<AtenSquareOp>(loc, inputTensorTy, subMean);
|
|
|
|
|
|
|
|
if (!unbiased) {
|
2022-12-08 01:51:37 +08:00
|
|
|
Value result = rewriter.create<AtenMeanDimOp>(
|
|
|
|
loc, newOutputType, square, dimList, keepDim, /*dtype=*/constantNone);
|
2022-07-22 20:42:14 +08:00
|
|
|
result = convertTensorToDtype(rewriter, loc, result,
|
|
|
|
outputTensorType.getDtype());
|
|
|
|
rewriter.replaceOp(op, result);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
// Divide the square sum by productDimSize - correction.
|
|
|
|
Value squareSum = rewriter.create<AtenSumDimIntListOp>(
|
2022-12-08 01:51:37 +08:00
|
|
|
loc, newOutputType, square, dimList, keepDim, /*dtype=*/constantNone);
|
2022-07-22 20:42:14 +08:00
|
|
|
|
|
|
|
// `productDimSize` is product of sizes of dimensions to be reduced.
|
|
|
|
Value constantOne =
|
|
|
|
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
Value productDimSize = constantOne;
|
|
|
|
for (Value dim : dimListElements) {
|
|
|
|
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, self, dim);
|
|
|
|
productDimSize =
|
|
|
|
rewriter.create<AtenMulIntOp>(loc, productDimSize, dimSize);
|
|
|
|
}
|
|
|
|
Value cstCorrection = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(correction));
|
|
|
|
// The `correction` value should be less than or equal to `productDimSize +
|
|
|
|
// 1`.
|
|
|
|
Value productDimSizePlusOne =
|
|
|
|
rewriter.create<AtenAddIntOp>(loc, productDimSize, constantOne);
|
|
|
|
Value cond =
|
|
|
|
rewriter.create<AtenGeIntOp>(loc, productDimSizePlusOne, cstCorrection);
|
|
|
|
rewriter.create<RuntimeAssertOp>(
|
|
|
|
loc, cond,
|
|
|
|
"correction value should be less than or equal to productDimSize + 1");
|
|
|
|
Value productDimSizeSubCorrection =
|
|
|
|
rewriter.create<AtenSubIntOp>(loc, productDimSize, cstCorrection);
|
2022-12-08 01:51:37 +08:00
|
|
|
Value result = rewriter.create<AtenDivScalarOp>(loc, newOutputType, squareSum,
|
|
|
|
productDimSizeSubCorrection);
|
2022-07-22 20:42:14 +08:00
|
|
|
result =
|
|
|
|
convertTensorToDtype(rewriter, loc, result, outputTensorType.getDtype());
|
|
|
|
rewriter.replaceOp(op, result);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-06-29 15:23:57 +08:00
|
|
|
// Decompose aten.var(x, dims) into:
|
|
|
|
// sub = aten.sub(x, aten.mean(x, dims))
|
|
|
|
// square = aten.square(sub)
|
|
|
|
// For Unbiased case:
|
|
|
|
// out = aten.sum(square, dims) / (productDimSize-1)
|
|
|
|
// For Biased case:
|
|
|
|
// out = aten.mean(square, dims)
|
|
|
|
namespace {
|
|
|
|
class DecomposeAtenVarDimOp : public OpRewritePattern<AtenVarDimOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenVarDimOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
bool unbiased;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getUnbiased(), m_TorchConstantBool(&unbiased))) {
|
2022-06-29 15:23:57 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "Only support constant unbiased for aten.var");
|
|
|
|
}
|
2022-07-22 20:42:14 +08:00
|
|
|
int64_t correction = unbiased ? 1 : 0;
|
|
|
|
if (failed(calculateVariance<AtenVarDimOp>(op, rewriter, unbiased,
|
|
|
|
correction)))
|
|
|
|
return rewriter.notifyMatchFailure(op, "invalid variance parameters");
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
2022-06-29 15:23:57 +08:00
|
|
|
|
2022-07-22 20:42:14 +08:00
|
|
|
// Decompose aten.var(x, dims) into:
|
|
|
|
// sub = aten.sub(x, aten.mean(x, dims))
|
|
|
|
// square = aten.square(sub)
|
|
|
|
// out = aten.sum(square, dims) / (productDimSize - correction)
|
|
|
|
namespace {
|
|
|
|
class DecomposeAtenVarCorrectionOp
|
|
|
|
: public OpRewritePattern<AtenVarCorrectionOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenVarCorrectionOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
int64_t correction;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!op.getCorrection().getType().isa<Torch::NoneType>()) {
|
|
|
|
if (!matchPattern(op.getCorrection(), m_TorchConstantInt(&correction)))
|
2022-07-22 20:42:14 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "Only support constant int correction for aten.var");
|
2022-06-29 15:23:57 +08:00
|
|
|
} else {
|
2022-07-22 20:42:14 +08:00
|
|
|
// The default value in case of `correction` being None is 1.
|
|
|
|
correction = 1;
|
2022-06-29 15:23:57 +08:00
|
|
|
}
|
2022-07-22 20:42:14 +08:00
|
|
|
bool unbiased = correction == 0 ? false : true;
|
|
|
|
if (failed(calculateVariance<AtenVarCorrectionOp>(op, rewriter, unbiased,
|
|
|
|
correction)))
|
|
|
|
return rewriter.notifyMatchFailure(op, "invalid variance parameters");
|
2022-06-29 15:23:57 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-05-10 21:15:59 +08:00
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
// Decompose the `aten.selectScatter` operation into `aten.sliceScatter` op.
|
2022-05-10 21:15:59 +08:00
|
|
|
class DecomposeAtenSelectScatterOp
|
|
|
|
: public OpRewritePattern<AtenSelectScatterOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenSelectScatterOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value start = op.getIndex();
|
|
|
|
Value dim = op.getDim();
|
|
|
|
Value self = op.getSelf();
|
|
|
|
Value src = op.getSrc();
|
2022-05-10 21:15:59 +08:00
|
|
|
|
|
|
|
Value one =
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
Value startPlusOne =
|
|
|
|
rewriter.create<AtenAddIntOp>(loc, one.getType(), start, one);
|
|
|
|
BaseTensorType srcTensorType = src.getType().cast<BaseTensorType>();
|
|
|
|
SmallVector<int64_t> sizes;
|
|
|
|
if (!srcTensorType.hasSizes())
|
|
|
|
return rewriter.notifyMatchFailure(op, "src tensor must have size");
|
|
|
|
|
|
|
|
ArrayRef<int64_t> srcShape = srcTensorType.getSizes();
|
|
|
|
// `src` has a reduced rank. Hence add 1.
|
|
|
|
int64_t srcRank = srcShape.size() + 1;
|
|
|
|
int64_t dimInt = 0;
|
|
|
|
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
|
|
|
|
dimInt = toPositiveDim(dimInt, srcRank);
|
|
|
|
if (!isValidDim(dimInt, srcRank))
|
|
|
|
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
|
|
|
|
|
|
|
sizes.append(srcShape.begin(), srcShape.end());
|
|
|
|
sizes.insert(sizes.begin() + dimInt, 1);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
sizes.resize(srcShape.size() + 1, kUnknownSize);
|
|
|
|
}
|
2023-01-04 06:19:18 +08:00
|
|
|
Type srcType = srcTensorType.getWithSizesAndDtype(
|
|
|
|
llvm::makeArrayRef(sizes), srcTensorType.getOptionalDtype());
|
2022-05-10 21:15:59 +08:00
|
|
|
src = rewriter.create<AtenUnsqueezeOp>(loc, srcType, src, dim);
|
|
|
|
rewriter.replaceOpWithNewOp<AtenSliceScatterOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
op, op.getSelf().getType(), self, src, dim, start, startPlusOne,
|
2022-05-10 21:15:59 +08:00
|
|
|
/*step=*/one);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-08-09 06:56:49 +08:00
|
|
|
namespace {
|
|
|
|
class DecomposeAten_EmbeddingBagOp
|
|
|
|
: public OpRewritePattern<Aten_EmbeddingBagOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(Aten_EmbeddingBagOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
Value weight = op.getWeight();
|
|
|
|
Value indices = op.getIndices();
|
|
|
|
Value offsets = op.getOffsets();
|
|
|
|
Value scaleGradByFreq = op.getScaleGradByFreq();
|
|
|
|
Value mode = op.getMode();
|
|
|
|
Value sparse = op.getSparse();
|
|
|
|
Value perSampleWeights = op.getPerSampleWeights();
|
|
|
|
Value includeLastOffset = op.getIncludeLastOffset();
|
|
|
|
Value paddingIdx = op.getPaddingIdx();
|
2022-08-09 06:56:49 +08:00
|
|
|
|
|
|
|
auto resultType0 = op->getResult(0).getType();
|
|
|
|
auto resultType1 = op->getResult(1).getType();
|
|
|
|
auto resultType2 = op->getResult(2).getType();
|
|
|
|
auto resultType3 = op->getResult(3).getType();
|
|
|
|
|
|
|
|
mlir::TypeRange returnTypes{resultType0, resultType1, resultType2,
|
|
|
|
resultType3};
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenEmbeddingBagPaddingIdxOp>(
|
|
|
|
op, returnTypes, weight, indices, offsets, scaleGradByFreq, mode,
|
|
|
|
sparse, perSampleWeights, includeLastOffset, paddingIdx);
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-09-06 22:07:17 +08:00
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
// Decompose `aten.liftFreshCopy` op into `aten.clone` op.
|
2022-09-06 22:07:17 +08:00
|
|
|
class DecomposeAtenLiftFreshCopyOp
|
|
|
|
: public OpRewritePattern<AtenLiftFreshCopyOp> {
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenLiftFreshCopyOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Value constantNone = rewriter.create<ConstantNoneOp>(op.getLoc());
|
2022-12-08 04:20:41 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenCloneOp>(op, op.getType(), op.getSelf(),
|
2022-09-06 22:07:17 +08:00
|
|
|
/*memoryFormat=*/constantNone);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-09-06 21:29:24 +08:00
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
// Decompose `aten.index.TensorHackedTwin` op into `aten.index.Tensor` op.
|
2022-09-06 21:29:24 +08:00
|
|
|
class DecomposeAtenIndexTensorHackedTwinOp
|
|
|
|
: public OpRewritePattern<AtenIndexTensorHackedTwinOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenIndexTensorHackedTwinOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
rewriter.replaceOpWithNewOp<AtenIndexTensorOp>(op, op.getType(), op.getSelf(),
|
|
|
|
op.getIndices());
|
2022-09-06 21:29:24 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-10-20 19:02:09 +08:00
|
|
|
namespace {
|
|
|
|
class DecomposeAtenMseLossOp : public OpRewritePattern<AtenMseLossOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenMseLossOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
// The `reduction` arg would have only three valid values.
|
|
|
|
// 0 means no reduction.
|
|
|
|
// 1 means mean reduction.
|
|
|
|
// 2 means sum reduction.
|
|
|
|
int64_t reductionType;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getReduction(), m_TorchConstantInt(&reductionType)))
|
2022-10-20 19:02:09 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "Expected a constant integer value for reduction");
|
|
|
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
BaseTensorType resultType = op.getType().cast<BaseTensorType>();
|
2022-12-08 04:20:41 +08:00
|
|
|
BaseTensorType inputType = op.getSelf().getType().cast<BaseTensorType>();
|
2022-10-20 19:02:09 +08:00
|
|
|
if (!inputType.hasSizes())
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "Expected the input tensor to have sizes");
|
|
|
|
BaseTensorType subType =
|
|
|
|
inputType
|
|
|
|
.getWithSizesAndDtype(llvm::makeArrayRef(inputType.getSizes()),
|
2023-01-04 06:19:18 +08:00
|
|
|
resultType.getOptionalDtype())
|
2022-10-20 19:02:09 +08:00
|
|
|
.cast<BaseTensorType>();
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
Value sub = createTensorSub(rewriter, loc, subType, op.getSelf(), op.getTarget());
|
2022-10-20 19:02:09 +08:00
|
|
|
Value result = rewriter.create<AtenSquareOp>(loc, subType, sub);
|
|
|
|
if (reductionType == torch_upstream::Reduction::None) {
|
|
|
|
rewriter.replaceOp(op, result);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
|
|
|
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
|
|
|
|
if (reductionType == torch_upstream::Reduction::Mean)
|
|
|
|
result = rewriter.create<AtenMeanDimOp>(loc, resultType, result,
|
|
|
|
/*dim=*/cstNone,
|
|
|
|
/*keepdim=*/cstFalse,
|
|
|
|
/*dtype=*/cstNone);
|
|
|
|
else
|
|
|
|
result = rewriter.create<AtenSumDimIntListOp>(
|
|
|
|
loc, resultType, result, /*dim=*/cstNone, /*keepdim=*/cstFalse,
|
|
|
|
/*dtype=*/cstNone);
|
|
|
|
rewriter.replaceOp(op, result);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-11-06 20:44:05 +08:00
|
|
|
namespace {
|
|
|
|
class DecomposeAtenRandintLowOp : public OpRewritePattern<AtenRandintLowOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenRandintLowOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
Type resultType = op.getType();
|
|
|
|
BaseTensorType resultTensorType = resultType.cast<BaseTensorType>();
|
2023-01-04 06:19:18 +08:00
|
|
|
if (!resultTensorType.hasDtype()) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "expected result type to have a dtype");
|
|
|
|
}
|
2022-11-06 20:44:05 +08:00
|
|
|
|
|
|
|
int64_t cstLow, cstHigh;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getLow(), m_TorchConstantInt(&cstLow)))
|
2022-11-06 20:44:05 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: low must be a constant integer");
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getHigh(), m_TorchConstantInt(&cstHigh)))
|
2022-11-06 20:44:05 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: high must be a constant integer");
|
|
|
|
|
|
|
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
Value cstFalse = rewriter.create<ConstantBoolOp>(loc, false);
|
|
|
|
Value low = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
loc, rewriter.getF64FloatAttr((double)cstLow));
|
|
|
|
Value high = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
loc, rewriter.getF64FloatAttr((double)cstHigh));
|
|
|
|
|
|
|
|
BaseTensorType floatResultType =
|
|
|
|
resultTensorType
|
|
|
|
.getWithSizesAndDtype(resultTensorType.getSizes(),
|
|
|
|
rewriter.getF32Type())
|
|
|
|
.cast<BaseTensorType>();
|
|
|
|
Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
loc, floatResultType, op.getSize(), /*dtype=*/none, /*layout=*/op.getLayout(),
|
|
|
|
/*device=*/op.getDevice(), /*pinMemory=*/op.getPinMemory(),
|
|
|
|
/*memoryFormat=*/none);
|
2022-11-06 20:44:05 +08:00
|
|
|
|
|
|
|
Value result =
|
|
|
|
rewriter.create<AtenUniformOp>(loc, floatResultType, emptyTensor,
|
|
|
|
/*from=*/low,
|
|
|
|
/*to=*/high,
|
|
|
|
/*generator=*/none);
|
|
|
|
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(
|
|
|
|
op, resultType, result,
|
|
|
|
getDtypeIntValueForType(rewriter, loc, resultTensorType.getDtype()),
|
2022-12-08 04:20:41 +08:00
|
|
|
/*nonBlocking=*/cstFalse, /*copy=*/cstFalse, /*memoryFormat=*/none);
|
2022-11-06 20:44:05 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-11-15 22:39:40 +08:00
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
// Decompose `aten.varMean.correction` op into `aten.var.correction` and
|
2022-11-15 22:39:40 +08:00
|
|
|
// `aten.mean.dim` op.
|
|
|
|
class DecomposeAtenVarMeanCorrectionOp
|
|
|
|
: public OpRewritePattern<AtenVarMeanCorrectionOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenVarMeanCorrectionOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
Value var = rewriter.create<AtenVarCorrectionOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
loc, op.getType(0), op.getSelf(), op.getDim(), op.getCorrection(), op.getKeepdim());
|
2022-11-15 22:39:40 +08:00
|
|
|
Value mean =
|
2022-12-08 04:20:41 +08:00
|
|
|
rewriter.create<AtenMeanDimOp>(loc, op.getType(0), op.getSelf(), op.getDim(),
|
|
|
|
op.getKeepdim(), /*dtype=*/noneVal);
|
2022-11-15 22:39:40 +08:00
|
|
|
rewriter.replaceOp(op, {var, mean});
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-11-21 16:38:47 +08:00
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
// Decompose `prims.convertElementType` op into `aten.to.dtype` op.
|
2022-11-21 16:38:47 +08:00
|
|
|
class DecomposePrimsConvertElementTypeOp
|
|
|
|
: public OpRewritePattern<PrimsConvertElementTypeOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(PrimsConvertElementTypeOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
|
|
|
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
|
|
|
|
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
op, op.getType(), op.getA(), op.getDtype(), /*nonBlocking=*/cstFalse,
|
|
|
|
/*copy=*/cstFalse, /*memoryFormat=*/cstNone);
|
2022-11-21 16:38:47 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2023-01-11 14:01:45 +08:00
|
|
|
namespace {
|
|
|
|
// Decompose `prims.var` op into `aten.var.correction` op.
|
|
|
|
class DecomposePrimsVarOp : public OpRewritePattern<PrimsVarOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(PrimsVarOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
if (!op.getOutputDtype().getType().isa<Torch::NoneType>())
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "Unimplemented non-None dtype for prims::var op");
|
|
|
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
|
|
|
rewriter.replaceOpWithNewOp<AtenVarCorrectionOp>(
|
|
|
|
op, op.getType(), op.getInp(), op.getDims(), op.getCorrection(),
|
|
|
|
/*keepdim=*/cstFalse);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
// Decompose `prims.sqrt` op into `aten.sqrt` op.
|
|
|
|
class DecomposePrimsSqrtOp : public OpRewritePattern<PrimsSqrtOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(PrimsSqrtOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
rewriter.replaceOpWithNewOp<AtenSqrtOp>(op, op.getType(), op.getSelf());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-11-14 15:08:13 +08:00
|
|
|
namespace {
|
|
|
|
// The op is decomposed using the Box-Muller transform.
|
|
|
|
// Refer: https://en.wikipedia.org/wiki/Box-Muller_transform
|
|
|
|
class DecomposeAtenRandnGeneratorOp
|
|
|
|
: public OpRewritePattern<AtenRandnGeneratorOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenRandnGeneratorOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
Type resultType = op.getType();
|
|
|
|
|
|
|
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
Value low = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
loc, rewriter.getF64FloatAttr((double)0.0));
|
|
|
|
Value high = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
loc, rewriter.getF64FloatAttr((double)1.0));
|
|
|
|
Value cstMinusTwo = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
loc, rewriter.getF64FloatAttr((double)-2.0));
|
|
|
|
Value cstTwoPie = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
loc, rewriter.getF64FloatAttr((double)(2.0 * 3.14159)));
|
|
|
|
|
|
|
|
Value emptyTensorA = rewriter.create<AtenEmptyMemoryFormatOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
loc, resultType, op.getSize(), /*dtype=*/none, /*layout=*/op.getLayout(),
|
|
|
|
/*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(),
|
2022-11-14 15:08:13 +08:00
|
|
|
/*memory_format=*/none);
|
|
|
|
Value emptyTensorB = rewriter.create<AtenEmptyMemoryFormatOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
loc, resultType, op.getSize(), /*dtype=*/none, /*layout=*/op.getLayout(),
|
|
|
|
/*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(),
|
2022-11-14 15:08:13 +08:00
|
|
|
/*memory_format=*/none);
|
|
|
|
|
|
|
|
Value uOne = rewriter.create<AtenUniformOp>(loc, resultType, emptyTensorA,
|
|
|
|
/*from=*/low,
|
|
|
|
/*to=*/high,
|
2022-12-08 04:20:41 +08:00
|
|
|
/*generator=*/op.getGenerator());
|
2022-11-14 15:08:13 +08:00
|
|
|
Value uTwo = rewriter.create<AtenUniformOp>(loc, resultType, emptyTensorB,
|
|
|
|
/*from=*/low,
|
|
|
|
/*to=*/high,
|
2022-12-08 04:20:41 +08:00
|
|
|
/*generator=*/op.getGenerator());
|
2022-11-14 15:08:13 +08:00
|
|
|
|
|
|
|
Value logUOne = rewriter.create<AtenLogOp>(loc, resultType, uOne);
|
|
|
|
Value minusTwoLogUOne =
|
|
|
|
rewriter.create<AtenMulScalarOp>(loc, resultType, logUOne, cstMinusTwo);
|
|
|
|
Value r = rewriter.create<AtenSqrtOp>(loc, resultType, minusTwoLogUOne);
|
|
|
|
Value theta =
|
|
|
|
rewriter.create<AtenMulScalarOp>(loc, resultType, uTwo, cstTwoPie);
|
|
|
|
Value cosTheta = rewriter.create<AtenCosOp>(loc, resultType, theta);
|
|
|
|
rewriter.replaceOpWithNewOp<AtenMulTensorOp>(op, op.getType(), r, cosTheta);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
// Decompose `aten.randn` op into `aten.randn.generator` op.
|
|
|
|
class DecomposeAtenRandnOp : public OpRewritePattern<AtenRandnOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenRandnOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Value none = rewriter.create<Torch::ConstantNoneOp>(op.getLoc());
|
|
|
|
rewriter.replaceOpWithNewOp<AtenRandnGeneratorOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
op, op.getType(), op.getSize(), /*generator=*/none, op.getDtype(),
|
|
|
|
op.getLayout(), op.getDevice(), op.getPinMemory());
|
2022-11-14 15:08:13 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2023-01-16 19:40:21 +08:00
|
|
|
namespace {
|
|
|
|
// Decompose `aten.randn_like` op into `aten.randn.generator` op.
|
|
|
|
class DecomposeAtenRandnLikeOp : public OpRewritePattern<AtenRandnLikeOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenRandnLikeOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
// Only `none`, `contiguous` and `preserve` memory_format is supported.
|
|
|
|
if (!op.getMemoryFormat().getType().isa<Torch::NoneType>()) {
|
|
|
|
int64_t memoryFormat;
|
|
|
|
if (!matchPattern(op.getMemoryFormat(),
|
|
|
|
m_TorchConstantInt(&memoryFormat)))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: the memory format should be specified in "
|
|
|
|
"an integer constant");
|
|
|
|
if (memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
|
|
|
|
memoryFormat != torch_upstream::MemoryFormat::Preserve)
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: only none, contiguous and preserve "
|
|
|
|
"memory_format is supported");
|
|
|
|
}
|
|
|
|
Value none = rewriter.create<Torch::ConstantNoneOp>(op.getLoc());
|
|
|
|
auto sizeListType =
|
|
|
|
Torch::ListType::get(Torch::IntType::get(op.getContext()));
|
|
|
|
Value sizeList =
|
|
|
|
rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.getSelf());
|
|
|
|
rewriter.replaceOpWithNewOp<AtenRandnGeneratorOp>(
|
|
|
|
op, op.getType(), sizeList, /*generator=*/none, op.getDtype(),
|
|
|
|
op.getLayout(), op.getDevice(), op.getPinMemory());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-12-09 23:22:26 +08:00
|
|
|
namespace {
|
|
|
|
class DecomposeAtenVarMeanOp : public OpRewritePattern<AtenVarMeanOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenVarMeanOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
Value falseVal = rewriter.create<ConstantBoolOp>(loc, false);
|
|
|
|
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
Value var = rewriter.create<AtenVarDimOp>(loc, op.getType(0), op.getSelf(),
|
|
|
|
/*dim=*/noneVal, op.getUnbiased(),
|
|
|
|
/*keepdim=*/falseVal);
|
|
|
|
Value mean = rewriter.create<AtenMeanOp>(loc, op.getType(0), op.getSelf(),
|
|
|
|
/*dtype=*/noneVal);
|
|
|
|
rewriter.replaceOp(op, {var, mean});
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-12-29 22:52:23 +08:00
|
|
|
namespace {
|
|
|
|
class DecomposeAtenNewEmptyStridedOp
|
|
|
|
: public OpRewritePattern<AtenNewEmptyStridedOp> {
|
|
|
|
public:
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(AtenNewEmptyStridedOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
SmallVector<int64_t> sizeListInts, strideListInts;
|
|
|
|
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizeListInts)))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "all size list elements must be constant ints");
|
|
|
|
if (!matchPattern(op.getStride(),
|
|
|
|
m_TorchListOfConstantInts(strideListInts)))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "all stride list elements must be constant ints");
|
|
|
|
|
|
|
|
// We only support the cases with default stride values.
|
|
|
|
// For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1])
|
|
|
|
// Here the stride[0] == size[1] * size[2], stride[1] == size[2], and
|
|
|
|
// stride[2] == 1.
|
|
|
|
bool isDefaultStride = true;
|
|
|
|
for (unsigned i = 0; i < strideListInts.size(); i++) {
|
|
|
|
int64_t defaultStride = 1;
|
|
|
|
for (unsigned j = i + 1; j < sizeListInts.size(); j++)
|
|
|
|
defaultStride *= sizeListInts[j];
|
|
|
|
if (defaultStride != strideListInts[i]) {
|
|
|
|
isDefaultStride = false;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (!isDefaultStride)
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "only default strides supported for new_empty_strided op");
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenNewEmptyOp>(
|
|
|
|
op, op.getType(), op.getSelf(), op.getSize(), op.getDtype(),
|
|
|
|
op.getLayout(), op.getDevice(), op.getPinMemory());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2021-10-16 06:23:59 +08:00
|
|
|
namespace {
|
|
|
|
class DecomposeComplexOpsPass
|
|
|
|
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
2022-12-09 01:26:38 +08:00
|
|
|
private:
|
|
|
|
llvm::StringSet<> legalOpsSet;
|
|
|
|
|
|
|
|
template <typename DecomposePattern>
|
|
|
|
void addPatternIfTargetOpIsIllegal(RewritePatternSet &patterns) {
|
|
|
|
MLIRContext *context = &getContext();
|
2022-12-20 18:17:27 +08:00
|
|
|
std::optional<OperationName> opName =
|
|
|
|
DecomposePattern(context).getRootKind();
|
2022-12-09 01:26:38 +08:00
|
|
|
// Because the `DecomposeComplexOpsPass` uses a greedy algorithm
|
|
|
|
// to apply patterns, only patterns that we for sure know we want to run
|
|
|
|
// must be added. This restricts the set of patterns allowed in this file to
|
|
|
|
// patterns that apply to a single op. In other words, patterns that match
|
|
|
|
// on `Operation *` are not allowed, since there is no way of telling if
|
|
|
|
// that pattern will match on an op in the `legalOpsSet` or not.
|
|
|
|
assert(opName && "All decomposition patterns must target a single op");
|
|
|
|
if (!legalOpsSet.contains(opName->getStringRef()))
|
|
|
|
patterns.add<DecomposePattern>(context);
|
|
|
|
}
|
|
|
|
|
2022-08-18 07:23:52 +08:00
|
|
|
public:
|
|
|
|
DecomposeComplexOpsPass() = default;
|
|
|
|
DecomposeComplexOpsPass(ArrayRef<std::string> legalOps) {
|
|
|
|
this->legalOps = legalOps;
|
|
|
|
}
|
2021-10-16 06:23:59 +08:00
|
|
|
void runOnOperation() override {
|
|
|
|
MLIRContext *context = &getContext();
|
|
|
|
RewritePatternSet patterns(context);
|
2022-12-09 01:26:38 +08:00
|
|
|
// The strings in the `legalOps` ArrayRef don't exist during the call to the
|
|
|
|
// constructor `DecomposeComplexOpsPass`, so the creation of the
|
|
|
|
// `legalOpsSet` must be delayed to when `runOnOperation` gets called.
|
|
|
|
legalOpsSet.clear();
|
|
|
|
legalOpsSet.insert(legalOps.begin(), legalOps.end());
|
|
|
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftmaxIntOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAten_SoftmaxOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenLogSoftmaxIntOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenEmptyLikeOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<
|
|
|
|
DecomposeConstantTensorAllocLikeOp<AtenOnesLikeOp, 1>>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<
|
|
|
|
DecomposeConstantTensorAllocLikeOp<AtenZerosLikeOp, 0>>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRollOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRepeatOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenExpandOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenFlattenUsingIntsOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarOtherOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarSelfOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<
|
|
|
|
DecomposeAtenConvolutionBackwardOverrideableOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSizeOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAten_SoftmaxBackwardDataOp>(
|
|
|
|
patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenTanhBackwardOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenAddmmOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanDimOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectIntOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMatmulOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMvOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenTOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxBackwardDataOp>(
|
|
|
|
patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<
|
|
|
|
DecomposeAtenAddCLikeOp<AtenAddcmulOp, AtenMulTensorOp>>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<
|
|
|
|
DecomposeAtenAddCLikeOp<AtenAddcdivOp, AtenDivTensorOp>>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenLayerNormOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeLayerNormOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeBatchNormOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenConvolutionOverrideableOp>(
|
|
|
|
patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<
|
|
|
|
DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionOp>>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<
|
|
|
|
DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionDeprecatedOp>>(
|
|
|
|
patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenConvolutionBackwardOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenConv2dOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTranspose2dOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeStartOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenArgMaxOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSquareOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenVarOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenStdOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAten_UnsafeViewOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAten_ReshapeAliasOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenBernoulliOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeValsemVariantAtenBernoulliFloatOp>(
|
|
|
|
patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenBernoulliTensorOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenZeroOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRandLikeOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenHardswishOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftplusOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSiluOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<
|
|
|
|
DecomposeConstantTensorNewLikeOp<AtenNewZerosOp, AtenZerosOp>>(
|
|
|
|
patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<
|
|
|
|
DecomposeConstantTensorNewLikeOp<AtenNewOnesOp, AtenOnesOp>>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenHardtanhOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenFullOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenLinearOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMishOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenFullLikeOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexPutOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenExpandAsOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAten_ToCopyOp>(patterns);
|
2022-12-22 10:13:59 +08:00
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenCopyOp>(patterns);
|
2022-12-09 01:26:38 +08:00
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenDropoutOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexPutHackedTwinOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenPadOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenToDtypeLayoutOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenToDeviceOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool2dOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMaxOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenBaddbmmOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenFloorDivideOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenNumpyTOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectScatterOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenVarDimOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenAmaxOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenVarCorrectionOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenStdDimOp>(patterns);
|
2022-12-22 13:02:40 +08:00
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenStdCorrectionOp>(patterns);
|
2022-12-09 01:26:38 +08:00
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenNarrowOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAten_EmbeddingBagOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenLiftFreshCopyOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexTensorHackedTwinOp>(
|
|
|
|
patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMseLossOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRandintLowOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanCorrectionOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposePrimsConvertElementTypeOp>(patterns);
|
2023-01-11 14:01:45 +08:00
|
|
|
addPatternIfTargetOpIsIllegal<DecomposePrimsVarOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposePrimsSqrtOp>(patterns);
|
2022-12-09 01:26:38 +08:00
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnGeneratorOp>(patterns);
|
2023-01-16 19:40:21 +08:00
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnLikeOp>(patterns);
|
2022-12-09 23:22:26 +08:00
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanOp>(patterns);
|
2023-01-04 00:30:16 +08:00
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluOp>(patterns);
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluBackwardOp>(patterns);
|
2022-12-29 22:52:23 +08:00
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyStridedOp>(patterns);
|
2022-12-09 01:26:38 +08:00
|
|
|
|
|
|
|
GreedyRewriteConfig config;
|
|
|
|
config.useTopDownTraversal = true;
|
2023-01-11 07:07:19 +08:00
|
|
|
config.maxIterations = GreedyRewriteConfig::kNoLimit;
|
2022-12-09 01:26:38 +08:00
|
|
|
|
|
|
|
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
|
|
|
|
config))) {
|
2021-10-16 06:23:59 +08:00
|
|
|
return signalPassFailure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
2022-08-19 08:01:54 +08:00
|
|
|
|
2022-04-27 03:27:51 +08:00
|
|
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
2022-08-19 08:01:54 +08:00
|
|
|
mlir::torch::Torch::createDecomposeComplexOpsPass(
|
|
|
|
ArrayRef<std::string> legalOps) {
|
|
|
|
return std::make_unique<DecomposeComplexOpsPass>(legalOps);
|
2023-01-21 02:40:13 +08:00
|
|
|
}
|