Maksim Levental 2023-03-24 21:50:01 -05:00 committed by GitHub
parent a7449785ec
commit 953ea39cb5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 323 additions and 64 deletions

View File

@ -56,7 +56,12 @@ struct TorchLoweringPipelineOptions
// to check for a specific set of legal ops to stop its iteration.
ListOption<std::string> backendLegalOps{
*this, "backend-legal-ops",
llvm::cl::desc("List of ops to be considered legal for the backend.")};
llvm::cl::desc("List of ops to be considered legal for the backend, such "
"as 'aten.foo'.")};
Option<std::string> extraLibrary{
*this, "extra-library",
llvm::cl::desc("Filename of MLIR module for splicing into the abstract interpretation library.")};
};
/// Creates a pipeline that lowers the object graph IR that is produced by
@ -78,10 +83,12 @@ void createTorchSimplificationPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
/// Creates a pipeline that refines shapes of tensor operations in the program.
void createTorchShapeRefinementPipeline(OpPassManager &pm);
void createTorchShapeRefinementPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
/// Creates a pipeline that refines dtype of tensor operations in the program.
void createTorchDtypeRefinementPipeline(OpPassManager &pm);
void createTorchDtypeRefinementPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
@ -89,7 +96,8 @@ std::unique_ptr<OperationPass<func::FuncOp>> createRefineTypesPass();
std::unique_ptr<OperationPass<ModuleOp>> createInlineGlobalSlotsPass();
std::unique_ptr<OperationPass<func::FuncOp>> createReduceOpVariantsPass();
std::unique_ptr<OperationPass<func::FuncOp>>
createReduceOpVariantsPass(StringRef extraLibrary);
std::unique_ptr<OperationPass<func::FuncOp>> createMaximizeValueSemanticsPass();
@ -100,14 +108,14 @@ createDecomposeComplexOpsPass(ArrayRef<std::string> legalOps);
std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOps();
std::unique_ptr<OperationPass<ModuleOp>> createPreprocessShapeLibraryPass();
std::unique_ptr<OperationPass<ModuleOp>> createReifyShapeCalculationsPass();
std::unique_ptr<OperationPass<ModuleOp>>
createReifyShapeCalculationsPass(StringRef extraLibrary);
std::unique_ptr<OperationPass<func::FuncOp>>
createSimplifyShapeCalculationsPass();
std::unique_ptr<OperationPass<ModuleOp>> createReifyDtypeCalculationsPass();
std::unique_ptr<OperationPass<ModuleOp>>
createReifyDtypeCalculationsPass(StringRef extraLibrary);
std::unique_ptr<OperationPass<func::FuncOp>>
createSimplifyDtypeCalculationsPass();
@ -120,13 +128,16 @@ createEraseModuleInitializerPass();
std::unique_ptr<OperationPass<ModuleOp>>
createLowerToBackendContractPass(int maxIterations, bool decompose,
ArrayRef<std::string> backendLegalOps);
ArrayRef<std::string> backendLegalOps,
StringRef extraLibrary);
std::unique_ptr<OperationPass<ModuleOp>>
createVerifyBackendContractNoDecompositionsPass();
StringRef getAbstractInterpLibrary();
static const char kTorchOpPrefix[] = R"(torch.)";
} // namespace Torch
/// Registers all Torch transformation passes.

View File

