mirror of https://github.com/llvm/torch-mlir
Add simplification pattern to tm_tensor to linalg conversion
Attempts to compute whether or not each dimension in a tm_tensor.npbroadcast op is either broadcasted or not. If all dims are not broadcasted and the input and output ranks match, then the op is folded away, else it generates a linalg style broadcast.numpy_style_broadcast
parent
a7f506adc4
commit
7c2e5031b9
|
@ -13,6 +13,7 @@
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
||||||
|
@ -26,7 +27,76 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch::TMTensor;
|
using namespace mlir::torch::TMTensor;
|
||||||
|
|
||||||
/// Pattern rewriter hook to lower a `ScalarLoopOpInterface` to loops.
|
namespace {
|
||||||
|
class SimplifyNumpyBroadcast : public OpRewritePattern<NumpyBroadcastOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(NumpyBroadcastOp broadcastOp,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Location loc = broadcastOp.getLoc();
|
||||||
|
Value input = broadcastOp.getInput();
|
||||||
|
Value output = broadcastOp.getOutput();
|
||||||
|
auto inputType = input.getType().cast<RankedTensorType>();
|
||||||
|
auto outputType = output.getType().cast<RankedTensorType>();
|
||||||
|
int64_t inputRank = inputType.getRank();
|
||||||
|
int64_t outputRank = outputType.getRank();
|
||||||
|
int64_t diff = outputRank - inputRank;
|
||||||
|
|
||||||
|
Value oneIndex =
|
||||||
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
||||||
|
|
||||||
|
SmallVector<bool> broadcastedStatus;
|
||||||
|
for (int64_t i = 0, e = inputRank; i < e; ++i) {
|
||||||
|
FailureOr<bool> dimsEqual =
|
||||||
|
ValueBoundsConstraintSet::areEqual(input, output, i, i + diff);
|
||||||
|
if (succeeded(dimsEqual) && *dimsEqual) {
|
||||||
|
broadcastedStatus.push_back(false);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
FailureOr<bool> isUnit =
|
||||||
|
ValueBoundsConstraintSet::areEqual(input, oneIndex, i, std::nullopt);
|
||||||
|
if (succeeded(isUnit) || *isUnit) {
|
||||||
|
broadcastedStatus.push_back(true);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Unable to statically bound all input dims to a broadcast status; bail.
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no dims are broadcasted and the rank doesn't change, we can just fold
|
||||||
|
// the op away entirely.
|
||||||
|
if (!llvm::any_of(broadcastedStatus, [](bool b) { return b; }) &&
|
||||||
|
inputRank == outputRank) {
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||||
|
broadcastOp, broadcastOp.getResult(0).getType(), input);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<AffineExpr> inputExprs;
|
||||||
|
for (int64_t i = 0, e = inputRank; i < e; ++i) {
|
||||||
|
if (broadcastedStatus[i]) {
|
||||||
|
inputExprs.push_back(rewriter.getAffineConstantExpr(0));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
inputExprs.push_back(rewriter.getAffineDimExpr(i + diff));
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<AffineMap> indexingMaps = {
|
||||||
|
AffineMap::get(outputRank, 0, inputExprs, broadcastOp.getContext()),
|
||||||
|
rewriter.getMultiDimIdentityMap(outputRank)};
|
||||||
|
SmallVector<utils::IteratorType> iteratorTypes(
|
||||||
|
outputRank, utils::IteratorType::parallel);
|
||||||
|
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
|
||||||
|
broadcastOp, output.getType(), input, output, indexingMaps,
|
||||||
|
iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
b.create<linalg::YieldOp>(loc, args[0]);
|
||||||
|
});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
/// Pattern rewriter hook to lower a `tm_tensor.npbroadcast` to linalg.
|
||||||
namespace {
|
namespace {
|
||||||
class LowerNumpyBroadcastToLinalg : public OpRewritePattern<NumpyBroadcastOp> {
|
class LowerNumpyBroadcastToLinalg : public OpRewritePattern<NumpyBroadcastOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -106,6 +176,16 @@ struct TMTensorBroadcastToLinalgPass
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
MLIRContext *context = &getContext();
|
MLIRContext *context = &getContext();
|
||||||
|
|
||||||
|
{
|
||||||
|
RewritePatternSet patterns(context);
|
||||||
|
patterns.insert<SimplifyNumpyBroadcast>(context);
|
||||||
|
if (failed(applyPatternsAndFoldGreedily(getOperation(),
|
||||||
|
std::move(patterns)))) {
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
patterns.insert<LowerNumpyBroadcastToLinalg>(context);
|
patterns.insert<LowerNumpyBroadcastToLinalg>(context);
|
||||||
if (failed(applyPatternsAndFoldGreedily(getOperation(),
|
if (failed(applyPatternsAndFoldGreedily(getOperation(),
|
||||||
|
@ -113,6 +193,7 @@ struct TMTensorBroadcastToLinalgPass
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue