Maksim Levental 2023-03-24 21:50:01 -05:00 committed by GitHub
parent a7449785ec
commit 953ea39cb5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 323 additions and 64 deletions

View File

@ -56,7 +56,12 @@ struct TorchLoweringPipelineOptions
// to check for a specific set of legal ops to stop its iteration. // to check for a specific set of legal ops to stop its iteration.
ListOption<std::string> backendLegalOps{ ListOption<std::string> backendLegalOps{
*this, "backend-legal-ops", *this, "backend-legal-ops",
llvm::cl::desc("List of ops to be considered legal for the backend.")}; llvm::cl::desc("List of ops to be considered legal for the backend, such "
"as 'aten.foo'.")};
Option<std::string> extraLibrary{
*this, "extra-library",
llvm::cl::desc("Filename of MLIR module for splicing into the abstract interpretation library.")};
}; };
/// Creates a pipeline that lowers the object graph IR that is produced by /// Creates a pipeline that lowers the object graph IR that is produced by
@ -78,10 +83,12 @@ void createTorchSimplificationPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options); OpPassManager &pm, const TorchLoweringPipelineOptions &options);
/// Creates a pipeline that refines shapes of tensor operations in the program. /// Creates a pipeline that refines shapes of tensor operations in the program.
void createTorchShapeRefinementPipeline(OpPassManager &pm); void createTorchShapeRefinementPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
/// Creates a pipeline that refines dtype of tensor operations in the program. /// Creates a pipeline that refines dtype of tensor operations in the program.
void createTorchDtypeRefinementPipeline(OpPassManager &pm); void createTorchDtypeRefinementPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass(); std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
@ -89,7 +96,8 @@ std::unique_ptr<OperationPass<func::FuncOp>> createRefineTypesPass();
std::unique_ptr<OperationPass<ModuleOp>> createInlineGlobalSlotsPass(); std::unique_ptr<OperationPass<ModuleOp>> createInlineGlobalSlotsPass();
std::unique_ptr<OperationPass<func::FuncOp>> createReduceOpVariantsPass(); std::unique_ptr<OperationPass<func::FuncOp>>
createReduceOpVariantsPass(StringRef extraLibrary);
std::unique_ptr<OperationPass<func::FuncOp>> createMaximizeValueSemanticsPass(); std::unique_ptr<OperationPass<func::FuncOp>> createMaximizeValueSemanticsPass();
@ -100,14 +108,14 @@ createDecomposeComplexOpsPass(ArrayRef<std::string> legalOps);
std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOps(); std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOps();
std::unique_ptr<OperationPass<ModuleOp>> createPreprocessShapeLibraryPass(); std::unique_ptr<OperationPass<ModuleOp>>
createReifyShapeCalculationsPass(StringRef extraLibrary);
std::unique_ptr<OperationPass<ModuleOp>> createReifyShapeCalculationsPass();
std::unique_ptr<OperationPass<func::FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
createSimplifyShapeCalculationsPass(); createSimplifyShapeCalculationsPass();
std::unique_ptr<OperationPass<ModuleOp>> createReifyDtypeCalculationsPass(); std::unique_ptr<OperationPass<ModuleOp>>
createReifyDtypeCalculationsPass(StringRef extraLibrary);
std::unique_ptr<OperationPass<func::FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
createSimplifyDtypeCalculationsPass(); createSimplifyDtypeCalculationsPass();
@ -120,13 +128,16 @@ createEraseModuleInitializerPass();
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
createLowerToBackendContractPass(int maxIterations, bool decompose, createLowerToBackendContractPass(int maxIterations, bool decompose,
ArrayRef<std::string> backendLegalOps); ArrayRef<std::string> backendLegalOps,
StringRef extraLibrary);
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
createVerifyBackendContractNoDecompositionsPass(); createVerifyBackendContractNoDecompositionsPass();
StringRef getAbstractInterpLibrary(); StringRef getAbstractInterpLibrary();
static const char kTorchOpPrefix[] = R"(torch.)";
} // namespace Torch } // namespace Torch
/// Registers all Torch transformation passes. /// Registers all Torch transformation passes.

View File

@ -151,7 +151,13 @@ def InlineGlobalSlots : Pass<"torch-inline-global-slots", "ModuleOp"> {
def ReduceOpVariants : Pass<"torch-reduce-op-variants", "func::FuncOp"> { def ReduceOpVariants : Pass<"torch-reduce-op-variants", "func::FuncOp"> {
let summary = "Reduces variants of ops to a smaller set of ops."; let summary = "Reduces variants of ops to a smaller set of ops.";
let constructor = "mlir::torch::Torch::createReduceOpVariantsPass()"; let constructor = [{
mlir::torch::Torch::createReduceOpVariantsPass(/*extraLibrary=*/"")
}];
let options = [
Option<"extraLibrary", "extra-library", "std::string", /*default=*/"",
"MLIR module for verifying custom op value semantics">,
];
let description = [{ let description = [{
Replaces ops with other ops to reduce the number of variants that Replaces ops with other ops to reduce the number of variants that
need to be handled elsewhere in the code. need to be handled elsewhere in the code.
@ -240,7 +246,13 @@ def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "func::FuncOp"> {
def ReifyShapeCalculations : Pass<"torch-reify-shape-calculations", "ModuleOp"> { def ReifyShapeCalculations : Pass<"torch-reify-shape-calculations", "ModuleOp"> {
let summary = "Reify shape calculations."; let summary = "Reify shape calculations.";
let constructor = "mlir::torch::Torch::createReifyShapeCalculationsPass()"; let constructor = [{
mlir::torch::Torch::createReifyShapeCalculationsPass(/*extraLibrary=*/"")
}];
let options = [
Option<"extraLibrary", "extra-library", "std::string", /*default=*/"",
"MLIR module for splicing into the shape library">,
];
let description = [{ let description = [{
}]; }];
} }
@ -255,7 +267,13 @@ def SimplifyShapeCalculations : Pass<"torch-simplify-shape-calculations", "func:
def ReifyDtypeCalculations : Pass<"torch-reify-dtype-calculations", "ModuleOp"> { def ReifyDtypeCalculations : Pass<"torch-reify-dtype-calculations", "ModuleOp"> {
let summary = "Reify dtype calculations."; let summary = "Reify dtype calculations.";
let constructor = "mlir::torch::Torch::createReifyDtypeCalculationsPass()"; let constructor = [{
mlir::torch::Torch::createReifyDtypeCalculationsPass(/*extraLibrary=*/"")
}];
let options = [
Option<"extraLibrary", "extra-library", "std::string", /*default=*/"",
"MLIR module for splicing into the dtype library">,
];
let description = [{ let description = [{
}]; }];
} }
@ -291,7 +309,7 @@ def LowerToBackendContract
let summary = "Perform simplifications until the backend contract is satisfied."; let summary = "Perform simplifications until the backend contract is satisfied.";
let constructor = [{ let constructor = [{
mlir::torch::Torch::createLowerToBackendContractPass( mlir::torch::Torch::createLowerToBackendContractPass(
/*maxIterations=*/10, /*decompose=*/true, /*backendLegalOps=*/{}) /*maxIterations=*/10, /*decompose=*/true, /*backendLegalOps=*/{}, /*extraLibrary=*/"")
}]; }];
let description = [{ let description = [{
This pass performs the bulk of the lowering of the program's computations This pass performs the bulk of the lowering of the program's computations
@ -335,7 +353,9 @@ def LowerToBackendContract
Option<"decompose", "decompose", "bool", /*default=*/"true", Option<"decompose", "decompose", "bool", /*default=*/"true",
"Decompose ops.">, "Decompose ops.">,
ListOption<"backendLegalOps", "backend-legal-ops", "std::string", ListOption<"backendLegalOps", "backend-legal-ops", "std::string",
"List of ops to be considered legal for the backend."> "List of ops to be considered legal for the backend, such as 'aten.foo'.">,
Option<"extraLibrary", "extra-library", "std::string", /*default=*/"",
"MLIR module for splicing into the abstract interpretation library">,
]; ];
// TODO: Debug why this is needed, even though the input program has func.func // TODO: Debug why this is needed, even though the input program has func.func

View File

@ -3890,7 +3890,7 @@ private:
// on `Operation *` are not allowed, since there is no way of telling if // on `Operation *` are not allowed, since there is no way of telling if
// that pattern will match on an op in the `legalOpsSet` or not. // that pattern will match on an op in the `legalOpsSet` or not.
assert(opName && "All decomposition patterns must target a single op"); assert(opName && "All decomposition patterns must target a single op");
if (!legalOpsSet.contains(opName->getStringRef())) if (!legalOpsSet.contains(opName->getStringRef().ltrim(kTorchOpPrefix)))
patterns.add<DecomposePattern>(context); patterns.add<DecomposePattern>(context);
} }

View File

@ -18,6 +18,7 @@
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
#include "llvm/ADT/StringSet.h"
#define DEBUG_TYPE "torch-lower-to-backend-contract" #define DEBUG_TYPE "torch-lower-to-backend-contract"
@ -31,7 +32,7 @@ using namespace mlir::torch::Torch;
static void markDecomposedOpsAsIllegal(MLIRContext *context, static void markDecomposedOpsAsIllegal(MLIRContext *context,
ConversionTarget &target, ConversionTarget &target,
ArrayRef<std::string> backendLegalOps); llvm::StringSet<> backendLegalOps);
static LogicalResult checkType(Operation *op, Type type, static LogicalResult checkType(Operation *op, Type type,
bool actuallyEmitDiagnostics) { bool actuallyEmitDiagnostics) {
@ -246,11 +247,11 @@ static bool satisfiesBackendContract(ModuleOp module,
// Explicitly set ops and dialects allowed and not allowed in backend contract. // Explicitly set ops and dialects allowed and not allowed in backend contract.
static ConversionTarget static ConversionTarget
getBackendContractTarget(MLIRContext *context, bool decompose, getBackendContractTarget(MLIRContext *context, bool decompose,
ArrayRef<std::string> backendLegalOps) { llvm::StringSet<> backendLegalOpsSet) {
ConversionTarget target(*context); ConversionTarget target(*context);
target.addLegalDialect<func::FuncDialect, Torch::TorchDialect>(); target.addLegalDialect<func::FuncDialect, Torch::TorchDialect>();
if (decompose) if (decompose)
markDecomposedOpsAsIllegal(context, target, backendLegalOps); markDecomposedOpsAsIllegal(context, target, backendLegalOpsSet);
return target; return target;
} }
@ -260,21 +261,27 @@ class LowerToBackendContractPass
public: public:
LowerToBackendContractPass() = default; LowerToBackendContractPass() = default;
LowerToBackendContractPass(int maxIterations, bool decompose, LowerToBackendContractPass(int maxIterations, bool decompose,
ArrayRef<std::string> backendLegalOps) { ArrayRef<std::string> backendLegalOps,
StringRef extraLibrary) {
this->maxIterations = maxIterations; this->maxIterations = maxIterations;
this->decompose = decompose; this->decompose = decompose;
this->backendLegalOps = backendLegalOps; this->backendLegalOps = backendLegalOps;
this->extraLibrary = extraLibrary.str();
} }
void runOnOperation() override { void runOnOperation() override {
ModuleOp module = getOperation(); ModuleOp module = getOperation();
MLIRContext *context = &getContext(); MLIRContext *context = &getContext();
backendLegalOpsSet.clear();
backendLegalOpsSet.insert(backendLegalOps.begin(), backendLegalOps.end());
ConversionTarget target = ConversionTarget target =
getBackendContractTarget(context, decompose, backendLegalOps); getBackendContractTarget(context, decompose, backendLegalOpsSet);
OpPassManager pm(module.getOperationName()); OpPassManager pm(module.getOperationName());
TorchLoweringPipelineOptions options; TorchLoweringPipelineOptions options;
options.decompose = decompose; options.decompose = decompose;
options.backendLegalOps = backendLegalOps; options.backendLegalOps = backendLegalOps;
options.extraLibrary = extraLibrary;
createTorchSimplificationPipeline(pm, options); createTorchSimplificationPipeline(pm, options);
int i = 0; int i = 0;
@ -301,6 +308,8 @@ public:
<< " iterations of the simplification pipeline\n"; << " iterations of the simplification pipeline\n";
}); });
} }
private:
llvm::StringSet<> backendLegalOpsSet;
}; };
class VerifyBackendContractNoDecompositionsPass class VerifyBackendContractNoDecompositionsPass
@ -312,7 +321,7 @@ public:
MLIRContext *context = &getContext(); MLIRContext *context = &getContext();
ConversionTarget target = ConversionTarget target =
getBackendContractTarget(context, /*decompose*/false, getBackendContractTarget(context, /*decompose*/false,
/*backendLegalOps*/{}); /*backendLegalOpsSet*/{});
if (!satisfiesBackendContract(getOperation(), target, if (!satisfiesBackendContract(getOperation(), target,
/*actuallyEmitDiagnostics=*/true)) { /*actuallyEmitDiagnostics=*/true)) {
@ -324,9 +333,10 @@ public:
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::Torch::createLowerToBackendContractPass( mlir::torch::Torch::createLowerToBackendContractPass(
int maxIterations, bool decompose, ArrayRef<std::string> backendLegalOps) { int maxIterations, bool decompose, ArrayRef<std::string> backendLegalOps,
return std::make_unique<LowerToBackendContractPass>(maxIterations, decompose, StringRef extraLibrary) {
backendLegalOps); return std::make_unique<LowerToBackendContractPass>(
maxIterations, decompose, backendLegalOps, extraLibrary);
} }
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
@ -337,9 +347,9 @@ mlir::torch::Torch::createVerifyBackendContractNoDecompositionsPass() {
// The backend contract guarantees that ops with decompositions available will // The backend contract guarantees that ops with decompositions available will
// be decomposed. The only way to have an op reach the backend contract without // be decomposed. The only way to have an op reach the backend contract without
// getting decomposed is by having the user explicitly specify that op in the // getting decomposed is by having the user explicitly specify that op in the
// `backendLegalOps` argument to the `LowerToBackendContractPass`. Therefore, // `backendLegalOpsSet` argument to the `LowerToBackendContractPass`. Therefore,
// here we mark as illegal all ops with decompositions except for those in // here we mark as illegal all ops with decompositions except for those in
// `backendLegalOps`. // `backendLegalOpsSet`.
// //
// The legality check takes place here instead of in the `DecomposeComplexOps` // The legality check takes place here instead of in the `DecomposeComplexOps`
// pass for two reasons: // pass for two reasons:
@ -352,7 +362,7 @@ mlir::torch::Torch::createVerifyBackendContractNoDecompositionsPass() {
// decompositions explicit in this file // decompositions explicit in this file
static void markDecomposedOpsAsIllegal(MLIRContext *context, static void markDecomposedOpsAsIllegal(MLIRContext *context,
ConversionTarget &target, ConversionTarget &target,
ArrayRef<std::string> backendLegalOps) { llvm::StringSet<> backendLegalOpsSet) {
target.addIllegalOp<AtenSoftmaxIntOp>(); target.addIllegalOp<AtenSoftmaxIntOp>();
target.addIllegalOp<Aten_SoftmaxOp>(); target.addIllegalOp<Aten_SoftmaxOp>();
target.addIllegalOp<Aten_LogSoftmaxOp>(); target.addIllegalOp<Aten_LogSoftmaxOp>();
@ -463,7 +473,13 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenVarMeanOp>(); target.addIllegalOp<AtenVarMeanOp>();
target.addIllegalOp<AtenNewEmptyStridedOp>(); target.addIllegalOp<AtenNewEmptyStridedOp>();
target.addIllegalOp<AtenBucketizeTensorOp>(); target.addIllegalOp<AtenBucketizeTensorOp>();
for (std::string opName : backendLegalOps) { for (auto &opName : backendLegalOpsSet) {
target.addLegalOp(OperationName(opName, context)); target.addLegalOp(
OperationName(kTorchOpPrefix + opName.first().str(), context));
} }
target.addDynamicallyLegalOp<OperatorOp>(
[backendLegalOpsSet](OperatorOp opOp) {
auto opName = opOp->getAttr("name").cast<StringAttr>().getValue();
return backendLegalOpsSet.contains(opName);
});
} }

View File

@ -25,7 +25,7 @@ void mlir::torch::registerTorchPasses() {
"torch-simplification-pipeline", "torch-simplification-pipeline",
"Pipeline simplifying computations in the program.", "Pipeline simplifying computations in the program.",
mlir::torch::Torch::createTorchSimplificationPipeline); mlir::torch::Torch::createTorchSimplificationPipeline);
mlir::PassPipelineRegistration<>( mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torch-shape-refinement-pipeline", "Pipeline refining shapes of tensors.", "torch-shape-refinement-pipeline", "Pipeline refining shapes of tensors.",
mlir::torch::Torch::createTorchShapeRefinementPipeline); mlir::torch::Torch::createTorchShapeRefinementPipeline);
} }
@ -66,7 +66,8 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
// Perform the bulk of lowering to the backend contract. // Perform the bulk of lowering to the backend contract.
// See the pass documentation for more information. // See the pass documentation for more information.
pm.addPass(createLowerToBackendContractPass( pm.addPass(createLowerToBackendContractPass(
options.maxIterations, options.decompose, options.backendLegalOps)); options.maxIterations, options.decompose, options.backendLegalOps,
options.extraLibrary));
} }
// A simplification pipeline to establish the invariants of the backend // A simplification pipeline to establish the invariants of the backend
@ -108,7 +109,8 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass()); pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createRecomposeComplexOps()); pm.addNestedPass<func::FuncOp>(createRecomposeComplexOps());
// Reduce variants of ops to a smaller set of primitives. // Reduce variants of ops to a smaller set of primitives.
pm.addNestedPass<func::FuncOp>(createReduceOpVariantsPass()); pm.addNestedPass<func::FuncOp>(
createReduceOpVariantsPass(options.extraLibrary));
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass()); pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// Remove dead global slots. // Remove dead global slots.
pm.addPass(createSymbolDCEPass()); pm.addPass(createSymbolDCEPass());
@ -121,8 +123,8 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
// This should be run before RefineTypes (which primarily does dtype // This should be run before RefineTypes (which primarily does dtype
// inference), because Torch type promotion rules actually depend on the shape // inference), because Torch type promotion rules actually depend on the shape
// of the operand. // of the operand.
createTorchShapeRefinementPipeline(pm); createTorchShapeRefinementPipeline(pm, options);
createTorchDtypeRefinementPipeline(pm); createTorchDtypeRefinementPipeline(pm, options);
// Refine types in the program, which mainly means inferring dtypes of ops. // Refine types in the program, which mainly means inferring dtypes of ops.
pm.addNestedPass<func::FuncOp>(Torch::createRefineTypesPass()); pm.addNestedPass<func::FuncOp>(Torch::createRefineTypesPass());
// Propagate to ABI return types the shape/dtype information discovered by // Propagate to ABI return types the shape/dtype information discovered by
@ -141,13 +143,15 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
static void createRefinementPipeline( static void createRefinementPipeline(
mlir::OpPassManager &pm, mlir::OpPassManager &pm,
llvm::function_ref<std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>()> llvm::function_ref<
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>(llvm::StringRef)>
reifyCalculationsPass, reifyCalculationsPass,
llvm::function_ref< llvm::function_ref<
std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>()> std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>()>
simplifyCalculationsPass) { simplifyCalculationsPass,
const mlir::torch::Torch::TorchLoweringPipelineOptions &options) {
// Reify the library functions for each op that is present in the library. // Reify the library functions for each op that is present in the library.
pm.addPass(reifyCalculationsPass()); pm.addPass(reifyCalculationsPass(options.extraLibrary));
// Inline the library functions to enable analysis and transformation. // Inline the library functions to enable analysis and transformation.
// TODO: Only inline library functions (this will currently inline // TODO: Only inline library functions (this will currently inline
@ -168,12 +172,14 @@ static void createRefinementPipeline(
mlir::torch::Torch::createDropAbstractInterpCalculationsPass()); mlir::torch::Torch::createDropAbstractInterpCalculationsPass());
} }
void mlir::torch::Torch::createTorchShapeRefinementPipeline(OpPassManager &pm) { void mlir::torch::Torch::createTorchShapeRefinementPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
createRefinementPipeline(pm, Torch::createReifyShapeCalculationsPass, createRefinementPipeline(pm, Torch::createReifyShapeCalculationsPass,
Torch::createSimplifyShapeCalculationsPass); Torch::createSimplifyShapeCalculationsPass, options);
} }
void mlir::torch::Torch::createTorchDtypeRefinementPipeline(OpPassManager &pm) { void mlir::torch::Torch::createTorchDtypeRefinementPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
createRefinementPipeline(pm, Torch::createReifyDtypeCalculationsPass, createRefinementPipeline(pm, Torch::createReifyDtypeCalculationsPass,
Torch::createSimplifyDtypeCalculationsPass); Torch::createSimplifyDtypeCalculationsPass, options);
} }

