mirror of https://github.com/llvm/torch-mlir
Import TCP pad
parent
689b40c7a6
commit
fc650c9447
|
@ -54,4 +54,22 @@ def TCP_SplattedOp : TCP_Op<"splatted"> {
|
|||
let assemblyFormat = "$splatVal `,` $shape attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def TCP_PadOp : TCP_Op<"pad"> {
|
||||
let summary = "Pads a tensor with a fill value";
|
||||
let description = [{
|
||||
Pads a tensor with `fillVal` along the borders of each dimension according to `lowerExpansion` and `upperExpansion`. Note that this op is unmanaged, meaning that it assumes its operands and their shapes are valid.
|
||||
|
||||
The tensors have dimensions:
|
||||
- operand: [D1, D2, ..., DN]
|
||||
- lowerExpansion: [L1, L2, ..., LN]
|
||||
- upperExpansion: [U1, U2, ..., UN]
|
||||
- fillVal: scalar
|
||||
- result: [D1+L1+U1, D2+L2+U2, ..., DN+LN+UN]
|
||||
}];
|
||||
let arguments = (ins AnyRankedTensor:$operand, Shape_ExtentTensorType:$lowerExpansion, Shape_ExtentTensorType:$upperExpansion, AnyType:$fillVal);
|
||||
let results = (outs AnyRankedTensor:$result);
|
||||
|
||||
let assemblyFormat = "$operand `,` $lowerExpansion `,` $upperExpansion `,` $fillVal attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
#endif // TCP_OPS
|
||||
|
|
|
@ -37,6 +37,22 @@ static SmallVector<Value, 6> bypassResultShapes(Operation &op) {
|
|||
return {splatted.shape()};
|
||||
}
|
||||
|
||||
if (auto pad = dyn_cast<tcp::PadOp>(op)) {
|
||||
SmallVector<Value, 6> outDims;
|
||||
auto inputType = pad.operand().getType().cast<RankedTensorType>();
|
||||
for (int i = 0, e = inputType.getRank(); i < e; i++) {
|
||||
Value dimIndex = builder.create<ConstantIndexOp>(op.getLoc(), i);
|
||||
Value lowerExpansion = builder.create<tensor::ExtractOp>(op.getLoc(), pad.lowerExpansion(), ValueRange({dimIndex}));
|
||||
Value upperExpansion = builder.create<tensor::ExtractOp>(op.getLoc(), pad.upperExpansion(), ValueRange({dimIndex}));
|
||||
Value operandDim = builder.create<DimOp>(op.getLoc(), pad.operand(), i);
|
||||
Value totalExpansion = builder.create<AddIOp>(op.getLoc(), lowerExpansion, upperExpansion);
|
||||
Value outDim = builder.create<AddIOp>(op.getLoc(), totalExpansion, operandDim);
|
||||
outDims.push_back(outDim);
|
||||
}
|
||||
Value outDimTensor = builder.create<TensorFromElementsOp>(op.getLoc(), ValueRange(outDims));
|
||||
return {outDimTensor};
|
||||
}
|
||||
|
||||
// No shape transfer function.
|
||||
return {};
|
||||
}
|
||||
|
@ -158,6 +174,39 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class BufferizePadOp : public OpConversionPattern<tcp::PadOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(tcp::PadOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto resultsOrFailure = allocateResults(op, rewriter, op.getLoc());
|
||||
if (failed(resultsOrFailure))
|
||||
return failure();
|
||||
auto results = *resultsOrFailure;
|
||||
auto c1 = rewriter.create<ConstantOp>(op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
|
||||
SmallVector<Value, 6> offsets, sizes, strides;
|
||||
auto resultType = op.getType().cast<RankedTensorType>();
|
||||
for (int i = 0, e = resultType.getRank(); i < e; i++) {
|
||||
Value dimIndex = rewriter.create<ConstantIndexOp>(op.getLoc(), i);
|
||||
Value offset = rewriter.create<tensor::ExtractOp>(op.getLoc(), op.lowerExpansion(), ValueRange({dimIndex}));
|
||||
Value size = rewriter.create<DimOp>(op.getLoc(), op.operand(), i);
|
||||
Value stride = c1;
|
||||
offsets.push_back(offset);
|
||||
sizes.push_back(size);
|
||||
strides.push_back(stride);
|
||||
}
|
||||
rewriter.create<linalg::FillOp>(op.getLoc(), results[0], op.fillVal());
|
||||
auto unpadded = rewriter.create<SubViewOp>(op.getLoc(), results[0], ValueRange(offsets), ValueRange(sizes), ValueRange(strides));
|
||||
Value inputMemref = operands[0];
|
||||
rewriter.create<linalg::CopyOp>(op.getLoc(), inputMemref, unpadded);
|
||||
rewriter.replaceOp(op, results);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class TCPBufferizePass : public TCPBufferizeBase<TCPBufferizePass> {
|
||||
void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
|
||||
|
@ -185,6 +234,8 @@ class TCPBufferizePass : public TCPBufferizeBase<TCPBufferizePass> {
|
|||
target.addIllegalOp<tcp::BroadcastToOp>();
|
||||
patterns.insert<BufferizeSplattedOp>(typeConverter, context);
|
||||
target.addIllegalOp<tcp::SplattedOp>();
|
||||
patterns.insert<BufferizePadOp>(typeConverter, context);
|
||||
target.addIllegalOp<tcp::PadOp>();
|
||||
|
||||
target.addLegalDialect<linalg::LinalgDialect>();
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
|
|
|
@ -25,3 +25,35 @@ func @tcp_splatted(%arg0: f32, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
|
|||
%0 = tcp.splatted %arg0, %arg1 : (f32, tensor<?xindex>) -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tcp_pad(
|
||||
// CHECK-SAME: %[[TENSOR:[a-zA-Z0-9]+]]: tensor<?xf32>,
|
||||
// CHECK-SAME: %[[LOWER_EXPANSION:[a-zA-Z0-9]+]]: tensor<?xindex>,
|
||||
// CHECK-SAME: %[[UPPER_EXPANSION:[a-zA-Z0-9]+]]: tensor<?xindex>,
|
||||
// CHECK-SAME: %[[FILL_VAL:[a-zA-Z0-9]+]]: f32) -> tensor<?xf32> {
|
||||
// CHECK: %[[TENSOR_MREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<?xf32>
|
||||
// CHECK: %[[LOWER_EXPANSION_MREF:.*]] = tensor_to_memref %[[LOWER_EXPANSION]] : memref<?xindex>
|
||||
// CHECK: %[[UPPER_EXPANSION_MREF:.*]] = tensor_to_memref %[[UPPER_EXPANSION]] : memref<?xindex>
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[LOWER_EXTENT_D1:.*]] = tensor.extract %[[LOWER_EXPANSION]][%[[C0]]] : tensor<?xindex>
|
||||
// CHECK: %[[UPPER_EXTENT_D1:.*]] = tensor.extract %[[UPPER_EXPANSION]][%[[C0]]] : tensor<?xindex>
|
||||
// CHECK: %[[C0_0:.*]] = constant 0 : index
|
||||
// CHECK: %[[D1:.*]] = dim %[[TENSOR]], %[[C0_0]] : tensor<?xf32>
|
||||
// CHECK: %[[D1_EXPANSION:.*]] = addi %[[LOWER_EXTENT_D1]], %[[UPPER_EXTENT_D1]] : index
|
||||
// CHECK: %[[D1_OUT:.*]] = addi %[[D1_EXPANSION]], %[[D1]] : index
|
||||
// CHECK: %[[D1_OUT_TENSOR:.*]] = tensor_from_elements %[[D1_OUT]] : tensor<1xindex>
|
||||
// CHECK: %[[D1_OUT_MREF:.*]] = refback.alloc_memref %[[D1_OUT_TENSOR]] : memref<?xf32>
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[C0_1:.*]] = constant 0 : index
|
||||
// CHECK: %[[LOWER_EXTENT_D1_1:.*]] = tensor.extract %[[LOWER_EXPANSION]][%[[C0_1]]] : tensor<?xindex>
|
||||
// CHECK: %[[C0_2:.*]] = constant 0 : index
|
||||
// CHECK: %[[D1_1:.*]] = dim %[[TENSOR]], %[[C0_2]] : tensor<?xf32>
|
||||
// CHECK: linalg.fill(%[[D1_OUT_MREF]], %[[FILL_VAL]]) : memref<?xf32>, f32
|
||||
// CHECK: %[[SUBVIEW:.*]] = subview %[[D1_OUT_MREF]][%[[LOWER_EXTENT_D1_1]]] [%[[D1_1]]] [%[[C1]]] : memref<?xf32> to memref<?xf32, #map>
|
||||
// CHECK: linalg.copy(%0, %[[SUBVIEW]]) : memref<?xf32>, memref<?xf32, #map>
|
||||
// CHECK: %[[RESULT_TENSOR:.*]] = tensor_load %[[D1_OUT_MREF]] : memref<?xf32>
|
||||
// CHECK: return %[[RESULT_TENSOR]] : tensor<?xf32>
|
||||
func @tcp_pad(%arg0: tensor<?xf32>, %arg1: tensor<?xindex>, %arg2: tensor<?xindex>, %arg3: f32) -> tensor<?xf32> {
|
||||
%0 = tcp.pad %arg0, %arg1, %arg2, %arg3 : (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>, f32) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
|
|
@ -13,3 +13,10 @@ func @splatted(%arg0: f32, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
|
|||
%0 = tcp.splatted %arg0, %arg1 : (f32, tensor<?xindex>) -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @pad
|
||||
func @pad(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?xindex>, %arg2: tensor<?xindex>, %arg3: f32) -> tensor<?x?x?x?xf32> {
|
||||
// CHECK: tcp.pad
|
||||
%0 = tcp.pad %arg0, %arg1, %arg2, %arg3 : (tensor<?x?x?x?xf32>, tensor<?xindex>, tensor<?xindex>, f32) -> tensor<?x?x?x?xf32>
|
||||
return %0 : tensor<?x?x?x?xf32>
|
||||
}
|
||||
|
|
|
@ -42,7 +42,8 @@ convertAttrToTensor(Attribute attr) {
|
|||
auto extents = llvm::to_vector<6>(llvm::map_range(
|
||||
type.getShape(), [](int64_t x) { return static_cast<std::int32_t>(x); }));
|
||||
auto elementType = type.getElementType();
|
||||
if (auto denseFp = attr.dyn_cast<DenseFPElementsAttr>()) {
|
||||
auto denseFp = attr.dyn_cast<DenseFPElementsAttr>();
|
||||
if (denseFp) {
|
||||
if (elementType.isF32()) {
|
||||
auto values = llvm::to_vector<100>(llvm::map_range(
|
||||
denseFp, [](APFloat f) { return f.convertToFloat(); }));
|
||||
|
@ -50,6 +51,8 @@ convertAttrToTensor(Attribute attr) {
|
|||
refbackrt::ArrayRef<std::int32_t>(extents.data(), extents.size()),
|
||||
refbackrt::ElementType::F32, static_cast<void *>(values.data()));
|
||||
}
|
||||
} else {
|
||||
return make_string_error("unhandled argument; must be dense floating-point");
|
||||
}
|
||||
return make_string_error("unhandled argument");
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue