Add Custom Op example (#1391)

Co-authored-by: nithinsubbiah <nithinsubbiah@gmail.com>
custom-op-example
powderluv 2022-09-20 09:59:54 -07:00 committed by GitHub
parent 0e2e94d542
commit eccf145542
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 216 additions and 27 deletions

View File

@ -49,6 +49,20 @@ else()
set(ENV{TORCH_MLIR_ENABLE_LTC} 0)
endif()
option(TORCH_MLIR_USE_INSTALLED_PYTORCH "Build from local PyTorch in environment" ON)
if(TORCH_MLIR_USE_INSTALLED_PYTORCH)
add_definitions(-DTORCH_MLIR_USE_INSTALLED_PYTORCH)
else()
set(TORCH_MLIR_USE_INSTALLED_PYTORCH OFF)
endif()
option(TORCH_MLIR_CUSTOM_OP_EXAMPLE "Builds custom op example" OFF)
if(NOT TORCH_MLIR_USE_INSTALLED_PYTORCH)
add_definitions(-DTORCH_MLIR_CUSTOM_OP_EXAMPLE)
set(TORCH_MLIR_CUSTOM_OP_EXAMPLE ON)
endif()
message(STATUS "TORCH_MLIR_CUSTOM_OP_EXAMPLE:" ${TORCH_MLIR_CUSTOM_OP_EXAMPLE})
torch_mlir_add_llvm_external_project(
torch-mlir-dialects
TORCH_MLIR_DIALECTS

View File

@ -22,11 +22,16 @@ if [ ! -z ${TORCH_MLIR_EXT_PYTHONPATH} ]; then
pypath="${pypath}:${TORCH_MLIR_EXT_PYTHONPATH}"
fi
TORCH_MLIR_EXT_MODULES="${TORCH_MLIR_EXT_MODULES:-""}"
include_custom_op_example="${INCLUDE_CUSTOM_OP:-OFF}"
if [ "$include_custom_op_example" == "ON" ]; then
ext_module="torch_mlir._torch_mlir_custom_op_example"
fi
if [ ! -z ${TORCH_MLIR_EXT_MODULES} ]; then
ext_module="${TORCH_MLIR_EXT_MODULES} "
ext_module="${ext_module},${TORCH_MLIR_EXT_MODULES} "
fi
PYTHONPATH="${pypath}" python \
-m torch_mlir.dialects.torch.importer.jit_ir.build_tools.shape_lib_gen \
--pytorch_op_extensions=${ext_module:-""} \
--torch_transforms_cpp_dir="${torch_transforms_cpp_dir}"
--torch_transforms_cpp_dir="${torch_transforms_cpp_dir}" \
--include_custom_op_example="${include_custom_op_example}"

View File

@ -22,13 +22,17 @@ if [ ! -z ${TORCH_MLIR_EXT_PYTHONPATH} ]; then
pypath="${pypath}:${TORCH_MLIR_EXT_PYTHONPATH}"
fi
TORCH_MLIR_EXT_MODULES="${TORCH_MLIR_EXT_MODULES:-""}"
ext_module="${ext_module:-""}"
include_custom_op_example="${INCLUDE_CUSTOM_OP:-OFF}"
if [ "$include_custom_op_example" == "ON" ]; then
ext_module="torch_mlir._torch_mlir_custom_op_example"
fi
if [ ! -z ${TORCH_MLIR_EXT_MODULES} ]; then
ext_module="${TORCH_MLIR_EXT_MODULES}"
ext_module="${ext_module},${TORCH_MLIR_EXT_MODULES}"
fi
PYTHONPATH="${pypath}" python \
-m torch_mlir.dialects.torch.importer.jit_ir.build_tools.torch_ods_gen \
--torch_ir_include_dir="${torch_ir_include_dir}" \
--pytorch_op_extensions="${ext_module}" \
--debug_registry_dump="${torch_ir_include_dir}/JITOperatorRegistryDump.txt"
--pytorch_op_extensions="${ext_module:-""}" \
--debug_registry_dump="${torch_ir_include_dir}/JITOperatorRegistryDump.txt" \
--include_custom_op_example="${include_custom_op_example}"

View File

@ -30,9 +30,13 @@ from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsT
from .xfail_sets import REFBACKEND_XFAIL_SET, MHLO_PASS_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET, LTC_XFAIL_SET
import os
include_custom_op = False
if os.getenv('INCLUDE_CUSTOM_OP') == 'ON':
include_custom_op = True
# Import tests to register them in the global registry.
from torch_mlir_e2e_test.test_suite import register_all_tests
register_all_tests()
register_all_tests(include_custom_op)
def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'refbackend', 'mhlo', 'tosa', 'eager_mode', 'lazy_tensor_core']

View File

@ -9728,3 +9728,25 @@ def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [
}];
}
def Torch_TorchMlirCustomOpExampleIdentityOp : Torch_Op<"_torch_mlir_custom_op_example.identity", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `_torch_mlir_custom_op_example::identity : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$t
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult TorchMlirCustomOpExampleIdentityOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void TorchMlirCustomOpExampleIdentityOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

View File

@ -18,6 +18,9 @@ set(linked_libs TorchMLIRTorchToLinalg
if(TORCH_MLIR_ENABLE_MHLO)
list(APPEND linked_libs TorchMLIRTorchToMhlo)
endif()
if (TORCH_MLIR_CUSTOM_OP_EXAMPLE)
list(APPEND linked_libs TorchMLIRTorchToLinalgCustomOp)
endif()
add_mlir_library(TorchMLIRConversionPasses
Passes.cpp

View File

@ -29,3 +29,7 @@ add_mlir_conversion_library(TorchMLIRTorchToLinalg
)
torch_mlir_target_includes(TorchMLIRTorchToLinalg)
if(TORCH_MLIR_CUSTOM_OP_EXAMPLE)
add_subdirectory(CustomOp)
endif()

View File

@ -0,0 +1,21 @@
add_mlir_conversion_library(TorchMLIRTorchToLinalgCustomOp
CustomOpExample.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToLinalg
DEPENDS
TorchMLIRConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRLinalgDialect
MLIRMathDialect
TorchMLIRTorchDialect
)
torch_mlir_target_includes(TorchMLIRTorchToLinalgCustomOp)

View File

@ -0,0 +1,54 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "../../PassDetail.h"
#include "../PopulatePatterns.h"
#include "../Utils.h"
#include "mlir/IR/Matchers.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace {
class ConvertCustomOpExample
: public OpConversionPattern<TorchMlirCustomOpExampleIdentityOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(TorchMlirCustomOpExampleIdentityOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter
) const override {
// Type checks.
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
// Since the example op does nothing, we simply replace the uses of the
// return value with its argument, then remove the op.
rewriter.replaceOp(op, op->getOperands());
return success();
}
};
} // namespace
void mlir::torch::torch_to_linalg::populateCustomOpExamplePatternsAndLegality(
TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target) {
MLIRContext *context = patterns.getContext();
target.addIllegalOp<TorchMlirCustomOpExampleIdentityOp>();
patterns.add<ConvertCustomOpExample>(typeConverter, context);
}

View File

@ -63,6 +63,11 @@ void populateIndirectDataMovementPatternsAndLegality(
void populateTensorConstructorsPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);
#ifdef TORCH_MLIR_CUSTOM_OP_EXAMPLE
void populateCustomOpExamplePatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);
#endif // TORCH_MLIR_CUSTOM_OP_EXAMPLE
} // namespace torch_to_linalg
} // namespace torch

View File

@ -61,7 +61,10 @@ public:
TorchConversion::setupBackendTypeConversion(target, typeConverter);
RewritePatternSet patterns(context);
#ifdef TORCH_MLIR_CUSTOM_OP_EXAMPLE
torch_to_linalg::populateCustomOpExamplePatternsAndLegality(
typeConverter, patterns, target);
#endif // TORCH_MLIR_CUSTOM_OP_EXAMPLE
torch_to_linalg::populateTensorScalarInteropPatternsAndLegality(
typeConverter, patterns, target);
torch_to_linalg::populateLinearPatternsAndLegality(typeConverter, patterns,

View File

@ -701,7 +701,11 @@ void TypeAnalysis::visitOperation(Operation *op,
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp, AtenMaskedFillTensorOp,
AtenRollOp, AtenPowTensorTensorOp, AtenLiftFreshCopyOp,
AtenIndexTensorHackedTwinOp>(op)) {
AtenIndexTensorHackedTwinOp
#ifdef TORCH_MLIR_CUSTOM_OP_EXAMPLE
, TorchMlirCustomOpExampleIdentityOp
#endif // TORCH_MLIR_CUSTOM_OP_EXAMPLE
>(op)) {
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
}

View File

@ -7815,6 +7815,10 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn._torch_mlir_custom_op_example.identity\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
"}\n"
"";
// clang-format on

View File

@ -17,8 +17,6 @@ add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=torch_mlir.")
# PyTorch
################################################################################
option(TORCH_MLIR_USE_INSTALLED_PYTORCH "Build from local PyTorch in environment" ON)
if (NOT TORCH_MLIR_USE_INSTALLED_PYTORCH)
# Source builds
set(ENV{TORCH_MLIR_SRC_PYTORCH_REPO} ${TORCH_MLIR_SRC_PYTORCH_REPO})
@ -65,6 +63,15 @@ declare_mlir_dialect_python_bindings(
DIALECT_NAME torch
)
if(TORCH_MLIR_CUSTOM_OP_EXAMPLE)
declare_mlir_python_sources(TorchMLIRPythonSources.CustomOp
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TorchMLIRPythonSources
SOURCES_GLOB
_torch_mlir_custom_op_example/__init__.py
)
endif()
################################################################################
# Extensions
################################################################################
@ -111,7 +118,9 @@ add_subdirectory(torch_mlir/eager_mode)
# Required for running the update_torch_ods.sh and update_shape_lib.sh scripts.
################################################################################
# add_subdirectory(torch_mlir/_torch_mlir_custom_op_example)
if(TORCH_MLIR_CUSTOM_OP_EXAMPLE)
add_subdirectory(torch_mlir/_torch_mlir_custom_op_example)
endif()
################################################################################
# Generate packages and shared library
@ -165,4 +174,8 @@ if(TORCH_MLIR_ENABLE_LTC)
add_dependencies(TorchMLIRPythonModules reference_lazy_backend)
endif()
if(TORCH_MLIR_CUSTOM_OP_EXAMPLE)
add_dependencies(TorchMLIRPythonModules torch_mlir_custom_op_example)
endif()
add_subdirectory(test)

View File

@ -1,17 +1,18 @@
###########################################################################
# Setup PyTorch
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/../cmake/modules")
include(TorchMLIRPyTorch)
TorchMLIRProbeForPyTorchInstall()
find_package(Torch 1.8 REQUIRED)
TorchMLIRConfigurePyTorch()
###########################################################################
# Python sources
declare_mlir_python_sources(TorchMLIRPythonSources.CustomOp
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TorchMLIRPythonSources
SOURCES_GLOB
_torch_mlir_custom_op_example/__init__.py
)
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/python/torch_mlir/cmake/modules")
include(TorchMLIRPyTorch)
TorchMLIRProbeForPyTorchInstall()
set(Torch_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../../libtorch/share/cmake/Torch")
find_package(Torch 1.11 REQUIRED)
###########################################################################
# Library definition
###########################################################################
# C++ extension
include_directories(BEFORE

View File

@ -1192,6 +1192,9 @@ def atenlinalg_vector_norm(self: List[int], ord: float = 2, dim: Optional[Lis
def atenfrobenius_normdim(self: List[int], dim: List[int], keepdim: bool = False) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0)
def _torch_mlir_custom_op_exampleidentity(t: List[int]) -> List[int]:
return upstream_shape_functions.unary(t)
# ==============================================================================
# Shape library generator main().
# ==============================================================================
@ -1230,6 +1233,8 @@ def main(args):
for k, v in globals().items():
if "" not in k:
continue
if k == "_torch_mlir_custom_op_exampleidentity" and args.include_custom_op_example == "OFF":
continue
if not hasattr(v, "_not_present_in_registry"):
_verify_signature_matches_registry(v, registry)
# Add it to the compilation unit.
@ -1305,6 +1310,10 @@ def _create_argparse() -> argparse.ArgumentParser:
"--torch_transforms_cpp_dir",
required=True,
help="Directory containing the Torch transforms cpp files")
parser.add_argument(
"--include_custom_op_example",
type=str,
help="String value to denote if custom_op_example has to be included")
parser.add_argument(
"--pytorch_op_extensions",
type=str,

View File

@ -218,7 +218,7 @@ def emit_op(operator: JitOperator,
has_canonicalizer=has_canonicalizer)
def emit_ops(emitter_td: TextEmitter, registry: Registry):
def emit_ops(emitter_td: TextEmitter, registry: Registry, include_custom_op_example: str):
def emit(key, **kwargs):
emit_op(registry[key], emitter_td, **kwargs)
@ -644,6 +644,16 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"quantized::linear : (Tensor, __torch__.torch.classes.quantized.LinearPackedParamsBase, float, int) -> (Tensor)",
traits=["HasValueSemantics"])
# ==========================================================================
# `_torch_mlir_custom_op_example::` namespace.
#
# This is a demonstration of supporting an operation defined in a PyTorch
# extension.
# ==========================================================================
if include_custom_op_example == "ON":
emit("_torch_mlir_custom_op_example::identity : (Tensor) -> (Tensor)")
def dump_registered_ops(outfile: TextIO, registry: Registry):
for _, v in sorted(registry.by_unique_key.items()):
@ -668,7 +678,7 @@ def main(args: argparse.Namespace):
with open(td_path, "w") as f_td:
emitter_td = TextEmitter(f_td)
emitter_td.print(ODS_BANNER)
emit_ops(emitter_td, registry)
emit_ops(emitter_td, registry, args.include_custom_op_example)
def _create_argparse() -> argparse.ArgumentParser:
@ -680,6 +690,10 @@ def _create_argparse() -> argparse.ArgumentParser:
parser.add_argument(
"--debug_registry_dump",
help="File to dump the the PyTorch JIT operator registry into")
parser.add_argument(
"--include_custom_op_example",
type=str,
help="String value to denote if custom_op_example has to be included")
parser.add_argument(
"--pytorch_op_extensions",
type=str,

View File

@ -17,7 +17,7 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = {
"MaxPool2dWithIndicesWith3dInputModule_basic",
}
def register_all_tests():
def register_all_tests(include_custom_op: bool):
"""Registers all the built-in E2E tests that Torch-MLIR provides."""
# Side-effecting import statements.
from . import basic
@ -53,3 +53,5 @@ def register_all_tests():
from . import return_types
from . import control_flow
from . import stats
if include_custom_op:
from . import custom_op_example

View File

@ -2,6 +2,9 @@
set -euo pipefail
src_dir="$(realpath "$(dirname "$0")"/..)"
build_dir="${src_dir}/build"
export INCLUDE_CUSTOM_OP="${INCLUDE_CUSTOM_OP:-OFF}"
cd "$src_dir"