[torch-mlir earthmoving (1/N)] C/C++ code movement.

This creates the `external/torch-mlir` directory as an
LLVM_EXTERNAL_PROJECTS-compatible project (analogous to
`iree-dialects`) and completes movement/rename of all pure MLIR C/C++
compiler code into there. The next step will be to move all the Python
code / code that links/includes PyTorch C++ code (which currently lives
in `frontends/pytorch`) into a subdirectory here.

I call this "earthmoving" because it is mostly mechanical changes and
renames. As a quick summary (we can change this down the road easily)
- C++ `mlir::NPCOMP::Torch -> mlir::torch::Torch`
- CAPI `npcompTorchListTypeGet -> torchMlirTorchListTypeGet`
- preprocessor `#ifndef NPCOMP_ -> #ifndef TORCHMLIR_`
- CMake `NPCOMPFoo -> TorchMLIRFoo`

The goal of this is to create a standalone project creating a center of
mass for entry into the MLIR ecosystem from PyTorch, suitable in scope
for eventual inclusion/ownership in PyTorch. The idea is that
`external/torch-mlir` will some day be pulled out into its own
repository, and then npcomp will simply pull it in as a submodule.

Layering-wise, what lives in `torch-mlir` lowers code from PyTorch
(currently TorchScript, but TorchFX or pytorch/xla-style tracing are
possible extensions) down to what we have been calling the "Torch
backend contract" which is cleaned up IR (inlining, simplifcation,
conversion to value tensors, ...) entirely in the `torch` dialect. This
is the branching off point for further lowering, of which npcomp takes
one opinion (outside `torch-mlir` of course!), namely the
`TorchConversion` dialect/transforms which lower to IR suitable for IREE
and other linalg-on-tensors based lower-level compilers.

Summary of changes:
- move `{include,lib,test}/Dialect/Torch` into `torch-mlir`
- move relevant parts of CAPI into `torch-mlir`.
- leave a few things related to the `torch-mlir` Python build commented
  out, which should be resolved in a subsequent change.
pull/305/head
Sean Silva 2021-09-09 19:24:10 +00:00
parent 28762699b3
commit 28a7738189
122 changed files with 826 additions and 436 deletions

View File

@ -128,9 +128,10 @@ if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
"suffix = '${PYTHON_MODULE_SUFFIX}', "
"extension = '${PYTHON_MODULE_EXTENSION}")
# Include the iree-dialects external project.
set(LLVM_EXTERNAL_PROJECTS "iree-dialects")
# Include LLVM_EXTERNAL_PROJECTS.
set(LLVM_EXTERNAL_PROJECTS "iree-dialects;torch-mlir")
set(LLVM_EXTERNAL_IREE_DIALECTS_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/external/iree-dialects")
set(LLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/external/torch-mlir")
# LLVM configuration.
message(STATUS "*** ADDING LLVM ***")
@ -183,6 +184,8 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/external/iree-dialects/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/llvm/tools/iree-dialects/include)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/external/torch-mlir/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/llvm/tools/torch-mlir/include)
link_directories(${LLVM_BUILD_LIBRARY_DIR})
add_definitions(${LLVM_DEFINITIONS})
set(NPCOMP_TABLEGEN_ARGS "")

View File

@ -21,6 +21,7 @@ cd $td/build
ninja
ninja check-npcomp
ninja check-torch-mlir
ninja check-frontends-pytorch
echo

View File

