//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "npcomp/Conversion/TCFToStd/TCFToStd.h" #include "../PassDetail.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Traits.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "npcomp/Dialect/TCF/IR/TCFOps.h" #include "npcomp/Dialect/TCP/IR/TCPDialect.h" #include "npcomp/Dialect/TCP/IR/TCPOps.h" using namespace mlir; using namespace mlir::NPCOMP; static RankedTensorType getExtentTensorType(Builder &builder) { return RankedTensorType::get({ShapedType::kDynamicSize}, builder.getIndexType()); } // Non-templated version of the body of ConvertBinaryElementwise to keep things // simple. static LogicalResult matchAndRewriteBinaryElementwise(Operation *op, PatternRewriter &rewriter) { Value lhs = op->getOperand(0); Value rhs = op->getOperand(1); Location loc = op->getLoc(); Value result = op->getResult(0); auto lhsType = lhs.getType().dyn_cast(); auto rhsType = rhs.getType().dyn_cast(); if (!lhsType || !rhsType) return rewriter.notifyMatchFailure(op, "requires ranked tensors"); Value lhsShape = rewriter.create(loc, lhs); Value rhsShape = rewriter.create(loc, rhs); // Create the constraints, and the assuming region. Value witness = rewriter.create(loc, lhsShape, rhsShape); auto assuming = rewriter.create( loc, ArrayRef{result.getType()}, witness); // Start building the region body. rewriter.createBlock(&assuming.doRegion()); Value broadcastedShape = rewriter.create( loc, getExtentTensorType(rewriter), lhsShape, rhsShape, /*error=*/nullptr); // TODO: It's annoying to do the dynamic broadcast above then // do the static transfer function here. Would be nice if they could // somehow be unified. SmallVector broadcastedStaticShape; OpTrait::util::getBroadcastedShape(lhsType.getShape(), rhsType.getShape(), broadcastedStaticShape); auto resultType = RankedTensorType::get(broadcastedStaticShape, lhsType.getElementType()); Value lhsBroadcasted = rewriter.create( loc, resultType, lhs, broadcastedShape); Value rhsBroadcasted = rewriter.create( loc, resultType, rhs, broadcastedShape); Value binaryOpResult; if (isa(op)) { binaryOpResult = rewriter.create(loc, result.getType(), lhsBroadcasted, rhsBroadcasted); } else if (isa(op)) { // XXX: remove TCP dep // XXX: remove TCP ops from TCP auto pred = rewriter.create(loc, CmpFPredicate::OGT, lhsBroadcasted, rhsBroadcasted); binaryOpResult = rewriter.create(loc, pred, lhsBroadcasted, rhsBroadcasted); } else if (isa(op)) { binaryOpResult = rewriter.create(loc, result.getType(), lhsBroadcasted, rhsBroadcasted); } else { op->dump(); llvm::report_fatal_error( "unhandled op (see dump above): TCF->Std binary elementwise"); } rewriter.create(loc, binaryOpResult); // Finally, replace with the results of the shape.assuming rewriter.replaceOp(op, assuming.getResults()); return success(); } namespace { template class ConvertBinaryElementwise : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SourceOp op, PatternRewriter &rewriter) const override { return matchAndRewriteBinaryElementwise(op, rewriter); } }; } // namespace static LogicalResult matchAndRewriteUnaryElementwise(Operation *op, PatternRewriter &rewriter) { if (isa(op)) { rewriter.replaceOpWithNewOp(op, op->getOperand(0)); } else if (isa(op)) { rewriter.replaceOpWithNewOp(op, op->getOperand(0)); } else { op->dump(); llvm::report_fatal_error( "unhandled op (see dump above): TCF->TCP unary elementwise"); } return success(); } namespace { template class ConvertUnaryElementwise : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SourceOp op, PatternRewriter &rewriter) const override { return matchAndRewriteUnaryElementwise(op, rewriter); } }; } // namespace namespace { class ConvertTCFToStd : public ConvertTCFToStdBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } void runOnOperation() override { (void)applyPatternsAndFoldGreedily(getOperation(), getPatterns()); } FrozenRewritePatternList getPatterns() { MLIRContext *context = &getContext(); OwningRewritePatternList patterns; patterns.insert, ConvertUnaryElementwise>(context); patterns.insert, ConvertBinaryElementwise, ConvertBinaryElementwise>(context); return std::move(patterns); } }; } // namespace std::unique_ptr> mlir::NPCOMP::createConvertTCFToStdPass() { return std::make_unique(); }