2022-02-03 07:01:38 +08:00
|
|
|
//===------------------------------------------------------------*- C++ -*-===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
|
|
|
|
|
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
2022-03-16 18:44:23 +08:00
|
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
2022-02-03 07:01:38 +08:00
|
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
|
|
#include "mlir/Dialect/Math/IR/Math.h"
|
|
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
2022-06-23 11:23:46 +08:00
|
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
2022-02-03 07:01:38 +08:00
|
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
|
|
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
|
|
|
#include "mlir/IR/Attributes.h"
|
|
|
|
#include "mlir/IR/Builders.h"
|
|
|
|
#include "mlir/IR/Diagnostics.h"
|
|
|
|
#include "mlir/IR/Matchers.h"
|
|
|
|
#include "mlir/IR/OpImplementation.h"
|
|
|
|
#include "mlir/IR/OperationSupport.h"
|
|
|
|
#include "mlir/IR/PatternMatch.h"
|
|
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
|
|
#include "mlir/IR/Value.h"
|
|
|
|
#include "mlir/Support/LLVM.h"
|
|
|
|
#include "mlir/Support/LogicalResult.h"
|
|
|
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
|
|
|
#include "llvm/ADT/STLExtras.h"
|
|
|
|
#include "llvm/ADT/SmallSet.h"
|
|
|
|
#include "llvm/ADT/SmallVector.h"
|
|
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
|
|
#include "llvm/Support/SMLoc.h"
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::torch;
|
|
|
|
using namespace mlir::torch::TMTensor;
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Utils.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static void getEffectsImpl(
|
|
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
|
|
&effects,
|
2024-06-28 10:28:02 +08:00
|
|
|
ResultRange results, ArrayRef<OpOperand *> inputBuffers,
|
|
|
|
ArrayRef<OpOperand *> outputBuffers) {
|
|
|
|
for (OpResult value : results) {
|
2022-02-03 07:01:38 +08:00
|
|
|
effects.emplace_back(MemoryEffects::Allocate::get(), value,
|
|
|
|
SideEffects::DefaultResource::get());
|
|
|
|
}
|
2024-06-28 10:28:02 +08:00
|
|
|
for (OpOperand *value : inputBuffers) {
|
2022-02-03 07:01:38 +08:00
|
|
|
effects.emplace_back(MemoryEffects::Read::get(), value,
|
|
|
|
SideEffects::DefaultResource::get());
|
|
|
|
}
|
2024-06-28 10:28:02 +08:00
|
|
|
for (OpOperand *value : outputBuffers) {
|
2022-02-03 07:01:38 +08:00
|
|
|
effects.emplace_back(MemoryEffects::Read::get(), value,
|
|
|
|
SideEffects::DefaultResource::get());
|
|
|
|
effects.emplace_back(MemoryEffects::Write::get(), value,
|
|
|
|
SideEffects::DefaultResource::get());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
Value TMTensor::getDimValue(OpBuilder &builder, Location loc, Value v,
|
|
|
|
int64_t dim) {
|
|
|
|
return TypeSwitch<Type, Value>(v.getType())
|
|
|
|
.Case<RankedTensorType>([&](RankedTensorType t) -> Value {
|
|
|
|
return builder.create<tensor::DimOp>(loc, v, dim);
|
|
|
|
})
|
|
|
|
.Case<MemRefType>([&](MemRefType t) -> Value {
|
|
|
|
return builder.create<memref::DimOp>(loc, v, dim);
|
|
|
|
})
|
|
|
|
.Default([&](Type t) { return Value(); });
|
|
|
|
}
|
|
|
|
|
|
|
|
OpFoldResult TMTensor::getDim(OpBuilder &builder, Location loc, Value v,
|
|
|
|
int64_t dim) {
|
2024-04-28 05:00:56 +08:00
|
|
|
auto t = cast<ShapedType>(v.getType());
|
2022-02-03 07:01:38 +08:00
|
|
|
if (t.isDynamicDim(dim)) {
|
|
|
|
return getDimValue(builder, loc, v, dim);
|
|
|
|
}
|
|
|
|
return builder.getI64IntegerAttr(t.getDimSize(dim));
|
|
|
|
}
|
|
|
|
|
2023-05-17 03:17:45 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AttentionOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult AttentionOp::verify() {
|
|
|
|
Operation *op = getOperation();
|
|
|
|
ShapedType queryType = getQueryType();
|
|
|
|
ShapedType keyType = getKeyType();
|
|
|
|
ArrayRef<int64_t> queryShape = queryType.getShape();
|
|
|
|
ArrayRef<int64_t> keyShape = keyType.getShape();
|
2024-02-14 13:18:01 +08:00
|
|
|
for (int i = 0, s = queryShape.size() - 2; i < s; ++i) {
|
|
|
|
if (keyShape[i] != queryShape[i])
|
|
|
|
return op->emitOpError("query and key batch mismatch");
|
|
|
|
}
|
|
|
|
if (keyShape.back() != queryShape.back())
|
2023-05-26 07:04:54 +08:00
|
|
|
return op->emitOpError("query and key head dimension mismatch");
|
2023-05-17 03:17:45 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<Range> AttentionOp::getIterationDomain(OpBuilder &builder) {
|
2024-02-14 13:18:01 +08:00
|
|
|
SmallVector<Range> loopBounds;
|
2023-05-17 03:17:45 +08:00
|
|
|
return loopBounds;
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<utils::IteratorType> AttentionOp::getLoopIteratorTypes() {
|
2024-02-14 13:18:01 +08:00
|
|
|
SmallVector<utils::IteratorType> iteratorTypes;
|
2023-05-17 03:17:45 +08:00
|
|
|
return iteratorTypes;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool AttentionOp::payloadUsesValueFromOperand(OpOperand *opOperand) {
|
|
|
|
Value operand = opOperand->get();
|
|
|
|
return operand == getQuery() || operand == getKey() || operand == getValue();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Performs a matmul between lhs and rhs
|
|
|
|
// Note that "transposed" means the last two dims of rhs are swapped
|
|
|
|
static void matmul(OpBuilder &b, Location loc, Value lhs, ValueRange lhsSizes,
|
|
|
|
Value rhs, ValueRange rhsSizes, Value output,
|
|
|
|
ValueRange outputSizes, bool transposed = false) {
|
2024-04-28 05:00:56 +08:00
|
|
|
auto elementType = cast<MemRefType>(lhs.getType()).getElementType();
|
2023-05-17 03:17:45 +08:00
|
|
|
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
|
|
|
|
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
|
|
|
auto rank = outputSizes.size();
|
|
|
|
Value reductionDimSize = lhsSizes[lhsSizes.size() - 1];
|
|
|
|
|
|
|
|
// Loop over output
|
|
|
|
b.create<scf::ParallelOp>(
|
|
|
|
loc, SmallVector<Value>(rank, zero), outputSizes,
|
|
|
|
SmallVector<Value>(rank, one),
|
|
|
|
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
|
|
|
|
Value acc = b.create<arith::ConstantOp>(
|
|
|
|
loc, elementType, b.getFloatAttr(elementType, 0.0));
|
|
|
|
Value sum =
|
|
|
|
b.create<scf::ForOp>(
|
|
|
|
loc, zero, reductionDimSize, one, SmallVector<Value>{acc},
|
|
|
|
[&](OpBuilder &b, Location loc, Value i, ValueRange accs) {
|
|
|
|
SmallVector<Value> lhsIVs(localIVs), rhsIVs(localIVs);
|
|
|
|
lhsIVs[lhsIVs.size() - 1] = i;
|
|
|
|
rhsIVs[rhsIVs.size() - 2] = i;
|
|
|
|
if (transposed)
|
|
|
|
std::swap(rhsIVs[rhsIVs.size() - 1],
|
|
|
|
rhsIVs[rhsIVs.size() - 2]);
|
|
|
|
|
|
|
|
Value acc = accs[0];
|
|
|
|
Value rElem = b.create<memref::LoadOp>(loc, lhs, lhsIVs);
|
|
|
|
Value cElem = b.create<memref::LoadOp>(loc, rhs, rhsIVs);
|
|
|
|
Value x = b.create<arith::MulFOp>(loc, rElem, cElem);
|
|
|
|
x = b.create<arith::AddFOp>(loc, x, acc);
|
|
|
|
|
|
|
|
b.create<scf::YieldOp>(loc, x);
|
|
|
|
})
|
|
|
|
->getResult(0);
|
|
|
|
b.create<memref::StoreOp>(loc, sum, output, localIVs);
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
|
|
|
|
Location loc,
|
|
|
|
ValueRange ivs) {
|
|
|
|
|
|
|
|
Value query = getQuery();
|
|
|
|
Value key = getKey();
|
|
|
|
Value value = getValue();
|
|
|
|
Value output = getOutput();
|
2024-04-28 05:00:56 +08:00
|
|
|
auto queryType = cast<MemRefType>(query.getType());
|
|
|
|
auto keyType = cast<MemRefType>(key.getType());
|
|
|
|
auto valueType = cast<MemRefType>(value.getType());
|
2023-05-17 03:17:45 +08:00
|
|
|
auto queryRank = queryType.getRank();
|
|
|
|
auto keyRank = keyType.getRank();
|
|
|
|
auto valueRank = valueType.getRank();
|
|
|
|
auto keySizes = keyType.getShape();
|
|
|
|
Type elementType = queryType.getElementType();
|
|
|
|
|
|
|
|
Value zeroF = b.create<arith::ConstantOp>(loc, elementType,
|
|
|
|
b.getFloatAttr(elementType, 0.0));
|
|
|
|
|
2024-02-14 13:18:01 +08:00
|
|
|
// TODO: This needs to be fixed, it assumes everything is dynamic however if
|
|
|
|
// any shapes are static the `memref.alloc` generated is illegal.
|
2023-05-17 03:17:45 +08:00
|
|
|
SmallVector<Value> queryDynSizes, keyDynSizes, valueDynSizes, outputDynSizes;
|
|
|
|
for (auto i = 0; i < queryRank; i++)
|
|
|
|
queryDynSizes.push_back(b.create<memref::DimOp>(loc, query, i));
|
|
|
|
for (auto i = 0; i < keyRank; i++)
|
|
|
|
keyDynSizes.push_back(b.create<memref::DimOp>(loc, key, i));
|
|
|
|
for (auto i = 0; i < valueRank; i++)
|
|
|
|
valueDynSizes.push_back(b.create<memref::DimOp>(loc, value, i));
|
|
|
|
for (auto i = 0; i < queryRank; i++)
|
|
|
|
outputDynSizes.push_back(b.create<memref::DimOp>(loc, output, i));
|
|
|
|
|
|
|
|
// weight = query @ key
|
|
|
|
auto weightRank = queryRank;
|
|
|
|
auto weightSizes = SmallVector<int64_t>(queryType.getShape());
|
|
|
|
weightSizes[weightRank - 1] = keySizes[keyRank - 2];
|
|
|
|
auto weightType = MemRefType::get(weightSizes, queryType.getElementType());
|
2024-02-14 13:18:01 +08:00
|
|
|
|
|
|
|
// Setup the weight dynamic sizes:
|
2023-05-17 03:17:45 +08:00
|
|
|
SmallVector<Value> weightDynSizes(queryDynSizes);
|
|
|
|
weightDynSizes[weightRank - 1] = keyDynSizes[keyRank - 2];
|
2024-02-14 13:18:01 +08:00
|
|
|
|
|
|
|
SmallVector<Value> weightFilteredDynSizes;
|
|
|
|
for (int i = 0; i < weightRank; ++i)
|
|
|
|
if (weightSizes[i] == ShapedType::kDynamic)
|
|
|
|
weightFilteredDynSizes.push_back(weightDynSizes[i]);
|
|
|
|
|
|
|
|
Value weight =
|
|
|
|
b.create<memref::AllocOp>(loc, weightType, weightFilteredDynSizes);
|
2023-05-17 03:17:45 +08:00
|
|
|
matmul(b, loc, query, queryDynSizes, key, keyDynSizes, weight, weightDynSizes,
|
|
|
|
/*transposed=*/true);
|
|
|
|
|
|
|
|
// weight = softmax(weight)
|
|
|
|
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
|
|
|
|
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
|
|
|
Value dim = weightDynSizes[weightRank - 1];
|
|
|
|
Value scaleFactor = b.create<math::SqrtOp>(
|
|
|
|
loc, b.create<arith::UIToFPOp>(
|
|
|
|
loc, elementType,
|
|
|
|
b.create<arith::IndexCastUIOp>(loc, b.getI32Type(),
|
|
|
|
queryDynSizes[queryRank - 1])));
|
|
|
|
// calculate max(weight)
|
|
|
|
Value init = b.create<memref::LoadOp>(loc, weight,
|
|
|
|
SmallVector<Value>(weightRank, zero));
|
|
|
|
Value globalMax =
|
|
|
|
b.create<scf::ParallelOp>(
|
|
|
|
loc, SmallVector<Value>(weightRank, zero), weightDynSizes,
|
|
|
|
SmallVector<Value>(weightRank, one), init,
|
|
|
|
[&](OpBuilder &b, Location loc, ValueRange localIVs,
|
|
|
|
ValueRange accs) {
|
2024-01-05 06:33:41 +08:00
|
|
|
auto reduceOp = b.create<scf::ReduceOp>(loc, init);
|
|
|
|
// Build reduce body.
|
|
|
|
Block &reductionBody = reduceOp.getReductions()[0].front();
|
|
|
|
auto bodyBuilder = OpBuilder::atBlockEnd(&reductionBody);
|
|
|
|
Value acc = reductionBody.getArgument(0);
|
|
|
|
Value x =
|
|
|
|
bodyBuilder.create<memref::LoadOp>(loc, weight, localIVs);
|
|
|
|
Value max = bodyBuilder.create<arith::MaximumFOp>(loc, x, acc);
|
|
|
|
bodyBuilder.create<scf::ReduceReturnOp>(loc, max);
|
2023-05-17 03:17:45 +08:00
|
|
|
})
|
|
|
|
.getResult(0);
|
|
|
|
// weight = (weight - max(weight)) / math.sqrt(querySizes[-1])
|
|
|
|
b.create<scf::ParallelOp>(
|
|
|
|
loc, SmallVector<Value>(weightRank, zero), weightDynSizes,
|
|
|
|
SmallVector<Value>(weightRank, one),
|
|
|
|
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
|
|
|
|
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
|
|
|
|
x = b.create<arith::SubFOp>(loc, x, globalMax);
|
|
|
|
x = b.create<arith::DivFOp>(loc, x, scaleFactor);
|
|
|
|
b.create<memref::StoreOp>(loc, x, weight, localIVs);
|
|
|
|
});
|
|
|
|
// calculate exp(weight)
|
|
|
|
SmallVector<Value> min(weightRank, zero),
|
|
|
|
max(weightDynSizes.begin(), weightDynSizes.end()), steps(weightRank, one);
|
|
|
|
b.create<scf::ParallelOp>(
|
|
|
|
loc, min, max, steps,
|
|
|
|
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
|
|
|
|
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
|
|
|
|
x = b.create<math::ExpOp>(loc, x);
|
|
|
|
b.create<memref::StoreOp>(loc, x, weight, localIVs);
|
|
|
|
});
|
2024-02-14 13:18:01 +08:00
|
|
|
|
|
|
|
llvm::SmallVector<Value> expWeightDynDims(weightFilteredDynSizes);
|
|
|
|
if (weightSizes.back() == ShapedType::kDynamic)
|
|
|
|
expWeightDynDims.resize(expWeightDynDims.size() - 1);
|
|
|
|
|
2023-05-17 03:17:45 +08:00
|
|
|
Value expWeightSum = b.create<memref::AllocOp>(
|
|
|
|
loc,
|
|
|
|
MemRefType::get(
|
|
|
|
SmallVector<int64_t>(weightSizes.begin(), weightSizes.end() - 1),
|
|
|
|
elementType),
|
2024-02-14 13:18:01 +08:00
|
|
|
expWeightDynDims);
|
2023-05-17 03:17:45 +08:00
|
|
|
b.create<scf::ParallelOp>(
|
|
|
|
loc, SmallVector<Value>(weightRank - 1, zero),
|
|
|
|
SmallVector<Value>{weightDynSizes.begin(), weightDynSizes.end() - 1},
|
|
|
|
SmallVector<Value>(weightRank - 1, one),
|
|
|
|
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
|
|
|
|
b.create<memref::StoreOp>(loc, zeroF, expWeightSum, localIVs);
|
|
|
|
});
|
|
|
|
// Loop over all dims but -1
|
|
|
|
b.create<scf::ParallelOp>(
|
|
|
|
loc, SmallVector<Value>(weightRank - 1, zero),
|
|
|
|
SmallVector<Value>(weightDynSizes.begin(), weightDynSizes.end() - 1),
|
|
|
|
SmallVector<Value>(weightRank - 1, one),
|
|
|
|
[&](OpBuilder &b, Location loc, ValueRange outsideDims) {
|
|
|
|
// Sum over last dim
|
|
|
|
b.create<scf::ParallelOp>(
|
|
|
|
loc, zero, dim, one,
|
|
|
|
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
|
|
|
|
SmallVector<Value> coords(outsideDims);
|
|
|
|
coords.push_back(localIVs[0]);
|
|
|
|
Value x =
|
|
|
|
b.create<memref::LoadOp>(loc, expWeightSum, outsideDims);
|
|
|
|
Value y = b.create<memref::LoadOp>(loc, weight, coords);
|
|
|
|
Value sum = b.create<arith::AddFOp>(loc, x, y);
|
|
|
|
b.create<memref::StoreOp>(loc, sum, expWeightSum, outsideDims);
|
|
|
|
});
|
|
|
|
});
|
|
|
|
// calculate exp(weight) / sum(exp(weight))
|
|
|
|
b.create<scf::ParallelOp>(
|
|
|
|
loc, SmallVector<Value>(weightRank, zero),
|
|
|
|
SmallVector<Value>(weightDynSizes.begin(), weightDynSizes.end()),
|
|
|
|
SmallVector<Value>(weightRank, one),
|
|
|
|
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
|
|
|
|
SmallVector<Value> sumIVs(localIVs);
|
|
|
|
sumIVs.pop_back();
|
|
|
|
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
|
|
|
|
Value sum = b.create<memref::LoadOp>(loc, expWeightSum, sumIVs);
|
|
|
|
x = b.create<arith::DivFOp>(loc, x, sum);
|
|
|
|
b.create<memref::StoreOp>(loc, x, weight, localIVs);
|
|
|
|
});
|
|
|
|
|
|
|
|
// output = weight @ value
|
|
|
|
matmul(b, loc, weight, weightDynSizes, value, valueDynSizes, output,
|
|
|
|
outputDynSizes, /*transposed=*/false);
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-02-03 07:01:38 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ScanOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2022-04-27 03:27:51 +08:00
|
|
|
LogicalResult ScanOp::verify() {
|
|
|
|
if (getNumInputs() != 1) {
|
|
|
|
return emitOpError("expected one input operands");
|
2022-02-03 07:01:38 +08:00
|
|
|
}
|
2022-04-27 03:27:51 +08:00
|
|
|
if (getNumOutputs() != 2) {
|
|
|
|
return emitOpError("expected two output operands");
|
2022-02-03 07:01:38 +08:00
|
|
|
}
|
2024-04-28 05:00:56 +08:00
|
|
|
if (!isa<ShapedType>(input().getType())) {
|
2022-04-27 03:27:51 +08:00
|
|
|
return emitOpError("expected first input element type to be shaped");
|
2022-02-03 07:01:38 +08:00
|
|
|
}
|
2024-04-28 05:00:56 +08:00
|
|
|
auto accumulatorType = cast<ShapedType>(accumulator().getType());
|
|
|
|
auto inputType = cast<ShapedType>(input().getType());
|
|
|
|
auto outputType = cast<ShapedType>(output().getType());
|
2022-02-03 07:01:38 +08:00
|
|
|
ArrayRef<int64_t> inputShapes = inputType.getShape();
|
|
|
|
ArrayRef<int64_t> outputShapes = outputType.getShape();
|
|
|
|
if (accumulatorType.getElementType() != inputType.getElementType()) {
|
2022-04-27 03:27:51 +08:00
|
|
|
return emitOpError(
|
2022-02-03 07:01:38 +08:00
|
|
|
"expected input/accumulator element types to be identical");
|
|
|
|
}
|
|
|
|
ArrayRef<int64_t> accumulatorShape = accumulatorType.getShape();
|
|
|
|
int64_t accumulatorRank = accumulatorType.getRank();
|
|
|
|
if (accumulatorRank != inputType.getRank() - 1) {
|
2022-04-27 03:27:51 +08:00
|
|
|
return emitOpError(
|
2022-02-03 07:01:38 +08:00
|
|
|
"expected accumulator rank to be equal to input rank - 1");
|
|
|
|
}
|
|
|
|
SmallVector<int64_t> expectedAccumulatorShape;
|
|
|
|
for (size_t i = 0; i < (size_t)inputType.getRank(); i++) {
|
2022-12-08 04:20:41 +08:00
|
|
|
if (i != getDimension())
|
2022-02-03 07:01:38 +08:00
|
|
|
expectedAccumulatorShape.push_back(inputShapes[i]);
|
|
|
|
}
|
|
|
|
if (llvm::any_of(llvm::zip(expectedAccumulatorShape, accumulatorShape),
|
|
|
|
[](std::tuple<int64_t, int64_t> s) {
|
2022-12-02 12:38:28 +08:00
|
|
|
return std::get<0>(s) != ShapedType::kDynamic &&
|
|
|
|
std::get<1>(s) != ShapedType::kDynamic &&
|
2022-02-03 07:01:38 +08:00
|
|
|
std::get<0>(s) != std::get<1>(s);
|
|
|
|
})) {
|
2022-04-27 03:27:51 +08:00
|
|
|
return emitOpError("incompatible input/accumulator shapes");
|
2022-02-03 07:01:38 +08:00
|
|
|
}
|
|
|
|
if (inputType.getElementType() != outputType.getElementType()) {
|
2022-04-27 03:27:51 +08:00
|
|
|
return emitOpError("expected input/output element types to be identical");
|
2022-02-03 07:01:38 +08:00
|
|
|
}
|
|
|
|
if (inputShapes.size() != outputShapes.size()) {
|
2022-04-27 03:27:51 +08:00
|
|
|
return emitOpError("expected input/output to have identical ranks");
|
2022-02-03 07:01:38 +08:00
|
|
|
}
|
|
|
|
if (llvm::any_of(llvm::zip(inputShapes, outputShapes),
|
|
|
|
[](std::tuple<int64_t, int64_t> s) {
|
2022-12-02 12:38:28 +08:00
|
|
|
return std::get<0>(s) != ShapedType::kDynamic &&
|
|
|
|
std::get<1>(s) != ShapedType::kDynamic &&
|
2022-02-03 07:01:38 +08:00
|
|
|
std::get<0>(s) != std::get<1>(s);
|
|
|
|
})) {
|
2022-04-27 03:27:51 +08:00
|
|
|
return emitOpError("incompatible input/output shapes");
|
2022-02-03 07:01:38 +08:00
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<Range> ScanOp::getIterationDomain(OpBuilder &builder) {
|
|
|
|
int64_t operandRank = getOperandRank();
|
|
|
|
SmallVector<Range> loopBounds(operandRank);
|
|
|
|
Location loc = getLoc();
|
|
|
|
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
|
|
|
|
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
|
|
|
|
Value source = input();
|
|
|
|
for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
|
|
|
|
loopBounds[dim].offset = zero;
|
|
|
|
loopBounds[dim].size = getDimValue(builder, loc, source, dim);
|
|
|
|
loopBounds[dim].stride = one;
|
|
|
|
}
|
|
|
|
return loopBounds;
|
|
|
|
}
|
|
|
|
|
2022-11-17 06:40:36 +08:00
|
|
|
SmallVector<utils::IteratorType> ScanOp::getLoopIteratorTypes() {
|
|
|
|
SmallVector<utils::IteratorType> iteratorTypes(getOperandRank(),
|
|
|
|
utils::IteratorType::parallel);
|
2022-12-08 04:20:41 +08:00
|
|
|
iteratorTypes[getDimension()] = utils::IteratorType::reduction;
|
2022-02-03 07:01:38 +08:00
|
|
|
return iteratorTypes;
|
|
|
|
}
|
|
|
|
|
2022-02-26 07:04:33 +08:00
|
|
|
bool ScanOp::payloadUsesValueFromOperand(OpOperand *opOperand) {
|
|
|
|
Value operand = opOperand->get();
|
|
|
|
if (operand == accumulator())
|
2022-12-08 04:20:41 +08:00
|
|
|
return !getInclusive();
|
2022-02-26 07:04:33 +08:00
|
|
|
else if (operand == output())
|
|
|
|
return false;
|
|
|
|
else {
|
|
|
|
assert(operand == input() &&
|
|
|
|
"operand must belong to the current tm_tensor.scan op");
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-02-03 07:01:38 +08:00
|
|
|
// Generates naive scalar implementation of scan for a given operator f.
|
|
|
|
// For inclusive,
|
|
|
|
// output[0] = input[0]
|
|
|
|
// output[i] = f(output[i-1], input[i])
|
|
|
|
//
|
|
|
|
// For exclusive,
|
|
|
|
// output[0] = 0
|
|
|
|
// output[i] = f(output[i-1], input[i-1])
|
|
|
|
|
|
|
|
LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc,
|
|
|
|
ValueRange ivs) {
|
|
|
|
SmallVector<Value> indices, scanBlkArgs;
|
|
|
|
indices.append(ivs.begin(), ivs.end());
|
|
|
|
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
|
|
|
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
|
2022-12-08 04:20:41 +08:00
|
|
|
uint64_t scanDim = getDimension();
|
2022-02-03 07:01:38 +08:00
|
|
|
Value cond = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
|
|
|
|
indices[scanDim], zero);
|
2022-12-08 04:20:41 +08:00
|
|
|
bool isInclusive = getInclusive();
|
2022-02-03 07:01:38 +08:00
|
|
|
SmallVector<Value> accIndices;
|
|
|
|
for (size_t i = 0; i < indices.size(); i++) {
|
|
|
|
if (i != scanDim)
|
|
|
|
accIndices.push_back(indices[i]);
|
|
|
|
}
|
|
|
|
|
|
|
|
auto scfIf = b.create<scf::IfOp>(
|
2023-01-25 09:29:42 +08:00
|
|
|
loc, cond,
|
2022-02-03 07:01:38 +08:00
|
|
|
[&](OpBuilder &b, Location loc) {
|
|
|
|
if (isInclusive) {
|
|
|
|
auto value = b.create<memref::LoadOp>(loc, input(), indices);
|
|
|
|
b.create<memref::StoreOp>(loc, value, output(), indices);
|
|
|
|
} else {
|
|
|
|
auto value = b.create<memref::LoadOp>(loc, accumulator(), accIndices);
|
|
|
|
b.create<memref::StoreOp>(loc, value, output(), indices);
|
|
|
|
}
|
|
|
|
b.create<scf::YieldOp>(loc);
|
|
|
|
},
|
|
|
|
[&](OpBuilder &b, Location loc) {
|
|
|
|
SmallVector<Value> indices(ivs.begin(), ivs.end());
|
|
|
|
Value iv = indices[scanDim];
|
|
|
|
Value ivMinusOne = b.create<arith::SubIOp>(loc, iv, one);
|
|
|
|
indices[scanDim] = ivMinusOne;
|
|
|
|
scanBlkArgs.push_back(b.create<memref::LoadOp>(loc, output(), indices));
|
|
|
|
Value i0;
|
|
|
|
if (!isInclusive)
|
|
|
|
i0 = b.create<memref::LoadOp>(loc, input(), indices);
|
|
|
|
indices[scanDim] = iv;
|
|
|
|
if (isInclusive)
|
|
|
|
i0 = b.create<memref::LoadOp>(loc, input(), indices);
|
|
|
|
scanBlkArgs.push_back(i0);
|
|
|
|
});
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
auto &srcBlock = getRegion().front();
|
2022-04-27 03:27:51 +08:00
|
|
|
Region &thisRegion = scfIf.getElseRegion();
|
2023-01-24 08:34:22 +08:00
|
|
|
IRMapping bvm;
|
2022-02-03 07:01:38 +08:00
|
|
|
{
|
|
|
|
OpBuilder::InsertionGuard guard(b);
|
2022-04-27 03:27:51 +08:00
|
|
|
auto &block = thisRegion.front();
|
2022-02-03 07:01:38 +08:00
|
|
|
b.setInsertionPointToEnd(&block);
|
|
|
|
for (auto it : llvm::zip(srcBlock.getArguments(), scanBlkArgs)) {
|
|
|
|
bvm.map(std::get<0>(it), std::get<1>(it));
|
|
|
|
}
|
|
|
|
for (auto &blockOp : srcBlock.without_terminator()) {
|
|
|
|
b.clone(blockOp, bvm);
|
|
|
|
}
|
|
|
|
b.create<memref::StoreOp>(
|
|
|
|
loc, bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)),
|
|
|
|
output(), indices);
|
|
|
|
b.create<memref::StoreOp>(
|
|
|
|
loc, bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)),
|
|
|
|
accumulator(), accIndices);
|
|
|
|
b.create<scf::YieldOp>(loc);
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
static LogicalResult foldMemRefCast(Operation *op) {
|
|
|
|
bool folded = false;
|
|
|
|
for (OpOperand &operand : op->getOpOperands()) {
|
|
|
|
auto castOp = operand.get().getDefiningOp<memref::CastOp>();
|
|
|
|
if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
|
|
|
|
operand.set(castOp.getOperand());
|
|
|
|
folded = true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return success(folded);
|
|
|
|
}
|
|
|
|
|
2023-01-25 09:29:42 +08:00
|
|
|
LogicalResult ScanOp::fold(FoldAdaptor adaptor,
|
2022-02-03 07:01:38 +08:00
|
|
|
SmallVectorImpl<OpFoldResult> &) {
|
|
|
|
return foldMemRefCast(*this);
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ScatterOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
2024-02-28 03:46:57 +08:00
|
|
|
static Type getComplexElementTypeOrSelf(Type ty) {
|
|
|
|
if (auto complex = dyn_cast_or_null<ComplexType>(ty))
|
|
|
|
return complex.getElementType();
|
|
|
|
return ty;
|
|
|
|
}
|
|
|
|
|
|
|
|
static bool isInvalid(ArrayRef<int64_t> dimsPos, int64_t rank) {
|
|
|
|
// early exit.
|
|
|
|
if (static_cast<int64_t>(dimsPos.size()) > rank)
|
|
|
|
return true;
|
|
|
|
DenseSet<int64_t> uniqued;
|
|
|
|
for (int64_t dim : dimsPos)
|
|
|
|
uniqued.insert(dim);
|
|
|
|
if (static_cast<int64_t>(dimsPos.size()) != uniqued.size())
|
|
|
|
return true;
|
|
|
|
return llvm::any_of(
|
|
|
|
dimsPos, [rank](int64_t dimPos) { return dimPos < 0 || dimPos >= rank; });
|
|
|
|
}
|
|
|
|
|
2022-04-27 03:27:51 +08:00
|
|
|
LogicalResult ScatterOp::verify() {
|
2024-02-28 03:46:57 +08:00
|
|
|
Operation *op = getOperation();
|
2022-12-08 04:20:41 +08:00
|
|
|
if (getInputs().size() != 2) {
|
2024-02-28 03:46:57 +08:00
|
|
|
return op->emitOpError("expected two input operands");
|
2022-02-03 07:01:38 +08:00
|
|
|
}
|
2022-12-08 04:20:41 +08:00
|
|
|
if (getOutputs().size() != 1) {
|
2024-02-28 03:46:57 +08:00
|
|
|
return op->emitOpError("expected one output operand");
|
2022-02-03 07:01:38 +08:00
|
|
|
}
|
|
|
|
auto checkDimensionsMatch = [&](ShapedType t1, ShapedType t2, unsigned dim) {
|
|
|
|
return t1.getShape()[dim] == t2.getShape()[dim];
|
|
|
|
};
|
|
|
|
|
2022-04-27 03:27:51 +08:00
|
|
|
auto indicesType = getIndicesType();
|
2022-02-03 07:01:38 +08:00
|
|
|
if (indicesType.getRank() != 2 ||
|
|
|
|
!indicesType.getElementType().isInteger(32)) {
|
2022-04-27 03:27:51 +08:00
|
|
|
return emitOpError("expected indices to be of rank 2 of i32 element type");
|
2022-02-03 07:01:38 +08:00
|
|
|
}
|
2022-04-27 03:27:51 +08:00
|
|
|
auto indexDepth = getIndexDepth();
|
2024-02-28 03:46:57 +08:00
|
|
|
if (ShapedType::isDynamic(indexDepth)) {
|
2022-04-27 03:27:51 +08:00
|
|
|
return emitOpError("expected index depth is static");
|
2022-02-03 07:01:38 +08:00
|
|
|
}
|
|
|
|
|
2024-02-28 03:46:57 +08:00
|
|
|
ArrayRef<int64_t> dimMap = getDimensionMap();
|
|
|
|
if (static_cast<int64_t>(dimMap.size()) != indexDepth) {
|
|
|
|
return op->emitOpError("invalid number of dimension map entries ");
|
|
|
|
}
|
|
|
|
|
|
|
|
auto originalType = getOriginalType();
|
|
|
|
if (isInvalid(dimMap, originalType.getRank()))
|
|
|
|
return op->emitOpError("dimension map is invalid");
|
|
|
|
|
2022-02-03 07:01:38 +08:00
|
|
|
// The first dimension of the indices should match the first dimension of the
|
|
|
|
// output. They indicate to the number of updates.
|
2022-04-27 03:27:51 +08:00
|
|
|
auto updateType = getUpdateType();
|
2022-02-03 07:01:38 +08:00
|
|
|
if (updateType.getRank() < 1) {
|
2022-04-27 03:27:51 +08:00
|
|
|
return emitOpError("expected update value to be at least rank 1");
|
2022-02-03 07:01:38 +08:00
|
|
|
}
|
|
|
|
if (!checkDimensionsMatch(indicesType, updateType, 0)) {
|
2022-04-27 03:27:51 +08:00
|
|
|
return emitOpError(
|
2022-02-03 07:01:38 +08:00
|
|
|
"mismatch in shape of indices and update value at dim#0");
|
|
|
|
}
|
2022-02-21 15:40:02 +08:00
|
|
|
if (updateType.getRank() - 1 > originalType.getRank()) {
|
2022-04-27 03:27:51 +08:00
|
|
|
return emitOpError(
|
2022-02-21 15:40:02 +08:00
|
|
|
"update value rank exceeds the rank of the original value");
|
|
|
|
}
|
|
|
|
|
|
|
|
// indexDepth + update dims should cover the original dims. The first dim of
|
2022-02-03 07:01:38 +08:00
|
|
|
// update is the number of updates.
|
2022-02-21 15:40:02 +08:00
|
|
|
if (originalType.getRank() > indexDepth + updateType.getRank() - 1) {
|
2022-04-27 03:27:51 +08:00
|
|
|
return emitOpError(
|
2022-02-21 15:40:02 +08:00
|
|
|
"index depth and update value does not cover rank of original value");
|
2022-02-03 07:01:38 +08:00
|
|
|
}
|
2022-02-21 15:40:02 +08:00
|
|
|
|
2024-02-28 03:46:57 +08:00
|
|
|
// Validate the non-indexed update dims cover the full slice size of the
|
2022-02-21 15:40:02 +08:00
|
|
|
// original tensor.
|
|
|
|
int64_t fullSliceDims = originalType.getRank() - indexDepth;
|
|
|
|
for (auto it :
|
|
|
|
llvm::zip(llvm::seq<unsigned>(indexDepth, originalType.getRank()),
|
|
|
|
llvm::seq<unsigned>(updateType.getRank() - fullSliceDims,
|
|
|
|
updateType.getRank()))) {
|
|
|
|
int64_t originalDim = std::get<0>(it);
|
|
|
|
int64_t updateDim = std::get<1>(it);
|
2024-02-28 03:46:57 +08:00
|
|
|
if (!originalType.isDynamicDim(originalDim) &&
|
|
|
|
updateType.getDimSize(updateDim) >
|
|
|
|
originalType.getDimSize(originalDim)) {
|
|
|
|
return op->emitOpError("shape of update value dim#")
|
|
|
|
<< updateDim << " exceeds original value at dim#" << originalDim;
|
2022-02-03 07:01:38 +08:00
|
|
|
}
|
|
|
|
}
|
2022-02-21 15:40:02 +08:00
|
|
|
|
|
|
|
// Check that the remaining update indices do not exceed the update length.
|
|
|
|
int64_t insertDims = originalType.getRank() - updateType.getRank() + 1;
|
|
|
|
for (auto it : llvm::zip(
|
|
|
|
llvm::seq<unsigned>(insertDims, indexDepth),
|
|
|
|
llvm::seq<unsigned>(1, updateType.getRank() - fullSliceDims))) {
|
|
|
|
int64_t originalDim = std::get<0>(it);
|
|
|
|
int64_t updateDim = std::get<1>(it);
|
2024-02-28 03:46:57 +08:00
|
|
|
if (!originalType.isDynamicDim(originalDim) &&
|
|
|
|
updateType.getDimSize(updateDim) >
|
|
|
|
originalType.getDimSize(originalDim)) {
|
|
|
|
return op->emitOpError("indexed shape of update value dim#")
|
2022-02-21 15:40:02 +08:00
|
|
|
<< updateDim << " exceeds original value at dim#" << originalDim
|
|
|
|
<< " " << updateType.getDimSize(updateDim) << " "
|
|
|
|
<< originalType.getDimSize(originalDim);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-02-28 03:46:57 +08:00
|
|
|
Region ®ion = this->getRegion();
|
|
|
|
Block *body = ®ion.front();
|
2022-02-03 07:01:38 +08:00
|
|
|
if (body->getNumArguments() != 2) {
|
2024-02-28 03:46:57 +08:00
|
|
|
return op->emitOpError("expected region to have two arguments");
|
2022-02-03 07:01:38 +08:00
|
|
|
}
|
|
|
|
Type arg0Type = body->getArgument(0).getType();
|
|
|
|
Type arg1Type = body->getArgument(1).getType();
|
2024-02-28 03:46:57 +08:00
|
|
|
if (!getComplexElementTypeOrSelf(arg0Type).isIntOrFloat() ||
|
|
|
|
!getComplexElementTypeOrSelf(arg1Type).isIntOrFloat()) {
|
2022-04-27 03:27:51 +08:00
|
|
|
return emitOpError(
|
2022-02-03 07:01:38 +08:00
|
|
|
"expected region to have scalar argument of integer or float types");
|
|
|
|
}
|
|
|
|
if (arg0Type != updateType.getElementType()) {
|
2022-04-27 03:27:51 +08:00
|
|
|
return emitOpError("mismatch in argument 0 of region ")
|
2022-02-03 07:01:38 +08:00
|
|
|
<< arg0Type << " and element type of update value "
|
|
|
|
<< updateType.getElementType();
|
|
|
|
}
|
|
|
|
if (arg1Type != originalType.getElementType()) {
|
2022-04-27 03:27:51 +08:00
|
|
|
return emitOpError("mismatch in argument 1 of region ")
|
2022-02-03 07:01:38 +08:00
|
|
|
<< arg1Type << " and element type of original value "
|
|
|
|
<< originalType.getElementType();
|
|
|
|
}
|
|
|
|
if (arg0Type != arg1Type) {
|
2022-04-27 03:27:51 +08:00
|
|
|
return emitOpError("mismatch in region argument types ")
|
2022-02-03 07:01:38 +08:00
|
|
|
<< arg0Type << " and " << arg1Type;
|
|
|
|
}
|
|
|
|
auto yieldOp = cast<TMTensor::YieldOp>(body->getTerminator());
|
|
|
|
if (yieldOp->getNumOperands() != 1) {
|
|
|
|
return yieldOp.emitOpError("expected region to yield a single value");
|
|
|
|
}
|
|
|
|
auto yieldedType = yieldOp->getOperand(0).getType();
|
|
|
|
if (yieldedType != arg0Type) {
|
|
|
|
return yieldOp.emitOpError("mismatch in type of yielded value ")
|
|
|
|
<< yieldedType << " and argument of the region " << arg0Type;
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-11-17 06:40:36 +08:00
|
|
|
SmallVector<utils::IteratorType> ScatterOp::getLoopIteratorTypes() {
|
|
|
|
SmallVector<utils::IteratorType> iteratorTypes(getUpdateType().getRank(),
|
|
|
|
utils::IteratorType::parallel);
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!getUniqueIndices()) {
|
2022-11-17 06:40:36 +08:00
|
|
|
iteratorTypes[0] = utils::IteratorType::reduction;
|
2022-02-21 15:40:02 +08:00
|
|
|
}
|
2022-02-03 07:01:38 +08:00
|
|
|
return iteratorTypes;
|
|
|
|
}
|
|
|
|
|
2022-02-26 07:04:33 +08:00
|
|
|
bool ScatterOp::payloadUsesValueFromOperand(OpOperand *opOperand) {
|
|
|
|
unsigned bbArgNumber;
|
|
|
|
Value operand = opOperand->get();
|
|
|
|
if (operand == updates())
|
|
|
|
bbArgNumber = 0; // block arg 0 is `update`.
|
|
|
|
else {
|
2022-03-10 04:53:30 +08:00
|
|
|
bool isValidOperand = operand == indices() || operand == original();
|
2022-05-26 05:04:59 +08:00
|
|
|
(void)isValidOperand;
|
2022-03-10 04:53:30 +08:00
|
|
|
assert(isValidOperand &&
|
2022-02-26 07:04:33 +08:00
|
|
|
"operand must belong to the current tm_tensor.scatter op");
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
assert(this->getOperation()->getNumRegions() == 1 &&
|
|
|
|
"unexpected "
|
|
|
|
"missing region (calling `payloadUsesValueFromOperand` on "
|
|
|
|
"manually defined named TMTensor op?)");
|
|
|
|
Block &block = this->getOperation()->getRegion(0).front();
|
|
|
|
return !block.getArgument(bbArgNumber).use_empty();
|
|
|
|
}
|
|
|
|
|
2022-02-03 07:01:38 +08:00
|
|
|
SmallVector<Range> ScatterOp::getIterationDomain(OpBuilder &builder) {
|
|
|
|
Location loc = getLoc();
|
|
|
|
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
|
|
|
|
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
|
|
|
|
SmallVector<Range> ranges;
|
|
|
|
for (auto dim : llvm::seq<int64_t>(0, getUpdateType().getRank())) {
|
|
|
|
Value ub = getDimValue(builder, loc, updates(), dim);
|
|
|
|
ranges.emplace_back(Range{zero, ub, one});
|
|
|
|
}
|
|
|
|
return ranges;
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b,
|
|
|
|
Location loc,
|
|
|
|
ValueRange ivs) {
|
|
|
|
auto indexDepth = getIndexDepth();
|
|
|
|
Value update = b.create<memref::LoadOp>(loc, updates(), ivs);
|
|
|
|
SmallVector<Value> starts;
|
|
|
|
SmallVector<Value> loadIndices;
|
|
|
|
loadIndices.push_back(ivs.front());
|
|
|
|
loadIndices.push_back(Value());
|
2022-02-21 15:40:02 +08:00
|
|
|
|
|
|
|
// Populate with empty values.
|
2024-04-28 05:00:56 +08:00
|
|
|
auto originalTy = cast<ShapedType>(original().getType());
|
2022-02-21 15:40:02 +08:00
|
|
|
starts.resize(originalTy.getRank(), Value());
|
|
|
|
auto updateIvs = ivs.drop_front(1);
|
|
|
|
|
|
|
|
int64_t offset = starts.size() - updateIvs.size();
|
|
|
|
for (auto it : llvm::enumerate(updateIvs)) {
|
|
|
|
starts[it.index() + offset] = it.value();
|
|
|
|
}
|
|
|
|
|
2024-02-28 03:46:57 +08:00
|
|
|
ArrayRef<int64_t> dimMap = getDimensionMap();
|
2022-02-03 07:01:38 +08:00
|
|
|
for (auto i : llvm::seq<unsigned>(0, indexDepth)) {
|
|
|
|
loadIndices.back() = b.create<arith::ConstantIndexOp>(loc, i);
|
|
|
|
Value idx = b.create<memref::LoadOp>(loc, indices(), loadIndices);
|
2024-02-28 03:46:57 +08:00
|
|
|
Value ret = b.create<arith::IndexCastOp>(loc, b.getIndexType(), idx);
|
2022-02-21 15:40:02 +08:00
|
|
|
|
2024-02-28 03:46:57 +08:00
|
|
|
auto dim = dimMap[i];
|
|
|
|
if (starts[dim])
|
|
|
|
ret = b.create<arith::AddIOp>(loc, ret, starts[dim]);
|
|
|
|
starts[dim] = ret;
|
2022-02-03 07:01:38 +08:00
|
|
|
}
|
2022-02-21 15:40:02 +08:00
|
|
|
|
2022-02-03 07:01:38 +08:00
|
|
|
Value init = b.create<memref::LoadOp>(loc, original(), starts);
|
|
|
|
|
2023-01-24 08:34:22 +08:00
|
|
|
IRMapping bvm;
|
2022-12-08 04:20:41 +08:00
|
|
|
Block &block = getRegion().front();
|
2022-02-03 07:01:38 +08:00
|
|
|
bvm.map(block.getArgument(0), update);
|
|
|
|
bvm.map(block.getArgument(1), init);
|
|
|
|
for (auto &blockOp : block.without_terminator()) {
|
|
|
|
b.clone(blockOp, bvm);
|
|
|
|
}
|
|
|
|
// The last op is linalg_ext.yield op. Store the operand to
|
|
|
|
// destination.
|
|
|
|
b.create<memref::StoreOp>(
|
|
|
|
loc, bvm.lookupOrDefault(block.getTerminator()->getOperand(0)),
|
|
|
|
original(), starts);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2023-04-03 18:49:01 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// SortOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult SortOp::verify() {
|
|
|
|
Operation *op = getOperation();
|
|
|
|
if (getNumInputs()) {
|
|
|
|
return op->emitOpError("does not expect to take any inputs");
|
|
|
|
}
|
|
|
|
if (getNumOutputs() == 0) {
|
|
|
|
return op->emitOpError("expected at least one `outs` operand");
|
|
|
|
}
|
|
|
|
|
|
|
|
Block &block = getRegion().front();
|
|
|
|
size_t numOutputs = getNumOutputs();
|
|
|
|
if (block.getNumArguments() != 2 * numOutputs) {
|
|
|
|
return op->emitOpError("region block should have ")
|
|
|
|
<< 2 * numOutputs << " arguments";
|
|
|
|
}
|
|
|
|
|
|
|
|
int64_t rank = getOperandRank();
|
|
|
|
int sortDim = getDimension();
|
|
|
|
if (sortDim < 0 || sortDim >= rank) {
|
|
|
|
return op->emitOpError("dimension must be within (0, ") << rank << "]";
|
|
|
|
}
|
|
|
|
|
|
|
|
ArrayRef<int64_t> shape = getOperandShape();
|
|
|
|
for (auto indexedOperand : llvm::enumerate(getOutputs())) {
|
|
|
|
int index = indexedOperand.index();
|
|
|
|
auto operandType = getOperandType(index);
|
|
|
|
if (operandType.getRank() != rank) {
|
|
|
|
return op->emitOpError("expected operand ")
|
|
|
|
<< index << " to be rank " << rank << ", same as other operands";
|
|
|
|
}
|
|
|
|
if (operandType.getShape() != shape) {
|
|
|
|
return op->emitOpError("expected operand ")
|
|
|
|
<< index << " to have same shape as other operands";
|
|
|
|
}
|
|
|
|
Type elemType = operandType.getElementType();
|
|
|
|
for (int i : {2 * index, 2 * index + 1}) {
|
|
|
|
Type argType = block.getArgument(i).getType();
|
|
|
|
if (argType != elemType) {
|
|
|
|
return op->emitOpError("region block argument #")
|
|
|
|
<< i << " should be of type " << elemType << " but got "
|
|
|
|
<< argType;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
auto yieldOp = cast<YieldOp>(block.getTerminator());
|
|
|
|
if (yieldOp.getNumOperands() != 1) {
|
|
|
|
return op->emitOpError("should yield exactly one operand");
|
|
|
|
}
|
2024-04-28 05:00:56 +08:00
|
|
|
auto ty = dyn_cast<IntegerType>(yieldOp.getOperand(0).getType());
|
2023-04-03 18:49:01 +08:00
|
|
|
if (!ty || ty.getWidth() != 1) {
|
|
|
|
return op->emitOpError("should yield i1 type");
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<utils::IteratorType> SortOp::getLoopIteratorTypes() {
|
|
|
|
// All loops except the dimension to sort along are parallel.
|
|
|
|
SmallVector<utils::IteratorType> iteratorTypes(getOperandRank(),
|
|
|
|
utils::IteratorType::parallel);
|
|
|
|
iteratorTypes[getDimension()] = utils::IteratorType::reduction;
|
|
|
|
return iteratorTypes;
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<Range> SortOp::getIterationDomain(OpBuilder &builder) {
|
|
|
|
int64_t operandRank = getOperandRank();
|
|
|
|
SmallVector<Range> loopBounds(operandRank);
|
|
|
|
Location loc = getLoc();
|
|
|
|
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
|
|
|
|
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
|
|
|
|
Value source = operand(0);
|
|
|
|
for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
|
|
|
|
loopBounds[dim].offset = zero;
|
|
|
|
loopBounds[dim].size = getDimValue(builder, loc, source, dim);
|
|
|
|
loopBounds[dim].stride = one;
|
|
|
|
}
|
|
|
|
return loopBounds;
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult SortOp::generateScalarImplementation(OpBuilder &b, Location loc,
|
|
|
|
ValueRange ivs) {
|
|
|
|
auto sortDim = getDimension();
|
|
|
|
SmallVector<Value> indices, sortBlkArgs;
|
|
|
|
indices.append(ivs.begin(), ivs.end());
|
|
|
|
// Bubble sort innermost loop.
|
|
|
|
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
|
|
|
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
|
|
|
|
Value ub;
|
|
|
|
if (getOperandType(0).isDynamicDim(sortDim)) {
|
|
|
|
ub = b.create<memref::DimOp>(loc, operand(0), sortDim);
|
|
|
|
} else {
|
|
|
|
ub = b.create<arith::ConstantIndexOp>(
|
|
|
|
loc, getOperandType(0).getDimSize(sortDim));
|
|
|
|
}
|
|
|
|
ub = b.create<arith::SubIOp>(loc, ub, one);
|
|
|
|
auto scfFor = b.create<scf::ForOp>(
|
|
|
|
loc, zero, ub, one, ValueRange{},
|
|
|
|
[&](OpBuilder &b, Location loc, Value iv, ValueRange iters) {
|
|
|
|
SmallVector<Value> indices(ivs);
|
|
|
|
Value ivPlusOne = b.create<arith::AddIOp>(loc, iv, one);
|
|
|
|
for (auto output : getOutputOperands()) {
|
|
|
|
indices[sortDim] = iv;
|
|
|
|
sortBlkArgs.push_back(
|
|
|
|
b.create<memref::LoadOp>(loc, output->get(), indices));
|
|
|
|
indices[sortDim] = ivPlusOne;
|
|
|
|
sortBlkArgs.push_back(
|
|
|
|
b.create<memref::LoadOp>(loc, output->get(), indices));
|
|
|
|
}
|
|
|
|
});
|
|
|
|
|
|
|
|
auto &srcBlock = getRegion().front();
|
|
|
|
Region ®ion = scfFor.getRegion();
|
|
|
|
IRMapping bvm;
|
|
|
|
{
|
|
|
|
OpBuilder::InsertionGuard guard(b);
|
|
|
|
auto &block = region.front();
|
|
|
|
b.setInsertionPointToEnd(&block);
|
|
|
|
for (auto it : llvm::zip(srcBlock.getArguments(), sortBlkArgs)) {
|
|
|
|
bvm.map(std::get<0>(it), std::get<1>(it));
|
|
|
|
}
|
|
|
|
for (auto &blockOp : srcBlock.without_terminator()) {
|
|
|
|
b.clone(blockOp, bvm);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
Value cond = bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0));
|
|
|
|
|
|
|
|
OpBuilder::InsertionGuard g(b);
|
|
|
|
b.setInsertionPointToEnd(®ion.front());
|
|
|
|
b.create<scf::IfOp>(
|
|
|
|
loc, cond,
|
|
|
|
[&](OpBuilder &b, Location loc) {
|
|
|
|
// Do not swap the pairs if true.
|
|
|
|
b.create<scf::YieldOp>(loc);
|
|
|
|
},
|
|
|
|
[&](OpBuilder &b, Location loc) {
|
|
|
|
// Swap the pairs if false.
|
|
|
|
SmallVector<Value> indices(ivs.begin(), ivs.end());
|
|
|
|
Value ivPlusOne =
|
|
|
|
b.create<arith::AddIOp>(loc, scfFor.getInductionVar(), one);
|
|
|
|
for (int i = 0, e = getNumOutputs(); i < e; ++i) {
|
|
|
|
Value v1 = sortBlkArgs[i * 2];
|
|
|
|
Value v2 = sortBlkArgs[i * 2 + 1];
|
|
|
|
indices[sortDim] = scfFor.getInductionVar();
|
|
|
|
b.create<memref::StoreOp>(loc, v2, getOutputOperand(i)->get(),
|
|
|
|
indices);
|
|
|
|
indices[sortDim] = ivPlusOne;
|
|
|
|
b.create<memref::StoreOp>(loc, v1, getOutputOperand(i)->get(),
|
|
|
|
indices);
|
|
|
|
}
|
|
|
|
b.create<scf::YieldOp>(loc);
|
|
|
|
});
|
|
|
|
b.create<scf::YieldOp>(loc);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
bool SortOp::payloadUsesValueFromOperand(OpOperand *opOperand) {
|
|
|
|
// All operands of SortOp will be sorted. So, we'll end up loading/storing
|
|
|
|
// from them - hence setting this utility to always return `true`.
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2024-06-15 13:48:39 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// TopkOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult TopkOp::verify() {
|
|
|
|
Operation *op = getOperation();
|
|
|
|
if (getNumInputs() != 1 && getNumInputs() != 2) {
|
|
|
|
return op->emitOpError("expected one or two input operands");
|
|
|
|
}
|
|
|
|
if (getNumOutputs() != 2) {
|
|
|
|
return op->emitOpError("expected two output operands");
|
|
|
|
}
|
|
|
|
// First check added to eliminate comparison of different int types
|
|
|
|
if (getInputRank() < 0 ||
|
|
|
|
(getDimension() >= static_cast<uint64_t>(getInputRank()))) {
|
|
|
|
return op->emitOpError("dimension exceeds rank");
|
|
|
|
}
|
|
|
|
// Ensure input/output element types match
|
|
|
|
auto inputValuesType = cast<ShapedType>(values().getType());
|
|
|
|
auto outputValuesType = cast<ShapedType>(outputValues().getType());
|
|
|
|
if (inputValuesType.getElementType() != outputValuesType.getElementType()) {
|
|
|
|
return op->emitOpError("expected input/output value types to be identical");
|
|
|
|
}
|
|
|
|
// Indices must be int if provided
|
|
|
|
auto outputIndicesType = cast<ShapedType>(outputIndices().getType());
|
|
|
|
if (auto inputIndices = indices()) {
|
|
|
|
auto inputIndicesType = cast<ShapedType>(inputIndices->getType());
|
|
|
|
if (!inputIndicesType.getElementType().isInteger(32) ||
|
|
|
|
!outputIndicesType.getElementType().isInteger(32)) {
|
|
|
|
return op->emitOpError("expected input/output indices types to be int32");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Ranks must match
|
|
|
|
if (inputValuesType.getRank() != outputValuesType.getRank()) {
|
|
|
|
return op->emitOpError("expected input/output to have the same rank");
|
|
|
|
}
|
|
|
|
if (auto inputIndices = indices()) {
|
|
|
|
auto inputIndicesType = cast<ShapedType>(inputIndices->getType());
|
|
|
|
if (inputIndicesType.getRank() != outputIndicesType.getRank()) {
|
|
|
|
return op->emitOpError("expected input/output to have the same rank");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// Input indicies and values must have the same shape.
|
|
|
|
if (auto inputIndices = indices()) {
|
|
|
|
auto inputIndicesType = cast<ShapedType>(inputIndices->getType());
|
|
|
|
if (failed(verifyCompatibleShape(inputValuesType, inputIndicesType)))
|
|
|
|
return op->emitOpError("input indices/values shape must match");
|
|
|
|
}
|
|
|
|
// Output indicies and values must have the same shape.
|
|
|
|
if (failed(verifyCompatibleShape(outputValuesType, outputIndicesType)))
|
|
|
|
return op->emitOpError("output indices/values shape must match");
|
|
|
|
// Input shape must match the output shape except for the dimension()
|
|
|
|
uint64_t dim = getDimension();
|
|
|
|
if (!llvm::all_of(llvm::enumerate(llvm::zip(inputValuesType.getShape(),
|
|
|
|
outputValuesType.getShape())),
|
|
|
|
[dim](auto e) {
|
|
|
|
if (e.index() == dim) {
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
std::tuple<int64_t, int64_t> s = e.value();
|
|
|
|
return succeeded(verifyCompatibleShape(std::get<0>(s),
|
|
|
|
|
|
|
|
std::get<1>(s)));
|
|
|
|
})) {
|
|
|
|
return op->emitOpError("incompatible input/output shapes");
|
|
|
|
}
|
|
|
|
// Check region compatibility
|
|
|
|
Block &block = getRegion().front();
|
|
|
|
if (block.getNumArguments() != 2) {
|
|
|
|
return op->emitOpError("region block should have 2 arguments");
|
|
|
|
}
|
|
|
|
if (block.getArgument(0).getType() != inputValuesType.getElementType() ||
|
|
|
|
block.getArgument(1).getType() != inputValuesType.getElementType()) {
|
|
|
|
return op->emitOpError("region block types must match input");
|
|
|
|
}
|
|
|
|
auto terminatorOp = llvm::dyn_cast<YieldOp>(block.getTerminator());
|
|
|
|
if (!terminatorOp || !terminatorOp.getOperand(0).getType().isInteger(1)) {
|
|
|
|
return op->emitOpError("region block must end with a linalg_ext.yield i1!");
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<utils::IteratorType> TopkOp::getLoopIteratorTypes() {
|
|
|
|
SmallVector<utils::IteratorType> iteratorTypes(getInputRank(),
|
|
|
|
utils::IteratorType::parallel);
|
|
|
|
iteratorTypes[getDimension()] = utils::IteratorType::reduction;
|
|
|
|
return iteratorTypes;
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<Range> TopkOp::getIterationDomain(OpBuilder &builder) {
|
|
|
|
int64_t operandRank = getInputRank();
|
|
|
|
SmallVector<Range> loopBounds(operandRank);
|
|
|
|
Location loc = getLoc();
|
|
|
|
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
|
|
|
|
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
|
|
|
|
Value source = values();
|
|
|
|
for (auto dim : llvm::enumerate(getInputType().getShape())) {
|
|
|
|
loopBounds[dim.index()].offset = zero;
|
|
|
|
loopBounds[dim.index()].size =
|
|
|
|
getDimValue(builder, loc, source, dim.index());
|
|
|
|
loopBounds[dim.index()].stride = one;
|
|
|
|
}
|
|
|
|
return loopBounds;
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult TopkOp::generateScalarImplementation(OpBuilder &b, Location loc,
|
|
|
|
ValueRange ivs) {
|
|
|
|
uint64_t kDim = getDimension();
|
|
|
|
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
|
|
|
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
|
|
|
|
Value initialValue = b.create<memref::LoadOp>(loc, values(), ivs);
|
|
|
|
|
|
|
|
// If the indices tensor is not provided, the value index is derived from the
|
|
|
|
// loop induction variables.
|
|
|
|
Value initialIndex;
|
|
|
|
if (indices()) {
|
|
|
|
initialIndex = b.create<memref::LoadOp>(loc, *indices(), ivs);
|
|
|
|
} else {
|
|
|
|
Value rawInitialIndex = ivs[kDim];
|
|
|
|
initialIndex =
|
|
|
|
b.create<arith::IndexCastOp>(loc, b.getI32Type(), rawInitialIndex);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Compute K (ub) from the selected dim of the output
|
|
|
|
Value ub = b.create<memref::DimOp>(loc, outputValues(), getDimension());
|
|
|
|
|
|
|
|
// Inner K loop functions:
|
|
|
|
// Load current K value and index
|
|
|
|
// Compare N/K using inserted block compare
|
|
|
|
// Check if N == K using strict weak ordering, select which index came first
|
|
|
|
// Select new K value from N/K comparison
|
|
|
|
// Select new K index from N/K comparison or which index came first
|
|
|
|
// Store new k value and index
|
|
|
|
// Yield loop carry values after K selection
|
|
|
|
Value kValue, kIndex;
|
|
|
|
auto scfFor = b.create<scf::ForOp>(
|
|
|
|
loc, zero, ub, one, ValueRange{initialValue, initialIndex},
|
|
|
|
[&](OpBuilder &b, Location loc, Value iv, ValueRange loopCarryValues) {
|
|
|
|
SmallVector<Value> indices(ivs);
|
|
|
|
indices[kDim] = iv;
|
|
|
|
kValue = b.create<memref::LoadOp>(loc, outputValues(), indices);
|
|
|
|
kIndex = b.create<memref::LoadOp>(loc, outputIndices(), indices);
|
|
|
|
});
|
|
|
|
|
|
|
|
SmallVector<Value> indices(ivs);
|
|
|
|
indices[kDim] = scfFor.getInductionVar();
|
|
|
|
auto loopCarryValues = scfFor.getRegionIterArgs();
|
|
|
|
|
|
|
|
// Retrieve region as black box comparision function f(x,y). Plug into op.
|
|
|
|
auto &srcBlock = getRegion().front();
|
|
|
|
IRMapping bvmF; // f(x,y)
|
|
|
|
IRMapping bvmR; // f(y,x)
|
|
|
|
{
|
|
|
|
// Save previous insertion point. Continue within loop body.
|
|
|
|
OpBuilder::InsertionGuard guard(b);
|
|
|
|
b.setInsertionPointToEnd(&scfFor.getRegion().front());
|
|
|
|
SmallVector<Value> forwardValues{loopCarryValues[0], kValue};
|
|
|
|
SmallVector<Value> reverseValues{kValue, loopCarryValues[0]};
|
|
|
|
for (auto it : llvm::zip(srcBlock.getArguments(), forwardValues)) {
|
|
|
|
bvmF.map(std::get<0>(it), std::get<1>(it));
|
|
|
|
}
|
|
|
|
for (auto it : llvm::zip(srcBlock.getArguments(), reverseValues)) {
|
|
|
|
bvmR.map(std::get<0>(it), std::get<1>(it));
|
|
|
|
}
|
|
|
|
for (auto &blockOp : srcBlock.without_terminator()) {
|
|
|
|
b.clone(blockOp, bvmF);
|
|
|
|
b.clone(blockOp, bvmR);
|
|
|
|
}
|
|
|
|
Value forwardCmpRes = bvmF.lookup(srcBlock.getTerminator()->getOperand(0));
|
|
|
|
Value reverseCmpRes = bvmR.lookup(srcBlock.getTerminator()->getOperand(0));
|
|
|
|
|
|
|
|
// Check value equality using strictly weak ordering from the region:
|
|
|
|
// f(x,y) --> forwardCmpRes
|
|
|
|
// f(y,x) --> reverseCmpRes
|
|
|
|
// if forwardCmpRes == reverseCmpRes then select which came first
|
|
|
|
Value cmpValuesEqual = b.create<arith::CmpIOp>(
|
|
|
|
loc, arith::CmpIPredicate::eq, forwardCmpRes, reverseCmpRes);
|
|
|
|
Value cmpFirstIndex = b.create<arith::CmpIOp>(
|
|
|
|
loc, arith::CmpIPredicate::slt, loopCarryValues[1], kIndex);
|
|
|
|
Value combinedCmpEqRes =
|
|
|
|
b.create<arith::AndIOp>(loc, cmpValuesEqual, cmpFirstIndex);
|
|
|
|
// True if N > K or N came before K
|
|
|
|
Value indexCmpRes =
|
|
|
|
b.create<arith::OrIOp>(loc, forwardCmpRes, combinedCmpEqRes);
|
|
|
|
// Select results for K based on comparisons
|
|
|
|
Value resultKValue = b.create<arith::SelectOp>(loc, forwardCmpRes,
|
|
|
|
loopCarryValues[0], kValue);
|
|
|
|
Value resultKIndex =
|
|
|
|
b.create<arith::SelectOp>(loc, indexCmpRes, loopCarryValues[1], kIndex);
|
|
|
|
b.create<memref::StoreOp>(loc, resultKValue, outputValues(), indices);
|
|
|
|
b.create<memref::StoreOp>(loc, resultKIndex, outputIndices(), indices);
|
|
|
|
// Select loop carry, opposite of K results
|
|
|
|
Value resultCarryValue = b.create<arith::SelectOp>(
|
|
|
|
loc, forwardCmpRes, kValue, loopCarryValues[0]);
|
|
|
|
Value resultCarryIndex =
|
|
|
|
b.create<arith::SelectOp>(loc, indexCmpRes, kIndex, loopCarryValues[1]);
|
|
|
|
b.create<scf::YieldOp>(loc, ValueRange{resultCarryValue, resultCarryIndex});
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
bool TopkOp::payloadUsesValueFromOperand(OpOperand *opOperand) {
|
|
|
|
// Set to true so that output operands are always initialized.
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2022-02-03 07:01:38 +08:00
|
|
|
#define DEFINE_OP_GET_EFFECTS(OP_NAME) \
|
|
|
|
void OP_NAME::getEffects( \
|
|
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> \
|
|
|
|
&effects) { \
|
2024-06-28 10:28:02 +08:00
|
|
|
OpOperandVector inputBuffers = getInputBufferOperands(); \
|
|
|
|
OpOperandVector outputBuffers = getOutputBufferOperands(); \
|
2022-02-03 07:01:38 +08:00
|
|
|
getEffectsImpl(effects, getOperation()->getResults(), inputBuffers, \
|
|
|
|
outputBuffers); \
|
|
|
|
}
|
|
|
|
|
2023-05-17 03:17:45 +08:00
|
|
|
DEFINE_OP_GET_EFFECTS(AttentionOp)
|
2022-02-03 07:01:38 +08:00
|
|
|
DEFINE_OP_GET_EFFECTS(ScanOp)
|
|
|
|
DEFINE_OP_GET_EFFECTS(ScatterOp)
|
2023-04-03 18:49:01 +08:00
|
|
|
DEFINE_OP_GET_EFFECTS(SortOp)
|
2024-06-15 13:48:39 +08:00
|
|
|
DEFINE_OP_GET_EFFECTS(TopkOp)
|
2022-02-03 07:01:38 +08:00
|
|
|
|
|
|
|
namespace {
|
|
|
|
/// This is derived from mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp without any
|
|
|
|
/// changes.
|
|
|
|
struct FoldTensorCastOp : public OpInterfaceRewritePattern<TMTensorOp> {
|
|
|
|
using OpInterfaceRewritePattern<TMTensorOp>::OpInterfaceRewritePattern;
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(TMTensorOp op,
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
// If no operand comes from a tensor::CastOp and can be folded then fail.
|
|
|
|
bool hasTensorCastOperand =
|
|
|
|
llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
|
2024-04-30 03:01:40 +08:00
|
|
|
if (isa<BlockArgument>(opOperand->get()))
|
2022-02-03 07:01:38 +08:00
|
|
|
return false;
|
|
|
|
auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>();
|
|
|
|
return castOp && canFoldIntoConsumerOp(castOp);
|
|
|
|
});
|
|
|
|
if (!hasTensorCastOperand)
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
SmallVector<Type, 4> newResultTypes;
|
|
|
|
newResultTypes.reserve(op->getNumResults());
|
|
|
|
SmallVector<Value, 4> newOperands;
|
|
|
|
newOperands.reserve(op->getNumOperands());
|
|
|
|
// Inputs may fold.
|
|
|
|
for (OpOperand *opOperand : op.getInputOperands()) {
|
|
|
|
auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
|
|
|
|
newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp)
|
2022-08-16 14:54:45 +08:00
|
|
|
? tensorCastOp.getSource()
|
2022-02-03 07:01:38 +08:00
|
|
|
: opOperand->get());
|
|
|
|
}
|
|
|
|
// Init tensors may fold, in which case the resultType must also change.
|
|
|
|
for (OpOperand *opOperand : op.getOutputOperands()) {
|
|
|
|
auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
|
|
|
|
bool fold = canFoldIntoConsumerOp(tensorCastOp);
|
|
|
|
newOperands.push_back(fold ? tensorCastOp.getOperand()
|
|
|
|
: opOperand->get());
|
|
|
|
newResultTypes.push_back(newOperands.back().getType());
|
|
|
|
}
|
|
|
|
// Clone op.
|
|
|
|
Operation *newOp =
|
|
|
|
op.clone(rewriter, op->getLoc(), newResultTypes, newOperands);
|
|
|
|
SmallVector<Value, 4> replacements;
|
|
|
|
replacements.reserve(newOp->getNumResults());
|
|
|
|
for (auto result : llvm::zip(op->getResults(), newOp->getResults())) {
|
|
|
|
Value oldResult = std::get<0>(result);
|
|
|
|
Value newResult = std::get<1>(result);
|
|
|
|
if (newResult.getType() != oldResult.getType()) {
|
|
|
|
replacements.push_back(rewriter.create<tensor::CastOp>(
|
|
|
|
op->getLoc(), oldResult.getType(), newResult));
|
|
|
|
} else {
|
|
|
|
replacements.push_back(newResult);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
rewriter.replaceOp(op, replacements);
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// TMTensorDialect
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void TMTensorDialect::getCanonicalizationPatterns(
|
|
|
|
RewritePatternSet &results) const {
|
|
|
|
results.add<FoldTensorCastOp>(getContext());
|
|
|
|
}
|
|
|
|
|
|
|
|
#define GET_OP_CLASSES
|
|
|
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.cpp.inc"
|