From fec2ee007279af041d38a94f3843f483fe9156f8 Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Mon, 11 May 2020 13:12:16 -0700 Subject: [PATCH] Avoid introducing DimOp's in LowerBroadcastToToLoops. This makes sure we stay resonably canonically using the shape machinery. (In fact, DimOp should probably be in the shape dialect since it hides a `shape.shape_of` call) --- lib/E2E/E2E.cpp | 2 -- lib/E2E/LowerToHybridTensorMemRef.cpp | 22 +++++++++++++++++++--- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/lib/E2E/E2E.cpp b/lib/E2E/E2E.cpp index 441c1da74..c14c61574 100644 --- a/lib/E2E/E2E.cpp +++ b/lib/E2E/E2E.cpp @@ -124,8 +124,6 @@ void mlir::NPCOMP::createE2ELoweringPipeline(OpPassManager &pm) { // Convert tcp ops to Linalg where possible. pm.addPass(createConvertTCPToLinalgPass()); - // TODO: legalize `dim` to shape.shape_of + tcp.get_extent - // -------------------------------------------------------------------------- // Tensor to buffer (memref) conversion. // -------------------------------------------------------------------------- diff --git a/lib/E2E/LowerToHybridTensorMemRef.cpp b/lib/E2E/LowerToHybridTensorMemRef.cpp index 0ecbac736..0c4ec96ae 100644 --- a/lib/E2E/LowerToHybridTensorMemRef.cpp +++ b/lib/E2E/LowerToHybridTensorMemRef.cpp @@ -115,6 +115,21 @@ public: } }; +// TODO: This should be layered in better somewhere. +// We currently only create DimOp's during LowerBroadcastToToLoopsPattern, +// so for now just stuff it in here. +class LowerDimOpToShape : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(DimOp op, + PatternRewriter &rewriter) const override { + auto shape = + rewriter.create(op.getLoc(), op.getOperand()); + rewriter.replaceOpWithNewOp(op, shape, op.index()); + return success(); + } +}; + namespace { class LowerBroadcastToToLoops : public LowerBroadcastToToLoopsBase { @@ -126,12 +141,13 @@ class LowerBroadcastToToLoops target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); - target.addIllegalOp(); + OwningRewritePatternList patterns; - - target.addIllegalOp(); patterns.insert(context); + target.addIllegalOp(); + patterns.insert(context); + if (failed(applyPartialConversion(func, target, patterns))) { return signalPassFailure(); }