mirror of https://github.com/llvm/torch-mlir
Add TMTensor::Attention and lower ScaledDotProductAttentionOp to it (#2027)
parent
c76a48308e
commit
0302cf1d92
|
@ -244,6 +244,90 @@ def TMTensor_SortOp : TMTensor_Op<"sort",
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TMTensor_AttentionOp : TMTensor_Op<"attention",
|
||||||
|
[DeclareOpInterfaceMethods<TMTensorInterface,
|
||||||
|
["payloadUsesValueFromOperand"]>,
|
||||||
|
DeclareOpInterfaceMethods<ScalarLoopOpInterface,
|
||||||
|
["generateScalarImplementation"]>]> {
|
||||||
|
let summary = "Attention operator";
|
||||||
|
let description = [{
|
||||||
|
This operator takes in 3 tensors: query(Q), key(K) and value(V) and computes
|
||||||
|
the attention. Each of the inputs has shape BxNxd where B is the
|
||||||
|
of the batch dimension, N is the sequence length and d is head dimension.
|
||||||
|
Typically N >>> d. Mathematically, the attention is defined as
|
||||||
|
matmul(softmax(matmul(Q, transpose(K))), V) and has shape BxNxd. Usually,
|
||||||
|
this operator also performs scaling, masking and dropout, but we leave
|
||||||
|
that out of the current implementation.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins Variadic<AnyShaped>:$inputs,
|
||||||
|
Variadic<AnyShaped>:$outputs
|
||||||
|
);
|
||||||
|
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs)>
|
||||||
|
];
|
||||||
|
|
||||||
|
let results = (outs Variadic<AnyRankedTensor>:$result);
|
||||||
|
let assemblyFormat = [{
|
||||||
|
attr-dict
|
||||||
|
`ins` `(` $inputs `:` type($inputs) `)`
|
||||||
|
`outs` `(` $outputs `:` type($outputs) `)`
|
||||||
|
(`->` type($result)^)?
|
||||||
|
}];
|
||||||
|
|
||||||
|
let extraClassDeclaration = extraTMTensorOpClassDeclaration # [{
|
||||||
|
Value getQuery() {
|
||||||
|
return getInputOperand(0)->get();
|
||||||
|
}
|
||||||
|
Value getKey() {
|
||||||
|
return getInputOperand(1)->get();
|
||||||
|
}
|
||||||
|
Value getValue() {
|
||||||
|
return getInputOperand(2)->get();
|
||||||
|
}
|
||||||
|
Value getOutput() {
|
||||||
|
return getOutputOperand(0)->get();
|
||||||
|
}
|
||||||
|
ShapedType getQueryType() {
|
||||||
|
return getQuery().getType().cast<ShapedType>();
|
||||||
|
}
|
||||||
|
ShapedType getKeyType() {
|
||||||
|
return getKey().getType().cast<ShapedType>();
|
||||||
|
}
|
||||||
|
ShapedType getValueType() {
|
||||||
|
return getValue().getType().cast<ShapedType>();
|
||||||
|
}
|
||||||
|
ShapedType getOutputType() {
|
||||||
|
return getOutput().getType().cast<ShapedType>();
|
||||||
|
}
|
||||||
|
int64_t getQueryRank() {
|
||||||
|
return getQueryType().getRank();
|
||||||
|
}
|
||||||
|
int64_t getKeyRank() {
|
||||||
|
return getKeyType().getRank();
|
||||||
|
}
|
||||||
|
int64_t getValueRank() {
|
||||||
|
return getValueType().getRank();
|
||||||
|
}
|
||||||
|
int64_t getOutputRank() {
|
||||||
|
return getOutputType().getRank();
|
||||||
|
}
|
||||||
|
int64_t getIterationDomainRank() {
|
||||||
|
return 2;
|
||||||
|
};
|
||||||
|
// Method to implement for specifying output range for
|
||||||
|
// DestinationStyleOpInterface
|
||||||
|
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
|
||||||
|
std::pair<unsigned, unsigned> outputsIndexAndLength =
|
||||||
|
getODSOperandIndexAndLength(1);
|
||||||
|
return std::make_pair<int64_t, int64_t>(
|
||||||
|
outputsIndexAndLength.first,
|
||||||
|
outputsIndexAndLength.first + outputsIndexAndLength.second);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Pure ops
|
// Pure ops
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -84,6 +84,243 @@ OpFoldResult TMTensor::getDim(OpBuilder &builder, Location loc, Value v,
|
||||||
return builder.getI64IntegerAttr(t.getDimSize(dim));
|
return builder.getI64IntegerAttr(t.getDimSize(dim));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AttentionOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
LogicalResult AttentionOp::verify() {
|
||||||
|
Operation *op = getOperation();
|
||||||
|
ShapedType queryType = getQueryType();
|
||||||
|
ShapedType keyType = getKeyType();
|
||||||
|
ShapedType valueType = getValueType();
|
||||||
|
ShapedType outputType = getOutputType();
|
||||||
|
ArrayRef<int64_t> queryShape = queryType.getShape();
|
||||||
|
ArrayRef<int64_t> keyShape = keyType.getShape();
|
||||||
|
ArrayRef<int64_t> valueShape = valueType.getShape();
|
||||||
|
ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||||
|
if (failed(verifyCompatibleShape(queryShape, keyShape)))
|
||||||
|
return op->emitOpError("incompatible key shape");
|
||||||
|
if (failed(verifyCompatibleShape(queryShape, valueShape)))
|
||||||
|
return op->emitOpError("incompatible value shape");
|
||||||
|
if (failed(verifyCompatibleShape(queryShape, outputShape)))
|
||||||
|
return op->emitOpError("incompatible output shape");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Range> AttentionOp::getIterationDomain(OpBuilder &builder) {
|
||||||
|
int64_t iterationDomainRank = getIterationDomainRank();
|
||||||
|
SmallVector<Range> loopBounds(iterationDomainRank);
|
||||||
|
Location loc = getLoc();
|
||||||
|
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
|
||||||
|
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
|
||||||
|
Value source = getQuery();
|
||||||
|
for (auto dim : llvm::seq<int64_t>(0, iterationDomainRank)) {
|
||||||
|
loopBounds[dim].offset = zero;
|
||||||
|
loopBounds[dim].size = getDimValue(builder, loc, source, dim);
|
||||||
|
loopBounds[dim].stride = one;
|
||||||
|
}
|
||||||
|
return loopBounds;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<utils::IteratorType> AttentionOp::getLoopIteratorTypes() {
|
||||||
|
SmallVector<utils::IteratorType> iteratorTypes(getIterationDomainRank(),
|
||||||
|
utils::IteratorType::parallel);
|
||||||
|
return iteratorTypes;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AttentionOp::payloadUsesValueFromOperand(OpOperand *opOperand) {
|
||||||
|
Value operand = opOperand->get();
|
||||||
|
return operand == getQuery() || operand == getKey() || operand == getValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Performs a matmul between lhs and rhs
|
||||||
|
// Note that "transposed" means the last two dims of rhs are swapped
|
||||||
|
static void matmul(OpBuilder &b, Location loc, Value lhs, ValueRange lhsSizes,
|
||||||
|
Value rhs, ValueRange rhsSizes, Value output,
|
||||||
|
ValueRange outputSizes, bool transposed = false) {
|
||||||
|
auto elementType = lhs.getType().cast<MemRefType>().getElementType();
|
||||||
|
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
|
||||||
|
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
||||||
|
auto rank = outputSizes.size();
|
||||||
|
Value reductionDimSize = lhsSizes[lhsSizes.size() - 1];
|
||||||
|
|
||||||
|
// Loop over output
|
||||||
|
b.create<scf::ParallelOp>(
|
||||||
|
loc, SmallVector<Value>(rank, zero), outputSizes,
|
||||||
|
SmallVector<Value>(rank, one),
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
|
||||||
|
Value acc = b.create<arith::ConstantOp>(
|
||||||
|
loc, elementType, b.getFloatAttr(elementType, 0.0));
|
||||||
|
Value sum =
|
||||||
|
b.create<scf::ForOp>(
|
||||||
|
loc, zero, reductionDimSize, one, SmallVector<Value>{acc},
|
||||||
|
[&](OpBuilder &b, Location loc, Value i, ValueRange accs) {
|
||||||
|
SmallVector<Value> lhsIVs(localIVs), rhsIVs(localIVs);
|
||||||
|
lhsIVs[lhsIVs.size() - 1] = i;
|
||||||
|
rhsIVs[rhsIVs.size() - 2] = i;
|
||||||
|
if (transposed)
|
||||||
|
std::swap(rhsIVs[rhsIVs.size() - 1],
|
||||||
|
rhsIVs[rhsIVs.size() - 2]);
|
||||||
|
|
||||||
|
Value acc = accs[0];
|
||||||
|
Value rElem = b.create<memref::LoadOp>(loc, lhs, lhsIVs);
|
||||||
|
Value cElem = b.create<memref::LoadOp>(loc, rhs, rhsIVs);
|
||||||
|
Value x = b.create<arith::MulFOp>(loc, rElem, cElem);
|
||||||
|
x = b.create<arith::AddFOp>(loc, x, acc);
|
||||||
|
|
||||||
|
b.create<scf::YieldOp>(loc, x);
|
||||||
|
})
|
||||||
|
->getResult(0);
|
||||||
|
b.create<memref::StoreOp>(loc, sum, output, localIVs);
|
||||||
|
b.create<scf::YieldOp>(loc);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
|
||||||
|
Location loc,
|
||||||
|
ValueRange ivs) {
|
||||||
|
|
||||||
|
Value query = getQuery();
|
||||||
|
Value key = getKey();
|
||||||
|
Value value = getValue();
|
||||||
|
Value output = getOutput();
|
||||||
|
auto queryType = query.getType().cast<MemRefType>();
|
||||||
|
auto keyType = key.getType().cast<MemRefType>();
|
||||||
|
auto valueType = value.getType().cast<MemRefType>();
|
||||||
|
auto queryRank = queryType.getRank();
|
||||||
|
auto keyRank = keyType.getRank();
|
||||||
|
auto valueRank = valueType.getRank();
|
||||||
|
auto keySizes = keyType.getShape();
|
||||||
|
Type elementType = queryType.getElementType();
|
||||||
|
|
||||||
|
Value zeroF = b.create<arith::ConstantOp>(loc, elementType,
|
||||||
|
b.getFloatAttr(elementType, 0.0));
|
||||||
|
|
||||||
|
SmallVector<Value> queryDynSizes, keyDynSizes, valueDynSizes, outputDynSizes;
|
||||||
|
for (auto i = 0; i < queryRank; i++)
|
||||||
|
queryDynSizes.push_back(b.create<memref::DimOp>(loc, query, i));
|
||||||
|
for (auto i = 0; i < keyRank; i++)
|
||||||
|
keyDynSizes.push_back(b.create<memref::DimOp>(loc, key, i));
|
||||||
|
for (auto i = 0; i < valueRank; i++)
|
||||||
|
valueDynSizes.push_back(b.create<memref::DimOp>(loc, value, i));
|
||||||
|
for (auto i = 0; i < queryRank; i++)
|
||||||
|
outputDynSizes.push_back(b.create<memref::DimOp>(loc, output, i));
|
||||||
|
|
||||||
|
// weight = query @ key
|
||||||
|
auto weightRank = queryRank;
|
||||||
|
auto weightSizes = SmallVector<int64_t>(queryType.getShape());
|
||||||
|
weightSizes[weightRank - 1] = keySizes[keyRank - 2];
|
||||||
|
auto weightType = MemRefType::get(weightSizes, queryType.getElementType());
|
||||||
|
SmallVector<Value> weightDynSizes(queryDynSizes);
|
||||||
|
weightDynSizes[weightRank - 1] = keyDynSizes[keyRank - 2];
|
||||||
|
Value weight = b.create<memref::AllocOp>(loc, weightType, weightDynSizes);
|
||||||
|
matmul(b, loc, query, queryDynSizes, key, keyDynSizes, weight, weightDynSizes,
|
||||||
|
/*transposed=*/true);
|
||||||
|
|
||||||
|
// weight = softmax(weight)
|
||||||
|
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
|
||||||
|
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
||||||
|
Value dim = weightDynSizes[weightRank - 1];
|
||||||
|
Value scaleFactor = b.create<math::SqrtOp>(
|
||||||
|
loc, b.create<arith::UIToFPOp>(
|
||||||
|
loc, elementType,
|
||||||
|
b.create<arith::IndexCastUIOp>(loc, b.getI32Type(),
|
||||||
|
queryDynSizes[queryRank - 1])));
|
||||||
|
// calculate max(weight)
|
||||||
|
Value init = b.create<memref::LoadOp>(loc, weight,
|
||||||
|
SmallVector<Value>(weightRank, zero));
|
||||||
|
Value globalMax =
|
||||||
|
b.create<scf::ParallelOp>(
|
||||||
|
loc, SmallVector<Value>(weightRank, zero), weightDynSizes,
|
||||||
|
SmallVector<Value>(weightRank, one), init,
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange localIVs,
|
||||||
|
ValueRange accs) {
|
||||||
|
b.create<scf::ReduceOp>(
|
||||||
|
loc, init,
|
||||||
|
[&](OpBuilder &b, Location loc, Value elem, Value acc) {
|
||||||
|
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
|
||||||
|
Value max = b.create<arith::MaxFOp>(loc, x, acc);
|
||||||
|
b.create<scf::ReduceReturnOp>(loc, max);
|
||||||
|
});
|
||||||
|
})
|
||||||
|
.getResult(0);
|
||||||
|
// weight = (weight - max(weight)) / math.sqrt(querySizes[-1])
|
||||||
|
b.create<scf::ParallelOp>(
|
||||||
|
loc, SmallVector<Value>(weightRank, zero), weightDynSizes,
|
||||||
|
SmallVector<Value>(weightRank, one),
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
|
||||||
|
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
|
||||||
|
x = b.create<arith::SubFOp>(loc, x, globalMax);
|
||||||
|
x = b.create<arith::DivFOp>(loc, x, scaleFactor);
|
||||||
|
b.create<memref::StoreOp>(loc, x, weight, localIVs);
|
||||||
|
b.create<scf::YieldOp>(loc);
|
||||||
|
});
|
||||||
|
// calculate exp(weight)
|
||||||
|
SmallVector<Value> min(weightRank, zero),
|
||||||
|
max(weightDynSizes.begin(), weightDynSizes.end()), steps(weightRank, one);
|
||||||
|
b.create<scf::ParallelOp>(
|
||||||
|
loc, min, max, steps,
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
|
||||||
|
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
|
||||||
|
x = b.create<math::ExpOp>(loc, x);
|
||||||
|
b.create<memref::StoreOp>(loc, x, weight, localIVs);
|
||||||
|
b.create<scf::YieldOp>(loc);
|
||||||
|
});
|
||||||
|
Value expWeightSum = b.create<memref::AllocOp>(
|
||||||
|
loc,
|
||||||
|
MemRefType::get(
|
||||||
|
SmallVector<int64_t>(weightSizes.begin(), weightSizes.end() - 1),
|
||||||
|
elementType),
|
||||||
|
SmallVector<Value>{weightDynSizes.begin(), weightDynSizes.end() - 1});
|
||||||
|
b.create<scf::ParallelOp>(
|
||||||
|
loc, SmallVector<Value>(weightRank - 1, zero),
|
||||||
|
SmallVector<Value>{weightDynSizes.begin(), weightDynSizes.end() - 1},
|
||||||
|
SmallVector<Value>(weightRank - 1, one),
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
|
||||||
|
b.create<memref::StoreOp>(loc, zeroF, expWeightSum, localIVs);
|
||||||
|
});
|
||||||
|
// Loop over all dims but -1
|
||||||
|
b.create<scf::ParallelOp>(
|
||||||
|
loc, SmallVector<Value>(weightRank - 1, zero),
|
||||||
|
SmallVector<Value>(weightDynSizes.begin(), weightDynSizes.end() - 1),
|
||||||
|
SmallVector<Value>(weightRank - 1, one),
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange outsideDims) {
|
||||||
|
// Sum over last dim
|
||||||
|
b.create<scf::ParallelOp>(
|
||||||
|
loc, zero, dim, one,
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
|
||||||
|
SmallVector<Value> coords(outsideDims);
|
||||||
|
coords.push_back(localIVs[0]);
|
||||||
|
Value x =
|
||||||
|
b.create<memref::LoadOp>(loc, expWeightSum, outsideDims);
|
||||||
|
Value y = b.create<memref::LoadOp>(loc, weight, coords);
|
||||||
|
Value sum = b.create<arith::AddFOp>(loc, x, y);
|
||||||
|
b.create<memref::StoreOp>(loc, sum, expWeightSum, outsideDims);
|
||||||
|
b.create<scf::YieldOp>(loc);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
// calculate exp(weight) / sum(exp(weight))
|
||||||
|
b.create<scf::ParallelOp>(
|
||||||
|
loc, SmallVector<Value>(weightRank, zero),
|
||||||
|
SmallVector<Value>(weightDynSizes.begin(), weightDynSizes.end()),
|
||||||
|
SmallVector<Value>(weightRank, one),
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
|
||||||
|
SmallVector<Value> sumIVs(localIVs);
|
||||||
|
sumIVs.pop_back();
|
||||||
|
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
|
||||||
|
Value sum = b.create<memref::LoadOp>(loc, expWeightSum, sumIVs);
|
||||||
|
x = b.create<arith::DivFOp>(loc, x, sum);
|
||||||
|
b.create<memref::StoreOp>(loc, x, weight, localIVs);
|
||||||
|
b.create<scf::YieldOp>(loc);
|
||||||
|
});
|
||||||
|
|
||||||
|
// output = weight @ value
|
||||||
|
matmul(b, loc, weight, weightDynSizes, value, valueDynSizes, output,
|
||||||
|
outputDynSizes, /*transposed=*/false);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ScanOp
|
// ScanOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -652,6 +889,7 @@ bool SortOp::payloadUsesValueFromOperand(OpOperand *opOperand) {
|
||||||
outputBuffers); \
|
outputBuffers); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DEFINE_OP_GET_EFFECTS(AttentionOp)
|
||||||
DEFINE_OP_GET_EFFECTS(ScanOp)
|
DEFINE_OP_GET_EFFECTS(ScanOp)
|
||||||
DEFINE_OP_GET_EFFECTS(ScatterOp)
|
DEFINE_OP_GET_EFFECTS(ScatterOp)
|
||||||
DEFINE_OP_GET_EFFECTS(SortOp)
|
DEFINE_OP_GET_EFFECTS(SortOp)
|
||||||
|
|
|
@ -9040,6 +9040,35 @@ def Torch_AtenUpsampleNearest2dOp : Torch_Op<"aten.upsample_nearest2d", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenScaledDotProductAttentionOp : Torch_Op<"aten.scaled_dot_product_attention", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$query,
|
||||||
|
AnyTorchTensorType:$key,
|
||||||
|
AnyTorchTensorType:$value,
|
||||||
|
AnyTorchOptionalTensorType:$attn_mask,
|
||||||
|
Torch_FloatType:$dropout_p,
|
||||||
|
Torch_BoolType:$is_causal,
|
||||||
|
AnyTorchOptionalFloatType:$scale
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenScaledDotProductAttentionOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 7, 1);
|
||||||
|
}
|
||||||
|
void AtenScaledDotProductAttentionOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 7, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [
|
def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -13,10 +13,13 @@
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
|
#include "mlir/Dialect/Math/IR/Math.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/MLIRContext.h"
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "mlir/IR/ValueRange.h"
|
#include "mlir/IR/ValueRange.h"
|
||||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
||||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
|
||||||
|
@ -1494,6 +1497,68 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertAtenScaledDotProductAttentionOp
|
||||||
|
: public OpConversionPattern<AtenScaledDotProductAttentionOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenScaledDotProductAttentionOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
Value mask = op.getAttnMask();
|
||||||
|
Value dropoutP = op.getDropoutP();
|
||||||
|
Value isCausal = op.getIsCausal();
|
||||||
|
Value scale = op.getScale();
|
||||||
|
Type elementType =
|
||||||
|
adaptor.getQuery().getType().cast<ShapedType>().getElementType();
|
||||||
|
|
||||||
|
// Verify inputs (only support defaults)
|
||||||
|
if (!mask.getType().isa<Torch::NoneType>())
|
||||||
|
return rewriter.notifyMatchFailure(op.getLoc(),
|
||||||
|
"attention masking not supported");
|
||||||
|
double dropout;
|
||||||
|
if (!matchPattern(dropoutP, m_TorchConstantFloat(&dropout)) ||
|
||||||
|
dropout > 0.0)
|
||||||
|
return rewriter.notifyMatchFailure(op.getLoc(), "dropout not supported");
|
||||||
|
bool causal;
|
||||||
|
if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op.getLoc(), "causal attention masking not supported");
|
||||||
|
if (!scale.getType().isa<Torch::NoneType>()) {
|
||||||
|
double scaleFloat;
|
||||||
|
if (!matchPattern(scale, m_TorchConstantFloat(&scaleFloat)) ||
|
||||||
|
scaleFloat != 1.0)
|
||||||
|
return rewriter.notifyMatchFailure(op.getLoc(),
|
||||||
|
"only default scale supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> outSizes(
|
||||||
|
adaptor.getQuery().getType().cast<ShapedType>().getShape());
|
||||||
|
SmallVector<int64_t> valueSizes(
|
||||||
|
adaptor.getValue().getType().cast<ShapedType>().getShape());
|
||||||
|
outSizes[outSizes.size() - 1] = valueSizes[valueSizes.size() - 1];
|
||||||
|
SmallVector<Value> outSizesDynamic(
|
||||||
|
getTensorSizes(rewriter, op.getLoc(), adaptor.getQuery()));
|
||||||
|
outSizesDynamic[outSizesDynamic.size() - 1] = getTensorSizes(
|
||||||
|
rewriter, op.getLoc(), adaptor.getValue())[valueSizes.size() - 1];
|
||||||
|
Type outType = RankedTensorType::get(outSizes, elementType);
|
||||||
|
Value output = createZeroInitTensor(rewriter, op.getLoc(), outSizesDynamic,
|
||||||
|
elementType);
|
||||||
|
|
||||||
|
// Overwrite with tm_tensor::attention
|
||||||
|
auto attention = rewriter.create<AttentionOp>(
|
||||||
|
op.getLoc(), outType,
|
||||||
|
SmallVector<Value>{adaptor.getQuery(), adaptor.getKey(),
|
||||||
|
adaptor.getValue()},
|
||||||
|
SmallVector<Value>{output});
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, attention.getResult());
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
// The pass
|
// The pass
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
@ -1516,7 +1581,8 @@ public:
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
target.addLegalDialect<linalg::LinalgDialect, func::FuncDialect,
|
target.addLegalDialect<linalg::LinalgDialect, func::FuncDialect,
|
||||||
tensor::TensorDialect, arith::ArithDialect,
|
tensor::TensorDialect, arith::ArithDialect,
|
||||||
Torch::TorchDialect, TMTensorDialect>();
|
math::MathDialect, Torch::TorchDialect,
|
||||||
|
TMTensorDialect>();
|
||||||
|
|
||||||
TypeConverter typeConverter;
|
TypeConverter typeConverter;
|
||||||
typeConverter.addConversion([](Type type) { return type; });
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
|
@ -1536,6 +1602,9 @@ public:
|
||||||
patterns.add<ConvertAtenSortOp>(typeConverter, context);
|
patterns.add<ConvertAtenSortOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenCumsumOp>();
|
target.addIllegalOp<AtenCumsumOp>();
|
||||||
patterns.add<ConvertAtenCumsumOp>(typeConverter, context);
|
patterns.add<ConvertAtenCumsumOp>(typeConverter, context);
|
||||||
|
target.addIllegalOp<AtenScaledDotProductAttentionOp>();
|
||||||
|
patterns.add<ConvertAtenScaledDotProductAttentionOp>(typeConverter,
|
||||||
|
context);
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
std::move(patterns))))
|
std::move(patterns))))
|
||||||
|
|
|
@ -6867,6 +6867,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.linear(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.linear(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.scaled_dot_product_attention\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.optional<float>) -> !torch.list<int> {\n"
|
||||||
|
" %int-1 = torch.constant.int -1\n"
|
||||||
|
" %0 = torch.aten.__getitem__.t %arg2, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %1 = torch.aten._set_item.t %arg0, %int-1, %0 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>\n"
|
||||||
|
" return %arg0 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.zeros\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.zeros\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.list<int> {\n"
|
||||||
" return %arg0 : !torch.list<int>\n"
|
" return %arg0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
@ -8559,6 +8565,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %int11 = torch.constant.int 11\n"
|
" %int11 = torch.constant.int 11\n"
|
||||||
" return %int11 : !torch.int\n"
|
" return %int11 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.scaled_dot_product_attention\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.optional<float>) -> !torch.int {\n"
|
||||||
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" return %0#1 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.logical_or\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten.logical_or\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||||
" %int11 = torch.constant.int 11\n"
|
" %int11 = torch.constant.int 11\n"
|
||||||
" return %int11 : !torch.int\n"
|
" return %int11 : !torch.int\n"
|
||||||
|
|
|
@ -522,6 +522,15 @@ def aten〇flatten〇using_ints〡shape(self: List[int], start_dim: int = 0, end
|
||||||
def aten〇linear〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None) -> List[int]:
|
def aten〇linear〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None) -> List[int]:
|
||||||
return upstream_shape_functions.linear(input, weight, bias)
|
return upstream_shape_functions.linear(input, weight, bias)
|
||||||
|
|
||||||
|
@check_shape_function([
|
||||||
|
Invocation(TensorOfShape(3, 2, 8, 4), TensorOfShape(3, 2, 8, 4), TensorOfShape(3, 2, 8, 4)), # Same shape
|
||||||
|
Invocation(TensorOfShape(3, 2, 16, 8), TensorOfShape(3, 2, 8, 8), TensorOfShape(3, 2, 8, 4)), # Different shape
|
||||||
|
])
|
||||||
|
def aten〇scaled_dot_product_attention〡shape(query: List[int], key: List[int], value: List[int], attn_mask: Optional[List[int]] = None, dropout_p: float = 0., is_causal: bool = False, scale: Optional[float] = None) -> List[int]:
|
||||||
|
outshape = query
|
||||||
|
outshape[-1] = value[-1]
|
||||||
|
return outshape
|
||||||
|
|
||||||
@check_shape_function([
|
@check_shape_function([
|
||||||
Invocation([2, 3]),
|
Invocation([2, 3]),
|
||||||
])
|
])
|
||||||
|
@ -1904,6 +1913,11 @@ def aten〇logical_and〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp
|
||||||
def aten〇logical_not〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
def aten〇logical_not〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||||
return torch.bool
|
return torch.bool
|
||||||
|
|
||||||
|
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(3, 4, 32, 16), (3, 4, 32, 16), (3, 4, 32, 16)]))
|
||||||
|
def aten〇scaled_dot_product_attention〡dtype(query_rank_dtype: Tuple[int, int], key_rank_dtype: Tuple[int, int], value_rank_dtype: Tuple[int, int], attn_mask_rank_dtype: Optional[Tuple[int, int]] = None, dropout_p: float = 0., is_causal: bool = False, scale: Optional[float] = None) -> int:
|
||||||
|
_, query_dtype = query_rank_dtype
|
||||||
|
return query_dtype
|
||||||
|
|
||||||
@check_dtype_function(_check_two_tensor_op())
|
@check_dtype_function(_check_two_tensor_op())
|
||||||
def aten〇logical_or〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
|
def aten〇logical_or〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
|
||||||
return torch.bool
|
return torch.bool
|
||||||
|
|
|
@ -565,7 +565,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)")
|
emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)")
|
||||||
emit("aten::as_strided_scatter : (Tensor, Tensor, int[], int[], int?) -> (Tensor)")
|
emit("aten::as_strided_scatter : (Tensor, Tensor, int[], int[], int?) -> (Tensor)")
|
||||||
emit("aten::upsample_nearest2d : (Tensor, int[], float?, float?) -> (Tensor)")
|
emit("aten::upsample_nearest2d : (Tensor, int[], float?, float?) -> (Tensor)")
|
||||||
|
emit("aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?) -> (Tensor)")
|
||||||
|
|
||||||
# Dict ops.
|
# Dict ops.
|
||||||
emit("aten::__contains__.str : (Dict(str, t), str) -> (bool)", has_folder=True)
|
emit("aten::__contains__.str : (Dict(str, t), str) -> (bool)", has_folder=True)
|
||||||
|
|
|
@ -3695,6 +3695,51 @@ class MoveDimIntNegativeIndexModule(torch.nn.Module):
|
||||||
def MoveDimIntNegativeIndexModule_basic(module, tu: TestUtils):
|
def MoveDimIntNegativeIndexModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4, 2))
|
module.forward(tu.rand(3, 4, 2))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ScaledDotProductAttentionSameModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1, -1, -1], torch.float32, True)
|
||||||
|
])
|
||||||
|
def forward(self, query, key, value):
|
||||||
|
return torch.ops.aten.scaled_dot_product_attention(query, key, value)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ScaledDotProductAttentionSameModule())
|
||||||
|
def ScaledDotProductAttentionSameModule_basic(module, tu: TestUtils):
|
||||||
|
query = torch.randn(1, 1, 5, 5, dtype=torch.float32)
|
||||||
|
key = torch.randn(1, 1, 5, 5, dtype=torch.float32)
|
||||||
|
value = torch.randn(1, 1, 5, 5, dtype=torch.float32)
|
||||||
|
module.forward(query, key, value)
|
||||||
|
|
||||||
|
class ScaledDotProductAttentionDifferentModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1, -1, -1], torch.float32, True)
|
||||||
|
])
|
||||||
|
def forward(self, query, key, value):
|
||||||
|
return torch.ops.aten.scaled_dot_product_attention(query, key, value)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ScaledDotProductAttentionDifferentModule())
|
||||||
|
def ScaledDotProductAttentionDifferentModule_basic(module, tu: TestUtils):
|
||||||
|
query = torch.randn(3, 2, 8, 4, dtype=torch.float32)
|
||||||
|
key = torch.randn(3, 2, 16, 4, dtype=torch.float32)
|
||||||
|
value = torch.randn(3, 2, 16, 4, dtype=torch.float32)
|
||||||
|
module.forward(query, key, value)
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue