Add TMTensor::Attention and lower ScaledDotProductAttentionOp to it (#2027)

pull/2153/head snapshot-20230517.841
gpetters94 2023-05-16 15:17:45 -04:00 committed by GitHub
parent c76a48308e
commit 0302cf1d92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 491 additions and 2 deletions

View File

@ -244,6 +244,90 @@ def TMTensor_SortOp : TMTensor_Op<"sort",
}]; }];
} }
def TMTensor_AttentionOp : TMTensor_Op<"attention",
[DeclareOpInterfaceMethods<TMTensorInterface,
["payloadUsesValueFromOperand"]>,
DeclareOpInterfaceMethods<ScalarLoopOpInterface,
["generateScalarImplementation"]>]> {
let summary = "Attention operator";
let description = [{
This operator takes in 3 tensors: query(Q), key(K) and value(V) and computes
the attention. Each of the inputs has shape BxNxd where B is the
of the batch dimension, N is the sequence length and d is head dimension.
Typically N >>> d. Mathematically, the attention is defined as
matmul(softmax(matmul(Q, transpose(K))), V) and has shape BxNxd. Usually,
this operator also performs scaling, masking and dropout, but we leave
that out of the current implementation.
}];
let arguments = (ins Variadic<AnyShaped>:$inputs,
Variadic<AnyShaped>:$outputs
);
let builders = [
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs)>
];
let results = (outs Variadic<AnyRankedTensor>:$result);
let assemblyFormat = [{
attr-dict
`ins` `(` $inputs `:` type($inputs) `)`
`outs` `(` $outputs `:` type($outputs) `)`
(`->` type($result)^)?
}];
let extraClassDeclaration = extraTMTensorOpClassDeclaration # [{
Value getQuery() {
return getInputOperand(0)->get();
}
Value getKey() {
return getInputOperand(1)->get();
}
Value getValue() {
return getInputOperand(2)->get();
}
Value getOutput() {
return getOutputOperand(0)->get();
}
ShapedType getQueryType() {
return getQuery().getType().cast<ShapedType>();
}
ShapedType getKeyType() {
return getKey().getType().cast<ShapedType>();
}
ShapedType getValueType() {
return getValue().getType().cast<ShapedType>();
}
ShapedType getOutputType() {
return getOutput().getType().cast<ShapedType>();
}
int64_t getQueryRank() {
return getQueryType().getRank();
}
int64_t getKeyRank() {
return getKeyType().getRank();
}
int64_t getValueRank() {
return getValueType().getRank();
}
int64_t getOutputRank() {
return getOutputType().getRank();
}
int64_t getIterationDomainRank() {
return 2;
};
// Method to implement for specifying output range for
// DestinationStyleOpInterface
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
std::pair<unsigned, unsigned> outputsIndexAndLength =
getODSOperandIndexAndLength(1);
return std::make_pair<int64_t, int64_t>(
outputsIndexAndLength.first,
outputsIndexAndLength.first + outputsIndexAndLength.second);
}
}];
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Pure ops // Pure ops
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -84,6 +84,243 @@ OpFoldResult TMTensor::getDim(OpBuilder &builder, Location loc, Value v,
return builder.getI64IntegerAttr(t.getDimSize(dim)); return builder.getI64IntegerAttr(t.getDimSize(dim));
} }
//===----------------------------------------------------------------------===//
// AttentionOp
//===----------------------------------------------------------------------===//
LogicalResult AttentionOp::verify() {
Operation *op = getOperation();
ShapedType queryType = getQueryType();
ShapedType keyType = getKeyType();
ShapedType valueType = getValueType();
ShapedType outputType = getOutputType();
ArrayRef<int64_t> queryShape = queryType.getShape();
ArrayRef<int64_t> keyShape = keyType.getShape();
ArrayRef<int64_t> valueShape = valueType.getShape();
ArrayRef<int64_t> outputShape = outputType.getShape();
if (failed(verifyCompatibleShape(queryShape, keyShape)))
return op->emitOpError("incompatible key shape");
if (failed(verifyCompatibleShape(queryShape, valueShape)))
return op->emitOpError("incompatible value shape");
if (failed(verifyCompatibleShape(queryShape, outputShape)))
return op->emitOpError("incompatible output shape");
return success();
}
SmallVector<Range> AttentionOp::getIterationDomain(OpBuilder &builder) {
int64_t iterationDomainRank = getIterationDomainRank();
SmallVector<Range> loopBounds(iterationDomainRank);
Location loc = getLoc();
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
Value source = getQuery();
for (auto dim : llvm::seq<int64_t>(0, iterationDomainRank)) {
loopBounds[dim].offset = zero;
loopBounds[dim].size = getDimValue(builder, loc, source, dim);
loopBounds[dim].stride = one;
}
return loopBounds;
}
SmallVector<utils::IteratorType> AttentionOp::getLoopIteratorTypes() {
SmallVector<utils::IteratorType> iteratorTypes(getIterationDomainRank(),
utils::IteratorType::parallel);
return iteratorTypes;
}
bool AttentionOp::payloadUsesValueFromOperand(OpOperand *opOperand) {
Value operand = opOperand->get();
return operand == getQuery() || operand == getKey() || operand == getValue();
}
// Performs a matmul between lhs and rhs
// Note that "transposed" means the last two dims of rhs are swapped
static void matmul(OpBuilder &b, Location loc, Value lhs, ValueRange lhsSizes,
Value rhs, ValueRange rhsSizes, Value output,
ValueRange outputSizes, bool transposed = false) {
auto elementType = lhs.getType().cast<MemRefType>().getElementType();
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
auto rank = outputSizes.size();
Value reductionDimSize = lhsSizes[lhsSizes.size() - 1];
// Loop over output
b.create<scf::ParallelOp>(
loc, SmallVector<Value>(rank, zero), outputSizes,
SmallVector<Value>(rank, one),
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
Value acc = b.create<arith::ConstantOp>(
loc, elementType, b.getFloatAttr(elementType, 0.0));
Value sum =
b.create<scf::ForOp>(
loc, zero, reductionDimSize, one, SmallVector<Value>{acc},
[&](OpBuilder &b, Location loc, Value i, ValueRange accs) {
SmallVector<Value> lhsIVs(localIVs), rhsIVs(localIVs);
lhsIVs[lhsIVs.size() - 1] = i;
rhsIVs[rhsIVs.size() - 2] = i;
if (transposed)
std::swap(rhsIVs[rhsIVs.size() - 1],
rhsIVs[rhsIVs.size() - 2]);
Value acc = accs[0];
Value rElem = b.create<memref::LoadOp>(loc, lhs, lhsIVs);
Value cElem = b.create<memref::LoadOp>(loc, rhs, rhsIVs);
Value x = b.create<arith::MulFOp>(loc, rElem, cElem);
x = b.create<arith::AddFOp>(loc, x, acc);
b.create<scf::YieldOp>(loc, x);
})
->getResult(0);
b.create<memref::StoreOp>(loc, sum, output, localIVs);
b.create<scf::YieldOp>(loc);
});
}
LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
Location loc,
ValueRange ivs) {
Value query = getQuery();
Value key = getKey();
Value value = getValue();
Value output = getOutput();
auto queryType = query.getType().cast<MemRefType>();
auto keyType = key.getType().cast<MemRefType>();
auto valueType = value.getType().cast<MemRefType>();
auto queryRank = queryType.getRank();
auto keyRank = keyType.getRank();
auto valueRank = valueType.getRank();
auto keySizes = keyType.getShape();
Type elementType = queryType.getElementType();
Value zeroF = b.create<arith::ConstantOp>(loc, elementType,
b.getFloatAttr(elementType, 0.0));
SmallVector<Value> queryDynSizes, keyDynSizes, valueDynSizes, outputDynSizes;
for (auto i = 0; i < queryRank; i++)
queryDynSizes.push_back(b.create<memref::DimOp>(loc, query, i));
for (auto i = 0; i < keyRank; i++)
keyDynSizes.push_back(b.create<memref::DimOp>(loc, key, i));
for (auto i = 0; i < valueRank; i++)
valueDynSizes.push_back(b.create<memref::DimOp>(loc, value, i));
for (auto i = 0; i < queryRank; i++)
outputDynSizes.push_back(b.create<memref::DimOp>(loc, output, i));
// weight = query @ key
auto weightRank = queryRank;
auto weightSizes = SmallVector<int64_t>(queryType.getShape());
weightSizes[weightRank - 1] = keySizes[keyRank - 2];
auto weightType = MemRefType::get(weightSizes, queryType.getElementType());
SmallVector<Value> weightDynSizes(queryDynSizes);
weightDynSizes[weightRank - 1] = keyDynSizes[keyRank - 2];
Value weight = b.create<memref::AllocOp>(loc, weightType, weightDynSizes);
matmul(b, loc, query, queryDynSizes, key, keyDynSizes, weight, weightDynSizes,
/*transposed=*/true);
// weight = softmax(weight)
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
Value dim = weightDynSizes[weightRank - 1];
Value scaleFactor = b.create<math::SqrtOp>(
loc, b.create<arith::UIToFPOp>(
loc, elementType,
b.create<arith::IndexCastUIOp>(loc, b.getI32Type(),
queryDynSizes[queryRank - 1])));
// calculate max(weight)
Value init = b.create<memref::LoadOp>(loc, weight,
SmallVector<Value>(weightRank, zero));
Value globalMax =
b.create<scf::ParallelOp>(
loc, SmallVector<Value>(weightRank, zero), weightDynSizes,
SmallVector<Value>(weightRank, one), init,
[&](OpBuilder &b, Location loc, ValueRange localIVs,
ValueRange accs) {
b.create<scf::ReduceOp>(
loc, init,
[&](OpBuilder &b, Location loc, Value elem, Value acc) {
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
Value max = b.create<arith::MaxFOp>(loc, x, acc);
b.create<scf::ReduceReturnOp>(loc, max);
});
})
.getResult(0);
// weight = (weight - max(weight)) / math.sqrt(querySizes[-1])
b.create<scf::ParallelOp>(
loc, SmallVector<Value>(weightRank, zero), weightDynSizes,
SmallVector<Value>(weightRank, one),
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
x = b.create<arith::SubFOp>(loc, x, globalMax);
x = b.create<arith::DivFOp>(loc, x, scaleFactor);
b.create<memref::StoreOp>(loc, x, weight, localIVs);
b.create<scf::YieldOp>(loc);
});
// calculate exp(weight)
SmallVector<Value> min(weightRank, zero),
max(weightDynSizes.begin(), weightDynSizes.end()), steps(weightRank, one);
b.create<scf::ParallelOp>(
loc, min, max, steps,
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
x = b.create<math::ExpOp>(loc, x);
b.create<memref::StoreOp>(loc, x, weight, localIVs);
b.create<scf::YieldOp>(loc);
});
Value expWeightSum = b.create<memref::AllocOp>(
loc,
MemRefType::get(
SmallVector<int64_t>(weightSizes.begin(), weightSizes.end() - 1),
elementType),
SmallVector<Value>{weightDynSizes.begin(), weightDynSizes.end() - 1});
b.create<scf::ParallelOp>(
loc, SmallVector<Value>(weightRank - 1, zero),
SmallVector<Value>{weightDynSizes.begin(), weightDynSizes.end() - 1},
SmallVector<Value>(weightRank - 1, one),
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
b.create<memref::StoreOp>(loc, zeroF, expWeightSum, localIVs);
});
// Loop over all dims but -1
b.create<scf::ParallelOp>(
loc, SmallVector<Value>(weightRank - 1, zero),
SmallVector<Value>(weightDynSizes.begin(), weightDynSizes.end() - 1),
SmallVector<Value>(weightRank - 1, one),
[&](OpBuilder &b, Location loc, ValueRange outsideDims) {
// Sum over last dim
b.create<scf::ParallelOp>(
loc, zero, dim, one,
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
SmallVector<Value> coords(outsideDims);
coords.push_back(localIVs[0]);
Value x =
b.create<memref::LoadOp>(loc, expWeightSum, outsideDims);
Value y = b.create<memref::LoadOp>(loc, weight, coords);
Value sum = b.create<arith::AddFOp>(loc, x, y);
b.create<memref::StoreOp>(loc, sum, expWeightSum, outsideDims);
b.create<scf::YieldOp>(loc);
});
});
// calculate exp(weight) / sum(exp(weight))
b.create<scf::ParallelOp>(
loc, SmallVector<Value>(weightRank, zero),
SmallVector<Value>(weightDynSizes.begin(), weightDynSizes.end()),
SmallVector<Value>(weightRank, one),
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
SmallVector<Value> sumIVs(localIVs);
sumIVs.pop_back();
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
Value sum = b.create<memref::LoadOp>(loc, expWeightSum, sumIVs);
x = b.create<arith::DivFOp>(loc, x, sum);
b.create<memref::StoreOp>(loc, x, weight, localIVs);
b.create<scf::YieldOp>(loc);
});
// output = weight @ value
matmul(b, loc, weight, weightDynSizes, value, valueDynSizes, output,
outputDynSizes, /*transposed=*/false);
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ScanOp // ScanOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -652,6 +889,7 @@ bool SortOp::payloadUsesValueFromOperand(OpOperand *opOperand) {
outputBuffers); \ outputBuffers); \
} }
DEFINE_OP_GET_EFFECTS(AttentionOp)
DEFINE_OP_GET_EFFECTS(ScanOp) DEFINE_OP_GET_EFFECTS(ScanOp)
DEFINE_OP_GET_EFFECTS(ScatterOp) DEFINE_OP_GET_EFFECTS(ScatterOp)
DEFINE_OP_GET_EFFECTS(SortOp) DEFINE_OP_GET_EFFECTS(SortOp)

View File

@ -9040,6 +9040,35 @@ def Torch_AtenUpsampleNearest2dOp : Torch_Op<"aten.upsample_nearest2d", [
}]; }];
} }
def Torch_AtenScaledDotProductAttentionOp : Torch_Op<"aten.scaled_dot_product_attention", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$query,
AnyTorchTensorType:$key,
AnyTorchTensorType:$value,
AnyTorchOptionalTensorType:$attn_mask,
Torch_FloatType:$dropout_p,
Torch_BoolType:$is_causal,
AnyTorchOptionalFloatType:$scale
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenScaledDotProductAttentionOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 7, 1);
}
void AtenScaledDotProductAttentionOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 7, 1);
}
}];
}
def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [ def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -13,10 +13,13 @@
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/ValueRange.h" #include "mlir/IR/ValueRange.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
@ -1494,6 +1497,68 @@ public:
}; };
} // namespace } // namespace
namespace {
class ConvertAtenScaledDotProductAttentionOp
: public OpConversionPattern<AtenScaledDotProductAttentionOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenScaledDotProductAttentionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value mask = op.getAttnMask();
Value dropoutP = op.getDropoutP();
Value isCausal = op.getIsCausal();
Value scale = op.getScale();
Type elementType =
adaptor.getQuery().getType().cast<ShapedType>().getElementType();
// Verify inputs (only support defaults)
if (!mask.getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(op.getLoc(),
"attention masking not supported");
double dropout;
if (!matchPattern(dropoutP, m_TorchConstantFloat(&dropout)) ||
dropout > 0.0)
return rewriter.notifyMatchFailure(op.getLoc(), "dropout not supported");
bool causal;
if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal)
return rewriter.notifyMatchFailure(
op.getLoc(), "causal attention masking not supported");
if (!scale.getType().isa<Torch::NoneType>()) {
double scaleFloat;
if (!matchPattern(scale, m_TorchConstantFloat(&scaleFloat)) ||
scaleFloat != 1.0)
return rewriter.notifyMatchFailure(op.getLoc(),
"only default scale supported");
}
SmallVector<int64_t> outSizes(
adaptor.getQuery().getType().cast<ShapedType>().getShape());
SmallVector<int64_t> valueSizes(
adaptor.getValue().getType().cast<ShapedType>().getShape());
outSizes[outSizes.size() - 1] = valueSizes[valueSizes.size() - 1];
SmallVector<Value> outSizesDynamic(
getTensorSizes(rewriter, op.getLoc(), adaptor.getQuery()));
outSizesDynamic[outSizesDynamic.size() - 1] = getTensorSizes(
rewriter, op.getLoc(), adaptor.getValue())[valueSizes.size() - 1];
Type outType = RankedTensorType::get(outSizes, elementType);
Value output = createZeroInitTensor(rewriter, op.getLoc(), outSizesDynamic,
elementType);
// Overwrite with tm_tensor::attention
auto attention = rewriter.create<AttentionOp>(
op.getLoc(), outType,
SmallVector<Value>{adaptor.getQuery(), adaptor.getKey(),
adaptor.getValue()},
SmallVector<Value>{output});
rewriter.replaceOp(op, attention.getResult());
return success();
}
};
} // namespace
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
// The pass // The pass
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
@ -1516,7 +1581,8 @@ public:
ConversionTarget target(*context); ConversionTarget target(*context);
target.addLegalDialect<linalg::LinalgDialect, func::FuncDialect, target.addLegalDialect<linalg::LinalgDialect, func::FuncDialect,
tensor::TensorDialect, arith::ArithDialect, tensor::TensorDialect, arith::ArithDialect,
Torch::TorchDialect, TMTensorDialect>(); math::MathDialect, Torch::TorchDialect,
TMTensorDialect>();
TypeConverter typeConverter; TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; }); typeConverter.addConversion([](Type type) { return type; });
@ -1536,6 +1602,9 @@ public:
patterns.add<ConvertAtenSortOp>(typeConverter, context); patterns.add<ConvertAtenSortOp>(typeConverter, context);
target.addIllegalOp<AtenCumsumOp>(); target.addIllegalOp<AtenCumsumOp>();
patterns.add<ConvertAtenCumsumOp>(typeConverter, context); patterns.add<ConvertAtenCumsumOp>(typeConverter, context);
target.addIllegalOp<AtenScaledDotProductAttentionOp>();
patterns.add<ConvertAtenScaledDotProductAttentionOp>(typeConverter,
context);
if (failed(applyPartialConversion(getOperation(), target, if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) std::move(patterns))))

View File

@ -6867,6 +6867,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.linear(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.linear(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.scaled_dot_product_attention\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.optional<float>) -> !torch.list<int> {\n"
" %int-1 = torch.constant.int -1\n"
" %0 = torch.aten.__getitem__.t %arg2, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %1 = torch.aten._set_item.t %arg0, %int-1, %0 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.zeros\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.zeros\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n" " return %arg0 : !torch.list<int>\n"
" }\n" " }\n"
@ -8559,6 +8565,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %int11 = torch.constant.int 11\n" " %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n" " return %int11 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.scaled_dot_product_attention\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.optional<float>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.logical_or\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.logical_or\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n" " %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n" " return %int11 : !torch.int\n"

View File

@ -522,6 +522,15 @@ def atenflattenusing_ints〡shape(self: List[int], start_dim: int = 0, end
def atenlinear〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None) -> List[int]: def atenlinear〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None) -> List[int]:
return upstream_shape_functions.linear(input, weight, bias) return upstream_shape_functions.linear(input, weight, bias)
@check_shape_function([
Invocation(TensorOfShape(3, 2, 8, 4), TensorOfShape(3, 2, 8, 4), TensorOfShape(3, 2, 8, 4)), # Same shape
Invocation(TensorOfShape(3, 2, 16, 8), TensorOfShape(3, 2, 8, 8), TensorOfShape(3, 2, 8, 4)), # Different shape
])
def atenscaled_dot_product_attention〡shape(query: List[int], key: List[int], value: List[int], attn_mask: Optional[List[int]] = None, dropout_p: float = 0., is_causal: bool = False, scale: Optional[float] = None) -> List[int]:
outshape = query
outshape[-1] = value[-1]
return outshape
@check_shape_function([ @check_shape_function([
Invocation([2, 3]), Invocation([2, 3]),
]) ])
@ -1904,6 +1913,11 @@ def atenlogical_and〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp
def atenlogical_not〡dtype(self_rank_dtype: Tuple[int, int]) -> int: def atenlogical_not〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
return torch.bool return torch.bool
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(3, 4, 32, 16), (3, 4, 32, 16), (3, 4, 32, 16)]))
def atenscaled_dot_product_attention〡dtype(query_rank_dtype: Tuple[int, int], key_rank_dtype: Tuple[int, int], value_rank_dtype: Tuple[int, int], attn_mask_rank_dtype: Optional[Tuple[int, int]] = None, dropout_p: float = 0., is_causal: bool = False, scale: Optional[float] = None) -> int:
_, query_dtype = query_rank_dtype
return query_dtype
@check_dtype_function(_check_two_tensor_op()) @check_dtype_function(_check_two_tensor_op())
def atenlogical_or〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: def atenlogical_or〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
return torch.bool return torch.bool

View File

@ -565,7 +565,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)") emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)")
emit("aten::as_strided_scatter : (Tensor, Tensor, int[], int[], int?) -> (Tensor)") emit("aten::as_strided_scatter : (Tensor, Tensor, int[], int[], int?) -> (Tensor)")
emit("aten::upsample_nearest2d : (Tensor, int[], float?, float?) -> (Tensor)") emit("aten::upsample_nearest2d : (Tensor, int[], float?, float?) -> (Tensor)")
emit("aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?) -> (Tensor)")
# Dict ops. # Dict ops.
emit("aten::__contains__.str : (Dict(str, t), str) -> (bool)", has_folder=True) emit("aten::__contains__.str : (Dict(str, t), str) -> (bool)", has_folder=True)

View File

@ -3695,6 +3695,51 @@ class MoveDimIntNegativeIndexModule(torch.nn.Module):
def MoveDimIntNegativeIndexModule_basic(module, tu: TestUtils): def MoveDimIntNegativeIndexModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 2)) module.forward(tu.rand(3, 4, 2))
# ==============================================================================
class ScaledDotProductAttentionSameModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True)
])
def forward(self, query, key, value):
return torch.ops.aten.scaled_dot_product_attention(query, key, value)
@register_test_case(module_factory=lambda: ScaledDotProductAttentionSameModule())
def ScaledDotProductAttentionSameModule_basic(module, tu: TestUtils):
query = torch.randn(1, 1, 5, 5, dtype=torch.float32)
key = torch.randn(1, 1, 5, 5, dtype=torch.float32)
value = torch.randn(1, 1, 5, 5, dtype=torch.float32)
module.forward(query, key, value)
class ScaledDotProductAttentionDifferentModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True)
])
def forward(self, query, key, value):
return torch.ops.aten.scaled_dot_product_attention(query, key, value)
@register_test_case(module_factory=lambda: ScaledDotProductAttentionDifferentModule())
def ScaledDotProductAttentionDifferentModule_basic(module, tu: TestUtils):
query = torch.randn(3, 2, 8, 4, dtype=torch.float32)
key = torch.randn(3, 2, 16, 4, dtype=torch.float32)
value = torch.randn(3, 2, 16, 4, dtype=torch.float32)
module.forward(query, key, value)
# ============================================================================== # ==============================================================================