Avoid introducing DimOp's in LowerBroadcastToToLoops.

This makes sure we stay resonably canonically using the shape machinery.
(In fact, DimOp should probably be in the shape dialect since it hides a
`shape.shape_of` call)
pull/1/head
Sean Silva 2020-05-11 13:12:16 -07:00
parent 1ef8b91a95
commit fec2ee0072
2 changed files with 19 additions and 5 deletions

View File

@ -124,8 +124,6 @@ void mlir::NPCOMP::createE2ELoweringPipeline(OpPassManager &pm) {
// Convert tcp ops to Linalg where possible.
pm.addPass(createConvertTCPToLinalgPass());
// TODO: legalize `dim` to shape.shape_of + tcp.get_extent
// --------------------------------------------------------------------------
// Tensor to buffer (memref) conversion.
// --------------------------------------------------------------------------

View File

@ -115,6 +115,21 @@ public:
}
};
// TODO: This should be layered in better somewhere.
// We currently only create DimOp's during LowerBroadcastToToLoopsPattern,
// so for now just stuff it in here.
class LowerDimOpToShape : public OpRewritePattern<DimOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(DimOp op,
PatternRewriter &rewriter) const override {
auto shape =
rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.getOperand());
rewriter.replaceOpWithNewOp<tcp::GetExtentOp>(op, shape, op.index());
return success();
}
};
namespace {
class LowerBroadcastToToLoops
: public LowerBroadcastToToLoopsBase<LowerBroadcastToToLoops> {
@ -126,12 +141,13 @@ class LowerBroadcastToToLoops
target.addLegalDialect<StandardOpsDialect>();
target.addLegalDialect<loop::LoopOpsDialect>();
target.addLegalDialect<tcp::TCPDialect>();
target.addIllegalOp<tcp::BroadcastToOp>();
OwningRewritePatternList patterns;
target.addIllegalOp<tcp::BroadcastToOp>();
patterns.insert<LowerBroadcastToToLoopsPattern>(context);
target.addIllegalOp<DimOp>();
patterns.insert<LowerDimOpToShape>(context);
if (failed(applyPartialConversion(func, target, patterns))) {
return signalPassFailure();
}