@ -151,7 +151,13 @@ def InlineGlobalSlots : Pass<"torch-inline-global-slots", "ModuleOp"> {
def ReduceOpVariants : Pass<"torch-reduce-op-variants", "func::FuncOp"> {
let summary = "Reduces variants of ops to a smaller set of ops.";
let constructor = "mlir::torch::Torch::createReduceOpVariantsPass()";
let constructor = [{
mlir::torch::Torch::createReduceOpVariantsPass(/*extraLibrary=*/"")
}];
let options = [
Option<"extraLibrary", "extra-library", "std::string", /*default=*/"",
"MLIR module for verifying custom op value semantics">,
];
let description = [{
Replaces ops with other ops to reduce the number of variants that
need to be handled elsewhere in the code.
@ -240,7 +246,13 @@ def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "func::FuncOp"> {
def ReifyShapeCalculations : Pass<"torch-reify-shape-calculations", "ModuleOp"> {
let summary = "Reify shape calculations.";
let constructor = "mlir::torch::Torch::createReifyShapeCalculationsPass()";
let constructor = [{
mlir::torch::Torch::createReifyShapeCalculationsPass(/*extraLibrary=*/"")
}];
let options = [
Option<"extraLibrary", "extra-library", "std::string", /*default=*/"",
"MLIR module for splicing into the shape library">,
];
let description = [{
}];
}
@ -255,7 +267,13 @@ def SimplifyShapeCalculations : Pass<"torch-simplify-shape-calculations", "func:
def ReifyDtypeCalculations : Pass<"torch-reify-dtype-calculations", "ModuleOp"> {
let summary = "Reify dtype calculations.";
let constructor = "mlir::torch::Torch::createReifyDtypeCalculationsPass()";
let constructor = [{
mlir::torch::Torch::createReifyDtypeCalculationsPass(/*extraLibrary=*/"")
}];
let options = [
Option<"extraLibrary", "extra-library", "std::string", /*default=*/"",
"MLIR module for splicing into the dtype library">,
];
let description = [{
}];
}
@ -291,7 +309,7 @@ def LowerToBackendContract
let summary = "Perform simplifications until the backend contract is satisfied.";
let constructor = [{
mlir::torch::Torch::createLowerToBackendContractPass(
/*maxIterations=*/10, /*decompose=*/true, /*backendLegalOps=*/{})
/*maxIterations=*/10, /*decompose=*/true, /*backendLegalOps=*/{}, /*extraLibrary=*/"")
}];
let description = [{
This pass performs the bulk of the lowering of the program's computations
@ -335,7 +353,9 @@ def LowerToBackendContract
Option<"decompose", "decompose", "bool", /*default=*/"true",
"Decompose ops.">,
ListOption<"backendLegalOps", "backend-legal-ops", "std::string",
"List of ops to be considered legal for the backend.">
"List of ops to be considered legal for the backend, such as 'aten.foo'.">,
Option<"extraLibrary", "extra-library", "std::string", /*default=*/"",
"MLIR module for splicing into the abstract interpretation library">,
];
// TODO: Debug why this is needed, even though the input program has func.func

View File

@ -3890,7 +3890,7 @@ private:
// on `Operation *` are not allowed, since there is no way of telling if
// that pattern will match on an op in the `legalOpsSet` or not.
assert(opName && "All decomposition patterns must target a single op");
if (!legalOpsSet.contains(opName->getStringRef()))
if (!legalOpsSet.contains(opName->getStringRef().ltrim(kTorchOpPrefix)))
patterns.add<DecomposePattern>(context);
}

View File

@ -18,6 +18,7 @@
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/Support/Debug.h"
#include "llvm/ADT/StringSet.h"
#define DEBUG_TYPE "torch-lower-to-backend-contract"
@ -31,7 +32,7 @@ using namespace mlir::torch::Torch;
static void markDecomposedOpsAsIllegal(MLIRContext *context,
ConversionTarget &target,
ArrayRef<std::string> backendLegalOps);
llvm::StringSet<> backendLegalOps);
static LogicalResult checkType(Operation *op, Type type,
bool actuallyEmitDiagnostics) {
@ -246,11 +247,11 @@ static bool satisfiesBackendContract(ModuleOp module,
// Explicitly set ops and dialects allowed and not allowed in backend contract.
static ConversionTarget
getBackendContractTarget(MLIRContext *context, bool decompose,
ArrayRef<std::string> backendLegalOps) {
llvm::StringSet<> backendLegalOpsSet) {
ConversionTarget target(*context);
target.addLegalDialect<func::FuncDialect, Torch::TorchDialect>();
if (decompose)
markDecomposedOpsAsIllegal(context, target, backendLegalOps);
markDecomposedOpsAsIllegal(context, target, backendLegalOpsSet);
return target;
}
@ -260,21 +261,27 @@ class LowerToBackendContractPass
public:
LowerToBackendContractPass() = default;
LowerToBackendContractPass(int maxIterations, bool decompose,
ArrayRef<std::string> backendLegalOps) {
ArrayRef<std::string> backendLegalOps,
StringRef extraLibrary) {
this->maxIterations = maxIterations;
this->decompose = decompose;
this->backendLegalOps = backendLegalOps;
this->extraLibrary = extraLibrary.str();
}
void runOnOperation() override {
ModuleOp module = getOperation();
MLIRContext *context = &getContext();
backendLegalOpsSet.clear();
backendLegalOpsSet.insert(backendLegalOps.begin(), backendLegalOps.end());
ConversionTarget target =
getBackendContractTarget(context, decompose, backendLegalOps);
getBackendContractTarget(context, decompose, backendLegalOpsSet);
OpPassManager pm(module.getOperationName());
TorchLoweringPipelineOptions options;
options.decompose = decompose;
options.backendLegalOps = backendLegalOps;
options.extraLibrary = extraLibrary;
createTorchSimplificationPipeline(pm, options);
int i = 0;
@ -301,6 +308,8 @@ public:
<< " iterations of the simplification pipeline\n";
});
}
private:
llvm::StringSet<> backendLegalOpsSet;
};
class VerifyBackendContractNoDecompositionsPass
@ -312,7 +321,7 @@ public:
MLIRContext *context = &getContext();
ConversionTarget target =
getBackendContractTarget(context, /*decompose*/false,
/*backendLegalOps*/{});
/*backendLegalOpsSet*/{});
if (!satisfiesBackendContract(getOperation(), target,
/*actuallyEmitDiagnostics=*/true)) {
@ -324,9 +333,10 @@ public:
std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::Torch::createLowerToBackendContractPass(
int maxIterations, bool decompose, ArrayRef<std::string> backendLegalOps) {
return std::make_unique<LowerToBackendContractPass>(maxIterations, decompose,
backendLegalOps);
int maxIterations, bool decompose, ArrayRef<std::string> backendLegalOps,
StringRef extraLibrary) {
return std::make_unique<LowerToBackendContractPass>(
maxIterations, decompose, backendLegalOps, extraLibrary);
}
std::unique_ptr<OperationPass<ModuleOp>>
@ -337,9 +347,9 @@ mlir::torch::Torch::createVerifyBackendContractNoDecompositionsPass() {
// The backend contract guarantees that ops with decompositions available will
// be decomposed. The only way to have an op reach the backend contract without
// getting decomposed is by having the user explicitly specify that op in the
// `backendLegalOps` argument to the `LowerToBackendContractPass`. Therefore,
// `backendLegalOpsSet` argument to the `LowerToBackendContractPass`. Therefore,
// here we mark as illegal all ops with decompositions except for those in
// `backendLegalOps`.
// `backendLegalOpsSet`.
//
// The legality check takes place here instead of in the `DecomposeComplexOps`
// pass for two reasons:
@ -352,7 +362,7 @@ mlir::torch::Torch::createVerifyBackendContractNoDecompositionsPass() {
// decompositions explicit in this file
static void markDecomposedOpsAsIllegal(MLIRContext *context,
ConversionTarget &target,
ArrayRef<std::string> backendLegalOps) {
llvm::StringSet<> backendLegalOpsSet) {
target.addIllegalOp<AtenSoftmaxIntOp>();
target.addIllegalOp<Aten_SoftmaxOp>();
target.addIllegalOp<Aten_LogSoftmaxOp>();
@ -463,7 +473,13 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenVarMeanOp>();
target.addIllegalOp<AtenNewEmptyStridedOp>();
target.addIllegalOp<AtenBucketizeTensorOp>();
for (std::string opName : backendLegalOps) {
target.addLegalOp(OperationName(opName, context));
for (auto &opName : backendLegalOpsSet) {
target.addLegalOp(
OperationName(kTorchOpPrefix + opName.first().str(), context));
}
target.addDynamicallyLegalOp<OperatorOp>(
[backendLegalOpsSet](OperatorOp opOp) {
auto opName = opOp->getAttr("name").cast<StringAttr>().getValue();
return backendLegalOpsSet.contains(opName);
});
}

View File

@ -25,7 +25,7 @@ void mlir::torch::registerTorchPasses() {
"torch-simplification-pipeline",
"Pipeline simplifying computations in the program.",
mlir::torch::Torch::createTorchSimplificationPipeline);
mlir::PassPipelineRegistration<>(
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torch-shape-refinement-pipeline", "Pipeline refining shapes of tensors.",
mlir::torch::Torch::createTorchShapeRefinementPipeline);
}
@ -66,7 +66,8 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
// Perform the bulk of lowering to the backend contract.
// See the pass documentation for more information.
pm.addPass(createLowerToBackendContractPass(
options.maxIterations, options.decompose, options.backendLegalOps));
options.maxIterations, options.decompose, options.backendLegalOps,
options.extraLibrary));
}
// A simplification pipeline to establish the invariants of the backend
@ -108,7 +109,8 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createRecomposeComplexOps());
// Reduce variants of ops to a smaller set of primitives.
pm.addNestedPass<func::FuncOp>(createReduceOpVariantsPass());
pm.addNestedPass<func::FuncOp>(
createReduceOpVariantsPass(options.extraLibrary));
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// Remove dead global slots.
pm.addPass(createSymbolDCEPass());
@ -121,8 +123,8 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
// This should be run before RefineTypes (which primarily does dtype
// inference), because Torch type promotion rules actually depend on the shape
// of the operand.
createTorchShapeRefinementPipeline(pm);
createTorchDtypeRefinementPipeline(pm);
createTorchShapeRefinementPipeline(pm, options);
createTorchDtypeRefinementPipeline(pm, options);
// Refine types in the program, which mainly means inferring dtypes of ops.
pm.addNestedPass<func::FuncOp>(Torch::createRefineTypesPass());
// Propagate to ABI return types the shape/dtype information discovered by
@ -141,13 +143,15 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
static void createRefinementPipeline(
mlir::OpPassManager &pm,
llvm::function_ref<std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>()>
llvm::function_ref<
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>(llvm::StringRef)>
reifyCalculationsPass,
llvm::function_ref<
std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>()>
simplifyCalculationsPass) {
simplifyCalculationsPass,
const mlir::torch::Torch::TorchLoweringPipelineOptions &options) {
// Reify the library functions for each op that is present in the library.
pm.addPass(reifyCalculationsPass());
pm.addPass(reifyCalculationsPass(options.extraLibrary));
// Inline the library functions to enable analysis and transformation.
// TODO: Only inline library functions (this will currently inline
@ -168,12 +172,14 @@ static void createRefinementPipeline(
mlir::torch::Torch::createDropAbstractInterpCalculationsPass());
}
void mlir::torch::Torch::createTorchShapeRefinementPipeline(OpPassManager &pm) {
void mlir::torch::Torch::createTorchShapeRefinementPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
createRefinementPipeline(pm, Torch::createReifyShapeCalculationsPass,
Torch::createSimplifyShapeCalculationsPass);
Torch::createSimplifyShapeCalculationsPass, options);
}
void mlir::torch::Torch::createTorchDtypeRefinementPipeline(OpPassManager &pm) {
void mlir::torch::Torch::createTorchDtypeRefinementPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
createRefinementPipeline(pm, Torch::createReifyDtypeCalculationsPass,
Torch::createSimplifyDtypeCalculationsPass);
Torch::createSimplifyDtypeCalculationsPass, options);
}

