torch-mlir/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp

298 lines
12 KiB
C++

//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/Parser.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/InliningUtils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
static Value adjustShapeFunctionArg(Value operand, Type desiredType,
OpBuilder &b, Location loc);
static Value adjustListArg(Value operand, Torch::ListType desiredType,
OpBuilder &b, Location loc) {
auto providedType = operand.getType().cast<Torch::ListType>();
// Pseudocode:
//
// operand = ...
// adjusted_list = []
// for i in range(len(operand)):
// adjusted_list.append(adjust(operand[i]))
// return adjusted_list
Value adjustedList =
b.create<PrimListConstructOp>(loc, desiredType, ValueRange({}));
// Create a for-like PrimLoopOp.
Value maxTripCount = b.create<AtenLenTOp>(loc, operand);
Value cTrue = b.create<Torch::ConstantBoolOp>(loc, true);
auto loop = b.create<PrimLoopOp>(loc, TypeRange({}), maxTripCount,
/*initialCondition=*/cTrue,
/*iterArgsInit=*/ValueRange({}));
// Create the loop body.
{
OpBuilder::InsertionGuard guard(b);
Block *body =
b.createBlock(&loop.region(), loop.region().begin(),
TypeRange({b.getType<Torch::IntType>()}), {loc});
Value iterationNumber = body->getArgument(0);
Value element = b.create<Aten__Getitem__TOp>(
loc, providedType.getContainedType(), operand, iterationNumber);
Value adjustedElement =
adjustShapeFunctionArg(element, desiredType.getContainedType(), b, loc);
b.create<AtenAppendTOp>(loc, adjustedList.getType(), adjustedList,
adjustedElement);
b.create<PrimLoopConditionOp>(loc, /*shouldContinue=*/cTrue,
/*iterArgs=*/ValueRange({}));
}
return adjustedList;
}
static Value adjustShapeFunctionArg(Value operand, Type desiredType,
OpBuilder &b, Location loc) {
auto operandType = operand.getType();
// No need for adjustment if they already match.
if (operandType == desiredType)
return operand;
if (desiredType.isa<Torch::AnyType>()) {
// Generator's are currently passed as Any because TorchScript cannot
// compile a function with Generator type arguments.
// Ignoring that hack, this is a correct handling of Any type should we need
// to actually support it in the future.
return b.create<DerefineOp>(loc, desiredType, operand);
}
// If the operand is NoneType, then we just need to derefine it to the
// optional type in the shape function signature.
if (operandType.isa<Torch::NoneType>()) {
assert(desiredType.isa<Torch::OptionalType>() &&
"Don't expect shape functions to have NoneType parameters");
return b.create<DerefineOp>(loc, desiredType, operand);
}
// If the operand type is statically !torch.optional, then we need to do
// different things for the None and non-None cases.
// For the None case, we just need to derefine it to the desired type.
// For the non-None case, we need to unwrap the optional type and then adjust
// it recursively (which also takes care of derefining it to ultimate desired
// type).
// A case where this happens is `!torch.optional<!torch.vtensor>` ->
// `!torch.optional<!torch.list<!torch.int>>>`.
if (auto operandOptionalType = operandType.dyn_cast<Torch::OptionalType>()) {
if (desiredType.isa<Torch::OptionalType>()) {
// if optional is None:
// return derefine(None)
// else:
// return adjust(unchecked_cast(optional))
auto none = b.create<ConstantNoneOp>(loc);
auto isNone = b.create<Aten__Is__Op>(loc, operand, none);
auto primIf = b.create<PrimIfOp>(loc, desiredType, isNone);
{
Region &thenRegion = primIf.thenRegion();
b.createBlock(&thenRegion, thenRegion.end());
auto derefineNone = b.create<DerefineOp>(loc, desiredType, none);
b.create<PrimIfYieldOp>(loc, ValueRange{derefineNone});
}
{
Region &elseRegion = primIf.elseRegion();
b.createBlock(&elseRegion, elseRegion.end());
auto downcasted = b.create<PrimUncheckedCastOp>(
loc, operandOptionalType.getContainedType(), operand);
auto adjusted = adjustShapeFunctionArg(downcasted, desiredType, b, loc);
b.create<PrimIfYieldOp>(loc, adjusted);
}
b.setInsertionPointAfter(primIf);
return primIf.getResult(0);
}
}
// If the desired type is OptionalType, then recursively adjust the operand to
// the contained type, then derefine it to `!torch.optional`. For example,
// `!torch.vtensor -> !torch.optional<!torch.list<!torch.int>>>`.
if (auto desiredOptionalType = desiredType.dyn_cast<Torch::OptionalType>()) {
auto adjusted = adjustShapeFunctionArg(
operand, desiredOptionalType.getContainedType(), b, loc);
return b.create<DerefineOp>(loc, desiredType, adjusted);
}
// The shape library functions have tensor operands replaced with
// `!torch.list<!torch.int>` types for the shape. Get the sizes.
if (operand.getType().isa<Torch::BaseTensorType>()) {
assert(desiredType.isa<Torch::ListType>() &&
"Don't expect shape functions to have tensor parameters");
return b.create<AtenSizeOp>(loc, desiredType, operand);
}
// Run this after `operand.getType().isa<Torch::BaseTensorType>()` so that
// `!torch.vtensor` -> `!torch.list<!torch.int>` is handled there specially
// first.
if (auto desiredListType = desiredType.dyn_cast<Torch::ListType>()) {
return adjustListArg(operand, desiredListType, b, loc);
}
// The shape library functions use `float` where the operator
// signature uses `Scalar` (see comments in torch_ods_gen.py for
// explanation).
if (desiredType.isa<Torch::FloatType>() &&
operand.getType().isa<Torch::IntType>()) {
return b.create<AtenFloatScalarOp>(loc, desiredType, operand);
}
// Pass the operand as-is.
return operand;
}
// Populates the shape calculation region with a call to the shape function
// from the shape library.
static LogicalResult
populateShapeCalculationRegion(ShapeCalculateOp op, ValueRange originalOperands,
mlir::FuncOp shapeFunction) {
// Create a call to the shape function in the `shapeCalculation` region.
// We will import the callee from the shape library later.
OpBuilder b(op.getContext());
Location loc = op->getLoc();
b.createBlock(&op.shapeCalculation());
// Massage the op operands to match the shape function signature.
// The shape function generally takes the same operands as the op, with a few
// systematic modifications, such as replacing tensors with their shapes.
SmallVector<Value> shapeFunctionArgs;
for (auto operandAndDesiredType :
llvm::zip(originalOperands, shapeFunction.getArgumentTypes())) {
Value operand;
Type desiredType;
std::tie(operand, desiredType) = operandAndDesiredType;
Value shapeFunctionArg =
adjustShapeFunctionArg(operand, desiredType, b, loc);
if (!shapeFunctionArg)
return failure();
shapeFunctionArgs.push_back(shapeFunctionArg);
}
// Create the call to the shape function!
auto call = b.create<mlir::CallOp>(loc, shapeFunction, shapeFunctionArgs);
// Python models multiple results with a tuple, so we need to unpack it
// if the op has multiple results.
SmallVector<Value> unpackedResults;
assert(call.getNumResults() == 1 &&
"Multiple results are packed in a tuple in Python!");
Value result = call.getResult(0);
if (auto tupleType = result.getType().dyn_cast<Torch::TupleType>()) {
auto unpack =
b.create<PrimTupleUnpackOp>(loc, tupleType.getContainedTypes(), result);
llvm::append_range(unpackedResults, unpack.getResults());
} else {
unpackedResults.push_back(result);
}
// Terminate the region.
b.create<ShapeCalculateYieldShapesOp>(loc, unpackedResults);
return success();
}
namespace {
class ReifyShapeCalculationsPass
: public ReifyShapeCalculationsBase<ReifyShapeCalculationsPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp module = getOperation();
// TODO: Find a way to not have to parse this every time.
// The shape library is O(#ops we know about), and this pass should be
// O(#ops in the program) ideally.
auto shapeLibrary = parseSourceString(getShapeLibrary(), context);
// Walk all the operations, and if we have a shape function, wrap the op
// in a `torch.shape.calculate` op.
SmallVector<std::string> neededShapeFunctions;
bool hadError = false;
module.walk([&](Operation *op) {
Location loc = op->getLoc();
auto name = op->getName().stripDialect();
// For pseudo-ops (ops that are mechanically consistent with existing
// torch conventions, but simply not present, such as a missing in-place
// or out-of-place variant), remove the pseudo prefix.
if (name.startswith("pseudo."))
name = name.drop_front(strlen("pseudo."));
auto shapeFunctionName = ("__torch_mlir_shape_fn." + Twine(name)).str();
auto shapeFunction =
shapeLibrary->lookupSymbol<FuncOp>(shapeFunctionName);
if (!shapeFunction)
return;
neededShapeFunctions.push_back(shapeFunctionName);
auto shapeCalculate =
OpBuilder(op).create<ShapeCalculateOp>(loc, op->getResultTypes());
op->replaceAllUsesWith(shapeCalculate);
{
// Move the op into the body of the `torch.shape.calculate` op and yield
// its results.
OpBuilder b(context);
Block *block = b.createBlock(&shapeCalculate.body());
op->moveBefore(block, block->end());
b.setInsertionPointAfter(op);
b.create<ShapeCalculateYieldOp>(loc, op->getResults());
}
if (failed(populateShapeCalculationRegion(
shapeCalculate, op->getOperands(), shapeFunction))) {
hadError = true;
return;
}
});
if (hadError)
return signalPassFailure();
// Import just the functions we need. This includes transitive callees,
// so we use a worklist algorithm.
llvm::StringSet<> importedFunctions;
SmallVector<std::string> worklist;
llvm::append_range(worklist, neededShapeFunctions);
while (!worklist.empty()) {
auto symName = worklist.pop_back_val();
if (importedFunctions.count(symName))
continue;
auto func = shapeLibrary->lookupSymbol<mlir::FuncOp>(symName);
assert(func && "broken shape library");
// Move the shape function from the library to the module this pass
// is running on. (this mutates the library, but we re-parse it each time
// so this is safe to do).
func->moveBefore(&module.getBody()->front());
// Set the visibility to private so that the shape functions go away
// nicely after we are done with them.
func.setVisibility(SymbolTable::Visibility::Private);
// Continue the DFS.
importedFunctions.insert(symName);
func.walk([&](CallOp op) { worklist.push_back(op.getCallee().str()); });
}
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::Torch::createReifyShapeCalculationsPass() {
return std::make_unique<ReifyShapeCalculationsPass>();
}