mirror of https://github.com/llvm/torch-mlir
Fall cleaning: Re-organize sources so that each target type is contained and optional.
While working on https://github.com/openxla/iree/pull/14917, I noticed that it is somewhat hard to take a dependency on torch-mlir such that one only builds deps for the target(s) of interest (in this case Linalg). I noticed that some ifdef'ey optionality was added for stablehlo, but this was not mirrored for the others. Further, it does the switching very deep in the dependency graph vs having top-level directories and defines gating entire features. In addition, I noticed that a lot of things in the Linalg path were broken down to a fine level of detail but were not actually shared/shareable outside of that target. I opted to clump these together into TorchToLinalg. It is easy enough to "promote" them to common with this new organization if the need arises. General approach: * Isolate each conversion target in one of TorchToLinalg, TorchToStablehlo, TorchToTosa. * Gate each by top-level CMake flags and defines. * Common conversions go in a Common/ directory (currently Arith and SCF). * Pull target specific conversions out of TorchConversion/Transforms and put in their top-level directory. * General maintenance on the build graph and registration stuff that had bitrotted and was blocking progress. The main functional change for people taking a source dep is that `#include "torch-mlir/Conversion/Passes.h"` no longer is a one stop shop: For optional conversions, you have to include the dedicated `Passes.h` of each and take a library dep. See `InitAll.cpp` which does it right (and *is* a one stop shop still).isolate_optional_targets
parent
9cb5d38cd1
commit
dd79067571
|
@ -36,10 +36,19 @@ macro(torch_mlir_add_llvm_external_project name identifier location)
|
||||||
set(LLVM_EXTERNAL_PROJECTS ${LLVM_EXTERNAL_PROJECTS} CACHE STRING "" FORCE)
|
set(LLVM_EXTERNAL_PROJECTS ${LLVM_EXTERNAL_PROJECTS} CACHE STRING "" FORCE)
|
||||||
endmacro()
|
endmacro()
|
||||||
|
|
||||||
option(TORCH_MLIR_ENABLE_STABLEHLO "Add stablehlo dialect" ON)
|
# Optional conversion targets.
|
||||||
|
if(TORCH_MLIR_ENABLE_LINALG)
|
||||||
|
add_definitions(-DTORCH_MLIR_ENABLE_LINALG)
|
||||||
|
endif()
|
||||||
|
option(TORCH_MLIR_ENABLE_STABLEHLO "Add stablehlo dialect conversions" ON)
|
||||||
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
||||||
add_definitions(-DTORCH_MLIR_ENABLE_STABLEHLO)
|
add_definitions(-DTORCH_MLIR_ENABLE_STABLEHLO)
|
||||||
endif()
|
endif()
|
||||||
|
option(TORCH_MLIR_ENABLE_TOSA "Add tosa dialect conversions" ON)
|
||||||
|
if(TORCH_MLIR_ENABLE_TOSA)
|
||||||
|
add_definitions(-DTORCH_MLIR_ENABLE_TOSA)
|
||||||
|
endif()
|
||||||
|
option(TORCH_MLIR_ENABLE_LINALG "Add linalg dialect" ON)
|
||||||
|
|
||||||
option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON)
|
option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON)
|
||||||
option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF)
|
option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF)
|
||||||
|
|
|
@ -1,9 +1,14 @@
|
||||||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||||
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||||
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO)
|
add_public_tablegen_target(TorchMLIRConversionCommonPassIncGen)
|
||||||
else()
|
add_mlir_doc(Passes TorchMLIRConversionCommonPasses ./ -gen-pass-doc)
|
||||||
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
|
||||||
endif()
|
|
||||||
add_public_tablegen_target(TorchMLIRConversionPassIncGen)
|
|
||||||
|
|
||||||
add_mlir_doc(Passes TorchMLIRConversionPasses ./ -gen-pass-doc)
|
if(TORCH_MLIR_ENABLE_LINALG)
|
||||||
|
add_subdirectory(TorchToLinalg)
|
||||||
|
endif()
|
||||||
|
if(TORCH_MLIR_ENABLE_TOSA)
|
||||||
|
add_subdirectory(TorchToTosa)
|
||||||
|
endif()
|
||||||
|
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
||||||
|
add_subdirectory(TorchToStablehlo)
|
||||||
|
endif()
|
|
@ -7,16 +7,21 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef TORCHMLIR_CONVERSION_PASSES_H
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#define TORCHMLIR_CONVERSION_PASSES_H
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace torch {
|
namespace torch {
|
||||||
|
|
||||||
/// Registers all torch-mlir conversion passes.
|
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToArithPass();
|
||||||
|
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToSCFPass();
|
||||||
|
|
||||||
|
// Note that this only registers common conversion passes. Backend
|
||||||
|
// specific passes with their own Passes.h in a subdirectory must be
|
||||||
|
// included/registered explicitly as they are all optional.
|
||||||
void registerConversionPasses();
|
void registerConversionPasses();
|
||||||
|
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#endif // TORCHMLIR_CONVERSION_PASSES_H
|
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
include "mlir/Pass/PassBase.td"
|
include "mlir/Pass/PassBase.td"
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Torch conversions
|
// Common conversions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def ConvertTorchToArith : Pass<"convert-torch-to-arith", "func::FuncOp"> {
|
def ConvertTorchToArith : Pass<"convert-torch-to-arith", "func::FuncOp"> {
|
||||||
|
@ -26,132 +26,4 @@ def ConvertTorchToSCF: Pass<"convert-torch-to-scf", "func::FuncOp"> {
|
||||||
let constructor = "mlir::torch::createConvertTorchToSCFPass()";
|
let constructor = "mlir::torch::createConvertTorchToSCFPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> {
|
|
||||||
let summary = "Convert recognized Torch ops to Linalg ops";
|
|
||||||
let description = [{
|
|
||||||
Convert ATen ops to linalg ops.
|
|
||||||
|
|
||||||
This pass's main responsibility is to bridge the world between ops
|
|
||||||
that safely terminate the program in case of operand shape mismatches
|
|
||||||
(ATen) and ops where such mismatches are undefined behavior (linalg).
|
|
||||||
|
|
||||||
To model the termination of the program for implementing error guards,
|
|
||||||
we use the `cf.assert` op.
|
|
||||||
This is a design decision that is at variance from other passes in the
|
|
||||||
ecosystem, which use the
|
|
||||||
`shape` dialect's witness system (`shape.cstr_*` family of ops feeding into
|
|
||||||
`shape.assuming` regions). This is a change in design decisions
|
|
||||||
from those passes (which the authors of this pass have contributed to).
|
|
||||||
The reasons for this change are heuristic, but boil down to:
|
|
||||||
1. The modeling of `shape.assuming` is odd, as it uses a region, which is
|
|
||||||
not a good fit for modeling error guards. Regions mark a "start" and an
|
|
||||||
"end" (which is their nesting property). But
|
|
||||||
modeling assertions in the program doesn't fit into that. For assertions,
|
|
||||||
only the "start" matters (once tested, a predicate remains true "forever"
|
|
||||||
-- it doesn't end at the "yield" of the region).
|
|
||||||
Thus, having regions places arbitrary "end"s that just add IR structure
|
|
||||||
that has no semantic value for modeling this problem! (and to make things
|
|
||||||
worse the "end"s, which we don't need, are what require "yielding"
|
|
||||||
values, which interrupts use-def chains). Consider the different
|
|
||||||
structural properties of regions:
|
|
||||||
a. IsolatedFromAbove region:
|
|
||||||
- "start" interrupts use-def chains,
|
|
||||||
- "end" interrupts use-def chains
|
|
||||||
- structurally protects from intra-block upward and downward
|
|
||||||
code motion
|
|
||||||
b. Capturing region (like `shape.assuming`):
|
|
||||||
- "start" does not interrupt use-def chains,
|
|
||||||
- "end" interrupts use-def chains
|
|
||||||
- structurally protects from intra-block upward and downward
|
|
||||||
code motion
|
|
||||||
c. What we "ideally" want:
|
|
||||||
- "start" interrupts use-def chains (can be pruned though)
|
|
||||||
- no "end" IR structure!
|
|
||||||
- structurally protects from intra-block upward code motion
|
|
||||||
(but not downward code motion!)
|
|
||||||
- Observation: We probably can't get all of this, but overall this
|
|
||||||
problem is much better suited for a "MemorySSA"-like
|
|
||||||
abstraction, call it "EffectSSA" which is constructed on-demand
|
|
||||||
based on MLIR's effect modeling system (rather than
|
|
||||||
`shape.assuming`, which only covers the effects the IR creator
|
|
||||||
encoded -- with witnesses/`shape.assuming` -- it is easy to forget
|
|
||||||
to handle effects other than those encoded in the
|
|
||||||
witness structure).
|
|
||||||
2. The presence of `shape.assuming` regions tends to create highly nested
|
|
||||||
IR structures, which don't interoperate well with any other IR
|
|
||||||
structures, and creates very bulky IR (and IR creation code). In general
|
|
||||||
if we are going to do anything with anything (e.g. canonicalize) we
|
|
||||||
end up needing need to either:
|
|
||||||
a. Flatten the `shape.assuming` IR (defeating the purpose of having
|
|
||||||
it).
|
|
||||||
b. Do some sort of shape.assuming "region merging".
|
|
||||||
c. Have special patterns that handle a subset of special cases (looking
|
|
||||||
through "yields" and such) and don't generalize.
|
|
||||||
3. Witnesses tend to encourage non-scalable peephole transformations, which
|
|
||||||
tend to make analyses/transformations non-robust to the presence of
|
|
||||||
control flow and side effecting ops (easy to forget to handle side
|
|
||||||
effects other than those modeled by the witness system).
|
|
||||||
4. All this code operates on ranked tensors, for which using individual
|
|
||||||
SSA values for sizes (rather than a "shape type") seems to
|
|
||||||
work really well at this level of abstraction based on prior experience
|
|
||||||
in other projects. (unranked code tends to benefit from having a discrete
|
|
||||||
"shape type" to model shapes).
|
|
||||||
|
|
||||||
We will see if we end up needing something like `shape.assuming`, but for
|
|
||||||
now, it seems likely we can do something simpler and just bypass it. The
|
|
||||||
design of having an EffectSSA that is constructed on-demand seems very
|
|
||||||
compelling for modeling effects more broadly.
|
|
||||||
}];
|
|
||||||
let constructor = "mlir::torch::createConvertTorchToLinalgPass()";
|
|
||||||
}
|
|
||||||
|
|
||||||
def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> {
|
|
||||||
let summary = "Convert Torch ops to TOSA ops";
|
|
||||||
let description = [{
|
|
||||||
This pass assumes that TOSA ops are responsible for emitting error
|
|
||||||
guards in case of shape mismatches.
|
|
||||||
}];
|
|
||||||
let constructor = "mlir::torch::createConvertTorchToTosaPass()";
|
|
||||||
}
|
|
||||||
|
|
||||||
def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> {
|
|
||||||
let summary = "Convert recognized Torch ops to TMTensor/Linalg ops";
|
|
||||||
let description = [{
|
|
||||||
Convert ATen ops to tmtensor/linalg ops.
|
|
||||||
|
|
||||||
This pass is similar to the TorchToLinalg pass; the difference is that this
|
|
||||||
pass also makes use of TMTensor Dialect, which the former one doesn't.
|
|
||||||
}];
|
|
||||||
let constructor = "mlir::torch::createConvertTorchToTMTensorPass()";
|
|
||||||
}
|
|
||||||
|
|
||||||
def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprogram", "ModuleOp"> {
|
|
||||||
let summary = "Convert recognized TorchConversion ops to MLProgram ops";
|
|
||||||
let description = [{
|
|
||||||
Convert TorchConversion ops to mlprogram ops.
|
|
||||||
}];
|
|
||||||
let constructor = "mlir::torch::createConvertTorchConversionToMLProgramPass()";
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
|
||||||
def ConvertTorchToStablehlo : Pass<"convert-torch-to-stablehlo", "func::FuncOp"> {
|
|
||||||
let summary = "Convert Torch ops to Stablehlo ops";
|
|
||||||
let description = [{
|
|
||||||
Convert Torch ops to Stablehlo ops.
|
|
||||||
}];
|
|
||||||
let constructor = "mlir::torch::createConvertTorchToStablehloPass()";
|
|
||||||
|
|
||||||
// Specify any options.
|
|
||||||
let options = [
|
|
||||||
Option<"enableStaticShape", "enable-static-shape", "bool", /*default=*/"false",
|
|
||||||
"Enable static shape conversion">,
|
|
||||||
// The i64 calculation is much slower than i32 on some devices, such as
|
|
||||||
// Nvidia GPU. One can truncate from i64 to i32 since dimension sizes
|
|
||||||
// are unlikely to exceed the range of i32(4GiB)
|
|
||||||
Option<"enableI32Index", "enable-i32-index", "bool", /*default=*/"false",
|
|
||||||
"Enable truncate index from i64 to i32(unsafely)">,
|
|
||||||
];
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#endif // TORCHMLIR_CONVERSION_PASSES
|
#endif // TORCHMLIR_CONVERSION_PASSES
|
||||||
|
|
|
@ -1,23 +0,0 @@
|
||||||
//===------------------------------------------------------------*- C++ -*-===//
|
|
||||||
//
|
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
// Also available under a BSD-style license. See LICENSE.
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
#ifndef TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H
|
|
||||||
#define TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H
|
|
||||||
|
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
|
||||||
#include "mlir/Pass/Pass.h"
|
|
||||||
|
|
||||||
namespace mlir {
|
|
||||||
namespace torch {
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
|
||||||
createConvertTorchConversionToMLProgramPass();
|
|
||||||
}
|
|
||||||
} // namespace mlir
|
|
||||||
|
|
||||||
#endif // TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H
|
|
|
@ -0,0 +1,4 @@
|
||||||
|
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||||
|
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||||
|
add_public_tablegen_target(TorchMLIRConversionLinalgPassIncGen)
|
||||||
|
add_mlir_doc(Passes TorchMLIRConversionLinalgPasses ./ -gen-pass-doc)
|
|
@ -0,0 +1,41 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCHMLIR_CONVERSION_LINALG_PASSES_H
|
||||||
|
#define TORCHMLIR_CONVERSION_LINALG_PASSES_H
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace torch {
|
||||||
|
|
||||||
|
/// Creates a pipeline that lowers from the torch backend contract to the
|
||||||
|
/// linalg-on-tensors backend contract.
|
||||||
|
void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm);
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
createVerifyLinalgOnTensorsBackendContractPass();
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToLinalgPass();
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
createConvertTorchConversionToMLProgramPass();
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTMTensorPass();
|
||||||
|
|
||||||
|
/// Registers all torch-mlir conversion passes.
|
||||||
|
void registerLinalgConversionPasses();
|
||||||
|
|
||||||
|
} // namespace torch
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TORCHMLIR_CONVERSION_PASSES_H
|
|
@ -0,0 +1,132 @@
|
||||||
|
//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCHMLIR_CONVERSION_LINALG_PASSES
|
||||||
|
#define TORCHMLIR_CONVERSION_LINALG_PASSES
|
||||||
|
|
||||||
|
include "mlir/Pass/PassBase.td"
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Torch conversions
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def ConvertTorchToArith : Pass<"convert-torch-to-arith", "func::FuncOp"> {
|
||||||
|
let summary = "Convert recognized Torch ops to Std ops";
|
||||||
|
let constructor = "mlir::torch::createConvertTorchToArithPass()";
|
||||||
|
}
|
||||||
|
|
||||||
|
def ConvertTorchToSCF: Pass<"convert-torch-to-scf", "func::FuncOp"> {
|
||||||
|
let summary = "Convert recognized Torch ops to SCF ops";
|
||||||
|
let constructor = "mlir::torch::createConvertTorchToSCFPass()";
|
||||||
|
}
|
||||||
|
|
||||||
|
def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> {
|
||||||
|
let summary = "Convert recognized Torch ops to Linalg ops";
|
||||||
|
let description = [{
|
||||||
|
Convert ATen ops to linalg ops.
|
||||||
|
|
||||||
|
This pass's main responsibility is to bridge the world between ops
|
||||||
|
that safely terminate the program in case of operand shape mismatches
|
||||||
|
(ATen) and ops where such mismatches are undefined behavior (linalg).
|
||||||
|
|
||||||
|
To model the termination of the program for implementing error guards,
|
||||||
|
we use the `cf.assert` op.
|
||||||
|
This is a design decision that is at variance from other passes in the
|
||||||
|
ecosystem, which use the
|
||||||
|
`shape` dialect's witness system (`shape.cstr_*` family of ops feeding into
|
||||||
|
`shape.assuming` regions). This is a change in design decisions
|
||||||
|
from those passes (which the authors of this pass have contributed to).
|
||||||
|
The reasons for this change are heuristic, but boil down to:
|
||||||
|
1. The modeling of `shape.assuming` is odd, as it uses a region, which is
|
||||||
|
not a good fit for modeling error guards. Regions mark a "start" and an
|
||||||
|
"end" (which is their nesting property). But
|
||||||
|
modeling assertions in the program doesn't fit into that. For assertions,
|
||||||
|
only the "start" matters (once tested, a predicate remains true "forever"
|
||||||
|
-- it doesn't end at the "yield" of the region).
|
||||||
|
Thus, having regions places arbitrary "end"s that just add IR structure
|
||||||
|
that has no semantic value for modeling this problem! (and to make things
|
||||||
|
worse the "end"s, which we don't need, are what require "yielding"
|
||||||
|
values, which interrupts use-def chains). Consider the different
|
||||||
|
structural properties of regions:
|
||||||
|
a. IsolatedFromAbove region:
|
||||||
|
- "start" interrupts use-def chains,
|
||||||
|
- "end" interrupts use-def chains
|
||||||
|
- structurally protects from intra-block upward and downward
|
||||||
|
code motion
|
||||||
|
b. Capturing region (like `shape.assuming`):
|
||||||
|
- "start" does not interrupt use-def chains,
|
||||||
|
- "end" interrupts use-def chains
|
||||||
|
- structurally protects from intra-block upward and downward
|
||||||
|
code motion
|
||||||
|
c. What we "ideally" want:
|
||||||
|
- "start" interrupts use-def chains (can be pruned though)
|
||||||
|
- no "end" IR structure!
|
||||||
|
- structurally protects from intra-block upward code motion
|
||||||
|
(but not downward code motion!)
|
||||||
|
- Observation: We probably can't get all of this, but overall this
|
||||||
|
problem is much better suited for a "MemorySSA"-like
|
||||||
|
abstraction, call it "EffectSSA" which is constructed on-demand
|
||||||
|
based on MLIR's effect modeling system (rather than
|
||||||
|
`shape.assuming`, which only covers the effects the IR creator
|
||||||
|
encoded -- with witnesses/`shape.assuming` -- it is easy to forget
|
||||||
|
to handle effects other than those encoded in the
|
||||||
|
witness structure).
|
||||||
|
2. The presence of `shape.assuming` regions tends to create highly nested
|
||||||
|
IR structures, which don't interoperate well with any other IR
|
||||||
|
structures, and creates very bulky IR (and IR creation code). In general
|
||||||
|
if we are going to do anything with anything (e.g. canonicalize) we
|
||||||
|
end up needing need to either:
|
||||||
|
a. Flatten the `shape.assuming` IR (defeating the purpose of having
|
||||||
|
it).
|
||||||
|
b. Do some sort of shape.assuming "region merging".
|
||||||
|
c. Have special patterns that handle a subset of special cases (looking
|
||||||
|
through "yields" and such) and don't generalize.
|
||||||
|
3. Witnesses tend to encourage non-scalable peephole transformations, which
|
||||||
|
tend to make analyses/transformations non-robust to the presence of
|
||||||
|
control flow and side effecting ops (easy to forget to handle side
|
||||||
|
effects other than those modeled by the witness system).
|
||||||
|
4. All this code operates on ranked tensors, for which using individual
|
||||||
|
SSA values for sizes (rather than a "shape type") seems to
|
||||||
|
work really well at this level of abstraction based on prior experience
|
||||||
|
in other projects. (unranked code tends to benefit from having a discrete
|
||||||
|
"shape type" to model shapes).
|
||||||
|
|
||||||
|
We will see if we end up needing something like `shape.assuming`, but for
|
||||||
|
now, it seems likely we can do something simpler and just bypass it. The
|
||||||
|
design of having an EffectSSA that is constructed on-demand seems very
|
||||||
|
compelling for modeling effects more broadly.
|
||||||
|
}];
|
||||||
|
let constructor = "mlir::torch::createConvertTorchToLinalgPass()";
|
||||||
|
}
|
||||||
|
|
||||||
|
def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> {
|
||||||
|
let summary = "Convert recognized Torch ops to TMTensor/Linalg ops";
|
||||||
|
let description = [{
|
||||||
|
Convert ATen ops to tmtensor/linalg ops.
|
||||||
|
|
||||||
|
This pass is similar to the TorchToLinalg pass; the difference is that this
|
||||||
|
pass also makes use of TMTensor Dialect, which the former one doesn't.
|
||||||
|
}];
|
||||||
|
let constructor = "mlir::torch::createConvertTorchToTMTensorPass()";
|
||||||
|
}
|
||||||
|
|
||||||
|
def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprogram", "ModuleOp"> {
|
||||||
|
let summary = "Convert recognized TorchConversion ops to MLProgram ops";
|
||||||
|
let description = [{
|
||||||
|
Convert TorchConversion ops to mlprogram ops.
|
||||||
|
}];
|
||||||
|
let constructor = "mlir::torch::createConvertTorchConversionToMLProgramPass()";
|
||||||
|
}
|
||||||
|
|
||||||
|
def VerifyLinalgOnTensorsBackendContract : Pass<"torch-verify-linalg-on-tensors-backend-contract", "ModuleOp"> {
|
||||||
|
let summary = "Verifies conformity to the linalg-on-tensors backend contract";
|
||||||
|
let constructor = "mlir::torch::createVerifyLinalgOnTensorsBackendContractPass()";
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // TORCHMLIR_CONVERSION_LINALG_PASSES
|
|
@ -1,24 +0,0 @@
|
||||||
//===------------------------------------------------------------*- C++ -*-===//
|
|
||||||
//
|
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
// Also available under a BSD-style license. See LICENSE.
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
#ifndef TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H
|
|
||||||
#define TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H
|
|
||||||
|
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
||||||
#include "mlir/Pass/Pass.h"
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
namespace mlir {
|
|
||||||
namespace torch {
|
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToLinalgPass();
|
|
||||||
}
|
|
||||||
} // namespace mlir
|
|
||||||
|
|
||||||
#endif // TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H
|
|
|
@ -1,22 +0,0 @@
|
||||||
//===------------------------------------------------------------*- C++ -*-===//
|
|
||||||
//
|
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
// Also available under a BSD-style license. See LICENSE.
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
#ifndef TORCHMLIR_CONVERSION_TORCHTOSCF_TORCHTOSCF_H
|
|
||||||
#define TORCHMLIR_CONVERSION_TORCHTOSCF_TORCHTOSCF_H
|
|
||||||
|
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
||||||
#include "mlir/Pass/Pass.h"
|
|
||||||
|
|
||||||
namespace mlir {
|
|
||||||
namespace torch {
|
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToSCFPass();
|
|
||||||
}
|
|
||||||
} // namespace mlir
|
|
||||||
|
|
||||||
#endif // TORCHMLIR_CONVERSION_TORCHTOSCF_TORCHTOSCF_H
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||||
|
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||||
|
add_public_tablegen_target(TorchMLIRConversionStablehloPassIncGen)
|
||||||
|
|
||||||
|
add_mlir_doc(Passes TorchMLIRConversionStablehloPasses ./ -gen-pass-doc)
|
|
@ -0,0 +1,51 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCHMLIR_CONVERSION_STABLEHLO_PASSES_H
|
||||||
|
#define TORCHMLIR_CONVERSION_STABLEHLO_PASSES_H
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace torch {
|
||||||
|
|
||||||
|
struct StablehloBackendPipelineOptions
|
||||||
|
: public PassPipelineOptions<StablehloBackendPipelineOptions> {
|
||||||
|
Option<bool> enableStaticShape{
|
||||||
|
*this, "enable-static-shape",
|
||||||
|
llvm::cl::desc("Enable static shape conversion."), llvm::cl::init(false)};
|
||||||
|
// The i64 calculation is much slower than i32 on some devices, such as
|
||||||
|
// Nvidia GPU. One can truncate from i64 to i32 since dimension sizes
|
||||||
|
// are unlikely to exceed the range of i32(4GiB)
|
||||||
|
Option<bool> enableI32Index{
|
||||||
|
*this, "enable-i32-index",
|
||||||
|
llvm::cl::desc("Enable truncate index from i64 to i32(unsafely)"),
|
||||||
|
llvm::cl::init(false)};
|
||||||
|
};
|
||||||
|
|
||||||
|
void createTorchBackendToStablehloBackendPipeline(
|
||||||
|
OpPassManager &pm, const StablehloBackendPipelineOptions &options);
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
createVerifyStablehloBackendContractPass();
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
|
createConvertTorchToStablehloPass();
|
||||||
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
|
createConvertTorchToStablehloPass(bool enableStaticShape, bool enableI32Index);
|
||||||
|
|
||||||
|
/// Registers all torch-mlir conversion passes.
|
||||||
|
void registerStablehloConversionPasses();
|
||||||
|
|
||||||
|
} // namespace torch
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TORCHMLIR_CONVERSION_STABLEHLO_PASSES_H
|
|
@ -0,0 +1,43 @@
|
||||||
|
//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCHMLIR_CONVERSION_STABLEHLO_PASSES
|
||||||
|
#define TORCHMLIR_CONVERSION_STABLEHLO_PASSES
|
||||||
|
|
||||||
|
include "mlir/Pass/PassBase.td"
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Torch conversions
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def ConvertTorchToStablehlo : Pass<"convert-torch-to-stablehlo", "func::FuncOp"> {
|
||||||
|
let summary = "Convert Torch ops to Stablehlo ops";
|
||||||
|
let description = [{
|
||||||
|
Convert Torch ops to Stablehlo ops.
|
||||||
|
}];
|
||||||
|
let constructor = "mlir::torch::createConvertTorchToStablehloPass()";
|
||||||
|
|
||||||
|
// Specify any options.
|
||||||
|
let options = [
|
||||||
|
Option<"enableStaticShape", "enable-static-shape", "bool", /*default=*/"false",
|
||||||
|
"Enable static shape conversion">,
|
||||||
|
// The i64 calculation is much slower than i32 on some devices, such as
|
||||||
|
// Nvidia GPU. One can truncate from i64 to i32 since dimension sizes
|
||||||
|
// are unlikely to exceed the range of i32(4GiB)
|
||||||
|
Option<"enableI32Index", "enable-i32-index", "bool", /*default=*/"false",
|
||||||
|
"Enable truncate index from i64 to i32(unsafely)">,
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
|
def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contract", "ModuleOp"> {
|
||||||
|
let summary = "Verifies conformity to the stablehlo backend contract";
|
||||||
|
let constructor = "mlir::torch::createVerifyStablehloBackendContractPass()";
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // TORCHMLIR_CONVERSION_STABLEHLO_PASSES
|
|
@ -1,26 +0,0 @@
|
||||||
//===------------------------------------------------------------*- C++ -*-===//
|
|
||||||
//
|
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
// Also available under a BSD-style license. See LICENSE.
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
#ifndef TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H
|
|
||||||
#define TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H
|
|
||||||
|
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
||||||
#include "mlir/Pass/Pass.h"
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
namespace mlir {
|
|
||||||
namespace torch {
|
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
|
||||||
createConvertTorchToStablehloPass();
|
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
|
||||||
createConvertTorchToStablehloPass(bool enableStaticShape, bool enableI32Index);
|
|
||||||
} // namespace torch
|
|
||||||
} // namespace mlir
|
|
||||||
|
|
||||||
#endif // TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||||
|
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||||
|
add_public_tablegen_target(TorchMLIRConversionTosaPassIncGen)
|
||||||
|
|
||||||
|
add_mlir_doc(Passes TorchMLIRConversionTosaPasses ./ -gen-pass-doc)
|
|
@ -0,0 +1,35 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCHMLIR_CONVERSION_TOSA_PASSES_H
|
||||||
|
#define TORCHMLIR_CONVERSION_TOSA_PASSES_H
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace torch {
|
||||||
|
|
||||||
|
/// Creates a pipeline that lowers from the torch backend contract to the
|
||||||
|
/// TOSA backend contract.
|
||||||
|
void createTorchBackendToTosaBackendPipeline(OpPassManager &pm);
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>> createVerifyTosaBackendContractPass();
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTosaPass();
|
||||||
|
|
||||||
|
/// Registers all torch-mlir conversion passes.
|
||||||
|
void registerTosaConversionPasses();
|
||||||
|
|
||||||
|
} // namespace torch
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TORCHMLIR_CONVERSION_PASSES_H
|
|
@ -0,0 +1,33 @@
|
||||||
|
//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCHMLIR_CONVERSION_TOSA_PASSES
|
||||||
|
#define TORCHMLIR_CONVERSION_TOSA_PASSES
|
||||||
|
|
||||||
|
include "mlir/Pass/PassBase.td"
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Torch conversions
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> {
|
||||||
|
let summary = "Convert Torch ops to TOSA ops";
|
||||||
|
let description = [{
|
||||||
|
This pass assumes that TOSA ops are responsible for emitting error
|
||||||
|
guards in case of shape mismatches.
|
||||||
|
}];
|
||||||
|
let constructor = "mlir::torch::createConvertTorchToTosaPass()";
|
||||||
|
}
|
||||||
|
|
||||||
|
def VerifyTosaBackendContract : Pass<"torch-verify-tosa-backend-contract", "ModuleOp"> {
|
||||||
|
let summary = "Verifies conformity to the linalg-on-tensors backend contract";
|
||||||
|
let constructor = "mlir::torch::createVerifyTosaBackendContractPass()";
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // TORCHMLIR_CONVERSION_TOSA_PASSES
|
|
@ -21,8 +21,6 @@ class ModuleOp;
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace Torch {
|
namespace Torch {
|
||||||
|
|
||||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc"
|
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createGlobalizeObjectGraphPass();
|
std::unique_ptr<OperationPass<ModuleOp>> createGlobalizeObjectGraphPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
@ -141,13 +139,6 @@ static const char kTorchOpPrefix[] = R"(torch.)";
|
||||||
/// Registers all Torch transformation passes.
|
/// Registers all Torch transformation passes.
|
||||||
void registerTorchPasses();
|
void registerTorchPasses();
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Pass registration
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
#define GEN_PASS_REGISTRATION
|
|
||||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc"
|
|
||||||
|
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -22,36 +22,6 @@ class ModuleOp;
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace TorchConversion {
|
namespace TorchConversion {
|
||||||
|
|
||||||
/// Creates a pipeline that lowers from the torch backend contract to the
|
|
||||||
/// linalg-on-tensors backend contract.
|
|
||||||
void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm);
|
|
||||||
|
|
||||||
/// Creates a pipeline that lowers from the torch backend contract to the
|
|
||||||
/// TOSA backend contract.
|
|
||||||
void createTorchBackendToTosaBackendPipeline(OpPassManager &pm);
|
|
||||||
|
|
||||||
// Do not register the stablehlo options if the stablehlo target is disabled
|
|
||||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
|
||||||
struct StablehloBackendPipelineOptions
|
|
||||||
: public PassPipelineOptions<StablehloBackendPipelineOptions> {
|
|
||||||
Option<bool> enableStaticShape{
|
|
||||||
*this, "enable-static-shape",
|
|
||||||
llvm::cl::desc("Enable static shape conversion."), llvm::cl::init(false)};
|
|
||||||
// The i64 calculation is much slower than i32 on some devices, such as
|
|
||||||
// Nvidia GPU. One can truncate from i64 to i32 since dimension sizes
|
|
||||||
// are unlikely to exceed the range of i32(4GiB)
|
|
||||||
Option<bool> enableI32Index{
|
|
||||||
*this, "enable-i32-index",
|
|
||||||
llvm::cl::desc("Enable truncate index from i64 to i32(unsafely)"),
|
|
||||||
llvm::cl::init(false)};
|
|
||||||
};
|
|
||||||
|
|
||||||
void createTorchBackendToStablehloBackendPipeline(
|
|
||||||
OpPassManager &pm, const StablehloBackendPipelineOptions &options);
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
|
||||||
createVerifyStablehloBackendContractPass();
|
|
||||||
#endif
|
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();
|
std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
|
@ -65,11 +35,6 @@ createFinalizingBackendTypeConversionPass();
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>> createUnpackQuantTensorPass();
|
std::unique_ptr<OperationPass<func::FuncOp>> createUnpackQuantTensorPass();
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>> createConvertCustomQuantOpPass();
|
std::unique_ptr<OperationPass<func::FuncOp>> createConvertCustomQuantOpPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
|
||||||
createVerifyLinalgOnTensorsBackendContractPass();
|
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createVerifyTosaBackendContractPass();
|
|
||||||
|
|
||||||
} // namespace TorchConversion
|
} // namespace TorchConversion
|
||||||
|
|
||||||
/// Registers all Torch transformation passes.
|
/// Registers all Torch transformation passes.
|
||||||
|
|
|
@ -32,23 +32,6 @@ def FinalizingBackendTypeConversion
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def VerifyLinalgOnTensorsBackendContract : Pass<"torch-verify-linalg-on-tensors-backend-contract", "ModuleOp"> {
|
|
||||||
let summary = "Verifies conformity to the linalg-on-tensors backend contract";
|
|
||||||
let constructor = "mlir::torch::TorchConversion::createVerifyLinalgOnTensorsBackendContractPass()";
|
|
||||||
}
|
|
||||||
|
|
||||||
def VerifyTosaBackendContract : Pass<"torch-verify-tosa-backend-contract", "ModuleOp"> {
|
|
||||||
let summary = "Verifies conformity to the linalg-on-tensors backend contract";
|
|
||||||
let constructor = "mlir::torch::TorchConversion::createVerifyTosaBackendContractPass()";
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
|
||||||
def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contract", "ModuleOp"> {
|
|
||||||
let summary = "Verifies conformity to the stablehlo backend contract";
|
|
||||||
let constructor = "mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass()";
|
|
||||||
}
|
|
||||||
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
|
||||||
|
|
||||||
// The following passes are for a one-off conversion of a specific kind of quantized group matmul.
|
// The following passes are for a one-off conversion of a specific kind of quantized group matmul.
|
||||||
// They should not be included in default lowering flows until further along.
|
// They should not be included in default lowering flows until further along.
|
||||||
def UnpackQuantTensor : Pass<"torch-unpack-quant-tensor", "func::FuncOp"> {
|
def UnpackQuantTensor : Pass<"torch-unpack-quant-tensor", "func::FuncOp"> {
|
||||||
|
|
|
@ -8,6 +8,9 @@ add_mlir_public_c_api_library(TorchMLIRCAPI
|
||||||
ADDITIONAL_HEADER_DIRS
|
ADDITIONAL_HEADER_DIRS
|
||||||
${PROJECT_SOURCE_DIR}/include/torch-mlir-c/
|
${PROJECT_SOURCE_DIR}/include/torch-mlir-c/
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
TorchMLIRTorchPassIncGen
|
||||||
|
|
||||||
ENABLE_AGGREGATION
|
ENABLE_AGGREGATION
|
||||||
LINK_COMPONENTS
|
LINK_COMPONENTS
|
||||||
Core
|
Core
|
||||||
|
|
|
@ -9,6 +9,11 @@
|
||||||
#include "mlir/CAPI/Pass.h"
|
#include "mlir/CAPI/Pass.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
#define GEN_PASS_REGISTRATION
|
||||||
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc"
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Must include the declarations as they carry important visibility attributes.
|
// Must include the declarations as they carry important visibility attributes.
|
||||||
#include "torch-mlir/Dialect/Torch/Transforms/Transforms.capi.h.inc"
|
#include "torch-mlir/Dialect/Torch/Transforms/Transforms.capi.h.inc"
|
||||||
|
|
||||||
|
|
|
@ -23,8 +23,16 @@ set(LinkedLibs
|
||||||
TorchMLIRRefBackend
|
TorchMLIRRefBackend
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Conditionally link in backends if enabled.
|
||||||
|
if(TORCH_MLIR_ENABLE_LINALG)
|
||||||
|
list(APPEND LinkedLibs TorchMLIRTorchToLinalg)
|
||||||
|
endif()
|
||||||
|
if(TORCH_MLIR_ENABLE_TOSA)
|
||||||
|
list(APPEND LinkedLibs TorchMLIRTorchToTosa)
|
||||||
|
endif()
|
||||||
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
||||||
list(APPEND LinkedLibs
|
list(APPEND LinkedLibs
|
||||||
|
TorchMLIRTorchToStablehlo
|
||||||
MhloPasses
|
MhloPasses
|
||||||
MhloToLinalg
|
MhloToLinalg
|
||||||
StablehloToMhlo
|
StablehloToMhlo
|
||||||
|
|
|
@ -1,36 +1,18 @@
|
||||||
add_subdirectory(TorchToLinalg)
|
if(TORCH_MLIR_ENABLE_LINALG)
|
||||||
add_subdirectory(TorchToSCF)
|
add_subdirectory(TorchToLinalg)
|
||||||
add_subdirectory(TorchToArith)
|
endif()
|
||||||
add_subdirectory(TorchToTosa)
|
if(TORCH_MLIR_ENABLE_TOSA)
|
||||||
|
add_subdirectory(TorchToTosa)
|
||||||
|
endif()
|
||||||
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
||||||
add_subdirectory(TorchToStablehlo)
|
add_subdirectory(TorchToStablehlo)
|
||||||
endif()
|
endif()
|
||||||
add_subdirectory(TorchToTMTensor)
|
|
||||||
add_subdirectory(TorchConversionToMLProgram)
|
|
||||||
add_subdirectory(Utils)
|
add_subdirectory(Utils)
|
||||||
|
|
||||||
# TODO: Automate this with add_torch_mlir_conversion_library.
|
|
||||||
set(linked_libs TorchMLIRTorchToLinalg
|
|
||||||
TorchMLIRTorchToSCF
|
|
||||||
TorchMLIRTorchToArith
|
|
||||||
TorchMLIRTorchToTosa
|
|
||||||
TorchMLIRTorchToTMTensor
|
|
||||||
TorchMLIRTorchConversionToMLProgram
|
|
||||||
TorchMLIRConversionUtils)
|
|
||||||
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
|
||||||
list(APPEND linked_libs TorchMLIRTorchToStablehlo)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
add_mlir_library(TorchMLIRConversionPasses
|
add_mlir_library(TorchMLIRConversionPasses
|
||||||
Passes.cpp
|
Passes.cpp
|
||||||
|
|
||||||
DEPENDS
|
|
||||||
TorchMLIRConversionPassIncGen
|
|
||||||
|
|
||||||
LINK_COMPONENTS
|
|
||||||
Core
|
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
${linked_libs}
|
TorchMLIRConversionUtils
|
||||||
#${torch_mlir_conversion_libs}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
add_subdirectory(Common)
|
|
@ -0,0 +1,28 @@
|
||||||
|
add_mlir_conversion_library(TorchMLIRConversionCommon
|
||||||
|
TorchToArith.cpp
|
||||||
|
TorchToSCF.cpp
|
||||||
|
|
||||||
|
ADDITIONAL_HEADER_DIRS
|
||||||
|
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
TorchMLIRConversionCommonPassIncGen
|
||||||
|
|
||||||
|
LINK_COMPONENTS
|
||||||
|
Core
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRArithDialect
|
||||||
|
MLIRIR
|
||||||
|
MLIRFuncDialect
|
||||||
|
MLIRPass
|
||||||
|
MLIRMathDialect
|
||||||
|
MLIRSCFDialect
|
||||||
|
MLIRTransforms
|
||||||
|
TorchMLIRTorchDialect
|
||||||
|
TorchMLIRTorchConversionDialect
|
||||||
|
TorchMLIRTorchConversionPasses
|
||||||
|
TorchMLIRTorchUtils
|
||||||
|
)
|
||||||
|
|
||||||
|
torch_mlir_target_includes(TorchMLIRConversionCommon)
|
|
@ -7,7 +7,7 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
|
#include "torch-mlir/Conversion/Passes.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
@ -43,7 +43,8 @@ public:
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(AtenDimOp op, OpAdaptor adaptor,
|
matchAndRewrite(AtenDimOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto rank = rewriter.create<tensor::RankOp>(op->getLoc(), adaptor.getSelf());
|
auto rank =
|
||||||
|
rewriter.create<tensor::RankOp>(op->getLoc(), adaptor.getSelf());
|
||||||
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
|
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
|
||||||
op, getTypeConverter()->convertType(op.getType()), rank);
|
op, getTypeConverter()->convertType(op.getType()), rank);
|
||||||
return success();
|
return success();
|
||||||
|
@ -74,7 +75,8 @@ public:
|
||||||
matchAndRewrite(AtenOp op,
|
matchAndRewrite(AtenOp op,
|
||||||
typename OpConversionPattern<AtenOp>::OpAdaptor adaptor,
|
typename OpConversionPattern<AtenOp>::OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
rewriter.template replaceOpWithNewOp<BinOp>(op, adaptor.getA(), adaptor.getB());
|
rewriter.template replaceOpWithNewOp<BinOp>(op, adaptor.getA(),
|
||||||
|
adaptor.getB());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -112,10 +114,10 @@ public:
|
||||||
typename OpConversionPattern<AtenDivIntOp>::OpAdaptor adaptor,
|
typename OpConversionPattern<AtenDivIntOp>::OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value a =
|
Value a = convertScalarToDtype(rewriter, loc, adaptor.getA(),
|
||||||
convertScalarToDtype(rewriter, loc, adaptor.getA(), rewriter.getF64Type());
|
rewriter.getF64Type());
|
||||||
Value b =
|
Value b = convertScalarToDtype(rewriter, loc, adaptor.getB(),
|
||||||
convertScalarToDtype(rewriter, loc, adaptor.getB(), rewriter.getF64Type());
|
rewriter.getF64Type());
|
||||||
rewriter.replaceOpWithNewOp<arith::DivFOp>(op, a, b);
|
rewriter.replaceOpWithNewOp<arith::DivFOp>(op, a, b);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -180,7 +182,8 @@ public:
|
||||||
}));
|
}));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
if (auto elements = op.getValueAttr().dyn_cast<DenseResourceElementsAttr>()) {
|
if (auto elements =
|
||||||
|
op.getValueAttr().dyn_cast<DenseResourceElementsAttr>()) {
|
||||||
if (auto type = elements.getType().dyn_cast<RankedTensorType>()) {
|
if (auto type = elements.getType().dyn_cast<RankedTensorType>()) {
|
||||||
if (auto intType = type.getElementType().dyn_cast<IntegerType>()) {
|
if (auto intType = type.getElementType().dyn_cast<IntegerType>()) {
|
||||||
Type builtinTensorElemTy =
|
Type builtinTensorElemTy =
|
||||||
|
@ -357,7 +360,8 @@ public:
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertTorchToArith : public ConvertTorchToArithBase<ConvertTorchToArith> {
|
class ConvertTorchToArith
|
||||||
|
: public ConvertTorchToArithBase<ConvertTorchToArith> {
|
||||||
public:
|
public:
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
registry.insert<func::FuncDialect>();
|
registry.insert<func::FuncDialect>();
|
|
@ -7,7 +7,7 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
|
#include "torch-mlir/Conversion/Passes.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
@ -9,22 +9,6 @@
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/Passes.h"
|
#include "torch-mlir/Conversion/Passes.h"
|
||||||
|
|
||||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
|
||||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
|
||||||
#include "transforms/passes.h"
|
|
||||||
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
|
||||||
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
|
|
||||||
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
|
|
||||||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
|
||||||
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
|
||||||
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Pass registration
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
#define GEN_PASS_REGISTRATION
|
#define GEN_PASS_REGISTRATION
|
||||||
#include "torch-mlir/Conversion/Passes.h.inc"
|
#include "torch-mlir/Conversion/Passes.h.inc"
|
||||||
|
|
|
@ -1,22 +0,0 @@
|
||||||
add_mlir_conversion_library(TorchMLIRTorchConversionToMLProgram
|
|
||||||
TorchConversionToMLProgram.cpp
|
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
|
||||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchConversionToMLProgram
|
|
||||||
|
|
||||||
DEPENDS
|
|
||||||
TorchMLIRConversionPassIncGen
|
|
||||||
|
|
||||||
LINK_COMPONENTS
|
|
||||||
Core
|
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
|
||||||
MLIRIR
|
|
||||||
MLIRLinalgDialect
|
|
||||||
MLIRMLProgramDialect
|
|
||||||
MLIRMathDialect
|
|
||||||
MLIRPass
|
|
||||||
TorchMLIRTorchDialect
|
|
||||||
)
|
|
||||||
|
|
||||||
torch_mlir_target_includes(TorchMLIRTorchConversionToMLProgram)
|
|
|
@ -1,20 +0,0 @@
|
||||||
add_mlir_conversion_library(TorchMLIRTorchToArith
|
|
||||||
TorchToArith.cpp
|
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
|
||||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToArith
|
|
||||||
|
|
||||||
DEPENDS
|
|
||||||
TorchMLIRConversionPassIncGen
|
|
||||||
|
|
||||||
LINK_COMPONENTS
|
|
||||||
Core
|
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
|
||||||
MLIRIR
|
|
||||||
MLIRPass
|
|
||||||
MLIRFuncDialect
|
|
||||||
TorchMLIRTorchDialect
|
|
||||||
)
|
|
||||||
|
|
||||||
torch_mlir_target_includes(TorchMLIRTorchToArith)
|
|
|
@ -1,4 +1,5 @@
|
||||||
add_mlir_conversion_library(TorchMLIRTorchToLinalg
|
add_mlir_conversion_library(TorchMLIRTorchToLinalg
|
||||||
|
Passes.cpp
|
||||||
DataMovement.cpp
|
DataMovement.cpp
|
||||||
IndirectDataMovement.cpp
|
IndirectDataMovement.cpp
|
||||||
Linear.cpp
|
Linear.cpp
|
||||||
|
@ -7,25 +8,37 @@ add_mlir_conversion_library(TorchMLIRTorchToLinalg
|
||||||
Reduction.cpp
|
Reduction.cpp
|
||||||
TensorConstructors.cpp
|
TensorConstructors.cpp
|
||||||
TensorScalarInterop.cpp
|
TensorScalarInterop.cpp
|
||||||
|
TorchConversionToMLProgram.cpp
|
||||||
TorchToLinalg.cpp
|
TorchToLinalg.cpp
|
||||||
|
TorchToTMTensor.cpp
|
||||||
Uncategorized.cpp
|
Uncategorized.cpp
|
||||||
Utils.cpp
|
Utils.cpp
|
||||||
|
VerifyLinalgOnTensorsBackendContract.cpp
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
ADDITIONAL_HEADER_DIRS
|
||||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToLinalg
|
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToLinalg
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
TorchMLIRConversionPassIncGen
|
TorchMLIRConversionLinalgPassIncGen
|
||||||
|
|
||||||
LINK_COMPONENTS
|
LINK_COMPONENTS
|
||||||
Core
|
Core
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
MLIRIR
|
MLIRIR
|
||||||
|
MLIRFuncDialect
|
||||||
MLIRPass
|
MLIRPass
|
||||||
MLIRLinalgDialect
|
MLIRLinalgDialect
|
||||||
MLIRMathDialect
|
MLIRMathDialect
|
||||||
|
MLIRMLProgramDialect
|
||||||
|
MLIRSCFDialect
|
||||||
|
MLIRTransforms
|
||||||
|
TorchMLIRConversionCommon
|
||||||
TorchMLIRTorchDialect
|
TorchMLIRTorchDialect
|
||||||
|
TorchMLIRTorchConversionDialect
|
||||||
|
TorchMLIRTorchConversionPasses
|
||||||
|
TorchMLIRTorchUtils
|
||||||
|
TorchMLIRTMTensorDialect
|
||||||
)
|
)
|
||||||
|
|
||||||
torch_mlir_target_includes(TorchMLIRTorchToLinalg)
|
torch_mlir_target_includes(TorchMLIRTorchToLinalg)
|
||||||
|
|
|
@ -7,13 +7,7 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "./PassDetail.h"
|
||||||
#include "mlir/IR/TypeSupport.h"
|
|
||||||
#include "mlir/Support/LogicalResult.h"
|
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
|
||||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
#include "Utils.h"
|
#include "Utils.h"
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
@ -21,7 +15,12 @@
|
||||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
|
#include "mlir/IR/TypeSupport.h"
|
||||||
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "torch-mlir/Conversion/TorchToLinalg/Passes.h"
|
||||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
|
|
@ -7,9 +7,9 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
#include "torch-mlir/Conversion/TorchToLinalg/Passes.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "./PassDetail.h"
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
#include "Utils.h"
|
#include "Utils.h"
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
|
|
@ -7,9 +7,9 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
#include "torch-mlir/Conversion/TorchToLinalg/Passes.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "./PassDetail.h"
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
#include "Utils.h"
|
#include "Utils.h"
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
//===------------------------------------------------------------*- C++ -*-===//
|
//===- PassDetail.h - Conversion Pass class details -------------*- C++ -*-===//
|
||||||
//
|
//
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
@ -7,17 +7,20 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef TORCHMLIR_CONVERSION_ATENTOSTD_ATENTOSTD_H
|
#ifndef TORCHMLIR_CONVERSION_LINALG_PASSDETAIL_H
|
||||||
#define TORCHMLIR_CONVERSION_ATENTOSTD_ATENTOSTD_H
|
#define TORCHMLIR_CONVERSION_LINALG_PASSDETAIL_H
|
||||||
|
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace torch {
|
namespace torch {
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToArithPass();
|
|
||||||
}
|
|
||||||
} // namespace mlir
|
|
||||||
|
|
||||||
#endif // TORCHMLIR_CONVERSION_ATENTOSTD_ATENTOSTD_H
|
#define GEN_PASS_CLASSES
|
||||||
|
#include "torch-mlir/Conversion/TorchToLinalg/Passes.h.inc"
|
||||||
|
|
||||||
|
} // namespace torch
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // TORCHMLIR_CONVERSION_LINALG_PASSDETAIL_H
|
|
@ -0,0 +1,75 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "torch-mlir/Conversion/TorchToLinalg/Passes.h"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/Transforms/Passes.h"
|
||||||
|
#include "mlir/Dialect/Linalg/Passes.h"
|
||||||
|
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
||||||
|
#include "mlir/Pass/PassManager.h"
|
||||||
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
#include "torch-mlir/Conversion/Passes.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::torch;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pass registration
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
#define GEN_PASS_REGISTRATION
|
||||||
|
#include "torch-mlir/Conversion/TorchToLinalg/Passes.h.inc"
|
||||||
|
} // end namespace
|
||||||
|
|
||||||
|
void mlir::torch::registerLinalgConversionPasses() {
|
||||||
|
::registerPasses();
|
||||||
|
mlir::PassPipelineRegistration<>(
|
||||||
|
"torch-backend-to-linalg-on-tensors-backend-pipeline",
|
||||||
|
"Pipeline lowering torch backend contract to linalg-on-tensors backend "
|
||||||
|
"contract.",
|
||||||
|
createTorchBackendToLinalgOnTensorsBackendPipeline);
|
||||||
|
}
|
||||||
|
|
||||||
|
void mlir::torch::createTorchBackendToLinalgOnTensorsBackendPipeline(
|
||||||
|
OpPassManager &pm) {
|
||||||
|
// Lower to linalg + guards which is the input to codegen backends.
|
||||||
|
// We do this first as it tends to involve pattern-matching against constants,
|
||||||
|
// (e.g. dimensions which must be constant in a ranked programming model)
|
||||||
|
// and those constants get somewhat obscured by TorchToArith.
|
||||||
|
pm.addNestedPass<func::FuncOp>(createConvertTorchToTMTensorPass());
|
||||||
|
pm.addNestedPass<func::FuncOp>(createConvertTorchToLinalgPass());
|
||||||
|
pm.addNestedPass<func::FuncOp>(createConvertTorchToSCFPass());
|
||||||
|
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());
|
||||||
|
pm.addPass(createConvertTorchConversionToMLProgramPass());
|
||||||
|
pm.addNestedPass<func::FuncOp>(memref::createExpandOpsPass());
|
||||||
|
|
||||||
|
// Clean up any non-canonical code introduced above..
|
||||||
|
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||||
|
// Resolve `dim` ops on tensors (which currently live in the `memref`
|
||||||
|
// dialect for some reason -- we don't have memrefs at this level).
|
||||||
|
pm.addNestedPass<func::FuncOp>(
|
||||||
|
memref::createResolveShapedTypeResultDimsPass());
|
||||||
|
// The resolution of `dim` ops tends to create identical ops. CSE them.
|
||||||
|
pm.addNestedPass<func::FuncOp>(createCSEPass());
|
||||||
|
|
||||||
|
// Finish the type conversion from `torch` types to the types of the
|
||||||
|
// linalg-on-tensors backend contract.
|
||||||
|
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
||||||
|
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||||
|
pm.addNestedPass<func::FuncOp>(
|
||||||
|
TorchConversion::createFinalizingBackendTypeConversionPass());
|
||||||
|
|
||||||
|
// Verify that we have lowered to the form that linalg on tensors backends
|
||||||
|
// expect. This fails compilation (signalPassFailure) if the IR is not in the
|
||||||
|
// correct form.
|
||||||
|
pm.addPass(createVerifyLinalgOnTensorsBackendContractPass());
|
||||||
|
}
|
|
@ -7,9 +7,9 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
#include "torch-mlir/Conversion/TorchToLinalg/Passes.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "./PassDetail.h"
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
#include "Utils.h"
|
#include "Utils.h"
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
|
|
@ -7,9 +7,9 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
#include "torch-mlir/Conversion/TorchToLinalg/Passes.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "./PassDetail.h"
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
#include "Utils.h"
|
#include "Utils.h"
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
|
|
@ -7,9 +7,9 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
#include "torch-mlir/Conversion/TorchToLinalg/Passes.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "./PassDetail.h"
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
#include "Utils.h"
|
#include "Utils.h"
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
|
|
@ -7,9 +7,9 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
#include "torch-mlir/Conversion/TorchToLinalg/Passes.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "./PassDetail.h"
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
#include "Utils.h"
|
#include "Utils.h"
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
|
|
@ -7,9 +7,9 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
#include "torch-mlir/Conversion/TorchToLinalg/Passes.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "./PassDetail.h"
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
|
|
|
@ -7,9 +7,9 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
|
#include "torch-mlir/Conversion/TorchToLinalg/Passes.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "./PassDetail.h"
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
|
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
@ -82,7 +82,8 @@ public:
|
||||||
// temp = multiplier * currentSeed + incrementStep
|
// temp = multiplier * currentSeed + incrementStep
|
||||||
Value mul = rewriter.create<arith::MulIOp>(loc, currentSeed, multiplier);
|
Value mul = rewriter.create<arith::MulIOp>(loc, currentSeed, multiplier);
|
||||||
Value seed = rewriter.create<arith::AddIOp>(loc, mul, incrementStep);
|
Value seed = rewriter.create<arith::AddIOp>(loc, mul, incrementStep);
|
||||||
globalVar = rewriter.create<tensor::InsertOp>(loc, seed, globalVar, ValueRange());
|
globalVar =
|
||||||
|
rewriter.create<tensor::InsertOp>(loc, seed, globalVar, ValueRange());
|
||||||
rewriter.create<ml_program::GlobalStoreOp>(
|
rewriter.create<ml_program::GlobalStoreOp>(
|
||||||
loc, SymbolRefAttr::get(op->getContext(), getSeedGobalVarName()),
|
loc, SymbolRefAttr::get(op->getContext(), getSeedGobalVarName()),
|
||||||
globalVar);
|
globalVar);
|
|
@ -7,9 +7,9 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
#include "torch-mlir/Conversion/TorchToLinalg/Passes.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "./PassDetail.h"
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Complex/IR/Complex.h"
|
#include "mlir/Dialect/Complex/IR/Complex.h"
|
||||||
|
|
|
@ -7,9 +7,9 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
#include "torch-mlir/Conversion/TorchToLinalg/Passes.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "./PassDetail.h"
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
@ -7,9 +7,9 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
#include "torch-mlir/Conversion/TorchToLinalg/Passes.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "./PassDetail.h"
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
#include "Utils.h"
|
#include "Utils.h"
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
|
|
@ -7,9 +7,9 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "Utils.h"
|
#include "./Utils.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "./PassDetail.h"
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
|
|
|
@ -7,7 +7,8 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "PassDetail.h"
|
#include "./PassDetail.h"
|
||||||
|
#include "torch-mlir/Conversion/TorchToLinalg/Passes.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
@ -24,7 +25,6 @@
|
||||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
||||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
|
||||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
|
||||||
|
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
|
||||||
|
@ -33,7 +33,6 @@ using namespace mlir::torch;
|
||||||
using namespace mlir::torch::TorchConversion;
|
using namespace mlir::torch::TorchConversion;
|
||||||
using namespace TMTensor;
|
using namespace TMTensor;
|
||||||
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class VerifyLinalgOnTensorsBackendContractPass
|
class VerifyLinalgOnTensorsBackendContractPass
|
||||||
: public VerifyLinalgOnTensorsBackendContractBase<
|
: public VerifyLinalgOnTensorsBackendContractBase<
|
||||||
|
@ -96,7 +95,8 @@ class VerifyLinalgOnTensorsBackendContractPass
|
||||||
// We avoid `module.emitError()` so that mlir-print-op-on-diagnostics
|
// We avoid `module.emitError()` so that mlir-print-op-on-diagnostics
|
||||||
// doesn't unnecessarily spew out the entire module.
|
// doesn't unnecessarily spew out the entire module.
|
||||||
emitError(module.getLoc())
|
emitError(module.getLoc())
|
||||||
<< "Module does not conform to the linalg-on-tensors backend contract. "
|
<< "Module does not conform to the linalg-on-tensors backend "
|
||||||
|
"contract. "
|
||||||
"See dialect conversion legality information above.";
|
"See dialect conversion legality information above.";
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
}
|
}
|
||||||
|
@ -105,6 +105,6 @@ class VerifyLinalgOnTensorsBackendContractPass
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
mlir::torch::TorchConversion::createVerifyLinalgOnTensorsBackendContractPass() {
|
mlir::torch::createVerifyLinalgOnTensorsBackendContractPass() {
|
||||||
return std::make_unique<VerifyLinalgOnTensorsBackendContractPass>();
|
return std::make_unique<VerifyLinalgOnTensorsBackendContractPass>();
|
||||||
}
|
}
|
|
@ -1,22 +0,0 @@
|
||||||
add_mlir_conversion_library(TorchMLIRTorchToSCF
|
|
||||||
TorchToSCF.cpp
|
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
|
||||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToSCF
|
|
||||||
|
|
||||||
DEPENDS
|
|
||||||
TorchMLIRConversionPassIncGen
|
|
||||||
|
|
||||||
LINK_COMPONENTS
|
|
||||||
Core
|
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
|
||||||
MLIRIR
|
|
||||||
MLIRPass
|
|
||||||
MLIRSCFDialect
|
|
||||||
MLIRFuncDialect
|
|
||||||
TorchMLIRTorchDialect
|
|
||||||
TorchMLIRTorchConversionDialect
|
|
||||||
)
|
|
||||||
|
|
||||||
torch_mlir_target_includes(TorchMLIRTorchToSCF)
|
|
|
@ -7,9 +7,9 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "./PassDetail.h"
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
add_mlir_conversion_library(TorchMLIRTorchToStablehlo
|
add_mlir_conversion_library(TorchMLIRTorchToStablehlo
|
||||||
|
Passes.cpp
|
||||||
TorchToStablehlo.cpp
|
TorchToStablehlo.cpp
|
||||||
StablehloLegalizeUtils.cpp
|
StablehloLegalizeUtils.cpp
|
||||||
Basic.cpp
|
Basic.cpp
|
||||||
|
@ -7,21 +8,25 @@ add_mlir_conversion_library(TorchMLIRTorchToStablehlo
|
||||||
ViewLike.cpp
|
ViewLike.cpp
|
||||||
Reduction.cpp
|
Reduction.cpp
|
||||||
Pooling.cpp
|
Pooling.cpp
|
||||||
|
VerifyStablehloBackendContract.cpp
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
ADDITIONAL_HEADER_DIRS
|
||||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToStablehlo
|
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToStablehlo
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
TorchMLIRConversionPassIncGen
|
TorchMLIRConversionStablehloPassIncGen
|
||||||
|
|
||||||
LINK_COMPONENTS
|
LINK_COMPONENTS
|
||||||
Core
|
Core
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRBufferTransforms
|
||||||
MLIRIR
|
MLIRIR
|
||||||
MLIRPass
|
MLIRPass
|
||||||
MLIRBufferTransforms
|
MLIRTransforms
|
||||||
StablehloOps
|
StablehloOps
|
||||||
|
TorchMLIRConversionCommon
|
||||||
|
TorchMLIRTorchConversionPasses
|
||||||
TorchMLIRTorchDialect
|
TorchMLIRTorchDialect
|
||||||
TorchMLIRConversionUtils
|
TorchMLIRConversionUtils
|
||||||
)
|
)
|
||||||
|
|
|
@ -7,9 +7,9 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "./PassDetail.h"
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
|
|
@ -7,9 +7,9 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "./PassDetail.h"
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
//===------------------------------------------------------------*- C++ -*-===//
|
//===- PassDetail.h - Conversion Pass class details -------------*- C++ -*-===//
|
||||||
//
|
//
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
@ -7,17 +7,20 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H
|
#ifndef TORCHMLIR_CONVERSION_STABLEHLO_PASSDETAIL_H
|
||||||
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H
|
#define TORCHMLIR_CONVERSION_STABLEHLO_PASSDETAIL_H
|
||||||
|
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace torch {
|
namespace torch {
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTosaPass();
|
|
||||||
}
|
|
||||||
} // namespace mlir
|
|
||||||
|
|
||||||
#endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H
|
#define GEN_PASS_CLASSES
|
||||||
|
#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h.inc"
|
||||||
|
|
||||||
|
} // namespace torch
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // TORCHMLIR_CONVERSION_STABLEHLO_PASSDETAIL_H
|
|
@ -0,0 +1,62 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h"
|
||||||
|
|
||||||
|
#include "mlir/Pass/PassManager.h"
|
||||||
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
#include "torch-mlir/Conversion/Passes.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
||||||
|
#include "transforms/passes.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::torch;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pass registration
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
#define GEN_PASS_REGISTRATION
|
||||||
|
#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h.inc"
|
||||||
|
} // end namespace
|
||||||
|
|
||||||
|
void mlir::torch::registerStablehloConversionPasses() {
|
||||||
|
::registerPasses();
|
||||||
|
mlir::PassPipelineRegistration<StablehloBackendPipelineOptions>(
|
||||||
|
"torch-backend-to-stablehlo-backend-pipeline",
|
||||||
|
"Pipeline lowering torch backend contract to StableHLO backend "
|
||||||
|
"contract.",
|
||||||
|
createTorchBackendToStablehloBackendPipeline);
|
||||||
|
}
|
||||||
|
|
||||||
|
void mlir::torch::createTorchBackendToStablehloBackendPipeline(
|
||||||
|
OpPassManager &pm, const StablehloBackendPipelineOptions &options) {
|
||||||
|
// Generate Stablehlo ops.
|
||||||
|
pm.addNestedPass<func::FuncOp>(createConvertTorchToStablehloPass(
|
||||||
|
options.enableStaticShape, options.enableI32Index));
|
||||||
|
// Lowering remained ops to Arith
|
||||||
|
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());
|
||||||
|
|
||||||
|
// Clean up any non-canonical code introduced above..
|
||||||
|
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||||
|
// The resolution of `dim` ops tends to create identical ops. CSE them.
|
||||||
|
pm.addNestedPass<func::FuncOp>(createCSEPass());
|
||||||
|
|
||||||
|
// Finish the type conversion from `torch` types to the types of the
|
||||||
|
// StableHLO backend contract.
|
||||||
|
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
||||||
|
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||||
|
pm.addNestedPass<func::FuncOp>(
|
||||||
|
TorchConversion::createFinalizingBackendTypeConversionPass());
|
||||||
|
|
||||||
|
// Verify that we have lowered to Stablehlo and Chlo ops.
|
||||||
|
pm.addPass(createVerifyStablehloBackendContractPass());
|
||||||
|
}
|
|
@ -7,9 +7,9 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "./PassDetail.h"
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
|
|
@ -7,9 +7,9 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "./PassDetail.h"
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "stablehlo/dialect/StablehloOps.h"
|
#include "stablehlo/dialect/StablehloOps.h"
|
||||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
|
@ -7,9 +7,9 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "./PassDetail.h"
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
|
|
@ -6,8 +6,9 @@
|
||||||
// Also available under a BSD-style license. See LICENSE.
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
|
||||||
#include "PassDetail.h"
|
#include "./PassDetail.h"
|
||||||
|
#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
@ -18,11 +19,9 @@
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "stablehlo/dialect/ChloOps.h"
|
#include "stablehlo/dialect/ChloOps.h"
|
||||||
#include "stablehlo/dialect/StablehloOps.h"
|
#include "stablehlo/dialect/StablehloOps.h"
|
||||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch::TorchConversion;
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class VerifyStablehloBackendContractPass
|
class VerifyStablehloBackendContractPass
|
||||||
|
@ -45,7 +44,8 @@ class VerifyStablehloBackendContractPass
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
|
|
||||||
// Structural operations.
|
// Structural operations.
|
||||||
target.addDynamicallyLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>(opHasLegalTypes);
|
target.addDynamicallyLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>(
|
||||||
|
opHasLegalTypes);
|
||||||
// Shape operations.
|
// Shape operations.
|
||||||
target.addDynamicallyLegalOp<shape::ShapeOfOp>(opHasLegalTypes);
|
target.addDynamicallyLegalOp<shape::ShapeOfOp>(opHasLegalTypes);
|
||||||
|
|
||||||
|
@ -58,7 +58,6 @@ class VerifyStablehloBackendContractPass
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass() {
|
mlir::torch::createVerifyStablehloBackendContractPass() {
|
||||||
return std::make_unique<VerifyStablehloBackendContractPass>();
|
return std::make_unique<VerifyStablehloBackendContractPass>();
|
||||||
}
|
}
|
||||||
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
|
|
@ -7,9 +7,9 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "./PassDetail.h"
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
|
|
@ -1,23 +0,0 @@
|
||||||
add_mlir_conversion_library(TorchMLIRTorchToTMTensor
|
|
||||||
TorchToTMTensor.cpp
|
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
|
||||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToTMTensor
|
|
||||||
|
|
||||||
DEPENDS
|
|
||||||
TorchMLIRConversionPassIncGen
|
|
||||||
|
|
||||||
LINK_COMPONENTS
|
|
||||||
Core
|
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
|
||||||
MLIRIR
|
|
||||||
MLIRPass
|
|
||||||
MLIRLinalgDialect
|
|
||||||
MLIRMathDialect
|
|
||||||
TorchMLIRTorchDialect
|
|
||||||
TorchMLIRTMTensorDialect
|
|
||||||
TorchMLIRTorchUtils
|
|
||||||
)
|
|
||||||
|
|
||||||
torch_mlir_target_includes(TorchMLIRTorchToTMTensor)
|
|
|
@ -1,13 +1,15 @@
|
||||||
add_mlir_conversion_library(TorchMLIRTorchToTosa
|
add_mlir_conversion_library(TorchMLIRTorchToTosa
|
||||||
|
Passes.cpp
|
||||||
TorchToTosa.cpp
|
TorchToTosa.cpp
|
||||||
TosaLegalizeUtils.cpp
|
TosaLegalizeUtils.cpp
|
||||||
TosaLegalizeCommon.cpp
|
TosaLegalizeCommon.cpp
|
||||||
|
VerifyTosaBackendContract.cpp
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
ADDITIONAL_HEADER_DIRS
|
||||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToTosa
|
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToTosa
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
TorchMLIRConversionPassIncGen
|
TorchMLIRConversionTosaPassIncGen
|
||||||
|
|
||||||
LINK_COMPONENTS
|
LINK_COMPONENTS
|
||||||
Core
|
Core
|
||||||
|
@ -16,6 +18,8 @@ add_mlir_conversion_library(TorchMLIRTorchToTosa
|
||||||
MLIRIR
|
MLIRIR
|
||||||
MLIRPass
|
MLIRPass
|
||||||
MLIRTosaDialect
|
MLIRTosaDialect
|
||||||
|
MLIRTransforms
|
||||||
|
TorchMLIRTorchConversionPasses
|
||||||
TorchMLIRConversionUtils
|
TorchMLIRConversionUtils
|
||||||
TorchMLIRTorchDialect
|
TorchMLIRTorchDialect
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
//===------------------------------------------------------------*- C++ -*-===//
|
//===- PassDetail.h - Conversion Pass class details -------------*- C++ -*-===//
|
||||||
//
|
//
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
@ -7,16 +7,20 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H
|
#ifndef TORCHMLIR_CONVERSION_TOSA_PASSDETAIL_H
|
||||||
#define TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H
|
#define TORCHMLIR_CONVERSION_TOSA_PASSDETAIL_H
|
||||||
|
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace torch {
|
namespace torch {
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTMTensorPass();
|
|
||||||
}
|
|
||||||
} // namespace mlir
|
|
||||||
|
|
||||||
#endif // TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H
|
#define GEN_PASS_CLASSES
|
||||||
|
#include "torch-mlir/Conversion/TorchToTosa/Passes.h.inc"
|
||||||
|
|
||||||
|
} // namespace torch
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // TORCHMLIR_CONVERSION_TOSA_PASSDETAIL_H
|
|
@ -0,0 +1,61 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "torch-mlir/Conversion/TorchToTosa/Passes.h"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
|
||||||
|
#include "mlir/Pass/PassManager.h"
|
||||||
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::torch;
|
||||||
|
using namespace mlir::tosa;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pass registration
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
#define GEN_PASS_REGISTRATION
|
||||||
|
#include "torch-mlir/Conversion/TorchToTosa/Passes.h.inc"
|
||||||
|
} // end namespace
|
||||||
|
|
||||||
|
void mlir::torch::registerTosaConversionPasses() {
|
||||||
|
::registerPasses();
|
||||||
|
mlir::PassPipelineRegistration<>(
|
||||||
|
"torch-backend-to-tosa-backend-pipeline",
|
||||||
|
"Pipeline lowering torch backend contract to TOSA backend "
|
||||||
|
"contract.",
|
||||||
|
createTorchBackendToTosaBackendPipeline);
|
||||||
|
}
|
||||||
|
|
||||||
|
void mlir::torch::createTorchBackendToTosaBackendPipeline(OpPassManager &pm) {
|
||||||
|
pm.addNestedPass<func::FuncOp>(createConvertTorchToTosaPass());
|
||||||
|
// Perform rank broadcasting so TosaToLinalg pass works
|
||||||
|
pm.addNestedPass<func::FuncOp>(createTosaMakeBroadcastablePass());
|
||||||
|
|
||||||
|
// Clean up any non-canonical code introduced above..
|
||||||
|
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||||
|
// The resolution of `dim` ops tends to create identical ops. CSE them.
|
||||||
|
pm.addNestedPass<func::FuncOp>(createCSEPass());
|
||||||
|
|
||||||
|
// Finish the type conversion from `torch` types to the types of the
|
||||||
|
// TOSA backend contract.
|
||||||
|
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
||||||
|
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||||
|
pm.addNestedPass<func::FuncOp>(
|
||||||
|
TorchConversion::createFinalizingBackendTypeConversionPass());
|
||||||
|
|
||||||
|
// Verify that we have lowered to the form that TOSA backends
|
||||||
|
// expect. This fails compilation (signalPassFailure) if the IR is not in the
|
||||||
|
// correct form.
|
||||||
|
pm.addPass(createVerifyTosaBackendContractPass());
|
||||||
|
}
|
|
@ -7,12 +7,12 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
#include "torch-mlir/Conversion/TorchToTosa/Passes.h"
|
||||||
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
|
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
|
||||||
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
|
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
|
||||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "./PassDetail.h"
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||||
|
|
|
@ -7,7 +7,8 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "PassDetail.h"
|
#include "./PassDetail.h"
|
||||||
|
#include "torch-mlir/Conversion/TorchToTosa/Passes.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
@ -16,11 +17,9 @@
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch::TorchConversion;
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class VerifyTosaBackendContractPass
|
class VerifyTosaBackendContractPass
|
||||||
|
@ -62,6 +61,6 @@ class VerifyTosaBackendContractPass
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
mlir::torch::TorchConversion::createVerifyTosaBackendContractPass() {
|
mlir::torch::createVerifyTosaBackendContractPass() {
|
||||||
return std::make_unique<VerifyTosaBackendContractPass>();
|
return std::make_unique<VerifyTosaBackendContractPass>();
|
||||||
}
|
}
|
|
@ -11,8 +11,17 @@
|
||||||
#include "mlir/Pass/PassManager.h"
|
#include "mlir/Pass/PassManager.h"
|
||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pass registration
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
#define GEN_PASS_REGISTRATION
|
||||||
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc"
|
||||||
|
} // namespace
|
||||||
|
|
||||||
void mlir::torch::registerTorchPasses() {
|
void mlir::torch::registerTorchPasses() {
|
||||||
mlir::torch::registerPasses();
|
::registerPasses();
|
||||||
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||||
"torchscript-module-to-torch-backend-pipeline",
|
"torchscript-module-to-torch-backend-pipeline",
|
||||||
"Pipeline lowering TorchScript object graph IR to Torch backend form.",
|
"Pipeline lowering TorchScript object graph IR to Torch backend form.",
|
||||||
|
|
|
@ -1,20 +1,10 @@
|
||||||
set(LinkedLibs
|
set(LinkedLibs
|
||||||
MLIRFuncTransforms
|
MLIRFuncTransforms
|
||||||
MLIRIR
|
MLIRIR
|
||||||
MLIRLinalgTransforms
|
|
||||||
MLIRMemRefTransforms
|
|
||||||
MLIRPass
|
MLIRPass
|
||||||
MLIRTosaTransforms
|
|
||||||
MLIRVectorTransforms
|
|
||||||
TorchMLIRTorchConversionDialect
|
TorchMLIRTorchConversionDialect
|
||||||
TorchMLIRTorchConversionToMLProgram
|
|
||||||
TorchMLIRTorchDialect
|
TorchMLIRTorchDialect
|
||||||
TorchMLIRTorchPasses
|
TorchMLIRTorchPasses
|
||||||
TorchMLIRTorchToArith
|
|
||||||
TorchMLIRTorchToLinalg
|
|
||||||
TorchMLIRTorchToSCF
|
|
||||||
TorchMLIRTorchToTMTensor
|
|
||||||
TorchMLIRTorchToTosa
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
||||||
|
@ -27,9 +17,6 @@ add_mlir_library(TorchMLIRTorchConversionPasses
|
||||||
Passes.cpp
|
Passes.cpp
|
||||||
ConvertCustomQuantOp.cpp
|
ConvertCustomQuantOp.cpp
|
||||||
UnpackQuantTensor.cpp
|
UnpackQuantTensor.cpp
|
||||||
VerifyLinalgOnTensorsBackendContract.cpp
|
|
||||||
VerifyTosaBackendContract.cpp
|
|
||||||
VerifyStablehloBackendContract.cpp
|
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
ADDITIONAL_HEADER_DIRS
|
||||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/TorchConversion/Transforms
|
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/TorchConversion/Transforms
|
||||||
|
|
|
@ -8,27 +8,10 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
||||||
#include "mlir/Conversion/Passes.h"
|
|
||||||
#include "mlir/Dialect/Func/Transforms/Passes.h"
|
|
||||||
#include "mlir/Dialect/Linalg/Passes.h"
|
|
||||||
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
|
||||||
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
|
|
||||||
#include "mlir/Pass/PassManager.h"
|
#include "mlir/Pass/PassManager.h"
|
||||||
#include "mlir/Transforms/Passes.h"
|
|
||||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
|
||||||
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
|
|
||||||
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
|
|
||||||
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
|
||||||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
|
||||||
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
|
|
||||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
|
||||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
|
||||||
#endif
|
|
||||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::tosa;
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Pass registration
|
// Pass registration
|
||||||
|
@ -39,111 +22,4 @@ namespace reg {
|
||||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc"
|
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc"
|
||||||
} // end namespace reg
|
} // end namespace reg
|
||||||
|
|
||||||
void mlir::torch::registerTorchConversionPasses() {
|
void mlir::torch::registerTorchConversionPasses() { reg::registerPasses(); }
|
||||||
reg::registerPasses();
|
|
||||||
mlir::PassPipelineRegistration<>(
|
|
||||||
"torch-backend-to-linalg-on-tensors-backend-pipeline",
|
|
||||||
"Pipeline lowering torch backend contract to linalg-on-tensors backend "
|
|
||||||
"contract.",
|
|
||||||
TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline);
|
|
||||||
|
|
||||||
mlir::PassPipelineRegistration<>(
|
|
||||||
"torch-backend-to-tosa-backend-pipeline",
|
|
||||||
"Pipeline lowering torch backend contract to TOSA backend "
|
|
||||||
"contract.",
|
|
||||||
TorchConversion::createTorchBackendToTosaBackendPipeline);
|
|
||||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
|
||||||
mlir::PassPipelineRegistration<
|
|
||||||
TorchConversion::StablehloBackendPipelineOptions>(
|
|
||||||
"torch-backend-to-stablehlo-backend-pipeline",
|
|
||||||
"Pipeline lowering torch backend contract to StableHLO backend "
|
|
||||||
"contract.",
|
|
||||||
TorchConversion::createTorchBackendToStablehloBackendPipeline);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
|
|
||||||
OpPassManager &pm) {
|
|
||||||
// Lower to linalg + guards which is the input to codegen backends.
|
|
||||||
// We do this first as it tends to involve pattern-matching against constants,
|
|
||||||
// (e.g. dimensions which must be constant in a ranked programming model)
|
|
||||||
// and those constants get somewhat obscured by TorchToArith.
|
|
||||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToTMTensorPass());
|
|
||||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToLinalgPass());
|
|
||||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToSCFPass());
|
|
||||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());
|
|
||||||
pm.addPass(createConvertTorchConversionToMLProgramPass());
|
|
||||||
pm.addNestedPass<func::FuncOp>(memref::createExpandOpsPass());
|
|
||||||
|
|
||||||
// Clean up any non-canonical code introduced above..
|
|
||||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
|
||||||
// Resolve `dim` ops on tensors (which currently live in the `memref`
|
|
||||||
// dialect for some reason -- we don't have memrefs at this level).
|
|
||||||
pm.addNestedPass<func::FuncOp>(
|
|
||||||
memref::createResolveShapedTypeResultDimsPass());
|
|
||||||
// The resolution of `dim` ops tends to create identical ops. CSE them.
|
|
||||||
pm.addNestedPass<func::FuncOp>(createCSEPass());
|
|
||||||
|
|
||||||
// Finish the type conversion from `torch` types to the types of the
|
|
||||||
// linalg-on-tensors backend contract.
|
|
||||||
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
|
||||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
|
||||||
pm.addNestedPass<func::FuncOp>(
|
|
||||||
TorchConversion::createFinalizingBackendTypeConversionPass());
|
|
||||||
|
|
||||||
// Verify that we have lowered to the form that linalg on tensors backends
|
|
||||||
// expect. This fails compilation (signalPassFailure) if the IR is not in the
|
|
||||||
// correct form.
|
|
||||||
pm.addPass(TorchConversion::createVerifyLinalgOnTensorsBackendContractPass());
|
|
||||||
}
|
|
||||||
|
|
||||||
void TorchConversion::createTorchBackendToTosaBackendPipeline(
|
|
||||||
OpPassManager &pm) {
|
|
||||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToTosaPass());
|
|
||||||
// Perform rank broadcasting so TosaToLinalg pass works
|
|
||||||
pm.addNestedPass<func::FuncOp>(createTosaMakeBroadcastablePass());
|
|
||||||
|
|
||||||
// Clean up any non-canonical code introduced above..
|
|
||||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
|
||||||
// The resolution of `dim` ops tends to create identical ops. CSE them.
|
|
||||||
pm.addNestedPass<func::FuncOp>(createCSEPass());
|
|
||||||
|
|
||||||
// Finish the type conversion from `torch` types to the types of the
|
|
||||||
// TOSA backend contract.
|
|
||||||
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
|
||||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
|
||||||
pm.addNestedPass<func::FuncOp>(
|
|
||||||
TorchConversion::createFinalizingBackendTypeConversionPass());
|
|
||||||
|
|
||||||
// Verify that we have lowered to the form that TOSA backends
|
|
||||||
// expect. This fails compilation (signalPassFailure) if the IR is not in the
|
|
||||||
// correct form.
|
|
||||||
pm.addPass(TorchConversion::createVerifyTosaBackendContractPass());
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
|
||||||
void TorchConversion::createTorchBackendToStablehloBackendPipeline(
|
|
||||||
OpPassManager &pm,
|
|
||||||
const TorchConversion::StablehloBackendPipelineOptions &options) {
|
|
||||||
// Generate Stablehlo ops.
|
|
||||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToStablehloPass(
|
|
||||||
options.enableStaticShape, options.enableI32Index));
|
|
||||||
// Lowering remained ops to Arith
|
|
||||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());
|
|
||||||
|
|
||||||
// Clean up any non-canonical code introduced above..
|
|
||||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
|
||||||
// The resolution of `dim` ops tends to create identical ops. CSE them.
|
|
||||||
pm.addNestedPass<func::FuncOp>(createCSEPass());
|
|
||||||
|
|
||||||
// Finish the type conversion from `torch` types to the types of the
|
|
||||||
// StableHLO backend contract.
|
|
||||||
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
|
||||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
|
||||||
pm.addNestedPass<func::FuncOp>(
|
|
||||||
TorchConversion::createFinalizingBackendTypeConversionPass());
|
|
||||||
|
|
||||||
// Verify that we have lowered to Stablehlo and Chlo ops.
|
|
||||||
pm.addPass(TorchConversion::createVerifyStablehloBackendContractPass());
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
|
@ -21,8 +21,17 @@
|
||||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
||||||
#include "torch-mlir/RefBackend/Passes.h"
|
#include "torch-mlir/RefBackend/Passes.h"
|
||||||
|
|
||||||
|
#ifdef TORCH_MLIR_ENABLE_LINALG
|
||||||
|
#include "torch-mlir/Conversion/TorchToLinalg/Passes.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
#include "mhlo/transforms/passes.h"
|
#include "mhlo/transforms/passes.h"
|
||||||
|
#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef TORCH_MLIR_ENABLE_TOSA
|
||||||
|
#include "torch-mlir/Conversion/TorchToTosa/Passes.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void mlir::torch::registerAllDialects(mlir::DialectRegistry ®istry) {
|
void mlir::torch::registerAllDialects(mlir::DialectRegistry ®istry) {
|
||||||
|
@ -41,11 +50,20 @@ void mlir::torch::registerAllPasses() {
|
||||||
mlir::torch::RefBackend::registerRefBackendPasses();
|
mlir::torch::RefBackend::registerRefBackendPasses();
|
||||||
mlir::torch::TMTensor::registerPasses();
|
mlir::torch::TMTensor::registerPasses();
|
||||||
|
|
||||||
|
#ifdef TORCH_MLIR_ENABLE_LINALG
|
||||||
|
mlir::torch::registerLinalgConversionPasses();
|
||||||
|
#endif // TORCH_MLIR_ENABLE_LINALG
|
||||||
|
|
||||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
mlir::mhlo::registerSymbolicShapeOptimizationPass();
|
mlir::mhlo::registerSymbolicShapeOptimizationPass();
|
||||||
mlir::mhlo::registerStablehloLegalizeToHloPass();
|
mlir::mhlo::registerStablehloLegalizeToHloPass();
|
||||||
mlir::mhlo::registerChloLegalizeToHloPass();
|
mlir::mhlo::registerChloLegalizeToHloPass();
|
||||||
mlir::mhlo::registerHloLegalizeToLinalgPass();
|
mlir::mhlo::registerHloLegalizeToLinalgPass();
|
||||||
mlir::mhlo::registerTestUnfuseBatchNormPass();
|
mlir::mhlo::registerTestUnfuseBatchNormPass();
|
||||||
|
mlir::torch::registerStablehloConversionPasses();
|
||||||
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
|
|
||||||
|
#ifdef TORCH_MLIR_ENABLE_TOSA
|
||||||
|
mlir::torch::registerTosaConversionPasses();
|
||||||
|
#endif // TORCH_MLIR_ENABLE_TOSA
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue