mirror of https://github.com/llvm/torch-mlir
839 lines
33 KiB
C++
839 lines
33 KiB
C++
//===----------------------------------------------------------------------===//
|
||
//
|
||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
// See https://llvm.org/LICENSE.txt for license information.
|
||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
// Also available under a BSD-style license. See LICENSE.
|
||
//
|
||
//===----------------------------------------------------------------------===//
|
||
|
||
#include "PassDetail.h"
|
||
|
||
#include "mlir/Transforms/DialectConversion.h"
|
||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||
#include "llvm/ADT/StringExtras.h"
|
||
|
||
using namespace mlir;
|
||
using namespace mlir::torch;
|
||
using namespace mlir::torch::Torch;
|
||
|
||
// Helper funtion to get rank of `Base tensor type`.
|
||
// -1 is returned if the tensorRank can't be determined.
|
||
static int getTensorRank(Value tensor) {
|
||
int tensorRank = -1;
|
||
BaseTensorType tensorType = tensor.getType().cast<BaseTensorType>();
|
||
|
||
if (tensorType.hasSizes()) {
|
||
ArrayRef<int64_t> tensorShape = tensorType.getSizes();
|
||
tensorRank = tensorShape.size();
|
||
}
|
||
return tensorRank;
|
||
}
|
||
|
||
static Value createSumAlongDimension(PatternRewriter &rewriter, Location loc,
|
||
Operation *op, Value input, Value dim,
|
||
bool keepDim) {
|
||
BaseTensorType tensorType = input.getType().cast<BaseTensorType>();
|
||
Value dimList = rewriter.create<PrimListConstructOp>(
|
||
loc, Torch::ListType::get(dim.getType()), dim);
|
||
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
|
||
Value dtype = rewriter.create<ConstantNoneOp>(loc);
|
||
SmallVector<int64_t> sizes;
|
||
int64_t dimInt;
|
||
if (tensorType.hasSizes()) {
|
||
ArrayRef<int64_t> inputShape = tensorType.getSizes();
|
||
int64_t inputRank = inputShape.size();
|
||
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
|
||
dimInt = toPositiveDim(dimInt, inputRank);
|
||
if (!isValidDim(dimInt, inputRank)) {
|
||
(void)rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
||
return nullptr;
|
||
}
|
||
sizes.append(inputShape.begin(), inputShape.end());
|
||
sizes[dimInt] = 1;
|
||
} else {
|
||
sizes.resize(inputRank, kUnknownSize);
|
||
}
|
||
}
|
||
|
||
Type resultType = tensorType.getWithSizesAndDtype(
|
||
sizes.size() == 0 ? Optional<ArrayRef<int64_t>>()
|
||
: llvm::makeArrayRef(sizes),
|
||
tensorType.getDtype());
|
||
Value sum = rewriter.create<AtenSumDimIntListOp>(loc, resultType, input,
|
||
dimList, keepDimCst, dtype);
|
||
return sum;
|
||
}
|
||
|
||
// Helper for creating `aten::sub_tensor_op`.
|
||
static Value createTensorSub(PatternRewriter &rewriter, Location loc,
|
||
Type tensorType, Value lhs, Value rhs) {
|
||
Value alpha =
|
||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
|
||
Value sub =
|
||
rewriter.create<AtenSubTensorOp>(loc, tensorType, lhs, rhs, alpha);
|
||
return sub;
|
||
}
|
||
|
||
// Share code between `softmax_backward` and `log_softmax_backward` ops.
|
||
// Returns x - y * sum(z, dim).
|
||
static Value createSoftmaxBackwardCommonKernel(PatternRewriter &rewriter,
|
||
Location loc, Operation *op,
|
||
Type tensorType, Value x,
|
||
Value y, Value z, Value dim) {
|
||
Value sum = createSumAlongDimension(rewriter, loc, op, z, dim, /*keepDim=*/true);
|
||
if (!sum)
|
||
return nullptr;
|
||
auto broadcastSizeType =
|
||
Torch::ListType::get(Torch::IntType::get(op->getContext()));
|
||
Value broadcastSize = rewriter.create<AtenSizeOp>(loc, broadcastSizeType, z);
|
||
Value sumBroadcast =
|
||
rewriter.create<AtenBroadcastToOp>(loc, tensorType, sum, broadcastSize);
|
||
Value temp =
|
||
rewriter.create<AtenMulTensorOp>(loc, tensorType, y, sumBroadcast);
|
||
|
||
Value sub = createTensorSub(rewriter, loc, tensorType, x, temp);
|
||
return sub;
|
||
}
|
||
|
||
namespace {
|
||
class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
|
||
public:
|
||
using OpRewritePattern::OpRewritePattern;
|
||
LogicalResult matchAndRewrite(AtenSizeOp op,
|
||
PatternRewriter &rewriter) const override {
|
||
Location loc = op.getLoc();
|
||
Value self = op.self();
|
||
MLIRContext *context = op.getContext();
|
||
int64_t rank = getTensorRank(self);
|
||
if (rank < 0)
|
||
return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor");
|
||
SmallVector<Value> sizes;
|
||
for (int i = 0; i < rank; i++) {
|
||
Value dim = rewriter.create<Torch::ConstantIntOp>(
|
||
loc, rewriter.getI64IntegerAttr(i));
|
||
sizes.push_back(rewriter.create<AtenSizeIntOp>(loc, self, dim));
|
||
}
|
||
|
||
Value sizeList = rewriter.create<PrimListConstructOp>(
|
||
loc, Torch::ListType::get(Torch::IntType::get(context)), sizes);
|
||
rewriter.replaceOp(op, sizeList);
|
||
return success();
|
||
}
|
||
};
|
||
} // namespace
|
||
|
||
namespace {
|
||
class DecomposeAtenSelectIntOp : public OpRewritePattern<AtenSelectIntOp> {
|
||
public:
|
||
using OpRewritePattern::OpRewritePattern;
|
||
LogicalResult matchAndRewrite(AtenSelectIntOp op,
|
||
PatternRewriter &rewriter) const override {
|
||
Location loc = op.getLoc();
|
||
Value one =
|
||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||
Value end =
|
||
rewriter.create<AtenAddIntOp>(loc, one.getType(), op.index(), one);
|
||
rewriter.replaceOpWithNewOp<AtenSliceTensorOp>(op, op.getResult().getType(),
|
||
op.self(), op.dim(),
|
||
op.index(), end, one);
|
||
|
||
return success();
|
||
}
|
||
};
|
||
} // namespace
|
||
|
||
// Calculates the softmax function on the given `input` tensor. Softmax(x) =
|
||
// exp(x)/sum(exp(x)).
|
||
template <typename OpTy>
|
||
static Value getSoftmaxResult(OpTy op, Type resultType,
|
||
PatternRewriter &rewriter) {
|
||
Location loc = op.getLoc();
|
||
Value dim = op.dim();
|
||
Value self = op.self();
|
||
|
||
// exp(x)
|
||
Value exp = rewriter.create<AtenExpOp>(loc, resultType, self);
|
||
// sum(exp(x))
|
||
Value sum =
|
||
createSumAlongDimension(rewriter, loc, op, exp, dim, /*keepDim=*/true);
|
||
if (!sum)
|
||
return nullptr;
|
||
// exp(x) / sum(exp(x))
|
||
return rewriter.create<AtenDivTensorOp>(loc, resultType, exp, sum);
|
||
}
|
||
|
||
// Decompose softmax into: exp(x) / sum(exp(x))
|
||
namespace {
|
||
class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> {
|
||
public:
|
||
using OpRewritePattern::OpRewritePattern;
|
||
LogicalResult matchAndRewrite(AtenSoftmaxIntOp op,
|
||
PatternRewriter &rewriter) const override {
|
||
Value self = op.self();
|
||
if (!op.dtype().getType().isa<Torch::NoneType>())
|
||
return rewriter.notifyMatchFailure(
|
||
op, "Unimplemented non-None dtype for softmax");
|
||
|
||
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
|
||
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
||
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
||
|
||
Value result = getSoftmaxResult(op, tensorType, rewriter);
|
||
if (!result)
|
||
return failure();
|
||
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
|
||
result);
|
||
return success();
|
||
}
|
||
};
|
||
} // namespace
|
||
|
||
namespace {
|
||
class DecomposeAten_SoftmaxOp : public OpRewritePattern<Aten_SoftmaxOp> {
|
||
public:
|
||
using OpRewritePattern::OpRewritePattern;
|
||
LogicalResult matchAndRewrite(Aten_SoftmaxOp op,
|
||
PatternRewriter &rewriter) const override {
|
||
Value self = op.self();
|
||
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
|
||
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
||
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
||
bool halfToFloat;
|
||
if (!matchPattern(op.half_to_float(), m_TorchConstantBool(&halfToFloat)))
|
||
return rewriter.notifyMatchFailure(
|
||
op, "Expected a boolean value for half_to_float");
|
||
|
||
// Currently, setting `halfToFloat` is not supported as the E2E testing for
|
||
// the same is not present on CPU.
|
||
if (halfToFloat)
|
||
return rewriter.notifyMatchFailure(
|
||
op, "halfToFloat is currently not supported.");
|
||
|
||
Value result = getSoftmaxResult(op, tensorType, rewriter);
|
||
if (!result)
|
||
return op.emitError("failed to get softmax result");
|
||
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
|
||
result);
|
||
return success();
|
||
}
|
||
};
|
||
} // namespace
|
||
|
||
// Aten_SoftmaxBackwardDataOp(gradOutput, output, dim) =>
|
||
// newGrad = gradOutput * output
|
||
// result = newGrad - output * sum(newGrad, dim))
|
||
//
|
||
// Refer to
|
||
// https://github.com/pytorch/pytorch/blob/15fecc4c830a3907fde4b44c9962dc4144da50a4/torch/csrc/jit/codegen/cuda/ops/normalization.cpp#L31
|
||
namespace {
|
||
class DecomposeAten_SoftmaxBackwardDataOp
|
||
: public OpRewritePattern<Aten_SoftmaxBackwardDataOp> {
|
||
public:
|
||
using OpRewritePattern::OpRewritePattern;
|
||
LogicalResult matchAndRewrite(Aten_SoftmaxBackwardDataOp op,
|
||
PatternRewriter &rewriter) const override {
|
||
Location loc = op.getLoc();
|
||
Value gradOutput = op.grad_output();
|
||
Value output = op.output();
|
||
Value dim = op.dim();
|
||
|
||
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>();
|
||
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
||
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
||
|
||
Value newGrad =
|
||
rewriter.create<AtenMulTensorOp>(loc, tensorType, gradOutput, output);
|
||
Value result = createSoftmaxBackwardCommonKernel(
|
||
rewriter, loc, op, tensorType, newGrad, output, newGrad, dim);
|
||
if (!result)
|
||
return rewriter.notifyMatchFailure(
|
||
op,
|
||
"nullptr returned by createSoftmaxBackwardCommonKernel function.");
|
||
rewriter.replaceOp(op, result);
|
||
return success();
|
||
}
|
||
};
|
||
} // namespace
|
||
|
||
// AtenTanhBackwardOp(gradOutput, output) =>
|
||
// result = gradOutput * (1 - output^2)
|
||
// To get away from broadcasts the above formula is expanded i.e.,
|
||
// result = gradOutput - (gradOutput * output^2)
|
||
namespace {
|
||
class DecomposeAtenTanhBackwardOp
|
||
: public OpRewritePattern<AtenTanhBackwardOp> {
|
||
public:
|
||
using OpRewritePattern::OpRewritePattern;
|
||
LogicalResult matchAndRewrite(AtenTanhBackwardOp op,
|
||
PatternRewriter &rewriter) const override {
|
||
Location loc = op.getLoc();
|
||
Value gradOutput = op.grad_output();
|
||
|
||
// `output` is the value flowing out from tanh. Hence, tanh(x) = output.
|
||
// Since, dTanh(x) = (1 - tanh(x)^2) hence, dOutput = (1 - output^2).
|
||
Value output = op.output();
|
||
|
||
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>();
|
||
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
||
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
||
|
||
Value tanhSquare =
|
||
rewriter.create<AtenMulTensorOp>(loc, tensorType, output, output);
|
||
Value gradMulTanhSquare = rewriter.create<AtenMulTensorOp>(
|
||
loc, tensorType, tanhSquare, gradOutput);
|
||
|
||
Value newGrad = createTensorSub(rewriter, loc, tensorType, gradOutput,
|
||
gradMulTanhSquare);
|
||
rewriter.replaceOp(op, newGrad);
|
||
return success();
|
||
}
|
||
};
|
||
} // namespace
|
||
|
||
// Aten_LogSoftmaxBackwardDataOp(gradOutput, output, dim) =>
|
||
// result = gradOutput - (exp(output) * sum(gradOutput, dim))
|
||
namespace {
|
||
class DecomposeAten_LogSoftmaxBackwardDataOp
|
||
: public OpRewritePattern<Aten_LogSoftmaxBackwardDataOp> {
|
||
public:
|
||
using OpRewritePattern::OpRewritePattern;
|
||
LogicalResult matchAndRewrite(Aten_LogSoftmaxBackwardDataOp op,
|
||
PatternRewriter &rewriter) const override {
|
||
Location loc = op.getLoc();
|
||
Value gradOutput = op.grad_output();
|
||
Value output = op.output();
|
||
Value dim = op.dim();
|
||
|
||
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>();
|
||
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
||
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
||
|
||
Value expOut = rewriter.create<AtenExpOp>(loc, tensorType, output);
|
||
Value result = createSoftmaxBackwardCommonKernel(
|
||
rewriter, loc, op, tensorType, gradOutput, expOut, gradOutput, dim);
|
||
if (!result)
|
||
return rewriter.notifyMatchFailure(
|
||
op,
|
||
"nullptr returned by createSoftmaxBackwardCommonKernel function.");
|
||
rewriter.replaceOp(op, result);
|
||
return success();
|
||
}
|
||
};
|
||
} // namespace
|
||
|
||
// Decompose `AtenArgMaxOp` into `AtenMaxDimOp`.
|
||
namespace {
|
||
class DecomposeAtenArgMaxOp : public OpRewritePattern<AtenArgmaxOp> {
|
||
public:
|
||
using OpRewritePattern::OpRewritePattern;
|
||
LogicalResult matchAndRewrite(AtenArgmaxOp op,
|
||
PatternRewriter &rewriter) const override {
|
||
Location loc = op.getLoc();
|
||
Value input = op.self();
|
||
Value dim = op.dim();
|
||
Value keepDim = op.keepdim();
|
||
Value result = op.result();
|
||
|
||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||
BaseTensorType indicesTensorType = result.getType().cast<BaseTensorType>();
|
||
|
||
if (!indicesTensorType.hasSizes())
|
||
return failure();
|
||
BaseTensorType valueTensorType =
|
||
inputType
|
||
.getWithSizesAndDtype(indicesTensorType.getSizes(),
|
||
inputType.getDtype())
|
||
.cast<BaseTensorType>();
|
||
|
||
// If the dim type is `NoneType` i.e. reduce along all the dimensions.
|
||
// `AtenMaxDimOp` doesn't support dim as `NoneType` so first the input
|
||
// tensor is flattened to 1d tensor and then the reduction happens on the
|
||
// 0th dimension.
|
||
if (dim.getType().isa<Torch::NoneType>()) {
|
||
BaseTensorType flattenType =
|
||
inputType.getWithSizesAndDtype({kUnknownSize}, inputType.getDtype())
|
||
.cast<BaseTensorType>();
|
||
dim = rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||
Value end = rewriter.create<ConstantIntOp>(
|
||
loc, rewriter.getI64IntegerAttr(getTensorRank(input) - 1));
|
||
input = rewriter.create<AtenFlattenUsingIntsOp>(loc, flattenType, input,
|
||
dim, end);
|
||
}
|
||
Value maxResult =
|
||
rewriter
|
||
.create<AtenMaxDimOp>(loc, valueTensorType, indicesTensorType,
|
||
input, dim, keepDim)
|
||
.indices();
|
||
|
||
rewriter.replaceOp(op, maxResult);
|
||
return success();
|
||
}
|
||
};
|
||
} // namespace
|
||
|
||
// Decompose aten.log_softmax op into: log(softmax(x))
|
||
namespace {
|
||
class DecomposeAtenLogSoftmaxIntOp
|
||
: public OpRewritePattern<AtenLogSoftmaxIntOp> {
|
||
public:
|
||
using OpRewritePattern::OpRewritePattern;
|
||
LogicalResult matchAndRewrite(AtenLogSoftmaxIntOp op,
|
||
PatternRewriter &rewriter) const override {
|
||
Location loc = op.getLoc();
|
||
Value self = op.self();
|
||
Value dim = op.dim();
|
||
if (!op.dtype().getType().isa<Torch::NoneType>())
|
||
return rewriter.notifyMatchFailure(
|
||
op, "Unimplemented non-None dtype for log_softmax");
|
||
|
||
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
|
||
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
||
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
||
|
||
// softmax(x, dim)
|
||
Value softmax = rewriter.create<AtenSoftmaxIntOp>(loc, tensorType, self,
|
||
dim, op.dtype());
|
||
rewriter.replaceOpWithNewOp<AtenLogOp>(op, op.getType(), softmax);
|
||
return success();
|
||
}
|
||
};
|
||
} // namespace
|
||
|
||
// Decompose aten.matmul into: aten.mm and aten.bmm according to ranks.
|
||
namespace {
|
||
class DecomposeAtenMatmulOp : public OpRewritePattern<AtenMatmulOp> {
|
||
public:
|
||
using OpRewritePattern::OpRewritePattern;
|
||
LogicalResult matchAndRewrite(AtenMatmulOp op,
|
||
PatternRewriter &rewriter) const override {
|
||
Value lhs = op.self();
|
||
Value rhs = op.other();
|
||
|
||
int lhsRank = getTensorRank(lhs);
|
||
int rhsRank = getTensorRank(rhs);
|
||
|
||
// If both lhs and rhs ranks are 2 then map it to `aten.mm` op.
|
||
if (lhsRank == 2 && rhsRank == 2)
|
||
rewriter.replaceOpWithNewOp<AtenMmOp>(op, op.getType(), lhs, rhs);
|
||
|
||
// If both lhs and rhs ranks are 3 then map it to `aten.bmm` op.
|
||
if (lhsRank == 3 && rhsRank == 3)
|
||
rewriter.replaceOpWithNewOp<AtenBmmOp>(op, op.getType(), lhs, rhs);
|
||
|
||
return success();
|
||
}
|
||
};
|
||
} // namespace
|
||
|
||
namespace {
|
||
class DecomposeAtenTOp : public OpRewritePattern<AtenTOp> {
|
||
public:
|
||
using OpRewritePattern::OpRewritePattern;
|
||
LogicalResult matchAndRewrite(AtenTOp op,
|
||
PatternRewriter &rewriter) const override {
|
||
Value lhs = op.self();
|
||
int lhsRank = getTensorRank(lhs);
|
||
auto loc = op.getLoc();
|
||
|
||
if (lhsRank > 2 || lhsRank < 0) {
|
||
std::string errorMessage =
|
||
"t() expects a tensor with <=2 dimensions, but self is " +
|
||
std::to_string(lhsRank) + "D";
|
||
return rewriter.notifyMatchFailure(op, errorMessage.c_str());
|
||
} else if (lhsRank < 2)
|
||
rewriter.replaceOp(op, lhs);
|
||
else {
|
||
Value zero =
|
||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||
Value one =
|
||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||
rewriter.replaceOpWithNewOp<AtenTransposeIntOp>(op, op.getType(), lhs,
|
||
zero, one);
|
||
}
|
||
return success();
|
||
}
|
||
};
|
||
} // namespace
|
||
|
||
// Decompose aten.expand into aten.broadcast_to op.
|
||
namespace {
|
||
class DecomposeAtenExpandOp : public OpRewritePattern<AtenExpandOp> {
|
||
public:
|
||
using OpRewritePattern::OpRewritePattern;
|
||
LogicalResult matchAndRewrite(AtenExpandOp op,
|
||
PatternRewriter &rewriter) const override {
|
||
bool implicit = false;
|
||
if (!matchPattern(op.implicit(), m_TorchConstantBool(&implicit)) ||
|
||
implicit) {
|
||
return rewriter.notifyMatchFailure(
|
||
op, "unimplemented: requires implicit to be false");
|
||
}
|
||
rewriter.replaceOpWithNewOp<AtenBroadcastToOp>(op, op.getType(), op.self(),
|
||
op.size());
|
||
return success();
|
||
}
|
||
};
|
||
} // namespace
|
||
|
||
// Decompose aten.addmm into aten.mm and aten.add.Tensor op.
|
||
namespace {
|
||
class DecomposeAtenAddmmOp : public OpRewritePattern<AtenAddmmOp> {
|
||
public:
|
||
using OpRewritePattern::OpRewritePattern;
|
||
LogicalResult matchAndRewrite(AtenAddmmOp op,
|
||
PatternRewriter &rewriter) const override {
|
||
Location loc = op.getLoc();
|
||
Value input = op.self();
|
||
Value mat1 = op.mat1();
|
||
Value mat2 = op.mat2();
|
||
|
||
// The operands `mat1`, `mat2` to aten.addmm must be of rank 2.
|
||
if (getTensorRank(mat1) != 2 || getTensorRank(mat2) != 2) {
|
||
return rewriter.notifyMatchFailure(
|
||
op, "expected mat1, mat2 operands to aten.addmm to be rank 2");
|
||
}
|
||
|
||
// TODO: Handle integer type operands.
|
||
if (!input.getType()
|
||
.cast<ValueTensorType>()
|
||
.getDtype()
|
||
.isa<mlir::FloatType>()) {
|
||
return rewriter.notifyMatchFailure(
|
||
op, "unimplemented: non-floating point dtype");
|
||
}
|
||
|
||
// matrix multiplication: matmul = mat1 @ mat2
|
||
Value matmul = rewriter.create<AtenMmOp>(loc, op.getType(), mat1, mat2);
|
||
// scaledInput = self * beta
|
||
Value scaledInput = rewriter.create<AtenMulScalarOp>(loc, input.getType(),
|
||
input, op.beta());
|
||
// result = scaledInput + alpha * matmul
|
||
rewriter.replaceOpWithNewOp<AtenAddTensorOp>(op, op.getType(), scaledInput,
|
||
matmul, op.alpha());
|
||
return success();
|
||
}
|
||
};
|
||
} // namespace
|
||
|
||
// Decompose aten.mean into: sum(x)/div(numTensorElements).
|
||
namespace {
|
||
class DecomposeAtenMeanOp : public OpRewritePattern<AtenMeanOp> {
|
||
public:
|
||
using OpRewritePattern::OpRewritePattern;
|
||
LogicalResult matchAndRewrite(AtenMeanOp op,
|
||
PatternRewriter &rewriter) const override {
|
||
Location loc = op.getLoc();
|
||
Value input = op.self();
|
||
Value output = op.result();
|
||
BaseTensorType outputTensorType = output.getType().cast<BaseTensorType>();
|
||
Value sum = rewriter.create<AtenSumOp>(loc, outputTensorType, input, op.dtype());
|
||
Value numTensorElements = rewriter.create<AtenNumelOp>(loc, input);
|
||
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, outputTensorType, sum,
|
||
numTensorElements);
|
||
return success();
|
||
}
|
||
};
|
||
} // namespace
|
||
|
||
namespace {
|
||
class DecomposeAtenSquareOp : public OpRewritePattern<AtenSquareOp> {
|
||
public:
|
||
using OpRewritePattern::OpRewritePattern;
|
||
LogicalResult matchAndRewrite(AtenSquareOp op,
|
||
PatternRewriter &rewriter) const override {
|
||
Value self = op.self();
|
||
rewriter.replaceOpWithNewOp<AtenMulTensorOp>(op, op.getType(), self, self);
|
||
return success();
|
||
}
|
||
};
|
||
} // namespace
|
||
|
||
// Decompose aten.var into: sum(square(x - mean))/(numTensorElements-1)
|
||
// for unbiased and mean(square(x - mean)) for biased case.
|
||
namespace {
|
||
class DecomposeAtenVarOp : public OpRewritePattern<AtenVarOp> {
|
||
public:
|
||
using OpRewritePattern::OpRewritePattern;
|
||
LogicalResult matchAndRewrite(AtenVarOp op,
|
||
PatternRewriter &rewriter) const override {
|
||
Location loc = op.getLoc();
|
||
Value self = op.self();
|
||
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
|
||
if (!inputTensorTy.hasDtype() ||
|
||
!inputTensorTy.getDtype().isa<mlir::FloatType>()) {
|
||
return rewriter.notifyMatchFailure(op,
|
||
"Only aten.var support floating type");
|
||
}
|
||
BaseTensorType rank0FloatTensorTy = op.getType().cast<BaseTensorType>();
|
||
assert(rank0FloatTensorTy.getSizes().size() == 0 &&
|
||
"Op should have rank 0 tensor type");
|
||
|
||
bool unbiased;
|
||
if (!matchPattern(op.unbiased(), m_TorchConstantBool(&unbiased))) {
|
||
return rewriter.notifyMatchFailure(
|
||
op, "Only support constant unbiased for aten.var");
|
||
}
|
||
|
||
Value dtype = rewriter.create<ConstantNoneOp>(loc);
|
||
Value mean =
|
||
rewriter.create<AtenMeanOp>(loc, rank0FloatTensorTy, self, dtype);
|
||
Value subMean = createTensorSub(rewriter, loc, inputTensorTy, self, mean);
|
||
Value square = rewriter.create<AtenSquareOp>(loc, inputTensorTy, subMean);
|
||
Value var;
|
||
if (unbiased) {
|
||
// Bessel’s correction is used. Divide the square sum by
|
||
// numTensorElements-1.
|
||
Value squareSum =
|
||
rewriter.create<AtenSumOp>(loc, rank0FloatTensorTy, square, dtype);
|
||
Value numTensorElements = rewriter.create<AtenNumelOp>(loc, square);
|
||
Value cst1 = rewriter.create<Torch::ConstantIntOp>(
|
||
loc, rewriter.getI64IntegerAttr(1));
|
||
Value numTensorElementsSub1 =
|
||
rewriter.create<AtenSubIntOp>(loc, numTensorElements, cst1);
|
||
var = rewriter.replaceOpWithNewOp<AtenDivScalarOp>(
|
||
op, rank0FloatTensorTy, squareSum, numTensorElementsSub1);
|
||
} else {
|
||
var = rewriter.replaceOpWithNewOp<AtenMeanOp>(op, rank0FloatTensorTy,
|
||
square, dtype);
|
||
}
|
||
return success();
|
||
}
|
||
};
|
||
} // namespace
|
||
|
||
// Decompose aten.std to sqrt(var(x))
|
||
namespace {
|
||
class DecomposeAtenStdOp : public OpRewritePattern<AtenStdOp> {
|
||
public:
|
||
using OpRewritePattern::OpRewritePattern;
|
||
LogicalResult matchAndRewrite(AtenStdOp op,
|
||
PatternRewriter &rewriter) const override {
|
||
Value self = op.self();
|
||
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
|
||
if (!inputTensorTy.hasDtype() ||
|
||
!inputTensorTy.getDtype().isa<mlir::FloatType>()) {
|
||
return rewriter.notifyMatchFailure(op,
|
||
"Only aten.std support floating type");
|
||
}
|
||
Value var = rewriter.create<AtenVarOp>(op->getLoc(), op.getType(),
|
||
op.self(), op.unbiased());
|
||
rewriter.replaceOpWithNewOp<AtenSqrtOp>(op, op.getType(), var);
|
||
return success();
|
||
}
|
||
};
|
||
} // namespace
|
||
|
||
namespace {
|
||
template<typename OpTy, typename T1T2Op>
|
||
class DecomposeAtenAddCLikeOp : public OpRewritePattern<OpTy> {
|
||
using OpRewritePattern<OpTy>::OpRewritePattern;
|
||
LogicalResult matchAndRewrite(OpTy op,
|
||
PatternRewriter &rewriter) const override {
|
||
Location loc = op.getLoc();
|
||
Value input = op.self();
|
||
Value tensor1 = op.tensor1();
|
||
Value tensor2 = op.tensor2();
|
||
Value value = op.value();
|
||
|
||
Value product = rewriter.create<T1T2Op>(loc, op.getType(), tensor1, tensor2);
|
||
rewriter.replaceOpWithNewOp<AtenAddTensorOp>(op, op.getType(), input, product,
|
||
value);
|
||
return success();
|
||
}
|
||
};
|
||
|
||
class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
|
||
using OpRewritePattern<AtenLayerNormOp>::OpRewritePattern;
|
||
LogicalResult matchAndRewrite(AtenLayerNormOp op,
|
||
PatternRewriter &rewriter) const override {
|
||
Location loc = op.getLoc();
|
||
|
||
auto input = op.input().getType().cast<BaseTensorType>();
|
||
if (!input.hasSizes())
|
||
return rewriter.notifyMatchFailure(
|
||
op, "input tensor should have known sizes.");
|
||
int64_t inputRank = input.getSizes().size();
|
||
Value normalizedShape = op.normalized_shape();
|
||
SmallVector<Value> normalizedShapeSizesTorchInt;
|
||
getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt);
|
||
std::vector<int64_t> meanVarSizes;
|
||
for (int i = normalizedShapeSizesTorchInt.size(); i < inputRank; i++)
|
||
meanVarSizes.push_back(input.getSizes()[i]);
|
||
auto meanVarType = input.getWithSizesAndDtype(
|
||
llvm::makeArrayRef(meanVarSizes), input.getDtype());
|
||
auto nativeLayerNorm = rewriter.create<AtenNativeLayerNormOp>(
|
||
loc, op.getType(), meanVarType, meanVarType, op.input(),
|
||
op.normalized_shape(), op.weight(), op.bias(), op.eps());
|
||
rewriter.replaceOp(op, nativeLayerNorm.getResult(0));
|
||
return success();
|
||
}
|
||
};
|
||
} // namespace
|
||
|
||
namespace {
|
||
// Decompose `aten.empty_like` op into `aten.size` and `aten.empty` ops.
|
||
class DecomposeAtenEmptyLikeOp : public OpRewritePattern<AtenEmptyLikeOp> {
|
||
public:
|
||
using OpRewritePattern::OpRewritePattern;
|
||
LogicalResult matchAndRewrite(AtenEmptyLikeOp op,
|
||
PatternRewriter &rewriter) const override {
|
||
auto sizeListType =
|
||
Torch::ListType::get(Torch::IntType::get(op.getContext()));
|
||
Value sizeList =
|
||
rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.self());
|
||
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
|
||
op, op.getType(), sizeList, op.dtype(), op.layout(), op.device(),
|
||
op.pin_memory(), op.memory_format());
|
||
return success();
|
||
}
|
||
};
|
||
} // namespace
|
||
|
||
namespace {
|
||
// The `aten.arange` op is converted to `aten.arange.start_step` op.
|
||
class DecomposeAtenArangeOp : public OpRewritePattern<AtenArangeOp> {
|
||
using OpRewritePattern::OpRewritePattern;
|
||
LogicalResult matchAndRewrite(AtenArangeOp op,
|
||
PatternRewriter &rewriter) const override {
|
||
Location loc = op.getLoc();
|
||
// The AtenArangeOp doesn't have a start and step value. Therefore we set
|
||
// them as default values 0 and 1, respectively.
|
||
Value start, step;
|
||
start = rewriter.create<Torch::ConstantIntOp>(
|
||
loc, rewriter.getI64IntegerAttr(0));
|
||
step = rewriter.create<Torch::ConstantIntOp>(loc,
|
||
rewriter.getI64IntegerAttr(1));
|
||
rewriter.replaceOpWithNewOp<AtenArangeStartStepOp>(
|
||
op, op.getType(), start, op.end(), step, op.dtype(), op.layout(),
|
||
op.device(), op.pin_memory());
|
||
return success();
|
||
}
|
||
};
|
||
} // namespace
|
||
|
||
namespace {
|
||
// The `aten.arange.start` op is converted to `aten.arange.start_step` op.
|
||
class DecomposeAtenArangeStartOp : public OpRewritePattern<AtenArangeStartOp> {
|
||
using OpRewritePattern::OpRewritePattern;
|
||
LogicalResult matchAndRewrite(AtenArangeStartOp op,
|
||
PatternRewriter &rewriter) const override {
|
||
Location loc = op.getLoc();
|
||
// The AtenArangeStartOp doesn't have a step value. Therefore we set it as
|
||
// default value 1.
|
||
Value step;
|
||
step = rewriter.create<Torch::ConstantIntOp>(loc,
|
||
rewriter.getI64IntegerAttr(1));
|
||
rewriter.replaceOpWithNewOp<AtenArangeStartStepOp>(
|
||
op, op.getType(), op.start(), op.end(), step, op.dtype(), op.layout(),
|
||
op.device(), op.pin_memory());
|
||
return success();
|
||
}
|
||
};
|
||
} // namespace
|
||
|
||
namespace {
|
||
// Decompose constant tensor allocation like ops.
|
||
template <typename OpTy, int fillVal>
|
||
class DecomposeConstantTensorAllocLikeOp : public OpRewritePattern<OpTy> {
|
||
using OpRewritePattern<OpTy>::OpRewritePattern;
|
||
LogicalResult matchAndRewrite(OpTy op,
|
||
PatternRewriter &rewriter) const override {
|
||
Location loc = op.getLoc();
|
||
// Allocate a memory block.
|
||
Value initTensor = rewriter.create<AtenEmptyLikeOp>(
|
||
loc, op.getType(), op.self(), op.dtype(), op.layout(), op.device(),
|
||
op.pin_memory(), op.memory_format());
|
||
Value constVal = rewriter.create<Torch::ConstantIntOp>(
|
||
loc, rewriter.getI64IntegerAttr(fillVal));
|
||
// Initialize the allocated memory block with `fillVal`.
|
||
rewriter.replaceOpWithNewOp<AtenFill_ScalarOp>(op, initTensor.getType(),
|
||
initTensor, constVal);
|
||
return success();
|
||
}
|
||
};
|
||
} // namespace
|
||
|
||
namespace {
|
||
class DecomposeComplexOpsPass
|
||
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
||
void runOnOperation() override {
|
||
MLIRContext *context = &getContext();
|
||
RewritePatternSet patterns(context);
|
||
ConversionTarget target(*context);
|
||
target.addLegalDialect<Torch::TorchDialect>();
|
||
|
||
patterns.add<DecomposeAtenSoftmaxIntOp>(context);
|
||
target.addIllegalOp<AtenSoftmaxIntOp>();
|
||
patterns.add<DecomposeAten_SoftmaxOp>(context);
|
||
target.addIllegalOp<Aten_SoftmaxOp>();
|
||
patterns.add<DecomposeAtenLogSoftmaxIntOp>(context);
|
||
target.addIllegalOp<AtenLogSoftmaxIntOp>();
|
||
patterns.add<DecomposeAtenEmptyLikeOp>(context);
|
||
target.addIllegalOp<AtenEmptyLikeOp>();
|
||
patterns.add<DecomposeConstantTensorAllocLikeOp<AtenOnesLikeOp, 1>>(
|
||
context);
|
||
target.addIllegalOp<AtenOnesLikeOp>();
|
||
patterns.add<DecomposeConstantTensorAllocLikeOp<AtenZerosLikeOp, 0>>(
|
||
context);
|
||
target.addIllegalOp<AtenZerosLikeOp>();
|
||
patterns.add<DecomposeAtenExpandOp>(context);
|
||
target.addIllegalOp<AtenExpandOp>();
|
||
patterns.add<DecomposeAtenSizeOp>(context);
|
||
target.addIllegalOp<AtenSizeOp>();
|
||
patterns.add<DecomposeAten_SoftmaxBackwardDataOp>(context);
|
||
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
|
||
patterns.add<DecomposeAtenTanhBackwardOp>(context);
|
||
target.addIllegalOp<AtenTanhBackwardOp>();
|
||
patterns.add<DecomposeAtenAddmmOp>(context);
|
||
target.addIllegalOp<AtenAddmmOp>();
|
||
patterns.add<DecomposeAtenMeanOp>(context);
|
||
target.addIllegalOp<AtenMeanOp>();
|
||
patterns.add<DecomposeAtenSelectIntOp>(context);
|
||
target.addIllegalOp<AtenSelectIntOp>();
|
||
patterns.add<DecomposeAtenMatmulOp>(context);
|
||
target.addIllegalOp<AtenTOp>();
|
||
patterns.add<DecomposeAtenTOp>(context);
|
||
patterns.add<DecomposeAten_LogSoftmaxBackwardDataOp>(context);
|
||
target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>();
|
||
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {
|
||
int lhsRank = getTensorRank(op.self());
|
||
int rhsRank = getTensorRank(op.other());
|
||
|
||
// Make aten.matmul legal if the following condition is satisfied.
|
||
return (lhsRank != 2 || rhsRank != 2) && (lhsRank != 3 || rhsRank != 3);
|
||
});
|
||
patterns.add<DecomposeAtenAddCLikeOp<AtenAddcmulOp, AtenMulTensorOp>>(context);
|
||
target.addIllegalOp<AtenAddcmulOp>();
|
||
patterns.add<DecomposeAtenAddCLikeOp<AtenAddcdivOp, AtenDivTensorOp>>(context);
|
||
target.addIllegalOp<AtenAddcdivOp>();
|
||
target.addIllegalOp<AtenLayerNormOp>();
|
||
patterns.add<DecomposeAtenLayerNormOp>(context);
|
||
patterns.add<DecomposeAtenArangeOp>(context);
|
||
target.addIllegalOp<AtenArangeOp>();
|
||
patterns.add<DecomposeAtenArangeStartOp>(context);
|
||
target.addIllegalOp<AtenArangeStartOp>();
|
||
patterns.add<DecomposeAtenArgMaxOp>(context);
|
||
target.addIllegalOp<AtenArgmaxOp>();
|
||
patterns.add<DecomposeAtenSquareOp>(context);
|
||
target.addIllegalOp<AtenSquareOp>();
|
||
patterns.add<DecomposeAtenVarOp>(context);
|
||
target.addIllegalOp<AtenVarOp>();
|
||
patterns.add<DecomposeAtenStdOp>(context);
|
||
target.addIllegalOp<AtenStdOp>();
|
||
|
||
if (failed(applyPartialConversion(getOperation(), target,
|
||
std::move(patterns)))) {
|
||
return signalPassFailure();
|
||
}
|
||
}
|
||
};
|
||
} // namespace
|
||
std::unique_ptr<OperationPass<FuncOp>>
|
||
mlir::torch::Torch::createDecomposeComplexOpsPass() {
|
||
return std::make_unique<DecomposeComplexOpsPass>();
|
||
}
|