//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "npcomp/Conversion/TCFToLinalg/TCFToLinalg.h" #include "../PassDetail.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Traits.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "npcomp/Dialect/TCF/IR/TCFOps.h" #include "npcomp/Dialect/TCP/IR/TCPDialect.h" #include "npcomp/Dialect/TCP/IR/TCPOps.h" using namespace mlir; using namespace mlir::NPCOMP; static SmallVector bypassResultShapes(Operation *op, OpBuilder &builder) { if (auto matmul = dyn_cast(op)) { auto lhsRows = builder.create(op->getLoc(), matmul.lhs(), 0); auto rhsCols = builder.create(op->getLoc(), matmul.rhs(), 1); auto shape = builder.create( op->getLoc(), ValueRange({lhsRows, rhsCols})); return {shape}; } // TODO: This only supports the NCHW data format. Consider other formats and // lower ranks. if (auto conv2dNCHW = dyn_cast(op)) { // TODO: Replace hard-coded stride/dilation/padding constant-ops. // TODO: Consider migrating this SSA shape-computing graph to a complex op // or use the `mlir-linalg-ods-gen` approach and define a `*.tc` spec file. auto cI0 = builder.create( op->getLoc(), builder.getIntegerAttr(builder.getIndexType(), 0)); auto cI1 = builder.create( op->getLoc(), builder.getIntegerAttr(builder.getIndexType(), 1)); auto cI2 = builder.create( op->getLoc(), builder.getIntegerAttr(builder.getIndexType(), 2)); auto stride = cI1; auto dilation = cI1; auto padding = cI0; auto strideHeight = stride; auto strideWidth = stride; auto dilationHeight = dilation; auto dilationWidth = dilation; auto paddingHeight = padding; auto paddingWidth = padding; auto batch = builder.create(op->getLoc(), conv2dNCHW.in(), 0); auto height = builder.create(op->getLoc(), conv2dNCHW.in(), 2); auto width = builder.create(op->getLoc(), conv2dNCHW.in(), 3); auto filterOutChannels = builder.create(op->getLoc(), conv2dNCHW.filter(), 0); auto filterHeight = builder.create(op->getLoc(), conv2dNCHW.filter(), 2); auto filterWidth = builder.create(op->getLoc(), conv2dNCHW.filter(), 3); // Output height auto twicePaddingHeight = builder.create(op->getLoc(), paddingHeight, cI2); auto heightPlusTwicePadding = builder.create(op->getLoc(), height, twicePaddingHeight); auto filterHeightMinusOne = builder.create(op->getLoc(), filterHeight, cI1); auto dilationFilterHeight = builder.create( op->getLoc(), dilationHeight, filterHeightMinusOne); auto outHeightUnstridedPlusOne = builder.create( op->getLoc(), heightPlusTwicePadding, dilationFilterHeight); auto outHeightUnstrided = builder.create(op->getLoc(), outHeightUnstridedPlusOne, cI1); auto outHeightMinusOne = builder.create( op->getLoc(), outHeightUnstrided, strideHeight); auto outHeight = builder.create(op->getLoc(), outHeightMinusOne, cI1); // Output width auto twicePaddingWidth = builder.create(op->getLoc(), paddingWidth, cI2); auto widthPlusTwicePadding = builder.create(op->getLoc(), width, twicePaddingWidth); auto filterWidthMinusOne = builder.create(op->getLoc(), filterWidth, cI1); auto dilationFilterWidth = builder.create( op->getLoc(), dilationWidth, filterWidthMinusOne); auto outWidthUnstridedPlusOne = builder.create( op->getLoc(), widthPlusTwicePadding, dilationFilterWidth); auto outWidthUnstrided = builder.create(op->getLoc(), outWidthUnstridedPlusOne, cI1); auto outWidthMinusOne = builder.create( op->getLoc(), outWidthUnstrided, strideWidth); auto outWidth = builder.create(op->getLoc(), outWidthMinusOne, cI1); // Output shape auto shape = builder.create( op->getLoc(), ValueRange({batch, filterOutChannels, outHeight, outWidth})); return {shape}; } // No shape transfer function. return {}; } namespace { class ConvertMatmul : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tcf::MatmulOp op, PatternRewriter &rewriter) const override { // Create the constraints, and the assuming region. Value lhsK = rewriter.create(op.getLoc(), op.lhs(), 1); Value rhsK = rewriter.create(op.getLoc(), op.rhs(), 0); Value matchingK = rewriter.create(op.getLoc(), CmpIPredicate::eq, lhsK, rhsK); Value witness = rewriter.create( op.getLoc(), matchingK, "mismatching contracting dimension for matmul"); auto assuming = rewriter.create( op.getLoc(), ArrayRef{op.getType()}, witness); // Build the region body. rewriter.createBlock(&assuming.doRegion()); // Create the init tensor for the matmul. // TODO: Expand supported data types. Value c0 = rewriter.create(op.getLoc(), rewriter.getF32FloatAttr(0.0)); Value shape = bypassResultShapes(op, rewriter)[0]; Value initTensor = rewriter.create(op.getLoc(), op.getType(), c0, shape); // Create the matmul. auto matmul = rewriter.create( op.getLoc(), TypeRange(op.getType()), op.getOperands(), ValueRange(initTensor)); rewriter.create(op.getLoc(), matmul.getResult(0)); // Finally, replace with the results of the shape.assuming rewriter.replaceOp(op, assuming.getResults()); return success(); } }; } // namespace namespace { class ConvertConvNCHW : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tcf::ConvNCHWOp op, PatternRewriter &rewriter) const override { // Create the constraints, and the assuming region. Value inputCin = rewriter.create(op.getLoc(), op.in(), 1); Value inputH = rewriter.create(op.getLoc(), op.in(), 2); Value inputW = rewriter.create(op.getLoc(), op.in(), 3); Value filterCin = rewriter.create(op.getLoc(), op.filter(), 1); Value filterKH = rewriter.create(op.getLoc(), op.filter(), 2); Value filterKW = rewriter.create(op.getLoc(), op.filter(), 3); Value matchingCin = rewriter.create(op.getLoc(), CmpIPredicate::eq, inputCin, filterCin); Value validFilterH = rewriter.create( op.getLoc(), CmpIPredicate::uge, inputH, filterKH); Value validFilterW = rewriter.create( op.getLoc(), CmpIPredicate::uge, inputW, filterKW); Value witnessCin = rewriter.create( op.getLoc(), matchingCin, "input and filter in-channels must be equal"); Value witnessFilterH = rewriter.create( op.getLoc(), validFilterH, "input height must be greater than or equal to filter KH-dimension"); Value witnessFilterW = rewriter.create( op.getLoc(), validFilterW, "input width must be greater than or equal to filter KW-dimension"); Value assumingAll = rewriter.create( op.getLoc(), witnessCin.getType(), ValueRange({witnessCin, witnessFilterH, witnessFilterW})); auto assuming = rewriter.create( op.getLoc(), ArrayRef{op.getType()}, assumingAll); // Build the region body. rewriter.createBlock(&assuming.doRegion()); // Create the init tensor for the ConvNCHW. // TODO: Expand supported data types. Value c0 = rewriter.create(op.getLoc(), rewriter.getF32FloatAttr(0.0)); Value shape = bypassResultShapes(op, rewriter)[0]; Value initTensor = rewriter.create(op.getLoc(), op.getType(), c0, shape); // Unit strides and dilations. auto strides = rewriter.getI64VectorAttr({1, 1}); auto dilations = rewriter.getI64VectorAttr({1, 1}); // Create the ConvNCHW. auto conv2dNCHW = rewriter.create( op.getLoc(), TypeRange(op.getType()), ValueRange({op.in(), op.filter()}), ValueRange(initTensor), strides, dilations); rewriter.create(op.getLoc(), conv2dNCHW.getResults()); // Finally, replace with the results of the shape.assuming rewriter.replaceOp(op, assuming.getResults()); return success(); } }; } // namespace namespace { class ConvertTCFToLinalg : public ConvertTCFToLinalgBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } void runOnOperation() override { (void)applyPatternsAndFoldGreedily(getOperation(), getPatterns()); } FrozenRewritePatternSet getPatterns() { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); patterns.add(context); patterns.add(context); return std::move(patterns); } }; } // namespace std::unique_ptr> mlir::NPCOMP::createConvertTCFToLinalgPass() { return std::make_unique(); }