mirror of https://github.com/llvm/torch-mlir
Add simplification pattern to tm_tensor to linalg conversion
Attempts to compute whether or not each dimension in a tm_tensor.npbroadcast op is either broadcasted or not. If all dims are not broadcasted and the input and output ranks match, then the op is folded away, else it generates a linalg style broadcast.numpy_style_broadcast
parent
a7f506adc4
commit
7c2e5031b9
|
@ -13,6 +13,7 @@
|
|||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
||||
|
@ -26,7 +27,76 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::torch::TMTensor;
|
||||
|
||||
/// Pattern rewriter hook to lower a `ScalarLoopOpInterface` to loops.
|
||||
namespace {
|
||||
class SimplifyNumpyBroadcast : public OpRewritePattern<NumpyBroadcastOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(NumpyBroadcastOp broadcastOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = broadcastOp.getLoc();
|
||||
Value input = broadcastOp.getInput();
|
||||
Value output = broadcastOp.getOutput();
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto outputType = output.getType().cast<RankedTensorType>();
|
||||
int64_t inputRank = inputType.getRank();
|
||||
int64_t outputRank = outputType.getRank();
|
||||
int64_t diff = outputRank - inputRank;
|
||||
|
||||
Value oneIndex =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
||||
|
||||
SmallVector<bool> broadcastedStatus;
|
||||
for (int64_t i = 0, e = inputRank; i < e; ++i) {
|
||||
FailureOr<bool> dimsEqual =
|
||||
ValueBoundsConstraintSet::areEqual(input, output, i, i + diff);
|
||||
if (succeeded(dimsEqual) && *dimsEqual) {
|
||||
broadcastedStatus.push_back(false);
|
||||
continue;
|
||||
}
|
||||
FailureOr<bool> isUnit =
|
||||
ValueBoundsConstraintSet::areEqual(input, oneIndex, i, std::nullopt);
|
||||
if (succeeded(isUnit) || *isUnit) {
|
||||
broadcastedStatus.push_back(true);
|
||||
continue;
|
||||
}
|
||||
// Unable to statically bound all input dims to a broadcast status; bail.
|
||||
return failure();
|
||||
}
|
||||
|
||||
// If no dims are broadcasted and the rank doesn't change, we can just fold
|
||||
// the op away entirely.
|
||||
if (!llvm::any_of(broadcastedStatus, [](bool b) { return b; }) &&
|
||||
inputRank == outputRank) {
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||
broadcastOp, broadcastOp.getResult(0).getType(), input);
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<AffineExpr> inputExprs;
|
||||
for (int64_t i = 0, e = inputRank; i < e; ++i) {
|
||||
if (broadcastedStatus[i]) {
|
||||
inputExprs.push_back(rewriter.getAffineConstantExpr(0));
|
||||
continue;
|
||||
}
|
||||
inputExprs.push_back(rewriter.getAffineDimExpr(i + diff));
|
||||
}
|
||||
|
||||
SmallVector<AffineMap> indexingMaps = {
|
||||
AffineMap::get(outputRank, 0, inputExprs, broadcastOp.getContext()),
|
||||
rewriter.getMultiDimIdentityMap(outputRank)};
|
||||
SmallVector<utils::IteratorType> iteratorTypes(
|
||||
outputRank, utils::IteratorType::parallel);
|
||||
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
|
||||
broadcastOp, output.getType(), input, output, indexingMaps,
|
||||
iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
b.create<linalg::YieldOp>(loc, args[0]);
|
||||
});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
/// Pattern rewriter hook to lower a `tm_tensor.npbroadcast` to linalg.
|
||||
namespace {
|
||||
class LowerNumpyBroadcastToLinalg : public OpRewritePattern<NumpyBroadcastOp> {
|
||||
public:
|
||||
|
@ -106,11 +176,22 @@ struct TMTensorBroadcastToLinalgPass
|
|||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.insert<LowerNumpyBroadcastToLinalg>(context);
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation(),
|
||||
std::move(patterns)))) {
|
||||
return signalPassFailure();
|
||||
{
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.insert<SimplifyNumpyBroadcast>(context);
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation(),
|
||||
std::move(patterns)))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.insert<LowerNumpyBroadcastToLinalg>(context);
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation(),
|
||||
std::move(patterns)))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue