mirror of https://github.com/llvm/torch-mlir
handles 2,3,4 from https://github.com/llvm/torch-mlir/issues/1963 (#1964)
parent
a7449785ec
commit
953ea39cb5
|
@ -56,7 +56,12 @@ struct TorchLoweringPipelineOptions
|
||||||
// to check for a specific set of legal ops to stop its iteration.
|
// to check for a specific set of legal ops to stop its iteration.
|
||||||
ListOption<std::string> backendLegalOps{
|
ListOption<std::string> backendLegalOps{
|
||||||
*this, "backend-legal-ops",
|
*this, "backend-legal-ops",
|
||||||
llvm::cl::desc("List of ops to be considered legal for the backend.")};
|
llvm::cl::desc("List of ops to be considered legal for the backend, such "
|
||||||
|
"as 'aten.foo'.")};
|
||||||
|
|
||||||
|
Option<std::string> extraLibrary{
|
||||||
|
*this, "extra-library",
|
||||||
|
llvm::cl::desc("Filename of MLIR module for splicing into the abstract interpretation library.")};
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Creates a pipeline that lowers the object graph IR that is produced by
|
/// Creates a pipeline that lowers the object graph IR that is produced by
|
||||||
|
@ -78,10 +83,12 @@ void createTorchSimplificationPipeline(
|
||||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
|
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
|
||||||
|
|
||||||
/// Creates a pipeline that refines shapes of tensor operations in the program.
|
/// Creates a pipeline that refines shapes of tensor operations in the program.
|
||||||
void createTorchShapeRefinementPipeline(OpPassManager &pm);
|
void createTorchShapeRefinementPipeline(
|
||||||
|
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
|
||||||
|
|
||||||
/// Creates a pipeline that refines dtype of tensor operations in the program.
|
/// Creates a pipeline that refines dtype of tensor operations in the program.
|
||||||
void createTorchDtypeRefinementPipeline(OpPassManager &pm);
|
void createTorchDtypeRefinementPipeline(
|
||||||
|
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
|
std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
|
||||||
|
|
||||||
|
@ -89,7 +96,8 @@ std::unique_ptr<OperationPass<func::FuncOp>> createRefineTypesPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createInlineGlobalSlotsPass();
|
std::unique_ptr<OperationPass<ModuleOp>> createInlineGlobalSlotsPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>> createReduceOpVariantsPass();
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
|
createReduceOpVariantsPass(StringRef extraLibrary);
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>> createMaximizeValueSemanticsPass();
|
std::unique_ptr<OperationPass<func::FuncOp>> createMaximizeValueSemanticsPass();
|
||||||
|
|
||||||
|
@ -100,14 +108,14 @@ createDecomposeComplexOpsPass(ArrayRef<std::string> legalOps);
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOps();
|
std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOps();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createPreprocessShapeLibraryPass();
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
createReifyShapeCalculationsPass(StringRef extraLibrary);
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createReifyShapeCalculationsPass();
|
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
createSimplifyShapeCalculationsPass();
|
createSimplifyShapeCalculationsPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createReifyDtypeCalculationsPass();
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
createReifyDtypeCalculationsPass(StringRef extraLibrary);
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
createSimplifyDtypeCalculationsPass();
|
createSimplifyDtypeCalculationsPass();
|
||||||
|
@ -120,13 +128,16 @@ createEraseModuleInitializerPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
createLowerToBackendContractPass(int maxIterations, bool decompose,
|
createLowerToBackendContractPass(int maxIterations, bool decompose,
|
||||||
ArrayRef<std::string> backendLegalOps);
|
ArrayRef<std::string> backendLegalOps,
|
||||||
|
StringRef extraLibrary);
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
createVerifyBackendContractNoDecompositionsPass();
|
createVerifyBackendContractNoDecompositionsPass();
|
||||||
|
|
||||||
StringRef getAbstractInterpLibrary();
|
StringRef getAbstractInterpLibrary();
|
||||||
|
|
||||||
|
static const char kTorchOpPrefix[] = R"(torch.)";
|
||||||
|
|
||||||
} // namespace Torch
|
} // namespace Torch
|
||||||
|
|
||||||
/// Registers all Torch transformation passes.
|
/// Registers all Torch transformation passes.
|
||||||
|
|
|
@ -151,7 +151,13 @@ def InlineGlobalSlots : Pass<"torch-inline-global-slots", "ModuleOp"> {
|
||||||
|
|
||||||
def ReduceOpVariants : Pass<"torch-reduce-op-variants", "func::FuncOp"> {
|
def ReduceOpVariants : Pass<"torch-reduce-op-variants", "func::FuncOp"> {
|
||||||
let summary = "Reduces variants of ops to a smaller set of ops.";
|
let summary = "Reduces variants of ops to a smaller set of ops.";
|
||||||
let constructor = "mlir::torch::Torch::createReduceOpVariantsPass()";
|
let constructor = [{
|
||||||
|
mlir::torch::Torch::createReduceOpVariantsPass(/*extraLibrary=*/"")
|
||||||
|
}];
|
||||||
|
let options = [
|
||||||
|
Option<"extraLibrary", "extra-library", "std::string", /*default=*/"",
|
||||||
|
"MLIR module for verifying custom op value semantics">,
|
||||||
|
];
|
||||||
let description = [{
|
let description = [{
|
||||||
Replaces ops with other ops to reduce the number of variants that
|
Replaces ops with other ops to reduce the number of variants that
|
||||||
need to be handled elsewhere in the code.
|
need to be handled elsewhere in the code.
|
||||||
|
@ -240,7 +246,13 @@ def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "func::FuncOp"> {
|
||||||
|
|
||||||
def ReifyShapeCalculations : Pass<"torch-reify-shape-calculations", "ModuleOp"> {
|
def ReifyShapeCalculations : Pass<"torch-reify-shape-calculations", "ModuleOp"> {
|
||||||
let summary = "Reify shape calculations.";
|
let summary = "Reify shape calculations.";
|
||||||
let constructor = "mlir::torch::Torch::createReifyShapeCalculationsPass()";
|
let constructor = [{
|
||||||
|
mlir::torch::Torch::createReifyShapeCalculationsPass(/*extraLibrary=*/"")
|
||||||
|
}];
|
||||||
|
let options = [
|
||||||
|
Option<"extraLibrary", "extra-library", "std::string", /*default=*/"",
|
||||||
|
"MLIR module for splicing into the shape library">,
|
||||||
|
];
|
||||||
let description = [{
|
let description = [{
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
@ -255,7 +267,13 @@ def SimplifyShapeCalculations : Pass<"torch-simplify-shape-calculations", "func:
|
||||||
|
|
||||||
def ReifyDtypeCalculations : Pass<"torch-reify-dtype-calculations", "ModuleOp"> {
|
def ReifyDtypeCalculations : Pass<"torch-reify-dtype-calculations", "ModuleOp"> {
|
||||||
let summary = "Reify dtype calculations.";
|
let summary = "Reify dtype calculations.";
|
||||||
let constructor = "mlir::torch::Torch::createReifyDtypeCalculationsPass()";
|
let constructor = [{
|
||||||
|
mlir::torch::Torch::createReifyDtypeCalculationsPass(/*extraLibrary=*/"")
|
||||||
|
}];
|
||||||
|
let options = [
|
||||||
|
Option<"extraLibrary", "extra-library", "std::string", /*default=*/"",
|
||||||
|
"MLIR module for splicing into the dtype library">,
|
||||||
|
];
|
||||||
let description = [{
|
let description = [{
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
@ -291,7 +309,7 @@ def LowerToBackendContract
|
||||||
let summary = "Perform simplifications until the backend contract is satisfied.";
|
let summary = "Perform simplifications until the backend contract is satisfied.";
|
||||||
let constructor = [{
|
let constructor = [{
|
||||||
mlir::torch::Torch::createLowerToBackendContractPass(
|
mlir::torch::Torch::createLowerToBackendContractPass(
|
||||||
/*maxIterations=*/10, /*decompose=*/true, /*backendLegalOps=*/{})
|
/*maxIterations=*/10, /*decompose=*/true, /*backendLegalOps=*/{}, /*extraLibrary=*/"")
|
||||||
}];
|
}];
|
||||||
let description = [{
|
let description = [{
|
||||||
This pass performs the bulk of the lowering of the program's computations
|
This pass performs the bulk of the lowering of the program's computations
|
||||||
|
@ -335,7 +353,9 @@ def LowerToBackendContract
|
||||||
Option<"decompose", "decompose", "bool", /*default=*/"true",
|
Option<"decompose", "decompose", "bool", /*default=*/"true",
|
||||||
"Decompose ops.">,
|
"Decompose ops.">,
|
||||||
ListOption<"backendLegalOps", "backend-legal-ops", "std::string",
|
ListOption<"backendLegalOps", "backend-legal-ops", "std::string",
|
||||||
"List of ops to be considered legal for the backend.">
|
"List of ops to be considered legal for the backend, such as 'aten.foo'.">,
|
||||||
|
Option<"extraLibrary", "extra-library", "std::string", /*default=*/"",
|
||||||
|
"MLIR module for splicing into the abstract interpretation library">,
|
||||||
|
|
||||||
];
|
];
|
||||||
// TODO: Debug why this is needed, even though the input program has func.func
|
// TODO: Debug why this is needed, even though the input program has func.func
|
||||||
|
|
|
@ -3890,7 +3890,7 @@ private:
|
||||||
// on `Operation *` are not allowed, since there is no way of telling if
|
// on `Operation *` are not allowed, since there is no way of telling if
|
||||||
// that pattern will match on an op in the `legalOpsSet` or not.
|
// that pattern will match on an op in the `legalOpsSet` or not.
|
||||||
assert(opName && "All decomposition patterns must target a single op");
|
assert(opName && "All decomposition patterns must target a single op");
|
||||||
if (!legalOpsSet.contains(opName->getStringRef()))
|
if (!legalOpsSet.contains(opName->getStringRef().ltrim(kTorchOpPrefix)))
|
||||||
patterns.add<DecomposePattern>(context);
|
patterns.add<DecomposePattern>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
|
#include "llvm/ADT/StringSet.h"
|
||||||
|
|
||||||
#define DEBUG_TYPE "torch-lower-to-backend-contract"
|
#define DEBUG_TYPE "torch-lower-to-backend-contract"
|
||||||
|
|
||||||
|
@ -31,7 +32,7 @@ using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
||||||
ConversionTarget &target,
|
ConversionTarget &target,
|
||||||
ArrayRef<std::string> backendLegalOps);
|
llvm::StringSet<> backendLegalOps);
|
||||||
|
|
||||||
static LogicalResult checkType(Operation *op, Type type,
|
static LogicalResult checkType(Operation *op, Type type,
|
||||||
bool actuallyEmitDiagnostics) {
|
bool actuallyEmitDiagnostics) {
|
||||||
|
@ -246,11 +247,11 @@ static bool satisfiesBackendContract(ModuleOp module,
|
||||||
// Explicitly set ops and dialects allowed and not allowed in backend contract.
|
// Explicitly set ops and dialects allowed and not allowed in backend contract.
|
||||||
static ConversionTarget
|
static ConversionTarget
|
||||||
getBackendContractTarget(MLIRContext *context, bool decompose,
|
getBackendContractTarget(MLIRContext *context, bool decompose,
|
||||||
ArrayRef<std::string> backendLegalOps) {
|
llvm::StringSet<> backendLegalOpsSet) {
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
target.addLegalDialect<func::FuncDialect, Torch::TorchDialect>();
|
target.addLegalDialect<func::FuncDialect, Torch::TorchDialect>();
|
||||||
if (decompose)
|
if (decompose)
|
||||||
markDecomposedOpsAsIllegal(context, target, backendLegalOps);
|
markDecomposedOpsAsIllegal(context, target, backendLegalOpsSet);
|
||||||
return target;
|
return target;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -260,21 +261,27 @@ class LowerToBackendContractPass
|
||||||
public:
|
public:
|
||||||
LowerToBackendContractPass() = default;
|
LowerToBackendContractPass() = default;
|
||||||
LowerToBackendContractPass(int maxIterations, bool decompose,
|
LowerToBackendContractPass(int maxIterations, bool decompose,
|
||||||
ArrayRef<std::string> backendLegalOps) {
|
ArrayRef<std::string> backendLegalOps,
|
||||||
|
StringRef extraLibrary) {
|
||||||
this->maxIterations = maxIterations;
|
this->maxIterations = maxIterations;
|
||||||
this->decompose = decompose;
|
this->decompose = decompose;
|
||||||
this->backendLegalOps = backendLegalOps;
|
this->backendLegalOps = backendLegalOps;
|
||||||
|
this->extraLibrary = extraLibrary.str();
|
||||||
}
|
}
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
ModuleOp module = getOperation();
|
ModuleOp module = getOperation();
|
||||||
MLIRContext *context = &getContext();
|
MLIRContext *context = &getContext();
|
||||||
|
|
||||||
|
backendLegalOpsSet.clear();
|
||||||
|
backendLegalOpsSet.insert(backendLegalOps.begin(), backendLegalOps.end());
|
||||||
ConversionTarget target =
|
ConversionTarget target =
|
||||||
getBackendContractTarget(context, decompose, backendLegalOps);
|
getBackendContractTarget(context, decompose, backendLegalOpsSet);
|
||||||
|
|
||||||
OpPassManager pm(module.getOperationName());
|
OpPassManager pm(module.getOperationName());
|
||||||
TorchLoweringPipelineOptions options;
|
TorchLoweringPipelineOptions options;
|
||||||
options.decompose = decompose;
|
options.decompose = decompose;
|
||||||
options.backendLegalOps = backendLegalOps;
|
options.backendLegalOps = backendLegalOps;
|
||||||
|
options.extraLibrary = extraLibrary;
|
||||||
createTorchSimplificationPipeline(pm, options);
|
createTorchSimplificationPipeline(pm, options);
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
|
@ -301,6 +308,8 @@ public:
|
||||||
<< " iterations of the simplification pipeline\n";
|
<< " iterations of the simplification pipeline\n";
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
private:
|
||||||
|
llvm::StringSet<> backendLegalOpsSet;
|
||||||
};
|
};
|
||||||
|
|
||||||
class VerifyBackendContractNoDecompositionsPass
|
class VerifyBackendContractNoDecompositionsPass
|
||||||
|
@ -312,7 +321,7 @@ public:
|
||||||
MLIRContext *context = &getContext();
|
MLIRContext *context = &getContext();
|
||||||
ConversionTarget target =
|
ConversionTarget target =
|
||||||
getBackendContractTarget(context, /*decompose*/false,
|
getBackendContractTarget(context, /*decompose*/false,
|
||||||
/*backendLegalOps*/{});
|
/*backendLegalOpsSet*/{});
|
||||||
|
|
||||||
if (!satisfiesBackendContract(getOperation(), target,
|
if (!satisfiesBackendContract(getOperation(), target,
|
||||||
/*actuallyEmitDiagnostics=*/true)) {
|
/*actuallyEmitDiagnostics=*/true)) {
|
||||||
|
@ -324,9 +333,10 @@ public:
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
mlir::torch::Torch::createLowerToBackendContractPass(
|
mlir::torch::Torch::createLowerToBackendContractPass(
|
||||||
int maxIterations, bool decompose, ArrayRef<std::string> backendLegalOps) {
|
int maxIterations, bool decompose, ArrayRef<std::string> backendLegalOps,
|
||||||
return std::make_unique<LowerToBackendContractPass>(maxIterations, decompose,
|
StringRef extraLibrary) {
|
||||||
backendLegalOps);
|
return std::make_unique<LowerToBackendContractPass>(
|
||||||
|
maxIterations, decompose, backendLegalOps, extraLibrary);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
@ -337,9 +347,9 @@ mlir::torch::Torch::createVerifyBackendContractNoDecompositionsPass() {
|
||||||
// The backend contract guarantees that ops with decompositions available will
|
// The backend contract guarantees that ops with decompositions available will
|
||||||
// be decomposed. The only way to have an op reach the backend contract without
|
// be decomposed. The only way to have an op reach the backend contract without
|
||||||
// getting decomposed is by having the user explicitly specify that op in the
|
// getting decomposed is by having the user explicitly specify that op in the
|
||||||
// `backendLegalOps` argument to the `LowerToBackendContractPass`. Therefore,
|
// `backendLegalOpsSet` argument to the `LowerToBackendContractPass`. Therefore,
|
||||||
// here we mark as illegal all ops with decompositions except for those in
|
// here we mark as illegal all ops with decompositions except for those in
|
||||||
// `backendLegalOps`.
|
// `backendLegalOpsSet`.
|
||||||
//
|
//
|
||||||
// The legality check takes place here instead of in the `DecomposeComplexOps`
|
// The legality check takes place here instead of in the `DecomposeComplexOps`
|
||||||
// pass for two reasons:
|
// pass for two reasons:
|
||||||
|
@ -352,7 +362,7 @@ mlir::torch::Torch::createVerifyBackendContractNoDecompositionsPass() {
|
||||||
// decompositions explicit in this file
|
// decompositions explicit in this file
|
||||||
static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
||||||
ConversionTarget &target,
|
ConversionTarget &target,
|
||||||
ArrayRef<std::string> backendLegalOps) {
|
llvm::StringSet<> backendLegalOpsSet) {
|
||||||
target.addIllegalOp<AtenSoftmaxIntOp>();
|
target.addIllegalOp<AtenSoftmaxIntOp>();
|
||||||
target.addIllegalOp<Aten_SoftmaxOp>();
|
target.addIllegalOp<Aten_SoftmaxOp>();
|
||||||
target.addIllegalOp<Aten_LogSoftmaxOp>();
|
target.addIllegalOp<Aten_LogSoftmaxOp>();
|
||||||
|
@ -463,7 +473,13 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
||||||
target.addIllegalOp<AtenVarMeanOp>();
|
target.addIllegalOp<AtenVarMeanOp>();
|
||||||
target.addIllegalOp<AtenNewEmptyStridedOp>();
|
target.addIllegalOp<AtenNewEmptyStridedOp>();
|
||||||
target.addIllegalOp<AtenBucketizeTensorOp>();
|
target.addIllegalOp<AtenBucketizeTensorOp>();
|
||||||
for (std::string opName : backendLegalOps) {
|
for (auto &opName : backendLegalOpsSet) {
|
||||||
target.addLegalOp(OperationName(opName, context));
|
target.addLegalOp(
|
||||||
|
OperationName(kTorchOpPrefix + opName.first().str(), context));
|
||||||
}
|
}
|
||||||
|
target.addDynamicallyLegalOp<OperatorOp>(
|
||||||
|
[backendLegalOpsSet](OperatorOp opOp) {
|
||||||
|
auto opName = opOp->getAttr("name").cast<StringAttr>().getValue();
|
||||||
|
return backendLegalOpsSet.contains(opName);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,7 +25,7 @@ void mlir::torch::registerTorchPasses() {
|
||||||
"torch-simplification-pipeline",
|
"torch-simplification-pipeline",
|
||||||
"Pipeline simplifying computations in the program.",
|
"Pipeline simplifying computations in the program.",
|
||||||
mlir::torch::Torch::createTorchSimplificationPipeline);
|
mlir::torch::Torch::createTorchSimplificationPipeline);
|
||||||
mlir::PassPipelineRegistration<>(
|
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||||
"torch-shape-refinement-pipeline", "Pipeline refining shapes of tensors.",
|
"torch-shape-refinement-pipeline", "Pipeline refining shapes of tensors.",
|
||||||
mlir::torch::Torch::createTorchShapeRefinementPipeline);
|
mlir::torch::Torch::createTorchShapeRefinementPipeline);
|
||||||
}
|
}
|
||||||
|
@ -66,7 +66,8 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
|
||||||
// Perform the bulk of lowering to the backend contract.
|
// Perform the bulk of lowering to the backend contract.
|
||||||
// See the pass documentation for more information.
|
// See the pass documentation for more information.
|
||||||
pm.addPass(createLowerToBackendContractPass(
|
pm.addPass(createLowerToBackendContractPass(
|
||||||
options.maxIterations, options.decompose, options.backendLegalOps));
|
options.maxIterations, options.decompose, options.backendLegalOps,
|
||||||
|
options.extraLibrary));
|
||||||
}
|
}
|
||||||
|
|
||||||
// A simplification pipeline to establish the invariants of the backend
|
// A simplification pipeline to establish the invariants of the backend
|
||||||
|
@ -108,7 +109,8 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
|
||||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||||
pm.addNestedPass<func::FuncOp>(createRecomposeComplexOps());
|
pm.addNestedPass<func::FuncOp>(createRecomposeComplexOps());
|
||||||
// Reduce variants of ops to a smaller set of primitives.
|
// Reduce variants of ops to a smaller set of primitives.
|
||||||
pm.addNestedPass<func::FuncOp>(createReduceOpVariantsPass());
|
pm.addNestedPass<func::FuncOp>(
|
||||||
|
createReduceOpVariantsPass(options.extraLibrary));
|
||||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||||
// Remove dead global slots.
|
// Remove dead global slots.
|
||||||
pm.addPass(createSymbolDCEPass());
|
pm.addPass(createSymbolDCEPass());
|
||||||
|
@ -121,8 +123,8 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
|
||||||
// This should be run before RefineTypes (which primarily does dtype
|
// This should be run before RefineTypes (which primarily does dtype
|
||||||
// inference), because Torch type promotion rules actually depend on the shape
|
// inference), because Torch type promotion rules actually depend on the shape
|
||||||
// of the operand.
|
// of the operand.
|
||||||
createTorchShapeRefinementPipeline(pm);
|
createTorchShapeRefinementPipeline(pm, options);
|
||||||
createTorchDtypeRefinementPipeline(pm);
|
createTorchDtypeRefinementPipeline(pm, options);
|
||||||
// Refine types in the program, which mainly means inferring dtypes of ops.
|
// Refine types in the program, which mainly means inferring dtypes of ops.
|
||||||
pm.addNestedPass<func::FuncOp>(Torch::createRefineTypesPass());
|
pm.addNestedPass<func::FuncOp>(Torch::createRefineTypesPass());
|
||||||
// Propagate to ABI return types the shape/dtype information discovered by
|
// Propagate to ABI return types the shape/dtype information discovered by
|
||||||
|
@ -141,13 +143,15 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
|
||||||
|
|
||||||
static void createRefinementPipeline(
|
static void createRefinementPipeline(
|
||||||
mlir::OpPassManager &pm,
|
mlir::OpPassManager &pm,
|
||||||
llvm::function_ref<std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>()>
|
llvm::function_ref<
|
||||||
|
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>(llvm::StringRef)>
|
||||||
reifyCalculationsPass,
|
reifyCalculationsPass,
|
||||||
llvm::function_ref<
|
llvm::function_ref<
|
||||||
std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>()>
|
std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>()>
|
||||||
simplifyCalculationsPass) {
|
simplifyCalculationsPass,
|
||||||
|
const mlir::torch::Torch::TorchLoweringPipelineOptions &options) {
|
||||||
// Reify the library functions for each op that is present in the library.
|
// Reify the library functions for each op that is present in the library.
|
||||||
pm.addPass(reifyCalculationsPass());
|
pm.addPass(reifyCalculationsPass(options.extraLibrary));
|
||||||
|
|
||||||
// Inline the library functions to enable analysis and transformation.
|
// Inline the library functions to enable analysis and transformation.
|
||||||
// TODO: Only inline library functions (this will currently inline
|
// TODO: Only inline library functions (this will currently inline
|
||||||
|
@ -168,12 +172,14 @@ static void createRefinementPipeline(
|
||||||
mlir::torch::Torch::createDropAbstractInterpCalculationsPass());
|
mlir::torch::Torch::createDropAbstractInterpCalculationsPass());
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::torch::Torch::createTorchShapeRefinementPipeline(OpPassManager &pm) {
|
void mlir::torch::Torch::createTorchShapeRefinementPipeline(
|
||||||
|
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
|
||||||
createRefinementPipeline(pm, Torch::createReifyShapeCalculationsPass,
|
createRefinementPipeline(pm, Torch::createReifyShapeCalculationsPass,
|
||||||
Torch::createSimplifyShapeCalculationsPass);
|
Torch::createSimplifyShapeCalculationsPass, options);
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::torch::Torch::createTorchDtypeRefinementPipeline(OpPassManager &pm) {
|
void mlir::torch::Torch::createTorchDtypeRefinementPipeline(
|
||||||
|
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
|
||||||
createRefinementPipeline(pm, Torch::createReifyDtypeCalculationsPass,
|
createRefinementPipeline(pm, Torch::createReifyDtypeCalculationsPass,
|
||||||
Torch::createSimplifyDtypeCalculationsPass);
|
Torch::createSimplifyDtypeCalculationsPass, options);
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
#include "ReifyAbstractInterpCalculationsUtils.h"
|
||||||
#include "llvm/ADT/StringExtras.h"
|
#include "llvm/ADT/StringExtras.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
@ -52,17 +53,39 @@ static Type getContainerOrTensorTypeWithValueSemantics(Type type) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool
|
||||||
|
operatorOpHasValueSemantics(OperatorOp opOp,
|
||||||
|
std::optional<SymbolTable> extraLibrary) {
|
||||||
|
if (!extraLibrary.has_value())
|
||||||
|
return false;
|
||||||
|
auto opName = opOp->getAttr("name").cast<StringAttr>().getValue();
|
||||||
|
std::string libFuncName = (mlir::torch::Torch::getLibraryFunctionPrefix(
|
||||||
|
LibraryFunctionKind::HasValueSemantics) +
|
||||||
|
Twine(opName))
|
||||||
|
.str();
|
||||||
|
auto libFunc = extraLibrary->lookup<func::FuncOp>(libFuncName);
|
||||||
|
return bool(libFunc);
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Convert value semantic ops operating on mutable arrays to instead operate on
|
// Convert value semantic ops operating on mutable arrays to instead operate on
|
||||||
// immutable tensors.
|
// immutable tensors.
|
||||||
class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
|
class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
|
||||||
public:
|
public:
|
||||||
ConvertHasValueSemanticsOpsToValueTensors(MLIRContext *context)
|
ConvertHasValueSemanticsOpsToValueTensors(MLIRContext *context,
|
||||||
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
|
const std::optional<SymbolTable>& extraLibrary)
|
||||||
|
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {
|
||||||
|
this->extraLibrary = extraLibrary;
|
||||||
|
}
|
||||||
LogicalResult matchAndRewrite(Operation *op,
|
LogicalResult matchAndRewrite(Operation *op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
if (!op->hasTrait<Torch::OpTrait::HasValueSemantics>())
|
if (isa<OperatorOp>(op)) {
|
||||||
|
if (!operatorOpHasValueSemantics(cast<OperatorOp>(op), extraLibrary)) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "does not have value semantics");
|
||||||
|
}
|
||||||
|
} else if (!op->hasTrait<Torch::OpTrait::HasValueSemantics>()) {
|
||||||
return rewriter.notifyMatchFailure(op, "does not have value semantics");
|
return rewriter.notifyMatchFailure(op, "does not have value semantics");
|
||||||
|
}
|
||||||
|
|
||||||
rewriter.startRootUpdate(op);
|
rewriter.startRootUpdate(op);
|
||||||
// Convert all operands.
|
// Convert all operands.
|
||||||
|
@ -160,6 +183,8 @@ public:
|
||||||
rewriter.finalizeRootUpdate(op);
|
rewriter.finalizeRootUpdate(op);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
private:
|
||||||
|
std::optional<SymbolTable> extraLibrary;
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
@ -241,11 +266,30 @@ reduceNonValueTensorLiteralOpToValueTensorLiteralOp(NonValueTensorLiteralOp op,
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
|
struct ReduceOpVariantsPass
|
||||||
|
: public ReduceOpVariantsBase<ReduceOpVariantsPass> {
|
||||||
|
ReduceOpVariantsPass() = default;
|
||||||
|
ReduceOpVariantsPass(StringRef extraLibrary) {
|
||||||
|
this->extraLibrary = extraLibrary.str();
|
||||||
|
}
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
MLIRContext *context = &getContext();
|
MLIRContext *context = &getContext();
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
patterns.add<ConvertHasValueSemanticsOpsToValueTensors>(context);
|
OwningOpRef<ModuleOp> extraLibraryModule =
|
||||||
|
ModuleOp::create(UnknownLoc::get(context));
|
||||||
|
std::optional<SymbolTable> extraLibraryModuleSymTable = std::nullopt;
|
||||||
|
if (!extraLibrary.empty()) {
|
||||||
|
if (failed(loadExtraLibrary(extraLibrary, extraLibraryModule))) {
|
||||||
|
emitError(getOperation()->getLoc(),
|
||||||
|
"Failed to load extra-library file at " + extraLibrary);
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
|
||||||
|
extraLibraryModuleSymTable =
|
||||||
|
SymbolTable(extraLibraryModule->getOperation());
|
||||||
|
}
|
||||||
|
patterns.add<ConvertHasValueSemanticsOpsToValueTensors>(
|
||||||
|
context, extraLibraryModuleSymTable);
|
||||||
patterns.add<ReduceTrailingUnderscoreInplaceVariant>(context);
|
patterns.add<ReduceTrailingUnderscoreInplaceVariant>(context);
|
||||||
patterns.add(reduceNonValueTensorLiteralOpToValueTensorLiteralOp);
|
patterns.add(reduceNonValueTensorLiteralOpToValueTensorLiteralOp);
|
||||||
patterns.add<ReduceNonValueSemanticOps>(context);
|
patterns.add<ReduceNonValueSemanticOps>(context);
|
||||||
|
@ -253,8 +297,12 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
target.addIllegalOp<NonValueTensorLiteralOp>();
|
target.addIllegalOp<NonValueTensorLiteralOp>();
|
||||||
target.addIllegalOp<AtenBernoulli_FloatOp>();
|
target.addIllegalOp<AtenBernoulli_FloatOp>();
|
||||||
target.markUnknownOpDynamicallyLegal([](Operation *op) {
|
target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable](
|
||||||
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>()) {
|
Operation *op) {
|
||||||
|
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>() ||
|
||||||
|
(isa<OperatorOp>(op) &&
|
||||||
|
operatorOpHasValueSemantics(cast<OperatorOp>(op),
|
||||||
|
extraLibraryModuleSymTable))) {
|
||||||
auto hasValueSemantics = [](Type t) {
|
auto hasValueSemantics = [](Type t) {
|
||||||
// TODO: Make this an allowlist based on a closed torch dialect
|
// TODO: Make this an allowlist based on a closed torch dialect
|
||||||
// type system.
|
// type system.
|
||||||
|
@ -281,6 +329,6 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
mlir::torch::Torch::createReduceOpVariantsPass() {
|
mlir::torch::Torch::createReduceOpVariantsPass(StringRef extraLibrary) {
|
||||||
return std::make_unique<ReduceOpVariantsPass>();
|
return std::make_unique<ReduceOpVariantsPass>(extraLibrary);
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,18 +8,25 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "ReifyAbstractInterpCalculationsUtils.h"
|
#include "ReifyAbstractInterpCalculationsUtils.h"
|
||||||
|
#include "mlir/Parser/Parser.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "llvm/ADT/StringSet.h"
|
#include "llvm/ADT/StringSet.h"
|
||||||
|
#include "llvm/Support/ErrorOr.h"
|
||||||
|
#include "llvm/Support/MemoryBuffer.h"
|
||||||
|
#include "llvm/Support/SourceMgr.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
static std::string getLibraryFunctionPrefix(LibraryFunctionKind libFuncKind) {
|
std::string
|
||||||
|
mlir::torch::Torch::getLibraryFunctionPrefix(LibraryFunctionKind libFuncKind) {
|
||||||
if (libFuncKind == LibraryFunctionKind::ShapeFunction)
|
if (libFuncKind == LibraryFunctionKind::ShapeFunction)
|
||||||
return "__torch_mlir_shape_fn.";
|
return "__torch_mlir_shape_fn.";
|
||||||
else if (libFuncKind == LibraryFunctionKind::DtypeFunction)
|
else if (libFuncKind == LibraryFunctionKind::DtypeFunction)
|
||||||
return "__torch_mlir_dtype_fn.";
|
return "__torch_mlir_dtype_fn.";
|
||||||
|
else if (libFuncKind == LibraryFunctionKind::HasValueSemantics)
|
||||||
|
return "__torch_mlir_has_value_semantics_fn.";
|
||||||
llvm_unreachable(
|
llvm_unreachable(
|
||||||
"`getLibraryFunctionPrefix` called with an unsupported `CalculateOp`");
|
"`getLibraryFunctionPrefix` called with an unsupported `CalculateOp`");
|
||||||
}
|
}
|
||||||
|
@ -73,6 +80,8 @@ LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable(
|
||||||
// looking them up in the library.
|
// looking them up in the library.
|
||||||
if (name.startswith("valsem."))
|
if (name.startswith("valsem."))
|
||||||
name = name.drop_front(strlen("valsem."));
|
name = name.drop_front(strlen("valsem."));
|
||||||
|
if (isa<OperatorOp>(op))
|
||||||
|
name = cast<OperatorOp>(op)->getAttr("name").cast<StringAttr>().getValue();
|
||||||
std::string libFuncName =
|
std::string libFuncName =
|
||||||
(getLibraryFunctionPrefix(libFuncKind) + Twine(name)).str();
|
(getLibraryFunctionPrefix(libFuncKind) + Twine(name)).str();
|
||||||
auto libFunc = library.lookupSymbol<func::FuncOp>(libFuncName);
|
auto libFunc = library.lookupSymbol<func::FuncOp>(libFuncName);
|
||||||
|
@ -288,3 +297,39 @@ FailureOr<Value> Torch::adjustFunctionArg(
|
||||||
// Pass the operand as-is.
|
// Pass the operand as-is.
|
||||||
return operand;
|
return operand;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
mlir::torch::Torch::loadExtraLibrary(const std::string &filename,
|
||||||
|
OwningOpRef<ModuleOp> &moduleToAppendTo) {
|
||||||
|
auto ctx = moduleToAppendTo->getContext();
|
||||||
|
assert(ctx && "Module should be fully initialized.");
|
||||||
|
|
||||||
|
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||||||
|
llvm::MemoryBuffer::getFileOrSTDIN(filename);
|
||||||
|
if (std::error_code ec = fileOrErr.getError()) {
|
||||||
|
llvm::errs() << "Could not open input file: " << ec.message() << "\n";
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SourceMgr sourceMgr;
|
||||||
|
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
|
||||||
|
OwningOpRef<ModuleOp> module_ =
|
||||||
|
mlir::parseSourceFile<mlir::ModuleOp>(sourceMgr, ctx);
|
||||||
|
if (!module_) {
|
||||||
|
llvm::errs() << "Error can't load file " << filename << "\n";
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
assert((moduleToAppendTo->getBodyRegion().empty() ||
|
||||||
|
moduleToAppendTo->getBodyRegion().hasOneBlock()) &&
|
||||||
|
"Module should have at most one block.");
|
||||||
|
if (moduleToAppendTo->getBodyRegion().empty()) {
|
||||||
|
moduleToAppendTo = std::move(module_);
|
||||||
|
} else {
|
||||||
|
Block *block = moduleToAppendTo->getBody(0);
|
||||||
|
block->getOperations().splice(block->end(),
|
||||||
|
module_->getBody(0)->getOperations());
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
|
@ -22,7 +22,12 @@ namespace mlir {
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace Torch {
|
namespace Torch {
|
||||||
|
|
||||||
enum class LibraryFunctionKind { ShapeFunction, DtypeFunction, Decomposition };
|
enum class LibraryFunctionKind {
|
||||||
|
ShapeFunction,
|
||||||
|
DtypeFunction,
|
||||||
|
Decomposition,
|
||||||
|
HasValueSemantics
|
||||||
|
};
|
||||||
|
|
||||||
// Searches the function library for an abstract interpretation function for
|
// Searches the function library for an abstract interpretation function for
|
||||||
// `op`. If one is found, wraps the op in a `CalculateOp`, with the op placed in
|
// `op`. If one is found, wraps the op in a `CalculateOp`, with the op placed in
|
||||||
|
@ -60,6 +65,16 @@ FailureOr<Value> adjustFunctionArg(
|
||||||
OpBuilder &b, Location loc, Value operand, Type desiredType,
|
OpBuilder &b, Location loc, Value operand, Type desiredType,
|
||||||
function_ref<Value(OpBuilder &, Location, Value, Type)> baseTransformation =
|
function_ref<Value(OpBuilder &, Location, Value, Type)> baseTransformation =
|
||||||
[](OpBuilder &, Location, Value operand, Type) { return operand; });
|
[](OpBuilder &, Location, Value operand, Type) { return operand; });
|
||||||
|
|
||||||
|
std::string getLibraryFunctionPrefix(LibraryFunctionKind libFuncKind);
|
||||||
|
|
||||||
|
// Parse MLIR module at `filename` into a ModuleOp that will then
|
||||||
|
// be appended to an existing, fully hydrated, ModuleOp; note the module
|
||||||
|
// should have been instantiated with an associated context like so:
|
||||||
|
// `OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context));`
|
||||||
|
LogicalResult loadExtraLibrary(const std::string &filename,
|
||||||
|
OwningOpRef<ModuleOp> &moduleToAppendTo);
|
||||||
|
|
||||||
} // namespace Torch
|
} // namespace Torch
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -61,13 +61,23 @@ dtypeFunctionArgsBuilder(OpBuilder &b, Location loc,
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ReifyDtypeCalculationsPass
|
struct ReifyDtypeCalculationsPass
|
||||||
: public ReifyDtypeCalculationsBase<ReifyDtypeCalculationsPass> {
|
: public ReifyDtypeCalculationsBase<ReifyDtypeCalculationsPass> {
|
||||||
|
ReifyDtypeCalculationsPass() = default;
|
||||||
|
ReifyDtypeCalculationsPass(StringRef extraLibrary) {
|
||||||
|
this->extraLibrary = extraLibrary.str();
|
||||||
|
}
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
MLIRContext *context = &getContext();
|
MLIRContext *context = &getContext();
|
||||||
ModuleOp module = getOperation();
|
ModuleOp module = getOperation();
|
||||||
OwningOpRef<ModuleOp> library =
|
OwningOpRef<ModuleOp> library =
|
||||||
parseSourceString<ModuleOp>(getAbstractInterpLibrary(), context);
|
parseSourceString<ModuleOp>(getAbstractInterpLibrary(), context);
|
||||||
|
if (!extraLibrary.empty())
|
||||||
|
if (failed(mlir::torch::Torch::loadExtraLibrary(extraLibrary, library))) {
|
||||||
|
emitError(module->getLoc(),
|
||||||
|
"Failed to load extra-library file at " + extraLibrary);
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
|
||||||
// Walk all the operations, and if we have a dtype function, wrap the op
|
// Walk all the operations, and if we have a dtype function, wrap the op
|
||||||
// in a `torch.dtype.calculate` op.
|
// in a `torch.dtype.calculate` op.
|
||||||
|
@ -86,6 +96,6 @@ class ReifyDtypeCalculationsPass
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
Torch::createReifyDtypeCalculationsPass() {
|
Torch::createReifyDtypeCalculationsPass(StringRef extraLibrary) {
|
||||||
return std::make_unique<ReifyDtypeCalculationsPass>();
|
return std::make_unique<ReifyDtypeCalculationsPass>(extraLibrary);
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
#include "llvm/Support/MemoryBuffer.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
|
@ -55,8 +56,12 @@ shapeFunctionArgsBuilder(OpBuilder &b, Location loc,
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ReifyShapeCalculationsPass
|
struct ReifyShapeCalculationsPass
|
||||||
: public ReifyShapeCalculationsBase<ReifyShapeCalculationsPass> {
|
: public ReifyShapeCalculationsBase<ReifyShapeCalculationsPass> {
|
||||||
|
ReifyShapeCalculationsPass() = default;
|
||||||
|
ReifyShapeCalculationsPass(StringRef extraLibrary) {
|
||||||
|
this->extraLibrary = extraLibrary.str();
|
||||||
|
}
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
MLIRContext *context = &getContext();
|
MLIRContext *context = &getContext();
|
||||||
ModuleOp module = getOperation();
|
ModuleOp module = getOperation();
|
||||||
|
@ -66,6 +71,12 @@ class ReifyShapeCalculationsPass
|
||||||
// O(#ops in the program) ideally.
|
// O(#ops in the program) ideally.
|
||||||
OwningOpRef<ModuleOp> library =
|
OwningOpRef<ModuleOp> library =
|
||||||
parseSourceString<ModuleOp>(getAbstractInterpLibrary(), context);
|
parseSourceString<ModuleOp>(getAbstractInterpLibrary(), context);
|
||||||
|
if (!extraLibrary.empty())
|
||||||
|
if (failed(mlir::torch::Torch::loadExtraLibrary(extraLibrary, library))) {
|
||||||
|
emitError(module->getLoc(),
|
||||||
|
"Failed to load extra-library file at " + extraLibrary);
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
|
||||||
// Walk all the operations, and if we have a shape function, wrap the op
|
// Walk all the operations, and if we have a shape function, wrap the op
|
||||||
// in a `torch.shape.calculate` op.
|
// in a `torch.shape.calculate` op.
|
||||||
|
@ -84,6 +95,6 @@ class ReifyShapeCalculationsPass
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
mlir::torch::Torch::createReifyShapeCalculationsPass() {
|
mlir::torch::Torch::createReifyShapeCalculationsPass(StringRef extraLibrary) {
|
||||||
return std::make_unique<ReifyShapeCalculationsPass>();
|
return std::make_unique<ReifyShapeCalculationsPass>(extraLibrary);
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,6 @@ class AddmmModule(torch.nn.Module):
|
||||||
example_args = 3 * [torch_mlir.TensorPlaceholder([-1, -1], torch.float32)]
|
example_args = 3 * [torch_mlir.TensorPlaceholder([-1, -1], torch.float32)]
|
||||||
|
|
||||||
print(torch_mlir.compile(AddmmModule(), example_args,
|
print(torch_mlir.compile(AddmmModule(), example_args,
|
||||||
output_type="torch", backend_legal_ops=["torch.aten.addmm"]))
|
output_type="torch", backend_legal_ops=["aten.addmm"]))
|
||||||
# CHECK-LABEL: @forward
|
# CHECK-LABEL: @forward
|
||||||
# CHECK: torch.aten.addmm
|
# CHECK: torch.aten.addmm
|
||||||
|
|
|
@ -240,8 +240,8 @@ class ExampleArgs:
|
||||||
# ops in the backend contract, and move these lists somewhere deeper in the
|
# ops in the backend contract, and move these lists somewhere deeper in the
|
||||||
# compiler where each backend can "own" its set of legal ops.
|
# compiler where each backend can "own" its set of legal ops.
|
||||||
BACKEND_LEGAL_OPS = {
|
BACKEND_LEGAL_OPS = {
|
||||||
OutputType.TOSA: ['torch.aten.flatten.using_ints', 'torch.aten.native_layer_norm', 'torch.aten.linear'],
|
OutputType.TOSA: ['aten.flatten.using_ints', 'aten.native_layer_norm', 'aten.linear'],
|
||||||
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', ],
|
OutputType.LINALG_ON_TENSORS: ['aten.flatten.using_ints', ],
|
||||||
OutputType.STABLEHLO: [],
|
OutputType.STABLEHLO: [],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -252,6 +252,7 @@ def compile(model: torch.nn.Module,
|
||||||
use_tracing: bool = False,
|
use_tracing: bool = False,
|
||||||
ignore_traced_shapes=False,
|
ignore_traced_shapes=False,
|
||||||
backend_legal_ops: Optional[Sequence[str]] = None,
|
backend_legal_ops: Optional[Sequence[str]] = None,
|
||||||
|
_completely_unsupported_in_progress_extra_library: Optional[str] = None,
|
||||||
verbose: bool = False):
|
verbose: bool = False):
|
||||||
"""Convert a PyTorch model to MLIR.
|
"""Convert a PyTorch model to MLIR.
|
||||||
|
|
||||||
|
@ -367,7 +368,11 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
|
||||||
if output_type == OutputType.RAW:
|
if output_type == OutputType.RAW:
|
||||||
return mb.module
|
return mb.module
|
||||||
|
|
||||||
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + "}"
|
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + (
|
||||||
|
(" extra-library=" + _completely_unsupported_in_progress_extra_library)
|
||||||
|
if (_completely_unsupported_in_progress_extra_library is not None)
|
||||||
|
else ""
|
||||||
|
) + "}"
|
||||||
run_pipeline_with_repro_report(
|
run_pipeline_with_repro_report(
|
||||||
mb.module,
|
mb.module,
|
||||||
f"builtin.module(torchscript-module-to-torch-backend-pipeline{option_string})",
|
f"builtin.module(torchscript-module-to-torch-backend-pipeline{option_string})",
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: torch-mlir-opt -torch-decompose-complex-ops="legal-ops=torch.aten.softmax.int" -split-input-file %s | FileCheck %s
|
// RUN: torch-mlir-opt -torch-decompose-complex-ops="legal-ops=aten.softmax.int" -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.softmax.int$cst_dim
|
// CHECK-LABEL: func.func @torch.aten.softmax.int$cst_dim
|
||||||
func.func @torch.aten.softmax.int$cst_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.tensor<[2,3],f32> {
|
func.func @torch.aten.softmax.int$cst_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.tensor<[2,3],f32> {
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-function-to-torch-backend-pipeline{backend-legal-ops=torch.aten.square,torch.aten.argmax})' -split-input-file %s | FileCheck %s
|
// RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-function-to-torch-backend-pipeline{backend-legal-ops=aten.square,aten.argmax})' -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.square
|
// CHECK-LABEL: func.func @torch.aten.square
|
||||||
func.func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
func.func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||||
|
|
|
@ -0,0 +1,72 @@
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.cpp_extension
|
||||||
|
import torch_mlir
|
||||||
|
from torch_mlir_e2e_test.annotations import export, annotate_args
|
||||||
|
|
||||||
|
|
||||||
|
# RUN: %PYTHON %s | FileCheck %s
|
||||||
|
|
||||||
|
|
||||||
|
def identity(x: torch.Tensor):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
goofy_lib = torch.library.Library("goofy", "DEF")
|
||||||
|
goofy_lib.define("identity(Tensor t) -> Tensor")
|
||||||
|
goofy_lib.impl("identity", identity)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomOpExampleModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, a):
|
||||||
|
b = 2 * a
|
||||||
|
return torch.ops.goofy.identity(b)
|
||||||
|
|
||||||
|
|
||||||
|
mod = CustomOpExampleModule()
|
||||||
|
mod.eval()
|
||||||
|
|
||||||
|
abstract_interp_src = """\
|
||||||
|
func.func @__torch_mlir_shape_fn.goofy.identity(%arg0: !torch.list<int>) -> !torch.list<int> {
|
||||||
|
return %arg0 : !torch.list<int>
|
||||||
|
}
|
||||||
|
func.func @__torch_mlir_dtype_fn.goofy.identity(%arg0 : !torch.tuple<int, int>) -> !torch.int {
|
||||||
|
%0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int
|
||||||
|
return %0#1 : !torch.int
|
||||||
|
}
|
||||||
|
func.func @__torch_mlir_has_value_semantics_fn.goofy.identity() { return }
|
||||||
|
"""
|
||||||
|
|
||||||
|
with open("/tmp/custom_op_shape_dtype_fn.mlir", "w") as tmp:
|
||||||
|
tmp.write(abstract_interp_src)
|
||||||
|
|
||||||
|
module = torch_mlir.compile(
|
||||||
|
mod,
|
||||||
|
torch.ones(3, 4),
|
||||||
|
output_type="torch",
|
||||||
|
backend_legal_ops=["goofy.identity"],
|
||||||
|
_completely_unsupported_in_progress_extra_library="/tmp/custom_op_shape_dtype_fn.mlir",
|
||||||
|
)
|
||||||
|
|
||||||
|
print(module)
|
||||||
|
|
||||||
|
# CHECK: module attributes {torch.debug_module_name = "CustomOpExampleModule"} {
|
||||||
|
# CHECK: func.func @forward(%{{.*}}: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {
|
||||||
|
# CHECK: %{{.*}} = torch.constant.int 2
|
||||||
|
# CHECK: %{{.*}} = torch.aten.mul.Scalar %{{.*}}, %{{.*}} : !torch.vtensor<[3,4],f32>, !torch.int -> !torch.vtensor<[3,4],f32>
|
||||||
|
# CHECK: %{{.*}} = torch.operator "goofy.identity"(%{{.*}}) : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
|
||||||
|
# CHECK: return %1 : !torch.vtensor<[3,4],f32>
|
||||||
|
# CHECK: }
|
||||||
|
# CHECK: }
|
Loading…
Reference in New Issue