//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/PatternMatch.h" #include "npcomp/Conversion/BasicpyToStd/Patterns.h" #include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h" #include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h" using namespace mlir; using namespace mlir::NPCOMP; namespace { bool isLegalBinaryOpType(Type type) { if (type.isIntOrFloat()) { return type.getIntOrFloatBitWidth() > 1; // Do not match i1 } return false; } // Convert to std ops when all types match. It is assumed that additional // patterns and type inference are used to get into this form. class NumericBinaryExpr : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Basicpy::BinaryExprOp op, PatternRewriter &rewriter) const override { // Match failure unless if both: // a) operands/results are the same type // b) matches a set of supported primitive types // c) the operation maps to a simple std op without further massaging auto valueType = op.left().getType(); if (valueType != op.right().getType() || valueType != op.result().getType()) return failure(); if (!isLegalBinaryOpType(valueType)) return failure(); auto operation = Basicpy::symbolizeBinaryOperation(op.operation()); if (!operation) return failure(); auto left = op.left(); auto right = op.right(); // Generally, int and float ops in std are different. using Basicpy::BinaryOperation; if (valueType.isa()) { // Note that not all operations make sense or are defined for integer // math. Of specific note is the Div vs FloorDiv distinction. switch (*operation) { case BinaryOperation::Add: rewriter.replaceOpWithNewOp(op, left, right); return success(); case BinaryOperation::BitAnd: rewriter.replaceOpWithNewOp(op, left, right); return success(); case BinaryOperation::BitOr: rewriter.replaceOpWithNewOp(op, left, right); return success(); case BinaryOperation::BitXor: rewriter.replaceOpWithNewOp(op, left, right); return success(); case BinaryOperation::FloorDiv: // TODO: This is not a precise match for negative division. // SignedDivIOp rounds towards zero and python rounds towards // most negative. rewriter.replaceOpWithNewOp(op, left, right); return success(); case BinaryOperation::LShift: rewriter.replaceOpWithNewOp(op, left, right); return success(); case BinaryOperation::Mod: rewriter.replaceOpWithNewOp(op, left, right); return success(); case BinaryOperation::Mult: rewriter.replaceOpWithNewOp(op, left, right); return success(); case BinaryOperation::RShift: rewriter.replaceOpWithNewOp(op, left, right); return success(); case BinaryOperation::Sub: rewriter.replaceOpWithNewOp(op, left, right); return success(); default: return failure(); } } else if (valueType.isa()) { // Note that most operations are not supported on floating point values. // In addition, some cannot be directly implemented with single std // ops. switch (*operation) { case BinaryOperation::Add: rewriter.replaceOpWithNewOp(op, left, right); return success(); case BinaryOperation::Div: rewriter.replaceOpWithNewOp(op, left, right); return success(); case BinaryOperation::FloorDiv: // TODO: Implement floating point floor division. return rewriter.notifyMatchFailure( op, "floating point floor division not implemented"); case BinaryOperation::Mod: // TODO: Implement floating point mod. return rewriter.notifyMatchFailure( op, "floating point mod not implemented"); case BinaryOperation::Mult: rewriter.replaceOpWithNewOp(op, left, right); return success(); case BinaryOperation::Sub: rewriter.replaceOpWithNewOp(op, left, right); return success(); default: return failure(); } } return failure(); } }; Optional mapBasicpyPredicateToCmpI(Basicpy::CompareOperation predicate) { using Basicpy::CompareOperation; switch (predicate) { case CompareOperation::Eq: return CmpIPredicate::eq; case CompareOperation::Gt: return CmpIPredicate::sgt; case CompareOperation::GtE: return CmpIPredicate::sge; case CompareOperation::Is: return CmpIPredicate::eq; case CompareOperation::IsNot: return CmpIPredicate::ne; case CompareOperation::Lt: return CmpIPredicate::slt; case CompareOperation::LtE: return CmpIPredicate::sle; case CompareOperation::NotEq: return CmpIPredicate::ne; default: return llvm::None; } } Optional mapBasicpyPredicateToCmpF(Basicpy::CompareOperation predicate) { using Basicpy::CompareOperation; switch (predicate) { case CompareOperation::Eq: return CmpFPredicate::OEQ; case CompareOperation::Gt: return CmpFPredicate::OGT; case CompareOperation::GtE: return CmpFPredicate::OGE; case CompareOperation::Is: return CmpFPredicate::OEQ; case CompareOperation::IsNot: return CmpFPredicate::ONE; case CompareOperation::Lt: return CmpFPredicate::OLT; case CompareOperation::LtE: return CmpFPredicate::OLE; case CompareOperation::NotEq: return CmpFPredicate::ONE; default: return llvm::None; } } class NumericCompare : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Basicpy::BinaryCompareOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto valueType = op.left().getType(); if (valueType != op.right().getType()) return failure(); if (!isLegalBinaryOpType(valueType)) return failure(); auto bpyPredicate = Basicpy::symbolizeCompareOperation(op.operation()); if (!bpyPredicate) return failure(); if (valueType.isa()) { if (auto stdPredicate = mapBasicpyPredicateToCmpI(*bpyPredicate)) { auto cmp = rewriter.create(loc, *stdPredicate, op.left(), op.right()); rewriter.replaceOpWithNewOp( op, Basicpy::BoolType::get(rewriter.getContext()), cmp); return success(); } else { return rewriter.notifyMatchFailure(op, "unsupported compare operation"); } } else if (valueType.isa()) { if (auto stdPredicate = mapBasicpyPredicateToCmpF(*bpyPredicate)) { auto cmp = rewriter.create(loc, *stdPredicate, op.left(), op.right()); rewriter.replaceOpWithNewOp( op, Basicpy::BoolType::get(rewriter.getContext()), cmp); return success(); } else { return rewriter.notifyMatchFailure(op, "unsupported compare operation"); } } return failure(); } }; // Converts the as_i1 op for numeric types. class NumericToI1 : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Basicpy::AsI1Op op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto operandType = op.operand().getType(); if (operandType.isa()) { auto zero = rewriter.create(loc, 0, operandType); rewriter.replaceOpWithNewOp(op, CmpIPredicate::ne, op.operand(), zero); return success(); } else if (operandType.isa()) { auto zero = rewriter.create(loc, operandType, FloatAttr::get(operandType, 0.0)); rewriter.replaceOpWithNewOp(op, CmpFPredicate::ONE, op.operand(), zero); return success(); } return failure(); } }; } // namespace void mlir::NPCOMP::populateBasicpyToStdPrimitiveOpPatterns( RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); patterns.add(context); patterns.add(context); patterns.add(context); }