mirror of https://github.com/llvm/torch-mlir
Add Custom Op example (#1391)
Co-authored-by: nithinsubbiah <nithinsubbiah@gmail.com>custom-op-example
parent
0e2e94d542
commit
eccf145542
|
@ -49,6 +49,20 @@ else()
|
|||
set(ENV{TORCH_MLIR_ENABLE_LTC} 0)
|
||||
endif()
|
||||
|
||||
option(TORCH_MLIR_USE_INSTALLED_PYTORCH "Build from local PyTorch in environment" ON)
|
||||
if(TORCH_MLIR_USE_INSTALLED_PYTORCH)
|
||||
add_definitions(-DTORCH_MLIR_USE_INSTALLED_PYTORCH)
|
||||
else()
|
||||
set(TORCH_MLIR_USE_INSTALLED_PYTORCH OFF)
|
||||
endif()
|
||||
|
||||
option(TORCH_MLIR_CUSTOM_OP_EXAMPLE "Builds custom op example" OFF)
|
||||
if(NOT TORCH_MLIR_USE_INSTALLED_PYTORCH)
|
||||
add_definitions(-DTORCH_MLIR_CUSTOM_OP_EXAMPLE)
|
||||
set(TORCH_MLIR_CUSTOM_OP_EXAMPLE ON)
|
||||
endif()
|
||||
message(STATUS "TORCH_MLIR_CUSTOM_OP_EXAMPLE:" ${TORCH_MLIR_CUSTOM_OP_EXAMPLE})
|
||||
|
||||
torch_mlir_add_llvm_external_project(
|
||||
torch-mlir-dialects
|
||||
TORCH_MLIR_DIALECTS
|
||||
|
|
|
@ -22,11 +22,16 @@ if [ ! -z ${TORCH_MLIR_EXT_PYTHONPATH} ]; then
|
|||
pypath="${pypath}:${TORCH_MLIR_EXT_PYTHONPATH}"
|
||||
fi
|
||||
TORCH_MLIR_EXT_MODULES="${TORCH_MLIR_EXT_MODULES:-""}"
|
||||
include_custom_op_example="${INCLUDE_CUSTOM_OP:-OFF}"
|
||||
if [ "$include_custom_op_example" == "ON" ]; then
|
||||
ext_module="torch_mlir._torch_mlir_custom_op_example"
|
||||
fi
|
||||
if [ ! -z ${TORCH_MLIR_EXT_MODULES} ]; then
|
||||
ext_module="${TORCH_MLIR_EXT_MODULES} "
|
||||
ext_module="${ext_module},${TORCH_MLIR_EXT_MODULES} "
|
||||
fi
|
||||
|
||||
PYTHONPATH="${pypath}" python \
|
||||
-m torch_mlir.dialects.torch.importer.jit_ir.build_tools.shape_lib_gen \
|
||||
--pytorch_op_extensions=${ext_module:-""} \
|
||||
--torch_transforms_cpp_dir="${torch_transforms_cpp_dir}"
|
||||
--torch_transforms_cpp_dir="${torch_transforms_cpp_dir}" \
|
||||
--include_custom_op_example="${include_custom_op_example}"
|
||||
|
|
|
@ -22,13 +22,17 @@ if [ ! -z ${TORCH_MLIR_EXT_PYTHONPATH} ]; then
|
|||
pypath="${pypath}:${TORCH_MLIR_EXT_PYTHONPATH}"
|
||||
fi
|
||||
TORCH_MLIR_EXT_MODULES="${TORCH_MLIR_EXT_MODULES:-""}"
|
||||
ext_module="${ext_module:-""}"
|
||||
include_custom_op_example="${INCLUDE_CUSTOM_OP:-OFF}"
|
||||
if [ "$include_custom_op_example" == "ON" ]; then
|
||||
ext_module="torch_mlir._torch_mlir_custom_op_example"
|
||||
fi
|
||||
if [ ! -z ${TORCH_MLIR_EXT_MODULES} ]; then
|
||||
ext_module="${TORCH_MLIR_EXT_MODULES}"
|
||||
ext_module="${ext_module},${TORCH_MLIR_EXT_MODULES}"
|
||||
fi
|
||||
|
||||
PYTHONPATH="${pypath}" python \
|
||||
-m torch_mlir.dialects.torch.importer.jit_ir.build_tools.torch_ods_gen \
|
||||
--torch_ir_include_dir="${torch_ir_include_dir}" \
|
||||
--pytorch_op_extensions="${ext_module}" \
|
||||
--debug_registry_dump="${torch_ir_include_dir}/JITOperatorRegistryDump.txt"
|
||||
--pytorch_op_extensions="${ext_module:-""}" \
|
||||
--debug_registry_dump="${torch_ir_include_dir}/JITOperatorRegistryDump.txt" \
|
||||
--include_custom_op_example="${include_custom_op_example}"
|
||||
|
|
|
@ -30,9 +30,13 @@ from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsT
|
|||
|
||||
from .xfail_sets import REFBACKEND_XFAIL_SET, MHLO_PASS_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET, LTC_XFAIL_SET
|
||||
|
||||
import os
|
||||
include_custom_op = False
|
||||
if os.getenv('INCLUDE_CUSTOM_OP') == 'ON':
|
||||
include_custom_op = True
|
||||
# Import tests to register them in the global registry.
|
||||
from torch_mlir_e2e_test.test_suite import register_all_tests
|
||||
register_all_tests()
|
||||
register_all_tests(include_custom_op)
|
||||
|
||||
def _get_argparse():
|
||||
config_choices = ['native_torch', 'torchscript', 'refbackend', 'mhlo', 'tosa', 'eager_mode', 'lazy_tensor_core']
|
||||
|
|
|
@ -9728,3 +9728,25 @@ def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_TorchMlirCustomOpExampleIdentityOp : Torch_Op<"_torch_mlir_custom_op_example.identity", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `_torch_mlir_custom_op_example::identity : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$t
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult TorchMlirCustomOpExampleIdentityOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void TorchMlirCustomOpExampleIdentityOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
|
|
@ -18,6 +18,9 @@ set(linked_libs TorchMLIRTorchToLinalg
|
|||
if(TORCH_MLIR_ENABLE_MHLO)
|
||||
list(APPEND linked_libs TorchMLIRTorchToMhlo)
|
||||
endif()
|
||||
if (TORCH_MLIR_CUSTOM_OP_EXAMPLE)
|
||||
list(APPEND linked_libs TorchMLIRTorchToLinalgCustomOp)
|
||||
endif()
|
||||
|
||||
add_mlir_library(TorchMLIRConversionPasses
|
||||
Passes.cpp
|
||||
|
|
|
@ -29,3 +29,7 @@ add_mlir_conversion_library(TorchMLIRTorchToLinalg
|
|||
)
|
||||
|
||||
torch_mlir_target_includes(TorchMLIRTorchToLinalg)
|
||||
|
||||
if(TORCH_MLIR_CUSTOM_OP_EXAMPLE)
|
||||
add_subdirectory(CustomOp)
|
||||
endif()
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
add_mlir_conversion_library(TorchMLIRTorchToLinalgCustomOp
|
||||
CustomOpExample.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToLinalg
|
||||
|
||||
DEPENDS
|
||||
TorchMLIRConversionPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRLinalgDialect
|
||||
MLIRMathDialect
|
||||
TorchMLIRTorchDialect
|
||||
)
|
||||
|
||||
torch_mlir_target_includes(TorchMLIRTorchToLinalgCustomOp)
|
|
@ -0,0 +1,54 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||||
|
||||
#include "../../PassDetail.h"
|
||||
#include "../PopulatePatterns.h"
|
||||
#include "../Utils.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
namespace {
|
||||
class ConvertCustomOpExample
|
||||
: public OpConversionPattern<TorchMlirCustomOpExampleIdentityOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult matchAndRewrite(TorchMlirCustomOpExampleIdentityOp op,
|
||||
OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter
|
||||
) const override {
|
||||
// Type checks.
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
// Since the example op does nothing, we simply replace the uses of the
|
||||
// return value with its argument, then remove the op.
|
||||
rewriter.replaceOp(op, op->getOperands());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::torch::torch_to_linalg::populateCustomOpExamplePatternsAndLegality(
|
||||
TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
target.addIllegalOp<TorchMlirCustomOpExampleIdentityOp>();
|
||||
patterns.add<ConvertCustomOpExample>(typeConverter, context);
|
||||
}
|
|
@ -63,6 +63,11 @@ void populateIndirectDataMovementPatternsAndLegality(
|
|||
void populateTensorConstructorsPatternsAndLegality(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target);
|
||||
#ifdef TORCH_MLIR_CUSTOM_OP_EXAMPLE
|
||||
void populateCustomOpExamplePatternsAndLegality(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target);
|
||||
#endif // TORCH_MLIR_CUSTOM_OP_EXAMPLE
|
||||
|
||||
} // namespace torch_to_linalg
|
||||
} // namespace torch
|
||||
|
|
|
@ -61,7 +61,10 @@ public:
|
|||
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
|
||||
#ifdef TORCH_MLIR_CUSTOM_OP_EXAMPLE
|
||||
torch_to_linalg::populateCustomOpExamplePatternsAndLegality(
|
||||
typeConverter, patterns, target);
|
||||
#endif // TORCH_MLIR_CUSTOM_OP_EXAMPLE
|
||||
torch_to_linalg::populateTensorScalarInteropPatternsAndLegality(
|
||||
typeConverter, patterns, target);
|
||||
torch_to_linalg::populateLinearPatternsAndLegality(typeConverter, patterns,
|
||||
|
|
|
@ -701,7 +701,11 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
|
||||
PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp, AtenMaskedFillTensorOp,
|
||||
AtenRollOp, AtenPowTensorTensorOp, AtenLiftFreshCopyOp,
|
||||
AtenIndexTensorHackedTwinOp>(op)) {
|
||||
AtenIndexTensorHackedTwinOp
|
||||
#ifdef TORCH_MLIR_CUSTOM_OP_EXAMPLE
|
||||
, TorchMlirCustomOpExampleIdentityOp
|
||||
#endif // TORCH_MLIR_CUSTOM_OP_EXAMPLE
|
||||
>(op)) {
|
||||
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
||||
}
|
||||
|
||||
|
|
|
@ -7815,6 +7815,10 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
|
|||
" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
|
||||
" return %2 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn._torch_mlir_custom_op_example.identity\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
"";
|
||||
// clang-format on
|
||||
|
|
|
@ -17,8 +17,6 @@ add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=torch_mlir.")
|
|||
# PyTorch
|
||||
################################################################################
|
||||
|
||||
option(TORCH_MLIR_USE_INSTALLED_PYTORCH "Build from local PyTorch in environment" ON)
|
||||
|
||||
if (NOT TORCH_MLIR_USE_INSTALLED_PYTORCH)
|
||||
# Source builds
|
||||
set(ENV{TORCH_MLIR_SRC_PYTORCH_REPO} ${TORCH_MLIR_SRC_PYTORCH_REPO})
|
||||
|
@ -65,6 +63,15 @@ declare_mlir_dialect_python_bindings(
|
|||
DIALECT_NAME torch
|
||||
)
|
||||
|
||||
if(TORCH_MLIR_CUSTOM_OP_EXAMPLE)
|
||||
declare_mlir_python_sources(TorchMLIRPythonSources.CustomOp
|
||||
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
||||
ADD_TO_PARENT TorchMLIRPythonSources
|
||||
SOURCES_GLOB
|
||||
_torch_mlir_custom_op_example/__init__.py
|
||||
)
|
||||
endif()
|
||||
|
||||
################################################################################
|
||||
# Extensions
|
||||
################################################################################
|
||||
|
@ -111,7 +118,9 @@ add_subdirectory(torch_mlir/eager_mode)
|
|||
# Required for running the update_torch_ods.sh and update_shape_lib.sh scripts.
|
||||
################################################################################
|
||||
|
||||
# add_subdirectory(torch_mlir/_torch_mlir_custom_op_example)
|
||||
if(TORCH_MLIR_CUSTOM_OP_EXAMPLE)
|
||||
add_subdirectory(torch_mlir/_torch_mlir_custom_op_example)
|
||||
endif()
|
||||
|
||||
################################################################################
|
||||
# Generate packages and shared library
|
||||
|
@ -165,4 +174,8 @@ if(TORCH_MLIR_ENABLE_LTC)
|
|||
add_dependencies(TorchMLIRPythonModules reference_lazy_backend)
|
||||
endif()
|
||||
|
||||
if(TORCH_MLIR_CUSTOM_OP_EXAMPLE)
|
||||
add_dependencies(TorchMLIRPythonModules torch_mlir_custom_op_example)
|
||||
endif()
|
||||
|
||||
add_subdirectory(test)
|
||||
|
|
|
@ -1,17 +1,18 @@
|
|||
###########################################################################
|
||||
# Setup PyTorch
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/../cmake/modules")
|
||||
include(TorchMLIRPyTorch)
|
||||
TorchMLIRProbeForPyTorchInstall()
|
||||
find_package(Torch 1.8 REQUIRED)
|
||||
TorchMLIRConfigurePyTorch()
|
||||
###########################################################################
|
||||
|
||||
# Python sources
|
||||
declare_mlir_python_sources(TorchMLIRPythonSources.CustomOp
|
||||
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
||||
ADD_TO_PARENT TorchMLIRPythonSources
|
||||
SOURCES_GLOB
|
||||
_torch_mlir_custom_op_example/__init__.py
|
||||
)
|
||||
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/python/torch_mlir/cmake/modules")
|
||||
include(TorchMLIRPyTorch)
|
||||
|
||||
TorchMLIRProbeForPyTorchInstall()
|
||||
set(Torch_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../../libtorch/share/cmake/Torch")
|
||||
|
||||
find_package(Torch 1.11 REQUIRED)
|
||||
|
||||
###########################################################################
|
||||
# Library definition
|
||||
###########################################################################
|
||||
|
||||
# C++ extension
|
||||
include_directories(BEFORE
|
||||
|
|
|
@ -1192,6 +1192,9 @@ def aten〇linalg_vector_norm(self: List[int], ord: float = 2, dim: Optional[Lis
|
|||
def aten〇frobenius_norm〇dim(self: List[int], dim: List[int], keepdim: bool = False) -> List[int]:
|
||||
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0)
|
||||
|
||||
def _torch_mlir_custom_op_example〇identity(t: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(t)
|
||||
|
||||
# ==============================================================================
|
||||
# Shape library generator main().
|
||||
# ==============================================================================
|
||||
|
@ -1230,6 +1233,8 @@ def main(args):
|
|||
for k, v in globals().items():
|
||||
if "〇" not in k:
|
||||
continue
|
||||
if k == "_torch_mlir_custom_op_example〇identity" and args.include_custom_op_example == "OFF":
|
||||
continue
|
||||
if not hasattr(v, "_not_present_in_registry"):
|
||||
_verify_signature_matches_registry(v, registry)
|
||||
# Add it to the compilation unit.
|
||||
|
@ -1305,6 +1310,10 @@ def _create_argparse() -> argparse.ArgumentParser:
|
|||
"--torch_transforms_cpp_dir",
|
||||
required=True,
|
||||
help="Directory containing the Torch transforms cpp files")
|
||||
parser.add_argument(
|
||||
"--include_custom_op_example",
|
||||
type=str,
|
||||
help="String value to denote if custom_op_example has to be included")
|
||||
parser.add_argument(
|
||||
"--pytorch_op_extensions",
|
||||
type=str,
|
||||
|
|
|
@ -218,7 +218,7 @@ def emit_op(operator: JitOperator,
|
|||
has_canonicalizer=has_canonicalizer)
|
||||
|
||||
|
||||
def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||
def emit_ops(emitter_td: TextEmitter, registry: Registry, include_custom_op_example: str):
|
||||
def emit(key, **kwargs):
|
||||
emit_op(registry[key], emitter_td, **kwargs)
|
||||
|
||||
|
@ -644,6 +644,16 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
"quantized::linear : (Tensor, __torch__.torch.classes.quantized.LinearPackedParamsBase, float, int) -> (Tensor)",
|
||||
traits=["HasValueSemantics"])
|
||||
|
||||
# ==========================================================================
|
||||
# `_torch_mlir_custom_op_example::` namespace.
|
||||
#
|
||||
# This is a demonstration of supporting an operation defined in a PyTorch
|
||||
# extension.
|
||||
# ==========================================================================
|
||||
|
||||
if include_custom_op_example == "ON":
|
||||
emit("_torch_mlir_custom_op_example::identity : (Tensor) -> (Tensor)")
|
||||
|
||||
|
||||
def dump_registered_ops(outfile: TextIO, registry: Registry):
|
||||
for _, v in sorted(registry.by_unique_key.items()):
|
||||
|
@ -668,7 +678,7 @@ def main(args: argparse.Namespace):
|
|||
with open(td_path, "w") as f_td:
|
||||
emitter_td = TextEmitter(f_td)
|
||||
emitter_td.print(ODS_BANNER)
|
||||
emit_ops(emitter_td, registry)
|
||||
emit_ops(emitter_td, registry, args.include_custom_op_example)
|
||||
|
||||
|
||||
def _create_argparse() -> argparse.ArgumentParser:
|
||||
|
@ -680,6 +690,10 @@ def _create_argparse() -> argparse.ArgumentParser:
|
|||
parser.add_argument(
|
||||
"--debug_registry_dump",
|
||||
help="File to dump the the PyTorch JIT operator registry into")
|
||||
parser.add_argument(
|
||||
"--include_custom_op_example",
|
||||
type=str,
|
||||
help="String value to denote if custom_op_example has to be included")
|
||||
parser.add_argument(
|
||||
"--pytorch_op_extensions",
|
||||
type=str,
|
||||
|
|
|
@ -17,7 +17,7 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = {
|
|||
"MaxPool2dWithIndicesWith3dInputModule_basic",
|
||||
}
|
||||
|
||||
def register_all_tests():
|
||||
def register_all_tests(include_custom_op: bool):
|
||||
"""Registers all the built-in E2E tests that Torch-MLIR provides."""
|
||||
# Side-effecting import statements.
|
||||
from . import basic
|
||||
|
@ -53,3 +53,5 @@ def register_all_tests():
|
|||
from . import return_types
|
||||
from . import control_flow
|
||||
from . import stats
|
||||
if include_custom_op:
|
||||
from . import custom_op_example
|
||||
|
|
|
@ -2,6 +2,9 @@
|
|||
set -euo pipefail
|
||||
|
||||
src_dir="$(realpath "$(dirname "$0")"/..)"
|
||||
build_dir="${src_dir}/build"
|
||||
|
||||
export INCLUDE_CUSTOM_OP="${INCLUDE_CUSTOM_OP:-OFF}"
|
||||
|
||||
cd "$src_dir"
|
||||
|
||||
|
|
Loading…
Reference in New Issue