torch-mlir/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp

251 lines
9.5 KiB
C++

//===- AdjustCallingConventions.cpp ------------------------*- C++-*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
#include "npcomp/Dialect/Numpy/IR/NumpyOps.h"
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Torch;
// Map from func name and arg index to the type bound for that arg.
// This is needed because to rewrite calls, we need the non-local information
// from the func definition.
// We also benefit from populating this all at once, which avoids ordering
// issues between rewriting of func ops vs call ops.
using TypeBoundMap = DenseMap<std::pair<StringRef, int>, Type> ;
namespace {
class AdjustCallingConventionForFunc : public OpConversionPattern<FuncOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(FuncOp func, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *context = func.getContext();
auto typeBoundIdent = Identifier::get("torch.type_bound", context);
TypeConverter::SignatureConversion conversion(func.getNumArguments());
// The TypeConverter hooks for type conversion are "context free", so we
// cannot use the usual helpers here for populating SignatureConversion and
// new result types.
//
// The incoporation of the torch.type_bound arg attr is context-dependent.
for (auto type : llvm::enumerate(func.getArgumentTypes())) {
if (auto ndarray = type.value().dyn_cast<Numpy::NdArrayType>()) {
auto typeBoundAttr =
func.getArgAttrOfType<TypeAttr>(type.index(), typeBoundIdent);
conversion.addInputs(type.index(), typeBoundAttr
? typeBoundAttr.getValue()
: type.value());
continue;
// type is attached to ndarray type.
// TODO: check if more specific?
} else if (auto none = type.value().dyn_cast<Basicpy::NoneType>()) {
continue;
}
// TODO: add tuple type.
conversion.addInputs(type.index(), type.value());
}
SmallVector<Type> newResultTypes;
for (auto type : func.getType().getResults()) {
if (auto none = type.dyn_cast<Basicpy::NoneType>()) {
continue;
}
newResultTypes.push_back(type);
}
rewriter.applySignatureConversion(&func.getBody(), conversion,
typeConverter);
rewriter.updateRootInPlace(func, [&] {
func.setType(FunctionType::get(
getContext(), conversion.getConvertedTypes(), newResultTypes));
// Clear out the type bounds, now that the type incorporates them.
for (int i = 0, e = func.getNumArguments(); i != e; i++)
func.removeArgAttr(i, typeBoundIdent);
});
return success();
}
};
} // namespace
namespace {
class AdjustCallingConventionForCall : public OpConversionPattern<CallOp> {
public:
AdjustCallingConventionForCall(TypeConverter &converter, MLIRContext *context,
TypeBoundMap &typeBoundMap)
: OpConversionPattern<CallOp>(converter, context),
typeBoundMap(typeBoundMap) {}
LogicalResult
matchAndRewrite(CallOp call, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type> convertedResults;
if (failed(typeConverter->convertTypes(call.getResultTypes(),
convertedResults)))
return failure();
SmallVector<Value> newOperands;
for (auto operand : llvm::enumerate(operands)) {
if (operand.value().getType().isa<Basicpy::NoneType>())
continue;
auto it = typeBoundMap.find({call.callee(), operand.index()});
if (it != typeBoundMap.end()) {
newOperands.push_back(rewriter.create<Numpy::StaticInfoCastOp>(
call.getLoc(), it->second, operand.value()));
continue;
}
newOperands.push_back(operand.value());
}
CallOp newCall = rewriter.create<CallOp>(call.getLoc(), call.callee(),
convertedResults, newOperands);
int newOpResultIdx = 0;
SmallVector<Value> newResults;
for (auto type : call.getResultTypes()) {
if (type.isa<Basicpy::NoneType>()) {
newResults.push_back(
rewriter.create<Basicpy::SingletonOp>(call.getLoc(), type));
continue;
}
newResults.push_back(newCall.getResult(newOpResultIdx++));
}
rewriter.replaceOp(call, newResults);
return success();
}
private:
TypeBoundMap &typeBoundMap;
};
} // namespace
namespace {
class AdjustCallingConventionForReturn : public OpConversionPattern<ReturnOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> newOperands;
for (auto operand : llvm::enumerate(operands)) {
if (!operand.value())
continue;
if (operand.value().getType().isa<Basicpy::NoneType>())
continue;
newOperands.push_back(operand.value());
}
rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
return success();
}
};
} // namespace
static LogicalResult adjustCallingConventions(FuncOp func,
TypeBoundMap &typeBoundMap) {
MLIRContext *context = func.getContext();
RewritePatternSet patterns(context);
// TODO: TupleTypes
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
typeConverter.addConversion(
[](Basicpy::NoneType type,
SmallVectorImpl<Type> &types) -> Optional<LogicalResult> {
return success();
});
typeConverter.addArgumentMaterialization(
[](OpBuilder &builder, Numpy::NdArrayType type, ValueRange inputs,
Location loc) -> Value {
assert(inputs.size() == 1);
assert(inputs[0].getType().isa<Numpy::NdArrayType>());
return builder.create<Numpy::StaticInfoCastOp>(loc, type, inputs[0]);
});
patterns.add<AdjustCallingConventionForFunc>(typeConverter, context);
patterns.add<AdjustCallingConventionForCall>(typeConverter, context,
typeBoundMap);
patterns.add<AdjustCallingConventionForReturn>(typeConverter, context);
ConversionTarget target(*context);
target.addDynamicallyLegalOp<FuncOp>([](FuncOp func) {
for (int i = 0, e = func.getNumArguments(); i != e; i++) {
if (func.getArgAttr(i, "torch.type_bound"))
return false;
if (func.getArgumentTypes()[i].isa<Basicpy::NoneType>())
return false;
}
for (int i = 0, e = func.getNumResults(); i != e; i++) {
if (func.getType().getResults()[i].isa<Basicpy::NoneType>())
return false;
}
return true;
});
// The dynamic legality conditions for call and return are a pain to write...
// Just run the patterns once and call it a day.
//
// Bug for doing this better https://bugs.llvm.org/show_bug.cgi?id=49812
DenseSet<Operation *> opsInOriginalProgram;
func.walk([&](CallOp op) { opsInOriginalProgram.insert(op.getOperation()); });
func.walk(
[&](ReturnOp op) { opsInOriginalProgram.insert(op.getOperation()); });
target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
return !opsInOriginalProgram.contains(op.getOperation());
});
target.addDynamicallyLegalOp<ReturnOp>([&](ReturnOp op) {
return !opsInOriginalProgram.contains(op.getOperation());
});
target.addLegalOp<Numpy::StaticInfoCastOp>();
target.addLegalOp<Basicpy::SingletonOp>();
// We don't know how to rewrite it, so mark it as illegal.
target.addIllegalOp<CallIndirectOp>();
if (failed(applyPartialConversion(func.getOperation(), target,
std::move(patterns))))
return failure();
return success();
}
namespace {
class AdjustCallingConventionsPass
: public AdjustCallingConventionsBase<AdjustCallingConventionsPass> {
void runOnOperation() override {
auto module = getOperation();
TypeBoundMap typeBoundMap;
for (auto func : module.getOps<FuncOp>()) {
for (int i = 0, e = func.getNumArguments(); i != e; i++) {
auto typeBoundAttr =
func.getArgAttrOfType<TypeAttr>(i, "torch.type_bound");
if (!typeBoundAttr)
continue;
typeBoundMap[{func.getName(), i}] = typeBoundAttr.getValue();
}
}
for (auto func : module.getOps<FuncOp>()) {
if (failed(adjustCallingConventions(func, typeBoundMap)))
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::NPCOMP::Torch::createAdjustCallingConventionsPass() {
return std::make_unique<AdjustCallingConventionsPass>();
}