2020-10-16 03:26:21 +08:00
|
|
|
//===- Bufferize.cpp - Bufferization for TCP dialect -------------*- C++-*-===//
|
2020-09-17 08:31:40 +08:00
|
|
|
//
|
2020-10-16 03:26:21 +08:00
|
|
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
2020-09-17 08:31:40 +08:00
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-10-16 03:26:21 +08:00
|
|
|
#include "PassDetail.h"
|
2020-09-17 08:31:40 +08:00
|
|
|
|
|
|
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
|
|
|
#include "mlir/Dialect/SCF/SCF.h"
|
|
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
2020-10-16 03:26:21 +08:00
|
|
|
#include "mlir/IR/Builders.h"
|
|
|
|
#include "mlir/IR/Module.h"
|
2020-10-15 10:28:43 +08:00
|
|
|
#include "mlir/Transforms/Bufferize.h"
|
2020-09-17 08:31:40 +08:00
|
|
|
#include "mlir/Transforms/DialectConversion.h"
|
2020-10-16 03:26:21 +08:00
|
|
|
#include "npcomp/Dialect/Refback/IR/RefbackDialect.h"
|
2020-10-08 08:30:10 +08:00
|
|
|
#include "npcomp/Dialect/Refback/IR/RefbackOps.h"
|
2020-09-17 08:31:40 +08:00
|
|
|
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
|
|
|
|
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
2020-10-16 03:26:21 +08:00
|
|
|
#include "npcomp/Dialect/TCP/Transforms/Passes.h"
|
2020-09-17 08:31:40 +08:00
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::NPCOMP;
|
|
|
|
|
2020-10-16 03:26:21 +08:00
|
|
|
// TODO: Don't just open-code all shape transfer functions here.
|
|
|
|
static SmallVector<Value, 6> bypassResultShapes(Operation &op) {
|
|
|
|
OpBuilder builder(&op);
|
|
|
|
|
|
|
|
if (auto broadcastTo = dyn_cast<tcp::BroadcastToOp>(op)) {
|
|
|
|
return {broadcastTo.shape()};
|
|
|
|
}
|
|
|
|
|
2020-11-10 07:49:22 +08:00
|
|
|
if (auto splatted = dyn_cast<tcp::SplattedOp>(op)) {
|
|
|
|
return {splatted.shape()};
|
2020-10-16 03:26:21 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
// No shape transfer function.
|
|
|
|
return {};
|
|
|
|
}
|
|
|
|
|
2020-09-18 09:56:01 +08:00
|
|
|
static FailureOr<SmallVector<Value, 6>>
|
|
|
|
allocateResults(Operation *op, ConversionPatternRewriter &rewriter,
|
|
|
|
Location loc,
|
|
|
|
SmallVectorImpl<Value> *resultShapesOut = nullptr) {
|
2020-10-16 03:26:21 +08:00
|
|
|
auto resultShapes = bypassResultShapes(*op);
|
2020-09-18 09:56:01 +08:00
|
|
|
SmallVector<Value, 6> results;
|
|
|
|
for (auto t : llvm::zip(op->getResults(), resultShapes)) {
|
|
|
|
auto result = std::get<0>(t);
|
|
|
|
auto resultShape = std::get<1>(t);
|
|
|
|
auto tensorType = result.getType().cast<RankedTensorType>();
|
|
|
|
auto memrefType =
|
|
|
|
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
|
|
|
auto memref =
|
[RefBackend] Split out RefBackend (refback) dialect from TCP.
This is the first in a patch series that is refactoring the
constellation of things variously called or associated with "E2E",
"RefE2E", "npcomprt", and "TCP" into a more cleanly layered result.
Concretely, this first patch fixes the fact that TCP was basically
acting like a dumping ground needed by the reference backend. This
splits it out, which is fairly mechanical, but touches a lot of lines of
code (basically replacing `tcp` with `refback` and `TCP` with
`RefBackend).
Now, the RefBackend dialect is that dumping ground, which
is slighly better, as it starts allowing TCP to become a nice clean
middle layer that is not related per se to the reference backend.
The previous name RefE2E or "reference e2e flow" was super confusing.
Now that we are seeing more clearly where the "backend" distinction
lies, the [RefBackend] commit tag is born :)
2020-10-07 06:44:18 +08:00
|
|
|
rewriter.create<refback::AllocMemRefOp>(loc, memrefType, resultShape);
|
2020-09-18 09:56:01 +08:00
|
|
|
results.push_back(memref);
|
|
|
|
}
|
|
|
|
if (resultShapesOut)
|
|
|
|
resultShapesOut->append(resultShapes.begin(), resultShapes.end());
|
|
|
|
return results;
|
2020-09-17 08:31:40 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
// TODO: Lower to a "buffer version" of tcp::BroadcastTo instead of directly to
|
|
|
|
// loops.
|
|
|
|
class LowerBroadcastToToLoopsPattern
|
|
|
|
: public OpConversionPattern<tcp::BroadcastToOp> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(tcp::BroadcastToOp op, ArrayRef<Value> operands,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
auto resultType = op.getType().cast<RankedTensorType>();
|
|
|
|
auto inputType = op.operand().getType().cast<RankedTensorType>();
|
2020-09-18 09:56:01 +08:00
|
|
|
SmallVector<Value, 6> resultShapes;
|
|
|
|
auto resultsOrFailure =
|
|
|
|
allocateResults(op, rewriter, op.getLoc(), &resultShapes);
|
|
|
|
if (failed(resultsOrFailure))
|
|
|
|
return failure();
|
|
|
|
Value resultMemref = (*resultsOrFailure)[0];
|
|
|
|
auto resultShape = resultShapes[0];
|
2020-09-17 08:31:40 +08:00
|
|
|
Value inputMemref = operands[0];
|
|
|
|
|
|
|
|
SmallVector<Value, 6> outputExtents;
|
|
|
|
for (int i = 0, e = resultType.getRank(); i < e; i++) {
|
|
|
|
Value dimIndex = rewriter.create<ConstantIndexOp>(op.getLoc(), i);
|
|
|
|
Value outputExtent = rewriter.create<shape::GetExtentOp>(
|
|
|
|
op.getLoc(), rewriter.getIndexType(), resultShape, dimIndex);
|
|
|
|
outputExtents.push_back(outputExtent);
|
|
|
|
}
|
|
|
|
int rankDiff = resultType.getRank() - inputType.getRank();
|
|
|
|
SmallVector<Value, 6> inputDimRequiresBroadcasting;
|
|
|
|
for (int i = 0, e = inputType.getRank(); i < e; i++) {
|
|
|
|
// Calculate the relevant extents.
|
|
|
|
Value inputExtent = rewriter.create<DimOp>(op.getLoc(), op.operand(), i);
|
|
|
|
inputDimRequiresBroadcasting.push_back(
|
|
|
|
rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::ne, inputExtent,
|
|
|
|
outputExtents[rankDiff + i]));
|
|
|
|
}
|
|
|
|
|
|
|
|
{
|
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
|
|
Value c0 = rewriter.create<ConstantIndexOp>(op.getLoc(), 0);
|
|
|
|
Value c1 = rewriter.create<ConstantIndexOp>(op.getLoc(), 1);
|
|
|
|
|
|
|
|
SmallVector<Value, 6> inductionVariables;
|
|
|
|
// Create the (perfectly nested) loops.
|
|
|
|
// Loop invariant: At the start of iteration `i`, the rewriter insertion
|
|
|
|
// point is inside `i` nested loops.
|
|
|
|
for (int i = 0, e = resultType.getRank(); i < e; i++) {
|
|
|
|
auto loop = rewriter.create<scf::ForOp>(
|
|
|
|
op.getLoc(), c0, outputExtents[i], c1, ValueRange({}));
|
|
|
|
Block *body = loop.getBody();
|
|
|
|
inductionVariables.push_back(body->getArgument(0));
|
|
|
|
// Leave the insertion point at the beginning of the body.
|
|
|
|
rewriter.setInsertionPointToStart(body);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Create the inner loop body.
|
|
|
|
// When reading from the input, clamp any indices for dimensions that are
|
|
|
|
// being broadcast.
|
|
|
|
SmallVector<Value, 6> inputIndices;
|
|
|
|
for (int i = 0, e = inputType.getRank(); i < e; i++) {
|
|
|
|
auto c0 = rewriter.create<ConstantIndexOp>(op.getLoc(), 0);
|
|
|
|
auto select = rewriter.create<SelectOp>(
|
|
|
|
op.getLoc(), inputDimRequiresBroadcasting[i], c0,
|
|
|
|
inductionVariables[rankDiff + i]);
|
|
|
|
inputIndices.push_back(select);
|
|
|
|
}
|
|
|
|
Value load =
|
|
|
|
rewriter.create<LoadOp>(op.getLoc(), inputMemref, inputIndices);
|
|
|
|
rewriter.create<StoreOp>(op.getLoc(), load, resultMemref,
|
|
|
|
inductionVariables);
|
|
|
|
}
|
|
|
|
rewriter.replaceOp(op, resultMemref);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2020-09-18 09:56:01 +08:00
|
|
|
namespace {
|
2020-11-10 07:49:22 +08:00
|
|
|
class BufferizeSplattedOp : public OpConversionPattern<tcp::SplattedOp> {
|
2020-09-18 09:56:01 +08:00
|
|
|
public:
|
|
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
2020-11-10 07:49:22 +08:00
|
|
|
matchAndRewrite(tcp::SplattedOp op, ArrayRef<Value> operands,
|
2020-09-18 09:56:01 +08:00
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
auto resultsOrFailure = allocateResults(op, rewriter, op.getLoc());
|
|
|
|
if (failed(resultsOrFailure))
|
|
|
|
return failure();
|
|
|
|
auto results = *resultsOrFailure;
|
2020-11-10 07:49:22 +08:00
|
|
|
rewriter.create<linalg::FillOp>(op.getLoc(), results[0], op.splatVal());
|
2020-09-18 09:56:01 +08:00
|
|
|
rewriter.replaceOp(op, results);
|
2020-09-17 08:31:40 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
namespace {
|
2020-10-16 03:26:21 +08:00
|
|
|
class TCPBufferizePass : public TCPBufferizeBase<TCPBufferizePass> {
|
|
|
|
void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
|
|
|
|
registry.insert<refback::RefbackDialect>();
|
|
|
|
registry.insert<linalg::LinalgDialect>();
|
|
|
|
registry.insert<scf::SCFDialect>();
|
|
|
|
registry.insert<shape::ShapeDialect>();
|
2020-09-22 05:48:44 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
void runOnOperation() override {
|
2020-09-17 08:31:40 +08:00
|
|
|
auto func = getOperation();
|
|
|
|
auto *context = &getContext();
|
|
|
|
|
2020-10-15 10:28:43 +08:00
|
|
|
BufferizeTypeConverter typeConverter;
|
2020-09-17 08:31:40 +08:00
|
|
|
|
|
|
|
OwningRewritePatternList patterns;
|
|
|
|
|
|
|
|
ConversionTarget target(*context);
|
|
|
|
|
[RefBackend] Split out RefBackend (refback) dialect from TCP.
This is the first in a patch series that is refactoring the
constellation of things variously called or associated with "E2E",
"RefE2E", "npcomprt", and "TCP" into a more cleanly layered result.
Concretely, this first patch fixes the fact that TCP was basically
acting like a dumping ground needed by the reference backend. This
splits it out, which is fairly mechanical, but touches a lot of lines of
code (basically replacing `tcp` with `refback` and `TCP` with
`RefBackend).
Now, the RefBackend dialect is that dumping ground, which
is slighly better, as it starts allowing TCP to become a nice clean
middle layer that is not related per se to the reference backend.
The previous name RefE2E or "reference e2e flow" was super confusing.
Now that we are seeing more clearly where the "backend" distinction
lies, the [RefBackend] commit tag is born :)
2020-10-07 06:44:18 +08:00
|
|
|
// All lowering to buffers involves refback.alloc_memref ops.
|
2020-10-16 03:26:21 +08:00
|
|
|
// TODO: This makes the tests cleaner, but otherwise isn't too essential as
|
|
|
|
// we can just open-code the extents for the alloc.
|
[RefBackend] Split out RefBackend (refback) dialect from TCP.
This is the first in a patch series that is refactoring the
constellation of things variously called or associated with "E2E",
"RefE2E", "npcomprt", and "TCP" into a more cleanly layered result.
Concretely, this first patch fixes the fact that TCP was basically
acting like a dumping ground needed by the reference backend. This
splits it out, which is fairly mechanical, but touches a lot of lines of
code (basically replacing `tcp` with `refback` and `TCP` with
`RefBackend).
Now, the RefBackend dialect is that dumping ground, which
is slighly better, as it starts allowing TCP to become a nice clean
middle layer that is not related per se to the reference backend.
The previous name RefE2E or "reference e2e flow" was super confusing.
Now that we are seeing more clearly where the "backend" distinction
lies, the [RefBackend] commit tag is born :)
2020-10-07 06:44:18 +08:00
|
|
|
target.addLegalOp<refback::AllocMemRefOp>();
|
2020-09-17 08:31:40 +08:00
|
|
|
|
|
|
|
patterns.insert<LowerBroadcastToToLoopsPattern>(typeConverter, context);
|
|
|
|
target.addIllegalOp<tcp::BroadcastToOp>();
|
2020-11-10 07:49:22 +08:00
|
|
|
patterns.insert<BufferizeSplattedOp>(typeConverter, context);
|
|
|
|
target.addIllegalOp<tcp::SplattedOp>();
|
2020-09-18 09:56:01 +08:00
|
|
|
|
|
|
|
target.addLegalDialect<linalg::LinalgDialect>();
|
2020-09-17 08:31:40 +08:00
|
|
|
target.addLegalDialect<StandardOpsDialect>();
|
|
|
|
target.addLegalDialect<scf::SCFDialect>();
|
|
|
|
target.addLegalOp<shape::GetExtentOp>();
|
|
|
|
|
2020-10-30 06:25:55 +08:00
|
|
|
if (failed(applyPartialConversion(func, target, std::move(patterns))))
|
2020-09-17 08:31:40 +08:00
|
|
|
return signalPassFailure();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2020-10-16 03:26:21 +08:00
|
|
|
std::unique_ptr<OperationPass<FuncOp>> mlir::NPCOMP::createTCPBufferizePass() {
|
|
|
|
return std::make_unique<TCPBufferizePass>();
|
2020-09-17 08:31:40 +08:00
|
|
|
}
|