View File

@ -12,6 +12,7 @@
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "ReifyAbstractInterpCalculationsUtils.h"
#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringExtras.h"
using namespace mlir; using namespace mlir;
@ -52,17 +53,39 @@ static Type getContainerOrTensorTypeWithValueSemantics(Type type) {
} }
} }
static bool
operatorOpHasValueSemantics(OperatorOp opOp,
std::optional<SymbolTable> extraLibrary) {
if (!extraLibrary.has_value())
return false;
auto opName = opOp->getAttr("name").cast<StringAttr>().getValue();
std::string libFuncName = (mlir::torch::Torch::getLibraryFunctionPrefix(
LibraryFunctionKind::HasValueSemantics) +
Twine(opName))
.str();
auto libFunc = extraLibrary->lookup<func::FuncOp>(libFuncName);
return bool(libFunc);
}
namespace { namespace {
// Convert value semantic ops operating on mutable arrays to instead operate on // Convert value semantic ops operating on mutable arrays to instead operate on
// immutable tensors. // immutable tensors.
class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
public: public:
ConvertHasValueSemanticsOpsToValueTensors(MLIRContext *context) ConvertHasValueSemanticsOpsToValueTensors(MLIRContext *context,
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} const std::optional<SymbolTable>& extraLibrary)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {
this->extraLibrary = extraLibrary;
}
LogicalResult matchAndRewrite(Operation *op, LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
if (!op->hasTrait<Torch::OpTrait::HasValueSemantics>()) if (isa<OperatorOp>(op)) {
if (!operatorOpHasValueSemantics(cast<OperatorOp>(op), extraLibrary)) {
return rewriter.notifyMatchFailure(op, "does not have value semantics");
}
} else if (!op->hasTrait<Torch::OpTrait::HasValueSemantics>()) {
return rewriter.notifyMatchFailure(op, "does not have value semantics"); return rewriter.notifyMatchFailure(op, "does not have value semantics");
}
rewriter.startRootUpdate(op); rewriter.startRootUpdate(op);
// Convert all operands. // Convert all operands.
@ -160,6 +183,8 @@ public:
rewriter.finalizeRootUpdate(op); rewriter.finalizeRootUpdate(op);
return success(); return success();
} }
private:
std::optional<SymbolTable> extraLibrary;
}; };
} // namespace } // namespace
@ -241,11 +266,30 @@ reduceNonValueTensorLiteralOpToValueTensorLiteralOp(NonValueTensorLiteralOp op,
} }
namespace { namespace {
class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> { struct ReduceOpVariantsPass
: public ReduceOpVariantsBase<ReduceOpVariantsPass> {
ReduceOpVariantsPass() = default;
ReduceOpVariantsPass(StringRef extraLibrary) {
this->extraLibrary = extraLibrary.str();
}
void runOnOperation() override { void runOnOperation() override {
MLIRContext *context = &getContext(); MLIRContext *context = &getContext();
RewritePatternSet patterns(context); RewritePatternSet patterns(context);
patterns.add<ConvertHasValueSemanticsOpsToValueTensors>(context); OwningOpRef<ModuleOp> extraLibraryModule =
ModuleOp::create(UnknownLoc::get(context));
std::optional<SymbolTable> extraLibraryModuleSymTable = std::nullopt;
if (!extraLibrary.empty()) {
if (failed(loadExtraLibrary(extraLibrary, extraLibraryModule))) {
emitError(getOperation()->getLoc(),
"Failed to load extra-library file at " + extraLibrary);
return signalPassFailure();
}
extraLibraryModuleSymTable =
SymbolTable(extraLibraryModule->getOperation());
}
patterns.add<ConvertHasValueSemanticsOpsToValueTensors>(
context, extraLibraryModuleSymTable);
patterns.add<ReduceTrailingUnderscoreInplaceVariant>(context); patterns.add<ReduceTrailingUnderscoreInplaceVariant>(context);
patterns.add(reduceNonValueTensorLiteralOpToValueTensorLiteralOp); patterns.add(reduceNonValueTensorLiteralOpToValueTensorLiteralOp);
patterns.add<ReduceNonValueSemanticOps>(context); patterns.add<ReduceNonValueSemanticOps>(context);
@ -253,8 +297,12 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
ConversionTarget target(*context); ConversionTarget target(*context);
target.addIllegalOp<NonValueTensorLiteralOp>(); target.addIllegalOp<NonValueTensorLiteralOp>();
target.addIllegalOp<AtenBernoulli_FloatOp>(); target.addIllegalOp<AtenBernoulli_FloatOp>();
target.markUnknownOpDynamicallyLegal([](Operation *op) { target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable](
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>()) { Operation *op) {
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>() ||
(isa<OperatorOp>(op) &&
operatorOpHasValueSemantics(cast<OperatorOp>(op),
extraLibraryModuleSymTable))) {
auto hasValueSemantics = [](Type t) { auto hasValueSemantics = [](Type t) {
// TODO: Make this an allowlist based on a closed torch dialect // TODO: Make this an allowlist based on a closed torch dialect
// type system. // type system.
@ -281,6 +329,6 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
} // namespace } // namespace
std::unique_ptr<OperationPass<func::FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createReduceOpVariantsPass() { mlir::torch::Torch::createReduceOpVariantsPass(StringRef extraLibrary) {
return std::make_unique<ReduceOpVariantsPass>(); return std::make_unique<ReduceOpVariantsPass>(extraLibrary);
} }

View File

@ -8,18 +8,25 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "ReifyAbstractInterpCalculationsUtils.h" #include "ReifyAbstractInterpCalculationsUtils.h"
#include "mlir/Parser/Parser.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "llvm/ADT/StringSet.h" #include "llvm/ADT/StringSet.h"
#include "llvm/Support/ErrorOr.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SourceMgr.h"
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
static std::string getLibraryFunctionPrefix(LibraryFunctionKind libFuncKind) { std::string
mlir::torch::Torch::getLibraryFunctionPrefix(LibraryFunctionKind libFuncKind) {
if (libFuncKind == LibraryFunctionKind::ShapeFunction) if (libFuncKind == LibraryFunctionKind::ShapeFunction)
return "__torch_mlir_shape_fn."; return "__torch_mlir_shape_fn.";
else if (libFuncKind == LibraryFunctionKind::DtypeFunction) else if (libFuncKind == LibraryFunctionKind::DtypeFunction)
return "__torch_mlir_dtype_fn."; return "__torch_mlir_dtype_fn.";
else if (libFuncKind == LibraryFunctionKind::HasValueSemantics)
return "__torch_mlir_has_value_semantics_fn.";
llvm_unreachable( llvm_unreachable(
"`getLibraryFunctionPrefix` called with an unsupported `CalculateOp`"); "`getLibraryFunctionPrefix` called with an unsupported `CalculateOp`");
} }
@ -73,6 +80,8 @@ LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable(
// looking them up in the library. // looking them up in the library.
if (name.startswith("valsem.")) if (name.startswith("valsem."))
name = name.drop_front(strlen("valsem.")); name = name.drop_front(strlen("valsem."));
if (isa<OperatorOp>(op))
name = cast<OperatorOp>(op)->getAttr("name").cast<StringAttr>().getValue();
std::string libFuncName = std::string libFuncName =
(getLibraryFunctionPrefix(libFuncKind) + Twine(name)).str(); (getLibraryFunctionPrefix(libFuncKind) + Twine(name)).str();
auto libFunc = library.lookupSymbol<func::FuncOp>(libFuncName); auto libFunc = library.lookupSymbol<func::FuncOp>(libFuncName);
@ -288,3 +297,39 @@ FailureOr<Value> Torch::adjustFunctionArg(
// Pass the operand as-is. // Pass the operand as-is.
return operand; return operand;
} }
LogicalResult
mlir::torch::Torch::loadExtraLibrary(const std::string &filename,
OwningOpRef<ModuleOp> &moduleToAppendTo) {
auto ctx = moduleToAppendTo->getContext();
assert(ctx && "Module should be fully initialized.");
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(filename);
if (std::error_code ec = fileOrErr.getError()) {
llvm::errs() << "Could not open input file: " << ec.message() << "\n";
return failure();
}
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
OwningOpRef<ModuleOp> module_ =
mlir::parseSourceFile<mlir::ModuleOp>(sourceMgr, ctx);
if (!module_) {
llvm::errs() << "Error can't load file " << filename << "\n";
return failure();
}
assert((moduleToAppendTo->getBodyRegion().empty() ||
moduleToAppendTo->getBodyRegion().hasOneBlock()) &&
"Module should have at most one block.");
if (moduleToAppendTo->getBodyRegion().empty()) {
moduleToAppendTo = std::move(module_);
} else {
Block *block = moduleToAppendTo->getBody(0);
block->getOperations().splice(block->end(),
module_->getBody(0)->getOperations());
}
return success();
}

View File

@ -22,7 +22,12 @@ namespace mlir {
namespace torch { namespace torch {
namespace Torch { namespace Torch {
enum class LibraryFunctionKind { ShapeFunction, DtypeFunction, Decomposition }; enum class LibraryFunctionKind {
ShapeFunction,
DtypeFunction,
Decomposition,
HasValueSemantics
};
// Searches the function library for an abstract interpretation function for // Searches the function library for an abstract interpretation function for
// `op`. If one is found, wraps the op in a `CalculateOp`, with the op placed in // `op`. If one is found, wraps the op in a `CalculateOp`, with the op placed in
@ -60,6 +65,16 @@ FailureOr<Value> adjustFunctionArg(
OpBuilder &b, Location loc, Value operand, Type desiredType, OpBuilder &b, Location loc, Value operand, Type desiredType,
function_ref<Value(OpBuilder &, Location, Value, Type)> baseTransformation = function_ref<Value(OpBuilder &, Location, Value, Type)> baseTransformation =
[](OpBuilder &, Location, Value operand, Type) { return operand; }); [](OpBuilder &, Location, Value operand, Type) { return operand; });
std::string getLibraryFunctionPrefix(LibraryFunctionKind libFuncKind);
// Parse MLIR module at `filename` into a ModuleOp that will then
// be appended to an existing, fully hydrated, ModuleOp; note the module
// should have been instantiated with an associated context like so:
// `OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context));`
LogicalResult loadExtraLibrary(const std::string &filename,
OwningOpRef<ModuleOp> &moduleToAppendTo);
} // namespace Torch } // namespace Torch
} // namespace torch } // namespace torch
} // namespace mlir } // namespace mlir

View File

@ -61,13 +61,23 @@ dtypeFunctionArgsBuilder(OpBuilder &b, Location loc,
} }
namespace { namespace {
class ReifyDtypeCalculationsPass struct ReifyDtypeCalculationsPass
: public ReifyDtypeCalculationsBase<ReifyDtypeCalculationsPass> { : public ReifyDtypeCalculationsBase<ReifyDtypeCalculationsPass> {
ReifyDtypeCalculationsPass() = default;
ReifyDtypeCalculationsPass(StringRef extraLibrary) {
this->extraLibrary = extraLibrary.str();
}
void runOnOperation() override { void runOnOperation() override {
MLIRContext *context = &getContext(); MLIRContext *context = &getContext();
ModuleOp module = getOperation(); ModuleOp module = getOperation();
OwningOpRef<ModuleOp> library = OwningOpRef<ModuleOp> library =
parseSourceString<ModuleOp>(getAbstractInterpLibrary(), context); parseSourceString<ModuleOp>(getAbstractInterpLibrary(), context);
if (!extraLibrary.empty())
if (failed(mlir::torch::Torch::loadExtraLibrary(extraLibrary, library))) {
emitError(module->getLoc(),
"Failed to load extra-library file at " + extraLibrary);
return signalPassFailure();
}
// Walk all the operations, and if we have a dtype function, wrap the op // Walk all the operations, and if we have a dtype function, wrap the op
// in a `torch.dtype.calculate` op. // in a `torch.dtype.calculate` op.
@ -86,6 +96,6 @@ class ReifyDtypeCalculationsPass
} // namespace } // namespace
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
Torch::createReifyDtypeCalculationsPass() { Torch::createReifyDtypeCalculationsPass(StringRef extraLibrary) {
return std::make_unique<ReifyDtypeCalculationsPass>(); return std::make_unique<ReifyDtypeCalculationsPass>(extraLibrary);
} }

View File

@ -14,6 +14,7 @@
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "llvm/Support/MemoryBuffer.h"
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
@ -55,8 +56,12 @@ shapeFunctionArgsBuilder(OpBuilder &b, Location loc,
} }
namespace { namespace {
class ReifyShapeCalculationsPass struct ReifyShapeCalculationsPass
: public ReifyShapeCalculationsBase<ReifyShapeCalculationsPass> { : public ReifyShapeCalculationsBase<ReifyShapeCalculationsPass> {
ReifyShapeCalculationsPass() = default;
ReifyShapeCalculationsPass(StringRef extraLibrary) {
this->extraLibrary = extraLibrary.str();
}
void runOnOperation() override { void runOnOperation() override {
MLIRContext *context = &getContext(); MLIRContext *context = &getContext();
ModuleOp module = getOperation(); ModuleOp module = getOperation();
@ -66,6 +71,12 @@ class ReifyShapeCalculationsPass
// O(#ops in the program) ideally. // O(#ops in the program) ideally.
OwningOpRef<ModuleOp> library = OwningOpRef<ModuleOp> library =
parseSourceString<ModuleOp>(getAbstractInterpLibrary(), context); parseSourceString<ModuleOp>(getAbstractInterpLibrary(), context);
if (!extraLibrary.empty())
if (failed(mlir::torch::Torch::loadExtraLibrary(extraLibrary, library))) {
emitError(module->getLoc(),
"Failed to load extra-library file at " + extraLibrary);
return signalPassFailure();
}
// Walk all the operations, and if we have a shape function, wrap the op // Walk all the operations, and if we have a shape function, wrap the op
// in a `torch.shape.calculate` op. // in a `torch.shape.calculate` op.
@ -84,6 +95,6 @@ class ReifyShapeCalculationsPass
} // namespace } // namespace
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::Torch::createReifyShapeCalculationsPass() { mlir::torch::Torch::createReifyShapeCalculationsPass(StringRef extraLibrary) {
return std::make_unique<ReifyShapeCalculationsPass>(); return std::make_unique<ReifyShapeCalculationsPass>(extraLibrary);
} }

View File

@ -18,6 +18,6 @@ class AddmmModule(torch.nn.Module):
example_args = 3 * [torch_mlir.TensorPlaceholder([-1, -1], torch.float32)] example_args = 3 * [torch_mlir.TensorPlaceholder([-1, -1], torch.float32)]
print(torch_mlir.compile(AddmmModule(), example_args, print(torch_mlir.compile(AddmmModule(), example_args,
output_type="torch", backend_legal_ops=["torch.aten.addmm"])) output_type="torch", backend_legal_ops=["aten.addmm"]))
# CHECK-LABEL: @forward # CHECK-LABEL: @forward
# CHECK: torch.aten.addmm # CHECK: torch.aten.addmm

View File

@ -240,8 +240,8 @@ class ExampleArgs:
# ops in the backend contract, and move these lists somewhere deeper in the # ops in the backend contract, and move these lists somewhere deeper in the
# compiler where each backend can "own" its set of legal ops. # compiler where each backend can "own" its set of legal ops.
BACKEND_LEGAL_OPS = { BACKEND_LEGAL_OPS = {
OutputType.TOSA: ['torch.aten.flatten.using_ints', 'torch.aten.native_layer_norm', 'torch.aten.linear'], OutputType.TOSA: ['aten.flatten.using_ints', 'aten.native_layer_norm', 'aten.linear'],
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', ], OutputType.LINALG_ON_TENSORS: ['aten.flatten.using_ints', ],
OutputType.STABLEHLO: [], OutputType.STABLEHLO: [],
} }
@ -252,6 +252,7 @@ def compile(model: torch.nn.Module,
use_tracing: bool = False, use_tracing: bool = False,
ignore_traced_shapes=False, ignore_traced_shapes=False,
backend_legal_ops: Optional[Sequence[str]] = None, backend_legal_ops: Optional[Sequence[str]] = None,
_completely_unsupported_in_progress_extra_library: Optional[str] = None,
verbose: bool = False): verbose: bool = False):
"""Convert a PyTorch model to MLIR. """Convert a PyTorch model to MLIR.
@ -367,7 +368,11 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
if output_type == OutputType.RAW: if output_type == OutputType.RAW:
return mb.module return mb.module
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + "}" option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + (
(" extra-library=" + _completely_unsupported_in_progress_extra_library)
if (_completely_unsupported_in_progress_extra_library is not None)
else ""
) + "}"
run_pipeline_with_repro_report( run_pipeline_with_repro_report(
mb.module, mb.module,
f"builtin.module(torchscript-module-to-torch-backend-pipeline{option_string})", f"builtin.module(torchscript-module-to-torch-backend-pipeline{option_string})",

View File

@ -1,4 +1,4 @@
// RUN: torch-mlir-opt -torch-decompose-complex-ops="legal-ops=torch.aten.softmax.int" -split-input-file %s | FileCheck %s // RUN: torch-mlir-opt -torch-decompose-complex-ops="legal-ops=aten.softmax.int" -split-input-file %s | FileCheck %s
// CHECK-LABEL: func.func @torch.aten.softmax.int$cst_dim // CHECK-LABEL: func.func @torch.aten.softmax.int$cst_dim
func.func @torch.aten.softmax.int$cst_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.tensor<[2,3],f32> { func.func @torch.aten.softmax.int$cst_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.tensor<[2,3],f32> {

View File

@ -1,4 +1,4 @@
// RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-function-to-torch-backend-pipeline{backend-legal-ops=torch.aten.square,torch.aten.argmax})' -split-input-file %s | FileCheck %s // RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-function-to-torch-backend-pipeline{backend-legal-ops=aten.square,aten.argmax})' -split-input-file %s | FileCheck %s
// CHECK-LABEL: func.func @torch.aten.square // CHECK-LABEL: func.func @torch.aten.square
func.func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { func.func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {

View File

@ -0,0 +1,72 @@
import os
import tempfile
import torch
import torch.utils.cpp_extension
import torch_mlir
from torch_mlir_e2e_test.annotations import export, annotate_args
# RUN: %PYTHON %s | FileCheck %s
def identity(x: torch.Tensor):
return x
goofy_lib = torch.library.Library("goofy", "DEF")
goofy_lib.define("identity(Tensor t) -> Tensor")
goofy_lib.impl("identity", identity)
class CustomOpExampleModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
]
)
def forward(self, a):
b = 2 * a
return torch.ops.goofy.identity(b)
mod = CustomOpExampleModule()
mod.eval()
abstract_interp_src = """\
func.func @__torch_mlir_shape_fn.goofy.identity(%arg0: !torch.list<int>) -> !torch.list<int> {
return %arg0 : !torch.list<int>
}
func.func @__torch_mlir_dtype_fn.goofy.identity(%arg0 : !torch.tuple<int, int>) -> !torch.int {
%0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int
return %0#1 : !torch.int
}
func.func @__torch_mlir_has_value_semantics_fn.goofy.identity() { return }
"""
with open("/tmp/custom_op_shape_dtype_fn.mlir", "w") as tmp:
tmp.write(abstract_interp_src)
module = torch_mlir.compile(
mod,
torch.ones(3, 4),
output_type="torch",
backend_legal_ops=["goofy.identity"],
_completely_unsupported_in_progress_extra_library="/tmp/custom_op_shape_dtype_fn.mlir",
)
print(module)
# CHECK: module attributes {torch.debug_module_name = "CustomOpExampleModule"} {
# CHECK: func.func @forward(%{{.*}}: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {
# CHECK: %{{.*}} = torch.constant.int 2
# CHECK: %{{.*}} = torch.aten.mul.Scalar %{{.*}}, %{{.*}} : !torch.vtensor<[3,4],f32>, !torch.int -> !torch.vtensor<[3,4],f32>
# CHECK: %{{.*}} = torch.operator "goofy.identity"(%{{.*}}) : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
# CHECK: return %1 : !torch.vtensor<[3,4],f32>
# CHECK: }
# CHECK: }