torch-mlir/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

293 lines
11 KiB
C++
Raw Normal View History

//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/StringExtras.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
// Helper funtion to get rank of `Base tensor type`.
// -1 is returned if the tensorRank can't be determined.
static int getTensorRank(Value tensor) {
int tensorRank = -1;
BaseTensorType tensorType = tensor.getType().cast<BaseTensorType>();
if (tensorType.hasSizes()) {
ArrayRef<int64_t> tensorShape = tensorType.getSizes();
tensorRank = tensorShape.size();
}
return tensorRank;
}
static Value createAtenSum(PatternRewriter &rewriter, Location loc,
Operation *op, Value input, Value dim,
bool keepDim) {
BaseTensorType tensorType = input.getType().cast<BaseTensorType>();
Value dimList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(dim.getType()), dim);
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
Value dtype = rewriter.create<ConstantNoneOp>(loc);
SmallVector<int64_t> sizes;
int64_t dimInt;
if (tensorType.hasSizes()) {
ArrayRef<int64_t> inputShape = tensorType.getSizes();
int64_t inputRank = inputShape.size();
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
dimInt = toPositiveDim(dimInt, inputRank);
if (!isValidDim(dimInt, inputRank)) {
(void)rewriter.notifyMatchFailure(op, "dim is not a valid dim");
return nullptr;
}
sizes.append(inputShape.begin(), inputShape.end());
sizes[dimInt] = 1;
} else {
sizes.resize(inputRank, kUnknownSize);
}
}
Type resultType = tensorType.getWithSizesAndDtype(
sizes.size() == 0 ? Optional<ArrayRef<int64_t>>()
: llvm::makeArrayRef(sizes),
tensorType.getDtype());
Value sum = rewriter.create<AtenSumDimIntListOp>(loc, resultType, input,
dimList, keepDimCst, dtype);
return sum;
}
namespace {
class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenSizeOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.self();
MLIRContext *context = op.getContext();
int64_t rank = getTensorRank(self);
if (rank < 0)
return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor");
SmallVector<Value> sizes;
for (int i = 0; i < rank; i++) {
Value dim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));
sizes.push_back(rewriter.create<AtenSizeIntOp>(loc, self, dim));
}
Value sizeList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)), sizes);
rewriter.replaceOp(op, sizeList);
return success();
}
};
} // namespace
// Decompose softmax into: exp(x) / sum(exp(x))
namespace {
class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenSoftmaxIntOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.self();
Value dim = op.dim();
if (!op.dtype().getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "Unimplemented non-None dtype for softmax");
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(op, "Only support floating type");
// exp(x)
Value exp = rewriter.create<AtenExpOp>(loc, tensorType, self);
// sum(exp(x))
Value sum = createAtenSum(rewriter, loc, op, exp, dim, /*keepDim=*/true);
if (!sum)
return failure();
// exp(x) / sum(exp(x))
Value result = rewriter.create<AtenDivTensorOp>(loc, tensorType, exp, sum);
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
result);
return success();
}
};
} // namespace
// Aten_SoftmaxBackwardDataOp(gradOutput, output, dim) =>
// newGrad = gradOutput * output
// result = newGrad - output * sum(newGrad, dim))
//
// Refer to
// https://github.com/pytorch/pytorch/blob/15fecc4c830a3907fde4b44c9962dc4144da50a4/torch/csrc/jit/codegen/cuda/ops/normalization.cpp#L31
namespace {
class DecomposeAten_SoftmaxBackwardDataOp
: public OpRewritePattern<Aten_SoftmaxBackwardDataOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_SoftmaxBackwardDataOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
MLIRContext *context = op.getContext();
Value gradOutput = op.grad_output();
Value output = op.output();
Value dim = op.dim();
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>();
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(op, "Only support floating type");
Value newGrad =
rewriter.create<AtenMulTensorOp>(loc, tensorType, gradOutput, output);
// temp = output * sum(newGrad, dim)
Value sum =
createAtenSum(rewriter, loc, op, newGrad, dim, /*keepDim=*/true);
if (!sum)
return failure();
auto broadcastSizeType = Torch::ListType::get(Torch::IntType::get(context));
Value broadcastSize =
rewriter.create<AtenSizeOp>(loc, broadcastSizeType, output);
Value sumBroadcast =
rewriter.create<AtenBroadcastToOp>(loc, tensorType, sum, broadcastSize);
Value temp =
rewriter.create<AtenMulTensorOp>(loc, tensorType, output, sumBroadcast);
// newGrad - temp
Value alpha =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
Value sub =
rewriter.create<AtenSubTensorOp>(loc, tensorType, newGrad, temp, alpha);
rewriter.replaceOp(op, sub);
return success();
}
};
} // namespace
// Decompose aten.log_softmax op into: log(softmax(x))
namespace {
class DecomposeAtenLogSoftmaxIntOp
: public OpRewritePattern<AtenLogSoftmaxIntOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenLogSoftmaxIntOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.self();
Value dim = op.dim();
if (!op.dtype().getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "Unimplemented non-None dtype for log_softmax");
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(op, "Only support floating type");
// softmax(x, dim)
Value softmax = rewriter.create<AtenSoftmaxIntOp>(loc, tensorType, self,
dim, op.dtype());
rewriter.replaceOpWithNewOp<AtenLogOp>(op, op.getType(), softmax);
return success();
}
};
} // namespace
// Decompose torch.matmul into: torch.mm and torch.bmm according to ranks.
namespace {
class DecomposeAtenMatmulOp : public OpRewritePattern<AtenMatmulOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenMatmulOp op,
PatternRewriter &rewriter) const override {
Value lhs = op.self();
Value rhs = op.other();
int lhsRank = getTensorRank(lhs);
int rhsRank = getTensorRank(rhs);
// If both lhs and rhs ranks are 2 then map it to `aten.mm` op.
if (lhsRank == 2 && rhsRank == 2)
rewriter.replaceOpWithNewOp<AtenMmOp>(op, op.getType(), lhs, rhs);
// If both lhs and rhs ranks are 3 then map it to `aten.bmm` op.
if (lhsRank == 3 && rhsRank == 3)
rewriter.replaceOpWithNewOp<AtenBmmOp>(op, op.getType(), lhs, rhs);
return success();
}
};
} // namespace
// Decompose torch.expand into torch.broadcast_to op.
namespace {
class DecomposeAtenExpandOp : public OpRewritePattern<AtenExpandOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenExpandOp op,
PatternRewriter &rewriter) const override {
bool implicit = false;
if (!matchPattern(op.implicit(), m_TorchConstantBool(&implicit)) ||
implicit) {
return rewriter.notifyMatchFailure(
op, "unimplemented: requires implicit to be false");
}
rewriter.replaceOpWithNewOp<AtenBroadcastToOp>(op, op.getType(), op.self(),
op.size());
return success();
}
};
} // namespace
namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
ConversionTarget target(*context);
target.addLegalDialect<Torch::TorchDialect>();
patterns.add<DecomposeAtenSoftmaxIntOp>(context);
target.addIllegalOp<AtenSoftmaxIntOp>();
patterns.add<DecomposeAtenLogSoftmaxIntOp>(context);
target.addIllegalOp<AtenLogSoftmaxIntOp>();
patterns.add<DecomposeAtenExpandOp>(context);
target.addIllegalOp<AtenExpandOp>();
patterns.add<DecomposeAtenSizeOp>(context);
target.addIllegalOp<AtenSizeOp>();
patterns.add<DecomposeAten_SoftmaxBackwardDataOp>(context);
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
patterns.add<DecomposeAtenMatmulOp>(context);
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {
Add type promotion code to refine types. The types have different levels of categories: where complex > floating > integral > boolean (> means left hand side has higher category). The operands have different levels of priorities where: dimensioned tensor > 0-dim tensor > scalar == wrapped 0-dim tensor. This is represented by the `ResultTypeState.dimResult`, `ResultTypeState.zeroResult` and `ResultTypeState..wrappedResult` in the source code. For operands of the same priorities, the result type should be the highest categories with sufficient width to hold all operands. By default, only the highest priority operands participate in the type promotion logic. Lower priority operands participate if they are in a higher category than any higher priority operands. For example, <[],f32> (lower priority) and <[1], si64> tensor would result in <[?],f32> tensor because floating > integeral. Another example <[],f64> (lower priority) and <[1], f32> tensor would result in <[?], f32> tensor because f32 and f64 are the same category. The ScalarType enum definition, type promotion table, ResultTypeState struct definition and some helpers are copied from aten/src/ATen/native/TypeProperties.* Other references: - https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc - https://github.com/pytorch/pytorch/issues/9515 Other minor changes: 1. Fix `visitExpandLikeOp` to consider cases where the given sizes list size is larger than the input rank. 2. Add back the somehow deleted `torch.aten.softmax.int` tests in decompose-complex-ops.mlir.
2021-10-21 03:31:28 +08:00
int lhsRank = getTensorRank(op.self());
int rhsRank = getTensorRank(op.other());
// Make aten.matmul legal if the following condition is satisfied.
return (lhsRank != 2 || rhsRank != 2) && (lhsRank != 3 || rhsRank != 3);
});
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::torch::Torch::createDecomposeComplexOpsPass() {
return std::make_unique<DecomposeComplexOpsPass>();
}