Add bufferization pass for TMTensor ops

The pass is mostly borrowed from the BufferizeAnyLinalgOp pass in mlir
upstream with some minor changes. At a high level, it's a naive partial
bufferization pass which allocate new buffers for all the output
tensors. The initial value of an output buffer is copied from the
original buffer if there are uses of the original value.

One difference from linalg bufferization pass is the way to tell if
the loop body uses the init value of output operand. For TMTensor ops,
it differs from op to op because the payload region doesn't represent
the entire loop body.
pull/639/head
Yi Zhang 2022-02-25 18:04:33 -05:00
parent 5ec70c175d
commit 486f95e84f
10 changed files with 324 additions and 8 deletions

View File

@ -287,8 +287,8 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
>,
InterfaceMethod<
/*desc=*/[{
Return true if the payload uses the value loaded from `opOperand`. This
is useful to avoid loading from "write-only" memory that may be
Return true if the loop body uses the value loaded from `opOperand`.
This is useful to avoid loading from "write-only" memory that may be
uninitialized, as well as properly cloning "read-write" operands.
}],
/*retTy=*/"bool",

View File

@ -46,9 +46,11 @@ class TMTensor_Op<string mnemonic, list<Trait> traits = []> :
// Non-structured ops
//===----------------------------------------------------------------------===//
def TMTensor_ScanOp : TMTensor_Op<"scan"
,[DeclareOpInterfaceMethods<ScalarLoopOpInterface,
["generateScalarImplementation"]>]> {
def TMTensor_ScanOp : TMTensor_Op<"scan",
[DeclareOpInterfaceMethods<TMTensorInterface,
["payloadUsesValueFromOperand"]>,
DeclareOpInterfaceMethods<ScalarLoopOpInterface,
["generateScalarImplementation"]>]> {
let summary = "Scan operator";
let description = [{
Computes the inclusive/exclusive scan along a given dimension.
@ -97,7 +99,9 @@ def TMTensor_ScanOp : TMTensor_Op<"scan"
}
def TMTensor_ScatterOp : TMTensor_Op<"scatter",
[DeclareOpInterfaceMethods<ScalarLoopOpInterface,
[DeclareOpInterfaceMethods<TMTensorInterface,
["payloadUsesValueFromOperand"]>,
DeclareOpInterfaceMethods<ScalarLoopOpInterface,
["generateScalarImplementation"]>]> {
let summary = "Scatter operator";
let description = [{

View File

@ -17,6 +17,7 @@ namespace torch {
namespace TMTensor {
std::unique_ptr<OperationPass<FuncOp>> createTMTensorToLoopsPass();
std::unique_ptr<OperationPass<FuncOp>> createTMTensorBufferizePass();
void registerPasses();

View File

@ -13,9 +13,14 @@
include "mlir/Pass/PassBase.td"
def TMTensorToLoops :
Pass<"torch-mlir-tm-tensor-to-loops", "FuncOp"> {
Pass<"tm-tensor-to-loops", "FuncOp"> {
let summary = "Convert TMTensor ops to loops and Linalg ops.";
let constructor = "mlir::torch::TMTensor::createTMTensorToLoopsPass()";
}
def TMTensorBufferize : Pass<"tm-tensor-bufferize", "FuncOp"> {
let summary = "Bufferize the TMTensor dialect";
let constructor = "mlir::torch::TMTensor::createTMTensorBufferizePass()";
}
#endif // TORCH_MLIR_DIALECT_TMTENSOR_PASSES

View File

@ -166,6 +166,19 @@ SmallVector<StringRef> ScanOp::getLoopIteratorTypes() {
return iteratorTypes;
}
bool ScanOp::payloadUsesValueFromOperand(OpOperand *opOperand) {
Value operand = opOperand->get();
if (operand == accumulator())
return !inclusive();
else if (operand == output())
return false;
else {
assert(operand == input() &&
"operand must belong to the current tm_tensor.scan op");
return true;
}
}
// Generates naive scalar implementation of scan for a given operator f.
// For inclusive,
// output[0] = input[0]
@ -385,6 +398,27 @@ SmallVector<StringRef> ScatterOp::getLoopIteratorTypes() {
return iteratorTypes;
}
bool ScatterOp::payloadUsesValueFromOperand(OpOperand *opOperand) {
unsigned bbArgNumber;
Value operand = opOperand->get();
if (operand == updates())
bbArgNumber = 0; // block arg 0 is `update`.
else if (operand == original())
bbArgNumber = 1; // block arg 1 is `original`.
else {
assert(operand == indices() &&
"operand must belong to the current tm_tensor.scatter op");
return true;
}
assert(this->getOperation()->getNumRegions() == 1 &&
"unexpected "
"missing region (calling `payloadUsesValueFromOperand` on "
"manually defined named TMTensor op?)");
Block &block = this->getOperation()->getRegion(0).front();
return !block.getArgument(bbArgNumber).use_empty();
}
SmallVector<Range> ScatterOp::getIterationDomain(OpBuilder &builder) {
Location loc = getLoc();
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);

View File

@ -0,0 +1,155 @@
//===- Bufferize.cpp - Bufferization of tmtensor ops ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/Pass/Pass.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/PassDetail.h"
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h"
using namespace ::mlir;
using namespace ::mlir::torch::TMTensor;
static Value cloneMemref(Location loc, Value memref, OpBuilder &b) {
auto memrefType = memref.getType().cast<MemRefType>();
auto alloc = b.create<memref::AllocOp>(
loc, memrefType, linalg::getDynOperands(loc, memref, b));
b.create<memref::CopyOp>(loc, memref, alloc);
return alloc;
}
static LogicalResult
allocateBuffersForResults(Location loc, TMTensorOp tmtensorOp,
ValueRange outputs,
SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) {
// Lazily compute loopRanges.
SmallVector<Range, 4> loopRanges;
// Allocate a buffer for every tensor result.
assert(tmtensorOp.getNumOutputs() == tmtensorOp->getNumResults());
for (const auto &en : llvm::enumerate(tmtensorOp->getResultTypes())) {
size_t resultIndex = en.index();
Type resultType = en.value();
auto tensorType = resultType.dyn_cast<RankedTensorType>();
if (tensorType == nullptr) {
tmtensorOp.emitOpError()
<< "tensor to buffer conversion expects ranked tensor results";
return failure();
}
auto tensorShape = tensorType.getShape();
auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType());
Value resultTensor = outputs[resultIndex];
// Clone output buffers whose value is actually used.
OpOperand *tiedOpOperand = tmtensorOp.getOutputOperand(resultIndex);
if (tmtensorOp.payloadUsesValueFromOperand(tiedOpOperand)) {
resultBuffers.push_back(cloneMemref(loc, resultTensor, b));
continue;
}
// Allocate buffers for statically-shaped results.
if (memrefType.hasStaticShape()) {
resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType));
continue;
}
resultBuffers.push_back(b.create<memref::AllocOp>(
loc, memrefType, linalg::getDynOperands(loc, resultTensor, b)));
}
return success();
}
/// Create TMTensor op on buffers given the original tensor-based operation and
/// the buffers for the outputs.
static TMTensorOp createTMTensorOpOnBuffers(ConversionPatternRewriter &rewriter,
TMTensorOp tmtensorOp,
ValueRange inputs,
ValueRange outputs) {
SmallVector<Value, 8> newOperands = inputs;
newOperands.append(outputs.begin(), outputs.end());
return tmtensorOp.clone(rewriter, tmtensorOp->getLoc(), {}, newOperands);
}
/// Generic conversion pattern that matches any TMTensorOp. This avoids template
/// instantiating one pattern for each TMTensorOp.
class BufferizeAnyTMTensorOp : public OpInterfaceConversionPattern<TMTensorOp> {
public:
using OpInterfaceConversionPattern<TMTensorOp>::OpInterfaceConversionPattern;
LogicalResult
matchAndRewrite(TMTensorOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
Location loc = op.getLoc();
SmallVector<Value, 2> newOutputBuffers;
SmallVector<Value> outputs(operands.begin() + op.getNumInputs(),
operands.end());
if (failed(allocateBuffersForResults(loc, op, outputs, newOutputBuffers,
rewriter))) {
return op.emitOpError()
<< "Failed to allocate buffers for tensor results.";
}
SmallVector<Value> inputs(operands.begin(),
operands.begin() + op.getNumInputs());
createTMTensorOpOnBuffers(rewriter, op, inputs, newOutputBuffers);
// Replace the results of the old op with the new output buffers.
rewriter.replaceOp(op, newOutputBuffers);
return success();
}
};
namespace {
/// Converts TMTensor operations that work on tensor-type operands or results to
/// work on buffers.
struct TMTensorBufferizePass
: public TMTensorBufferizeBase<TMTensorBufferizePass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
torch::TMTensor::TMTensorDialect>();
}
void runOnOperation() override {
MLIRContext &context = getContext();
ConversionTarget target(context);
bufferization::BufferizeTypeConverter typeConverter;
// Mark all Standard operations legal.
target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect,
StandardOpsDialect, tensor::TensorDialect>();
// Mark all TMTensor operations illegal as long as they work on tensors.
auto isLegalOperation = [&](Operation *op) {
return typeConverter.isLegal(op);
};
target.addDynamicallyLegalDialect<TMTensorDialect>(isLegalOperation);
RewritePatternSet patterns(&context);
patterns.add<BufferizeAnyTMTensorOp>(typeConverter, patterns.getContext());
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
torch::TMTensor::createTMTensorBufferizePass() {
return std::make_unique<TMTensorBufferizePass>();
}

View File

@ -1,5 +1,6 @@
add_mlir_library(TorchMLIRTMTensorPasses
ConvertToLoops.cpp
Bufferize.cpp
Passes.cpp
DEPENDS

View File

@ -0,0 +1,114 @@
// RUN: torch-mlir-dialects-opt -split-input-file -tm-tensor-bufferize %s | FileCheck %s
// -----
// CHECK-LABEL: func @scan_1d_inclusive(
// CHECK-SAME: %[[IN_TENSOR:.*]]: tensor<128xi32>, %[[OUT_TENSOR:.*]]: tensor<128xi32>,
// CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
// CHECK: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32>
// CHECK: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32>
// CHECK: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref<i32>
// CHECK: tm_tensor.scan dimension(0) inclusive(true) ins(%[[IN_MEMREF]] : memref<128xi32>)
// CHECK-SAME: outs(%[[OUT_MEMREF_NEW]], %[[ACC_MEMREF_NEW]] : memref<128xi32>, memref<i32>) {
// CHECK: ^bb0(%[[OUT_PREV_ELEMENT:.*]]: i32, %[[IN_ELEMENT:.*]]: i32):
// CHECK: %[[OUT_CURRENT_ELEMENT:.*]] = arith.addi %[[OUT_PREV_ELEMENT]], %[[IN_ELEMENT]] : i32
// CHECK: tm_tensor.yield %[[OUT_CURRENT_ELEMENT]] : i32
// CHECK: }
// CHECK: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32>
// CHECK: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref<i32>
// CHECK: return %[[OUT_TENSOR_NEW]], %[[ACC_TENSOR_NEW]] : tensor<128xi32>, tensor<i32>
func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
%ret_out, %ret_acc = tm_tensor.scan dimension(0) inclusive(true)
ins(%in : tensor<128xi32>) outs(%out, %acc: tensor<128xi32>, tensor<i32>) {
^bb0(%arg0 : i32, %arg1 : i32):
%sum = arith.addi %arg0, %arg1 : i32
tm_tensor.yield %sum : i32
} -> tensor<128xi32>, tensor<i32>
return %ret_out, %ret_acc: tensor<128xi32>, tensor<i32>
}
// -----
// CHECK-LABEL: func @scan_1d_exclusive(
// CHECK-SAME: %[[IN_TENSOR:.*]]: tensor<128xi32>, %[[OUT_TENSOR:.*]]: tensor<128xi32>,
// CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
// CHECK: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32>
// CHECK: %[[ACC_MEMREF:.*]] = bufferization.to_memref %[[ACC_TENSOR]] : memref<i32>
// CHECK: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32>
// CHECK: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref<i32>
// CHECK: memref.copy %[[ACC_MEMREF]], %[[ACC_MEMREF_NEW]] : memref<i32> to memref<i32>
// CHECK: tm_tensor.scan dimension(0) inclusive(false) ins(%[[IN_MEMREF]] : memref<128xi32>)
// CHECK-SAME: outs(%[[OUT_MEMREF_NEW]], %[[ACC_MEMREF_NEW]] : memref<128xi32>, memref<i32>) {
// CHECK: ^bb0(%[[OUT_PREV_ELEMENT:.*]]: i32, %[[IN_ELEMENT:.*]]: i32):
// CHECK: %[[OUT_CURRENT_ELEMENT:.*]] = arith.addi %[[OUT_PREV_ELEMENT]], %[[IN_ELEMENT]] : i32
// CHECK: tm_tensor.yield %[[OUT_CURRENT_ELEMENT]] : i32
// CHECK: }
// CHECK: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32>
// CHECK: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref<i32>
// CHECK: return %[[OUT_TENSOR_NEW]], %[[ACC_TENSOR_NEW]] : tensor<128xi32>, tensor<i32>
func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
%ret_out, %ret_acc = tm_tensor.scan dimension(0) inclusive(false)
ins(%in : tensor<128xi32>) outs(%out, %acc: tensor<128xi32>, tensor<i32>) {
^bb0(%arg0 : i32, %arg1 : i32):
%sum = arith.addi %arg0, %arg1 : i32
tm_tensor.yield %sum : i32
} -> tensor<128xi32>, tensor<i32>
return %ret_out, %ret_acc: tensor<128xi32>, tensor<i32>
}
// -----
// CHECK-LABEL: func @scatter_update_scalar_1D(
// CHECK-SAME: %[[ORIG_TENSOR:.*]]: tensor<8xi32>,
// CHECK-SAME: %[[INDICES_TENSOR:.*]]: tensor<3x1xi32>,
// CHECK-SAME: %[[UPDATES_TENSOR:.*]]: tensor<3xi32>) -> tensor<8xi32> {
// CHECK: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : memref<3xi32>
// CHECK: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : memref<3x1xi32>
// CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32>
// CHECK: tm_tensor.scatter unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]]
// CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) {
// CHECK: ^bb0(%[[UPDATE_SCALAR:.*]]: i32, %[[ORIG_SCALAR:.*]]: i32):
// CHECK: tm_tensor.yield %[[UPDATE_SCALAR]] : i32
// CHECK: }
// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32>
// CHECK: return %[[OUT_TENSOR]] : tensor<8xi32>
func @scatter_update_scalar_1D(
%original: tensor<8xi32>, %indices: tensor<3x1xi32>,
%updates: tensor<3xi32>) -> tensor<8xi32> {
%0 = tm_tensor.scatter unique_indices(true)
ins(%updates, %indices : tensor<3xi32>, tensor<3x1xi32>)
outs(%original : tensor<8xi32>) {
^bb0(%update: i32, %orig: i32): // no predecessors
tm_tensor.yield %update: i32
} -> tensor<8xi32>
return %0 : tensor<8xi32>
}
// CHECK-LABEL: func @scatter_add_scalar_1D(
// CHECK-SAME: %[[ORIG_TENSOR:.*]]: tensor<8xi32>,
// CHECK-SAME: %[[INDICES_TENSOR:.*]]: tensor<3x1xi32>,
// CHECK-SAME: %[[UPDATES_TENSOR:.*]]: tensor<3xi32>) -> tensor<8xi32> {
// CHECK: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : memref<3xi32>
// CHECK: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : memref<3x1xi32>
// CHECK: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32>
// CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32>
// CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32>
// CHECK: tm_tensor.scatter unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]]
// CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) {
// CHECK: ^bb0(%[[UPDATE_SCALAR:.*]]: i32, %[[ORIG_SCALAR:.*]]: i32):
// CHECK: %[[CST1:.*]] = arith.constant 1 : i32
// CHECK: %[[ADD:.*]] = arith.addi %[[ORIG_SCALAR]], %[[CST1]] : i32
// CHECK: tm_tensor.yield %[[ADD]] : i32
// CHECK: }
// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32>
// CHECK: return %[[OUT_TENSOR]] : tensor<8xi32>
func @scatter_add_scalar_1D(
%original: tensor<8xi32>, %indices: tensor<3x1xi32>,
%updates: tensor<3xi32>) -> tensor<8xi32> {
%0 = tm_tensor.scatter unique_indices(true)
ins(%updates, %indices : tensor<3xi32>, tensor<3x1xi32>)
outs(%original : tensor<8xi32>) {
^bb0(%update: i32, %orig: i32): // no predecessors
%cst1 = arith.constant 1: i32
%add = arith.addi %orig, %cst1: i32
tm_tensor.yield %add: i32
} -> tensor<8xi32>
return %0 : tensor<8xi32>
}

View File

@ -1,4 +1,4 @@
// RUN: torch-mlir-dialects-opt -split-input-file -torch-mlir-tm-tensor-to-loops %s | FileCheck %s
// RUN: torch-mlir-dialects-opt -split-input-file -tm-tensor-to-loops %s | FileCheck %s
func @scan_1d_inclusive(%0: memref<128xi32>, %1: memref<128xi32>) {
%c0 = memref.alloc() : memref<i32>

View File

@ -166,6 +166,7 @@ class RefBackendInvoker:
LOWERING_PIPELINE = ",".join([
# Bufferize.
"builtin.func(scf-bufferize)",
"builtin.func(tm-tensor-bufferize)",
"builtin.func(linalg-bufferize)",
"builtin.func(refback-munge-memref-copy)",
"func-bufferize",
@ -183,6 +184,7 @@ LOWERING_PIPELINE = ",".join([
# global seed used in stateful rng.
"refback-insert-rng-globals",
# Lower to LLVM
"builtin.func(tm-tensor-to-loops)",
"builtin.func(convert-linalg-to-loops)",
"builtin.func(lower-affine)",
"convert-scf-to-cf",