mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add torch-onnx-to-torch-backend pipeline (#3801)
This commit adds the torch-onnx-to-torch-backend pipeline which converts the Torch Onnx IR to Torch Backend IR. This commit also moves the `ScalarizeShapes` pass from the `torch-backend-to-linalg-on-tensors-backend-pipeline` to the `torch-onnx-to-torch-backend` pipeline since the primary goal of this pass is to scalarize the shapes in the IR coming from the Onnx models.pull/3811/head
parent
d2330df58f
commit
fa4794dae2
|
@ -84,6 +84,11 @@ void createTorchDynamoExportToTorchBackendPipeline(
|
|||
void createTorchFunctionToTorchBackendPipeline(
|
||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
|
||||
|
||||
/// Creates a pipeline that lowers the torch Onnx IR that is produced by
|
||||
/// Onnx import into the form expected by torch-verify-backend-contract.
|
||||
void createTorchOnnxToTorchBackendPipeline(
|
||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
|
||||
|
||||
/// Creates a pipeline that simplifies the computations in the program.
|
||||
/// This pass does not do any global program restructuring -- it works entirely
|
||||
/// within a single semantic model of a `builtin.module` with
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h"
|
||||
|
||||
void mlir::torch::registerTorchPasses() {
|
||||
mlir::torch::registerPasses();
|
||||
|
@ -25,6 +26,10 @@ void mlir::torch::registerTorchPasses() {
|
|||
"torch-function-to-torch-backend-pipeline",
|
||||
"Pipeline lowering a Torch function to Torch backend form.",
|
||||
mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline);
|
||||
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||
"torch-onnx-to-torch-backend-pipeline",
|
||||
"Pipeline lowering Torch Onnx IR to Torch backend form.",
|
||||
mlir::torch::Torch::createTorchOnnxToTorchBackendPipeline);
|
||||
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||
"torch-simplification-pipeline",
|
||||
"Pipeline simplifying computations in the program.",
|
||||
|
@ -86,6 +91,37 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
|
|||
options.backendLegalOps, options.extraLibrary));
|
||||
}
|
||||
|
||||
void mlir::torch::Torch::createTorchOnnxToTorchBackendPipeline(
|
||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
|
||||
pm.addNestedPass<func::FuncOp>(onnx_c::createTorchOnnxToTorchPass());
|
||||
// The above pass just converts the torch onnx IR to torch, hence the given
|
||||
// pipeline will make sure that the IR is transformed such that it satisfies
|
||||
// the backend contract.
|
||||
if (options.decompose) {
|
||||
pm.addNestedPass<func::FuncOp>(
|
||||
Torch::createDecomposeComplexOpsPass(options.backendLegalOps));
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
}
|
||||
// TODO: Move the combination of two passes i.e., ScalarizeShapes and
|
||||
// TorchShapeRefinementPipeline out of here and create an onnx shape
|
||||
// refinement pipeline which runs iteratively over the IR.
|
||||
createTorchShapeRefinementPipeline(pm, options);
|
||||
// This pass scalarizes the tensor shape computations.
|
||||
pm.addNestedPass<mlir::func::FuncOp>(
|
||||
mlir::torch::Torch::createScalarizeShapesPass());
|
||||
createTorchShapeRefinementPipeline(pm, options);
|
||||
pm.addPass(Torch::createRefinePublicReturnPass());
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
// The decompose pass is run again here since the scalarize shapes pass and
|
||||
// shape refinement pipeline might create some ops for which decomposition
|
||||
// exists.
|
||||
if (options.decompose) {
|
||||
pm.addNestedPass<func::FuncOp>(
|
||||
Torch::createDecomposeComplexOpsPass(options.backendLegalOps));
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
}
|
||||
}
|
||||
|
||||
// A simplification pipeline to establish the invariants of the backend
|
||||
// contract (see `satisfiedBackendContract` in `LowerToBackendContract`).
|
||||
//
|
||||
|
|
|
@ -70,7 +70,6 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
|
|||
|
||||
// We want to fuse quantized operations together before lowering to linalg.
|
||||
pm.addNestedPass<func::FuncOp>(Torch::createFuseQuantizedOpsPass());
|
||||
pm.addNestedPass<func::FuncOp>(Torch::createScalarizeShapesPass());
|
||||
|
||||
// Lower to linalg + guards which is the input to codegen backends.
|
||||
// We do this first as it tends to involve pattern-matching against constants,
|
||||
|
|
|
@ -100,33 +100,25 @@ def _module_lowering(
|
|||
print("ONNX RAW IR")
|
||||
print(torch_mod)
|
||||
|
||||
# Lower from ONNX to Torch
|
||||
run_pipeline_with_repro_report(
|
||||
torch_mod,
|
||||
# The importer may produce additional MLIR functions corresponding to
|
||||
# ONNX operators that are functions. In some cases they need to be
|
||||
# inlined to avoid the backend choking on them.
|
||||
f"builtin.module(inline, func.func({ONNX_TO_TORCH_FUNC_PIPELINE}))",
|
||||
"Lowering Onnx backend contract to Linalg-on-Tensors backend contract",
|
||||
)
|
||||
|
||||
if verbose:
|
||||
print("\n====================")
|
||||
print("TorchFX IR")
|
||||
print(torch_mod)
|
||||
|
||||
backend_legal_ops = [
|
||||
"aten.flatten.using_ints",
|
||||
"aten.adaptive_avg_pool1d",
|
||||
"aten.unflatten.int",
|
||||
]
|
||||
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + "}"
|
||||
|
||||
# Lower from ONNX to Torch
|
||||
run_pipeline_with_repro_report(
|
||||
torch_mod,
|
||||
f"builtin.module(torch-lower-to-backend-contract{option_string})",
|
||||
"Lowering TorchFX IR -> Torch Backend IR",
|
||||
f"builtin.module(torch-onnx-to-torch-backend-pipeline{option_string})",
|
||||
"Lowering Onnx Raw IR -> Torch Backend IR",
|
||||
)
|
||||
|
||||
if verbose:
|
||||
print("\n====================")
|
||||
print("Torch IR")
|
||||
print(torch_mod)
|
||||
|
||||
return lower_mlir_module(verbose, output_type, torch_mod)
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
// RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-onnx-to-torch-backend-pipeline{backend-legal-ops=aten.flatten.using_ints,aten.unflatten.int})' -split-input-file %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @test_reshape_negative_dim_decompose
|
||||
func.func @test_reshape_negative_dim_decompose(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,6,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
||||
// CHECK: %[[INT6:.+]] = torch.constant.int 6
|
||||
// CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT6]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: torch.aten.view %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[2,6,2],f32>
|
||||
%0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,6,2],f32>
|
||||
return %0 : !torch.vtensor<[2,6,2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_triu_decompose
|
||||
func.func @test_triu_decompose(%arg0: !torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[ZERO_TENSOR:.+]] = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK: %[[INT4:.+]] = torch.constant.int 4
|
||||
// CHECK: %[[INT5:.+]] = torch.constant.int 5
|
||||
// CHECK: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[INT0]], %[[INT4]], %[[INT1]], %[[INT4]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[4],si64>
|
||||
// CHECK: %[[ARANGE_0:.+]] = torch.aten.arange.start_step %[[INT0]], %[[INT5]], %[[INT1]], %[[INT4]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[5],si64>
|
||||
// CHECK: %[[UNSQUEEZE:.+]] = torch.aten.unsqueeze %[[ARANGE]], %[[INT1]] : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64>
|
||||
// CHECK: %[[UNSQUEEZE_0:.+]] = torch.aten.unsqueeze %[[ARANGE_0]], %[[INT0]] : !torch.vtensor<[5],si64>, !torch.int -> !torch.vtensor<[1,5],si64>
|
||||
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %[[UNSQUEEZE]], %[[INT0]], %[[INT1]] : !torch.vtensor<[4,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1],si64>
|
||||
// CHECK: %[[COND:.+]] = torch.aten.ge.Tensor %[[UNSQUEEZE_0]], %[[ADD]] : !torch.vtensor<[1,5],si64>, !torch.vtensor<[4,1],si64> -> !torch.vtensor<[4,5],i1>
|
||||
// CHECK: %[[RESULT:.+]] = torch.aten.where.self %[[COND]], %arg0, %[[ZERO_TENSOR]] : !torch.vtensor<[4,5],i1>, !torch.vtensor<[4,5],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[4,5],si64>
|
||||
%0 = torch.operator "onnx.Trilu"(%arg0) : (!torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64>
|
||||
return %0 : !torch.vtensor<[4,5],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
// CHECK-LABEL: func.func @test_scalarize
|
||||
func.func @test_scalarize(%arg0: !torch.vtensor<[?,?,16,64],f32>) -> !torch.vtensor<[?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "1.11.0"} {
|
||||
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
||||
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
||||
// CHECK: %[[ADD:.+]] = torch.aten.flatten.using_ints %arg0, %[[INT2]], %[[INT3]] : !torch.vtensor<[?,?,16,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,1024],f32>
|
||||
%0 = torch.operator "onnx.Shape"(%arg0) : (!torch.vtensor<[?,?,16,64],f32>) -> !torch.vtensor<[4],si64>
|
||||
%1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__21> : tensor<si64>} : () -> !torch.vtensor<[],si64>
|
||||
%2 = torch.operator "onnx.Gather"(%0, %1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64>
|
||||
%3 = torch.operator "onnx.Shape"(%arg0) : (!torch.vtensor<[?,?,16,64],f32>) -> !torch.vtensor<[4],si64>
|
||||
%4 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__22> : tensor<si64>} : () -> !torch.vtensor<[],si64>
|
||||
%5 = torch.operator "onnx.Gather"(%3, %4) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64>
|
||||
%6 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
|
||||
%7 = torch.operator "onnx.Unsqueeze"(%2, %6) : (!torch.vtensor<[],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64>
|
||||
%8 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
|
||||
%9 = torch.operator "onnx.Unsqueeze"(%5, %8) : (!torch.vtensor<[],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64>
|
||||
%10 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_onnx__Concat_3209> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
|
||||
%11 = torch.operator "onnx.Concat"(%7, %9, %10) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3],si64>
|
||||
%12 = torch.operator "onnx.Reshape"(%arg0, %11) : (!torch.vtensor<[?,?,16,64],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,?,?],f32>
|
||||
return %12 : !torch.vtensor<[?,?,?],f32>
|
||||
}
|
||||
}
|
||||
|
||||
{-#
|
||||
dialect_resources: {
|
||||
builtin: {
|
||||
__21: "0x080000000000000000000000",
|
||||
__22: "0x080000000100000000000000",
|
||||
_onnx__Concat_3209: "0x080000000004000000000000"
|
||||
}
|
||||
}
|
||||
#-}
|
Loading…
Reference in New Issue