Import TCP pad

pull/155/head
Aaron J Arthurs 2020-12-17 12:56:46 -06:00 committed by Sean Silva
parent 689b40c7a6
commit fc650c9447
5 changed files with 112 additions and 1 deletions

View File

@ -54,4 +54,22 @@ def TCP_SplattedOp : TCP_Op<"splatted"> {
let assemblyFormat = "$splatVal `,` $shape attr-dict `:` functional-type(operands, results)"; let assemblyFormat = "$splatVal `,` $shape attr-dict `:` functional-type(operands, results)";
} }
def TCP_PadOp : TCP_Op<"pad"> {
let summary = "Pads a tensor with a fill value";
let description = [{
Pads a tensor with `fillVal` along the borders of each dimension according to `lowerExpansion` and `upperExpansion`. Note that this op is unmanaged, meaning that it assumes its operands and their shapes are valid.
The tensors have dimensions:
- operand: [D1, D2, ..., DN]
- lowerExpansion: [L1, L2, ..., LN]
- upperExpansion: [U1, U2, ..., UN]
- fillVal: scalar
- result: [D1+L1+U1, D2+L2+U2, ..., DN+LN+UN]
}];
let arguments = (ins AnyRankedTensor:$operand, Shape_ExtentTensorType:$lowerExpansion, Shape_ExtentTensorType:$upperExpansion, AnyType:$fillVal);
let results = (outs AnyRankedTensor:$result);
let assemblyFormat = "$operand `,` $lowerExpansion `,` $upperExpansion `,` $fillVal attr-dict `:` functional-type(operands, results)";
}
#endif // TCP_OPS #endif // TCP_OPS

View File

@ -37,6 +37,22 @@ static SmallVector<Value, 6> bypassResultShapes(Operation &op) {
return {splatted.shape()}; return {splatted.shape()};
} }
if (auto pad = dyn_cast<tcp::PadOp>(op)) {
SmallVector<Value, 6> outDims;
auto inputType = pad.operand().getType().cast<RankedTensorType>();
for (int i = 0, e = inputType.getRank(); i < e; i++) {
Value dimIndex = builder.create<ConstantIndexOp>(op.getLoc(), i);
Value lowerExpansion = builder.create<tensor::ExtractOp>(op.getLoc(), pad.lowerExpansion(), ValueRange({dimIndex}));
Value upperExpansion = builder.create<tensor::ExtractOp>(op.getLoc(), pad.upperExpansion(), ValueRange({dimIndex}));
Value operandDim = builder.create<DimOp>(op.getLoc(), pad.operand(), i);
Value totalExpansion = builder.create<AddIOp>(op.getLoc(), lowerExpansion, upperExpansion);
Value outDim = builder.create<AddIOp>(op.getLoc(), totalExpansion, operandDim);
outDims.push_back(outDim);
}
Value outDimTensor = builder.create<TensorFromElementsOp>(op.getLoc(), ValueRange(outDims));
return {outDimTensor};
}
// No shape transfer function. // No shape transfer function.
return {}; return {};
} }
@ -158,6 +174,39 @@ public:
}; };
} // namespace } // namespace
namespace {
class BufferizePadOp : public OpConversionPattern<tcp::PadOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tcp::PadOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto resultsOrFailure = allocateResults(op, rewriter, op.getLoc());
if (failed(resultsOrFailure))
return failure();
auto results = *resultsOrFailure;
auto c1 = rewriter.create<ConstantOp>(op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
SmallVector<Value, 6> offsets, sizes, strides;
auto resultType = op.getType().cast<RankedTensorType>();
for (int i = 0, e = resultType.getRank(); i < e; i++) {
Value dimIndex = rewriter.create<ConstantIndexOp>(op.getLoc(), i);
Value offset = rewriter.create<tensor::ExtractOp>(op.getLoc(), op.lowerExpansion(), ValueRange({dimIndex}));
Value size = rewriter.create<DimOp>(op.getLoc(), op.operand(), i);
Value stride = c1;
offsets.push_back(offset);
sizes.push_back(size);
strides.push_back(stride);
}
rewriter.create<linalg::FillOp>(op.getLoc(), results[0], op.fillVal());
auto unpadded = rewriter.create<SubViewOp>(op.getLoc(), results[0], ValueRange(offsets), ValueRange(sizes), ValueRange(strides));
Value inputMemref = operands[0];
rewriter.create<linalg::CopyOp>(op.getLoc(), inputMemref, unpadded);
rewriter.replaceOp(op, results);
return success();
}
};
} // namespace
namespace { namespace {
class TCPBufferizePass : public TCPBufferizeBase<TCPBufferizePass> { class TCPBufferizePass : public TCPBufferizeBase<TCPBufferizePass> {
void getDependentDialects(::mlir::DialectRegistry &registry) const override { void getDependentDialects(::mlir::DialectRegistry &registry) const override {
@ -185,6 +234,8 @@ class TCPBufferizePass : public TCPBufferizeBase<TCPBufferizePass> {
target.addIllegalOp<tcp::BroadcastToOp>(); target.addIllegalOp<tcp::BroadcastToOp>();
patterns.insert<BufferizeSplattedOp>(typeConverter, context); patterns.insert<BufferizeSplattedOp>(typeConverter, context);
target.addIllegalOp<tcp::SplattedOp>(); target.addIllegalOp<tcp::SplattedOp>();
patterns.insert<BufferizePadOp>(typeConverter, context);
target.addIllegalOp<tcp::PadOp>();
target.addLegalDialect<linalg::LinalgDialect>(); target.addLegalDialect<linalg::LinalgDialect>();
target.addLegalDialect<StandardOpsDialect>(); target.addLegalDialect<StandardOpsDialect>();

View File

@ -25,3 +25,35 @@ func @tcp_splatted(%arg0: f32, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
%0 = tcp.splatted %arg0, %arg1 : (f32, tensor<?xindex>) -> tensor<?x?xf32> %0 = tcp.splatted %arg0, %arg1 : (f32, tensor<?xindex>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32> return %0 : tensor<?x?xf32>
} }
// CHECK-LABEL: func @tcp_pad(
// CHECK-SAME: %[[TENSOR:[a-zA-Z0-9]+]]: tensor<?xf32>,
// CHECK-SAME: %[[LOWER_EXPANSION:[a-zA-Z0-9]+]]: tensor<?xindex>,
// CHECK-SAME: %[[UPPER_EXPANSION:[a-zA-Z0-9]+]]: tensor<?xindex>,
// CHECK-SAME: %[[FILL_VAL:[a-zA-Z0-9]+]]: f32) -> tensor<?xf32> {
// CHECK: %[[TENSOR_MREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<?xf32>
// CHECK: %[[LOWER_EXPANSION_MREF:.*]] = tensor_to_memref %[[LOWER_EXPANSION]] : memref<?xindex>
// CHECK: %[[UPPER_EXPANSION_MREF:.*]] = tensor_to_memref %[[UPPER_EXPANSION]] : memref<?xindex>
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[LOWER_EXTENT_D1:.*]] = tensor.extract %[[LOWER_EXPANSION]][%[[C0]]] : tensor<?xindex>
// CHECK: %[[UPPER_EXTENT_D1:.*]] = tensor.extract %[[UPPER_EXPANSION]][%[[C0]]] : tensor<?xindex>
// CHECK: %[[C0_0:.*]] = constant 0 : index
// CHECK: %[[D1:.*]] = dim %[[TENSOR]], %[[C0_0]] : tensor<?xf32>
// CHECK: %[[D1_EXPANSION:.*]] = addi %[[LOWER_EXTENT_D1]], %[[UPPER_EXTENT_D1]] : index
// CHECK: %[[D1_OUT:.*]] = addi %[[D1_EXPANSION]], %[[D1]] : index
// CHECK: %[[D1_OUT_TENSOR:.*]] = tensor_from_elements %[[D1_OUT]] : tensor<1xindex>
// CHECK: %[[D1_OUT_MREF:.*]] = refback.alloc_memref %[[D1_OUT_TENSOR]] : memref<?xf32>
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[C0_1:.*]] = constant 0 : index
// CHECK: %[[LOWER_EXTENT_D1_1:.*]] = tensor.extract %[[LOWER_EXPANSION]][%[[C0_1]]] : tensor<?xindex>
// CHECK: %[[C0_2:.*]] = constant 0 : index
// CHECK: %[[D1_1:.*]] = dim %[[TENSOR]], %[[C0_2]] : tensor<?xf32>
// CHECK: linalg.fill(%[[D1_OUT_MREF]], %[[FILL_VAL]]) : memref<?xf32>, f32
// CHECK: %[[SUBVIEW:.*]] = subview %[[D1_OUT_MREF]][%[[LOWER_EXTENT_D1_1]]] [%[[D1_1]]] [%[[C1]]] : memref<?xf32> to memref<?xf32, #map>
// CHECK: linalg.copy(%0, %[[SUBVIEW]]) : memref<?xf32>, memref<?xf32, #map>
// CHECK: %[[RESULT_TENSOR:.*]] = tensor_load %[[D1_OUT_MREF]] : memref<?xf32>
// CHECK: return %[[RESULT_TENSOR]] : tensor<?xf32>
func @tcp_pad(%arg0: tensor<?xf32>, %arg1: tensor<?xindex>, %arg2: tensor<?xindex>, %arg3: f32) -> tensor<?xf32> {
%0 = tcp.pad %arg0, %arg1, %arg2, %arg3 : (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>, f32) -> tensor<?xf32>
return %0 : tensor<?xf32>
}

View File

@ -13,3 +13,10 @@ func @splatted(%arg0: f32, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
%0 = tcp.splatted %arg0, %arg1 : (f32, tensor<?xindex>) -> tensor<?x?xf32> %0 = tcp.splatted %arg0, %arg1 : (f32, tensor<?xindex>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32> return %0 : tensor<?x?xf32>
} }
// CHECK-LABEL: @pad
func @pad(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?xindex>, %arg2: tensor<?xindex>, %arg3: f32) -> tensor<?x?x?x?xf32> {
// CHECK: tcp.pad
%0 = tcp.pad %arg0, %arg1, %arg2, %arg3 : (tensor<?x?x?x?xf32>, tensor<?xindex>, tensor<?xindex>, f32) -> tensor<?x?x?x?xf32>
return %0 : tensor<?x?x?x?xf32>
}

View File

@ -42,7 +42,8 @@ convertAttrToTensor(Attribute attr) {
auto extents = llvm::to_vector<6>(llvm::map_range( auto extents = llvm::to_vector<6>(llvm::map_range(
type.getShape(), [](int64_t x) { return static_cast<std::int32_t>(x); })); type.getShape(), [](int64_t x) { return static_cast<std::int32_t>(x); }));
auto elementType = type.getElementType(); auto elementType = type.getElementType();
if (auto denseFp = attr.dyn_cast<DenseFPElementsAttr>()) { auto denseFp = attr.dyn_cast<DenseFPElementsAttr>();
if (denseFp) {
if (elementType.isF32()) { if (elementType.isF32()) {
auto values = llvm::to_vector<100>(llvm::map_range( auto values = llvm::to_vector<100>(llvm::map_range(
denseFp, [](APFloat f) { return f.convertToFloat(); })); denseFp, [](APFloat f) { return f.convertToFloat(); }));
@ -50,6 +51,8 @@ convertAttrToTensor(Attribute attr) {
refbackrt::ArrayRef<std::int32_t>(extents.data(), extents.size()), refbackrt::ArrayRef<std::int32_t>(extents.data(), extents.size()),
refbackrt::ElementType::F32, static_cast<void *>(values.data())); refbackrt::ElementType::F32, static_cast<void *>(values.data()));
} }
} else {
return make_string_error("unhandled argument; must be dense floating-point");
} }
return make_string_error("unhandled argument"); return make_string_error("unhandled argument");
} }