mirror of https://github.com/llvm/torch-mlir
[FxImporter] Eliminate the dependency on the refinement pass (#3309)
parent
afe87d62b4
commit
64b59c7fc3
|
@ -42,6 +42,9 @@ struct TorchLoweringPipelineOptions
|
|||
Option<bool> decompose{*this, "decompose-complex-ops",
|
||||
llvm::cl::desc("Decompose complex operations."),
|
||||
llvm::cl::init(true)};
|
||||
Option<bool> shapeDtypeRefine{
|
||||
*this, "shape-dtype-refine",
|
||||
llvm::cl::desc("Do shape and dtype refinement."), llvm::cl::init(true)};
|
||||
// A list of ops that should be considered legal for the backend.
|
||||
// TODO: The meaning of this list should be formalized.
|
||||
// A sketch of the semantics would be:
|
||||
|
@ -130,10 +133,9 @@ createDropAbstractInterpCalculationsPass();
|
|||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createEraseModuleInitializerPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createLowerToBackendContractPass(int maxIterations, bool decompose,
|
||||
ArrayRef<std::string> backendLegalOps,
|
||||
StringRef extraLibrary);
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createLowerToBackendContractPass(
|
||||
int maxIterations, bool decompose, bool shapeDtypeRefine,
|
||||
ArrayRef<std::string> backendLegalOps, StringRef extraLibrary);
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createVerifyBackendContractNoDecompositionsPass();
|
||||
|
|
|
@ -362,7 +362,7 @@ def LowerToBackendContract
|
|||
let summary = "Perform simplifications until the backend contract is satisfied.";
|
||||
let constructor = [{
|
||||
mlir::torch::Torch::createLowerToBackendContractPass(
|
||||
/*maxIterations=*/10, /*decompose=*/true, /*backendLegalOps=*/{}, /*extraLibrary=*/"")
|
||||
/*maxIterations=*/10, /*decompose=*/true, /*shapeDtypeRefine*/true, /*backendLegalOps=*/{}, /*extraLibrary=*/"")
|
||||
}];
|
||||
let description = [{
|
||||
This pass performs the bulk of the lowering of the program's computations
|
||||
|
@ -405,6 +405,8 @@ def LowerToBackendContract
|
|||
"Maximum number of invocations of the simplification pipeline.">,
|
||||
Option<"decompose", "decompose", "bool", /*default=*/"true",
|
||||
"Decompose ops.">,
|
||||
Option<"shapeDtypeRefine", "shape-dtype-refine", "bool", /*default=*/"true",
|
||||
"Do shape and dtype refinement.">,
|
||||
ListOption<"backendLegalOps", "backend-legal-ops", "std::string",
|
||||
"List of ops to be considered legal for the backend, such as 'aten.foo'.">,
|
||||
Option<"extraLibrary", "extra-library", "std::string", /*default=*/"",
|
||||
|
|
|
@ -263,10 +263,12 @@ class LowerToBackendContractPass
|
|||
public:
|
||||
LowerToBackendContractPass() = default;
|
||||
LowerToBackendContractPass(int maxIterations, bool decompose,
|
||||
bool shapeDtypeRefine,
|
||||
ArrayRef<std::string> backendLegalOps,
|
||||
StringRef extraLibrary) {
|
||||
this->maxIterations = maxIterations;
|
||||
this->decompose = decompose;
|
||||
this->shapeDtypeRefine = shapeDtypeRefine;
|
||||
this->backendLegalOps = backendLegalOps;
|
||||
this->extraLibrary = extraLibrary.str();
|
||||
}
|
||||
|
@ -282,6 +284,7 @@ public:
|
|||
OpPassManager pm(module.getOperationName());
|
||||
TorchLoweringPipelineOptions options;
|
||||
options.decompose = decompose;
|
||||
options.shapeDtypeRefine = shapeDtypeRefine;
|
||||
options.backendLegalOps = backendLegalOps;
|
||||
options.extraLibrary = extraLibrary;
|
||||
createTorchSimplificationPipeline(pm, options);
|
||||
|
@ -336,10 +339,11 @@ public:
|
|||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::torch::Torch::createLowerToBackendContractPass(
|
||||
int maxIterations, bool decompose, ArrayRef<std::string> backendLegalOps,
|
||||
StringRef extraLibrary) {
|
||||
int maxIterations, bool decompose, bool shapeDtypeRefine,
|
||||
ArrayRef<std::string> backendLegalOps, StringRef extraLibrary) {
|
||||
return std::make_unique<LowerToBackendContractPass>(
|
||||
maxIterations, decompose, backendLegalOps, extraLibrary);
|
||||
maxIterations, decompose, shapeDtypeRefine, backendLegalOps,
|
||||
extraLibrary);
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
|
|
|
@ -66,8 +66,8 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
|
|||
// Perform the bulk of lowering to the backend contract.
|
||||
// See the pass documentation for more information.
|
||||
pm.addPass(createLowerToBackendContractPass(
|
||||
options.maxIterations, options.decompose, options.backendLegalOps,
|
||||
options.extraLibrary));
|
||||
options.maxIterations, options.decompose, options.shapeDtypeRefine,
|
||||
options.backendLegalOps, options.extraLibrary));
|
||||
}
|
||||
|
||||
// A simplification pipeline to establish the invariants of the backend
|
||||
|
@ -119,11 +119,13 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
|
|||
// Update the return op to return value tensors.
|
||||
pm.addPass(Torch::createRefinePublicReturnPass());
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
// Do shape and dtype refinement.
|
||||
// Shape refinement should be run before dtype refinement because Torch type
|
||||
// promotion rules actually depend on the shape of the operand.
|
||||
createTorchShapeRefinementPipeline(pm, options);
|
||||
createTorchDtypeRefinementPipeline(pm, options);
|
||||
if (options.shapeDtypeRefine) {
|
||||
// Do shape and dtype refinement.
|
||||
// Shape refinement should be run before dtype refinement because Torch type
|
||||
// promotion rules actually depend on the shape of the operand.
|
||||
createTorchShapeRefinementPipeline(pm, options);
|
||||
createTorchDtypeRefinementPipeline(pm, options);
|
||||
}
|
||||
// Propagate to ABI return types the shape/dtype information discovered by
|
||||
// the previous pass. Doing this is ABI-compatible for our backends.
|
||||
pm.addPass(Torch::createRefinePublicReturnPass());
|
||||
|
|
|
@ -44,6 +44,9 @@ DEFAULT_DECOMPOSITIONS = [
|
|||
torch.ops.aten._log_softmax_backward_data,
|
||||
torch.ops.aten.lift_fresh_copy.default,
|
||||
torch.ops.aten._unsafe_index.Tensor,
|
||||
torch.ops.aten.linspace.default,
|
||||
torch.ops.aten.triu.default,
|
||||
torch.ops.aten.nan_to_num.default,
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -45,6 +45,8 @@ def _module_lowering(
|
|||
+ ",".join(backend_legal_ops)
|
||||
+ " extra-library="
|
||||
+ extra_library_file_name
|
||||
+ " shape-dtype-refine="
|
||||
+ ("false" if not backend_legal_ops and not extra_library_file_name else "true")
|
||||
+ "}"
|
||||
)
|
||||
run_pipeline_with_repro_report(
|
||||
|
|
Loading…
Reference in New Issue