[MHLO] simplify aten.frobenius_norm.dim's lowering (#1800)

pull/1804/head
Yuanqiang Liu 2023-01-18 05:52:12 +08:00 committed by GitHub
parent e2698433db
commit 0a85033780
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 12 additions and 15 deletions

View File

@ -624,41 +624,38 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
op, "non-const bool `keepdim` is not supported");
}
auto squareOp = rewriter.create<mhlo::MulOp>(op->getLoc(), input, input);
auto initValue = createInitialValueForReduceOp(op, inputElemType, rewriter);
if (!initValue) {
return failure();
}
auto squareSumReduceOp = rewriter.create<mhlo::ReduceOp>(
op->getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
auto reduceOp = rewriter.create<mhlo::ReduceOp>(
op->getLoc(), squareOp.getResult(), initValue,
rewriter.getI64TensorAttr(dims));
Region &region = squareSumReduceOp.getBody();
Region &region = reduceOp.getBody();
Block &block = region.emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputElemType);
block.addArgument(blockArgumentTy, op->getLoc());
block.addArgument(blockArgumentTy, op->getLoc());
auto *firstArgument = block.args_begin();
auto secondArgument = block.args_rbegin();
auto firstArgument = *block.args_begin();
auto secondArgument = *block.args_rbegin();
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
auto constantOrd2 = rewriter.create<mhlo::ConstantOp>(
op->getLoc(), blockArgumentTy,
DenseElementsAttr::get(blockArgumentTy, llvm::ArrayRef<float>{2.0}));
auto abs = rewriter.create<mhlo::AbsOp>(op->getLoc(), *secondArgument);
auto squareResult = rewriter.create<mhlo::PowOp>(
op->getLoc(), abs, constantOrd2);
auto addResult = rewriter.create<mhlo::AddOp>(op->getLoc(), squareResult,
*firstArgument);
auto addResult = rewriter.create<mhlo::AddOp>(op->getLoc(), firstArgument,
secondArgument);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), addResult.getResult());
}
auto output = rewriter.create<mhlo::SqrtOp>(op->getLoc(),
squareSumReduceOp.getResult(0));
auto output =
rewriter.create<mhlo::SqrtOp>(op->getLoc(), reduceOp.getResult(0));
if (keepDim) {
auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);