mirror of https://github.com/llvm/torch-mlir
Add bufferization pass for TMTensor ops
The pass is mostly borrowed from the BufferizeAnyLinalgOp pass in mlir upstream with some minor changes. At a high level, it's a naive partial bufferization pass which allocate new buffers for all the output tensors. The initial value of an output buffer is copied from the original buffer if there are uses of the original value. One difference from linalg bufferization pass is the way to tell if the loop body uses the init value of output operand. For TMTensor ops, it differs from op to op because the payload region doesn't represent the entire loop body.pull/639/head
parent
5ec70c175d
commit
486f95e84f
|
@ -287,8 +287,8 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
||||||
>,
|
>,
|
||||||
InterfaceMethod<
|
InterfaceMethod<
|
||||||
/*desc=*/[{
|
/*desc=*/[{
|
||||||
Return true if the payload uses the value loaded from `opOperand`. This
|
Return true if the loop body uses the value loaded from `opOperand`.
|
||||||
is useful to avoid loading from "write-only" memory that may be
|
This is useful to avoid loading from "write-only" memory that may be
|
||||||
uninitialized, as well as properly cloning "read-write" operands.
|
uninitialized, as well as properly cloning "read-write" operands.
|
||||||
}],
|
}],
|
||||||
/*retTy=*/"bool",
|
/*retTy=*/"bool",
|
||||||
|
|
|
@ -46,9 +46,11 @@ class TMTensor_Op<string mnemonic, list<Trait> traits = []> :
|
||||||
// Non-structured ops
|
// Non-structured ops
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def TMTensor_ScanOp : TMTensor_Op<"scan"
|
def TMTensor_ScanOp : TMTensor_Op<"scan",
|
||||||
,[DeclareOpInterfaceMethods<ScalarLoopOpInterface,
|
[DeclareOpInterfaceMethods<TMTensorInterface,
|
||||||
["generateScalarImplementation"]>]> {
|
["payloadUsesValueFromOperand"]>,
|
||||||
|
DeclareOpInterfaceMethods<ScalarLoopOpInterface,
|
||||||
|
["generateScalarImplementation"]>]> {
|
||||||
let summary = "Scan operator";
|
let summary = "Scan operator";
|
||||||
let description = [{
|
let description = [{
|
||||||
Computes the inclusive/exclusive scan along a given dimension.
|
Computes the inclusive/exclusive scan along a given dimension.
|
||||||
|
@ -97,7 +99,9 @@ def TMTensor_ScanOp : TMTensor_Op<"scan"
|
||||||
}
|
}
|
||||||
|
|
||||||
def TMTensor_ScatterOp : TMTensor_Op<"scatter",
|
def TMTensor_ScatterOp : TMTensor_Op<"scatter",
|
||||||
[DeclareOpInterfaceMethods<ScalarLoopOpInterface,
|
[DeclareOpInterfaceMethods<TMTensorInterface,
|
||||||
|
["payloadUsesValueFromOperand"]>,
|
||||||
|
DeclareOpInterfaceMethods<ScalarLoopOpInterface,
|
||||||
["generateScalarImplementation"]>]> {
|
["generateScalarImplementation"]>]> {
|
||||||
let summary = "Scatter operator";
|
let summary = "Scatter operator";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
|
|
@ -17,6 +17,7 @@ namespace torch {
|
||||||
namespace TMTensor {
|
namespace TMTensor {
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createTMTensorToLoopsPass();
|
std::unique_ptr<OperationPass<FuncOp>> createTMTensorToLoopsPass();
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createTMTensorBufferizePass();
|
||||||
|
|
||||||
void registerPasses();
|
void registerPasses();
|
||||||
|
|
||||||
|
|
|
@ -13,9 +13,14 @@
|
||||||
include "mlir/Pass/PassBase.td"
|
include "mlir/Pass/PassBase.td"
|
||||||
|
|
||||||
def TMTensorToLoops :
|
def TMTensorToLoops :
|
||||||
Pass<"torch-mlir-tm-tensor-to-loops", "FuncOp"> {
|
Pass<"tm-tensor-to-loops", "FuncOp"> {
|
||||||
let summary = "Convert TMTensor ops to loops and Linalg ops.";
|
let summary = "Convert TMTensor ops to loops and Linalg ops.";
|
||||||
let constructor = "mlir::torch::TMTensor::createTMTensorToLoopsPass()";
|
let constructor = "mlir::torch::TMTensor::createTMTensorToLoopsPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TMTensorBufferize : Pass<"tm-tensor-bufferize", "FuncOp"> {
|
||||||
|
let summary = "Bufferize the TMTensor dialect";
|
||||||
|
let constructor = "mlir::torch::TMTensor::createTMTensorBufferizePass()";
|
||||||
|
}
|
||||||
|
|
||||||
#endif // TORCH_MLIR_DIALECT_TMTENSOR_PASSES
|
#endif // TORCH_MLIR_DIALECT_TMTENSOR_PASSES
|
||||||
|
|
|
@ -166,6 +166,19 @@ SmallVector<StringRef> ScanOp::getLoopIteratorTypes() {
|
||||||
return iteratorTypes;
|
return iteratorTypes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool ScanOp::payloadUsesValueFromOperand(OpOperand *opOperand) {
|
||||||
|
Value operand = opOperand->get();
|
||||||
|
if (operand == accumulator())
|
||||||
|
return !inclusive();
|
||||||
|
else if (operand == output())
|
||||||
|
return false;
|
||||||
|
else {
|
||||||
|
assert(operand == input() &&
|
||||||
|
"operand must belong to the current tm_tensor.scan op");
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Generates naive scalar implementation of scan for a given operator f.
|
// Generates naive scalar implementation of scan for a given operator f.
|
||||||
// For inclusive,
|
// For inclusive,
|
||||||
// output[0] = input[0]
|
// output[0] = input[0]
|
||||||
|
@ -385,6 +398,27 @@ SmallVector<StringRef> ScatterOp::getLoopIteratorTypes() {
|
||||||
return iteratorTypes;
|
return iteratorTypes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool ScatterOp::payloadUsesValueFromOperand(OpOperand *opOperand) {
|
||||||
|
unsigned bbArgNumber;
|
||||||
|
Value operand = opOperand->get();
|
||||||
|
if (operand == updates())
|
||||||
|
bbArgNumber = 0; // block arg 0 is `update`.
|
||||||
|
else if (operand == original())
|
||||||
|
bbArgNumber = 1; // block arg 1 is `original`.
|
||||||
|
else {
|
||||||
|
assert(operand == indices() &&
|
||||||
|
"operand must belong to the current tm_tensor.scatter op");
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(this->getOperation()->getNumRegions() == 1 &&
|
||||||
|
"unexpected "
|
||||||
|
"missing region (calling `payloadUsesValueFromOperand` on "
|
||||||
|
"manually defined named TMTensor op?)");
|
||||||
|
Block &block = this->getOperation()->getRegion(0).front();
|
||||||
|
return !block.getArgument(bbArgNumber).use_empty();
|
||||||
|
}
|
||||||
|
|
||||||
SmallVector<Range> ScatterOp::getIterationDomain(OpBuilder &builder) {
|
SmallVector<Range> ScatterOp::getIterationDomain(OpBuilder &builder) {
|
||||||
Location loc = getLoc();
|
Location loc = getLoc();
|
||||||
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
|
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
|
||||||
|
|
|
@ -0,0 +1,155 @@
|
||||||
|
//===- Bufferize.cpp - Bufferization of tmtensor ops ------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
|
||||||
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
|
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
|
||||||
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||||
|
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||||
|
#include "mlir/Dialect/Math/IR/Math.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||||
|
#include "mlir/IR/BuiltinDialect.h"
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/PassDetail.h"
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h"
|
||||||
|
|
||||||
|
using namespace ::mlir;
|
||||||
|
using namespace ::mlir::torch::TMTensor;
|
||||||
|
|
||||||
|
static Value cloneMemref(Location loc, Value memref, OpBuilder &b) {
|
||||||
|
auto memrefType = memref.getType().cast<MemRefType>();
|
||||||
|
auto alloc = b.create<memref::AllocOp>(
|
||||||
|
loc, memrefType, linalg::getDynOperands(loc, memref, b));
|
||||||
|
b.create<memref::CopyOp>(loc, memref, alloc);
|
||||||
|
return alloc;
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult
|
||||||
|
allocateBuffersForResults(Location loc, TMTensorOp tmtensorOp,
|
||||||
|
ValueRange outputs,
|
||||||
|
SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) {
|
||||||
|
// Lazily compute loopRanges.
|
||||||
|
SmallVector<Range, 4> loopRanges;
|
||||||
|
|
||||||
|
// Allocate a buffer for every tensor result.
|
||||||
|
assert(tmtensorOp.getNumOutputs() == tmtensorOp->getNumResults());
|
||||||
|
for (const auto &en : llvm::enumerate(tmtensorOp->getResultTypes())) {
|
||||||
|
size_t resultIndex = en.index();
|
||||||
|
Type resultType = en.value();
|
||||||
|
|
||||||
|
auto tensorType = resultType.dyn_cast<RankedTensorType>();
|
||||||
|
if (tensorType == nullptr) {
|
||||||
|
tmtensorOp.emitOpError()
|
||||||
|
<< "tensor to buffer conversion expects ranked tensor results";
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
auto tensorShape = tensorType.getShape();
|
||||||
|
auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType());
|
||||||
|
Value resultTensor = outputs[resultIndex];
|
||||||
|
|
||||||
|
// Clone output buffers whose value is actually used.
|
||||||
|
OpOperand *tiedOpOperand = tmtensorOp.getOutputOperand(resultIndex);
|
||||||
|
if (tmtensorOp.payloadUsesValueFromOperand(tiedOpOperand)) {
|
||||||
|
resultBuffers.push_back(cloneMemref(loc, resultTensor, b));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate buffers for statically-shaped results.
|
||||||
|
if (memrefType.hasStaticShape()) {
|
||||||
|
resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
resultBuffers.push_back(b.create<memref::AllocOp>(
|
||||||
|
loc, memrefType, linalg::getDynOperands(loc, resultTensor, b)));
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create TMTensor op on buffers given the original tensor-based operation and
|
||||||
|
/// the buffers for the outputs.
|
||||||
|
static TMTensorOp createTMTensorOpOnBuffers(ConversionPatternRewriter &rewriter,
|
||||||
|
TMTensorOp tmtensorOp,
|
||||||
|
ValueRange inputs,
|
||||||
|
ValueRange outputs) {
|
||||||
|
SmallVector<Value, 8> newOperands = inputs;
|
||||||
|
newOperands.append(outputs.begin(), outputs.end());
|
||||||
|
return tmtensorOp.clone(rewriter, tmtensorOp->getLoc(), {}, newOperands);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generic conversion pattern that matches any TMTensorOp. This avoids template
|
||||||
|
/// instantiating one pattern for each TMTensorOp.
|
||||||
|
class BufferizeAnyTMTensorOp : public OpInterfaceConversionPattern<TMTensorOp> {
|
||||||
|
public:
|
||||||
|
using OpInterfaceConversionPattern<TMTensorOp>::OpInterfaceConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(TMTensorOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
SmallVector<Value, 2> newOutputBuffers;
|
||||||
|
|
||||||
|
SmallVector<Value> outputs(operands.begin() + op.getNumInputs(),
|
||||||
|
operands.end());
|
||||||
|
if (failed(allocateBuffersForResults(loc, op, outputs, newOutputBuffers,
|
||||||
|
rewriter))) {
|
||||||
|
return op.emitOpError()
|
||||||
|
<< "Failed to allocate buffers for tensor results.";
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> inputs(operands.begin(),
|
||||||
|
operands.begin() + op.getNumInputs());
|
||||||
|
createTMTensorOpOnBuffers(rewriter, op, inputs, newOutputBuffers);
|
||||||
|
// Replace the results of the old op with the new output buffers.
|
||||||
|
rewriter.replaceOp(op, newOutputBuffers);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
/// Converts TMTensor operations that work on tensor-type operands or results to
|
||||||
|
/// work on buffers.
|
||||||
|
struct TMTensorBufferizePass
|
||||||
|
: public TMTensorBufferizeBase<TMTensorBufferizePass> {
|
||||||
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
|
||||||
|
torch::TMTensor::TMTensorDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
MLIRContext &context = getContext();
|
||||||
|
ConversionTarget target(context);
|
||||||
|
bufferization::BufferizeTypeConverter typeConverter;
|
||||||
|
|
||||||
|
// Mark all Standard operations legal.
|
||||||
|
target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect,
|
||||||
|
StandardOpsDialect, tensor::TensorDialect>();
|
||||||
|
|
||||||
|
// Mark all TMTensor operations illegal as long as they work on tensors.
|
||||||
|
auto isLegalOperation = [&](Operation *op) {
|
||||||
|
return typeConverter.isLegal(op);
|
||||||
|
};
|
||||||
|
target.addDynamicallyLegalDialect<TMTensorDialect>(isLegalOperation);
|
||||||
|
RewritePatternSet patterns(&context);
|
||||||
|
patterns.add<BufferizeAnyTMTensorOp>(typeConverter, patterns.getContext());
|
||||||
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
|
std::move(patterns))))
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
|
torch::TMTensor::createTMTensorBufferizePass() {
|
||||||
|
return std::make_unique<TMTensorBufferizePass>();
|
||||||
|
}
|
|
@ -1,5 +1,6 @@
|
||||||
add_mlir_library(TorchMLIRTMTensorPasses
|
add_mlir_library(TorchMLIRTMTensorPasses
|
||||||
ConvertToLoops.cpp
|
ConvertToLoops.cpp
|
||||||
|
Bufferize.cpp
|
||||||
Passes.cpp
|
Passes.cpp
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
|
|
114
external/llvm-external-projects/torch-mlir-dialects/test/tmtensor/bufferize.mlir
vendored
100644
114
external/llvm-external-projects/torch-mlir-dialects/test/tmtensor/bufferize.mlir
vendored
100644
|
@ -0,0 +1,114 @@
|
||||||
|
// RUN: torch-mlir-dialects-opt -split-input-file -tm-tensor-bufferize %s | FileCheck %s
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: func @scan_1d_inclusive(
|
||||||
|
// CHECK-SAME: %[[IN_TENSOR:.*]]: tensor<128xi32>, %[[OUT_TENSOR:.*]]: tensor<128xi32>,
|
||||||
|
// CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
|
||||||
|
// CHECK: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32>
|
||||||
|
// CHECK: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32>
|
||||||
|
// CHECK: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref<i32>
|
||||||
|
// CHECK: tm_tensor.scan dimension(0) inclusive(true) ins(%[[IN_MEMREF]] : memref<128xi32>)
|
||||||
|
// CHECK-SAME: outs(%[[OUT_MEMREF_NEW]], %[[ACC_MEMREF_NEW]] : memref<128xi32>, memref<i32>) {
|
||||||
|
// CHECK: ^bb0(%[[OUT_PREV_ELEMENT:.*]]: i32, %[[IN_ELEMENT:.*]]: i32):
|
||||||
|
// CHECK: %[[OUT_CURRENT_ELEMENT:.*]] = arith.addi %[[OUT_PREV_ELEMENT]], %[[IN_ELEMENT]] : i32
|
||||||
|
// CHECK: tm_tensor.yield %[[OUT_CURRENT_ELEMENT]] : i32
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32>
|
||||||
|
// CHECK: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref<i32>
|
||||||
|
// CHECK: return %[[OUT_TENSOR_NEW]], %[[ACC_TENSOR_NEW]] : tensor<128xi32>, tensor<i32>
|
||||||
|
func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
|
||||||
|
%ret_out, %ret_acc = tm_tensor.scan dimension(0) inclusive(true)
|
||||||
|
ins(%in : tensor<128xi32>) outs(%out, %acc: tensor<128xi32>, tensor<i32>) {
|
||||||
|
^bb0(%arg0 : i32, %arg1 : i32):
|
||||||
|
%sum = arith.addi %arg0, %arg1 : i32
|
||||||
|
tm_tensor.yield %sum : i32
|
||||||
|
} -> tensor<128xi32>, tensor<i32>
|
||||||
|
return %ret_out, %ret_acc: tensor<128xi32>, tensor<i32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: func @scan_1d_exclusive(
|
||||||
|
// CHECK-SAME: %[[IN_TENSOR:.*]]: tensor<128xi32>, %[[OUT_TENSOR:.*]]: tensor<128xi32>,
|
||||||
|
// CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
|
||||||
|
// CHECK: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32>
|
||||||
|
// CHECK: %[[ACC_MEMREF:.*]] = bufferization.to_memref %[[ACC_TENSOR]] : memref<i32>
|
||||||
|
// CHECK: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32>
|
||||||
|
// CHECK: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref<i32>
|
||||||
|
// CHECK: memref.copy %[[ACC_MEMREF]], %[[ACC_MEMREF_NEW]] : memref<i32> to memref<i32>
|
||||||
|
// CHECK: tm_tensor.scan dimension(0) inclusive(false) ins(%[[IN_MEMREF]] : memref<128xi32>)
|
||||||
|
// CHECK-SAME: outs(%[[OUT_MEMREF_NEW]], %[[ACC_MEMREF_NEW]] : memref<128xi32>, memref<i32>) {
|
||||||
|
// CHECK: ^bb0(%[[OUT_PREV_ELEMENT:.*]]: i32, %[[IN_ELEMENT:.*]]: i32):
|
||||||
|
// CHECK: %[[OUT_CURRENT_ELEMENT:.*]] = arith.addi %[[OUT_PREV_ELEMENT]], %[[IN_ELEMENT]] : i32
|
||||||
|
// CHECK: tm_tensor.yield %[[OUT_CURRENT_ELEMENT]] : i32
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32>
|
||||||
|
// CHECK: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref<i32>
|
||||||
|
// CHECK: return %[[OUT_TENSOR_NEW]], %[[ACC_TENSOR_NEW]] : tensor<128xi32>, tensor<i32>
|
||||||
|
func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
|
||||||
|
%ret_out, %ret_acc = tm_tensor.scan dimension(0) inclusive(false)
|
||||||
|
ins(%in : tensor<128xi32>) outs(%out, %acc: tensor<128xi32>, tensor<i32>) {
|
||||||
|
^bb0(%arg0 : i32, %arg1 : i32):
|
||||||
|
%sum = arith.addi %arg0, %arg1 : i32
|
||||||
|
tm_tensor.yield %sum : i32
|
||||||
|
} -> tensor<128xi32>, tensor<i32>
|
||||||
|
return %ret_out, %ret_acc: tensor<128xi32>, tensor<i32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: func @scatter_update_scalar_1D(
|
||||||
|
// CHECK-SAME: %[[ORIG_TENSOR:.*]]: tensor<8xi32>,
|
||||||
|
// CHECK-SAME: %[[INDICES_TENSOR:.*]]: tensor<3x1xi32>,
|
||||||
|
// CHECK-SAME: %[[UPDATES_TENSOR:.*]]: tensor<3xi32>) -> tensor<8xi32> {
|
||||||
|
// CHECK: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : memref<3xi32>
|
||||||
|
// CHECK: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : memref<3x1xi32>
|
||||||
|
// CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32>
|
||||||
|
// CHECK: tm_tensor.scatter unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]]
|
||||||
|
// CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) {
|
||||||
|
// CHECK: ^bb0(%[[UPDATE_SCALAR:.*]]: i32, %[[ORIG_SCALAR:.*]]: i32):
|
||||||
|
// CHECK: tm_tensor.yield %[[UPDATE_SCALAR]] : i32
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32>
|
||||||
|
// CHECK: return %[[OUT_TENSOR]] : tensor<8xi32>
|
||||||
|
func @scatter_update_scalar_1D(
|
||||||
|
%original: tensor<8xi32>, %indices: tensor<3x1xi32>,
|
||||||
|
%updates: tensor<3xi32>) -> tensor<8xi32> {
|
||||||
|
%0 = tm_tensor.scatter unique_indices(true)
|
||||||
|
ins(%updates, %indices : tensor<3xi32>, tensor<3x1xi32>)
|
||||||
|
outs(%original : tensor<8xi32>) {
|
||||||
|
^bb0(%update: i32, %orig: i32): // no predecessors
|
||||||
|
tm_tensor.yield %update: i32
|
||||||
|
} -> tensor<8xi32>
|
||||||
|
return %0 : tensor<8xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @scatter_add_scalar_1D(
|
||||||
|
// CHECK-SAME: %[[ORIG_TENSOR:.*]]: tensor<8xi32>,
|
||||||
|
// CHECK-SAME: %[[INDICES_TENSOR:.*]]: tensor<3x1xi32>,
|
||||||
|
// CHECK-SAME: %[[UPDATES_TENSOR:.*]]: tensor<3xi32>) -> tensor<8xi32> {
|
||||||
|
// CHECK: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : memref<3xi32>
|
||||||
|
// CHECK: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : memref<3x1xi32>
|
||||||
|
// CHECK: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32>
|
||||||
|
// CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32>
|
||||||
|
// CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32>
|
||||||
|
// CHECK: tm_tensor.scatter unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]]
|
||||||
|
// CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) {
|
||||||
|
// CHECK: ^bb0(%[[UPDATE_SCALAR:.*]]: i32, %[[ORIG_SCALAR:.*]]: i32):
|
||||||
|
// CHECK: %[[CST1:.*]] = arith.constant 1 : i32
|
||||||
|
// CHECK: %[[ADD:.*]] = arith.addi %[[ORIG_SCALAR]], %[[CST1]] : i32
|
||||||
|
// CHECK: tm_tensor.yield %[[ADD]] : i32
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32>
|
||||||
|
// CHECK: return %[[OUT_TENSOR]] : tensor<8xi32>
|
||||||
|
func @scatter_add_scalar_1D(
|
||||||
|
%original: tensor<8xi32>, %indices: tensor<3x1xi32>,
|
||||||
|
%updates: tensor<3xi32>) -> tensor<8xi32> {
|
||||||
|
%0 = tm_tensor.scatter unique_indices(true)
|
||||||
|
ins(%updates, %indices : tensor<3xi32>, tensor<3x1xi32>)
|
||||||
|
outs(%original : tensor<8xi32>) {
|
||||||
|
^bb0(%update: i32, %orig: i32): // no predecessors
|
||||||
|
%cst1 = arith.constant 1: i32
|
||||||
|
%add = arith.addi %orig, %cst1: i32
|
||||||
|
tm_tensor.yield %add: i32
|
||||||
|
} -> tensor<8xi32>
|
||||||
|
return %0 : tensor<8xi32>
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: torch-mlir-dialects-opt -split-input-file -torch-mlir-tm-tensor-to-loops %s | FileCheck %s
|
// RUN: torch-mlir-dialects-opt -split-input-file -tm-tensor-to-loops %s | FileCheck %s
|
||||||
|
|
||||||
func @scan_1d_inclusive(%0: memref<128xi32>, %1: memref<128xi32>) {
|
func @scan_1d_inclusive(%0: memref<128xi32>, %1: memref<128xi32>) {
|
||||||
%c0 = memref.alloc() : memref<i32>
|
%c0 = memref.alloc() : memref<i32>
|
||||||
|
|
|
@ -166,6 +166,7 @@ class RefBackendInvoker:
|
||||||
LOWERING_PIPELINE = ",".join([
|
LOWERING_PIPELINE = ",".join([
|
||||||
# Bufferize.
|
# Bufferize.
|
||||||
"builtin.func(scf-bufferize)",
|
"builtin.func(scf-bufferize)",
|
||||||
|
"builtin.func(tm-tensor-bufferize)",
|
||||||
"builtin.func(linalg-bufferize)",
|
"builtin.func(linalg-bufferize)",
|
||||||
"builtin.func(refback-munge-memref-copy)",
|
"builtin.func(refback-munge-memref-copy)",
|
||||||
"func-bufferize",
|
"func-bufferize",
|
||||||
|
@ -183,6 +184,7 @@ LOWERING_PIPELINE = ",".join([
|
||||||
# global seed used in stateful rng.
|
# global seed used in stateful rng.
|
||||||
"refback-insert-rng-globals",
|
"refback-insert-rng-globals",
|
||||||
# Lower to LLVM
|
# Lower to LLVM
|
||||||
|
"builtin.func(tm-tensor-to-loops)",
|
||||||
"builtin.func(convert-linalg-to-loops)",
|
"builtin.func(convert-linalg-to-loops)",
|
||||||
"builtin.func(lower-affine)",
|
"builtin.func(lower-affine)",
|
||||||
"convert-scf-to-cf",
|
"convert-scf-to-cf",
|
||||||
|
|
Loading…
Reference in New Issue