//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// // // The torch-mlir "reference backend" requires a few passes to glue things // together so that the final IR will work with ExecutionEngine. // // There is no actual "backend". // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Approximation.h" #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/RefBackend/Passes.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::RefBackend; //===----------------------------------------------------------------------===// // Pass registration //===----------------------------------------------------------------------===// namespace { #define GEN_PASS_REGISTRATION #include "torch-mlir/RefBackend/Passes.h.inc" } // end namespace void mlir::torch::RefBackend::registerRefBackendPasses() { ::registerPasses(); } //===----------------------------------------------------------------------===// // MungeCallingConventions //===----------------------------------------------------------------------===// static bool isArgMemRefTypeValid(Type type) { if (auto memRefType = type.dyn_cast()) { Type elemTy = memRefType.getElementType(); if (elemTy.isa()) { return true; } else if (elemTy.isa()) { return true; } else if (auto integerTy = elemTy.dyn_cast()) { if (integerTy.isSignlessInteger(64)) return true; } } return false; } static void addEmitCInterfaceAttr(FuncOp func) { func->setAttr("llvm.emit_c_interface", UnitAttr::get(func.getContext())); } static Type getAbiTypeForMemRef(Type type) { return UnrankedMemRefType::get(type.cast().getElementType(), 0); } static LogicalResult mungeFunction( FuncOp func, DenseMap consumeFuncReturnFuncs) { // Add `llvm.emit_c_interface`. // This allows ExecutionEngine to resolve the symbol properly. addEmitCInterfaceAttr(func); // Rewrite the function as follows: // - replace all memref arguments with unranked memref // - replace all returns with a call to a function, which is going to be // supplied by the code setting up the ExecutionEngine to process the // result. Additionally, ensure that all results are passed as unranked // memrefs. // - replace the function signature accordingly (unranked inputs, no returns). OpBuilder b(func.getBody()); SmallVector newArgTypes; for (auto arg : func.getArguments()) { auto type = arg.getType(); if (!isArgMemRefTypeValid(type)) return emitError(arg.getLoc(), "argument must be a memref of f32, f64, i64"); auto cast = b.create(arg.getLoc(), arg, type); arg.replaceAllUsesExcept(cast, cast); arg.setType(getAbiTypeForMemRef(type)); newArgTypes.push_back(arg.getType()); } SmallVector toErase; bool hadError = false; func.walk([&](ReturnOp op) { auto memRefType = op.getOperandTypes()[0].dyn_cast(); if (!memRefType) { hadError = true; op.emitError("return value must be memref type"); return; } auto returnType = memRefType.getElementType(); auto it = consumeFuncReturnFuncs.find(returnType); if (op.getNumOperands() != 1 || it == consumeFuncReturnFuncs.end()) { hadError = true; op.emitError("must have one return value: a memref of f32, i64 or f64"); return; } b.setInsertionPoint(op); auto cast = b.create(op.getLoc(), op.getOperand(0), getAbiTypeForMemRef(op.getOperandTypes()[0])); b.create(op.getLoc(), consumeFuncReturnFuncs[returnType], cast.getResult()); b.create(op.getLoc()); toErase.push_back(op); }); if (hadError) return failure(); func.setType(FunctionType::get(func.getContext(), newArgTypes, {})); for (Operation *op : toErase) op->erase(); return success(); } namespace { class MungeCallingConventions : public MungeCallingConventionsBase { void runOnOperation() override { auto module = getOperation(); OpBuilder b(module.getBodyRegion()); DenseMap consumeFuncReturnFuncs; DenseSet consumeFuncReturnFuncsSet; auto createConsumeFuncReturnFunc = [&](Type elemTy, std::string funcName) { auto consumeFuncReturnFunc = b.create( module.getLoc(), funcName, FunctionType::get(module.getContext(), UnrankedMemRefType::get(elemTy, /*memorySpace=*/0), {}), b.getStringAttr("private")); addEmitCInterfaceAttr(consumeFuncReturnFunc); consumeFuncReturnFuncs[elemTy] = consumeFuncReturnFunc; consumeFuncReturnFuncsSet.insert(consumeFuncReturnFunc); }; createConsumeFuncReturnFunc(b.getI64Type(), "refbackend_consume_int64_func_return"); createConsumeFuncReturnFunc(b.getF32Type(), "refbackend_consume_float32_func_return"); createConsumeFuncReturnFunc(b.getF64Type(), "refbackend_consume_float64_func_return"); for (auto func : module.getOps()) { if (consumeFuncReturnFuncsSet.contains(func)) continue; if (failed(mungeFunction(func, consumeFuncReturnFuncs))) return signalPassFailure(); } } }; } // namespace std::unique_ptr> mlir::torch::RefBackend::createMungeCallingConventionsPass() { return std::make_unique(); } //===----------------------------------------------------------------------===// // ExpandOpsForLLVM //===----------------------------------------------------------------------===// namespace { class ExpandOpsForLLVM : public ExpandOpsForLLVMBase { void runOnOperation() override { auto func = getOperation(); auto *context = &getContext(); RewritePatternSet patterns(context); populateExpandTanhPattern(patterns); patterns.add(patterns.getContext()); ConversionTarget target(*context); target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); target.addIllegalOp(); target.addIllegalOp(); if (failed(applyPartialConversion(func, target, std::move(patterns)))) { return signalPassFailure(); } } }; } // namespace std::unique_ptr> mlir::torch::RefBackend::createExpandOpsForLLVMPass() { return std::make_unique(); }