torch-mlir/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculat...

291 lines
12 KiB
C++

//===----------------------------------------------------------------------===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "ReifyAbstractInterpCalculationsUtils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "llvm/ADT/StringSet.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
static std::string getLibraryFunctionPrefix(LibraryFunctionKind libFuncKind) {
if (libFuncKind == LibraryFunctionKind::ShapeFunction)
return "__torch_mlir_shape_fn.";
else if (libFuncKind == LibraryFunctionKind::DtypeFunction)
return "__torch_mlir_dtype_fn.";
llvm_unreachable(
"`getLibraryFunctionPrefix` called with an unsupported `CalculateOp`");
}
static Operation *createCalculateOp(OpBuilder &b, Location loc,
TypeRange resultTypes,
LibraryFunctionKind libFuncKind) {
if (libFuncKind == LibraryFunctionKind::ShapeFunction)
return b.create<ShapeCalculateOp>(loc, resultTypes);
else if (libFuncKind == LibraryFunctionKind::DtypeFunction)
return b.create<DtypeCalculateOp>(loc, resultTypes);
llvm_unreachable(
"`createCalculateOp` called with an unsupported `LibraryFunctionKind`");
}
static Operation *createCalculateYieldOp(OpBuilder &b, Location loc,
ValueRange results,
LibraryFunctionKind libFuncKind) {
if (libFuncKind == LibraryFunctionKind::ShapeFunction)
return b.create<ShapeCalculateYieldOp>(loc, results);
else if (libFuncKind == LibraryFunctionKind::DtypeFunction)
return b.create<DtypeCalculateYieldOp>(loc, results);
llvm_unreachable("`createCalculateYieldOp` called with an unsupported "
"`LibraryFunctionKind`");
}
static Operation *
createCalculateYieldCalculationOp(OpBuilder &b, Location loc,
ValueRange results,
LibraryFunctionKind libFuncKind) {
if (libFuncKind == LibraryFunctionKind::ShapeFunction)
return b.create<ShapeCalculateYieldShapesOp>(loc, results);
else if (libFuncKind == LibraryFunctionKind::DtypeFunction)
return b.create<DtypeCalculateYieldDtypesOp>(loc, results);
llvm_unreachable("`createCalculateYieldCalculationOp` called with an "
"unsupported `LibraryFunctionKind`");
}
LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable(
Operation *op, ModuleOp library, LibraryFunctionKind libFuncKind,
SmallVector<std::string> &libFuncNamesUsed,
function_ref<FailureOr<SmallVector<Value>>(OpBuilder &, Location,
ValueRange, func::FuncOp)>
libFuncArgsBuilder) {
Location loc = op->getLoc();
MLIRContext *context = op->getContext();
auto name = op->getName().stripDialect();
// For value-semantic variant ops, i.e. valsem-ops (ops that are
// mechanically consistent with existing torch conventions of in-place vs.
// out-of-place (value-semantic) variants), remove the prefix when
// looking them up in the library.
if (name.startswith("valsem."))
name = name.drop_front(strlen("valsem."));
std::string libFuncName =
(getLibraryFunctionPrefix(libFuncKind) + Twine(name)).str();
auto libFunc = library.lookupSymbol<func::FuncOp>(libFuncName);
if (!libFunc)
return success();
libFuncNamesUsed.push_back(libFuncName);
OpBuilder b(op);
Operation *calculate =
createCalculateOp(b, loc, op->getResultTypes(), libFuncKind);
op->replaceAllUsesWith(calculate);
{
// Move the op into the body of the `torch.{libFuncType}.calculate` op
// and yield its results.
OpBuilder b(context);
Block *bodyBlock = b.createBlock(&calculate->getRegion(0));
op->moveBefore(bodyBlock, bodyBlock->end());
b.setInsertionPointAfter(op);
createCalculateYieldOp(b, loc, op->getResults(), libFuncKind);
}
{
OpBuilder b(context);
b.createBlock(&calculate->getRegion(1));
// Create the call to the library function!
FailureOr<SmallVector<Value>> libFuncArgs =
libFuncArgsBuilder(b, loc, op->getOperands(), libFunc);
if (failed(libFuncArgs))
return failure();
auto call = b.create<mlir::func::CallOp>(loc, libFunc, *libFuncArgs);
// Python models multiple results with a tuple, so we need to unpack it
// if the op has multiple results.
SmallVector<Value> unpackedResults;
assert(call.getNumResults() == 1 &&
"Multiple results are packed in a tuple in Python!");
Value result = call.getResult(0);
if (auto tupleType = result.getType().dyn_cast<Torch::TupleType>()) {
auto unpack = b.create<PrimTupleUnpackOp>(
loc, tupleType.getContainedTypes(), result);
llvm::append_range(unpackedResults, unpack.getResults());
} else {
unpackedResults.push_back(result);
}
// Terminate the region.
createCalculateYieldCalculationOp(b, loc, unpackedResults, libFuncKind);
}
return success();
}
void Torch::importLibraryFunctions(ModuleOp module, ModuleOp library,
SmallVector<std::string> functionsNeeded) {
// Import just the functions we need. This includes transitive callees,
// so we use a worklist algorithm.
llvm::StringSet<> importedFunctions;
while (!functionsNeeded.empty()) {
std::string symName = functionsNeeded.pop_back_val();
if (importedFunctions.contains(symName))
continue;
auto func = library.lookupSymbol<func::FuncOp>(symName);
assert(func && "broken library");
// Move the function from the library to the module this pass
// is running on. (this mutates the library, but we re-parse it each time
// so this is safe to do).
func->moveBefore(&module.getBody()->front());
// Set the visibility to private so that the functions go away
// nicely after we are done with them.
func.setVisibility(SymbolTable::Visibility::Private);
// Continue the DFS.
importedFunctions.insert(symName);
func.walk([&](func::CallOp op) {
functionsNeeded.push_back(op.getCallee().str());
});
}
}
FailureOr<Value> Torch::adjustFunctionArg(
OpBuilder &b, Location loc, Value operand, Type desiredType,
function_ref<Value(OpBuilder &, Location, Value, Type)> baseTransformation) {
operand = baseTransformation(b, loc, operand, desiredType);
// No need for adjustment if they already match.
auto operandType = operand.getType();
if (operandType == desiredType)
return operand;
if (desiredType.isa<Torch::AnyType>()) {
// Generator's are currently passed as Any because TorchScript cannot
// compile a function with Generator type arguments.
// Ignoring that hack, this is a correct handling of Any type should we need
// to actually support it in the future.
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
}
// !torch.union<int, float> is the type used for `Scalar` inputs. At
// compile time, such inputs will usually be resolved to an `int` or a `float`
// so we need to derefine to match the library function signature.
if (auto unionType = desiredType.dyn_cast<Torch::UnionType>()) {
if (llvm::all_of(unionType.getContainedTypes(), [](Type containedType) {
return containedType.isa<Torch::IntType, Torch::FloatType>();
}))
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
}
// If the operand is NoneType, then we just need to derefine it to the
// optional type in the function signature.
if (operandType.isa<Torch::NoneType>()) {
assert(desiredType.isa<Torch::OptionalType>() &&
"Don't expect library functions to have NoneType parameters");
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
}
// If the operand type is statically !torch.optional, then we need to do
// different things for the None and non-None cases.
// For the None case, we just need to derefine it to the desired type.
// For the non-None case, we need to unwrap the optional type and then adjust
// it recursively (which also takes care of derefining it to ultimate desired
// type).
// A case where this happens is `!torch.optional<vtensor>` ->
// `!torch.optional<list<int>>>`.
if (auto operandOptionalType = operandType.dyn_cast<Torch::OptionalType>()) {
if (desiredType.isa<Torch::OptionalType>()) {
// if optional is None:
// return derefine(None)
// else:
// return adjust(unchecked_cast(optional))
auto none = b.create<ConstantNoneOp>(loc);
auto isNone = b.create<Aten__Is__Op>(loc, operand, none);
auto primIf = b.create<PrimIfOp>(loc, desiredType, isNone);
{
Region &thenRegion = primIf.getThenRegion();
b.createBlock(&thenRegion, thenRegion.end());
auto derefineNone = b.create<DerefineOp>(loc, desiredType, none);
b.create<PrimIfYieldOp>(loc, ValueRange{derefineNone});
}
{
Region &elseRegion = primIf.getElseRegion();
b.createBlock(&elseRegion, elseRegion.end());
auto downcasted = b.create<PrimUncheckedCastOp>(
loc, operandOptionalType.getContainedType(), operand);
FailureOr<Value> adjusted = adjustFunctionArg(
b, loc, downcasted, desiredType, baseTransformation);
if (failed(adjusted))
return failure();
b.create<PrimIfYieldOp>(loc, *adjusted);
}
b.setInsertionPointAfter(primIf);
return primIf.getResult(0);
}
}
// If the desired type is OptionalType, then recursively adjust the operand to
// the contained type, then derefine it to `!torch.optional`. For example,
// `!torch.vtensor -> !torch.optional<list<int>>>`.
if (auto desiredOptionalType = desiredType.dyn_cast<Torch::OptionalType>()) {
FailureOr<Value> adjusted = adjustFunctionArg(
b, loc, operand, desiredOptionalType.getContainedType(),
baseTransformation);
if (failed(adjusted))
return failure();
return b.create<DerefineOp>(loc, desiredType, *adjusted).getResult();
}
if (auto desiredListType = desiredType.dyn_cast<Torch::ListType>()) {
// Pseudocode:
//
// operand = ...
// adjusted_list = []
// for i in range(len(operand)):
// adjusted_list.append(adjust(operand[i]))
// return adjusted_list
auto providedType = operand.getType().cast<Torch::ListType>();
Value adjustedList =
b.create<PrimListConstructOp>(loc, desiredListType, ValueRange({}));
// Create a for-like PrimLoopOp.
Value maxTripCount = b.create<AtenLenTOp>(loc, operand);
Value cTrue = b.create<Torch::ConstantBoolOp>(loc, true);
auto loop = b.create<PrimLoopOp>(loc, TypeRange({}), maxTripCount,
/*initialCondition=*/cTrue,
/*iterArgsInit=*/ValueRange({}));
// Create the loop body.
{
OpBuilder::InsertionGuard guard(b);
Block *body =
b.createBlock(&loop.getRegion(), loop.getRegion().begin(),
TypeRange({b.getType<Torch::IntType>()}), {loc});
Value iterationNumber = body->getArgument(0);
Value element = b.create<Aten__Getitem__TOp>(
loc, providedType.getContainedType(), operand, iterationNumber);
FailureOr<Value> adjustedElement =
adjustFunctionArg(b, loc, element, desiredListType.getContainedType(),
baseTransformation);
if (failed(adjustedElement))
return failure();
b.create<AtenAppendTOp>(loc, adjustedList.getType(), adjustedList,
*adjustedElement);
b.create<PrimLoopConditionOp>(loc, /*shouldContinue=*/cTrue,
/*iterArgs=*/ValueRange({}));
}
return adjustedList;
}
// The library functions use `float` where the operator
// signature uses `Scalar` (see comments in torch_ods_gen.py for
// explanation).
if (desiredType.isa<Torch::FloatType>() &&
operand.getType().isa<Torch::IntType>()) {
return b.create<AtenFloatScalarOp>(loc, desiredType, operand).getResult();
}
// Pass the operand as-is.
return operand;
}