diff --git a/include/npcomp/Dialect/TCF/IR/TCFOps.td b/include/npcomp/Dialect/TCF/IR/TCFOps.td index 3d53028c6..4c1b00213 100644 --- a/include/npcomp/Dialect/TCF/IR/TCFOps.td +++ b/include/npcomp/Dialect/TCF/IR/TCFOps.td @@ -95,4 +95,25 @@ def TCF_MatmulOp : TCF_Op<"matmul"> { let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)"; } +def TCF_ConvNCHWOp : TCF_Op<"conv_2d_nchw"> { + let summary = "2-D convolution"; + let description = [{ + Performs 2-D convolution. This op is inspired by PyTorch's Conv2d layer (https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html). + + The tensors have dimensions: + - in: [N, Cin, H, W] + - filter: [Cout, Cin, KH, KW] + - result: [N, Cout, Hout, Wout] + + The tensors must meet the following conditions; otherwise, this op aborts the program. + - H is greater than or equal to KH + - W is greater than or equal to KW + - Cin matches between in and filter + }]; + let arguments = (ins 4DTensorOf<[F32]>:$in, 4DTensorOf<[F32]>:$filter); + let results = (outs 4DTensorOf<[F32]>:$result); + + let assemblyFormat = "$in `,` $filter attr-dict `:` functional-type(operands, results)"; +} + #endif // #ifndef TCF_OPS diff --git a/lib/Conversion/TCFToLinalg/TCFToLinalg.cpp b/lib/Conversion/TCFToLinalg/TCFToLinalg.cpp index 2494744a3..71232264f 100644 --- a/lib/Conversion/TCFToLinalg/TCFToLinalg.cpp +++ b/lib/Conversion/TCFToLinalg/TCFToLinalg.cpp @@ -32,6 +32,51 @@ static SmallVector bypassResultShapes(Operation *op, op->getLoc(), ValueRange({lhsRows, rhsCols})); return {shape}; } + // TODO: This only supports the NCHW data format. Consider other formats and lower ranks. + if (auto conv2dNCHW = dyn_cast(op)) { + // TODO: Replace hard-coded stride/dilation/padding constant-ops. + // TODO: Consider migrating this SSA shape-computing graph to a complex op or use the `mlir-linalg-ods-gen` approach and define a `*.tc` spec file. + auto cI0 = builder.create(op->getLoc(), builder.getIntegerAttr(builder.getIndexType(), 0)); + auto cI1 = builder.create(op->getLoc(), builder.getIntegerAttr(builder.getIndexType(), 1)); + auto cI2 = builder.create(op->getLoc(), builder.getIntegerAttr(builder.getIndexType(), 2)); + auto stride = cI1; + auto dilation = cI1; + auto padding = cI0; + auto strideHeight = stride; + auto strideWidth = stride; + auto dilationHeight = dilation; + auto dilationWidth = dilation; + auto paddingHeight = padding; + auto paddingWidth = padding; + auto batch = builder.create(op->getLoc(), conv2dNCHW.in(), 0); + auto height = builder.create(op->getLoc(), conv2dNCHW.in(), 2); + auto width = builder.create(op->getLoc(), conv2dNCHW.in(), 3); + auto filterOutChannels = builder.create(op->getLoc(), conv2dNCHW.filter(), 0); + auto filterHeight = builder.create(op->getLoc(), conv2dNCHW.filter(), 2); + auto filterWidth = builder.create(op->getLoc(), conv2dNCHW.filter(), 3); + // Output height + auto twicePaddingHeight = builder.create(op->getLoc(), paddingHeight, cI2); + auto heightPlusTwicePadding = builder.create(op->getLoc(), height, twicePaddingHeight); + auto filterHeightMinusOne = builder.create(op->getLoc(), filterHeight, cI1); + auto dilationFilterHeight = builder.create(op->getLoc(), dilationHeight, filterHeightMinusOne); + auto outHeightUnstridedPlusOne = builder.create(op->getLoc(), heightPlusTwicePadding, dilationFilterHeight); + auto outHeightUnstrided = builder.create(op->getLoc(), outHeightUnstridedPlusOne, cI1); + auto outHeightMinusOne = builder.create(op->getLoc(), outHeightUnstrided, strideHeight); + auto outHeight = builder.create(op->getLoc(), outHeightMinusOne, cI1); + // Output width + auto twicePaddingWidth = builder.create(op->getLoc(), paddingWidth, cI2); + auto widthPlusTwicePadding = builder.create(op->getLoc(), width, twicePaddingWidth); + auto filterWidthMinusOne = builder.create(op->getLoc(), filterWidth, cI1); + auto dilationFilterWidth = builder.create(op->getLoc(), dilationWidth, filterWidthMinusOne); + auto outWidthUnstridedPlusOne = builder.create(op->getLoc(), widthPlusTwicePadding, dilationFilterWidth); + auto outWidthUnstrided = builder.create(op->getLoc(), outWidthUnstridedPlusOne, cI1); + auto outWidthMinusOne = builder.create(op->getLoc(), outWidthUnstrided, strideWidth); + auto outWidth = builder.create(op->getLoc(), outWidthMinusOne, cI1); + // Output shape + auto shape = builder.create( + op->getLoc(), ValueRange({batch, filterOutChannels, outHeight, outWidth})); + return {shape}; + } // No shape transfer function. return {}; @@ -76,6 +121,59 @@ public: }; } // namespace +namespace { +class ConvertConvNCHW : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tcf::ConvNCHWOp op, + PatternRewriter &rewriter) const override { + // Create the constraints, and the assuming region. + Value inputCin = rewriter.create(op.getLoc(), op.in(), 1); + Value inputH = rewriter.create(op.getLoc(), op.in(), 2); + Value inputW = rewriter.create(op.getLoc(), op.in(), 3); + Value filterCin = rewriter.create(op.getLoc(), op.filter(), 1); + Value filterKH = rewriter.create(op.getLoc(), op.filter(), 2); + Value filterKW = rewriter.create(op.getLoc(), op.filter(), 3); + Value matchingCin = + rewriter.create(op.getLoc(), CmpIPredicate::eq, inputCin, filterCin); + Value validFilterH = + rewriter.create(op.getLoc(), CmpIPredicate::uge, inputH, filterKH); + Value validFilterW = + rewriter.create(op.getLoc(), CmpIPredicate::uge, inputW, filterKW); + Value witnessCin = rewriter.create( + op.getLoc(), matchingCin, "input and filter in-channels must be equal"); + Value witnessFilterH = rewriter.create( + op.getLoc(), validFilterH, "input height must be greater than or equal to filter KH-dimension"); + Value witnessFilterW = rewriter.create( + op.getLoc(), validFilterW, "input width must be greater than or equal to filter KW-dimension"); + Value assumingAll = rewriter.create( + op.getLoc(), witnessCin.getType(), ValueRange({witnessCin, witnessFilterH, witnessFilterW})); + auto assuming = rewriter.create( + op.getLoc(), ArrayRef{op.getType()}, assumingAll); + + // Build the region body. + rewriter.createBlock(&assuming.doRegion()); + // Create the init tensor for the ConvNCHW. + // TODO: Expand supported data types. + Value c0 = + rewriter.create(op.getLoc(), rewriter.getF32FloatAttr(0.0)); + Value shape = bypassResultShapes(op, rewriter)[0]; + Value initTensor = + rewriter.create(op.getLoc(), op.getType(), c0, shape); + + // Create the ConvNCHW. + auto conv2dNCHW = rewriter.create( + op.getLoc(), TypeRange(op.getType()), ValueRange({op.in(), op.filter()}), ValueRange(), + ValueRange(initTensor)); + rewriter.create(op.getLoc(), conv2dNCHW.getResults()); + + // Finally, replace with the results of the shape.assuming + rewriter.replaceOp(op, assuming.getResults()); + return success(); + } +}; +} // namespace + namespace { class ConvertTCFToLinalg : public ConvertTCFToLinalgBase { public: @@ -91,6 +189,7 @@ public: MLIRContext *context = &getContext(); OwningRewritePatternList patterns; patterns.insert(context); + patterns.insert(context); return std::move(patterns); } }; diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index 30710180c..ab3470e95 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -26,6 +26,7 @@ #include "npcomp/RefBackend/RefBackend.h" #include "PassDetail.h" +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" @@ -283,6 +284,9 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline( // Final conversion to an LLVM module. // -------------------------------------------------------------------------- + // Convert affine to std control flow in preparation for going to LLVM. + pm.addNestedPass(createLowerAffinePass()); + // Convert scf to std control flow in preparation for going to LLVM. pm.addNestedPass(createLowerToCFGPass()); diff --git a/test/Conversion/TCFToLinalg/basic.mlir b/test/Conversion/TCFToLinalg/basic.mlir index e72dd0b49..e19206686 100644 --- a/test/Conversion/TCFToLinalg/basic.mlir +++ b/test/Conversion/TCFToLinalg/basic.mlir @@ -23,3 +23,50 @@ func @tcf_matmul(%arg0: tensor, %arg1: tensor) -> tensor, tensor) -> tensor return %0 : tensor } + +// CHECK-LABEL: func @tcf_conv_2d_nchw( +// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[FILTER:[a-zA-Z0-9]+]]: tensor) -> tensor { +// CHECK: %[[C0F32:.*]] = constant 0.000000e+00 : f32 +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C2:.*]] = constant 2 : index +// CHECK: %[[C3:.*]] = constant 3 : index +// CHECK: %[[CHANNELS:.*]] = dim %[[IN]], %[[C1]] : tensor +// CHECK: %[[HEIGHT:.*]] = dim %[[IN]], %[[C2]] : tensor +// CHECK: %[[WIDTH:.*]] = dim %[[IN]], %[[C3]] : tensor +// CHECK: %[[FILTERCHANNELS:.*]] = dim %[[FILTER]], %[[C1]] : tensor +// CHECK: %[[FILTERHEIGHT:.*]] = dim %[[FILTER]], %[[C2]] : tensor +// CHECK: %[[FILTERWIDTH:.*]] = dim %[[FILTER]], %[[C3]] : tensor +// CHECK: %[[CMPCHANNELS:.*]] = cmpi "eq", %[[CHANNELS]], %[[FILTERCHANNELS]] : index +// CHECK: %[[CMPHEIGHT:.*]] = cmpi "uge", %[[HEIGHT]], %[[FILTERHEIGHT]] : index +// CHECK: %[[CMPWIDTH:.*]] = cmpi "uge", %[[WIDTH]], %[[FILTERWIDTH]] : index +// CHECK: %[[CSTRCHANNELS:.*]] = shape.cstr_require %[[CMPCHANNELS]], "input and filter in-channels must be equal" +// CHECK: %[[CSTRHEIGHT:.*]] = shape.cstr_require %[[CMPHEIGHT]], "input height must be greater than or equal to filter KH-dimension" +// CHECK: %[[CSTRWIDTH:.*]] = shape.cstr_require %[[CMPWIDTH]], "input width must be greater than or equal to filter KW-dimension" +// CHECK: %[[WITNESS:.*]] = shape.assuming_all %[[CSTRCHANNELS]], %[[CSTRHEIGHT]], %[[CSTRWIDTH]] +// CHECK: %[[RET:.*]] = shape.assuming %[[WITNESS]] -> (tensor) { +// CHECK: %[[BATCH:.*]] = dim %[[IN]], %[[C0]] : tensor +// CHECK: %[[HEIGHT:.*]] = dim %[[IN]], %[[C2]] : tensor +// CHECK: %[[WIDTH:.*]] = dim %[[IN]], %[[C3]] : tensor +// CHECK: %[[OUTCHANNELS:.*]] = dim %[[FILTER]], %[[C0]] : tensor +// CHECK: %[[FILTERHEIGHT:.*]] = dim %[[FILTER]], %[[C2]] : tensor +// CHECK: %[[FILTERWIDTH:.*]] = dim %[[FILTER]], %[[C3]] : tensor +// CHECK: %[[FILTERHEIGHTM1:.*]] = subi %[[FILTERHEIGHT]], %[[C1]] : index +// CHECK: %[[HEIGHTV0:.*]] = subi %[[HEIGHT]], %[[FILTERHEIGHTM1]] : index +// CHECK: %[[HEIGHTV0M1:.*]] = subi %[[HEIGHTV0]], %[[C1]] : index +// CHECK: %[[OUTHEIGHT:.*]] = addi %[[HEIGHTV0M1]], %[[C1]] : index +// CHECK: %[[FILTERWIDTHM1:.*]] = subi %[[FILTERWIDTH]], %[[C1]] : index +// CHECK: %[[WIDTHV0:.*]] = subi %[[WIDTH]], %[[FILTERWIDTHM1]] : index +// CHECK: %[[WIDTHV0M1:.*]] = subi %[[WIDTHV0]], %[[C1]] : index +// CHECK: %[[OUTWIDTH:.*]] = addi %[[WIDTHV0M1]], %[[C1]] : index +// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[BATCH]], %[[OUTCHANNELS]], %[[OUTHEIGHT]], %[[OUTWIDTH]] : tensor<4xindex> +// CHECK: %[[INIT_TENSOR:.*]] = tcp.splatted %[[C0F32]], %[[SHAPE]] : (f32, tensor<4xindex>) -> tensor +// CHECK: %[[CONVNCHW:.*]] = linalg.conv_2d_nchw ins(%[[IN]], %[[FILTER]] : tensor, tensor) init(%[[INIT_TENSOR]] : tensor) -> tensor +// CHECK: shape.assuming_yield %[[CONVNCHW]] : tensor +// CHECK: } +// CHECK: return %[[RET:.*]] : tensor +func @tcf_conv_2d_nchw(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = tcf.conv_2d_nchw %arg0, %arg1 : (tensor, tensor) -> tensor + return %0 : tensor +} diff --git a/test/Dialect/TCF/ops.mlir b/test/Dialect/TCF/ops.mlir index 8f5f9577f..69e9d11d0 100644 --- a/test/Dialect/TCF/ops.mlir +++ b/test/Dialect/TCF/ops.mlir @@ -17,3 +17,10 @@ func @matmul(%arg0: tensor, %arg1: tensor) -> tensor %0 = tcf.matmul %arg0, %arg1 : (tensor, tensor) -> tensor return %0 : tensor } + +// CHECK-LABEL: func @conv_2d_nchw +func @conv_2d_nchw(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: tcf.conv_2d_nchw %arg0, %arg1 : (tensor, tensor) -> tensor + %0 = tcf.conv_2d_nchw %arg0, %arg1 : (tensor, tensor) -> tensor + return %0 : tensor +} diff --git a/test/npcomp-run-mlir/conv_2d_nchw.mlir b/test/npcomp-run-mlir/conv_2d_nchw.mlir new file mode 100644 index 000000000..7f7a97293 --- /dev/null +++ b/test/npcomp-run-mlir/conv_2d_nchw.mlir @@ -0,0 +1,67 @@ +// RUN: npcomp-run-mlir %s \ +// RUN: -invoke conv_2d_nchw \ +// RUN: -arg-value="dense<0.0> : tensor<2x1x1x1xf32>" \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \ +// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \ +// RUN: | FileCheck %s --check-prefix=BATCH + +// RUN: npcomp-run-mlir %s \ +// RUN: -invoke conv_2d_nchw \ +// RUN: -arg-value="dense<0.0> : tensor<1x2x1x1xf32>" \ +// RUN: -arg-value="dense<0.0> : tensor<2x2x1x1xf32>" \ +// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \ +// RUN: | FileCheck %s --check-prefix=SAME_CHANNELS + +// RUN: npcomp-run-mlir %s \ +// RUN: -invoke conv_2d_nchw \ +// RUN: -arg-value="dense<0.0> : tensor<1x2x1x1xf32>" \ +// RUN: -arg-value="dense<0.0> : tensor<1x2x1x1xf32>" \ +// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \ +// RUN: | FileCheck %s --check-prefix=DIFFERENT_CHANNELS + +// RUN: npcomp-run-mlir %s \ +// RUN: -invoke conv_2d_nchw \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x2x2xf32>" \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x1x1xf32>" \ +// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \ +// RUN: | FileCheck %s --check-prefix=TINY_SQUARE + +// RUN: npcomp-run-mlir %s \ +// RUN: -invoke conv_2d_nchw \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x32x32xf32>" \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x32x32xf32>" \ +// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \ +// RUN: | FileCheck %s --check-prefix=HUGE_SQUARE + +// RUN: npcomp-run-mlir %s \ +// RUN: -invoke conv_2d_nchw \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x2x2xf32>" \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x0x0xf32>" \ +// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \ +// RUN: | FileCheck %s --check-prefix=ZERO_KH_KW + +// RUN: npcomp-run-mlir %s \ +// RUN: -invoke conv_2d_nchw \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x0x0xf32>" \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x0x0xf32>" \ +// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \ +// RUN: | FileCheck %s --check-prefix=ZERO_H_W + +// BATCH: output #0: dense<0.000000e+00> : tensor<2x1x1x1xf32> + +// SAME_CHANNELS: output #0: dense<0.000000e+00> : tensor<1x2x1x1xf32> + +// DIFFERENT_CHANNELS: output #0: dense<0.000000e+00> : tensor<1x1x1x1xf32> + +// TINY_SQUARE: output #0: dense<0.000000e+00> : tensor<1x1x2x2xf32> + +// HUGE_SQUARE: output #0: dense<0.000000e+00> : tensor<1x1x1x1xf32> + +// ZERO_KH_KW: output #0: dense<0.000000e+00> : tensor<1x1x3x3xf32> + +// ZERO_H_W: output #0: dense<0.000000e+00> : tensor<1x1x1x1xf32> + +func @conv_2d_nchw(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = tcf.conv_2d_nchw %arg0, %arg1 : (tensor, tensor) -> tensor + return %0 : tensor +} diff --git a/test/npcomp-run-mlir/invalid-conv_2d_nchw.mlir b/test/npcomp-run-mlir/invalid-conv_2d_nchw.mlir new file mode 100644 index 000000000..b1866bee2 --- /dev/null +++ b/test/npcomp-run-mlir/invalid-conv_2d_nchw.mlir @@ -0,0 +1,28 @@ +// RUN: not npcomp-run-mlir %s \ +// RUN: -invoke conv_2d_nchw \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x2x2xf32>" \ +// RUN: -arg-value="dense<0.0> : tensor<1x2x2x2xf32>" \ +// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \ +// RUN: | FileCheck %s --check-prefix=CHANNELS + +// RUN: not npcomp-run-mlir %s \ +// RUN: -invoke conv_2d_nchw \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x2x2xf32>" \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x3x2xf32>" \ +// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \ +// RUN: | FileCheck %s --check-prefix=HEIGHT + +// RUN: not npcomp-run-mlir %s \ +// RUN: -invoke conv_2d_nchw \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x2x2xf32>" \ +// RUN: -arg-value="dense<0.0> : tensor<1x1x2x3xf32>" \ +// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \ +// RUN: | FileCheck %s --check-prefix=WIDTH + +// CHANNELS: NPCOMP: aborting: input and filter in-channels must be equal +// HEIGHT: NPCOMP: aborting: input height must be greater than or equal to filter KH-dimension +// WIDTH: NPCOMP: aborting: input width must be greater than or equal to filter KW-dimension +func @conv_2d_nchw(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = tcf.conv_2d_nchw %arg0, %arg1 : (tensor, tensor) -> tensor + return %0 : tensor +}