mirror of https://github.com/llvm/torch-mlir
[MHLO] simplify aten.frobenius_norm.dim's lowering (#1800)
parent
e2698433db
commit
0a85033780
|
@ -624,41 +624,38 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
|
|||
op, "non-const bool `keepdim` is not supported");
|
||||
}
|
||||
|
||||
auto squareOp = rewriter.create<mhlo::MulOp>(op->getLoc(), input, input);
|
||||
|
||||
auto initValue = createInitialValueForReduceOp(op, inputElemType, rewriter);
|
||||
if (!initValue) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto squareSumReduceOp = rewriter.create<mhlo::ReduceOp>(
|
||||
op->getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
|
||||
auto reduceOp = rewriter.create<mhlo::ReduceOp>(
|
||||
op->getLoc(), squareOp.getResult(), initValue,
|
||||
rewriter.getI64TensorAttr(dims));
|
||||
|
||||
Region ®ion = squareSumReduceOp.getBody();
|
||||
Region ®ion = reduceOp.getBody();
|
||||
Block &block = region.emplaceBlock();
|
||||
auto blockArgumentTy = RankedTensorType::get({}, inputElemType);
|
||||
|
||||
block.addArgument(blockArgumentTy, op->getLoc());
|
||||
block.addArgument(blockArgumentTy, op->getLoc());
|
||||
|
||||
auto *firstArgument = block.args_begin();
|
||||
auto secondArgument = block.args_rbegin();
|
||||
auto firstArgument = *block.args_begin();
|
||||
auto secondArgument = *block.args_rbegin();
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
|
||||
auto constantOrd2 = rewriter.create<mhlo::ConstantOp>(
|
||||
op->getLoc(), blockArgumentTy,
|
||||
DenseElementsAttr::get(blockArgumentTy, llvm::ArrayRef<float>{2.0}));
|
||||
auto abs = rewriter.create<mhlo::AbsOp>(op->getLoc(), *secondArgument);
|
||||
auto squareResult = rewriter.create<mhlo::PowOp>(
|
||||
op->getLoc(), abs, constantOrd2);
|
||||
auto addResult = rewriter.create<mhlo::AddOp>(op->getLoc(), squareResult,
|
||||
*firstArgument);
|
||||
auto addResult = rewriter.create<mhlo::AddOp>(op->getLoc(), firstArgument,
|
||||
secondArgument);
|
||||
rewriter.create<mhlo::ReturnOp>(op->getLoc(), addResult.getResult());
|
||||
}
|
||||
|
||||
auto output = rewriter.create<mhlo::SqrtOp>(op->getLoc(),
|
||||
squareSumReduceOp.getResult(0));
|
||||
auto output =
|
||||
rewriter.create<mhlo::SqrtOp>(op->getLoc(), reduceOp.getResult(0));
|
||||
|
||||
if (keepDim) {
|
||||
auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
|
|
Loading…
Reference in New Issue