mirror of https://github.com/llvm/torch-mlir
Avoid introducing DimOp's in LowerBroadcastToToLoops.
This makes sure we stay resonably canonically using the shape machinery. (In fact, DimOp should probably be in the shape dialect since it hides a `shape.shape_of` call)pull/1/head
parent
1ef8b91a95
commit
fec2ee0072
|
@ -124,8 +124,6 @@ void mlir::NPCOMP::createE2ELoweringPipeline(OpPassManager &pm) {
|
|||
// Convert tcp ops to Linalg where possible.
|
||||
pm.addPass(createConvertTCPToLinalgPass());
|
||||
|
||||
// TODO: legalize `dim` to shape.shape_of + tcp.get_extent
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Tensor to buffer (memref) conversion.
|
||||
// --------------------------------------------------------------------------
|
||||
|
|
|
@ -115,6 +115,21 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
// TODO: This should be layered in better somewhere.
|
||||
// We currently only create DimOp's during LowerBroadcastToToLoopsPattern,
|
||||
// so for now just stuff it in here.
|
||||
class LowerDimOpToShape : public OpRewritePattern<DimOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(DimOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto shape =
|
||||
rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.getOperand());
|
||||
rewriter.replaceOpWithNewOp<tcp::GetExtentOp>(op, shape, op.index());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
class LowerBroadcastToToLoops
|
||||
: public LowerBroadcastToToLoopsBase<LowerBroadcastToToLoops> {
|
||||
|
@ -126,12 +141,13 @@ class LowerBroadcastToToLoops
|
|||
target.addLegalDialect<StandardOpsDialect>();
|
||||
target.addLegalDialect<loop::LoopOpsDialect>();
|
||||
target.addLegalDialect<tcp::TCPDialect>();
|
||||
target.addIllegalOp<tcp::BroadcastToOp>();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
|
||||
|
||||
target.addIllegalOp<tcp::BroadcastToOp>();
|
||||
patterns.insert<LowerBroadcastToToLoopsPattern>(context);
|
||||
target.addIllegalOp<DimOp>();
|
||||
patterns.insert<LowerDimOpToShape>(context);
|
||||
|
||||
if (failed(applyPartialConversion(func, target, patterns))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue