[torch] Fix tm_tensor.attention for end-to-end (#2907)

Some operations include a backend matcher for specialized operations. We
map these back to generics so they appropriately match to the high
performance versions. This is done for the attention operation.
pull/2909/head
Rob Suderman 2024-02-13 21:18:01 -08:00 committed by GitHub
parent d6e1d836ca
commit e9cdd6cbc5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 226 additions and 50 deletions

View File

@ -313,9 +313,6 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
int64_t getOutputRank() {
return getOutputType().getRank();
}
int64_t getIterationDomainRank() {
return 2;
};
// Method to implement for specifying output range for
// DestinationStyleOpInterface
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {

View File

@ -1600,27 +1600,82 @@ public:
"only default scale supported");
}
auto opTy = cast<ValueTensorType>(op.getType()).toBuiltinTensor();
auto query = adaptor.getQuery();
auto value = adaptor.getValue();
auto key = adaptor.getKey();
auto queryTy = cast<ShapedType>(query.getType());
auto valueTy = cast<ShapedType>(value.getType());
auto keyTy = cast<ShapedType>(key.getType());
if (queryTy.getRank() != valueTy.getRank() ||
queryTy.getRank() != keyTy.getRank())
return rewriter.notifyMatchFailure(op, "operand ranks do not match");
if (queryTy.getRank() < 3)
return rewriter.notifyMatchFailure(op, "missing batch dimension");
llvm::SmallVector<ReassociationIndices, 3> reassociation(3);
for (int i = 0, s = valueTy.getRank() - 2; i < s; ++i)
reassociation.front().push_back(i);
reassociation[1].push_back(valueTy.getRank() - 2);
reassociation[2].push_back(valueTy.getRank() - 1);
auto loc = op.getLoc();
auto collapseBatch = [&rewriter, &reassociation,
loc](Value value) -> Value {
auto valueTy = cast<ShapedType>(value.getType());
if (valueTy.getRank() == 3)
return value;
llvm::SmallVector<int64_t, 3> newShape(3, 1);
newShape[1] = valueTy.getDimSize(valueTy.getRank() - 2);
newShape[2] = valueTy.getDimSize(valueTy.getRank() - 1);
for (int i = 0, s = valueTy.getRank() - 2; i < s; ++i) {
if (valueTy.isDynamicDim(i)) {
newShape[0] = ShapedType::kDynamic;
break;
}
newShape[0] = newShape[0] * valueTy.getDimSize(i);
}
auto collapseTy = valueTy.clone(newShape);
return rewriter.create<tensor::CollapseShapeOp>(loc, collapseTy, value,
reassociation);
};
query = collapseBatch(query);
key = collapseBatch(key);
value = collapseBatch(value);
SmallVector<int64_t> outSizes(
adaptor.getQuery().getType().cast<ShapedType>().getShape());
query.getType().cast<ShapedType>().getShape());
SmallVector<int64_t> valueSizes(
adaptor.getValue().getType().cast<ShapedType>().getShape());
value.getType().cast<ShapedType>().getShape());
outSizes[outSizes.size() - 1] = valueSizes[valueSizes.size() - 1];
SmallVector<Value> outSizesDynamic(
getTensorSizes(rewriter, op.getLoc(), adaptor.getQuery()));
outSizesDynamic[outSizesDynamic.size() - 1] = getTensorSizes(
rewriter, op.getLoc(), adaptor.getValue())[valueSizes.size() - 1];
getTensorSizes(rewriter, op.getLoc(), query));
outSizesDynamic[outSizesDynamic.size() - 1] =
getTensorSizes(rewriter, op.getLoc(), value)[valueSizes.size() - 1];
Type outType = RankedTensorType::get(outSizes, elementType);
Value output = createZeroInitTensor(rewriter, op.getLoc(), outSizesDynamic,
elementType);
// Overwrite with tm_tensor::attention
auto attention = rewriter.create<AttentionOp>(
op.getLoc(), outType,
SmallVector<Value>{adaptor.getQuery(), adaptor.getKey(),
adaptor.getValue()},
SmallVector<Value>{output});
Value attention =
rewriter
.create<AttentionOp>(loc, outType,
SmallVector<Value>{query, key, value},
SmallVector<Value>{output})
.getResult()[0];
rewriter.replaceOp(op, attention.getResult());
if (opTy != outType) {
attention = rewriter.create<tensor::ExpandShapeOp>(loc, opTy, attention,
reassociation);
}
rewriter.replaceOp(op, attention);
return success();
}

