Add simplification pattern to tm_tensor to linalg conversion

Attempts to compute whether or not each dimension in
a tm_tensor.npbroadcast op is either broadcasted or not. If all dims are
not broadcasted and the input and output ranks match, then the op is
folded away, else it generates a linalg style broadcast.
numpy_style_broadcast
Quinn Dawkins 2023-08-26 21:03:16 -04:00
parent a7f506adc4
commit 7c2e5031b9
1 changed files with 87 additions and 6 deletions

View File

@ -13,6 +13,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
@ -26,7 +27,76 @@
using namespace mlir;
using namespace mlir::torch::TMTensor;
/// Pattern rewriter hook to lower a `ScalarLoopOpInterface` to loops.
namespace {
class SimplifyNumpyBroadcast : public OpRewritePattern<NumpyBroadcastOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(NumpyBroadcastOp broadcastOp,
PatternRewriter &rewriter) const override {
Location loc = broadcastOp.getLoc();
Value input = broadcastOp.getInput();
Value output = broadcastOp.getOutput();
auto inputType = input.getType().cast<RankedTensorType>();
auto outputType = output.getType().cast<RankedTensorType>();
int64_t inputRank = inputType.getRank();
int64_t outputRank = outputType.getRank();
int64_t diff = outputRank - inputRank;
Value oneIndex =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
SmallVector<bool> broadcastedStatus;
for (int64_t i = 0, e = inputRank; i < e; ++i) {
FailureOr<bool> dimsEqual =
ValueBoundsConstraintSet::areEqual(input, output, i, i + diff);
if (succeeded(dimsEqual) && *dimsEqual) {
broadcastedStatus.push_back(false);
continue;
}
FailureOr<bool> isUnit =
ValueBoundsConstraintSet::areEqual(input, oneIndex, i, std::nullopt);
if (succeeded(isUnit) || *isUnit) {
broadcastedStatus.push_back(true);
continue;
}
// Unable to statically bound all input dims to a broadcast status; bail.
return failure();
}
// If no dims are broadcasted and the rank doesn't change, we can just fold
// the op away entirely.
if (!llvm::any_of(broadcastedStatus, [](bool b) { return b; }) &&
inputRank == outputRank) {
rewriter.replaceOpWithNewOp<tensor::CastOp>(
broadcastOp, broadcastOp.getResult(0).getType(), input);
return success();
}
SmallVector<AffineExpr> inputExprs;
for (int64_t i = 0, e = inputRank; i < e; ++i) {
if (broadcastedStatus[i]) {
inputExprs.push_back(rewriter.getAffineConstantExpr(0));
continue;
}
inputExprs.push_back(rewriter.getAffineDimExpr(i + diff));
}
SmallVector<AffineMap> indexingMaps = {
AffineMap::get(outputRank, 0, inputExprs, broadcastOp.getContext()),
rewriter.getMultiDimIdentityMap(outputRank)};
SmallVector<utils::IteratorType> iteratorTypes(
outputRank, utils::IteratorType::parallel);
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
broadcastOp, output.getType(), input, output, indexingMaps,
iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
});
return success();
}
};
} // namespace
/// Pattern rewriter hook to lower a `tm_tensor.npbroadcast` to linalg.
namespace {
class LowerNumpyBroadcastToLinalg : public OpRewritePattern<NumpyBroadcastOp> {
public:
@ -106,6 +176,16 @@ struct TMTensorBroadcastToLinalgPass
void runOnOperation() override {
MLIRContext *context = &getContext();
{
RewritePatternSet patterns(context);
patterns.insert<SimplifyNumpyBroadcast>(context);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
}
}
{
RewritePatternSet patterns(context);
patterns.insert<LowerNumpyBroadcastToLinalg>(context);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
@ -113,6 +193,7 @@ struct TMTensorBroadcastToLinalgPass
return signalPassFailure();
}
}
}
};
} // namespace