//===- AdjustCallingConventions.cpp ------------------------*- C++-*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h" #include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h" #include "npcomp/Dialect/Numpy/IR/NumpyDialect.h" #include "npcomp/Dialect/Numpy/IR/NumpyOps.h" #include "npcomp/Dialect/Torch/IR/TorchDialect.h" #include "npcomp/Dialect/Torch/IR/TorchOps.h" #include "npcomp/Dialect/Torch/Transforms/Passes.h" using namespace mlir; using namespace mlir::NPCOMP; using namespace mlir::NPCOMP::Torch; // Map from func name and arg index to the type bound for that arg. // This is needed because to rewrite calls, we need the non-local information // from the func definition. // We also benefit from populating this all at once, which avoids ordering // issues between rewriting of func ops vs call ops. using TypeBoundMap = DenseMap, Type> ; namespace { class AdjustCallingConventionForFunc : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(FuncOp func, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { MLIRContext *context = func.getContext(); auto typeBoundIdent = Identifier::get("torch.type_bound", context); TypeConverter::SignatureConversion conversion(func.getNumArguments()); // The TypeConverter hooks for type conversion are "context free", so we // cannot use the usual helpers here for populating SignatureConversion and // new result types. // // The incoporation of the torch.type_bound arg attr is context-dependent. for (auto type : llvm::enumerate(func.getArgumentTypes())) { if (auto ndarray = type.value().dyn_cast()) { auto typeBoundAttr = func.getArgAttrOfType(type.index(), typeBoundIdent); conversion.addInputs(type.index(), typeBoundAttr ? typeBoundAttr.getValue() : type.value()); continue; // type is attached to ndarray type. // TODO: check if more specific? } else if (auto none = type.value().dyn_cast()) { continue; } // TODO: add tuple type. conversion.addInputs(type.index(), type.value()); } SmallVector newResultTypes; for (auto type : func.getType().getResults()) { if (auto none = type.dyn_cast()) { continue; } newResultTypes.push_back(type); } rewriter.applySignatureConversion(&func.getBody(), conversion, typeConverter); rewriter.updateRootInPlace(func, [&] { func.setType(FunctionType::get( getContext(), conversion.getConvertedTypes(), newResultTypes)); // Clear out the type bounds, now that the type incorporates them. for (int i = 0, e = func.getNumArguments(); i != e; i++) func.removeArgAttr(i, typeBoundIdent); }); return success(); } }; } // namespace namespace { class AdjustCallingConventionForCall : public OpConversionPattern { public: AdjustCallingConventionForCall(TypeConverter &converter, MLIRContext *context, TypeBoundMap &typeBoundMap) : OpConversionPattern(converter, context), typeBoundMap(typeBoundMap) {} LogicalResult matchAndRewrite(CallOp call, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { SmallVector convertedResults; if (failed(typeConverter->convertTypes(call.getResultTypes(), convertedResults))) return failure(); SmallVector newOperands; for (auto operand : llvm::enumerate(operands)) { if (operand.value().getType().isa()) continue; auto it = typeBoundMap.find({call.callee(), operand.index()}); if (it != typeBoundMap.end()) { newOperands.push_back(rewriter.create( call.getLoc(), it->second, operand.value())); continue; } newOperands.push_back(operand.value()); } CallOp newCall = rewriter.create(call.getLoc(), call.callee(), convertedResults, newOperands); int newOpResultIdx = 0; SmallVector newResults; for (auto type : call.getResultTypes()) { if (type.isa()) { newResults.push_back( rewriter.create(call.getLoc(), type)); continue; } newResults.push_back(newCall.getResult(newOpResultIdx++)); } rewriter.replaceOp(call, newResults); return success(); } private: TypeBoundMap &typeBoundMap; }; } // namespace namespace { class AdjustCallingConventionForReturn : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ReturnOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { SmallVector newOperands; for (auto operand : llvm::enumerate(operands)) { if (!operand.value()) continue; if (operand.value().getType().isa()) continue; newOperands.push_back(operand.value()); } rewriter.replaceOpWithNewOp(op, newOperands); return success(); } }; } // namespace static LogicalResult adjustCallingConventions(FuncOp func, TypeBoundMap &typeBoundMap) { MLIRContext *context = func.getContext(); RewritePatternSet patterns(context); // TODO: TupleTypes TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); typeConverter.addConversion( [](Basicpy::NoneType type, SmallVectorImpl &types) -> Optional { return success(); }); typeConverter.addArgumentMaterialization( [](OpBuilder &builder, Numpy::NdArrayType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); assert(inputs[0].getType().isa()); return builder.create(loc, type, inputs[0]); }); patterns.add(typeConverter, context); patterns.add(typeConverter, context, typeBoundMap); patterns.add(typeConverter, context); ConversionTarget target(*context); target.addDynamicallyLegalOp([](FuncOp func) { for (int i = 0, e = func.getNumArguments(); i != e; i++) { if (func.getArgAttr(i, "torch.type_bound")) return false; if (func.getArgumentTypes()[i].isa()) return false; } for (int i = 0, e = func.getNumResults(); i != e; i++) { if (func.getType().getResults()[i].isa()) return false; } return true; }); // The dynamic legality conditions for call and return are a pain to write... // Just run the patterns once and call it a day. // // Bug for doing this better https://bugs.llvm.org/show_bug.cgi?id=49812 DenseSet opsInOriginalProgram; func.walk([&](CallOp op) { opsInOriginalProgram.insert(op.getOperation()); }); func.walk( [&](ReturnOp op) { opsInOriginalProgram.insert(op.getOperation()); }); target.addDynamicallyLegalOp([&](CallOp op) { return !opsInOriginalProgram.contains(op.getOperation()); }); target.addDynamicallyLegalOp([&](ReturnOp op) { return !opsInOriginalProgram.contains(op.getOperation()); }); target.addLegalOp(); target.addLegalOp(); // We don't know how to rewrite it, so mark it as illegal. target.addIllegalOp(); if (failed(applyPartialConversion(func.getOperation(), target, std::move(patterns)))) return failure(); return success(); } namespace { class AdjustCallingConventionsPass : public AdjustCallingConventionsBase { void runOnOperation() override { auto module = getOperation(); TypeBoundMap typeBoundMap; for (auto func : module.getOps()) { for (int i = 0, e = func.getNumArguments(); i != e; i++) { auto typeBoundAttr = func.getArgAttrOfType(i, "torch.type_bound"); if (!typeBoundAttr) continue; typeBoundMap[{func.getName(), i}] = typeBoundAttr.getValue(); } } for (auto func : module.getOps()) { if (failed(adjustCallingConventions(func, typeBoundMap))) return signalPassFailure(); } } }; } // namespace std::unique_ptr> mlir::NPCOMP::Torch::createAdjustCallingConventionsPass() { return std::make_unique(); }