2021-10-16 06:23:59 +08:00
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
//
|
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
|
|
|
//
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
#include "PassDetail.h"
|
|
|
|
|
|
2022-02-09 04:57:23 +08:00
|
|
|
|
#include "mlir/IR/BuiltinDialect.h"
|
2021-10-16 06:23:59 +08:00
|
|
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
|
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
2022-04-26 20:18:09 +08:00
|
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
2021-10-16 06:23:59 +08:00
|
|
|
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
|
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
2022-05-10 21:15:59 +08:00
|
|
|
|
#include "llvm/ADT/ArrayRef.h"
|
2021-10-16 06:23:59 +08:00
|
|
|
|
#include "llvm/ADT/StringExtras.h"
|
2022-05-10 21:15:59 +08:00
|
|
|
|
#include <cstdint>
|
2021-10-16 06:23:59 +08:00
|
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
|
using namespace mlir::torch;
|
|
|
|
|
using namespace mlir::torch::Torch;
|
|
|
|
|
|
2022-03-11 01:25:21 +08:00
|
|
|
|
// Helper function to check whether the `dtype` is None or Float type.
|
|
|
|
|
static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) {
|
|
|
|
|
if (dtype.getType().isa<Torch::NoneType>())
|
|
|
|
|
return true;
|
|
|
|
|
int64_t dtypeInt;
|
|
|
|
|
if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
|
|
|
|
|
return false;
|
|
|
|
|
Type resDtype =
|
|
|
|
|
getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
|
|
|
|
|
return resDtype.isa<mlir::FloatType>();
|
|
|
|
|
}
|
|
|
|
|
|
2022-02-01 03:56:32 +08:00
|
|
|
|
// Helper function to compute the return type of the reduction function.
|
|
|
|
|
// `dim` specifies the dimension to reduce and `keepDim` preserves the rank of
|
|
|
|
|
// the input tensor.
|
|
|
|
|
static Type computeReductionType(PatternRewriter &rewriter, Operation *op,
|
|
|
|
|
Value input, Value dim, bool keepDim) {
|
2021-11-08 23:56:40 +08:00
|
|
|
|
BaseTensorType tensorType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
SmallVector<int64_t> sizes;
|
|
|
|
|
int64_t dimInt;
|
|
|
|
|
if (tensorType.hasSizes()) {
|
|
|
|
|
ArrayRef<int64_t> inputShape = tensorType.getSizes();
|
|
|
|
|
int64_t inputRank = inputShape.size();
|
|
|
|
|
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
|
|
|
|
|
dimInt = toPositiveDim(dimInt, inputRank);
|
|
|
|
|
if (!isValidDim(dimInt, inputRank)) {
|
|
|
|
|
(void)rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
sizes.append(inputShape.begin(), inputShape.end());
|
2022-02-01 03:56:32 +08:00
|
|
|
|
// The dimension to be reduced is set to 1 when `keepDim` is true else it
|
|
|
|
|
// is removed.
|
|
|
|
|
if (keepDim)
|
|
|
|
|
sizes[dimInt] = 1;
|
|
|
|
|
else
|
|
|
|
|
sizes.erase(sizes.begin() + dimInt - 1);
|
2021-11-08 23:56:40 +08:00
|
|
|
|
} else {
|
2022-02-01 03:56:32 +08:00
|
|
|
|
unsigned reducedRank = keepDim ? inputRank : inputRank - 1;
|
|
|
|
|
sizes.resize(reducedRank, kUnknownSize);
|
2021-11-08 23:56:40 +08:00
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Type resultType = tensorType.getWithSizesAndDtype(
|
|
|
|
|
sizes.size() == 0 ? Optional<ArrayRef<int64_t>>()
|
|
|
|
|
: llvm::makeArrayRef(sizes),
|
|
|
|
|
tensorType.getDtype());
|
2022-02-01 03:56:32 +08:00
|
|
|
|
return resultType;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Reduction function to calculate sum along given `dim`.
|
|
|
|
|
static Value createSumAlongDimension(PatternRewriter &rewriter, Location loc,
|
|
|
|
|
Operation *op, Value input, Value dim,
|
|
|
|
|
bool keepDim) {
|
|
|
|
|
Value dimList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(dim.getType()), dim);
|
|
|
|
|
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
|
|
|
|
|
Value dtype = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
|
Type resultType = computeReductionType(rewriter, op, input, dim, keepDim);
|
|
|
|
|
if (!resultType)
|
|
|
|
|
return nullptr;
|
|
|
|
|
return rewriter.create<AtenSumDimIntListOp>(loc, resultType, input, dimList,
|
|
|
|
|
keepDimCst, dtype);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Redunction function to calculate max along given `dim`.
|
|
|
|
|
static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc,
|
|
|
|
|
Operation *op, Value input, Value dim,
|
|
|
|
|
bool keepDim) {
|
|
|
|
|
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
|
|
|
|
|
BaseTensorType valueType =
|
|
|
|
|
computeReductionType(rewriter, op, input, dim, keepDim)
|
|
|
|
|
.cast<BaseTensorType>();
|
|
|
|
|
if (!valueType)
|
|
|
|
|
return nullptr;
|
|
|
|
|
BaseTensorType indexType =
|
|
|
|
|
valueType
|
|
|
|
|
.getWithSizesAndDtype(
|
|
|
|
|
!valueType.hasSizes() ? Optional<ArrayRef<int64_t>>()
|
|
|
|
|
: llvm::makeArrayRef(valueType.getSizes()),
|
|
|
|
|
IntegerType::get(op->getContext(), 64, IntegerType::Signed))
|
|
|
|
|
.cast<BaseTensorType>();
|
|
|
|
|
return rewriter
|
|
|
|
|
.create<AtenMaxDimOp>(loc, valueType, indexType, input, dim, keepDimCst)
|
|
|
|
|
.values();
|
2021-11-08 23:56:40 +08:00
|
|
|
|
}
|
|
|
|
|
|
2021-11-19 02:02:20 +08:00
|
|
|
|
// Helper for creating `aten::sub_tensor_op`.
|
2021-11-19 20:18:41 +08:00
|
|
|
|
static Value createTensorSub(PatternRewriter &rewriter, Location loc,
|
2022-02-15 21:14:32 +08:00
|
|
|
|
Type tensorType, Value lhs, Value rhs) {
|
2021-11-19 02:02:20 +08:00
|
|
|
|
Value alpha =
|
|
|
|
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
|
|
|
|
|
Value sub =
|
|
|
|
|
rewriter.create<AtenSubTensorOp>(loc, tensorType, lhs, rhs, alpha);
|
|
|
|
|
return sub;
|
|
|
|
|
}
|
|
|
|
|
|
2022-03-03 00:48:15 +08:00
|
|
|
|
// Helper to create a tensor filled with the given scalar. Scalar would be
|
|
|
|
|
// converted the to the element type of the given tensor type.
|
2022-02-09 04:57:23 +08:00
|
|
|
|
static Value createInitTensor(PatternRewriter &rewriter, Location loc,
|
|
|
|
|
Type resultType, Value scalar, Value sizeList) {
|
|
|
|
|
BaseTensorType tensorType = resultType.cast<BaseTensorType>();
|
|
|
|
|
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
|
Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
|
|
|
|
|
loc, tensorType, sizeList, /*dtype=*/noneVal, /*layout=*/noneVal,
|
2022-02-26 00:35:04 +08:00
|
|
|
|
/*device=*/noneVal, /*pin_memory=*/noneVal, /*memory_format=*/noneVal);
|
2022-03-16 07:57:33 +08:00
|
|
|
|
return rewriter.create<ValsemVariantAtenFillScalarOp>(loc, resultType,
|
|
|
|
|
emptyTensor, scalar);
|
2022-02-09 04:57:23 +08:00
|
|
|
|
}
|
|
|
|
|
|
2022-02-26 00:35:04 +08:00
|
|
|
|
// Helper to create a rank 0 tensor filled with the given `scalar`. `scalar`
|
|
|
|
|
// would be converted to the element type of the given `inputType`.
|
2022-02-09 04:57:23 +08:00
|
|
|
|
static Value createRank0Tensor(PatternRewriter &rewriter, Location loc,
|
|
|
|
|
BaseTensorType inputType, Value scalar) {
|
|
|
|
|
SmallVector<int64_t> sizes;
|
|
|
|
|
Type rank0TensorTy = inputType.getWithSizesAndDtype(
|
|
|
|
|
makeArrayRef(sizes), inputType.getOptionalDtype());
|
|
|
|
|
Value dimList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())),
|
|
|
|
|
ValueRange{});
|
|
|
|
|
return createInitTensor(rewriter, loc, rank0TensorTy, scalar, dimList);
|
|
|
|
|
}
|
|
|
|
|
|
2021-11-19 20:18:41 +08:00
|
|
|
|
// Share code between `softmax_backward` and `log_softmax_backward` ops.
|
|
|
|
|
// Returns x - y * sum(z, dim).
|
|
|
|
|
static Value createSoftmaxBackwardCommonKernel(PatternRewriter &rewriter,
|
|
|
|
|
Location loc, Operation *op,
|
|
|
|
|
Type tensorType, Value x,
|
|
|
|
|
Value y, Value z, Value dim) {
|
2022-02-15 21:14:32 +08:00
|
|
|
|
Value sum =
|
|
|
|
|
createSumAlongDimension(rewriter, loc, op, z, dim, /*keepDim=*/true);
|
2021-11-19 20:18:41 +08:00
|
|
|
|
if (!sum)
|
|
|
|
|
return nullptr;
|
|
|
|
|
auto broadcastSizeType =
|
|
|
|
|
Torch::ListType::get(Torch::IntType::get(op->getContext()));
|
|
|
|
|
Value broadcastSize = rewriter.create<AtenSizeOp>(loc, broadcastSizeType, z);
|
|
|
|
|
Value sumBroadcast =
|
|
|
|
|
rewriter.create<AtenBroadcastToOp>(loc, tensorType, sum, broadcastSize);
|
|
|
|
|
Value temp =
|
|
|
|
|
rewriter.create<AtenMulTensorOp>(loc, tensorType, y, sumBroadcast);
|
|
|
|
|
|
|
|
|
|
Value sub = createTensorSub(rewriter, loc, tensorType, x, temp);
|
|
|
|
|
return sub;
|
|
|
|
|
}
|
|
|
|
|
|
2021-11-08 23:56:40 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenSizeOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value self = op.self();
|
|
|
|
|
MLIRContext *context = op.getContext();
|
|
|
|
|
int64_t rank = getTensorRank(self);
|
|
|
|
|
if (rank < 0)
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor");
|
|
|
|
|
SmallVector<Value> sizes;
|
|
|
|
|
for (int i = 0; i < rank; i++) {
|
|
|
|
|
Value dim = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(i));
|
|
|
|
|
sizes.push_back(rewriter.create<AtenSizeIntOp>(loc, self, dim));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value sizeList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(context)), sizes);
|
|
|
|
|
rewriter.replaceOp(op, sizeList);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2021-12-03 12:09:21 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenSelectIntOp : public OpRewritePattern<AtenSelectIntOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenSelectIntOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2022-02-12 03:34:05 +08:00
|
|
|
|
Value start = op.index();
|
|
|
|
|
Value dim = op.dim();
|
|
|
|
|
Value self = op.self();
|
|
|
|
|
|
2021-12-03 12:09:21 +08:00
|
|
|
|
Value one =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
2022-02-12 03:34:05 +08:00
|
|
|
|
Value startPlusOne =
|
|
|
|
|
rewriter.create<AtenAddIntOp>(loc, one.getType(), start, one);
|
|
|
|
|
Value slice = rewriter.create<AtenSliceTensorOp>(
|
|
|
|
|
loc, computeReductionType(rewriter, op, self, dim, /*keepDim=*/true),
|
|
|
|
|
op.self(), dim, start, startPlusOne, /*step=*/one);
|
|
|
|
|
|
|
|
|
|
// `aten.slice.tensor` doesn't squeeze the dim even when it's size 1 after
|
|
|
|
|
// slicing, while `aten.select.int` does.
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenSqueezeDimOp>(op, op.getResult().getType(),
|
|
|
|
|
slice, op.dim());
|
2021-12-03 12:09:21 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-03-25 03:26:37 +08:00
|
|
|
|
namespace {
|
2022-07-12 01:56:12 +08:00
|
|
|
|
class DecomposeAtenZeroOp
|
|
|
|
|
: public OpRewritePattern<AtenZeroOp> {
|
2022-03-25 03:26:37 +08:00
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
2022-07-12 01:56:12 +08:00
|
|
|
|
LogicalResult matchAndRewrite(AtenZeroOp op,
|
2022-03-25 03:26:37 +08:00
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value zero = rewriter.create<ConstantIntOp>(op.getLoc(),
|
|
|
|
|
rewriter.getI64IntegerAttr(0));
|
|
|
|
|
rewriter.replaceOpWithNewOp<ValsemVariantAtenFillScalarOp>(op, op.getType(),
|
|
|
|
|
op.self(), zero);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2021-12-17 23:54:03 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenReshapeOp : public OpRewritePattern<AtenReshapeOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenReshapeOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value input = op.self();
|
|
|
|
|
// TODO: Handle non value tensor type operands.
|
|
|
|
|
if (!input.getType().isa<ValueTensorType>()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: only value tensor type operands are supported");
|
|
|
|
|
}
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenViewOp>(op, op.getType(), input,
|
|
|
|
|
op.shape());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2021-11-25 13:49:02 +08:00
|
|
|
|
// Calculates the softmax function on the given `input` tensor. Softmax(x) =
|
|
|
|
|
// exp(x)/sum(exp(x)).
|
2022-02-01 03:56:32 +08:00
|
|
|
|
// To avoid overflow we use the following decomposition rule:
|
|
|
|
|
// x_max = max(input, dim, keepdim = True)
|
|
|
|
|
// unnorm = aten.exp(input - x_max)
|
|
|
|
|
// softmax = unnorm / sum(unnorm, dim, keepdim = True)
|
2021-11-25 13:49:02 +08:00
|
|
|
|
template <typename OpTy>
|
|
|
|
|
static Value getSoftmaxResult(OpTy op, Type resultType,
|
|
|
|
|
PatternRewriter &rewriter) {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value dim = op.dim();
|
|
|
|
|
Value self = op.self();
|
2022-02-01 03:56:32 +08:00
|
|
|
|
Value xMax =
|
|
|
|
|
createMaxAlongDimension(rewriter, loc, op, self, dim, /*keepDim=*/true);
|
|
|
|
|
if (!xMax)
|
|
|
|
|
return nullptr;
|
|
|
|
|
Value unNormalized = createTensorSub(rewriter, loc, resultType, self, xMax);
|
|
|
|
|
Value unNormalizedExp =
|
|
|
|
|
rewriter.create<AtenExpOp>(loc, resultType, unNormalized);
|
|
|
|
|
Value sum = createSumAlongDimension(rewriter, loc, op, unNormalizedExp, dim,
|
|
|
|
|
/*keepDim=*/true);
|
2021-11-25 13:49:02 +08:00
|
|
|
|
if (!sum)
|
|
|
|
|
return nullptr;
|
2022-02-01 03:56:32 +08:00
|
|
|
|
return rewriter.create<AtenDivTensorOp>(loc, resultType, unNormalizedExp,
|
|
|
|
|
sum);
|
2021-11-25 13:49:02 +08:00
|
|
|
|
}
|
|
|
|
|
|
2021-10-16 06:23:59 +08:00
|
|
|
|
// Decompose softmax into: exp(x) / sum(exp(x))
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenSoftmaxIntOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value self = op.self();
|
|
|
|
|
if (!op.dtype().getType().isa<Torch::NoneType>())
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Unimplemented non-None dtype for softmax");
|
|
|
|
|
|
|
|
|
|
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
2021-11-08 23:56:40 +08:00
|
|
|
|
|
2021-11-25 13:49:02 +08:00
|
|
|
|
Value result = getSoftmaxResult(op, tensorType, rewriter);
|
|
|
|
|
if (!result)
|
2021-11-08 23:56:40 +08:00
|
|
|
|
return failure();
|
2021-11-25 13:49:02 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
|
|
|
|
|
result);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAten_SoftmaxOp : public OpRewritePattern<Aten_SoftmaxOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(Aten_SoftmaxOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value self = op.self();
|
|
|
|
|
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
|
|
|
|
bool halfToFloat;
|
|
|
|
|
if (!matchPattern(op.half_to_float(), m_TorchConstantBool(&halfToFloat)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Expected a boolean value for half_to_float");
|
|
|
|
|
|
|
|
|
|
// Currently, setting `halfToFloat` is not supported as the E2E testing for
|
|
|
|
|
// the same is not present on CPU.
|
|
|
|
|
if (halfToFloat)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "halfToFloat is currently not supported.");
|
|
|
|
|
|
|
|
|
|
Value result = getSoftmaxResult(op, tensorType, rewriter);
|
|
|
|
|
if (!result)
|
|
|
|
|
return op.emitError("failed to get softmax result");
|
2021-10-16 06:23:59 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
|
|
|
|
|
result);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2021-11-08 23:56:40 +08:00
|
|
|
|
// Aten_SoftmaxBackwardDataOp(gradOutput, output, dim) =>
|
|
|
|
|
// newGrad = gradOutput * output
|
|
|
|
|
// result = newGrad - output * sum(newGrad, dim))
|
|
|
|
|
//
|
|
|
|
|
// Refer to
|
|
|
|
|
// https://github.com/pytorch/pytorch/blob/15fecc4c830a3907fde4b44c9962dc4144da50a4/torch/csrc/jit/codegen/cuda/ops/normalization.cpp#L31
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAten_SoftmaxBackwardDataOp
|
|
|
|
|
: public OpRewritePattern<Aten_SoftmaxBackwardDataOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(Aten_SoftmaxBackwardDataOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value gradOutput = op.grad_output();
|
|
|
|
|
Value output = op.output();
|
|
|
|
|
Value dim = op.dim();
|
|
|
|
|
|
|
|
|
|
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
|
|
|
|
|
|
|
|
|
Value newGrad =
|
|
|
|
|
rewriter.create<AtenMulTensorOp>(loc, tensorType, gradOutput, output);
|
2021-11-19 20:18:41 +08:00
|
|
|
|
Value result = createSoftmaxBackwardCommonKernel(
|
|
|
|
|
rewriter, loc, op, tensorType, newGrad, output, newGrad, dim);
|
|
|
|
|
if (!result)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op,
|
|
|
|
|
"nullptr returned by createSoftmaxBackwardCommonKernel function.");
|
|
|
|
|
rewriter.replaceOp(op, result);
|
2021-11-08 23:56:40 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2021-11-09 20:25:04 +08:00
|
|
|
|
// AtenTanhBackwardOp(gradOutput, output) =>
|
|
|
|
|
// result = gradOutput * (1 - output^2)
|
|
|
|
|
// To get away from broadcasts the above formula is expanded i.e.,
|
|
|
|
|
// result = gradOutput - (gradOutput * output^2)
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenTanhBackwardOp
|
|
|
|
|
: public OpRewritePattern<AtenTanhBackwardOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenTanhBackwardOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value gradOutput = op.grad_output();
|
|
|
|
|
|
|
|
|
|
// `output` is the value flowing out from tanh. Hence, tanh(x) = output.
|
2021-11-19 20:18:41 +08:00
|
|
|
|
// Since, dTanh(x) = (1 - tanh(x)^2) hence, dOutput = (1 - output^2).
|
2021-11-09 20:25:04 +08:00
|
|
|
|
Value output = op.output();
|
|
|
|
|
|
|
|
|
|
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
|
|
|
|
|
|
|
|
|
Value tanhSquare =
|
|
|
|
|
rewriter.create<AtenMulTensorOp>(loc, tensorType, output, output);
|
|
|
|
|
Value gradMulTanhSquare = rewriter.create<AtenMulTensorOp>(
|
|
|
|
|
loc, tensorType, tanhSquare, gradOutput);
|
|
|
|
|
|
2021-11-19 20:18:41 +08:00
|
|
|
|
Value newGrad = createTensorSub(rewriter, loc, tensorType, gradOutput,
|
2022-02-15 21:14:32 +08:00
|
|
|
|
gradMulTanhSquare);
|
2021-11-09 20:25:04 +08:00
|
|
|
|
rewriter.replaceOp(op, newGrad);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2021-11-19 02:02:20 +08:00
|
|
|
|
// Aten_LogSoftmaxBackwardDataOp(gradOutput, output, dim) =>
|
|
|
|
|
// result = gradOutput - (exp(output) * sum(gradOutput, dim))
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAten_LogSoftmaxBackwardDataOp
|
|
|
|
|
: public OpRewritePattern<Aten_LogSoftmaxBackwardDataOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(Aten_LogSoftmaxBackwardDataOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value gradOutput = op.grad_output();
|
|
|
|
|
Value output = op.output();
|
|
|
|
|
Value dim = op.dim();
|
|
|
|
|
|
|
|
|
|
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
|
|
|
|
|
|
|
|
|
Value expOut = rewriter.create<AtenExpOp>(loc, tensorType, output);
|
2021-11-19 20:18:41 +08:00
|
|
|
|
Value result = createSoftmaxBackwardCommonKernel(
|
|
|
|
|
rewriter, loc, op, tensorType, gradOutput, expOut, gradOutput, dim);
|
|
|
|
|
if (!result)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op,
|
|
|
|
|
"nullptr returned by createSoftmaxBackwardCommonKernel function.");
|
|
|
|
|
rewriter.replaceOp(op, result);
|
2021-11-19 02:02:20 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-01-25 16:53:55 +08:00
|
|
|
|
// Decompose `AtenArgMaxOp` into `AtenMaxDimOp`.
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenArgMaxOp : public OpRewritePattern<AtenArgmaxOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenArgmaxOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value input = op.self();
|
|
|
|
|
Value dim = op.dim();
|
|
|
|
|
Value keepDim = op.keepdim();
|
|
|
|
|
Value result = op.result();
|
|
|
|
|
|
|
|
|
|
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
BaseTensorType indicesTensorType = result.getType().cast<BaseTensorType>();
|
|
|
|
|
|
|
|
|
|
if (!indicesTensorType.hasSizes())
|
|
|
|
|
return failure();
|
|
|
|
|
BaseTensorType valueTensorType =
|
|
|
|
|
inputType
|
|
|
|
|
.getWithSizesAndDtype(indicesTensorType.getSizes(),
|
|
|
|
|
inputType.getDtype())
|
|
|
|
|
.cast<BaseTensorType>();
|
|
|
|
|
|
|
|
|
|
// If the dim type is `NoneType` i.e. reduce along all the dimensions.
|
|
|
|
|
// `AtenMaxDimOp` doesn't support dim as `NoneType` so first the input
|
|
|
|
|
// tensor is flattened to 1d tensor and then the reduction happens on the
|
|
|
|
|
// 0th dimension.
|
|
|
|
|
if (dim.getType().isa<Torch::NoneType>()) {
|
|
|
|
|
BaseTensorType flattenType =
|
|
|
|
|
inputType.getWithSizesAndDtype({kUnknownSize}, inputType.getDtype())
|
|
|
|
|
.cast<BaseTensorType>();
|
|
|
|
|
dim = rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
Value end = rewriter.create<ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(getTensorRank(input) - 1));
|
|
|
|
|
input = rewriter.create<AtenFlattenUsingIntsOp>(loc, flattenType, input,
|
|
|
|
|
dim, end);
|
|
|
|
|
}
|
|
|
|
|
Value maxResult =
|
|
|
|
|
rewriter
|
|
|
|
|
.create<AtenMaxDimOp>(loc, valueTensorType, indicesTensorType,
|
|
|
|
|
input, dim, keepDim)
|
|
|
|
|
.indices();
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, maxResult);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-02-11 16:39:34 +08:00
|
|
|
|
// To avoid overflow we use the following decomposition rule:
|
|
|
|
|
// x_max = aten.max(x, dim, keepdim=True)[0]
|
|
|
|
|
// shifted = x - x_max
|
|
|
|
|
// shifted_logsumexp = aten.log(aten.sum(aten.exp(shifted), dim, keepdim=True))
|
|
|
|
|
// log_softmax = shifted - shifted_logsumexp
|
|
|
|
|
template <typename OpTy>
|
|
|
|
|
static Value getLogSoftmaxResult(OpTy op, PatternRewriter &rewriter) {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value dim = op.dim();
|
|
|
|
|
Value self = op.self();
|
|
|
|
|
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
|
|
|
|
|
Value xMax =
|
|
|
|
|
createMaxAlongDimension(rewriter, loc, op, self, dim, /*keepDim=*/true);
|
|
|
|
|
if (!xMax)
|
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
|
|
Value shifted = createTensorSub(rewriter, loc, tensorType, self, xMax);
|
|
|
|
|
Value shiftedExp = rewriter.create<AtenExpOp>(loc, tensorType, shifted);
|
|
|
|
|
Value shiftedSumExp =
|
|
|
|
|
createSumAlongDimension(rewriter, loc, op, shiftedExp, dim,
|
|
|
|
|
/*keepDim=*/true);
|
|
|
|
|
if (!shiftedSumExp)
|
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
|
|
Value shiftedLogSumExp =
|
|
|
|
|
rewriter.create<AtenLogOp>(loc, shiftedSumExp.getType(), shiftedSumExp);
|
|
|
|
|
Value result =
|
|
|
|
|
createTensorSub(rewriter, loc, op.getType(), shifted, shiftedLogSumExp);
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
2021-11-03 01:06:04 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenLogSoftmaxIntOp
|
|
|
|
|
: public OpRewritePattern<AtenLogSoftmaxIntOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenLogSoftmaxIntOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value self = op.self();
|
|
|
|
|
if (!op.dtype().getType().isa<Torch::NoneType>())
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Unimplemented non-None dtype for log_softmax");
|
|
|
|
|
|
|
|
|
|
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
|
|
|
|
|
2022-02-11 16:39:34 +08:00
|
|
|
|
Value logSoftmax = getLogSoftmaxResult(op, rewriter);
|
|
|
|
|
if (!logSoftmax)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "getLogSoftmaxResult function returned nullptr");
|
|
|
|
|
rewriter.replaceOp(op, logSoftmax);
|
2021-11-03 01:06:04 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-02-10 15:05:23 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAten_LogSoftmaxOp : public OpRewritePattern<Aten_LogSoftmaxOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(Aten_LogSoftmaxOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-02-11 16:39:34 +08:00
|
|
|
|
bool halfToFloat;
|
|
|
|
|
if (!matchPattern(op.half_to_float(), m_TorchConstantBool(&halfToFloat)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Expected a boolean value for half_to_float");
|
|
|
|
|
|
|
|
|
|
// Currently, setting `halfToFloat` is not supported as the E2E testing for
|
|
|
|
|
// the same is not present on CPU.
|
|
|
|
|
if (halfToFloat)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "halfToFloat is currently not supported.");
|
|
|
|
|
Value _logSoftmax = getLogSoftmaxResult(op, rewriter);
|
|
|
|
|
if (!_logSoftmax)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "getLogSoftmaxResult function returned nullptr");
|
|
|
|
|
rewriter.replaceOp(op, _logSoftmax);
|
2022-02-10 15:05:23 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-01-30 01:10:50 +08:00
|
|
|
|
// Decompose aten.matmul into: aten.mm and aten.bmm according to ranks.
|
2021-10-21 13:15:10 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenMatmulOp : public OpRewritePattern<AtenMatmulOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenMatmulOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value lhs = op.self();
|
|
|
|
|
Value rhs = op.other();
|
|
|
|
|
|
|
|
|
|
int lhsRank = getTensorRank(lhs);
|
|
|
|
|
int rhsRank = getTensorRank(rhs);
|
|
|
|
|
|
|
|
|
|
// If both lhs and rhs ranks are 2 then map it to `aten.mm` op.
|
|
|
|
|
if (lhsRank == 2 && rhsRank == 2)
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenMmOp>(op, op.getType(), lhs, rhs);
|
|
|
|
|
|
|
|
|
|
// If both lhs and rhs ranks are 3 then map it to `aten.bmm` op.
|
|
|
|
|
if (lhsRank == 3 && rhsRank == 3)
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenBmmOp>(op, op.getType(), lhs, rhs);
|
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-02-15 21:14:32 +08:00
|
|
|
|
// ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6)
|
|
|
|
|
static Value getRelu6Results(PatternRewriter &rewriter, Location loc,
|
|
|
|
|
Value input) {
|
|
|
|
|
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
2022-02-09 04:57:23 +08:00
|
|
|
|
|
2022-02-15 21:14:32 +08:00
|
|
|
|
Value relu = rewriter.create<AtenReluOp>(loc, inputType, input);
|
2022-02-09 04:57:23 +08:00
|
|
|
|
Value cst6 =
|
2022-02-15 21:14:32 +08:00
|
|
|
|
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(6));
|
2022-02-09 04:57:23 +08:00
|
|
|
|
Value sixTensor = createRank0Tensor(rewriter, loc, inputType, cst6);
|
2022-02-15 21:14:32 +08:00
|
|
|
|
Value relu6Out =
|
|
|
|
|
rewriter.create<AtenMinimumOp>(loc, inputType, relu, sixTensor);
|
|
|
|
|
return relu6Out;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Hardswish(x) = x * Relu6(x+3)/6
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenHardswishOp : public OpRewritePattern<AtenHardswishOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenHardswishOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value input = op.self();
|
|
|
|
|
Type inputType = input.getType();
|
|
|
|
|
|
|
|
|
|
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
Value constantThree = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(3));
|
|
|
|
|
Value constantSix = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(6));
|
|
|
|
|
Value inputPlusThree = rewriter.create<AtenAddScalarOp>(
|
|
|
|
|
loc, inputType, input, constantThree, /*alpha=*/constantOne);
|
|
|
|
|
Value relu6 = getRelu6Results(rewriter, loc, inputPlusThree);
|
|
|
|
|
Value divTensor =
|
|
|
|
|
rewriter.create<AtenDivScalarOp>(loc, inputType, relu6, constantSix);
|
|
|
|
|
Value mulTensor =
|
|
|
|
|
rewriter.create<AtenMulTensorOp>(loc, inputType, divTensor, input);
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, mulTensor);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2021-12-17 12:08:07 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenTOp : public OpRewritePattern<AtenTOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenTOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value lhs = op.self();
|
|
|
|
|
int lhsRank = getTensorRank(lhs);
|
|
|
|
|
auto loc = op.getLoc();
|
|
|
|
|
|
|
|
|
|
if (lhsRank > 2 || lhsRank < 0) {
|
|
|
|
|
std::string errorMessage =
|
|
|
|
|
"t() expects a tensor with <=2 dimensions, but self is " +
|
|
|
|
|
std::to_string(lhsRank) + "D";
|
|
|
|
|
return rewriter.notifyMatchFailure(op, errorMessage.c_str());
|
|
|
|
|
} else if (lhsRank < 2)
|
|
|
|
|
rewriter.replaceOp(op, lhs);
|
|
|
|
|
else {
|
|
|
|
|
Value zero =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
Value one =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenTransposeIntOp>(op, op.getType(), lhs,
|
|
|
|
|
zero, one);
|
|
|
|
|
}
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-07-01 13:02:31 +08:00
|
|
|
|
// Decompose aten.repeat into aten.expand and aten.view ops.
|
|
|
|
|
//
|
|
|
|
|
// Ref: https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html
|
|
|
|
|
//
|
|
|
|
|
// For shape [S1, S2, S3] and repeats [M0, M1, M2, M3]
|
|
|
|
|
// MS0 = M0; MS1 = M1 * S1; MS2 = M2 * S2; MS3 = M3 * S3
|
|
|
|
|
//
|
|
|
|
|
// def aten_repeat(self, repeats):
|
|
|
|
|
// sizes = self.size()
|
|
|
|
|
// unsqueezed_sizes = []
|
|
|
|
|
// expanded_sizes = []
|
|
|
|
|
// reshape_sizes = []
|
|
|
|
|
// leading_rank = repeats.size() - sizes.size()
|
|
|
|
|
// for r in range(leading_rank):
|
|
|
|
|
// unsqueezed_sizes.append(1)
|
|
|
|
|
// expanded_sizes.append(repeats[r])
|
|
|
|
|
// reshaped_sizes.append(repeats[r])
|
|
|
|
|
//
|
|
|
|
|
// for s, m in zip(sizes, repeats[leading_rank:]):
|
|
|
|
|
// unsqueezed_sizes += [1, s]
|
|
|
|
|
// expanded_sizes += [m, s]
|
|
|
|
|
// reshaped_sizes += [m * s]
|
|
|
|
|
// return self.view(unsqueezed_sizes).expand(expanded_sizes).view(reshaped_sizes)
|
|
|
|
|
//
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenRepeatOp : public OpRewritePattern<AtenRepeatOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenRepeatOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value self = op.self();
|
|
|
|
|
MLIRContext *context = op.getContext();
|
2022-07-07 00:06:10 +08:00
|
|
|
|
int rank = getTensorRank(self);
|
2022-07-01 13:02:31 +08:00
|
|
|
|
if (rank < 0)
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor");
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> repeats;
|
|
|
|
|
if (!getListConstructElements(op.repeats(), repeats))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Unimplemented: repeats not list of Scalar");
|
|
|
|
|
|
2022-07-07 00:06:10 +08:00
|
|
|
|
if (rank > (int)repeats.size()) {
|
2022-07-01 13:02:31 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "repeats are not matched with self's rank");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto insertDimSizes = [](SmallVector<Value> &dimSizes,
|
|
|
|
|
SmallVector<int64_t> &shape,
|
|
|
|
|
const ArrayRef<Value> &vals) {
|
|
|
|
|
dimSizes.insert(dimSizes.end(), vals.begin(), vals.end());
|
|
|
|
|
std::transform(vals.begin(), vals.end(), std::back_inserter(shape),
|
|
|
|
|
[&](Value val) -> int64_t {
|
|
|
|
|
int64_t cst_val;
|
|
|
|
|
if (matchPattern(val, m_TorchConstantInt(&cst_val))) {
|
|
|
|
|
return cst_val;
|
|
|
|
|
} else {
|
|
|
|
|
return ShapedType::kDynamicSize;
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
Value one = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> unsqueezedSizes, expandedSizes, reshapedSizes;
|
|
|
|
|
SmallVector<int64_t> unsqueezedIntSizes, expandedIntSizes;
|
|
|
|
|
auto leadingRank = repeats.size() - rank;
|
|
|
|
|
assert(leadingRank >= 0 && "leadingRank should greater than 0");
|
|
|
|
|
for (size_t i = 0; i < leadingRank; ++i) {
|
|
|
|
|
insertDimSizes(unsqueezedSizes, unsqueezedIntSizes, ArrayRef<Value>{one});
|
|
|
|
|
insertDimSizes(expandedSizes, expandedIntSizes, ArrayRef<Value>{repeats[i]});
|
|
|
|
|
reshapedSizes.push_back(repeats[i]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto selfType = self.getType().dyn_cast<BaseTensorType>();
|
|
|
|
|
auto selfShape = selfType.getSizes();
|
2022-07-07 00:06:10 +08:00
|
|
|
|
for (int i = 0; i < rank; i++) {
|
2022-07-01 13:02:31 +08:00
|
|
|
|
auto scale = repeats[i + leadingRank];
|
|
|
|
|
Value dimSize;
|
|
|
|
|
if (selfShape[i] == ShapedType::kDynamicSize) {
|
|
|
|
|
Value dim = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(i));
|
|
|
|
|
dimSize = rewriter.create<AtenSizeIntOp>(loc, self, dim);
|
|
|
|
|
} else {
|
|
|
|
|
dimSize = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(selfShape[i]));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
insertDimSizes(unsqueezedSizes, unsqueezedIntSizes, ArrayRef<Value>{one, dimSize});
|
|
|
|
|
insertDimSizes(expandedSizes, expandedIntSizes, ArrayRef<Value>{scale, dimSize});
|
|
|
|
|
|
|
|
|
|
Value scaledSize = rewriter.create<AtenMulIntOp>(loc, dimSize, scale);
|
|
|
|
|
reshapedSizes.push_back(scaledSize);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Type dtype = self.getType().cast<ValueTensorType>().getDtype();
|
|
|
|
|
Type unsqueezedType =
|
|
|
|
|
ValueTensorType::get(context, llvm::makeArrayRef(unsqueezedIntSizes), dtype);
|
|
|
|
|
Type expandedType =
|
|
|
|
|
ValueTensorType::get(context, llvm::makeArrayRef(expandedIntSizes), dtype);
|
|
|
|
|
|
|
|
|
|
auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext()));
|
|
|
|
|
Value unsqueezedDims =
|
|
|
|
|
rewriter.create<PrimListConstructOp>(loc, listType, unsqueezedSizes);
|
|
|
|
|
Value expandedDims =
|
|
|
|
|
rewriter.create<PrimListConstructOp>(loc, listType, expandedSizes);
|
|
|
|
|
Value reshapedDims =
|
|
|
|
|
rewriter.create<PrimListConstructOp>(loc, listType, reshapedSizes);
|
|
|
|
|
auto reshaped =
|
|
|
|
|
rewriter.create<AtenViewOp>(loc, unsqueezedType, op.self(), unsqueezedDims);
|
|
|
|
|
auto expanded = rewriter.create<AtenBroadcastToOp>(loc, expandedType,
|
|
|
|
|
reshaped, expandedDims);
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenViewOp>(op, op.getType(), expanded,
|
|
|
|
|
reshapedDims);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-01-30 01:10:50 +08:00
|
|
|
|
// Decompose aten.expand into aten.broadcast_to op.
|
2021-11-03 00:48:29 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenExpandOp : public OpRewritePattern<AtenExpandOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenExpandOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
bool implicit = false;
|
|
|
|
|
if (!matchPattern(op.implicit(), m_TorchConstantBool(&implicit)) ||
|
|
|
|
|
implicit) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: requires implicit to be false");
|
|
|
|
|
}
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenBroadcastToOp>(op, op.getType(), op.self(),
|
|
|
|
|
op.size());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-03-12 01:21:36 +08:00
|
|
|
|
// Decompose aten.where.Scalar into aten.where.self op.
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenWhereScalarOp : public OpRewritePattern<AtenWhereScalarOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenWhereScalarOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
auto resType = op.getType().cast<BaseTensorType>();
|
|
|
|
|
Value selfTensor = createRank0Tensor(rewriter, loc, resType, op.self());
|
|
|
|
|
Value otherTensor = createRank0Tensor(rewriter, loc, resType, op.other());
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.condition(),
|
|
|
|
|
selfTensor, otherTensor);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
// Decompose aten.where.ScalarOther into aten.where.self op.
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenWhereScalarOtherOp
|
|
|
|
|
: public OpRewritePattern<AtenWhereScalarOtherOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenWhereScalarOtherOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
auto resType = op.getType().cast<BaseTensorType>();
|
|
|
|
|
Value otherTensor = createRank0Tensor(rewriter, loc, resType, op.other());
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.condition(),
|
|
|
|
|
op.self(), otherTensor);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
// Decompose aten.where.ScalarSelf into aten.where.self op.
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenWhereScalarSelfOp
|
|
|
|
|
: public OpRewritePattern<AtenWhereScalarSelfOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenWhereScalarSelfOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
auto resType = op.getType().cast<BaseTensorType>();
|
|
|
|
|
Value selfTensor = createRank0Tensor(rewriter, loc, resType, op.self());
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.condition(),
|
|
|
|
|
selfTensor, op.other());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-04-08 12:47:57 +08:00
|
|
|
|
// Decompose aten.convolution_overrideable to aten.convolution
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenConvolutionOverrideableOp
|
|
|
|
|
: public OpRewritePattern<AtenConvolutionOverrideableOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenConvolutionOverrideableOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
|
|
|
|
|
op, op->getResultTypes(), op.input(), op.weight(), op.bias(),
|
|
|
|
|
op.stride(), op.padding(), op.dilation(), op.transposed(),
|
|
|
|
|
op.output_padding(), op.groups());
|
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-07-08 14:44:03 +08:00
|
|
|
|
// Decompose aten.convolution_overrideable to aten.convolution
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAten_ConvolutionOp
|
|
|
|
|
: public OpRewritePattern<Aten_ConvolutionOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(Aten_ConvolutionOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
|
|
|
|
|
op, op->getResultTypes(), op.input(), op.weight(), op.bias(),
|
|
|
|
|
op.stride(), op.padding(), op.dilation(), op.transposed(),
|
|
|
|
|
op.output_padding(), op.groups());
|
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-04-08 12:47:57 +08:00
|
|
|
|
// Decompose aten.conv2d to aten.convolution
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenConv2dOp : public OpRewritePattern<AtenConv2dOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenConv2dOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
|
|
Value emptyList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
|
|
|
|
SmallVector<Value>());
|
|
|
|
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
|
|
|
|
|
op, op->getResultTypes(), op.input(), op.weight(), op.bias(),
|
|
|
|
|
op.stride(), op.padding(), op.dilation(), cstFalse, emptyList,
|
|
|
|
|
op.groups());
|
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-01-30 01:10:50 +08:00
|
|
|
|
// Decompose aten.addmm into aten.mm and aten.add.Tensor op.
|
2021-11-11 17:02:13 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenAddmmOp : public OpRewritePattern<AtenAddmmOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenAddmmOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value input = op.self();
|
|
|
|
|
Value mat1 = op.mat1();
|
|
|
|
|
Value mat2 = op.mat2();
|
|
|
|
|
|
|
|
|
|
// The operands `mat1`, `mat2` to aten.addmm must be of rank 2.
|
|
|
|
|
if (getTensorRank(mat1) != 2 || getTensorRank(mat2) != 2) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "expected mat1, mat2 operands to aten.addmm to be rank 2");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO: Handle integer type operands.
|
|
|
|
|
if (!input.getType()
|
|
|
|
|
.cast<ValueTensorType>()
|
|
|
|
|
.getDtype()
|
|
|
|
|
.isa<mlir::FloatType>()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: non-floating point dtype");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// matrix multiplication: matmul = mat1 @ mat2
|
|
|
|
|
Value matmul = rewriter.create<AtenMmOp>(loc, op.getType(), mat1, mat2);
|
|
|
|
|
// scaledInput = self * beta
|
|
|
|
|
Value scaledInput = rewriter.create<AtenMulScalarOp>(loc, input.getType(),
|
|
|
|
|
input, op.beta());
|
|
|
|
|
// result = scaledInput + alpha * matmul
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenAddTensorOp>(op, op.getType(), scaledInput,
|
|
|
|
|
matmul, op.alpha());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-01-30 01:10:50 +08:00
|
|
|
|
// Decompose aten.mean into: sum(x)/div(numTensorElements).
|
2021-11-19 23:59:29 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenMeanOp : public OpRewritePattern<AtenMeanOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenMeanOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value input = op.self();
|
|
|
|
|
Value output = op.result();
|
|
|
|
|
BaseTensorType outputTensorType = output.getType().cast<BaseTensorType>();
|
2022-02-15 21:14:32 +08:00
|
|
|
|
Value sum =
|
|
|
|
|
rewriter.create<AtenSumOp>(loc, outputTensorType, input, op.dtype());
|
2021-11-19 23:59:29 +08:00
|
|
|
|
Value numTensorElements = rewriter.create<AtenNumelOp>(loc, input);
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, outputTensorType, sum,
|
|
|
|
|
numTensorElements);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-03-11 01:25:21 +08:00
|
|
|
|
// productDimSize = product(size(dim) for dim in dims)
|
|
|
|
|
// aten.mean(x, dims) = aten.sum(x, dims) / productDimSize.
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenMeanDimOp : public OpRewritePattern<AtenMeanDimOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenMeanDimOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value input = op.self();
|
|
|
|
|
Value dimList = op.dim();
|
|
|
|
|
Value keepDim = op.keepdim();
|
|
|
|
|
Value dtype = op.dtype();
|
|
|
|
|
Type outputType = op.getType();
|
|
|
|
|
MLIRContext *context = op.getContext();
|
|
|
|
|
|
|
|
|
|
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>() ||
|
|
|
|
|
!isNoneOrFloatDtype(context, dtype)) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only floating-point type is supported");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto dimListConstruct = dimList.getDefiningOp<PrimListConstructOp>();
|
|
|
|
|
if (!dimListConstruct) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "expect dimList to be constructed from list construct");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Compute sum along dimensions specified in `dimList`.
|
|
|
|
|
Value sumAlongDims = rewriter.create<AtenSumDimIntListOp>(
|
|
|
|
|
loc, outputType, input, dimList, keepDim, dtype);
|
|
|
|
|
|
|
|
|
|
// `productDimSize` is product of sizes of dimensions to be reduced.
|
|
|
|
|
Value productDimSize = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
for (Value dim : dimListConstruct.elements()) {
|
|
|
|
|
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, input, dim);
|
|
|
|
|
productDimSize =
|
|
|
|
|
rewriter.create<AtenMulIntOp>(loc, productDimSize, dimSize);
|
|
|
|
|
}
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, outputType, sumAlongDims,
|
|
|
|
|
productDimSize);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-01-30 01:10:50 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenSquareOp : public OpRewritePattern<AtenSquareOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenSquareOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value self = op.self();
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenMulTensorOp>(op, op.getType(), self, self);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-03-02 01:30:58 +08:00
|
|
|
|
// Silu(x) = sigmoid(x) * x
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenSiluOp : public OpRewritePattern<AtenSiluOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenSiluOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value self = op.self();
|
|
|
|
|
Value sigmoid =
|
|
|
|
|
rewriter.create<AtenSigmoidOp>(op.getLoc(), op.getType(), self);
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenMulTensorOp>(op, op.getType(), sigmoid,
|
|
|
|
|
self);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-02-17 00:34:03 +08:00
|
|
|
|
// pDash = 1.0 - p
|
|
|
|
|
// boolMask = aten.rand_like(input) < pDash
|
|
|
|
|
// dropout(input, p, train=True) = (boolMask * input) / pDash
|
|
|
|
|
// dropout(input, p, train=False) = input
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenDropoutOp : public OpRewritePattern<AtenDropoutOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenDropoutOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value input = op.input();
|
|
|
|
|
Value prob = op.p();
|
|
|
|
|
bool train = false;
|
|
|
|
|
if (!matchPattern(op.train(), m_TorchConstantBool(&train)))
|
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"train must be a boolean constant");
|
|
|
|
|
if (!train) {
|
|
|
|
|
rewriter.replaceOp(op, input);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>())
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only support floating type input for training mode");
|
|
|
|
|
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
|
Value floatOne =
|
|
|
|
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
|
|
|
|
Value oneMinusP = rewriter.create<AtenSubFloatOp>(loc, floatOne, prob);
|
|
|
|
|
Value boolMask = rewriter.create<ValsemVariantAtenBernoulliFloatOp>(
|
|
|
|
|
loc, inputType, input, oneMinusP, /*generator=*/noneVal);
|
|
|
|
|
Value maskedInput =
|
|
|
|
|
rewriter.create<AtenMulTensorOp>(loc, inputType, boolMask, input);
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, op.getType(), maskedInput,
|
|
|
|
|
oneMinusP);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-01-30 01:10:50 +08:00
|
|
|
|
// Decompose aten.var into: sum(square(x - mean))/(numTensorElements-1)
|
|
|
|
|
// for unbiased and mean(square(x - mean)) for biased case.
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenVarOp : public OpRewritePattern<AtenVarOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenVarOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value self = op.self();
|
|
|
|
|
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!inputTensorTy.hasDtype() ||
|
|
|
|
|
!inputTensorTy.getDtype().isa<mlir::FloatType>()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"Only aten.var support floating type");
|
|
|
|
|
}
|
|
|
|
|
BaseTensorType rank0FloatTensorTy = op.getType().cast<BaseTensorType>();
|
2022-03-10 08:44:22 +08:00
|
|
|
|
if (!rank0FloatTensorTy.hasSizes() ||
|
|
|
|
|
rank0FloatTensorTy.getSizes().size() != 0) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "expected aten.var to have a rank 0 tensor type");
|
|
|
|
|
}
|
2022-01-30 01:10:50 +08:00
|
|
|
|
|
|
|
|
|
bool unbiased;
|
|
|
|
|
if (!matchPattern(op.unbiased(), m_TorchConstantBool(&unbiased))) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Only support constant unbiased for aten.var");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value dtype = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
|
Value mean =
|
|
|
|
|
rewriter.create<AtenMeanOp>(loc, rank0FloatTensorTy, self, dtype);
|
|
|
|
|
Value subMean = createTensorSub(rewriter, loc, inputTensorTy, self, mean);
|
|
|
|
|
Value square = rewriter.create<AtenSquareOp>(loc, inputTensorTy, subMean);
|
|
|
|
|
Value var;
|
|
|
|
|
if (unbiased) {
|
|
|
|
|
// Bessel’s correction is used. Divide the square sum by
|
|
|
|
|
// numTensorElements-1.
|
|
|
|
|
Value squareSum =
|
|
|
|
|
rewriter.create<AtenSumOp>(loc, rank0FloatTensorTy, square, dtype);
|
|
|
|
|
Value numTensorElements = rewriter.create<AtenNumelOp>(loc, square);
|
|
|
|
|
Value cst1 = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
Value numTensorElementsSub1 =
|
|
|
|
|
rewriter.create<AtenSubIntOp>(loc, numTensorElements, cst1);
|
|
|
|
|
var = rewriter.replaceOpWithNewOp<AtenDivScalarOp>(
|
|
|
|
|
op, rank0FloatTensorTy, squareSum, numTensorElementsSub1);
|
|
|
|
|
} else {
|
|
|
|
|
var = rewriter.replaceOpWithNewOp<AtenMeanOp>(op, rank0FloatTensorTy,
|
|
|
|
|
square, dtype);
|
|
|
|
|
}
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
// Decompose aten.std to sqrt(var(x))
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenStdOp : public OpRewritePattern<AtenStdOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenStdOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value self = op.self();
|
|
|
|
|
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!inputTensorTy.hasDtype() ||
|
|
|
|
|
!inputTensorTy.getDtype().isa<mlir::FloatType>()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"Only aten.std support floating type");
|
|
|
|
|
}
|
|
|
|
|
Value var = rewriter.create<AtenVarOp>(op->getLoc(), op.getType(),
|
|
|
|
|
op.self(), op.unbiased());
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenSqrtOp>(op, op.getType(), var);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-02-14 22:46:44 +08:00
|
|
|
|
// Hardsigmoid(x) = max(0, min(1, (x+3)/6))
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenHardsigmoidOp : public OpRewritePattern<AtenHardsigmoidOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenHardsigmoidOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value input = op.self();
|
2022-02-09 04:57:23 +08:00
|
|
|
|
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
2022-02-14 22:46:44 +08:00
|
|
|
|
|
|
|
|
|
// outputTensor = (input + 3) / 6.
|
|
|
|
|
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
Value constantThree = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(3));
|
|
|
|
|
Value constantSix = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(6));
|
|
|
|
|
Value inputPlusThree = rewriter.create<AtenAddScalarOp>(
|
|
|
|
|
loc, inputType, input, constantThree, /*alpha=*/constantOne);
|
|
|
|
|
Value outputTensor = rewriter.create<AtenDivScalarOp>(
|
|
|
|
|
loc, inputType, inputPlusThree, constantSix);
|
|
|
|
|
|
|
|
|
|
// result = max(0, min(1, (input+3)/6))
|
2022-02-09 04:57:23 +08:00
|
|
|
|
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
Value oneTensor = createRank0Tensor(rewriter, loc, inputType, constantOne);
|
2022-02-14 22:46:44 +08:00
|
|
|
|
Value minResult =
|
|
|
|
|
rewriter.create<AtenMinimumOp>(loc, inputType, oneTensor, outputTensor);
|
2022-02-09 04:57:23 +08:00
|
|
|
|
Value zeroTensor =
|
|
|
|
|
createRank0Tensor(rewriter, loc, inputType, constantZero);
|
2022-02-14 22:46:44 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenMaximumOp>(op, op.getType(), zeroTensor,
|
|
|
|
|
minResult);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-02-09 04:57:23 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenHardtanhOp : public OpRewritePattern<AtenHardtanhOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenHardtanhOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value input = op.self();
|
|
|
|
|
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
|
|
|
|
|
// result = min(maxVal, max(minVal, x))
|
|
|
|
|
Value minVal = createRank0Tensor(rewriter, loc, inputType, op.min_val());
|
|
|
|
|
Value maxResult =
|
|
|
|
|
rewriter.create<AtenMaximumOp>(loc, inputType, input, minVal);
|
|
|
|
|
Value maxVal = createRank0Tensor(rewriter, loc, inputType, op.max_val());
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenMinimumOp>(op, op.getType(), maxVal,
|
|
|
|
|
maxResult);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-02-26 00:35:04 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenRandLikeOp : public OpRewritePattern<AtenRandLikeOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenRandLikeOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value input = op.self();
|
2022-05-13 07:00:59 +08:00
|
|
|
|
Type resultType = op.getType();
|
2022-02-26 00:35:04 +08:00
|
|
|
|
auto inputType = input.getType().cast<BaseTensorType>();
|
2022-05-13 07:00:59 +08:00
|
|
|
|
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
|
2022-02-26 00:35:04 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"only support floating-point type");
|
|
|
|
|
}
|
|
|
|
|
|
2022-05-13 07:00:59 +08:00
|
|
|
|
// Create a uniform random op with low and high set to 0.0 and 1.0,
|
2022-02-26 00:35:04 +08:00
|
|
|
|
// respectively.
|
|
|
|
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
2022-05-13 07:00:59 +08:00
|
|
|
|
Value zero =
|
2022-02-26 00:35:04 +08:00
|
|
|
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
|
2022-05-13 07:00:59 +08:00
|
|
|
|
Value one =
|
2022-02-26 00:35:04 +08:00
|
|
|
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
2022-05-13 07:00:59 +08:00
|
|
|
|
Value emptyTensor = rewriter.create<AtenEmptyLikeOp>(
|
|
|
|
|
loc, resultType, input, op.dtype(), op.layout(), op.device(),
|
|
|
|
|
op.pin_memory(), op.memory_format());
|
2022-03-16 07:57:33 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<ValsemVariantAtenUniformOp>(
|
2022-05-13 07:00:59 +08:00
|
|
|
|
op, resultType, emptyTensor, /*from=*/zero, /*to=*/one,
|
|
|
|
|
/*generator=*/none);
|
2022-02-26 00:35:04 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
// Bernoulli(x, p) = (rand_like(float(x)) < p).cast(type(x)). Here,
|
|
|
|
|
// 1. p must be a float tensor.
|
|
|
|
|
// 2. The shape of p should be broadcastable to the shape of x.
|
|
|
|
|
// 3. Bernoulli(x, p) returns a tensor of the same type as that of x.
|
2022-02-09 04:57:23 +08:00
|
|
|
|
static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter,
|
|
|
|
|
Operation *op, Location loc,
|
2022-02-26 00:35:04 +08:00
|
|
|
|
Value input, Value prob,
|
|
|
|
|
Value &output) {
|
|
|
|
|
auto inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
auto probType = prob.getType().cast<BaseTensorType>();
|
|
|
|
|
// Both the `input` and `prob` must be ranked tensors.
|
|
|
|
|
if (!inputType.hasSizes() || !inputType.hasDtype() || !probType.hasSizes() ||
|
|
|
|
|
!probType.hasDtype()) {
|
2022-02-09 04:57:23 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
2022-02-26 00:35:04 +08:00
|
|
|
|
op, "can't decompose bernoulli like ops without sizes or dtype");
|
|
|
|
|
}
|
|
|
|
|
// The `prob` is expected to be a float type tensor.
|
|
|
|
|
if (!probType.getDtype().isa<mlir::FloatType>()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "probabilities must be a float type tensor");
|
2022-02-09 04:57:23 +08:00
|
|
|
|
}
|
2022-02-04 19:43:25 +08:00
|
|
|
|
|
2022-02-26 00:35:04 +08:00
|
|
|
|
// Since the `aten.rand_like` op expects float-type operand, create a
|
|
|
|
|
// float-type tensor with the same shape as that of the `input`.
|
|
|
|
|
Value floatTensor =
|
|
|
|
|
convertTensorToDtype(rewriter, loc, input, rewriter.getF64Type());
|
|
|
|
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
|
Value randomVal = rewriter.create<AtenRandLikeOp>(
|
|
|
|
|
loc, floatTensor.getType(), floatTensor, /*dtype=*/none, /*layout=*/none,
|
|
|
|
|
/*device=*/none, /*pin_memory=*/none, /*memory_format=*/none);
|
|
|
|
|
|
|
|
|
|
// Bernoulli(x, p) = rand_like(float(x)) < p.
|
|
|
|
|
auto boolResType = inputType.getWithSizesAndDtype(inputType.getSizes(),
|
|
|
|
|
rewriter.getI1Type());
|
|
|
|
|
Value lessThanP =
|
|
|
|
|
rewriter.create<AtenLtTensorOp>(loc, boolResType, randomVal, prob);
|
|
|
|
|
|
|
|
|
|
// As the `output` is expected to be of the `input` type, convert the boolean
|
|
|
|
|
// tensor `lessThanP` to a `input` type tensor.
|
|
|
|
|
output = convertTensorToDtype(rewriter, loc, lessThanP, inputType.getDtype());
|
2022-02-09 04:57:23 +08:00
|
|
|
|
return success();
|
2022-02-04 19:43:25 +08:00
|
|
|
|
}
|
|
|
|
|
|
2022-02-26 00:35:04 +08:00
|
|
|
|
// aten.bernoulli(x) = rand_like(x) < x. Here, the input x is a tensor
|
|
|
|
|
// containing probabilities to be used for drawing the binary random number.
|
2022-02-04 19:43:25 +08:00
|
|
|
|
class DecomposeAtenBernoulliOp : public OpRewritePattern<AtenBernoulliOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenBernoulliOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2022-02-26 00:35:04 +08:00
|
|
|
|
Value input = op.self();
|
|
|
|
|
if (!op.generator().getType().isa<Torch::NoneType>())
|
2022-02-04 19:43:25 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "The generator has to ben None because only global default "
|
|
|
|
|
"generator is supported");
|
2022-02-26 00:35:04 +08:00
|
|
|
|
Value output;
|
|
|
|
|
if (failed(
|
|
|
|
|
decomposeBernoulliLikeOp(rewriter, op, loc, input, input, output)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "decomposeBernoulliLikeOp failed to decompose the op");
|
|
|
|
|
rewriter.replaceOp(op, output);
|
2022-02-04 19:43:25 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
2022-02-26 00:35:04 +08:00
|
|
|
|
// aten.bernoulli.float(x, p) = (rand_like(float(x)) < tensor(p)).cast(type(x)).
|
|
|
|
|
// Since the input x can be an integer tensor, it's important to cast it to
|
|
|
|
|
// float type before passing it to the `aten.rand_like` op.
|
2022-03-16 07:57:33 +08:00
|
|
|
|
class DecomposeValsemVariantAtenBernoulliFloatOp
|
|
|
|
|
: public OpRewritePattern<ValsemVariantAtenBernoulliFloatOp> {
|
2022-02-04 19:43:25 +08:00
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
2022-03-16 07:57:33 +08:00
|
|
|
|
LogicalResult matchAndRewrite(ValsemVariantAtenBernoulliFloatOp op,
|
2022-02-04 19:43:25 +08:00
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2022-02-26 00:35:04 +08:00
|
|
|
|
Value input = op.self();
|
|
|
|
|
Value p = op.p();
|
|
|
|
|
if (!op.generator().getType().isa<Torch::NoneType>())
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "The generator has to ben None because only global default "
|
|
|
|
|
"generator is supported");
|
2022-02-04 19:43:25 +08:00
|
|
|
|
|
2022-02-26 00:35:04 +08:00
|
|
|
|
auto inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
SmallVector<int64_t> empty;
|
|
|
|
|
Type tensorType = inputType.getWithSizesAndDtype(llvm::makeArrayRef(empty),
|
|
|
|
|
rewriter.getF64Type());
|
|
|
|
|
Value prob = rewriter.create<PrimNumToTensorScalarOp>(loc, tensorType, p);
|
|
|
|
|
Value output;
|
|
|
|
|
if (failed(
|
|
|
|
|
decomposeBernoulliLikeOp(rewriter, op, loc, input, prob, output)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "decomposeBernoulliLikeOp failed to decompose the op");
|
|
|
|
|
rewriter.replaceOp(op, output);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// aten.bernoulli.Tensor(x, p) = (rand_like(float(x)) < p).cast(type(x)).
|
|
|
|
|
// Since the input x can be an integer tensor, it's important to cast it to
|
|
|
|
|
// float type before passing it to the `aten.rand_like` op.
|
2022-03-16 07:57:33 +08:00
|
|
|
|
class DecomposeValsemVariantAtenBernoulliTensorOp
|
|
|
|
|
: public OpRewritePattern<ValsemVariantAtenBernoulliTensorOp> {
|
2022-02-26 00:35:04 +08:00
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
2022-03-16 07:57:33 +08:00
|
|
|
|
LogicalResult matchAndRewrite(ValsemVariantAtenBernoulliTensorOp op,
|
2022-02-26 00:35:04 +08:00
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value input = op.self();
|
|
|
|
|
Value prob = op.p();
|
|
|
|
|
if (!op.generator().getType().isa<Torch::NoneType>())
|
2022-02-04 19:43:25 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "The generator has to ben None because only global default "
|
|
|
|
|
"generator is supported");
|
2022-02-26 00:35:04 +08:00
|
|
|
|
Value output;
|
|
|
|
|
if (failed(
|
|
|
|
|
decomposeBernoulliLikeOp(rewriter, op, loc, input, prob, output)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "decomposeBernoulliLikeOp failed to decompose the op");
|
|
|
|
|
rewriter.replaceOp(op, output);
|
2022-02-04 19:43:25 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2021-11-25 06:01:48 +08:00
|
|
|
|
namespace {
|
2022-02-15 21:14:32 +08:00
|
|
|
|
template <typename OpTy, typename T1T2Op>
|
2021-11-25 06:01:48 +08:00
|
|
|
|
class DecomposeAtenAddCLikeOp : public OpRewritePattern<OpTy> {
|
|
|
|
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(OpTy op,
|
2022-02-15 21:14:32 +08:00
|
|
|
|
PatternRewriter &rewriter) const override {
|
2021-11-25 06:01:48 +08:00
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value input = op.self();
|
|
|
|
|
Value tensor1 = op.tensor1();
|
|
|
|
|
Value tensor2 = op.tensor2();
|
|
|
|
|
Value value = op.value();
|
|
|
|
|
|
2022-02-15 21:14:32 +08:00
|
|
|
|
Value product =
|
|
|
|
|
rewriter.create<T1T2Op>(loc, op.getType(), tensor1, tensor2);
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenAddTensorOp>(op, op.getType(), input,
|
|
|
|
|
product, value);
|
2021-11-25 06:01:48 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
2021-12-10 21:36:19 +08:00
|
|
|
|
|
|
|
|
|
class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
|
|
|
|
|
using OpRewritePattern<AtenLayerNormOp>::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenLayerNormOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
|
|
|
|
|
auto input = op.input().getType().cast<BaseTensorType>();
|
|
|
|
|
if (!input.hasSizes())
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "input tensor should have known sizes.");
|
|
|
|
|
int64_t inputRank = input.getSizes().size();
|
|
|
|
|
Value normalizedShape = op.normalized_shape();
|
|
|
|
|
SmallVector<Value> normalizedShapeSizesTorchInt;
|
|
|
|
|
getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt);
|
2022-03-16 20:51:57 +08:00
|
|
|
|
int64_t axis = inputRank - normalizedShapeSizesTorchInt.size();
|
|
|
|
|
std::vector<int64_t> meanVarSizes(inputRank, 1);
|
|
|
|
|
for (int i = 0; i < axis; i++)
|
|
|
|
|
meanVarSizes[i] = input.getSizes()[i];
|
2021-12-10 21:36:19 +08:00
|
|
|
|
auto meanVarType = input.getWithSizesAndDtype(
|
|
|
|
|
llvm::makeArrayRef(meanVarSizes), input.getDtype());
|
|
|
|
|
auto nativeLayerNorm = rewriter.create<AtenNativeLayerNormOp>(
|
|
|
|
|
loc, op.getType(), meanVarType, meanVarType, op.input(),
|
|
|
|
|
op.normalized_shape(), op.weight(), op.bias(), op.eps());
|
|
|
|
|
rewriter.replaceOp(op, nativeLayerNorm.getResult(0));
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
2021-11-25 06:01:48 +08:00
|
|
|
|
} // namespace
|
|
|
|
|
|
2021-12-14 03:01:10 +08:00
|
|
|
|
namespace {
|
2021-12-21 19:51:19 +08:00
|
|
|
|
// Decompose `aten.empty_like` op into `aten.size` and `aten.empty` ops.
|
2021-12-14 03:01:10 +08:00
|
|
|
|
class DecomposeAtenEmptyLikeOp : public OpRewritePattern<AtenEmptyLikeOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenEmptyLikeOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
auto sizeListType =
|
|
|
|
|
Torch::ListType::get(Torch::IntType::get(op.getContext()));
|
|
|
|
|
Value sizeList =
|
|
|
|
|
rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.self());
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
|
|
|
|
|
op, op.getType(), sizeList, op.dtype(), op.layout(), op.device(),
|
|
|
|
|
op.pin_memory(), op.memory_format());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2021-12-23 21:22:45 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// The `aten.arange` op is converted to `aten.arange.start_step` op.
|
|
|
|
|
class DecomposeAtenArangeOp : public OpRewritePattern<AtenArangeOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenArangeOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
// The AtenArangeOp doesn't have a start and step value. Therefore we set
|
|
|
|
|
// them as default values 0 and 1, respectively.
|
|
|
|
|
Value start, step;
|
|
|
|
|
start = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
step = rewriter.create<Torch::ConstantIntOp>(loc,
|
|
|
|
|
rewriter.getI64IntegerAttr(1));
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenArangeStartStepOp>(
|
|
|
|
|
op, op.getType(), start, op.end(), step, op.dtype(), op.layout(),
|
|
|
|
|
op.device(), op.pin_memory());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
// The `aten.arange.start` op is converted to `aten.arange.start_step` op.
|
|
|
|
|
class DecomposeAtenArangeStartOp : public OpRewritePattern<AtenArangeStartOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenArangeStartOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
// The AtenArangeStartOp doesn't have a step value. Therefore we set it as
|
|
|
|
|
// default value 1.
|
|
|
|
|
Value step;
|
|
|
|
|
step = rewriter.create<Torch::ConstantIntOp>(loc,
|
|
|
|
|
rewriter.getI64IntegerAttr(1));
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenArangeStartStepOp>(
|
|
|
|
|
op, op.getType(), op.start(), op.end(), step, op.dtype(), op.layout(),
|
|
|
|
|
op.device(), op.pin_memory());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2021-12-21 19:51:19 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose constant tensor allocation like ops.
|
|
|
|
|
template <typename OpTy, int fillVal>
|
|
|
|
|
class DecomposeConstantTensorAllocLikeOp : public OpRewritePattern<OpTy> {
|
|
|
|
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(OpTy op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
// Allocate a memory block.
|
|
|
|
|
Value initTensor = rewriter.create<AtenEmptyLikeOp>(
|
|
|
|
|
loc, op.getType(), op.self(), op.dtype(), op.layout(), op.device(),
|
|
|
|
|
op.pin_memory(), op.memory_format());
|
|
|
|
|
Value constVal = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(fillVal));
|
|
|
|
|
// Initialize the allocated memory block with `fillVal`.
|
2022-03-16 07:57:33 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<ValsemVariantAtenFillScalarOp>(
|
2022-02-16 03:58:03 +08:00
|
|
|
|
op, initTensor.getType(), initTensor, constVal);
|
2021-12-21 19:51:19 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-02-08 00:08:10 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenNativeBatchNormOp
|
|
|
|
|
: public OpRewritePattern<AtenNativeBatchNormOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenNativeBatchNormOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
MLIRContext *context = op.getContext();
|
|
|
|
|
Value input = op.input();
|
|
|
|
|
Value weight = op.weight();
|
|
|
|
|
Value bias = op.bias();
|
|
|
|
|
Value runningMean = op.running_mean();
|
|
|
|
|
Value runningVar = op.running_var();
|
|
|
|
|
Value eps = op.eps();
|
|
|
|
|
|
|
|
|
|
// TODO: Add support for `training` mode.
|
|
|
|
|
bool training = false;
|
|
|
|
|
if (!matchPattern(op.training(), m_TorchConstantBool(&training)) ||
|
|
|
|
|
training)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: training mode is not supported");
|
|
|
|
|
|
|
|
|
|
// Rank of the input tensor must be greater than or equal to 2. The shape of
|
|
|
|
|
// the `input` is supposed to be (N, C, D?, H?, W?).
|
|
|
|
|
int64_t inputRank = getTensorRank(input);
|
|
|
|
|
if (inputRank < 2)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "input must have rank greater than or equal to 2");
|
|
|
|
|
|
|
|
|
|
// In the inference mode, the `runningMean` and `runningVar` must not be
|
|
|
|
|
// None.
|
|
|
|
|
if (runningMean.getType().isa<Torch::NoneType>() ||
|
|
|
|
|
runningVar.getType().isa<Torch::NoneType>())
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "running stats must not be None in inference mode");
|
|
|
|
|
|
|
|
|
|
// Rank of `runningMean` and `runningVar` must be exactly 1.
|
|
|
|
|
if (getTensorRank(runningMean) != 1 || getTensorRank(runningVar) != 1)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "expected running_mean and running_var to be rank 1");
|
|
|
|
|
|
|
|
|
|
Value zero =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
Value one =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
Value numFeatures = rewriter.create<AtenSizeIntOp>(loc, input, /*dim=*/one);
|
2022-02-25 03:41:55 +08:00
|
|
|
|
// TODO: Add Runtime Asserts to check the shape of weight, bias,
|
|
|
|
|
// running_mean and running_var to be (numFeatures).
|
2022-02-08 00:08:10 +08:00
|
|
|
|
|
|
|
|
|
// The `runningMean` and `runningVar` must be reshaped to (1, C, 1?, 1?, 1?)
|
|
|
|
|
// to make it broadcast-compatible with (N, C, D?, H?, W?).
|
|
|
|
|
// 1. runningMean = runningMean.view(1, C, 1?, 1?, 1?)
|
|
|
|
|
// 2. runningVar = runningVar.view(1, C, 1?, 1?, 1?)
|
|
|
|
|
SmallVector<Value> runningStatsShape(inputRank, one);
|
|
|
|
|
runningStatsShape[1] = numFeatures;
|
|
|
|
|
Value runningStatsSizeList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, ListType::get(IntType::get(context)), runningStatsShape);
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t> runningStatsShapeInt(inputRank, 1);
|
|
|
|
|
runningStatsShapeInt[1] = ShapedType::kDynamicSize;
|
|
|
|
|
Type dtype = input.getType().cast<ValueTensorType>().getDtype();
|
|
|
|
|
Type reshapeType = ValueTensorType::get(
|
|
|
|
|
context, llvm::makeArrayRef(runningStatsShapeInt), dtype);
|
|
|
|
|
|
|
|
|
|
runningMean = rewriter.create<AtenViewOp>(loc, reshapeType, runningMean,
|
|
|
|
|
runningStatsSizeList);
|
|
|
|
|
runningVar = rewriter.create<AtenViewOp>(loc, reshapeType, runningVar,
|
|
|
|
|
runningStatsSizeList);
|
|
|
|
|
|
|
|
|
|
// normalizedInput = (input - runningMean) / (sqrt(runningVar + eps)).
|
|
|
|
|
Value inputSubMean = rewriter.create<AtenSubTensorOp>(
|
|
|
|
|
loc, input.getType(), input, runningMean, /*alpha=*/one);
|
|
|
|
|
Value varEps = rewriter.create<AtenAddScalarOp>(
|
|
|
|
|
loc, runningVar.getType(), runningVar, eps, /*alpha=*/one);
|
|
|
|
|
Value invStd = rewriter.create<AtenRsqrtOp>(loc, varEps.getType(), varEps);
|
|
|
|
|
Value normalizedInput = rewriter.create<AtenMulTensorOp>(
|
|
|
|
|
loc, inputSubMean.getType(), inputSubMean, invStd);
|
|
|
|
|
|
|
|
|
|
// The `weight` and `bias` must be reshaped to (1, C, 1?, 1?, 1?) to make it
|
|
|
|
|
// broadcast-compatible with (N, C, D?, H?, W?).
|
|
|
|
|
// 1. weight = weight.view(1, C, 1?, 1?, 1?)
|
|
|
|
|
// 2. bias = bias.view(1, C, 1?, 1?, 1?)
|
|
|
|
|
// 3. output = normalizedInput * weight + bias
|
|
|
|
|
Value batchNormOutput = normalizedInput;
|
|
|
|
|
if (!weight.getType().isa<Torch::NoneType>()) {
|
2022-02-25 03:41:55 +08:00
|
|
|
|
// Rank of `weight` must be exactly 1.
|
2022-02-08 00:08:10 +08:00
|
|
|
|
if (getTensorRank(weight) != 1)
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "expected weight to be rank 1");
|
|
|
|
|
weight = rewriter.create<AtenViewOp>(loc, reshapeType, weight,
|
|
|
|
|
runningStatsSizeList);
|
|
|
|
|
batchNormOutput = rewriter.create<AtenMulTensorOp>(
|
|
|
|
|
loc, batchNormOutput.getType(), batchNormOutput, weight);
|
|
|
|
|
}
|
|
|
|
|
if (!bias.getType().isa<Torch::NoneType>()) {
|
2022-02-25 03:41:55 +08:00
|
|
|
|
// Rank of `bias` must be exactly 1.
|
2022-02-08 00:08:10 +08:00
|
|
|
|
if (getTensorRank(bias) != 1)
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "expected bias to be rank 1");
|
|
|
|
|
bias = rewriter.create<AtenViewOp>(loc, reshapeType, bias,
|
|
|
|
|
runningStatsSizeList);
|
|
|
|
|
batchNormOutput = rewriter.create<AtenAddTensorOp>(
|
|
|
|
|
loc, batchNormOutput.getType(), batchNormOutput, bias, /*alpha=*/one);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// The `mean` and `invstd` outputs are empty tensors in inference mode.
|
|
|
|
|
Value zeroList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(zero.getType()), zero);
|
|
|
|
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
|
Value emptyMeanTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
|
|
|
|
|
loc, op.getType(1), zeroList, /*dtype=*/none, /*layout=*/none,
|
|
|
|
|
/*device=*/none, /*pin_memory=*/none, /*memory_format=*/none);
|
|
|
|
|
Value emptyInvStdTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
|
|
|
|
|
loc, op.getType(2), zeroList, /*dtype=*/none, /*layout=*/none,
|
|
|
|
|
/*device=*/none, /*pin_memory=*/none, /*memory_format=*/none);
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOp(op,
|
|
|
|
|
{batchNormOutput, emptyMeanTensor, emptyInvStdTensor});
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-02-10 16:11:05 +08:00
|
|
|
|
// Decompse `Aten_UnsafeViewOp` into `AtenViewOp`. _unsafe_view() differs from
|
|
|
|
|
// view() in that the returned tensor isn't treated as a view for the purposes
|
|
|
|
|
// of automatic differentiation. It's only safe to use if the `self` tensor is
|
|
|
|
|
// temporary. For example, the viewed tensor here (a + b) is discarded
|
|
|
|
|
// immediately after viewing:
|
|
|
|
|
//
|
|
|
|
|
// res = _unsafe_view(a + b, size);
|
|
|
|
|
//
|
|
|
|
|
// This is a hack because in-place operations on tensors treated like views
|
|
|
|
|
// can be much more expensive than the same operations on non-view tensors.
|
|
|
|
|
|
|
|
|
|
// Refer to
|
|
|
|
|
// https://github.com/pytorch/pytorch/blob/364055b2771ecf9b54f1d67a8bf44bb5496476d4/aten/src/ATen/native/TensorShape.cpp#L2072
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAten_UnsafeViewOp : public OpRewritePattern<Aten_UnsafeViewOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(Aten_UnsafeViewOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenViewOp>(op, op.getType(), op.self(),
|
|
|
|
|
op.size());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-03-29 12:54:28 +08:00
|
|
|
|
// In PyTorch, _reshape_alias just uses an already computed stride.
|
|
|
|
|
// See
|
|
|
|
|
// https://github.com/pytorch/pytorch/blob/d8c31a819d4a65e732b5901e3b994e1869851f1a/aten/src/ATen/native/TensorShape.cpp#L1153
|
|
|
|
|
// Note that this is the same decomposition as in AOTAutograd
|
|
|
|
|
// https://github.com/pytorch/functorch/blob/a3042d94e616d4143813668b1372d9d4545be14e/functorch/_src/aot_autograd.py#L104
|
|
|
|
|
namespace {
|
2022-05-13 20:06:24 +08:00
|
|
|
|
class DecomposeAten_ReshapeAliasOp
|
|
|
|
|
: public OpRewritePattern<Aten_ReshapeAliasOp> {
|
2022-03-29 12:54:28 +08:00
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(Aten_ReshapeAliasOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenViewOp>(op, op.getType(), op.self(),
|
|
|
|
|
op.size());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-02-28 14:14:40 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose constant tensor like ops.
|
|
|
|
|
template <typename OpTy, typename NewOpTy>
|
|
|
|
|
class DecomposeConstantTensorNewLikeOp : public OpRewritePattern<OpTy> {
|
|
|
|
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(OpTy op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-03-25 00:40:21 +08:00
|
|
|
|
Value dtype = op.dtype();
|
|
|
|
|
if (dtype.getType().isa<Torch::NoneType>()) {
|
|
|
|
|
BaseTensorType tensorType =
|
|
|
|
|
op.self().getType().template cast<BaseTensorType>();
|
|
|
|
|
dtype =
|
|
|
|
|
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
|
|
|
|
|
}
|
|
|
|
|
rewriter.replaceOpWithNewOp<NewOpTy>(op, op.getType(), op.size(), dtype,
|
|
|
|
|
op.layout(), op.device(),
|
2022-02-28 14:14:40 +08:00
|
|
|
|
op.pin_memory());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-03-03 21:41:14 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `aten.full` op into `aten.empty` and `aten.fill` ops.
|
|
|
|
|
class DecomposeAtenFullOp : public OpRewritePattern<AtenFullOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenFullOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(loc);
|
|
|
|
|
Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
|
|
|
|
|
loc, op.getType(), op.size(), op.dtype(), op.layout(), op.device(),
|
|
|
|
|
op.pin_memory(), /*memory_format=*/noneVal);
|
2022-03-16 07:57:33 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<ValsemVariantAtenFillScalarOp>(
|
2022-03-03 21:41:14 +08:00
|
|
|
|
op, op.getType(), emptyTensor, op.fill_value());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-03-03 22:25:22 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `aten.full_like` op into `aten.empty_like` and `aten.fill` ops.
|
|
|
|
|
class DecomposeAtenFullLikeOp : public OpRewritePattern<AtenFullLikeOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenFullLikeOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value emptyTensor = rewriter.create<AtenEmptyLikeOp>(
|
|
|
|
|
op.getLoc(), op.getType(), op.self(), op.dtype(), op.layout(),
|
|
|
|
|
op.device(), op.pin_memory(), op.memory_format());
|
2022-03-16 07:57:33 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<ValsemVariantAtenFillScalarOp>(
|
2022-03-03 22:25:22 +08:00
|
|
|
|
op, op.getType(), emptyTensor, op.fill_value());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-03-10 23:18:08 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `aten.index_put` op into `valsem.aten.index_put_impl` op.
|
|
|
|
|
class DecomposeAtenIndexPutOp : public OpRewritePattern<AtenIndexPutOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenIndexPutOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
|
|
|
|
rewriter.replaceOpWithNewOp<ValsemVariantAtenIndexPutImplOp>(
|
|
|
|
|
op, op.getType(), op.self(), op.indices(), op.values(), op.accumulate(),
|
|
|
|
|
/*unsafe=*/cstFalse);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-03-14 16:12:37 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenExpandAsOp : public OpRewritePattern<AtenExpandAsOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenExpandAsOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
|
|
auto sizeListType =
|
|
|
|
|
Torch::ListType::get(Torch::IntType::get(op.getContext()));
|
|
|
|
|
Value sizeList =
|
|
|
|
|
rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.other());
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenBroadcastToOp>(op, op.getType(), op.self(),
|
|
|
|
|
sizeList);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-03-17 21:35:17 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `aten._to_copy` op into `valsem.aten.copy` op.
|
|
|
|
|
class DecomposeAten_ToCopyOp : public OpRewritePattern<Aten_ToCopyOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(Aten_ToCopyOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value emptyTensor = rewriter.create<AtenEmptyLikeOp>(
|
|
|
|
|
op.getLoc(), op.getType(), op.self(), op.dtype(), op.layout(),
|
|
|
|
|
op.device(), op.pin_memory(), op.memory_format());
|
|
|
|
|
rewriter.replaceOpWithNewOp<ValsemVariantAtenCopyOp>(
|
|
|
|
|
op, op.getType(), emptyTensor, op.self(), op.non_blocking());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-03-25 00:40:21 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `aten.new_empty` op into `aten.empty.memory_format` op.
|
|
|
|
|
class DecomposeAtenNewEmptyOp : public OpRewritePattern<AtenNewEmptyOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenNewEmptyOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
|
|
|
|
|
Value dtype = op.dtype();
|
|
|
|
|
if (dtype.getType().isa<Torch::NoneType>()) {
|
|
|
|
|
BaseTensorType tensorType = op.self().getType().cast<BaseTensorType>();
|
|
|
|
|
dtype =
|
|
|
|
|
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
|
|
|
|
|
}
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
|
|
|
|
|
op, op.getType(), op.size(), dtype, op.layout(), op.device(),
|
|
|
|
|
op.pin_memory(), /*memory_format=*/noneVal);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-03-24 15:12:59 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `aten.index_put.hacked_twin` op into `valsem.aten.index_put_impl`
|
|
|
|
|
// op.
|
|
|
|
|
class DecomposeAtenIndexPutHackedTwinOp
|
|
|
|
|
: public OpRewritePattern<AtenIndexPutHackedTwinOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenIndexPutHackedTwinOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
|
|
|
|
rewriter.replaceOpWithNewOp<ValsemVariantAtenIndexPutImplOp>(
|
|
|
|
|
op, op.getType(), op.self(), op.indices(), op.values(), op.accumulate(),
|
|
|
|
|
/*unsafe=*/cstFalse);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-04-26 20:18:09 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `aten.pad` op into `aten.constant_pad_nd` op.
|
|
|
|
|
class DecomposeAtenPadOp : public OpRewritePattern<AtenPadOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenPadOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
|
|
Value value = op.value();
|
|
|
|
|
if (value.getType().isa<Torch::OptionalType>())
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "optional type not supported");
|
|
|
|
|
if (value.getType().isa<Torch::NoneType>())
|
|
|
|
|
value = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
|
op.getLoc(), rewriter.getF64FloatAttr(0));
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenConstantPadNdOp>(
|
|
|
|
|
op, op.getType(), op.self(), op.pad(), value);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-04-27 19:07:40 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `aten.to.dtype_layout` op into `aten.to.dtype` op.
|
|
|
|
|
class DecomposeAtenToDtypeLayoutOp
|
|
|
|
|
: public OpRewritePattern<AtenToDtypeLayoutOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenToDtypeLayoutOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
// TODO: Add support for pin_memory arg equal to `True`.
|
|
|
|
|
if (!op.pin_memory().getType().isa<Torch::NoneType>()) {
|
|
|
|
|
bool pinMemory;
|
|
|
|
|
if (!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: pin_memory must be a constant");
|
|
|
|
|
else if (pinMemory)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: pin_memory is expected to be false");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO: Add support for non-None device arg.
|
|
|
|
|
if (!op.device().getType().isa<Torch::NoneType>()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: device arg must be None");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO: Add support for non-strided layout.
|
|
|
|
|
// torch.layout is by default strided i.e. 0.
|
|
|
|
|
if (!op.layout().getType().isa<Torch::NoneType>()) {
|
|
|
|
|
int64_t tensorLayout;
|
|
|
|
|
if (!matchPattern(op.layout(), m_TorchConstantInt(&tensorLayout)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: layout must be a constant");
|
|
|
|
|
else if (tensorLayout != torch_upstream::Layout::Strided)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: layout is expected to be strided");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(op, op.getType(), op.self(),
|
|
|
|
|
op.dtype(), op.non_blocking(),
|
|
|
|
|
op.copy(), op.memory_format());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-05-13 20:06:24 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `aten.adaptive_avg_pool2d` op into `aten.avg_pool2d` op.
|
|
|
|
|
//
|
|
|
|
|
// For AdaptiveAvgPool2d op, when the input size is an integer multiple of
|
|
|
|
|
// output size the kernel_size, stride and padding is calculated as follows:
|
|
|
|
|
// strideH = inH // outH
|
|
|
|
|
// strideW = inH // outH
|
|
|
|
|
// kernelH = inH - [(outH - 1) * strideH]
|
|
|
|
|
// kernelW = inW - [(outW - 1) * strideW]
|
|
|
|
|
// paddingH = 0, paddingW = 0
|
|
|
|
|
//
|
|
|
|
|
// For the special case, when the output size is one for all dimensions,
|
|
|
|
|
// the kernel size is same as the input size.
|
|
|
|
|
class DecomposeAtenAdaptiveAvgPool2dOp
|
|
|
|
|
: public OpRewritePattern<AtenAdaptiveAvgPool2dOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenAdaptiveAvgPool2dOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
MLIRContext *context = op.getContext();
|
|
|
|
|
|
|
|
|
|
Value input = op.self();
|
|
|
|
|
int64_t rank = getTensorRank(input);
|
|
|
|
|
SmallVector<Value, 2> inputHW;
|
|
|
|
|
Value dimH = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(rank - 2));
|
|
|
|
|
inputHW.push_back(
|
|
|
|
|
/*inH=*/rewriter.create<AtenSizeIntOp>(loc, input, dimH));
|
|
|
|
|
Value dimW = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(rank - 1));
|
|
|
|
|
inputHW.push_back(
|
|
|
|
|
/*inW=*/rewriter.create<AtenSizeIntOp>(loc, input, dimW));
|
|
|
|
|
|
|
|
|
|
Value outputShape = op.output_size();
|
|
|
|
|
SmallVector<Value> outputShapeSizesTorchInt;
|
|
|
|
|
getListConstructElements(outputShape, outputShapeSizesTorchInt);
|
|
|
|
|
|
|
|
|
|
// TODO: Add support for cases other than:
|
|
|
|
|
// 1.) inH == outH and inW == outW.
|
|
|
|
|
// 2.) outH == outW == 1
|
|
|
|
|
bool unitOutputSize = true;
|
|
|
|
|
for (Value outShape : outputShapeSizesTorchInt) {
|
|
|
|
|
int64_t outShapeInt;
|
|
|
|
|
if (!matchPattern(outShape, m_TorchConstantInt(&outShapeInt))) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "output size is expected to be a constant");
|
|
|
|
|
}
|
|
|
|
|
if (outShapeInt != 1) {
|
|
|
|
|
unitOutputSize = false;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
Value constantFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
|
|
|
|
Value constantTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
|
|
|
|
|
Value constantNone = rewriter.create<Torch::ConstantNoneOp>(loc);
|
|
|
|
|
SmallVector<Value, 2> kernelSize;
|
|
|
|
|
|
|
|
|
|
for (unsigned i = 0; i < inputHW.size(); i++) {
|
|
|
|
|
if (unitOutputSize) {
|
|
|
|
|
BaseTensorType inputTensorType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
|
|
|
|
|
kernelSize.push_back(inputShape[rank - 2 + i] == kUnknownSize
|
|
|
|
|
? inputHW[i]
|
|
|
|
|
: rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(
|
|
|
|
|
inputShape[rank - 2 + i])));
|
|
|
|
|
} else {
|
|
|
|
|
Value cond = rewriter.create<AtenEqIntOp>(loc, inputHW[i],
|
|
|
|
|
outputShapeSizesTorchInt[i]);
|
|
|
|
|
rewriter.create<RuntimeAssertOp>(
|
|
|
|
|
loc, cond,
|
|
|
|
|
"unimplemented: only support cases where input and output size are "
|
|
|
|
|
"equal for non-unit output size");
|
|
|
|
|
|
|
|
|
|
Value outMinusOne = rewriter.create<AtenSubIntOp>(
|
|
|
|
|
loc, outputShapeSizesTorchInt[i], constantOne);
|
|
|
|
|
kernelSize.push_back(
|
|
|
|
|
rewriter.create<AtenSubIntOp>(loc, inputHW[i], outMinusOne));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value kernelSizeList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize);
|
|
|
|
|
// Currently we only support cases where input size is equal to the output
|
|
|
|
|
// size or unit output size. For the former case, stride is always equal to
|
|
|
|
|
// one and for the latter the stride value doesn't matter, since the kernel
|
|
|
|
|
// size is same as the input size. Therfore, keeping the stride as one for
|
|
|
|
|
// the latter case as well for the ease of implementation.
|
|
|
|
|
Value strideList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(context)),
|
|
|
|
|
ValueRange{constantOne, constantOne});
|
|
|
|
|
Value paddingSizeList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(context)),
|
|
|
|
|
ValueRange{constantZero, constantZero});
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenAvgPool2dOp>(
|
|
|
|
|
op, op.getType(), input, kernelSizeList, strideList, paddingSizeList,
|
|
|
|
|
/*ceil_mode=*/constantFalse, /*count_include_pad=*/constantTrue,
|
|
|
|
|
/*divisor_override=*/constantNone);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-06-03 15:41:13 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `aten.clamp_min` op into `aten.clamp` op.
|
|
|
|
|
class DecomposeAtenClampMinOp : public OpRewritePattern<AtenClampMinOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenClampMinOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value constantNone = rewriter.create<Torch::ConstantNoneOp>(op.getLoc());
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenClampOp>(op, op.getType(), op.self(),
|
|
|
|
|
op.min(), /*max=*/constantNone);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `aten.clamp_max` op into `aten.clamp` op.
|
|
|
|
|
class DecomposeAtenClampMaxOp : public OpRewritePattern<AtenClampMaxOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenClampMaxOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value constantNone = rewriter.create<Torch::ConstantNoneOp>(op.getLoc());
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenClampOp>(op, op.getType(), op.self(),
|
|
|
|
|
/*min=*/constantNone, op.max());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-05-30 16:08:54 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `aten.baddbmm` op into `aten.bmm`, `aten.mul.Scalar`, and
|
|
|
|
|
// `aten.add.Tensor` op.
|
|
|
|
|
class DecomposeAtenBaddbmmOp : public OpRewritePattern<AtenBaddbmmOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenBaddbmmOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value bmm =
|
|
|
|
|
rewriter.create<AtenBmmOp>(loc, op.getType(), op.batch1(), op.batch2());
|
|
|
|
|
Value alphaTimesBmm =
|
|
|
|
|
rewriter.create<AtenMulScalarOp>(loc, op.getType(), bmm, op.alpha());
|
|
|
|
|
Value input = op.self();
|
|
|
|
|
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
BaseTensorType resultType =
|
|
|
|
|
op->getResult(0).getType().cast<BaseTensorType>();
|
|
|
|
|
if (inputType.hasDtype() && resultType.hasDtype() &&
|
|
|
|
|
inputType.getDtype() != resultType.getDtype()) {
|
|
|
|
|
input = convertTensorToDtype(rewriter, loc, input, resultType.getDtype());
|
|
|
|
|
}
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenAddTensorOp>(
|
|
|
|
|
op, op.getType(), alphaTimesBmm, op.self(), op.beta());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-06-09 14:09:28 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `aten.floor_divide` op into `aten.div.Tensor_mode` op.
|
|
|
|
|
class DecomposeAtenFloorDivideOp : public OpRewritePattern<AtenFloorDivideOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenFloorDivideOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value cstStrFloor =
|
|
|
|
|
rewriter.create<Torch::ConstantStrOp>(op.getLoc(), "floor");
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenDivTensorModeOp>(
|
|
|
|
|
op, op.getType(), op.self(), op.other(),
|
|
|
|
|
/*rounding_mode=*/cstStrFloor);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-06-03 20:38:59 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `aten.numpy_T` op into `aten.permute` op.
|
|
|
|
|
class DecomposeAtenNumpyTOp : public OpRewritePattern<AtenNumpyTOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenNumpyTOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value self = op.self();
|
|
|
|
|
int64_t inputRank = getTensorRank(self);
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> dimListElements;
|
|
|
|
|
for (int64_t i = inputRank - 1; i >= 0; i--)
|
|
|
|
|
dimListElements.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(i)));
|
|
|
|
|
Value dimList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(op->getContext())),
|
|
|
|
|
dimListElements);
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenPermuteOp>(op, op.getType(), self, dimList);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-05-10 21:15:59 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose the `aten.select_scatter` operation into `aten.slice_scatter` op.
|
|
|
|
|
class DecomposeAtenSelectScatterOp
|
|
|
|
|
: public OpRewritePattern<AtenSelectScatterOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenSelectScatterOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value start = op.index();
|
|
|
|
|
Value dim = op.dim();
|
|
|
|
|
Value self = op.self();
|
|
|
|
|
Value src = op.src();
|
|
|
|
|
|
|
|
|
|
Value one =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
Value startPlusOne =
|
|
|
|
|
rewriter.create<AtenAddIntOp>(loc, one.getType(), start, one);
|
|
|
|
|
BaseTensorType srcTensorType = src.getType().cast<BaseTensorType>();
|
|
|
|
|
SmallVector<int64_t> sizes;
|
|
|
|
|
if (!srcTensorType.hasSizes())
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "src tensor must have size");
|
|
|
|
|
|
|
|
|
|
ArrayRef<int64_t> srcShape = srcTensorType.getSizes();
|
|
|
|
|
// `src` has a reduced rank. Hence add 1.
|
|
|
|
|
int64_t srcRank = srcShape.size() + 1;
|
|
|
|
|
int64_t dimInt = 0;
|
|
|
|
|
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
|
|
|
|
|
dimInt = toPositiveDim(dimInt, srcRank);
|
|
|
|
|
if (!isValidDim(dimInt, srcRank))
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
|
|
|
|
|
|
|
|
|
sizes.append(srcShape.begin(), srcShape.end());
|
|
|
|
|
sizes.insert(sizes.begin() + dimInt, 1);
|
|
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
sizes.resize(srcShape.size() + 1, kUnknownSize);
|
|
|
|
|
}
|
|
|
|
|
Type srcType = srcTensorType.getWithSizesAndDtype(llvm::makeArrayRef(sizes),
|
|
|
|
|
srcTensorType.getDtype());
|
|
|
|
|
src = rewriter.create<AtenUnsqueezeOp>(loc, srcType, src, dim);
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenSliceScatterOp>(
|
|
|
|
|
op, op.self().getType(), self, src, dim, start, startPlusOne,
|
|
|
|
|
/*step=*/one);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2021-10-16 06:23:59 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeComplexOpsPass
|
|
|
|
|
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
|
|
|
|
void runOnOperation() override {
|
|
|
|
|
MLIRContext *context = &getContext();
|
|
|
|
|
RewritePatternSet patterns(context);
|
|
|
|
|
ConversionTarget target(*context);
|
|
|
|
|
target.addLegalDialect<Torch::TorchDialect>();
|
|
|
|
|
|
|
|
|
|
patterns.add<DecomposeAtenSoftmaxIntOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenSoftmaxIntOp>();
|
2021-11-25 13:49:02 +08:00
|
|
|
|
patterns.add<DecomposeAten_SoftmaxOp>(context);
|
|
|
|
|
target.addIllegalOp<Aten_SoftmaxOp>();
|
2022-02-10 15:05:23 +08:00
|
|
|
|
patterns.add<DecomposeAten_LogSoftmaxOp>(context);
|
|
|
|
|
target.addIllegalOp<Aten_LogSoftmaxOp>();
|
2021-11-03 01:06:04 +08:00
|
|
|
|
patterns.add<DecomposeAtenLogSoftmaxIntOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenLogSoftmaxIntOp>();
|
2021-12-14 03:01:10 +08:00
|
|
|
|
patterns.add<DecomposeAtenEmptyLikeOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenEmptyLikeOp>();
|
2021-12-21 19:51:19 +08:00
|
|
|
|
patterns.add<DecomposeConstantTensorAllocLikeOp<AtenOnesLikeOp, 1>>(
|
|
|
|
|
context);
|
|
|
|
|
target.addIllegalOp<AtenOnesLikeOp>();
|
|
|
|
|
patterns.add<DecomposeConstantTensorAllocLikeOp<AtenZerosLikeOp, 0>>(
|
|
|
|
|
context);
|
|
|
|
|
target.addIllegalOp<AtenZerosLikeOp>();
|
2022-07-01 13:02:31 +08:00
|
|
|
|
patterns.add<DecomposeAtenRepeatOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenRepeatOp>();
|
2021-11-03 00:48:29 +08:00
|
|
|
|
patterns.add<DecomposeAtenExpandOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenExpandOp>();
|
2022-03-12 01:21:36 +08:00
|
|
|
|
patterns.add<DecomposeAtenWhereScalarOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenWhereScalarOp>();
|
|
|
|
|
patterns.add<DecomposeAtenWhereScalarOtherOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenWhereScalarOtherOp>();
|
|
|
|
|
patterns.add<DecomposeAtenWhereScalarSelfOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenWhereScalarSelfOp>();
|
2021-11-08 23:56:40 +08:00
|
|
|
|
patterns.add<DecomposeAtenSizeOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenSizeOp>();
|
2021-12-17 23:54:03 +08:00
|
|
|
|
patterns.add<DecomposeAtenReshapeOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenReshapeOp>();
|
2021-11-08 23:56:40 +08:00
|
|
|
|
patterns.add<DecomposeAten_SoftmaxBackwardDataOp>(context);
|
|
|
|
|
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
|
2021-11-09 20:25:04 +08:00
|
|
|
|
patterns.add<DecomposeAtenTanhBackwardOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenTanhBackwardOp>();
|
2021-11-11 17:02:13 +08:00
|
|
|
|
patterns.add<DecomposeAtenAddmmOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenAddmmOp>();
|
2021-11-19 23:59:29 +08:00
|
|
|
|
patterns.add<DecomposeAtenMeanOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenMeanOp>();
|
2022-03-11 01:25:21 +08:00
|
|
|
|
patterns.add<DecomposeAtenMeanDimOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenMeanDimOp>();
|
2021-12-03 12:09:21 +08:00
|
|
|
|
patterns.add<DecomposeAtenSelectIntOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenSelectIntOp>();
|
2021-10-21 13:15:10 +08:00
|
|
|
|
patterns.add<DecomposeAtenMatmulOp>(context);
|
2021-12-17 12:08:07 +08:00
|
|
|
|
target.addIllegalOp<AtenTOp>();
|
|
|
|
|
patterns.add<DecomposeAtenTOp>(context);
|
2021-11-19 02:02:20 +08:00
|
|
|
|
patterns.add<DecomposeAten_LogSoftmaxBackwardDataOp>(context);
|
|
|
|
|
target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>();
|
2021-10-21 13:15:10 +08:00
|
|
|
|
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {
|
Add type promotion code to refine types.
The types have different levels of categories: where
complex > floating > integral > boolean (> means left hand
side has higher category).
The operands have different levels of priorities where:
dimensioned tensor > 0-dim tensor > scalar == wrapped 0-dim tensor.
This is represented by the `ResultTypeState.dimResult`,
`ResultTypeState.zeroResult` and `ResultTypeState..wrappedResult` in
the source code.
For operands of the same priorities, the result type should be the
highest categories with sufficient width to hold all operands.
By default, only the highest priority operands participate in the type
promotion logic. Lower priority operands participate if they are in
a higher category than any higher priority operands.
For example, <[],f32> (lower priority) and <[1], si64> tensor would
result in <[?],f32> tensor because floating > integeral. Another example
<[],f64> (lower priority) and <[1], f32> tensor would result in
<[?], f32> tensor because f32 and f64 are the same category.
The ScalarType enum definition, type promotion table, ResultTypeState
struct definition and some helpers are copied from
aten/src/ATen/native/TypeProperties.*
Other references:
- https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc
- https://github.com/pytorch/pytorch/issues/9515
Other minor changes:
1. Fix `visitExpandLikeOp` to consider cases where the given sizes list
size is larger than the input rank.
2. Add back the somehow deleted `torch.aten.softmax.int` tests in
decompose-complex-ops.mlir.
2021-10-21 03:31:28 +08:00
|
|
|
|
int lhsRank = getTensorRank(op.self());
|
|
|
|
|
int rhsRank = getTensorRank(op.other());
|
2021-10-16 06:23:59 +08:00
|
|
|
|
|
2021-10-21 13:15:10 +08:00
|
|
|
|
// Make aten.matmul legal if the following condition is satisfied.
|
|
|
|
|
return (lhsRank != 2 || rhsRank != 2) && (lhsRank != 3 || rhsRank != 3);
|
|
|
|
|
});
|
2022-02-15 21:14:32 +08:00
|
|
|
|
patterns.add<DecomposeAtenAddCLikeOp<AtenAddcmulOp, AtenMulTensorOp>>(
|
|
|
|
|
context);
|
2021-12-04 00:37:37 +08:00
|
|
|
|
target.addIllegalOp<AtenAddcmulOp>();
|
2022-02-15 21:14:32 +08:00
|
|
|
|
patterns.add<DecomposeAtenAddCLikeOp<AtenAddcdivOp, AtenDivTensorOp>>(
|
|
|
|
|
context);
|
2021-12-04 00:37:37 +08:00
|
|
|
|
target.addIllegalOp<AtenAddcdivOp>();
|
2021-12-10 21:36:19 +08:00
|
|
|
|
target.addIllegalOp<AtenLayerNormOp>();
|
|
|
|
|
patterns.add<DecomposeAtenLayerNormOp>(context);
|
2022-02-08 00:08:10 +08:00
|
|
|
|
target.addIllegalOp<AtenNativeBatchNormOp>();
|
|
|
|
|
patterns.add<DecomposeAtenNativeBatchNormOp>(context);
|
2022-04-08 12:47:57 +08:00
|
|
|
|
target.addIllegalOp<AtenConvolutionOverrideableOp>();
|
|
|
|
|
patterns.add<DecomposeAtenConvolutionOverrideableOp>(context);
|
2022-07-08 14:44:03 +08:00
|
|
|
|
target.addIllegalOp<Aten_ConvolutionOp>();
|
|
|
|
|
patterns.add<DecomposeAten_ConvolutionOp>(context);
|
2022-04-08 12:47:57 +08:00
|
|
|
|
target.addIllegalOp<AtenConv2dOp>();
|
|
|
|
|
patterns.add<DecomposeAtenConv2dOp>(context);
|
2021-12-23 21:22:45 +08:00
|
|
|
|
patterns.add<DecomposeAtenArangeOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenArangeOp>();
|
|
|
|
|
patterns.add<DecomposeAtenArangeStartOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenArangeStartOp>();
|
2022-01-25 16:53:55 +08:00
|
|
|
|
patterns.add<DecomposeAtenArgMaxOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenArgmaxOp>();
|
2022-01-30 01:10:50 +08:00
|
|
|
|
patterns.add<DecomposeAtenSquareOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenSquareOp>();
|
|
|
|
|
patterns.add<DecomposeAtenVarOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenVarOp>();
|
|
|
|
|
patterns.add<DecomposeAtenStdOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenStdOp>();
|
2022-02-10 16:11:05 +08:00
|
|
|
|
patterns.add<DecomposeAten_UnsafeViewOp>(context);
|
|
|
|
|
target.addIllegalOp<Aten_UnsafeViewOp>();
|
2022-03-29 12:54:28 +08:00
|
|
|
|
patterns.add<DecomposeAten_ReshapeAliasOp>(context);
|
|
|
|
|
target.addIllegalOp<Aten_ReshapeAliasOp>();
|
2022-02-04 19:43:25 +08:00
|
|
|
|
patterns.add<DecomposeAtenBernoulliOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenBernoulliOp>();
|
2022-03-16 07:57:33 +08:00
|
|
|
|
patterns.add<DecomposeValsemVariantAtenBernoulliFloatOp>(context);
|
|
|
|
|
target.addIllegalOp<ValsemVariantAtenBernoulliFloatOp>();
|
|
|
|
|
patterns.add<DecomposeValsemVariantAtenBernoulliTensorOp>(context);
|
|
|
|
|
target.addIllegalOp<ValsemVariantAtenBernoulliTensorOp>();
|
2022-07-12 01:56:12 +08:00
|
|
|
|
patterns.add<DecomposeAtenZeroOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenZeroOp>();
|
2022-02-26 00:35:04 +08:00
|
|
|
|
patterns.add<DecomposeAtenRandLikeOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenRandLikeOp>();
|
2022-02-14 22:46:44 +08:00
|
|
|
|
patterns.add<DecomposeAtenHardsigmoidOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenHardsigmoidOp>();
|
2022-02-15 21:14:32 +08:00
|
|
|
|
patterns.add<DecomposeAtenHardswishOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenHardswishOp>();
|
2022-03-02 01:30:58 +08:00
|
|
|
|
patterns.add<DecomposeAtenSiluOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenSiluOp>();
|
2022-02-28 14:14:40 +08:00
|
|
|
|
patterns.add<DecomposeConstantTensorNewLikeOp<AtenNewZerosOp, AtenZerosOp>>(
|
|
|
|
|
context);
|
|
|
|
|
target.addIllegalOp<AtenNewZerosOp>();
|
|
|
|
|
patterns.add<DecomposeConstantTensorNewLikeOp<AtenNewOnesOp, AtenOnesOp>>(
|
|
|
|
|
context);
|
|
|
|
|
target.addIllegalOp<AtenNewOnesOp>();
|
2022-02-09 04:57:23 +08:00
|
|
|
|
patterns.add<DecomposeAtenHardtanhOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenHardtanhOp>();
|
2022-03-03 21:41:14 +08:00
|
|
|
|
patterns.add<DecomposeAtenFullOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenFullOp>();
|
2022-03-03 22:25:22 +08:00
|
|
|
|
patterns.add<DecomposeAtenFullLikeOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenFullLikeOp>();
|
2022-03-10 23:18:08 +08:00
|
|
|
|
patterns.add<DecomposeAtenIndexPutOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenIndexPutOp>();
|
2022-03-14 16:12:37 +08:00
|
|
|
|
patterns.add<DecomposeAtenExpandAsOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenExpandAsOp>();
|
2022-03-17 21:35:17 +08:00
|
|
|
|
patterns.add<DecomposeAten_ToCopyOp>(context);
|
|
|
|
|
target.addIllegalOp<Aten_ToCopyOp>();
|
2022-02-17 00:34:03 +08:00
|
|
|
|
patterns.add<DecomposeAtenDropoutOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenDropoutOp>();
|
2022-03-25 00:40:21 +08:00
|
|
|
|
target.addIllegalOp<AtenNewEmptyOp>();
|
|
|
|
|
patterns.add<DecomposeAtenNewEmptyOp>(context);
|
2022-03-24 15:12:59 +08:00
|
|
|
|
patterns.add<DecomposeAtenIndexPutHackedTwinOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenIndexPutHackedTwinOp>();
|
2022-04-26 20:18:09 +08:00
|
|
|
|
target.addIllegalOp<AtenPadOp>();
|
|
|
|
|
patterns.add<DecomposeAtenPadOp>(context);
|
2022-04-27 19:07:40 +08:00
|
|
|
|
patterns.add<DecomposeAtenToDtypeLayoutOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenToDtypeLayoutOp>();
|
2022-05-13 20:06:24 +08:00
|
|
|
|
patterns.add<DecomposeAtenAdaptiveAvgPool2dOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenAdaptiveAvgPool2dOp>();
|
2022-06-03 15:41:13 +08:00
|
|
|
|
patterns.add<DecomposeAtenClampMinOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenClampMinOp>();
|
|
|
|
|
patterns.add<DecomposeAtenClampMaxOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenClampMaxOp>();
|
2022-05-30 16:08:54 +08:00
|
|
|
|
patterns.add<DecomposeAtenBaddbmmOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenBaddbmmOp>();
|
2022-06-09 14:09:28 +08:00
|
|
|
|
patterns.add<DecomposeAtenFloorDivideOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenFloorDivideOp>();
|
2022-06-03 20:38:59 +08:00
|
|
|
|
patterns.add<DecomposeAtenNumpyTOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenNumpyTOp>();
|
2022-05-10 21:15:59 +08:00
|
|
|
|
patterns.add<DecomposeAtenSelectScatterOp>(context);
|
|
|
|
|
target.addIllegalOp<AtenSelectScatterOp>();
|
2021-12-10 21:36:19 +08:00
|
|
|
|
|
2021-10-16 06:23:59 +08:00
|
|
|
|
if (failed(applyPartialConversion(getOperation(), target,
|
|
|
|
|
std::move(patterns)))) {
|
|
|
|
|
return signalPassFailure();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
2022-04-27 03:27:51 +08:00
|
|
|
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
2021-10-16 06:23:59 +08:00
|
|
|
|
mlir::torch::Torch::createDecomposeComplexOpsPass() {
|
|
|
|
|
return std::make_unique<DecomposeComplexOpsPass>();
|
|
|
|
|
}
|