2022-03-10 08:44:22 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "PassDetail.h"
|
|
|
|
|
2022-12-14 00:25:41 +08:00
|
|
|
#include "ReifyAbstractInterpCalculationsUtils.h"
|
2022-03-16 18:44:23 +08:00
|
|
|
#include "mlir/Parser/Parser.h"
|
2022-03-10 08:44:22 +08:00
|
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
2023-03-25 10:50:01 +08:00
|
|
|
#include "llvm/Support/MemoryBuffer.h"
|
2022-03-10 08:44:22 +08:00
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::torch;
|
|
|
|
using namespace mlir::torch::Torch;
|
|
|
|
|
2022-12-14 00:25:41 +08:00
|
|
|
static FailureOr<SmallVector<Value>>
|
|
|
|
shapeFunctionArgsBuilder(OpBuilder &b, Location loc,
|
|
|
|
ValueRange originalOperands, func::FuncOp shapeFunc) {
|
2022-03-10 08:44:22 +08:00
|
|
|
// Massage the op operands to match the shape function signature.
|
|
|
|
// The shape function generally takes the same operands as the op, with a few
|
|
|
|
// systematic modifications, such as replacing tensors with their shapes.
|
2022-12-14 00:25:41 +08:00
|
|
|
SmallVector<Value> shapeFuncArgs;
|
2022-03-10 08:44:22 +08:00
|
|
|
for (auto operandAndDesiredType :
|
2022-12-14 00:25:41 +08:00
|
|
|
llvm::zip(originalOperands, shapeFunc.getArgumentTypes())) {
|
2022-03-10 08:44:22 +08:00
|
|
|
Value operand;
|
|
|
|
Type desiredType;
|
|
|
|
std::tie(operand, desiredType) = operandAndDesiredType;
|
2022-12-14 00:25:41 +08:00
|
|
|
FailureOr<Value> shapeFuncArg = adjustFunctionArg(
|
|
|
|
b, loc, operand, desiredType,
|
|
|
|
[](OpBuilder &b, Location loc, Value operand,
|
|
|
|
Type desiredType) -> Value {
|
|
|
|
// The shape library functions have tensor operands replaced with
|
|
|
|
// `!torch.list<int>` types for the shape. Get the sizes.
|
2024-04-11 21:47:35 +08:00
|
|
|
auto desiredListType = dyn_cast<Torch::ListType>(desiredType);
|
2022-12-14 00:25:41 +08:00
|
|
|
if (!desiredListType)
|
|
|
|
return operand;
|
2024-04-28 05:00:56 +08:00
|
|
|
if (isa<Torch::BaseTensorType>(operand.getType()) &&
|
|
|
|
isa<Torch::IntType>(desiredListType.getContainedType())) {
|
2022-12-14 00:25:41 +08:00
|
|
|
return b.create<AtenSizeOp>(loc, desiredType, operand);
|
|
|
|
}
|
|
|
|
return operand;
|
|
|
|
});
|
|
|
|
if (failed(shapeFuncArg))
|
2022-03-10 08:44:22 +08:00
|
|
|
return failure();
|
2022-12-14 00:25:41 +08:00
|
|
|
shapeFuncArgs.push_back(*shapeFuncArg);
|
2022-03-10 08:44:22 +08:00
|
|
|
}
|
|
|
|
|
2022-12-14 00:25:41 +08:00
|
|
|
return shapeFuncArgs;
|
2022-03-10 08:44:22 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
namespace {
|
2023-03-25 10:50:01 +08:00
|
|
|
struct ReifyShapeCalculationsPass
|
2022-03-10 08:44:22 +08:00
|
|
|
: public ReifyShapeCalculationsBase<ReifyShapeCalculationsPass> {
|
2023-03-25 10:50:01 +08:00
|
|
|
ReifyShapeCalculationsPass() = default;
|
|
|
|
ReifyShapeCalculationsPass(StringRef extraLibrary) {
|
|
|
|
this->extraLibrary = extraLibrary.str();
|
|
|
|
}
|
2022-03-10 08:44:22 +08:00
|
|
|
void runOnOperation() override {
|
|
|
|
MLIRContext *context = &getContext();
|
|
|
|
ModuleOp module = getOperation();
|
|
|
|
|
|
|
|
// TODO: Find a way to not have to parse this every time.
|
2022-12-14 00:25:41 +08:00
|
|
|
// The library is O(#ops we know about), and this pass should be
|
2022-03-10 08:44:22 +08:00
|
|
|
// O(#ops in the program) ideally.
|
2022-12-14 00:25:41 +08:00
|
|
|
OwningOpRef<ModuleOp> library =
|
|
|
|
parseSourceString<ModuleOp>(getAbstractInterpLibrary(), context);
|
2023-03-25 10:50:01 +08:00
|
|
|
if (!extraLibrary.empty())
|
|
|
|
if (failed(mlir::torch::Torch::loadExtraLibrary(extraLibrary, library))) {
|
|
|
|
emitError(module->getLoc(),
|
|
|
|
"Failed to load extra-library file at " + extraLibrary);
|
|
|
|
return signalPassFailure();
|
|
|
|
}
|
2022-03-10 08:44:22 +08:00
|
|
|
|
|
|
|
// Walk all the operations, and if we have a shape function, wrap the op
|
|
|
|
// in a `torch.shape.calculate` op.
|
2022-12-14 00:25:41 +08:00
|
|
|
SmallVector<std::string> functionsNeeded;
|
|
|
|
WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult {
|
|
|
|
return wrapWithCalculateOpIfLibraryFunctionAvailable(
|
|
|
|
op, *library, LibraryFunctionKind::ShapeFunction, functionsNeeded,
|
|
|
|
shapeFunctionArgsBuilder);
|
2022-03-10 08:44:22 +08:00
|
|
|
});
|
|
|
|
|
2022-12-14 00:25:41 +08:00
|
|
|
if (walkResult.wasInterrupted())
|
2022-03-10 08:44:22 +08:00
|
|
|
return signalPassFailure();
|
2022-12-14 00:25:41 +08:00
|
|
|
importLibraryFunctions(module, *library, std::move(functionsNeeded));
|
2022-03-10 08:44:22 +08:00
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
std::unique_ptr<OperationPass<ModuleOp>>
|
2023-03-25 10:50:01 +08:00
|
|
|
mlir::torch::Torch::createReifyShapeCalculationsPass(StringRef extraLibrary) {
|
|
|
|
return std::make_unique<ReifyShapeCalculationsPass>(extraLibrary);
|
2022-03-10 08:44:22 +08:00
|
|
|
}
|