Bump llvm-project to 6b65d79fbb4682468333cea42b62f15c2dffd8f3 (#2723)

Co-authored-by: hanhanW <hanhan0912@gmail.com>
pull/2728/head
Kunwar Grover 2024-01-05 04:03:41 +05:30 committed by GitHub
parent aa7e95f7c8
commit fb1dfa3126
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 25 additions and 14 deletions

@ -1 +1 @@
Subproject commit 99045b60b57571079f9cb4aea57870692523fbe8
Subproject commit 6b65d79fbb4682468333cea42b62f15c2dffd8f3

View File

@ -166,7 +166,6 @@ static void matmul(OpBuilder &b, Location loc, Value lhs, ValueRange lhsSizes,
})
->getResult(0);
b.create<memref::StoreOp>(loc, sum, output, localIVs);
b.create<scf::YieldOp>(loc);
});
}
@ -229,13 +228,15 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
SmallVector<Value>(weightRank, one), init,
[&](OpBuilder &b, Location loc, ValueRange localIVs,
ValueRange accs) {
b.create<scf::ReduceOp>(
loc, init,
[&](OpBuilder &b, Location loc, Value elem, Value acc) {
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
Value max = b.create<arith::MaximumFOp>(loc, x, acc);
b.create<scf::ReduceReturnOp>(loc, max);
});
auto reduceOp = b.create<scf::ReduceOp>(loc, init);
// Build reduce body.
Block &reductionBody = reduceOp.getReductions()[0].front();
auto bodyBuilder = OpBuilder::atBlockEnd(&reductionBody);
Value acc = reductionBody.getArgument(0);
Value x =
bodyBuilder.create<memref::LoadOp>(loc, weight, localIVs);
Value max = bodyBuilder.create<arith::MaximumFOp>(loc, x, acc);
bodyBuilder.create<scf::ReduceReturnOp>(loc, max);
})
.getResult(0);
// weight = (weight - max(weight)) / math.sqrt(querySizes[-1])
@ -247,7 +248,6 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
x = b.create<arith::SubFOp>(loc, x, globalMax);
x = b.create<arith::DivFOp>(loc, x, scaleFactor);
b.create<memref::StoreOp>(loc, x, weight, localIVs);
b.create<scf::YieldOp>(loc);
});
// calculate exp(weight)
SmallVector<Value> min(weightRank, zero),
@ -258,7 +258,6 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
x = b.create<math::ExpOp>(loc, x);
b.create<memref::StoreOp>(loc, x, weight, localIVs);
b.create<scf::YieldOp>(loc);
});
Value expWeightSum = b.create<memref::AllocOp>(
loc,
@ -290,7 +289,6 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
Value y = b.create<memref::LoadOp>(loc, weight, coords);
Value sum = b.create<arith::AddFOp>(loc, x, y);
b.create<memref::StoreOp>(loc, sum, expWeightSum, outsideDims);
b.create<scf::YieldOp>(loc);
});
});
// calculate exp(weight) / sum(exp(weight))
@ -305,7 +303,6 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
Value sum = b.create<memref::LoadOp>(loc, expWeightSum, sumIVs);
x = b.create<arith::DivFOp>(loc, x, sum);
b.create<memref::StoreOp>(loc, x, weight, localIVs);
b.create<scf::YieldOp>(loc);
});
// output = weight @ value

View File

@ -715,6 +715,8 @@ OpFoldResult AtenNeBoolOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) {
if (getOperand().getType() != getResult().getType())
return nullptr;
if (auto tensorType = getOperand().getType().dyn_cast<BaseTensorType>()) {
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
return getOperand();
@ -727,6 +729,8 @@ OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) {
if (getOperand(0).getType() != getResult().getType())
return nullptr;
if (auto tensorType = getOperand(0).getType().dyn_cast<BaseTensorType>()) {
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
return getOperand(0);
@ -739,6 +743,8 @@ OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {
if (getSelf().getType() != getResult().getType())
return nullptr;
if (auto selfType = getSelf().getType().dyn_cast<BaseTensorType>()) {
if (selfType.hasDtype() && selfType.getDtype().isa<mlir::IntegerType>())
return getSelf();
@ -911,6 +917,8 @@ OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) {
auto resType = getType().dyn_cast<BaseTensorType>();
if (!resType || !resType.hasSizes() || resType.getSizes().size() != 1)
return nullptr;
if (inputType != resType)
return nullptr;
// Fold when both the input tensor and result are unity rank tensors.
return getOperand(0);
}
@ -2441,6 +2449,8 @@ OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) {
auto list = getOperand(0).getDefiningOp<PrimListConstructOp>();
if (!list || !list->hasOneUse() || list.getElements().size() != 1)
return nullptr;
if (list.getElements()[0].getType() != getResult().getType())
return nullptr;
return list.getElements()[0];
}
@ -2451,6 +2461,8 @@ OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) {
OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) {
auto inType = getOperand(0).getType().dyn_cast<BaseTensorType>();
auto outType = getResult().getType().dyn_cast<BaseTensorType>();
if (inType != outType)
return nullptr;
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes())
return nullptr;
if (inType.getSizes().size() != outType.getSizes().size() ||
@ -2480,6 +2492,8 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
auto inType = getOperand(0).getType().dyn_cast<BaseTensorType>();
auto outType = getResult().getType().dyn_cast<BaseTensorType>();
if (inType != outType)
return nullptr;
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes())
return nullptr;
if (inType.getSizes().size() != outType.getSizes().size() ||

View File

@ -95,7 +95,7 @@ static bool isUseTreatedWithValueSemantics(OpOperand &use) {
class InlineGlobalSlotsAnalysisState : public AnalysisState {
public:
InlineGlobalSlotsAnalysisState(ProgramPoint point) : AnalysisState(point) {
setSafe();
(void)setSafe();
}
void print(raw_ostream &os) const override {