mirror of https://github.com/llvm/torch-mlir
[LINALG] Make `AtenMaxDimOp` use `arith.maxf` to calculate maximum (#1466)
This commit updates the linalg conversion of `AtenMaxDimOp` to use `arith.maxf` instead of `arith.select` to calculate the maximum. This allows better vectorization further downstream, since the operation can be converted to a simple max reduction when the `indices` result is not used. See: https://github.com/iree-org/iree/issues/10666.pull/1467/head
parent
e7b2b84a66
commit
8201e7b067
|
@ -152,12 +152,10 @@ public:
|
||||||
nestedLoc, oldIndex.getType(),
|
nestedLoc, oldIndex.getType(),
|
||||||
rewriter.create<linalg::IndexOp>(loc, dim));
|
rewriter.create<linalg::IndexOp>(loc, dim));
|
||||||
|
|
||||||
Value predicate;
|
auto resultMax = rewriter.create<arith::MaxFOp>(
|
||||||
if (inElementType.isa<mlir::FloatType>())
|
nestedLoc, newValue, oldValue);
|
||||||
predicate = rewriter.create<arith::CmpFOp>(
|
Value predicate = rewriter.create<arith::CmpFOp>(
|
||||||
nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
|
nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
|
||||||
auto resultMax = rewriter.create<arith::SelectOp>(
|
|
||||||
nestedLoc, predicate, newValue, oldValue);
|
|
||||||
auto resultIndex = rewriter.create<arith::SelectOp>(
|
auto resultIndex = rewriter.create<arith::SelectOp>(
|
||||||
nestedLoc, predicate, newIndex, oldIndex);
|
nestedLoc, predicate, newIndex, oldIndex);
|
||||||
nestedBuilder.create<linalg::YieldOp>(
|
nestedBuilder.create<linalg::YieldOp>(
|
||||||
|
|
Loading…
Reference in New Issue