torch-mlir/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

1315 lines
53 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/StringExtras.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_upstream; // For ScalarType and type
// Helper funtion to get rank of `Base tensor type`.
// -1 is returned if the tensorRank can't be determined.
static int getTensorRank(Value tensor) {
int tensorRank = -1;
BaseTensorType tensorType = tensor.getType().cast<BaseTensorType>();
if (tensorType.hasSizes()) {
ArrayRef<int64_t> tensorShape = tensorType.getSizes();
tensorRank = tensorShape.size();
}
return tensorRank;
}
// Helper function to compute the return type of the reduction function.
// `dim` specifies the dimension to reduce and `keepDim` preserves the rank of
// the input tensor.
static Type computeReductionType(PatternRewriter &rewriter, Operation *op,
Value input, Value dim, bool keepDim) {
BaseTensorType tensorType = input.getType().cast<BaseTensorType>();
SmallVector<int64_t> sizes;
int64_t dimInt;
if (tensorType.hasSizes()) {
ArrayRef<int64_t> inputShape = tensorType.getSizes();
int64_t inputRank = inputShape.size();
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
dimInt = toPositiveDim(dimInt, inputRank);
if (!isValidDim(dimInt, inputRank)) {
(void)rewriter.notifyMatchFailure(op, "dim is not a valid dim");
return nullptr;
}
sizes.append(inputShape.begin(), inputShape.end());
// The dimension to be reduced is set to 1 when `keepDim` is true else it
// is removed.
if (keepDim)
sizes[dimInt] = 1;
else
sizes.erase(sizes.begin() + dimInt - 1);
} else {
unsigned reducedRank = keepDim ? inputRank : inputRank - 1;
sizes.resize(reducedRank, kUnknownSize);
}
}
Type resultType = tensorType.getWithSizesAndDtype(
sizes.size() == 0 ? Optional<ArrayRef<int64_t>>()
: llvm::makeArrayRef(sizes),
tensorType.getDtype());
return resultType;
}
// Reduction function to calculate sum along given `dim`.
static Value createSumAlongDimension(PatternRewriter &rewriter, Location loc,
Operation *op, Value input, Value dim,
bool keepDim) {
Value dimList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(dim.getType()), dim);
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
Value dtype = rewriter.create<ConstantNoneOp>(loc);
Type resultType = computeReductionType(rewriter, op, input, dim, keepDim);
if (!resultType)
return nullptr;
return rewriter.create<AtenSumDimIntListOp>(loc, resultType, input, dimList,
keepDimCst, dtype);
}
// Redunction function to calculate max along given `dim`.
static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc,
Operation *op, Value input, Value dim,
bool keepDim) {
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
BaseTensorType valueType =
computeReductionType(rewriter, op, input, dim, keepDim)
.cast<BaseTensorType>();
if (!valueType)
return nullptr;
BaseTensorType indexType =
valueType
.getWithSizesAndDtype(
!valueType.hasSizes() ? Optional<ArrayRef<int64_t>>()
: llvm::makeArrayRef(valueType.getSizes()),
IntegerType::get(op->getContext(), 64, IntegerType::Signed))
.cast<BaseTensorType>();
return rewriter
.create<AtenMaxDimOp>(loc, valueType, indexType, input, dim, keepDimCst)
.values();
}
// Helper for creating `aten::sub_tensor_op`.
static Value createTensorSub(PatternRewriter &rewriter, Location loc,
Type tensorType, Value lhs, Value rhs) {
Value alpha =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
Value sub =
rewriter.create<AtenSubTensorOp>(loc, tensorType, lhs, rhs, alpha);
return sub;
}
// Share code between `softmax_backward` and `log_softmax_backward` ops.
// Returns x - y * sum(z, dim).
static Value createSoftmaxBackwardCommonKernel(PatternRewriter &rewriter,
Location loc, Operation *op,
Type tensorType, Value x,
Value y, Value z, Value dim) {
Value sum = createSumAlongDimension(rewriter, loc, op, z, dim, /*keepDim=*/true);
if (!sum)
return nullptr;
auto broadcastSizeType =
Torch::ListType::get(Torch::IntType::get(op->getContext()));
Value broadcastSize = rewriter.create<AtenSizeOp>(loc, broadcastSizeType, z);
Value sumBroadcast =
rewriter.create<AtenBroadcastToOp>(loc, tensorType, sum, broadcastSize);
Value temp =
rewriter.create<AtenMulTensorOp>(loc, tensorType, y, sumBroadcast);
Value sub = createTensorSub(rewriter, loc, tensorType, x, temp);
return sub;
}
namespace {
class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenSizeOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.self();
MLIRContext *context = op.getContext();
int64_t rank = getTensorRank(self);
if (rank < 0)
return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor");
SmallVector<Value> sizes;
for (int i = 0; i < rank; i++) {
Value dim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));
sizes.push_back(rewriter.create<AtenSizeIntOp>(loc, self, dim));
}
Value sizeList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)), sizes);
rewriter.replaceOp(op, sizeList);
return success();
}
};
} // namespace
namespace {
class DecomposeAtenSelectIntOp : public OpRewritePattern<AtenSelectIntOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenSelectIntOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value start = op.index();
Value dim = op.dim();
Value self = op.self();
Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value startPlusOne =
rewriter.create<AtenAddIntOp>(loc, one.getType(), start, one);
Value slice = rewriter.create<AtenSliceTensorOp>(
loc, computeReductionType(rewriter, op, self, dim, /*keepDim=*/true),
op.self(), dim, start, startPlusOne, /*step=*/one);
// `aten.slice.tensor` doesn't squeeze the dim even when it's size 1 after
// slicing, while `aten.select.int` does.
rewriter.replaceOpWithNewOp<AtenSqueezeDimOp>(op, op.getResult().getType(),
slice, op.dim());
return success();
}
};
} // namespace
namespace {
class DecomposeAtenReshapeOp : public OpRewritePattern<AtenReshapeOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenReshapeOp op,
PatternRewriter &rewriter) const override {
Value input = op.self();
// TODO: Handle non value tensor type operands.
if (!input.getType().isa<ValueTensorType>()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: only value tensor type operands are supported");
}
rewriter.replaceOpWithNewOp<AtenViewOp>(op, op.getType(), input,
op.shape());
return success();
}
};
} // namespace
// Calculates the softmax function on the given `input` tensor. Softmax(x) =
// exp(x)/sum(exp(x)).
// To avoid overflow we use the following decomposition rule:
// x_max = max(input, dim, keepdim = True)
// unnorm = aten.exp(input - x_max)
// softmax = unnorm / sum(unnorm, dim, keepdim = True)
template <typename OpTy>
static Value getSoftmaxResult(OpTy op, Type resultType,
PatternRewriter &rewriter) {
Location loc = op.getLoc();
Value dim = op.dim();
Value self = op.self();
Value xMax =
createMaxAlongDimension(rewriter, loc, op, self, dim, /*keepDim=*/true);
if (!xMax)
return nullptr;
Value unNormalized = createTensorSub(rewriter, loc, resultType, self, xMax);
Value unNormalizedExp =
rewriter.create<AtenExpOp>(loc, resultType, unNormalized);
Value sum = createSumAlongDimension(rewriter, loc, op, unNormalizedExp, dim,
/*keepDim=*/true);
if (!sum)
return nullptr;
return rewriter.create<AtenDivTensorOp>(loc, resultType, unNormalizedExp,
sum);
}
// Decompose softmax into: exp(x) / sum(exp(x))
namespace {
class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenSoftmaxIntOp op,
PatternRewriter &rewriter) const override {
Value self = op.self();
if (!op.dtype().getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "Unimplemented non-None dtype for softmax");
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(op, "Only support floating type");
Value result = getSoftmaxResult(op, tensorType, rewriter);
if (!result)
return failure();
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
result);
return success();
}
};
} // namespace
namespace {
class DecomposeAten_SoftmaxOp : public OpRewritePattern<Aten_SoftmaxOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_SoftmaxOp op,
PatternRewriter &rewriter) const override {
Value self = op.self();
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(op, "Only support floating type");
bool halfToFloat;
if (!matchPattern(op.half_to_float(), m_TorchConstantBool(&halfToFloat)))
return rewriter.notifyMatchFailure(
op, "Expected a boolean value for half_to_float");
// Currently, setting `halfToFloat` is not supported as the E2E testing for
// the same is not present on CPU.
if (halfToFloat)
return rewriter.notifyMatchFailure(
op, "halfToFloat is currently not supported.");
Value result = getSoftmaxResult(op, tensorType, rewriter);
if (!result)
return op.emitError("failed to get softmax result");
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
result);
return success();
}
};
} // namespace
// Aten_SoftmaxBackwardDataOp(gradOutput, output, dim) =>
// newGrad = gradOutput * output
// result = newGrad - output * sum(newGrad, dim))
//
// Refer to
// https://github.com/pytorch/pytorch/blob/15fecc4c830a3907fde4b44c9962dc4144da50a4/torch/csrc/jit/codegen/cuda/ops/normalization.cpp#L31
namespace {
class DecomposeAten_SoftmaxBackwardDataOp
: public OpRewritePattern<Aten_SoftmaxBackwardDataOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_SoftmaxBackwardDataOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value gradOutput = op.grad_output();
Value output = op.output();
Value dim = op.dim();
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>();
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(op, "Only support floating type");
Value newGrad =
rewriter.create<AtenMulTensorOp>(loc, tensorType, gradOutput, output);
Value result = createSoftmaxBackwardCommonKernel(
rewriter, loc, op, tensorType, newGrad, output, newGrad, dim);
if (!result)
return rewriter.notifyMatchFailure(
op,
"nullptr returned by createSoftmaxBackwardCommonKernel function.");
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
// AtenTanhBackwardOp(gradOutput, output) =>
// result = gradOutput * (1 - output^2)
// To get away from broadcasts the above formula is expanded i.e.,
// result = gradOutput - (gradOutput * output^2)
namespace {
class DecomposeAtenTanhBackwardOp
: public OpRewritePattern<AtenTanhBackwardOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenTanhBackwardOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value gradOutput = op.grad_output();
// `output` is the value flowing out from tanh. Hence, tanh(x) = output.
// Since, dTanh(x) = (1 - tanh(x)^2) hence, dOutput = (1 - output^2).
Value output = op.output();
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>();
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(op, "Only support floating type");
Value tanhSquare =
rewriter.create<AtenMulTensorOp>(loc, tensorType, output, output);
Value gradMulTanhSquare = rewriter.create<AtenMulTensorOp>(
loc, tensorType, tanhSquare, gradOutput);
Value newGrad = createTensorSub(rewriter, loc, tensorType, gradOutput,
gradMulTanhSquare);
rewriter.replaceOp(op, newGrad);
return success();
}
};
} // namespace
// Aten_LogSoftmaxBackwardDataOp(gradOutput, output, dim) =>
// result = gradOutput - (exp(output) * sum(gradOutput, dim))
namespace {
class DecomposeAten_LogSoftmaxBackwardDataOp
: public OpRewritePattern<Aten_LogSoftmaxBackwardDataOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_LogSoftmaxBackwardDataOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value gradOutput = op.grad_output();
Value output = op.output();
Value dim = op.dim();
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>();
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(op, "Only support floating type");
Value expOut = rewriter.create<AtenExpOp>(loc, tensorType, output);
Value result = createSoftmaxBackwardCommonKernel(
rewriter, loc, op, tensorType, gradOutput, expOut, gradOutput, dim);
if (!result)
return rewriter.notifyMatchFailure(
op,
"nullptr returned by createSoftmaxBackwardCommonKernel function.");
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
// Decompose `AtenArgMaxOp` into `AtenMaxDimOp`.
namespace {
class DecomposeAtenArgMaxOp : public OpRewritePattern<AtenArgmaxOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenArgmaxOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.self();
Value dim = op.dim();
Value keepDim = op.keepdim();
Value result = op.result();
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
BaseTensorType indicesTensorType = result.getType().cast<BaseTensorType>();
if (!indicesTensorType.hasSizes())
return failure();
BaseTensorType valueTensorType =
inputType
.getWithSizesAndDtype(indicesTensorType.getSizes(),
inputType.getDtype())
.cast<BaseTensorType>();
// If the dim type is `NoneType` i.e. reduce along all the dimensions.
// `AtenMaxDimOp` doesn't support dim as `NoneType` so first the input
// tensor is flattened to 1d tensor and then the reduction happens on the
// 0th dimension.
if (dim.getType().isa<Torch::NoneType>()) {
BaseTensorType flattenType =
inputType.getWithSizesAndDtype({kUnknownSize}, inputType.getDtype())
.cast<BaseTensorType>();
dim = rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value end = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(getTensorRank(input) - 1));
input = rewriter.create<AtenFlattenUsingIntsOp>(loc, flattenType, input,
dim, end);
}
Value maxResult =
rewriter
.create<AtenMaxDimOp>(loc, valueTensorType, indicesTensorType,
input, dim, keepDim)
.indices();
rewriter.replaceOp(op, maxResult);
return success();
}
};
} // namespace
// To avoid overflow we use the following decomposition rule:
// x_max = aten.max(x, dim, keepdim=True)[0]
// shifted = x - x_max
// shifted_logsumexp = aten.log(aten.sum(aten.exp(shifted), dim, keepdim=True))
// log_softmax = shifted - shifted_logsumexp
template <typename OpTy>
static Value getLogSoftmaxResult(OpTy op, PatternRewriter &rewriter) {
Location loc = op.getLoc();
Value dim = op.dim();
Value self = op.self();
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
Value xMax =
createMaxAlongDimension(rewriter, loc, op, self, dim, /*keepDim=*/true);
if (!xMax)
return nullptr;
Value shifted = createTensorSub(rewriter, loc, tensorType, self, xMax);
Value shiftedExp = rewriter.create<AtenExpOp>(loc, tensorType, shifted);
Value shiftedSumExp =
createSumAlongDimension(rewriter, loc, op, shiftedExp, dim,
/*keepDim=*/true);
if (!shiftedSumExp)
return nullptr;
Value shiftedLogSumExp =
rewriter.create<AtenLogOp>(loc, shiftedSumExp.getType(), shiftedSumExp);
Value result =
createTensorSub(rewriter, loc, op.getType(), shifted, shiftedLogSumExp);
return result;
}
namespace {
class DecomposeAtenLogSoftmaxIntOp
: public OpRewritePattern<AtenLogSoftmaxIntOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenLogSoftmaxIntOp op,
PatternRewriter &rewriter) const override {
Value self = op.self();
if (!op.dtype().getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "Unimplemented non-None dtype for log_softmax");
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(op, "Only support floating type");
Value logSoftmax = getLogSoftmaxResult(op, rewriter);
if (!logSoftmax)
return rewriter.notifyMatchFailure(
op, "getLogSoftmaxResult function returned nullptr");
rewriter.replaceOp(op, logSoftmax);
return success();
}
};
} // namespace
namespace {
class DecomposeAten_LogSoftmaxOp : public OpRewritePattern<Aten_LogSoftmaxOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_LogSoftmaxOp op,
PatternRewriter &rewriter) const override {
bool halfToFloat;
if (!matchPattern(op.half_to_float(), m_TorchConstantBool(&halfToFloat)))
return rewriter.notifyMatchFailure(
op, "Expected a boolean value for half_to_float");
// Currently, setting `halfToFloat` is not supported as the E2E testing for
// the same is not present on CPU.
if (halfToFloat)
return rewriter.notifyMatchFailure(
op, "halfToFloat is currently not supported.");
Value _logSoftmax = getLogSoftmaxResult(op, rewriter);
if (!_logSoftmax)
return rewriter.notifyMatchFailure(
op, "getLogSoftmaxResult function returned nullptr");
rewriter.replaceOp(op, _logSoftmax);
return success();
}
};
} // namespace
// Decompose aten.matmul into: aten.mm and aten.bmm according to ranks.
namespace {
class DecomposeAtenMatmulOp : public OpRewritePattern<AtenMatmulOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenMatmulOp op,
PatternRewriter &rewriter) const override {
Value lhs = op.self();
Value rhs = op.other();
int lhsRank = getTensorRank(lhs);
int rhsRank = getTensorRank(rhs);
// If both lhs and rhs ranks are 2 then map it to `aten.mm` op.
if (lhsRank == 2 && rhsRank == 2)
rewriter.replaceOpWithNewOp<AtenMmOp>(op, op.getType(), lhs, rhs);
// If both lhs and rhs ranks are 3 then map it to `aten.bmm` op.
if (lhsRank == 3 && rhsRank == 3)
rewriter.replaceOpWithNewOp<AtenBmmOp>(op, op.getType(), lhs, rhs);
return success();
}
};
} // namespace
namespace {
class DecomposeAtenTOp : public OpRewritePattern<AtenTOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenTOp op,
PatternRewriter &rewriter) const override {
Value lhs = op.self();
int lhsRank = getTensorRank(lhs);
auto loc = op.getLoc();
if (lhsRank > 2 || lhsRank < 0) {
std::string errorMessage =
"t() expects a tensor with <=2 dimensions, but self is " +
std::to_string(lhsRank) + "D";
return rewriter.notifyMatchFailure(op, errorMessage.c_str());
} else if (lhsRank < 2)
rewriter.replaceOp(op, lhs);
else {
Value zero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
rewriter.replaceOpWithNewOp<AtenTransposeIntOp>(op, op.getType(), lhs,
zero, one);
}
return success();
}
};
} // namespace
// Decompose aten.expand into aten.broadcast_to op.
namespace {
class DecomposeAtenExpandOp : public OpRewritePattern<AtenExpandOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenExpandOp op,
PatternRewriter &rewriter) const override {
bool implicit = false;
if (!matchPattern(op.implicit(), m_TorchConstantBool(&implicit)) ||
implicit) {
return rewriter.notifyMatchFailure(
op, "unimplemented: requires implicit to be false");
}
rewriter.replaceOpWithNewOp<AtenBroadcastToOp>(op, op.getType(), op.self(),
op.size());
return success();
}
};
} // namespace
// Decompose aten.addmm into aten.mm and aten.add.Tensor op.
namespace {
class DecomposeAtenAddmmOp : public OpRewritePattern<AtenAddmmOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenAddmmOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.self();
Value mat1 = op.mat1();
Value mat2 = op.mat2();
// The operands `mat1`, `mat2` to aten.addmm must be of rank 2.
if (getTensorRank(mat1) != 2 || getTensorRank(mat2) != 2) {
return rewriter.notifyMatchFailure(
op, "expected mat1, mat2 operands to aten.addmm to be rank 2");
}
// TODO: Handle integer type operands.
if (!input.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: non-floating point dtype");
}
// matrix multiplication: matmul = mat1 @ mat2
Value matmul = rewriter.create<AtenMmOp>(loc, op.getType(), mat1, mat2);
// scaledInput = self * beta
Value scaledInput = rewriter.create<AtenMulScalarOp>(loc, input.getType(),
input, op.beta());
// result = scaledInput + alpha * matmul
rewriter.replaceOpWithNewOp<AtenAddTensorOp>(op, op.getType(), scaledInput,
matmul, op.alpha());
return success();
}
};
} // namespace
// Decompose aten.mean into: sum(x)/div(numTensorElements).
namespace {
class DecomposeAtenMeanOp : public OpRewritePattern<AtenMeanOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenMeanOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.self();
Value output = op.result();
BaseTensorType outputTensorType = output.getType().cast<BaseTensorType>();
Value sum = rewriter.create<AtenSumOp>(loc, outputTensorType, input, op.dtype());
Value numTensorElements = rewriter.create<AtenNumelOp>(loc, input);
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, outputTensorType, sum,
numTensorElements);
return success();
}
};
} // namespace
namespace {
class DecomposeAtenSquareOp : public OpRewritePattern<AtenSquareOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenSquareOp op,
PatternRewriter &rewriter) const override {
Value self = op.self();
rewriter.replaceOpWithNewOp<AtenMulTensorOp>(op, op.getType(), self, self);
return success();
}
};
} // namespace
// Decompose aten.var into: sum(square(x - mean))/(numTensorElements-1)
// for unbiased and mean(square(x - mean)) for biased case.
namespace {
class DecomposeAtenVarOp : public OpRewritePattern<AtenVarOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenVarOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.self();
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
if (!inputTensorTy.hasDtype() ||
!inputTensorTy.getDtype().isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure(op,
"Only aten.var support floating type");
}
BaseTensorType rank0FloatTensorTy = op.getType().cast<BaseTensorType>();
assert(rank0FloatTensorTy.getSizes().size() == 0 &&
"Op should have rank 0 tensor type");
bool unbiased;
if (!matchPattern(op.unbiased(), m_TorchConstantBool(&unbiased))) {
return rewriter.notifyMatchFailure(
op, "Only support constant unbiased for aten.var");
}
Value dtype = rewriter.create<ConstantNoneOp>(loc);
Value mean =
rewriter.create<AtenMeanOp>(loc, rank0FloatTensorTy, self, dtype);
Value subMean = createTensorSub(rewriter, loc, inputTensorTy, self, mean);
Value square = rewriter.create<AtenSquareOp>(loc, inputTensorTy, subMean);
Value var;
if (unbiased) {
// Bessels correction is used. Divide the square sum by
// numTensorElements-1.
Value squareSum =
rewriter.create<AtenSumOp>(loc, rank0FloatTensorTy, square, dtype);
Value numTensorElements = rewriter.create<AtenNumelOp>(loc, square);
Value cst1 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value numTensorElementsSub1 =
rewriter.create<AtenSubIntOp>(loc, numTensorElements, cst1);
var = rewriter.replaceOpWithNewOp<AtenDivScalarOp>(
op, rank0FloatTensorTy, squareSum, numTensorElementsSub1);
} else {
var = rewriter.replaceOpWithNewOp<AtenMeanOp>(op, rank0FloatTensorTy,
square, dtype);
}
return success();
}
};
} // namespace
// Decompose aten.std to sqrt(var(x))
namespace {
class DecomposeAtenStdOp : public OpRewritePattern<AtenStdOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenStdOp op,
PatternRewriter &rewriter) const override {
Value self = op.self();
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
if (!inputTensorTy.hasDtype() ||
!inputTensorTy.getDtype().isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure(op,
"Only aten.std support floating type");
}
Value var = rewriter.create<AtenVarOp>(op->getLoc(), op.getType(),
op.self(), op.unbiased());
rewriter.replaceOpWithNewOp<AtenSqrtOp>(op, op.getType(), var);
return success();
}
};
} // namespace
// Hardsigmoid(x) = max(0, min(1, (x+3)/6))
namespace {
class DecomposeAtenHardsigmoidOp : public OpRewritePattern<AtenHardsigmoidOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenHardsigmoidOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.self();
Type inputType = input.getType();
// outputTensor = (input + 3) / 6.
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value constantThree = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(3));
Value constantSix = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(6));
Value inputPlusThree = rewriter.create<AtenAddScalarOp>(
loc, inputType, input, constantThree, /*alpha=*/constantOne);
Value outputTensor = rewriter.create<AtenDivScalarOp>(
loc, inputType, inputPlusThree, constantSix);
// result = max(0, min(1, (input+3)/6))
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
Value zeroTensor = rewriter.create<AtenZerosLikeOp>(
loc, inputType, input, /*dtype=*/none, /*layout=*/none, /*device=*/none,
/*pin_memory=*/none, /*memory_format=*/none);
Value oneTensor = rewriter.create<AtenOnesLikeOp>(
loc, inputType, input, /*dtype=*/none, /*layout=*/none, /*device=*/none,
/*pin_memory=*/none, /*memory_format=*/none);
Value minResult =
rewriter.create<AtenMinimumOp>(loc, inputType, oneTensor, outputTensor);
rewriter.replaceOpWithNewOp<AtenMaximumOp>(op, op.getType(), zeroTensor,
minResult);
return success();
}
};
} // namespace
// Returns a tensor with bernoulli(p) distribution.
// Decompose aten.bernoulli(x, p) to aten.gtTensor(aten.uniform(x), p).
static Value decomposeBernoulliLikeOp(PatternRewriter &rewriter, Operation *op,
Location loc, Value input, double p) {
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
// `intType` contains the corresponding integer for the dtype which is used
// by the aten.to.dtype op.
int intType = (int)getScalarTypeForType(inputType.getDtype());
Value convertIntVal =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(intType));
if (!inputType.hasSizes())
return nullptr;
BaseTensorType boolType =
inputType
.getWithSizesAndDtype(
inputType.getSizes(),
IntegerType::get(op->getContext(), 1, IntegerType::Signless))
.cast<BaseTensorType>();
Value prob =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(p));
Value lb =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
Value ub =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
Value falseVal = rewriter.create<ConstantBoolOp>(loc, false);
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
// Create a uniform random op with low and high set to lb and ub respectively.
Value uniformRandom = rewriter.create<PseudoAtenUniformOp>(
loc, inputType, input, lb, ub, noneVal);
Value gtValue = rewriter.create<AtenLtScalarOp>(loc, boolType, uniformRandom,
prob);
// Since `gtValue` will be a boolean tensor convert it back to the original
// type.
Value convertBack = rewriter.create<AtenToDtypeOp>(
loc, inputType, gtValue, convertIntVal, falseVal, falseVal, noneVal);
return convertBack;
}
namespace {
class DecomposeAtenBernoulliOp : public OpRewritePattern<AtenBernoulliOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenBernoulliOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.self();
Value generator = op.generator();
if (!generator.getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "The generator has to ben None because only global default "
"generator is supported");
Value result = decomposeBernoulliLikeOp(rewriter, op, loc, self, /*p=*/0.5);
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
namespace {
class DecomposePseudoAtenBernoulliFloatOp
: public OpRewritePattern<PseudoAtenBernoulliFloatOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PseudoAtenBernoulliFloatOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.self();
Value generator = op.generator();
double p;
if (!matchPattern(op.p(), m_TorchConstantFloat(&p)))
return rewriter.notifyMatchFailure(op, "p should be constant float");
if (!generator.getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "The generator has to ben None because only global default "
"generator is supported");
Value result = decomposeBernoulliLikeOp(rewriter, op, loc, self, p);
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
namespace {
template<typename OpTy, typename T1T2Op>
class DecomposeAtenAddCLikeOp : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.self();
Value tensor1 = op.tensor1();
Value tensor2 = op.tensor2();
Value value = op.value();
Value product = rewriter.create<T1T2Op>(loc, op.getType(), tensor1, tensor2);
rewriter.replaceOpWithNewOp<AtenAddTensorOp>(op, op.getType(), input, product,
value);
return success();
}
};
} // namespace
namespace {
class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
using OpRewritePattern<AtenLayerNormOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenLayerNormOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto input = op.input().getType().cast<BaseTensorType>();
if (!input.hasSizes())
return rewriter.notifyMatchFailure(
op, "input tensor should have known sizes.");
int64_t inputRank = input.getSizes().size();
Value normalizedShape = op.normalized_shape();
SmallVector<Value> normalizedShapeSizesTorchInt;
getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt);
std::vector<int64_t> meanVarSizes;
for (int i = normalizedShapeSizesTorchInt.size(); i < inputRank; i++)
meanVarSizes.push_back(input.getSizes()[i]);
auto meanVarType = input.getWithSizesAndDtype(
llvm::makeArrayRef(meanVarSizes), input.getDtype());
auto nativeLayerNorm = rewriter.create<AtenNativeLayerNormOp>(
loc, op.getType(), meanVarType, meanVarType, op.input(),
op.normalized_shape(), op.weight(), op.bias(), op.eps());
rewriter.replaceOp(op, nativeLayerNorm.getResult(0));
return success();
}
};
} // namespace
namespace {
class DecomposeAtenBatchNormOp : public OpRewritePattern<AtenBatchNormOp> {
using OpRewritePattern<AtenBatchNormOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenBatchNormOp op,
PatternRewriter &rewriter) const override {
// TODO: Add support for `training` mode.
bool training = false;
if (!matchPattern(op.training(), m_TorchConstantBool(&training)) ||
training)
return rewriter.notifyMatchFailure(
op, "unimplemented: training mode is not supported");
// The `mean` and `invstd` outputs shape should be {0} in the inference
// mode.
BaseTensorType tensorType = op.getType().cast<BaseTensorType>();
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(
op, "unimplemented: non-floating point type input");
Type emptyType =
tensorType.getWithSizesAndDtype({0}, tensorType.getDtype());
// The first output tensor of the `AtenNativeBatchNormOp` is essentially
// `AtenBatchNormOp` result.
auto nativeBatchNorm = rewriter.create<AtenNativeBatchNormOp>(
op.getLoc(), op.getType(), /*meanType=*/emptyType,
/*invStdType=*/emptyType, op.input(), op.weight(), op.bias(),
op.running_mean(), op.running_var(), op.training(), op.momentum(),
op.eps());
rewriter.replaceOp(op, nativeBatchNorm.getResult(0));
return success();
}
};
} // namespace
namespace {
// Decompose `aten.empty_like` op into `aten.size` and `aten.empty` ops.
class DecomposeAtenEmptyLikeOp : public OpRewritePattern<AtenEmptyLikeOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenEmptyLikeOp op,
PatternRewriter &rewriter) const override {
auto sizeListType =
Torch::ListType::get(Torch::IntType::get(op.getContext()));
Value sizeList =
rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.self());
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
op, op.getType(), sizeList, op.dtype(), op.layout(), op.device(),
op.pin_memory(), op.memory_format());
return success();
}
};
} // namespace
namespace {
// The `aten.arange` op is converted to `aten.arange.start_step` op.
class DecomposeAtenArangeOp : public OpRewritePattern<AtenArangeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenArangeOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
// The AtenArangeOp doesn't have a start and step value. Therefore we set
// them as default values 0 and 1, respectively.
Value start, step;
start = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
step = rewriter.create<Torch::ConstantIntOp>(loc,
rewriter.getI64IntegerAttr(1));
rewriter.replaceOpWithNewOp<AtenArangeStartStepOp>(
op, op.getType(), start, op.end(), step, op.dtype(), op.layout(),
op.device(), op.pin_memory());
return success();
}
};
} // namespace
namespace {
// The `aten.arange.start` op is converted to `aten.arange.start_step` op.
class DecomposeAtenArangeStartOp : public OpRewritePattern<AtenArangeStartOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenArangeStartOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
// The AtenArangeStartOp doesn't have a step value. Therefore we set it as
// default value 1.
Value step;
step = rewriter.create<Torch::ConstantIntOp>(loc,
rewriter.getI64IntegerAttr(1));
rewriter.replaceOpWithNewOp<AtenArangeStartStepOp>(
op, op.getType(), op.start(), op.end(), step, op.dtype(), op.layout(),
op.device(), op.pin_memory());
return success();
}
};
} // namespace
namespace {
// Decompose constant tensor allocation like ops.
template <typename OpTy, int fillVal>
class DecomposeConstantTensorAllocLikeOp : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
// Allocate a memory block.
Value initTensor = rewriter.create<AtenEmptyLikeOp>(
loc, op.getType(), op.self(), op.dtype(), op.layout(), op.device(),
op.pin_memory(), op.memory_format());
Value constVal = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(fillVal));
// Initialize the allocated memory block with `fillVal`.
rewriter.replaceOpWithNewOp<PseudoAtenFillScalarOp>(
op, initTensor.getType(), initTensor, constVal);
return success();
}
};
} // namespace
namespace {
class DecomposeAtenNativeBatchNormOp
: public OpRewritePattern<AtenNativeBatchNormOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNativeBatchNormOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
MLIRContext *context = op.getContext();
Value input = op.input();
Value weight = op.weight();
Value bias = op.bias();
Value runningMean = op.running_mean();
Value runningVar = op.running_var();
Value eps = op.eps();
// TODO: Add support for optional type parameters.
if (weight.getType().isa<OptionalType>() ||
bias.getType().isa<OptionalType>() ||
runningMean.getType().isa<OptionalType>() ||
runningVar.getType().isa<OptionalType>())
return rewriter.notifyMatchFailure(
op, "unimplemented: optional type arg is not supported");
// TODO: Add support for `training` mode.
bool training = false;
if (!matchPattern(op.training(), m_TorchConstantBool(&training)) ||
training)
return rewriter.notifyMatchFailure(
op, "unimplemented: training mode is not supported");
// Rank of the input tensor must be greater than or equal to 2. The shape of
// the `input` is supposed to be (N, C, D?, H?, W?).
int64_t inputRank = getTensorRank(input);
if (inputRank < 2)
return rewriter.notifyMatchFailure(
op, "input must have rank greater than or equal to 2");
// In the inference mode, the `runningMean` and `runningVar` must not be
// None.
if (runningMean.getType().isa<Torch::NoneType>() ||
runningVar.getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "running stats must not be None in inference mode");
// Rank of `runningMean` and `runningVar` must be exactly 1.
if (getTensorRank(runningMean) != 1 || getTensorRank(runningVar) != 1)
return rewriter.notifyMatchFailure(
op, "expected running_mean and running_var to be rank 1");
// The shape of `runningMean` and `runningVar` must be (numFeatures). Here,
// 'numFeatures' is C from an expected 'input' of size (N,C,D?,H?,W?).
Value zero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value numFeatures = rewriter.create<AtenSizeIntOp>(loc, input, /*dim=*/one);
auto dim0EqualsNumFeatures = [&](Value v) {
Value dim0 = rewriter.create<AtenSizeIntOp>(loc, v, /*dim=*/zero);
Value eqCmp = rewriter.create<AtenEqIntOp>(loc, BoolType::get(context),
dim0, numFeatures);
rewriter.create<RuntimeAssertOp>(
loc, eqCmp,
rewriter.getStringAttr("size of the 0th dimension must be equal to "
"the number of features"));
};
dim0EqualsNumFeatures(runningMean);
dim0EqualsNumFeatures(runningVar);
// The `runningMean` and `runningVar` must be reshaped to (1, C, 1?, 1?, 1?)
// to make it broadcast-compatible with (N, C, D?, H?, W?).
// 1. runningMean = runningMean.view(1, C, 1?, 1?, 1?)
// 2. runningVar = runningVar.view(1, C, 1?, 1?, 1?)
SmallVector<Value> runningStatsShape(inputRank, one);
runningStatsShape[1] = numFeatures;
Value runningStatsSizeList = rewriter.create<PrimListConstructOp>(
loc, ListType::get(IntType::get(context)), runningStatsShape);
SmallVector<int64_t> runningStatsShapeInt(inputRank, 1);
runningStatsShapeInt[1] = ShapedType::kDynamicSize;
Type dtype = input.getType().cast<ValueTensorType>().getDtype();
Type reshapeType = ValueTensorType::get(
context, llvm::makeArrayRef(runningStatsShapeInt), dtype);
runningMean = rewriter.create<AtenViewOp>(loc, reshapeType, runningMean,
runningStatsSizeList);
runningVar = rewriter.create<AtenViewOp>(loc, reshapeType, runningVar,
runningStatsSizeList);
// normalizedInput = (input - runningMean) / (sqrt(runningVar + eps)).
Value inputSubMean = rewriter.create<AtenSubTensorOp>(
loc, input.getType(), input, runningMean, /*alpha=*/one);
Value varEps = rewriter.create<AtenAddScalarOp>(
loc, runningVar.getType(), runningVar, eps, /*alpha=*/one);
Value invStd = rewriter.create<AtenRsqrtOp>(loc, varEps.getType(), varEps);
Value normalizedInput = rewriter.create<AtenMulTensorOp>(
loc, inputSubMean.getType(), inputSubMean, invStd);
// The `weight` and `bias` must be reshaped to (1, C, 1?, 1?, 1?) to make it
// broadcast-compatible with (N, C, D?, H?, W?).
// 1. weight = weight.view(1, C, 1?, 1?, 1?)
// 2. bias = bias.view(1, C, 1?, 1?, 1?)
// 3. output = normalizedInput * weight + bias
Value batchNormOutput = normalizedInput;
if (!weight.getType().isa<Torch::NoneType>()) {
// The shape of the `weight` tensor must be (numFeatures).
if (getTensorRank(weight) != 1)
return rewriter.notifyMatchFailure(op, "expected weight to be rank 1");
dim0EqualsNumFeatures(weight);
weight = rewriter.create<AtenViewOp>(loc, reshapeType, weight,
runningStatsSizeList);
batchNormOutput = rewriter.create<AtenMulTensorOp>(
loc, batchNormOutput.getType(), batchNormOutput, weight);
}
if (!bias.getType().isa<Torch::NoneType>()) {
// The shape of the `bias` tensor must be (numFeatures).
if (getTensorRank(bias) != 1)
return rewriter.notifyMatchFailure(op, "expected bias to be rank 1");
dim0EqualsNumFeatures(bias);
bias = rewriter.create<AtenViewOp>(loc, reshapeType, bias,
runningStatsSizeList);
batchNormOutput = rewriter.create<AtenAddTensorOp>(
loc, batchNormOutput.getType(), batchNormOutput, bias, /*alpha=*/one);
}
// The `mean` and `invstd` outputs are empty tensors in inference mode.
Value zeroList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(zero.getType()), zero);
Value none = rewriter.create<ConstantNoneOp>(loc);
Value emptyMeanTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
loc, op.getType(1), zeroList, /*dtype=*/none, /*layout=*/none,
/*device=*/none, /*pin_memory=*/none, /*memory_format=*/none);
Value emptyInvStdTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
loc, op.getType(2), zeroList, /*dtype=*/none, /*layout=*/none,
/*device=*/none, /*pin_memory=*/none, /*memory_format=*/none);
rewriter.replaceOp(op,
{batchNormOutput, emptyMeanTensor, emptyInvStdTensor});
return success();
}
};
} // namespace
// Decompse `Aten_UnsafeViewOp` into `AtenViewOp`. _unsafe_view() differs from
// view() in that the returned tensor isn't treated as a view for the purposes
// of automatic differentiation. It's only safe to use if the `self` tensor is
// temporary. For example, the viewed tensor here (a + b) is discarded
// immediately after viewing:
//
// res = _unsafe_view(a + b, size);
//
// This is a hack because in-place operations on tensors treated like views
// can be much more expensive than the same operations on non-view tensors.
// Refer to
// https://github.com/pytorch/pytorch/blob/364055b2771ecf9b54f1d67a8bf44bb5496476d4/aten/src/ATen/native/TensorShape.cpp#L2072
namespace {
class DecomposeAten_UnsafeViewOp : public OpRewritePattern<Aten_UnsafeViewOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_UnsafeViewOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<AtenViewOp>(op, op.getType(), op.self(),
op.size());
return success();
}
};
} // namespace
namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
ConversionTarget target(*context);
target.addLegalDialect<Torch::TorchDialect>();
patterns.add<DecomposeAtenSoftmaxIntOp>(context);
target.addIllegalOp<AtenSoftmaxIntOp>();
patterns.add<DecomposeAten_SoftmaxOp>(context);
target.addIllegalOp<Aten_SoftmaxOp>();
patterns.add<DecomposeAten_LogSoftmaxOp>(context);
target.addIllegalOp<Aten_LogSoftmaxOp>();
patterns.add<DecomposeAtenLogSoftmaxIntOp>(context);
target.addIllegalOp<AtenLogSoftmaxIntOp>();
patterns.add<DecomposeAtenEmptyLikeOp>(context);
target.addIllegalOp<AtenEmptyLikeOp>();
patterns.add<DecomposeConstantTensorAllocLikeOp<AtenOnesLikeOp, 1>>(
context);
target.addIllegalOp<AtenOnesLikeOp>();
patterns.add<DecomposeConstantTensorAllocLikeOp<AtenZerosLikeOp, 0>>(
context);
target.addIllegalOp<AtenZerosLikeOp>();
patterns.add<DecomposeAtenExpandOp>(context);
target.addIllegalOp<AtenExpandOp>();
patterns.add<DecomposeAtenSizeOp>(context);
target.addIllegalOp<AtenSizeOp>();
patterns.add<DecomposeAtenReshapeOp>(context);
target.addIllegalOp<AtenReshapeOp>();
patterns.add<DecomposeAten_SoftmaxBackwardDataOp>(context);
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
patterns.add<DecomposeAtenTanhBackwardOp>(context);
target.addIllegalOp<AtenTanhBackwardOp>();
patterns.add<DecomposeAtenAddmmOp>(context);
target.addIllegalOp<AtenAddmmOp>();
patterns.add<DecomposeAtenMeanOp>(context);
target.addIllegalOp<AtenMeanOp>();
patterns.add<DecomposeAtenSelectIntOp>(context);
target.addIllegalOp<AtenSelectIntOp>();
patterns.add<DecomposeAtenMatmulOp>(context);
target.addIllegalOp<AtenTOp>();
patterns.add<DecomposeAtenTOp>(context);
patterns.add<DecomposeAten_LogSoftmaxBackwardDataOp>(context);
target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>();
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {
int lhsRank = getTensorRank(op.self());
int rhsRank = getTensorRank(op.other());
// Make aten.matmul legal if the following condition is satisfied.
return (lhsRank != 2 || rhsRank != 2) && (lhsRank != 3 || rhsRank != 3);
});
patterns.add<DecomposeAtenAddCLikeOp<AtenAddcmulOp, AtenMulTensorOp>>(context);
target.addIllegalOp<AtenAddcmulOp>();
patterns.add<DecomposeAtenAddCLikeOp<AtenAddcdivOp, AtenDivTensorOp>>(context);
target.addIllegalOp<AtenAddcdivOp>();
target.addIllegalOp<AtenLayerNormOp>();
patterns.add<DecomposeAtenLayerNormOp>(context);
target.addIllegalOp<AtenNativeBatchNormOp>();
patterns.add<DecomposeAtenNativeBatchNormOp>(context);
target.addIllegalOp<AtenBatchNormOp>();
patterns.add<DecomposeAtenBatchNormOp>(context);
patterns.add<DecomposeAtenArangeOp>(context);
target.addIllegalOp<AtenArangeOp>();
patterns.add<DecomposeAtenArangeStartOp>(context);
target.addIllegalOp<AtenArangeStartOp>();
patterns.add<DecomposeAtenArgMaxOp>(context);
target.addIllegalOp<AtenArgmaxOp>();
patterns.add<DecomposeAtenSquareOp>(context);
target.addIllegalOp<AtenSquareOp>();
patterns.add<DecomposeAtenVarOp>(context);
target.addIllegalOp<AtenVarOp>();
patterns.add<DecomposeAtenStdOp>(context);
target.addIllegalOp<AtenStdOp>();
patterns.add<DecomposeAten_UnsafeViewOp>(context);
target.addIllegalOp<Aten_UnsafeViewOp>();
patterns.add<DecomposeAtenBernoulliOp>(context);
target.addIllegalOp<AtenBernoulliOp>();
patterns.add<DecomposePseudoAtenBernoulliFloatOp>(context);
target.addIllegalOp<PseudoAtenBernoulliFloatOp>();
patterns.add<DecomposeAtenHardsigmoidOp>(context);
target.addIllegalOp<AtenHardsigmoidOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::torch::Torch::createDecomposeComplexOpsPass() {
return std::make_unique<DecomposeComplexOpsPass>();
}