diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h index bdc28afbe..d4cceb05d 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h @@ -42,6 +42,9 @@ struct TorchLoweringPipelineOptions Option decompose{*this, "decompose-complex-ops", llvm::cl::desc("Decompose complex operations."), llvm::cl::init(true)}; + Option shapeDtypeRefine{ + *this, "shape-dtype-refine", + llvm::cl::desc("Do shape and dtype refinement."), llvm::cl::init(true)}; // A list of ops that should be considered legal for the backend. // TODO: The meaning of this list should be formalized. // A sketch of the semantics would be: @@ -130,10 +133,9 @@ createDropAbstractInterpCalculationsPass(); std::unique_ptr> createEraseModuleInitializerPass(); -std::unique_ptr> -createLowerToBackendContractPass(int maxIterations, bool decompose, - ArrayRef backendLegalOps, - StringRef extraLibrary); +std::unique_ptr> createLowerToBackendContractPass( + int maxIterations, bool decompose, bool shapeDtypeRefine, + ArrayRef backendLegalOps, StringRef extraLibrary); std::unique_ptr> createVerifyBackendContractNoDecompositionsPass(); diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td index 715b2265d..6439feb39 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td @@ -362,7 +362,7 @@ def LowerToBackendContract let summary = "Perform simplifications until the backend contract is satisfied."; let constructor = [{ mlir::torch::Torch::createLowerToBackendContractPass( - /*maxIterations=*/10, /*decompose=*/true, /*backendLegalOps=*/{}, /*extraLibrary=*/"") + /*maxIterations=*/10, /*decompose=*/true, /*shapeDtypeRefine*/true, /*backendLegalOps=*/{}, /*extraLibrary=*/"") }]; let description = [{ This pass performs the bulk of the lowering of the program's computations @@ -405,6 +405,8 @@ def LowerToBackendContract "Maximum number of invocations of the simplification pipeline.">, Option<"decompose", "decompose", "bool", /*default=*/"true", "Decompose ops.">, + Option<"shapeDtypeRefine", "shape-dtype-refine", "bool", /*default=*/"true", + "Do shape and dtype refinement.">, ListOption<"backendLegalOps", "backend-legal-ops", "std::string", "List of ops to be considered legal for the backend, such as 'aten.foo'.">, Option<"extraLibrary", "extra-library", "std::string", /*default=*/"", diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 3981cff44..bda2d258a 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -263,10 +263,12 @@ class LowerToBackendContractPass public: LowerToBackendContractPass() = default; LowerToBackendContractPass(int maxIterations, bool decompose, + bool shapeDtypeRefine, ArrayRef backendLegalOps, StringRef extraLibrary) { this->maxIterations = maxIterations; this->decompose = decompose; + this->shapeDtypeRefine = shapeDtypeRefine; this->backendLegalOps = backendLegalOps; this->extraLibrary = extraLibrary.str(); } @@ -282,6 +284,7 @@ public: OpPassManager pm(module.getOperationName()); TorchLoweringPipelineOptions options; options.decompose = decompose; + options.shapeDtypeRefine = shapeDtypeRefine; options.backendLegalOps = backendLegalOps; options.extraLibrary = extraLibrary; createTorchSimplificationPipeline(pm, options); @@ -336,10 +339,11 @@ public: std::unique_ptr> mlir::torch::Torch::createLowerToBackendContractPass( - int maxIterations, bool decompose, ArrayRef backendLegalOps, - StringRef extraLibrary) { + int maxIterations, bool decompose, bool shapeDtypeRefine, + ArrayRef backendLegalOps, StringRef extraLibrary) { return std::make_unique( - maxIterations, decompose, backendLegalOps, extraLibrary); + maxIterations, decompose, shapeDtypeRefine, backendLegalOps, + extraLibrary); } std::unique_ptr> diff --git a/lib/Dialect/Torch/Transforms/Passes.cpp b/lib/Dialect/Torch/Transforms/Passes.cpp index 407e90247..d01eac967 100644 --- a/lib/Dialect/Torch/Transforms/Passes.cpp +++ b/lib/Dialect/Torch/Transforms/Passes.cpp @@ -66,8 +66,8 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline( // Perform the bulk of lowering to the backend contract. // See the pass documentation for more information. pm.addPass(createLowerToBackendContractPass( - options.maxIterations, options.decompose, options.backendLegalOps, - options.extraLibrary)); + options.maxIterations, options.decompose, options.shapeDtypeRefine, + options.backendLegalOps, options.extraLibrary)); } // A simplification pipeline to establish the invariants of the backend @@ -119,11 +119,13 @@ void mlir::torch::Torch::createTorchSimplificationPipeline( // Update the return op to return value tensors. pm.addPass(Torch::createRefinePublicReturnPass()); pm.addNestedPass(createCanonicalizerPass()); - // Do shape and dtype refinement. - // Shape refinement should be run before dtype refinement because Torch type - // promotion rules actually depend on the shape of the operand. - createTorchShapeRefinementPipeline(pm, options); - createTorchDtypeRefinementPipeline(pm, options); + if (options.shapeDtypeRefine) { + // Do shape and dtype refinement. + // Shape refinement should be run before dtype refinement because Torch type + // promotion rules actually depend on the shape of the operand. + createTorchShapeRefinementPipeline(pm, options); + createTorchDtypeRefinementPipeline(pm, options); + } // Propagate to ABI return types the shape/dtype information discovered by // the previous pass. Doing this is ABI-compatible for our backends. pm.addPass(Torch::createRefinePublicReturnPass()); diff --git a/python/torch_mlir/extras/fx_decomp_util.py b/python/torch_mlir/extras/fx_decomp_util.py index e049a0149..7a6f67b22 100644 --- a/python/torch_mlir/extras/fx_decomp_util.py +++ b/python/torch_mlir/extras/fx_decomp_util.py @@ -44,6 +44,9 @@ DEFAULT_DECOMPOSITIONS = [ torch.ops.aten._log_softmax_backward_data, torch.ops.aten.lift_fresh_copy.default, torch.ops.aten._unsafe_index.Tensor, + torch.ops.aten.linspace.default, + torch.ops.aten.triu.default, + torch.ops.aten.nan_to_num.default, ] diff --git a/python/torch_mlir/fx.py b/python/torch_mlir/fx.py index 8d5c5cb11..834cffd63 100644 --- a/python/torch_mlir/fx.py +++ b/python/torch_mlir/fx.py @@ -45,6 +45,8 @@ def _module_lowering( + ",".join(backend_legal_ops) + " extra-library=" + extra_library_file_name + + " shape-dtype-refine=" + + ("false" if not backend_legal_ops and not extra_library_file_name else "true") + "}" ) run_pipeline_with_repro_report(