mirror of https://github.com/llvm/torch-mlir
[LINALG] Make `AtenMaxDimOp` use `arith.maxf` to calculate maximum (#1466)
This commit updates the linalg conversion of `AtenMaxDimOp` to use `arith.maxf` instead of `arith.select` to calculate the maximum. This allows better vectorization further downstream, since the operation can be converted to a simple max reduction when the `indices` result is not used. See: https://github.com/iree-org/iree/issues/10666.pull/1467/head
parent
e7b2b84a66
commit
8201e7b067
|
@ -152,12 +152,10 @@ public:
|
|||
nestedLoc, oldIndex.getType(),
|
||||
rewriter.create<linalg::IndexOp>(loc, dim));
|
||||
|
||||
Value predicate;
|
||||
if (inElementType.isa<mlir::FloatType>())
|
||||
predicate = rewriter.create<arith::CmpFOp>(
|
||||
nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
|
||||
auto resultMax = rewriter.create<arith::SelectOp>(
|
||||
nestedLoc, predicate, newValue, oldValue);
|
||||
auto resultMax = rewriter.create<arith::MaxFOp>(
|
||||
nestedLoc, newValue, oldValue);
|
||||
Value predicate = rewriter.create<arith::CmpFOp>(
|
||||
nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
|
||||
auto resultIndex = rewriter.create<arith::SelectOp>(
|
||||
nestedLoc, predicate, newIndex, oldIndex);
|
||||
nestedBuilder.create<linalg::YieldOp>(
|
||||
|
|
Loading…
Reference in New Issue