View File

@ -12,6 +12,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "ReifyAbstractInterpCalculationsUtils.h"
#include "llvm/ADT/StringExtras.h"
using namespace mlir;
@ -52,17 +53,39 @@ static Type getContainerOrTensorTypeWithValueSemantics(Type type) {
}
}
static bool
operatorOpHasValueSemantics(OperatorOp opOp,
std::optional<SymbolTable> extraLibrary) {
if (!extraLibrary.has_value())
return false;
auto opName = opOp->getAttr("name").cast<StringAttr>().getValue();
std::string libFuncName = (mlir::torch::Torch::getLibraryFunctionPrefix(
LibraryFunctionKind::HasValueSemantics) +
Twine(opName))
.str();
auto libFunc = extraLibrary->lookup<func::FuncOp>(libFuncName);
return bool(libFunc);
}
namespace {
// Convert value semantic ops operating on mutable arrays to instead operate on
// immutable tensors.
class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
public:
ConvertHasValueSemanticsOpsToValueTensors(MLIRContext *context)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
ConvertHasValueSemanticsOpsToValueTensors(MLIRContext *context,
const std::optional<SymbolTable>& extraLibrary)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {
this->extraLibrary = extraLibrary;
}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (!op->hasTrait<Torch::OpTrait::HasValueSemantics>())
if (isa<OperatorOp>(op)) {
if (!operatorOpHasValueSemantics(cast<OperatorOp>(op), extraLibrary)) {
return rewriter.notifyMatchFailure(op, "does not have value semantics");
}
} else if (!op->hasTrait<Torch::OpTrait::HasValueSemantics>()) {
return rewriter.notifyMatchFailure(op, "does not have value semantics");
}
rewriter.startRootUpdate(op);
// Convert all operands.
@ -160,6 +183,8 @@ public:
rewriter.finalizeRootUpdate(op);
return success();
}
private:
std::optional<SymbolTable> extraLibrary;
};
} // namespace
@ -241,11 +266,30 @@ reduceNonValueTensorLiteralOpToValueTensorLiteralOp(NonValueTensorLiteralOp op,
}
namespace {
class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
struct ReduceOpVariantsPass
: public ReduceOpVariantsBase<ReduceOpVariantsPass> {
ReduceOpVariantsPass() = default;
ReduceOpVariantsPass(StringRef extraLibrary) {
this->extraLibrary = extraLibrary.str();
}
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.add<ConvertHasValueSemanticsOpsToValueTensors>(context);
OwningOpRef<ModuleOp> extraLibraryModule =
ModuleOp::create(UnknownLoc::get(context));
std::optional<SymbolTable> extraLibraryModuleSymTable = std::nullopt;
if (!extraLibrary.empty()) {
if (failed(loadExtraLibrary(extraLibrary, extraLibraryModule))) {
emitError(getOperation()->getLoc(),
"Failed to load extra-library file at " + extraLibrary);
return signalPassFailure();
}
extraLibraryModuleSymTable =
SymbolTable(extraLibraryModule->getOperation());
}
patterns.add<ConvertHasValueSemanticsOpsToValueTensors>(
context, extraLibraryModuleSymTable);
patterns.add<ReduceTrailingUnderscoreInplaceVariant>(context);
patterns.add(reduceNonValueTensorLiteralOpToValueTensorLiteralOp);
patterns.add<ReduceNonValueSemanticOps>(context);
@ -253,8 +297,12 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
ConversionTarget target(*context);
target.addIllegalOp<NonValueTensorLiteralOp>();
target.addIllegalOp<AtenBernoulli_FloatOp>();
target.markUnknownOpDynamicallyLegal([](Operation *op) {
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>()) {
target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable](
Operation *op) {
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>() ||
(isa<OperatorOp>(op) &&
operatorOpHasValueSemantics(cast<OperatorOp>(op),
extraLibraryModuleSymTable))) {
auto hasValueSemantics = [](Type t) {
// TODO: Make this an allowlist based on a closed torch dialect
// type system.
@ -281,6 +329,6 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createReduceOpVariantsPass() {
return std::make_unique<ReduceOpVariantsPass>();
mlir::torch::Torch::createReduceOpVariantsPass(StringRef extraLibrary) {
return std::make_unique<ReduceOpVariantsPass>(extraLibrary);
}

View File

@ -8,18 +8,25 @@
//===----------------------------------------------------------------------===//
#include "ReifyAbstractInterpCalculationsUtils.h"
#include "mlir/Parser/Parser.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/ErrorOr.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SourceMgr.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
static std::string getLibraryFunctionPrefix(LibraryFunctionKind libFuncKind) {
std::string
mlir::torch::Torch::getLibraryFunctionPrefix(LibraryFunctionKind libFuncKind) {
if (libFuncKind == LibraryFunctionKind::ShapeFunction)
return "__torch_mlir_shape_fn.";
else if (libFuncKind == LibraryFunctionKind::DtypeFunction)
return "__torch_mlir_dtype_fn.";
else if (libFuncKind == LibraryFunctionKind::HasValueSemantics)
return "__torch_mlir_has_value_semantics_fn.";
llvm_unreachable(
"`getLibraryFunctionPrefix` called with an unsupported `CalculateOp`");
}
@ -73,6 +80,8 @@ LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable(
// looking them up in the library.
if (name.startswith("valsem."))
name = name.drop_front(strlen("valsem."));
if (isa<OperatorOp>(op))
name = cast<OperatorOp>(op)->getAttr("name").cast<StringAttr>().getValue();
std::string libFuncName =
(getLibraryFunctionPrefix(libFuncKind) + Twine(name)).str();
auto libFunc = library.lookupSymbol<func::FuncOp>(libFuncName);
@ -288,3 +297,39 @@ FailureOr<Value> Torch::adjustFunctionArg(
// Pass the operand as-is.
return operand;
}
LogicalResult
mlir::torch::Torch::loadExtraLibrary(const std::string &filename,
OwningOpRef<ModuleOp> &moduleToAppendTo) {
auto ctx = moduleToAppendTo->getContext();
assert(ctx && "Module should be fully initialized.");
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(filename);
if (std::error_code ec = fileOrErr.getError()) {
llvm::errs() << "Could not open input file: " << ec.message() << "\n";
return failure();
}
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
OwningOpRef<ModuleOp> module_ =
mlir::parseSourceFile<mlir::ModuleOp>(sourceMgr, ctx);
if (!module_) {
llvm::errs() << "Error can't load file " << filename << "\n";
return failure();
}
assert((moduleToAppendTo->getBodyRegion().empty() ||
moduleToAppendTo->getBodyRegion().hasOneBlock()) &&
"Module should have at most one block.");
if (moduleToAppendTo->getBodyRegion().empty()) {
moduleToAppendTo = std::move(module_);
} else {
Block *block = moduleToAppendTo->getBody(0);
block->getOperations().splice(block->end(),
module_->getBody(0)->getOperations());
}
return success();
}

View File

@ -22,7 +22,12 @@ namespace mlir {
namespace torch {
namespace Torch {
enum class LibraryFunctionKind { ShapeFunction, DtypeFunction, Decomposition };
enum class LibraryFunctionKind {
ShapeFunction,
DtypeFunction,
Decomposition,
HasValueSemantics
};
// Searches the function library for an abstract interpretation function for
// `op`. If one is found, wraps the op in a `CalculateOp`, with the op placed in
@ -60,6 +65,16 @@ FailureOr<Value> adjustFunctionArg(
OpBuilder &b, Location loc, Value operand, Type desiredType,
function_ref<Value(OpBuilder &, Location, Value, Type)> baseTransformation =
[](OpBuilder &, Location, Value operand, Type) { return operand; });
std::string getLibraryFunctionPrefix(LibraryFunctionKind libFuncKind);
// Parse MLIR module at `filename` into a ModuleOp that will then
// be appended to an existing, fully hydrated, ModuleOp; note the module
// should have been instantiated with an associated context like so:
// `OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context));`
LogicalResult loadExtraLibrary(const std::string &filename,
OwningOpRef<ModuleOp> &moduleToAppendTo);
} // namespace Torch
} // namespace torch
} // namespace mlir

View File

@ -61,13 +61,23 @@ dtypeFunctionArgsBuilder(OpBuilder &b, Location loc,
}
namespace {
class ReifyDtypeCalculationsPass
struct ReifyDtypeCalculationsPass
: public ReifyDtypeCalculationsBase<ReifyDtypeCalculationsPass> {
ReifyDtypeCalculationsPass() = default;
ReifyDtypeCalculationsPass(StringRef extraLibrary) {
this->extraLibrary = extraLibrary.str();
}
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp module = getOperation();
OwningOpRef<ModuleOp> library =
parseSourceString<ModuleOp>(getAbstractInterpLibrary(), context);
if (!extraLibrary.empty())
if (failed(mlir::torch::Torch::loadExtraLibrary(extraLibrary, library))) {
emitError(module->getLoc(),
"Failed to load extra-library file at " + extraLibrary);
return signalPassFailure();
}
// Walk all the operations, and if we have a dtype function, wrap the op
// in a `torch.dtype.calculate` op.
@ -86,6 +96,6 @@ class ReifyDtypeCalculationsPass
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
Torch::createReifyDtypeCalculationsPass() {
return std::make_unique<ReifyDtypeCalculationsPass>();
Torch::createReifyDtypeCalculationsPass(StringRef extraLibrary) {
return std::make_unique<ReifyDtypeCalculationsPass>(extraLibrary);
}

View File

@ -14,6 +14,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "llvm/Support/MemoryBuffer.h"
using namespace mlir;
using namespace mlir::torch;
@ -55,8 +56,12 @@ shapeFunctionArgsBuilder(OpBuilder &b, Location loc,
}
namespace {
class ReifyShapeCalculationsPass
struct ReifyShapeCalculationsPass
: public ReifyShapeCalculationsBase<ReifyShapeCalculationsPass> {
ReifyShapeCalculationsPass() = default;
ReifyShapeCalculationsPass(StringRef extraLibrary) {
this->extraLibrary = extraLibrary.str();
}
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp module = getOperation();
@ -66,6 +71,12 @@ class ReifyShapeCalculationsPass
// O(#ops in the program) ideally.
OwningOpRef<ModuleOp> library =
parseSourceString<ModuleOp>(getAbstractInterpLibrary(), context);
if (!extraLibrary.empty())
if (failed(mlir::torch::Torch::loadExtraLibrary(extraLibrary, library))) {
emitError(module->getLoc(),
"Failed to load extra-library file at " + extraLibrary);
return signalPassFailure();
}
// Walk all the operations, and if we have a shape function, wrap the op
// in a `torch.shape.calculate` op.
@ -84,6 +95,6 @@ class ReifyShapeCalculationsPass
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::Torch::createReifyShapeCalculationsPass() {
return std::make_unique<ReifyShapeCalculationsPass>();
mlir::torch::Torch::createReifyShapeCalculationsPass(StringRef extraLibrary) {
return std::make_unique<ReifyShapeCalculationsPass>(extraLibrary);
}

View File

@ -18,6 +18,6 @@ class AddmmModule(torch.nn.Module):
example_args = 3 * [torch_mlir.TensorPlaceholder([-1, -1], torch.float32)]
print(torch_mlir.compile(AddmmModule(), example_args,
output_type="torch", backend_legal_ops=["torch.aten.addmm"]))
output_type="torch", backend_legal_ops=["aten.addmm"]))
# CHECK-LABEL: @forward
# CHECK: torch.aten.addmm

View File

@ -240,8 +240,8 @@ class ExampleArgs:
# ops in the backend contract, and move these lists somewhere deeper in the
# compiler where each backend can "own" its set of legal ops.
BACKEND_LEGAL_OPS = {
OutputType.TOSA: ['torch.aten.flatten.using_ints', 'torch.aten.native_layer_norm', 'torch.aten.linear'],
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', ],
OutputType.TOSA: ['aten.flatten.using_ints', 'aten.native_layer_norm', 'aten.linear'],
OutputType.LINALG_ON_TENSORS: ['aten.flatten.using_ints', ],
OutputType.STABLEHLO: [],
}
@ -252,6 +252,7 @@ def compile(model: torch.nn.Module,
use_tracing: bool = False,
ignore_traced_shapes=False,
backend_legal_ops: Optional[Sequence[str]] = None,
_completely_unsupported_in_progress_extra_library: Optional[str] = None,
verbose: bool = False):
"""Convert a PyTorch model to MLIR.
@ -367,7 +368,11 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
if output_type == OutputType.RAW:
return mb.module
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + "}"
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + (
(" extra-library=" + _completely_unsupported_in_progress_extra_library)
if (_completely_unsupported_in_progress_extra_library is not None)
else ""
) + "}"
run_pipeline_with_repro_report(
mb.module,
f"builtin.module(torchscript-module-to-torch-backend-pipeline{option_string})",

View File

@ -1,4 +1,4 @@
// RUN: torch-mlir-opt -torch-decompose-complex-ops="legal-ops=torch.aten.softmax.int" -split-input-file %s | FileCheck %s
// RUN: torch-mlir-opt -torch-decompose-complex-ops="legal-ops=aten.softmax.int" -split-input-file %s | FileCheck %s
// CHECK-LABEL: func.func @torch.aten.softmax.int$cst_dim
func.func @torch.aten.softmax.int$cst_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.tensor<[2,3],f32> {

View File

@ -1,4 +1,4 @@
// RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-function-to-torch-backend-pipeline{backend-legal-ops=torch.aten.square,torch.aten.argmax})' -split-input-file %s | FileCheck %s
// RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-function-to-torch-backend-pipeline{backend-legal-ops=aten.square,aten.argmax})' -split-input-file %s | FileCheck %s
// CHECK-LABEL: func.func @torch.aten.square
func.func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {

View File

@ -0,0 +1,72 @@
import os
import tempfile
import torch
import torch.utils.cpp_extension
import torch_mlir
from torch_mlir_e2e_test.annotations import export, annotate_args
# RUN: %PYTHON %s | FileCheck %s
def identity(x: torch.Tensor):
return x
goofy_lib = torch.library.Library("goofy", "DEF")
goofy_lib.define("identity(Tensor t) -> Tensor")
goofy_lib.impl("identity", identity)
class CustomOpExampleModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
]
)
def forward(self, a):
b = 2 * a
return torch.ops.goofy.identity(b)
mod = CustomOpExampleModule()
mod.eval()
abstract_interp_src = """\
func.func @__torch_mlir_shape_fn.goofy.identity(%arg0: !torch.list<int>) -> !torch.list<int> {
return %arg0 : !torch.list<int>
}
func.func @__torch_mlir_dtype_fn.goofy.identity(%arg0 : !torch.tuple<int, int>) -> !torch.int {
%0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int
return %0#1 : !torch.int
}
func.func @__torch_mlir_has_value_semantics_fn.goofy.identity() { return }
"""
with open("/tmp/custom_op_shape_dtype_fn.mlir", "w") as tmp:
tmp.write(abstract_interp_src)
module = torch_mlir.compile(
mod,
torch.ones(3, 4),
output_type="torch",
backend_legal_ops=["goofy.identity"],
_completely_unsupported_in_progress_extra_library="/tmp/custom_op_shape_dtype_fn.mlir",
)
print(module)
# CHECK: module attributes {torch.debug_module_name = "CustomOpExampleModule"} {
# CHECK: func.func @forward(%{{.*}}: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {
# CHECK: %{{.*}} = torch.constant.int 2
# CHECK: %{{.*}} = torch.aten.mul.Scalar %{{.*}}, %{{.*}} : !torch.vtensor<[3,4],f32>, !torch.int -> !torch.vtensor<[3,4],f32>
# CHECK: %{{.*}} = torch.operator "goofy.identity"(%{{.*}}) : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
# CHECK: return %1 : !torch.vtensor<[3,4],f32>
# CHECK: }
# CHECK: }