View File

@ -94,31 +94,22 @@ LogicalResult AttentionOp::verify() {
ShapedType keyType = getKeyType();
ArrayRef<int64_t> queryShape = queryType.getShape();
ArrayRef<int64_t> keyShape = keyType.getShape();
if (keyShape[0] != queryShape[0])
return op->emitOpError("query and key batch mismatch");
if (keyShape[2] != queryShape[2])
for (int i = 0, s = queryShape.size() - 2; i < s; ++i) {
if (keyShape[i] != queryShape[i])
return op->emitOpError("query and key batch mismatch");
}
if (keyShape.back() != queryShape.back())
return op->emitOpError("query and key head dimension mismatch");
return success();
}
SmallVector<Range> AttentionOp::getIterationDomain(OpBuilder &builder) {
int64_t iterationDomainRank = getIterationDomainRank();
SmallVector<Range> loopBounds(iterationDomainRank);
Location loc = getLoc();
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
Value source = getQuery();
for (auto dim : llvm::seq<int64_t>(0, iterationDomainRank)) {
loopBounds[dim].offset = zero;
loopBounds[dim].size = getDimValue(builder, loc, source, dim);
loopBounds[dim].stride = one;
}
SmallVector<Range> loopBounds;
return loopBounds;
}
SmallVector<utils::IteratorType> AttentionOp::getLoopIteratorTypes() {
SmallVector<utils::IteratorType> iteratorTypes(getIterationDomainRank(),
utils::IteratorType::parallel);
SmallVector<utils::IteratorType> iteratorTypes;
return iteratorTypes;
}
@ -189,6 +180,8 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
Value zeroF = b.create<arith::ConstantOp>(loc, elementType,
b.getFloatAttr(elementType, 0.0));
// TODO: This needs to be fixed, it assumes everything is dynamic however if
// any shapes are static the `memref.alloc` generated is illegal.
SmallVector<Value> queryDynSizes, keyDynSizes, valueDynSizes, outputDynSizes;
for (auto i = 0; i < queryRank; i++)
queryDynSizes.push_back(b.create<memref::DimOp>(loc, query, i));
@ -204,9 +197,18 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
auto weightSizes = SmallVector<int64_t>(queryType.getShape());
weightSizes[weightRank - 1] = keySizes[keyRank - 2];
auto weightType = MemRefType::get(weightSizes, queryType.getElementType());
// Setup the weight dynamic sizes:
SmallVector<Value> weightDynSizes(queryDynSizes);
weightDynSizes[weightRank - 1] = keyDynSizes[keyRank - 2];
Value weight = b.create<memref::AllocOp>(loc, weightType, weightDynSizes);
SmallVector<Value> weightFilteredDynSizes;
for (int i = 0; i < weightRank; ++i)
if (weightSizes[i] == ShapedType::kDynamic)
weightFilteredDynSizes.push_back(weightDynSizes[i]);
Value weight =
b.create<memref::AllocOp>(loc, weightType, weightFilteredDynSizes);
matmul(b, loc, query, queryDynSizes, key, keyDynSizes, weight, weightDynSizes,
/*transposed=*/true);
@ -259,12 +261,17 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
x = b.create<math::ExpOp>(loc, x);
b.create<memref::StoreOp>(loc, x, weight, localIVs);
});
llvm::SmallVector<Value> expWeightDynDims(weightFilteredDynSizes);
if (weightSizes.back() == ShapedType::kDynamic)
expWeightDynDims.resize(expWeightDynDims.size() - 1);
Value expWeightSum = b.create<memref::AllocOp>(
loc,
MemRefType::get(
SmallVector<int64_t>(weightSizes.begin(), weightSizes.end() - 1),
elementType),
SmallVector<Value>{weightDynSizes.begin(), weightDynSizes.end() - 1});
expWeightDynDims);
b.create<scf::ParallelOp>(
loc, SmallVector<Value>(weightRank - 1, zero),
SmallVector<Value>{weightDynSizes.begin(), weightDynSizes.end() - 1},

View File

@ -189,6 +189,78 @@ private:
};
} // namespace
namespace {
class TorchMatchSpecializedBackendOp
: public OpConversionPattern<Torch::OperatorOp> {
public:
using OpConversionPattern::OpConversionPattern;
using HandlerFn = LogicalResult (*)(OperatorOp op,
ConversionPatternRewriter &rewriter);
LogicalResult
matchAndRewrite(Torch::OperatorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (namedHandlers.contains(op.getNameAttr())) {
return namedHandlers.lookup(op.getNameAttr()).front()(op, rewriter);
}
return failure();
}
static void
populateSpecializedConversions(TorchMatchSpecializedBackendOp &matcher);
static std::unique_ptr<TorchMatchSpecializedBackendOp>
getPopulatedMatcher(MLIRContext *context) {
auto matcher = std::make_unique<TorchMatchSpecializedBackendOp>(context);
populateSpecializedConversions(*matcher);
return matcher;
};
void populate(StringRef name, HandlerFn fn) {
namedHandlers[StringAttr::get(getContext(), name)].push_back(fn);
}
void populateLegalizedNames(llvm::DenseSet<StringAttr> &set) {
for (auto handle : namedHandlers) {
set.insert(handle.first);
}
}
private:
DenseMap<StringAttr, SmallVector<HandlerFn, 1>> namedHandlers;
};
void TorchMatchSpecializedBackendOp::populateSpecializedConversions(
TorchMatchSpecializedBackendOp &matcher) {
matcher.populate(
"torch.aten._scaled_dot_product_flash_attention_for_cpu",
[](Torch::OperatorOp op,
ConversionPatternRewriter &rewriter) -> LogicalResult {
auto uses = op.getResult(1).getUses();
if (uses.end() == uses.begin()) {
auto oldOperands = op->getOperands();
llvm::SmallVector<Value> newOperands{
oldOperands[0], oldOperands[1], oldOperands[2], oldOperands[5],
oldOperands[3], oldOperands[4], oldOperands[6]};
auto newOp = rewriter.create<Torch::AtenScaledDotProductAttentionOp>(
op.getLoc(), op->getResultTypes()[0], newOperands,
op->getAttrs());
rewriter.replaceAllUsesWith(op.getResult(0), newOp.getResult());
rewriter.eraseOp(op);
return success();
}
return failure();
});
}
bool isSpecializedOperation(Torch::OperatorOp op) { return true; }
} // namespace
// Reduce Ops without value semantics but the corresponding without trailing
// underscore variant doesn't exist.
namespace {
@ -353,12 +425,24 @@ struct ReduceOpVariantsPass
patterns.add(reduceNonValueTensorLiteralOpToValueTensorLiteralOp);
patterns.add<ReduceNonValueSemanticOps>(context);
// Create specialized matcher:
auto specialized =
TorchMatchSpecializedBackendOp::getPopulatedMatcher(context);
DenseSet<StringAttr> specializedNames;
specialized->populateLegalizedNames(specializedNames);
patterns.insert(std::move(specialized));
ConversionTarget target(*context);
target.addIllegalOp<NonValueTensorLiteralOp>();
target.addIllegalOp<AtenBernoulli_FloatOp>();
target.addIllegalOp<AtenArangeStartOutOp>();
target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable](
Operation *op) {
target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable,
&specializedNames](Operation *op) {
if (isa<OperatorOp>(op)) {
if (specializedNames.contains(cast<OperatorOp>(op).getNameAttr())) {
return false;
}
}
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>() ||
(isa<OperatorOp>(op) &&
operatorOpHasValueSemantics(cast<OperatorOp>(op),
@ -377,6 +461,9 @@ struct ReduceOpVariantsPass
if (op->hasTrait<Torch::OpTrait::IsTrailingUnderscoreInplaceVariant>()) {
return false;
}
if (isa<OperatorOp>(op) && isSpecializedOperation(cast<OperatorOp>(op)))
return false;
return true;
});

View File

@ -303,8 +303,7 @@ TORCHDYNAMO_XFAIL_SET = {
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
# Exception: Unsupported: node.meta['val'] is not a FakeTensor or list of FakeTensor's: _scaled_dot_product_flash_attention;
"ScaledDotProductAttentionSameModule_basic",
# AssertionError: Unregistered operation: torch.aten._scaled_dot_product_flash_attention_for_cpu
"ScaledDotProductAttentionDifferentModule_basic",
# AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only

View File

@ -202,7 +202,6 @@ class RefBackendLinalgOnTensorsBackend(LinalgOnTensorsBackend):
An opaque, backend specific compiled artifact object that can be
passed to `load`.
"""
run_pipeline_with_repro_report(
imported_module, LOWERING_PIPELINE,
"Lowering Linalg-on-Tensors IR to LLVM with RefBackend",

View File

@ -4517,18 +4517,18 @@ class ScaledDotProductAttentionSameModule(torch.nn.Module):
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True)
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True)
])
def forward(self, query, key, value):
return torch.ops.aten.scaled_dot_product_attention(query, key, value)
@register_test_case(module_factory=lambda: ScaledDotProductAttentionSameModule())
def ScaledDotProductAttentionSameModule_basic(module, tu: TestUtils):
query = torch.randn(1, 1, 5, 5, dtype=torch.float32)
key = torch.randn(1, 1, 5, 5, dtype=torch.float32)
value = torch.randn(1, 1, 5, 5, dtype=torch.float32)
query = torch.randn(1, 5, 5, dtype=torch.float32)
key = torch.randn(1, 5, 5, dtype=torch.float32)
value = torch.randn(1, 5, 5, dtype=torch.float32)
module.forward(query, key, value)
class ScaledDotProductAttentionDifferentModule(torch.nn.Module):
@ -4539,18 +4539,18 @@ class ScaledDotProductAttentionDifferentModule(torch.nn.Module):
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True)
([2, 3, 8, 4], torch.float32, True),
([2, 3, 16, 4], torch.float32, True),
([2, 3, 16, 4], torch.float32, True)
])
def forward(self, query, key, value):
return torch.ops.aten.scaled_dot_product_attention(query, key, value)
@register_test_case(module_factory=lambda: ScaledDotProductAttentionDifferentModule())
def ScaledDotProductAttentionDifferentModule_basic(module, tu: TestUtils):
query = torch.randn(3, 2, 8, 4, dtype=torch.float32)
key = torch.randn(3, 2, 16, 4, dtype=torch.float32)
value = torch.randn(3, 2, 16, 4, dtype=torch.float32)
query = torch.randn(2, 3, 8, 4, dtype=torch.float32)
key = torch.randn(2, 3, 16, 4, dtype=torch.float32)
value = torch.randn(2, 3, 16, 4, dtype=torch.float32)
module.forward(query, key, value)
# ==============================================================================

View File

@ -1,4 +1,4 @@
// RUN: torch-mlir-opt -torch-reduce-op-variants %s | FileCheck %s
// RUN: torch-mlir-opt -torch-reduce-op-variants --split-input-file %s | FileCheck %s
// CHECK-LABEL: func.func @convert_to_value_semantic_tensors(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
@ -11,6 +11,8 @@ func.func @convert_to_value_semantic_tensors(%arg0: !torch.tensor<[],f32>) -> !t
return %0 : !torch.tensor<[],f32>
}
// -----
// CHECK-LABEL: func.func @convert_to_value_semantic_tensors_list(
// CHECK-SAME: %[[VT0:.*]]: !torch.vtensor, %[[VT1:.*]]: !torch.vtensor,
// CHECK-SAME: %[[VT2:.*]]: !torch.vtensor) -> !torch.tensor {
@ -40,6 +42,8 @@ func.func @convert_to_value_semantic_tensors_list(%vt0: !torch.vtensor, %vt1: !t
return %ret : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @convert_to_value_semantic_tensors_optional(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor, %[[FLOAT_TENSOR:.*]]: !torch.tensor<[4],f32>,
// CHECK-SAME: %[[TRAINING:.*]]: !torch.bool, %[[CUDNN_ENABLE:.*]]: !torch.bool,
@ -83,6 +87,8 @@ func.func @convert_to_value_semantic_tensors_optional(%t: !torch.tensor,
return %ret: !torch.tensor
}
// -----
// CHECK-LABEL: func.func @reduce_trailing_underscore_inplace_variant(
// CHECK-SAME: %[[ARG0:.*]]: !torch.tensor<[2,2],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) {
@ -106,6 +112,7 @@ func.func @reduce_trailing_underscore_inplace_variant(%arg0: !torch.tensor<[2,2]
%0 = torch.aten.add_.Tensor %arg0, %arg1, %c1 : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>, !torch.int -> !torch.tensor<[2,2],f32>
return %0, %arg0 : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>
}
// -----
// CHECK-LABEL: func.func @torch.tensor.literal() -> !torch.tensor {
// CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<7xf32>) : !torch.vtensor<[7],f32>
@ -117,6 +124,8 @@ func.func @torch.tensor.literal() -> !torch.tensor {
return %0 : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @convert_to_value_semantic_tensors_optional_list(
// CHECK-SAME: %[[SELF:.*]]: !torch.tensor<[5],f32>,
// CHECK-SAME: %[[INDICES:.*]]: !torch.tensor<[2,3],si64>) -> !torch.tensor {
@ -134,6 +143,8 @@ func.func @convert_to_value_semantic_tensors_optional_list(%self: !torch.tensor<
return %ret : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @convert_to_value_semantic_tensors_optional_list_nones_and_tensors(
// CHECK-SAME: %[[SELF:.*]]: !torch.tensor<[5],f32>,
// CHECK-SAME: %[[INDICES_0:.*]]: !torch.tensor<[2,3],si64>,
@ -155,6 +166,8 @@ func.func @convert_to_value_semantic_tensors_optional_list_nones_and_tensors(%se
return %ret : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @torch.aten.bernoulli_.float(
// CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: %[[GENERATOR:.*]] = torch.constant.none
@ -171,3 +184,22 @@ func.func @torch.aten.bernoulli_.float(%t: !torch.tensor) -> !torch.tensor {
%ret = torch.aten.bernoulli_.float %t, %p, %generator : !torch.tensor, !torch.float, !torch.none -> !torch.tensor
return %ret : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @scaled_dot_product_flash_attention_for_cpu
// CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[1,1,5,5],f32>, %[[ARG1:.+]]: !torch.vtensor<[1,1,5,5],f32>, %[[ARG2:.+]]: !torch.vtensor<[1,1,5,5],f32>
// CHECK: %[[ZERO:.+]] = torch.constant.float 0.000000e+00
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[NONE0:.+]] = torch.constant.none
// CHECK: %[[NONE1:.+]] = torch.constant.none
// CHECK: %[[ATTEN:.+]] = torch.aten.scaled_dot_product_attention %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[NONE0]], %[[ZERO]], %[[FALSE]], %[[NONE1]]
// CHECK: return %[[ATTEN]]
func.func @scaled_dot_product_flash_attention_for_cpu(%arg0: !torch.vtensor<[1,1,5,5],f32>, %arg1: !torch.vtensor<[1,1,5,5],f32>, %arg2: !torch.vtensor<[1,1,5,5],f32>) -> !torch.vtensor<[1,1,5,5],f32> {
%float0.000000e00 = torch.constant.float 0.000000e+00
%false = torch.constant.bool false
%none = torch.constant.none
%none_0 = torch.constant.none
%0:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%arg0, %arg1, %arg2, %float0.000000e00, %false, %none, %none_0) : (!torch.vtensor<[1,1,5,5],f32>, !torch.vtensor<[1,1,5,5],f32>, !torch.vtensor<[1,1,5,5],f32>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,1,5,5],f32>, !torch.vtensor<[1,1,5],f32>)
return %0#0 : !torch.vtensor<[1,1,5,5],f32>
}