mirror of https://github.com/llvm/torch-mlir
Bump llvm-project to 6b65d79fbb4682468333cea42b62f15c2dffd8f3 (#2723)
Co-authored-by: hanhanW <hanhan0912@gmail.com>pull/2728/head
parent
aa7e95f7c8
commit
fb1dfa3126
|
@ -1 +1 @@
|
||||||
Subproject commit 99045b60b57571079f9cb4aea57870692523fbe8
|
Subproject commit 6b65d79fbb4682468333cea42b62f15c2dffd8f3
|
|
@ -166,7 +166,6 @@ static void matmul(OpBuilder &b, Location loc, Value lhs, ValueRange lhsSizes,
|
||||||
})
|
})
|
||||||
->getResult(0);
|
->getResult(0);
|
||||||
b.create<memref::StoreOp>(loc, sum, output, localIVs);
|
b.create<memref::StoreOp>(loc, sum, output, localIVs);
|
||||||
b.create<scf::YieldOp>(loc);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -229,13 +228,15 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
|
||||||
SmallVector<Value>(weightRank, one), init,
|
SmallVector<Value>(weightRank, one), init,
|
||||||
[&](OpBuilder &b, Location loc, ValueRange localIVs,
|
[&](OpBuilder &b, Location loc, ValueRange localIVs,
|
||||||
ValueRange accs) {
|
ValueRange accs) {
|
||||||
b.create<scf::ReduceOp>(
|
auto reduceOp = b.create<scf::ReduceOp>(loc, init);
|
||||||
loc, init,
|
// Build reduce body.
|
||||||
[&](OpBuilder &b, Location loc, Value elem, Value acc) {
|
Block &reductionBody = reduceOp.getReductions()[0].front();
|
||||||
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
|
auto bodyBuilder = OpBuilder::atBlockEnd(&reductionBody);
|
||||||
Value max = b.create<arith::MaximumFOp>(loc, x, acc);
|
Value acc = reductionBody.getArgument(0);
|
||||||
b.create<scf::ReduceReturnOp>(loc, max);
|
Value x =
|
||||||
});
|
bodyBuilder.create<memref::LoadOp>(loc, weight, localIVs);
|
||||||
|
Value max = bodyBuilder.create<arith::MaximumFOp>(loc, x, acc);
|
||||||
|
bodyBuilder.create<scf::ReduceReturnOp>(loc, max);
|
||||||
})
|
})
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
// weight = (weight - max(weight)) / math.sqrt(querySizes[-1])
|
// weight = (weight - max(weight)) / math.sqrt(querySizes[-1])
|
||||||
|
@ -247,7 +248,6 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
|
||||||
x = b.create<arith::SubFOp>(loc, x, globalMax);
|
x = b.create<arith::SubFOp>(loc, x, globalMax);
|
||||||
x = b.create<arith::DivFOp>(loc, x, scaleFactor);
|
x = b.create<arith::DivFOp>(loc, x, scaleFactor);
|
||||||
b.create<memref::StoreOp>(loc, x, weight, localIVs);
|
b.create<memref::StoreOp>(loc, x, weight, localIVs);
|
||||||
b.create<scf::YieldOp>(loc);
|
|
||||||
});
|
});
|
||||||
// calculate exp(weight)
|
// calculate exp(weight)
|
||||||
SmallVector<Value> min(weightRank, zero),
|
SmallVector<Value> min(weightRank, zero),
|
||||||
|
@ -258,7 +258,6 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
|
||||||
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
|
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
|
||||||
x = b.create<math::ExpOp>(loc, x);
|
x = b.create<math::ExpOp>(loc, x);
|
||||||
b.create<memref::StoreOp>(loc, x, weight, localIVs);
|
b.create<memref::StoreOp>(loc, x, weight, localIVs);
|
||||||
b.create<scf::YieldOp>(loc);
|
|
||||||
});
|
});
|
||||||
Value expWeightSum = b.create<memref::AllocOp>(
|
Value expWeightSum = b.create<memref::AllocOp>(
|
||||||
loc,
|
loc,
|
||||||
|
@ -290,7 +289,6 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
|
||||||
Value y = b.create<memref::LoadOp>(loc, weight, coords);
|
Value y = b.create<memref::LoadOp>(loc, weight, coords);
|
||||||
Value sum = b.create<arith::AddFOp>(loc, x, y);
|
Value sum = b.create<arith::AddFOp>(loc, x, y);
|
||||||
b.create<memref::StoreOp>(loc, sum, expWeightSum, outsideDims);
|
b.create<memref::StoreOp>(loc, sum, expWeightSum, outsideDims);
|
||||||
b.create<scf::YieldOp>(loc);
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
// calculate exp(weight) / sum(exp(weight))
|
// calculate exp(weight) / sum(exp(weight))
|
||||||
|
@ -305,7 +303,6 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
|
||||||
Value sum = b.create<memref::LoadOp>(loc, expWeightSum, sumIVs);
|
Value sum = b.create<memref::LoadOp>(loc, expWeightSum, sumIVs);
|
||||||
x = b.create<arith::DivFOp>(loc, x, sum);
|
x = b.create<arith::DivFOp>(loc, x, sum);
|
||||||
b.create<memref::StoreOp>(loc, x, weight, localIVs);
|
b.create<memref::StoreOp>(loc, x, weight, localIVs);
|
||||||
b.create<scf::YieldOp>(loc);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// output = weight @ value
|
// output = weight @ value
|
||||||
|
|
|
@ -715,6 +715,8 @@ OpFoldResult AtenNeBoolOp::fold(FoldAdaptor adaptor) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) {
|
||||||
|
if (getOperand().getType() != getResult().getType())
|
||||||
|
return nullptr;
|
||||||
if (auto tensorType = getOperand().getType().dyn_cast<BaseTensorType>()) {
|
if (auto tensorType = getOperand().getType().dyn_cast<BaseTensorType>()) {
|
||||||
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
|
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
|
||||||
return getOperand();
|
return getOperand();
|
||||||
|
@ -727,6 +729,8 @@ OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) {
|
||||||
|
if (getOperand(0).getType() != getResult().getType())
|
||||||
|
return nullptr;
|
||||||
if (auto tensorType = getOperand(0).getType().dyn_cast<BaseTensorType>()) {
|
if (auto tensorType = getOperand(0).getType().dyn_cast<BaseTensorType>()) {
|
||||||
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
|
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
|
||||||
return getOperand(0);
|
return getOperand(0);
|
||||||
|
@ -739,6 +743,8 @@ OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {
|
||||||
|
if (getSelf().getType() != getResult().getType())
|
||||||
|
return nullptr;
|
||||||
if (auto selfType = getSelf().getType().dyn_cast<BaseTensorType>()) {
|
if (auto selfType = getSelf().getType().dyn_cast<BaseTensorType>()) {
|
||||||
if (selfType.hasDtype() && selfType.getDtype().isa<mlir::IntegerType>())
|
if (selfType.hasDtype() && selfType.getDtype().isa<mlir::IntegerType>())
|
||||||
return getSelf();
|
return getSelf();
|
||||||
|
@ -911,6 +917,8 @@ OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) {
|
||||||
auto resType = getType().dyn_cast<BaseTensorType>();
|
auto resType = getType().dyn_cast<BaseTensorType>();
|
||||||
if (!resType || !resType.hasSizes() || resType.getSizes().size() != 1)
|
if (!resType || !resType.hasSizes() || resType.getSizes().size() != 1)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
if (inputType != resType)
|
||||||
|
return nullptr;
|
||||||
// Fold when both the input tensor and result are unity rank tensors.
|
// Fold when both the input tensor and result are unity rank tensors.
|
||||||
return getOperand(0);
|
return getOperand(0);
|
||||||
}
|
}
|
||||||
|
@ -2441,6 +2449,8 @@ OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) {
|
||||||
auto list = getOperand(0).getDefiningOp<PrimListConstructOp>();
|
auto list = getOperand(0).getDefiningOp<PrimListConstructOp>();
|
||||||
if (!list || !list->hasOneUse() || list.getElements().size() != 1)
|
if (!list || !list->hasOneUse() || list.getElements().size() != 1)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
if (list.getElements()[0].getType() != getResult().getType())
|
||||||
|
return nullptr;
|
||||||
return list.getElements()[0];
|
return list.getElements()[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2451,6 +2461,8 @@ OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) {
|
||||||
OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) {
|
||||||
auto inType = getOperand(0).getType().dyn_cast<BaseTensorType>();
|
auto inType = getOperand(0).getType().dyn_cast<BaseTensorType>();
|
||||||
auto outType = getResult().getType().dyn_cast<BaseTensorType>();
|
auto outType = getResult().getType().dyn_cast<BaseTensorType>();
|
||||||
|
if (inType != outType)
|
||||||
|
return nullptr;
|
||||||
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes())
|
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
if (inType.getSizes().size() != outType.getSizes().size() ||
|
if (inType.getSizes().size() != outType.getSizes().size() ||
|
||||||
|
@ -2480,6 +2492,8 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
|
||||||
|
|
||||||
auto inType = getOperand(0).getType().dyn_cast<BaseTensorType>();
|
auto inType = getOperand(0).getType().dyn_cast<BaseTensorType>();
|
||||||
auto outType = getResult().getType().dyn_cast<BaseTensorType>();
|
auto outType = getResult().getType().dyn_cast<BaseTensorType>();
|
||||||
|
if (inType != outType)
|
||||||
|
return nullptr;
|
||||||
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes())
|
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
if (inType.getSizes().size() != outType.getSizes().size() ||
|
if (inType.getSizes().size() != outType.getSizes().size() ||
|
||||||
|
|
|
@ -95,7 +95,7 @@ static bool isUseTreatedWithValueSemantics(OpOperand &use) {
|
||||||
class InlineGlobalSlotsAnalysisState : public AnalysisState {
|
class InlineGlobalSlotsAnalysisState : public AnalysisState {
|
||||||
public:
|
public:
|
||||||
InlineGlobalSlotsAnalysisState(ProgramPoint point) : AnalysisState(point) {
|
InlineGlobalSlotsAnalysisState(ProgramPoint point) : AnalysisState(point) {
|
||||||
setSafe();
|
(void)setSafe();
|
||||||
}
|
}
|
||||||
|
|
||||||
void print(raw_ostream &os) const override {
|
void print(raw_ostream &os) const override {
|
||||||
|
|
Loading…
Reference in New Issue