diff --git a/include/npcomp/Dialect/TCP/IR/TCPOps.td b/include/npcomp/Dialect/TCP/IR/TCPOps.td index 81f90f0fa..5a24dbcc1 100644 --- a/include/npcomp/Dialect/TCP/IR/TCPOps.td +++ b/include/npcomp/Dialect/TCP/IR/TCPOps.td @@ -54,4 +54,22 @@ def TCP_SplattedOp : TCP_Op<"splatted"> { let assemblyFormat = "$splatVal `,` $shape attr-dict `:` functional-type(operands, results)"; } +def TCP_PadOp : TCP_Op<"pad"> { + let summary = "Pads a tensor with a fill value"; + let description = [{ + Pads a tensor with `fillVal` along the borders of each dimension according to `lowerExpansion` and `upperExpansion`. Note that this op is unmanaged, meaning that it assumes its operands and their shapes are valid. + + The tensors have dimensions: + - operand: [D1, D2, ..., DN] + - lowerExpansion: [L1, L2, ..., LN] + - upperExpansion: [U1, U2, ..., UN] + - fillVal: scalar + - result: [D1+L1+U1, D2+L2+U2, ..., DN+LN+UN] + }]; + let arguments = (ins AnyRankedTensor:$operand, Shape_ExtentTensorType:$lowerExpansion, Shape_ExtentTensorType:$upperExpansion, AnyType:$fillVal); + let results = (outs AnyRankedTensor:$result); + + let assemblyFormat = "$operand `,` $lowerExpansion `,` $upperExpansion `,` $fillVal attr-dict `:` functional-type(operands, results)"; +} + #endif // TCP_OPS diff --git a/lib/Dialect/TCP/Transforms/Bufferize.cpp b/lib/Dialect/TCP/Transforms/Bufferize.cpp index 776f98c4e..4e68b9b0b 100644 --- a/lib/Dialect/TCP/Transforms/Bufferize.cpp +++ b/lib/Dialect/TCP/Transforms/Bufferize.cpp @@ -37,6 +37,22 @@ static SmallVector bypassResultShapes(Operation &op) { return {splatted.shape()}; } + if (auto pad = dyn_cast(op)) { + SmallVector outDims; + auto inputType = pad.operand().getType().cast(); + for (int i = 0, e = inputType.getRank(); i < e; i++) { + Value dimIndex = builder.create(op.getLoc(), i); + Value lowerExpansion = builder.create(op.getLoc(), pad.lowerExpansion(), ValueRange({dimIndex})); + Value upperExpansion = builder.create(op.getLoc(), pad.upperExpansion(), ValueRange({dimIndex})); + Value operandDim = builder.create(op.getLoc(), pad.operand(), i); + Value totalExpansion = builder.create(op.getLoc(), lowerExpansion, upperExpansion); + Value outDim = builder.create(op.getLoc(), totalExpansion, operandDim); + outDims.push_back(outDim); + } + Value outDimTensor = builder.create(op.getLoc(), ValueRange(outDims)); + return {outDimTensor}; + } + // No shape transfer function. return {}; } @@ -158,6 +174,39 @@ public: }; } // namespace +namespace { +class BufferizePadOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(tcp::PadOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto resultsOrFailure = allocateResults(op, rewriter, op.getLoc()); + if (failed(resultsOrFailure)) + return failure(); + auto results = *resultsOrFailure; + auto c1 = rewriter.create(op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); + SmallVector offsets, sizes, strides; + auto resultType = op.getType().cast(); + for (int i = 0, e = resultType.getRank(); i < e; i++) { + Value dimIndex = rewriter.create(op.getLoc(), i); + Value offset = rewriter.create(op.getLoc(), op.lowerExpansion(), ValueRange({dimIndex})); + Value size = rewriter.create(op.getLoc(), op.operand(), i); + Value stride = c1; + offsets.push_back(offset); + sizes.push_back(size); + strides.push_back(stride); + } + rewriter.create(op.getLoc(), results[0], op.fillVal()); + auto unpadded = rewriter.create(op.getLoc(), results[0], ValueRange(offsets), ValueRange(sizes), ValueRange(strides)); + Value inputMemref = operands[0]; + rewriter.create(op.getLoc(), inputMemref, unpadded); + rewriter.replaceOp(op, results); + return success(); + } +}; +} // namespace + namespace { class TCPBufferizePass : public TCPBufferizeBase { void getDependentDialects(::mlir::DialectRegistry ®istry) const override { @@ -185,6 +234,8 @@ class TCPBufferizePass : public TCPBufferizeBase { target.addIllegalOp(); patterns.insert(typeConverter, context); target.addIllegalOp(); + patterns.insert(typeConverter, context); + target.addIllegalOp(); target.addLegalDialect(); target.addLegalDialect(); diff --git a/test/Dialect/TCP/bufferize.mlir b/test/Dialect/TCP/bufferize.mlir index 16fd050b2..12ad0bc9f 100644 --- a/test/Dialect/TCP/bufferize.mlir +++ b/test/Dialect/TCP/bufferize.mlir @@ -25,3 +25,35 @@ func @tcp_splatted(%arg0: f32, %arg1: tensor) -> tensor { %0 = tcp.splatted %arg0, %arg1 : (f32, tensor) -> tensor return %0 : tensor } + +// CHECK-LABEL: func @tcp_pad( +// CHECK-SAME: %[[TENSOR:[a-zA-Z0-9]+]]: tensor, +// CHECK-SAME: %[[LOWER_EXPANSION:[a-zA-Z0-9]+]]: tensor, +// CHECK-SAME: %[[UPPER_EXPANSION:[a-zA-Z0-9]+]]: tensor, +// CHECK-SAME: %[[FILL_VAL:[a-zA-Z0-9]+]]: f32) -> tensor { +// CHECK: %[[TENSOR_MREF:.*]] = tensor_to_memref %[[TENSOR]] : memref +// CHECK: %[[LOWER_EXPANSION_MREF:.*]] = tensor_to_memref %[[LOWER_EXPANSION]] : memref +// CHECK: %[[UPPER_EXPANSION_MREF:.*]] = tensor_to_memref %[[UPPER_EXPANSION]] : memref +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[LOWER_EXTENT_D1:.*]] = tensor.extract %[[LOWER_EXPANSION]][%[[C0]]] : tensor +// CHECK: %[[UPPER_EXTENT_D1:.*]] = tensor.extract %[[UPPER_EXPANSION]][%[[C0]]] : tensor +// CHECK: %[[C0_0:.*]] = constant 0 : index +// CHECK: %[[D1:.*]] = dim %[[TENSOR]], %[[C0_0]] : tensor +// CHECK: %[[D1_EXPANSION:.*]] = addi %[[LOWER_EXTENT_D1]], %[[UPPER_EXTENT_D1]] : index +// CHECK: %[[D1_OUT:.*]] = addi %[[D1_EXPANSION]], %[[D1]] : index +// CHECK: %[[D1_OUT_TENSOR:.*]] = tensor_from_elements %[[D1_OUT]] : tensor<1xindex> +// CHECK: %[[D1_OUT_MREF:.*]] = refback.alloc_memref %[[D1_OUT_TENSOR]] : memref +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[C0_1:.*]] = constant 0 : index +// CHECK: %[[LOWER_EXTENT_D1_1:.*]] = tensor.extract %[[LOWER_EXPANSION]][%[[C0_1]]] : tensor +// CHECK: %[[C0_2:.*]] = constant 0 : index +// CHECK: %[[D1_1:.*]] = dim %[[TENSOR]], %[[C0_2]] : tensor +// CHECK: linalg.fill(%[[D1_OUT_MREF]], %[[FILL_VAL]]) : memref, f32 +// CHECK: %[[SUBVIEW:.*]] = subview %[[D1_OUT_MREF]][%[[LOWER_EXTENT_D1_1]]] [%[[D1_1]]] [%[[C1]]] : memref to memref +// CHECK: linalg.copy(%0, %[[SUBVIEW]]) : memref, memref +// CHECK: %[[RESULT_TENSOR:.*]] = tensor_load %[[D1_OUT_MREF]] : memref +// CHECK: return %[[RESULT_TENSOR]] : tensor +func @tcp_pad(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: f32) -> tensor { + %0 = tcp.pad %arg0, %arg1, %arg2, %arg3 : (tensor, tensor, tensor, f32) -> tensor + return %0 : tensor +} diff --git a/test/Dialect/TCP/ops.mlir b/test/Dialect/TCP/ops.mlir index ca08c173e..9f1a836ec 100644 --- a/test/Dialect/TCP/ops.mlir +++ b/test/Dialect/TCP/ops.mlir @@ -13,3 +13,10 @@ func @splatted(%arg0: f32, %arg1: tensor) -> tensor { %0 = tcp.splatted %arg0, %arg1 : (f32, tensor) -> tensor return %0 : tensor } + +// CHECK-LABEL: @pad +func @pad(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: f32) -> tensor { + // CHECK: tcp.pad + %0 = tcp.pad %arg0, %arg1, %arg2, %arg3 : (tensor, tensor, tensor, f32) -> tensor + return %0 : tensor +} diff --git a/tools/npcomp-run-mlir/npcomp-run-mlir.cpp b/tools/npcomp-run-mlir/npcomp-run-mlir.cpp index 904e90ede..a2494f2d7 100644 --- a/tools/npcomp-run-mlir/npcomp-run-mlir.cpp +++ b/tools/npcomp-run-mlir/npcomp-run-mlir.cpp @@ -42,7 +42,8 @@ convertAttrToTensor(Attribute attr) { auto extents = llvm::to_vector<6>(llvm::map_range( type.getShape(), [](int64_t x) { return static_cast(x); })); auto elementType = type.getElementType(); - if (auto denseFp = attr.dyn_cast()) { + auto denseFp = attr.dyn_cast(); + if (denseFp) { if (elementType.isF32()) { auto values = llvm::to_vector<100>(llvm::map_range( denseFp, [](APFloat f) { return f.convertToFloat(); })); @@ -50,6 +51,8 @@ convertAttrToTensor(Attribute attr) { refbackrt::ArrayRef(extents.data(), extents.size()), refbackrt::ElementType::F32, static_cast(values.data())); } + } else { + return make_string_error("unhandled argument; must be dense floating-point"); } return make_string_error("unhandled argument"); }