mirror of https://github.com/llvm/torch-mlir
[torch-mlir earthmoving (1/N)] C/C++ code movement.
This creates the `external/torch-mlir` directory as an LLVM_EXTERNAL_PROJECTS-compatible project (analogous to `iree-dialects`) and completes movement/rename of all pure MLIR C/C++ compiler code into there. The next step will be to move all the Python code / code that links/includes PyTorch C++ code (which currently lives in `frontends/pytorch`) into a subdirectory here. I call this "earthmoving" because it is mostly mechanical changes and renames. As a quick summary (we can change this down the road easily) - C++ `mlir::NPCOMP::Torch -> mlir::torch::Torch` - CAPI `npcompTorchListTypeGet -> torchMlirTorchListTypeGet` - preprocessor `#ifndef NPCOMP_ -> #ifndef TORCHMLIR_` - CMake `NPCOMPFoo -> TorchMLIRFoo` The goal of this is to create a standalone project creating a center of mass for entry into the MLIR ecosystem from PyTorch, suitable in scope for eventual inclusion/ownership in PyTorch. The idea is that `external/torch-mlir` will some day be pulled out into its own repository, and then npcomp will simply pull it in as a submodule. Layering-wise, what lives in `torch-mlir` lowers code from PyTorch (currently TorchScript, but TorchFX or pytorch/xla-style tracing are possible extensions) down to what we have been calling the "Torch backend contract" which is cleaned up IR (inlining, simplifcation, conversion to value tensors, ...) entirely in the `torch` dialect. This is the branching off point for further lowering, of which npcomp takes one opinion (outside `torch-mlir` of course!), namely the `TorchConversion` dialect/transforms which lower to IR suitable for IREE and other linalg-on-tensors based lower-level compilers. Summary of changes: - move `{include,lib,test}/Dialect/Torch` into `torch-mlir` - move relevant parts of CAPI into `torch-mlir`. - leave a few things related to the `torch-mlir` Python build commented out, which should be resolved in a subsequent change.pull/305/head
parent
28762699b3
commit
28a7738189
|
@ -128,9 +128,10 @@ if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
|||
"suffix = '${PYTHON_MODULE_SUFFIX}', "
|
||||
"extension = '${PYTHON_MODULE_EXTENSION}")
|
||||
|
||||
# Include the iree-dialects external project.
|
||||
set(LLVM_EXTERNAL_PROJECTS "iree-dialects")
|
||||
# Include LLVM_EXTERNAL_PROJECTS.
|
||||
set(LLVM_EXTERNAL_PROJECTS "iree-dialects;torch-mlir")
|
||||
set(LLVM_EXTERNAL_IREE_DIALECTS_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/external/iree-dialects")
|
||||
set(LLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/external/torch-mlir")
|
||||
|
||||
# LLVM configuration.
|
||||
message(STATUS "*** ADDING LLVM ***")
|
||||
|
@ -183,6 +184,8 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
|
|||
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/external/iree-dialects/include)
|
||||
include_directories(${CMAKE_CURRENT_BINARY_DIR}/llvm/tools/iree-dialects/include)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/external/torch-mlir/include)
|
||||
include_directories(${CMAKE_CURRENT_BINARY_DIR}/llvm/tools/torch-mlir/include)
|
||||
link_directories(${LLVM_BUILD_LIBRARY_DIR})
|
||||
add_definitions(${LLVM_DEFINITIONS})
|
||||
set(NPCOMP_TABLEGEN_ARGS "")
|
||||
|
|
|
@ -21,6 +21,7 @@ cd $td/build
|
|||
|
||||
ninja
|
||||
ninja check-npcomp
|
||||
ninja check-torch-mlir
|
||||
ninja check-frontends-pytorch
|
||||
|
||||
echo
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
||||
message(FATAL_ERROR
|
||||
"This project is intended to be built as part of LLVM via "
|
||||
"-DLLVM_EXTERNAL_PROJECTS=torch-mlir "
|
||||
"-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR=${CMAKE_CURRENT_SOURCE_DIR}")
|
||||
endif()
|
||||
|
||||
option(MLIR_ENABLE_BINDINGS_PYTHON "Enables MLIR Python Bindings" OFF)
|
||||
|
||||
set(TORCH_MLIR_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
|
||||
set(TORCH_MLIR_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")
|
||||
message(STATUS "Building torch-mlir project at ${TORCH_MLIR_SOURCE_DIR} (into ${TORCH_MLIR_BINARY_DIR})")
|
||||
|
||||
# TODO: Fix this upstream so that global include directories are not needed.
|
||||
set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir)
|
||||
set(MLIR_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include)
|
||||
set(MLIR_GENERATED_INCLUDE_DIR ${LLVM_BINARY_DIR}/tools/mlir/include)
|
||||
|
||||
# TODO: Needed for tablegen. Remove.
|
||||
include_directories(SYSTEM ${MLIR_INCLUDE_DIR})
|
||||
include_directories(SYSTEM ${MLIR_GENERATED_INCLUDE_DIR})
|
||||
include_directories(SYSTEM ${TORCH_MLIR_SOURCE_DIR}/include)
|
||||
|
||||
function(torch_mlir_target_includes target)
|
||||
set(_dirs
|
||||
$<BUILD_INTERFACE:${MLIR_INCLUDE_DIR}>
|
||||
$<BUILD_INTERFACE:${MLIR_GENERATED_INCLUDE_DIR}>
|
||||
$<BUILD_INTERFACE:${TORCH_MLIR_SOURCE_DIR}/include>
|
||||
$<BUILD_INTERFACE:${TORCH_MLIR_BINARY_DIR}/include>
|
||||
)
|
||||
# In LLVM parlance, the actual target may just be an interface and may not
|
||||
# be responsible for actually compiling anything. The corresponding obj.
|
||||
# target, when present, is just used for compilation and does not
|
||||
# contribute to the interface properties.
|
||||
# TODO: Normalize this upstream.
|
||||
target_include_directories(${target} PUBLIC ${_dirs})
|
||||
if(TARGET obj.${target})
|
||||
target_include_directories(obj.${target} PRIVATE ${_dirs})
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
# Configure CMake and tablegen.
|
||||
list(APPEND CMAKE_MODULE_PATH ${MLIR_MAIN_SRC_DIR}/cmake/modules)
|
||||
list(APPEND CMAKE_MODULE_PATH ${LLVM_MAIN_SRC_DIR}/cmake)
|
||||
set(MLIR_TABLEGEN_EXE mlir-tblgen)
|
||||
|
||||
include(TableGen)
|
||||
include(AddLLVM)
|
||||
include(AddMLIR)
|
||||
|
||||
################################################################################
|
||||
# Setup python.
|
||||
# TODO: Make one upstream macro to do this.
|
||||
################################################################################
|
||||
|
||||
if(MLIR_ENABLE_BINDINGS_PYTHON)
|
||||
include(MLIRDetectPythonEnv)
|
||||
mlir_detect_pybind11_install()
|
||||
find_package(Python3 ${LLVM_MINIMUM_PYTHON_VERSION}
|
||||
COMPONENTS Interpreter Development NumPy REQUIRED)
|
||||
find_package(pybind11 2.6 CONFIG REQUIRED)
|
||||
endif()
|
||||
|
||||
add_subdirectory(include)
|
||||
add_subdirectory(lib)
|
||||
add_subdirectory(test)
|
||||
add_subdirectory(tools)
|
||||
|
||||
if(MLIR_ENABLE_BINDINGS_PYTHON)
|
||||
#XXX: Enable
|
||||
#add_subdirectory(python)
|
||||
endif()
|
|
@ -0,0 +1 @@
|
|||
add_subdirectory(torch-mlir)
|
|
@ -0,0 +1,32 @@
|
|||
/*===-- torch-mlir-c/Registration.h - Registration functions -----*- C -*-===*\
|
||||
|* *|
|
||||
|* Part of the LLVM Project, under the Apache License v2.0 with LLVM *|
|
||||
|* Exceptions. *|
|
||||
|* See https://llvm.org/LICENSE.txt for license information. *|
|
||||
|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *|
|
||||
|* *|
|
||||
\*===----------------------------------------------------------------------===*/
|
||||
|
||||
#ifndef TORCHMLIR_C_REGISTRATION_H
|
||||
#define TORCHMLIR_C_REGISTRATION_H
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Support.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/** Registers all dialects with a context.
|
||||
* This is needed before creating IR for these Dialects.
|
||||
*/
|
||||
MLIR_CAPI_EXPORTED void torchMlirRegisterAllDialects(MlirContext context);
|
||||
|
||||
/** Registers all passes for symbolic access with the global registry. */
|
||||
MLIR_CAPI_EXPORTED void torchMlirRegisterAllPasses();
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // TORCHMLIR_C_REGISTRATION_H
|
|
@ -1,4 +1,4 @@
|
|||
//===-- npcomp-c/TorchTypes.h - C API for torch types -------------*- C -*-===//
|
||||
//===-- torch-mlir-c/TorchTypes.h - C API for torch types ---------*- C -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
||||
// Exceptions.
|
||||
|
@ -7,8 +7,8 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_C_TORCHTYPES_H
|
||||
#define NPCOMP_C_TORCHTYPES_H
|
||||
#ifndef TORCHMLIR_C_TORCHTYPES_H
|
||||
#define TORCHMLIR_C_TORCHTYPES_H
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Support.h"
|
||||
|
@ -22,110 +22,112 @@ extern "C" {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a torch.nn.Module type
|
||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchNnModule(MlirType t);
|
||||
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNnModule(MlirType t);
|
||||
|
||||
/// Gets the !torch.nn.Module type of the specified class.
|
||||
MLIR_CAPI_EXPORTED MlirType npcompTorchNnModuleTypeGet(MlirContext context,
|
||||
MlirStringRef className);
|
||||
MLIR_CAPI_EXPORTED MlirType
|
||||
torchMlirTorchNnModuleTypeGet(MlirContext context, MlirStringRef className);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.optional type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a !torch.optional<T> type
|
||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchOptional(MlirType t);
|
||||
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchOptional(MlirType t);
|
||||
|
||||
/// Gets the !torch.optional<T> type with subtype T.
|
||||
MLIR_CAPI_EXPORTED MlirType npcompTorchOptionalTypeGet(MlirType containedType);
|
||||
MLIR_CAPI_EXPORTED MlirType
|
||||
torchMlirTorchOptionalTypeGet(MlirType containedType);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.tuple<T1, T2, T3> type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a !torch.tuple type
|
||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchTuple(MlirType t);
|
||||
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchTuple(MlirType t);
|
||||
|
||||
/// Gets the !torch.tuple type with contained types `containedTypes`.
|
||||
MLIR_CAPI_EXPORTED MlirType
|
||||
npcompTorchTupleTypeGet(MlirContext context, intptr_t numContainedTypes,
|
||||
MlirType const *containedTypes);
|
||||
torchMlirTorchTupleTypeGet(MlirContext context, intptr_t numContainedTypes,
|
||||
MlirType const *containedTypes);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.list<T> type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a !torch.list<T> type
|
||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchList(MlirType t);
|
||||
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchList(MlirType t);
|
||||
|
||||
/// Gets the !torch.list<T> type with contained T.
|
||||
MLIR_CAPI_EXPORTED MlirType npcompTorchListTypeGet(MlirType containedType);
|
||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGet(MlirType containedType);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.Device type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a !torch.Device type
|
||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchDevice(MlirType t);
|
||||
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchDevice(MlirType t);
|
||||
|
||||
/// Gets the !torch.Device type.
|
||||
MLIR_CAPI_EXPORTED MlirType npcompTorchDeviceTypeGet(MlirContext context);
|
||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchDeviceTypeGet(MlirContext context);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.bool type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a !torch.bool type
|
||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchBool(MlirType t);
|
||||
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchBool(MlirType t);
|
||||
|
||||
/// Gets the !torch.bool type.
|
||||
MLIR_CAPI_EXPORTED MlirType npcompTorchBoolTypeGet(MlirContext context);
|
||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchBoolTypeGet(MlirContext context);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.int type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a !torch.int type
|
||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchInt(MlirType t);
|
||||
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchInt(MlirType t);
|
||||
|
||||
/// Gets the !torch.int type.
|
||||
MLIR_CAPI_EXPORTED MlirType npcompTorchIntTypeGet(MlirContext context);
|
||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchIntTypeGet(MlirContext context);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.float type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a !torch.float type
|
||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchFloat(MlirType t);
|
||||
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchFloat(MlirType t);
|
||||
|
||||
/// Gets the !torch.float type.
|
||||
MLIR_CAPI_EXPORTED MlirType npcompTorchFloatTypeGet(MlirContext context);
|
||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchFloatTypeGet(MlirContext context);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.LinearParams type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a !torch.LinearParams type
|
||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchLinearParams(MlirType t);
|
||||
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchLinearParams(MlirType t);
|
||||
|
||||
/// Gets the !torch.LinearParams type.
|
||||
MLIR_CAPI_EXPORTED MlirType npcompTorchLinearParamsTypeGet(MlirContext context);
|
||||
MLIR_CAPI_EXPORTED MlirType
|
||||
torchMlirTorchLinearParamsTypeGet(MlirContext context);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.qint8 type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a !torch.qint8 type
|
||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchQInt8(MlirType t);
|
||||
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQInt8(MlirType t);
|
||||
|
||||
/// Gets the !torch.qint8 type.
|
||||
MLIR_CAPI_EXPORTED MlirType npcompTorchQInt8TypeGet(MlirContext context);
|
||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchQInt8TypeGet(MlirContext context);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.tensor type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a !torch.tensor type
|
||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchNonValueTensor(MlirType t);
|
||||
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNonValueTensor(MlirType t);
|
||||
|
||||
/// Gets a !torch.tensor type.
|
||||
///
|
||||
|
@ -133,24 +135,25 @@ MLIR_CAPI_EXPORTED bool npcompTypeIsATorchNonValueTensor(MlirType t);
|
|||
/// information is present (and `numSizes` is ignored in that case). -
|
||||
/// `optionalDtype` is allowed to be null, meaning that no dtype
|
||||
/// information is present.
|
||||
MLIR_CAPI_EXPORTED MlirType npcompTorchNonValueTensorTypeGet(
|
||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchNonValueTensorTypeGet(
|
||||
MlirContext context, intptr_t numSizes, const int64_t *optionalSizes,
|
||||
MlirType optionalDtype);
|
||||
|
||||
/// Gets the !torch.tensor type with the least static information.
|
||||
MLIR_CAPI_EXPORTED MlirType
|
||||
npcompTorchNonValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
|
||||
torchMlirTorchNonValueTensorTypeGetWithLeastStaticInformation(
|
||||
MlirContext context);
|
||||
|
||||
/// Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`.
|
||||
MLIR_CAPI_EXPORTED MlirType
|
||||
npcompTorchNonValueTensorTypeGetFromShaped(MlirType type);
|
||||
torchMlirTorchNonValueTensorTypeGetFromShaped(MlirType type);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.vtensor type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a !torch.vtensor type
|
||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchValueTensor(MlirType t);
|
||||
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchValueTensor(MlirType t);
|
||||
|
||||
/// Gets a !torch.vtensor type.
|
||||
///
|
||||
|
@ -158,71 +161,71 @@ MLIR_CAPI_EXPORTED bool npcompTypeIsATorchValueTensor(MlirType t);
|
|||
/// information is present (and `numSizes` is ignored in that case).
|
||||
/// - `optionalDtype` is allowed to be null, meaning that no dtype
|
||||
/// information is present.
|
||||
MLIR_CAPI_EXPORTED MlirType npcompTorchValueTensorTypeGet(
|
||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchValueTensorTypeGet(
|
||||
MlirContext context, intptr_t numSizes, const int64_t *optionalSizes,
|
||||
MlirType optionalDtype);
|
||||
|
||||
/// Gets the !torch.tensor type with the least static information.
|
||||
MLIR_CAPI_EXPORTED MlirType
|
||||
npcompTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
|
||||
torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
|
||||
|
||||
/// Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`.
|
||||
MLIR_CAPI_EXPORTED MlirType
|
||||
npcompTorchValueTensorTypeGetFromShaped(MlirType type);
|
||||
torchMlirTorchValueTensorTypeGetFromShaped(MlirType type);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// !torch.none type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a !torch.none type
|
||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchNone(MlirType t);
|
||||
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNone(MlirType t);
|
||||
|
||||
/// Gets the !torch.none type.
|
||||
MLIR_CAPI_EXPORTED MlirType npcompTorchNoneTypeGet(MlirContext context);
|
||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchNoneTypeGet(MlirContext context);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// !torch.str type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a !torch.str type
|
||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchString(MlirType t);
|
||||
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchString(MlirType t);
|
||||
|
||||
/// Gets the !torch.str type.
|
||||
MLIR_CAPI_EXPORTED MlirType npcompTorchStringTypeGet(MlirContext context);
|
||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchStringTypeGet(MlirContext context);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// !torch.any type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a !torch.any type.
|
||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchAny(MlirType t);
|
||||
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchAny(MlirType t);
|
||||
|
||||
/// Gets the !torch.str type.
|
||||
MLIR_CAPI_EXPORTED MlirType npcompTorchAnyTypeGet(MlirContext context);
|
||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchAnyTypeGet(MlirContext context);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// !torch.number type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a !torch.number type.
|
||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchNumber(MlirType t);
|
||||
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNumber(MlirType t);
|
||||
|
||||
/// Gets the !torch.number type.
|
||||
MLIR_CAPI_EXPORTED MlirType npcompTorchNumberTypeGet(MlirContext context);
|
||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchNumberTypeGet(MlirContext context);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// !torch.dict type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a !torch.dict type.
|
||||
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchDict(MlirType t);
|
||||
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchDict(MlirType t);
|
||||
|
||||
/// Gets the !torch.dict type.
|
||||
MLIR_CAPI_EXPORTED MlirType npcompTorchDictTypeGet(MlirType keyType,
|
||||
MlirType valueType);
|
||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGet(MlirType keyType,
|
||||
MlirType valueType);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // NPCOMP_C_TORCHTYPES_H
|
||||
#endif // TORCHMLIR_C_TORCHTYPES_H
|
|
@ -0,0 +1 @@
|
|||
add_subdirectory(Dialect)
|
|
@ -0,0 +1 @@
|
|||
add_subdirectory(Torch)
|
|
@ -13,7 +13,7 @@ include "mlir/IR/OpBase.td"
|
|||
|
||||
def Torch_Dialect : Dialect {
|
||||
let name = "torch";
|
||||
let cppNamespace = "::mlir::NPCOMP::Torch";
|
||||
let cppNamespace = "::mlir::torch::Torch";
|
||||
let description = [{
|
||||
Top-level dialect for interfacing PyTorch and MLIR.
|
||||
|
||||
|
@ -21,7 +21,7 @@ def Torch_Dialect : Dialect {
|
|||
|
||||
This dialect also provides transforms that lower it to the
|
||||
"Torch backend contract", which is an IR form that we present to
|
||||
later conversions, such as conversion to the npcomp backend contract.
|
||||
later conversions.
|
||||
The Torch backend contract significantly simplifies the IR representation
|
||||
and puts it in a form easier for later lowering to work on. Specifically:
|
||||
- The TorchScript object graph has been flattened to a list of globals (see
|
||||
|
@ -39,11 +39,12 @@ def Torch_Dialect : Dialect {
|
|||
|
||||
class TorchOpTrait<string name> : OpTrait, NativeTrait<"", ""> {
|
||||
let trait = name;
|
||||
let cppNamespace = "::mlir::NPCOMP::Torch::OpTrait";
|
||||
let cppNamespace = "::mlir::torch::Torch::OpTrait";
|
||||
}
|
||||
|
||||
def HasValueSemantics : TorchOpTrait<"HasValueSemantics">;
|
||||
def IsTrailingUnderscoreInplaceVariant
|
||||
: TorchOpTrait<"IsTrailingUnderscoreInplaceVariant">;
|
||||
def AllowsTypeRefinement : TorchOpTrait<"AllowsTypeRefinement">;
|
||||
|
||||
#endif // TORCH_BASE
|
|
@ -6,11 +6,11 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_DIALECT_TORCH_IR_TORCHDIALECT_H
|
||||
#define NPCOMP_DIALECT_TORCH_IR_TORCHDIALECT_H
|
||||
#ifndef TORCHMLIR_DIALECT_TORCH_IR_TORCHDIALECT_H
|
||||
#define TORCHMLIR_DIALECT_TORCH_IR_TORCHDIALECT_H
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h.inc"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h.inc"
|
||||
|
||||
#endif // NPCOMP_DIALECT_TORCH_IR_TORCHDIALECT_H
|
||||
#endif // TORCHMLIR_DIALECT_TORCH_IR_TORCHDIALECT_H
|
|
@ -6,8 +6,8 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_DIALECT_TORCH_IR_TORCHOPS_H
|
||||
#define NPCOMP_DIALECT_TORCH_IR_TORCHOPS_H
|
||||
#ifndef TORCHMLIR_DIALECT_TORCH_IR_TORCHOPS_H
|
||||
#define TORCHMLIR_DIALECT_TORCH_IR_TORCHOPS_H
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
@ -18,15 +18,14 @@
|
|||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchTraits.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
||||
#include "npcomp/Interfaces/Traits.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchTraits.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h.inc"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace torch {
|
||||
namespace Torch {
|
||||
|
||||
namespace detail {
|
||||
|
@ -117,11 +116,11 @@ m_TorchConstantIntList(SmallVectorImpl<int64_t> &bind_values) {
|
|||
Value copyTensorToType(OpBuilder &builder, Location loc, BaseTensorType newType,
|
||||
Value tensor);
|
||||
} // namespace Torch
|
||||
} // namespace NPCOMP
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
||||
template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::SlotOp> {
|
||||
using SlotOp = ::mlir::NPCOMP::Torch::SlotOp;
|
||||
template <> struct llvm::DenseMapInfo<::mlir::torch::Torch::SlotOp> {
|
||||
using SlotOp = ::mlir::torch::Torch::SlotOp;
|
||||
static SlotOp getEmptyKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||
return SlotOp::getFromOpaquePointer(pointer);
|
||||
|
@ -136,8 +135,8 @@ template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::SlotOp> {
|
|||
static bool isEqual(SlotOp lhs, SlotOp rhs) { return lhs == rhs; }
|
||||
};
|
||||
|
||||
template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::NnModuleOp> {
|
||||
using NnModuleOp = ::mlir::NPCOMP::Torch::NnModuleOp;
|
||||
template <> struct llvm::DenseMapInfo<::mlir::torch::Torch::NnModuleOp> {
|
||||
using NnModuleOp = ::mlir::torch::Torch::NnModuleOp;
|
||||
static NnModuleOp getEmptyKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||
return NnModuleOp::getFromOpaquePointer(pointer);
|
||||
|
@ -152,8 +151,8 @@ template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::NnModuleOp> {
|
|||
static bool isEqual(NnModuleOp lhs, NnModuleOp rhs) { return lhs == rhs; }
|
||||
};
|
||||
|
||||
template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::ClassTypeOp> {
|
||||
using ClassTypeOp = ::mlir::NPCOMP::Torch::ClassTypeOp;
|
||||
template <> struct llvm::DenseMapInfo<::mlir::torch::Torch::ClassTypeOp> {
|
||||
using ClassTypeOp = ::mlir::torch::Torch::ClassTypeOp;
|
||||
static ClassTypeOp getEmptyKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||
return ClassTypeOp::getFromOpaquePointer(pointer);
|
||||
|
@ -168,8 +167,8 @@ template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::ClassTypeOp> {
|
|||
static bool isEqual(ClassTypeOp lhs, ClassTypeOp rhs) { return lhs == rhs; }
|
||||
};
|
||||
|
||||
template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::GlobalSlotOp> {
|
||||
using OpTy = ::mlir::NPCOMP::Torch::GlobalSlotOp;
|
||||
template <> struct llvm::DenseMapInfo<::mlir::torch::Torch::GlobalSlotOp> {
|
||||
using OpTy = ::mlir::torch::Torch::GlobalSlotOp;
|
||||
static OpTy getEmptyKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||
return OpTy::getFromOpaquePointer(pointer);
|
||||
|
@ -184,4 +183,4 @@ template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::GlobalSlotOp> {
|
|||
static bool isEqual(OpTy lhs, OpTy rhs) { return lhs == rhs; }
|
||||
};
|
||||
|
||||
#endif // NPCOMP_DIALECT_TORCH_IR_TORCHOPS_H
|
||||
#endif // TORCHMLIR_DIALECT_TORCH_IR_TORCHOPS_H
|
|
@ -9,8 +9,7 @@
|
|||
#ifndef TORCH_OPS
|
||||
#define TORCH_OPS
|
||||
|
||||
include "npcomp/Dialect/Torch/IR/TorchTypes.td"
|
||||
include "npcomp/Interfaces/Traits.td"
|
||||
include "torch-mlir/Dialect/Torch/IR/TorchTypes.td"
|
||||
include "mlir/IR/OpAsmInterface.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
include "mlir/Interfaces/CastInterfaces.td"
|
||||
|
@ -22,9 +21,9 @@ class Torch_Op<string mnemonic, list<OpTrait> traits = []>
|
|||
: Op<Torch_Dialect, mnemonic, traits> {
|
||||
}
|
||||
|
||||
include "npcomp/Dialect/Torch/IR/GeneratedAtenOps.td"
|
||||
include "npcomp/Dialect/Torch/IR/GeneratedPrimOps.td"
|
||||
include "npcomp/Dialect/Torch/IR/GeneratedQuantizedOps.td"
|
||||
include "torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td"
|
||||
include "torch-mlir/Dialect/Torch/IR/GeneratedPrimOps.td"
|
||||
include "torch-mlir/Dialect/Torch/IR/GeneratedQuantizedOps.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TorchScript `torch.nn.Module` object instantiation ops.
|
||||
|
@ -32,7 +31,7 @@ include "npcomp/Dialect/Torch/IR/GeneratedQuantizedOps.td"
|
|||
|
||||
def Torch_NnModuleOp : Torch_Op<"nn_module", [
|
||||
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
|
||||
SingleBlockImplicitTerminator<"::mlir::NPCOMP::Torch::NnModuleTerminatorOp">]> {
|
||||
SingleBlockImplicitTerminator<"::mlir::torch::Torch::NnModuleTerminatorOp">]> {
|
||||
let summary = "Constructs a torch.nn.Module";
|
||||
let description = [{
|
||||
This op is used to represent a torch.nn.Module when importing a
|
||||
|
@ -75,7 +74,7 @@ def Torch_NnModuleOp : Torch_Op<"nn_module", [
|
|||
}
|
||||
|
||||
def Torch_NnModuleTerminatorOp : Torch_Op<"nn_module_terminator", [Terminator,
|
||||
HasParent<"::mlir::NPCOMP::Torch::NnModuleOp">]> {
|
||||
HasParent<"::mlir::torch::Torch::NnModuleOp">]> {
|
||||
let summary = "Implicit terminator for torch.nn_module";
|
||||
|
||||
let arguments = (ins);
|
||||
|
@ -85,7 +84,7 @@ def Torch_NnModuleTerminatorOp : Torch_Op<"nn_module_terminator", [Terminator,
|
|||
}
|
||||
|
||||
def Torch_SlotOp : Torch_Op<"slot", [
|
||||
HasParent<"::mlir::NPCOMP::Torch::NnModuleOp">]> {
|
||||
HasParent<"::mlir::torch::Torch::NnModuleOp">]> {
|
||||
let summary = "Define the value of a slot of a torch.nn.Module";
|
||||
let description = [{
|
||||
This op specifies that the initial value of the slot `name` of the
|
||||
|
@ -107,7 +106,7 @@ def Torch_SlotOp : Torch_Op<"slot", [
|
|||
|
||||
def Torch_ClassTypeOp : Torch_Op<"class_type", [
|
||||
Symbol,
|
||||
SingleBlockImplicitTerminator<"::mlir::NPCOMP::Torch::ClassTypeTerminatorOp">]> {
|
||||
SingleBlockImplicitTerminator<"::mlir::torch::Torch::ClassTypeTerminatorOp">]> {
|
||||
let summary = "Constructs a torch.ClassType";
|
||||
let description = [{
|
||||
Declares a class type. Class types are the types used to describe
|
||||
|
@ -152,7 +151,7 @@ def Torch_ClassTypeOp : Torch_Op<"class_type", [
|
|||
}
|
||||
|
||||
def Torch_ClassTypeTerminatorOp : Torch_Op<"class_type_terminator", [Terminator,
|
||||
HasParent<"::mlir::NPCOMP::Torch::ClassTypeOp">]> {
|
||||
HasParent<"::mlir::torch::Torch::ClassTypeOp">]> {
|
||||
let summary = "Implicit terminator for torch.class_type";
|
||||
|
||||
let arguments = (ins);
|
||||
|
@ -162,7 +161,7 @@ def Torch_ClassTypeTerminatorOp : Torch_Op<"class_type_terminator", [Terminator,
|
|||
}
|
||||
|
||||
def Torch_MethodOp : Torch_Op<"method", [
|
||||
HasParent<"::mlir::NPCOMP::Torch::ClassTypeOp">,
|
||||
HasParent<"::mlir::torch::Torch::ClassTypeOp">,
|
||||
DeclareOpInterfaceMethods<SymbolUserOpInterface>
|
||||
]> {
|
||||
let summary = "Declare a method of a torch.class_type";
|
||||
|
@ -193,7 +192,7 @@ def Torch_MethodOp : Torch_Op<"method", [
|
|||
}
|
||||
|
||||
def Torch_AttrOp : Torch_Op<"attr", [
|
||||
HasParent<"::mlir::NPCOMP::Torch::ClassTypeOp">
|
||||
HasParent<"::mlir::torch::Torch::ClassTypeOp">
|
||||
]> {
|
||||
let summary = "Declare an attribute of a torch.class_type";
|
||||
let description = [{
|
||||
|
@ -231,7 +230,7 @@ def Torch_AttrOp : Torch_Op<"attr", [
|
|||
def Torch_GlobalSlotOp : Torch_Op<"global_slot", [
|
||||
Symbol,
|
||||
IsolatedFromAbove,
|
||||
SingleBlockImplicitTerminator<"::mlir::NPCOMP::Torch::GlobalSlotInitOp">
|
||||
SingleBlockImplicitTerminator<"::mlir::torch::Torch::GlobalSlotInitOp">
|
||||
]> {
|
||||
let summary = "A slot with global storage";
|
||||
let description = [{
|
||||
|
@ -256,7 +255,7 @@ def Torch_GlobalSlotOp : Torch_Op<"global_slot", [
|
|||
|
||||
def Torch_GlobalSlotInitOp : Torch_Op<"global_slot.init", [
|
||||
Terminator,
|
||||
HasParent<"::mlir::NPCOMP::Torch::GlobalSlotOp">]> {
|
||||
HasParent<"::mlir::torch::Torch::GlobalSlotOp">]> {
|
||||
let summary = "yield-like terminator for torch.global_slot initializer region";
|
||||
let description = [{
|
||||
The operand to this op becomes the initial value of the parent
|
||||
|
@ -463,7 +462,7 @@ def Torch_PrimLoopOp : Torch_Op<"prim.Loop", [
|
|||
def Torch_PrimLoopConditionOp : Torch_Op<"prim.Loop.condition", [
|
||||
DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>,
|
||||
Terminator,
|
||||
HasParent<"::mlir::NPCOMP::Torch::PrimLoopOp">]> {
|
||||
HasParent<"::mlir::torch::Torch::PrimLoopOp">]> {
|
||||
let summary = "yield-like terminator for torch.prim.Loop";
|
||||
let description = [{
|
||||
Does not correspond to any torch prim op directly (the way that they model
|
||||
|
@ -512,7 +511,7 @@ def Torch_PrimIfOp : Torch_Op<"prim.If", [
|
|||
def Torch_PrimIfYieldOp : Torch_Op<"prim.If.yield", [
|
||||
Terminator,
|
||||
ReturnLike,
|
||||
HasParent<"::mlir::NPCOMP::Torch::PrimIfOp">]> {
|
||||
HasParent<"::mlir::torch::Torch::PrimIfOp">]> {
|
||||
let summary = "yield-like terminator for torch.prim.If";
|
||||
let description = [{
|
||||
Does not correspond to any torch prim op directly (the way that they model
|
|
@ -10,15 +10,14 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
#ifndef NPCOMP_DIALECT_TORCH_IR_TORCHTRAITS_H
|
||||
#define NPCOMP_DIALECT_TORCH_IR_TORCHTRAITS_H
|
||||
#ifndef TORCHMLIR_DIALECT_TORCH_IR_TORCHTRAITS_H
|
||||
#define TORCHMLIR_DIALECT_TORCH_IR_TORCHTRAITS_H
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace torch {
|
||||
namespace Torch {
|
||||
namespace OpTrait {
|
||||
|
||||
|
@ -39,9 +38,16 @@ class IsTrailingUnderscoreInplaceVariant
|
|||
: public ::mlir::OpTrait::TraitBase<ConcreteType,
|
||||
IsTrailingUnderscoreInplaceVariant> {};
|
||||
|
||||
// If a Torch op has this trait, it means that the op allows all of its operand
|
||||
// and result types to be refined. That is, a less specific type is allowed to
|
||||
// be replaced by a more specific type, according to PEP 483 subtyping rules.
|
||||
template <typename ConcreteType>
|
||||
class AllowsTypeRefinement
|
||||
: public ::mlir::OpTrait::TraitBase<ConcreteType, AllowsTypeRefinement> {};
|
||||
|
||||
} // namespace OpTrait
|
||||
} // namespace Torch
|
||||
} // namespace NPCOMP
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
||||
#endif // NPCOMP_DIALECT_TORCH_IR_TORCHTRAITS_H
|
||||
#endif // TORCHMLIR_DIALECT_TORCH_IR_TORCHTRAITS_H
|
|
@ -6,13 +6,13 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_DIALECT_TORCH_IR_TORCHTYPES_H
|
||||
#define NPCOMP_DIALECT_TORCH_IR_TORCHTYPES_H
|
||||
#ifndef TORCHMLIR_DIALECT_TORCH_IR_TORCHTYPES_H
|
||||
#define TORCHMLIR_DIALECT_TORCH_IR_TORCHTYPES_H
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace torch {
|
||||
namespace Torch {
|
||||
|
||||
class NonValueTensorType;
|
||||
|
@ -87,18 +87,18 @@ public:
|
|||
ValueTensorType getWithValueSemantics() const;
|
||||
};
|
||||
} // namespace Torch
|
||||
} // namespace NPCOMP
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.h.inc"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Inline definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace torch {
|
||||
namespace Torch {
|
||||
|
||||
inline Optional<ArrayRef<int64_t>> BaseTensorType::getOptionalSizes() const {
|
||||
|
@ -122,7 +122,7 @@ inline bool BaseTensorType::classof(Type type) {
|
|||
}
|
||||
|
||||
} // namespace Torch
|
||||
} // namespace NPCOMP
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
||||
#endif // NPCOMP_DIALECT_TORCH_IR_TORCHTYPES_H
|
||||
#endif // TORCHMLIR_DIALECT_TORCH_IR_TORCHTYPES_H
|
|
@ -9,7 +9,7 @@
|
|||
#ifndef TORCH_TYPES
|
||||
#define TORCH_TYPES
|
||||
|
||||
include "npcomp/Dialect/Torch/IR/TorchBase.td"
|
||||
include "torch-mlir/Dialect/Torch/IR/TorchBase.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type defs
|
||||
|
@ -83,7 +83,7 @@ class OptionalArrayRefParameter<string arrayOf, string desc = ""> :
|
|||
}
|
||||
|
||||
class AnyTorchTensorType<string name, string typeMnemonic>
|
||||
: Torch_Type<name, typeMnemonic, "::mlir::NPCOMP::Torch::BaseTensorType"> {
|
||||
: Torch_Type<name, typeMnemonic, "::mlir::torch::Torch::BaseTensorType"> {
|
||||
let summary = "Multi-dimensional array modeling Torch's Tensor type";
|
||||
let description = [{
|
||||
Syntax:
|
||||
|
@ -107,8 +107,8 @@ class AnyTorchTensorType<string name, string typeMnemonic>
|
|||
a strict separation between the value-semantic and potentially-mutating
|
||||
worlds, as one of our main jobs in the compiler is to isolate the mutating
|
||||
parts as much as possible because most lower levels of the compiler stack
|
||||
are expected to require value semantics. E.g. npcomp's backend contract
|
||||
is mostly in terms of linalg-on-tensor for compute-heavy ops, which require
|
||||
are expected to require value semantics. E.g. many backend contracts
|
||||
mostly use linalg-on-tensor for compute-heavy ops, which require
|
||||
a conversion to the builtin `tensor` type which has value semantics.
|
||||
Some notes about value semantics:
|
||||
- Using the type system described in PEP 483 (which TorchScript and other
|
||||
|
@ -165,7 +165,7 @@ class AnyTorchTensorType<string name, string typeMnemonic>
|
|||
|
||||
Note: We avoid the C++ identifier `TensorType` to avoid C++ name ambiguities
|
||||
with `mlir::TensorType`, since most code is transitively nested in
|
||||
both `::mlir` and `::mlir::NPCOMP::Torch` namespaces.
|
||||
both `::mlir` and `::mlir::torch::Torch` namespaces.
|
||||
|
||||
Note: We use the Torch-aligned terminology "sizes" and "dtype" instead of
|
||||
the MLIR-aligned terminology "rank/shape" and "element type". The cheat
|
||||
|
@ -209,7 +209,7 @@ def Torch_ValueTensorType : AnyTorchTensorType<"ValueTensor", "vtensor"> {
|
|||
}
|
||||
|
||||
def AnyTorchTensorType : Type<
|
||||
CPred<"$_self.isa<::mlir::NPCOMP::Torch::BaseTensorType>()">,
|
||||
CPred<"$_self.isa<::mlir::torch::Torch::BaseTensorType>()">,
|
||||
"Any Torch tensor type"
|
||||
>;
|
||||
|
||||
|
@ -317,7 +317,7 @@ def Torch_LinearParamsType : Torch_Type<"LinearParams", "LinearParams"> {
|
|||
The whole LinearPackedParamsBase abstraction as it stands in PyTorch is a
|
||||
very library-call-y, runtime-y thing that embodies a number of assumptions
|
||||
about the structure of how the program will be executed, which need not hold
|
||||
for npcomp backends.
|
||||
for backends.
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -386,12 +386,12 @@ def TorchOptionalBoolType:
|
|||
def TorchOptionalDeviceType:
|
||||
OptionalOf<Torch_DeviceType, "Optional torch device type">;
|
||||
|
||||
def IsListTypePred : CPred<"$_self.isa<::mlir::NPCOMP::Torch::ListType>()">;
|
||||
def IsListTypePred : CPred<"$_self.isa<::mlir::torch::Torch::ListType>()">;
|
||||
class ListOf<list<Type> allowedTypes, string descr> :
|
||||
ContainerType<AnyTypeOf<allowedTypes>,
|
||||
IsListTypePred,
|
||||
"$_self.cast<::mlir::NPCOMP::Torch::ListType>().getContainedType()",
|
||||
descr, "::mlir::NPCOMP::Torch::ListType">;
|
||||
"$_self.cast<::mlir::torch::Torch::ListType>().getContainedType()",
|
||||
descr, "::mlir::torch::Torch::ListType">;
|
||||
|
||||
def TorchBoolListType : ListOf<[Torch_BoolType], "Bool list type (bool[])">;
|
||||
def TorchIntListType : ListOf<[Torch_IntType], "Int list type (int[])">;
|
|
@ -0,0 +1,5 @@
|
|||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||
add_public_tablegen_target(TorchMLIRTorchPassIncGen)
|
||||
|
||||
add_mlir_doc(Passes TorchMLIRTorchTransforms ./ -gen-pass-doc)
|
|
@ -6,15 +6,15 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSES_H
|
||||
#define NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSES_H
|
||||
#ifndef TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSES_H
|
||||
#define TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSES_H
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace torch {
|
||||
namespace Torch {
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createGlobalizeObjectGraphPass();
|
||||
|
@ -58,7 +58,7 @@ std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
|
|||
/// Registers all Torch transformation passes.
|
||||
void registerTorchPasses();
|
||||
|
||||
} // namespace NPCOMP
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
||||
#endif // NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSES_H
|
||||
#endif // TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSES_H
|
|
@ -6,14 +6,14 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_TORCH_PASSES
|
||||
#define NPCOMP_TORCH_PASSES
|
||||
#ifndef TORCHMLIR_TORCH_PASSES
|
||||
#define TORCHMLIR_TORCH_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def GlobalizeObjectGraph : Pass<"torch-globalize-object-graph", "ModuleOp"> {
|
||||
let summary = "Converts TorchScript object graphs to a globalized form";
|
||||
let constructor = "mlir::NPCOMP::Torch::createGlobalizeObjectGraphPass()";
|
||||
let constructor = "mlir::torch::Torch::createGlobalizeObjectGraphPass()";
|
||||
let description = [{
|
||||
This pass converts a subset of possible TorchScript modules into a
|
||||
more restrictive lower-level form that strips away the need to be
|
||||
|
@ -80,7 +80,7 @@ def GlobalizeObjectGraph : Pass<"torch-globalize-object-graph", "ModuleOp"> {
|
|||
- Rationale: This makes the representation of initial values simpler. Also
|
||||
as of Feb 2021, TorchScript won't import into this form except
|
||||
potentially for Tensors (it has a bug related to the identity of
|
||||
objects). And for tensors, the npcomp IValue importer only supports a
|
||||
objects). And for tensors, the IValue importer only supports a
|
||||
very restricted form of aliasing anyway for other reasons. We are
|
||||
waiting for signals that more general handling of object aliasing is
|
||||
important to devote the effort to it.
|
||||
|
@ -90,7 +90,7 @@ def GlobalizeObjectGraph : Pass<"torch-globalize-object-graph", "ModuleOp"> {
|
|||
def PrepareForGlobalizeObjectGraph
|
||||
: Pass<"torch-prepare-for-globalize-object-graph", "ModuleOp"> {
|
||||
let summary = "Lowering in preparation for globalizing";
|
||||
let constructor = "mlir::NPCOMP::Torch::createPrepareForGlobalizeObjectGraphPass()";
|
||||
let constructor = "mlir::torch::Torch::createPrepareForGlobalizeObjectGraphPass()";
|
||||
let description = [{
|
||||
Establishes and the invariants needed by the
|
||||
torch-globalize-object-graph transformation. Fails if that cannot be
|
||||
|
@ -104,7 +104,7 @@ def PrepareForGlobalizeObjectGraph
|
|||
def AdjustCallingConventions
|
||||
: Pass<"torch-adjust-calling-conventions", "ModuleOp"> {
|
||||
let summary = "Adjust the calling conventions of functions";
|
||||
let constructor = "mlir::NPCOMP::Torch::createAdjustCallingConventionsPass()";
|
||||
let constructor = "mlir::torch::Torch::createAdjustCallingConventionsPass()";
|
||||
let description = [{
|
||||
Adjusts the calling conventions of functions in the module, with the aim of
|
||||
preparing them for backends and further lowering passes. As this changes
|
||||
|
@ -127,7 +127,7 @@ def AdjustCallingConventions
|
|||
|
||||
def RefineTypes : Pass<"torch-refine-types", "FuncOp"> {
|
||||
let summary = "Refine types";
|
||||
let constructor = "mlir::NPCOMP::Torch::createRefineTypesPass()";
|
||||
let constructor = "mlir::torch::Torch::createRefineTypesPass()";
|
||||
let description = [{
|
||||
Refines types of the program. Currently, this means shapes and dtypes of
|
||||
tensors/arrays.
|
||||
|
@ -136,7 +136,7 @@ def RefineTypes : Pass<"torch-refine-types", "FuncOp"> {
|
|||
|
||||
def InlineGlobalSlots : Pass<"torch-inline-global-slots", "ModuleOp"> {
|
||||
let summary = "Inlines torch.global_slot ops.";
|
||||
let constructor = "mlir::NPCOMP::Torch::createInlineGlobalSlotsPass()";
|
||||
let constructor = "mlir::torch::Torch::createInlineGlobalSlotsPass()";
|
||||
let description = [{
|
||||
Inlines torch.global_slot ops when it is safe to do so.
|
||||
|
||||
|
@ -150,7 +150,7 @@ def InlineGlobalSlots : Pass<"torch-inline-global-slots", "ModuleOp"> {
|
|||
|
||||
def ReduceOpVariants : Pass<"torch-reduce-op-variants", "FuncOp"> {
|
||||
let summary = "Reduces variants of ops to a smaller set of ops.";
|
||||
let constructor = "mlir::NPCOMP::Torch::createReduceOpVariantsPass()";
|
||||
let constructor = "mlir::torch::Torch::createReduceOpVariantsPass()";
|
||||
let description = [{
|
||||
Replaces ops with other ops to reduce the number of variants that
|
||||
need to be handled elsewhere in the code.
|
||||
|
@ -181,13 +181,13 @@ def MaximizeValueSemantics : Pass<"torch-maximize-value-semantics", "FuncOp"> {
|
|||
Also, this pass doesn't currently handle interprocedural rewriting
|
||||
(of private functions), which is even more complex.
|
||||
}];
|
||||
let constructor = "mlir::NPCOMP::Torch::createMaximizeValueSemanticsPass()";
|
||||
let constructor = "mlir::torch::Torch::createMaximizeValueSemanticsPass()";
|
||||
}
|
||||
|
||||
|
||||
def RefinePublicReturn : Pass<"torch-refine-public-return", "ModuleOp"> {
|
||||
let summary = "Refine public return";
|
||||
let constructor = "mlir::NPCOMP::Torch::createRefinePublicReturnPass()";
|
||||
let constructor = "mlir::torch::Torch::createRefinePublicReturnPass()";
|
||||
let description = [{
|
||||
Refines types of values returned from public functions based on
|
||||
intraprocedural information.
|
||||
|
@ -214,4 +214,4 @@ def RefinePublicReturn : Pass<"torch-refine-public-return", "ModuleOp"> {
|
|||
}];
|
||||
}
|
||||
|
||||
#endif // NPCOMP_TORCH_PASSES
|
||||
#endif // TORCHMLIR_TORCH_PASSES
|
|
@ -0,0 +1,23 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCH_MLIR_INITALL_H
|
||||
#define TORCH_MLIR_INITALL_H
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
|
||||
void registerAllDialects(mlir::DialectRegistry ®istry);
|
||||
void registerAllPasses();
|
||||
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TORCH_MLIR_INITALL_H
|
|
@ -0,0 +1,18 @@
|
|||
add_mlir_library(TorchMLIRCAPI
|
||||
Registration.cpp
|
||||
TorchTypes.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/torch-mlir-c/
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRSupport
|
||||
TorchMLIRTorchDialect
|
||||
TorchMLIRInitAll
|
||||
)
|
||||
|
||||
torch_mlir_target_includes(TorchMLIRCAPI)
|
|
@ -0,0 +1,25 @@
|
|||
//===- Registration.cpp - C Interface for MLIR Registration ---------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir-c/Registration.h"
|
||||
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "mlir/Conversion/Passes.h"
|
||||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "torch-mlir/InitAll.h"
|
||||
|
||||
void torchMlirRegisterAllDialects(MlirContext context) {
|
||||
mlir::DialectRegistry registry;
|
||||
mlir::torch::registerAllDialects(registry);
|
||||
unwrap(context)->appendDialectRegistry(registry);
|
||||
// TODO: Don't eagerly load once D88162 is in and clients can do this.
|
||||
unwrap(context)->loadAllAvailableDialects();
|
||||
}
|
||||
|
||||
void torchMlirRegisterAllPasses() { mlir::torch::registerAllPasses(); }
|
|
@ -6,26 +6,26 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp-c/TorchTypes.h"
|
||||
#include "torch-mlir-c/TorchTypes.h"
|
||||
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "mlir/CAPI/Support.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::torch;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.nn.Module type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsATorchNnModule(MlirType t) {
|
||||
bool torchMlirTypeIsATorchNnModule(MlirType t) {
|
||||
return unwrap(t).isa<Torch::NnModuleType>();
|
||||
}
|
||||
|
||||
MlirType npcompTorchNnModuleTypeGet(MlirContext context,
|
||||
MlirStringRef className) {
|
||||
MlirType torchMlirTorchNnModuleTypeGet(MlirContext context,
|
||||
MlirStringRef className) {
|
||||
return wrap(Torch::NnModuleType::get(unwrap(context), unwrap(className)));
|
||||
}
|
||||
|
||||
|
@ -33,11 +33,11 @@ MlirType npcompTorchNnModuleTypeGet(MlirContext context,
|
|||
// torch.optional type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsATorchOptional(MlirType t) {
|
||||
bool torchMlirTypeIsATorchOptional(MlirType t) {
|
||||
return unwrap(t).isa<Torch::OptionalType>();
|
||||
}
|
||||
|
||||
MlirType npcompTorchOptionalTypeGet(MlirType containedType) {
|
||||
MlirType torchMlirTorchOptionalTypeGet(MlirType containedType) {
|
||||
return wrap(Torch::OptionalType::get(unwrap(containedType)));
|
||||
}
|
||||
|
||||
|
@ -45,13 +45,13 @@ MlirType npcompTorchOptionalTypeGet(MlirType containedType) {
|
|||
// torch.tuple<T1, T2, T3> type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsATorchTuple(MlirType t) {
|
||||
bool torchMlirTypeIsATorchTuple(MlirType t) {
|
||||
return unwrap(t).isa<Torch::TupleType>();
|
||||
}
|
||||
|
||||
MlirType npcompTorchTupleTypeGet(MlirContext context,
|
||||
intptr_t numContainedTypes,
|
||||
MlirType const *containedTypes) {
|
||||
MlirType torchMlirTorchTupleTypeGet(MlirContext context,
|
||||
intptr_t numContainedTypes,
|
||||
MlirType const *containedTypes) {
|
||||
return wrap(Torch::TupleType::get(
|
||||
unwrap(context),
|
||||
llvm::to_vector<6>(
|
||||
|
@ -63,11 +63,11 @@ MlirType npcompTorchTupleTypeGet(MlirContext context,
|
|||
// torch.list<T> type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsATorchList(MlirType t) {
|
||||
bool torchMlirTypeIsATorchList(MlirType t) {
|
||||
return unwrap(t).isa<Torch::ListType>();
|
||||
}
|
||||
|
||||
MlirType npcompTorchListTypeGet(MlirType containedType) {
|
||||
MlirType torchMlirTorchListTypeGet(MlirType containedType) {
|
||||
return wrap(Torch::ListType::get(unwrap(containedType)));
|
||||
}
|
||||
|
||||
|
@ -75,11 +75,11 @@ MlirType npcompTorchListTypeGet(MlirType containedType) {
|
|||
// torch.Device type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsATorchDevice(MlirType t) {
|
||||
bool torchMlirTypeIsATorchDevice(MlirType t) {
|
||||
return unwrap(t).isa<Torch::DeviceType>();
|
||||
}
|
||||
|
||||
MlirType npcompTorchDeviceTypeGet(MlirContext context) {
|
||||
MlirType torchMlirTorchDeviceTypeGet(MlirContext context) {
|
||||
return wrap(Torch::DeviceType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
|
@ -87,11 +87,11 @@ MlirType npcompTorchDeviceTypeGet(MlirContext context) {
|
|||
// torch.bool type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsATorchBool(MlirType t) {
|
||||
bool torchMlirTypeIsATorchBool(MlirType t) {
|
||||
return unwrap(t).isa<Torch::BoolType>();
|
||||
}
|
||||
|
||||
MlirType npcompTorchBoolTypeGet(MlirContext context) {
|
||||
MlirType torchMlirTorchBoolTypeGet(MlirContext context) {
|
||||
return wrap(Torch::BoolType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
|
@ -99,11 +99,11 @@ MlirType npcompTorchBoolTypeGet(MlirContext context) {
|
|||
// torch.int type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsATorchInt(MlirType t) {
|
||||
bool torchMlirTypeIsATorchInt(MlirType t) {
|
||||
return unwrap(t).isa<Torch::IntType>();
|
||||
}
|
||||
|
||||
MlirType npcompTorchIntTypeGet(MlirContext context) {
|
||||
MlirType torchMlirTorchIntTypeGet(MlirContext context) {
|
||||
return wrap(Torch::IntType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
|
@ -111,11 +111,11 @@ MlirType npcompTorchIntTypeGet(MlirContext context) {
|
|||
// torch.float type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsATorchFloat(MlirType t) {
|
||||
bool torchMlirTypeIsATorchFloat(MlirType t) {
|
||||
return unwrap(t).isa<Torch::FloatType>();
|
||||
}
|
||||
|
||||
MlirType npcompTorchFloatTypeGet(MlirContext context) {
|
||||
MlirType torchMlirTorchFloatTypeGet(MlirContext context) {
|
||||
return wrap(Torch::FloatType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
|
@ -123,11 +123,11 @@ MlirType npcompTorchFloatTypeGet(MlirContext context) {
|
|||
// torch.LinearParams type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsATorchLinearParams(MlirType t) {
|
||||
bool torchMlirTypeIsATorchLinearParams(MlirType t) {
|
||||
return unwrap(t).isa<Torch::LinearParamsType>();
|
||||
}
|
||||
|
||||
MlirType npcompTorchLinearParamsTypeGet(MlirContext context) {
|
||||
MlirType torchMlirTorchLinearParamsTypeGet(MlirContext context) {
|
||||
return wrap(Torch::LinearParamsType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
|
@ -135,11 +135,11 @@ MlirType npcompTorchLinearParamsTypeGet(MlirContext context) {
|
|||
// torch.qint8 type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsATorchQInt8(MlirType t) {
|
||||
bool torchMlirTypeIsATorchQInt8(MlirType t) {
|
||||
return unwrap(t).isa<Torch::QInt8Type>();
|
||||
}
|
||||
|
||||
MlirType npcompTorchQInt8TypeGet(MlirContext context) {
|
||||
MlirType torchMlirTorchQInt8TypeGet(MlirContext context) {
|
||||
return wrap(Torch::QInt8Type::get(unwrap(context)));
|
||||
}
|
||||
|
||||
|
@ -147,14 +147,14 @@ MlirType npcompTorchQInt8TypeGet(MlirContext context) {
|
|||
// torch.tensor type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsATorchNonValueTensor(MlirType t) {
|
||||
bool torchMlirTypeIsATorchNonValueTensor(MlirType t) {
|
||||
return unwrap(t).isa<Torch::NonValueTensorType>();
|
||||
}
|
||||
|
||||
MlirType npcompTorchNonValueTensorTypeGet(MlirContext context,
|
||||
intptr_t numSizes,
|
||||
const int64_t *optionalSizes,
|
||||
MlirType optionalDtype) {
|
||||
MlirType torchMlirTorchNonValueTensorTypeGet(MlirContext context,
|
||||
intptr_t numSizes,
|
||||
const int64_t *optionalSizes,
|
||||
MlirType optionalDtype) {
|
||||
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
|
||||
if (optionalSizes)
|
||||
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
|
||||
|
@ -162,13 +162,13 @@ MlirType npcompTorchNonValueTensorTypeGet(MlirContext context,
|
|||
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
|
||||
}
|
||||
|
||||
MlirType npcompTorchNonValueTensorTypeGetWithLeastStaticInformation(
|
||||
MlirType torchMlirTorchNonValueTensorTypeGetWithLeastStaticInformation(
|
||||
MlirContext context) {
|
||||
return wrap(Torch::NonValueTensorType::getWithLeastStaticInformation(
|
||||
unwrap(context)));
|
||||
}
|
||||
|
||||
MlirType npcompTorchNonValueTensorTypeGetFromShaped(MlirType type) {
|
||||
MlirType torchMlirTorchNonValueTensorTypeGetFromShaped(MlirType type) {
|
||||
return wrap(Torch::NonValueTensorType::getFromShaped(
|
||||
unwrap(type).cast<ShapedType>()));
|
||||
}
|
||||
|
@ -177,13 +177,14 @@ MlirType npcompTorchNonValueTensorTypeGetFromShaped(MlirType type) {
|
|||
// torch.vtensor type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsATorchValueTensor(MlirType t) {
|
||||
bool torchMlirTypeIsATorchValueTensor(MlirType t) {
|
||||
return unwrap(t).isa<Torch::ValueTensorType>();
|
||||
}
|
||||
|
||||
MlirType npcompTorchValueTensorTypeGet(MlirContext context, intptr_t numSizes,
|
||||
const int64_t *optionalSizes,
|
||||
MlirType optionalDtype) {
|
||||
MlirType torchMlirTorchValueTensorTypeGet(MlirContext context,
|
||||
intptr_t numSizes,
|
||||
const int64_t *optionalSizes,
|
||||
MlirType optionalDtype) {
|
||||
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
|
||||
if (optionalSizes)
|
||||
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
|
||||
|
@ -191,13 +192,13 @@ MlirType npcompTorchValueTensorTypeGet(MlirContext context, intptr_t numSizes,
|
|||
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
|
||||
}
|
||||
|
||||
MlirType
|
||||
npcompTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context) {
|
||||
MlirType torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(
|
||||
MlirContext context) {
|
||||
return wrap(
|
||||
Torch::ValueTensorType::getWithLeastStaticInformation(unwrap(context)));
|
||||
}
|
||||
|
||||
MlirType npcompTorchValueTensorTypeGetFromShaped(MlirType type) {
|
||||
MlirType torchMlirTorchValueTensorTypeGetFromShaped(MlirType type) {
|
||||
return wrap(
|
||||
Torch::ValueTensorType::getFromShaped(unwrap(type).cast<ShapedType>()));
|
||||
}
|
||||
|
@ -206,11 +207,11 @@ MlirType npcompTorchValueTensorTypeGetFromShaped(MlirType type) {
|
|||
// torch.none type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsATorchNone(MlirType t) {
|
||||
bool torchMlirTypeIsATorchNone(MlirType t) {
|
||||
return unwrap(t).isa<Torch::NoneType>();
|
||||
}
|
||||
|
||||
MlirType npcompTorchNoneTypeGet(MlirContext context) {
|
||||
MlirType torchMlirTorchNoneTypeGet(MlirContext context) {
|
||||
return wrap(Torch::NoneType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
|
@ -218,11 +219,11 @@ MlirType npcompTorchNoneTypeGet(MlirContext context) {
|
|||
// torch.str type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsATorchString(MlirType t) {
|
||||
bool torchMlirTypeIsATorchString(MlirType t) {
|
||||
return unwrap(t).isa<Torch::StringType>();
|
||||
}
|
||||
|
||||
MlirType npcompTorchStringTypeGet(MlirContext context) {
|
||||
MlirType torchMlirTorchStringTypeGet(MlirContext context) {
|
||||
return wrap(Torch::StringType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
|
@ -230,11 +231,11 @@ MlirType npcompTorchStringTypeGet(MlirContext context) {
|
|||
// torch.any type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsATorchAny(MlirType t) {
|
||||
bool torchMlirTypeIsATorchAny(MlirType t) {
|
||||
return unwrap(t).isa<Torch::AnyType>();
|
||||
}
|
||||
|
||||
MlirType npcompTorchAnyTypeGet(MlirContext context) {
|
||||
MlirType torchMlirTorchAnyTypeGet(MlirContext context) {
|
||||
return wrap(Torch::AnyType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
|
@ -242,11 +243,11 @@ MlirType npcompTorchAnyTypeGet(MlirContext context) {
|
|||
// torch.number type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsATorchNumber(MlirType t) {
|
||||
bool torchMlirTypeIsATorchNumber(MlirType t) {
|
||||
return unwrap(t).isa<Torch::NumberType>();
|
||||
}
|
||||
|
||||
MlirType npcompTorchNumberTypeGet(MlirContext context) {
|
||||
MlirType torchMlirTorchNumberTypeGet(MlirContext context) {
|
||||
return wrap(Torch::NumberType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
|
@ -254,10 +255,10 @@ MlirType npcompTorchNumberTypeGet(MlirContext context) {
|
|||
// torch.Dict type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsATorchDict(MlirType t) {
|
||||
bool torchMlirTypeIsATorchDict(MlirType t) {
|
||||
return unwrap(t).isa<Torch::DictType>();
|
||||
}
|
||||
|
||||
MlirType npcompTorchDictTypeGet(MlirType keyType, MlirType valueType) {
|
||||
MlirType torchMlirTorchDictTypeGet(MlirType keyType, MlirType valueType) {
|
||||
return wrap(Torch::DictType::get(unwrap(keyType), unwrap(valueType)));
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
add_subdirectory(CAPI)
|
||||
add_subdirectory(Dialect)
|
||||
|
||||
add_mlir_library(TorchMLIRInitAll
|
||||
InitAll.cpp
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRSupport
|
||||
TorchMLIRTorchDialect
|
||||
TorchMLIRTorchPasses
|
||||
)
|
||||
|
||||
torch_mlir_target_includes(TorchMLIRInitAll)
|
|
@ -0,0 +1 @@
|
|||
add_subdirectory(Torch)
|
|
@ -1,10 +1,10 @@
|
|||
add_npcomp_dialect_library(NPCOMPTorchDialect
|
||||
add_mlir_library(TorchMLIRTorchDialect
|
||||
TorchDialect.cpp
|
||||
TorchOps.cpp
|
||||
TorchTypes.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Torch
|
||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/Torch
|
||||
|
||||
DEPENDS
|
||||
MLIRTorchOpsIncGen
|
||||
|
@ -17,6 +17,8 @@ add_npcomp_dialect_library(NPCOMPTorchDialect
|
|||
MLIRIR
|
||||
MLIRSupport
|
||||
MLIRControlFlowInterfaces
|
||||
MLIRInferTypeOpInterface
|
||||
MLIRSideEffectInterfaces
|
||||
NPCOMPInterfaces
|
||||
)
|
||||
|
||||
torch_mlir_target_includes(TorchMLIRTorchDialect)
|
|
@ -6,20 +6,20 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::Torch;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.cpp.inc"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.cpp.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect Interfaces
|
||||
|
@ -44,7 +44,7 @@ struct TorchInlinerInterface : public DialectInlinerInterface {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.cpp.inc"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.cpp.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect initialize method.
|
||||
|
@ -53,11 +53,11 @@ struct TorchInlinerInterface : public DialectInlinerInterface {
|
|||
void TorchDialect::initialize() {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.cpp.inc"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc"
|
||||
>();
|
||||
addTypes<
|
||||
#define GET_TYPEDEF_LIST
|
||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.cpp.inc"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.cpp.inc"
|
||||
>();
|
||||
addInterfaces<TorchInlinerInterface>();
|
||||
}
|
||||
|
@ -84,7 +84,6 @@ void TorchDialect::printType(Type type, DialectAsmPrinter &printer) const {
|
|||
llvm_unreachable("unknown 'torch' type");
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect-level verifiers.
|
||||
//===----------------------------------------------------------------------===//
|
|
@ -6,7 +6,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
@ -15,8 +15,8 @@
|
|||
#include "llvm/ADT/StringMap.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::Torch;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
// see https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h#L28
|
||||
static int64_t getDtypeIntegerFromMlirType(Type dtype) {
|
||||
|
@ -36,9 +36,9 @@ static int64_t getDtypeIntegerFromMlirType(Type dtype) {
|
|||
// Utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Value mlir::NPCOMP::Torch::copyTensorToType(OpBuilder &builder, Location loc,
|
||||
BaseTensorType newType,
|
||||
Value tensor) {
|
||||
Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc,
|
||||
BaseTensorType newType,
|
||||
Value tensor) {
|
||||
auto originalType = tensor.getType().cast<BaseTensorType>();
|
||||
// Adjust the static information in the type to match between the original and
|
||||
// new types.
|
||||
|
@ -393,7 +393,10 @@ void DerefineOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
// if not removed eagerly by canonicalizer would prevent ReduceOpVariants
|
||||
// from converting certain tensors value semantics.
|
||||
bool allAllowRefinement =
|
||||
llvm::all_of(op.getResult().getUsers(), allowsTypeRefinement);
|
||||
llvm::all_of(op.getResult().getUsers(), [](Operation *op) {
|
||||
return op
|
||||
->hasTrait<mlir::torch::Torch::OpTrait::AllowsTypeRefinement>();
|
||||
});
|
||||
if (!allAllowRefinement)
|
||||
return failure();
|
||||
rewriter.replaceOp(op, op.getOperand());
|
||||
|
@ -1004,4 +1007,4 @@ OpFoldResult PrimDtypeOp::fold(ArrayRef<Attribute> operands) {
|
|||
return nullptr;
|
||||
}
|
||||
#define GET_OP_CLASSES
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.cpp.inc"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc"
|
|
@ -6,15 +6,15 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::Torch;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TupleType
|
|
@ -14,20 +14,20 @@
|
|||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::Torch;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
// Map from func name and arg index to the type bound for that arg.
|
||||
// This is needed because to rewrite calls, we need the non-local information
|
||||
// from the func definition.
|
||||
// We also benefit from populating this all at once, which avoids ordering
|
||||
// issues between rewriting of func ops vs call ops.
|
||||
using TypeBoundMap = DenseMap<std::pair<StringRef, int>, Type> ;
|
||||
using TypeBoundMap = DenseMap<std::pair<StringRef, int>, Type>;
|
||||
|
||||
namespace {
|
||||
class AdjustCallingConventionForFunc : public OpConversionPattern<FuncOp> {
|
||||
|
@ -136,8 +136,8 @@ public:
|
|||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
TypeBoundMap &typeBoundMap;
|
||||
private:
|
||||
TypeBoundMap &typeBoundMap;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
@ -251,6 +251,6 @@ class AdjustCallingConventionsPass
|
|||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::NPCOMP::Torch::createAdjustCallingConventionsPass() {
|
||||
mlir::torch::Torch::createAdjustCallingConventionsPass() {
|
||||
return std::make_unique<AdjustCallingConventionsPass>();
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
add_npcomp_conversion_library(NPCOMPTorchPasses
|
||||
add_mlir_library(TorchMLIRTorchPasses
|
||||
AdjustCallingConventions.cpp
|
||||
Passes.cpp
|
||||
GlobalizeObjectGraph.cpp
|
||||
|
@ -10,10 +10,10 @@ add_npcomp_conversion_library(NPCOMPTorchPasses
|
|||
RefineTypes.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Torch/Transforms
|
||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/Torch/Transforms
|
||||
|
||||
DEPENDS
|
||||
NPCOMPTorchPassIncGen
|
||||
TorchMLIRTorchPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
@ -22,6 +22,7 @@ add_npcomp_conversion_library(NPCOMPTorchPasses
|
|||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRTransforms
|
||||
NPCOMPTorchDialect
|
||||
NPCOMPInterfaces
|
||||
TorchMLIRTorchDialect
|
||||
)
|
||||
|
||||
torch_mlir_target_includes(TorchMLIRTorchPasses)
|
|
@ -12,9 +12,9 @@
|
|||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
|
@ -22,8 +22,8 @@
|
|||
#include "llvm/ADT/StringSet.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::Torch;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
static FailureOr<NnModuleOp> findRootNnModule(ModuleOp module) {
|
||||
NnModuleOp rootNnModule;
|
||||
|
@ -664,8 +664,6 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) {
|
|||
module.push_back(newFunc);
|
||||
}
|
||||
|
||||
|
||||
|
||||
for (auto &kv : newFuncs) {
|
||||
BlockAndValueMapping mapping;
|
||||
if (failed(analyzeInstances(kv.second, kv.first.argInstances, mapping)))
|
||||
|
@ -706,6 +704,6 @@ class GlobalizeObjectGraphPass
|
|||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::NPCOMP::Torch::createGlobalizeObjectGraphPass() {
|
||||
mlir::torch::Torch::createGlobalizeObjectGraphPass() {
|
||||
return std::make_unique<GlobalizeObjectGraphPass>();
|
||||
}
|
|
@ -11,16 +11,16 @@
|
|||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::Torch;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
namespace {
|
||||
class InlineGlobalSlotsPass
|
||||
|
@ -87,6 +87,6 @@ class InlineGlobalSlotsPass
|
|||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::NPCOMP::Torch::createInlineGlobalSlotsPass() {
|
||||
mlir::torch::Torch::createInlineGlobalSlotsPass() {
|
||||
return std::make_unique<InlineGlobalSlotsPass>();
|
||||
}
|
|
@ -12,12 +12,12 @@
|
|||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::Torch;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
namespace {
|
||||
class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
|
||||
|
@ -131,6 +131,6 @@ class MaximizeValueSemanticsPass
|
|||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::NPCOMP::Torch::createMaximizeValueSemanticsPass() {
|
||||
mlir::torch::Torch::createMaximizeValueSemanticsPass() {
|
||||
return std::make_unique<MaximizeValueSemanticsPass>();
|
||||
}
|
|
@ -6,20 +6,20 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
|
||||
#define NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
|
||||
#ifndef TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
|
||||
#define TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace torch {
|
||||
namespace Torch {
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h.inc"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc"
|
||||
|
||||
} // namespace Torch
|
||||
} // namespace NPCOMP
|
||||
} // namespace torch
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
|
||||
#endif // TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
|
|
@ -6,7 +6,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
|
@ -16,22 +16,22 @@
|
|||
|
||||
namespace {
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h.inc"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc"
|
||||
} // end namespace
|
||||
|
||||
void mlir::NPCOMP::registerTorchPasses() {
|
||||
void mlir::torch::registerTorchPasses() {
|
||||
::registerPasses();
|
||||
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||
"torchscript-to-torch-backend-pipeline",
|
||||
"Pipeline lowering TorchScript object graph IR to Torch backend form.",
|
||||
mlir::NPCOMP::Torch::createTorchScriptToTorchBackendPipeline);
|
||||
mlir::torch::Torch::createTorchScriptToTorchBackendPipeline);
|
||||
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||
"torch-globalized-module-to-torch-backend-pipeline",
|
||||
"Pipeline lowering a globalized Torch program to Torch backend form.",
|
||||
mlir::NPCOMP::Torch::createGlobalizedModuleToTorchBackendPipeline);
|
||||
mlir::torch::Torch::createGlobalizedModuleToTorchBackendPipeline);
|
||||
}
|
||||
|
||||
void mlir::NPCOMP::Torch::createTorchScriptToTorchBackendPipeline(
|
||||
void mlir::torch::Torch::createTorchScriptToTorchBackendPipeline(
|
||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
|
||||
// When we import TorchScript IR, we import their entire "compilation unit",
|
||||
// which can contain numerous functions unrelated to the current program,
|
||||
|
@ -62,7 +62,7 @@ void mlir::NPCOMP::Torch::createTorchScriptToTorchBackendPipeline(
|
|||
createGlobalizedModuleToTorchBackendPipeline(pm, options);
|
||||
}
|
||||
|
||||
void mlir::NPCOMP::Torch::createGlobalizedModuleToTorchBackendPipeline(
|
||||
void mlir::torch::Torch::createGlobalizedModuleToTorchBackendPipeline(
|
||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
|
||||
// General considerations: As a matter of bring-up, we are simultaneously
|
||||
// building out the frontend pipeline and also co-developing the backend
|
|
@ -14,13 +14,13 @@
|
|||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::Torch;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
namespace {
|
||||
class ConvertPrimCallMethodToCall : public OpRewritePattern<PrimCallMethodOp> {
|
||||
|
@ -93,16 +93,15 @@ class PrepareForGlobalizeObjectGraphPass
|
|||
// to the form we want.
|
||||
ConversionTarget target(*context);
|
||||
target.addIllegalOp<PrimCallMethodOp>();
|
||||
target.addDynamicallyLegalOp<ConstantOp>([](ConstantOp op) {
|
||||
return !op.getType().isa<FunctionType>();
|
||||
});
|
||||
target.addDynamicallyLegalOp<ConstantOp>(
|
||||
[](ConstantOp op) { return !op.getType().isa<FunctionType>(); });
|
||||
target.addIllegalOp<CallIndirectOp>();
|
||||
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
|
||||
|
||||
RewritePatternSet dummyPatterns(context);
|
||||
|
||||
if (failed(applyFullConversion(getOperation(), target,
|
||||
std::move(dummyPatterns)))) {
|
||||
std::move(dummyPatterns)))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
@ -110,6 +109,6 @@ class PrepareForGlobalizeObjectGraphPass
|
|||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::NPCOMP::Torch::createPrepareForGlobalizeObjectGraphPass() {
|
||||
mlir::torch::Torch::createPrepareForGlobalizeObjectGraphPass() {
|
||||
return std::make_unique<PrepareForGlobalizeObjectGraphPass>();
|
||||
}
|
|
@ -9,13 +9,13 @@
|
|||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::Torch;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
namespace {
|
||||
// Convert value semantic ops operating on mutable arrays to instead operate on
|
||||
|
@ -145,6 +145,6 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
|
|||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::NPCOMP::Torch::createReduceOpVariantsPass() {
|
||||
mlir::torch::Torch::createReduceOpVariantsPass() {
|
||||
return std::make_unique<ReduceOpVariantsPass>();
|
||||
}
|
|
@ -11,12 +11,12 @@
|
|||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::Torch;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -86,6 +86,6 @@ class RefinePublicReturnPass
|
|||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::NPCOMP::Torch::createRefinePublicReturnPass() {
|
||||
mlir::torch::Torch::createRefinePublicReturnPass() {
|
||||
return std::make_unique<RefinePublicReturnPass>();
|
||||
}
|
|
@ -15,13 +15,13 @@
|
|||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::Torch;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Analysis.
|
||||
|
@ -264,9 +264,8 @@ public:
|
|||
} else if (auto sum = dyn_cast<AtenSumOp>(op)) {
|
||||
return visitReductionAlongAllDimsOp(sum, operands);
|
||||
} else if (auto sumDimIntList = dyn_cast<AtenSumDimIntListOp>(op)) {
|
||||
return visitReductionAlongDimIntListOp(
|
||||
sumDimIntList, sumDimIntList.dim(), sumDimIntList.keepdim(),
|
||||
operands);
|
||||
return visitReductionAlongDimIntListOp(sumDimIntList, sumDimIntList.dim(),
|
||||
sumDimIntList.keepdim(), operands);
|
||||
} else if (auto meanDim = dyn_cast<AtenMeanDimOp>(op)) {
|
||||
return visitReductionAlongDimIntListOp(meanDim, meanDim.dim(),
|
||||
meanDim.keepdim(), operands);
|
||||
|
@ -1114,7 +1113,7 @@ static Type getMostRefinedStaticType(Value v, TypeAnalyzer &analyzer) {
|
|||
// the right thing forthose ops.
|
||||
//
|
||||
static bool allowsTypeRefinementOrIsSafeToRefine(Operation *op) {
|
||||
return allowsTypeRefinement(op) ||
|
||||
return op->hasTrait<mlir::torch::Torch::OpTrait::AllowsTypeRefinement>() ||
|
||||
isa<CopyToNonValueTensorOp, CopyToValueTensorOp>(op);
|
||||
}
|
||||
|
||||
|
@ -1244,6 +1243,6 @@ class RefineTypesPass : public RefineTypesBase<RefineTypesPass> {
|
|||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::NPCOMP::Torch::createRefineTypesPass() {
|
||||
mlir::torch::Torch::createRefineTypesPass() {
|
||||
return std::make_unique<RefineTypesPass>();
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/InitAll.h"
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
|
||||
void mlir::torch::registerAllDialects(mlir::DialectRegistry ®istry) {
|
||||
registry.insert<mlir::torch::Torch::TorchDialect>();
|
||||
}
|
||||
|
||||
void mlir::torch::registerAllPasses() { mlir::torch::registerTorchPasses(); }
|
|
@ -0,0 +1,30 @@
|
|||
llvm_canonicalize_cmake_booleans(
|
||||
MLIR_ENABLE_BINDINGS_PYTHON
|
||||
)
|
||||
|
||||
configure_lit_site_cfg(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py
|
||||
MAIN_CONFIG
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py
|
||||
)
|
||||
|
||||
set(TORCH_MLIR_TEST_DEPENDS
|
||||
FileCheck count not
|
||||
torch-mlir-opt
|
||||
)
|
||||
|
||||
# XXX: Enable
|
||||
#if(MLIR_ENABLE_BINDINGS_PYTHON)
|
||||
# list(APPEND TORCH_MLIR_TEST_DEPENDS
|
||||
# TorchMLIRPythonModules
|
||||
# )
|
||||
#endif()
|
||||
|
||||
add_lit_testsuite(check-torch-mlir "Running the torch-mlir regression tests"
|
||||
${CMAKE_CURRENT_BINARY_DIR}
|
||||
DEPENDS ${TORCH_MLIR_TEST_DEPENDS}
|
||||
)
|
||||
set_target_properties(check-torch-mlir PROPERTIES FOLDER "Tests")
|
||||
|
||||
add_lit_testsuites(TORCH_MLIR ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TORCH_MLIR_TEST_DEPENDS})
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||
|
||||
// Basic case.
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: npcomp-opt -torch-globalize-object-graph -verify-diagnostics -split-input-file %s
|
||||
// RUN: torch-mlir-opt -torch-globalize-object-graph -verify-diagnostics -split-input-file %s
|
||||
|
||||
torch.class_type @c1 {}
|
||||
torch.class_type @c2 {}
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||
|
||||
torch.class_type @c {
|
||||
torch.attr "float" : !torch.float
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||
|
||||
// CHECK that multiple nested initialization ops are properly handled.
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||
|
||||
torch.class_type @c {
|
||||
torch.attr "float" : !torch.float
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file -verify-diagnostics %s
|
||||
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file -verify-diagnostics %s
|
||||
|
||||
torch.class_type @parent {
|
||||
torch.method "module_type_return", @module_type_return
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||
|
||||
torch.class_type @child {
|
||||
torch.attr "float" : !torch.float
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: npcomp-opt -torch-globalize-object-graph -verify-diagnostics -split-input-file %s
|
||||
// RUN: torch-mlir-opt -torch-globalize-object-graph -verify-diagnostics -split-input-file %s
|
||||
|
||||
func private @__torch__.Submodule.forward(%arg0: !torch.nn.Module<"__torch__.Submodule">, %arg1: !torch.nn.Module<"__torch__.Submodule">) {
|
||||
return
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||
|
||||
// Tests monomorphization of same function with different instance argument types.
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||
|
||||
torch.class_type @__torch__.TestModule {
|
||||
torch.attr private "s1" : !torch.nn.Module<"__torch__.Submodule">
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||
|
||||
// Check that linkage names consist of the dotted path from the root.
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||
|
||||
torch.class_type @c {
|
||||
// CHECK: torch.global_slot "private" @float : !torch.float
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: npcomp-opt -torch-adjust-calling-conventions -allow-unregistered-dialect -split-input-file %s | FileCheck %s
|
||||
// RUN: torch-mlir-opt -torch-adjust-calling-conventions -allow-unregistered-dialect -split-input-file %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @basic(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: npcomp-opt %s -canonicalize | FileCheck %s
|
||||
// RUN: torch-mlir-opt %s -canonicalize | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.__is__
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: npcomp-opt -torch-inline-global-slots -split-input-file %s | FileCheck %s
|
||||
// RUN: torch-mlir-opt -torch-inline-global-slots -split-input-file %s | FileCheck %s
|
||||
|
||||
// CHECK-NOT: @readonly
|
||||
torch.global_slot "private" @readonly : !torch.tensor {
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: npcomp-opt <%s -split-input-file -verify-diagnostics
|
||||
// RUN: torch-mlir-opt <%s -split-input-file -verify-diagnostics
|
||||
|
||||
// -----
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: npcomp-opt -split-input-file -allow-unregistered-dialect %s -torch-maximize-value-semantics | FileCheck %s
|
||||
// RUN: torch-mlir-opt -split-input-file -allow-unregistered-dialect %s -torch-maximize-value-semantics | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @torch.copy.tensor$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: npcomp-opt %s | npcomp-opt | FileCheck %s
|
||||
// RUN: torch-mlir-opt %s | torch-mlir-opt | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @torch.operator(
|
||||
func @torch.operator(%arg0: !torch.tensor, %arg1: !torch.tensor) -> !torch.tensor {
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: npcomp-opt -torch-prepare-for-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||
// RUN: torch-mlir-opt -torch-prepare-for-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||
|
||||
torch.class_type @c {
|
||||
torch.method "test_call_method", @test_call_method
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: npcomp-opt -torch-reduce-op-variants %s | FileCheck %s
|
||||
// RUN: torch-mlir-opt -torch-reduce-op-variants %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @convert_to_value_semantic_tensors(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: npcomp-opt -split-input-file -verify-diagnostics %s -torch-refine-public-return | FileCheck %s
|
||||
// RUN: torch-mlir-opt -split-input-file -verify-diagnostics %s -torch-refine-public-return | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @basic(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor<[2,3,?],f32> {
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: npcomp-opt -torch-refine-types -split-input-file %s | FileCheck %s
|
||||
// RUN: torch-mlir-opt -torch-refine-types -split-input-file %s | FileCheck %s
|
||||
|
||||
// -----
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: npcomp-opt -torch-refine-types -split-input-file %s | FileCheck %s
|
||||
// RUN: torch-mlir-opt -torch-refine-types -split-input-file %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @f(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {
|
|
@ -0,0 +1,71 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
import lit.formats
|
||||
import lit.util
|
||||
|
||||
from lit.llvm import llvm_config
|
||||
from lit.llvm.subst import ToolSubst
|
||||
from lit.llvm.subst import FindTool
|
||||
|
||||
# Configuration file for the 'lit' test runner.
|
||||
|
||||
# name: The name of this test suite.
|
||||
config.name = 'TORCH_MLIR'
|
||||
|
||||
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
|
||||
|
||||
# suffixes: A list of file extensions to treat as test files.
|
||||
config.suffixes = ['.mlir', '.py']
|
||||
|
||||
# test_source_root: The root path where tests are located.
|
||||
config.test_source_root = os.path.dirname(__file__)
|
||||
|
||||
# test_exec_root: The root path where tests should be run.
|
||||
config.test_exec_root = os.path.join(config.torch_mlir_obj_root, 'test')
|
||||
|
||||
config.substitutions.append(('%PATH%', config.environment['PATH']))
|
||||
config.substitutions.append(('%shlibext', config.llvm_shlib_ext))
|
||||
|
||||
llvm_config.with_system_environment(['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP'])
|
||||
|
||||
#llvm_config.use_default_substitutions()
|
||||
|
||||
# excludes: A list of directories to exclude from the testsuite. The 'Inputs'
|
||||
# subdirectories contain auxiliary inputs for various tests in their parent
|
||||
# directories.
|
||||
config.excludes = [
|
||||
'Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt',
|
||||
'lit.cfg.py', 'lit.site.cfg.py'
|
||||
]
|
||||
|
||||
# test_source_root: The root path where tests are located.
|
||||
config.test_source_root = os.path.dirname(__file__)
|
||||
|
||||
# test_exec_root: The root path where tests should be run.
|
||||
config.test_exec_root = os.path.join(config.torch_mlir_obj_root, 'test')
|
||||
config.standalone_tools_dir = os.path.join(config.torch_mlir_obj_root, 'bin')
|
||||
|
||||
# Tweak the PATH to include the tools dir.
|
||||
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
|
||||
|
||||
tool_dirs = [config.llvm_tools_dir]
|
||||
tools = [
|
||||
ToolSubst('%PYTHON', config.python_executable, unresolved='ignore'),
|
||||
]
|
||||
|
||||
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
||||
|
||||
if config.enable_bindings_python:
|
||||
llvm_config.with_environment('PYTHONPATH', [
|
||||
os.path.join(config.torch_mlir_obj_root, 'python_packages',
|
||||
'torch_mlir'),
|
||||
],
|
||||
append_path=True)
|
|
@ -0,0 +1,21 @@
|
|||
@LIT_SITE_CFG_IN_HEADER@
|
||||
|
||||
import sys
|
||||
|
||||
config.enable_bindings_python = @MLIR_ENABLE_BINDINGS_PYTHON@
|
||||
config.torch_mlir_obj_root = "@TORCH_MLIR_BINARY_DIR@"
|
||||
config.llvm_src_root = "@LLVM_SOURCE_DIR@"
|
||||
config.llvm_obj_root = "@LLVM_BINARY_DIR@"
|
||||
config.llvm_tools_dir = "@LLVM_TOOLS_DIR@"
|
||||
config.llvm_lib_dir = "@LLVM_LIBS_DIR@"
|
||||
config.llvm_shlib_dir = "@SHLIBDIR@"
|
||||
config.llvm_shlib_ext = "@SHLIBEXT@"
|
||||
config.llvm_exe_ext = "@EXEEXT@"
|
||||
config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@"
|
||||
config.python_executable = sys.executable
|
||||
|
||||
import lit.llvm
|
||||
lit.llvm.initialize(lit_config, config)
|
||||
|
||||
# Let the main config do the real work.
|
||||
lit_config.load_config(config, "@TORCH_MLIR_SOURCE_DIR@/test/lit.cfg.py")
|
|
@ -0,0 +1,2 @@
|
|||
if not config.enable_bindings_python:
|
||||
config.unsupported = True
|
|
@ -0,0 +1,9 @@
|
|||
# RUN: %PYTHON %s
|
||||
# XXX: Fix this
|
||||
# XFAIL: *
|
||||
|
||||
import mlir.ir
|
||||
from mlir.dialects import iree
|
||||
|
||||
with mlir.ir.Context() as ctx:
|
||||
iree.register_iree_dialect(ctx)
|
|
@ -0,0 +1 @@
|
|||
add_subdirectory(torch-mlir-opt)
|
|
@ -0,0 +1,13 @@
|
|||
add_executable(torch-mlir-opt torch-mlir-opt.cpp)
|
||||
|
||||
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
|
||||
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
|
||||
|
||||
target_link_libraries(torch-mlir-opt PRIVATE
|
||||
MLIROptLib
|
||||
TorchMLIRInitAll
|
||||
TorchMLIRTorchDialect
|
||||
TorchMLIRTorchPasses
|
||||
${dialect_libs}
|
||||
${conversion_libs}
|
||||
)
|
|
@ -0,0 +1,27 @@
|
|||
//===- torch-mlir-opt.cpp - MLIR Optimizer Driver -------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/InitAllDialects.h"
|
||||
#include "mlir/InitAllPasses.h"
|
||||
#include "mlir/Support/MlirOptMain.h"
|
||||
#include "torch-mlir/InitAll.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
registerAllPasses();
|
||||
mlir::torch::registerAllPasses();
|
||||
|
||||
DialectRegistry registry;
|
||||
registerAllDialects(registry);
|
||||
mlir::torch::registerAllDialects(registry);
|
||||
|
||||
return mlir::asMainReturnCode(
|
||||
mlir::MlirOptMain(argc, argv, "MLIR modular optimizer driver\n", registry,
|
||||
/*preloadDialectsInContext=*/false));
|
||||
}
|
|
@ -11,7 +11,7 @@
|
|||
|
||||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "npcomp-c/TorchTypes.h"
|
||||
#include "torch-mlir-c/TorchTypes.h"
|
||||
|
||||
#include <ATen/core/function_schema.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
|
@ -510,14 +510,14 @@ MlirType AcapController::mapIValueToMlirType(MlirLocation loc,
|
|||
return mlirIntegerTypeGet(funcBuilder->getContext(), 1);
|
||||
}
|
||||
if (ival.isList()) {
|
||||
return npcompTorchListTypeGet(
|
||||
return torchMlirTorchListTypeGet(
|
||||
typeMapper.mapFromTorchType(loc, ival.toList().elementType()));
|
||||
}
|
||||
if (ival.isNone()) {
|
||||
return npcompTorchNoneTypeGet(funcBuilder->getContext());
|
||||
return torchMlirTorchNoneTypeGet(funcBuilder->getContext());
|
||||
}
|
||||
if (ival.isDevice()) {
|
||||
return npcompTorchNoneTypeGet(funcBuilder->getContext());
|
||||
return torchMlirTorchNoneTypeGet(funcBuilder->getContext());
|
||||
}
|
||||
return {nullptr};
|
||||
}
|
||||
|
@ -527,7 +527,7 @@ MlirValue AcapController::importTensorByValue(at::Tensor tensor) {
|
|||
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
|
||||
MlirOperation tensorOp = createMlirOperationAtEnd(
|
||||
funcBuilder->getEntryBlock(), "torch.tensor.literal", loc,
|
||||
npcompTorchNonValueTensorTypeGetFromShaped(
|
||||
torchMlirTorchNonValueTensorTypeGetFromShaped(
|
||||
mlirAttributeGetType(denseElements)),
|
||||
toMlirNamedAttribute("value", denseElements));
|
||||
MlirValue tensorValue = mlirOperationGetResult(tensorOp, 0);
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "mlir-c/Diagnostics.h"
|
||||
#include "npcomp-c/TorchTypes.h"
|
||||
#include "torch-mlir-c/TorchTypes.h"
|
||||
|
||||
using namespace torch_mlir;
|
||||
|
||||
|
@ -98,7 +98,7 @@ MlirValue FuncBuilder::getScalarConstant(MlirLocation loc, at::Scalar s) {
|
|||
// represented as one of double or int64_t, with a special tag for whether
|
||||
// it should be interpreted as a bool.
|
||||
if (s.isIntegral(/*includeBool=*/false)) {
|
||||
MlirType t = npcompTorchIntTypeGet(context);
|
||||
MlirType t = torchMlirTorchIntTypeGet(context);
|
||||
MlirAttribute value =
|
||||
mlirIntegerAttrGet(mlirIntegerTypeGet(context, 64), s.to<int64_t>());
|
||||
MlirOperation op = createMlirOperation(
|
||||
|
@ -107,7 +107,7 @@ MlirValue FuncBuilder::getScalarConstant(MlirLocation loc, at::Scalar s) {
|
|||
return mlirOperationGetResult(op, 0);
|
||||
}
|
||||
if (s.isFloatingPoint()) {
|
||||
MlirType t = npcompTorchFloatTypeGet(context);
|
||||
MlirType t = torchMlirTorchFloatTypeGet(context);
|
||||
MlirAttribute value = mlirFloatAttrDoubleGet(
|
||||
context, mlirF64TypeGet(context), s.to<double>());
|
||||
MlirOperation op = createMlirOperation(
|
||||
|
@ -133,7 +133,7 @@ MlirValue FuncBuilder::getNoneConstant(MlirLocation loc) {
|
|||
|
||||
MlirValue FuncBuilder::buildList(MlirLocation loc, MlirType elementType,
|
||||
std::vector<MlirValue> &elements) {
|
||||
MlirType resultType = npcompTorchListTypeGet(elementType);
|
||||
MlirType resultType = torchMlirTorchListTypeGet(elementType);
|
||||
OperationStateHolder state{"torch.prim.ListConstruct", loc};
|
||||
mlirOperationStateAddResults(state, 1, &resultType);
|
||||
mlirOperationStateAddOperands(state, elements.size(), elements.data());
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "mlir-c/Diagnostics.h"
|
||||
#include "npcomp-c/TorchTypes.h"
|
||||
#include "torch-mlir-c/TorchTypes.h"
|
||||
|
||||
#include "caffe2/core/scope_guard.h"
|
||||
#include "ATen/native/quantized/cpu/packed_params.h"
|
||||
|
@ -170,7 +170,7 @@ IValueImporter::importModule(torch::jit::Module currentModule) {
|
|||
|
||||
MlirOperation nnModule = createMlirOperation(
|
||||
"torch.nn_module", loc,
|
||||
npcompTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
|
||||
torchMlirTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
|
||||
mlirRegionCreate());
|
||||
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
|
||||
mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr));
|
||||
|
@ -240,7 +240,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
|||
MlirLocation loc = mlirLocationUnknownGet(context);
|
||||
|
||||
if (ivalue.isBool()) {
|
||||
MlirType type = npcompTorchBoolTypeGet(context);
|
||||
MlirType type = torchMlirTorchBoolTypeGet(context);
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
importBlock, "torch.constant.bool", loc, type,
|
||||
toMlirNamedAttribute("value",
|
||||
|
@ -248,7 +248,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
|||
return mlirOperationGetResult(operation, 0);
|
||||
}
|
||||
if (ivalue.isDouble()) {
|
||||
MlirType type = npcompTorchFloatTypeGet(context);
|
||||
MlirType type = torchMlirTorchFloatTypeGet(context);
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
importBlock, "torch.constant.float", loc, type,
|
||||
toMlirNamedAttribute(
|
||||
|
@ -257,7 +257,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
|||
return mlirOperationGetResult(operation, 0);
|
||||
}
|
||||
if (ivalue.isInt()) {
|
||||
MlirType type = npcompTorchIntTypeGet(context);
|
||||
MlirType type = torchMlirTorchIntTypeGet(context);
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
importBlock, "torch.constant.int", loc, type,
|
||||
toMlirNamedAttribute("value",
|
||||
|
@ -273,7 +273,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
|||
}
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
importBlock, "torch.prim.ListConstruct", loc,
|
||||
npcompTorchListTypeGet(
|
||||
torchMlirTorchListTypeGet(
|
||||
typeMapper.mapFromTorchType(loc, list.elementType())),
|
||||
elems);
|
||||
return mlirOperationGetResult(operation, 0);
|
||||
|
@ -288,7 +288,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
|||
}
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
importBlock, "torch.prim.DictConstruct", loc,
|
||||
npcompTorchDictTypeGet(
|
||||
torchMlirTorchDictTypeGet(
|
||||
typeMapper.mapFromTorchType(loc, dict.keyType()),
|
||||
typeMapper.mapFromTorchType(loc, dict.valueType())),
|
||||
keys, values);
|
||||
|
@ -305,7 +305,8 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
|||
}
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
importBlock, "torch.prim.TupleConstruct", loc,
|
||||
npcompTorchTupleTypeGet(context, types.size(), types.data()), operands);
|
||||
torchMlirTorchTupleTypeGet(context, types.size(), types.data()),
|
||||
operands);
|
||||
return mlirOperationGetResult(operation, 0);
|
||||
}
|
||||
if (ivalue.isTensor()) {
|
||||
|
@ -317,7 +318,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
|||
if (ivalue.isString()) {
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
importBlock, "torch.constant.str", loc,
|
||||
npcompTorchStringTypeGet(context),
|
||||
torchMlirTorchStringTypeGet(context),
|
||||
toMlirNamedAttribute(
|
||||
"value",
|
||||
mlirStringAttrGet(context,
|
||||
|
@ -327,7 +328,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
|||
if (ivalue.isNone()) {
|
||||
MlirOperation operation =
|
||||
createMlirOperationAtEnd(importBlock, "torch.constant.none", loc,
|
||||
npcompTorchNoneTypeGet(context));
|
||||
torchMlirTorchNoneTypeGet(context));
|
||||
return mlirOperationGetResult(operation, 0);
|
||||
}
|
||||
if (ivalue.isCustomClass()) {
|
||||
|
@ -346,7 +347,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
|||
}
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
importBlock, "torch.linear_params.create", loc,
|
||||
npcompTorchLinearParamsTypeGet(context), weightValue, biasValue);
|
||||
torchMlirTorchLinearParamsTypeGet(context), weightValue, biasValue);
|
||||
return mlirOperationGetResult(operation, 0);
|
||||
}
|
||||
}
|
||||
|
@ -366,7 +367,7 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
|
|||
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
|
||||
MlirOperation tensorOp =
|
||||
createMlirOperationAtEnd(importBlock, "torch.tensor.literal", loc,
|
||||
npcompTorchNonValueTensorTypeGetFromShaped(
|
||||
torchMlirTorchNonValueTensorTypeGetFromShaped(
|
||||
mlirAttributeGetType(denseElements)),
|
||||
toMlirNamedAttribute("value", denseElements));
|
||||
MlirValue tensorReprValue = mlirOperationGetResult(tensorOp, 0);
|
||||
|
@ -381,7 +382,7 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
|
|||
// compiler stages that are building a statically modeled quantization
|
||||
// representation will need to convert this to their representation.
|
||||
std::vector<int64_t> shape(tensor.sizes().begin(), tensor.sizes().end());
|
||||
MlirType quantizedTensorType = npcompTorchNonValueTensorTypeGet(
|
||||
MlirType quantizedTensorType = torchMlirTorchNonValueTensorTypeGet(
|
||||
context, shape.size(), shape.data(),
|
||||
typeMapper.mapFromTorchScalarType(tensor.scalar_type()));
|
||||
if (tensor.qscheme() == c10::kPerTensorAffine) {
|
||||
|
@ -531,11 +532,11 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
|
|||
int64_t dummy;
|
||||
int64_t *shapeData = shape.size() == 0 ? &dummy : shape.data();
|
||||
if (hasValueSemantics) {
|
||||
typeBound = npcompTorchValueTensorTypeGet(context, shape.size(),
|
||||
shapeData, dtype);
|
||||
} else {
|
||||
typeBound = npcompTorchNonValueTensorTypeGet(context, shape.size(),
|
||||
typeBound = torchMlirTorchValueTensorTypeGet(context, shape.size(),
|
||||
shapeData, dtype);
|
||||
} else {
|
||||
typeBound = torchMlirTorchNonValueTensorTypeGet(
|
||||
context, shape.size(), shapeData, dtype);
|
||||
}
|
||||
|
||||
MlirNamedAttribute typeBoundAttr = toMlirNamedAttribute(
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include "mlir-c/Diagnostics.h"
|
||||
#include "mlir-c/Registration.h"
|
||||
#include "npcomp-c/Registration.h"
|
||||
#include "torch-mlir-c/Registration.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace torch_mlir;
|
||||
|
@ -114,7 +115,7 @@ ModuleBuilder::ModuleBuilder(pybind11::object contextObj)
|
|||
// TODO: Rework this once dialect registration C-APIs are in place.
|
||||
// https://reviews.llvm.org/D88162
|
||||
mlirRegisterAllDialects(context);
|
||||
npcompRegisterAllDialects(context);
|
||||
torchMlirRegisterAllDialects(context);
|
||||
|
||||
registerPythonSysStderrDiagnosticHandler(context);
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "mlir-c/Diagnostics.h"
|
||||
#include "npcomp-c/TorchTypes.h"
|
||||
#include "torch-mlir-c/TorchTypes.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace torch_mlir;
|
||||
|
@ -150,7 +150,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
importAttribute(loc, node, c10::attr::value)));
|
||||
} else if (output->type()->cast<c10::StringType>()) {
|
||||
op = createMlirOperation(
|
||||
"torch.constant.str", loc, npcompTorchStringTypeGet(context),
|
||||
"torch.constant.str", loc, torchMlirTorchStringTypeGet(context),
|
||||
toMlirNamedAttribute(
|
||||
"value", mlirStringAttrGet(context, toMlirStringRef(node->s(
|
||||
c10::attr::value)))));
|
||||
|
@ -186,7 +186,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
mlirRegionCreate());
|
||||
mapResults(node, operation);
|
||||
std::vector<MlirType> terminatorOperandTypes = {
|
||||
npcompTorchBoolTypeGet(context)};
|
||||
torchMlirTorchBoolTypeGet(context)};
|
||||
terminatorOperandTypes.insert(terminatorOperandTypes.end(),
|
||||
resultTypes.begin(), resultTypes.end());
|
||||
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "mlir-c/Diagnostics.h"
|
||||
#include "npcomp-c/TorchTypes.h"
|
||||
#include "torch-mlir-c/TorchTypes.h"
|
||||
|
||||
using namespace torch_mlir;
|
||||
|
||||
|
@ -18,11 +18,11 @@ OpBuilder::OpBuilder(MlirContext context) : context(context) {}
|
|||
|
||||
MlirOperation OpBuilder::createNoneConstant(MlirLocation loc) {
|
||||
return createMlirOperation("torch.constant.none", loc,
|
||||
npcompTorchNoneTypeGet(context));
|
||||
torchMlirTorchNoneTypeGet(context));
|
||||
}
|
||||
|
||||
MlirOperation OpBuilder::createBoolConstant(MlirLocation loc, bool value) {
|
||||
return createMlirOperation(
|
||||
"torch.constant.bool", loc, npcompTorchBoolTypeGet(context),
|
||||
"torch.constant.bool", loc, torchMlirTorchBoolTypeGet(context),
|
||||
toMlirNamedAttribute("value", mlirBoolAttrGet(context, value)));
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "mlir-c/Diagnostics.h"
|
||||
#include "npcomp-c/TorchTypes.h"
|
||||
#include "torch-mlir-c/TorchTypes.h"
|
||||
|
||||
using namespace torch_mlir;
|
||||
|
||||
|
@ -66,7 +66,7 @@ MlirType TypeMapper::rawMapFromTorchScalarType(c10::ScalarType scalarType) {
|
|||
case ScalarType::Half:
|
||||
return mlirF16TypeGet(context);
|
||||
case ScalarType::QInt8:
|
||||
return npcompTorchQInt8TypeGet(context);
|
||||
return torchMlirTorchQInt8TypeGet(context);
|
||||
default: {
|
||||
return {nullptr};
|
||||
}
|
||||
|
@ -105,7 +105,7 @@ static MlirType mapCustomClassType(MlirContext context, MlirLocation loc,
|
|||
|
||||
// Individually handle the custom classes that we know about.
|
||||
if (name == "__torch__.torch.classes.quantized.LinearPackedParamsBase") {
|
||||
return npcompTorchLinearParamsTypeGet(context);
|
||||
return torchMlirTorchLinearParamsTypeGet(context);
|
||||
}
|
||||
|
||||
// At this point, we know that the type is indeed a custom class type, but
|
||||
|
@ -136,11 +136,11 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
|||
auto &sizes = tensorType->symbolic_sizes();
|
||||
if (!sizes.rank()) {
|
||||
// Unranked.
|
||||
return npcompTorchNonValueTensorTypeGet(context,
|
||||
/*numSizes=*/0,
|
||||
/*optionalSizes=*/nullptr,
|
||||
/*optionalDtype=*/
|
||||
elementType);
|
||||
return torchMlirTorchNonValueTensorTypeGet(context,
|
||||
/*numSizes=*/0,
|
||||
/*optionalSizes=*/nullptr,
|
||||
/*optionalDtype=*/
|
||||
elementType);
|
||||
}
|
||||
// Ranked with possibly dynamic dims.
|
||||
auto &symbolicShape = tensorType->symbolic_sizes();
|
||||
|
@ -150,28 +150,28 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
|||
auto shapeSymbol = symbolicShape[i];
|
||||
dims[i] = shapeSymbol.is_static() ? shapeSymbol.static_size() : -1;
|
||||
}
|
||||
return npcompTorchNonValueTensorTypeGet(context, dims.size(),
|
||||
/*optionalSizes=*/dims.data(),
|
||||
/*optionalDtype=*/
|
||||
elementType);
|
||||
return torchMlirTorchNonValueTensorTypeGet(context, dims.size(),
|
||||
/*optionalSizes=*/dims.data(),
|
||||
/*optionalDtype=*/
|
||||
elementType);
|
||||
}
|
||||
case TypeKind::IntType: {
|
||||
return npcompTorchIntTypeGet(context);
|
||||
return torchMlirTorchIntTypeGet(context);
|
||||
}
|
||||
case TypeKind::FloatType: {
|
||||
return npcompTorchFloatTypeGet(context);
|
||||
return torchMlirTorchFloatTypeGet(context);
|
||||
}
|
||||
case TypeKind::BoolType: {
|
||||
return npcompTorchBoolTypeGet(context);
|
||||
return torchMlirTorchBoolTypeGet(context);
|
||||
}
|
||||
case TypeKind::NumberType: {
|
||||
return npcompTorchNumberTypeGet(context);
|
||||
return torchMlirTorchNumberTypeGet(context);
|
||||
}
|
||||
case TypeKind::StringType: {
|
||||
return npcompTorchStringTypeGet(context);
|
||||
return torchMlirTorchStringTypeGet(context);
|
||||
}
|
||||
case TypeKind::OptionalType: {
|
||||
return npcompTorchOptionalTypeGet(mapFromTorchType(
|
||||
return torchMlirTorchOptionalTypeGet(mapFromTorchType(
|
||||
loc, torchType->cast<c10::OptionalType>()->getElementType()));
|
||||
}
|
||||
case TypeKind::TupleType: {
|
||||
|
@ -180,25 +180,25 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
|||
torchType->cast<c10::TupleType>()->containedTypes()) {
|
||||
containedTypes.push_back(mapFromTorchType(loc, type));
|
||||
}
|
||||
return npcompTorchTupleTypeGet(context, containedTypes.size(),
|
||||
containedTypes.data());
|
||||
return torchMlirTorchTupleTypeGet(context, containedTypes.size(),
|
||||
containedTypes.data());
|
||||
}
|
||||
case TypeKind::ListType: {
|
||||
return npcompTorchListTypeGet(mapFromTorchType(
|
||||
return torchMlirTorchListTypeGet(mapFromTorchType(
|
||||
loc, torchType->cast<c10::ListType>()->getElementType()));
|
||||
}
|
||||
case TypeKind::DictType: {
|
||||
auto dictType = torchType->cast<c10::DictType>();
|
||||
return npcompTorchDictTypeGet(
|
||||
return torchMlirTorchDictTypeGet(
|
||||
mapFromTorchType(loc, dictType->getKeyType()),
|
||||
mapFromTorchType(loc, dictType->getValueType()));
|
||||
}
|
||||
case TypeKind::NoneType: {
|
||||
return npcompTorchNoneTypeGet(context);
|
||||
return torchMlirTorchNoneTypeGet(context);
|
||||
}
|
||||
case TypeKind::AnyType: {
|
||||
auto anyType = torchType->cast<c10::AnyType>();
|
||||
return npcompTorchAnyTypeGet(context);
|
||||
return torchMlirTorchAnyTypeGet(context);
|
||||
}
|
||||
case TypeKind::ClassType: {
|
||||
const c10::ClassTypePtr &classType = torchType->cast<c10::ClassType>();
|
||||
|
@ -208,10 +208,10 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
|||
}
|
||||
auto maybeName = classType->name();
|
||||
std::string name = maybeName ? maybeName->qualifiedName() : "unnamed class";
|
||||
return npcompTorchNnModuleTypeGet(context, toMlirStringRef(name));
|
||||
return torchMlirTorchNnModuleTypeGet(context, toMlirStringRef(name));
|
||||
}
|
||||
case TypeKind::DeviceObjType: {
|
||||
return npcompTorchDeviceTypeGet(context);
|
||||
return torchMlirTorchDeviceTypeGet(context);
|
||||
}
|
||||
default: {
|
||||
std::stringstream message;
|
||||
|
@ -226,7 +226,7 @@ MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) {
|
|||
if (!tensor.defined()) {
|
||||
// Undefined tensors are equivalent to None.
|
||||
// This may need to be re-evaluated at some point.
|
||||
return npcompTorchNoneTypeGet(context);
|
||||
return torchMlirTorchNoneTypeGet(context);
|
||||
}
|
||||
|
||||
MlirType elementType = mapFromTorchScalarType(tensor.scalar_type());
|
||||
|
@ -234,8 +234,8 @@ MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) {
|
|||
// just erase them and let the compiler decide.
|
||||
|
||||
auto sizes = tensor.sizes();
|
||||
return npcompTorchNonValueTensorTypeGet(context, sizes.size(), sizes.data(),
|
||||
elementType);
|
||||
return torchMlirTorchNonValueTensorTypeGet(context, sizes.size(),
|
||||
sizes.data(), elementType);
|
||||
}
|
||||
|
||||
MlirType
|
||||
|
|
|
@ -2,5 +2,4 @@ add_subdirectory(Basicpy)
|
|||
add_subdirectory(Numpy)
|
||||
add_subdirectory(Refback)
|
||||
add_subdirectory(Refbackrt)
|
||||
add_subdirectory(Torch)
|
||||
add_subdirectory(TorchConversion)
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Interfaces/CastInterfaces.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "npcomp/Interfaces/Traits.h"
|
||||
#include "npcomp/Typing/Analysis/CPA/Interfaces.h"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
|
|
|
@ -10,7 +10,6 @@
|
|||
#define NPCOMP_DIALECT_NUMPY_IR_NUMPY_OPS
|
||||
|
||||
include "npcomp/Dialect/Numpy/IR/NumpyDialect.td"
|
||||
include "npcomp/Interfaces/Traits.td"
|
||||
include "npcomp/Typing/Analysis/CPA/Interfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/CastInterfaces.td"
|
||||
|
@ -39,7 +38,6 @@ def Numpy_NarrowOp : Numpy_Op<"narrow", []> {
|
|||
|
||||
def Numpy_StaticInfoCastOp : Numpy_Op<"static_info_cast", [
|
||||
DeclareOpInterfaceMethods<CastOpInterface>,
|
||||
AllowsTypeRefinement,
|
||||
NoSideEffect]> {
|
||||
let summary = "Adds/removes static information from an array type.";
|
||||
let description = [{
|
||||
|
@ -60,7 +58,6 @@ def Numpy_StaticInfoCastOp : Numpy_Op<"static_info_cast", [
|
|||
|
||||
def Numpy_TensorStaticInfoCastOp : Numpy_Op<"tensor_static_info_cast", [
|
||||
DeclareOpInterfaceMethods<CastOpInterface>,
|
||||
AllowsTypeRefinement,
|
||||
NoSideEffect]> {
|
||||
let summary = "Adds/removes static information from a tensor type.";
|
||||
let description = [{
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||
add_public_tablegen_target(NPCOMPTorchPassIncGen)
|
||||
|
||||
add_mlir_doc(Passes NPCOMPTorchTransforms ./ -gen-pass-doc)
|
|
@ -16,8 +16,8 @@
|
|||
#include "mlir/Interfaces/CastInterfaces.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
||||
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.h.inc"
|
||||
|
|
|
@ -14,7 +14,7 @@ include "mlir/IR/SymbolInterfaces.td"
|
|||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "npcomp/Dialect/TorchConversion/IR/TorchConversionBase.td"
|
||||
include "npcomp/Dialect/Torch/IR/TorchTypes.td"
|
||||
include "torch-mlir/Dialect/Torch/IR/TorchTypes.td"
|
||||
include "iree-dialects/Dialect/IREE/IREEDialect.td"
|
||||
|
||||
class TorchConversion_Op<string mnemonic, list<OpTrait> traits = []>
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
#define NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
|
@ -21,7 +21,8 @@ namespace TorchConversion {
|
|||
/// Creates a pipeline that lowers the object graph IR that is produced by
|
||||
/// TorchScript import into the form expected by npcomp-verify-backend-contract.
|
||||
void createTorchScriptToNpcompBackendPipeline(
|
||||
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options);
|
||||
OpPassManager &pm,
|
||||
const torch::Torch::TorchLoweringPipelineOptions &options);
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createVerifyInvariantsBeforeBackendLoweringPass();
|
||||
|
|
|
@ -13,21 +13,7 @@
|
|||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace OpTrait {
|
||||
|
||||
template <typename ConcreteType>
|
||||
class AllowsTypeRefinement
|
||||
: public ::mlir::OpTrait::TraitBase<ConcreteType, AllowsTypeRefinement> {};
|
||||
|
||||
} // namespace OpTrait
|
||||
|
||||
// Check if an operation has the AllowsTypeRefinement trait.
|
||||
//
|
||||
// This function should be used in preference to
|
||||
// `op->hasTrait<AllowsTypeRefinement>()` because this function has knowledge of
|
||||
// some upstream ops that have this property, but which we cannot annotate with
|
||||
// this trait.
|
||||
bool allowsTypeRefinement(Operation *op);
|
||||
namespace OpTrait {} // namespace OpTrait
|
||||
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
|
|
@ -19,13 +19,6 @@ class NpcompOpTrait<string name> : OpTrait, NativeTrait<"", ""> {
|
|||
let cppNamespace = "::mlir::NPCOMP::OpTrait";
|
||||
}
|
||||
|
||||
// Op allows operand and result types to be refined.
|
||||
// For example a `tensor<?xf32>` can be refined to `tensor<4xf32>`.
|
||||
//
|
||||
// TODO: Implement RefinableTypeInterface that allows actually modeling
|
||||
// which types are refinements of other types.
|
||||
// See the design in:
|
||||
// https://llvm.discourse.group/t/allow-shape-concretization-or-type-concretization-in-rewrites/3327/3
|
||||
def AllowsTypeRefinement : NpcompOpTrait<"AllowsTypeRefinement">;
|
||||
// Empty for now. Kept as boilerplate placeholder.
|
||||
|
||||
#endif // NPCOMP_INTERFACES_TRAITS
|
||||
|
|
|
@ -10,7 +10,6 @@ add_npcomp_library(NPCOMPCAPI
|
|||
Registration.cpp
|
||||
BasicpyTypes.cpp
|
||||
NumpyTypes.cpp
|
||||
TorchTypes.cpp
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRExecutionEngine
|
||||
|
@ -21,7 +20,8 @@ add_npcomp_library(NPCOMPCAPI
|
|||
NPCOMPNumpyDialect
|
||||
NPCOMPRefBackendJITHelpers
|
||||
NPCOMPRuntime
|
||||
NPCOMPTorchDialect
|
||||
TorchMLIRTorchDialect
|
||||
TorchMLIRInitAll
|
||||
|
||||
# MLIR CAPI deps
|
||||
MLIRCAPIIR
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "npcomp/InitAll.h"
|
||||
#include "torch-mlir/InitAll.h"
|
||||
|
||||
void npcompRegisterAllDialects(MlirContext context) {
|
||||
mlir::DialectRegistry registry;
|
||||
|
|
|
@ -31,7 +31,7 @@ add_npcomp_library(NPCOMPInitAll
|
|||
NPCOMPIREEBackend
|
||||
NPCOMPRefBackend
|
||||
NPCOMPRefbackDialect
|
||||
NPCOMPTorchDialect
|
||||
TorchMLIRTorchDialect
|
||||
NPCOMPTorchConversionDialect
|
||||
NPCOMPRefbackrtDialect
|
||||
NPCOMPBasicpyDialect
|
||||
|
|
|
@ -13,7 +13,7 @@ add_npcomp_conversion_library(NPCOMPTorchToIREE
|
|||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
NPCOMPTorchDialect
|
||||
TorchMLIRTorchDialect
|
||||
MLIRStandard
|
||||
IREEDialectsIREEDialect
|
||||
)
|
||||
|
|
|
@ -14,13 +14,13 @@
|
|||
#include "mlir/Dialect/Traits.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
||||
#include "npcomp/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::Torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The patterns
|
||||
|
|
|
@ -15,5 +15,5 @@ add_npcomp_conversion_library(NPCOMPTorchToLinalg
|
|||
MLIRPass
|
||||
MLIRLinalg
|
||||
MLIRMath
|
||||
NPCOMPTorchDialect
|
||||
TorchMLIRTorchDialect
|
||||
)
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue