mirror of https://github.com/llvm/torch-mlir
[torch-mlir earthmoving (1/N)] C/C++ code movement.
This creates the `external/torch-mlir` directory as an LLVM_EXTERNAL_PROJECTS-compatible project (analogous to `iree-dialects`) and completes movement/rename of all pure MLIR C/C++ compiler code into there. The next step will be to move all the Python code / code that links/includes PyTorch C++ code (which currently lives in `frontends/pytorch`) into a subdirectory here. I call this "earthmoving" because it is mostly mechanical changes and renames. As a quick summary (we can change this down the road easily) - C++ `mlir::NPCOMP::Torch -> mlir::torch::Torch` - CAPI `npcompTorchListTypeGet -> torchMlirTorchListTypeGet` - preprocessor `#ifndef NPCOMP_ -> #ifndef TORCHMLIR_` - CMake `NPCOMPFoo -> TorchMLIRFoo` The goal of this is to create a standalone project creating a center of mass for entry into the MLIR ecosystem from PyTorch, suitable in scope for eventual inclusion/ownership in PyTorch. The idea is that `external/torch-mlir` will some day be pulled out into its own repository, and then npcomp will simply pull it in as a submodule. Layering-wise, what lives in `torch-mlir` lowers code from PyTorch (currently TorchScript, but TorchFX or pytorch/xla-style tracing are possible extensions) down to what we have been calling the "Torch backend contract" which is cleaned up IR (inlining, simplifcation, conversion to value tensors, ...) entirely in the `torch` dialect. This is the branching off point for further lowering, of which npcomp takes one opinion (outside `torch-mlir` of course!), namely the `TorchConversion` dialect/transforms which lower to IR suitable for IREE and other linalg-on-tensors based lower-level compilers. Summary of changes: - move `{include,lib,test}/Dialect/Torch` into `torch-mlir` - move relevant parts of CAPI into `torch-mlir`. - leave a few things related to the `torch-mlir` Python build commented out, which should be resolved in a subsequent change.pull/305/head
parent
28762699b3
commit
28a7738189
|
@ -128,9 +128,10 @@ if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
||||||
"suffix = '${PYTHON_MODULE_SUFFIX}', "
|
"suffix = '${PYTHON_MODULE_SUFFIX}', "
|
||||||
"extension = '${PYTHON_MODULE_EXTENSION}")
|
"extension = '${PYTHON_MODULE_EXTENSION}")
|
||||||
|
|
||||||
# Include the iree-dialects external project.
|
# Include LLVM_EXTERNAL_PROJECTS.
|
||||||
set(LLVM_EXTERNAL_PROJECTS "iree-dialects")
|
set(LLVM_EXTERNAL_PROJECTS "iree-dialects;torch-mlir")
|
||||||
set(LLVM_EXTERNAL_IREE_DIALECTS_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/external/iree-dialects")
|
set(LLVM_EXTERNAL_IREE_DIALECTS_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/external/iree-dialects")
|
||||||
|
set(LLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/external/torch-mlir")
|
||||||
|
|
||||||
# LLVM configuration.
|
# LLVM configuration.
|
||||||
message(STATUS "*** ADDING LLVM ***")
|
message(STATUS "*** ADDING LLVM ***")
|
||||||
|
@ -183,6 +184,8 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
|
||||||
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)
|
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/external/iree-dialects/include)
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/external/iree-dialects/include)
|
||||||
include_directories(${CMAKE_CURRENT_BINARY_DIR}/llvm/tools/iree-dialects/include)
|
include_directories(${CMAKE_CURRENT_BINARY_DIR}/llvm/tools/iree-dialects/include)
|
||||||
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/external/torch-mlir/include)
|
||||||
|
include_directories(${CMAKE_CURRENT_BINARY_DIR}/llvm/tools/torch-mlir/include)
|
||||||
link_directories(${LLVM_BUILD_LIBRARY_DIR})
|
link_directories(${LLVM_BUILD_LIBRARY_DIR})
|
||||||
add_definitions(${LLVM_DEFINITIONS})
|
add_definitions(${LLVM_DEFINITIONS})
|
||||||
set(NPCOMP_TABLEGEN_ARGS "")
|
set(NPCOMP_TABLEGEN_ARGS "")
|
||||||
|
|
|
@ -21,6 +21,7 @@ cd $td/build
|
||||||
|
|
||||||
ninja
|
ninja
|
||||||
ninja check-npcomp
|
ninja check-npcomp
|
||||||
|
ninja check-torch-mlir
|
||||||
ninja check-frontends-pytorch
|
ninja check-frontends-pytorch
|
||||||
|
|
||||||
echo
|
echo
|
||||||
|
|
|
@ -0,0 +1,72 @@
|
||||||
|
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
||||||
|
message(FATAL_ERROR
|
||||||
|
"This project is intended to be built as part of LLVM via "
|
||||||
|
"-DLLVM_EXTERNAL_PROJECTS=torch-mlir "
|
||||||
|
"-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR=${CMAKE_CURRENT_SOURCE_DIR}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
option(MLIR_ENABLE_BINDINGS_PYTHON "Enables MLIR Python Bindings" OFF)
|
||||||
|
|
||||||
|
set(TORCH_MLIR_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
|
||||||
|
set(TORCH_MLIR_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")
|
||||||
|
message(STATUS "Building torch-mlir project at ${TORCH_MLIR_SOURCE_DIR} (into ${TORCH_MLIR_BINARY_DIR})")
|
||||||
|
|
||||||
|
# TODO: Fix this upstream so that global include directories are not needed.
|
||||||
|
set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir)
|
||||||
|
set(MLIR_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include)
|
||||||
|
set(MLIR_GENERATED_INCLUDE_DIR ${LLVM_BINARY_DIR}/tools/mlir/include)
|
||||||
|
|
||||||
|
# TODO: Needed for tablegen. Remove.
|
||||||
|
include_directories(SYSTEM ${MLIR_INCLUDE_DIR})
|
||||||
|
include_directories(SYSTEM ${MLIR_GENERATED_INCLUDE_DIR})
|
||||||
|
include_directories(SYSTEM ${TORCH_MLIR_SOURCE_DIR}/include)
|
||||||
|
|
||||||
|
function(torch_mlir_target_includes target)
|
||||||
|
set(_dirs
|
||||||
|
$<BUILD_INTERFACE:${MLIR_INCLUDE_DIR}>
|
||||||
|
$<BUILD_INTERFACE:${MLIR_GENERATED_INCLUDE_DIR}>
|
||||||
|
$<BUILD_INTERFACE:${TORCH_MLIR_SOURCE_DIR}/include>
|
||||||
|
$<BUILD_INTERFACE:${TORCH_MLIR_BINARY_DIR}/include>
|
||||||
|
)
|
||||||
|
# In LLVM parlance, the actual target may just be an interface and may not
|
||||||
|
# be responsible for actually compiling anything. The corresponding obj.
|
||||||
|
# target, when present, is just used for compilation and does not
|
||||||
|
# contribute to the interface properties.
|
||||||
|
# TODO: Normalize this upstream.
|
||||||
|
target_include_directories(${target} PUBLIC ${_dirs})
|
||||||
|
if(TARGET obj.${target})
|
||||||
|
target_include_directories(obj.${target} PRIVATE ${_dirs})
|
||||||
|
endif()
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
# Configure CMake and tablegen.
|
||||||
|
list(APPEND CMAKE_MODULE_PATH ${MLIR_MAIN_SRC_DIR}/cmake/modules)
|
||||||
|
list(APPEND CMAKE_MODULE_PATH ${LLVM_MAIN_SRC_DIR}/cmake)
|
||||||
|
set(MLIR_TABLEGEN_EXE mlir-tblgen)
|
||||||
|
|
||||||
|
include(TableGen)
|
||||||
|
include(AddLLVM)
|
||||||
|
include(AddMLIR)
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# Setup python.
|
||||||
|
# TODO: Make one upstream macro to do this.
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
if(MLIR_ENABLE_BINDINGS_PYTHON)
|
||||||
|
include(MLIRDetectPythonEnv)
|
||||||
|
mlir_detect_pybind11_install()
|
||||||
|
find_package(Python3 ${LLVM_MINIMUM_PYTHON_VERSION}
|
||||||
|
COMPONENTS Interpreter Development NumPy REQUIRED)
|
||||||
|
find_package(pybind11 2.6 CONFIG REQUIRED)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
add_subdirectory(include)
|
||||||
|
add_subdirectory(lib)
|
||||||
|
add_subdirectory(test)
|
||||||
|
add_subdirectory(tools)
|
||||||
|
|
||||||
|
if(MLIR_ENABLE_BINDINGS_PYTHON)
|
||||||
|
#XXX: Enable
|
||||||
|
#add_subdirectory(python)
|
||||||
|
endif()
|
|
@ -0,0 +1 @@
|
||||||
|
add_subdirectory(torch-mlir)
|
|
@ -0,0 +1,32 @@
|
||||||
|
/*===-- torch-mlir-c/Registration.h - Registration functions -----*- C -*-===*\
|
||||||
|
|* *|
|
||||||
|
|* Part of the LLVM Project, under the Apache License v2.0 with LLVM *|
|
||||||
|
|* Exceptions. *|
|
||||||
|
|* See https://llvm.org/LICENSE.txt for license information. *|
|
||||||
|
|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *|
|
||||||
|
|* *|
|
||||||
|
\*===----------------------------------------------------------------------===*/
|
||||||
|
|
||||||
|
#ifndef TORCHMLIR_C_REGISTRATION_H
|
||||||
|
#define TORCHMLIR_C_REGISTRATION_H
|
||||||
|
|
||||||
|
#include "mlir-c/IR.h"
|
||||||
|
#include "mlir-c/Support.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/** Registers all dialects with a context.
|
||||||
|
* This is needed before creating IR for these Dialects.
|
||||||
|
*/
|
||||||
|
MLIR_CAPI_EXPORTED void torchMlirRegisterAllDialects(MlirContext context);
|
||||||
|
|
||||||
|
/** Registers all passes for symbolic access with the global registry. */
|
||||||
|
MLIR_CAPI_EXPORTED void torchMlirRegisterAllPasses();
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // TORCHMLIR_C_REGISTRATION_H
|
|
@ -1,4 +1,4 @@
|
||||||
//===-- npcomp-c/TorchTypes.h - C API for torch types -------------*- C -*-===//
|
//===-- torch-mlir-c/TorchTypes.h - C API for torch types ---------*- C -*-===//
|
||||||
//
|
//
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
||||||
// Exceptions.
|
// Exceptions.
|
||||||
|
@ -7,8 +7,8 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef NPCOMP_C_TORCHTYPES_H
|
#ifndef TORCHMLIR_C_TORCHTYPES_H
|
||||||
#define NPCOMP_C_TORCHTYPES_H
|
#define TORCHMLIR_C_TORCHTYPES_H
|
||||||
|
|
||||||
#include "mlir-c/IR.h"
|
#include "mlir-c/IR.h"
|
||||||
#include "mlir-c/Support.h"
|
#include "mlir-c/Support.h"
|
||||||
|
@ -22,32 +22,33 @@ extern "C" {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Checks whether the given type is a torch.nn.Module type
|
/// Checks whether the given type is a torch.nn.Module type
|
||||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchNnModule(MlirType t);
|
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNnModule(MlirType t);
|
||||||
|
|
||||||
/// Gets the !torch.nn.Module type of the specified class.
|
/// Gets the !torch.nn.Module type of the specified class.
|
||||||
MLIR_CAPI_EXPORTED MlirType npcompTorchNnModuleTypeGet(MlirContext context,
|
MLIR_CAPI_EXPORTED MlirType
|
||||||
MlirStringRef className);
|
torchMlirTorchNnModuleTypeGet(MlirContext context, MlirStringRef className);
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.optional type.
|
// torch.optional type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Checks whether the given type is a !torch.optional<T> type
|
/// Checks whether the given type is a !torch.optional<T> type
|
||||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchOptional(MlirType t);
|
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchOptional(MlirType t);
|
||||||
|
|
||||||
/// Gets the !torch.optional<T> type with subtype T.
|
/// Gets the !torch.optional<T> type with subtype T.
|
||||||
MLIR_CAPI_EXPORTED MlirType npcompTorchOptionalTypeGet(MlirType containedType);
|
MLIR_CAPI_EXPORTED MlirType
|
||||||
|
torchMlirTorchOptionalTypeGet(MlirType containedType);
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.tuple<T1, T2, T3> type.
|
// torch.tuple<T1, T2, T3> type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Checks whether the given type is a !torch.tuple type
|
/// Checks whether the given type is a !torch.tuple type
|
||||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchTuple(MlirType t);
|
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchTuple(MlirType t);
|
||||||
|
|
||||||
/// Gets the !torch.tuple type with contained types `containedTypes`.
|
/// Gets the !torch.tuple type with contained types `containedTypes`.
|
||||||
MLIR_CAPI_EXPORTED MlirType
|
MLIR_CAPI_EXPORTED MlirType
|
||||||
npcompTorchTupleTypeGet(MlirContext context, intptr_t numContainedTypes,
|
torchMlirTorchTupleTypeGet(MlirContext context, intptr_t numContainedTypes,
|
||||||
MlirType const *containedTypes);
|
MlirType const *containedTypes);
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -55,77 +56,78 @@ npcompTorchTupleTypeGet(MlirContext context, intptr_t numContainedTypes,
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Checks whether the given type is a !torch.list<T> type
|
/// Checks whether the given type is a !torch.list<T> type
|
||||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchList(MlirType t);
|
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchList(MlirType t);
|
||||||
|
|
||||||
/// Gets the !torch.list<T> type with contained T.
|
/// Gets the !torch.list<T> type with contained T.
|
||||||
MLIR_CAPI_EXPORTED MlirType npcompTorchListTypeGet(MlirType containedType);
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGet(MlirType containedType);
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.Device type.
|
// torch.Device type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Checks whether the given type is a !torch.Device type
|
/// Checks whether the given type is a !torch.Device type
|
||||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchDevice(MlirType t);
|
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchDevice(MlirType t);
|
||||||
|
|
||||||
/// Gets the !torch.Device type.
|
/// Gets the !torch.Device type.
|
||||||
MLIR_CAPI_EXPORTED MlirType npcompTorchDeviceTypeGet(MlirContext context);
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchDeviceTypeGet(MlirContext context);
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.bool type.
|
// torch.bool type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Checks whether the given type is a !torch.bool type
|
/// Checks whether the given type is a !torch.bool type
|
||||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchBool(MlirType t);
|
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchBool(MlirType t);
|
||||||
|
|
||||||
/// Gets the !torch.bool type.
|
/// Gets the !torch.bool type.
|
||||||
MLIR_CAPI_EXPORTED MlirType npcompTorchBoolTypeGet(MlirContext context);
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchBoolTypeGet(MlirContext context);
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.int type.
|
// torch.int type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Checks whether the given type is a !torch.int type
|
/// Checks whether the given type is a !torch.int type
|
||||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchInt(MlirType t);
|
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchInt(MlirType t);
|
||||||
|
|
||||||
/// Gets the !torch.int type.
|
/// Gets the !torch.int type.
|
||||||
MLIR_CAPI_EXPORTED MlirType npcompTorchIntTypeGet(MlirContext context);
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchIntTypeGet(MlirContext context);
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.float type.
|
// torch.float type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Checks whether the given type is a !torch.float type
|
/// Checks whether the given type is a !torch.float type
|
||||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchFloat(MlirType t);
|
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchFloat(MlirType t);
|
||||||
|
|
||||||
/// Gets the !torch.float type.
|
/// Gets the !torch.float type.
|
||||||
MLIR_CAPI_EXPORTED MlirType npcompTorchFloatTypeGet(MlirContext context);
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchFloatTypeGet(MlirContext context);
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.LinearParams type.
|
// torch.LinearParams type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Checks whether the given type is a !torch.LinearParams type
|
/// Checks whether the given type is a !torch.LinearParams type
|
||||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchLinearParams(MlirType t);
|
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchLinearParams(MlirType t);
|
||||||
|
|
||||||
/// Gets the !torch.LinearParams type.
|
/// Gets the !torch.LinearParams type.
|
||||||
MLIR_CAPI_EXPORTED MlirType npcompTorchLinearParamsTypeGet(MlirContext context);
|
MLIR_CAPI_EXPORTED MlirType
|
||||||
|
torchMlirTorchLinearParamsTypeGet(MlirContext context);
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.qint8 type.
|
// torch.qint8 type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Checks whether the given type is a !torch.qint8 type
|
/// Checks whether the given type is a !torch.qint8 type
|
||||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchQInt8(MlirType t);
|
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQInt8(MlirType t);
|
||||||
|
|
||||||
/// Gets the !torch.qint8 type.
|
/// Gets the !torch.qint8 type.
|
||||||
MLIR_CAPI_EXPORTED MlirType npcompTorchQInt8TypeGet(MlirContext context);
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchQInt8TypeGet(MlirContext context);
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.tensor type.
|
// torch.tensor type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Checks whether the given type is a !torch.tensor type
|
/// Checks whether the given type is a !torch.tensor type
|
||||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchNonValueTensor(MlirType t);
|
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNonValueTensor(MlirType t);
|
||||||
|
|
||||||
/// Gets a !torch.tensor type.
|
/// Gets a !torch.tensor type.
|
||||||
///
|
///
|
||||||
|
@ -133,24 +135,25 @@ MLIR_CAPI_EXPORTED bool npcompTypeIsATorchNonValueTensor(MlirType t);
|
||||||
/// information is present (and `numSizes` is ignored in that case). -
|
/// information is present (and `numSizes` is ignored in that case). -
|
||||||
/// `optionalDtype` is allowed to be null, meaning that no dtype
|
/// `optionalDtype` is allowed to be null, meaning that no dtype
|
||||||
/// information is present.
|
/// information is present.
|
||||||
MLIR_CAPI_EXPORTED MlirType npcompTorchNonValueTensorTypeGet(
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchNonValueTensorTypeGet(
|
||||||
MlirContext context, intptr_t numSizes, const int64_t *optionalSizes,
|
MlirContext context, intptr_t numSizes, const int64_t *optionalSizes,
|
||||||
MlirType optionalDtype);
|
MlirType optionalDtype);
|
||||||
|
|
||||||
/// Gets the !torch.tensor type with the least static information.
|
/// Gets the !torch.tensor type with the least static information.
|
||||||
MLIR_CAPI_EXPORTED MlirType
|
MLIR_CAPI_EXPORTED MlirType
|
||||||
npcompTorchNonValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
|
torchMlirTorchNonValueTensorTypeGetWithLeastStaticInformation(
|
||||||
|
MlirContext context);
|
||||||
|
|
||||||
/// Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`.
|
/// Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`.
|
||||||
MLIR_CAPI_EXPORTED MlirType
|
MLIR_CAPI_EXPORTED MlirType
|
||||||
npcompTorchNonValueTensorTypeGetFromShaped(MlirType type);
|
torchMlirTorchNonValueTensorTypeGetFromShaped(MlirType type);
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.vtensor type.
|
// torch.vtensor type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Checks whether the given type is a !torch.vtensor type
|
/// Checks whether the given type is a !torch.vtensor type
|
||||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchValueTensor(MlirType t);
|
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchValueTensor(MlirType t);
|
||||||
|
|
||||||
/// Gets a !torch.vtensor type.
|
/// Gets a !torch.vtensor type.
|
||||||
///
|
///
|
||||||
|
@ -158,71 +161,71 @@ MLIR_CAPI_EXPORTED bool npcompTypeIsATorchValueTensor(MlirType t);
|
||||||
/// information is present (and `numSizes` is ignored in that case).
|
/// information is present (and `numSizes` is ignored in that case).
|
||||||
/// - `optionalDtype` is allowed to be null, meaning that no dtype
|
/// - `optionalDtype` is allowed to be null, meaning that no dtype
|
||||||
/// information is present.
|
/// information is present.
|
||||||
MLIR_CAPI_EXPORTED MlirType npcompTorchValueTensorTypeGet(
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchValueTensorTypeGet(
|
||||||
MlirContext context, intptr_t numSizes, const int64_t *optionalSizes,
|
MlirContext context, intptr_t numSizes, const int64_t *optionalSizes,
|
||||||
MlirType optionalDtype);
|
MlirType optionalDtype);
|
||||||
|
|
||||||
/// Gets the !torch.tensor type with the least static information.
|
/// Gets the !torch.tensor type with the least static information.
|
||||||
MLIR_CAPI_EXPORTED MlirType
|
MLIR_CAPI_EXPORTED MlirType
|
||||||
npcompTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
|
torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
|
||||||
|
|
||||||
/// Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`.
|
/// Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`.
|
||||||
MLIR_CAPI_EXPORTED MlirType
|
MLIR_CAPI_EXPORTED MlirType
|
||||||
npcompTorchValueTensorTypeGetFromShaped(MlirType type);
|
torchMlirTorchValueTensorTypeGetFromShaped(MlirType type);
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// !torch.none type.
|
// !torch.none type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Checks whether the given type is a !torch.none type
|
/// Checks whether the given type is a !torch.none type
|
||||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchNone(MlirType t);
|
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNone(MlirType t);
|
||||||
|
|
||||||
/// Gets the !torch.none type.
|
/// Gets the !torch.none type.
|
||||||
MLIR_CAPI_EXPORTED MlirType npcompTorchNoneTypeGet(MlirContext context);
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchNoneTypeGet(MlirContext context);
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// !torch.str type.
|
// !torch.str type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Checks whether the given type is a !torch.str type
|
/// Checks whether the given type is a !torch.str type
|
||||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchString(MlirType t);
|
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchString(MlirType t);
|
||||||
|
|
||||||
/// Gets the !torch.str type.
|
/// Gets the !torch.str type.
|
||||||
MLIR_CAPI_EXPORTED MlirType npcompTorchStringTypeGet(MlirContext context);
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchStringTypeGet(MlirContext context);
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// !torch.any type.
|
// !torch.any type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Checks whether the given type is a !torch.any type.
|
/// Checks whether the given type is a !torch.any type.
|
||||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchAny(MlirType t);
|
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchAny(MlirType t);
|
||||||
|
|
||||||
/// Gets the !torch.str type.
|
/// Gets the !torch.str type.
|
||||||
MLIR_CAPI_EXPORTED MlirType npcompTorchAnyTypeGet(MlirContext context);
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchAnyTypeGet(MlirContext context);
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// !torch.number type.
|
// !torch.number type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Checks whether the given type is a !torch.number type.
|
/// Checks whether the given type is a !torch.number type.
|
||||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchNumber(MlirType t);
|
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNumber(MlirType t);
|
||||||
|
|
||||||
/// Gets the !torch.number type.
|
/// Gets the !torch.number type.
|
||||||
MLIR_CAPI_EXPORTED MlirType npcompTorchNumberTypeGet(MlirContext context);
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchNumberTypeGet(MlirContext context);
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// !torch.dict type.
|
// !torch.dict type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Checks whether the given type is a !torch.dict type.
|
/// Checks whether the given type is a !torch.dict type.
|
||||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchDict(MlirType t);
|
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchDict(MlirType t);
|
||||||
|
|
||||||
/// Gets the !torch.dict type.
|
/// Gets the !torch.dict type.
|
||||||
MLIR_CAPI_EXPORTED MlirType npcompTorchDictTypeGet(MlirType keyType,
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGet(MlirType keyType,
|
||||||
MlirType valueType);
|
MlirType valueType);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#endif // NPCOMP_C_TORCHTYPES_H
|
#endif // TORCHMLIR_C_TORCHTYPES_H
|
|
@ -0,0 +1 @@
|
||||||
|
add_subdirectory(Dialect)
|
|
@ -0,0 +1 @@
|
||||||
|
add_subdirectory(Torch)
|
|
@ -13,7 +13,7 @@ include "mlir/IR/OpBase.td"
|
||||||
|
|
||||||
def Torch_Dialect : Dialect {
|
def Torch_Dialect : Dialect {
|
||||||
let name = "torch";
|
let name = "torch";
|
||||||
let cppNamespace = "::mlir::NPCOMP::Torch";
|
let cppNamespace = "::mlir::torch::Torch";
|
||||||
let description = [{
|
let description = [{
|
||||||
Top-level dialect for interfacing PyTorch and MLIR.
|
Top-level dialect for interfacing PyTorch and MLIR.
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ def Torch_Dialect : Dialect {
|
||||||
|
|
||||||
This dialect also provides transforms that lower it to the
|
This dialect also provides transforms that lower it to the
|
||||||
"Torch backend contract", which is an IR form that we present to
|
"Torch backend contract", which is an IR form that we present to
|
||||||
later conversions, such as conversion to the npcomp backend contract.
|
later conversions.
|
||||||
The Torch backend contract significantly simplifies the IR representation
|
The Torch backend contract significantly simplifies the IR representation
|
||||||
and puts it in a form easier for later lowering to work on. Specifically:
|
and puts it in a form easier for later lowering to work on. Specifically:
|
||||||
- The TorchScript object graph has been flattened to a list of globals (see
|
- The TorchScript object graph has been flattened to a list of globals (see
|
||||||
|
@ -39,11 +39,12 @@ def Torch_Dialect : Dialect {
|
||||||
|
|
||||||
class TorchOpTrait<string name> : OpTrait, NativeTrait<"", ""> {
|
class TorchOpTrait<string name> : OpTrait, NativeTrait<"", ""> {
|
||||||
let trait = name;
|
let trait = name;
|
||||||
let cppNamespace = "::mlir::NPCOMP::Torch::OpTrait";
|
let cppNamespace = "::mlir::torch::Torch::OpTrait";
|
||||||
}
|
}
|
||||||
|
|
||||||
def HasValueSemantics : TorchOpTrait<"HasValueSemantics">;
|
def HasValueSemantics : TorchOpTrait<"HasValueSemantics">;
|
||||||
def IsTrailingUnderscoreInplaceVariant
|
def IsTrailingUnderscoreInplaceVariant
|
||||||
: TorchOpTrait<"IsTrailingUnderscoreInplaceVariant">;
|
: TorchOpTrait<"IsTrailingUnderscoreInplaceVariant">;
|
||||||
|
def AllowsTypeRefinement : TorchOpTrait<"AllowsTypeRefinement">;
|
||||||
|
|
||||||
#endif // TORCH_BASE
|
#endif // TORCH_BASE
|
|
@ -6,11 +6,11 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef NPCOMP_DIALECT_TORCH_IR_TORCHDIALECT_H
|
#ifndef TORCHMLIR_DIALECT_TORCH_IR_TORCHDIALECT_H
|
||||||
#define NPCOMP_DIALECT_TORCH_IR_TORCHDIALECT_H
|
#define TORCHMLIR_DIALECT_TORCH_IR_TORCHDIALECT_H
|
||||||
|
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
|
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h.inc"
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h.inc"
|
||||||
|
|
||||||
#endif // NPCOMP_DIALECT_TORCH_IR_TORCHDIALECT_H
|
#endif // TORCHMLIR_DIALECT_TORCH_IR_TORCHDIALECT_H
|
|
@ -6,8 +6,8 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef NPCOMP_DIALECT_TORCH_IR_TORCHOPS_H
|
#ifndef TORCHMLIR_DIALECT_TORCH_IR_TORCHOPS_H
|
||||||
#define NPCOMP_DIALECT_TORCH_IR_TORCHOPS_H
|
#define TORCHMLIR_DIALECT_TORCH_IR_TORCHOPS_H
|
||||||
|
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
|
@ -18,15 +18,14 @@
|
||||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchTraits.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchTraits.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||||
#include "npcomp/Interfaces/Traits.h"
|
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h.inc"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h.inc"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace NPCOMP {
|
namespace torch {
|
||||||
namespace Torch {
|
namespace Torch {
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
|
@ -117,11 +116,11 @@ m_TorchConstantIntList(SmallVectorImpl<int64_t> &bind_values) {
|
||||||
Value copyTensorToType(OpBuilder &builder, Location loc, BaseTensorType newType,
|
Value copyTensorToType(OpBuilder &builder, Location loc, BaseTensorType newType,
|
||||||
Value tensor);
|
Value tensor);
|
||||||
} // namespace Torch
|
} // namespace Torch
|
||||||
} // namespace NPCOMP
|
} // namespace torch
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::SlotOp> {
|
template <> struct llvm::DenseMapInfo<::mlir::torch::Torch::SlotOp> {
|
||||||
using SlotOp = ::mlir::NPCOMP::Torch::SlotOp;
|
using SlotOp = ::mlir::torch::Torch::SlotOp;
|
||||||
static SlotOp getEmptyKey() {
|
static SlotOp getEmptyKey() {
|
||||||
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||||
return SlotOp::getFromOpaquePointer(pointer);
|
return SlotOp::getFromOpaquePointer(pointer);
|
||||||
|
@ -136,8 +135,8 @@ template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::SlotOp> {
|
||||||
static bool isEqual(SlotOp lhs, SlotOp rhs) { return lhs == rhs; }
|
static bool isEqual(SlotOp lhs, SlotOp rhs) { return lhs == rhs; }
|
||||||
};
|
};
|
||||||
|
|
||||||
template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::NnModuleOp> {
|
template <> struct llvm::DenseMapInfo<::mlir::torch::Torch::NnModuleOp> {
|
||||||
using NnModuleOp = ::mlir::NPCOMP::Torch::NnModuleOp;
|
using NnModuleOp = ::mlir::torch::Torch::NnModuleOp;
|
||||||
static NnModuleOp getEmptyKey() {
|
static NnModuleOp getEmptyKey() {
|
||||||
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||||
return NnModuleOp::getFromOpaquePointer(pointer);
|
return NnModuleOp::getFromOpaquePointer(pointer);
|
||||||
|
@ -152,8 +151,8 @@ template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::NnModuleOp> {
|
||||||
static bool isEqual(NnModuleOp lhs, NnModuleOp rhs) { return lhs == rhs; }
|
static bool isEqual(NnModuleOp lhs, NnModuleOp rhs) { return lhs == rhs; }
|
||||||
};
|
};
|
||||||
|
|
||||||
template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::ClassTypeOp> {
|
template <> struct llvm::DenseMapInfo<::mlir::torch::Torch::ClassTypeOp> {
|
||||||
using ClassTypeOp = ::mlir::NPCOMP::Torch::ClassTypeOp;
|
using ClassTypeOp = ::mlir::torch::Torch::ClassTypeOp;
|
||||||
static ClassTypeOp getEmptyKey() {
|
static ClassTypeOp getEmptyKey() {
|
||||||
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||||
return ClassTypeOp::getFromOpaquePointer(pointer);
|
return ClassTypeOp::getFromOpaquePointer(pointer);
|
||||||
|
@ -168,8 +167,8 @@ template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::ClassTypeOp> {
|
||||||
static bool isEqual(ClassTypeOp lhs, ClassTypeOp rhs) { return lhs == rhs; }
|
static bool isEqual(ClassTypeOp lhs, ClassTypeOp rhs) { return lhs == rhs; }
|
||||||
};
|
};
|
||||||
|
|
||||||
template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::GlobalSlotOp> {
|
template <> struct llvm::DenseMapInfo<::mlir::torch::Torch::GlobalSlotOp> {
|
||||||
using OpTy = ::mlir::NPCOMP::Torch::GlobalSlotOp;
|
using OpTy = ::mlir::torch::Torch::GlobalSlotOp;
|
||||||
static OpTy getEmptyKey() {
|
static OpTy getEmptyKey() {
|
||||||
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||||
return OpTy::getFromOpaquePointer(pointer);
|
return OpTy::getFromOpaquePointer(pointer);
|
||||||
|
@ -184,4 +183,4 @@ template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::GlobalSlotOp> {
|
||||||
static bool isEqual(OpTy lhs, OpTy rhs) { return lhs == rhs; }
|
static bool isEqual(OpTy lhs, OpTy rhs) { return lhs == rhs; }
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // NPCOMP_DIALECT_TORCH_IR_TORCHOPS_H
|
#endif // TORCHMLIR_DIALECT_TORCH_IR_TORCHOPS_H
|
|
@ -9,8 +9,7 @@
|
||||||
#ifndef TORCH_OPS
|
#ifndef TORCH_OPS
|
||||||
#define TORCH_OPS
|
#define TORCH_OPS
|
||||||
|
|
||||||
include "npcomp/Dialect/Torch/IR/TorchTypes.td"
|
include "torch-mlir/Dialect/Torch/IR/TorchTypes.td"
|
||||||
include "npcomp/Interfaces/Traits.td"
|
|
||||||
include "mlir/IR/OpAsmInterface.td"
|
include "mlir/IR/OpAsmInterface.td"
|
||||||
include "mlir/IR/SymbolInterfaces.td"
|
include "mlir/IR/SymbolInterfaces.td"
|
||||||
include "mlir/Interfaces/CastInterfaces.td"
|
include "mlir/Interfaces/CastInterfaces.td"
|
||||||
|
@ -22,9 +21,9 @@ class Torch_Op<string mnemonic, list<OpTrait> traits = []>
|
||||||
: Op<Torch_Dialect, mnemonic, traits> {
|
: Op<Torch_Dialect, mnemonic, traits> {
|
||||||
}
|
}
|
||||||
|
|
||||||
include "npcomp/Dialect/Torch/IR/GeneratedAtenOps.td"
|
include "torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td"
|
||||||
include "npcomp/Dialect/Torch/IR/GeneratedPrimOps.td"
|
include "torch-mlir/Dialect/Torch/IR/GeneratedPrimOps.td"
|
||||||
include "npcomp/Dialect/Torch/IR/GeneratedQuantizedOps.td"
|
include "torch-mlir/Dialect/Torch/IR/GeneratedQuantizedOps.td"
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TorchScript `torch.nn.Module` object instantiation ops.
|
// TorchScript `torch.nn.Module` object instantiation ops.
|
||||||
|
@ -32,7 +31,7 @@ include "npcomp/Dialect/Torch/IR/GeneratedQuantizedOps.td"
|
||||||
|
|
||||||
def Torch_NnModuleOp : Torch_Op<"nn_module", [
|
def Torch_NnModuleOp : Torch_Op<"nn_module", [
|
||||||
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
|
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
|
||||||
SingleBlockImplicitTerminator<"::mlir::NPCOMP::Torch::NnModuleTerminatorOp">]> {
|
SingleBlockImplicitTerminator<"::mlir::torch::Torch::NnModuleTerminatorOp">]> {
|
||||||
let summary = "Constructs a torch.nn.Module";
|
let summary = "Constructs a torch.nn.Module";
|
||||||
let description = [{
|
let description = [{
|
||||||
This op is used to represent a torch.nn.Module when importing a
|
This op is used to represent a torch.nn.Module when importing a
|
||||||
|
@ -75,7 +74,7 @@ def Torch_NnModuleOp : Torch_Op<"nn_module", [
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_NnModuleTerminatorOp : Torch_Op<"nn_module_terminator", [Terminator,
|
def Torch_NnModuleTerminatorOp : Torch_Op<"nn_module_terminator", [Terminator,
|
||||||
HasParent<"::mlir::NPCOMP::Torch::NnModuleOp">]> {
|
HasParent<"::mlir::torch::Torch::NnModuleOp">]> {
|
||||||
let summary = "Implicit terminator for torch.nn_module";
|
let summary = "Implicit terminator for torch.nn_module";
|
||||||
|
|
||||||
let arguments = (ins);
|
let arguments = (ins);
|
||||||
|
@ -85,7 +84,7 @@ def Torch_NnModuleTerminatorOp : Torch_Op<"nn_module_terminator", [Terminator,
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_SlotOp : Torch_Op<"slot", [
|
def Torch_SlotOp : Torch_Op<"slot", [
|
||||||
HasParent<"::mlir::NPCOMP::Torch::NnModuleOp">]> {
|
HasParent<"::mlir::torch::Torch::NnModuleOp">]> {
|
||||||
let summary = "Define the value of a slot of a torch.nn.Module";
|
let summary = "Define the value of a slot of a torch.nn.Module";
|
||||||
let description = [{
|
let description = [{
|
||||||
This op specifies that the initial value of the slot `name` of the
|
This op specifies that the initial value of the slot `name` of the
|
||||||
|
@ -107,7 +106,7 @@ def Torch_SlotOp : Torch_Op<"slot", [
|
||||||
|
|
||||||
def Torch_ClassTypeOp : Torch_Op<"class_type", [
|
def Torch_ClassTypeOp : Torch_Op<"class_type", [
|
||||||
Symbol,
|
Symbol,
|
||||||
SingleBlockImplicitTerminator<"::mlir::NPCOMP::Torch::ClassTypeTerminatorOp">]> {
|
SingleBlockImplicitTerminator<"::mlir::torch::Torch::ClassTypeTerminatorOp">]> {
|
||||||
let summary = "Constructs a torch.ClassType";
|
let summary = "Constructs a torch.ClassType";
|
||||||
let description = [{
|
let description = [{
|
||||||
Declares a class type. Class types are the types used to describe
|
Declares a class type. Class types are the types used to describe
|
||||||
|
@ -152,7 +151,7 @@ def Torch_ClassTypeOp : Torch_Op<"class_type", [
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_ClassTypeTerminatorOp : Torch_Op<"class_type_terminator", [Terminator,
|
def Torch_ClassTypeTerminatorOp : Torch_Op<"class_type_terminator", [Terminator,
|
||||||
HasParent<"::mlir::NPCOMP::Torch::ClassTypeOp">]> {
|
HasParent<"::mlir::torch::Torch::ClassTypeOp">]> {
|
||||||
let summary = "Implicit terminator for torch.class_type";
|
let summary = "Implicit terminator for torch.class_type";
|
||||||
|
|
||||||
let arguments = (ins);
|
let arguments = (ins);
|
||||||
|
@ -162,7 +161,7 @@ def Torch_ClassTypeTerminatorOp : Torch_Op<"class_type_terminator", [Terminator,
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_MethodOp : Torch_Op<"method", [
|
def Torch_MethodOp : Torch_Op<"method", [
|
||||||
HasParent<"::mlir::NPCOMP::Torch::ClassTypeOp">,
|
HasParent<"::mlir::torch::Torch::ClassTypeOp">,
|
||||||
DeclareOpInterfaceMethods<SymbolUserOpInterface>
|
DeclareOpInterfaceMethods<SymbolUserOpInterface>
|
||||||
]> {
|
]> {
|
||||||
let summary = "Declare a method of a torch.class_type";
|
let summary = "Declare a method of a torch.class_type";
|
||||||
|
@ -193,7 +192,7 @@ def Torch_MethodOp : Torch_Op<"method", [
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AttrOp : Torch_Op<"attr", [
|
def Torch_AttrOp : Torch_Op<"attr", [
|
||||||
HasParent<"::mlir::NPCOMP::Torch::ClassTypeOp">
|
HasParent<"::mlir::torch::Torch::ClassTypeOp">
|
||||||
]> {
|
]> {
|
||||||
let summary = "Declare an attribute of a torch.class_type";
|
let summary = "Declare an attribute of a torch.class_type";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -231,7 +230,7 @@ def Torch_AttrOp : Torch_Op<"attr", [
|
||||||
def Torch_GlobalSlotOp : Torch_Op<"global_slot", [
|
def Torch_GlobalSlotOp : Torch_Op<"global_slot", [
|
||||||
Symbol,
|
Symbol,
|
||||||
IsolatedFromAbove,
|
IsolatedFromAbove,
|
||||||
SingleBlockImplicitTerminator<"::mlir::NPCOMP::Torch::GlobalSlotInitOp">
|
SingleBlockImplicitTerminator<"::mlir::torch::Torch::GlobalSlotInitOp">
|
||||||
]> {
|
]> {
|
||||||
let summary = "A slot with global storage";
|
let summary = "A slot with global storage";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -256,7 +255,7 @@ def Torch_GlobalSlotOp : Torch_Op<"global_slot", [
|
||||||
|
|
||||||
def Torch_GlobalSlotInitOp : Torch_Op<"global_slot.init", [
|
def Torch_GlobalSlotInitOp : Torch_Op<"global_slot.init", [
|
||||||
Terminator,
|
Terminator,
|
||||||
HasParent<"::mlir::NPCOMP::Torch::GlobalSlotOp">]> {
|
HasParent<"::mlir::torch::Torch::GlobalSlotOp">]> {
|
||||||
let summary = "yield-like terminator for torch.global_slot initializer region";
|
let summary = "yield-like terminator for torch.global_slot initializer region";
|
||||||
let description = [{
|
let description = [{
|
||||||
The operand to this op becomes the initial value of the parent
|
The operand to this op becomes the initial value of the parent
|
||||||
|
@ -463,7 +462,7 @@ def Torch_PrimLoopOp : Torch_Op<"prim.Loop", [
|
||||||
def Torch_PrimLoopConditionOp : Torch_Op<"prim.Loop.condition", [
|
def Torch_PrimLoopConditionOp : Torch_Op<"prim.Loop.condition", [
|
||||||
DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>,
|
DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>,
|
||||||
Terminator,
|
Terminator,
|
||||||
HasParent<"::mlir::NPCOMP::Torch::PrimLoopOp">]> {
|
HasParent<"::mlir::torch::Torch::PrimLoopOp">]> {
|
||||||
let summary = "yield-like terminator for torch.prim.Loop";
|
let summary = "yield-like terminator for torch.prim.Loop";
|
||||||
let description = [{
|
let description = [{
|
||||||
Does not correspond to any torch prim op directly (the way that they model
|
Does not correspond to any torch prim op directly (the way that they model
|
||||||
|
@ -512,7 +511,7 @@ def Torch_PrimIfOp : Torch_Op<"prim.If", [
|
||||||
def Torch_PrimIfYieldOp : Torch_Op<"prim.If.yield", [
|
def Torch_PrimIfYieldOp : Torch_Op<"prim.If.yield", [
|
||||||
Terminator,
|
Terminator,
|
||||||
ReturnLike,
|
ReturnLike,
|
||||||
HasParent<"::mlir::NPCOMP::Torch::PrimIfOp">]> {
|
HasParent<"::mlir::torch::Torch::PrimIfOp">]> {
|
||||||
let summary = "yield-like terminator for torch.prim.If";
|
let summary = "yield-like terminator for torch.prim.If";
|
||||||
let description = [{
|
let description = [{
|
||||||
Does not correspond to any torch prim op directly (the way that they model
|
Does not correspond to any torch prim op directly (the way that they model
|
|
@ -10,15 +10,14 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCHMLIR_DIALECT_TORCH_IR_TORCHTRAITS_H
|
||||||
#ifndef NPCOMP_DIALECT_TORCH_IR_TORCHTRAITS_H
|
#define TORCHMLIR_DIALECT_TORCH_IR_TORCHTRAITS_H
|
||||||
#define NPCOMP_DIALECT_TORCH_IR_TORCHTRAITS_H
|
|
||||||
|
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace NPCOMP {
|
namespace torch {
|
||||||
namespace Torch {
|
namespace Torch {
|
||||||
namespace OpTrait {
|
namespace OpTrait {
|
||||||
|
|
||||||
|
@ -39,9 +38,16 @@ class IsTrailingUnderscoreInplaceVariant
|
||||||
: public ::mlir::OpTrait::TraitBase<ConcreteType,
|
: public ::mlir::OpTrait::TraitBase<ConcreteType,
|
||||||
IsTrailingUnderscoreInplaceVariant> {};
|
IsTrailingUnderscoreInplaceVariant> {};
|
||||||
|
|
||||||
|
// If a Torch op has this trait, it means that the op allows all of its operand
|
||||||
|
// and result types to be refined. That is, a less specific type is allowed to
|
||||||
|
// be replaced by a more specific type, according to PEP 483 subtyping rules.
|
||||||
|
template <typename ConcreteType>
|
||||||
|
class AllowsTypeRefinement
|
||||||
|
: public ::mlir::OpTrait::TraitBase<ConcreteType, AllowsTypeRefinement> {};
|
||||||
|
|
||||||
} // namespace OpTrait
|
} // namespace OpTrait
|
||||||
} // namespace Torch
|
} // namespace Torch
|
||||||
} // namespace NPCOMP
|
} // namespace torch
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#endif // NPCOMP_DIALECT_TORCH_IR_TORCHTRAITS_H
|
#endif // TORCHMLIR_DIALECT_TORCH_IR_TORCHTRAITS_H
|
|
@ -6,13 +6,13 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef NPCOMP_DIALECT_TORCH_IR_TORCHTYPES_H
|
#ifndef TORCHMLIR_DIALECT_TORCH_IR_TORCHTYPES_H
|
||||||
#define NPCOMP_DIALECT_TORCH_IR_TORCHTYPES_H
|
#define TORCHMLIR_DIALECT_TORCH_IR_TORCHTYPES_H
|
||||||
|
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace NPCOMP {
|
namespace torch {
|
||||||
namespace Torch {
|
namespace Torch {
|
||||||
|
|
||||||
class NonValueTensorType;
|
class NonValueTensorType;
|
||||||
|
@ -87,18 +87,18 @@ public:
|
||||||
ValueTensorType getWithValueSemantics() const;
|
ValueTensorType getWithValueSemantics() const;
|
||||||
};
|
};
|
||||||
} // namespace Torch
|
} // namespace Torch
|
||||||
} // namespace NPCOMP
|
} // namespace torch
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#define GET_TYPEDEF_CLASSES
|
#define GET_TYPEDEF_CLASSES
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.h.inc"
|
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h.inc"
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Inline definitions
|
// Inline definitions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace NPCOMP {
|
namespace torch {
|
||||||
namespace Torch {
|
namespace Torch {
|
||||||
|
|
||||||
inline Optional<ArrayRef<int64_t>> BaseTensorType::getOptionalSizes() const {
|
inline Optional<ArrayRef<int64_t>> BaseTensorType::getOptionalSizes() const {
|
||||||
|
@ -122,7 +122,7 @@ inline bool BaseTensorType::classof(Type type) {
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace Torch
|
} // namespace Torch
|
||||||
} // namespace NPCOMP
|
} // namespace torch
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#endif // NPCOMP_DIALECT_TORCH_IR_TORCHTYPES_H
|
#endif // TORCHMLIR_DIALECT_TORCH_IR_TORCHTYPES_H
|
|
@ -9,7 +9,7 @@
|
||||||
#ifndef TORCH_TYPES
|
#ifndef TORCH_TYPES
|
||||||
#define TORCH_TYPES
|
#define TORCH_TYPES
|
||||||
|
|
||||||
include "npcomp/Dialect/Torch/IR/TorchBase.td"
|
include "torch-mlir/Dialect/Torch/IR/TorchBase.td"
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Type defs
|
// Type defs
|
||||||
|
@ -83,7 +83,7 @@ class OptionalArrayRefParameter<string arrayOf, string desc = ""> :
|
||||||
}
|
}
|
||||||
|
|
||||||
class AnyTorchTensorType<string name, string typeMnemonic>
|
class AnyTorchTensorType<string name, string typeMnemonic>
|
||||||
: Torch_Type<name, typeMnemonic, "::mlir::NPCOMP::Torch::BaseTensorType"> {
|
: Torch_Type<name, typeMnemonic, "::mlir::torch::Torch::BaseTensorType"> {
|
||||||
let summary = "Multi-dimensional array modeling Torch's Tensor type";
|
let summary = "Multi-dimensional array modeling Torch's Tensor type";
|
||||||
let description = [{
|
let description = [{
|
||||||
Syntax:
|
Syntax:
|
||||||
|
@ -107,8 +107,8 @@ class AnyTorchTensorType<string name, string typeMnemonic>
|
||||||
a strict separation between the value-semantic and potentially-mutating
|
a strict separation between the value-semantic and potentially-mutating
|
||||||
worlds, as one of our main jobs in the compiler is to isolate the mutating
|
worlds, as one of our main jobs in the compiler is to isolate the mutating
|
||||||
parts as much as possible because most lower levels of the compiler stack
|
parts as much as possible because most lower levels of the compiler stack
|
||||||
are expected to require value semantics. E.g. npcomp's backend contract
|
are expected to require value semantics. E.g. many backend contracts
|
||||||
is mostly in terms of linalg-on-tensor for compute-heavy ops, which require
|
mostly use linalg-on-tensor for compute-heavy ops, which require
|
||||||
a conversion to the builtin `tensor` type which has value semantics.
|
a conversion to the builtin `tensor` type which has value semantics.
|
||||||
Some notes about value semantics:
|
Some notes about value semantics:
|
||||||
- Using the type system described in PEP 483 (which TorchScript and other
|
- Using the type system described in PEP 483 (which TorchScript and other
|
||||||
|
@ -165,7 +165,7 @@ class AnyTorchTensorType<string name, string typeMnemonic>
|
||||||
|
|
||||||
Note: We avoid the C++ identifier `TensorType` to avoid C++ name ambiguities
|
Note: We avoid the C++ identifier `TensorType` to avoid C++ name ambiguities
|
||||||
with `mlir::TensorType`, since most code is transitively nested in
|
with `mlir::TensorType`, since most code is transitively nested in
|
||||||
both `::mlir` and `::mlir::NPCOMP::Torch` namespaces.
|
both `::mlir` and `::mlir::torch::Torch` namespaces.
|
||||||
|
|
||||||
Note: We use the Torch-aligned terminology "sizes" and "dtype" instead of
|
Note: We use the Torch-aligned terminology "sizes" and "dtype" instead of
|
||||||
the MLIR-aligned terminology "rank/shape" and "element type". The cheat
|
the MLIR-aligned terminology "rank/shape" and "element type". The cheat
|
||||||
|
@ -209,7 +209,7 @@ def Torch_ValueTensorType : AnyTorchTensorType<"ValueTensor", "vtensor"> {
|
||||||
}
|
}
|
||||||
|
|
||||||
def AnyTorchTensorType : Type<
|
def AnyTorchTensorType : Type<
|
||||||
CPred<"$_self.isa<::mlir::NPCOMP::Torch::BaseTensorType>()">,
|
CPred<"$_self.isa<::mlir::torch::Torch::BaseTensorType>()">,
|
||||||
"Any Torch tensor type"
|
"Any Torch tensor type"
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
@ -317,7 +317,7 @@ def Torch_LinearParamsType : Torch_Type<"LinearParams", "LinearParams"> {
|
||||||
The whole LinearPackedParamsBase abstraction as it stands in PyTorch is a
|
The whole LinearPackedParamsBase abstraction as it stands in PyTorch is a
|
||||||
very library-call-y, runtime-y thing that embodies a number of assumptions
|
very library-call-y, runtime-y thing that embodies a number of assumptions
|
||||||
about the structure of how the program will be executed, which need not hold
|
about the structure of how the program will be executed, which need not hold
|
||||||
for npcomp backends.
|
for backends.
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -386,12 +386,12 @@ def TorchOptionalBoolType:
|
||||||
def TorchOptionalDeviceType:
|
def TorchOptionalDeviceType:
|
||||||
OptionalOf<Torch_DeviceType, "Optional torch device type">;
|
OptionalOf<Torch_DeviceType, "Optional torch device type">;
|
||||||
|
|
||||||
def IsListTypePred : CPred<"$_self.isa<::mlir::NPCOMP::Torch::ListType>()">;
|
def IsListTypePred : CPred<"$_self.isa<::mlir::torch::Torch::ListType>()">;
|
||||||
class ListOf<list<Type> allowedTypes, string descr> :
|
class ListOf<list<Type> allowedTypes, string descr> :
|
||||||
ContainerType<AnyTypeOf<allowedTypes>,
|
ContainerType<AnyTypeOf<allowedTypes>,
|
||||||
IsListTypePred,
|
IsListTypePred,
|
||||||
"$_self.cast<::mlir::NPCOMP::Torch::ListType>().getContainedType()",
|
"$_self.cast<::mlir::torch::Torch::ListType>().getContainedType()",
|
||||||
descr, "::mlir::NPCOMP::Torch::ListType">;
|
descr, "::mlir::torch::Torch::ListType">;
|
||||||
|
|
||||||
def TorchBoolListType : ListOf<[Torch_BoolType], "Bool list type (bool[])">;
|
def TorchBoolListType : ListOf<[Torch_BoolType], "Bool list type (bool[])">;
|
||||||
def TorchIntListType : ListOf<[Torch_IntType], "Int list type (int[])">;
|
def TorchIntListType : ListOf<[Torch_IntType], "Int list type (int[])">;
|
|
@ -0,0 +1,5 @@
|
||||||
|
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||||
|
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||||
|
add_public_tablegen_target(TorchMLIRTorchPassIncGen)
|
||||||
|
|
||||||
|
add_mlir_doc(Passes TorchMLIRTorchTransforms ./ -gen-pass-doc)
|
|
@ -6,15 +6,15 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSES_H
|
#ifndef TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSES_H
|
||||||
#define NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSES_H
|
#define TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSES_H
|
||||||
|
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace NPCOMP {
|
namespace torch {
|
||||||
namespace Torch {
|
namespace Torch {
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createGlobalizeObjectGraphPass();
|
std::unique_ptr<OperationPass<ModuleOp>> createGlobalizeObjectGraphPass();
|
||||||
|
@ -58,7 +58,7 @@ std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
|
||||||
/// Registers all Torch transformation passes.
|
/// Registers all Torch transformation passes.
|
||||||
void registerTorchPasses();
|
void registerTorchPasses();
|
||||||
|
|
||||||
} // namespace NPCOMP
|
} // namespace torch
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#endif // NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSES_H
|
#endif // TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSES_H
|
|
@ -6,14 +6,14 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef NPCOMP_TORCH_PASSES
|
#ifndef TORCHMLIR_TORCH_PASSES
|
||||||
#define NPCOMP_TORCH_PASSES
|
#define TORCHMLIR_TORCH_PASSES
|
||||||
|
|
||||||
include "mlir/Pass/PassBase.td"
|
include "mlir/Pass/PassBase.td"
|
||||||
|
|
||||||
def GlobalizeObjectGraph : Pass<"torch-globalize-object-graph", "ModuleOp"> {
|
def GlobalizeObjectGraph : Pass<"torch-globalize-object-graph", "ModuleOp"> {
|
||||||
let summary = "Converts TorchScript object graphs to a globalized form";
|
let summary = "Converts TorchScript object graphs to a globalized form";
|
||||||
let constructor = "mlir::NPCOMP::Torch::createGlobalizeObjectGraphPass()";
|
let constructor = "mlir::torch::Torch::createGlobalizeObjectGraphPass()";
|
||||||
let description = [{
|
let description = [{
|
||||||
This pass converts a subset of possible TorchScript modules into a
|
This pass converts a subset of possible TorchScript modules into a
|
||||||
more restrictive lower-level form that strips away the need to be
|
more restrictive lower-level form that strips away the need to be
|
||||||
|
@ -80,7 +80,7 @@ def GlobalizeObjectGraph : Pass<"torch-globalize-object-graph", "ModuleOp"> {
|
||||||
- Rationale: This makes the representation of initial values simpler. Also
|
- Rationale: This makes the representation of initial values simpler. Also
|
||||||
as of Feb 2021, TorchScript won't import into this form except
|
as of Feb 2021, TorchScript won't import into this form except
|
||||||
potentially for Tensors (it has a bug related to the identity of
|
potentially for Tensors (it has a bug related to the identity of
|
||||||
objects). And for tensors, the npcomp IValue importer only supports a
|
objects). And for tensors, the IValue importer only supports a
|
||||||
very restricted form of aliasing anyway for other reasons. We are
|
very restricted form of aliasing anyway for other reasons. We are
|
||||||
waiting for signals that more general handling of object aliasing is
|
waiting for signals that more general handling of object aliasing is
|
||||||
important to devote the effort to it.
|
important to devote the effort to it.
|
||||||
|
@ -90,7 +90,7 @@ def GlobalizeObjectGraph : Pass<"torch-globalize-object-graph", "ModuleOp"> {
|
||||||
def PrepareForGlobalizeObjectGraph
|
def PrepareForGlobalizeObjectGraph
|
||||||
: Pass<"torch-prepare-for-globalize-object-graph", "ModuleOp"> {
|
: Pass<"torch-prepare-for-globalize-object-graph", "ModuleOp"> {
|
||||||
let summary = "Lowering in preparation for globalizing";
|
let summary = "Lowering in preparation for globalizing";
|
||||||
let constructor = "mlir::NPCOMP::Torch::createPrepareForGlobalizeObjectGraphPass()";
|
let constructor = "mlir::torch::Torch::createPrepareForGlobalizeObjectGraphPass()";
|
||||||
let description = [{
|
let description = [{
|
||||||
Establishes and the invariants needed by the
|
Establishes and the invariants needed by the
|
||||||
torch-globalize-object-graph transformation. Fails if that cannot be
|
torch-globalize-object-graph transformation. Fails if that cannot be
|
||||||
|
@ -104,7 +104,7 @@ def PrepareForGlobalizeObjectGraph
|
||||||
def AdjustCallingConventions
|
def AdjustCallingConventions
|
||||||
: Pass<"torch-adjust-calling-conventions", "ModuleOp"> {
|
: Pass<"torch-adjust-calling-conventions", "ModuleOp"> {
|
||||||
let summary = "Adjust the calling conventions of functions";
|
let summary = "Adjust the calling conventions of functions";
|
||||||
let constructor = "mlir::NPCOMP::Torch::createAdjustCallingConventionsPass()";
|
let constructor = "mlir::torch::Torch::createAdjustCallingConventionsPass()";
|
||||||
let description = [{
|
let description = [{
|
||||||
Adjusts the calling conventions of functions in the module, with the aim of
|
Adjusts the calling conventions of functions in the module, with the aim of
|
||||||
preparing them for backends and further lowering passes. As this changes
|
preparing them for backends and further lowering passes. As this changes
|
||||||
|
@ -127,7 +127,7 @@ def AdjustCallingConventions
|
||||||
|
|
||||||
def RefineTypes : Pass<"torch-refine-types", "FuncOp"> {
|
def RefineTypes : Pass<"torch-refine-types", "FuncOp"> {
|
||||||
let summary = "Refine types";
|
let summary = "Refine types";
|
||||||
let constructor = "mlir::NPCOMP::Torch::createRefineTypesPass()";
|
let constructor = "mlir::torch::Torch::createRefineTypesPass()";
|
||||||
let description = [{
|
let description = [{
|
||||||
Refines types of the program. Currently, this means shapes and dtypes of
|
Refines types of the program. Currently, this means shapes and dtypes of
|
||||||
tensors/arrays.
|
tensors/arrays.
|
||||||
|
@ -136,7 +136,7 @@ def RefineTypes : Pass<"torch-refine-types", "FuncOp"> {
|
||||||
|
|
||||||
def InlineGlobalSlots : Pass<"torch-inline-global-slots", "ModuleOp"> {
|
def InlineGlobalSlots : Pass<"torch-inline-global-slots", "ModuleOp"> {
|
||||||
let summary = "Inlines torch.global_slot ops.";
|
let summary = "Inlines torch.global_slot ops.";
|
||||||
let constructor = "mlir::NPCOMP::Torch::createInlineGlobalSlotsPass()";
|
let constructor = "mlir::torch::Torch::createInlineGlobalSlotsPass()";
|
||||||
let description = [{
|
let description = [{
|
||||||
Inlines torch.global_slot ops when it is safe to do so.
|
Inlines torch.global_slot ops when it is safe to do so.
|
||||||
|
|
||||||
|
@ -150,7 +150,7 @@ def InlineGlobalSlots : Pass<"torch-inline-global-slots", "ModuleOp"> {
|
||||||
|
|
||||||
def ReduceOpVariants : Pass<"torch-reduce-op-variants", "FuncOp"> {
|
def ReduceOpVariants : Pass<"torch-reduce-op-variants", "FuncOp"> {
|
||||||
let summary = "Reduces variants of ops to a smaller set of ops.";
|
let summary = "Reduces variants of ops to a smaller set of ops.";
|
||||||
let constructor = "mlir::NPCOMP::Torch::createReduceOpVariantsPass()";
|
let constructor = "mlir::torch::Torch::createReduceOpVariantsPass()";
|
||||||
let description = [{
|
let description = [{
|
||||||
Replaces ops with other ops to reduce the number of variants that
|
Replaces ops with other ops to reduce the number of variants that
|
||||||
need to be handled elsewhere in the code.
|
need to be handled elsewhere in the code.
|
||||||
|
@ -181,13 +181,13 @@ def MaximizeValueSemantics : Pass<"torch-maximize-value-semantics", "FuncOp"> {
|
||||||
Also, this pass doesn't currently handle interprocedural rewriting
|
Also, this pass doesn't currently handle interprocedural rewriting
|
||||||
(of private functions), which is even more complex.
|
(of private functions), which is even more complex.
|
||||||
}];
|
}];
|
||||||
let constructor = "mlir::NPCOMP::Torch::createMaximizeValueSemanticsPass()";
|
let constructor = "mlir::torch::Torch::createMaximizeValueSemanticsPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def RefinePublicReturn : Pass<"torch-refine-public-return", "ModuleOp"> {
|
def RefinePublicReturn : Pass<"torch-refine-public-return", "ModuleOp"> {
|
||||||
let summary = "Refine public return";
|
let summary = "Refine public return";
|
||||||
let constructor = "mlir::NPCOMP::Torch::createRefinePublicReturnPass()";
|
let constructor = "mlir::torch::Torch::createRefinePublicReturnPass()";
|
||||||
let description = [{
|
let description = [{
|
||||||
Refines types of values returned from public functions based on
|
Refines types of values returned from public functions based on
|
||||||
intraprocedural information.
|
intraprocedural information.
|
||||||
|
@ -214,4 +214,4 @@ def RefinePublicReturn : Pass<"torch-refine-public-return", "ModuleOp"> {
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // NPCOMP_TORCH_PASSES
|
#endif // TORCHMLIR_TORCH_PASSES
|
|
@ -0,0 +1,23 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCH_MLIR_INITALL_H
|
||||||
|
#define TORCH_MLIR_INITALL_H
|
||||||
|
|
||||||
|
#include "mlir/IR/Dialect.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace torch {
|
||||||
|
|
||||||
|
void registerAllDialects(mlir::DialectRegistry ®istry);
|
||||||
|
void registerAllPasses();
|
||||||
|
|
||||||
|
} // namespace torch
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TORCH_MLIR_INITALL_H
|
|
@ -0,0 +1,18 @@
|
||||||
|
add_mlir_library(TorchMLIRCAPI
|
||||||
|
Registration.cpp
|
||||||
|
TorchTypes.cpp
|
||||||
|
|
||||||
|
ADDITIONAL_HEADER_DIRS
|
||||||
|
${PROJECT_SOURCE_DIR}/include/torch-mlir-c/
|
||||||
|
|
||||||
|
LINK_COMPONENTS
|
||||||
|
Core
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRIR
|
||||||
|
MLIRSupport
|
||||||
|
TorchMLIRTorchDialect
|
||||||
|
TorchMLIRInitAll
|
||||||
|
)
|
||||||
|
|
||||||
|
torch_mlir_target_includes(TorchMLIRCAPI)
|
|
@ -0,0 +1,25 @@
|
||||||
|
//===- Registration.cpp - C Interface for MLIR Registration ---------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "torch-mlir-c/Registration.h"
|
||||||
|
|
||||||
|
#include "mlir/CAPI/IR.h"
|
||||||
|
#include "mlir/Conversion/Passes.h"
|
||||||
|
#include "mlir/Dialect/Linalg/Passes.h"
|
||||||
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
#include "torch-mlir/InitAll.h"
|
||||||
|
|
||||||
|
void torchMlirRegisterAllDialects(MlirContext context) {
|
||||||
|
mlir::DialectRegistry registry;
|
||||||
|
mlir::torch::registerAllDialects(registry);
|
||||||
|
unwrap(context)->appendDialectRegistry(registry);
|
||||||
|
// TODO: Don't eagerly load once D88162 is in and clients can do this.
|
||||||
|
unwrap(context)->loadAllAvailableDialects();
|
||||||
|
}
|
||||||
|
|
||||||
|
void torchMlirRegisterAllPasses() { mlir::torch::registerAllPasses(); }
|
|
@ -6,25 +6,25 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "npcomp-c/TorchTypes.h"
|
#include "torch-mlir-c/TorchTypes.h"
|
||||||
|
|
||||||
#include "mlir/CAPI/IR.h"
|
#include "mlir/CAPI/IR.h"
|
||||||
#include "mlir/CAPI/Support.h"
|
#include "mlir/CAPI/Support.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::torch;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.nn.Module type.
|
// torch.nn.Module type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool npcompTypeIsATorchNnModule(MlirType t) {
|
bool torchMlirTypeIsATorchNnModule(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::NnModuleType>();
|
return unwrap(t).isa<Torch::NnModuleType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType npcompTorchNnModuleTypeGet(MlirContext context,
|
MlirType torchMlirTorchNnModuleTypeGet(MlirContext context,
|
||||||
MlirStringRef className) {
|
MlirStringRef className) {
|
||||||
return wrap(Torch::NnModuleType::get(unwrap(context), unwrap(className)));
|
return wrap(Torch::NnModuleType::get(unwrap(context), unwrap(className)));
|
||||||
}
|
}
|
||||||
|
@ -33,11 +33,11 @@ MlirType npcompTorchNnModuleTypeGet(MlirContext context,
|
||||||
// torch.optional type.
|
// torch.optional type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool npcompTypeIsATorchOptional(MlirType t) {
|
bool torchMlirTypeIsATorchOptional(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::OptionalType>();
|
return unwrap(t).isa<Torch::OptionalType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType npcompTorchOptionalTypeGet(MlirType containedType) {
|
MlirType torchMlirTorchOptionalTypeGet(MlirType containedType) {
|
||||||
return wrap(Torch::OptionalType::get(unwrap(containedType)));
|
return wrap(Torch::OptionalType::get(unwrap(containedType)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -45,11 +45,11 @@ MlirType npcompTorchOptionalTypeGet(MlirType containedType) {
|
||||||
// torch.tuple<T1, T2, T3> type.
|
// torch.tuple<T1, T2, T3> type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool npcompTypeIsATorchTuple(MlirType t) {
|
bool torchMlirTypeIsATorchTuple(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::TupleType>();
|
return unwrap(t).isa<Torch::TupleType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType npcompTorchTupleTypeGet(MlirContext context,
|
MlirType torchMlirTorchTupleTypeGet(MlirContext context,
|
||||||
intptr_t numContainedTypes,
|
intptr_t numContainedTypes,
|
||||||
MlirType const *containedTypes) {
|
MlirType const *containedTypes) {
|
||||||
return wrap(Torch::TupleType::get(
|
return wrap(Torch::TupleType::get(
|
||||||
|
@ -63,11 +63,11 @@ MlirType npcompTorchTupleTypeGet(MlirContext context,
|
||||||
// torch.list<T> type.
|
// torch.list<T> type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool npcompTypeIsATorchList(MlirType t) {
|
bool torchMlirTypeIsATorchList(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::ListType>();
|
return unwrap(t).isa<Torch::ListType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType npcompTorchListTypeGet(MlirType containedType) {
|
MlirType torchMlirTorchListTypeGet(MlirType containedType) {
|
||||||
return wrap(Torch::ListType::get(unwrap(containedType)));
|
return wrap(Torch::ListType::get(unwrap(containedType)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -75,11 +75,11 @@ MlirType npcompTorchListTypeGet(MlirType containedType) {
|
||||||
// torch.Device type.
|
// torch.Device type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool npcompTypeIsATorchDevice(MlirType t) {
|
bool torchMlirTypeIsATorchDevice(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::DeviceType>();
|
return unwrap(t).isa<Torch::DeviceType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType npcompTorchDeviceTypeGet(MlirContext context) {
|
MlirType torchMlirTorchDeviceTypeGet(MlirContext context) {
|
||||||
return wrap(Torch::DeviceType::get(unwrap(context)));
|
return wrap(Torch::DeviceType::get(unwrap(context)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -87,11 +87,11 @@ MlirType npcompTorchDeviceTypeGet(MlirContext context) {
|
||||||
// torch.bool type.
|
// torch.bool type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool npcompTypeIsATorchBool(MlirType t) {
|
bool torchMlirTypeIsATorchBool(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::BoolType>();
|
return unwrap(t).isa<Torch::BoolType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType npcompTorchBoolTypeGet(MlirContext context) {
|
MlirType torchMlirTorchBoolTypeGet(MlirContext context) {
|
||||||
return wrap(Torch::BoolType::get(unwrap(context)));
|
return wrap(Torch::BoolType::get(unwrap(context)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -99,11 +99,11 @@ MlirType npcompTorchBoolTypeGet(MlirContext context) {
|
||||||
// torch.int type.
|
// torch.int type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool npcompTypeIsATorchInt(MlirType t) {
|
bool torchMlirTypeIsATorchInt(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::IntType>();
|
return unwrap(t).isa<Torch::IntType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType npcompTorchIntTypeGet(MlirContext context) {
|
MlirType torchMlirTorchIntTypeGet(MlirContext context) {
|
||||||
return wrap(Torch::IntType::get(unwrap(context)));
|
return wrap(Torch::IntType::get(unwrap(context)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,11 +111,11 @@ MlirType npcompTorchIntTypeGet(MlirContext context) {
|
||||||
// torch.float type.
|
// torch.float type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool npcompTypeIsATorchFloat(MlirType t) {
|
bool torchMlirTypeIsATorchFloat(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::FloatType>();
|
return unwrap(t).isa<Torch::FloatType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType npcompTorchFloatTypeGet(MlirContext context) {
|
MlirType torchMlirTorchFloatTypeGet(MlirContext context) {
|
||||||
return wrap(Torch::FloatType::get(unwrap(context)));
|
return wrap(Torch::FloatType::get(unwrap(context)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -123,11 +123,11 @@ MlirType npcompTorchFloatTypeGet(MlirContext context) {
|
||||||
// torch.LinearParams type.
|
// torch.LinearParams type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool npcompTypeIsATorchLinearParams(MlirType t) {
|
bool torchMlirTypeIsATorchLinearParams(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::LinearParamsType>();
|
return unwrap(t).isa<Torch::LinearParamsType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType npcompTorchLinearParamsTypeGet(MlirContext context) {
|
MlirType torchMlirTorchLinearParamsTypeGet(MlirContext context) {
|
||||||
return wrap(Torch::LinearParamsType::get(unwrap(context)));
|
return wrap(Torch::LinearParamsType::get(unwrap(context)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -135,11 +135,11 @@ MlirType npcompTorchLinearParamsTypeGet(MlirContext context) {
|
||||||
// torch.qint8 type.
|
// torch.qint8 type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool npcompTypeIsATorchQInt8(MlirType t) {
|
bool torchMlirTypeIsATorchQInt8(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::QInt8Type>();
|
return unwrap(t).isa<Torch::QInt8Type>();
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType npcompTorchQInt8TypeGet(MlirContext context) {
|
MlirType torchMlirTorchQInt8TypeGet(MlirContext context) {
|
||||||
return wrap(Torch::QInt8Type::get(unwrap(context)));
|
return wrap(Torch::QInt8Type::get(unwrap(context)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -147,11 +147,11 @@ MlirType npcompTorchQInt8TypeGet(MlirContext context) {
|
||||||
// torch.tensor type.
|
// torch.tensor type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool npcompTypeIsATorchNonValueTensor(MlirType t) {
|
bool torchMlirTypeIsATorchNonValueTensor(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::NonValueTensorType>();
|
return unwrap(t).isa<Torch::NonValueTensorType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType npcompTorchNonValueTensorTypeGet(MlirContext context,
|
MlirType torchMlirTorchNonValueTensorTypeGet(MlirContext context,
|
||||||
intptr_t numSizes,
|
intptr_t numSizes,
|
||||||
const int64_t *optionalSizes,
|
const int64_t *optionalSizes,
|
||||||
MlirType optionalDtype) {
|
MlirType optionalDtype) {
|
||||||
|
@ -162,13 +162,13 @@ MlirType npcompTorchNonValueTensorTypeGet(MlirContext context,
|
||||||
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
|
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType npcompTorchNonValueTensorTypeGetWithLeastStaticInformation(
|
MlirType torchMlirTorchNonValueTensorTypeGetWithLeastStaticInformation(
|
||||||
MlirContext context) {
|
MlirContext context) {
|
||||||
return wrap(Torch::NonValueTensorType::getWithLeastStaticInformation(
|
return wrap(Torch::NonValueTensorType::getWithLeastStaticInformation(
|
||||||
unwrap(context)));
|
unwrap(context)));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType npcompTorchNonValueTensorTypeGetFromShaped(MlirType type) {
|
MlirType torchMlirTorchNonValueTensorTypeGetFromShaped(MlirType type) {
|
||||||
return wrap(Torch::NonValueTensorType::getFromShaped(
|
return wrap(Torch::NonValueTensorType::getFromShaped(
|
||||||
unwrap(type).cast<ShapedType>()));
|
unwrap(type).cast<ShapedType>()));
|
||||||
}
|
}
|
||||||
|
@ -177,11 +177,12 @@ MlirType npcompTorchNonValueTensorTypeGetFromShaped(MlirType type) {
|
||||||
// torch.vtensor type.
|
// torch.vtensor type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool npcompTypeIsATorchValueTensor(MlirType t) {
|
bool torchMlirTypeIsATorchValueTensor(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::ValueTensorType>();
|
return unwrap(t).isa<Torch::ValueTensorType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType npcompTorchValueTensorTypeGet(MlirContext context, intptr_t numSizes,
|
MlirType torchMlirTorchValueTensorTypeGet(MlirContext context,
|
||||||
|
intptr_t numSizes,
|
||||||
const int64_t *optionalSizes,
|
const int64_t *optionalSizes,
|
||||||
MlirType optionalDtype) {
|
MlirType optionalDtype) {
|
||||||
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
|
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
|
||||||
|
@ -191,13 +192,13 @@ MlirType npcompTorchValueTensorTypeGet(MlirContext context, intptr_t numSizes,
|
||||||
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
|
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType
|
MlirType torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(
|
||||||
npcompTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context) {
|
MlirContext context) {
|
||||||
return wrap(
|
return wrap(
|
||||||
Torch::ValueTensorType::getWithLeastStaticInformation(unwrap(context)));
|
Torch::ValueTensorType::getWithLeastStaticInformation(unwrap(context)));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType npcompTorchValueTensorTypeGetFromShaped(MlirType type) {
|
MlirType torchMlirTorchValueTensorTypeGetFromShaped(MlirType type) {
|
||||||
return wrap(
|
return wrap(
|
||||||
Torch::ValueTensorType::getFromShaped(unwrap(type).cast<ShapedType>()));
|
Torch::ValueTensorType::getFromShaped(unwrap(type).cast<ShapedType>()));
|
||||||
}
|
}
|
||||||
|
@ -206,11 +207,11 @@ MlirType npcompTorchValueTensorTypeGetFromShaped(MlirType type) {
|
||||||
// torch.none type.
|
// torch.none type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool npcompTypeIsATorchNone(MlirType t) {
|
bool torchMlirTypeIsATorchNone(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::NoneType>();
|
return unwrap(t).isa<Torch::NoneType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType npcompTorchNoneTypeGet(MlirContext context) {
|
MlirType torchMlirTorchNoneTypeGet(MlirContext context) {
|
||||||
return wrap(Torch::NoneType::get(unwrap(context)));
|
return wrap(Torch::NoneType::get(unwrap(context)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -218,11 +219,11 @@ MlirType npcompTorchNoneTypeGet(MlirContext context) {
|
||||||
// torch.str type.
|
// torch.str type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool npcompTypeIsATorchString(MlirType t) {
|
bool torchMlirTypeIsATorchString(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::StringType>();
|
return unwrap(t).isa<Torch::StringType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType npcompTorchStringTypeGet(MlirContext context) {
|
MlirType torchMlirTorchStringTypeGet(MlirContext context) {
|
||||||
return wrap(Torch::StringType::get(unwrap(context)));
|
return wrap(Torch::StringType::get(unwrap(context)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -230,11 +231,11 @@ MlirType npcompTorchStringTypeGet(MlirContext context) {
|
||||||
// torch.any type.
|
// torch.any type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool npcompTypeIsATorchAny(MlirType t) {
|
bool torchMlirTypeIsATorchAny(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::AnyType>();
|
return unwrap(t).isa<Torch::AnyType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType npcompTorchAnyTypeGet(MlirContext context) {
|
MlirType torchMlirTorchAnyTypeGet(MlirContext context) {
|
||||||
return wrap(Torch::AnyType::get(unwrap(context)));
|
return wrap(Torch::AnyType::get(unwrap(context)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -242,11 +243,11 @@ MlirType npcompTorchAnyTypeGet(MlirContext context) {
|
||||||
// torch.number type.
|
// torch.number type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool npcompTypeIsATorchNumber(MlirType t) {
|
bool torchMlirTypeIsATorchNumber(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::NumberType>();
|
return unwrap(t).isa<Torch::NumberType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType npcompTorchNumberTypeGet(MlirContext context) {
|
MlirType torchMlirTorchNumberTypeGet(MlirContext context) {
|
||||||
return wrap(Torch::NumberType::get(unwrap(context)));
|
return wrap(Torch::NumberType::get(unwrap(context)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -254,10 +255,10 @@ MlirType npcompTorchNumberTypeGet(MlirContext context) {
|
||||||
// torch.Dict type.
|
// torch.Dict type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
bool npcompTypeIsATorchDict(MlirType t) {
|
bool torchMlirTypeIsATorchDict(MlirType t) {
|
||||||
return unwrap(t).isa<Torch::DictType>();
|
return unwrap(t).isa<Torch::DictType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType npcompTorchDictTypeGet(MlirType keyType, MlirType valueType) {
|
MlirType torchMlirTorchDictTypeGet(MlirType keyType, MlirType valueType) {
|
||||||
return wrap(Torch::DictType::get(unwrap(keyType), unwrap(valueType)));
|
return wrap(Torch::DictType::get(unwrap(keyType), unwrap(valueType)));
|
||||||
}
|
}
|
|
@ -0,0 +1,17 @@
|
||||||
|
add_subdirectory(CAPI)
|
||||||
|
add_subdirectory(Dialect)
|
||||||
|
|
||||||
|
add_mlir_library(TorchMLIRInitAll
|
||||||
|
InitAll.cpp
|
||||||
|
|
||||||
|
LINK_COMPONENTS
|
||||||
|
Core
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRIR
|
||||||
|
MLIRSupport
|
||||||
|
TorchMLIRTorchDialect
|
||||||
|
TorchMLIRTorchPasses
|
||||||
|
)
|
||||||
|
|
||||||
|
torch_mlir_target_includes(TorchMLIRInitAll)
|
|
@ -0,0 +1 @@
|
||||||
|
add_subdirectory(Torch)
|
|
@ -1,10 +1,10 @@
|
||||||
add_npcomp_dialect_library(NPCOMPTorchDialect
|
add_mlir_library(TorchMLIRTorchDialect
|
||||||
TorchDialect.cpp
|
TorchDialect.cpp
|
||||||
TorchOps.cpp
|
TorchOps.cpp
|
||||||
TorchTypes.cpp
|
TorchTypes.cpp
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
ADDITIONAL_HEADER_DIRS
|
||||||
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Torch
|
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/Torch
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
MLIRTorchOpsIncGen
|
MLIRTorchOpsIncGen
|
||||||
|
@ -17,6 +17,8 @@ add_npcomp_dialect_library(NPCOMPTorchDialect
|
||||||
MLIRIR
|
MLIRIR
|
||||||
MLIRSupport
|
MLIRSupport
|
||||||
MLIRControlFlowInterfaces
|
MLIRControlFlowInterfaces
|
||||||
|
MLIRInferTypeOpInterface
|
||||||
MLIRSideEffectInterfaces
|
MLIRSideEffectInterfaces
|
||||||
NPCOMPInterfaces
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
torch_mlir_target_includes(TorchMLIRTorchDialect)
|
|
@ -6,20 +6,20 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/IR/DialectImplementation.h"
|
#include "mlir/IR/DialectImplementation.h"
|
||||||
#include "mlir/Transforms/InliningUtils.h"
|
#include "mlir/Transforms/InliningUtils.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||||
#include "llvm/ADT/StringExtras.h"
|
#include "llvm/ADT/StringExtras.h"
|
||||||
#include "llvm/ADT/TypeSwitch.h"
|
#include "llvm/ADT/TypeSwitch.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::NPCOMP::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.cpp.inc"
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.cpp.inc"
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Dialect Interfaces
|
// Dialect Interfaces
|
||||||
|
@ -44,7 +44,7 @@ struct TorchInlinerInterface : public DialectInlinerInterface {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#define GET_TYPEDEF_CLASSES
|
#define GET_TYPEDEF_CLASSES
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.cpp.inc"
|
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.cpp.inc"
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Dialect initialize method.
|
// Dialect initialize method.
|
||||||
|
@ -53,11 +53,11 @@ struct TorchInlinerInterface : public DialectInlinerInterface {
|
||||||
void TorchDialect::initialize() {
|
void TorchDialect::initialize() {
|
||||||
addOperations<
|
addOperations<
|
||||||
#define GET_OP_LIST
|
#define GET_OP_LIST
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchOps.cpp.inc"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc"
|
||||||
>();
|
>();
|
||||||
addTypes<
|
addTypes<
|
||||||
#define GET_TYPEDEF_LIST
|
#define GET_TYPEDEF_LIST
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.cpp.inc"
|
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.cpp.inc"
|
||||||
>();
|
>();
|
||||||
addInterfaces<TorchInlinerInterface>();
|
addInterfaces<TorchInlinerInterface>();
|
||||||
}
|
}
|
||||||
|
@ -84,7 +84,6 @@ void TorchDialect::printType(Type type, DialectAsmPrinter &printer) const {
|
||||||
llvm_unreachable("unknown 'torch' type");
|
llvm_unreachable("unknown 'torch' type");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Dialect-level verifiers.
|
// Dialect-level verifiers.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
|
@ -6,7 +6,7 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
@ -15,8 +15,8 @@
|
||||||
#include "llvm/ADT/StringMap.h"
|
#include "llvm/ADT/StringMap.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::NPCOMP::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
// see https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h#L28
|
// see https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h#L28
|
||||||
static int64_t getDtypeIntegerFromMlirType(Type dtype) {
|
static int64_t getDtypeIntegerFromMlirType(Type dtype) {
|
||||||
|
@ -36,7 +36,7 @@ static int64_t getDtypeIntegerFromMlirType(Type dtype) {
|
||||||
// Utilities
|
// Utilities
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
Value mlir::NPCOMP::Torch::copyTensorToType(OpBuilder &builder, Location loc,
|
Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc,
|
||||||
BaseTensorType newType,
|
BaseTensorType newType,
|
||||||
Value tensor) {
|
Value tensor) {
|
||||||
auto originalType = tensor.getType().cast<BaseTensorType>();
|
auto originalType = tensor.getType().cast<BaseTensorType>();
|
||||||
|
@ -393,7 +393,10 @@ void DerefineOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
// if not removed eagerly by canonicalizer would prevent ReduceOpVariants
|
// if not removed eagerly by canonicalizer would prevent ReduceOpVariants
|
||||||
// from converting certain tensors value semantics.
|
// from converting certain tensors value semantics.
|
||||||
bool allAllowRefinement =
|
bool allAllowRefinement =
|
||||||
llvm::all_of(op.getResult().getUsers(), allowsTypeRefinement);
|
llvm::all_of(op.getResult().getUsers(), [](Operation *op) {
|
||||||
|
return op
|
||||||
|
->hasTrait<mlir::torch::Torch::OpTrait::AllowsTypeRefinement>();
|
||||||
|
});
|
||||||
if (!allAllowRefinement)
|
if (!allAllowRefinement)
|
||||||
return failure();
|
return failure();
|
||||||
rewriter.replaceOp(op, op.getOperand());
|
rewriter.replaceOp(op, op.getOperand());
|
||||||
|
@ -1004,4 +1007,4 @@ OpFoldResult PrimDtypeOp::fold(ArrayRef<Attribute> operands) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchOps.cpp.inc"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc"
|
|
@ -6,15 +6,15 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||||
#include "mlir/IR/DialectImplementation.h"
|
#include "mlir/IR/DialectImplementation.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::NPCOMP::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TupleType
|
// TupleType
|
|
@ -14,13 +14,13 @@
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::NPCOMP::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
// Map from func name and arg index to the type bound for that arg.
|
// Map from func name and arg index to the type bound for that arg.
|
||||||
// This is needed because to rewrite calls, we need the non-local information
|
// This is needed because to rewrite calls, we need the non-local information
|
||||||
|
@ -251,6 +251,6 @@ class AdjustCallingConventionsPass
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
mlir::NPCOMP::Torch::createAdjustCallingConventionsPass() {
|
mlir::torch::Torch::createAdjustCallingConventionsPass() {
|
||||||
return std::make_unique<AdjustCallingConventionsPass>();
|
return std::make_unique<AdjustCallingConventionsPass>();
|
||||||
}
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
add_npcomp_conversion_library(NPCOMPTorchPasses
|
add_mlir_library(TorchMLIRTorchPasses
|
||||||
AdjustCallingConventions.cpp
|
AdjustCallingConventions.cpp
|
||||||
Passes.cpp
|
Passes.cpp
|
||||||
GlobalizeObjectGraph.cpp
|
GlobalizeObjectGraph.cpp
|
||||||
|
@ -10,10 +10,10 @@ add_npcomp_conversion_library(NPCOMPTorchPasses
|
||||||
RefineTypes.cpp
|
RefineTypes.cpp
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
ADDITIONAL_HEADER_DIRS
|
||||||
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Torch/Transforms
|
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/Torch/Transforms
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
NPCOMPTorchPassIncGen
|
TorchMLIRTorchPassIncGen
|
||||||
|
|
||||||
LINK_COMPONENTS
|
LINK_COMPONENTS
|
||||||
Core
|
Core
|
||||||
|
@ -22,6 +22,7 @@ add_npcomp_conversion_library(NPCOMPTorchPasses
|
||||||
MLIRIR
|
MLIRIR
|
||||||
MLIRPass
|
MLIRPass
|
||||||
MLIRTransforms
|
MLIRTransforms
|
||||||
NPCOMPTorchDialect
|
TorchMLIRTorchDialect
|
||||||
NPCOMPInterfaces
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
torch_mlir_target_includes(TorchMLIRTorchPasses)
|
|
@ -12,9 +12,9 @@
|
||||||
#include "mlir/IR/BlockAndValueMapping.h"
|
#include "mlir/IR/BlockAndValueMapping.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/SetVector.h"
|
#include "llvm/ADT/SetVector.h"
|
||||||
#include "llvm/ADT/StringExtras.h"
|
#include "llvm/ADT/StringExtras.h"
|
||||||
|
@ -22,8 +22,8 @@
|
||||||
#include "llvm/ADT/StringSet.h"
|
#include "llvm/ADT/StringSet.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::NPCOMP::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
static FailureOr<NnModuleOp> findRootNnModule(ModuleOp module) {
|
static FailureOr<NnModuleOp> findRootNnModule(ModuleOp module) {
|
||||||
NnModuleOp rootNnModule;
|
NnModuleOp rootNnModule;
|
||||||
|
@ -664,8 +664,6 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) {
|
||||||
module.push_back(newFunc);
|
module.push_back(newFunc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
for (auto &kv : newFuncs) {
|
for (auto &kv : newFuncs) {
|
||||||
BlockAndValueMapping mapping;
|
BlockAndValueMapping mapping;
|
||||||
if (failed(analyzeInstances(kv.second, kv.first.argInstances, mapping)))
|
if (failed(analyzeInstances(kv.second, kv.first.argInstances, mapping)))
|
||||||
|
@ -706,6 +704,6 @@ class GlobalizeObjectGraphPass
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
mlir::NPCOMP::Torch::createGlobalizeObjectGraphPass() {
|
mlir::torch::Torch::createGlobalizeObjectGraphPass() {
|
||||||
return std::make_unique<GlobalizeObjectGraphPass>();
|
return std::make_unique<GlobalizeObjectGraphPass>();
|
||||||
}
|
}
|
|
@ -11,16 +11,16 @@
|
||||||
#include "mlir/IR/BlockAndValueMapping.h"
|
#include "mlir/IR/BlockAndValueMapping.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/SetVector.h"
|
#include "llvm/ADT/SetVector.h"
|
||||||
#include "llvm/ADT/StringExtras.h"
|
#include "llvm/ADT/StringExtras.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::NPCOMP::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class InlineGlobalSlotsPass
|
class InlineGlobalSlotsPass
|
||||||
|
@ -87,6 +87,6 @@ class InlineGlobalSlotsPass
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
mlir::NPCOMP::Torch::createInlineGlobalSlotsPass() {
|
mlir::torch::Torch::createInlineGlobalSlotsPass() {
|
||||||
return std::make_unique<InlineGlobalSlotsPass>();
|
return std::make_unique<InlineGlobalSlotsPass>();
|
||||||
}
|
}
|
|
@ -12,12 +12,12 @@
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::NPCOMP::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
|
class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
|
||||||
|
@ -131,6 +131,6 @@ class MaximizeValueSemanticsPass
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
mlir::NPCOMP::Torch::createMaximizeValueSemanticsPass() {
|
mlir::torch::Torch::createMaximizeValueSemanticsPass() {
|
||||||
return std::make_unique<MaximizeValueSemanticsPass>();
|
return std::make_unique<MaximizeValueSemanticsPass>();
|
||||||
}
|
}
|
|
@ -6,20 +6,20 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
|
#ifndef TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
|
||||||
#define NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
|
#define TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
|
||||||
|
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace NPCOMP {
|
namespace torch {
|
||||||
namespace Torch {
|
namespace Torch {
|
||||||
|
|
||||||
#define GEN_PASS_CLASSES
|
#define GEN_PASS_CLASSES
|
||||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h.inc"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc"
|
||||||
|
|
||||||
} // namespace Torch
|
} // namespace Torch
|
||||||
} // namespace NPCOMP
|
} // namespace torch
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
||||||
#endif // NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
|
#endif // TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
|
|
@ -6,7 +6,7 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
#include "mlir/Pass/PassManager.h"
|
#include "mlir/Pass/PassManager.h"
|
||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
|
||||||
|
@ -16,22 +16,22 @@
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
#define GEN_PASS_REGISTRATION
|
#define GEN_PASS_REGISTRATION
|
||||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h.inc"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc"
|
||||||
} // end namespace
|
} // end namespace
|
||||||
|
|
||||||
void mlir::NPCOMP::registerTorchPasses() {
|
void mlir::torch::registerTorchPasses() {
|
||||||
::registerPasses();
|
::registerPasses();
|
||||||
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||||
"torchscript-to-torch-backend-pipeline",
|
"torchscript-to-torch-backend-pipeline",
|
||||||
"Pipeline lowering TorchScript object graph IR to Torch backend form.",
|
"Pipeline lowering TorchScript object graph IR to Torch backend form.",
|
||||||
mlir::NPCOMP::Torch::createTorchScriptToTorchBackendPipeline);
|
mlir::torch::Torch::createTorchScriptToTorchBackendPipeline);
|
||||||
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||||
"torch-globalized-module-to-torch-backend-pipeline",
|
"torch-globalized-module-to-torch-backend-pipeline",
|
||||||
"Pipeline lowering a globalized Torch program to Torch backend form.",
|
"Pipeline lowering a globalized Torch program to Torch backend form.",
|
||||||
mlir::NPCOMP::Torch::createGlobalizedModuleToTorchBackendPipeline);
|
mlir::torch::Torch::createGlobalizedModuleToTorchBackendPipeline);
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::NPCOMP::Torch::createTorchScriptToTorchBackendPipeline(
|
void mlir::torch::Torch::createTorchScriptToTorchBackendPipeline(
|
||||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
|
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
|
||||||
// When we import TorchScript IR, we import their entire "compilation unit",
|
// When we import TorchScript IR, we import their entire "compilation unit",
|
||||||
// which can contain numerous functions unrelated to the current program,
|
// which can contain numerous functions unrelated to the current program,
|
||||||
|
@ -62,7 +62,7 @@ void mlir::NPCOMP::Torch::createTorchScriptToTorchBackendPipeline(
|
||||||
createGlobalizedModuleToTorchBackendPipeline(pm, options);
|
createGlobalizedModuleToTorchBackendPipeline(pm, options);
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::NPCOMP::Torch::createGlobalizedModuleToTorchBackendPipeline(
|
void mlir::torch::Torch::createGlobalizedModuleToTorchBackendPipeline(
|
||||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
|
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
|
||||||
// General considerations: As a matter of bring-up, we are simultaneously
|
// General considerations: As a matter of bring-up, we are simultaneously
|
||||||
// building out the frontend pipeline and also co-developing the backend
|
// building out the frontend pipeline and also co-developing the backend
|
|
@ -14,13 +14,13 @@
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::NPCOMP::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertPrimCallMethodToCall : public OpRewritePattern<PrimCallMethodOp> {
|
class ConvertPrimCallMethodToCall : public OpRewritePattern<PrimCallMethodOp> {
|
||||||
|
@ -93,9 +93,8 @@ class PrepareForGlobalizeObjectGraphPass
|
||||||
// to the form we want.
|
// to the form we want.
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
target.addIllegalOp<PrimCallMethodOp>();
|
target.addIllegalOp<PrimCallMethodOp>();
|
||||||
target.addDynamicallyLegalOp<ConstantOp>([](ConstantOp op) {
|
target.addDynamicallyLegalOp<ConstantOp>(
|
||||||
return !op.getType().isa<FunctionType>();
|
[](ConstantOp op) { return !op.getType().isa<FunctionType>(); });
|
||||||
});
|
|
||||||
target.addIllegalOp<CallIndirectOp>();
|
target.addIllegalOp<CallIndirectOp>();
|
||||||
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
|
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
|
||||||
|
|
||||||
|
@ -110,6 +109,6 @@ class PrepareForGlobalizeObjectGraphPass
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
mlir::NPCOMP::Torch::createPrepareForGlobalizeObjectGraphPass() {
|
mlir::torch::Torch::createPrepareForGlobalizeObjectGraphPass() {
|
||||||
return std::make_unique<PrepareForGlobalizeObjectGraphPass>();
|
return std::make_unique<PrepareForGlobalizeObjectGraphPass>();
|
||||||
}
|
}
|
|
@ -9,13 +9,13 @@
|
||||||
#include "PassDetail.h"
|
#include "PassDetail.h"
|
||||||
|
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
#include "llvm/ADT/StringExtras.h"
|
#include "llvm/ADT/StringExtras.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::NPCOMP::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Convert value semantic ops operating on mutable arrays to instead operate on
|
// Convert value semantic ops operating on mutable arrays to instead operate on
|
||||||
|
@ -145,6 +145,6 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
mlir::NPCOMP::Torch::createReduceOpVariantsPass() {
|
mlir::torch::Torch::createReduceOpVariantsPass() {
|
||||||
return std::make_unique<ReduceOpVariantsPass>();
|
return std::make_unique<ReduceOpVariantsPass>();
|
||||||
}
|
}
|
|
@ -11,12 +11,12 @@
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::NPCOMP::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
@ -86,6 +86,6 @@ class RefinePublicReturnPass
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
mlir::NPCOMP::Torch::createRefinePublicReturnPass() {
|
mlir::torch::Torch::createRefinePublicReturnPass() {
|
||||||
return std::make_unique<RefinePublicReturnPass>();
|
return std::make_unique<RefinePublicReturnPass>();
|
||||||
}
|
}
|
|
@ -15,13 +15,13 @@
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::NPCOMP::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
// Analysis.
|
// Analysis.
|
||||||
|
@ -264,9 +264,8 @@ public:
|
||||||
} else if (auto sum = dyn_cast<AtenSumOp>(op)) {
|
} else if (auto sum = dyn_cast<AtenSumOp>(op)) {
|
||||||
return visitReductionAlongAllDimsOp(sum, operands);
|
return visitReductionAlongAllDimsOp(sum, operands);
|
||||||
} else if (auto sumDimIntList = dyn_cast<AtenSumDimIntListOp>(op)) {
|
} else if (auto sumDimIntList = dyn_cast<AtenSumDimIntListOp>(op)) {
|
||||||
return visitReductionAlongDimIntListOp(
|
return visitReductionAlongDimIntListOp(sumDimIntList, sumDimIntList.dim(),
|
||||||
sumDimIntList, sumDimIntList.dim(), sumDimIntList.keepdim(),
|
sumDimIntList.keepdim(), operands);
|
||||||
operands);
|
|
||||||
} else if (auto meanDim = dyn_cast<AtenMeanDimOp>(op)) {
|
} else if (auto meanDim = dyn_cast<AtenMeanDimOp>(op)) {
|
||||||
return visitReductionAlongDimIntListOp(meanDim, meanDim.dim(),
|
return visitReductionAlongDimIntListOp(meanDim, meanDim.dim(),
|
||||||
meanDim.keepdim(), operands);
|
meanDim.keepdim(), operands);
|
||||||
|
@ -1114,7 +1113,7 @@ static Type getMostRefinedStaticType(Value v, TypeAnalyzer &analyzer) {
|
||||||
// the right thing forthose ops.
|
// the right thing forthose ops.
|
||||||
//
|
//
|
||||||
static bool allowsTypeRefinementOrIsSafeToRefine(Operation *op) {
|
static bool allowsTypeRefinementOrIsSafeToRefine(Operation *op) {
|
||||||
return allowsTypeRefinement(op) ||
|
return op->hasTrait<mlir::torch::Torch::OpTrait::AllowsTypeRefinement>() ||
|
||||||
isa<CopyToNonValueTensorOp, CopyToValueTensorOp>(op);
|
isa<CopyToNonValueTensorOp, CopyToValueTensorOp>(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1244,6 +1243,6 @@ class RefineTypesPass : public RefineTypesBase<RefineTypesPass> {
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
mlir::NPCOMP::Torch::createRefineTypesPass() {
|
mlir::torch::Torch::createRefineTypesPass() {
|
||||||
return std::make_unique<RefineTypesPass>();
|
return std::make_unique<RefineTypesPass>();
|
||||||
}
|
}
|
|
@ -0,0 +1,19 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "torch-mlir/InitAll.h"
|
||||||
|
|
||||||
|
#include "mlir/IR/Dialect.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
|
||||||
|
void mlir::torch::registerAllDialects(mlir::DialectRegistry ®istry) {
|
||||||
|
registry.insert<mlir::torch::Torch::TorchDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void mlir::torch::registerAllPasses() { mlir::torch::registerTorchPasses(); }
|
|
@ -0,0 +1,30 @@
|
||||||
|
llvm_canonicalize_cmake_booleans(
|
||||||
|
MLIR_ENABLE_BINDINGS_PYTHON
|
||||||
|
)
|
||||||
|
|
||||||
|
configure_lit_site_cfg(
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py
|
||||||
|
MAIN_CONFIG
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py
|
||||||
|
)
|
||||||
|
|
||||||
|
set(TORCH_MLIR_TEST_DEPENDS
|
||||||
|
FileCheck count not
|
||||||
|
torch-mlir-opt
|
||||||
|
)
|
||||||
|
|
||||||
|
# XXX: Enable
|
||||||
|
#if(MLIR_ENABLE_BINDINGS_PYTHON)
|
||||||
|
# list(APPEND TORCH_MLIR_TEST_DEPENDS
|
||||||
|
# TorchMLIRPythonModules
|
||||||
|
# )
|
||||||
|
#endif()
|
||||||
|
|
||||||
|
add_lit_testsuite(check-torch-mlir "Running the torch-mlir regression tests"
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR}
|
||||||
|
DEPENDS ${TORCH_MLIR_TEST_DEPENDS}
|
||||||
|
)
|
||||||
|
set_target_properties(check-torch-mlir PROPERTIES FOLDER "Tests")
|
||||||
|
|
||||||
|
add_lit_testsuites(TORCH_MLIR ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TORCH_MLIR_TEST_DEPENDS})
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
// Basic case.
|
// Basic case.
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: npcomp-opt -torch-globalize-object-graph -verify-diagnostics -split-input-file %s
|
// RUN: torch-mlir-opt -torch-globalize-object-graph -verify-diagnostics -split-input-file %s
|
||||||
|
|
||||||
torch.class_type @c1 {}
|
torch.class_type @c1 {}
|
||||||
torch.class_type @c2 {}
|
torch.class_type @c2 {}
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
torch.class_type @c {
|
torch.class_type @c {
|
||||||
torch.attr "float" : !torch.float
|
torch.attr "float" : !torch.float
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
// CHECK that multiple nested initialization ops are properly handled.
|
// CHECK that multiple nested initialization ops are properly handled.
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
torch.class_type @c {
|
torch.class_type @c {
|
||||||
torch.attr "float" : !torch.float
|
torch.attr "float" : !torch.float
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file -verify-diagnostics %s
|
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file -verify-diagnostics %s
|
||||||
|
|
||||||
torch.class_type @parent {
|
torch.class_type @parent {
|
||||||
torch.method "module_type_return", @module_type_return
|
torch.method "module_type_return", @module_type_return
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
torch.class_type @child {
|
torch.class_type @child {
|
||||||
torch.attr "float" : !torch.float
|
torch.attr "float" : !torch.float
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: npcomp-opt -torch-globalize-object-graph -verify-diagnostics -split-input-file %s
|
// RUN: torch-mlir-opt -torch-globalize-object-graph -verify-diagnostics -split-input-file %s
|
||||||
|
|
||||||
func private @__torch__.Submodule.forward(%arg0: !torch.nn.Module<"__torch__.Submodule">, %arg1: !torch.nn.Module<"__torch__.Submodule">) {
|
func private @__torch__.Submodule.forward(%arg0: !torch.nn.Module<"__torch__.Submodule">, %arg1: !torch.nn.Module<"__torch__.Submodule">) {
|
||||||
return
|
return
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
// Tests monomorphization of same function with different instance argument types.
|
// Tests monomorphization of same function with different instance argument types.
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
torch.class_type @__torch__.TestModule {
|
torch.class_type @__torch__.TestModule {
|
||||||
torch.attr private "s1" : !torch.nn.Module<"__torch__.Submodule">
|
torch.attr private "s1" : !torch.nn.Module<"__torch__.Submodule">
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
// Check that linkage names consist of the dotted path from the root.
|
// Check that linkage names consist of the dotted path from the root.
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
torch.class_type @c {
|
torch.class_type @c {
|
||||||
// CHECK: torch.global_slot "private" @float : !torch.float
|
// CHECK: torch.global_slot "private" @float : !torch.float
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: npcomp-opt -torch-adjust-calling-conventions -allow-unregistered-dialect -split-input-file %s | FileCheck %s
|
// RUN: torch-mlir-opt -torch-adjust-calling-conventions -allow-unregistered-dialect -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func @basic(
|
// CHECK-LABEL: func @basic(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: npcomp-opt %s -canonicalize | FileCheck %s
|
// RUN: torch-mlir-opt %s -canonicalize | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.aten.__is__
|
// CHECK-LABEL: func @torch.aten.__is__
|
||||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: npcomp-opt -torch-inline-global-slots -split-input-file %s | FileCheck %s
|
// RUN: torch-mlir-opt -torch-inline-global-slots -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
// CHECK-NOT: @readonly
|
// CHECK-NOT: @readonly
|
||||||
torch.global_slot "private" @readonly : !torch.tensor {
|
torch.global_slot "private" @readonly : !torch.tensor {
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: npcomp-opt <%s -split-input-file -verify-diagnostics
|
// RUN: torch-mlir-opt <%s -split-input-file -verify-diagnostics
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: npcomp-opt -split-input-file -allow-unregistered-dialect %s -torch-maximize-value-semantics | FileCheck %s
|
// RUN: torch-mlir-opt -split-input-file -allow-unregistered-dialect %s -torch-maximize-value-semantics | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.copy.tensor$basic(
|
// CHECK-LABEL: func @torch.copy.tensor$basic(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: npcomp-opt %s | npcomp-opt | FileCheck %s
|
// RUN: torch-mlir-opt %s | torch-mlir-opt | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.operator(
|
// CHECK-LABEL: func @torch.operator(
|
||||||
func @torch.operator(%arg0: !torch.tensor, %arg1: !torch.tensor) -> !torch.tensor {
|
func @torch.operator(%arg0: !torch.tensor, %arg1: !torch.tensor) -> !torch.tensor {
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: npcomp-opt -torch-prepare-for-globalize-object-graph -split-input-file %s | FileCheck %s
|
// RUN: torch-mlir-opt -torch-prepare-for-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
torch.class_type @c {
|
torch.class_type @c {
|
||||||
torch.method "test_call_method", @test_call_method
|
torch.method "test_call_method", @test_call_method
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: npcomp-opt -torch-reduce-op-variants %s | FileCheck %s
|
// RUN: torch-mlir-opt -torch-reduce-op-variants %s | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func @convert_to_value_semantic_tensors(
|
// CHECK-LABEL: func @convert_to_value_semantic_tensors(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
|
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: npcomp-opt -split-input-file -verify-diagnostics %s -torch-refine-public-return | FileCheck %s
|
// RUN: torch-mlir-opt -split-input-file -verify-diagnostics %s -torch-refine-public-return | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func @basic(
|
// CHECK-LABEL: func @basic(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor<[2,3,?],f32> {
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor<[2,3,?],f32> {
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: npcomp-opt -torch-refine-types -split-input-file %s | FileCheck %s
|
// RUN: torch-mlir-opt -torch-refine-types -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: npcomp-opt -torch-refine-types -split-input-file %s | FileCheck %s
|
// RUN: torch-mlir-opt -torch-refine-types -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func @f(
|
// CHECK-LABEL: func @f(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {
|
|
@ -0,0 +1,71 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import lit.formats
|
||||||
|
import lit.util
|
||||||
|
|
||||||
|
from lit.llvm import llvm_config
|
||||||
|
from lit.llvm.subst import ToolSubst
|
||||||
|
from lit.llvm.subst import FindTool
|
||||||
|
|
||||||
|
# Configuration file for the 'lit' test runner.
|
||||||
|
|
||||||
|
# name: The name of this test suite.
|
||||||
|
config.name = 'TORCH_MLIR'
|
||||||
|
|
||||||
|
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
|
||||||
|
|
||||||
|
# suffixes: A list of file extensions to treat as test files.
|
||||||
|
config.suffixes = ['.mlir', '.py']
|
||||||
|
|
||||||
|
# test_source_root: The root path where tests are located.
|
||||||
|
config.test_source_root = os.path.dirname(__file__)
|
||||||
|
|
||||||
|
# test_exec_root: The root path where tests should be run.
|
||||||
|
config.test_exec_root = os.path.join(config.torch_mlir_obj_root, 'test')
|
||||||
|
|
||||||
|
config.substitutions.append(('%PATH%', config.environment['PATH']))
|
||||||
|
config.substitutions.append(('%shlibext', config.llvm_shlib_ext))
|
||||||
|
|
||||||
|
llvm_config.with_system_environment(['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP'])
|
||||||
|
|
||||||
|
#llvm_config.use_default_substitutions()
|
||||||
|
|
||||||
|
# excludes: A list of directories to exclude from the testsuite. The 'Inputs'
|
||||||
|
# subdirectories contain auxiliary inputs for various tests in their parent
|
||||||
|
# directories.
|
||||||
|
config.excludes = [
|
||||||
|
'Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt',
|
||||||
|
'lit.cfg.py', 'lit.site.cfg.py'
|
||||||
|
]
|
||||||
|
|
||||||
|
# test_source_root: The root path where tests are located.
|
||||||
|
config.test_source_root = os.path.dirname(__file__)
|
||||||
|
|
||||||
|
# test_exec_root: The root path where tests should be run.
|
||||||
|
config.test_exec_root = os.path.join(config.torch_mlir_obj_root, 'test')
|
||||||
|
config.standalone_tools_dir = os.path.join(config.torch_mlir_obj_root, 'bin')
|
||||||
|
|
||||||
|
# Tweak the PATH to include the tools dir.
|
||||||
|
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
|
||||||
|
|
||||||
|
tool_dirs = [config.llvm_tools_dir]
|
||||||
|
tools = [
|
||||||
|
ToolSubst('%PYTHON', config.python_executable, unresolved='ignore'),
|
||||||
|
]
|
||||||
|
|
||||||
|
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
||||||
|
|
||||||
|
if config.enable_bindings_python:
|
||||||
|
llvm_config.with_environment('PYTHONPATH', [
|
||||||
|
os.path.join(config.torch_mlir_obj_root, 'python_packages',
|
||||||
|
'torch_mlir'),
|
||||||
|
],
|
||||||
|
append_path=True)
|
|
@ -0,0 +1,21 @@
|
||||||
|
@LIT_SITE_CFG_IN_HEADER@
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
config.enable_bindings_python = @MLIR_ENABLE_BINDINGS_PYTHON@
|
||||||
|
config.torch_mlir_obj_root = "@TORCH_MLIR_BINARY_DIR@"
|
||||||
|
config.llvm_src_root = "@LLVM_SOURCE_DIR@"
|
||||||
|
config.llvm_obj_root = "@LLVM_BINARY_DIR@"
|
||||||
|
config.llvm_tools_dir = "@LLVM_TOOLS_DIR@"
|
||||||
|
config.llvm_lib_dir = "@LLVM_LIBS_DIR@"
|
||||||
|
config.llvm_shlib_dir = "@SHLIBDIR@"
|
||||||
|
config.llvm_shlib_ext = "@SHLIBEXT@"
|
||||||
|
config.llvm_exe_ext = "@EXEEXT@"
|
||||||
|
config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@"
|
||||||
|
config.python_executable = sys.executable
|
||||||
|
|
||||||
|
import lit.llvm
|
||||||
|
lit.llvm.initialize(lit_config, config)
|
||||||
|
|
||||||
|
# Let the main config do the real work.
|
||||||
|
lit_config.load_config(config, "@TORCH_MLIR_SOURCE_DIR@/test/lit.cfg.py")
|
|
@ -0,0 +1,2 @@
|
||||||
|
if not config.enable_bindings_python:
|
||||||
|
config.unsupported = True
|
|
@ -0,0 +1,9 @@
|
||||||
|
# RUN: %PYTHON %s
|
||||||
|
# XXX: Fix this
|
||||||
|
# XFAIL: *
|
||||||
|
|
||||||
|
import mlir.ir
|
||||||
|
from mlir.dialects import iree
|
||||||
|
|
||||||
|
with mlir.ir.Context() as ctx:
|
||||||
|
iree.register_iree_dialect(ctx)
|
|
@ -0,0 +1 @@
|
||||||
|
add_subdirectory(torch-mlir-opt)
|
|
@ -0,0 +1,13 @@
|
||||||
|
add_executable(torch-mlir-opt torch-mlir-opt.cpp)
|
||||||
|
|
||||||
|
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
|
||||||
|
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
|
||||||
|
|
||||||
|
target_link_libraries(torch-mlir-opt PRIVATE
|
||||||
|
MLIROptLib
|
||||||
|
TorchMLIRInitAll
|
||||||
|
TorchMLIRTorchDialect
|
||||||
|
TorchMLIRTorchPasses
|
||||||
|
${dialect_libs}
|
||||||
|
${conversion_libs}
|
||||||
|
)
|
|
@ -0,0 +1,27 @@
|
||||||
|
//===- torch-mlir-opt.cpp - MLIR Optimizer Driver -------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/InitAllDialects.h"
|
||||||
|
#include "mlir/InitAllPasses.h"
|
||||||
|
#include "mlir/Support/MlirOptMain.h"
|
||||||
|
#include "torch-mlir/InitAll.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
registerAllPasses();
|
||||||
|
mlir::torch::registerAllPasses();
|
||||||
|
|
||||||
|
DialectRegistry registry;
|
||||||
|
registerAllDialects(registry);
|
||||||
|
mlir::torch::registerAllDialects(registry);
|
||||||
|
|
||||||
|
return mlir::asMainReturnCode(
|
||||||
|
mlir::MlirOptMain(argc, argv, "MLIR modular optimizer driver\n", registry,
|
||||||
|
/*preloadDialectsInContext=*/false));
|
||||||
|
}
|
|
@ -11,7 +11,7 @@
|
||||||
|
|
||||||
#include "mlir-c/BuiltinAttributes.h"
|
#include "mlir-c/BuiltinAttributes.h"
|
||||||
#include "mlir-c/BuiltinTypes.h"
|
#include "mlir-c/BuiltinTypes.h"
|
||||||
#include "npcomp-c/TorchTypes.h"
|
#include "torch-mlir-c/TorchTypes.h"
|
||||||
|
|
||||||
#include <ATen/core/function_schema.h>
|
#include <ATen/core/function_schema.h>
|
||||||
#include <ATen/core/ivalue.h>
|
#include <ATen/core/ivalue.h>
|
||||||
|
@ -510,14 +510,14 @@ MlirType AcapController::mapIValueToMlirType(MlirLocation loc,
|
||||||
return mlirIntegerTypeGet(funcBuilder->getContext(), 1);
|
return mlirIntegerTypeGet(funcBuilder->getContext(), 1);
|
||||||
}
|
}
|
||||||
if (ival.isList()) {
|
if (ival.isList()) {
|
||||||
return npcompTorchListTypeGet(
|
return torchMlirTorchListTypeGet(
|
||||||
typeMapper.mapFromTorchType(loc, ival.toList().elementType()));
|
typeMapper.mapFromTorchType(loc, ival.toList().elementType()));
|
||||||
}
|
}
|
||||||
if (ival.isNone()) {
|
if (ival.isNone()) {
|
||||||
return npcompTorchNoneTypeGet(funcBuilder->getContext());
|
return torchMlirTorchNoneTypeGet(funcBuilder->getContext());
|
||||||
}
|
}
|
||||||
if (ival.isDevice()) {
|
if (ival.isDevice()) {
|
||||||
return npcompTorchNoneTypeGet(funcBuilder->getContext());
|
return torchMlirTorchNoneTypeGet(funcBuilder->getContext());
|
||||||
}
|
}
|
||||||
return {nullptr};
|
return {nullptr};
|
||||||
}
|
}
|
||||||
|
@ -527,7 +527,7 @@ MlirValue AcapController::importTensorByValue(at::Tensor tensor) {
|
||||||
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
|
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
|
||||||
MlirOperation tensorOp = createMlirOperationAtEnd(
|
MlirOperation tensorOp = createMlirOperationAtEnd(
|
||||||
funcBuilder->getEntryBlock(), "torch.tensor.literal", loc,
|
funcBuilder->getEntryBlock(), "torch.tensor.literal", loc,
|
||||||
npcompTorchNonValueTensorTypeGetFromShaped(
|
torchMlirTorchNonValueTensorTypeGetFromShaped(
|
||||||
mlirAttributeGetType(denseElements)),
|
mlirAttributeGetType(denseElements)),
|
||||||
toMlirNamedAttribute("value", denseElements));
|
toMlirNamedAttribute("value", denseElements));
|
||||||
MlirValue tensorValue = mlirOperationGetResult(tensorOp, 0);
|
MlirValue tensorValue = mlirOperationGetResult(tensorOp, 0);
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
#include "mlir-c/BuiltinAttributes.h"
|
#include "mlir-c/BuiltinAttributes.h"
|
||||||
#include "mlir-c/BuiltinTypes.h"
|
#include "mlir-c/BuiltinTypes.h"
|
||||||
#include "mlir-c/Diagnostics.h"
|
#include "mlir-c/Diagnostics.h"
|
||||||
#include "npcomp-c/TorchTypes.h"
|
#include "torch-mlir-c/TorchTypes.h"
|
||||||
|
|
||||||
using namespace torch_mlir;
|
using namespace torch_mlir;
|
||||||
|
|
||||||
|
@ -98,7 +98,7 @@ MlirValue FuncBuilder::getScalarConstant(MlirLocation loc, at::Scalar s) {
|
||||||
// represented as one of double or int64_t, with a special tag for whether
|
// represented as one of double or int64_t, with a special tag for whether
|
||||||
// it should be interpreted as a bool.
|
// it should be interpreted as a bool.
|
||||||
if (s.isIntegral(/*includeBool=*/false)) {
|
if (s.isIntegral(/*includeBool=*/false)) {
|
||||||
MlirType t = npcompTorchIntTypeGet(context);
|
MlirType t = torchMlirTorchIntTypeGet(context);
|
||||||
MlirAttribute value =
|
MlirAttribute value =
|
||||||
mlirIntegerAttrGet(mlirIntegerTypeGet(context, 64), s.to<int64_t>());
|
mlirIntegerAttrGet(mlirIntegerTypeGet(context, 64), s.to<int64_t>());
|
||||||
MlirOperation op = createMlirOperation(
|
MlirOperation op = createMlirOperation(
|
||||||
|
@ -107,7 +107,7 @@ MlirValue FuncBuilder::getScalarConstant(MlirLocation loc, at::Scalar s) {
|
||||||
return mlirOperationGetResult(op, 0);
|
return mlirOperationGetResult(op, 0);
|
||||||
}
|
}
|
||||||
if (s.isFloatingPoint()) {
|
if (s.isFloatingPoint()) {
|
||||||
MlirType t = npcompTorchFloatTypeGet(context);
|
MlirType t = torchMlirTorchFloatTypeGet(context);
|
||||||
MlirAttribute value = mlirFloatAttrDoubleGet(
|
MlirAttribute value = mlirFloatAttrDoubleGet(
|
||||||
context, mlirF64TypeGet(context), s.to<double>());
|
context, mlirF64TypeGet(context), s.to<double>());
|
||||||
MlirOperation op = createMlirOperation(
|
MlirOperation op = createMlirOperation(
|
||||||
|
@ -133,7 +133,7 @@ MlirValue FuncBuilder::getNoneConstant(MlirLocation loc) {
|
||||||
|
|
||||||
MlirValue FuncBuilder::buildList(MlirLocation loc, MlirType elementType,
|
MlirValue FuncBuilder::buildList(MlirLocation loc, MlirType elementType,
|
||||||
std::vector<MlirValue> &elements) {
|
std::vector<MlirValue> &elements) {
|
||||||
MlirType resultType = npcompTorchListTypeGet(elementType);
|
MlirType resultType = torchMlirTorchListTypeGet(elementType);
|
||||||
OperationStateHolder state{"torch.prim.ListConstruct", loc};
|
OperationStateHolder state{"torch.prim.ListConstruct", loc};
|
||||||
mlirOperationStateAddResults(state, 1, &resultType);
|
mlirOperationStateAddResults(state, 1, &resultType);
|
||||||
mlirOperationStateAddOperands(state, elements.size(), elements.data());
|
mlirOperationStateAddOperands(state, elements.size(), elements.data());
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
#include "mlir-c/BuiltinAttributes.h"
|
#include "mlir-c/BuiltinAttributes.h"
|
||||||
#include "mlir-c/BuiltinTypes.h"
|
#include "mlir-c/BuiltinTypes.h"
|
||||||
#include "mlir-c/Diagnostics.h"
|
#include "mlir-c/Diagnostics.h"
|
||||||
#include "npcomp-c/TorchTypes.h"
|
#include "torch-mlir-c/TorchTypes.h"
|
||||||
|
|
||||||
#include "caffe2/core/scope_guard.h"
|
#include "caffe2/core/scope_guard.h"
|
||||||
#include "ATen/native/quantized/cpu/packed_params.h"
|
#include "ATen/native/quantized/cpu/packed_params.h"
|
||||||
|
@ -170,7 +170,7 @@ IValueImporter::importModule(torch::jit::Module currentModule) {
|
||||||
|
|
||||||
MlirOperation nnModule = createMlirOperation(
|
MlirOperation nnModule = createMlirOperation(
|
||||||
"torch.nn_module", loc,
|
"torch.nn_module", loc,
|
||||||
npcompTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
|
torchMlirTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
|
||||||
mlirRegionCreate());
|
mlirRegionCreate());
|
||||||
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
|
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
|
||||||
mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr));
|
mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr));
|
||||||
|
@ -240,7 +240,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
||||||
MlirLocation loc = mlirLocationUnknownGet(context);
|
MlirLocation loc = mlirLocationUnknownGet(context);
|
||||||
|
|
||||||
if (ivalue.isBool()) {
|
if (ivalue.isBool()) {
|
||||||
MlirType type = npcompTorchBoolTypeGet(context);
|
MlirType type = torchMlirTorchBoolTypeGet(context);
|
||||||
MlirOperation operation = createMlirOperationAtEnd(
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
importBlock, "torch.constant.bool", loc, type,
|
importBlock, "torch.constant.bool", loc, type,
|
||||||
toMlirNamedAttribute("value",
|
toMlirNamedAttribute("value",
|
||||||
|
@ -248,7 +248,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
||||||
return mlirOperationGetResult(operation, 0);
|
return mlirOperationGetResult(operation, 0);
|
||||||
}
|
}
|
||||||
if (ivalue.isDouble()) {
|
if (ivalue.isDouble()) {
|
||||||
MlirType type = npcompTorchFloatTypeGet(context);
|
MlirType type = torchMlirTorchFloatTypeGet(context);
|
||||||
MlirOperation operation = createMlirOperationAtEnd(
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
importBlock, "torch.constant.float", loc, type,
|
importBlock, "torch.constant.float", loc, type,
|
||||||
toMlirNamedAttribute(
|
toMlirNamedAttribute(
|
||||||
|
@ -257,7 +257,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
||||||
return mlirOperationGetResult(operation, 0);
|
return mlirOperationGetResult(operation, 0);
|
||||||
}
|
}
|
||||||
if (ivalue.isInt()) {
|
if (ivalue.isInt()) {
|
||||||
MlirType type = npcompTorchIntTypeGet(context);
|
MlirType type = torchMlirTorchIntTypeGet(context);
|
||||||
MlirOperation operation = createMlirOperationAtEnd(
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
importBlock, "torch.constant.int", loc, type,
|
importBlock, "torch.constant.int", loc, type,
|
||||||
toMlirNamedAttribute("value",
|
toMlirNamedAttribute("value",
|
||||||
|
@ -273,7 +273,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
||||||
}
|
}
|
||||||
MlirOperation operation = createMlirOperationAtEnd(
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
importBlock, "torch.prim.ListConstruct", loc,
|
importBlock, "torch.prim.ListConstruct", loc,
|
||||||
npcompTorchListTypeGet(
|
torchMlirTorchListTypeGet(
|
||||||
typeMapper.mapFromTorchType(loc, list.elementType())),
|
typeMapper.mapFromTorchType(loc, list.elementType())),
|
||||||
elems);
|
elems);
|
||||||
return mlirOperationGetResult(operation, 0);
|
return mlirOperationGetResult(operation, 0);
|
||||||
|
@ -288,7 +288,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
||||||
}
|
}
|
||||||
MlirOperation operation = createMlirOperationAtEnd(
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
importBlock, "torch.prim.DictConstruct", loc,
|
importBlock, "torch.prim.DictConstruct", loc,
|
||||||
npcompTorchDictTypeGet(
|
torchMlirTorchDictTypeGet(
|
||||||
typeMapper.mapFromTorchType(loc, dict.keyType()),
|
typeMapper.mapFromTorchType(loc, dict.keyType()),
|
||||||
typeMapper.mapFromTorchType(loc, dict.valueType())),
|
typeMapper.mapFromTorchType(loc, dict.valueType())),
|
||||||
keys, values);
|
keys, values);
|
||||||
|
@ -305,7 +305,8 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
||||||
}
|
}
|
||||||
MlirOperation operation = createMlirOperationAtEnd(
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
importBlock, "torch.prim.TupleConstruct", loc,
|
importBlock, "torch.prim.TupleConstruct", loc,
|
||||||
npcompTorchTupleTypeGet(context, types.size(), types.data()), operands);
|
torchMlirTorchTupleTypeGet(context, types.size(), types.data()),
|
||||||
|
operands);
|
||||||
return mlirOperationGetResult(operation, 0);
|
return mlirOperationGetResult(operation, 0);
|
||||||
}
|
}
|
||||||
if (ivalue.isTensor()) {
|
if (ivalue.isTensor()) {
|
||||||
|
@ -317,7 +318,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
||||||
if (ivalue.isString()) {
|
if (ivalue.isString()) {
|
||||||
MlirOperation operation = createMlirOperationAtEnd(
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
importBlock, "torch.constant.str", loc,
|
importBlock, "torch.constant.str", loc,
|
||||||
npcompTorchStringTypeGet(context),
|
torchMlirTorchStringTypeGet(context),
|
||||||
toMlirNamedAttribute(
|
toMlirNamedAttribute(
|
||||||
"value",
|
"value",
|
||||||
mlirStringAttrGet(context,
|
mlirStringAttrGet(context,
|
||||||
|
@ -327,7 +328,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
||||||
if (ivalue.isNone()) {
|
if (ivalue.isNone()) {
|
||||||
MlirOperation operation =
|
MlirOperation operation =
|
||||||
createMlirOperationAtEnd(importBlock, "torch.constant.none", loc,
|
createMlirOperationAtEnd(importBlock, "torch.constant.none", loc,
|
||||||
npcompTorchNoneTypeGet(context));
|
torchMlirTorchNoneTypeGet(context));
|
||||||
return mlirOperationGetResult(operation, 0);
|
return mlirOperationGetResult(operation, 0);
|
||||||
}
|
}
|
||||||
if (ivalue.isCustomClass()) {
|
if (ivalue.isCustomClass()) {
|
||||||
|
@ -346,7 +347,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
||||||
}
|
}
|
||||||
MlirOperation operation = createMlirOperationAtEnd(
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
importBlock, "torch.linear_params.create", loc,
|
importBlock, "torch.linear_params.create", loc,
|
||||||
npcompTorchLinearParamsTypeGet(context), weightValue, biasValue);
|
torchMlirTorchLinearParamsTypeGet(context), weightValue, biasValue);
|
||||||
return mlirOperationGetResult(operation, 0);
|
return mlirOperationGetResult(operation, 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -366,7 +367,7 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
|
||||||
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
|
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
|
||||||
MlirOperation tensorOp =
|
MlirOperation tensorOp =
|
||||||
createMlirOperationAtEnd(importBlock, "torch.tensor.literal", loc,
|
createMlirOperationAtEnd(importBlock, "torch.tensor.literal", loc,
|
||||||
npcompTorchNonValueTensorTypeGetFromShaped(
|
torchMlirTorchNonValueTensorTypeGetFromShaped(
|
||||||
mlirAttributeGetType(denseElements)),
|
mlirAttributeGetType(denseElements)),
|
||||||
toMlirNamedAttribute("value", denseElements));
|
toMlirNamedAttribute("value", denseElements));
|
||||||
MlirValue tensorReprValue = mlirOperationGetResult(tensorOp, 0);
|
MlirValue tensorReprValue = mlirOperationGetResult(tensorOp, 0);
|
||||||
|
@ -381,7 +382,7 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
|
||||||
// compiler stages that are building a statically modeled quantization
|
// compiler stages that are building a statically modeled quantization
|
||||||
// representation will need to convert this to their representation.
|
// representation will need to convert this to their representation.
|
||||||
std::vector<int64_t> shape(tensor.sizes().begin(), tensor.sizes().end());
|
std::vector<int64_t> shape(tensor.sizes().begin(), tensor.sizes().end());
|
||||||
MlirType quantizedTensorType = npcompTorchNonValueTensorTypeGet(
|
MlirType quantizedTensorType = torchMlirTorchNonValueTensorTypeGet(
|
||||||
context, shape.size(), shape.data(),
|
context, shape.size(), shape.data(),
|
||||||
typeMapper.mapFromTorchScalarType(tensor.scalar_type()));
|
typeMapper.mapFromTorchScalarType(tensor.scalar_type()));
|
||||||
if (tensor.qscheme() == c10::kPerTensorAffine) {
|
if (tensor.qscheme() == c10::kPerTensorAffine) {
|
||||||
|
@ -531,11 +532,11 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
|
||||||
int64_t dummy;
|
int64_t dummy;
|
||||||
int64_t *shapeData = shape.size() == 0 ? &dummy : shape.data();
|
int64_t *shapeData = shape.size() == 0 ? &dummy : shape.data();
|
||||||
if (hasValueSemantics) {
|
if (hasValueSemantics) {
|
||||||
typeBound = npcompTorchValueTensorTypeGet(context, shape.size(),
|
typeBound = torchMlirTorchValueTensorTypeGet(context, shape.size(),
|
||||||
shapeData, dtype);
|
shapeData, dtype);
|
||||||
} else {
|
} else {
|
||||||
typeBound = npcompTorchNonValueTensorTypeGet(context, shape.size(),
|
typeBound = torchMlirTorchNonValueTensorTypeGet(
|
||||||
shapeData, dtype);
|
context, shape.size(), shapeData, dtype);
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirNamedAttribute typeBoundAttr = toMlirNamedAttribute(
|
MlirNamedAttribute typeBoundAttr = toMlirNamedAttribute(
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
#include "mlir-c/Diagnostics.h"
|
#include "mlir-c/Diagnostics.h"
|
||||||
#include "mlir-c/Registration.h"
|
#include "mlir-c/Registration.h"
|
||||||
#include "npcomp-c/Registration.h"
|
#include "npcomp-c/Registration.h"
|
||||||
|
#include "torch-mlir-c/Registration.h"
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
using namespace torch_mlir;
|
using namespace torch_mlir;
|
||||||
|
@ -114,7 +115,7 @@ ModuleBuilder::ModuleBuilder(pybind11::object contextObj)
|
||||||
// TODO: Rework this once dialect registration C-APIs are in place.
|
// TODO: Rework this once dialect registration C-APIs are in place.
|
||||||
// https://reviews.llvm.org/D88162
|
// https://reviews.llvm.org/D88162
|
||||||
mlirRegisterAllDialects(context);
|
mlirRegisterAllDialects(context);
|
||||||
npcompRegisterAllDialects(context);
|
torchMlirRegisterAllDialects(context);
|
||||||
|
|
||||||
registerPythonSysStderrDiagnosticHandler(context);
|
registerPythonSysStderrDiagnosticHandler(context);
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
#include "mlir-c/BuiltinAttributes.h"
|
#include "mlir-c/BuiltinAttributes.h"
|
||||||
#include "mlir-c/BuiltinTypes.h"
|
#include "mlir-c/BuiltinTypes.h"
|
||||||
#include "mlir-c/Diagnostics.h"
|
#include "mlir-c/Diagnostics.h"
|
||||||
#include "npcomp-c/TorchTypes.h"
|
#include "torch-mlir-c/TorchTypes.h"
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
using namespace torch_mlir;
|
using namespace torch_mlir;
|
||||||
|
@ -150,7 +150,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
||||||
importAttribute(loc, node, c10::attr::value)));
|
importAttribute(loc, node, c10::attr::value)));
|
||||||
} else if (output->type()->cast<c10::StringType>()) {
|
} else if (output->type()->cast<c10::StringType>()) {
|
||||||
op = createMlirOperation(
|
op = createMlirOperation(
|
||||||
"torch.constant.str", loc, npcompTorchStringTypeGet(context),
|
"torch.constant.str", loc, torchMlirTorchStringTypeGet(context),
|
||||||
toMlirNamedAttribute(
|
toMlirNamedAttribute(
|
||||||
"value", mlirStringAttrGet(context, toMlirStringRef(node->s(
|
"value", mlirStringAttrGet(context, toMlirStringRef(node->s(
|
||||||
c10::attr::value)))));
|
c10::attr::value)))));
|
||||||
|
@ -186,7 +186,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
||||||
mlirRegionCreate());
|
mlirRegionCreate());
|
||||||
mapResults(node, operation);
|
mapResults(node, operation);
|
||||||
std::vector<MlirType> terminatorOperandTypes = {
|
std::vector<MlirType> terminatorOperandTypes = {
|
||||||
npcompTorchBoolTypeGet(context)};
|
torchMlirTorchBoolTypeGet(context)};
|
||||||
terminatorOperandTypes.insert(terminatorOperandTypes.end(),
|
terminatorOperandTypes.insert(terminatorOperandTypes.end(),
|
||||||
resultTypes.begin(), resultTypes.end());
|
resultTypes.begin(), resultTypes.end());
|
||||||
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
|
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
|
||||||
|
|
|
@ -10,7 +10,7 @@
|
||||||
#include "mlir-c/BuiltinAttributes.h"
|
#include "mlir-c/BuiltinAttributes.h"
|
||||||
#include "mlir-c/BuiltinTypes.h"
|
#include "mlir-c/BuiltinTypes.h"
|
||||||
#include "mlir-c/Diagnostics.h"
|
#include "mlir-c/Diagnostics.h"
|
||||||
#include "npcomp-c/TorchTypes.h"
|
#include "torch-mlir-c/TorchTypes.h"
|
||||||
|
|
||||||
using namespace torch_mlir;
|
using namespace torch_mlir;
|
||||||
|
|
||||||
|
@ -18,11 +18,11 @@ OpBuilder::OpBuilder(MlirContext context) : context(context) {}
|
||||||
|
|
||||||
MlirOperation OpBuilder::createNoneConstant(MlirLocation loc) {
|
MlirOperation OpBuilder::createNoneConstant(MlirLocation loc) {
|
||||||
return createMlirOperation("torch.constant.none", loc,
|
return createMlirOperation("torch.constant.none", loc,
|
||||||
npcompTorchNoneTypeGet(context));
|
torchMlirTorchNoneTypeGet(context));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirOperation OpBuilder::createBoolConstant(MlirLocation loc, bool value) {
|
MlirOperation OpBuilder::createBoolConstant(MlirLocation loc, bool value) {
|
||||||
return createMlirOperation(
|
return createMlirOperation(
|
||||||
"torch.constant.bool", loc, npcompTorchBoolTypeGet(context),
|
"torch.constant.bool", loc, torchMlirTorchBoolTypeGet(context),
|
||||||
toMlirNamedAttribute("value", mlirBoolAttrGet(context, value)));
|
toMlirNamedAttribute("value", mlirBoolAttrGet(context, value)));
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
#include "mlir-c/BuiltinAttributes.h"
|
#include "mlir-c/BuiltinAttributes.h"
|
||||||
#include "mlir-c/BuiltinTypes.h"
|
#include "mlir-c/BuiltinTypes.h"
|
||||||
#include "mlir-c/Diagnostics.h"
|
#include "mlir-c/Diagnostics.h"
|
||||||
#include "npcomp-c/TorchTypes.h"
|
#include "torch-mlir-c/TorchTypes.h"
|
||||||
|
|
||||||
using namespace torch_mlir;
|
using namespace torch_mlir;
|
||||||
|
|
||||||
|
@ -66,7 +66,7 @@ MlirType TypeMapper::rawMapFromTorchScalarType(c10::ScalarType scalarType) {
|
||||||
case ScalarType::Half:
|
case ScalarType::Half:
|
||||||
return mlirF16TypeGet(context);
|
return mlirF16TypeGet(context);
|
||||||
case ScalarType::QInt8:
|
case ScalarType::QInt8:
|
||||||
return npcompTorchQInt8TypeGet(context);
|
return torchMlirTorchQInt8TypeGet(context);
|
||||||
default: {
|
default: {
|
||||||
return {nullptr};
|
return {nullptr};
|
||||||
}
|
}
|
||||||
|
@ -105,7 +105,7 @@ static MlirType mapCustomClassType(MlirContext context, MlirLocation loc,
|
||||||
|
|
||||||
// Individually handle the custom classes that we know about.
|
// Individually handle the custom classes that we know about.
|
||||||
if (name == "__torch__.torch.classes.quantized.LinearPackedParamsBase") {
|
if (name == "__torch__.torch.classes.quantized.LinearPackedParamsBase") {
|
||||||
return npcompTorchLinearParamsTypeGet(context);
|
return torchMlirTorchLinearParamsTypeGet(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
// At this point, we know that the type is indeed a custom class type, but
|
// At this point, we know that the type is indeed a custom class type, but
|
||||||
|
@ -136,7 +136,7 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
||||||
auto &sizes = tensorType->symbolic_sizes();
|
auto &sizes = tensorType->symbolic_sizes();
|
||||||
if (!sizes.rank()) {
|
if (!sizes.rank()) {
|
||||||
// Unranked.
|
// Unranked.
|
||||||
return npcompTorchNonValueTensorTypeGet(context,
|
return torchMlirTorchNonValueTensorTypeGet(context,
|
||||||
/*numSizes=*/0,
|
/*numSizes=*/0,
|
||||||
/*optionalSizes=*/nullptr,
|
/*optionalSizes=*/nullptr,
|
||||||
/*optionalDtype=*/
|
/*optionalDtype=*/
|
||||||
|
@ -150,28 +150,28 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
||||||
auto shapeSymbol = symbolicShape[i];
|
auto shapeSymbol = symbolicShape[i];
|
||||||
dims[i] = shapeSymbol.is_static() ? shapeSymbol.static_size() : -1;
|
dims[i] = shapeSymbol.is_static() ? shapeSymbol.static_size() : -1;
|
||||||
}
|
}
|
||||||
return npcompTorchNonValueTensorTypeGet(context, dims.size(),
|
return torchMlirTorchNonValueTensorTypeGet(context, dims.size(),
|
||||||
/*optionalSizes=*/dims.data(),
|
/*optionalSizes=*/dims.data(),
|
||||||
/*optionalDtype=*/
|
/*optionalDtype=*/
|
||||||
elementType);
|
elementType);
|
||||||
}
|
}
|
||||||
case TypeKind::IntType: {
|
case TypeKind::IntType: {
|
||||||
return npcompTorchIntTypeGet(context);
|
return torchMlirTorchIntTypeGet(context);
|
||||||
}
|
}
|
||||||
case TypeKind::FloatType: {
|
case TypeKind::FloatType: {
|
||||||
return npcompTorchFloatTypeGet(context);
|
return torchMlirTorchFloatTypeGet(context);
|
||||||
}
|
}
|
||||||
case TypeKind::BoolType: {
|
case TypeKind::BoolType: {
|
||||||
return npcompTorchBoolTypeGet(context);
|
return torchMlirTorchBoolTypeGet(context);
|
||||||
}
|
}
|
||||||
case TypeKind::NumberType: {
|
case TypeKind::NumberType: {
|
||||||
return npcompTorchNumberTypeGet(context);
|
return torchMlirTorchNumberTypeGet(context);
|
||||||
}
|
}
|
||||||
case TypeKind::StringType: {
|
case TypeKind::StringType: {
|
||||||
return npcompTorchStringTypeGet(context);
|
return torchMlirTorchStringTypeGet(context);
|
||||||
}
|
}
|
||||||
case TypeKind::OptionalType: {
|
case TypeKind::OptionalType: {
|
||||||
return npcompTorchOptionalTypeGet(mapFromTorchType(
|
return torchMlirTorchOptionalTypeGet(mapFromTorchType(
|
||||||
loc, torchType->cast<c10::OptionalType>()->getElementType()));
|
loc, torchType->cast<c10::OptionalType>()->getElementType()));
|
||||||
}
|
}
|
||||||
case TypeKind::TupleType: {
|
case TypeKind::TupleType: {
|
||||||
|
@ -180,25 +180,25 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
||||||
torchType->cast<c10::TupleType>()->containedTypes()) {
|
torchType->cast<c10::TupleType>()->containedTypes()) {
|
||||||
containedTypes.push_back(mapFromTorchType(loc, type));
|
containedTypes.push_back(mapFromTorchType(loc, type));
|
||||||
}
|
}
|
||||||
return npcompTorchTupleTypeGet(context, containedTypes.size(),
|
return torchMlirTorchTupleTypeGet(context, containedTypes.size(),
|
||||||
containedTypes.data());
|
containedTypes.data());
|
||||||
}
|
}
|
||||||
case TypeKind::ListType: {
|
case TypeKind::ListType: {
|
||||||
return npcompTorchListTypeGet(mapFromTorchType(
|
return torchMlirTorchListTypeGet(mapFromTorchType(
|
||||||
loc, torchType->cast<c10::ListType>()->getElementType()));
|
loc, torchType->cast<c10::ListType>()->getElementType()));
|
||||||
}
|
}
|
||||||
case TypeKind::DictType: {
|
case TypeKind::DictType: {
|
||||||
auto dictType = torchType->cast<c10::DictType>();
|
auto dictType = torchType->cast<c10::DictType>();
|
||||||
return npcompTorchDictTypeGet(
|
return torchMlirTorchDictTypeGet(
|
||||||
mapFromTorchType(loc, dictType->getKeyType()),
|
mapFromTorchType(loc, dictType->getKeyType()),
|
||||||
mapFromTorchType(loc, dictType->getValueType()));
|
mapFromTorchType(loc, dictType->getValueType()));
|
||||||
}
|
}
|
||||||
case TypeKind::NoneType: {
|
case TypeKind::NoneType: {
|
||||||
return npcompTorchNoneTypeGet(context);
|
return torchMlirTorchNoneTypeGet(context);
|
||||||
}
|
}
|
||||||
case TypeKind::AnyType: {
|
case TypeKind::AnyType: {
|
||||||
auto anyType = torchType->cast<c10::AnyType>();
|
auto anyType = torchType->cast<c10::AnyType>();
|
||||||
return npcompTorchAnyTypeGet(context);
|
return torchMlirTorchAnyTypeGet(context);
|
||||||
}
|
}
|
||||||
case TypeKind::ClassType: {
|
case TypeKind::ClassType: {
|
||||||
const c10::ClassTypePtr &classType = torchType->cast<c10::ClassType>();
|
const c10::ClassTypePtr &classType = torchType->cast<c10::ClassType>();
|
||||||
|
@ -208,10 +208,10 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
||||||
}
|
}
|
||||||
auto maybeName = classType->name();
|
auto maybeName = classType->name();
|
||||||
std::string name = maybeName ? maybeName->qualifiedName() : "unnamed class";
|
std::string name = maybeName ? maybeName->qualifiedName() : "unnamed class";
|
||||||
return npcompTorchNnModuleTypeGet(context, toMlirStringRef(name));
|
return torchMlirTorchNnModuleTypeGet(context, toMlirStringRef(name));
|
||||||
}
|
}
|
||||||
case TypeKind::DeviceObjType: {
|
case TypeKind::DeviceObjType: {
|
||||||
return npcompTorchDeviceTypeGet(context);
|
return torchMlirTorchDeviceTypeGet(context);
|
||||||
}
|
}
|
||||||
default: {
|
default: {
|
||||||
std::stringstream message;
|
std::stringstream message;
|
||||||
|
@ -226,7 +226,7 @@ MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) {
|
||||||
if (!tensor.defined()) {
|
if (!tensor.defined()) {
|
||||||
// Undefined tensors are equivalent to None.
|
// Undefined tensors are equivalent to None.
|
||||||
// This may need to be re-evaluated at some point.
|
// This may need to be re-evaluated at some point.
|
||||||
return npcompTorchNoneTypeGet(context);
|
return torchMlirTorchNoneTypeGet(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType elementType = mapFromTorchScalarType(tensor.scalar_type());
|
MlirType elementType = mapFromTorchScalarType(tensor.scalar_type());
|
||||||
|
@ -234,8 +234,8 @@ MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) {
|
||||||
// just erase them and let the compiler decide.
|
// just erase them and let the compiler decide.
|
||||||
|
|
||||||
auto sizes = tensor.sizes();
|
auto sizes = tensor.sizes();
|
||||||
return npcompTorchNonValueTensorTypeGet(context, sizes.size(), sizes.data(),
|
return torchMlirTorchNonValueTensorTypeGet(context, sizes.size(),
|
||||||
elementType);
|
sizes.data(), elementType);
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType
|
MlirType
|
||||||
|
|
|
@ -2,5 +2,4 @@ add_subdirectory(Basicpy)
|
||||||
add_subdirectory(Numpy)
|
add_subdirectory(Numpy)
|
||||||
add_subdirectory(Refback)
|
add_subdirectory(Refback)
|
||||||
add_subdirectory(Refbackrt)
|
add_subdirectory(Refbackrt)
|
||||||
add_subdirectory(Torch)
|
|
||||||
add_subdirectory(TorchConversion)
|
add_subdirectory(TorchConversion)
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
#include "mlir/IR/SymbolTable.h"
|
#include "mlir/IR/SymbolTable.h"
|
||||||
#include "mlir/Interfaces/CastInterfaces.h"
|
#include "mlir/Interfaces/CastInterfaces.h"
|
||||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
#include "npcomp/Interfaces/Traits.h"
|
|
||||||
#include "npcomp/Typing/Analysis/CPA/Interfaces.h"
|
#include "npcomp/Typing/Analysis/CPA/Interfaces.h"
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
|
|
|
@ -10,7 +10,6 @@
|
||||||
#define NPCOMP_DIALECT_NUMPY_IR_NUMPY_OPS
|
#define NPCOMP_DIALECT_NUMPY_IR_NUMPY_OPS
|
||||||
|
|
||||||
include "npcomp/Dialect/Numpy/IR/NumpyDialect.td"
|
include "npcomp/Dialect/Numpy/IR/NumpyDialect.td"
|
||||||
include "npcomp/Interfaces/Traits.td"
|
|
||||||
include "npcomp/Typing/Analysis/CPA/Interfaces.td"
|
include "npcomp/Typing/Analysis/CPA/Interfaces.td"
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
include "mlir/Interfaces/CastInterfaces.td"
|
include "mlir/Interfaces/CastInterfaces.td"
|
||||||
|
@ -39,7 +38,6 @@ def Numpy_NarrowOp : Numpy_Op<"narrow", []> {
|
||||||
|
|
||||||
def Numpy_StaticInfoCastOp : Numpy_Op<"static_info_cast", [
|
def Numpy_StaticInfoCastOp : Numpy_Op<"static_info_cast", [
|
||||||
DeclareOpInterfaceMethods<CastOpInterface>,
|
DeclareOpInterfaceMethods<CastOpInterface>,
|
||||||
AllowsTypeRefinement,
|
|
||||||
NoSideEffect]> {
|
NoSideEffect]> {
|
||||||
let summary = "Adds/removes static information from an array type.";
|
let summary = "Adds/removes static information from an array type.";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -60,7 +58,6 @@ def Numpy_StaticInfoCastOp : Numpy_Op<"static_info_cast", [
|
||||||
|
|
||||||
def Numpy_TensorStaticInfoCastOp : Numpy_Op<"tensor_static_info_cast", [
|
def Numpy_TensorStaticInfoCastOp : Numpy_Op<"tensor_static_info_cast", [
|
||||||
DeclareOpInterfaceMethods<CastOpInterface>,
|
DeclareOpInterfaceMethods<CastOpInterface>,
|
||||||
AllowsTypeRefinement,
|
|
||||||
NoSideEffect]> {
|
NoSideEffect]> {
|
||||||
let summary = "Adds/removes static information from a tensor type.";
|
let summary = "Adds/removes static information from a tensor type.";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
|
|
@ -1,5 +0,0 @@
|
||||||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
|
||||||
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
|
||||||
add_public_tablegen_target(NPCOMPTorchPassIncGen)
|
|
||||||
|
|
||||||
add_mlir_doc(Passes NPCOMPTorchTransforms ./ -gen-pass-doc)
|
|
|
@ -16,8 +16,8 @@
|
||||||
#include "mlir/Interfaces/CastInterfaces.h"
|
#include "mlir/Interfaces/CastInterfaces.h"
|
||||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
|
||||||
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.h.inc"
|
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.h.inc"
|
||||||
|
|
|
@ -14,7 +14,7 @@ include "mlir/IR/SymbolInterfaces.td"
|
||||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
include "npcomp/Dialect/TorchConversion/IR/TorchConversionBase.td"
|
include "npcomp/Dialect/TorchConversion/IR/TorchConversionBase.td"
|
||||||
include "npcomp/Dialect/Torch/IR/TorchTypes.td"
|
include "torch-mlir/Dialect/Torch/IR/TorchTypes.td"
|
||||||
include "iree-dialects/Dialect/IREE/IREEDialect.td"
|
include "iree-dialects/Dialect/IREE/IREEDialect.td"
|
||||||
|
|
||||||
class TorchConversion_Op<string mnemonic, list<OpTrait> traits = []>
|
class TorchConversion_Op<string mnemonic, list<OpTrait> traits = []>
|
||||||
|
|
|
@ -10,7 +10,7 @@
|
||||||
#define NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H
|
#define NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H
|
||||||
|
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
@ -21,7 +21,8 @@ namespace TorchConversion {
|
||||||
/// Creates a pipeline that lowers the object graph IR that is produced by
|
/// Creates a pipeline that lowers the object graph IR that is produced by
|
||||||
/// TorchScript import into the form expected by npcomp-verify-backend-contract.
|
/// TorchScript import into the form expected by npcomp-verify-backend-contract.
|
||||||
void createTorchScriptToNpcompBackendPipeline(
|
void createTorchScriptToNpcompBackendPipeline(
|
||||||
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options);
|
OpPassManager &pm,
|
||||||
|
const torch::Torch::TorchLoweringPipelineOptions &options);
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
createVerifyInvariantsBeforeBackendLoweringPass();
|
createVerifyInvariantsBeforeBackendLoweringPass();
|
||||||
|
|
|
@ -13,21 +13,7 @@
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace NPCOMP {
|
namespace NPCOMP {
|
||||||
namespace OpTrait {
|
namespace OpTrait {} // namespace OpTrait
|
||||||
|
|
||||||
template <typename ConcreteType>
|
|
||||||
class AllowsTypeRefinement
|
|
||||||
: public ::mlir::OpTrait::TraitBase<ConcreteType, AllowsTypeRefinement> {};
|
|
||||||
|
|
||||||
} // namespace OpTrait
|
|
||||||
|
|
||||||
// Check if an operation has the AllowsTypeRefinement trait.
|
|
||||||
//
|
|
||||||
// This function should be used in preference to
|
|
||||||
// `op->hasTrait<AllowsTypeRefinement>()` because this function has knowledge of
|
|
||||||
// some upstream ops that have this property, but which we cannot annotate with
|
|
||||||
// this trait.
|
|
||||||
bool allowsTypeRefinement(Operation *op);
|
|
||||||
|
|
||||||
} // namespace NPCOMP
|
} // namespace NPCOMP
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -19,13 +19,6 @@ class NpcompOpTrait<string name> : OpTrait, NativeTrait<"", ""> {
|
||||||
let cppNamespace = "::mlir::NPCOMP::OpTrait";
|
let cppNamespace = "::mlir::NPCOMP::OpTrait";
|
||||||
}
|
}
|
||||||
|
|
||||||
// Op allows operand and result types to be refined.
|
// Empty for now. Kept as boilerplate placeholder.
|
||||||
// For example a `tensor<?xf32>` can be refined to `tensor<4xf32>`.
|
|
||||||
//
|
|
||||||
// TODO: Implement RefinableTypeInterface that allows actually modeling
|
|
||||||
// which types are refinements of other types.
|
|
||||||
// See the design in:
|
|
||||||
// https://llvm.discourse.group/t/allow-shape-concretization-or-type-concretization-in-rewrites/3327/3
|
|
||||||
def AllowsTypeRefinement : NpcompOpTrait<"AllowsTypeRefinement">;
|
|
||||||
|
|
||||||
#endif // NPCOMP_INTERFACES_TRAITS
|
#endif // NPCOMP_INTERFACES_TRAITS
|
||||||
|
|
|
@ -10,7 +10,6 @@ add_npcomp_library(NPCOMPCAPI
|
||||||
Registration.cpp
|
Registration.cpp
|
||||||
BasicpyTypes.cpp
|
BasicpyTypes.cpp
|
||||||
NumpyTypes.cpp
|
NumpyTypes.cpp
|
||||||
TorchTypes.cpp
|
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
MLIRExecutionEngine
|
MLIRExecutionEngine
|
||||||
|
@ -21,7 +20,8 @@ add_npcomp_library(NPCOMPCAPI
|
||||||
NPCOMPNumpyDialect
|
NPCOMPNumpyDialect
|
||||||
NPCOMPRefBackendJITHelpers
|
NPCOMPRefBackendJITHelpers
|
||||||
NPCOMPRuntime
|
NPCOMPRuntime
|
||||||
NPCOMPTorchDialect
|
TorchMLIRTorchDialect
|
||||||
|
TorchMLIRInitAll
|
||||||
|
|
||||||
# MLIR CAPI deps
|
# MLIR CAPI deps
|
||||||
MLIRCAPIIR
|
MLIRCAPIIR
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
#include "mlir/Dialect/Linalg/Passes.h"
|
#include "mlir/Dialect/Linalg/Passes.h"
|
||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
#include "npcomp/InitAll.h"
|
#include "npcomp/InitAll.h"
|
||||||
|
#include "torch-mlir/InitAll.h"
|
||||||
|
|
||||||
void npcompRegisterAllDialects(MlirContext context) {
|
void npcompRegisterAllDialects(MlirContext context) {
|
||||||
mlir::DialectRegistry registry;
|
mlir::DialectRegistry registry;
|
||||||
|
|
|
@ -31,7 +31,7 @@ add_npcomp_library(NPCOMPInitAll
|
||||||
NPCOMPIREEBackend
|
NPCOMPIREEBackend
|
||||||
NPCOMPRefBackend
|
NPCOMPRefBackend
|
||||||
NPCOMPRefbackDialect
|
NPCOMPRefbackDialect
|
||||||
NPCOMPTorchDialect
|
TorchMLIRTorchDialect
|
||||||
NPCOMPTorchConversionDialect
|
NPCOMPTorchConversionDialect
|
||||||
NPCOMPRefbackrtDialect
|
NPCOMPRefbackrtDialect
|
||||||
NPCOMPBasicpyDialect
|
NPCOMPBasicpyDialect
|
||||||
|
|
|
@ -13,7 +13,7 @@ add_npcomp_conversion_library(NPCOMPTorchToIREE
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
MLIRIR
|
MLIRIR
|
||||||
MLIRPass
|
MLIRPass
|
||||||
NPCOMPTorchDialect
|
TorchMLIRTorchDialect
|
||||||
MLIRStandard
|
MLIRStandard
|
||||||
IREEDialectsIREEDialect
|
IREEDialectsIREEDialect
|
||||||
)
|
)
|
||||||
|
|
|
@ -14,13 +14,13 @@
|
||||||
#include "mlir/Dialect/Traits.h"
|
#include "mlir/Dialect/Traits.h"
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
|
||||||
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
||||||
#include "npcomp/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
#include "npcomp/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::NPCOMP;
|
||||||
using namespace mlir::NPCOMP::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// The patterns
|
// The patterns
|
||||||
|
|
|
@ -15,5 +15,5 @@ add_npcomp_conversion_library(NPCOMPTorchToLinalg
|
||||||
MLIRPass
|
MLIRPass
|
||||||
MLIRLinalg
|
MLIRLinalg
|
||||||
MLIRMath
|
MLIRMath
|
||||||
NPCOMPTorchDialect
|
TorchMLIRTorchDialect
|
||||||
)
|
)
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue