mirror of https://github.com/llvm/torch-mlir
[torch] Fix tm_tensor.attention for end-to-end (#2907)
Some operations include a backend matcher for specialized operations. We map these back to generics so they appropriately match to the high performance versions. This is done for the attention operation.pull/2909/head
parent
d6e1d836ca
commit
e9cdd6cbc5
|
@ -313,9 +313,6 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
|
||||||
int64_t getOutputRank() {
|
int64_t getOutputRank() {
|
||||||
return getOutputType().getRank();
|
return getOutputType().getRank();
|
||||||
}
|
}
|
||||||
int64_t getIterationDomainRank() {
|
|
||||||
return 2;
|
|
||||||
};
|
|
||||||
// Method to implement for specifying output range for
|
// Method to implement for specifying output range for
|
||||||
// DestinationStyleOpInterface
|
// DestinationStyleOpInterface
|
||||||
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
|
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
|
||||||
|
|
|
@ -1600,27 +1600,82 @@ public:
|
||||||
"only default scale supported");
|
"only default scale supported");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto opTy = cast<ValueTensorType>(op.getType()).toBuiltinTensor();
|
||||||
|
auto query = adaptor.getQuery();
|
||||||
|
auto value = adaptor.getValue();
|
||||||
|
auto key = adaptor.getKey();
|
||||||
|
auto queryTy = cast<ShapedType>(query.getType());
|
||||||
|
auto valueTy = cast<ShapedType>(value.getType());
|
||||||
|
auto keyTy = cast<ShapedType>(key.getType());
|
||||||
|
|
||||||
|
if (queryTy.getRank() != valueTy.getRank() ||
|
||||||
|
queryTy.getRank() != keyTy.getRank())
|
||||||
|
return rewriter.notifyMatchFailure(op, "operand ranks do not match");
|
||||||
|
|
||||||
|
if (queryTy.getRank() < 3)
|
||||||
|
return rewriter.notifyMatchFailure(op, "missing batch dimension");
|
||||||
|
|
||||||
|
llvm::SmallVector<ReassociationIndices, 3> reassociation(3);
|
||||||
|
for (int i = 0, s = valueTy.getRank() - 2; i < s; ++i)
|
||||||
|
reassociation.front().push_back(i);
|
||||||
|
reassociation[1].push_back(valueTy.getRank() - 2);
|
||||||
|
reassociation[2].push_back(valueTy.getRank() - 1);
|
||||||
|
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
auto collapseBatch = [&rewriter, &reassociation,
|
||||||
|
loc](Value value) -> Value {
|
||||||
|
auto valueTy = cast<ShapedType>(value.getType());
|
||||||
|
if (valueTy.getRank() == 3)
|
||||||
|
return value;
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t, 3> newShape(3, 1);
|
||||||
|
newShape[1] = valueTy.getDimSize(valueTy.getRank() - 2);
|
||||||
|
newShape[2] = valueTy.getDimSize(valueTy.getRank() - 1);
|
||||||
|
|
||||||
|
for (int i = 0, s = valueTy.getRank() - 2; i < s; ++i) {
|
||||||
|
if (valueTy.isDynamicDim(i)) {
|
||||||
|
newShape[0] = ShapedType::kDynamic;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
newShape[0] = newShape[0] * valueTy.getDimSize(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto collapseTy = valueTy.clone(newShape);
|
||||||
|
return rewriter.create<tensor::CollapseShapeOp>(loc, collapseTy, value,
|
||||||
|
reassociation);
|
||||||
|
};
|
||||||
|
|
||||||
|
query = collapseBatch(query);
|
||||||
|
key = collapseBatch(key);
|
||||||
|
value = collapseBatch(value);
|
||||||
|
|
||||||
SmallVector<int64_t> outSizes(
|
SmallVector<int64_t> outSizes(
|
||||||
adaptor.getQuery().getType().cast<ShapedType>().getShape());
|
query.getType().cast<ShapedType>().getShape());
|
||||||
SmallVector<int64_t> valueSizes(
|
SmallVector<int64_t> valueSizes(
|
||||||
adaptor.getValue().getType().cast<ShapedType>().getShape());
|
value.getType().cast<ShapedType>().getShape());
|
||||||
outSizes[outSizes.size() - 1] = valueSizes[valueSizes.size() - 1];
|
outSizes[outSizes.size() - 1] = valueSizes[valueSizes.size() - 1];
|
||||||
SmallVector<Value> outSizesDynamic(
|
SmallVector<Value> outSizesDynamic(
|
||||||
getTensorSizes(rewriter, op.getLoc(), adaptor.getQuery()));
|
getTensorSizes(rewriter, op.getLoc(), query));
|
||||||
outSizesDynamic[outSizesDynamic.size() - 1] = getTensorSizes(
|
outSizesDynamic[outSizesDynamic.size() - 1] =
|
||||||
rewriter, op.getLoc(), adaptor.getValue())[valueSizes.size() - 1];
|
getTensorSizes(rewriter, op.getLoc(), value)[valueSizes.size() - 1];
|
||||||
Type outType = RankedTensorType::get(outSizes, elementType);
|
Type outType = RankedTensorType::get(outSizes, elementType);
|
||||||
Value output = createZeroInitTensor(rewriter, op.getLoc(), outSizesDynamic,
|
Value output = createZeroInitTensor(rewriter, op.getLoc(), outSizesDynamic,
|
||||||
elementType);
|
elementType);
|
||||||
|
|
||||||
// Overwrite with tm_tensor::attention
|
// Overwrite with tm_tensor::attention
|
||||||
auto attention = rewriter.create<AttentionOp>(
|
Value attention =
|
||||||
op.getLoc(), outType,
|
rewriter
|
||||||
SmallVector<Value>{adaptor.getQuery(), adaptor.getKey(),
|
.create<AttentionOp>(loc, outType,
|
||||||
adaptor.getValue()},
|
SmallVector<Value>{query, key, value},
|
||||||
SmallVector<Value>{output});
|
SmallVector<Value>{output})
|
||||||
|
.getResult()[0];
|
||||||
|
|
||||||
rewriter.replaceOp(op, attention.getResult());
|
if (opTy != outType) {
|
||||||
|
attention = rewriter.create<tensor::ExpandShapeOp>(loc, opTy, attention,
|
||||||
|
reassociation);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, attention);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -94,31 +94,22 @@ LogicalResult AttentionOp::verify() {
|
||||||
ShapedType keyType = getKeyType();
|
ShapedType keyType = getKeyType();
|
||||||
ArrayRef<int64_t> queryShape = queryType.getShape();
|
ArrayRef<int64_t> queryShape = queryType.getShape();
|
||||||
ArrayRef<int64_t> keyShape = keyType.getShape();
|
ArrayRef<int64_t> keyShape = keyType.getShape();
|
||||||
if (keyShape[0] != queryShape[0])
|
for (int i = 0, s = queryShape.size() - 2; i < s; ++i) {
|
||||||
return op->emitOpError("query and key batch mismatch");
|
if (keyShape[i] != queryShape[i])
|
||||||
if (keyShape[2] != queryShape[2])
|
return op->emitOpError("query and key batch mismatch");
|
||||||
|
}
|
||||||
|
if (keyShape.back() != queryShape.back())
|
||||||
return op->emitOpError("query and key head dimension mismatch");
|
return op->emitOpError("query and key head dimension mismatch");
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Range> AttentionOp::getIterationDomain(OpBuilder &builder) {
|
SmallVector<Range> AttentionOp::getIterationDomain(OpBuilder &builder) {
|
||||||
int64_t iterationDomainRank = getIterationDomainRank();
|
SmallVector<Range> loopBounds;
|
||||||
SmallVector<Range> loopBounds(iterationDomainRank);
|
|
||||||
Location loc = getLoc();
|
|
||||||
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
|
|
||||||
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
|
|
||||||
Value source = getQuery();
|
|
||||||
for (auto dim : llvm::seq<int64_t>(0, iterationDomainRank)) {
|
|
||||||
loopBounds[dim].offset = zero;
|
|
||||||
loopBounds[dim].size = getDimValue(builder, loc, source, dim);
|
|
||||||
loopBounds[dim].stride = one;
|
|
||||||
}
|
|
||||||
return loopBounds;
|
return loopBounds;
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<utils::IteratorType> AttentionOp::getLoopIteratorTypes() {
|
SmallVector<utils::IteratorType> AttentionOp::getLoopIteratorTypes() {
|
||||||
SmallVector<utils::IteratorType> iteratorTypes(getIterationDomainRank(),
|
SmallVector<utils::IteratorType> iteratorTypes;
|
||||||
utils::IteratorType::parallel);
|
|
||||||
return iteratorTypes;
|
return iteratorTypes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -189,6 +180,8 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
|
||||||
Value zeroF = b.create<arith::ConstantOp>(loc, elementType,
|
Value zeroF = b.create<arith::ConstantOp>(loc, elementType,
|
||||||
b.getFloatAttr(elementType, 0.0));
|
b.getFloatAttr(elementType, 0.0));
|
||||||
|
|
||||||
|
// TODO: This needs to be fixed, it assumes everything is dynamic however if
|
||||||
|
// any shapes are static the `memref.alloc` generated is illegal.
|
||||||
SmallVector<Value> queryDynSizes, keyDynSizes, valueDynSizes, outputDynSizes;
|
SmallVector<Value> queryDynSizes, keyDynSizes, valueDynSizes, outputDynSizes;
|
||||||
for (auto i = 0; i < queryRank; i++)
|
for (auto i = 0; i < queryRank; i++)
|
||||||
queryDynSizes.push_back(b.create<memref::DimOp>(loc, query, i));
|
queryDynSizes.push_back(b.create<memref::DimOp>(loc, query, i));
|
||||||
|
@ -204,9 +197,18 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
|
||||||
auto weightSizes = SmallVector<int64_t>(queryType.getShape());
|
auto weightSizes = SmallVector<int64_t>(queryType.getShape());
|
||||||
weightSizes[weightRank - 1] = keySizes[keyRank - 2];
|
weightSizes[weightRank - 1] = keySizes[keyRank - 2];
|
||||||
auto weightType = MemRefType::get(weightSizes, queryType.getElementType());
|
auto weightType = MemRefType::get(weightSizes, queryType.getElementType());
|
||||||
|
|
||||||
|
// Setup the weight dynamic sizes:
|
||||||
SmallVector<Value> weightDynSizes(queryDynSizes);
|
SmallVector<Value> weightDynSizes(queryDynSizes);
|
||||||
weightDynSizes[weightRank - 1] = keyDynSizes[keyRank - 2];
|
weightDynSizes[weightRank - 1] = keyDynSizes[keyRank - 2];
|
||||||
Value weight = b.create<memref::AllocOp>(loc, weightType, weightDynSizes);
|
|
||||||
|
SmallVector<Value> weightFilteredDynSizes;
|
||||||
|
for (int i = 0; i < weightRank; ++i)
|
||||||
|
if (weightSizes[i] == ShapedType::kDynamic)
|
||||||
|
weightFilteredDynSizes.push_back(weightDynSizes[i]);
|
||||||
|
|
||||||
|
Value weight =
|
||||||
|
b.create<memref::AllocOp>(loc, weightType, weightFilteredDynSizes);
|
||||||
matmul(b, loc, query, queryDynSizes, key, keyDynSizes, weight, weightDynSizes,
|
matmul(b, loc, query, queryDynSizes, key, keyDynSizes, weight, weightDynSizes,
|
||||||
/*transposed=*/true);
|
/*transposed=*/true);
|
||||||
|
|
||||||
|
@ -259,12 +261,17 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
|
||||||
x = b.create<math::ExpOp>(loc, x);
|
x = b.create<math::ExpOp>(loc, x);
|
||||||
b.create<memref::StoreOp>(loc, x, weight, localIVs);
|
b.create<memref::StoreOp>(loc, x, weight, localIVs);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
llvm::SmallVector<Value> expWeightDynDims(weightFilteredDynSizes);
|
||||||
|
if (weightSizes.back() == ShapedType::kDynamic)
|
||||||
|
expWeightDynDims.resize(expWeightDynDims.size() - 1);
|
||||||
|
|
||||||
Value expWeightSum = b.create<memref::AllocOp>(
|
Value expWeightSum = b.create<memref::AllocOp>(
|
||||||
loc,
|
loc,
|
||||||
MemRefType::get(
|
MemRefType::get(
|
||||||
SmallVector<int64_t>(weightSizes.begin(), weightSizes.end() - 1),
|
SmallVector<int64_t>(weightSizes.begin(), weightSizes.end() - 1),
|
||||||
elementType),
|
elementType),
|
||||||
SmallVector<Value>{weightDynSizes.begin(), weightDynSizes.end() - 1});
|
expWeightDynDims);
|
||||||
b.create<scf::ParallelOp>(
|
b.create<scf::ParallelOp>(
|
||||||
loc, SmallVector<Value>(weightRank - 1, zero),
|
loc, SmallVector<Value>(weightRank - 1, zero),
|
||||||
SmallVector<Value>{weightDynSizes.begin(), weightDynSizes.end() - 1},
|
SmallVector<Value>{weightDynSizes.begin(), weightDynSizes.end() - 1},
|
||||||
|
|
|
@ -189,6 +189,78 @@ private:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class TorchMatchSpecializedBackendOp
|
||||||
|
: public OpConversionPattern<Torch::OperatorOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
|
||||||
|
using HandlerFn = LogicalResult (*)(OperatorOp op,
|
||||||
|
ConversionPatternRewriter &rewriter);
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(Torch::OperatorOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
|
if (namedHandlers.contains(op.getNameAttr())) {
|
||||||
|
return namedHandlers.lookup(op.getNameAttr()).front()(op, rewriter);
|
||||||
|
}
|
||||||
|
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
static void
|
||||||
|
populateSpecializedConversions(TorchMatchSpecializedBackendOp &matcher);
|
||||||
|
|
||||||
|
static std::unique_ptr<TorchMatchSpecializedBackendOp>
|
||||||
|
getPopulatedMatcher(MLIRContext *context) {
|
||||||
|
auto matcher = std::make_unique<TorchMatchSpecializedBackendOp>(context);
|
||||||
|
populateSpecializedConversions(*matcher);
|
||||||
|
return matcher;
|
||||||
|
};
|
||||||
|
|
||||||
|
void populate(StringRef name, HandlerFn fn) {
|
||||||
|
namedHandlers[StringAttr::get(getContext(), name)].push_back(fn);
|
||||||
|
}
|
||||||
|
|
||||||
|
void populateLegalizedNames(llvm::DenseSet<StringAttr> &set) {
|
||||||
|
for (auto handle : namedHandlers) {
|
||||||
|
set.insert(handle.first);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
DenseMap<StringAttr, SmallVector<HandlerFn, 1>> namedHandlers;
|
||||||
|
};
|
||||||
|
|
||||||
|
void TorchMatchSpecializedBackendOp::populateSpecializedConversions(
|
||||||
|
TorchMatchSpecializedBackendOp &matcher) {
|
||||||
|
matcher.populate(
|
||||||
|
"torch.aten._scaled_dot_product_flash_attention_for_cpu",
|
||||||
|
[](Torch::OperatorOp op,
|
||||||
|
ConversionPatternRewriter &rewriter) -> LogicalResult {
|
||||||
|
auto uses = op.getResult(1).getUses();
|
||||||
|
if (uses.end() == uses.begin()) {
|
||||||
|
auto oldOperands = op->getOperands();
|
||||||
|
llvm::SmallVector<Value> newOperands{
|
||||||
|
oldOperands[0], oldOperands[1], oldOperands[2], oldOperands[5],
|
||||||
|
oldOperands[3], oldOperands[4], oldOperands[6]};
|
||||||
|
|
||||||
|
auto newOp = rewriter.create<Torch::AtenScaledDotProductAttentionOp>(
|
||||||
|
op.getLoc(), op->getResultTypes()[0], newOperands,
|
||||||
|
op->getAttrs());
|
||||||
|
rewriter.replaceAllUsesWith(op.getResult(0), newOp.getResult());
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isSpecializedOperation(Torch::OperatorOp op) { return true; }
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Reduce Ops without value semantics but the corresponding without trailing
|
// Reduce Ops without value semantics but the corresponding without trailing
|
||||||
// underscore variant doesn't exist.
|
// underscore variant doesn't exist.
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -353,12 +425,24 @@ struct ReduceOpVariantsPass
|
||||||
patterns.add(reduceNonValueTensorLiteralOpToValueTensorLiteralOp);
|
patterns.add(reduceNonValueTensorLiteralOpToValueTensorLiteralOp);
|
||||||
patterns.add<ReduceNonValueSemanticOps>(context);
|
patterns.add<ReduceNonValueSemanticOps>(context);
|
||||||
|
|
||||||
|
// Create specialized matcher:
|
||||||
|
auto specialized =
|
||||||
|
TorchMatchSpecializedBackendOp::getPopulatedMatcher(context);
|
||||||
|
DenseSet<StringAttr> specializedNames;
|
||||||
|
specialized->populateLegalizedNames(specializedNames);
|
||||||
|
patterns.insert(std::move(specialized));
|
||||||
|
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
target.addIllegalOp<NonValueTensorLiteralOp>();
|
target.addIllegalOp<NonValueTensorLiteralOp>();
|
||||||
target.addIllegalOp<AtenBernoulli_FloatOp>();
|
target.addIllegalOp<AtenBernoulli_FloatOp>();
|
||||||
target.addIllegalOp<AtenArangeStartOutOp>();
|
target.addIllegalOp<AtenArangeStartOutOp>();
|
||||||
target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable](
|
target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable,
|
||||||
Operation *op) {
|
&specializedNames](Operation *op) {
|
||||||
|
if (isa<OperatorOp>(op)) {
|
||||||
|
if (specializedNames.contains(cast<OperatorOp>(op).getNameAttr())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>() ||
|
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>() ||
|
||||||
(isa<OperatorOp>(op) &&
|
(isa<OperatorOp>(op) &&
|
||||||
operatorOpHasValueSemantics(cast<OperatorOp>(op),
|
operatorOpHasValueSemantics(cast<OperatorOp>(op),
|
||||||
|
@ -377,6 +461,9 @@ struct ReduceOpVariantsPass
|
||||||
if (op->hasTrait<Torch::OpTrait::IsTrailingUnderscoreInplaceVariant>()) {
|
if (op->hasTrait<Torch::OpTrait::IsTrailingUnderscoreInplaceVariant>()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (isa<OperatorOp>(op) && isSpecializedOperation(cast<OperatorOp>(op)))
|
||||||
|
return false;
|
||||||
return true;
|
return true;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -303,8 +303,7 @@ TORCHDYNAMO_XFAIL_SET = {
|
||||||
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
|
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
|
||||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||||
|
|
||||||
# Exception: Unsupported: node.meta['val'] is not a FakeTensor or list of FakeTensor's: _scaled_dot_product_flash_attention;
|
# AssertionError: Unregistered operation: torch.aten._scaled_dot_product_flash_attention_for_cpu
|
||||||
"ScaledDotProductAttentionSameModule_basic",
|
|
||||||
"ScaledDotProductAttentionDifferentModule_basic",
|
"ScaledDotProductAttentionDifferentModule_basic",
|
||||||
|
|
||||||
# AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only
|
# AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only
|
||||||
|
|
|
@ -202,7 +202,6 @@ class RefBackendLinalgOnTensorsBackend(LinalgOnTensorsBackend):
|
||||||
An opaque, backend specific compiled artifact object that can be
|
An opaque, backend specific compiled artifact object that can be
|
||||||
passed to `load`.
|
passed to `load`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
run_pipeline_with_repro_report(
|
run_pipeline_with_repro_report(
|
||||||
imported_module, LOWERING_PIPELINE,
|
imported_module, LOWERING_PIPELINE,
|
||||||
"Lowering Linalg-on-Tensors IR to LLVM with RefBackend",
|
"Lowering Linalg-on-Tensors IR to LLVM with RefBackend",
|
||||||
|
|
|
@ -4517,18 +4517,18 @@ class ScaledDotProductAttentionSameModule(torch.nn.Module):
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args([
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1, -1], torch.float32, True),
|
([-1, -1, -1], torch.float32, True),
|
||||||
([-1, -1, -1, -1], torch.float32, True),
|
([-1, -1, -1], torch.float32, True),
|
||||||
([-1, -1, -1, -1], torch.float32, True)
|
([-1, -1, -1], torch.float32, True)
|
||||||
])
|
])
|
||||||
def forward(self, query, key, value):
|
def forward(self, query, key, value):
|
||||||
return torch.ops.aten.scaled_dot_product_attention(query, key, value)
|
return torch.ops.aten.scaled_dot_product_attention(query, key, value)
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ScaledDotProductAttentionSameModule())
|
@register_test_case(module_factory=lambda: ScaledDotProductAttentionSameModule())
|
||||||
def ScaledDotProductAttentionSameModule_basic(module, tu: TestUtils):
|
def ScaledDotProductAttentionSameModule_basic(module, tu: TestUtils):
|
||||||
query = torch.randn(1, 1, 5, 5, dtype=torch.float32)
|
query = torch.randn(1, 5, 5, dtype=torch.float32)
|
||||||
key = torch.randn(1, 1, 5, 5, dtype=torch.float32)
|
key = torch.randn(1, 5, 5, dtype=torch.float32)
|
||||||
value = torch.randn(1, 1, 5, 5, dtype=torch.float32)
|
value = torch.randn(1, 5, 5, dtype=torch.float32)
|
||||||
module.forward(query, key, value)
|
module.forward(query, key, value)
|
||||||
|
|
||||||
class ScaledDotProductAttentionDifferentModule(torch.nn.Module):
|
class ScaledDotProductAttentionDifferentModule(torch.nn.Module):
|
||||||
|
@ -4539,18 +4539,18 @@ class ScaledDotProductAttentionDifferentModule(torch.nn.Module):
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args([
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1, -1], torch.float32, True),
|
([2, 3, 8, 4], torch.float32, True),
|
||||||
([-1, -1, -1, -1], torch.float32, True),
|
([2, 3, 16, 4], torch.float32, True),
|
||||||
([-1, -1, -1, -1], torch.float32, True)
|
([2, 3, 16, 4], torch.float32, True)
|
||||||
])
|
])
|
||||||
def forward(self, query, key, value):
|
def forward(self, query, key, value):
|
||||||
return torch.ops.aten.scaled_dot_product_attention(query, key, value)
|
return torch.ops.aten.scaled_dot_product_attention(query, key, value)
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ScaledDotProductAttentionDifferentModule())
|
@register_test_case(module_factory=lambda: ScaledDotProductAttentionDifferentModule())
|
||||||
def ScaledDotProductAttentionDifferentModule_basic(module, tu: TestUtils):
|
def ScaledDotProductAttentionDifferentModule_basic(module, tu: TestUtils):
|
||||||
query = torch.randn(3, 2, 8, 4, dtype=torch.float32)
|
query = torch.randn(2, 3, 8, 4, dtype=torch.float32)
|
||||||
key = torch.randn(3, 2, 16, 4, dtype=torch.float32)
|
key = torch.randn(2, 3, 16, 4, dtype=torch.float32)
|
||||||
value = torch.randn(3, 2, 16, 4, dtype=torch.float32)
|
value = torch.randn(2, 3, 16, 4, dtype=torch.float32)
|
||||||
module.forward(query, key, value)
|
module.forward(query, key, value)
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: torch-mlir-opt -torch-reduce-op-variants %s | FileCheck %s
|
// RUN: torch-mlir-opt -torch-reduce-op-variants --split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @convert_to_value_semantic_tensors(
|
// CHECK-LABEL: func.func @convert_to_value_semantic_tensors(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
|
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
|
||||||
|
@ -11,6 +11,8 @@ func.func @convert_to_value_semantic_tensors(%arg0: !torch.tensor<[],f32>) -> !t
|
||||||
return %0 : !torch.tensor<[],f32>
|
return %0 : !torch.tensor<[],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @convert_to_value_semantic_tensors_list(
|
// CHECK-LABEL: func.func @convert_to_value_semantic_tensors_list(
|
||||||
// CHECK-SAME: %[[VT0:.*]]: !torch.vtensor, %[[VT1:.*]]: !torch.vtensor,
|
// CHECK-SAME: %[[VT0:.*]]: !torch.vtensor, %[[VT1:.*]]: !torch.vtensor,
|
||||||
// CHECK-SAME: %[[VT2:.*]]: !torch.vtensor) -> !torch.tensor {
|
// CHECK-SAME: %[[VT2:.*]]: !torch.vtensor) -> !torch.tensor {
|
||||||
|
@ -40,6 +42,8 @@ func.func @convert_to_value_semantic_tensors_list(%vt0: !torch.vtensor, %vt1: !t
|
||||||
return %ret : !torch.tensor
|
return %ret : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @convert_to_value_semantic_tensors_optional(
|
// CHECK-LABEL: func.func @convert_to_value_semantic_tensors_optional(
|
||||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor, %[[FLOAT_TENSOR:.*]]: !torch.tensor<[4],f32>,
|
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor, %[[FLOAT_TENSOR:.*]]: !torch.tensor<[4],f32>,
|
||||||
// CHECK-SAME: %[[TRAINING:.*]]: !torch.bool, %[[CUDNN_ENABLE:.*]]: !torch.bool,
|
// CHECK-SAME: %[[TRAINING:.*]]: !torch.bool, %[[CUDNN_ENABLE:.*]]: !torch.bool,
|
||||||
|
@ -83,6 +87,8 @@ func.func @convert_to_value_semantic_tensors_optional(%t: !torch.tensor,
|
||||||
return %ret: !torch.tensor
|
return %ret: !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @reduce_trailing_underscore_inplace_variant(
|
// CHECK-LABEL: func.func @reduce_trailing_underscore_inplace_variant(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.tensor<[2,2],f32>,
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.tensor<[2,2],f32>,
|
||||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) {
|
// CHECK-SAME: %[[ARG1:.*]]: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) {
|
||||||
|
@ -106,6 +112,7 @@ func.func @reduce_trailing_underscore_inplace_variant(%arg0: !torch.tensor<[2,2]
|
||||||
%0 = torch.aten.add_.Tensor %arg0, %arg1, %c1 : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>, !torch.int -> !torch.tensor<[2,2],f32>
|
%0 = torch.aten.add_.Tensor %arg0, %arg1, %c1 : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>, !torch.int -> !torch.tensor<[2,2],f32>
|
||||||
return %0, %arg0 : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>
|
return %0, %arg0 : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>
|
||||||
}
|
}
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.tensor.literal() -> !torch.tensor {
|
// CHECK-LABEL: func.func @torch.tensor.literal() -> !torch.tensor {
|
||||||
// CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<7xf32>) : !torch.vtensor<[7],f32>
|
// CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<7xf32>) : !torch.vtensor<[7],f32>
|
||||||
|
@ -117,6 +124,8 @@ func.func @torch.tensor.literal() -> !torch.tensor {
|
||||||
return %0 : !torch.tensor
|
return %0 : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @convert_to_value_semantic_tensors_optional_list(
|
// CHECK-LABEL: func.func @convert_to_value_semantic_tensors_optional_list(
|
||||||
// CHECK-SAME: %[[SELF:.*]]: !torch.tensor<[5],f32>,
|
// CHECK-SAME: %[[SELF:.*]]: !torch.tensor<[5],f32>,
|
||||||
// CHECK-SAME: %[[INDICES:.*]]: !torch.tensor<[2,3],si64>) -> !torch.tensor {
|
// CHECK-SAME: %[[INDICES:.*]]: !torch.tensor<[2,3],si64>) -> !torch.tensor {
|
||||||
|
@ -134,6 +143,8 @@ func.func @convert_to_value_semantic_tensors_optional_list(%self: !torch.tensor<
|
||||||
return %ret : !torch.tensor
|
return %ret : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @convert_to_value_semantic_tensors_optional_list_nones_and_tensors(
|
// CHECK-LABEL: func.func @convert_to_value_semantic_tensors_optional_list_nones_and_tensors(
|
||||||
// CHECK-SAME: %[[SELF:.*]]: !torch.tensor<[5],f32>,
|
// CHECK-SAME: %[[SELF:.*]]: !torch.tensor<[5],f32>,
|
||||||
// CHECK-SAME: %[[INDICES_0:.*]]: !torch.tensor<[2,3],si64>,
|
// CHECK-SAME: %[[INDICES_0:.*]]: !torch.tensor<[2,3],si64>,
|
||||||
|
@ -155,6 +166,8 @@ func.func @convert_to_value_semantic_tensors_optional_list_nones_and_tensors(%se
|
||||||
return %ret : !torch.tensor
|
return %ret : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.bernoulli_.float(
|
// CHECK-LABEL: func.func @torch.aten.bernoulli_.float(
|
||||||
// CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.tensor {
|
// CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.tensor {
|
||||||
// CHECK: %[[GENERATOR:.*]] = torch.constant.none
|
// CHECK: %[[GENERATOR:.*]] = torch.constant.none
|
||||||
|
@ -171,3 +184,22 @@ func.func @torch.aten.bernoulli_.float(%t: !torch.tensor) -> !torch.tensor {
|
||||||
%ret = torch.aten.bernoulli_.float %t, %p, %generator : !torch.tensor, !torch.float, !torch.none -> !torch.tensor
|
%ret = torch.aten.bernoulli_.float %t, %p, %generator : !torch.tensor, !torch.float, !torch.none -> !torch.tensor
|
||||||
return %ret : !torch.tensor
|
return %ret : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @scaled_dot_product_flash_attention_for_cpu
|
||||||
|
// CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[1,1,5,5],f32>, %[[ARG1:.+]]: !torch.vtensor<[1,1,5,5],f32>, %[[ARG2:.+]]: !torch.vtensor<[1,1,5,5],f32>
|
||||||
|
// CHECK: %[[ZERO:.+]] = torch.constant.float 0.000000e+00
|
||||||
|
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[NONE0:.+]] = torch.constant.none
|
||||||
|
// CHECK: %[[NONE1:.+]] = torch.constant.none
|
||||||
|
// CHECK: %[[ATTEN:.+]] = torch.aten.scaled_dot_product_attention %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[NONE0]], %[[ZERO]], %[[FALSE]], %[[NONE1]]
|
||||||
|
// CHECK: return %[[ATTEN]]
|
||||||
|
func.func @scaled_dot_product_flash_attention_for_cpu(%arg0: !torch.vtensor<[1,1,5,5],f32>, %arg1: !torch.vtensor<[1,1,5,5],f32>, %arg2: !torch.vtensor<[1,1,5,5],f32>) -> !torch.vtensor<[1,1,5,5],f32> {
|
||||||
|
%float0.000000e00 = torch.constant.float 0.000000e+00
|
||||||
|
%false = torch.constant.bool false
|
||||||
|
%none = torch.constant.none
|
||||||
|
%none_0 = torch.constant.none
|
||||||
|
%0:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%arg0, %arg1, %arg2, %float0.000000e00, %false, %none, %none_0) : (!torch.vtensor<[1,1,5,5],f32>, !torch.vtensor<[1,1,5,5],f32>, !torch.vtensor<[1,1,5,5],f32>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,1,5,5],f32>, !torch.vtensor<[1,1,5],f32>)
|
||||||
|
return %0#0 : !torch.vtensor<[1,1,5,5],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue