[FxImporter] Eliminate the dependency on the refinement pass (#3309)

pull/3321/head
penguin_wwy 2024-05-10 02:44:36 +08:00 committed by GitHub
parent afe87d62b4
commit 64b59c7fc3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 30 additions and 15 deletions

View File

@ -42,6 +42,9 @@ struct TorchLoweringPipelineOptions
Option<bool> decompose{*this, "decompose-complex-ops",
llvm::cl::desc("Decompose complex operations."),
llvm::cl::init(true)};
Option<bool> shapeDtypeRefine{
*this, "shape-dtype-refine",
llvm::cl::desc("Do shape and dtype refinement."), llvm::cl::init(true)};
// A list of ops that should be considered legal for the backend.
// TODO: The meaning of this list should be formalized.
// A sketch of the semantics would be:
@ -130,10 +133,9 @@ createDropAbstractInterpCalculationsPass();
std::unique_ptr<OperationPass<ModuleOp>> createEraseModuleInitializerPass();
std::unique_ptr<OperationPass<ModuleOp>>
createLowerToBackendContractPass(int maxIterations, bool decompose,
ArrayRef<std::string> backendLegalOps,
StringRef extraLibrary);
std::unique_ptr<OperationPass<ModuleOp>> createLowerToBackendContractPass(
int maxIterations, bool decompose, bool shapeDtypeRefine,
ArrayRef<std::string> backendLegalOps, StringRef extraLibrary);
std::unique_ptr<OperationPass<ModuleOp>>
createVerifyBackendContractNoDecompositionsPass();

View File

@ -362,7 +362,7 @@ def LowerToBackendContract
let summary = "Perform simplifications until the backend contract is satisfied.";
let constructor = [{
mlir::torch::Torch::createLowerToBackendContractPass(
/*maxIterations=*/10, /*decompose=*/true, /*backendLegalOps=*/{}, /*extraLibrary=*/"")
/*maxIterations=*/10, /*decompose=*/true, /*shapeDtypeRefine*/true, /*backendLegalOps=*/{}, /*extraLibrary=*/"")
}];
let description = [{
This pass performs the bulk of the lowering of the program's computations
@ -405,6 +405,8 @@ def LowerToBackendContract
"Maximum number of invocations of the simplification pipeline.">,
Option<"decompose", "decompose", "bool", /*default=*/"true",
"Decompose ops.">,
Option<"shapeDtypeRefine", "shape-dtype-refine", "bool", /*default=*/"true",
"Do shape and dtype refinement.">,
ListOption<"backendLegalOps", "backend-legal-ops", "std::string",
"List of ops to be considered legal for the backend, such as 'aten.foo'.">,
Option<"extraLibrary", "extra-library", "std::string", /*default=*/"",

View File

@ -263,10 +263,12 @@ class LowerToBackendContractPass
public:
LowerToBackendContractPass() = default;
LowerToBackendContractPass(int maxIterations, bool decompose,
bool shapeDtypeRefine,
ArrayRef<std::string> backendLegalOps,
StringRef extraLibrary) {
this->maxIterations = maxIterations;
this->decompose = decompose;
this->shapeDtypeRefine = shapeDtypeRefine;
this->backendLegalOps = backendLegalOps;
this->extraLibrary = extraLibrary.str();
}
@ -282,6 +284,7 @@ public:
OpPassManager pm(module.getOperationName());
TorchLoweringPipelineOptions options;
options.decompose = decompose;
options.shapeDtypeRefine = shapeDtypeRefine;
options.backendLegalOps = backendLegalOps;
options.extraLibrary = extraLibrary;
createTorchSimplificationPipeline(pm, options);
@ -336,10 +339,11 @@ public:
std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::Torch::createLowerToBackendContractPass(
int maxIterations, bool decompose, ArrayRef<std::string> backendLegalOps,
StringRef extraLibrary) {
int maxIterations, bool decompose, bool shapeDtypeRefine,
ArrayRef<std::string> backendLegalOps, StringRef extraLibrary) {
return std::make_unique<LowerToBackendContractPass>(
maxIterations, decompose, backendLegalOps, extraLibrary);
maxIterations, decompose, shapeDtypeRefine, backendLegalOps,
extraLibrary);
}
std::unique_ptr<OperationPass<ModuleOp>>

View File

@ -66,8 +66,8 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
// Perform the bulk of lowering to the backend contract.
// See the pass documentation for more information.
pm.addPass(createLowerToBackendContractPass(
options.maxIterations, options.decompose, options.backendLegalOps,
options.extraLibrary));
options.maxIterations, options.decompose, options.shapeDtypeRefine,
options.backendLegalOps, options.extraLibrary));
}
// A simplification pipeline to establish the invariants of the backend
@ -119,11 +119,13 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
// Update the return op to return value tensors.
pm.addPass(Torch::createRefinePublicReturnPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// Do shape and dtype refinement.
// Shape refinement should be run before dtype refinement because Torch type
// promotion rules actually depend on the shape of the operand.
createTorchShapeRefinementPipeline(pm, options);
createTorchDtypeRefinementPipeline(pm, options);
if (options.shapeDtypeRefine) {
// Do shape and dtype refinement.
// Shape refinement should be run before dtype refinement because Torch type
// promotion rules actually depend on the shape of the operand.
createTorchShapeRefinementPipeline(pm, options);
createTorchDtypeRefinementPipeline(pm, options);
}
// Propagate to ABI return types the shape/dtype information discovered by
// the previous pass. Doing this is ABI-compatible for our backends.
pm.addPass(Torch::createRefinePublicReturnPass());

View File

@ -44,6 +44,9 @@ DEFAULT_DECOMPOSITIONS = [
torch.ops.aten._log_softmax_backward_data,
torch.ops.aten.lift_fresh_copy.default,
torch.ops.aten._unsafe_index.Tensor,
torch.ops.aten.linspace.default,
torch.ops.aten.triu.default,
torch.ops.aten.nan_to_num.default,
]

View File

@ -45,6 +45,8 @@ def _module_lowering(
+ ",".join(backend_legal_ops)
+ " extra-library="
+ extra_library_file_name
+ " shape-dtype-refine="
+ ("false" if not backend_legal_ops and not extra_library_file_name else "true")
+ "}"
)
run_pipeline_with_repro_report(