2020-09-17 08:31:40 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "../PassDetail.h"
|
2020-10-07 07:14:37 +08:00
|
|
|
#include "npcomp/RefBackend/RefBackend.h"
|
2020-09-17 08:31:40 +08:00
|
|
|
|
|
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
#include "mlir/Pass/PassRegistry.h"
|
|
|
|
#include "mlir/Transforms/DialectConversion.h"
|
2020-10-08 08:30:10 +08:00
|
|
|
#include "npcomp/Dialect/Refback/IR/RefbackDialect.h"
|
|
|
|
#include "npcomp/Dialect/Refback/IR/RefbackOps.h"
|
2020-09-17 08:31:40 +08:00
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::NPCOMP;
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
class LowerExtractElementOp : public OpConversionPattern<ExtractElementOp> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(ExtractElementOp op, ArrayRef<Value> operands,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
ExtractElementOp::Adaptor adaptor(operands);
|
|
|
|
rewriter.replaceOpWithNewOp<LoadOp>(op, adaptor.aggregate(),
|
|
|
|
adaptor.indices());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
class LowerTensorFromElementsOp
|
|
|
|
: public OpConversionPattern<TensorFromElementsOp> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(TensorFromElementsOp op, ArrayRef<Value> operands,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
int numberOfElements = op.elements().size();
|
|
|
|
auto resultType = MemRefType::get(
|
|
|
|
{numberOfElements}, op.getType().cast<TensorType>().getElementType());
|
|
|
|
Value result = rewriter.create<AllocOp>(op.getLoc(), resultType);
|
|
|
|
for (auto element : llvm::enumerate(op.elements())) {
|
|
|
|
Value index =
|
|
|
|
rewriter.create<ConstantIndexOp>(op.getLoc(), element.index());
|
|
|
|
rewriter.create<StoreOp>(op.getLoc(), element.value(), result, index);
|
|
|
|
}
|
|
|
|
rewriter.replaceOp(op, {result});
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
class LowerTensorCastOp : public OpConversionPattern<TensorCastOp> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(TensorCastOp op, ArrayRef<Value> operands,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
auto resultType = typeConverter->convertType(op.getType());
|
|
|
|
rewriter.replaceOpWithNewOp<MemRefCastOp>(op, resultType, operands[0]);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
class LowerTensorLoadOp : public OpConversionPattern<TensorLoadOp> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(TensorLoadOp op, ArrayRef<Value> operands,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
rewriter.replaceOp(op, operands[0]);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
// TODO: Upstream this.
|
|
|
|
class LowerStdToMemref : public LowerStdToMemrefBase<LowerStdToMemref> {
|
2020-09-22 05:48:44 +08:00
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
2020-10-08 08:30:10 +08:00
|
|
|
registry.insert<refback::RefbackDialect>();
|
2020-09-22 05:48:44 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
void runOnOperation() override {
|
2020-09-17 08:31:40 +08:00
|
|
|
auto func = getOperation();
|
|
|
|
auto *context = &getContext();
|
|
|
|
|
|
|
|
TypeConverter typeConverter;
|
|
|
|
typeConverter.addConversion([](Type type) { return type; });
|
|
|
|
typeConverter.addConversion([](RankedTensorType type) -> Type {
|
|
|
|
return MemRefType::get(type.getShape(), type.getElementType());
|
|
|
|
});
|
|
|
|
typeConverter.addSourceMaterialization([](OpBuilder &builder,
|
|
|
|
RankedTensorType type,
|
|
|
|
ValueRange inputs, Location loc) {
|
|
|
|
assert(inputs.size() == 1);
|
|
|
|
assert(inputs[0].getType().isa<MemRefType>());
|
[RefBackend] Split out RefBackend (refback) dialect from TCP.
This is the first in a patch series that is refactoring the
constellation of things variously called or associated with "E2E",
"RefE2E", "npcomprt", and "TCP" into a more cleanly layered result.
Concretely, this first patch fixes the fact that TCP was basically
acting like a dumping ground needed by the reference backend. This
splits it out, which is fairly mechanical, but touches a lot of lines of
code (basically replacing `tcp` with `refback` and `TCP` with
`RefBackend).
Now, the RefBackend dialect is that dumping ground, which
is slighly better, as it starts allowing TCP to become a nice clean
middle layer that is not related per se to the reference backend.
The previous name RefE2E or "reference e2e flow" was super confusing.
Now that we are seeing more clearly where the "backend" distinction
lies, the [RefBackend] commit tag is born :)
2020-10-07 06:44:18 +08:00
|
|
|
return (Value)builder.create<refback::MemrefToTensorOp>(loc, type,
|
|
|
|
inputs[0]);
|
2020-09-17 08:31:40 +08:00
|
|
|
});
|
|
|
|
typeConverter.addTargetMaterialization([](OpBuilder &builder,
|
|
|
|
MemRefType type,
|
|
|
|
ValueRange inputs, Location loc) {
|
|
|
|
assert(inputs.size() == 1);
|
|
|
|
assert(inputs[0].getType().isa<RankedTensorType>());
|
[RefBackend] Split out RefBackend (refback) dialect from TCP.
This is the first in a patch series that is refactoring the
constellation of things variously called or associated with "E2E",
"RefE2E", "npcomprt", and "TCP" into a more cleanly layered result.
Concretely, this first patch fixes the fact that TCP was basically
acting like a dumping ground needed by the reference backend. This
splits it out, which is fairly mechanical, but touches a lot of lines of
code (basically replacing `tcp` with `refback` and `TCP` with
`RefBackend).
Now, the RefBackend dialect is that dumping ground, which
is slighly better, as it starts allowing TCP to become a nice clean
middle layer that is not related per se to the reference backend.
The previous name RefE2E or "reference e2e flow" was super confusing.
Now that we are seeing more clearly where the "backend" distinction
lies, the [RefBackend] commit tag is born :)
2020-10-07 06:44:18 +08:00
|
|
|
return (Value)builder.create<refback::TensorToMemrefOp>(loc, type,
|
|
|
|
inputs[0]);
|
2020-09-17 08:31:40 +08:00
|
|
|
});
|
|
|
|
|
|
|
|
OwningRewritePatternList patterns;
|
|
|
|
|
|
|
|
ConversionTarget target(*context);
|
|
|
|
|
|
|
|
target.addLegalDialect<StandardOpsDialect>();
|
|
|
|
|
|
|
|
// The casting ops are introduced by the type converter, so they must be
|
|
|
|
// legal.
|
[RefBackend] Split out RefBackend (refback) dialect from TCP.
This is the first in a patch series that is refactoring the
constellation of things variously called or associated with "E2E",
"RefE2E", "npcomprt", and "TCP" into a more cleanly layered result.
Concretely, this first patch fixes the fact that TCP was basically
acting like a dumping ground needed by the reference backend. This
splits it out, which is fairly mechanical, but touches a lot of lines of
code (basically replacing `tcp` with `refback` and `TCP` with
`RefBackend).
Now, the RefBackend dialect is that dumping ground, which
is slighly better, as it starts allowing TCP to become a nice clean
middle layer that is not related per se to the reference backend.
The previous name RefE2E or "reference e2e flow" was super confusing.
Now that we are seeing more clearly where the "backend" distinction
lies, the [RefBackend] commit tag is born :)
2020-10-07 06:44:18 +08:00
|
|
|
target.addLegalOp<refback::MemrefToTensorOp>();
|
|
|
|
target.addLegalOp<refback::TensorToMemrefOp>();
|
2020-09-17 08:31:40 +08:00
|
|
|
|
|
|
|
patterns.insert<LowerExtractElementOp>(typeConverter, context);
|
|
|
|
target.addIllegalOp<ExtractElementOp>();
|
|
|
|
patterns.insert<LowerTensorFromElementsOp>(typeConverter, context);
|
|
|
|
target.addIllegalOp<TensorFromElementsOp>();
|
|
|
|
patterns.insert<LowerTensorCastOp>(typeConverter, context);
|
|
|
|
target.addIllegalOp<TensorCastOp>();
|
|
|
|
patterns.insert<LowerTensorLoadOp>(typeConverter, context);
|
|
|
|
target.addIllegalOp<TensorLoadOp>();
|
|
|
|
|
|
|
|
if (failed(applyPartialConversion(func, target, patterns)))
|
|
|
|
return signalPassFailure();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
std::unique_ptr<OperationPass<FuncOp>>
|
|
|
|
mlir::NPCOMP::createLowerStdToMemrefPass() {
|
|
|
|
return std::make_unique<LowerStdToMemref>();
|
|
|
|
}
|