[torch-mlir earthmoving (1/N)] C/C++ code movement.

This creates the `external/torch-mlir` directory as an
LLVM_EXTERNAL_PROJECTS-compatible project (analogous to
`iree-dialects`) and completes movement/rename of all pure MLIR C/C++
compiler code into there. The next step will be to move all the Python
code / code that links/includes PyTorch C++ code (which currently lives
in `frontends/pytorch`) into a subdirectory here.

I call this "earthmoving" because it is mostly mechanical changes and
renames. As a quick summary (we can change this down the road easily)
- C++ `mlir::NPCOMP::Torch -> mlir::torch::Torch`
- CAPI `npcompTorchListTypeGet -> torchMlirTorchListTypeGet`
- preprocessor `#ifndef NPCOMP_ -> #ifndef TORCHMLIR_`
- CMake `NPCOMPFoo -> TorchMLIRFoo`

The goal of this is to create a standalone project creating a center of
mass for entry into the MLIR ecosystem from PyTorch, suitable in scope
for eventual inclusion/ownership in PyTorch. The idea is that
`external/torch-mlir` will some day be pulled out into its own
repository, and then npcomp will simply pull it in as a submodule.

Layering-wise, what lives in `torch-mlir` lowers code from PyTorch
(currently TorchScript, but TorchFX or pytorch/xla-style tracing are
possible extensions) down to what we have been calling the "Torch
backend contract" which is cleaned up IR (inlining, simplifcation,
conversion to value tensors, ...) entirely in the `torch` dialect. This
is the branching off point for further lowering, of which npcomp takes
one opinion (outside `torch-mlir` of course!), namely the
`TorchConversion` dialect/transforms which lower to IR suitable for IREE
and other linalg-on-tensors based lower-level compilers.

Summary of changes:
- move `{include,lib,test}/Dialect/Torch` into `torch-mlir`
- move relevant parts of CAPI into `torch-mlir`.
- leave a few things related to the `torch-mlir` Python build commented
  out, which should be resolved in a subsequent change.
pull/305/head
Sean Silva 2021-09-09 19:24:10 +00:00
parent 28762699b3
commit 28a7738189
122 changed files with 826 additions and 436 deletions

View File

@ -128,9 +128,10 @@ if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
"suffix = '${PYTHON_MODULE_SUFFIX}', " "suffix = '${PYTHON_MODULE_SUFFIX}', "
"extension = '${PYTHON_MODULE_EXTENSION}") "extension = '${PYTHON_MODULE_EXTENSION}")
# Include the iree-dialects external project. # Include LLVM_EXTERNAL_PROJECTS.
set(LLVM_EXTERNAL_PROJECTS "iree-dialects") set(LLVM_EXTERNAL_PROJECTS "iree-dialects;torch-mlir")
set(LLVM_EXTERNAL_IREE_DIALECTS_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/external/iree-dialects") set(LLVM_EXTERNAL_IREE_DIALECTS_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/external/iree-dialects")
set(LLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/external/torch-mlir")
# LLVM configuration. # LLVM configuration.
message(STATUS "*** ADDING LLVM ***") message(STATUS "*** ADDING LLVM ***")
@ -183,6 +184,8 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/external/iree-dialects/include) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/external/iree-dialects/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/llvm/tools/iree-dialects/include) include_directories(${CMAKE_CURRENT_BINARY_DIR}/llvm/tools/iree-dialects/include)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/external/torch-mlir/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/llvm/tools/torch-mlir/include)
link_directories(${LLVM_BUILD_LIBRARY_DIR}) link_directories(${LLVM_BUILD_LIBRARY_DIR})
add_definitions(${LLVM_DEFINITIONS}) add_definitions(${LLVM_DEFINITIONS})
set(NPCOMP_TABLEGEN_ARGS "") set(NPCOMP_TABLEGEN_ARGS "")

View File

@ -21,6 +21,7 @@ cd $td/build
ninja ninja
ninja check-npcomp ninja check-npcomp
ninja check-torch-mlir
ninja check-frontends-pytorch ninja check-frontends-pytorch
echo echo

View File

@ -0,0 +1,72 @@
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
message(FATAL_ERROR
"This project is intended to be built as part of LLVM via "
"-DLLVM_EXTERNAL_PROJECTS=torch-mlir "
"-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR=${CMAKE_CURRENT_SOURCE_DIR}")
endif()
option(MLIR_ENABLE_BINDINGS_PYTHON "Enables MLIR Python Bindings" OFF)
set(TORCH_MLIR_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
set(TORCH_MLIR_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")
message(STATUS "Building torch-mlir project at ${TORCH_MLIR_SOURCE_DIR} (into ${TORCH_MLIR_BINARY_DIR})")
# TODO: Fix this upstream so that global include directories are not needed.
set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir)
set(MLIR_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include)
set(MLIR_GENERATED_INCLUDE_DIR ${LLVM_BINARY_DIR}/tools/mlir/include)
# TODO: Needed for tablegen. Remove.
include_directories(SYSTEM ${MLIR_INCLUDE_DIR})
include_directories(SYSTEM ${MLIR_GENERATED_INCLUDE_DIR})
include_directories(SYSTEM ${TORCH_MLIR_SOURCE_DIR}/include)
function(torch_mlir_target_includes target)
set(_dirs
$<BUILD_INTERFACE:${MLIR_INCLUDE_DIR}>
$<BUILD_INTERFACE:${MLIR_GENERATED_INCLUDE_DIR}>
$<BUILD_INTERFACE:${TORCH_MLIR_SOURCE_DIR}/include>
$<BUILD_INTERFACE:${TORCH_MLIR_BINARY_DIR}/include>
)
# In LLVM parlance, the actual target may just be an interface and may not
# be responsible for actually compiling anything. The corresponding obj.
# target, when present, is just used for compilation and does not
# contribute to the interface properties.
# TODO: Normalize this upstream.
target_include_directories(${target} PUBLIC ${_dirs})
if(TARGET obj.${target})
target_include_directories(obj.${target} PRIVATE ${_dirs})
endif()
endfunction()
# Configure CMake and tablegen.
list(APPEND CMAKE_MODULE_PATH ${MLIR_MAIN_SRC_DIR}/cmake/modules)
list(APPEND CMAKE_MODULE_PATH ${LLVM_MAIN_SRC_DIR}/cmake)
set(MLIR_TABLEGEN_EXE mlir-tblgen)
include(TableGen)
include(AddLLVM)
include(AddMLIR)
################################################################################
# Setup python.
# TODO: Make one upstream macro to do this.
################################################################################
if(MLIR_ENABLE_BINDINGS_PYTHON)
include(MLIRDetectPythonEnv)
mlir_detect_pybind11_install()
find_package(Python3 ${LLVM_MINIMUM_PYTHON_VERSION}
COMPONENTS Interpreter Development NumPy REQUIRED)
find_package(pybind11 2.6 CONFIG REQUIRED)
endif()
add_subdirectory(include)
add_subdirectory(lib)
add_subdirectory(test)
add_subdirectory(tools)
if(MLIR_ENABLE_BINDINGS_PYTHON)
#XXX: Enable
#add_subdirectory(python)
endif()

View File

@ -0,0 +1 @@
add_subdirectory(torch-mlir)

View File

@ -0,0 +1,32 @@
/*===-- torch-mlir-c/Registration.h - Registration functions -----*- C -*-===*\
|* *|
|* Part of the LLVM Project, under the Apache License v2.0 with LLVM *|
|* Exceptions. *|
|* See https://llvm.org/LICENSE.txt for license information. *|
|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *|
|* *|
\*===----------------------------------------------------------------------===*/
#ifndef TORCHMLIR_C_REGISTRATION_H
#define TORCHMLIR_C_REGISTRATION_H
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#ifdef __cplusplus
extern "C" {
#endif
/** Registers all dialects with a context.
* This is needed before creating IR for these Dialects.
*/
MLIR_CAPI_EXPORTED void torchMlirRegisterAllDialects(MlirContext context);
/** Registers all passes for symbolic access with the global registry. */
MLIR_CAPI_EXPORTED void torchMlirRegisterAllPasses();
#ifdef __cplusplus
}
#endif
#endif // TORCHMLIR_C_REGISTRATION_H

View File

@ -1,4 +1,4 @@
//===-- npcomp-c/TorchTypes.h - C API for torch types -------------*- C -*-===// //===-- torch-mlir-c/TorchTypes.h - C API for torch types ---------*- C -*-===//
// //
// Part of the LLVM Project, under the Apache License v2.0 with LLVM // Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions. // Exceptions.
@ -7,8 +7,8 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef NPCOMP_C_TORCHTYPES_H #ifndef TORCHMLIR_C_TORCHTYPES_H
#define NPCOMP_C_TORCHTYPES_H #define TORCHMLIR_C_TORCHTYPES_H
#include "mlir-c/IR.h" #include "mlir-c/IR.h"
#include "mlir-c/Support.h" #include "mlir-c/Support.h"
@ -22,110 +22,112 @@ extern "C" {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Checks whether the given type is a torch.nn.Module type /// Checks whether the given type is a torch.nn.Module type
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchNnModule(MlirType t); MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNnModule(MlirType t);
/// Gets the !torch.nn.Module type of the specified class. /// Gets the !torch.nn.Module type of the specified class.
MLIR_CAPI_EXPORTED MlirType npcompTorchNnModuleTypeGet(MlirContext context, MLIR_CAPI_EXPORTED MlirType
MlirStringRef className); torchMlirTorchNnModuleTypeGet(MlirContext context, MlirStringRef className);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.optional type. // torch.optional type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.optional<T> type /// Checks whether the given type is a !torch.optional<T> type
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchOptional(MlirType t); MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchOptional(MlirType t);
/// Gets the !torch.optional<T> type with subtype T. /// Gets the !torch.optional<T> type with subtype T.
MLIR_CAPI_EXPORTED MlirType npcompTorchOptionalTypeGet(MlirType containedType); MLIR_CAPI_EXPORTED MlirType
torchMlirTorchOptionalTypeGet(MlirType containedType);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.tuple<T1, T2, T3> type. // torch.tuple<T1, T2, T3> type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.tuple type /// Checks whether the given type is a !torch.tuple type
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchTuple(MlirType t); MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchTuple(MlirType t);
/// Gets the !torch.tuple type with contained types `containedTypes`. /// Gets the !torch.tuple type with contained types `containedTypes`.
MLIR_CAPI_EXPORTED MlirType MLIR_CAPI_EXPORTED MlirType
npcompTorchTupleTypeGet(MlirContext context, intptr_t numContainedTypes, torchMlirTorchTupleTypeGet(MlirContext context, intptr_t numContainedTypes,
MlirType const *containedTypes); MlirType const *containedTypes);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.list<T> type. // torch.list<T> type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.list<T> type /// Checks whether the given type is a !torch.list<T> type
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchList(MlirType t); MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchList(MlirType t);
/// Gets the !torch.list<T> type with contained T. /// Gets the !torch.list<T> type with contained T.
MLIR_CAPI_EXPORTED MlirType npcompTorchListTypeGet(MlirType containedType); MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGet(MlirType containedType);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.Device type. // torch.Device type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.Device type /// Checks whether the given type is a !torch.Device type
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchDevice(MlirType t); MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchDevice(MlirType t);
/// Gets the !torch.Device type. /// Gets the !torch.Device type.
MLIR_CAPI_EXPORTED MlirType npcompTorchDeviceTypeGet(MlirContext context); MLIR_CAPI_EXPORTED MlirType torchMlirTorchDeviceTypeGet(MlirContext context);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.bool type. // torch.bool type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.bool type /// Checks whether the given type is a !torch.bool type
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchBool(MlirType t); MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchBool(MlirType t);
/// Gets the !torch.bool type. /// Gets the !torch.bool type.
MLIR_CAPI_EXPORTED MlirType npcompTorchBoolTypeGet(MlirContext context); MLIR_CAPI_EXPORTED MlirType torchMlirTorchBoolTypeGet(MlirContext context);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.int type. // torch.int type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.int type /// Checks whether the given type is a !torch.int type
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchInt(MlirType t); MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchInt(MlirType t);
/// Gets the !torch.int type. /// Gets the !torch.int type.
MLIR_CAPI_EXPORTED MlirType npcompTorchIntTypeGet(MlirContext context); MLIR_CAPI_EXPORTED MlirType torchMlirTorchIntTypeGet(MlirContext context);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.float type. // torch.float type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.float type /// Checks whether the given type is a !torch.float type
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchFloat(MlirType t); MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchFloat(MlirType t);
/// Gets the !torch.float type. /// Gets the !torch.float type.
MLIR_CAPI_EXPORTED MlirType npcompTorchFloatTypeGet(MlirContext context); MLIR_CAPI_EXPORTED MlirType torchMlirTorchFloatTypeGet(MlirContext context);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.LinearParams type. // torch.LinearParams type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.LinearParams type /// Checks whether the given type is a !torch.LinearParams type
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchLinearParams(MlirType t); MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchLinearParams(MlirType t);
/// Gets the !torch.LinearParams type. /// Gets the !torch.LinearParams type.
MLIR_CAPI_EXPORTED MlirType npcompTorchLinearParamsTypeGet(MlirContext context); MLIR_CAPI_EXPORTED MlirType
torchMlirTorchLinearParamsTypeGet(MlirContext context);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.qint8 type. // torch.qint8 type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.qint8 type /// Checks whether the given type is a !torch.qint8 type
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchQInt8(MlirType t); MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQInt8(MlirType t);
/// Gets the !torch.qint8 type. /// Gets the !torch.qint8 type.
MLIR_CAPI_EXPORTED MlirType npcompTorchQInt8TypeGet(MlirContext context); MLIR_CAPI_EXPORTED MlirType torchMlirTorchQInt8TypeGet(MlirContext context);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.tensor type. // torch.tensor type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.tensor type /// Checks whether the given type is a !torch.tensor type
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchNonValueTensor(MlirType t); MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNonValueTensor(MlirType t);
/// Gets a !torch.tensor type. /// Gets a !torch.tensor type.
/// ///
@ -133,24 +135,25 @@ MLIR_CAPI_EXPORTED bool npcompTypeIsATorchNonValueTensor(MlirType t);
/// information is present (and `numSizes` is ignored in that case). - /// information is present (and `numSizes` is ignored in that case). -
/// `optionalDtype` is allowed to be null, meaning that no dtype /// `optionalDtype` is allowed to be null, meaning that no dtype
/// information is present. /// information is present.
MLIR_CAPI_EXPORTED MlirType npcompTorchNonValueTensorTypeGet( MLIR_CAPI_EXPORTED MlirType torchMlirTorchNonValueTensorTypeGet(
MlirContext context, intptr_t numSizes, const int64_t *optionalSizes, MlirContext context, intptr_t numSizes, const int64_t *optionalSizes,
MlirType optionalDtype); MlirType optionalDtype);
/// Gets the !torch.tensor type with the least static information. /// Gets the !torch.tensor type with the least static information.
MLIR_CAPI_EXPORTED MlirType MLIR_CAPI_EXPORTED MlirType
npcompTorchNonValueTensorTypeGetWithLeastStaticInformation(MlirContext context); torchMlirTorchNonValueTensorTypeGetWithLeastStaticInformation(
MlirContext context);
/// Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`. /// Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`.
MLIR_CAPI_EXPORTED MlirType MLIR_CAPI_EXPORTED MlirType
npcompTorchNonValueTensorTypeGetFromShaped(MlirType type); torchMlirTorchNonValueTensorTypeGetFromShaped(MlirType type);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.vtensor type. // torch.vtensor type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.vtensor type /// Checks whether the given type is a !torch.vtensor type
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchValueTensor(MlirType t); MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchValueTensor(MlirType t);
/// Gets a !torch.vtensor type. /// Gets a !torch.vtensor type.
/// ///
@ -158,71 +161,71 @@ MLIR_CAPI_EXPORTED bool npcompTypeIsATorchValueTensor(MlirType t);
/// information is present (and `numSizes` is ignored in that case). /// information is present (and `numSizes` is ignored in that case).
/// - `optionalDtype` is allowed to be null, meaning that no dtype /// - `optionalDtype` is allowed to be null, meaning that no dtype
/// information is present. /// information is present.
MLIR_CAPI_EXPORTED MlirType npcompTorchValueTensorTypeGet( MLIR_CAPI_EXPORTED MlirType torchMlirTorchValueTensorTypeGet(
MlirContext context, intptr_t numSizes, const int64_t *optionalSizes, MlirContext context, intptr_t numSizes, const int64_t *optionalSizes,
MlirType optionalDtype); MlirType optionalDtype);
/// Gets the !torch.tensor type with the least static information. /// Gets the !torch.tensor type with the least static information.
MLIR_CAPI_EXPORTED MlirType MLIR_CAPI_EXPORTED MlirType
npcompTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context); torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
/// Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`. /// Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`.
MLIR_CAPI_EXPORTED MlirType MLIR_CAPI_EXPORTED MlirType
npcompTorchValueTensorTypeGetFromShaped(MlirType type); torchMlirTorchValueTensorTypeGetFromShaped(MlirType type);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// !torch.none type. // !torch.none type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.none type /// Checks whether the given type is a !torch.none type
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchNone(MlirType t); MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNone(MlirType t);
/// Gets the !torch.none type. /// Gets the !torch.none type.
MLIR_CAPI_EXPORTED MlirType npcompTorchNoneTypeGet(MlirContext context); MLIR_CAPI_EXPORTED MlirType torchMlirTorchNoneTypeGet(MlirContext context);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// !torch.str type. // !torch.str type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.str type /// Checks whether the given type is a !torch.str type
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchString(MlirType t); MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchString(MlirType t);
/// Gets the !torch.str type. /// Gets the !torch.str type.
MLIR_CAPI_EXPORTED MlirType npcompTorchStringTypeGet(MlirContext context); MLIR_CAPI_EXPORTED MlirType torchMlirTorchStringTypeGet(MlirContext context);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// !torch.any type. // !torch.any type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.any type. /// Checks whether the given type is a !torch.any type.
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchAny(MlirType t); MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchAny(MlirType t);
/// Gets the !torch.str type. /// Gets the !torch.str type.
MLIR_CAPI_EXPORTED MlirType npcompTorchAnyTypeGet(MlirContext context); MLIR_CAPI_EXPORTED MlirType torchMlirTorchAnyTypeGet(MlirContext context);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// !torch.number type. // !torch.number type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.number type. /// Checks whether the given type is a !torch.number type.
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchNumber(MlirType t); MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNumber(MlirType t);
/// Gets the !torch.number type. /// Gets the !torch.number type.
MLIR_CAPI_EXPORTED MlirType npcompTorchNumberTypeGet(MlirContext context); MLIR_CAPI_EXPORTED MlirType torchMlirTorchNumberTypeGet(MlirContext context);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// !torch.dict type. // !torch.dict type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.dict type. /// Checks whether the given type is a !torch.dict type.
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchDict(MlirType t); MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchDict(MlirType t);
/// Gets the !torch.dict type. /// Gets the !torch.dict type.
MLIR_CAPI_EXPORTED MlirType npcompTorchDictTypeGet(MlirType keyType, MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGet(MlirType keyType,
MlirType valueType); MlirType valueType);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif
#endif // NPCOMP_C_TORCHTYPES_H #endif // TORCHMLIR_C_TORCHTYPES_H

View File

@ -0,0 +1 @@
add_subdirectory(Dialect)

View File

@ -0,0 +1 @@
add_subdirectory(Torch)

View File

@ -13,7 +13,7 @@ include "mlir/IR/OpBase.td"
def Torch_Dialect : Dialect { def Torch_Dialect : Dialect {
let name = "torch"; let name = "torch";
let cppNamespace = "::mlir::NPCOMP::Torch"; let cppNamespace = "::mlir::torch::Torch";
let description = [{ let description = [{
Top-level dialect for interfacing PyTorch and MLIR. Top-level dialect for interfacing PyTorch and MLIR.
@ -21,7 +21,7 @@ def Torch_Dialect : Dialect {
This dialect also provides transforms that lower it to the This dialect also provides transforms that lower it to the
"Torch backend contract", which is an IR form that we present to "Torch backend contract", which is an IR form that we present to
later conversions, such as conversion to the npcomp backend contract. later conversions.
The Torch backend contract significantly simplifies the IR representation The Torch backend contract significantly simplifies the IR representation
and puts it in a form easier for later lowering to work on. Specifically: and puts it in a form easier for later lowering to work on. Specifically:
- The TorchScript object graph has been flattened to a list of globals (see - The TorchScript object graph has been flattened to a list of globals (see
@ -39,11 +39,12 @@ def Torch_Dialect : Dialect {
class TorchOpTrait<string name> : OpTrait, NativeTrait<"", ""> { class TorchOpTrait<string name> : OpTrait, NativeTrait<"", ""> {
let trait = name; let trait = name;
let cppNamespace = "::mlir::NPCOMP::Torch::OpTrait"; let cppNamespace = "::mlir::torch::Torch::OpTrait";
} }
def HasValueSemantics : TorchOpTrait<"HasValueSemantics">; def HasValueSemantics : TorchOpTrait<"HasValueSemantics">;
def IsTrailingUnderscoreInplaceVariant def IsTrailingUnderscoreInplaceVariant
: TorchOpTrait<"IsTrailingUnderscoreInplaceVariant">; : TorchOpTrait<"IsTrailingUnderscoreInplaceVariant">;
def AllowsTypeRefinement : TorchOpTrait<"AllowsTypeRefinement">;
#endif // TORCH_BASE #endif // TORCH_BASE

View File

@ -6,11 +6,11 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef NPCOMP_DIALECT_TORCH_IR_TORCHDIALECT_H #ifndef TORCHMLIR_DIALECT_TORCH_IR_TORCHDIALECT_H
#define NPCOMP_DIALECT_TORCH_IR_TORCHDIALECT_H #define TORCHMLIR_DIALECT_TORCH_IR_TORCHDIALECT_H
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "npcomp/Dialect/Torch/IR/TorchDialect.h.inc" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h.inc"
#endif // NPCOMP_DIALECT_TORCH_IR_TORCHDIALECT_H #endif // TORCHMLIR_DIALECT_TORCH_IR_TORCHDIALECT_H

View File

@ -6,8 +6,8 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef NPCOMP_DIALECT_TORCH_IR_TORCHOPS_H #ifndef TORCHMLIR_DIALECT_TORCH_IR_TORCHOPS_H
#define NPCOMP_DIALECT_TORCH_IR_TORCHOPS_H #define TORCHMLIR_DIALECT_TORCH_IR_TORCHOPS_H
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
@ -18,15 +18,14 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h"
#include "npcomp/Dialect/Torch/IR/TorchTraits.h" #include "torch-mlir/Dialect/Torch/IR/TorchTraits.h"
#include "npcomp/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "npcomp/Interfaces/Traits.h"
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "npcomp/Dialect/Torch/IR/TorchOps.h.inc" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h.inc"
namespace mlir { namespace mlir {
namespace NPCOMP { namespace torch {
namespace Torch { namespace Torch {
namespace detail { namespace detail {
@ -117,11 +116,11 @@ m_TorchConstantIntList(SmallVectorImpl<int64_t> &bind_values) {
Value copyTensorToType(OpBuilder &builder, Location loc, BaseTensorType newType, Value copyTensorToType(OpBuilder &builder, Location loc, BaseTensorType newType,
Value tensor); Value tensor);
} // namespace Torch } // namespace Torch
} // namespace NPCOMP } // namespace torch
} // namespace mlir } // namespace mlir
template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::SlotOp> { template <> struct llvm::DenseMapInfo<::mlir::torch::Torch::SlotOp> {
using SlotOp = ::mlir::NPCOMP::Torch::SlotOp; using SlotOp = ::mlir::torch::Torch::SlotOp;
static SlotOp getEmptyKey() { static SlotOp getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey(); auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return SlotOp::getFromOpaquePointer(pointer); return SlotOp::getFromOpaquePointer(pointer);
@ -136,8 +135,8 @@ template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::SlotOp> {
static bool isEqual(SlotOp lhs, SlotOp rhs) { return lhs == rhs; } static bool isEqual(SlotOp lhs, SlotOp rhs) { return lhs == rhs; }
}; };
template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::NnModuleOp> { template <> struct llvm::DenseMapInfo<::mlir::torch::Torch::NnModuleOp> {
using NnModuleOp = ::mlir::NPCOMP::Torch::NnModuleOp; using NnModuleOp = ::mlir::torch::Torch::NnModuleOp;
static NnModuleOp getEmptyKey() { static NnModuleOp getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey(); auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return NnModuleOp::getFromOpaquePointer(pointer); return NnModuleOp::getFromOpaquePointer(pointer);
@ -152,8 +151,8 @@ template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::NnModuleOp> {
static bool isEqual(NnModuleOp lhs, NnModuleOp rhs) { return lhs == rhs; } static bool isEqual(NnModuleOp lhs, NnModuleOp rhs) { return lhs == rhs; }
}; };
template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::ClassTypeOp> { template <> struct llvm::DenseMapInfo<::mlir::torch::Torch::ClassTypeOp> {
using ClassTypeOp = ::mlir::NPCOMP::Torch::ClassTypeOp; using ClassTypeOp = ::mlir::torch::Torch::ClassTypeOp;
static ClassTypeOp getEmptyKey() { static ClassTypeOp getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey(); auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return ClassTypeOp::getFromOpaquePointer(pointer); return ClassTypeOp::getFromOpaquePointer(pointer);
@ -168,8 +167,8 @@ template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::ClassTypeOp> {
static bool isEqual(ClassTypeOp lhs, ClassTypeOp rhs) { return lhs == rhs; } static bool isEqual(ClassTypeOp lhs, ClassTypeOp rhs) { return lhs == rhs; }
}; };
template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::GlobalSlotOp> { template <> struct llvm::DenseMapInfo<::mlir::torch::Torch::GlobalSlotOp> {
using OpTy = ::mlir::NPCOMP::Torch::GlobalSlotOp; using OpTy = ::mlir::torch::Torch::GlobalSlotOp;
static OpTy getEmptyKey() { static OpTy getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey(); auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return OpTy::getFromOpaquePointer(pointer); return OpTy::getFromOpaquePointer(pointer);
@ -184,4 +183,4 @@ template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::GlobalSlotOp> {
static bool isEqual(OpTy lhs, OpTy rhs) { return lhs == rhs; } static bool isEqual(OpTy lhs, OpTy rhs) { return lhs == rhs; }
}; };
#endif // NPCOMP_DIALECT_TORCH_IR_TORCHOPS_H #endif // TORCHMLIR_DIALECT_TORCH_IR_TORCHOPS_H

View File

@ -9,8 +9,7 @@
#ifndef TORCH_OPS #ifndef TORCH_OPS
#define TORCH_OPS #define TORCH_OPS
include "npcomp/Dialect/Torch/IR/TorchTypes.td" include "torch-mlir/Dialect/Torch/IR/TorchTypes.td"
include "npcomp/Interfaces/Traits.td"
include "mlir/IR/OpAsmInterface.td" include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td" include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/CastInterfaces.td"
@ -22,9 +21,9 @@ class Torch_Op<string mnemonic, list<OpTrait> traits = []>
: Op<Torch_Dialect, mnemonic, traits> { : Op<Torch_Dialect, mnemonic, traits> {
} }
include "npcomp/Dialect/Torch/IR/GeneratedAtenOps.td" include "torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td"
include "npcomp/Dialect/Torch/IR/GeneratedPrimOps.td" include "torch-mlir/Dialect/Torch/IR/GeneratedPrimOps.td"
include "npcomp/Dialect/Torch/IR/GeneratedQuantizedOps.td" include "torch-mlir/Dialect/Torch/IR/GeneratedQuantizedOps.td"
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TorchScript `torch.nn.Module` object instantiation ops. // TorchScript `torch.nn.Module` object instantiation ops.
@ -32,7 +31,7 @@ include "npcomp/Dialect/Torch/IR/GeneratedQuantizedOps.td"
def Torch_NnModuleOp : Torch_Op<"nn_module", [ def Torch_NnModuleOp : Torch_Op<"nn_module", [
DeclareOpInterfaceMethods<SymbolUserOpInterface>, DeclareOpInterfaceMethods<SymbolUserOpInterface>,
SingleBlockImplicitTerminator<"::mlir::NPCOMP::Torch::NnModuleTerminatorOp">]> { SingleBlockImplicitTerminator<"::mlir::torch::Torch::NnModuleTerminatorOp">]> {
let summary = "Constructs a torch.nn.Module"; let summary = "Constructs a torch.nn.Module";
let description = [{ let description = [{
This op is used to represent a torch.nn.Module when importing a This op is used to represent a torch.nn.Module when importing a
@ -75,7 +74,7 @@ def Torch_NnModuleOp : Torch_Op<"nn_module", [
} }
def Torch_NnModuleTerminatorOp : Torch_Op<"nn_module_terminator", [Terminator, def Torch_NnModuleTerminatorOp : Torch_Op<"nn_module_terminator", [Terminator,
HasParent<"::mlir::NPCOMP::Torch::NnModuleOp">]> { HasParent<"::mlir::torch::Torch::NnModuleOp">]> {
let summary = "Implicit terminator for torch.nn_module"; let summary = "Implicit terminator for torch.nn_module";
let arguments = (ins); let arguments = (ins);
@ -85,7 +84,7 @@ def Torch_NnModuleTerminatorOp : Torch_Op<"nn_module_terminator", [Terminator,
} }
def Torch_SlotOp : Torch_Op<"slot", [ def Torch_SlotOp : Torch_Op<"slot", [
HasParent<"::mlir::NPCOMP::Torch::NnModuleOp">]> { HasParent<"::mlir::torch::Torch::NnModuleOp">]> {
let summary = "Define the value of a slot of a torch.nn.Module"; let summary = "Define the value of a slot of a torch.nn.Module";
let description = [{ let description = [{
This op specifies that the initial value of the slot `name` of the This op specifies that the initial value of the slot `name` of the
@ -107,7 +106,7 @@ def Torch_SlotOp : Torch_Op<"slot", [
def Torch_ClassTypeOp : Torch_Op<"class_type", [ def Torch_ClassTypeOp : Torch_Op<"class_type", [
Symbol, Symbol,
SingleBlockImplicitTerminator<"::mlir::NPCOMP::Torch::ClassTypeTerminatorOp">]> { SingleBlockImplicitTerminator<"::mlir::torch::Torch::ClassTypeTerminatorOp">]> {
let summary = "Constructs a torch.ClassType"; let summary = "Constructs a torch.ClassType";
let description = [{ let description = [{
Declares a class type. Class types are the types used to describe Declares a class type. Class types are the types used to describe
@ -152,7 +151,7 @@ def Torch_ClassTypeOp : Torch_Op<"class_type", [
} }
def Torch_ClassTypeTerminatorOp : Torch_Op<"class_type_terminator", [Terminator, def Torch_ClassTypeTerminatorOp : Torch_Op<"class_type_terminator", [Terminator,
HasParent<"::mlir::NPCOMP::Torch::ClassTypeOp">]> { HasParent<"::mlir::torch::Torch::ClassTypeOp">]> {
let summary = "Implicit terminator for torch.class_type"; let summary = "Implicit terminator for torch.class_type";
let arguments = (ins); let arguments = (ins);
@ -162,7 +161,7 @@ def Torch_ClassTypeTerminatorOp : Torch_Op<"class_type_terminator", [Terminator,
} }
def Torch_MethodOp : Torch_Op<"method", [ def Torch_MethodOp : Torch_Op<"method", [
HasParent<"::mlir::NPCOMP::Torch::ClassTypeOp">, HasParent<"::mlir::torch::Torch::ClassTypeOp">,
DeclareOpInterfaceMethods<SymbolUserOpInterface> DeclareOpInterfaceMethods<SymbolUserOpInterface>
]> { ]> {
let summary = "Declare a method of a torch.class_type"; let summary = "Declare a method of a torch.class_type";
@ -193,7 +192,7 @@ def Torch_MethodOp : Torch_Op<"method", [
} }
def Torch_AttrOp : Torch_Op<"attr", [ def Torch_AttrOp : Torch_Op<"attr", [
HasParent<"::mlir::NPCOMP::Torch::ClassTypeOp"> HasParent<"::mlir::torch::Torch::ClassTypeOp">
]> { ]> {
let summary = "Declare an attribute of a torch.class_type"; let summary = "Declare an attribute of a torch.class_type";
let description = [{ let description = [{
@ -231,7 +230,7 @@ def Torch_AttrOp : Torch_Op<"attr", [
def Torch_GlobalSlotOp : Torch_Op<"global_slot", [ def Torch_GlobalSlotOp : Torch_Op<"global_slot", [
Symbol, Symbol,
IsolatedFromAbove, IsolatedFromAbove,
SingleBlockImplicitTerminator<"::mlir::NPCOMP::Torch::GlobalSlotInitOp"> SingleBlockImplicitTerminator<"::mlir::torch::Torch::GlobalSlotInitOp">
]> { ]> {
let summary = "A slot with global storage"; let summary = "A slot with global storage";
let description = [{ let description = [{
@ -256,7 +255,7 @@ def Torch_GlobalSlotOp : Torch_Op<"global_slot", [
def Torch_GlobalSlotInitOp : Torch_Op<"global_slot.init", [ def Torch_GlobalSlotInitOp : Torch_Op<"global_slot.init", [
Terminator, Terminator,
HasParent<"::mlir::NPCOMP::Torch::GlobalSlotOp">]> { HasParent<"::mlir::torch::Torch::GlobalSlotOp">]> {
let summary = "yield-like terminator for torch.global_slot initializer region"; let summary = "yield-like terminator for torch.global_slot initializer region";
let description = [{ let description = [{
The operand to this op becomes the initial value of the parent The operand to this op becomes the initial value of the parent
@ -463,7 +462,7 @@ def Torch_PrimLoopOp : Torch_Op<"prim.Loop", [
def Torch_PrimLoopConditionOp : Torch_Op<"prim.Loop.condition", [ def Torch_PrimLoopConditionOp : Torch_Op<"prim.Loop.condition", [
DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>, DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>,
Terminator, Terminator,
HasParent<"::mlir::NPCOMP::Torch::PrimLoopOp">]> { HasParent<"::mlir::torch::Torch::PrimLoopOp">]> {
let summary = "yield-like terminator for torch.prim.Loop"; let summary = "yield-like terminator for torch.prim.Loop";
let description = [{ let description = [{
Does not correspond to any torch prim op directly (the way that they model Does not correspond to any torch prim op directly (the way that they model
@ -512,7 +511,7 @@ def Torch_PrimIfOp : Torch_Op<"prim.If", [
def Torch_PrimIfYieldOp : Torch_Op<"prim.If.yield", [ def Torch_PrimIfYieldOp : Torch_Op<"prim.If.yield", [
Terminator, Terminator,
ReturnLike, ReturnLike,
HasParent<"::mlir::NPCOMP::Torch::PrimIfOp">]> { HasParent<"::mlir::torch::Torch::PrimIfOp">]> {
let summary = "yield-like terminator for torch.prim.If"; let summary = "yield-like terminator for torch.prim.If";
let description = [{ let description = [{
Does not correspond to any torch prim op directly (the way that they model Does not correspond to any torch prim op directly (the way that they model

View File

@ -10,15 +10,14 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_DIALECT_TORCH_IR_TORCHTRAITS_H
#ifndef NPCOMP_DIALECT_TORCH_IR_TORCHTRAITS_H #define TORCHMLIR_DIALECT_TORCH_IR_TORCHTRAITS_H
#define NPCOMP_DIALECT_TORCH_IR_TORCHTRAITS_H
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
namespace mlir { namespace mlir {
namespace NPCOMP { namespace torch {
namespace Torch { namespace Torch {
namespace OpTrait { namespace OpTrait {
@ -39,9 +38,16 @@ class IsTrailingUnderscoreInplaceVariant
: public ::mlir::OpTrait::TraitBase<ConcreteType, : public ::mlir::OpTrait::TraitBase<ConcreteType,
IsTrailingUnderscoreInplaceVariant> {}; IsTrailingUnderscoreInplaceVariant> {};
// If a Torch op has this trait, it means that the op allows all of its operand
// and result types to be refined. That is, a less specific type is allowed to
// be replaced by a more specific type, according to PEP 483 subtyping rules.
template <typename ConcreteType>
class AllowsTypeRefinement
: public ::mlir::OpTrait::TraitBase<ConcreteType, AllowsTypeRefinement> {};
} // namespace OpTrait } // namespace OpTrait
} // namespace Torch } // namespace Torch
} // namespace NPCOMP } // namespace torch
} // namespace mlir } // namespace mlir
#endif // NPCOMP_DIALECT_TORCH_IR_TORCHTRAITS_H #endif // TORCHMLIR_DIALECT_TORCH_IR_TORCHTRAITS_H

View File

@ -6,13 +6,13 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef NPCOMP_DIALECT_TORCH_IR_TORCHTYPES_H #ifndef TORCHMLIR_DIALECT_TORCH_IR_TORCHTYPES_H
#define NPCOMP_DIALECT_TORCH_IR_TORCHTYPES_H #define TORCHMLIR_DIALECT_TORCH_IR_TORCHTYPES_H
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
namespace mlir { namespace mlir {
namespace NPCOMP { namespace torch {
namespace Torch { namespace Torch {
class NonValueTensorType; class NonValueTensorType;
@ -87,18 +87,18 @@ public:
ValueTensorType getWithValueSemantics() const; ValueTensorType getWithValueSemantics() const;
}; };
} // namespace Torch } // namespace Torch
} // namespace NPCOMP } // namespace torch
} // namespace mlir } // namespace mlir
#define GET_TYPEDEF_CLASSES #define GET_TYPEDEF_CLASSES
#include "npcomp/Dialect/Torch/IR/TorchTypes.h.inc" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h.inc"
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Inline definitions // Inline definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
namespace mlir { namespace mlir {
namespace NPCOMP { namespace torch {
namespace Torch { namespace Torch {
inline Optional<ArrayRef<int64_t>> BaseTensorType::getOptionalSizes() const { inline Optional<ArrayRef<int64_t>> BaseTensorType::getOptionalSizes() const {
@ -122,7 +122,7 @@ inline bool BaseTensorType::classof(Type type) {
} }
} // namespace Torch } // namespace Torch
} // namespace NPCOMP } // namespace torch
} // namespace mlir } // namespace mlir
#endif // NPCOMP_DIALECT_TORCH_IR_TORCHTYPES_H #endif // TORCHMLIR_DIALECT_TORCH_IR_TORCHTYPES_H

View File

@ -9,7 +9,7 @@
#ifndef TORCH_TYPES #ifndef TORCH_TYPES
#define TORCH_TYPES #define TORCH_TYPES
include "npcomp/Dialect/Torch/IR/TorchBase.td" include "torch-mlir/Dialect/Torch/IR/TorchBase.td"
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Type defs // Type defs
@ -83,7 +83,7 @@ class OptionalArrayRefParameter<string arrayOf, string desc = ""> :
} }
class AnyTorchTensorType<string name, string typeMnemonic> class AnyTorchTensorType<string name, string typeMnemonic>
: Torch_Type<name, typeMnemonic, "::mlir::NPCOMP::Torch::BaseTensorType"> { : Torch_Type<name, typeMnemonic, "::mlir::torch::Torch::BaseTensorType"> {
let summary = "Multi-dimensional array modeling Torch's Tensor type"; let summary = "Multi-dimensional array modeling Torch's Tensor type";
let description = [{ let description = [{
Syntax: Syntax:
@ -107,8 +107,8 @@ class AnyTorchTensorType<string name, string typeMnemonic>
a strict separation between the value-semantic and potentially-mutating a strict separation between the value-semantic and potentially-mutating
worlds, as one of our main jobs in the compiler is to isolate the mutating worlds, as one of our main jobs in the compiler is to isolate the mutating
parts as much as possible because most lower levels of the compiler stack parts as much as possible because most lower levels of the compiler stack
are expected to require value semantics. E.g. npcomp's backend contract are expected to require value semantics. E.g. many backend contracts
is mostly in terms of linalg-on-tensor for compute-heavy ops, which require mostly use linalg-on-tensor for compute-heavy ops, which require
a conversion to the builtin `tensor` type which has value semantics. a conversion to the builtin `tensor` type which has value semantics.
Some notes about value semantics: Some notes about value semantics:
- Using the type system described in PEP 483 (which TorchScript and other - Using the type system described in PEP 483 (which TorchScript and other
@ -165,7 +165,7 @@ class AnyTorchTensorType<string name, string typeMnemonic>
Note: We avoid the C++ identifier `TensorType` to avoid C++ name ambiguities Note: We avoid the C++ identifier `TensorType` to avoid C++ name ambiguities
with `mlir::TensorType`, since most code is transitively nested in with `mlir::TensorType`, since most code is transitively nested in
both `::mlir` and `::mlir::NPCOMP::Torch` namespaces. both `::mlir` and `::mlir::torch::Torch` namespaces.
Note: We use the Torch-aligned terminology "sizes" and "dtype" instead of Note: We use the Torch-aligned terminology "sizes" and "dtype" instead of
the MLIR-aligned terminology "rank/shape" and "element type". The cheat the MLIR-aligned terminology "rank/shape" and "element type". The cheat
@ -209,7 +209,7 @@ def Torch_ValueTensorType : AnyTorchTensorType<"ValueTensor", "vtensor"> {
} }
def AnyTorchTensorType : Type< def AnyTorchTensorType : Type<
CPred<"$_self.isa<::mlir::NPCOMP::Torch::BaseTensorType>()">, CPred<"$_self.isa<::mlir::torch::Torch::BaseTensorType>()">,
"Any Torch tensor type" "Any Torch tensor type"
>; >;
@ -317,7 +317,7 @@ def Torch_LinearParamsType : Torch_Type<"LinearParams", "LinearParams"> {
The whole LinearPackedParamsBase abstraction as it stands in PyTorch is a The whole LinearPackedParamsBase abstraction as it stands in PyTorch is a
very library-call-y, runtime-y thing that embodies a number of assumptions very library-call-y, runtime-y thing that embodies a number of assumptions
about the structure of how the program will be executed, which need not hold about the structure of how the program will be executed, which need not hold
for npcomp backends. for backends.
}]; }];
} }
@ -386,12 +386,12 @@ def TorchOptionalBoolType:
def TorchOptionalDeviceType: def TorchOptionalDeviceType:
OptionalOf<Torch_DeviceType, "Optional torch device type">; OptionalOf<Torch_DeviceType, "Optional torch device type">;
def IsListTypePred : CPred<"$_self.isa<::mlir::NPCOMP::Torch::ListType>()">; def IsListTypePred : CPred<"$_self.isa<::mlir::torch::Torch::ListType>()">;
class ListOf<list<Type> allowedTypes, string descr> : class ListOf<list<Type> allowedTypes, string descr> :
ContainerType<AnyTypeOf<allowedTypes>, ContainerType<AnyTypeOf<allowedTypes>,
IsListTypePred, IsListTypePred,
"$_self.cast<::mlir::NPCOMP::Torch::ListType>().getContainedType()", "$_self.cast<::mlir::torch::Torch::ListType>().getContainedType()",
descr, "::mlir::NPCOMP::Torch::ListType">; descr, "::mlir::torch::Torch::ListType">;
def TorchBoolListType : ListOf<[Torch_BoolType], "Bool list type (bool[])">; def TorchBoolListType : ListOf<[Torch_BoolType], "Bool list type (bool[])">;
def TorchIntListType : ListOf<[Torch_IntType], "Int list type (int[])">; def TorchIntListType : ListOf<[Torch_IntType], "Int list type (int[])">;

View File

@ -0,0 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls)
add_public_tablegen_target(TorchMLIRTorchPassIncGen)
add_mlir_doc(Passes TorchMLIRTorchTransforms ./ -gen-pass-doc)

View File

@ -6,15 +6,15 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSES_H #ifndef TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSES_H
#define NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSES_H #define TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSES_H
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include <memory> #include <memory>
namespace mlir { namespace mlir {
namespace NPCOMP { namespace torch {
namespace Torch { namespace Torch {
std::unique_ptr<OperationPass<ModuleOp>> createGlobalizeObjectGraphPass(); std::unique_ptr<OperationPass<ModuleOp>> createGlobalizeObjectGraphPass();
@ -58,7 +58,7 @@ std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
/// Registers all Torch transformation passes. /// Registers all Torch transformation passes.
void registerTorchPasses(); void registerTorchPasses();
} // namespace NPCOMP } // namespace torch
} // namespace mlir } // namespace mlir
#endif // NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSES_H #endif // TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSES_H

View File

@ -6,14 +6,14 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef NPCOMP_TORCH_PASSES #ifndef TORCHMLIR_TORCH_PASSES
#define NPCOMP_TORCH_PASSES #define TORCHMLIR_TORCH_PASSES
include "mlir/Pass/PassBase.td" include "mlir/Pass/PassBase.td"
def GlobalizeObjectGraph : Pass<"torch-globalize-object-graph", "ModuleOp"> { def GlobalizeObjectGraph : Pass<"torch-globalize-object-graph", "ModuleOp"> {
let summary = "Converts TorchScript object graphs to a globalized form"; let summary = "Converts TorchScript object graphs to a globalized form";
let constructor = "mlir::NPCOMP::Torch::createGlobalizeObjectGraphPass()"; let constructor = "mlir::torch::Torch::createGlobalizeObjectGraphPass()";
let description = [{ let description = [{
This pass converts a subset of possible TorchScript modules into a This pass converts a subset of possible TorchScript modules into a
more restrictive lower-level form that strips away the need to be more restrictive lower-level form that strips away the need to be
@ -80,7 +80,7 @@ def GlobalizeObjectGraph : Pass<"torch-globalize-object-graph", "ModuleOp"> {
- Rationale: This makes the representation of initial values simpler. Also - Rationale: This makes the representation of initial values simpler. Also
as of Feb 2021, TorchScript won't import into this form except as of Feb 2021, TorchScript won't import into this form except
potentially for Tensors (it has a bug related to the identity of potentially for Tensors (it has a bug related to the identity of
objects). And for tensors, the npcomp IValue importer only supports a objects). And for tensors, the IValue importer only supports a
very restricted form of aliasing anyway for other reasons. We are very restricted form of aliasing anyway for other reasons. We are
waiting for signals that more general handling of object aliasing is waiting for signals that more general handling of object aliasing is
important to devote the effort to it. important to devote the effort to it.
@ -90,7 +90,7 @@ def GlobalizeObjectGraph : Pass<"torch-globalize-object-graph", "ModuleOp"> {
def PrepareForGlobalizeObjectGraph def PrepareForGlobalizeObjectGraph
: Pass<"torch-prepare-for-globalize-object-graph", "ModuleOp"> { : Pass<"torch-prepare-for-globalize-object-graph", "ModuleOp"> {
let summary = "Lowering in preparation for globalizing"; let summary = "Lowering in preparation for globalizing";
let constructor = "mlir::NPCOMP::Torch::createPrepareForGlobalizeObjectGraphPass()"; let constructor = "mlir::torch::Torch::createPrepareForGlobalizeObjectGraphPass()";
let description = [{ let description = [{
Establishes and the invariants needed by the Establishes and the invariants needed by the
torch-globalize-object-graph transformation. Fails if that cannot be torch-globalize-object-graph transformation. Fails if that cannot be
@ -104,7 +104,7 @@ def PrepareForGlobalizeObjectGraph
def AdjustCallingConventions def AdjustCallingConventions
: Pass<"torch-adjust-calling-conventions", "ModuleOp"> { : Pass<"torch-adjust-calling-conventions", "ModuleOp"> {
let summary = "Adjust the calling conventions of functions"; let summary = "Adjust the calling conventions of functions";
let constructor = "mlir::NPCOMP::Torch::createAdjustCallingConventionsPass()"; let constructor = "mlir::torch::Torch::createAdjustCallingConventionsPass()";
let description = [{ let description = [{
Adjusts the calling conventions of functions in the module, with the aim of Adjusts the calling conventions of functions in the module, with the aim of
preparing them for backends and further lowering passes. As this changes preparing them for backends and further lowering passes. As this changes
@ -127,7 +127,7 @@ def AdjustCallingConventions
def RefineTypes : Pass<"torch-refine-types", "FuncOp"> { def RefineTypes : Pass<"torch-refine-types", "FuncOp"> {
let summary = "Refine types"; let summary = "Refine types";
let constructor = "mlir::NPCOMP::Torch::createRefineTypesPass()"; let constructor = "mlir::torch::Torch::createRefineTypesPass()";
let description = [{ let description = [{
Refines types of the program. Currently, this means shapes and dtypes of Refines types of the program. Currently, this means shapes and dtypes of
tensors/arrays. tensors/arrays.
@ -136,7 +136,7 @@ def RefineTypes : Pass<"torch-refine-types", "FuncOp"> {
def InlineGlobalSlots : Pass<"torch-inline-global-slots", "ModuleOp"> { def InlineGlobalSlots : Pass<"torch-inline-global-slots", "ModuleOp"> {
let summary = "Inlines torch.global_slot ops."; let summary = "Inlines torch.global_slot ops.";
let constructor = "mlir::NPCOMP::Torch::createInlineGlobalSlotsPass()"; let constructor = "mlir::torch::Torch::createInlineGlobalSlotsPass()";
let description = [{ let description = [{
Inlines torch.global_slot ops when it is safe to do so. Inlines torch.global_slot ops when it is safe to do so.
@ -150,7 +150,7 @@ def InlineGlobalSlots : Pass<"torch-inline-global-slots", "ModuleOp"> {
def ReduceOpVariants : Pass<"torch-reduce-op-variants", "FuncOp"> { def ReduceOpVariants : Pass<"torch-reduce-op-variants", "FuncOp"> {
let summary = "Reduces variants of ops to a smaller set of ops."; let summary = "Reduces variants of ops to a smaller set of ops.";
let constructor = "mlir::NPCOMP::Torch::createReduceOpVariantsPass()"; let constructor = "mlir::torch::Torch::createReduceOpVariantsPass()";
let description = [{ let description = [{
Replaces ops with other ops to reduce the number of variants that Replaces ops with other ops to reduce the number of variants that
need to be handled elsewhere in the code. need to be handled elsewhere in the code.
@ -181,13 +181,13 @@ def MaximizeValueSemantics : Pass<"torch-maximize-value-semantics", "FuncOp"> {
Also, this pass doesn't currently handle interprocedural rewriting Also, this pass doesn't currently handle interprocedural rewriting
(of private functions), which is even more complex. (of private functions), which is even more complex.
}]; }];
let constructor = "mlir::NPCOMP::Torch::createMaximizeValueSemanticsPass()"; let constructor = "mlir::torch::Torch::createMaximizeValueSemanticsPass()";
} }
def RefinePublicReturn : Pass<"torch-refine-public-return", "ModuleOp"> { def RefinePublicReturn : Pass<"torch-refine-public-return", "ModuleOp"> {
let summary = "Refine public return"; let summary = "Refine public return";
let constructor = "mlir::NPCOMP::Torch::createRefinePublicReturnPass()"; let constructor = "mlir::torch::Torch::createRefinePublicReturnPass()";
let description = [{ let description = [{
Refines types of values returned from public functions based on Refines types of values returned from public functions based on
intraprocedural information. intraprocedural information.
@ -214,4 +214,4 @@ def RefinePublicReturn : Pass<"torch-refine-public-return", "ModuleOp"> {
}]; }];
} }
#endif // NPCOMP_TORCH_PASSES #endif // TORCHMLIR_TORCH_PASSES

View File

@ -0,0 +1,23 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef TORCH_MLIR_INITALL_H
#define TORCH_MLIR_INITALL_H
#include "mlir/IR/Dialect.h"
namespace mlir {
namespace torch {
void registerAllDialects(mlir::DialectRegistry &registry);
void registerAllPasses();
} // namespace torch
} // namespace mlir
#endif // TORCH_MLIR_INITALL_H

View File

@ -0,0 +1,18 @@
add_mlir_library(TorchMLIRCAPI
Registration.cpp
TorchTypes.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir-c/
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRSupport
TorchMLIRTorchDialect
TorchMLIRInitAll
)
torch_mlir_target_includes(TorchMLIRCAPI)

View File

@ -0,0 +1,25 @@
//===- Registration.cpp - C Interface for MLIR Registration ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "torch-mlir-c/Registration.h"
#include "mlir/CAPI/IR.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Transforms/Passes.h"
#include "torch-mlir/InitAll.h"
void torchMlirRegisterAllDialects(MlirContext context) {
mlir::DialectRegistry registry;
mlir::torch::registerAllDialects(registry);
unwrap(context)->appendDialectRegistry(registry);
// TODO: Don't eagerly load once D88162 is in and clients can do this.
unwrap(context)->loadAllAvailableDialects();
}
void torchMlirRegisterAllPasses() { mlir::torch::registerAllPasses(); }

View File

@ -6,26 +6,26 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "npcomp-c/TorchTypes.h" #include "torch-mlir-c/TorchTypes.h"
#include "mlir/CAPI/IR.h" #include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h" #include "mlir/CAPI/Support.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "npcomp/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP; using namespace mlir::torch;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.nn.Module type. // torch.nn.Module type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool npcompTypeIsATorchNnModule(MlirType t) { bool torchMlirTypeIsATorchNnModule(MlirType t) {
return unwrap(t).isa<Torch::NnModuleType>(); return unwrap(t).isa<Torch::NnModuleType>();
} }
MlirType npcompTorchNnModuleTypeGet(MlirContext context, MlirType torchMlirTorchNnModuleTypeGet(MlirContext context,
MlirStringRef className) { MlirStringRef className) {
return wrap(Torch::NnModuleType::get(unwrap(context), unwrap(className))); return wrap(Torch::NnModuleType::get(unwrap(context), unwrap(className)));
} }
@ -33,11 +33,11 @@ MlirType npcompTorchNnModuleTypeGet(MlirContext context,
// torch.optional type. // torch.optional type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool npcompTypeIsATorchOptional(MlirType t) { bool torchMlirTypeIsATorchOptional(MlirType t) {
return unwrap(t).isa<Torch::OptionalType>(); return unwrap(t).isa<Torch::OptionalType>();
} }
MlirType npcompTorchOptionalTypeGet(MlirType containedType) { MlirType torchMlirTorchOptionalTypeGet(MlirType containedType) {
return wrap(Torch::OptionalType::get(unwrap(containedType))); return wrap(Torch::OptionalType::get(unwrap(containedType)));
} }
@ -45,13 +45,13 @@ MlirType npcompTorchOptionalTypeGet(MlirType containedType) {
// torch.tuple<T1, T2, T3> type. // torch.tuple<T1, T2, T3> type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool npcompTypeIsATorchTuple(MlirType t) { bool torchMlirTypeIsATorchTuple(MlirType t) {
return unwrap(t).isa<Torch::TupleType>(); return unwrap(t).isa<Torch::TupleType>();
} }
MlirType npcompTorchTupleTypeGet(MlirContext context, MlirType torchMlirTorchTupleTypeGet(MlirContext context,
intptr_t numContainedTypes, intptr_t numContainedTypes,
MlirType const *containedTypes) { MlirType const *containedTypes) {
return wrap(Torch::TupleType::get( return wrap(Torch::TupleType::get(
unwrap(context), unwrap(context),
llvm::to_vector<6>( llvm::to_vector<6>(
@ -63,11 +63,11 @@ MlirType npcompTorchTupleTypeGet(MlirContext context,
// torch.list<T> type. // torch.list<T> type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool npcompTypeIsATorchList(MlirType t) { bool torchMlirTypeIsATorchList(MlirType t) {
return unwrap(t).isa<Torch::ListType>(); return unwrap(t).isa<Torch::ListType>();
} }
MlirType npcompTorchListTypeGet(MlirType containedType) { MlirType torchMlirTorchListTypeGet(MlirType containedType) {
return wrap(Torch::ListType::get(unwrap(containedType))); return wrap(Torch::ListType::get(unwrap(containedType)));
} }
@ -75,11 +75,11 @@ MlirType npcompTorchListTypeGet(MlirType containedType) {
// torch.Device type. // torch.Device type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool npcompTypeIsATorchDevice(MlirType t) { bool torchMlirTypeIsATorchDevice(MlirType t) {
return unwrap(t).isa<Torch::DeviceType>(); return unwrap(t).isa<Torch::DeviceType>();
} }
MlirType npcompTorchDeviceTypeGet(MlirContext context) { MlirType torchMlirTorchDeviceTypeGet(MlirContext context) {
return wrap(Torch::DeviceType::get(unwrap(context))); return wrap(Torch::DeviceType::get(unwrap(context)));
} }
@ -87,11 +87,11 @@ MlirType npcompTorchDeviceTypeGet(MlirContext context) {
// torch.bool type. // torch.bool type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool npcompTypeIsATorchBool(MlirType t) { bool torchMlirTypeIsATorchBool(MlirType t) {
return unwrap(t).isa<Torch::BoolType>(); return unwrap(t).isa<Torch::BoolType>();
} }
MlirType npcompTorchBoolTypeGet(MlirContext context) { MlirType torchMlirTorchBoolTypeGet(MlirContext context) {
return wrap(Torch::BoolType::get(unwrap(context))); return wrap(Torch::BoolType::get(unwrap(context)));
} }
@ -99,11 +99,11 @@ MlirType npcompTorchBoolTypeGet(MlirContext context) {
// torch.int type. // torch.int type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool npcompTypeIsATorchInt(MlirType t) { bool torchMlirTypeIsATorchInt(MlirType t) {
return unwrap(t).isa<Torch::IntType>(); return unwrap(t).isa<Torch::IntType>();
} }
MlirType npcompTorchIntTypeGet(MlirContext context) { MlirType torchMlirTorchIntTypeGet(MlirContext context) {
return wrap(Torch::IntType::get(unwrap(context))); return wrap(Torch::IntType::get(unwrap(context)));
} }
@ -111,11 +111,11 @@ MlirType npcompTorchIntTypeGet(MlirContext context) {
// torch.float type. // torch.float type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool npcompTypeIsATorchFloat(MlirType t) { bool torchMlirTypeIsATorchFloat(MlirType t) {
return unwrap(t).isa<Torch::FloatType>(); return unwrap(t).isa<Torch::FloatType>();
} }
MlirType npcompTorchFloatTypeGet(MlirContext context) { MlirType torchMlirTorchFloatTypeGet(MlirContext context) {
return wrap(Torch::FloatType::get(unwrap(context))); return wrap(Torch::FloatType::get(unwrap(context)));
} }
@ -123,11 +123,11 @@ MlirType npcompTorchFloatTypeGet(MlirContext context) {
// torch.LinearParams type. // torch.LinearParams type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool npcompTypeIsATorchLinearParams(MlirType t) { bool torchMlirTypeIsATorchLinearParams(MlirType t) {
return unwrap(t).isa<Torch::LinearParamsType>(); return unwrap(t).isa<Torch::LinearParamsType>();
} }
MlirType npcompTorchLinearParamsTypeGet(MlirContext context) { MlirType torchMlirTorchLinearParamsTypeGet(MlirContext context) {
return wrap(Torch::LinearParamsType::get(unwrap(context))); return wrap(Torch::LinearParamsType::get(unwrap(context)));
} }
@ -135,11 +135,11 @@ MlirType npcompTorchLinearParamsTypeGet(MlirContext context) {
// torch.qint8 type. // torch.qint8 type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool npcompTypeIsATorchQInt8(MlirType t) { bool torchMlirTypeIsATorchQInt8(MlirType t) {
return unwrap(t).isa<Torch::QInt8Type>(); return unwrap(t).isa<Torch::QInt8Type>();
} }
MlirType npcompTorchQInt8TypeGet(MlirContext context) { MlirType torchMlirTorchQInt8TypeGet(MlirContext context) {
return wrap(Torch::QInt8Type::get(unwrap(context))); return wrap(Torch::QInt8Type::get(unwrap(context)));
} }
@ -147,14 +147,14 @@ MlirType npcompTorchQInt8TypeGet(MlirContext context) {
// torch.tensor type. // torch.tensor type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool npcompTypeIsATorchNonValueTensor(MlirType t) { bool torchMlirTypeIsATorchNonValueTensor(MlirType t) {
return unwrap(t).isa<Torch::NonValueTensorType>(); return unwrap(t).isa<Torch::NonValueTensorType>();
} }
MlirType npcompTorchNonValueTensorTypeGet(MlirContext context, MlirType torchMlirTorchNonValueTensorTypeGet(MlirContext context,
intptr_t numSizes, intptr_t numSizes,
const int64_t *optionalSizes, const int64_t *optionalSizes,
MlirType optionalDtype) { MlirType optionalDtype) {
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None; Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
if (optionalSizes) if (optionalSizes)
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes); optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
@ -162,13 +162,13 @@ MlirType npcompTorchNonValueTensorTypeGet(MlirContext context,
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype))); unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
} }
MlirType npcompTorchNonValueTensorTypeGetWithLeastStaticInformation( MlirType torchMlirTorchNonValueTensorTypeGetWithLeastStaticInformation(
MlirContext context) { MlirContext context) {
return wrap(Torch::NonValueTensorType::getWithLeastStaticInformation( return wrap(Torch::NonValueTensorType::getWithLeastStaticInformation(
unwrap(context))); unwrap(context)));
} }
MlirType npcompTorchNonValueTensorTypeGetFromShaped(MlirType type) { MlirType torchMlirTorchNonValueTensorTypeGetFromShaped(MlirType type) {
return wrap(Torch::NonValueTensorType::getFromShaped( return wrap(Torch::NonValueTensorType::getFromShaped(
unwrap(type).cast<ShapedType>())); unwrap(type).cast<ShapedType>()));
} }
@ -177,13 +177,14 @@ MlirType npcompTorchNonValueTensorTypeGetFromShaped(MlirType type) {
// torch.vtensor type. // torch.vtensor type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool npcompTypeIsATorchValueTensor(MlirType t) { bool torchMlirTypeIsATorchValueTensor(MlirType t) {
return unwrap(t).isa<Torch::ValueTensorType>(); return unwrap(t).isa<Torch::ValueTensorType>();
} }
MlirType npcompTorchValueTensorTypeGet(MlirContext context, intptr_t numSizes, MlirType torchMlirTorchValueTensorTypeGet(MlirContext context,
const int64_t *optionalSizes, intptr_t numSizes,
MlirType optionalDtype) { const int64_t *optionalSizes,
MlirType optionalDtype) {
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None; Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
if (optionalSizes) if (optionalSizes)
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes); optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
@ -191,13 +192,13 @@ MlirType npcompTorchValueTensorTypeGet(MlirContext context, intptr_t numSizes,
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype))); unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
} }
MlirType MlirType torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(
npcompTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context) { MlirContext context) {
return wrap( return wrap(
Torch::ValueTensorType::getWithLeastStaticInformation(unwrap(context))); Torch::ValueTensorType::getWithLeastStaticInformation(unwrap(context)));
} }
MlirType npcompTorchValueTensorTypeGetFromShaped(MlirType type) { MlirType torchMlirTorchValueTensorTypeGetFromShaped(MlirType type) {
return wrap( return wrap(
Torch::ValueTensorType::getFromShaped(unwrap(type).cast<ShapedType>())); Torch::ValueTensorType::getFromShaped(unwrap(type).cast<ShapedType>()));
} }
@ -206,11 +207,11 @@ MlirType npcompTorchValueTensorTypeGetFromShaped(MlirType type) {
// torch.none type. // torch.none type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool npcompTypeIsATorchNone(MlirType t) { bool torchMlirTypeIsATorchNone(MlirType t) {
return unwrap(t).isa<Torch::NoneType>(); return unwrap(t).isa<Torch::NoneType>();
} }
MlirType npcompTorchNoneTypeGet(MlirContext context) { MlirType torchMlirTorchNoneTypeGet(MlirContext context) {
return wrap(Torch::NoneType::get(unwrap(context))); return wrap(Torch::NoneType::get(unwrap(context)));
} }
@ -218,11 +219,11 @@ MlirType npcompTorchNoneTypeGet(MlirContext context) {
// torch.str type. // torch.str type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool npcompTypeIsATorchString(MlirType t) { bool torchMlirTypeIsATorchString(MlirType t) {
return unwrap(t).isa<Torch::StringType>(); return unwrap(t).isa<Torch::StringType>();
} }
MlirType npcompTorchStringTypeGet(MlirContext context) { MlirType torchMlirTorchStringTypeGet(MlirContext context) {
return wrap(Torch::StringType::get(unwrap(context))); return wrap(Torch::StringType::get(unwrap(context)));
} }
@ -230,11 +231,11 @@ MlirType npcompTorchStringTypeGet(MlirContext context) {
// torch.any type. // torch.any type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool npcompTypeIsATorchAny(MlirType t) { bool torchMlirTypeIsATorchAny(MlirType t) {
return unwrap(t).isa<Torch::AnyType>(); return unwrap(t).isa<Torch::AnyType>();
} }
MlirType npcompTorchAnyTypeGet(MlirContext context) { MlirType torchMlirTorchAnyTypeGet(MlirContext context) {
return wrap(Torch::AnyType::get(unwrap(context))); return wrap(Torch::AnyType::get(unwrap(context)));
} }
@ -242,11 +243,11 @@ MlirType npcompTorchAnyTypeGet(MlirContext context) {
// torch.number type. // torch.number type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool npcompTypeIsATorchNumber(MlirType t) { bool torchMlirTypeIsATorchNumber(MlirType t) {
return unwrap(t).isa<Torch::NumberType>(); return unwrap(t).isa<Torch::NumberType>();
} }
MlirType npcompTorchNumberTypeGet(MlirContext context) { MlirType torchMlirTorchNumberTypeGet(MlirContext context) {
return wrap(Torch::NumberType::get(unwrap(context))); return wrap(Torch::NumberType::get(unwrap(context)));
} }
@ -254,10 +255,10 @@ MlirType npcompTorchNumberTypeGet(MlirContext context) {
// torch.Dict type. // torch.Dict type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
bool npcompTypeIsATorchDict(MlirType t) { bool torchMlirTypeIsATorchDict(MlirType t) {
return unwrap(t).isa<Torch::DictType>(); return unwrap(t).isa<Torch::DictType>();
} }
MlirType npcompTorchDictTypeGet(MlirType keyType, MlirType valueType) { MlirType torchMlirTorchDictTypeGet(MlirType keyType, MlirType valueType) {
return wrap(Torch::DictType::get(unwrap(keyType), unwrap(valueType))); return wrap(Torch::DictType::get(unwrap(keyType), unwrap(valueType)));
} }

View File

@ -0,0 +1,17 @@
add_subdirectory(CAPI)
add_subdirectory(Dialect)
add_mlir_library(TorchMLIRInitAll
InitAll.cpp
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRSupport
TorchMLIRTorchDialect
TorchMLIRTorchPasses
)
torch_mlir_target_includes(TorchMLIRInitAll)

View File

@ -0,0 +1 @@
add_subdirectory(Torch)

View File

@ -1,10 +1,10 @@
add_npcomp_dialect_library(NPCOMPTorchDialect add_mlir_library(TorchMLIRTorchDialect
TorchDialect.cpp TorchDialect.cpp
TorchOps.cpp TorchOps.cpp
TorchTypes.cpp TorchTypes.cpp
ADDITIONAL_HEADER_DIRS ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Torch ${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/Torch
DEPENDS DEPENDS
MLIRTorchOpsIncGen MLIRTorchOpsIncGen
@ -17,6 +17,8 @@ add_npcomp_dialect_library(NPCOMPTorchDialect
MLIRIR MLIRIR
MLIRSupport MLIRSupport
MLIRControlFlowInterfaces MLIRControlFlowInterfaces
MLIRInferTypeOpInterface
MLIRSideEffectInterfaces MLIRSideEffectInterfaces
NPCOMPInterfaces
) )
torch_mlir_target_includes(TorchMLIRTorchDialect)

View File

@ -6,20 +6,20 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "npcomp/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/DialectImplementation.h"
#include "mlir/Transforms/InliningUtils.h" #include "mlir/Transforms/InliningUtils.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/TypeSwitch.h"
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP; using namespace mlir::torch;
using namespace mlir::NPCOMP::Torch; using namespace mlir::torch::Torch;
#include "npcomp/Dialect/Torch/IR/TorchDialect.cpp.inc" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.cpp.inc"
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Dialect Interfaces // Dialect Interfaces
@ -44,7 +44,7 @@ struct TorchInlinerInterface : public DialectInlinerInterface {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#define GET_TYPEDEF_CLASSES #define GET_TYPEDEF_CLASSES
#include "npcomp/Dialect/Torch/IR/TorchTypes.cpp.inc" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.cpp.inc"
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Dialect initialize method. // Dialect initialize method.
@ -53,11 +53,11 @@ struct TorchInlinerInterface : public DialectInlinerInterface {
void TorchDialect::initialize() { void TorchDialect::initialize() {
addOperations< addOperations<
#define GET_OP_LIST #define GET_OP_LIST
#include "npcomp/Dialect/Torch/IR/TorchOps.cpp.inc" #include "torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc"
>(); >();
addTypes< addTypes<
#define GET_TYPEDEF_LIST #define GET_TYPEDEF_LIST
#include "npcomp/Dialect/Torch/IR/TorchTypes.cpp.inc" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.cpp.inc"
>(); >();
addInterfaces<TorchInlinerInterface>(); addInterfaces<TorchInlinerInterface>();
} }
@ -84,7 +84,6 @@ void TorchDialect::printType(Type type, DialectAsmPrinter &printer) const {
llvm_unreachable("unknown 'torch' type"); llvm_unreachable("unknown 'torch' type");
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Dialect-level verifiers. // Dialect-level verifiers.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -6,7 +6,7 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "npcomp/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
@ -15,8 +15,8 @@
#include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringMap.h"
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP; using namespace mlir::torch;
using namespace mlir::NPCOMP::Torch; using namespace mlir::torch::Torch;
// see https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h#L28 // see https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h#L28
static int64_t getDtypeIntegerFromMlirType(Type dtype) { static int64_t getDtypeIntegerFromMlirType(Type dtype) {
@ -36,9 +36,9 @@ static int64_t getDtypeIntegerFromMlirType(Type dtype) {
// Utilities // Utilities
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
Value mlir::NPCOMP::Torch::copyTensorToType(OpBuilder &builder, Location loc, Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc,
BaseTensorType newType, BaseTensorType newType,
Value tensor) { Value tensor) {
auto originalType = tensor.getType().cast<BaseTensorType>(); auto originalType = tensor.getType().cast<BaseTensorType>();
// Adjust the static information in the type to match between the original and // Adjust the static information in the type to match between the original and
// new types. // new types.
@ -393,7 +393,10 @@ void DerefineOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// if not removed eagerly by canonicalizer would prevent ReduceOpVariants // if not removed eagerly by canonicalizer would prevent ReduceOpVariants
// from converting certain tensors value semantics. // from converting certain tensors value semantics.
bool allAllowRefinement = bool allAllowRefinement =
llvm::all_of(op.getResult().getUsers(), allowsTypeRefinement); llvm::all_of(op.getResult().getUsers(), [](Operation *op) {
return op
->hasTrait<mlir::torch::Torch::OpTrait::AllowsTypeRefinement>();
});
if (!allAllowRefinement) if (!allAllowRefinement)
return failure(); return failure();
rewriter.replaceOp(op, op.getOperand()); rewriter.replaceOp(op, op.getOperand());
@ -1004,4 +1007,4 @@ OpFoldResult PrimDtypeOp::fold(ArrayRef<Attribute> operands) {
return nullptr; return nullptr;
} }
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "npcomp/Dialect/Torch/IR/TorchOps.cpp.inc" #include "torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc"

View File

@ -6,15 +6,15 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "npcomp/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/DialectImplementation.h"
#include "npcomp/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP; using namespace mlir::torch;
using namespace mlir::NPCOMP::Torch; using namespace mlir::torch::Torch;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TupleType // TupleType

View File

@ -14,20 +14,20 @@
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "npcomp/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP; using namespace mlir::torch;
using namespace mlir::NPCOMP::Torch; using namespace mlir::torch::Torch;
// Map from func name and arg index to the type bound for that arg. // Map from func name and arg index to the type bound for that arg.
// This is needed because to rewrite calls, we need the non-local information // This is needed because to rewrite calls, we need the non-local information
// from the func definition. // from the func definition.
// We also benefit from populating this all at once, which avoids ordering // We also benefit from populating this all at once, which avoids ordering
// issues between rewriting of func ops vs call ops. // issues between rewriting of func ops vs call ops.
using TypeBoundMap = DenseMap<std::pair<StringRef, int>, Type> ; using TypeBoundMap = DenseMap<std::pair<StringRef, int>, Type>;
namespace { namespace {
class AdjustCallingConventionForFunc : public OpConversionPattern<FuncOp> { class AdjustCallingConventionForFunc : public OpConversionPattern<FuncOp> {
@ -136,8 +136,8 @@ public:
return success(); return success();
} }
private: private:
TypeBoundMap &typeBoundMap; TypeBoundMap &typeBoundMap;
}; };
} // namespace } // namespace
@ -251,6 +251,6 @@ class AdjustCallingConventionsPass
} // namespace } // namespace
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
mlir::NPCOMP::Torch::createAdjustCallingConventionsPass() { mlir::torch::Torch::createAdjustCallingConventionsPass() {
return std::make_unique<AdjustCallingConventionsPass>(); return std::make_unique<AdjustCallingConventionsPass>();
} }

View File

@ -1,4 +1,4 @@
add_npcomp_conversion_library(NPCOMPTorchPasses add_mlir_library(TorchMLIRTorchPasses
AdjustCallingConventions.cpp AdjustCallingConventions.cpp
Passes.cpp Passes.cpp
GlobalizeObjectGraph.cpp GlobalizeObjectGraph.cpp
@ -10,10 +10,10 @@ add_npcomp_conversion_library(NPCOMPTorchPasses
RefineTypes.cpp RefineTypes.cpp
ADDITIONAL_HEADER_DIRS ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Torch/Transforms ${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/Torch/Transforms
DEPENDS DEPENDS
NPCOMPTorchPassIncGen TorchMLIRTorchPassIncGen
LINK_COMPONENTS LINK_COMPONENTS
Core Core
@ -22,6 +22,7 @@ add_npcomp_conversion_library(NPCOMPTorchPasses
MLIRIR MLIRIR
MLIRPass MLIRPass
MLIRTransforms MLIRTransforms
NPCOMPTorchDialect TorchMLIRTorchDialect
NPCOMPInterfaces
) )
torch_mlir_target_includes(TorchMLIRTorchPasses)

View File

@ -12,9 +12,9 @@
#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "npcomp/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringExtras.h"
@ -22,8 +22,8 @@
#include "llvm/ADT/StringSet.h" #include "llvm/ADT/StringSet.h"
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP; using namespace mlir::torch;
using namespace mlir::NPCOMP::Torch; using namespace mlir::torch::Torch;
static FailureOr<NnModuleOp> findRootNnModule(ModuleOp module) { static FailureOr<NnModuleOp> findRootNnModule(ModuleOp module) {
NnModuleOp rootNnModule; NnModuleOp rootNnModule;
@ -664,8 +664,6 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) {
module.push_back(newFunc); module.push_back(newFunc);
} }
for (auto &kv : newFuncs) { for (auto &kv : newFuncs) {
BlockAndValueMapping mapping; BlockAndValueMapping mapping;
if (failed(analyzeInstances(kv.second, kv.first.argInstances, mapping))) if (failed(analyzeInstances(kv.second, kv.first.argInstances, mapping)))
@ -706,6 +704,6 @@ class GlobalizeObjectGraphPass
} // namespace } // namespace
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
mlir::NPCOMP::Torch::createGlobalizeObjectGraphPass() { mlir::torch::Torch::createGlobalizeObjectGraphPass() {
return std::make_unique<GlobalizeObjectGraphPass>(); return std::make_unique<GlobalizeObjectGraphPass>();
} }

View File

@ -11,16 +11,16 @@
#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "npcomp/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringExtras.h"
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP; using namespace mlir::torch;
using namespace mlir::NPCOMP::Torch; using namespace mlir::torch::Torch;
namespace { namespace {
class InlineGlobalSlotsPass class InlineGlobalSlotsPass
@ -87,6 +87,6 @@ class InlineGlobalSlotsPass
} // namespace } // namespace
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
mlir::NPCOMP::Torch::createInlineGlobalSlotsPass() { mlir::torch::Torch::createInlineGlobalSlotsPass() {
return std::make_unique<InlineGlobalSlotsPass>(); return std::make_unique<InlineGlobalSlotsPass>();
} }

View File

@ -12,12 +12,12 @@
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP; using namespace mlir::torch;
using namespace mlir::NPCOMP::Torch; using namespace mlir::torch::Torch;
namespace { namespace {
class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
@ -131,6 +131,6 @@ class MaximizeValueSemanticsPass
} // namespace } // namespace
std::unique_ptr<OperationPass<FuncOp>> std::unique_ptr<OperationPass<FuncOp>>
mlir::NPCOMP::Torch::createMaximizeValueSemanticsPass() { mlir::torch::Torch::createMaximizeValueSemanticsPass() {
return std::make_unique<MaximizeValueSemanticsPass>(); return std::make_unique<MaximizeValueSemanticsPass>();
} }

View File

@ -6,20 +6,20 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H #ifndef TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
#define NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H #define TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
namespace mlir { namespace mlir {
namespace NPCOMP { namespace torch {
namespace Torch { namespace Torch {
#define GEN_PASS_CLASSES #define GEN_PASS_CLASSES
#include "npcomp/Dialect/Torch/Transforms/Passes.h.inc" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc"
} // namespace Torch } // namespace Torch
} // namespace NPCOMP } // namespace torch
} // end namespace mlir } // end namespace mlir
#endif // NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H #endif // TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H

View File

@ -6,7 +6,7 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "npcomp/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Passes.h"
@ -16,22 +16,22 @@
namespace { namespace {
#define GEN_PASS_REGISTRATION #define GEN_PASS_REGISTRATION
#include "npcomp/Dialect/Torch/Transforms/Passes.h.inc" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc"
} // end namespace } // end namespace
void mlir::NPCOMP::registerTorchPasses() { void mlir::torch::registerTorchPasses() {
::registerPasses(); ::registerPasses();
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>( mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torchscript-to-torch-backend-pipeline", "torchscript-to-torch-backend-pipeline",
"Pipeline lowering TorchScript object graph IR to Torch backend form.", "Pipeline lowering TorchScript object graph IR to Torch backend form.",
mlir::NPCOMP::Torch::createTorchScriptToTorchBackendPipeline); mlir::torch::Torch::createTorchScriptToTorchBackendPipeline);
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>( mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torch-globalized-module-to-torch-backend-pipeline", "torch-globalized-module-to-torch-backend-pipeline",
"Pipeline lowering a globalized Torch program to Torch backend form.", "Pipeline lowering a globalized Torch program to Torch backend form.",
mlir::NPCOMP::Torch::createGlobalizedModuleToTorchBackendPipeline); mlir::torch::Torch::createGlobalizedModuleToTorchBackendPipeline);
} }
void mlir::NPCOMP::Torch::createTorchScriptToTorchBackendPipeline( void mlir::torch::Torch::createTorchScriptToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options) { OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
// When we import TorchScript IR, we import their entire "compilation unit", // When we import TorchScript IR, we import their entire "compilation unit",
// which can contain numerous functions unrelated to the current program, // which can contain numerous functions unrelated to the current program,
@ -62,7 +62,7 @@ void mlir::NPCOMP::Torch::createTorchScriptToTorchBackendPipeline(
createGlobalizedModuleToTorchBackendPipeline(pm, options); createGlobalizedModuleToTorchBackendPipeline(pm, options);
} }
void mlir::NPCOMP::Torch::createGlobalizedModuleToTorchBackendPipeline( void mlir::torch::Torch::createGlobalizedModuleToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options) { OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
// General considerations: As a matter of bring-up, we are simultaneously // General considerations: As a matter of bring-up, we are simultaneously
// building out the frontend pipeline and also co-developing the backend // building out the frontend pipeline and also co-developing the backend

View File

@ -14,13 +14,13 @@
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "npcomp/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP; using namespace mlir::torch;
using namespace mlir::NPCOMP::Torch; using namespace mlir::torch::Torch;
namespace { namespace {
class ConvertPrimCallMethodToCall : public OpRewritePattern<PrimCallMethodOp> { class ConvertPrimCallMethodToCall : public OpRewritePattern<PrimCallMethodOp> {
@ -93,16 +93,15 @@ class PrepareForGlobalizeObjectGraphPass
// to the form we want. // to the form we want.
ConversionTarget target(*context); ConversionTarget target(*context);
target.addIllegalOp<PrimCallMethodOp>(); target.addIllegalOp<PrimCallMethodOp>();
target.addDynamicallyLegalOp<ConstantOp>([](ConstantOp op) { target.addDynamicallyLegalOp<ConstantOp>(
return !op.getType().isa<FunctionType>(); [](ConstantOp op) { return !op.getType().isa<FunctionType>(); });
});
target.addIllegalOp<CallIndirectOp>(); target.addIllegalOp<CallIndirectOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
RewritePatternSet dummyPatterns(context); RewritePatternSet dummyPatterns(context);
if (failed(applyFullConversion(getOperation(), target, if (failed(applyFullConversion(getOperation(), target,
std::move(dummyPatterns)))) { std::move(dummyPatterns)))) {
return signalPassFailure(); return signalPassFailure();
} }
} }
@ -110,6 +109,6 @@ class PrepareForGlobalizeObjectGraphPass
} // namespace } // namespace
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
mlir::NPCOMP::Torch::createPrepareForGlobalizeObjectGraphPass() { mlir::torch::Torch::createPrepareForGlobalizeObjectGraphPass() {
return std::make_unique<PrepareForGlobalizeObjectGraphPass>(); return std::make_unique<PrepareForGlobalizeObjectGraphPass>();
} }

View File

@ -9,13 +9,13 @@
#include "PassDetail.h" #include "PassDetail.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringExtras.h"
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP; using namespace mlir::torch;
using namespace mlir::NPCOMP::Torch; using namespace mlir::torch::Torch;
namespace { namespace {
// Convert value semantic ops operating on mutable arrays to instead operate on // Convert value semantic ops operating on mutable arrays to instead operate on
@ -145,6 +145,6 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
} // namespace } // namespace
std::unique_ptr<OperationPass<FuncOp>> std::unique_ptr<OperationPass<FuncOp>>
mlir::NPCOMP::Torch::createReduceOpVariantsPass() { mlir::torch::Torch::createReduceOpVariantsPass() {
return std::make_unique<ReduceOpVariantsPass>(); return std::make_unique<ReduceOpVariantsPass>();
} }

View File

@ -11,12 +11,12 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP; using namespace mlir::torch;
using namespace mlir::NPCOMP::Torch; using namespace mlir::torch::Torch;
namespace { namespace {
@ -86,6 +86,6 @@ class RefinePublicReturnPass
} // namespace } // namespace
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
mlir::NPCOMP::Torch::createRefinePublicReturnPass() { mlir::torch::Torch::createRefinePublicReturnPass() {
return std::make_unique<RefinePublicReturnPass>(); return std::make_unique<RefinePublicReturnPass>();
} }

View File

@ -15,13 +15,13 @@
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "npcomp/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP; using namespace mlir::torch;
using namespace mlir::NPCOMP::Torch; using namespace mlir::torch::Torch;
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
// Analysis. // Analysis.
@ -264,9 +264,8 @@ public:
} else if (auto sum = dyn_cast<AtenSumOp>(op)) { } else if (auto sum = dyn_cast<AtenSumOp>(op)) {
return visitReductionAlongAllDimsOp(sum, operands); return visitReductionAlongAllDimsOp(sum, operands);
} else if (auto sumDimIntList = dyn_cast<AtenSumDimIntListOp>(op)) { } else if (auto sumDimIntList = dyn_cast<AtenSumDimIntListOp>(op)) {
return visitReductionAlongDimIntListOp( return visitReductionAlongDimIntListOp(sumDimIntList, sumDimIntList.dim(),
sumDimIntList, sumDimIntList.dim(), sumDimIntList.keepdim(), sumDimIntList.keepdim(), operands);
operands);
} else if (auto meanDim = dyn_cast<AtenMeanDimOp>(op)) { } else if (auto meanDim = dyn_cast<AtenMeanDimOp>(op)) {
return visitReductionAlongDimIntListOp(meanDim, meanDim.dim(), return visitReductionAlongDimIntListOp(meanDim, meanDim.dim(),
meanDim.keepdim(), operands); meanDim.keepdim(), operands);
@ -1114,7 +1113,7 @@ static Type getMostRefinedStaticType(Value v, TypeAnalyzer &analyzer) {
// the right thing forthose ops. // the right thing forthose ops.
// //
static bool allowsTypeRefinementOrIsSafeToRefine(Operation *op) { static bool allowsTypeRefinementOrIsSafeToRefine(Operation *op) {
return allowsTypeRefinement(op) || return op->hasTrait<mlir::torch::Torch::OpTrait::AllowsTypeRefinement>() ||
isa<CopyToNonValueTensorOp, CopyToValueTensorOp>(op); isa<CopyToNonValueTensorOp, CopyToValueTensorOp>(op);
} }
@ -1244,6 +1243,6 @@ class RefineTypesPass : public RefineTypesBase<RefineTypesPass> {
} // namespace } // namespace
std::unique_ptr<OperationPass<FuncOp>> std::unique_ptr<OperationPass<FuncOp>>
mlir::NPCOMP::Torch::createRefineTypesPass() { mlir::torch::Torch::createRefineTypesPass() {
return std::make_unique<RefineTypesPass>(); return std::make_unique<RefineTypesPass>();
} }

View File

@ -0,0 +1,19 @@
//===----------------------------------------------------------------------===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/InitAll.h"
#include "mlir/IR/Dialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
void mlir::torch::registerAllDialects(mlir::DialectRegistry &registry) {
registry.insert<mlir::torch::Torch::TorchDialect>();
}
void mlir::torch::registerAllPasses() { mlir::torch::registerTorchPasses(); }

View File

@ -0,0 +1,30 @@
llvm_canonicalize_cmake_booleans(
MLIR_ENABLE_BINDINGS_PYTHON
)
configure_lit_site_cfg(
${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py
MAIN_CONFIG
${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py
)
set(TORCH_MLIR_TEST_DEPENDS
FileCheck count not
torch-mlir-opt
)
# XXX: Enable
#if(MLIR_ENABLE_BINDINGS_PYTHON)
# list(APPEND TORCH_MLIR_TEST_DEPENDS
# TorchMLIRPythonModules
# )
#endif()
add_lit_testsuite(check-torch-mlir "Running the torch-mlir regression tests"
${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${TORCH_MLIR_TEST_DEPENDS}
)
set_target_properties(check-torch-mlir PROPERTIES FOLDER "Tests")
add_lit_testsuites(TORCH_MLIR ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TORCH_MLIR_TEST_DEPENDS})

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s // RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
// Basic case. // Basic case.

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-globalize-object-graph -verify-diagnostics -split-input-file %s // RUN: torch-mlir-opt -torch-globalize-object-graph -verify-diagnostics -split-input-file %s
torch.class_type @c1 {} torch.class_type @c1 {}
torch.class_type @c2 {} torch.class_type @c2 {}

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s // RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
torch.class_type @c { torch.class_type @c {
torch.attr "float" : !torch.float torch.attr "float" : !torch.float

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s // RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
// CHECK that multiple nested initialization ops are properly handled. // CHECK that multiple nested initialization ops are properly handled.

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s // RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
torch.class_type @c { torch.class_type @c {
torch.attr "float" : !torch.float torch.attr "float" : !torch.float

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file -verify-diagnostics %s // RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file -verify-diagnostics %s
torch.class_type @parent { torch.class_type @parent {
torch.method "module_type_return", @module_type_return torch.method "module_type_return", @module_type_return

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s // RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
torch.class_type @child { torch.class_type @child {
torch.attr "float" : !torch.float torch.attr "float" : !torch.float

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-globalize-object-graph -verify-diagnostics -split-input-file %s // RUN: torch-mlir-opt -torch-globalize-object-graph -verify-diagnostics -split-input-file %s
func private @__torch__.Submodule.forward(%arg0: !torch.nn.Module<"__torch__.Submodule">, %arg1: !torch.nn.Module<"__torch__.Submodule">) { func private @__torch__.Submodule.forward(%arg0: !torch.nn.Module<"__torch__.Submodule">, %arg1: !torch.nn.Module<"__torch__.Submodule">) {
return return

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s // RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
// Tests monomorphization of same function with different instance argument types. // Tests monomorphization of same function with different instance argument types.

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s // RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
torch.class_type @__torch__.TestModule { torch.class_type @__torch__.TestModule {
torch.attr private "s1" : !torch.nn.Module<"__torch__.Submodule"> torch.attr private "s1" : !torch.nn.Module<"__torch__.Submodule">

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s // RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
// Check that linkage names consist of the dotted path from the root. // Check that linkage names consist of the dotted path from the root.

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s // RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
torch.class_type @c { torch.class_type @c {
// CHECK: torch.global_slot "private" @float : !torch.float // CHECK: torch.global_slot "private" @float : !torch.float

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-adjust-calling-conventions -allow-unregistered-dialect -split-input-file %s | FileCheck %s // RUN: torch-mlir-opt -torch-adjust-calling-conventions -allow-unregistered-dialect -split-input-file %s | FileCheck %s
// CHECK-LABEL: func @basic( // CHECK-LABEL: func @basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt %s -canonicalize | FileCheck %s // RUN: torch-mlir-opt %s -canonicalize | FileCheck %s
// CHECK-LABEL: func @torch.aten.__is__ // CHECK-LABEL: func @torch.aten.__is__
// CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[FALSE:.*]] = torch.constant.bool false

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-inline-global-slots -split-input-file %s | FileCheck %s // RUN: torch-mlir-opt -torch-inline-global-slots -split-input-file %s | FileCheck %s
// CHECK-NOT: @readonly // CHECK-NOT: @readonly
torch.global_slot "private" @readonly : !torch.tensor { torch.global_slot "private" @readonly : !torch.tensor {

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt <%s -split-input-file -verify-diagnostics // RUN: torch-mlir-opt <%s -split-input-file -verify-diagnostics
// ----- // -----

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -split-input-file -allow-unregistered-dialect %s -torch-maximize-value-semantics | FileCheck %s // RUN: torch-mlir-opt -split-input-file -allow-unregistered-dialect %s -torch-maximize-value-semantics | FileCheck %s
// CHECK-LABEL: func @torch.copy.tensor$basic( // CHECK-LABEL: func @torch.copy.tensor$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt %s | npcomp-opt | FileCheck %s // RUN: torch-mlir-opt %s | torch-mlir-opt | FileCheck %s
// CHECK-LABEL: func @torch.operator( // CHECK-LABEL: func @torch.operator(
func @torch.operator(%arg0: !torch.tensor, %arg1: !torch.tensor) -> !torch.tensor { func @torch.operator(%arg0: !torch.tensor, %arg1: !torch.tensor) -> !torch.tensor {

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-prepare-for-globalize-object-graph -split-input-file %s | FileCheck %s // RUN: torch-mlir-opt -torch-prepare-for-globalize-object-graph -split-input-file %s | FileCheck %s
torch.class_type @c { torch.class_type @c {
torch.method "test_call_method", @test_call_method torch.method "test_call_method", @test_call_method

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-reduce-op-variants %s | FileCheck %s // RUN: torch-mlir-opt -torch-reduce-op-variants %s | FileCheck %s
// CHECK-LABEL: func @convert_to_value_semantic_tensors( // CHECK-LABEL: func @convert_to_value_semantic_tensors(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -split-input-file -verify-diagnostics %s -torch-refine-public-return | FileCheck %s // RUN: torch-mlir-opt -split-input-file -verify-diagnostics %s -torch-refine-public-return | FileCheck %s
// CHECK-LABEL: func @basic( // CHECK-LABEL: func @basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor<[2,3,?],f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor<[2,3,?],f32> {

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-refine-types -split-input-file %s | FileCheck %s // RUN: torch-mlir-opt -torch-refine-types -split-input-file %s | FileCheck %s
// ----- // -----

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-refine-types -split-input-file %s | FileCheck %s // RUN: torch-mlir-opt -torch-refine-types -split-input-file %s | FileCheck %s
// CHECK-LABEL: func @f( // CHECK-LABEL: func @f(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {

View File

@ -0,0 +1,71 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import os
import platform
import re
import subprocess
import tempfile
import lit.formats
import lit.util
from lit.llvm import llvm_config
from lit.llvm.subst import ToolSubst
from lit.llvm.subst import FindTool
# Configuration file for the 'lit' test runner.
# name: The name of this test suite.
config.name = 'TORCH_MLIR'
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
# suffixes: A list of file extensions to treat as test files.
config.suffixes = ['.mlir', '.py']
# test_source_root: The root path where tests are located.
config.test_source_root = os.path.dirname(__file__)
# test_exec_root: The root path where tests should be run.
config.test_exec_root = os.path.join(config.torch_mlir_obj_root, 'test')
config.substitutions.append(('%PATH%', config.environment['PATH']))
config.substitutions.append(('%shlibext', config.llvm_shlib_ext))
llvm_config.with_system_environment(['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP'])
#llvm_config.use_default_substitutions()
# excludes: A list of directories to exclude from the testsuite. The 'Inputs'
# subdirectories contain auxiliary inputs for various tests in their parent
# directories.
config.excludes = [
'Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt',
'lit.cfg.py', 'lit.site.cfg.py'
]
# test_source_root: The root path where tests are located.
config.test_source_root = os.path.dirname(__file__)
# test_exec_root: The root path where tests should be run.
config.test_exec_root = os.path.join(config.torch_mlir_obj_root, 'test')
config.standalone_tools_dir = os.path.join(config.torch_mlir_obj_root, 'bin')
# Tweak the PATH to include the tools dir.
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
tool_dirs = [config.llvm_tools_dir]
tools = [
ToolSubst('%PYTHON', config.python_executable, unresolved='ignore'),
]
llvm_config.add_tool_substitutions(tools, tool_dirs)
if config.enable_bindings_python:
llvm_config.with_environment('PYTHONPATH', [
os.path.join(config.torch_mlir_obj_root, 'python_packages',
'torch_mlir'),
],
append_path=True)

View File

@ -0,0 +1,21 @@
@LIT_SITE_CFG_IN_HEADER@
import sys
config.enable_bindings_python = @MLIR_ENABLE_BINDINGS_PYTHON@
config.torch_mlir_obj_root = "@TORCH_MLIR_BINARY_DIR@"
config.llvm_src_root = "@LLVM_SOURCE_DIR@"
config.llvm_obj_root = "@LLVM_BINARY_DIR@"
config.llvm_tools_dir = "@LLVM_TOOLS_DIR@"
config.llvm_lib_dir = "@LLVM_LIBS_DIR@"
config.llvm_shlib_dir = "@SHLIBDIR@"
config.llvm_shlib_ext = "@SHLIBEXT@"
config.llvm_exe_ext = "@EXEEXT@"
config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@"
config.python_executable = sys.executable
import lit.llvm
lit.llvm.initialize(lit_config, config)
# Let the main config do the real work.
lit_config.load_config(config, "@TORCH_MLIR_SOURCE_DIR@/test/lit.cfg.py")

View File

@ -0,0 +1,2 @@
if not config.enable_bindings_python:
config.unsupported = True

View File

@ -0,0 +1,9 @@
# RUN: %PYTHON %s
# XXX: Fix this
# XFAIL: *
import mlir.ir
from mlir.dialects import iree
with mlir.ir.Context() as ctx:
iree.register_iree_dialect(ctx)

View File

@ -0,0 +1 @@
add_subdirectory(torch-mlir-opt)

View File

@ -0,0 +1,13 @@
add_executable(torch-mlir-opt torch-mlir-opt.cpp)
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
target_link_libraries(torch-mlir-opt PRIVATE
MLIROptLib
TorchMLIRInitAll
TorchMLIRTorchDialect
TorchMLIRTorchPasses
${dialect_libs}
${conversion_libs}
)

View File

@ -0,0 +1,27 @@
//===- torch-mlir-opt.cpp - MLIR Optimizer Driver -------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Support/MlirOptMain.h"
#include "torch-mlir/InitAll.h"
using namespace mlir;
int main(int argc, char **argv) {
registerAllPasses();
mlir::torch::registerAllPasses();
DialectRegistry registry;
registerAllDialects(registry);
mlir::torch::registerAllDialects(registry);
return mlir::asMainReturnCode(
mlir::MlirOptMain(argc, argv, "MLIR modular optimizer driver\n", registry,
/*preloadDialectsInContext=*/false));
}

View File

@ -11,7 +11,7 @@
#include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h" #include "mlir-c/BuiltinTypes.h"
#include "npcomp-c/TorchTypes.h" #include "torch-mlir-c/TorchTypes.h"
#include <ATen/core/function_schema.h> #include <ATen/core/function_schema.h>
#include <ATen/core/ivalue.h> #include <ATen/core/ivalue.h>
@ -510,14 +510,14 @@ MlirType AcapController::mapIValueToMlirType(MlirLocation loc,
return mlirIntegerTypeGet(funcBuilder->getContext(), 1); return mlirIntegerTypeGet(funcBuilder->getContext(), 1);
} }
if (ival.isList()) { if (ival.isList()) {
return npcompTorchListTypeGet( return torchMlirTorchListTypeGet(
typeMapper.mapFromTorchType(loc, ival.toList().elementType())); typeMapper.mapFromTorchType(loc, ival.toList().elementType()));
} }
if (ival.isNone()) { if (ival.isNone()) {
return npcompTorchNoneTypeGet(funcBuilder->getContext()); return torchMlirTorchNoneTypeGet(funcBuilder->getContext());
} }
if (ival.isDevice()) { if (ival.isDevice()) {
return npcompTorchNoneTypeGet(funcBuilder->getContext()); return torchMlirTorchNoneTypeGet(funcBuilder->getContext());
} }
return {nullptr}; return {nullptr};
} }
@ -527,7 +527,7 @@ MlirValue AcapController::importTensorByValue(at::Tensor tensor) {
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc); MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
MlirOperation tensorOp = createMlirOperationAtEnd( MlirOperation tensorOp = createMlirOperationAtEnd(
funcBuilder->getEntryBlock(), "torch.tensor.literal", loc, funcBuilder->getEntryBlock(), "torch.tensor.literal", loc,
npcompTorchNonValueTensorTypeGetFromShaped( torchMlirTorchNonValueTensorTypeGetFromShaped(
mlirAttributeGetType(denseElements)), mlirAttributeGetType(denseElements)),
toMlirNamedAttribute("value", denseElements)); toMlirNamedAttribute("value", denseElements));
MlirValue tensorValue = mlirOperationGetResult(tensorOp, 0); MlirValue tensorValue = mlirOperationGetResult(tensorOp, 0);

View File

@ -12,7 +12,7 @@
#include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h" #include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h" #include "mlir-c/Diagnostics.h"
#include "npcomp-c/TorchTypes.h" #include "torch-mlir-c/TorchTypes.h"
using namespace torch_mlir; using namespace torch_mlir;
@ -98,7 +98,7 @@ MlirValue FuncBuilder::getScalarConstant(MlirLocation loc, at::Scalar s) {
// represented as one of double or int64_t, with a special tag for whether // represented as one of double or int64_t, with a special tag for whether
// it should be interpreted as a bool. // it should be interpreted as a bool.
if (s.isIntegral(/*includeBool=*/false)) { if (s.isIntegral(/*includeBool=*/false)) {
MlirType t = npcompTorchIntTypeGet(context); MlirType t = torchMlirTorchIntTypeGet(context);
MlirAttribute value = MlirAttribute value =
mlirIntegerAttrGet(mlirIntegerTypeGet(context, 64), s.to<int64_t>()); mlirIntegerAttrGet(mlirIntegerTypeGet(context, 64), s.to<int64_t>());
MlirOperation op = createMlirOperation( MlirOperation op = createMlirOperation(
@ -107,7 +107,7 @@ MlirValue FuncBuilder::getScalarConstant(MlirLocation loc, at::Scalar s) {
return mlirOperationGetResult(op, 0); return mlirOperationGetResult(op, 0);
} }
if (s.isFloatingPoint()) { if (s.isFloatingPoint()) {
MlirType t = npcompTorchFloatTypeGet(context); MlirType t = torchMlirTorchFloatTypeGet(context);
MlirAttribute value = mlirFloatAttrDoubleGet( MlirAttribute value = mlirFloatAttrDoubleGet(
context, mlirF64TypeGet(context), s.to<double>()); context, mlirF64TypeGet(context), s.to<double>());
MlirOperation op = createMlirOperation( MlirOperation op = createMlirOperation(
@ -133,7 +133,7 @@ MlirValue FuncBuilder::getNoneConstant(MlirLocation loc) {
MlirValue FuncBuilder::buildList(MlirLocation loc, MlirType elementType, MlirValue FuncBuilder::buildList(MlirLocation loc, MlirType elementType,
std::vector<MlirValue> &elements) { std::vector<MlirValue> &elements) {
MlirType resultType = npcompTorchListTypeGet(elementType); MlirType resultType = torchMlirTorchListTypeGet(elementType);
OperationStateHolder state{"torch.prim.ListConstruct", loc}; OperationStateHolder state{"torch.prim.ListConstruct", loc};
mlirOperationStateAddResults(state, 1, &resultType); mlirOperationStateAddResults(state, 1, &resultType);
mlirOperationStateAddOperands(state, elements.size(), elements.data()); mlirOperationStateAddOperands(state, elements.size(), elements.data());

View File

@ -16,7 +16,7 @@
#include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h" #include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h" #include "mlir-c/Diagnostics.h"
#include "npcomp-c/TorchTypes.h" #include "torch-mlir-c/TorchTypes.h"
#include "caffe2/core/scope_guard.h" #include "caffe2/core/scope_guard.h"
#include "ATen/native/quantized/cpu/packed_params.h" #include "ATen/native/quantized/cpu/packed_params.h"
@ -170,7 +170,7 @@ IValueImporter::importModule(torch::jit::Module currentModule) {
MlirOperation nnModule = createMlirOperation( MlirOperation nnModule = createMlirOperation(
"torch.nn_module", loc, "torch.nn_module", loc,
npcompTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)), torchMlirTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
mlirRegionCreate()); mlirRegionCreate());
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0); MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr)); mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr));
@ -240,7 +240,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
MlirLocation loc = mlirLocationUnknownGet(context); MlirLocation loc = mlirLocationUnknownGet(context);
if (ivalue.isBool()) { if (ivalue.isBool()) {
MlirType type = npcompTorchBoolTypeGet(context); MlirType type = torchMlirTorchBoolTypeGet(context);
MlirOperation operation = createMlirOperationAtEnd( MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.constant.bool", loc, type, importBlock, "torch.constant.bool", loc, type,
toMlirNamedAttribute("value", toMlirNamedAttribute("value",
@ -248,7 +248,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
return mlirOperationGetResult(operation, 0); return mlirOperationGetResult(operation, 0);
} }
if (ivalue.isDouble()) { if (ivalue.isDouble()) {
MlirType type = npcompTorchFloatTypeGet(context); MlirType type = torchMlirTorchFloatTypeGet(context);
MlirOperation operation = createMlirOperationAtEnd( MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.constant.float", loc, type, importBlock, "torch.constant.float", loc, type,
toMlirNamedAttribute( toMlirNamedAttribute(
@ -257,7 +257,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
return mlirOperationGetResult(operation, 0); return mlirOperationGetResult(operation, 0);
} }
if (ivalue.isInt()) { if (ivalue.isInt()) {
MlirType type = npcompTorchIntTypeGet(context); MlirType type = torchMlirTorchIntTypeGet(context);
MlirOperation operation = createMlirOperationAtEnd( MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.constant.int", loc, type, importBlock, "torch.constant.int", loc, type,
toMlirNamedAttribute("value", toMlirNamedAttribute("value",
@ -273,7 +273,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
} }
MlirOperation operation = createMlirOperationAtEnd( MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.prim.ListConstruct", loc, importBlock, "torch.prim.ListConstruct", loc,
npcompTorchListTypeGet( torchMlirTorchListTypeGet(
typeMapper.mapFromTorchType(loc, list.elementType())), typeMapper.mapFromTorchType(loc, list.elementType())),
elems); elems);
return mlirOperationGetResult(operation, 0); return mlirOperationGetResult(operation, 0);
@ -288,7 +288,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
} }
MlirOperation operation = createMlirOperationAtEnd( MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.prim.DictConstruct", loc, importBlock, "torch.prim.DictConstruct", loc,
npcompTorchDictTypeGet( torchMlirTorchDictTypeGet(
typeMapper.mapFromTorchType(loc, dict.keyType()), typeMapper.mapFromTorchType(loc, dict.keyType()),
typeMapper.mapFromTorchType(loc, dict.valueType())), typeMapper.mapFromTorchType(loc, dict.valueType())),
keys, values); keys, values);
@ -305,7 +305,8 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
} }
MlirOperation operation = createMlirOperationAtEnd( MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.prim.TupleConstruct", loc, importBlock, "torch.prim.TupleConstruct", loc,
npcompTorchTupleTypeGet(context, types.size(), types.data()), operands); torchMlirTorchTupleTypeGet(context, types.size(), types.data()),
operands);
return mlirOperationGetResult(operation, 0); return mlirOperationGetResult(operation, 0);
} }
if (ivalue.isTensor()) { if (ivalue.isTensor()) {
@ -317,7 +318,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
if (ivalue.isString()) { if (ivalue.isString()) {
MlirOperation operation = createMlirOperationAtEnd( MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.constant.str", loc, importBlock, "torch.constant.str", loc,
npcompTorchStringTypeGet(context), torchMlirTorchStringTypeGet(context),
toMlirNamedAttribute( toMlirNamedAttribute(
"value", "value",
mlirStringAttrGet(context, mlirStringAttrGet(context,
@ -327,7 +328,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
if (ivalue.isNone()) { if (ivalue.isNone()) {
MlirOperation operation = MlirOperation operation =
createMlirOperationAtEnd(importBlock, "torch.constant.none", loc, createMlirOperationAtEnd(importBlock, "torch.constant.none", loc,
npcompTorchNoneTypeGet(context)); torchMlirTorchNoneTypeGet(context));
return mlirOperationGetResult(operation, 0); return mlirOperationGetResult(operation, 0);
} }
if (ivalue.isCustomClass()) { if (ivalue.isCustomClass()) {
@ -346,7 +347,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
} }
MlirOperation operation = createMlirOperationAtEnd( MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.linear_params.create", loc, importBlock, "torch.linear_params.create", loc,
npcompTorchLinearParamsTypeGet(context), weightValue, biasValue); torchMlirTorchLinearParamsTypeGet(context), weightValue, biasValue);
return mlirOperationGetResult(operation, 0); return mlirOperationGetResult(operation, 0);
} }
} }
@ -366,7 +367,7 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc); MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
MlirOperation tensorOp = MlirOperation tensorOp =
createMlirOperationAtEnd(importBlock, "torch.tensor.literal", loc, createMlirOperationAtEnd(importBlock, "torch.tensor.literal", loc,
npcompTorchNonValueTensorTypeGetFromShaped( torchMlirTorchNonValueTensorTypeGetFromShaped(
mlirAttributeGetType(denseElements)), mlirAttributeGetType(denseElements)),
toMlirNamedAttribute("value", denseElements)); toMlirNamedAttribute("value", denseElements));
MlirValue tensorReprValue = mlirOperationGetResult(tensorOp, 0); MlirValue tensorReprValue = mlirOperationGetResult(tensorOp, 0);
@ -381,7 +382,7 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
// compiler stages that are building a statically modeled quantization // compiler stages that are building a statically modeled quantization
// representation will need to convert this to their representation. // representation will need to convert this to their representation.
std::vector<int64_t> shape(tensor.sizes().begin(), tensor.sizes().end()); std::vector<int64_t> shape(tensor.sizes().begin(), tensor.sizes().end());
MlirType quantizedTensorType = npcompTorchNonValueTensorTypeGet( MlirType quantizedTensorType = torchMlirTorchNonValueTensorTypeGet(
context, shape.size(), shape.data(), context, shape.size(), shape.data(),
typeMapper.mapFromTorchScalarType(tensor.scalar_type())); typeMapper.mapFromTorchScalarType(tensor.scalar_type()));
if (tensor.qscheme() == c10::kPerTensorAffine) { if (tensor.qscheme() == c10::kPerTensorAffine) {
@ -531,11 +532,11 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
int64_t dummy; int64_t dummy;
int64_t *shapeData = shape.size() == 0 ? &dummy : shape.data(); int64_t *shapeData = shape.size() == 0 ? &dummy : shape.data();
if (hasValueSemantics) { if (hasValueSemantics) {
typeBound = npcompTorchValueTensorTypeGet(context, shape.size(), typeBound = torchMlirTorchValueTensorTypeGet(context, shape.size(),
shapeData, dtype);
} else {
typeBound = npcompTorchNonValueTensorTypeGet(context, shape.size(),
shapeData, dtype); shapeData, dtype);
} else {
typeBound = torchMlirTorchNonValueTensorTypeGet(
context, shape.size(), shapeData, dtype);
} }
MlirNamedAttribute typeBoundAttr = toMlirNamedAttribute( MlirNamedAttribute typeBoundAttr = toMlirNamedAttribute(

View File

@ -16,6 +16,7 @@
#include "mlir-c/Diagnostics.h" #include "mlir-c/Diagnostics.h"
#include "mlir-c/Registration.h" #include "mlir-c/Registration.h"
#include "npcomp-c/Registration.h" #include "npcomp-c/Registration.h"
#include "torch-mlir-c/Registration.h"
namespace py = pybind11; namespace py = pybind11;
using namespace torch_mlir; using namespace torch_mlir;
@ -114,7 +115,7 @@ ModuleBuilder::ModuleBuilder(pybind11::object contextObj)
// TODO: Rework this once dialect registration C-APIs are in place. // TODO: Rework this once dialect registration C-APIs are in place.
// https://reviews.llvm.org/D88162 // https://reviews.llvm.org/D88162
mlirRegisterAllDialects(context); mlirRegisterAllDialects(context);
npcompRegisterAllDialects(context); torchMlirRegisterAllDialects(context);
registerPythonSysStderrDiagnosticHandler(context); registerPythonSysStderrDiagnosticHandler(context);

View File

@ -15,7 +15,7 @@
#include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h" #include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h" #include "mlir-c/Diagnostics.h"
#include "npcomp-c/TorchTypes.h" #include "torch-mlir-c/TorchTypes.h"
namespace py = pybind11; namespace py = pybind11;
using namespace torch_mlir; using namespace torch_mlir;
@ -150,7 +150,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
importAttribute(loc, node, c10::attr::value))); importAttribute(loc, node, c10::attr::value)));
} else if (output->type()->cast<c10::StringType>()) { } else if (output->type()->cast<c10::StringType>()) {
op = createMlirOperation( op = createMlirOperation(
"torch.constant.str", loc, npcompTorchStringTypeGet(context), "torch.constant.str", loc, torchMlirTorchStringTypeGet(context),
toMlirNamedAttribute( toMlirNamedAttribute(
"value", mlirStringAttrGet(context, toMlirStringRef(node->s( "value", mlirStringAttrGet(context, toMlirStringRef(node->s(
c10::attr::value))))); c10::attr::value)))));
@ -186,7 +186,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
mlirRegionCreate()); mlirRegionCreate());
mapResults(node, operation); mapResults(node, operation);
std::vector<MlirType> terminatorOperandTypes = { std::vector<MlirType> terminatorOperandTypes = {
npcompTorchBoolTypeGet(context)}; torchMlirTorchBoolTypeGet(context)};
terminatorOperandTypes.insert(terminatorOperandTypes.end(), terminatorOperandTypes.insert(terminatorOperandTypes.end(),
resultTypes.begin(), resultTypes.end()); resultTypes.begin(), resultTypes.end());
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues, auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,

View File

@ -10,7 +10,7 @@
#include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h" #include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h" #include "mlir-c/Diagnostics.h"
#include "npcomp-c/TorchTypes.h" #include "torch-mlir-c/TorchTypes.h"
using namespace torch_mlir; using namespace torch_mlir;
@ -18,11 +18,11 @@ OpBuilder::OpBuilder(MlirContext context) : context(context) {}
MlirOperation OpBuilder::createNoneConstant(MlirLocation loc) { MlirOperation OpBuilder::createNoneConstant(MlirLocation loc) {
return createMlirOperation("torch.constant.none", loc, return createMlirOperation("torch.constant.none", loc,
npcompTorchNoneTypeGet(context)); torchMlirTorchNoneTypeGet(context));
} }
MlirOperation OpBuilder::createBoolConstant(MlirLocation loc, bool value) { MlirOperation OpBuilder::createBoolConstant(MlirLocation loc, bool value) {
return createMlirOperation( return createMlirOperation(
"torch.constant.bool", loc, npcompTorchBoolTypeGet(context), "torch.constant.bool", loc, torchMlirTorchBoolTypeGet(context),
toMlirNamedAttribute("value", mlirBoolAttrGet(context, value))); toMlirNamedAttribute("value", mlirBoolAttrGet(context, value)));
} }

View File

@ -15,7 +15,7 @@
#include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h" #include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h" #include "mlir-c/Diagnostics.h"
#include "npcomp-c/TorchTypes.h" #include "torch-mlir-c/TorchTypes.h"
using namespace torch_mlir; using namespace torch_mlir;
@ -66,7 +66,7 @@ MlirType TypeMapper::rawMapFromTorchScalarType(c10::ScalarType scalarType) {
case ScalarType::Half: case ScalarType::Half:
return mlirF16TypeGet(context); return mlirF16TypeGet(context);
case ScalarType::QInt8: case ScalarType::QInt8:
return npcompTorchQInt8TypeGet(context); return torchMlirTorchQInt8TypeGet(context);
default: { default: {
return {nullptr}; return {nullptr};
} }
@ -105,7 +105,7 @@ static MlirType mapCustomClassType(MlirContext context, MlirLocation loc,
// Individually handle the custom classes that we know about. // Individually handle the custom classes that we know about.
if (name == "__torch__.torch.classes.quantized.LinearPackedParamsBase") { if (name == "__torch__.torch.classes.quantized.LinearPackedParamsBase") {
return npcompTorchLinearParamsTypeGet(context); return torchMlirTorchLinearParamsTypeGet(context);
} }
// At this point, we know that the type is indeed a custom class type, but // At this point, we know that the type is indeed a custom class type, but
@ -136,11 +136,11 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
auto &sizes = tensorType->symbolic_sizes(); auto &sizes = tensorType->symbolic_sizes();
if (!sizes.rank()) { if (!sizes.rank()) {
// Unranked. // Unranked.
return npcompTorchNonValueTensorTypeGet(context, return torchMlirTorchNonValueTensorTypeGet(context,
/*numSizes=*/0, /*numSizes=*/0,
/*optionalSizes=*/nullptr, /*optionalSizes=*/nullptr,
/*optionalDtype=*/ /*optionalDtype=*/
elementType); elementType);
} }
// Ranked with possibly dynamic dims. // Ranked with possibly dynamic dims.
auto &symbolicShape = tensorType->symbolic_sizes(); auto &symbolicShape = tensorType->symbolic_sizes();
@ -150,28 +150,28 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
auto shapeSymbol = symbolicShape[i]; auto shapeSymbol = symbolicShape[i];
dims[i] = shapeSymbol.is_static() ? shapeSymbol.static_size() : -1; dims[i] = shapeSymbol.is_static() ? shapeSymbol.static_size() : -1;
} }
return npcompTorchNonValueTensorTypeGet(context, dims.size(), return torchMlirTorchNonValueTensorTypeGet(context, dims.size(),
/*optionalSizes=*/dims.data(), /*optionalSizes=*/dims.data(),
/*optionalDtype=*/ /*optionalDtype=*/
elementType); elementType);
} }
case TypeKind::IntType: { case TypeKind::IntType: {
return npcompTorchIntTypeGet(context); return torchMlirTorchIntTypeGet(context);
} }
case TypeKind::FloatType: { case TypeKind::FloatType: {
return npcompTorchFloatTypeGet(context); return torchMlirTorchFloatTypeGet(context);
} }
case TypeKind::BoolType: { case TypeKind::BoolType: {
return npcompTorchBoolTypeGet(context); return torchMlirTorchBoolTypeGet(context);
} }
case TypeKind::NumberType: { case TypeKind::NumberType: {
return npcompTorchNumberTypeGet(context); return torchMlirTorchNumberTypeGet(context);
} }
case TypeKind::StringType: { case TypeKind::StringType: {
return npcompTorchStringTypeGet(context); return torchMlirTorchStringTypeGet(context);
} }
case TypeKind::OptionalType: { case TypeKind::OptionalType: {
return npcompTorchOptionalTypeGet(mapFromTorchType( return torchMlirTorchOptionalTypeGet(mapFromTorchType(
loc, torchType->cast<c10::OptionalType>()->getElementType())); loc, torchType->cast<c10::OptionalType>()->getElementType()));
} }
case TypeKind::TupleType: { case TypeKind::TupleType: {
@ -180,25 +180,25 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
torchType->cast<c10::TupleType>()->containedTypes()) { torchType->cast<c10::TupleType>()->containedTypes()) {
containedTypes.push_back(mapFromTorchType(loc, type)); containedTypes.push_back(mapFromTorchType(loc, type));
} }
return npcompTorchTupleTypeGet(context, containedTypes.size(), return torchMlirTorchTupleTypeGet(context, containedTypes.size(),
containedTypes.data()); containedTypes.data());
} }
case TypeKind::ListType: { case TypeKind::ListType: {
return npcompTorchListTypeGet(mapFromTorchType( return torchMlirTorchListTypeGet(mapFromTorchType(
loc, torchType->cast<c10::ListType>()->getElementType())); loc, torchType->cast<c10::ListType>()->getElementType()));
} }
case TypeKind::DictType: { case TypeKind::DictType: {
auto dictType = torchType->cast<c10::DictType>(); auto dictType = torchType->cast<c10::DictType>();
return npcompTorchDictTypeGet( return torchMlirTorchDictTypeGet(
mapFromTorchType(loc, dictType->getKeyType()), mapFromTorchType(loc, dictType->getKeyType()),
mapFromTorchType(loc, dictType->getValueType())); mapFromTorchType(loc, dictType->getValueType()));
} }
case TypeKind::NoneType: { case TypeKind::NoneType: {
return npcompTorchNoneTypeGet(context); return torchMlirTorchNoneTypeGet(context);
} }
case TypeKind::AnyType: { case TypeKind::AnyType: {
auto anyType = torchType->cast<c10::AnyType>(); auto anyType = torchType->cast<c10::AnyType>();
return npcompTorchAnyTypeGet(context); return torchMlirTorchAnyTypeGet(context);
} }
case TypeKind::ClassType: { case TypeKind::ClassType: {
const c10::ClassTypePtr &classType = torchType->cast<c10::ClassType>(); const c10::ClassTypePtr &classType = torchType->cast<c10::ClassType>();
@ -208,10 +208,10 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
} }
auto maybeName = classType->name(); auto maybeName = classType->name();
std::string name = maybeName ? maybeName->qualifiedName() : "unnamed class"; std::string name = maybeName ? maybeName->qualifiedName() : "unnamed class";
return npcompTorchNnModuleTypeGet(context, toMlirStringRef(name)); return torchMlirTorchNnModuleTypeGet(context, toMlirStringRef(name));
} }
case TypeKind::DeviceObjType: { case TypeKind::DeviceObjType: {
return npcompTorchDeviceTypeGet(context); return torchMlirTorchDeviceTypeGet(context);
} }
default: { default: {
std::stringstream message; std::stringstream message;
@ -226,7 +226,7 @@ MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) {
if (!tensor.defined()) { if (!tensor.defined()) {
// Undefined tensors are equivalent to None. // Undefined tensors are equivalent to None.
// This may need to be re-evaluated at some point. // This may need to be re-evaluated at some point.
return npcompTorchNoneTypeGet(context); return torchMlirTorchNoneTypeGet(context);
} }
MlirType elementType = mapFromTorchScalarType(tensor.scalar_type()); MlirType elementType = mapFromTorchScalarType(tensor.scalar_type());
@ -234,8 +234,8 @@ MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) {
// just erase them and let the compiler decide. // just erase them and let the compiler decide.
auto sizes = tensor.sizes(); auto sizes = tensor.sizes();
return npcompTorchNonValueTensorTypeGet(context, sizes.size(), sizes.data(), return torchMlirTorchNonValueTensorTypeGet(context, sizes.size(),
elementType); sizes.data(), elementType);
} }
MlirType MlirType

View File

@ -2,5 +2,4 @@ add_subdirectory(Basicpy)
add_subdirectory(Numpy) add_subdirectory(Numpy)
add_subdirectory(Refback) add_subdirectory(Refback)
add_subdirectory(Refbackrt) add_subdirectory(Refbackrt)
add_subdirectory(Torch)
add_subdirectory(TorchConversion) add_subdirectory(TorchConversion)

View File

@ -17,7 +17,6 @@
#include "mlir/IR/SymbolTable.h" #include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h"
#include "npcomp/Interfaces/Traits.h"
#include "npcomp/Typing/Analysis/CPA/Interfaces.h" #include "npcomp/Typing/Analysis/CPA/Interfaces.h"
#define GET_OP_CLASSES #define GET_OP_CLASSES

View File

@ -10,7 +10,6 @@
#define NPCOMP_DIALECT_NUMPY_IR_NUMPY_OPS #define NPCOMP_DIALECT_NUMPY_IR_NUMPY_OPS
include "npcomp/Dialect/Numpy/IR/NumpyDialect.td" include "npcomp/Dialect/Numpy/IR/NumpyDialect.td"
include "npcomp/Interfaces/Traits.td"
include "npcomp/Typing/Analysis/CPA/Interfaces.td" include "npcomp/Typing/Analysis/CPA/Interfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/CastInterfaces.td"
@ -39,7 +38,6 @@ def Numpy_NarrowOp : Numpy_Op<"narrow", []> {
def Numpy_StaticInfoCastOp : Numpy_Op<"static_info_cast", [ def Numpy_StaticInfoCastOp : Numpy_Op<"static_info_cast", [
DeclareOpInterfaceMethods<CastOpInterface>, DeclareOpInterfaceMethods<CastOpInterface>,
AllowsTypeRefinement,
NoSideEffect]> { NoSideEffect]> {
let summary = "Adds/removes static information from an array type."; let summary = "Adds/removes static information from an array type.";
let description = [{ let description = [{
@ -60,7 +58,6 @@ def Numpy_StaticInfoCastOp : Numpy_Op<"static_info_cast", [
def Numpy_TensorStaticInfoCastOp : Numpy_Op<"tensor_static_info_cast", [ def Numpy_TensorStaticInfoCastOp : Numpy_Op<"tensor_static_info_cast", [
DeclareOpInterfaceMethods<CastOpInterface>, DeclareOpInterfaceMethods<CastOpInterface>,
AllowsTypeRefinement,
NoSideEffect]> { NoSideEffect]> {
let summary = "Adds/removes static information from a tensor type."; let summary = "Adds/removes static information from a tensor type.";
let description = [{ let description = [{

View File

@ -1,5 +0,0 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls)
add_public_tablegen_target(NPCOMPTorchPassIncGen)
add_mlir_doc(Passes NPCOMPTorchTransforms ./ -gen-pass-doc)

View File

@ -16,8 +16,8 @@
#include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h"
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.h.inc" #include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.h.inc"

View File

@ -14,7 +14,7 @@ include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td"
include "npcomp/Dialect/TorchConversion/IR/TorchConversionBase.td" include "npcomp/Dialect/TorchConversion/IR/TorchConversionBase.td"
include "npcomp/Dialect/Torch/IR/TorchTypes.td" include "torch-mlir/Dialect/Torch/IR/TorchTypes.td"
include "iree-dialects/Dialect/IREE/IREEDialect.td" include "iree-dialects/Dialect/IREE/IREEDialect.td"
class TorchConversion_Op<string mnemonic, list<OpTrait> traits = []> class TorchConversion_Op<string mnemonic, list<OpTrait> traits = []>

View File

@ -10,7 +10,7 @@
#define NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H #define NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "npcomp/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include <memory> #include <memory>
@ -21,7 +21,8 @@ namespace TorchConversion {
/// Creates a pipeline that lowers the object graph IR that is produced by /// Creates a pipeline that lowers the object graph IR that is produced by
/// TorchScript import into the form expected by npcomp-verify-backend-contract. /// TorchScript import into the form expected by npcomp-verify-backend-contract.
void createTorchScriptToNpcompBackendPipeline( void createTorchScriptToNpcompBackendPipeline(
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options); OpPassManager &pm,
const torch::Torch::TorchLoweringPipelineOptions &options);
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
createVerifyInvariantsBeforeBackendLoweringPass(); createVerifyInvariantsBeforeBackendLoweringPass();

View File

@ -13,21 +13,7 @@
namespace mlir { namespace mlir {
namespace NPCOMP { namespace NPCOMP {
namespace OpTrait { namespace OpTrait {} // namespace OpTrait
template <typename ConcreteType>
class AllowsTypeRefinement
: public ::mlir::OpTrait::TraitBase<ConcreteType, AllowsTypeRefinement> {};
} // namespace OpTrait
// Check if an operation has the AllowsTypeRefinement trait.
//
// This function should be used in preference to
// `op->hasTrait<AllowsTypeRefinement>()` because this function has knowledge of
// some upstream ops that have this property, but which we cannot annotate with
// this trait.
bool allowsTypeRefinement(Operation *op);
} // namespace NPCOMP } // namespace NPCOMP
} // namespace mlir } // namespace mlir

View File

@ -19,13 +19,6 @@ class NpcompOpTrait<string name> : OpTrait, NativeTrait<"", ""> {
let cppNamespace = "::mlir::NPCOMP::OpTrait"; let cppNamespace = "::mlir::NPCOMP::OpTrait";
} }
// Op allows operand and result types to be refined. // Empty for now. Kept as boilerplate placeholder.
// For example a `tensor<?xf32>` can be refined to `tensor<4xf32>`.
//
// TODO: Implement RefinableTypeInterface that allows actually modeling
// which types are refinements of other types.
// See the design in:
// https://llvm.discourse.group/t/allow-shape-concretization-or-type-concretization-in-rewrites/3327/3
def AllowsTypeRefinement : NpcompOpTrait<"AllowsTypeRefinement">;
#endif // NPCOMP_INTERFACES_TRAITS #endif // NPCOMP_INTERFACES_TRAITS

View File

@ -10,7 +10,6 @@ add_npcomp_library(NPCOMPCAPI
Registration.cpp Registration.cpp
BasicpyTypes.cpp BasicpyTypes.cpp
NumpyTypes.cpp NumpyTypes.cpp
TorchTypes.cpp
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRExecutionEngine MLIRExecutionEngine
@ -21,7 +20,8 @@ add_npcomp_library(NPCOMPCAPI
NPCOMPNumpyDialect NPCOMPNumpyDialect
NPCOMPRefBackendJITHelpers NPCOMPRefBackendJITHelpers
NPCOMPRuntime NPCOMPRuntime
NPCOMPTorchDialect TorchMLIRTorchDialect
TorchMLIRInitAll
# MLIR CAPI deps # MLIR CAPI deps
MLIRCAPIIR MLIRCAPIIR

View File

@ -13,6 +13,7 @@
#include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Passes.h"
#include "npcomp/InitAll.h" #include "npcomp/InitAll.h"
#include "torch-mlir/InitAll.h"
void npcompRegisterAllDialects(MlirContext context) { void npcompRegisterAllDialects(MlirContext context) {
mlir::DialectRegistry registry; mlir::DialectRegistry registry;

View File

@ -31,7 +31,7 @@ add_npcomp_library(NPCOMPInitAll
NPCOMPIREEBackend NPCOMPIREEBackend
NPCOMPRefBackend NPCOMPRefBackend
NPCOMPRefbackDialect NPCOMPRefbackDialect
NPCOMPTorchDialect TorchMLIRTorchDialect
NPCOMPTorchConversionDialect NPCOMPTorchConversionDialect
NPCOMPRefbackrtDialect NPCOMPRefbackrtDialect
NPCOMPBasicpyDialect NPCOMPBasicpyDialect

View File

@ -13,7 +13,7 @@ add_npcomp_conversion_library(NPCOMPTorchToIREE
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRIR MLIRIR
MLIRPass MLIRPass
NPCOMPTorchDialect TorchMLIRTorchDialect
MLIRStandard MLIRStandard
IREEDialectsIREEDialect IREEDialectsIREEDialect
) )

View File

@ -14,13 +14,13 @@
#include "mlir/Dialect/Traits.h" #include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "npcomp/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" #include "npcomp/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP; using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Torch; using namespace mlir::torch::Torch;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// The patterns // The patterns

View File

@ -15,5 +15,5 @@ add_npcomp_conversion_library(NPCOMPTorchToLinalg
MLIRPass MLIRPass
MLIRLinalg MLIRLinalg
MLIRMath MLIRMath
NPCOMPTorchDialect TorchMLIRTorchDialect
) )

Some files were not shown because too many files have changed in this diff Show More