@ -0,0 +1,72 @@
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
message(FATAL_ERROR
"This project is intended to be built as part of LLVM via "
"-DLLVM_EXTERNAL_PROJECTS=torch-mlir "
"-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR=${CMAKE_CURRENT_SOURCE_DIR}")
endif()
option(MLIR_ENABLE_BINDINGS_PYTHON "Enables MLIR Python Bindings" OFF)
set(TORCH_MLIR_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
set(TORCH_MLIR_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")
message(STATUS "Building torch-mlir project at ${TORCH_MLIR_SOURCE_DIR} (into ${TORCH_MLIR_BINARY_DIR})")
# TODO: Fix this upstream so that global include directories are not needed.
set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir)
set(MLIR_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include)
set(MLIR_GENERATED_INCLUDE_DIR ${LLVM_BINARY_DIR}/tools/mlir/include)
# TODO: Needed for tablegen. Remove.
include_directories(SYSTEM ${MLIR_INCLUDE_DIR})
include_directories(SYSTEM ${MLIR_GENERATED_INCLUDE_DIR})
include_directories(SYSTEM ${TORCH_MLIR_SOURCE_DIR}/include)
function(torch_mlir_target_includes target)
set(_dirs
$<BUILD_INTERFACE:${MLIR_INCLUDE_DIR}>
$<BUILD_INTERFACE:${MLIR_GENERATED_INCLUDE_DIR}>
$<BUILD_INTERFACE:${TORCH_MLIR_SOURCE_DIR}/include>
$<BUILD_INTERFACE:${TORCH_MLIR_BINARY_DIR}/include>
)
# In LLVM parlance, the actual target may just be an interface and may not
# be responsible for actually compiling anything. The corresponding obj.
# target, when present, is just used for compilation and does not
# contribute to the interface properties.
# TODO: Normalize this upstream.
target_include_directories(${target} PUBLIC ${_dirs})
if(TARGET obj.${target})
target_include_directories(obj.${target} PRIVATE ${_dirs})
endif()
endfunction()
# Configure CMake and tablegen.
list(APPEND CMAKE_MODULE_PATH ${MLIR_MAIN_SRC_DIR}/cmake/modules)
list(APPEND CMAKE_MODULE_PATH ${LLVM_MAIN_SRC_DIR}/cmake)
set(MLIR_TABLEGEN_EXE mlir-tblgen)
include(TableGen)
include(AddLLVM)
include(AddMLIR)
################################################################################
# Setup python.
# TODO: Make one upstream macro to do this.
################################################################################
if(MLIR_ENABLE_BINDINGS_PYTHON)
include(MLIRDetectPythonEnv)
mlir_detect_pybind11_install()
find_package(Python3 ${LLVM_MINIMUM_PYTHON_VERSION}
COMPONENTS Interpreter Development NumPy REQUIRED)
find_package(pybind11 2.6 CONFIG REQUIRED)
endif()
add_subdirectory(include)
add_subdirectory(lib)
add_subdirectory(test)
add_subdirectory(tools)
if(MLIR_ENABLE_BINDINGS_PYTHON)
#XXX: Enable
#add_subdirectory(python)
endif()

View File

@ -0,0 +1 @@
add_subdirectory(torch-mlir)

View File

@ -0,0 +1,32 @@
/*===-- torch-mlir-c/Registration.h - Registration functions -----*- C -*-===*\
|* *|
|* Part of the LLVM Project, under the Apache License v2.0 with LLVM *|
|* Exceptions. *|
|* See https://llvm.org/LICENSE.txt for license information. *|
|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *|
|* *|
\*===----------------------------------------------------------------------===*/
#ifndef TORCHMLIR_C_REGISTRATION_H
#define TORCHMLIR_C_REGISTRATION_H
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#ifdef __cplusplus
extern "C" {
#endif
/** Registers all dialects with a context.
* This is needed before creating IR for these Dialects.
*/
MLIR_CAPI_EXPORTED void torchMlirRegisterAllDialects(MlirContext context);
/** Registers all passes for symbolic access with the global registry. */
MLIR_CAPI_EXPORTED void torchMlirRegisterAllPasses();
#ifdef __cplusplus
}
#endif
#endif // TORCHMLIR_C_REGISTRATION_H

View File

@ -1,4 +1,4 @@
//===-- npcomp-c/TorchTypes.h - C API for torch types -------------*- C -*-===//
//===-- torch-mlir-c/TorchTypes.h - C API for torch types ---------*- C -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions.
@ -7,8 +7,8 @@
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_C_TORCHTYPES_H
#define NPCOMP_C_TORCHTYPES_H
#ifndef TORCHMLIR_C_TORCHTYPES_H
#define TORCHMLIR_C_TORCHTYPES_H
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
@ -22,110 +22,112 @@ extern "C" {
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a torch.nn.Module type
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchNnModule(MlirType t);
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNnModule(MlirType t);
/// Gets the !torch.nn.Module type of the specified class.
MLIR_CAPI_EXPORTED MlirType npcompTorchNnModuleTypeGet(MlirContext context,
MlirStringRef className);
MLIR_CAPI_EXPORTED MlirType
torchMlirTorchNnModuleTypeGet(MlirContext context, MlirStringRef className);
//===----------------------------------------------------------------------===//
// torch.optional type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.optional<T> type
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchOptional(MlirType t);
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchOptional(MlirType t);
/// Gets the !torch.optional<T> type with subtype T.
MLIR_CAPI_EXPORTED MlirType npcompTorchOptionalTypeGet(MlirType containedType);
MLIR_CAPI_EXPORTED MlirType
torchMlirTorchOptionalTypeGet(MlirType containedType);
//===----------------------------------------------------------------------===//
// torch.tuple<T1, T2, T3> type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.tuple type
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchTuple(MlirType t);
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchTuple(MlirType t);
/// Gets the !torch.tuple type with contained types `containedTypes`.
MLIR_CAPI_EXPORTED MlirType
npcompTorchTupleTypeGet(MlirContext context, intptr_t numContainedTypes,
MlirType const *containedTypes);
torchMlirTorchTupleTypeGet(MlirContext context, intptr_t numContainedTypes,
MlirType const *containedTypes);
//===----------------------------------------------------------------------===//
// torch.list<T> type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.list<T> type
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchList(MlirType t);
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchList(MlirType t);
/// Gets the !torch.list<T> type with contained T.
MLIR_CAPI_EXPORTED MlirType npcompTorchListTypeGet(MlirType containedType);
MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGet(MlirType containedType);
//===----------------------------------------------------------------------===//
// torch.Device type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.Device type
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchDevice(MlirType t);
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchDevice(MlirType t);
/// Gets the !torch.Device type.
MLIR_CAPI_EXPORTED MlirType npcompTorchDeviceTypeGet(MlirContext context);
MLIR_CAPI_EXPORTED MlirType torchMlirTorchDeviceTypeGet(MlirContext context);
//===----------------------------------------------------------------------===//
// torch.bool type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.bool type
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchBool(MlirType t);
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchBool(MlirType t);
/// Gets the !torch.bool type.
MLIR_CAPI_EXPORTED MlirType npcompTorchBoolTypeGet(MlirContext context);
MLIR_CAPI_EXPORTED MlirType torchMlirTorchBoolTypeGet(MlirContext context);
//===----------------------------------------------------------------------===//
// torch.int type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.int type
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchInt(MlirType t);
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchInt(MlirType t);
/// Gets the !torch.int type.
MLIR_CAPI_EXPORTED MlirType npcompTorchIntTypeGet(MlirContext context);
MLIR_CAPI_EXPORTED MlirType torchMlirTorchIntTypeGet(MlirContext context);
//===----------------------------------------------------------------------===//
// torch.float type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.float type
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchFloat(MlirType t);
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchFloat(MlirType t);
/// Gets the !torch.float type.
MLIR_CAPI_EXPORTED MlirType npcompTorchFloatTypeGet(MlirContext context);
MLIR_CAPI_EXPORTED MlirType torchMlirTorchFloatTypeGet(MlirContext context);
//===----------------------------------------------------------------------===//
// torch.LinearParams type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.LinearParams type
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchLinearParams(MlirType t);
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchLinearParams(MlirType t);
/// Gets the !torch.LinearParams type.
MLIR_CAPI_EXPORTED MlirType npcompTorchLinearParamsTypeGet(MlirContext context);
MLIR_CAPI_EXPORTED MlirType
torchMlirTorchLinearParamsTypeGet(MlirContext context);
//===----------------------------------------------------------------------===//
// torch.qint8 type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.qint8 type
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchQInt8(MlirType t);
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQInt8(MlirType t);
/// Gets the !torch.qint8 type.
MLIR_CAPI_EXPORTED MlirType npcompTorchQInt8TypeGet(MlirContext context);
MLIR_CAPI_EXPORTED MlirType torchMlirTorchQInt8TypeGet(MlirContext context);
//===----------------------------------------------------------------------===//
// torch.tensor type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.tensor type
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchNonValueTensor(MlirType t);
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNonValueTensor(MlirType t);
/// Gets a !torch.tensor type.
///
@ -133,24 +135,25 @@ MLIR_CAPI_EXPORTED bool npcompTypeIsATorchNonValueTensor(MlirType t);
/// information is present (and `numSizes` is ignored in that case). -
/// `optionalDtype` is allowed to be null, meaning that no dtype
/// information is present.
MLIR_CAPI_EXPORTED MlirType npcompTorchNonValueTensorTypeGet(
MLIR_CAPI_EXPORTED MlirType torchMlirTorchNonValueTensorTypeGet(
MlirContext context, intptr_t numSizes, const int64_t *optionalSizes,
MlirType optionalDtype);
/// Gets the !torch.tensor type with the least static information.
MLIR_CAPI_EXPORTED MlirType
npcompTorchNonValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
torchMlirTorchNonValueTensorTypeGetWithLeastStaticInformation(
MlirContext context);
/// Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`.
MLIR_CAPI_EXPORTED MlirType
npcompTorchNonValueTensorTypeGetFromShaped(MlirType type);
torchMlirTorchNonValueTensorTypeGetFromShaped(MlirType type);
//===----------------------------------------------------------------------===//
// torch.vtensor type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.vtensor type
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchValueTensor(MlirType t);
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchValueTensor(MlirType t);
/// Gets a !torch.vtensor type.
///
@ -158,71 +161,71 @@ MLIR_CAPI_EXPORTED bool npcompTypeIsATorchValueTensor(MlirType t);
/// information is present (and `numSizes` is ignored in that case).
/// - `optionalDtype` is allowed to be null, meaning that no dtype
/// information is present.
MLIR_CAPI_EXPORTED MlirType npcompTorchValueTensorTypeGet(
MLIR_CAPI_EXPORTED MlirType torchMlirTorchValueTensorTypeGet(
MlirContext context, intptr_t numSizes, const int64_t *optionalSizes,
MlirType optionalDtype);
/// Gets the !torch.tensor type with the least static information.
MLIR_CAPI_EXPORTED MlirType
npcompTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
/// Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`.
MLIR_CAPI_EXPORTED MlirType
npcompTorchValueTensorTypeGetFromShaped(MlirType type);
torchMlirTorchValueTensorTypeGetFromShaped(MlirType type);
//===----------------------------------------------------------------------===//
// !torch.none type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.none type
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchNone(MlirType t);
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNone(MlirType t);
/// Gets the !torch.none type.
MLIR_CAPI_EXPORTED MlirType npcompTorchNoneTypeGet(MlirContext context);
MLIR_CAPI_EXPORTED MlirType torchMlirTorchNoneTypeGet(MlirContext context);
//===----------------------------------------------------------------------===//
// !torch.str type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.str type
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchString(MlirType t);
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchString(MlirType t);
/// Gets the !torch.str type.
MLIR_CAPI_EXPORTED MlirType npcompTorchStringTypeGet(MlirContext context);
MLIR_CAPI_EXPORTED MlirType torchMlirTorchStringTypeGet(MlirContext context);
//===----------------------------------------------------------------------===//
// !torch.any type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.any type.
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchAny(MlirType t);
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchAny(MlirType t);
/// Gets the !torch.str type.
MLIR_CAPI_EXPORTED MlirType npcompTorchAnyTypeGet(MlirContext context);
MLIR_CAPI_EXPORTED MlirType torchMlirTorchAnyTypeGet(MlirContext context);
//===----------------------------------------------------------------------===//
// !torch.number type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.number type.
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchNumber(MlirType t);
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNumber(MlirType t);
/// Gets the !torch.number type.
MLIR_CAPI_EXPORTED MlirType npcompTorchNumberTypeGet(MlirContext context);
MLIR_CAPI_EXPORTED MlirType torchMlirTorchNumberTypeGet(MlirContext context);
//===----------------------------------------------------------------------===//
// !torch.dict type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.dict type.
MLIR_CAPI_EXPORTED bool npcompTypeIsATorchDict(MlirType t);
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchDict(MlirType t);
/// Gets the !torch.dict type.
MLIR_CAPI_EXPORTED MlirType npcompTorchDictTypeGet(MlirType keyType,
MlirType valueType);
MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGet(MlirType keyType,
MlirType valueType);
#ifdef __cplusplus
}
#endif
#endif // NPCOMP_C_TORCHTYPES_H
#endif // TORCHMLIR_C_TORCHTYPES_H

View File

@ -0,0 +1 @@
add_subdirectory(Dialect)

View File

@ -0,0 +1 @@
add_subdirectory(Torch)

View File

@ -13,7 +13,7 @@ include "mlir/IR/OpBase.td"
def Torch_Dialect : Dialect {
let name = "torch";
let cppNamespace = "::mlir::NPCOMP::Torch";
let cppNamespace = "::mlir::torch::Torch";
let description = [{
Top-level dialect for interfacing PyTorch and MLIR.
@ -21,7 +21,7 @@ def Torch_Dialect : Dialect {
This dialect also provides transforms that lower it to the
"Torch backend contract", which is an IR form that we present to
later conversions, such as conversion to the npcomp backend contract.
later conversions.
The Torch backend contract significantly simplifies the IR representation
and puts it in a form easier for later lowering to work on. Specifically:
- The TorchScript object graph has been flattened to a list of globals (see
@ -39,11 +39,12 @@ def Torch_Dialect : Dialect {
class TorchOpTrait<string name> : OpTrait, NativeTrait<"", ""> {
let trait = name;
let cppNamespace = "::mlir::NPCOMP::Torch::OpTrait";
let cppNamespace = "::mlir::torch::Torch::OpTrait";
}
def HasValueSemantics : TorchOpTrait<"HasValueSemantics">;
def IsTrailingUnderscoreInplaceVariant
: TorchOpTrait<"IsTrailingUnderscoreInplaceVariant">;
def AllowsTypeRefinement : TorchOpTrait<"AllowsTypeRefinement">;
#endif // TORCH_BASE

View File

@ -6,11 +6,11 @@
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_DIALECT_TORCH_IR_TORCHDIALECT_H
#define NPCOMP_DIALECT_TORCH_IR_TORCHDIALECT_H
#ifndef TORCHMLIR_DIALECT_TORCH_IR_TORCHDIALECT_H
#define TORCHMLIR_DIALECT_TORCH_IR_TORCHDIALECT_H
#include "mlir/IR/Dialect.h"
#include "npcomp/Dialect/Torch/IR/TorchDialect.h.inc"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h.inc"
#endif // NPCOMP_DIALECT_TORCH_IR_TORCHDIALECT_H
#endif // TORCHMLIR_DIALECT_TORCH_IR_TORCHDIALECT_H

View File

@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_DIALECT_TORCH_IR_TORCHOPS_H
#define NPCOMP_DIALECT_TORCH_IR_TORCHOPS_H
#ifndef TORCHMLIR_DIALECT_TORCH_IR_TORCHOPS_H
#define TORCHMLIR_DIALECT_TORCH_IR_TORCHOPS_H
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
@ -18,15 +18,14 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "npcomp/Dialect/Torch/IR/TorchTraits.h"
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
#include "npcomp/Interfaces/Traits.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTraits.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#define GET_OP_CLASSES
#include "npcomp/Dialect/Torch/IR/TorchOps.h.inc"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h.inc"
namespace mlir {
namespace NPCOMP {
namespace torch {
namespace Torch {
namespace detail {
@ -117,11 +116,11 @@ m_TorchConstantIntList(SmallVectorImpl<int64_t> &bind_values) {
Value copyTensorToType(OpBuilder &builder, Location loc, BaseTensorType newType,
Value tensor);
} // namespace Torch
} // namespace NPCOMP
} // namespace torch
} // namespace mlir
template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::SlotOp> {
using SlotOp = ::mlir::NPCOMP::Torch::SlotOp;
template <> struct llvm::DenseMapInfo<::mlir::torch::Torch::SlotOp> {
using SlotOp = ::mlir::torch::Torch::SlotOp;
static SlotOp getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return SlotOp::getFromOpaquePointer(pointer);
@ -136,8 +135,8 @@ template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::SlotOp> {
static bool isEqual(SlotOp lhs, SlotOp rhs) { return lhs == rhs; }
};
template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::NnModuleOp> {
using NnModuleOp = ::mlir::NPCOMP::Torch::NnModuleOp;
template <> struct llvm::DenseMapInfo<::mlir::torch::Torch::NnModuleOp> {
using NnModuleOp = ::mlir::torch::Torch::NnModuleOp;
static NnModuleOp getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return NnModuleOp::getFromOpaquePointer(pointer);
@ -152,8 +151,8 @@ template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::NnModuleOp> {
static bool isEqual(NnModuleOp lhs, NnModuleOp rhs) { return lhs == rhs; }
};
template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::ClassTypeOp> {
using ClassTypeOp = ::mlir::NPCOMP::Torch::ClassTypeOp;
template <> struct llvm::DenseMapInfo<::mlir::torch::Torch::ClassTypeOp> {
using ClassTypeOp = ::mlir::torch::Torch::ClassTypeOp;
static ClassTypeOp getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return ClassTypeOp::getFromOpaquePointer(pointer);
@ -168,8 +167,8 @@ template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::ClassTypeOp> {
static bool isEqual(ClassTypeOp lhs, ClassTypeOp rhs) { return lhs == rhs; }
};
template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::GlobalSlotOp> {
using OpTy = ::mlir::NPCOMP::Torch::GlobalSlotOp;
template <> struct llvm::DenseMapInfo<::mlir::torch::Torch::GlobalSlotOp> {
using OpTy = ::mlir::torch::Torch::GlobalSlotOp;
static OpTy getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return OpTy::getFromOpaquePointer(pointer);
@ -184,4 +183,4 @@ template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::GlobalSlotOp> {
static bool isEqual(OpTy lhs, OpTy rhs) { return lhs == rhs; }
};
#endif // NPCOMP_DIALECT_TORCH_IR_TORCHOPS_H
#endif // TORCHMLIR_DIALECT_TORCH_IR_TORCHOPS_H

View File

@ -9,8 +9,7 @@
#ifndef TORCH_OPS
#define TORCH_OPS
include "npcomp/Dialect/Torch/IR/TorchTypes.td"
include "npcomp/Interfaces/Traits.td"
include "torch-mlir/Dialect/Torch/IR/TorchTypes.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
@ -22,9 +21,9 @@ class Torch_Op<string mnemonic, list<OpTrait> traits = []>
: Op<Torch_Dialect, mnemonic, traits> {
}
include "npcomp/Dialect/Torch/IR/GeneratedAtenOps.td"
include "npcomp/Dialect/Torch/IR/GeneratedPrimOps.td"
include "npcomp/Dialect/Torch/IR/GeneratedQuantizedOps.td"
include "torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td"
include "torch-mlir/Dialect/Torch/IR/GeneratedPrimOps.td"
include "torch-mlir/Dialect/Torch/IR/GeneratedQuantizedOps.td"
//===----------------------------------------------------------------------===//
// TorchScript `torch.nn.Module` object instantiation ops.
@ -32,7 +31,7 @@ include "npcomp/Dialect/Torch/IR/GeneratedQuantizedOps.td"
def Torch_NnModuleOp : Torch_Op<"nn_module", [
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
SingleBlockImplicitTerminator<"::mlir::NPCOMP::Torch::NnModuleTerminatorOp">]> {
SingleBlockImplicitTerminator<"::mlir::torch::Torch::NnModuleTerminatorOp">]> {
let summary = "Constructs a torch.nn.Module";
let description = [{
This op is used to represent a torch.nn.Module when importing a
@ -75,7 +74,7 @@ def Torch_NnModuleOp : Torch_Op<"nn_module", [
}
def Torch_NnModuleTerminatorOp : Torch_Op<"nn_module_terminator", [Terminator,
HasParent<"::mlir::NPCOMP::Torch::NnModuleOp">]> {
HasParent<"::mlir::torch::Torch::NnModuleOp">]> {
let summary = "Implicit terminator for torch.nn_module";
let arguments = (ins);
@ -85,7 +84,7 @@ def Torch_NnModuleTerminatorOp : Torch_Op<"nn_module_terminator", [Terminator,
}
def Torch_SlotOp : Torch_Op<"slot", [
HasParent<"::mlir::NPCOMP::Torch::NnModuleOp">]> {
HasParent<"::mlir::torch::Torch::NnModuleOp">]> {
let summary = "Define the value of a slot of a torch.nn.Module";
let description = [{
This op specifies that the initial value of the slot `name` of the
@ -107,7 +106,7 @@ def Torch_SlotOp : Torch_Op<"slot", [
def Torch_ClassTypeOp : Torch_Op<"class_type", [
Symbol,
SingleBlockImplicitTerminator<"::mlir::NPCOMP::Torch::ClassTypeTerminatorOp">]> {
SingleBlockImplicitTerminator<"::mlir::torch::Torch::ClassTypeTerminatorOp">]> {
let summary = "Constructs a torch.ClassType";
let description = [{
Declares a class type. Class types are the types used to describe
@ -152,7 +151,7 @@ def Torch_ClassTypeOp : Torch_Op<"class_type", [
}
def Torch_ClassTypeTerminatorOp : Torch_Op<"class_type_terminator", [Terminator,
HasParent<"::mlir::NPCOMP::Torch::ClassTypeOp">]> {
HasParent<"::mlir::torch::Torch::ClassTypeOp">]> {
let summary = "Implicit terminator for torch.class_type";
let arguments = (ins);
@ -162,7 +161,7 @@ def Torch_ClassTypeTerminatorOp : Torch_Op<"class_type_terminator", [Terminator,
}
def Torch_MethodOp : Torch_Op<"method", [
HasParent<"::mlir::NPCOMP::Torch::ClassTypeOp">,
HasParent<"::mlir::torch::Torch::ClassTypeOp">,
DeclareOpInterfaceMethods<SymbolUserOpInterface>
]> {
let summary = "Declare a method of a torch.class_type";
@ -193,7 +192,7 @@ def Torch_MethodOp : Torch_Op<"method", [
}
def Torch_AttrOp : Torch_Op<"attr", [
HasParent<"::mlir::NPCOMP::Torch::ClassTypeOp">
HasParent<"::mlir::torch::Torch::ClassTypeOp">
]> {
let summary = "Declare an attribute of a torch.class_type";
let description = [{
@ -231,7 +230,7 @@ def Torch_AttrOp : Torch_Op<"attr", [
def Torch_GlobalSlotOp : Torch_Op<"global_slot", [
Symbol,
IsolatedFromAbove,
SingleBlockImplicitTerminator<"::mlir::NPCOMP::Torch::GlobalSlotInitOp">
SingleBlockImplicitTerminator<"::mlir::torch::Torch::GlobalSlotInitOp">
]> {
let summary = "A slot with global storage";
let description = [{
@ -256,7 +255,7 @@ def Torch_GlobalSlotOp : Torch_Op<"global_slot", [
def Torch_GlobalSlotInitOp : Torch_Op<"global_slot.init", [
Terminator,
HasParent<"::mlir::NPCOMP::Torch::GlobalSlotOp">]> {
HasParent<"::mlir::torch::Torch::GlobalSlotOp">]> {
let summary = "yield-like terminator for torch.global_slot initializer region";
let description = [{
The operand to this op becomes the initial value of the parent
@ -463,7 +462,7 @@ def Torch_PrimLoopOp : Torch_Op<"prim.Loop", [
def Torch_PrimLoopConditionOp : Torch_Op<"prim.Loop.condition", [
DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>,
Terminator,
HasParent<"::mlir::NPCOMP::Torch::PrimLoopOp">]> {
HasParent<"::mlir::torch::Torch::PrimLoopOp">]> {
let summary = "yield-like terminator for torch.prim.Loop";
let description = [{
Does not correspond to any torch prim op directly (the way that they model
@ -512,7 +511,7 @@ def Torch_PrimIfOp : Torch_Op<"prim.If", [
def Torch_PrimIfYieldOp : Torch_Op<"prim.If.yield", [
Terminator,
ReturnLike,
HasParent<"::mlir::NPCOMP::Torch::PrimIfOp">]> {
HasParent<"::mlir::torch::Torch::PrimIfOp">]> {
let summary = "yield-like terminator for torch.prim.If";
let description = [{
Does not correspond to any torch prim op directly (the way that they model

View File

@ -10,15 +10,14 @@
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_DIALECT_TORCH_IR_TORCHTRAITS_H
#define NPCOMP_DIALECT_TORCH_IR_TORCHTRAITS_H
#ifndef TORCHMLIR_DIALECT_TORCH_IR_TORCHTRAITS_H
#define TORCHMLIR_DIALECT_TORCH_IR_TORCHTRAITS_H
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
namespace mlir {
namespace NPCOMP {
namespace torch {
namespace Torch {
namespace OpTrait {
@ -39,9 +38,16 @@ class IsTrailingUnderscoreInplaceVariant
: public ::mlir::OpTrait::TraitBase<ConcreteType,
IsTrailingUnderscoreInplaceVariant> {};
// If a Torch op has this trait, it means that the op allows all of its operand
// and result types to be refined. That is, a less specific type is allowed to
// be replaced by a more specific type, according to PEP 483 subtyping rules.
template <typename ConcreteType>
class AllowsTypeRefinement
: public ::mlir::OpTrait::TraitBase<ConcreteType, AllowsTypeRefinement> {};
} // namespace OpTrait
} // namespace Torch
} // namespace NPCOMP
} // namespace torch
} // namespace mlir
#endif // NPCOMP_DIALECT_TORCH_IR_TORCHTRAITS_H
#endif // TORCHMLIR_DIALECT_TORCH_IR_TORCHTRAITS_H

View File

@ -6,13 +6,13 @@
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_DIALECT_TORCH_IR_TORCHTYPES_H
#define NPCOMP_DIALECT_TORCH_IR_TORCHTYPES_H
#ifndef TORCHMLIR_DIALECT_TORCH_IR_TORCHTYPES_H
#define TORCHMLIR_DIALECT_TORCH_IR_TORCHTYPES_H
#include "mlir/IR/BuiltinTypes.h"
namespace mlir {
namespace NPCOMP {
namespace torch {
namespace Torch {
class NonValueTensorType;
@ -87,18 +87,18 @@ public:
ValueTensorType getWithValueSemantics() const;
};
} // namespace Torch
} // namespace NPCOMP
} // namespace torch
} // namespace mlir
#define GET_TYPEDEF_CLASSES
#include "npcomp/Dialect/Torch/IR/TorchTypes.h.inc"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h.inc"
//===----------------------------------------------------------------------===//
// Inline definitions
//===----------------------------------------------------------------------===//
namespace mlir {
namespace NPCOMP {
namespace torch {
namespace Torch {
inline Optional<ArrayRef<int64_t>> BaseTensorType::getOptionalSizes() const {
@ -122,7 +122,7 @@ inline bool BaseTensorType::classof(Type type) {
}
} // namespace Torch
} // namespace NPCOMP
} // namespace torch
} // namespace mlir
#endif // NPCOMP_DIALECT_TORCH_IR_TORCHTYPES_H
#endif // TORCHMLIR_DIALECT_TORCH_IR_TORCHTYPES_H

View File

@ -9,7 +9,7 @@
#ifndef TORCH_TYPES
#define TORCH_TYPES
include "npcomp/Dialect/Torch/IR/TorchBase.td"
include "torch-mlir/Dialect/Torch/IR/TorchBase.td"
//===----------------------------------------------------------------------===//
// Type defs
@ -83,7 +83,7 @@ class OptionalArrayRefParameter<string arrayOf, string desc = ""> :
}
class AnyTorchTensorType<string name, string typeMnemonic>
: Torch_Type<name, typeMnemonic, "::mlir::NPCOMP::Torch::BaseTensorType"> {
: Torch_Type<name, typeMnemonic, "::mlir::torch::Torch::BaseTensorType"> {
let summary = "Multi-dimensional array modeling Torch's Tensor type";
let description = [{
Syntax:
@ -107,8 +107,8 @@ class AnyTorchTensorType<string name, string typeMnemonic>
a strict separation between the value-semantic and potentially-mutating
worlds, as one of our main jobs in the compiler is to isolate the mutating
parts as much as possible because most lower levels of the compiler stack
are expected to require value semantics. E.g. npcomp's backend contract
is mostly in terms of linalg-on-tensor for compute-heavy ops, which require
are expected to require value semantics. E.g. many backend contracts
mostly use linalg-on-tensor for compute-heavy ops, which require
a conversion to the builtin `tensor` type which has value semantics.
Some notes about value semantics:
- Using the type system described in PEP 483 (which TorchScript and other
@ -165,7 +165,7 @@ class AnyTorchTensorType<string name, string typeMnemonic>
Note: We avoid the C++ identifier `TensorType` to avoid C++ name ambiguities
with `mlir::TensorType`, since most code is transitively nested in
both `::mlir` and `::mlir::NPCOMP::Torch` namespaces.
both `::mlir` and `::mlir::torch::Torch` namespaces.
Note: We use the Torch-aligned terminology "sizes" and "dtype" instead of
the MLIR-aligned terminology "rank/shape" and "element type". The cheat
@ -209,7 +209,7 @@ def Torch_ValueTensorType : AnyTorchTensorType<"ValueTensor", "vtensor"> {
}
def AnyTorchTensorType : Type<
CPred<"$_self.isa<::mlir::NPCOMP::Torch::BaseTensorType>()">,
CPred<"$_self.isa<::mlir::torch::Torch::BaseTensorType>()">,
"Any Torch tensor type"
>;
@ -317,7 +317,7 @@ def Torch_LinearParamsType : Torch_Type<"LinearParams", "LinearParams"> {
The whole LinearPackedParamsBase abstraction as it stands in PyTorch is a
very library-call-y, runtime-y thing that embodies a number of assumptions
about the structure of how the program will be executed, which need not hold
for npcomp backends.
for backends.
}];
}
@ -386,12 +386,12 @@ def TorchOptionalBoolType:
def TorchOptionalDeviceType:
OptionalOf<Torch_DeviceType, "Optional torch device type">;
def IsListTypePred : CPred<"$_self.isa<::mlir::NPCOMP::Torch::ListType>()">;
def IsListTypePred : CPred<"$_self.isa<::mlir::torch::Torch::ListType>()">;
class ListOf<list<Type> allowedTypes, string descr> :
ContainerType<AnyTypeOf<allowedTypes>,
IsListTypePred,
"$_self.cast<::mlir::NPCOMP::Torch::ListType>().getContainedType()",
descr, "::mlir::NPCOMP::Torch::ListType">;
"$_self.cast<::mlir::torch::Torch::ListType>().getContainedType()",
descr, "::mlir::torch::Torch::ListType">;
def TorchBoolListType : ListOf<[Torch_BoolType], "Bool list type (bool[])">;
def TorchIntListType : ListOf<[Torch_IntType], "Int list type (int[])">;

View File

@ -0,0 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls)
add_public_tablegen_target(TorchMLIRTorchPassIncGen)
add_mlir_doc(Passes TorchMLIRTorchTransforms ./ -gen-pass-doc)

View File

@ -6,15 +6,15 @@
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSES_H
#define NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSES_H
#ifndef TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSES_H
#define TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSES_H
#include "mlir/Pass/Pass.h"
#include <memory>
namespace mlir {
namespace NPCOMP {
namespace torch {
namespace Torch {
std::unique_ptr<OperationPass<ModuleOp>> createGlobalizeObjectGraphPass();
@ -58,7 +58,7 @@ std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
/// Registers all Torch transformation passes.
void registerTorchPasses();
} // namespace NPCOMP
} // namespace torch
} // namespace mlir
#endif // NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSES_H
#endif // TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSES_H

View File

@ -6,14 +6,14 @@
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_TORCH_PASSES
#define NPCOMP_TORCH_PASSES
#ifndef TORCHMLIR_TORCH_PASSES
#define TORCHMLIR_TORCH_PASSES
include "mlir/Pass/PassBase.td"
def GlobalizeObjectGraph : Pass<"torch-globalize-object-graph", "ModuleOp"> {
let summary = "Converts TorchScript object graphs to a globalized form";
let constructor = "mlir::NPCOMP::Torch::createGlobalizeObjectGraphPass()";
let constructor = "mlir::torch::Torch::createGlobalizeObjectGraphPass()";
let description = [{
This pass converts a subset of possible TorchScript modules into a
more restrictive lower-level form that strips away the need to be
@ -80,7 +80,7 @@ def GlobalizeObjectGraph : Pass<"torch-globalize-object-graph", "ModuleOp"> {
- Rationale: This makes the representation of initial values simpler. Also
as of Feb 2021, TorchScript won't import into this form except
potentially for Tensors (it has a bug related to the identity of
objects). And for tensors, the npcomp IValue importer only supports a
objects). And for tensors, the IValue importer only supports a
very restricted form of aliasing anyway for other reasons. We are
waiting for signals that more general handling of object aliasing is
important to devote the effort to it.
@ -90,7 +90,7 @@ def GlobalizeObjectGraph : Pass<"torch-globalize-object-graph", "ModuleOp"> {
def PrepareForGlobalizeObjectGraph
: Pass<"torch-prepare-for-globalize-object-graph", "ModuleOp"> {
let summary = "Lowering in preparation for globalizing";
let constructor = "mlir::NPCOMP::Torch::createPrepareForGlobalizeObjectGraphPass()";
let constructor = "mlir::torch::Torch::createPrepareForGlobalizeObjectGraphPass()";
let description = [{
Establishes and the invariants needed by the
torch-globalize-object-graph transformation. Fails if that cannot be
@ -104,7 +104,7 @@ def PrepareForGlobalizeObjectGraph
def AdjustCallingConventions
: Pass<"torch-adjust-calling-conventions", "ModuleOp"> {
let summary = "Adjust the calling conventions of functions";
let constructor = "mlir::NPCOMP::Torch::createAdjustCallingConventionsPass()";
let constructor = "mlir::torch::Torch::createAdjustCallingConventionsPass()";
let description = [{
Adjusts the calling conventions of functions in the module, with the aim of
preparing them for backends and further lowering passes. As this changes
@ -127,7 +127,7 @@ def AdjustCallingConventions
def RefineTypes : Pass<"torch-refine-types", "FuncOp"> {
let summary = "Refine types";
let constructor = "mlir::NPCOMP::Torch::createRefineTypesPass()";
let constructor = "mlir::torch::Torch::createRefineTypesPass()";
let description = [{
Refines types of the program. Currently, this means shapes and dtypes of
tensors/arrays.
@ -136,7 +136,7 @@ def RefineTypes : Pass<"torch-refine-types", "FuncOp"> {
def InlineGlobalSlots : Pass<"torch-inline-global-slots", "ModuleOp"> {
let summary = "Inlines torch.global_slot ops.";
let constructor = "mlir::NPCOMP::Torch::createInlineGlobalSlotsPass()";
let constructor = "mlir::torch::Torch::createInlineGlobalSlotsPass()";
let description = [{
Inlines torch.global_slot ops when it is safe to do so.
@ -150,7 +150,7 @@ def InlineGlobalSlots : Pass<"torch-inline-global-slots", "ModuleOp"> {
def ReduceOpVariants : Pass<"torch-reduce-op-variants", "FuncOp"> {
let summary = "Reduces variants of ops to a smaller set of ops.";
let constructor = "mlir::NPCOMP::Torch::createReduceOpVariantsPass()";
let constructor = "mlir::torch::Torch::createReduceOpVariantsPass()";
let description = [{
Replaces ops with other ops to reduce the number of variants that
need to be handled elsewhere in the code.
@ -181,13 +181,13 @@ def MaximizeValueSemantics : Pass<"torch-maximize-value-semantics", "FuncOp"> {
Also, this pass doesn't currently handle interprocedural rewriting
(of private functions), which is even more complex.
}];
let constructor = "mlir::NPCOMP::Torch::createMaximizeValueSemanticsPass()";
let constructor = "mlir::torch::Torch::createMaximizeValueSemanticsPass()";
}
def RefinePublicReturn : Pass<"torch-refine-public-return", "ModuleOp"> {
let summary = "Refine public return";
let constructor = "mlir::NPCOMP::Torch::createRefinePublicReturnPass()";
let constructor = "mlir::torch::Torch::createRefinePublicReturnPass()";
let description = [{
Refines types of values returned from public functions based on
intraprocedural information.
@ -214,4 +214,4 @@ def RefinePublicReturn : Pass<"torch-refine-public-return", "ModuleOp"> {
}];
}
#endif // NPCOMP_TORCH_PASSES
#endif // TORCHMLIR_TORCH_PASSES

View File

@ -0,0 +1,23 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef TORCH_MLIR_INITALL_H
#define TORCH_MLIR_INITALL_H
#include "mlir/IR/Dialect.h"
namespace mlir {
namespace torch {
void registerAllDialects(mlir::DialectRegistry &registry);
void registerAllPasses();
} // namespace torch
} // namespace mlir
#endif // TORCH_MLIR_INITALL_H

View File

@ -0,0 +1,18 @@
add_mlir_library(TorchMLIRCAPI
Registration.cpp
TorchTypes.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir-c/
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRSupport
TorchMLIRTorchDialect
TorchMLIRInitAll
)
torch_mlir_target_includes(TorchMLIRCAPI)

View File

@ -0,0 +1,25 @@
//===- Registration.cpp - C Interface for MLIR Registration ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "torch-mlir-c/Registration.h"
#include "mlir/CAPI/IR.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Transforms/Passes.h"
#include "torch-mlir/InitAll.h"
void torchMlirRegisterAllDialects(MlirContext context) {
mlir::DialectRegistry registry;
mlir::torch::registerAllDialects(registry);
unwrap(context)->appendDialectRegistry(registry);
// TODO: Don't eagerly load once D88162 is in and clients can do this.
unwrap(context)->loadAllAvailableDialects();
}
void torchMlirRegisterAllPasses() { mlir::torch::registerAllPasses(); }

View File

@ -6,26 +6,26 @@
//
//===----------------------------------------------------------------------===//
#include "npcomp-c/TorchTypes.h"
#include "torch-mlir-c/TorchTypes.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
#include "mlir/IR/BuiltinTypes.h"
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::torch;
//===----------------------------------------------------------------------===//
// torch.nn.Module type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsATorchNnModule(MlirType t) {
bool torchMlirTypeIsATorchNnModule(MlirType t) {
return unwrap(t).isa<Torch::NnModuleType>();
}
MlirType npcompTorchNnModuleTypeGet(MlirContext context,
MlirStringRef className) {
MlirType torchMlirTorchNnModuleTypeGet(MlirContext context,
MlirStringRef className) {
return wrap(Torch::NnModuleType::get(unwrap(context), unwrap(className)));
}
@ -33,11 +33,11 @@ MlirType npcompTorchNnModuleTypeGet(MlirContext context,
// torch.optional type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsATorchOptional(MlirType t) {
bool torchMlirTypeIsATorchOptional(MlirType t) {
return unwrap(t).isa<Torch::OptionalType>();
}
MlirType npcompTorchOptionalTypeGet(MlirType containedType) {
MlirType torchMlirTorchOptionalTypeGet(MlirType containedType) {
return wrap(Torch::OptionalType::get(unwrap(containedType)));
}
@ -45,13 +45,13 @@ MlirType npcompTorchOptionalTypeGet(MlirType containedType) {
// torch.tuple<T1, T2, T3> type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsATorchTuple(MlirType t) {
bool torchMlirTypeIsATorchTuple(MlirType t) {
return unwrap(t).isa<Torch::TupleType>();
}
MlirType npcompTorchTupleTypeGet(MlirContext context,
intptr_t numContainedTypes,
MlirType const *containedTypes) {
MlirType torchMlirTorchTupleTypeGet(MlirContext context,
intptr_t numContainedTypes,
MlirType const *containedTypes) {
return wrap(Torch::TupleType::get(
unwrap(context),
llvm::to_vector<6>(
@ -63,11 +63,11 @@ MlirType npcompTorchTupleTypeGet(MlirContext context,
// torch.list<T> type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsATorchList(MlirType t) {
bool torchMlirTypeIsATorchList(MlirType t) {
return unwrap(t).isa<Torch::ListType>();
}
MlirType npcompTorchListTypeGet(MlirType containedType) {
MlirType torchMlirTorchListTypeGet(MlirType containedType) {
return wrap(Torch::ListType::get(unwrap(containedType)));
}
@ -75,11 +75,11 @@ MlirType npcompTorchListTypeGet(MlirType containedType) {
// torch.Device type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsATorchDevice(MlirType t) {
bool torchMlirTypeIsATorchDevice(MlirType t) {
return unwrap(t).isa<Torch::DeviceType>();
}
MlirType npcompTorchDeviceTypeGet(MlirContext context) {
MlirType torchMlirTorchDeviceTypeGet(MlirContext context) {
return wrap(Torch::DeviceType::get(unwrap(context)));
}
@ -87,11 +87,11 @@ MlirType npcompTorchDeviceTypeGet(MlirContext context) {
// torch.bool type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsATorchBool(MlirType t) {
bool torchMlirTypeIsATorchBool(MlirType t) {
return unwrap(t).isa<Torch::BoolType>();
}
MlirType npcompTorchBoolTypeGet(MlirContext context) {
MlirType torchMlirTorchBoolTypeGet(MlirContext context) {
return wrap(Torch::BoolType::get(unwrap(context)));
}
@ -99,11 +99,11 @@ MlirType npcompTorchBoolTypeGet(MlirContext context) {
// torch.int type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsATorchInt(MlirType t) {
bool torchMlirTypeIsATorchInt(MlirType t) {
return unwrap(t).isa<Torch::IntType>();
}
MlirType npcompTorchIntTypeGet(MlirContext context) {
MlirType torchMlirTorchIntTypeGet(MlirContext context) {
return wrap(Torch::IntType::get(unwrap(context)));
}
@ -111,11 +111,11 @@ MlirType npcompTorchIntTypeGet(MlirContext context) {
// torch.float type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsATorchFloat(MlirType t) {
bool torchMlirTypeIsATorchFloat(MlirType t) {
return unwrap(t).isa<Torch::FloatType>();
}
MlirType npcompTorchFloatTypeGet(MlirContext context) {
MlirType torchMlirTorchFloatTypeGet(MlirContext context) {
return wrap(Torch::FloatType::get(unwrap(context)));
}
@ -123,11 +123,11 @@ MlirType npcompTorchFloatTypeGet(MlirContext context) {
// torch.LinearParams type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsATorchLinearParams(MlirType t) {
bool torchMlirTypeIsATorchLinearParams(MlirType t) {
return unwrap(t).isa<Torch::LinearParamsType>();
}
MlirType npcompTorchLinearParamsTypeGet(MlirContext context) {
MlirType torchMlirTorchLinearParamsTypeGet(MlirContext context) {
return wrap(Torch::LinearParamsType::get(unwrap(context)));
}
@ -135,11 +135,11 @@ MlirType npcompTorchLinearParamsTypeGet(MlirContext context) {
// torch.qint8 type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsATorchQInt8(MlirType t) {
bool torchMlirTypeIsATorchQInt8(MlirType t) {
return unwrap(t).isa<Torch::QInt8Type>();
}
MlirType npcompTorchQInt8TypeGet(MlirContext context) {
MlirType torchMlirTorchQInt8TypeGet(MlirContext context) {
return wrap(Torch::QInt8Type::get(unwrap(context)));
}
@ -147,14 +147,14 @@ MlirType npcompTorchQInt8TypeGet(MlirContext context) {
// torch.tensor type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsATorchNonValueTensor(MlirType t) {
bool torchMlirTypeIsATorchNonValueTensor(MlirType t) {
return unwrap(t).isa<Torch::NonValueTensorType>();
}
MlirType npcompTorchNonValueTensorTypeGet(MlirContext context,
intptr_t numSizes,
const int64_t *optionalSizes,
MlirType optionalDtype) {
MlirType torchMlirTorchNonValueTensorTypeGet(MlirContext context,
intptr_t numSizes,
const int64_t *optionalSizes,
MlirType optionalDtype) {
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
if (optionalSizes)
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
@ -162,13 +162,13 @@ MlirType npcompTorchNonValueTensorTypeGet(MlirContext context,
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
}
MlirType npcompTorchNonValueTensorTypeGetWithLeastStaticInformation(
MlirType torchMlirTorchNonValueTensorTypeGetWithLeastStaticInformation(
MlirContext context) {
return wrap(Torch::NonValueTensorType::getWithLeastStaticInformation(
unwrap(context)));
}
MlirType npcompTorchNonValueTensorTypeGetFromShaped(MlirType type) {
MlirType torchMlirTorchNonValueTensorTypeGetFromShaped(MlirType type) {
return wrap(Torch::NonValueTensorType::getFromShaped(
unwrap(type).cast<ShapedType>()));
}
@ -177,13 +177,14 @@ MlirType npcompTorchNonValueTensorTypeGetFromShaped(MlirType type) {
// torch.vtensor type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsATorchValueTensor(MlirType t) {
bool torchMlirTypeIsATorchValueTensor(MlirType t) {
return unwrap(t).isa<Torch::ValueTensorType>();
}
MlirType npcompTorchValueTensorTypeGet(MlirContext context, intptr_t numSizes,
const int64_t *optionalSizes,
MlirType optionalDtype) {
MlirType torchMlirTorchValueTensorTypeGet(MlirContext context,
intptr_t numSizes,
const int64_t *optionalSizes,
MlirType optionalDtype) {
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
if (optionalSizes)
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
@ -191,13 +192,13 @@ MlirType npcompTorchValueTensorTypeGet(MlirContext context, intptr_t numSizes,
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
}
MlirType
npcompTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context) {
MlirType torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(
MlirContext context) {
return wrap(
Torch::ValueTensorType::getWithLeastStaticInformation(unwrap(context)));
}
MlirType npcompTorchValueTensorTypeGetFromShaped(MlirType type) {
MlirType torchMlirTorchValueTensorTypeGetFromShaped(MlirType type) {
return wrap(
Torch::ValueTensorType::getFromShaped(unwrap(type).cast<ShapedType>()));
}
@ -206,11 +207,11 @@ MlirType npcompTorchValueTensorTypeGetFromShaped(MlirType type) {
// torch.none type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsATorchNone(MlirType t) {
bool torchMlirTypeIsATorchNone(MlirType t) {
return unwrap(t).isa<Torch::NoneType>();
}
MlirType npcompTorchNoneTypeGet(MlirContext context) {
MlirType torchMlirTorchNoneTypeGet(MlirContext context) {
return wrap(Torch::NoneType::get(unwrap(context)));
}
@ -218,11 +219,11 @@ MlirType npcompTorchNoneTypeGet(MlirContext context) {
// torch.str type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsATorchString(MlirType t) {
bool torchMlirTypeIsATorchString(MlirType t) {
return unwrap(t).isa<Torch::StringType>();
}
MlirType npcompTorchStringTypeGet(MlirContext context) {
MlirType torchMlirTorchStringTypeGet(MlirContext context) {
return wrap(Torch::StringType::get(unwrap(context)));
}
@ -230,11 +231,11 @@ MlirType npcompTorchStringTypeGet(MlirContext context) {
// torch.any type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsATorchAny(MlirType t) {
bool torchMlirTypeIsATorchAny(MlirType t) {
return unwrap(t).isa<Torch::AnyType>();
}
MlirType npcompTorchAnyTypeGet(MlirContext context) {
MlirType torchMlirTorchAnyTypeGet(MlirContext context) {
return wrap(Torch::AnyType::get(unwrap(context)));
}
@ -242,11 +243,11 @@ MlirType npcompTorchAnyTypeGet(MlirContext context) {
// torch.number type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsATorchNumber(MlirType t) {
bool torchMlirTypeIsATorchNumber(MlirType t) {
return unwrap(t).isa<Torch::NumberType>();
}
MlirType npcompTorchNumberTypeGet(MlirContext context) {
MlirType torchMlirTorchNumberTypeGet(MlirContext context) {
return wrap(Torch::NumberType::get(unwrap(context)));
}
@ -254,10 +255,10 @@ MlirType npcompTorchNumberTypeGet(MlirContext context) {
// torch.Dict type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsATorchDict(MlirType t) {
bool torchMlirTypeIsATorchDict(MlirType t) {
return unwrap(t).isa<Torch::DictType>();
}
MlirType npcompTorchDictTypeGet(MlirType keyType, MlirType valueType) {
MlirType torchMlirTorchDictTypeGet(MlirType keyType, MlirType valueType) {
return wrap(Torch::DictType::get(unwrap(keyType), unwrap(valueType)));
}

View File

@ -0,0 +1,17 @@
add_subdirectory(CAPI)
add_subdirectory(Dialect)
add_mlir_library(TorchMLIRInitAll
InitAll.cpp
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRSupport
TorchMLIRTorchDialect
TorchMLIRTorchPasses
)
torch_mlir_target_includes(TorchMLIRInitAll)

View File

@ -0,0 +1 @@
add_subdirectory(Torch)

View File

@ -1,10 +1,10 @@
add_npcomp_dialect_library(NPCOMPTorchDialect
add_mlir_library(TorchMLIRTorchDialect
TorchDialect.cpp
TorchOps.cpp
TorchTypes.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Torch
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/Torch
DEPENDS
MLIRTorchOpsIncGen
@ -17,6 +17,8 @@ add_npcomp_dialect_library(NPCOMPTorchDialect
MLIRIR
MLIRSupport
MLIRControlFlowInterfaces
MLIRInferTypeOpInterface
MLIRSideEffectInterfaces
NPCOMPInterfaces
)
torch_mlir_target_includes(TorchMLIRTorchDialect)

View File

@ -6,20 +6,20 @@
//
//===----------------------------------------------------------------------===//
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Transforms/InliningUtils.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Torch;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
#include "npcomp/Dialect/Torch/IR/TorchDialect.cpp.inc"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.cpp.inc"
//===----------------------------------------------------------------------===//
// Dialect Interfaces
@ -44,7 +44,7 @@ struct TorchInlinerInterface : public DialectInlinerInterface {
//===----------------------------------------------------------------------===//
#define GET_TYPEDEF_CLASSES
#include "npcomp/Dialect/Torch/IR/TorchTypes.cpp.inc"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.cpp.inc"
//===----------------------------------------------------------------------===//
// Dialect initialize method.
@ -53,11 +53,11 @@ struct TorchInlinerInterface : public DialectInlinerInterface {
void TorchDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "npcomp/Dialect/Torch/IR/TorchOps.cpp.inc"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc"
>();
addTypes<
#define GET_TYPEDEF_LIST
#include "npcomp/Dialect/Torch/IR/TorchTypes.cpp.inc"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.cpp.inc"
>();
addInterfaces<TorchInlinerInterface>();
}
@ -84,7 +84,6 @@ void TorchDialect::printType(Type type, DialectAsmPrinter &printer) const {
llvm_unreachable("unknown 'torch' type");
}
//===----------------------------------------------------------------------===//
// Dialect-level verifiers.
//===----------------------------------------------------------------------===//

View File

@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
@ -15,8 +15,8 @@
#include "llvm/ADT/StringMap.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Torch;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
// see https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h#L28
static int64_t getDtypeIntegerFromMlirType(Type dtype) {
@ -36,9 +36,9 @@ static int64_t getDtypeIntegerFromMlirType(Type dtype) {
// Utilities
//===----------------------------------------------------------------------===//
Value mlir::NPCOMP::Torch::copyTensorToType(OpBuilder &builder, Location loc,
BaseTensorType newType,
Value tensor) {
Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc,
BaseTensorType newType,
Value tensor) {
auto originalType = tensor.getType().cast<BaseTensorType>();
// Adjust the static information in the type to match between the original and
// new types.
@ -393,7 +393,10 @@ void DerefineOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// if not removed eagerly by canonicalizer would prevent ReduceOpVariants
// from converting certain tensors value semantics.
bool allAllowRefinement =
llvm::all_of(op.getResult().getUsers(), allowsTypeRefinement);
llvm::all_of(op.getResult().getUsers(), [](Operation *op) {
return op
->hasTrait<mlir::torch::Torch::OpTrait::AllowsTypeRefinement>();
});
if (!allAllowRefinement)
return failure();
rewriter.replaceOp(op, op.getOperand());
@ -1004,4 +1007,4 @@ OpFoldResult PrimDtypeOp::fold(ArrayRef<Attribute> operands) {
return nullptr;
}
#define GET_OP_CLASSES
#include "npcomp/Dialect/Torch/IR/TorchOps.cpp.inc"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc"

View File

@ -6,15 +6,15 @@
//
//===----------------------------------------------------------------------===//
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "llvm/ADT/STLExtras.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Torch;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
//===----------------------------------------------------------------------===//
// TupleType

View File

@ -14,20 +14,20 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Torch;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
// Map from func name and arg index to the type bound for that arg.
// This is needed because to rewrite calls, we need the non-local information
// from the func definition.
// We also benefit from populating this all at once, which avoids ordering
// issues between rewriting of func ops vs call ops.
using TypeBoundMap = DenseMap<std::pair<StringRef, int>, Type> ;
using TypeBoundMap = DenseMap<std::pair<StringRef, int>, Type>;
namespace {
class AdjustCallingConventionForFunc : public OpConversionPattern<FuncOp> {
@ -136,8 +136,8 @@ public:
return success();
}
private:
TypeBoundMap &typeBoundMap;
private:
TypeBoundMap &typeBoundMap;
};
} // namespace
@ -251,6 +251,6 @@ class AdjustCallingConventionsPass
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::NPCOMP::Torch::createAdjustCallingConventionsPass() {
mlir::torch::Torch::createAdjustCallingConventionsPass() {
return std::make_unique<AdjustCallingConventionsPass>();
}

View File

@ -1,4 +1,4 @@
add_npcomp_conversion_library(NPCOMPTorchPasses
add_mlir_library(TorchMLIRTorchPasses
AdjustCallingConventions.cpp
Passes.cpp
GlobalizeObjectGraph.cpp
@ -10,10 +10,10 @@ add_npcomp_conversion_library(NPCOMPTorchPasses
RefineTypes.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Torch/Transforms
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/Torch/Transforms
DEPENDS
NPCOMPTorchPassIncGen
TorchMLIRTorchPassIncGen
LINK_COMPONENTS
Core
@ -22,6 +22,7 @@ add_npcomp_conversion_library(NPCOMPTorchPasses
MLIRIR
MLIRPass
MLIRTransforms
NPCOMPTorchDialect
NPCOMPInterfaces
TorchMLIRTorchDialect
)
torch_mlir_target_includes(TorchMLIRTorchPasses)

View File

@ -12,9 +12,9 @@
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringExtras.h"
@ -22,8 +22,8 @@
#include "llvm/ADT/StringSet.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Torch;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
static FailureOr<NnModuleOp> findRootNnModule(ModuleOp module) {
NnModuleOp rootNnModule;
@ -664,8 +664,6 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) {
module.push_back(newFunc);
}
for (auto &kv : newFuncs) {
BlockAndValueMapping mapping;
if (failed(analyzeInstances(kv.second, kv.first.argInstances, mapping)))
@ -706,6 +704,6 @@ class GlobalizeObjectGraphPass
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::NPCOMP::Torch::createGlobalizeObjectGraphPass() {
mlir::torch::Torch::createGlobalizeObjectGraphPass() {
return std::make_unique<GlobalizeObjectGraphPass>();
}

View File

@ -11,16 +11,16 @@
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringExtras.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Torch;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace {
class InlineGlobalSlotsPass
@ -87,6 +87,6 @@ class InlineGlobalSlotsPass
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::NPCOMP::Torch::createInlineGlobalSlotsPass() {
mlir::torch::Torch::createInlineGlobalSlotsPass() {
return std::make_unique<InlineGlobalSlotsPass>();
}

View File

@ -12,12 +12,12 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Torch;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace {
class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
@ -131,6 +131,6 @@ class MaximizeValueSemanticsPass
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::NPCOMP::Torch::createMaximizeValueSemanticsPass() {
mlir::torch::Torch::createMaximizeValueSemanticsPass() {
return std::make_unique<MaximizeValueSemanticsPass>();
}

View File

@ -6,20 +6,20 @@
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
#define NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
#ifndef TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
#define TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace NPCOMP {
namespace torch {
namespace Torch {
#define GEN_PASS_CLASSES
#include "npcomp/Dialect/Torch/Transforms/Passes.h.inc"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc"
} // namespace Torch
} // namespace NPCOMP
} // namespace torch
} // end namespace mlir
#endif // NPCOMP_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
#endif // TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H

View File

@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
@ -16,22 +16,22 @@
namespace {
#define GEN_PASS_REGISTRATION
#include "npcomp/Dialect/Torch/Transforms/Passes.h.inc"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc"
} // end namespace
void mlir::NPCOMP::registerTorchPasses() {
void mlir::torch::registerTorchPasses() {
::registerPasses();
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torchscript-to-torch-backend-pipeline",
"Pipeline lowering TorchScript object graph IR to Torch backend form.",
mlir::NPCOMP::Torch::createTorchScriptToTorchBackendPipeline);
mlir::torch::Torch::createTorchScriptToTorchBackendPipeline);
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torch-globalized-module-to-torch-backend-pipeline",
"Pipeline lowering a globalized Torch program to Torch backend form.",
mlir::NPCOMP::Torch::createGlobalizedModuleToTorchBackendPipeline);
mlir::torch::Torch::createGlobalizedModuleToTorchBackendPipeline);
}
void mlir::NPCOMP::Torch::createTorchScriptToTorchBackendPipeline(
void mlir::torch::Torch::createTorchScriptToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
// When we import TorchScript IR, we import their entire "compilation unit",
// which can contain numerous functions unrelated to the current program,
@ -62,7 +62,7 @@ void mlir::NPCOMP::Torch::createTorchScriptToTorchBackendPipeline(
createGlobalizedModuleToTorchBackendPipeline(pm, options);
}
void mlir::NPCOMP::Torch::createGlobalizedModuleToTorchBackendPipeline(
void mlir::torch::Torch::createGlobalizedModuleToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
// General considerations: As a matter of bring-up, we are simultaneously
// building out the frontend pipeline and also co-developing the backend

View File

@ -14,13 +14,13 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Torch;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace {
class ConvertPrimCallMethodToCall : public OpRewritePattern<PrimCallMethodOp> {
@ -93,16 +93,15 @@ class PrepareForGlobalizeObjectGraphPass
// to the form we want.
ConversionTarget target(*context);
target.addIllegalOp<PrimCallMethodOp>();
target.addDynamicallyLegalOp<ConstantOp>([](ConstantOp op) {
return !op.getType().isa<FunctionType>();
});
target.addDynamicallyLegalOp<ConstantOp>(
[](ConstantOp op) { return !op.getType().isa<FunctionType>(); });
target.addIllegalOp<CallIndirectOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
RewritePatternSet dummyPatterns(context);
if (failed(applyFullConversion(getOperation(), target,
std::move(dummyPatterns)))) {
std::move(dummyPatterns)))) {
return signalPassFailure();
}
}
@ -110,6 +109,6 @@ class PrepareForGlobalizeObjectGraphPass
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::NPCOMP::Torch::createPrepareForGlobalizeObjectGraphPass() {
mlir::torch::Torch::createPrepareForGlobalizeObjectGraphPass() {
return std::make_unique<PrepareForGlobalizeObjectGraphPass>();
}

View File

@ -9,13 +9,13 @@
#include "PassDetail.h"
#include "mlir/Transforms/DialectConversion.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "llvm/ADT/StringExtras.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Torch;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace {
// Convert value semantic ops operating on mutable arrays to instead operate on
@ -145,6 +145,6 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::NPCOMP::Torch::createReduceOpVariantsPass() {
mlir::torch::Torch::createReduceOpVariantsPass() {
return std::make_unique<ReduceOpVariantsPass>();
}

View File

@ -11,12 +11,12 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Torch;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace {
@ -86,6 +86,6 @@ class RefinePublicReturnPass
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::NPCOMP::Torch::createRefinePublicReturnPass() {
mlir::torch::Torch::createRefinePublicReturnPass() {
return std::make_unique<RefinePublicReturnPass>();
}

View File

@ -15,13 +15,13 @@
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Torch;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
// -----------------------------------------------------------------------------
// Analysis.
@ -264,9 +264,8 @@ public:
} else if (auto sum = dyn_cast<AtenSumOp>(op)) {
return visitReductionAlongAllDimsOp(sum, operands);
} else if (auto sumDimIntList = dyn_cast<AtenSumDimIntListOp>(op)) {
return visitReductionAlongDimIntListOp(
sumDimIntList, sumDimIntList.dim(), sumDimIntList.keepdim(),
operands);
return visitReductionAlongDimIntListOp(sumDimIntList, sumDimIntList.dim(),
sumDimIntList.keepdim(), operands);
} else if (auto meanDim = dyn_cast<AtenMeanDimOp>(op)) {
return visitReductionAlongDimIntListOp(meanDim, meanDim.dim(),
meanDim.keepdim(), operands);
@ -1114,7 +1113,7 @@ static Type getMostRefinedStaticType(Value v, TypeAnalyzer &analyzer) {
// the right thing forthose ops.
//
static bool allowsTypeRefinementOrIsSafeToRefine(Operation *op) {
return allowsTypeRefinement(op) ||
return op->hasTrait<mlir::torch::Torch::OpTrait::AllowsTypeRefinement>() ||
isa<CopyToNonValueTensorOp, CopyToValueTensorOp>(op);
}
@ -1244,6 +1243,6 @@ class RefineTypesPass : public RefineTypesBase<RefineTypesPass> {
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::NPCOMP::Torch::createRefineTypesPass() {
mlir::torch::Torch::createRefineTypesPass() {
return std::make_unique<RefineTypesPass>();
}

View File

@ -0,0 +1,19 @@
//===----------------------------------------------------------------------===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/InitAll.h"
#include "mlir/IR/Dialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
void mlir::torch::registerAllDialects(mlir::DialectRegistry &registry) {
registry.insert<mlir::torch::Torch::TorchDialect>();
}
void mlir::torch::registerAllPasses() { mlir::torch::registerTorchPasses(); }

View File

@ -0,0 +1,30 @@
llvm_canonicalize_cmake_booleans(
MLIR_ENABLE_BINDINGS_PYTHON
)
configure_lit_site_cfg(
${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py
MAIN_CONFIG
${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py
)
set(TORCH_MLIR_TEST_DEPENDS
FileCheck count not
torch-mlir-opt
)
# XXX: Enable
#if(MLIR_ENABLE_BINDINGS_PYTHON)
# list(APPEND TORCH_MLIR_TEST_DEPENDS
# TorchMLIRPythonModules
# )
#endif()
add_lit_testsuite(check-torch-mlir "Running the torch-mlir regression tests"
${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${TORCH_MLIR_TEST_DEPENDS}
)
set_target_properties(check-torch-mlir PROPERTIES FOLDER "Tests")
add_lit_testsuites(TORCH_MLIR ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TORCH_MLIR_TEST_DEPENDS})

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
// Basic case.

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-globalize-object-graph -verify-diagnostics -split-input-file %s
// RUN: torch-mlir-opt -torch-globalize-object-graph -verify-diagnostics -split-input-file %s
torch.class_type @c1 {}
torch.class_type @c2 {}

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
torch.class_type @c {
torch.attr "float" : !torch.float

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
// CHECK that multiple nested initialization ops are properly handled.

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
torch.class_type @c {
torch.attr "float" : !torch.float

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file -verify-diagnostics %s
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file -verify-diagnostics %s
torch.class_type @parent {
torch.method "module_type_return", @module_type_return

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
torch.class_type @child {
torch.attr "float" : !torch.float

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-globalize-object-graph -verify-diagnostics -split-input-file %s
// RUN: torch-mlir-opt -torch-globalize-object-graph -verify-diagnostics -split-input-file %s
func private @__torch__.Submodule.forward(%arg0: !torch.nn.Module<"__torch__.Submodule">, %arg1: !torch.nn.Module<"__torch__.Submodule">) {
return

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
// Tests monomorphization of same function with different instance argument types.

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
torch.class_type @__torch__.TestModule {
torch.attr private "s1" : !torch.nn.Module<"__torch__.Submodule">

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
// Check that linkage names consist of the dotted path from the root.

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
torch.class_type @c {
// CHECK: torch.global_slot "private" @float : !torch.float

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-adjust-calling-conventions -allow-unregistered-dialect -split-input-file %s | FileCheck %s
// RUN: torch-mlir-opt -torch-adjust-calling-conventions -allow-unregistered-dialect -split-input-file %s | FileCheck %s
// CHECK-LABEL: func @basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt %s -canonicalize | FileCheck %s
// RUN: torch-mlir-opt %s -canonicalize | FileCheck %s
// CHECK-LABEL: func @torch.aten.__is__
// CHECK: %[[FALSE:.*]] = torch.constant.bool false

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-inline-global-slots -split-input-file %s | FileCheck %s
// RUN: torch-mlir-opt -torch-inline-global-slots -split-input-file %s | FileCheck %s
// CHECK-NOT: @readonly
torch.global_slot "private" @readonly : !torch.tensor {

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt <%s -split-input-file -verify-diagnostics
// RUN: torch-mlir-opt <%s -split-input-file -verify-diagnostics
// -----

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -split-input-file -allow-unregistered-dialect %s -torch-maximize-value-semantics | FileCheck %s
// RUN: torch-mlir-opt -split-input-file -allow-unregistered-dialect %s -torch-maximize-value-semantics | FileCheck %s
// CHECK-LABEL: func @torch.copy.tensor$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt %s | npcomp-opt | FileCheck %s
// RUN: torch-mlir-opt %s | torch-mlir-opt | FileCheck %s
// CHECK-LABEL: func @torch.operator(
func @torch.operator(%arg0: !torch.tensor, %arg1: !torch.tensor) -> !torch.tensor {

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-prepare-for-globalize-object-graph -split-input-file %s | FileCheck %s
// RUN: torch-mlir-opt -torch-prepare-for-globalize-object-graph -split-input-file %s | FileCheck %s
torch.class_type @c {
torch.method "test_call_method", @test_call_method

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-reduce-op-variants %s | FileCheck %s
// RUN: torch-mlir-opt -torch-reduce-op-variants %s | FileCheck %s
// CHECK-LABEL: func @convert_to_value_semantic_tensors(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -split-input-file -verify-diagnostics %s -torch-refine-public-return | FileCheck %s
// RUN: torch-mlir-opt -split-input-file -verify-diagnostics %s -torch-refine-public-return | FileCheck %s
// CHECK-LABEL: func @basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor<[2,3,?],f32> {

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-refine-types -split-input-file %s | FileCheck %s
// RUN: torch-mlir-opt -torch-refine-types -split-input-file %s | FileCheck %s
// -----

View File

@ -1,4 +1,4 @@
// RUN: npcomp-opt -torch-refine-types -split-input-file %s | FileCheck %s
// RUN: torch-mlir-opt -torch-refine-types -split-input-file %s | FileCheck %s
// CHECK-LABEL: func @f(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {

View File

@ -0,0 +1,71 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import os
import platform
import re
import subprocess
import tempfile
import lit.formats
import lit.util
from lit.llvm import llvm_config
from lit.llvm.subst import ToolSubst
from lit.llvm.subst import FindTool
# Configuration file for the 'lit' test runner.
# name: The name of this test suite.
config.name = 'TORCH_MLIR'
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
# suffixes: A list of file extensions to treat as test files.
config.suffixes = ['.mlir', '.py']
# test_source_root: The root path where tests are located.
config.test_source_root = os.path.dirname(__file__)
# test_exec_root: The root path where tests should be run.
config.test_exec_root = os.path.join(config.torch_mlir_obj_root, 'test')
config.substitutions.append(('%PATH%', config.environment['PATH']))
config.substitutions.append(('%shlibext', config.llvm_shlib_ext))
llvm_config.with_system_environment(['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP'])
#llvm_config.use_default_substitutions()
# excludes: A list of directories to exclude from the testsuite. The 'Inputs'
# subdirectories contain auxiliary inputs for various tests in their parent
# directories.
config.excludes = [
'Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt',
'lit.cfg.py', 'lit.site.cfg.py'
]
# test_source_root: The root path where tests are located.
config.test_source_root = os.path.dirname(__file__)
# test_exec_root: The root path where tests should be run.
config.test_exec_root = os.path.join(config.torch_mlir_obj_root, 'test')
config.standalone_tools_dir = os.path.join(config.torch_mlir_obj_root, 'bin')
# Tweak the PATH to include the tools dir.
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
tool_dirs = [config.llvm_tools_dir]
tools = [
ToolSubst('%PYTHON', config.python_executable, unresolved='ignore'),
]
llvm_config.add_tool_substitutions(tools, tool_dirs)
if config.enable_bindings_python:
llvm_config.with_environment('PYTHONPATH', [
os.path.join(config.torch_mlir_obj_root, 'python_packages',
'torch_mlir'),
],
append_path=True)

View File

@ -0,0 +1,21 @@
@LIT_SITE_CFG_IN_HEADER@
import sys
config.enable_bindings_python = @MLIR_ENABLE_BINDINGS_PYTHON@
config.torch_mlir_obj_root = "@TORCH_MLIR_BINARY_DIR@"
config.llvm_src_root = "@LLVM_SOURCE_DIR@"
config.llvm_obj_root = "@LLVM_BINARY_DIR@"
config.llvm_tools_dir = "@LLVM_TOOLS_DIR@"
config.llvm_lib_dir = "@LLVM_LIBS_DIR@"
config.llvm_shlib_dir = "@SHLIBDIR@"
config.llvm_shlib_ext = "@SHLIBEXT@"
config.llvm_exe_ext = "@EXEEXT@"
config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@"
config.python_executable = sys.executable
import lit.llvm
lit.llvm.initialize(lit_config, config)
# Let the main config do the real work.
lit_config.load_config(config, "@TORCH_MLIR_SOURCE_DIR@/test/lit.cfg.py")

View File

@ -0,0 +1,2 @@
if not config.enable_bindings_python:
config.unsupported = True

View File

@ -0,0 +1,9 @@
# RUN: %PYTHON %s
# XXX: Fix this
# XFAIL: *
import mlir.ir
from mlir.dialects import iree
with mlir.ir.Context() as ctx:
iree.register_iree_dialect(ctx)

View File

@ -0,0 +1 @@
add_subdirectory(torch-mlir-opt)

View File

@ -0,0 +1,13 @@
add_executable(torch-mlir-opt torch-mlir-opt.cpp)
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
target_link_libraries(torch-mlir-opt PRIVATE
MLIROptLib
TorchMLIRInitAll
TorchMLIRTorchDialect
TorchMLIRTorchPasses
${dialect_libs}
${conversion_libs}
)

View File

@ -0,0 +1,27 @@
//===- torch-mlir-opt.cpp - MLIR Optimizer Driver -------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Support/MlirOptMain.h"
#include "torch-mlir/InitAll.h"
using namespace mlir;
int main(int argc, char **argv) {
registerAllPasses();
mlir::torch::registerAllPasses();
DialectRegistry registry;
registerAllDialects(registry);
mlir::torch::registerAllDialects(registry);
return mlir::asMainReturnCode(
mlir::MlirOptMain(argc, argv, "MLIR modular optimizer driver\n", registry,
/*preloadDialectsInContext=*/false));
}

View File

@ -11,7 +11,7 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "npcomp-c/TorchTypes.h"
#include "torch-mlir-c/TorchTypes.h"
#include <ATen/core/function_schema.h>
#include <ATen/core/ivalue.h>
@ -510,14 +510,14 @@ MlirType AcapController::mapIValueToMlirType(MlirLocation loc,
return mlirIntegerTypeGet(funcBuilder->getContext(), 1);
}
if (ival.isList()) {
return npcompTorchListTypeGet(
return torchMlirTorchListTypeGet(
typeMapper.mapFromTorchType(loc, ival.toList().elementType()));
}
if (ival.isNone()) {
return npcompTorchNoneTypeGet(funcBuilder->getContext());
return torchMlirTorchNoneTypeGet(funcBuilder->getContext());
}
if (ival.isDevice()) {
return npcompTorchNoneTypeGet(funcBuilder->getContext());
return torchMlirTorchNoneTypeGet(funcBuilder->getContext());
}
return {nullptr};
}
@ -527,7 +527,7 @@ MlirValue AcapController::importTensorByValue(at::Tensor tensor) {
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
MlirOperation tensorOp = createMlirOperationAtEnd(
funcBuilder->getEntryBlock(), "torch.tensor.literal", loc,
npcompTorchNonValueTensorTypeGetFromShaped(
torchMlirTorchNonValueTensorTypeGetFromShaped(
mlirAttributeGetType(denseElements)),
toMlirNamedAttribute("value", denseElements));
MlirValue tensorValue = mlirOperationGetResult(tensorOp, 0);

View File

@ -12,7 +12,7 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
#include "npcomp-c/TorchTypes.h"
#include "torch-mlir-c/TorchTypes.h"
using namespace torch_mlir;
@ -98,7 +98,7 @@ MlirValue FuncBuilder::getScalarConstant(MlirLocation loc, at::Scalar s) {
// represented as one of double or int64_t, with a special tag for whether
// it should be interpreted as a bool.
if (s.isIntegral(/*includeBool=*/false)) {
MlirType t = npcompTorchIntTypeGet(context);
MlirType t = torchMlirTorchIntTypeGet(context);
MlirAttribute value =
mlirIntegerAttrGet(mlirIntegerTypeGet(context, 64), s.to<int64_t>());
MlirOperation op = createMlirOperation(
@ -107,7 +107,7 @@ MlirValue FuncBuilder::getScalarConstant(MlirLocation loc, at::Scalar s) {
return mlirOperationGetResult(op, 0);
}
if (s.isFloatingPoint()) {
MlirType t = npcompTorchFloatTypeGet(context);
MlirType t = torchMlirTorchFloatTypeGet(context);
MlirAttribute value = mlirFloatAttrDoubleGet(
context, mlirF64TypeGet(context), s.to<double>());
MlirOperation op = createMlirOperation(
@ -133,7 +133,7 @@ MlirValue FuncBuilder::getNoneConstant(MlirLocation loc) {
MlirValue FuncBuilder::buildList(MlirLocation loc, MlirType elementType,
std::vector<MlirValue> &elements) {
MlirType resultType = npcompTorchListTypeGet(elementType);
MlirType resultType = torchMlirTorchListTypeGet(elementType);
OperationStateHolder state{"torch.prim.ListConstruct", loc};
mlirOperationStateAddResults(state, 1, &resultType);
mlirOperationStateAddOperands(state, elements.size(), elements.data());

View File

@ -16,7 +16,7 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
#include "npcomp-c/TorchTypes.h"
#include "torch-mlir-c/TorchTypes.h"
#include "caffe2/core/scope_guard.h"
#include "ATen/native/quantized/cpu/packed_params.h"
@ -170,7 +170,7 @@ IValueImporter::importModule(torch::jit::Module currentModule) {
MlirOperation nnModule = createMlirOperation(
"torch.nn_module", loc,
npcompTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
torchMlirTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
mlirRegionCreate());
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr));
@ -240,7 +240,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
MlirLocation loc = mlirLocationUnknownGet(context);
if (ivalue.isBool()) {
MlirType type = npcompTorchBoolTypeGet(context);
MlirType type = torchMlirTorchBoolTypeGet(context);
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.constant.bool", loc, type,
toMlirNamedAttribute("value",
@ -248,7 +248,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
return mlirOperationGetResult(operation, 0);
}
if (ivalue.isDouble()) {
MlirType type = npcompTorchFloatTypeGet(context);
MlirType type = torchMlirTorchFloatTypeGet(context);
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.constant.float", loc, type,
toMlirNamedAttribute(
@ -257,7 +257,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
return mlirOperationGetResult(operation, 0);
}
if (ivalue.isInt()) {
MlirType type = npcompTorchIntTypeGet(context);
MlirType type = torchMlirTorchIntTypeGet(context);
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.constant.int", loc, type,
toMlirNamedAttribute("value",
@ -273,7 +273,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
}
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.prim.ListConstruct", loc,
npcompTorchListTypeGet(
torchMlirTorchListTypeGet(
typeMapper.mapFromTorchType(loc, list.elementType())),
elems);
return mlirOperationGetResult(operation, 0);
@ -288,7 +288,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
}
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.prim.DictConstruct", loc,
npcompTorchDictTypeGet(
torchMlirTorchDictTypeGet(
typeMapper.mapFromTorchType(loc, dict.keyType()),
typeMapper.mapFromTorchType(loc, dict.valueType())),
keys, values);
@ -305,7 +305,8 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
}
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.prim.TupleConstruct", loc,
npcompTorchTupleTypeGet(context, types.size(), types.data()), operands);
torchMlirTorchTupleTypeGet(context, types.size(), types.data()),
operands);
return mlirOperationGetResult(operation, 0);
}
if (ivalue.isTensor()) {
@ -317,7 +318,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
if (ivalue.isString()) {
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.constant.str", loc,
npcompTorchStringTypeGet(context),
torchMlirTorchStringTypeGet(context),
toMlirNamedAttribute(
"value",
mlirStringAttrGet(context,
@ -327,7 +328,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
if (ivalue.isNone()) {
MlirOperation operation =
createMlirOperationAtEnd(importBlock, "torch.constant.none", loc,
npcompTorchNoneTypeGet(context));
torchMlirTorchNoneTypeGet(context));
return mlirOperationGetResult(operation, 0);
}
if (ivalue.isCustomClass()) {
@ -346,7 +347,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
}
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.linear_params.create", loc,
npcompTorchLinearParamsTypeGet(context), weightValue, biasValue);
torchMlirTorchLinearParamsTypeGet(context), weightValue, biasValue);
return mlirOperationGetResult(operation, 0);
}
}
@ -366,7 +367,7 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
MlirOperation tensorOp =
createMlirOperationAtEnd(importBlock, "torch.tensor.literal", loc,
npcompTorchNonValueTensorTypeGetFromShaped(
torchMlirTorchNonValueTensorTypeGetFromShaped(
mlirAttributeGetType(denseElements)),
toMlirNamedAttribute("value", denseElements));
MlirValue tensorReprValue = mlirOperationGetResult(tensorOp, 0);
@ -381,7 +382,7 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
// compiler stages that are building a statically modeled quantization
// representation will need to convert this to their representation.
std::vector<int64_t> shape(tensor.sizes().begin(), tensor.sizes().end());
MlirType quantizedTensorType = npcompTorchNonValueTensorTypeGet(
MlirType quantizedTensorType = torchMlirTorchNonValueTensorTypeGet(
context, shape.size(), shape.data(),
typeMapper.mapFromTorchScalarType(tensor.scalar_type()));
if (tensor.qscheme() == c10::kPerTensorAffine) {
@ -531,11 +532,11 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
int64_t dummy;
int64_t *shapeData = shape.size() == 0 ? &dummy : shape.data();
if (hasValueSemantics) {
typeBound = npcompTorchValueTensorTypeGet(context, shape.size(),
shapeData, dtype);
} else {
typeBound = npcompTorchNonValueTensorTypeGet(context, shape.size(),
typeBound = torchMlirTorchValueTensorTypeGet(context, shape.size(),
shapeData, dtype);
} else {
typeBound = torchMlirTorchNonValueTensorTypeGet(
context, shape.size(), shapeData, dtype);
}
MlirNamedAttribute typeBoundAttr = toMlirNamedAttribute(

View File

@ -16,6 +16,7 @@
#include "mlir-c/Diagnostics.h"
#include "mlir-c/Registration.h"
#include "npcomp-c/Registration.h"
#include "torch-mlir-c/Registration.h"
namespace py = pybind11;
using namespace torch_mlir;
@ -114,7 +115,7 @@ ModuleBuilder::ModuleBuilder(pybind11::object contextObj)
// TODO: Rework this once dialect registration C-APIs are in place.
// https://reviews.llvm.org/D88162
mlirRegisterAllDialects(context);
npcompRegisterAllDialects(context);
torchMlirRegisterAllDialects(context);
registerPythonSysStderrDiagnosticHandler(context);

View File

@ -15,7 +15,7 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
#include "npcomp-c/TorchTypes.h"
#include "torch-mlir-c/TorchTypes.h"
namespace py = pybind11;
using namespace torch_mlir;
@ -150,7 +150,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
importAttribute(loc, node, c10::attr::value)));
} else if (output->type()->cast<c10::StringType>()) {
op = createMlirOperation(
"torch.constant.str", loc, npcompTorchStringTypeGet(context),
"torch.constant.str", loc, torchMlirTorchStringTypeGet(context),
toMlirNamedAttribute(
"value", mlirStringAttrGet(context, toMlirStringRef(node->s(
c10::attr::value)))));
@ -186,7 +186,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
mlirRegionCreate());
mapResults(node, operation);
std::vector<MlirType> terminatorOperandTypes = {
npcompTorchBoolTypeGet(context)};
torchMlirTorchBoolTypeGet(context)};
terminatorOperandTypes.insert(terminatorOperandTypes.end(),
resultTypes.begin(), resultTypes.end());
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,

View File

@ -10,7 +10,7 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
#include "npcomp-c/TorchTypes.h"
#include "torch-mlir-c/TorchTypes.h"
using namespace torch_mlir;
@ -18,11 +18,11 @@ OpBuilder::OpBuilder(MlirContext context) : context(context) {}
MlirOperation OpBuilder::createNoneConstant(MlirLocation loc) {
return createMlirOperation("torch.constant.none", loc,
npcompTorchNoneTypeGet(context));
torchMlirTorchNoneTypeGet(context));
}
MlirOperation OpBuilder::createBoolConstant(MlirLocation loc, bool value) {
return createMlirOperation(
"torch.constant.bool", loc, npcompTorchBoolTypeGet(context),
"torch.constant.bool", loc, torchMlirTorchBoolTypeGet(context),
toMlirNamedAttribute("value", mlirBoolAttrGet(context, value)));
}

View File

@ -15,7 +15,7 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
#include "npcomp-c/TorchTypes.h"
#include "torch-mlir-c/TorchTypes.h"
using namespace torch_mlir;
@ -66,7 +66,7 @@ MlirType TypeMapper::rawMapFromTorchScalarType(c10::ScalarType scalarType) {
case ScalarType::Half:
return mlirF16TypeGet(context);
case ScalarType::QInt8:
return npcompTorchQInt8TypeGet(context);
return torchMlirTorchQInt8TypeGet(context);
default: {
return {nullptr};
}
@ -105,7 +105,7 @@ static MlirType mapCustomClassType(MlirContext context, MlirLocation loc,
// Individually handle the custom classes that we know about.
if (name == "__torch__.torch.classes.quantized.LinearPackedParamsBase") {
return npcompTorchLinearParamsTypeGet(context);
return torchMlirTorchLinearParamsTypeGet(context);
}
// At this point, we know that the type is indeed a custom class type, but
@ -136,11 +136,11 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
auto &sizes = tensorType->symbolic_sizes();
if (!sizes.rank()) {
// Unranked.
return npcompTorchNonValueTensorTypeGet(context,
/*numSizes=*/0,
/*optionalSizes=*/nullptr,
/*optionalDtype=*/
elementType);
return torchMlirTorchNonValueTensorTypeGet(context,
/*numSizes=*/0,
/*optionalSizes=*/nullptr,
/*optionalDtype=*/
elementType);
}
// Ranked with possibly dynamic dims.
auto &symbolicShape = tensorType->symbolic_sizes();
@ -150,28 +150,28 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
auto shapeSymbol = symbolicShape[i];
dims[i] = shapeSymbol.is_static() ? shapeSymbol.static_size() : -1;
}
return npcompTorchNonValueTensorTypeGet(context, dims.size(),
/*optionalSizes=*/dims.data(),
/*optionalDtype=*/
elementType);
return torchMlirTorchNonValueTensorTypeGet(context, dims.size(),
/*optionalSizes=*/dims.data(),
/*optionalDtype=*/
elementType);
}
case TypeKind::IntType: {
return npcompTorchIntTypeGet(context);
return torchMlirTorchIntTypeGet(context);
}
case TypeKind::FloatType: {
return npcompTorchFloatTypeGet(context);
return torchMlirTorchFloatTypeGet(context);
}
case TypeKind::BoolType: {
return npcompTorchBoolTypeGet(context);
return torchMlirTorchBoolTypeGet(context);
}
case TypeKind::NumberType: {
return npcompTorchNumberTypeGet(context);
return torchMlirTorchNumberTypeGet(context);
}
case TypeKind::StringType: {
return npcompTorchStringTypeGet(context);
return torchMlirTorchStringTypeGet(context);
}
case TypeKind::OptionalType: {
return npcompTorchOptionalTypeGet(mapFromTorchType(
return torchMlirTorchOptionalTypeGet(mapFromTorchType(
loc, torchType->cast<c10::OptionalType>()->getElementType()));
}
case TypeKind::TupleType: {
@ -180,25 +180,25 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
torchType->cast<c10::TupleType>()->containedTypes()) {
containedTypes.push_back(mapFromTorchType(loc, type));
}
return npcompTorchTupleTypeGet(context, containedTypes.size(),
containedTypes.data());
return torchMlirTorchTupleTypeGet(context, containedTypes.size(),
containedTypes.data());
}
case TypeKind::ListType: {
return npcompTorchListTypeGet(mapFromTorchType(
return torchMlirTorchListTypeGet(mapFromTorchType(
loc, torchType->cast<c10::ListType>()->getElementType()));
}
case TypeKind::DictType: {
auto dictType = torchType->cast<c10::DictType>();
return npcompTorchDictTypeGet(
return torchMlirTorchDictTypeGet(
mapFromTorchType(loc, dictType->getKeyType()),
mapFromTorchType(loc, dictType->getValueType()));
}
case TypeKind::NoneType: {
return npcompTorchNoneTypeGet(context);
return torchMlirTorchNoneTypeGet(context);
}
case TypeKind::AnyType: {
auto anyType = torchType->cast<c10::AnyType>();
return npcompTorchAnyTypeGet(context);
return torchMlirTorchAnyTypeGet(context);
}
case TypeKind::ClassType: {
const c10::ClassTypePtr &classType = torchType->cast<c10::ClassType>();
@ -208,10 +208,10 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
}
auto maybeName = classType->name();
std::string name = maybeName ? maybeName->qualifiedName() : "unnamed class";
return npcompTorchNnModuleTypeGet(context, toMlirStringRef(name));
return torchMlirTorchNnModuleTypeGet(context, toMlirStringRef(name));
}
case TypeKind::DeviceObjType: {
return npcompTorchDeviceTypeGet(context);
return torchMlirTorchDeviceTypeGet(context);
}
default: {
std::stringstream message;
@ -226,7 +226,7 @@ MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) {
if (!tensor.defined()) {
// Undefined tensors are equivalent to None.
// This may need to be re-evaluated at some point.
return npcompTorchNoneTypeGet(context);
return torchMlirTorchNoneTypeGet(context);
}
MlirType elementType = mapFromTorchScalarType(tensor.scalar_type());
@ -234,8 +234,8 @@ MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) {
// just erase them and let the compiler decide.
auto sizes = tensor.sizes();
return npcompTorchNonValueTensorTypeGet(context, sizes.size(), sizes.data(),
elementType);
return torchMlirTorchNonValueTensorTypeGet(context, sizes.size(),
sizes.data(), elementType);
}
MlirType

View File

@ -2,5 +2,4 @@ add_subdirectory(Basicpy)
add_subdirectory(Numpy)
add_subdirectory(Refback)
add_subdirectory(Refbackrt)
add_subdirectory(Torch)
add_subdirectory(TorchConversion)

View File

@ -17,7 +17,6 @@
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "npcomp/Interfaces/Traits.h"
#include "npcomp/Typing/Analysis/CPA/Interfaces.h"
#define GET_OP_CLASSES

View File

@ -10,7 +10,6 @@
#define NPCOMP_DIALECT_NUMPY_IR_NUMPY_OPS
include "npcomp/Dialect/Numpy/IR/NumpyDialect.td"
include "npcomp/Interfaces/Traits.td"
include "npcomp/Typing/Analysis/CPA/Interfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
@ -39,7 +38,6 @@ def Numpy_NarrowOp : Numpy_Op<"narrow", []> {
def Numpy_StaticInfoCastOp : Numpy_Op<"static_info_cast", [
DeclareOpInterfaceMethods<CastOpInterface>,
AllowsTypeRefinement,
NoSideEffect]> {
let summary = "Adds/removes static information from an array type.";
let description = [{
@ -60,7 +58,6 @@ def Numpy_StaticInfoCastOp : Numpy_Op<"static_info_cast", [
def Numpy_TensorStaticInfoCastOp : Numpy_Op<"tensor_static_info_cast", [
DeclareOpInterfaceMethods<CastOpInterface>,
AllowsTypeRefinement,
NoSideEffect]> {
let summary = "Adds/removes static information from a tensor type.";
let description = [{

View File

@ -1,5 +0,0 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls)
add_public_tablegen_target(NPCOMPTorchPassIncGen)
add_mlir_doc(Passes NPCOMPTorchTransforms ./ -gen-pass-doc)

View File

@ -16,8 +16,8 @@
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#define GET_OP_CLASSES
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.h.inc"

View File

@ -14,7 +14,7 @@ include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "npcomp/Dialect/TorchConversion/IR/TorchConversionBase.td"
include "npcomp/Dialect/Torch/IR/TorchTypes.td"
include "torch-mlir/Dialect/Torch/IR/TorchTypes.td"
include "iree-dialects/Dialect/IREE/IREEDialect.td"
class TorchConversion_Op<string mnemonic, list<OpTrait> traits = []>

View File

@ -10,7 +10,7 @@
#define NPCOMP_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H
#include "mlir/Pass/Pass.h"
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include <memory>
@ -21,7 +21,8 @@ namespace TorchConversion {
/// Creates a pipeline that lowers the object graph IR that is produced by
/// TorchScript import into the form expected by npcomp-verify-backend-contract.
void createTorchScriptToNpcompBackendPipeline(
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options);
OpPassManager &pm,
const torch::Torch::TorchLoweringPipelineOptions &options);
std::unique_ptr<OperationPass<ModuleOp>>
createVerifyInvariantsBeforeBackendLoweringPass();

View File

@ -13,21 +13,7 @@
namespace mlir {
namespace NPCOMP {
namespace OpTrait {
template <typename ConcreteType>
class AllowsTypeRefinement
: public ::mlir::OpTrait::TraitBase<ConcreteType, AllowsTypeRefinement> {};
} // namespace OpTrait
// Check if an operation has the AllowsTypeRefinement trait.
//
// This function should be used in preference to
// `op->hasTrait<AllowsTypeRefinement>()` because this function has knowledge of
// some upstream ops that have this property, but which we cannot annotate with
// this trait.
bool allowsTypeRefinement(Operation *op);
namespace OpTrait {} // namespace OpTrait
} // namespace NPCOMP
} // namespace mlir

View File

@ -19,13 +19,6 @@ class NpcompOpTrait<string name> : OpTrait, NativeTrait<"", ""> {
let cppNamespace = "::mlir::NPCOMP::OpTrait";
}
// Op allows operand and result types to be refined.
// For example a `tensor<?xf32>` can be refined to `tensor<4xf32>`.
//
// TODO: Implement RefinableTypeInterface that allows actually modeling
// which types are refinements of other types.
// See the design in:
// https://llvm.discourse.group/t/allow-shape-concretization-or-type-concretization-in-rewrites/3327/3
def AllowsTypeRefinement : NpcompOpTrait<"AllowsTypeRefinement">;
// Empty for now. Kept as boilerplate placeholder.
#endif // NPCOMP_INTERFACES_TRAITS

View File

@ -10,7 +10,6 @@ add_npcomp_library(NPCOMPCAPI
Registration.cpp
BasicpyTypes.cpp
NumpyTypes.cpp
TorchTypes.cpp
LINK_LIBS PUBLIC
MLIRExecutionEngine
@ -21,7 +20,8 @@ add_npcomp_library(NPCOMPCAPI
NPCOMPNumpyDialect
NPCOMPRefBackendJITHelpers
NPCOMPRuntime
NPCOMPTorchDialect
TorchMLIRTorchDialect
TorchMLIRInitAll
# MLIR CAPI deps
MLIRCAPIIR

View File

@ -13,6 +13,7 @@
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Transforms/Passes.h"
#include "npcomp/InitAll.h"
#include "torch-mlir/InitAll.h"
void npcompRegisterAllDialects(MlirContext context) {
mlir::DialectRegistry registry;

View File

@ -31,7 +31,7 @@ add_npcomp_library(NPCOMPInitAll
NPCOMPIREEBackend
NPCOMPRefBackend
NPCOMPRefbackDialect
NPCOMPTorchDialect
TorchMLIRTorchDialect
NPCOMPTorchConversionDialect
NPCOMPRefbackrtDialect
NPCOMPBasicpyDialect

View File

@ -13,7 +13,7 @@ add_npcomp_conversion_library(NPCOMPTorchToIREE
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
NPCOMPTorchDialect
TorchMLIRTorchDialect
MLIRStandard
IREEDialectsIREEDialect
)

View File

@ -14,13 +14,13 @@
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "npcomp/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Torch;
using namespace mlir::torch::Torch;
//===----------------------------------------------------------------------===//
// The patterns

View File

@ -15,5 +15,5 @@ add_npcomp_conversion_library(NPCOMPTorchToLinalg
MLIRPass
MLIRLinalg
MLIRMath
NPCOMPTorchDialect
TorchMLIRTorchDialect
)

Some files were not shown because too many files have changed in this diff Show More