mirror of https://github.com/llvm/torch-mlir
[torch] Fix tm_tensor.attention for end-to-end (#2907)
Some operations include a backend matcher for specialized operations. We map these back to generics so they appropriately match to the high performance versions. This is done for the attention operation.pull/2909/head
parent
d6e1d836ca
commit
e9cdd6cbc5
|
@ -313,9 +313,6 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
|
|||
int64_t getOutputRank() {
|
||||
return getOutputType().getRank();
|
||||
}
|
||||
int64_t getIterationDomainRank() {
|
||||
return 2;
|
||||
};
|
||||
// Method to implement for specifying output range for
|
||||
// DestinationStyleOpInterface
|
||||
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
|
||||
|
|
|
@ -1600,27 +1600,82 @@ public:
|
|||
"only default scale supported");
|
||||
}
|
||||
|
||||
auto opTy = cast<ValueTensorType>(op.getType()).toBuiltinTensor();
|
||||
auto query = adaptor.getQuery();
|
||||
auto value = adaptor.getValue();
|
||||
auto key = adaptor.getKey();
|
||||
auto queryTy = cast<ShapedType>(query.getType());
|
||||
auto valueTy = cast<ShapedType>(value.getType());
|
||||
auto keyTy = cast<ShapedType>(key.getType());
|
||||
|
||||
if (queryTy.getRank() != valueTy.getRank() ||
|
||||
queryTy.getRank() != keyTy.getRank())
|
||||
return rewriter.notifyMatchFailure(op, "operand ranks do not match");
|
||||
|
||||
if (queryTy.getRank() < 3)
|
||||
return rewriter.notifyMatchFailure(op, "missing batch dimension");
|
||||
|
||||
llvm::SmallVector<ReassociationIndices, 3> reassociation(3);
|
||||
for (int i = 0, s = valueTy.getRank() - 2; i < s; ++i)
|
||||
reassociation.front().push_back(i);
|
||||
reassociation[1].push_back(valueTy.getRank() - 2);
|
||||
reassociation[2].push_back(valueTy.getRank() - 1);
|
||||
|
||||
auto loc = op.getLoc();
|
||||
auto collapseBatch = [&rewriter, &reassociation,
|
||||
loc](Value value) -> Value {
|
||||
auto valueTy = cast<ShapedType>(value.getType());
|
||||
if (valueTy.getRank() == 3)
|
||||
return value;
|
||||
|
||||
llvm::SmallVector<int64_t, 3> newShape(3, 1);
|
||||
newShape[1] = valueTy.getDimSize(valueTy.getRank() - 2);
|
||||
newShape[2] = valueTy.getDimSize(valueTy.getRank() - 1);
|
||||
|
||||
for (int i = 0, s = valueTy.getRank() - 2; i < s; ++i) {
|
||||
if (valueTy.isDynamicDim(i)) {
|
||||
newShape[0] = ShapedType::kDynamic;
|
||||
break;
|
||||
}
|
||||
newShape[0] = newShape[0] * valueTy.getDimSize(i);
|
||||
}
|
||||
|
||||
auto collapseTy = valueTy.clone(newShape);
|
||||
return rewriter.create<tensor::CollapseShapeOp>(loc, collapseTy, value,
|
||||
reassociation);
|
||||
};
|
||||
|
||||
query = collapseBatch(query);
|
||||
key = collapseBatch(key);
|
||||
value = collapseBatch(value);
|
||||
|
||||
SmallVector<int64_t> outSizes(
|
||||
adaptor.getQuery().getType().cast<ShapedType>().getShape());
|
||||
query.getType().cast<ShapedType>().getShape());
|
||||
SmallVector<int64_t> valueSizes(
|
||||
adaptor.getValue().getType().cast<ShapedType>().getShape());
|
||||
value.getType().cast<ShapedType>().getShape());
|
||||
outSizes[outSizes.size() - 1] = valueSizes[valueSizes.size() - 1];
|
||||
SmallVector<Value> outSizesDynamic(
|
||||
getTensorSizes(rewriter, op.getLoc(), adaptor.getQuery()));
|
||||
outSizesDynamic[outSizesDynamic.size() - 1] = getTensorSizes(
|
||||
rewriter, op.getLoc(), adaptor.getValue())[valueSizes.size() - 1];
|
||||
getTensorSizes(rewriter, op.getLoc(), query));
|
||||
outSizesDynamic[outSizesDynamic.size() - 1] =
|
||||
getTensorSizes(rewriter, op.getLoc(), value)[valueSizes.size() - 1];
|
||||
Type outType = RankedTensorType::get(outSizes, elementType);
|
||||
Value output = createZeroInitTensor(rewriter, op.getLoc(), outSizesDynamic,
|
||||
elementType);
|
||||
|
||||
// Overwrite with tm_tensor::attention
|
||||
auto attention = rewriter.create<AttentionOp>(
|
||||
op.getLoc(), outType,
|
||||
SmallVector<Value>{adaptor.getQuery(), adaptor.getKey(),
|
||||
adaptor.getValue()},
|
||||
SmallVector<Value>{output});
|
||||
Value attention =
|
||||
rewriter
|
||||
.create<AttentionOp>(loc, outType,
|
||||
SmallVector<Value>{query, key, value},
|
||||
SmallVector<Value>{output})
|
||||
.getResult()[0];
|
||||
|
||||
rewriter.replaceOp(op, attention.getResult());
|
||||
if (opTy != outType) {
|
||||
attention = rewriter.create<tensor::ExpandShapeOp>(loc, opTy, attention,
|
||||
reassociation);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, attention);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -94,31 +94,22 @@ LogicalResult AttentionOp::verify() {
|
|||
ShapedType keyType = getKeyType();
|
||||
ArrayRef<int64_t> queryShape = queryType.getShape();
|
||||
ArrayRef<int64_t> keyShape = keyType.getShape();
|
||||
if (keyShape[0] != queryShape[0])
|
||||
for (int i = 0, s = queryShape.size() - 2; i < s; ++i) {
|
||||
if (keyShape[i] != queryShape[i])
|
||||
return op->emitOpError("query and key batch mismatch");
|
||||
if (keyShape[2] != queryShape[2])
|
||||
}
|
||||
if (keyShape.back() != queryShape.back())
|
||||
return op->emitOpError("query and key head dimension mismatch");
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<Range> AttentionOp::getIterationDomain(OpBuilder &builder) {
|
||||
int64_t iterationDomainRank = getIterationDomainRank();
|
||||
SmallVector<Range> loopBounds(iterationDomainRank);
|
||||
Location loc = getLoc();
|
||||
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
|
||||
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
|
||||
Value source = getQuery();
|
||||
for (auto dim : llvm::seq<int64_t>(0, iterationDomainRank)) {
|
||||
loopBounds[dim].offset = zero;
|
||||
loopBounds[dim].size = getDimValue(builder, loc, source, dim);
|
||||
loopBounds[dim].stride = one;
|
||||
}
|
||||
SmallVector<Range> loopBounds;
|
||||
return loopBounds;
|
||||
}
|
||||
|
||||
SmallVector<utils::IteratorType> AttentionOp::getLoopIteratorTypes() {
|
||||
SmallVector<utils::IteratorType> iteratorTypes(getIterationDomainRank(),
|
||||
utils::IteratorType::parallel);
|
||||
SmallVector<utils::IteratorType> iteratorTypes;
|
||||
return iteratorTypes;
|
||||
}
|
||||
|
||||
|
@ -189,6 +180,8 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
|
|||
Value zeroF = b.create<arith::ConstantOp>(loc, elementType,
|
||||
b.getFloatAttr(elementType, 0.0));
|
||||
|
||||
// TODO: This needs to be fixed, it assumes everything is dynamic however if
|
||||
// any shapes are static the `memref.alloc` generated is illegal.
|
||||
SmallVector<Value> queryDynSizes, keyDynSizes, valueDynSizes, outputDynSizes;
|
||||
for (auto i = 0; i < queryRank; i++)
|
||||
queryDynSizes.push_back(b.create<memref::DimOp>(loc, query, i));
|
||||
|
@ -204,9 +197,18 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
|
|||
auto weightSizes = SmallVector<int64_t>(queryType.getShape());
|
||||
weightSizes[weightRank - 1] = keySizes[keyRank - 2];
|
||||
auto weightType = MemRefType::get(weightSizes, queryType.getElementType());
|
||||
|
||||
// Setup the weight dynamic sizes:
|
||||
SmallVector<Value> weightDynSizes(queryDynSizes);
|
||||
weightDynSizes[weightRank - 1] = keyDynSizes[keyRank - 2];
|
||||
Value weight = b.create<memref::AllocOp>(loc, weightType, weightDynSizes);
|
||||
|
||||
SmallVector<Value> weightFilteredDynSizes;
|
||||
for (int i = 0; i < weightRank; ++i)
|
||||
if (weightSizes[i] == ShapedType::kDynamic)
|
||||
weightFilteredDynSizes.push_back(weightDynSizes[i]);
|
||||
|
||||
Value weight =
|
||||
b.create<memref::AllocOp>(loc, weightType, weightFilteredDynSizes);
|
||||
matmul(b, loc, query, queryDynSizes, key, keyDynSizes, weight, weightDynSizes,
|
||||
/*transposed=*/true);
|
||||
|
||||
|
@ -259,12 +261,17 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
|
|||
x = b.create<math::ExpOp>(loc, x);
|
||||
b.create<memref::StoreOp>(loc, x, weight, localIVs);
|
||||
});
|
||||
|
||||
llvm::SmallVector<Value> expWeightDynDims(weightFilteredDynSizes);
|
||||
if (weightSizes.back() == ShapedType::kDynamic)
|
||||
expWeightDynDims.resize(expWeightDynDims.size() - 1);
|
||||
|
||||
Value expWeightSum = b.create<memref::AllocOp>(
|
||||
loc,
|
||||
MemRefType::get(
|
||||
SmallVector<int64_t>(weightSizes.begin(), weightSizes.end() - 1),
|
||||
elementType),
|
||||
SmallVector<Value>{weightDynSizes.begin(), weightDynSizes.end() - 1});
|
||||
expWeightDynDims);
|
||||
b.create<scf::ParallelOp>(
|
||||
loc, SmallVector<Value>(weightRank - 1, zero),
|
||||
SmallVector<Value>{weightDynSizes.begin(), weightDynSizes.end() - 1},
|
||||
|
|
|
@ -189,6 +189,78 @@ private:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
|
||||
class TorchMatchSpecializedBackendOp
|
||||
: public OpConversionPattern<Torch::OperatorOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
using HandlerFn = LogicalResult (*)(OperatorOp op,
|
||||
ConversionPatternRewriter &rewriter);
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Torch::OperatorOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
if (namedHandlers.contains(op.getNameAttr())) {
|
||||
return namedHandlers.lookup(op.getNameAttr()).front()(op, rewriter);
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
static void
|
||||
populateSpecializedConversions(TorchMatchSpecializedBackendOp &matcher);
|
||||
|
||||
static std::unique_ptr<TorchMatchSpecializedBackendOp>
|
||||
getPopulatedMatcher(MLIRContext *context) {
|
||||
auto matcher = std::make_unique<TorchMatchSpecializedBackendOp>(context);
|
||||
populateSpecializedConversions(*matcher);
|
||||
return matcher;
|
||||
};
|
||||
|
||||
void populate(StringRef name, HandlerFn fn) {
|
||||
namedHandlers[StringAttr::get(getContext(), name)].push_back(fn);
|
||||
}
|
||||
|
||||
void populateLegalizedNames(llvm::DenseSet<StringAttr> &set) {
|
||||
for (auto handle : namedHandlers) {
|
||||
set.insert(handle.first);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
DenseMap<StringAttr, SmallVector<HandlerFn, 1>> namedHandlers;
|
||||
};
|
||||
|
||||
void TorchMatchSpecializedBackendOp::populateSpecializedConversions(
|
||||
TorchMatchSpecializedBackendOp &matcher) {
|
||||
matcher.populate(
|
||||
"torch.aten._scaled_dot_product_flash_attention_for_cpu",
|
||||
[](Torch::OperatorOp op,
|
||||
ConversionPatternRewriter &rewriter) -> LogicalResult {
|
||||
auto uses = op.getResult(1).getUses();
|
||||
if (uses.end() == uses.begin()) {
|
||||
auto oldOperands = op->getOperands();
|
||||
llvm::SmallVector<Value> newOperands{
|
||||
oldOperands[0], oldOperands[1], oldOperands[2], oldOperands[5],
|
||||
oldOperands[3], oldOperands[4], oldOperands[6]};
|
||||
|
||||
auto newOp = rewriter.create<Torch::AtenScaledDotProductAttentionOp>(
|
||||
op.getLoc(), op->getResultTypes()[0], newOperands,
|
||||
op->getAttrs());
|
||||
rewriter.replaceAllUsesWith(op.getResult(0), newOp.getResult());
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
});
|
||||
}
|
||||
|
||||
bool isSpecializedOperation(Torch::OperatorOp op) { return true; }
|
||||
} // namespace
|
||||
|
||||
// Reduce Ops without value semantics but the corresponding without trailing
|
||||
// underscore variant doesn't exist.
|
||||
namespace {
|
||||
|
@ -353,12 +425,24 @@ struct ReduceOpVariantsPass
|
|||
patterns.add(reduceNonValueTensorLiteralOpToValueTensorLiteralOp);
|
||||
patterns.add<ReduceNonValueSemanticOps>(context);
|
||||
|
||||
// Create specialized matcher:
|
||||
auto specialized =
|
||||
TorchMatchSpecializedBackendOp::getPopulatedMatcher(context);
|
||||
DenseSet<StringAttr> specializedNames;
|
||||
specialized->populateLegalizedNames(specializedNames);
|
||||
patterns.insert(std::move(specialized));
|
||||
|
||||
ConversionTarget target(*context);
|
||||
target.addIllegalOp<NonValueTensorLiteralOp>();
|
||||
target.addIllegalOp<AtenBernoulli_FloatOp>();
|
||||
target.addIllegalOp<AtenArangeStartOutOp>();
|
||||
target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable](
|
||||
Operation *op) {
|
||||
target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable,
|
||||
&specializedNames](Operation *op) {
|
||||
if (isa<OperatorOp>(op)) {
|
||||
if (specializedNames.contains(cast<OperatorOp>(op).getNameAttr())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>() ||
|
||||
(isa<OperatorOp>(op) &&
|
||||
operatorOpHasValueSemantics(cast<OperatorOp>(op),
|
||||
|
@ -377,6 +461,9 @@ struct ReduceOpVariantsPass
|
|||
if (op->hasTrait<Torch::OpTrait::IsTrailingUnderscoreInplaceVariant>()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (isa<OperatorOp>(op) && isSpecializedOperation(cast<OperatorOp>(op)))
|
||||
return false;
|
||||
return true;
|
||||
});
|
||||
|
||||
|
|
|
@ -303,8 +303,7 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||
|
||||
# Exception: Unsupported: node.meta['val'] is not a FakeTensor or list of FakeTensor's: _scaled_dot_product_flash_attention;
|
||||
"ScaledDotProductAttentionSameModule_basic",
|
||||
# AssertionError: Unregistered operation: torch.aten._scaled_dot_product_flash_attention_for_cpu
|
||||
"ScaledDotProductAttentionDifferentModule_basic",
|
||||
|
||||
# AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only
|
||||
|
|
|
@ -202,7 +202,6 @@ class RefBackendLinalgOnTensorsBackend(LinalgOnTensorsBackend):
|
|||
An opaque, backend specific compiled artifact object that can be
|
||||
passed to `load`.
|
||||
"""
|
||||
|
||||
run_pipeline_with_repro_report(
|
||||
imported_module, LOWERING_PIPELINE,
|
||||
"Lowering Linalg-on-Tensors IR to LLVM with RefBackend",
|
||||
|
|
|
@ -4517,18 +4517,18 @@ class ScaledDotProductAttentionSameModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True)
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True)
|
||||
])
|
||||
def forward(self, query, key, value):
|
||||
return torch.ops.aten.scaled_dot_product_attention(query, key, value)
|
||||
|
||||
@register_test_case(module_factory=lambda: ScaledDotProductAttentionSameModule())
|
||||
def ScaledDotProductAttentionSameModule_basic(module, tu: TestUtils):
|
||||
query = torch.randn(1, 1, 5, 5, dtype=torch.float32)
|
||||
key = torch.randn(1, 1, 5, 5, dtype=torch.float32)
|
||||
value = torch.randn(1, 1, 5, 5, dtype=torch.float32)
|
||||
query = torch.randn(1, 5, 5, dtype=torch.float32)
|
||||
key = torch.randn(1, 5, 5, dtype=torch.float32)
|
||||
value = torch.randn(1, 5, 5, dtype=torch.float32)
|
||||
module.forward(query, key, value)
|
||||
|
||||
class ScaledDotProductAttentionDifferentModule(torch.nn.Module):
|
||||
|
@ -4539,18 +4539,18 @@ class ScaledDotProductAttentionDifferentModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True)
|
||||
([2, 3, 8, 4], torch.float32, True),
|
||||
([2, 3, 16, 4], torch.float32, True),
|
||||
([2, 3, 16, 4], torch.float32, True)
|
||||
])
|
||||
def forward(self, query, key, value):
|
||||
return torch.ops.aten.scaled_dot_product_attention(query, key, value)
|
||||
|
||||
@register_test_case(module_factory=lambda: ScaledDotProductAttentionDifferentModule())
|
||||
def ScaledDotProductAttentionDifferentModule_basic(module, tu: TestUtils):
|
||||
query = torch.randn(3, 2, 8, 4, dtype=torch.float32)
|
||||
key = torch.randn(3, 2, 16, 4, dtype=torch.float32)
|
||||
value = torch.randn(3, 2, 16, 4, dtype=torch.float32)
|
||||
query = torch.randn(2, 3, 8, 4, dtype=torch.float32)
|
||||
key = torch.randn(2, 3, 16, 4, dtype=torch.float32)
|
||||
value = torch.randn(2, 3, 16, 4, dtype=torch.float32)
|
||||
module.forward(query, key, value)
|
||||
|
||||
# ==============================================================================
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: torch-mlir-opt -torch-reduce-op-variants %s | FileCheck %s
|
||||
// RUN: torch-mlir-opt -torch-reduce-op-variants --split-input-file %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @convert_to_value_semantic_tensors(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
|
||||
|
@ -11,6 +11,8 @@ func.func @convert_to_value_semantic_tensors(%arg0: !torch.tensor<[],f32>) -> !t
|
|||
return %0 : !torch.tensor<[],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @convert_to_value_semantic_tensors_list(
|
||||
// CHECK-SAME: %[[VT0:.*]]: !torch.vtensor, %[[VT1:.*]]: !torch.vtensor,
|
||||
// CHECK-SAME: %[[VT2:.*]]: !torch.vtensor) -> !torch.tensor {
|
||||
|
@ -40,6 +42,8 @@ func.func @convert_to_value_semantic_tensors_list(%vt0: !torch.vtensor, %vt1: !t
|
|||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @convert_to_value_semantic_tensors_optional(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor, %[[FLOAT_TENSOR:.*]]: !torch.tensor<[4],f32>,
|
||||
// CHECK-SAME: %[[TRAINING:.*]]: !torch.bool, %[[CUDNN_ENABLE:.*]]: !torch.bool,
|
||||
|
@ -83,6 +87,8 @@ func.func @convert_to_value_semantic_tensors_optional(%t: !torch.tensor,
|
|||
return %ret: !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @reduce_trailing_underscore_inplace_variant(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.tensor<[2,2],f32>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) {
|
||||
|
@ -106,6 +112,7 @@ func.func @reduce_trailing_underscore_inplace_variant(%arg0: !torch.tensor<[2,2]
|
|||
%0 = torch.aten.add_.Tensor %arg0, %arg1, %c1 : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>, !torch.int -> !torch.tensor<[2,2],f32>
|
||||
return %0, %arg0 : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>
|
||||
}
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.tensor.literal() -> !torch.tensor {
|
||||
// CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<7xf32>) : !torch.vtensor<[7],f32>
|
||||
|
@ -117,6 +124,8 @@ func.func @torch.tensor.literal() -> !torch.tensor {
|
|||
return %0 : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @convert_to_value_semantic_tensors_optional_list(
|
||||
// CHECK-SAME: %[[SELF:.*]]: !torch.tensor<[5],f32>,
|
||||
// CHECK-SAME: %[[INDICES:.*]]: !torch.tensor<[2,3],si64>) -> !torch.tensor {
|
||||
|
@ -134,6 +143,8 @@ func.func @convert_to_value_semantic_tensors_optional_list(%self: !torch.tensor<
|
|||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @convert_to_value_semantic_tensors_optional_list_nones_and_tensors(
|
||||
// CHECK-SAME: %[[SELF:.*]]: !torch.tensor<[5],f32>,
|
||||
// CHECK-SAME: %[[INDICES_0:.*]]: !torch.tensor<[2,3],si64>,
|
||||
|
@ -155,6 +166,8 @@ func.func @convert_to_value_semantic_tensors_optional_list_nones_and_tensors(%se
|
|||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.bernoulli_.float(
|
||||
// CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.tensor {
|
||||
// CHECK: %[[GENERATOR:.*]] = torch.constant.none
|
||||
|
@ -171,3 +184,22 @@ func.func @torch.aten.bernoulli_.float(%t: !torch.tensor) -> !torch.tensor {
|
|||
%ret = torch.aten.bernoulli_.float %t, %p, %generator : !torch.tensor, !torch.float, !torch.none -> !torch.tensor
|
||||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @scaled_dot_product_flash_attention_for_cpu
|
||||
// CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[1,1,5,5],f32>, %[[ARG1:.+]]: !torch.vtensor<[1,1,5,5],f32>, %[[ARG2:.+]]: !torch.vtensor<[1,1,5,5],f32>
|
||||
// CHECK: %[[ZERO:.+]] = torch.constant.float 0.000000e+00
|
||||
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE0:.+]] = torch.constant.none
|
||||
// CHECK: %[[NONE1:.+]] = torch.constant.none
|
||||
// CHECK: %[[ATTEN:.+]] = torch.aten.scaled_dot_product_attention %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[NONE0]], %[[ZERO]], %[[FALSE]], %[[NONE1]]
|
||||
// CHECK: return %[[ATTEN]]
|
||||
func.func @scaled_dot_product_flash_attention_for_cpu(%arg0: !torch.vtensor<[1,1,5,5],f32>, %arg1: !torch.vtensor<[1,1,5,5],f32>, %arg2: !torch.vtensor<[1,1,5,5],f32>) -> !torch.vtensor<[1,1,5,5],f32> {
|
||||
%float0.000000e00 = torch.constant.float 0.000000e+00
|
||||
%false = torch.constant.bool false
|
||||
%none = torch.constant.none
|
||||
%none_0 = torch.constant.none
|
||||
%0:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%arg0, %arg1, %arg2, %float0.000000e00, %false, %none, %none_0) : (!torch.vtensor<[1,1,5,5],f32>, !torch.vtensor<[1,1,5,5],f32>, !torch.vtensor<[1,1,5,5],f32>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,1,5,5],f32>, !torch.vtensor<[1,1,5],f32>)
|
||||
return %0#0 : !torch.vtensor<[1,1,5,5],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue