//===- AdjustCallingConventions.cpp ------------------------*- C++-*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; // Map from func name and arg index to the type bound for that arg. // This is needed because to rewrite calls, we need the non-local information // from the func definition. // We also benefit from populating this all at once, which avoids ordering // issues between rewriting of func ops vs call ops. using TypeBoundMap = DenseMap, Type>; namespace { class AdjustCallingConventionForFunc : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(func::FuncOp func, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { MLIRContext *context = func.getContext(); auto typeBoundIdent = StringAttr::get(context, "torch.type_bound"); TypeConverter::SignatureConversion conversion(func.getNumArguments()); // The TypeConverter hooks for type conversion are "context free", so we // cannot use the usual helpers here for populating SignatureConversion and // new result types. // // The incoporation of the torch.type_bound arg attr is context-dependent. for (auto type : llvm::enumerate(func.getArgumentTypes())) { if (isa(type.value())) { auto typeBoundAttr = func.getArgAttrOfType(type.index(), typeBoundIdent); Type bound = typeBoundAttr ? typeBoundAttr.getValue() : Type(); if (!isa(bound)) return rewriter.notifyMatchFailure( func, "unimplemented: preserving aliasing for non-value-semantic " "type bounds"); conversion.addInputs(type.index(), typeBoundAttr ? typeBoundAttr.getValue() : type.value()); continue; } else if (auto none = dyn_cast(type.value())) { continue; } // TODO: add tuple type. conversion.addInputs(type.index(), type.value()); } rewriter.applySignatureConversion(&func.getBody().front(), conversion, typeConverter); SmallVector newResultTypes; for (auto type : func.getFunctionType().getResults()) { if (auto none = dyn_cast(type)) { continue; } if (auto tuple = dyn_cast(type)) { llvm::append_range(newResultTypes, tuple.getContainedTypes()); continue; } newResultTypes.push_back(type); } rewriter.modifyOpInPlace(func, [&] { func.setType(FunctionType::get( getContext(), conversion.getConvertedTypes(), newResultTypes)); // Clear out the type bounds, now that the type incorporates them. for (int i = 0, e = func.getNumArguments(); i != e; i++) func.removeArgAttr(i, typeBoundIdent); }); return success(); } }; } // namespace namespace { class AdjustCallingConventionForCall : public OpConversionPattern { public: AdjustCallingConventionForCall(TypeConverter &converter, MLIRContext *context, TypeBoundMap &typeBoundMap) : OpConversionPattern(converter, context), typeBoundMap(typeBoundMap) {} LogicalResult matchAndRewrite(func::CallOp call, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector convertedResults; if (failed(typeConverter->convertTypes(call.getResultTypes(), convertedResults))) return failure(); SmallVector newOperands; for (auto operand : llvm::enumerate(adaptor.getOperands())) { if (isa(operand.value().getType())) continue; auto it = typeBoundMap.find({call.getCallee(), operand.index()}); if (it != typeBoundMap.end()) { if (auto valueTensorType = dyn_cast(it->second)) { newOperands.push_back(copyTensorToType( rewriter, call->getLoc(), valueTensorType, operand.value())); continue; } else { return rewriter.notifyMatchFailure( call, "unimplemented: preserving aliasing for non-value-semantic " "type bounds"); } } newOperands.push_back(operand.value()); } func::CallOp newCall = rewriter.create( call.getLoc(), call.getCallee(), convertedResults, newOperands); int newOpResultIdx = 0; SmallVector newResults; for (auto type : call.getResultTypes()) { if (isa(type)) { newResults.push_back( rewriter.create(call.getLoc(), type)); continue; } if (isa(type)) { newResults.push_back(rewriter.create( call.getLoc(), type, newCall.getResults())); continue; } newResults.push_back(newCall.getResult(newOpResultIdx++)); } rewriter.replaceOp(call, newResults); return success(); } private: TypeBoundMap &typeBoundMap; }; } // namespace namespace { class AdjustCallingConventionForReturn : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector newOperands; for (auto operand : adaptor.getOperands()) { if (!operand) continue; if (isa(operand.getType())) continue; if (auto tuple = dyn_cast(operand.getType())) { Location loc = op.getLoc(); for (auto en : llvm::enumerate(tuple.getContainedTypes())) { auto i = rewriter.create( loc, rewriter.getI64IntegerAttr(en.index())); newOperands.push_back( rewriter.create(loc, en.value(), operand, i)); } continue; } newOperands.push_back(operand); } rewriter.replaceOpWithNewOp(op, newOperands); return success(); } }; } // namespace static LogicalResult adjustCallingConventions(func::FuncOp func, TypeBoundMap &typeBoundMap) { MLIRContext *context = func.getContext(); RewritePatternSet patterns(context); TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); typeConverter.addConversion( [](Torch::TupleType type, SmallVectorImpl &types) -> LogicalResult { llvm::append_range(types, type.getContainedTypes()); return success(); }); typeConverter.addConversion( [](Torch::NoneType type, SmallVectorImpl &types) -> LogicalResult { return success(); }); typeConverter.addArgumentMaterialization( [](OpBuilder &builder, Torch::BaseTensorType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); assert(isa(inputs[0].getType())); return copyTensorToType(builder, loc, type, inputs[0]); }); patterns.add(typeConverter, context); patterns.add(typeConverter, context, typeBoundMap); patterns.add(typeConverter, context); ConversionTarget target(*context); target.addDynamicallyLegalOp([](func::FuncOp func) { for (int i = 0, e = func.getNumArguments(); i != e; i++) { if (func.getArgAttr(i, "torch.type_bound")) return false; if (isa(func.getArgumentTypes()[i])) return false; } for (int i = 0, e = func.getNumResults(); i != e; i++) { if (isa(func.getFunctionType().getResults()[i])) return false; } return true; }); // The dynamic legality conditions for call and return are a pain to write... // Just run the patterns once and call it a day. // // Bug for doing this better https://bugs.llvm.org/show_bug.cgi?id=49812 DenseSet opsInOriginalProgram; func.walk( [&](func::CallOp op) { opsInOriginalProgram.insert(op.getOperation()); }); func.walk([&](func::ReturnOp op) { opsInOriginalProgram.insert(op.getOperation()); }); target.addDynamicallyLegalOp([&](func::CallOp op) { return !opsInOriginalProgram.contains(op.getOperation()); }); target.addDynamicallyLegalOp([&](func::ReturnOp op) { return !opsInOriginalProgram.contains(op.getOperation()); }); target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); // We don't know how to rewrite it, so mark it as illegal. target.addIllegalOp(); if (failed(applyPartialConversion(func.getOperation(), target, std::move(patterns)))) return failure(); return success(); } namespace { class AdjustCallingConventionsPass : public AdjustCallingConventionsBase { void runOnOperation() override { auto module = getOperation(); TypeBoundMap typeBoundMap; for (auto func : module.getOps()) { for (int i = 0, e = func.getNumArguments(); i != e; i++) { auto typeBoundAttr = func.getArgAttrOfType(i, "torch.type_bound"); if (!typeBoundAttr) continue; typeBoundMap[{func.getName(), i}] = typeBoundAttr.getValue(); } } for (auto func : module.getOps()) { if (failed(adjustCallingConventions(func, typeBoundMap))) return signalPassFailure(); } } }; } // namespace std::unique_ptr> mlir::torch::Torch::createAdjustCallingConventionsPass() { return std::make_unique(); }