mirror of https://github.com/llvm/torch-mlir
Add TMTensor::Attention and lower ScaledDotProductAttentionOp to it (#2027)
parent
c76a48308e
commit
0302cf1d92
|
@ -244,6 +244,90 @@ def TMTensor_SortOp : TMTensor_Op<"sort",
|
|||
}];
|
||||
}
|
||||
|
||||
def TMTensor_AttentionOp : TMTensor_Op<"attention",
|
||||
[DeclareOpInterfaceMethods<TMTensorInterface,
|
||||
["payloadUsesValueFromOperand"]>,
|
||||
DeclareOpInterfaceMethods<ScalarLoopOpInterface,
|
||||
["generateScalarImplementation"]>]> {
|
||||
let summary = "Attention operator";
|
||||
let description = [{
|
||||
This operator takes in 3 tensors: query(Q), key(K) and value(V) and computes
|
||||
the attention. Each of the inputs has shape BxNxd where B is the
|
||||
of the batch dimension, N is the sequence length and d is head dimension.
|
||||
Typically N >>> d. Mathematically, the attention is defined as
|
||||
matmul(softmax(matmul(Q, transpose(K))), V) and has shape BxNxd. Usually,
|
||||
this operator also performs scaling, masking and dropout, but we leave
|
||||
that out of the current implementation.
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<AnyShaped>:$inputs,
|
||||
Variadic<AnyShaped>:$outputs
|
||||
);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs)>
|
||||
];
|
||||
|
||||
let results = (outs Variadic<AnyRankedTensor>:$result);
|
||||
let assemblyFormat = [{
|
||||
attr-dict
|
||||
`ins` `(` $inputs `:` type($inputs) `)`
|
||||
`outs` `(` $outputs `:` type($outputs) `)`
|
||||
(`->` type($result)^)?
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = extraTMTensorOpClassDeclaration # [{
|
||||
Value getQuery() {
|
||||
return getInputOperand(0)->get();
|
||||
}
|
||||
Value getKey() {
|
||||
return getInputOperand(1)->get();
|
||||
}
|
||||
Value getValue() {
|
||||
return getInputOperand(2)->get();
|
||||
}
|
||||
Value getOutput() {
|
||||
return getOutputOperand(0)->get();
|
||||
}
|
||||
ShapedType getQueryType() {
|
||||
return getQuery().getType().cast<ShapedType>();
|
||||
}
|
||||
ShapedType getKeyType() {
|
||||
return getKey().getType().cast<ShapedType>();
|
||||
}
|
||||
ShapedType getValueType() {
|
||||
return getValue().getType().cast<ShapedType>();
|
||||
}
|
||||
ShapedType getOutputType() {
|
||||
return getOutput().getType().cast<ShapedType>();
|
||||
}
|
||||
int64_t getQueryRank() {
|
||||
return getQueryType().getRank();
|
||||
}
|
||||
int64_t getKeyRank() {
|
||||
return getKeyType().getRank();
|
||||
}
|
||||
int64_t getValueRank() {
|
||||
return getValueType().getRank();
|
||||
}
|
||||
int64_t getOutputRank() {
|
||||
return getOutputType().getRank();
|
||||
}
|
||||
int64_t getIterationDomainRank() {
|
||||
return 2;
|
||||
};
|
||||
// Method to implement for specifying output range for
|
||||
// DestinationStyleOpInterface
|
||||
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
|
||||
std::pair<unsigned, unsigned> outputsIndexAndLength =
|
||||
getODSOperandIndexAndLength(1);
|
||||
return std::make_pair<int64_t, int64_t>(
|
||||
outputsIndexAndLength.first,
|
||||
outputsIndexAndLength.first + outputsIndexAndLength.second);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pure ops
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -84,6 +84,243 @@ OpFoldResult TMTensor::getDim(OpBuilder &builder, Location loc, Value v,
|
|||
return builder.getI64IntegerAttr(t.getDimSize(dim));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AttentionOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult AttentionOp::verify() {
|
||||
Operation *op = getOperation();
|
||||
ShapedType queryType = getQueryType();
|
||||
ShapedType keyType = getKeyType();
|
||||
ShapedType valueType = getValueType();
|
||||
ShapedType outputType = getOutputType();
|
||||
ArrayRef<int64_t> queryShape = queryType.getShape();
|
||||
ArrayRef<int64_t> keyShape = keyType.getShape();
|
||||
ArrayRef<int64_t> valueShape = valueType.getShape();
|
||||
ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||
if (failed(verifyCompatibleShape(queryShape, keyShape)))
|
||||
return op->emitOpError("incompatible key shape");
|
||||
if (failed(verifyCompatibleShape(queryShape, valueShape)))
|
||||
return op->emitOpError("incompatible value shape");
|
||||
if (failed(verifyCompatibleShape(queryShape, outputShape)))
|
||||
return op->emitOpError("incompatible output shape");
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<Range> AttentionOp::getIterationDomain(OpBuilder &builder) {
|
||||
int64_t iterationDomainRank = getIterationDomainRank();
|
||||
SmallVector<Range> loopBounds(iterationDomainRank);
|
||||
Location loc = getLoc();
|
||||
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
|
||||
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
|
||||
Value source = getQuery();
|
||||
for (auto dim : llvm::seq<int64_t>(0, iterationDomainRank)) {
|
||||
loopBounds[dim].offset = zero;
|
||||
loopBounds[dim].size = getDimValue(builder, loc, source, dim);
|
||||
loopBounds[dim].stride = one;
|
||||
}
|
||||
return loopBounds;
|
||||
}
|
||||
|
||||
SmallVector<utils::IteratorType> AttentionOp::getLoopIteratorTypes() {
|
||||
SmallVector<utils::IteratorType> iteratorTypes(getIterationDomainRank(),
|
||||
utils::IteratorType::parallel);
|
||||
return iteratorTypes;
|
||||
}
|
||||
|
||||
bool AttentionOp::payloadUsesValueFromOperand(OpOperand *opOperand) {
|
||||
Value operand = opOperand->get();
|
||||
return operand == getQuery() || operand == getKey() || operand == getValue();
|
||||
}
|
||||
|
||||
// Performs a matmul between lhs and rhs
|
||||
// Note that "transposed" means the last two dims of rhs are swapped
|
||||
static void matmul(OpBuilder &b, Location loc, Value lhs, ValueRange lhsSizes,
|
||||
Value rhs, ValueRange rhsSizes, Value output,
|
||||
ValueRange outputSizes, bool transposed = false) {
|
||||
auto elementType = lhs.getType().cast<MemRefType>().getElementType();
|
||||
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
|
||||
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
||||
auto rank = outputSizes.size();
|
||||
Value reductionDimSize = lhsSizes[lhsSizes.size() - 1];
|
||||
|
||||
// Loop over output
|
||||
b.create<scf::ParallelOp>(
|
||||
loc, SmallVector<Value>(rank, zero), outputSizes,
|
||||
SmallVector<Value>(rank, one),
|
||||
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
|
||||
Value acc = b.create<arith::ConstantOp>(
|
||||
loc, elementType, b.getFloatAttr(elementType, 0.0));
|
||||
Value sum =
|
||||
b.create<scf::ForOp>(
|
||||
loc, zero, reductionDimSize, one, SmallVector<Value>{acc},
|
||||
[&](OpBuilder &b, Location loc, Value i, ValueRange accs) {
|
||||
SmallVector<Value> lhsIVs(localIVs), rhsIVs(localIVs);
|
||||
lhsIVs[lhsIVs.size() - 1] = i;
|
||||
rhsIVs[rhsIVs.size() - 2] = i;
|
||||
if (transposed)
|
||||
std::swap(rhsIVs[rhsIVs.size() - 1],
|
||||
rhsIVs[rhsIVs.size() - 2]);
|
||||
|
||||
Value acc = accs[0];
|
||||
Value rElem = b.create<memref::LoadOp>(loc, lhs, lhsIVs);
|
||||
Value cElem = b.create<memref::LoadOp>(loc, rhs, rhsIVs);
|
||||
Value x = b.create<arith::MulFOp>(loc, rElem, cElem);
|
||||
x = b.create<arith::AddFOp>(loc, x, acc);
|
||||
|
||||
b.create<scf::YieldOp>(loc, x);
|
||||
})
|
||||
->getResult(0);
|
||||
b.create<memref::StoreOp>(loc, sum, output, localIVs);
|
||||
b.create<scf::YieldOp>(loc);
|
||||
});
|
||||
}
|
||||
|
||||
LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
|
||||
Location loc,
|
||||
ValueRange ivs) {
|
||||
|
||||
Value query = getQuery();
|
||||
Value key = getKey();
|
||||
Value value = getValue();
|
||||
Value output = getOutput();
|
||||
auto queryType = query.getType().cast<MemRefType>();
|
||||
auto keyType = key.getType().cast<MemRefType>();
|
||||
auto valueType = value.getType().cast<MemRefType>();
|
||||
auto queryRank = queryType.getRank();
|
||||
auto keyRank = keyType.getRank();
|
||||
auto valueRank = valueType.getRank();
|
||||
auto keySizes = keyType.getShape();
|
||||
Type elementType = queryType.getElementType();
|
||||
|
||||
Value zeroF = b.create<arith::ConstantOp>(loc, elementType,
|
||||
b.getFloatAttr(elementType, 0.0));
|
||||
|
||||
SmallVector<Value> queryDynSizes, keyDynSizes, valueDynSizes, outputDynSizes;
|
||||
for (auto i = 0; i < queryRank; i++)
|
||||
queryDynSizes.push_back(b.create<memref::DimOp>(loc, query, i));
|
||||
for (auto i = 0; i < keyRank; i++)
|
||||
keyDynSizes.push_back(b.create<memref::DimOp>(loc, key, i));
|
||||
for (auto i = 0; i < valueRank; i++)
|
||||
valueDynSizes.push_back(b.create<memref::DimOp>(loc, value, i));
|
||||
for (auto i = 0; i < queryRank; i++)
|
||||
outputDynSizes.push_back(b.create<memref::DimOp>(loc, output, i));
|
||||
|
||||
// weight = query @ key
|
||||
auto weightRank = queryRank;
|
||||
auto weightSizes = SmallVector<int64_t>(queryType.getShape());
|
||||
weightSizes[weightRank - 1] = keySizes[keyRank - 2];
|
||||
auto weightType = MemRefType::get(weightSizes, queryType.getElementType());
|
||||
SmallVector<Value> weightDynSizes(queryDynSizes);
|
||||
weightDynSizes[weightRank - 1] = keyDynSizes[keyRank - 2];
|
||||
Value weight = b.create<memref::AllocOp>(loc, weightType, weightDynSizes);
|
||||
matmul(b, loc, query, queryDynSizes, key, keyDynSizes, weight, weightDynSizes,
|
||||
/*transposed=*/true);
|
||||
|
||||
// weight = softmax(weight)
|
||||
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
|
||||
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
||||
Value dim = weightDynSizes[weightRank - 1];
|
||||
Value scaleFactor = b.create<math::SqrtOp>(
|
||||
loc, b.create<arith::UIToFPOp>(
|
||||
loc, elementType,
|
||||
b.create<arith::IndexCastUIOp>(loc, b.getI32Type(),
|
||||
queryDynSizes[queryRank - 1])));
|
||||
// calculate max(weight)
|
||||
Value init = b.create<memref::LoadOp>(loc, weight,
|
||||
SmallVector<Value>(weightRank, zero));
|
||||
Value globalMax =
|
||||
b.create<scf::ParallelOp>(
|
||||
loc, SmallVector<Value>(weightRank, zero), weightDynSizes,
|
||||
SmallVector<Value>(weightRank, one), init,
|
||||
[&](OpBuilder &b, Location loc, ValueRange localIVs,
|
||||
ValueRange accs) {
|
||||
b.create<scf::ReduceOp>(
|
||||
loc, init,
|
||||
[&](OpBuilder &b, Location loc, Value elem, Value acc) {
|
||||
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
|
||||
Value max = b.create<arith::MaxFOp>(loc, x, acc);
|
||||
b.create<scf::ReduceReturnOp>(loc, max);
|
||||
});
|
||||
})
|
||||
.getResult(0);
|
||||
// weight = (weight - max(weight)) / math.sqrt(querySizes[-1])
|
||||
b.create<scf::ParallelOp>(
|
||||
loc, SmallVector<Value>(weightRank, zero), weightDynSizes,
|
||||
SmallVector<Value>(weightRank, one),
|
||||
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
|
||||
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
|
||||
x = b.create<arith::SubFOp>(loc, x, globalMax);
|
||||
x = b.create<arith::DivFOp>(loc, x, scaleFactor);
|
||||
b.create<memref::StoreOp>(loc, x, weight, localIVs);
|
||||
b.create<scf::YieldOp>(loc);
|
||||
});
|
||||
// calculate exp(weight)
|
||||
SmallVector<Value> min(weightRank, zero),
|
||||
max(weightDynSizes.begin(), weightDynSizes.end()), steps(weightRank, one);
|
||||
b.create<scf::ParallelOp>(
|
||||
loc, min, max, steps,
|
||||
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
|
||||
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
|
||||
x = b.create<math::ExpOp>(loc, x);
|
||||
b.create<memref::StoreOp>(loc, x, weight, localIVs);
|
||||
b.create<scf::YieldOp>(loc);
|
||||
});
|
||||
Value expWeightSum = b.create<memref::AllocOp>(
|
||||
loc,
|
||||
MemRefType::get(
|
||||
SmallVector<int64_t>(weightSizes.begin(), weightSizes.end() - 1),
|
||||
elementType),
|
||||
SmallVector<Value>{weightDynSizes.begin(), weightDynSizes.end() - 1});
|
||||
b.create<scf::ParallelOp>(
|
||||
loc, SmallVector<Value>(weightRank - 1, zero),
|
||||
SmallVector<Value>{weightDynSizes.begin(), weightDynSizes.end() - 1},
|
||||
SmallVector<Value>(weightRank - 1, one),
|
||||
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
|
||||
b.create<memref::StoreOp>(loc, zeroF, expWeightSum, localIVs);
|
||||
});
|
||||
// Loop over all dims but -1
|
||||
b.create<scf::ParallelOp>(
|
||||
loc, SmallVector<Value>(weightRank - 1, zero),
|
||||
SmallVector<Value>(weightDynSizes.begin(), weightDynSizes.end() - 1),
|
||||
SmallVector<Value>(weightRank - 1, one),
|
||||
[&](OpBuilder &b, Location loc, ValueRange outsideDims) {
|
||||
// Sum over last dim
|
||||
b.create<scf::ParallelOp>(
|
||||
loc, zero, dim, one,
|
||||
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
|
||||
SmallVector<Value> coords(outsideDims);
|
||||
coords.push_back(localIVs[0]);
|
||||
Value x =
|
||||
b.create<memref::LoadOp>(loc, expWeightSum, outsideDims);
|
||||
Value y = b.create<memref::LoadOp>(loc, weight, coords);
|
||||
Value sum = b.create<arith::AddFOp>(loc, x, y);
|
||||
b.create<memref::StoreOp>(loc, sum, expWeightSum, outsideDims);
|
||||
b.create<scf::YieldOp>(loc);
|
||||
});
|
||||
});
|
||||
// calculate exp(weight) / sum(exp(weight))
|
||||
b.create<scf::ParallelOp>(
|
||||
loc, SmallVector<Value>(weightRank, zero),
|
||||
SmallVector<Value>(weightDynSizes.begin(), weightDynSizes.end()),
|
||||
SmallVector<Value>(weightRank, one),
|
||||
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
|
||||
SmallVector<Value> sumIVs(localIVs);
|
||||
sumIVs.pop_back();
|
||||
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
|
||||
Value sum = b.create<memref::LoadOp>(loc, expWeightSum, sumIVs);
|
||||
x = b.create<arith::DivFOp>(loc, x, sum);
|
||||
b.create<memref::StoreOp>(loc, x, weight, localIVs);
|
||||
b.create<scf::YieldOp>(loc);
|
||||
});
|
||||
|
||||
// output = weight @ value
|
||||
matmul(b, loc, weight, weightDynSizes, value, valueDynSizes, output,
|
||||
outputDynSizes, /*transposed=*/false);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ScanOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -652,6 +889,7 @@ bool SortOp::payloadUsesValueFromOperand(OpOperand *opOperand) {
|
|||
outputBuffers); \
|
||||
}
|
||||
|
||||
DEFINE_OP_GET_EFFECTS(AttentionOp)
|
||||
DEFINE_OP_GET_EFFECTS(ScanOp)
|
||||
DEFINE_OP_GET_EFFECTS(ScatterOp)
|
||||
DEFINE_OP_GET_EFFECTS(SortOp)
|
||||
|
|
|
@ -9040,6 +9040,35 @@ def Torch_AtenUpsampleNearest2dOp : Torch_Op<"aten.upsample_nearest2d", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenScaledDotProductAttentionOp : Torch_Op<"aten.scaled_dot_product_attention", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$query,
|
||||
AnyTorchTensorType:$key,
|
||||
AnyTorchTensorType:$value,
|
||||
AnyTorchOptionalTensorType:$attn_mask,
|
||||
Torch_FloatType:$dropout_p,
|
||||
Torch_BoolType:$is_causal,
|
||||
AnyTorchOptionalFloatType:$scale
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenScaledDotProductAttentionOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 7, 1);
|
||||
}
|
||||
void AtenScaledDotProductAttentionOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 7, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -13,10 +13,13 @@
|
|||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
|
||||
|
@ -1494,6 +1497,68 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenScaledDotProductAttentionOp
|
||||
: public OpConversionPattern<AtenScaledDotProductAttentionOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenScaledDotProductAttentionOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value mask = op.getAttnMask();
|
||||
Value dropoutP = op.getDropoutP();
|
||||
Value isCausal = op.getIsCausal();
|
||||
Value scale = op.getScale();
|
||||
Type elementType =
|
||||
adaptor.getQuery().getType().cast<ShapedType>().getElementType();
|
||||
|
||||
// Verify inputs (only support defaults)
|
||||
if (!mask.getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(op.getLoc(),
|
||||
"attention masking not supported");
|
||||
double dropout;
|
||||
if (!matchPattern(dropoutP, m_TorchConstantFloat(&dropout)) ||
|
||||
dropout > 0.0)
|
||||
return rewriter.notifyMatchFailure(op.getLoc(), "dropout not supported");
|
||||
bool causal;
|
||||
if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op.getLoc(), "causal attention masking not supported");
|
||||
if (!scale.getType().isa<Torch::NoneType>()) {
|
||||
double scaleFloat;
|
||||
if (!matchPattern(scale, m_TorchConstantFloat(&scaleFloat)) ||
|
||||
scaleFloat != 1.0)
|
||||
return rewriter.notifyMatchFailure(op.getLoc(),
|
||||
"only default scale supported");
|
||||
}
|
||||
|
||||
SmallVector<int64_t> outSizes(
|
||||
adaptor.getQuery().getType().cast<ShapedType>().getShape());
|
||||
SmallVector<int64_t> valueSizes(
|
||||
adaptor.getValue().getType().cast<ShapedType>().getShape());
|
||||
outSizes[outSizes.size() - 1] = valueSizes[valueSizes.size() - 1];
|
||||
SmallVector<Value> outSizesDynamic(
|
||||
getTensorSizes(rewriter, op.getLoc(), adaptor.getQuery()));
|
||||
outSizesDynamic[outSizesDynamic.size() - 1] = getTensorSizes(
|
||||
rewriter, op.getLoc(), adaptor.getValue())[valueSizes.size() - 1];
|
||||
Type outType = RankedTensorType::get(outSizes, elementType);
|
||||
Value output = createZeroInitTensor(rewriter, op.getLoc(), outSizesDynamic,
|
||||
elementType);
|
||||
|
||||
// Overwrite with tm_tensor::attention
|
||||
auto attention = rewriter.create<AttentionOp>(
|
||||
op.getLoc(), outType,
|
||||
SmallVector<Value>{adaptor.getQuery(), adaptor.getKey(),
|
||||
adaptor.getValue()},
|
||||
SmallVector<Value>{output});
|
||||
|
||||
rewriter.replaceOp(op, attention.getResult());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// The pass
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -1516,7 +1581,8 @@ public:
|
|||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<linalg::LinalgDialect, func::FuncDialect,
|
||||
tensor::TensorDialect, arith::ArithDialect,
|
||||
Torch::TorchDialect, TMTensorDialect>();
|
||||
math::MathDialect, Torch::TorchDialect,
|
||||
TMTensorDialect>();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
typeConverter.addConversion([](Type type) { return type; });
|
||||
|
@ -1536,6 +1602,9 @@ public:
|
|||
patterns.add<ConvertAtenSortOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenCumsumOp>();
|
||||
patterns.add<ConvertAtenCumsumOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenScaledDotProductAttentionOp>();
|
||||
patterns.add<ConvertAtenScaledDotProductAttentionOp>(typeConverter,
|
||||
context);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
|
|
|
@ -6867,6 +6867,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.linear(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.scaled_dot_product_attention\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.optional<float>) -> !torch.list<int> {\n"
|
||||
" %int-1 = torch.constant.int -1\n"
|
||||
" %0 = torch.aten.__getitem__.t %arg2, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %1 = torch.aten._set_item.t %arg0, %int-1, %0 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.zeros\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
|
@ -8559,6 +8565,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %int11 = torch.constant.int 11\n"
|
||||
" return %int11 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.scaled_dot_product_attention\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.optional<float>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.logical_or\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
" return %int11 : !torch.int\n"
|
||||
|
|
|
@ -522,6 +522,15 @@ def aten〇flatten〇using_ints〡shape(self: List[int], start_dim: int = 0, end
|
|||
def aten〇linear〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None) -> List[int]:
|
||||
return upstream_shape_functions.linear(input, weight, bias)
|
||||
|
||||
@check_shape_function([
|
||||
Invocation(TensorOfShape(3, 2, 8, 4), TensorOfShape(3, 2, 8, 4), TensorOfShape(3, 2, 8, 4)), # Same shape
|
||||
Invocation(TensorOfShape(3, 2, 16, 8), TensorOfShape(3, 2, 8, 8), TensorOfShape(3, 2, 8, 4)), # Different shape
|
||||
])
|
||||
def aten〇scaled_dot_product_attention〡shape(query: List[int], key: List[int], value: List[int], attn_mask: Optional[List[int]] = None, dropout_p: float = 0., is_causal: bool = False, scale: Optional[float] = None) -> List[int]:
|
||||
outshape = query
|
||||
outshape[-1] = value[-1]
|
||||
return outshape
|
||||
|
||||
@check_shape_function([
|
||||
Invocation([2, 3]),
|
||||
])
|
||||
|
@ -1904,6 +1913,11 @@ def aten〇logical_and〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp
|
|||
def aten〇logical_not〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||
return torch.bool
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(3, 4, 32, 16), (3, 4, 32, 16), (3, 4, 32, 16)]))
|
||||
def aten〇scaled_dot_product_attention〡dtype(query_rank_dtype: Tuple[int, int], key_rank_dtype: Tuple[int, int], value_rank_dtype: Tuple[int, int], attn_mask_rank_dtype: Optional[Tuple[int, int]] = None, dropout_p: float = 0., is_causal: bool = False, scale: Optional[float] = None) -> int:
|
||||
_, query_dtype = query_rank_dtype
|
||||
return query_dtype
|
||||
|
||||
@check_dtype_function(_check_two_tensor_op())
|
||||
def aten〇logical_or〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
|
||||
return torch.bool
|
||||
|
|
|
@ -565,7 +565,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)")
|
||||
emit("aten::as_strided_scatter : (Tensor, Tensor, int[], int[], int?) -> (Tensor)")
|
||||
emit("aten::upsample_nearest2d : (Tensor, int[], float?, float?) -> (Tensor)")
|
||||
|
||||
emit("aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?) -> (Tensor)")
|
||||
|
||||
# Dict ops.
|
||||
emit("aten::__contains__.str : (Dict(str, t), str) -> (bool)", has_folder=True)
|
||||
|
|
|
@ -3695,6 +3695,51 @@ class MoveDimIntNegativeIndexModule(torch.nn.Module):
|
|||
def MoveDimIntNegativeIndexModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 2))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ScaledDotProductAttentionSameModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True)
|
||||
])
|
||||
def forward(self, query, key, value):
|
||||
return torch.ops.aten.scaled_dot_product_attention(query, key, value)
|
||||
|
||||
@register_test_case(module_factory=lambda: ScaledDotProductAttentionSameModule())
|
||||
def ScaledDotProductAttentionSameModule_basic(module, tu: TestUtils):
|
||||
query = torch.randn(1, 1, 5, 5, dtype=torch.float32)
|
||||
key = torch.randn(1, 1, 5, 5, dtype=torch.float32)
|
||||
value = torch.randn(1, 1, 5, 5, dtype=torch.float32)
|
||||
module.forward(query, key, value)
|
||||
|
||||
class ScaledDotProductAttentionDifferentModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True)
|
||||
])
|
||||
def forward(self, query, key, value):
|
||||
return torch.ops.aten.scaled_dot_product_attention(query, key, value)
|
||||
|
||||
@register_test_case(module_factory=lambda: ScaledDotProductAttentionDifferentModule())
|
||||
def ScaledDotProductAttentionDifferentModule_basic(module, tu: TestUtils):
|
||||
query = torch.randn(3, 2, 8, 4, dtype=torch.float32)
|
||||
key = torch.randn(3, 2, 16, 4, dtype=torch.float32)
|
||||
value = torch.randn(3, 2, 16, 4, dtype=torch.float32)
|
||||
module.forward(query, key, value)
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
|
Loading…
Reference in New Issue