mirror of https://github.com/llvm/torch-mlir
Bump llvm-project to f66cd9e9556a53142a26a5c21a72e21f1579217c. (#2466)
Picks up DenseResourceElementsAttr python support and fixes minf/maxf C++ rename.pull/2447/merge
parent
b03efdf2e4
commit
278c41e938
|
@ -3,4 +3,4 @@
|
|||
url = https://github.com/llvm/llvm-project.git
|
||||
[submodule "externals/stablehlo"]
|
||||
path = externals/stablehlo
|
||||
url = https://github.com/openxla/stablehlo.git
|
||||
url = https://github.com/shark-infra/stablehlo.git
|
||||
|
|
|
@ -233,7 +233,7 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
|
|||
loc, init,
|
||||
[&](OpBuilder &b, Location loc, Value elem, Value acc) {
|
||||
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
|
||||
Value max = b.create<arith::MaxFOp>(loc, x, acc);
|
||||
Value max = b.create<arith::MaximumFOp>(loc, x, acc);
|
||||
b.create<scf::ReduceReturnOp>(loc, max);
|
||||
});
|
||||
})
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit 4acc3ffbb0af5631bc7916aeff3570f448899647
|
||||
Subproject commit f66cd9e9556a53142a26a5c21a72e21f1579217c
|
|
@ -176,8 +176,8 @@ public:
|
|||
|
||||
Value resultMax, predicate;
|
||||
if (inElementType.isa<mlir::FloatType>()) {
|
||||
resultMax =
|
||||
rewriter.create<arith::MaxFOp>(nestedLoc, newValue, oldValue);
|
||||
resultMax = rewriter.create<arith::MaximumFOp>(nestedLoc, newValue,
|
||||
oldValue);
|
||||
predicate = rewriter.create<arith::CmpFOp>(
|
||||
nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
|
||||
} else {
|
||||
|
@ -280,7 +280,7 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
|
|||
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
|
||||
Value result = payloadArgs[1];
|
||||
if (resultElementType.isa<mlir::FloatType>())
|
||||
return b.create<arith::MaxFOp>(loc, self, result);
|
||||
return b.create<arith::MaximumFOp>(loc, self, result);
|
||||
else if (resultElementType.isa<mlir::IntegerType>()) {
|
||||
IntegerType intType = max.getSelf()
|
||||
.getType()
|
||||
|
@ -297,7 +297,7 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
|
|||
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
|
||||
Value result = payloadArgs[1];
|
||||
if (resultElementType.isa<mlir::FloatType>())
|
||||
return b.create<arith::MinFOp>(loc, self, result);
|
||||
return b.create<arith::MinimumFOp>(loc, self, result);
|
||||
else if (resultElementType.isa<mlir::IntegerType>()) {
|
||||
IntegerType intType = min.getSelf()
|
||||
.getType()
|
||||
|
|
|
@ -1332,7 +1332,7 @@ public:
|
|||
if (update.getType().isa<mlir::IntegerType>()) {
|
||||
result = b.create<arith::MaxSIOp>(loc, update, current);
|
||||
} else if (update.getType().isa<mlir::FloatType>()) {
|
||||
result = b.create<arith::MaxFOp>(loc, update, current);
|
||||
result = b.create<arith::MaximumFOp>(loc, update, current);
|
||||
} else {
|
||||
llvm_unreachable("Only integer/float types supported!");
|
||||
}
|
||||
|
@ -1340,7 +1340,7 @@ public:
|
|||
if (update.getType().isa<mlir::IntegerType>()) {
|
||||
result = b.create<arith::MinSIOp>(loc, update, current);
|
||||
} else if (update.getType().isa<mlir::FloatType>()) {
|
||||
result = b.create<arith::MinFOp>(loc, update, current);
|
||||
result = b.create<arith::MinimumFOp>(loc, update, current);
|
||||
} else {
|
||||
llvm_unreachable("Only integer/float types supported!");
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue