From 486f95e84f587d020ba789b071b12f890510f1a1 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Fri, 25 Feb 2022 18:04:33 -0500 Subject: [PATCH] Add bufferization pass for TMTensor ops The pass is mostly borrowed from the BufferizeAnyLinalgOp pass in mlir upstream with some minor changes. At a high level, it's a naive partial bufferization pass which allocate new buffers for all the output tensors. The initial value of an output buffer is copied from the original buffer if there are uses of the original value. One difference from linalg bufferization pass is the way to tell if the loop body uses the init value of output operand. For TMTensor ops, it differs from op to op because the payload region doesn't represent the entire loop body. --- .../Dialect/TMTensor/IR/TMTensorInterfaces.td | 4 +- .../Dialect/TMTensor/IR/TMTensorOps.td | 12 +- .../Dialect/TMTensor/Transforms/Passes.h | 1 + .../Dialect/TMTensor/Transforms/Passes.td | 7 +- .../lib/Dialect/TMTensor/IR/TMTensorOps.cpp | 34 ++++ .../Dialect/TMTensor/Transforms/Bufferize.cpp | 155 ++++++++++++++++++ .../TMTensor/Transforms/CMakeLists.txt | 1 + .../test/tmtensor/bufferize.mlir | 114 +++++++++++++ .../test/tmtensor/convert_to_loops.mlir | 2 +- .../linalg_on_tensors_backends/refbackend.py | 2 + 10 files changed, 324 insertions(+), 8 deletions(-) create mode 100644 external/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/Bufferize.cpp create mode 100644 external/llvm-external-projects/torch-mlir-dialects/test/tmtensor/bufferize.mlir diff --git a/external/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.td b/external/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.td index beb25294d..7c4283549 100644 --- a/external/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.td +++ b/external/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.td @@ -287,8 +287,8 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> { >, InterfaceMethod< /*desc=*/[{ - Return true if the payload uses the value loaded from `opOperand`. This - is useful to avoid loading from "write-only" memory that may be + Return true if the loop body uses the value loaded from `opOperand`. + This is useful to avoid loading from "write-only" memory that may be uninitialized, as well as properly cloning "read-write" operands. }], /*retTy=*/"bool", diff --git a/external/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td b/external/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td index fd207500a..ae1319772 100644 --- a/external/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td +++ b/external/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td @@ -46,9 +46,11 @@ class TMTensor_Op traits = []> : // Non-structured ops //===----------------------------------------------------------------------===// -def TMTensor_ScanOp : TMTensor_Op<"scan" - ,[DeclareOpInterfaceMethods]> { +def TMTensor_ScanOp : TMTensor_Op<"scan", + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let summary = "Scan operator"; let description = [{ Computes the inclusive/exclusive scan along a given dimension. @@ -97,7 +99,9 @@ def TMTensor_ScanOp : TMTensor_Op<"scan" } def TMTensor_ScatterOp : TMTensor_Op<"scatter", - [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let summary = "Scatter operator"; let description = [{ diff --git a/external/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h b/external/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h index 342e23490..f4eea1b67 100644 --- a/external/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h +++ b/external/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h @@ -17,6 +17,7 @@ namespace torch { namespace TMTensor { std::unique_ptr> createTMTensorToLoopsPass(); +std::unique_ptr> createTMTensorBufferizePass(); void registerPasses(); diff --git a/external/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.td b/external/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.td index 6090a2241..b5250e4f4 100644 --- a/external/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.td +++ b/external/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.td @@ -13,9 +13,14 @@ include "mlir/Pass/PassBase.td" def TMTensorToLoops : - Pass<"torch-mlir-tm-tensor-to-loops", "FuncOp"> { + Pass<"tm-tensor-to-loops", "FuncOp"> { let summary = "Convert TMTensor ops to loops and Linalg ops."; let constructor = "mlir::torch::TMTensor::createTMTensorToLoopsPass()"; } +def TMTensorBufferize : Pass<"tm-tensor-bufferize", "FuncOp"> { + let summary = "Bufferize the TMTensor dialect"; + let constructor = "mlir::torch::TMTensor::createTMTensorBufferizePass()"; +} + #endif // TORCH_MLIR_DIALECT_TMTENSOR_PASSES diff --git a/external/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp b/external/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp index 5aa4560e6..1b461a9fa 100644 --- a/external/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp +++ b/external/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp @@ -166,6 +166,19 @@ SmallVector ScanOp::getLoopIteratorTypes() { return iteratorTypes; } +bool ScanOp::payloadUsesValueFromOperand(OpOperand *opOperand) { + Value operand = opOperand->get(); + if (operand == accumulator()) + return !inclusive(); + else if (operand == output()) + return false; + else { + assert(operand == input() && + "operand must belong to the current tm_tensor.scan op"); + return true; + } +} + // Generates naive scalar implementation of scan for a given operator f. // For inclusive, // output[0] = input[0] @@ -385,6 +398,27 @@ SmallVector ScatterOp::getLoopIteratorTypes() { return iteratorTypes; } +bool ScatterOp::payloadUsesValueFromOperand(OpOperand *opOperand) { + unsigned bbArgNumber; + Value operand = opOperand->get(); + if (operand == updates()) + bbArgNumber = 0; // block arg 0 is `update`. + else if (operand == original()) + bbArgNumber = 1; // block arg 1 is `original`. + else { + assert(operand == indices() && + "operand must belong to the current tm_tensor.scatter op"); + return true; + } + + assert(this->getOperation()->getNumRegions() == 1 && + "unexpected " + "missing region (calling `payloadUsesValueFromOperand` on " + "manually defined named TMTensor op?)"); + Block &block = this->getOperation()->getRegion(0).front(); + return !block.getArgument(bbArgNumber).use_empty(); +} + SmallVector ScatterOp::getIterationDomain(OpBuilder &builder) { Location loc = getLoc(); Value zero = builder.create(loc, 0); diff --git a/external/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/Bufferize.cpp b/external/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/Bufferize.cpp new file mode 100644 index 000000000..98d74e4f1 --- /dev/null +++ b/external/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/Bufferize.cpp @@ -0,0 +1,155 @@ +//===- Bufferize.cpp - Bufferization of tmtensor ops ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arithmetic/Utils/Utils.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/StandardOps/Transforms/Passes.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" +#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h" +#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/PassDetail.h" +#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h" + +using namespace ::mlir; +using namespace ::mlir::torch::TMTensor; + +static Value cloneMemref(Location loc, Value memref, OpBuilder &b) { + auto memrefType = memref.getType().cast(); + auto alloc = b.create( + loc, memrefType, linalg::getDynOperands(loc, memref, b)); + b.create(loc, memref, alloc); + return alloc; +} + +static LogicalResult +allocateBuffersForResults(Location loc, TMTensorOp tmtensorOp, + ValueRange outputs, + SmallVectorImpl &resultBuffers, OpBuilder &b) { + // Lazily compute loopRanges. + SmallVector loopRanges; + + // Allocate a buffer for every tensor result. + assert(tmtensorOp.getNumOutputs() == tmtensorOp->getNumResults()); + for (const auto &en : llvm::enumerate(tmtensorOp->getResultTypes())) { + size_t resultIndex = en.index(); + Type resultType = en.value(); + + auto tensorType = resultType.dyn_cast(); + if (tensorType == nullptr) { + tmtensorOp.emitOpError() + << "tensor to buffer conversion expects ranked tensor results"; + return failure(); + } + auto tensorShape = tensorType.getShape(); + auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType()); + Value resultTensor = outputs[resultIndex]; + + // Clone output buffers whose value is actually used. + OpOperand *tiedOpOperand = tmtensorOp.getOutputOperand(resultIndex); + if (tmtensorOp.payloadUsesValueFromOperand(tiedOpOperand)) { + resultBuffers.push_back(cloneMemref(loc, resultTensor, b)); + continue; + } + + // Allocate buffers for statically-shaped results. + if (memrefType.hasStaticShape()) { + resultBuffers.push_back(b.create(loc, memrefType)); + continue; + } + + resultBuffers.push_back(b.create( + loc, memrefType, linalg::getDynOperands(loc, resultTensor, b))); + } + return success(); +} + +/// Create TMTensor op on buffers given the original tensor-based operation and +/// the buffers for the outputs. +static TMTensorOp createTMTensorOpOnBuffers(ConversionPatternRewriter &rewriter, + TMTensorOp tmtensorOp, + ValueRange inputs, + ValueRange outputs) { + SmallVector newOperands = inputs; + newOperands.append(outputs.begin(), outputs.end()); + return tmtensorOp.clone(rewriter, tmtensorOp->getLoc(), {}, newOperands); +} + +/// Generic conversion pattern that matches any TMTensorOp. This avoids template +/// instantiating one pattern for each TMTensorOp. +class BufferizeAnyTMTensorOp : public OpInterfaceConversionPattern { +public: + using OpInterfaceConversionPattern::OpInterfaceConversionPattern; + + LogicalResult + matchAndRewrite(TMTensorOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + Location loc = op.getLoc(); + SmallVector newOutputBuffers; + + SmallVector outputs(operands.begin() + op.getNumInputs(), + operands.end()); + if (failed(allocateBuffersForResults(loc, op, outputs, newOutputBuffers, + rewriter))) { + return op.emitOpError() + << "Failed to allocate buffers for tensor results."; + } + + SmallVector inputs(operands.begin(), + operands.begin() + op.getNumInputs()); + createTMTensorOpOnBuffers(rewriter, op, inputs, newOutputBuffers); + // Replace the results of the old op with the new output buffers. + rewriter.replaceOp(op, newOutputBuffers); + return success(); + } +}; + +namespace { +/// Converts TMTensor operations that work on tensor-type operands or results to +/// work on buffers. +struct TMTensorBufferizePass + : public TMTensorBufferizeBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + MLIRContext &context = getContext(); + ConversionTarget target(context); + bufferization::BufferizeTypeConverter typeConverter; + + // Mark all Standard operations legal. + target.addLegalDialect(); + + // Mark all TMTensor operations illegal as long as they work on tensors. + auto isLegalOperation = [&](Operation *op) { + return typeConverter.isLegal(op); + }; + target.addDynamicallyLegalDialect(isLegalOperation); + RewritePatternSet patterns(&context); + patterns.add(typeConverter, patterns.getContext()); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr> +torch::TMTensor::createTMTensorBufferizePass() { + return std::make_unique(); +} diff --git a/external/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/CMakeLists.txt b/external/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/CMakeLists.txt index f3ab958a9..34f49ef8e 100644 --- a/external/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/CMakeLists.txt +++ b/external/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_library(TorchMLIRTMTensorPasses ConvertToLoops.cpp + Bufferize.cpp Passes.cpp DEPENDS diff --git a/external/llvm-external-projects/torch-mlir-dialects/test/tmtensor/bufferize.mlir b/external/llvm-external-projects/torch-mlir-dialects/test/tmtensor/bufferize.mlir new file mode 100644 index 000000000..b1f589428 --- /dev/null +++ b/external/llvm-external-projects/torch-mlir-dialects/test/tmtensor/bufferize.mlir @@ -0,0 +1,114 @@ +// RUN: torch-mlir-dialects-opt -split-input-file -tm-tensor-bufferize %s | FileCheck %s + +// ----- +// CHECK-LABEL: func @scan_1d_inclusive( +// CHECK-SAME: %[[IN_TENSOR:.*]]: tensor<128xi32>, %[[OUT_TENSOR:.*]]: tensor<128xi32>, +// CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor) -> (tensor<128xi32>, tensor) { +// CHECK: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32> +// CHECK: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32> +// CHECK: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref +// CHECK: tm_tensor.scan dimension(0) inclusive(true) ins(%[[IN_MEMREF]] : memref<128xi32>) +// CHECK-SAME: outs(%[[OUT_MEMREF_NEW]], %[[ACC_MEMREF_NEW]] : memref<128xi32>, memref) { +// CHECK: ^bb0(%[[OUT_PREV_ELEMENT:.*]]: i32, %[[IN_ELEMENT:.*]]: i32): +// CHECK: %[[OUT_CURRENT_ELEMENT:.*]] = arith.addi %[[OUT_PREV_ELEMENT]], %[[IN_ELEMENT]] : i32 +// CHECK: tm_tensor.yield %[[OUT_CURRENT_ELEMENT]] : i32 +// CHECK: } +// CHECK: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32> +// CHECK: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref +// CHECK: return %[[OUT_TENSOR_NEW]], %[[ACC_TENSOR_NEW]] : tensor<128xi32>, tensor +func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor) -> (tensor<128xi32>, tensor) { + %ret_out, %ret_acc = tm_tensor.scan dimension(0) inclusive(true) + ins(%in : tensor<128xi32>) outs(%out, %acc: tensor<128xi32>, tensor) { + ^bb0(%arg0 : i32, %arg1 : i32): + %sum = arith.addi %arg0, %arg1 : i32 + tm_tensor.yield %sum : i32 + } -> tensor<128xi32>, tensor + return %ret_out, %ret_acc: tensor<128xi32>, tensor +} + +// ----- +// CHECK-LABEL: func @scan_1d_exclusive( +// CHECK-SAME: %[[IN_TENSOR:.*]]: tensor<128xi32>, %[[OUT_TENSOR:.*]]: tensor<128xi32>, +// CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor) -> (tensor<128xi32>, tensor) { +// CHECK: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32> +// CHECK: %[[ACC_MEMREF:.*]] = bufferization.to_memref %[[ACC_TENSOR]] : memref +// CHECK: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32> +// CHECK: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref +// CHECK: memref.copy %[[ACC_MEMREF]], %[[ACC_MEMREF_NEW]] : memref to memref +// CHECK: tm_tensor.scan dimension(0) inclusive(false) ins(%[[IN_MEMREF]] : memref<128xi32>) +// CHECK-SAME: outs(%[[OUT_MEMREF_NEW]], %[[ACC_MEMREF_NEW]] : memref<128xi32>, memref) { +// CHECK: ^bb0(%[[OUT_PREV_ELEMENT:.*]]: i32, %[[IN_ELEMENT:.*]]: i32): +// CHECK: %[[OUT_CURRENT_ELEMENT:.*]] = arith.addi %[[OUT_PREV_ELEMENT]], %[[IN_ELEMENT]] : i32 +// CHECK: tm_tensor.yield %[[OUT_CURRENT_ELEMENT]] : i32 +// CHECK: } +// CHECK: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32> +// CHECK: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref +// CHECK: return %[[OUT_TENSOR_NEW]], %[[ACC_TENSOR_NEW]] : tensor<128xi32>, tensor +func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor) -> (tensor<128xi32>, tensor) { + %ret_out, %ret_acc = tm_tensor.scan dimension(0) inclusive(false) + ins(%in : tensor<128xi32>) outs(%out, %acc: tensor<128xi32>, tensor) { + ^bb0(%arg0 : i32, %arg1 : i32): + %sum = arith.addi %arg0, %arg1 : i32 + tm_tensor.yield %sum : i32 + } -> tensor<128xi32>, tensor + return %ret_out, %ret_acc: tensor<128xi32>, tensor +} + +// ----- +// CHECK-LABEL: func @scatter_update_scalar_1D( +// CHECK-SAME: %[[ORIG_TENSOR:.*]]: tensor<8xi32>, +// CHECK-SAME: %[[INDICES_TENSOR:.*]]: tensor<3x1xi32>, +// CHECK-SAME: %[[UPDATES_TENSOR:.*]]: tensor<3xi32>) -> tensor<8xi32> { +// CHECK: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : memref<3xi32> +// CHECK: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : memref<3x1xi32> +// CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32> +// CHECK: tm_tensor.scatter unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]] +// CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) { +// CHECK: ^bb0(%[[UPDATE_SCALAR:.*]]: i32, %[[ORIG_SCALAR:.*]]: i32): +// CHECK: tm_tensor.yield %[[UPDATE_SCALAR]] : i32 +// CHECK: } +// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32> +// CHECK: return %[[OUT_TENSOR]] : tensor<8xi32> +func @scatter_update_scalar_1D( + %original: tensor<8xi32>, %indices: tensor<3x1xi32>, + %updates: tensor<3xi32>) -> tensor<8xi32> { + %0 = tm_tensor.scatter unique_indices(true) + ins(%updates, %indices : tensor<3xi32>, tensor<3x1xi32>) + outs(%original : tensor<8xi32>) { + ^bb0(%update: i32, %orig: i32): // no predecessors + tm_tensor.yield %update: i32 + } -> tensor<8xi32> + return %0 : tensor<8xi32> +} + +// CHECK-LABEL: func @scatter_add_scalar_1D( +// CHECK-SAME: %[[ORIG_TENSOR:.*]]: tensor<8xi32>, +// CHECK-SAME: %[[INDICES_TENSOR:.*]]: tensor<3x1xi32>, +// CHECK-SAME: %[[UPDATES_TENSOR:.*]]: tensor<3xi32>) -> tensor<8xi32> { +// CHECK: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : memref<3xi32> +// CHECK: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : memref<3x1xi32> +// CHECK: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32> +// CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32> +// CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32> +// CHECK: tm_tensor.scatter unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]] +// CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) { +// CHECK: ^bb0(%[[UPDATE_SCALAR:.*]]: i32, %[[ORIG_SCALAR:.*]]: i32): +// CHECK: %[[CST1:.*]] = arith.constant 1 : i32 +// CHECK: %[[ADD:.*]] = arith.addi %[[ORIG_SCALAR]], %[[CST1]] : i32 +// CHECK: tm_tensor.yield %[[ADD]] : i32 +// CHECK: } +// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32> +// CHECK: return %[[OUT_TENSOR]] : tensor<8xi32> +func @scatter_add_scalar_1D( + %original: tensor<8xi32>, %indices: tensor<3x1xi32>, + %updates: tensor<3xi32>) -> tensor<8xi32> { + %0 = tm_tensor.scatter unique_indices(true) + ins(%updates, %indices : tensor<3xi32>, tensor<3x1xi32>) + outs(%original : tensor<8xi32>) { + ^bb0(%update: i32, %orig: i32): // no predecessors + %cst1 = arith.constant 1: i32 + %add = arith.addi %orig, %cst1: i32 + tm_tensor.yield %add: i32 + } -> tensor<8xi32> + return %0 : tensor<8xi32> +} diff --git a/external/llvm-external-projects/torch-mlir-dialects/test/tmtensor/convert_to_loops.mlir b/external/llvm-external-projects/torch-mlir-dialects/test/tmtensor/convert_to_loops.mlir index bb642464e..b4b9d2638 100644 --- a/external/llvm-external-projects/torch-mlir-dialects/test/tmtensor/convert_to_loops.mlir +++ b/external/llvm-external-projects/torch-mlir-dialects/test/tmtensor/convert_to_loops.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-dialects-opt -split-input-file -torch-mlir-tm-tensor-to-loops %s | FileCheck %s +// RUN: torch-mlir-dialects-opt -split-input-file -tm-tensor-to-loops %s | FileCheck %s func @scan_1d_inclusive(%0: memref<128xi32>, %1: memref<128xi32>) { %c0 = memref.alloc() : memref diff --git a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index a5826132b..8bb4c4985 100644 --- a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -166,6 +166,7 @@ class RefBackendInvoker: LOWERING_PIPELINE = ",".join([ # Bufferize. "builtin.func(scf-bufferize)", + "builtin.func(tm-tensor-bufferize)", "builtin.func(linalg-bufferize)", "builtin.func(refback-munge-memref-copy)", "func-bufferize", @@ -183,6 +184,7 @@ LOWERING_PIPELINE = ",".join([ # global seed used in stateful rng. "refback-insert-rng-globals", # Lower to LLVM + "builtin.func(tm-tensor-to-loops)", "builtin.func(convert-linalg-to-loops)", "builtin.func(lower-affine)", "convert-scf-to-cf",