//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "npcomp/Dialect/TCP/IR/TCPOps.h" #include "npcomp/E2E/E2E.h" using namespace mlir; using namespace mlir::NPCOMP; static bool isSimpleElementwiseLinalgGeneric(linalg::GenericOp op) { // Only handle generic ops where all operands and results are tensors. if (!llvm::all_of(op.getOperandTypes(), [](Type type) { return type.isa(); })) { return false; } if (!llvm::all_of(op.getResultTypes(), [](Type type) { return type.isa(); })) { return false; } // TODO: Loosen restrictions on indexing maps. // This will require more principled handling of shape reification // earlier in the compilation stack, as in general output shapes of a // linalg.generic cannot be inferred easily. // See: // https://llvm.discourse.group/t/computing-output-shapes-of-structured-ops-on-tensors/866 if (!llvm::all_of(op.indexing_maps(), [](Attribute map) { return map.cast().getValue().isIdentity(); })) { return false; } if (!llvm::all_of(op.iterator_types(), [](Attribute str) { return str.cast().getValue() == getParallelIteratorTypeName(); })) { return false; } return true; } // TODO: Don't just open-code all shape transfer functions here. // Note: for now, we can't just rely on an OpInterface, since OpInterfaces // cannot be "externally applied". E.g. we can't change the definition of // linalg::GenericOp. static SmallVector bypassResultShapes(Operation &op) { OpBuilder builder(&op); if (auto linalgGeneric = dyn_cast(op)) { // TODO: Avoid this excessive restriction. // This will require more principled handling of the lowering to // linalg.generic -- it should generally happen after this pass, becaue in // general output shapes of a linalg.generic cannot be inferred easily. See: // https://llvm.discourse.group/t/computing-output-shapes-of-structured-ops-on-tensors/866 if (!isSimpleElementwiseLinalgGeneric(linalgGeneric)) return {}; // All shapes of all operands and results are the same for now. So // arbitrarily pick the first operand. return {builder.create(op.getLoc(), op.getOperand(0))}; } if (auto broadcastTo = dyn_cast(op)) { return {broadcastTo.shape()}; } // No shape transfer function. return {}; } namespace { // TODO: There is a coupling between this pass and LowerShapedResults. // Any op that is wrapped in tcp.shaped_results here needs to be known how to be // lowered by LowerShapedResults. class BypassShapes : public BypassShapesBase { void runOnOperation() { auto func = getOperation(); func.walk([&](Operation *opPtr) { Operation &op = *opPtr; SmallVector resultShapes = bypassResultShapes(op); if (resultShapes.empty()) return; // We have result shapes, so wrap this op in a tcp.shaped_results op. OpBuilder builder(&op); auto shapedResults = builder.create( op.getLoc(), op.getResultTypes(), resultShapes); op.replaceAllUsesWith(shapedResults); // Move the op into the body and yield the results. Block *body = builder.createBlock(&shapedResults.body()); op.moveBefore(body, body->end()); builder.create(op.getLoc(), op.getResults()); }); } }; } // namespace std::unique_ptr> mlir::NPCOMP::createBypassShapesPass() { return std::make_unique(); }