torch-mlir/python/torch_mlir/fx.py

119 lines
3.9 KiB
Python
Raw Normal View History

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from typing import Optional, Union, Dict, Tuple, Any, Callable
import warnings
import torch
import torch.export
import torch.nn as nn
from torch.export import ExportedProgram
from torch_mlir.extras.fx_importer import FxImporter, FxImporterHooks
from torch_mlir import ir
from torch_mlir.dialects import torch as torch_d
from torch_mlir.extras.fx_decomp_util import get_decomposition_table
from torch_mlir.compiler_utils import (
OutputType,
run_pipeline_with_repro_report,
lower_mlir_module,
)
def _module_lowering(
verbose,
output_type,
torch_mod,
extra_library_file_name=None,
):
if output_type == OutputType.TORCH:
if verbose:
print(torch_mod)
return torch_mod
[Pipeline] Use dedicated simplification pipeline for TorchDynamo frontend (#3376) Discord Thread: https://discord.com/channels/636084430946959380/1238330633328005243 ## Context: [This](https://github.com/llvm/torch-mlir/blob/main/python/torch_mlir/fx.py#L61) was updated to support e2e tests for the TorchDynamo frontend in Torch-MLIR, where we run FX decompositions and import the FX IR to generate Torch dialect, followed by `torch-function-to-torch-backend-pipeline`, skipping only the shape/type refinement for now. However, we should be able to skip many of the torch simplification passes, as depicted in the [frontend roadmap](https://github.com/llvm/torch-mlir/blob/main/docs/images/roadmap_frontend.png). Based on IREE's TorchDynamo [pipeline](https://github.com/iree-org/iree/blob/main/compiler/plugins/input/Torch/InputConversion/Passes.cpp#L29), the only two passes we seem to require are: `ReduceOpVariantsPass` and `DecomposeComplexOpsPass`. This is inline with our findings as well based on initial exploration. This PR creates a dedicated frontend simplification pipeline for TorchDynamo / FX Importer which calls only `ReduceOpVariantsPass` and `DecomposeComplexOpsPass`. We rely on the e2e fx_importer tests to ensure we're not regressing by removing many of the passes that were historically needed for TorchScript. One notable change here is that we do not call the `LowerToBackendContractPass` anymore, which used to call `TorchSimplificationPipeline` iteratively until VerifyBackendContract was clean. Some of this was required for the shape/type refinement to converge, which seems a non-issue for Dynamo frontend. Do we anticipate this (the iterative invocation of TorchSimplificationPipeline followed by VerifyBackendContract) to be worth retaining in the Dynamo frontend pipeline? If so, I can make those changes, PLMK.
2024-05-22 20:23:18 +08:00
# TODO: pass extra_library_file_name by caller
if extra_library_file_name is None:
extra_library_file_name = ""
[Pipeline] Use dedicated simplification pipeline for TorchDynamo frontend (#3376) Discord Thread: https://discord.com/channels/636084430946959380/1238330633328005243 ## Context: [This](https://github.com/llvm/torch-mlir/blob/main/python/torch_mlir/fx.py#L61) was updated to support e2e tests for the TorchDynamo frontend in Torch-MLIR, where we run FX decompositions and import the FX IR to generate Torch dialect, followed by `torch-function-to-torch-backend-pipeline`, skipping only the shape/type refinement for now. However, we should be able to skip many of the torch simplification passes, as depicted in the [frontend roadmap](https://github.com/llvm/torch-mlir/blob/main/docs/images/roadmap_frontend.png). Based on IREE's TorchDynamo [pipeline](https://github.com/iree-org/iree/blob/main/compiler/plugins/input/Torch/InputConversion/Passes.cpp#L29), the only two passes we seem to require are: `ReduceOpVariantsPass` and `DecomposeComplexOpsPass`. This is inline with our findings as well based on initial exploration. This PR creates a dedicated frontend simplification pipeline for TorchDynamo / FX Importer which calls only `ReduceOpVariantsPass` and `DecomposeComplexOpsPass`. We rely on the e2e fx_importer tests to ensure we're not regressing by removing many of the passes that were historically needed for TorchScript. One notable change here is that we do not call the `LowerToBackendContractPass` anymore, which used to call `TorchSimplificationPipeline` iteratively until VerifyBackendContract was clean. Some of this was required for the shape/type refinement to converge, which seems a non-issue for Dynamo frontend. Do we anticipate this (the iterative invocation of TorchSimplificationPipeline followed by VerifyBackendContract) to be worth retaining in the Dynamo frontend pipeline? If so, I can make those changes, PLMK.
2024-05-22 20:23:18 +08:00
option_string = "{extra-library=" + extra_library_file_name + "}"
run_pipeline_with_repro_report(
torch_mod,
[Pipeline] Use dedicated simplification pipeline for TorchDynamo frontend (#3376) Discord Thread: https://discord.com/channels/636084430946959380/1238330633328005243 ## Context: [This](https://github.com/llvm/torch-mlir/blob/main/python/torch_mlir/fx.py#L61) was updated to support e2e tests for the TorchDynamo frontend in Torch-MLIR, where we run FX decompositions and import the FX IR to generate Torch dialect, followed by `torch-function-to-torch-backend-pipeline`, skipping only the shape/type refinement for now. However, we should be able to skip many of the torch simplification passes, as depicted in the [frontend roadmap](https://github.com/llvm/torch-mlir/blob/main/docs/images/roadmap_frontend.png). Based on IREE's TorchDynamo [pipeline](https://github.com/iree-org/iree/blob/main/compiler/plugins/input/Torch/InputConversion/Passes.cpp#L29), the only two passes we seem to require are: `ReduceOpVariantsPass` and `DecomposeComplexOpsPass`. This is inline with our findings as well based on initial exploration. This PR creates a dedicated frontend simplification pipeline for TorchDynamo / FX Importer which calls only `ReduceOpVariantsPass` and `DecomposeComplexOpsPass`. We rely on the e2e fx_importer tests to ensure we're not regressing by removing many of the passes that were historically needed for TorchScript. One notable change here is that we do not call the `LowerToBackendContractPass` anymore, which used to call `TorchSimplificationPipeline` iteratively until VerifyBackendContract was clean. Some of this was required for the shape/type refinement to converge, which seems a non-issue for Dynamo frontend. Do we anticipate this (the iterative invocation of TorchSimplificationPipeline followed by VerifyBackendContract) to be worth retaining in the Dynamo frontend pipeline? If so, I can make those changes, PLMK.
2024-05-22 20:23:18 +08:00
f"builtin.module(torchdynamo-export-to-torch-backend-pipeline{option_string})",
"Lowering TorchFX IR -> Torch Backend IR",
enable_ir_printing=verbose,
)
return lower_mlir_module(verbose, output_type, torch_mod)
def export_and_import(
f: Union[nn.Module, ExportedProgram],
*args,
output_type: Union[str, OutputType] = OutputType.TORCH,
fx_importer: Optional[FxImporter] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
experimental_support_mutation: bool = False,
Representing Symbolic Shape Expressions in Torch Dialect (#3372) Torch Dialect with symbolic shape expressions: ```ll module { func.func @main(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> { %0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 100} : !torch.int %2 = torch.symbolic_int "s3" {min_val = 0, max_val = 50} : !torch.int torch.bind_symbolic_shape %arg0, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %arg1, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> %3 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %3, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> %4 = torch.aten.sigmoid %arg1 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %4, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> %5 = torch.prim.ListConstruct %3, %3, %4 : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor> %int1 = torch.constant.int 1 %6 = torch.aten.cat %5, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %6, [%0, %1, %2], #affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32> return %6 : !torch.vtensor<[?,?,3],f32> } } ``` For reference, this is the TorchDynamo exported program with symbolic shape expressions that the above Torch dialect program is imported from: ```py ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[s0, s1, 3]", y: "f32[s0, s3, 3]"): # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:31 in forward, code: a = torch.tanh(x) tanh: "f32[s0, s1, 3]" = torch.ops.aten.tanh.default(x); x = None # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:32 in forward, code: b = torch.sigmoid(y) sigmoid: "f32[s0, s3, 3]" = torch.ops.aten.sigmoid.default(y); y = None # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:33 in forward, code: return torch.cat((a, a, b), dim=1) cat: "f32[s0, 2*s1 + s3, 3]" = torch.ops.aten.cat.default([tanh, tanh, sigmoid], 1); tanh = sigmoid = None return (cat,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat'), target=None)]) Range constraints: {s0: ValueRanges(lower=5, upper=10, is_bool=False), s1: ValueRanges(lower=0, upper=100, is_bool=False), s3: ValueRanges(lower=0, upper=50, is_bool=False)} ``` Huge credit to @stellaraccident for the inputs that helped evaluate the various design options and arrive at the representation of choice. - [x] Op definitions for symbolic_int and bind_symbolic_shape ops - [x] fx_importer updates to import range constraints + create symbolic_int ops - [x] fx_importer changes for AffineMapAttr building + adding bind_symbolic_shape ops - [x] custom printer/parser for inlined AffineMap expressions in mlir assembly - [x] Dialect lit test - [x] fx_importer python lit tests - [ ] Cleanup pass to remove these ops (can add in a follow-on)
2024-06-07 19:04:03 +08:00
import_symbolic_shape_expressions: bool = False,
hooks: Optional[FxImporterHooks] = None,
decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None,
func_name: str = "main",
enable_graph_printing: bool = False,
enable_ir_printing: bool = False,
**kwargs,
):
context = ir.Context()
torch_d.register_dialect(context)
if fx_importer is None:
fx_importer = FxImporter(context=context, hooks=hooks)
if isinstance(f, ExportedProgram):
prog = f
else:
prog = torch.export.export(f, args, kwargs, dynamic_shapes=dynamic_shapes)
if decomposition_table is None:
decomposition_table = get_decomposition_table()
if decomposition_table:
prog = prog.run_decompositions(decomposition_table)
if enable_graph_printing:
prog.graph_module.print_readable()
if experimental_support_mutation:
if torch.__version__ < "2.3.0.dev20240207":
warnings.warn("Mutable program import only supported on PyTorch 2.3+")
Representing Symbolic Shape Expressions in Torch Dialect (#3372) Torch Dialect with symbolic shape expressions: ```ll module { func.func @main(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> { %0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 100} : !torch.int %2 = torch.symbolic_int "s3" {min_val = 0, max_val = 50} : !torch.int torch.bind_symbolic_shape %arg0, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %arg1, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> %3 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %3, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> %4 = torch.aten.sigmoid %arg1 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %4, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> %5 = torch.prim.ListConstruct %3, %3, %4 : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor> %int1 = torch.constant.int 1 %6 = torch.aten.cat %5, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %6, [%0, %1, %2], #affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32> return %6 : !torch.vtensor<[?,?,3],f32> } } ``` For reference, this is the TorchDynamo exported program with symbolic shape expressions that the above Torch dialect program is imported from: ```py ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[s0, s1, 3]", y: "f32[s0, s3, 3]"): # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:31 in forward, code: a = torch.tanh(x) tanh: "f32[s0, s1, 3]" = torch.ops.aten.tanh.default(x); x = None # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:32 in forward, code: b = torch.sigmoid(y) sigmoid: "f32[s0, s3, 3]" = torch.ops.aten.sigmoid.default(y); y = None # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:33 in forward, code: return torch.cat((a, a, b), dim=1) cat: "f32[s0, 2*s1 + s3, 3]" = torch.ops.aten.cat.default([tanh, tanh, sigmoid], 1); tanh = sigmoid = None return (cat,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat'), target=None)]) Range constraints: {s0: ValueRanges(lower=5, upper=10, is_bool=False), s1: ValueRanges(lower=0, upper=100, is_bool=False), s3: ValueRanges(lower=0, upper=50, is_bool=False)} ``` Huge credit to @stellaraccident for the inputs that helped evaluate the various design options and arrive at the representation of choice. - [x] Op definitions for symbolic_int and bind_symbolic_shape ops - [x] fx_importer updates to import range constraints + create symbolic_int ops - [x] fx_importer changes for AffineMapAttr building + adding bind_symbolic_shape ops - [x] custom printer/parser for inlined AffineMap expressions in mlir assembly - [x] Dialect lit test - [x] fx_importer python lit tests - [ ] Cleanup pass to remove these ops (can add in a follow-on)
2024-06-07 19:04:03 +08:00
fx_importer.import_program(
prog,
func_name=func_name,
import_symbolic_shape_expressions=import_symbolic_shape_expressions,
)
else:
Representing Symbolic Shape Expressions in Torch Dialect (#3372) Torch Dialect with symbolic shape expressions: ```ll module { func.func @main(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> { %0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 100} : !torch.int %2 = torch.symbolic_int "s3" {min_val = 0, max_val = 50} : !torch.int torch.bind_symbolic_shape %arg0, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %arg1, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> %3 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %3, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> %4 = torch.aten.sigmoid %arg1 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %4, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> %5 = torch.prim.ListConstruct %3, %3, %4 : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor> %int1 = torch.constant.int 1 %6 = torch.aten.cat %5, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32> torch.bind_symbolic_shape %6, [%0, %1, %2], #affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32> return %6 : !torch.vtensor<[?,?,3],f32> } } ``` For reference, this is the TorchDynamo exported program with symbolic shape expressions that the above Torch dialect program is imported from: ```py ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[s0, s1, 3]", y: "f32[s0, s3, 3]"): # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:31 in forward, code: a = torch.tanh(x) tanh: "f32[s0, s1, 3]" = torch.ops.aten.tanh.default(x); x = None # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:32 in forward, code: b = torch.sigmoid(y) sigmoid: "f32[s0, s3, 3]" = torch.ops.aten.sigmoid.default(y); y = None # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:33 in forward, code: return torch.cat((a, a, b), dim=1) cat: "f32[s0, 2*s1 + s3, 3]" = torch.ops.aten.cat.default([tanh, tanh, sigmoid], 1); tanh = sigmoid = None return (cat,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat'), target=None)]) Range constraints: {s0: ValueRanges(lower=5, upper=10, is_bool=False), s1: ValueRanges(lower=0, upper=100, is_bool=False), s3: ValueRanges(lower=0, upper=50, is_bool=False)} ``` Huge credit to @stellaraccident for the inputs that helped evaluate the various design options and arrive at the representation of choice. - [x] Op definitions for symbolic_int and bind_symbolic_shape ops - [x] fx_importer updates to import range constraints + create symbolic_int ops - [x] fx_importer changes for AffineMapAttr building + adding bind_symbolic_shape ops - [x] custom printer/parser for inlined AffineMap expressions in mlir assembly - [x] Dialect lit test - [x] fx_importer python lit tests - [ ] Cleanup pass to remove these ops (can add in a follow-on)
2024-06-07 19:04:03 +08:00
fx_importer.import_frozen_program(
prog,
func_name=func_name,
import_symbolic_shape_expressions=import_symbolic_shape_expressions,
)
return _module_lowering(
enable_ir_printing, OutputType.get(output_type), fx_importer.module
)
2024-03-22 05:44:54 +08:00
def stateless_fx_import(
gm: torch.fx.GraphModule,
output_type: Union[str, OutputType] = OutputType.TORCH,
2024-03-22 05:44:54 +08:00
fx_importer: Optional[FxImporter] = None,
hooks: Optional[FxImporterHooks] = None,
model_name: str = "main",
enable_graph_printing: bool = False,
enable_ir_printing: bool = False,
2024-03-22 05:44:54 +08:00
):
if enable_graph_printing:
gm.print_readable()
2024-03-22 05:44:54 +08:00
context = ir.Context()
torch_d.register_dialect(context)
if fx_importer is None:
fx_importer = FxImporter(context=context, hooks=hooks)
fx_importer.import_stateless_graph(gm.graph, func_name=model_name)
return _module_lowering(
enable_ir_printing, OutputType.get(output_type), fx_importer.module
)