Eliminate almost all mentions of IREE.

A few remain in examples/docs that will be naturally be updated in due
time.

This regresses the list support and the general direction of more widely
supported control flow, lists/dicts/globals that we were going for with
the TorchScript path. The idea is that we are deferring that work to
make torch-mlir a very clean standalone thing. We will reboot it,
probably using some of the tools of iree_pydm to make it simpler, and in
a more natural place (such as an iree-torch repo that depends on IREE and
torch-mlir to build a working PyTorch frontend solution for IREE -- it
was really weird that npcomp depended on IREE).
pull/321/head
Sean Silva 2021-09-22 22:48:39 +00:00
parent 8779d920b2
commit 1a0b953ea7
75 changed files with 11 additions and 2531 deletions

View File

@ -22,7 +22,6 @@ endif()
# Options and settings
#-------------------------------------------------------------------------------
set(NPCOMP_MINIMUM_PYTHON_VERSION 3.6)
set(NPCOMP_IREE_BUILDDIR "../iree-build" CACHE STRING "If building IREE, then setting this elects to build from a source directory (versus installed package)")
# Turn on -gsplit-dwarf if requested in debug builds.
if (NPCOMP_USE_SPLIT_DWARF AND
@ -135,8 +134,7 @@ if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
"extension = '${PYTHON_MODULE_EXTENSION}")
# Include LLVM_EXTERNAL_PROJECTS.
set(LLVM_EXTERNAL_PROJECTS "iree-dialects;torch-mlir")
set(LLVM_EXTERNAL_IREE_DIALECTS_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/external/iree-dialects")
set(LLVM_EXTERNAL_PROJECTS "torch-mlir")
set(LLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/external/torch-mlir")
# LLVM configuration.
@ -188,8 +186,6 @@ include_directories(${LLVM_INCLUDE_DIRS})
include_directories(${MLIR_INCLUDE_DIRS})
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/external/iree-dialects/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/llvm/tools/iree-dialects/include)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/external/torch-mlir/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/llvm/tools/torch-mlir/include)
link_directories(${LLVM_BUILD_LIBRARY_DIR})

View File

@ -1,237 +0,0 @@
#!/usr/bin/env python3
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Find common version hashes for dependent projects.
Sample usage:
./build_tools/find_version_hashes.py --iree_dir=${IREE_DIR}
This script will fetch dependent projects and seek back over the last
--revision-depth commits against their respective version files in order to
find common revisions of each that share a same common LLVM hash, reporting
all such hashes.
Note that this procedure is not guaranteed to work or produce a recent
version. It has a reasonable probability of working since the non-LLVM
dependencies are published by Google at regular intervals and common LLVM
commits.
In general, unless if the versions found by this script are too old, integrating
at its recommendation will increase the probability that dependencies are
actually mutually compatible with each other and make for an easier time
upgrading. It is experimental and subject to change.
"""
import argparse
import collections
import os
import subprocess
import sys
TOP_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
def create_argument_parser():
parser = argparse.ArgumentParser(
prog="find_version_hashes.py",
description="Finds common version hashes for sub-projects")
parser.add_argument("--llvm-dir",
help="Directory of the llvm-project repo",
default="external/llvm-project")
parser.add_argument("--mlir-hlo-dir",
help="Directory of the MLIR HLO project checkout",
default=os.path.join(TOP_DIR, "external", "mlir-hlo"))
parser.add_argument("--iree-dir",
help="Directory of the IREE project checkout (optional)",
default=None)
parser.add_argument(
"--revision-depth",
type=int,
help="The number of revisions to search back for a common join",
default=50)
parser.add_argument("--no-fetch",
help="Do not fetch child repositories",
action="store_true")
return parser
ChildRevisionMap = collections.namedtuple(
"ChildRevisionMap", "parent_revision,parent_date,child_revision,child_date")
ParentChildJoin = collections.namedtuple(
"ParentChildJoin", "parent_revision,parent_date,child_revision_maps")
def main(args):
if not args.no_fetch:
fetch(args.llvm_dir)
llvm_revision_maps = {}
if args.mlir_hlo_dir:
llvm_revision_maps["mhlo"] = get_mhlo_llvm_history(args)
if args.iree_dir:
llvm_revision_maps["iree"] = get_iree_llvm_history(args)
# Join the LLVM revision.
join_results = join_child_revision_maps(llvm_revision_maps)
if not join_results:
print("No common LLVM version found (TODO print a better report)s.")
print(llvm_revision_maps)
sys.exit(1)
# Report.
print("COMMON LLVM REVISION: {} (at {})".format(join_results.parent_revision,
join_results.parent_date))
for child_key, child_revision_map in join_results.child_revision_maps.items():
print(" - {}: {} (at {})".format(child_key,
child_revision_map.child_revision,
child_revision_map.child_date))
def join_child_revision_maps(revision_maps):
"""Joins dicts of child_key -> [ChildRevisionMap].
Returns:
Return ParentChildJoin or None if no revisions found.
"""
parent_revision_dates = dict() # Dates of each parent revision.
parent_revisions = dict() # Of parent_revision -> count of agreements.
for child_key, child_maps in revision_maps.items():
for child_map in child_maps:
parent_revision_dates[child_map.parent_revision] = child_map.parent_date
count = parent_revisions.get(child_map.parent_revision)
parent_revisions[child_map.parent_revision] = (
(0 if count is None else count) + 1)
def build_child_map(parent_revision):
child_map = dict()
for child_key, child_revision_map in revision_maps.items():
for single_child_revision_map in child_revision_map:
if single_child_revision_map.parent_revision == parent_revision:
child_map[child_key] = single_child_revision_map
break
return child_map
# Find the most recent parent commit where all children agree.
expected_children = len(revision_maps)
for parent_revision, count in parent_revisions.items():
if count == expected_children:
# Found it!
return ParentChildJoin(parent_revision,
parent_revision_dates[parent_revision],
build_child_map(parent_revision))
return None
def get_mhlo_llvm_history(args):
"""Mlir-hlo stores its llvm commit hash in a text file which is parsed.
Returns:
List of ChildRevisionMap.
"""
if not args.no_fetch:
fetch(args.mlir_hlo_dir)
mlir_hlo_revisions = get_file_revisions(args.mlir_hlo_dir,
"build_tools/llvm_version.txt",
revision_depth=args.revision_depth)
# Re-arrange into (llvm_revision, llvm_date, child_revision, child_date)
llvm_history = []
for child_revision, child_date, contents in mlir_hlo_revisions:
llvm_revision = contents.decode("UTF-8").strip()
llvm_date = get_commit_date(args.llvm_dir, llvm_revision)
llvm_history.append(
ChildRevisionMap(llvm_revision, llvm_date, child_revision, child_date))
return llvm_history
def get_iree_llvm_history(args):
"""Gets the IREE LLVM history by parsing the SUBMODULE_VERSIONS file.
Returns:
List of ChildRevisionMap.
"""
if not args.no_fetch:
fetch(args.iree_dir)
iree_revisions = get_file_revisions(args.iree_dir,
"SUBMODULE_VERSIONS",
revision_depth=args.revision_depth)
def get_llvm_revision(submodule_versions):
# Each line is "hash path/to/module"
for line in submodule_versions.decode("UTF-8").strip().splitlines():
revision, path = line.split(" ", maxsplit=1)
if path == "third_party/llvm-project":
return revision
return None
llvm_history = []
for child_revision, child_date, contents in iree_revisions:
llvm_revision = get_llvm_revision(contents)
if llvm_revision is None:
print(
"Could not find llvm-project revision in IREE SUBMODULE_VERSIONS:\n" +
contents.decode("UTF-8"),
file=sys.stderr)
llvm_date = get_commit_date(args.llvm_dir, llvm_revision)
llvm_history.append(
ChildRevisionMap(llvm_revision, llvm_date, child_revision, child_date))
return llvm_history
def get_commit_date(repo_path, revision):
"""Gets the date of a commit."""
return subprocess_call(
["git", "log", "-n", "1", "--pretty=format:%ci", revision],
cwd=repo_path,
capture_output=True).decode("UTF-8").strip()
def get_file_revisions(repo_path, file_path, revision_depth):
"""Gets the file contents at the last `revision-depth` commits.
Returns:
A tuple of (revision, date, contents).
"""
revisions = subprocess_call([
"git", "log", "--pretty=format:%H %ci", "-n",
str(revision_depth), "origin/HEAD", "--", file_path
],
cwd=repo_path,
capture_output=True).decode("UTF-8").splitlines()
# Split on space.
revisions = [r.split(" ", maxsplit=1) for r in revisions]
# Generate the revision tuple (revision, date, contents).
def show_contents(revision):
return subprocess_call(["git", "show", "{}:{}".format(revision, file_path)],
cwd=repo_path,
capture_output=True)
revision_contents = [
(revision, date, show_contents(revision)) for revision, date in revisions
]
return revision_contents
def fetch(repo_path):
print("Fetching", repo_path, "...", file=sys.stderr)
subprocess_call(["git", "fetch", "--recurse-submodules=no"], cwd=repo_path)
def subprocess_call(args, cwd, capture_output=False, **kwargs):
"""Calls a subprocess, possibly capturing output."""
try:
if capture_output:
return subprocess.check_output(args, cwd=cwd, **kwargs)
else:
return subprocess.check_call(args, cwd=cwd, **kwargs)
except subprocess.CalledProcessError:
print("ERROR executing subprocess (from {}):\n {}".format(
cwd, " ".join(args)),
file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main(create_argument_parser().parse_args(sys.argv[1:]))

View File

@ -1,16 +0,0 @@
#!/bin/bash
set -euo pipefail
if [ "$#" -ne 1 ]; then
echo "Usage: $0 <iree_src_root>"
echo 'Description:
iree_src_root: root directory of IREE source checkout
'
exit 1
fi
npcomp_src_root="$(realpath $(dirname $0)/..)"
iree_src_root=$1
rm -rf "${npcomp_src_root}/external/iree-dialects"
cp -a "${iree_src_root}/llvm-external-projects/iree-dialects" "${npcomp_src_root}/external"

View File

@ -1,25 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import torch
from npcomp_torchscript.e2e_test.framework import TestUtils
from npcomp_torchscript.e2e_test.registry import register_test_case
from npcomp_torchscript.annotations import annotate_args, export
# ==============================================================================
class ListLiteralModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
def forward(self, x: int):
return [x, x]
@register_test_case(module_factory=lambda: ListLiteralModule())
def ListLiteralModule_basic(module, tu: TestUtils):
module.forward(3)

View File

@ -17,10 +17,6 @@ from npcomp_torchscript_e2e_test_configs import (
NpcompBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig
)
from npcomp.compiler.pytorch.backend import is_iree_enabled
IREE_ENABLED = is_iree_enabled()
if IREE_ENABLED:
from npcomp.compiler.pytorch.backend.iree import IreeNpcompBackend
from npcomp.compiler.pytorch.backend.refbackend import RefBackendNpcompBackend
from .xfail_sets import XFAIL_SETS
@ -35,13 +31,12 @@ from . import conv
from . import batchnorm
from . import quantized_models
from . import elementwise
from . import list_programs
from . import reduction
def _get_argparse():
# TODO: Allow pulling in an out-of-tree backend, so downstream can easily
# plug into the e2e tests.
config_choices = ['native_torch', 'torchscript', 'refbackend']
if IREE_ENABLED:
config_choices += ['iree']
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
parser.add_argument('-c', '--config',
choices=config_choices,
@ -49,7 +44,6 @@ def _get_argparse():
help=f'''
Meaning of options:
"refbackend": run through npcomp's RefBackend.
"iree"{'' if IREE_ENABLED else '(disabled)'}: run through npcomp's IREE backend.
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
''')
@ -76,8 +70,6 @@ def main():
# Find the selected config.
if args.config == 'refbackend':
config = NpcompBackendTestConfig(RefBackendNpcompBackend())
elif args.config == 'iree':
config = NpcompBackendTestConfig(IreeNpcompBackend())
elif args.config == 'native_torch':
config = NativeTorchTestConfig()
elif args.config == 'torchscript':

View File

@ -18,16 +18,7 @@ _common_npcomp_lowering_xfails = {
'QuantizedMLP_basic',
}
# Any test expected to fail on backends that don't support non-tensor types
# should be listed here.
_common_non_tensor_type_xfails = {
'ListLiteralModule_basic',
}
XFAIL_SETS['refbackend'] = (_common_npcomp_lowering_xfails
| _common_non_tensor_type_xfails)
XFAIL_SETS['iree'] = _common_npcomp_lowering_xfails
XFAIL_SETS['refbackend'] = _common_npcomp_lowering_xfails
XFAIL_SETS['torchscript'] = {}

View File

@ -13,7 +13,7 @@ import torch_mlir
import npcomp
from npcomp.passmanager import PassManager
from npcomp.compiler.pytorch.backend import refbackend, iree
from npcomp.compiler.pytorch.backend import refbackend
from npcomp.compiler.utils import logging
mb = torch_mlir.ModuleBuilder()

View File

@ -1 +0,0 @@
/build/

View File

@ -1,87 +0,0 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
package(
default_visibility = ["//visibility:public"],
features = ["layering_check"],
licenses = ["notice"],
)
exports_files(glob(["include/iree-dialects/Dialect/IREE/*.td"]))
filegroup(
name = "TdFilegroup",
srcs = glob(["include/iree-dialects/Dialect/IREE/*.td"]),
)
td_library(
name = "TdFiles",
srcs = glob(["include/iree-dialects/Dialect/IREE/*.td"]),
includes = ["include"],
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectTdFiles",
],
)
gentbl_cc_library(
name = "IREEOpsIncGen",
strip_include_prefix = "include",
tbl_outs = [
(
["-gen-dialect-decls"],
"include/iree-dialects/Dialect/IREE/IREEOpsDialect.h.inc",
),
(
["-gen-dialect-defs"],
"include/iree-dialects/Dialect/IREE/IREEOpsDialect.cpp.inc",
),
(
["-gen-op-decls"],
"include/iree-dialects/Dialect/IREE/IREEOps.h.inc",
),
(
["-gen-op-defs"],
"include/iree-dialects/Dialect/IREE/IREEOps.cpp.inc",
),
(
["-gen-typedef-decls"],
"include/iree-dialects/Dialect/IREE/IREEOpsTypes.h.inc",
),
(
["-gen-typedef-defs"],
"include/iree-dialects/Dialect/IREE/IREEOpsTypes.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "include/iree-dialects/Dialect/IREE/IREEOps.td",
deps = [":TdFiles"],
)
cc_library(
name = "IREEDialect",
srcs = glob([
"lib/Dialect/IREE/*.cpp",
]),
hdrs = glob(["include/iree-dialects/Dialect/IREE/*.h"]),
includes = ["include"],
deps = [
":IREEOpsIncGen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)
cc_library(
name = "CAPI",
srcs = [
"lib/CAPI/Dialects.cpp",
],
hdrs = [
"include/iree-dialects-c/Dialects.h",
],
deps = [
":IREEDialect",
"@llvm-project//mlir:CAPIIR",
],
)

View File

@ -1,70 +0,0 @@
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
message(FATAL_ERROR
"This project is intended to be built as part of LLVM via "
"-DLLVM_EXTERNAL_PROJECTS=iree-dialects "
"-DLLVM_EXTERNAL_IREE_DIALECTS_SOURCE_DIR=${CMAKE_CURRENT_SOURCE_DIR}")
endif()
option(MLIR_ENABLE_BINDINGS_PYTHON "Enables MLIR Python Bindings" OFF)
set(IREE_DIALECTS_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
set(IREE_DIALECTS_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")
message(STATUS "Building iree-dialects project at ${IREE_DIALECTS_SOURCE_DIR} (into ${IREE_DIALECTS_BINARY_DIR})")
# TODO: Fix this upstream so that global include directories are not needed.
set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir)
set(MLIR_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include)
set(MLIR_GENERATED_INCLUDE_DIR ${LLVM_BINARY_DIR}/tools/mlir/include)
# TODO: Needed for tablegen. Remove.
include_directories(SYSTEM ${MLIR_INCLUDE_DIR})
include_directories(SYSTEM ${MLIR_GENERATED_INCLUDE_DIR})
include_directories(SYSTEM ${IREE_DIALECTS_SOURCE_DIR}/include)
function(iree_dialects_target_includes target)
set(_dirs
$<BUILD_INTERFACE:${MLIR_INCLUDE_DIR}>
$<BUILD_INTERFACE:${MLIR_GENERATED_INCLUDE_DIR}>
$<BUILD_INTERFACE:${IREE_DIALECTS_SOURCE_DIR}/include>
$<BUILD_INTERFACE:${IREE_DIALECTS_BINARY_DIR}/include>
)
# In LLVM parlance, the actual target may just be an interface and may not
# be responsible for actually compiling anything. The corresponding obj.
# target, when present, is just used for compilation and does not
# contribute to the interface properties.
# TODO: Normalize this upstream.
target_include_directories(${target} PUBLIC ${_dirs})
if(TARGET obj.${target})
target_include_directories(obj.${target} PRIVATE ${_dirs})
endif()
endfunction()
# Configure CMake and tablegen.
list(APPEND CMAKE_MODULE_PATH ${MLIR_MAIN_SRC_DIR}/cmake/modules)
list(APPEND CMAKE_MODULE_PATH ${LLVM_MAIN_SRC_DIR}/cmake)
set(MLIR_TABLEGEN_EXE mlir-tblgen)
include(TableGen)
include(AddLLVM)
include(AddMLIR)
################################################################################
# Setup python.
# TODO: Make one upstream macro to do this.
################################################################################
if(MLIR_ENABLE_BINDINGS_PYTHON)
include(MLIRDetectPythonEnv)
mlir_detect_pybind11_install()
find_package(Python3 ${LLVM_MINIMUM_PYTHON_VERSION}
COMPONENTS Interpreter Development NumPy REQUIRED)
find_package(pybind11 2.6 CONFIG REQUIRED)
endif()
add_subdirectory(include)
add_subdirectory(lib)
add_subdirectory(test)
if(MLIR_ENABLE_BINDINGS_PYTHON)
add_subdirectory(python)
endif()

View File

@ -1,11 +0,0 @@
# IREE Dialects Project
Sources for IREE's public dialects (containing ops/types/attributes that are
unique to IREE and can appear in compiler inputs).
This project is intended to be used via LLVM's external projects setup:
* `-DLLVM_EXTERNAL_PROJECTS=iree-dialects`
* `-DLLVM_EXTERNAL_IREE_DIALECTS_SOURCE_DIR={this_directory}`
It depends on the `mlir` project.

View File

@ -1,27 +0,0 @@
#!/bin/bash
# Copyright 2021 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Simple script that does a CMake configure of this project as an external
# LLVM project so it can be tested in isolation to larger assemblies.
# This is meant for CI's and project maintainers.
set -eu -o errtrace
project_dir="$(cd $(dirname $0)/.. && pwd)"
repo_root="$(cd "$project_dir"/../.. && pwd)"
llvm_project_dir="$repo_root/third_party/llvm-project"
build_dir="$project_dir/build"
cmake -GNinja -B"$build_dir" "$llvm_project_dir/llvm" \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_ENABLE_PROJECTS=mlir \
-DLLVM_EXTERNAL_PROJECTS=iree-dialects \
-DLLVM_EXTERNAL_IREE_DIALECTS_SOURCE_DIR="$project_dir" \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON
cd "$build_dir"
ninja tools/iree-dialects/all

View File

@ -1 +0,0 @@
add_subdirectory(iree-dialects)

View File

@ -1,22 +0,0 @@
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_C_DIALECTS_H
#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_C_DIALECTS_H
#include "mlir-c/Registration.h"
#ifdef __cplusplus
extern "C" {
#endif
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(IREE, iree);
#ifdef __cplusplus
}
#endif
#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_C_DIALECTS_H

View File

@ -1 +0,0 @@
add_subdirectory(Dialect)

View File

@ -1 +0,0 @@
add_subdirectory(IREE)

View File

@ -1,3 +0,0 @@
add_mlir_dialect(IREEOps iree)
add_mlir_doc(IREEDialect IREEDialect IREE/ -gen-dialect-doc)
add_mlir_doc(IREEOps IREEOps IREE/ -gen-op-doc)

View File

@ -1,115 +0,0 @@
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_BASE_TD
#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_BASE_TD
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
def IREE_Dialect : Dialect {
let name = "iree";
let summary = "Public ops/type/attributes legal for input to IREE's compiler";
let description = [{
IREE's compiler allows as input a number of common dialects. This dialect
contains structural and unique ops that do not exist elsewhere or that IREE
has an interest in maintaining as a stable set.
The contents of this dialect often mirror various constructs in IREE's
internal implementation. The focus here is on simplicity and stability
over time. Generally, this dialect does not use "advanced" features and
should be broadly source compatible over a range of LLVM versions. There
are of course, limits, and source-compatibility is not guaranteed, since
LLVM/MLIR's API surface is itself unstable.
}];
let cppNamespace = "::mlir::iree";
}
class IREE_Op<string mnemonic, list<OpTrait> traits = []> :
Op<IREE_Dialect, mnemonic, traits>;
class IREE_PureOp<string mnemonic, list<OpTrait> traits = []> :
Op<IREE_Dialect, mnemonic, !listconcat(traits, [NoSideEffect])>;
class IREE_Type<string name> : TypeDef<IREE_Dialect, name>;
//===----------------------------------------------------------------------===//
// Predicates
//===----------------------------------------------------------------------===//
class IREE_AliasedSymbolRefAttr : Attr<CPred<"$_self.isa<FlatSymbolRefAttr>()">,
"symbol reference attribute"> {
let storageType = [{ FlatSymbolRefAttr }];
let returnType = [{ StringRef }];
let valueType = NoneType;
let constBuilderCall = "::mlir::SymbolRefAttr::get($_builder.getContext(), $0)";
}
class IREE_AnyPtrOf<list<Type> types> :
Type<And<[
CPred<"$_self.isa<::mlir::iree::PtrType>()">,
Or<!foreach(type, types,
SubstLeaves<
"$_self",
"$_self.cast<::mlir::iree::PtrType>().getTargetType()",
type.predicate>)>,
]>, !interleave(!foreach(type, types, type.summary), " or ")> {
string builderCall = "";
}
def IREE_PrimitiveType : AnyTypeOf<[Index, AnySignlessInteger, AnyFloat]>;
def IREE_Tensor : TypeAlias<AnyRankedTensor>;
def IREE_AnyList : DialectType<
IREE_Dialect,
CPred<"$_self.isa<::mlir::iree::ListType>()">,
"list"> {
let description = [{
A mutable, resizable list of some type.
}];
}
class IREE_ListOf<Type type> :
Type<And<[
CPred<"$_self.isa<::mlir::iree::ListType>()">,
SubstLeaves<"$_self",
"$_self.cast<::mlir::iree::ListType>().getElementType()",
type.predicate>
]>, "list<" # type.summary # ">"> {
// Set the builder call if the base type has a builder call.
string builderCall = !if(!empty(type.builderCall),
"", "::mlir::iree::ListType::get(" # type.builderCall # ")");
}
def IREE_ElementTypeParameter : TypeParameter<
"::mlir::Type", "A type suitable as an element type of a container">;
def IREE_PtrTargetTypeParameter : TypeParameter<
"::mlir::Type", "A type suitable as a target type of a pointer">;
def IREE_Dim : TypeAlias<Index>;
def IREE_Dims : Variadic<IREE_Dim>;
def IREE_Shape : Variadic<IREE_Dim>;
def IREE_ShapeDynamicDims : Variadic<IREE_Dim>;
def IREE_GlobalRefAttr : IREE_AliasedSymbolRefAttr;
def IREE_AnyGlobalPtr : IREE_AnyPtrOf<[IREE_Tensor, IREE_PrimitiveType]>;
class IREE_IndexAttrBase<string descr> :
TypedAttrBase<
Index, "IntegerAttr",
And<[
CPred<"$_self.isa<IntegerAttr>()">,
CPred<"$_self.cast<IntegerAttr>().getType().isIndex()">,
]>,
descr> {
let returnType = [{ APInt }];
}
def IREE_IndexAttr : IREE_IndexAttrBase<"size_t">;
def IREE_TiedOpStorageAttr :
TypedArrayAttrBase<IREE_IndexAttr, "64-bit integer array attribute"> {
let constBuilderCall = "$_builder.getI64ArrayAttr($0)";
}
#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_BASE_TD

View File

@ -1,19 +0,0 @@
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_DIALECT_H
#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_DIALECT_H
#include "mlir/IR/Dialect.h"
// Include generated dialect code (this comment blocks clang-format from
// clobbering order).
#include "iree-dialects/Dialect/IREE/IREEOpsDialect.h.inc"
#define GET_TYPEDEF_CLASSES
#include "iree-dialects/Dialect/IREE/IREEOpsTypes.h.inc"
#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_DIALECT_H

View File

@ -1,110 +0,0 @@
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_DIALECT_TD
#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_DIALECT_TD
include "iree-dialects/Dialect/IREE/IREEBase.td"
//===----------------------------------------------------------------------===//
// Types
//===----------------------------------------------------------------------===//
def IREE_BufferViewType : IREE_Type<"BufferView"> {
let mnemonic = "buffer_view";
let summary = "View into a buffer, with runtime shape and element type";
let description = [{
BufferViews represent views onto backing IREE runtime Buffer objects,
adding runtime shape and element type parameters to the backing buffer.
BufferViews are typically accepted and returned at boundaries with
external code.
In the runtime and lower level compiler, BufferView's are fully modeled;
however, as boundary types, not all features are exposed publicly. Since
within compiled tensor programs, it is typical to operate in terms of
fully typed tensors, the primary mechanism for getting or using a
BufferView at the high level is by casting to/from a tensor. It is left
to higher level code to ensure that aliasing rules are enforced at such
boundaries.
}];
let printer = [{
$_printer << "buffer_view";
}];
let parser = [{
return get($_ctxt);
}];
}
def IREE_VariantType : IREE_Type<"Variant"> {
let mnemonic = "variant";
let summary = "Represents any legal or reference type in the IREE runtime";
let description = [{
The variant type is typically used to parameterize container types that
can contain any legal primitive, reference or null in the IREE type system.
}];
let printer = [{
$_printer << "variant";
}];
let parser = [{
return get($_ctxt);
}];
}
def IREE_ListType : IREE_Type<"List"> {
let mnemonic = "list";
let summary = "A one dimensional list of runtime values";
let description = [{
Represents a list of arbitrary type. Primitive types can be expected to
be efficiently stored in an unboxed form. Reference types and variants
are permitted.
Lists can either be homogenous, with a fixed element type, or heterogenous
by parameterizing them with a VariantType.
}];
let parameters = (ins IREE_ElementTypeParameter:$elementType);
let printer = [{
$_printer << "list<" << getElementType() << ">";
}];
let parser = [{
Type elementType;
if ($_parser.parseLess() || $_parser.parseType(elementType) ||
$_parser.parseGreater())
return Type();
return get($_ctxt, elementType);
}];
}
def IREE_PtrType : IREE_Type<"Ptr"> {
let mnemonic = "ptr";
let summary = "Pointer to a concrete type";
let parameters = (ins IREE_PtrTargetTypeParameter:$targetType);
let printer = [{
$_printer << "ptr<" << getTargetType() << ">";
}];
let parser = [{
Type targetType;
if ($_parser.parseLess() || $_parser.parseType(targetType) ||
$_parser.parseGreater())
return Type();
return get($_ctxt, targetType);
}];
}
#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_DIALECT_TD

View File

@ -1,21 +0,0 @@
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_OPS_H
#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_OPS_H
#include "iree-dialects/Dialect/IREE/IREEDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#define GET_OP_CLASSES
#include "iree-dialects/Dialect/IREE/IREEOps.h.inc"
#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_OPS_H

View File

@ -1,525 +0,0 @@
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_OPS_TD
#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_OPS_TD
include "iree-dialects/Dialect/IREE/IREEDialect.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
def IREE_NullOp : IREE_PureOp<"null"> {
let summary = "a null value";
let description = [{
Initializes reference and variant types with a null value.
}];
let results = (outs
AnyType:$result
);
let assemblyFormat = [{
attr-dict `:` type($result)
}];
}
//===----------------------------------------------------------------------===//
// Casts
//===----------------------------------------------------------------------===//
def IREE_TensorToBufferViewOp : IREE_PureOp<"cast.tensor_to_buffer_view"> {
let summary = "Casts a tensor to a BufferView, capturing dynamic dims";
let arguments = (ins
IREE_Tensor:$source,
IREE_ShapeDynamicDims:$source_dims
);
let results = (outs IREE_BufferViewType:$target);
let assemblyFormat = [{
$source `:` type($source) (`{` $source_dims^ `}`)? `->` type($target)
attr-dict-with-keyword
}];
}
def IREE_BufferViewToTensorOp : IREE_PureOp<"cast.buffer_view_to_tensor"> {
let summary = "Casts a BufferView to a tensor, providing dynamic dims";
let arguments = (ins
IREE_BufferViewType:$source,
IREE_ShapeDynamicDims:$target_dims
);
let results = (outs IREE_Tensor:$target);
let assemblyFormat = [{
$source `:` type($source) `->` type($target) (`{` $target_dims^ `}`)?
attr-dict-with-keyword
}];
}
//===----------------------------------------------------------------------===//
// Global variables
//===----------------------------------------------------------------------===//
def IREE_GlobalOp : IREE_Op<"global", [
Symbol,
]> {
let summary = [{stateful global variable declaration}];
let description = [{
Declares a global variable that maintains its value across invocations.
The value is tied to the execution context of the module and different
contexts will have different global storage.
}];
let arguments = (ins
OptionalAttr<StrAttr>:$sym_visibility,
SymbolNameAttr:$sym_name,
TypeAttr:$type,
UnitAttr:$is_mutable,
OptionalAttr<FlatSymbolRefAttr>:$initializer,
OptionalAttr<AnyAttr>:$initial_value
);
let assemblyFormat = [{
custom<SymbolVisibility>($sym_visibility)
(`mutable` $is_mutable^)?
$sym_name
attr-dict
(`initializer` `(` $initializer^ `)`):(``)?
custom<TypeOrAttr>($type, $initial_value)
}];
}
def IREE_GlobalAddressOp : IREE_PureOp<"global.address"> {
let summary = [{returns an address reference to a global}];
let description = [{
Returns the address of a global as a typed reference. Can be used with the
global load and store indirect ops.
}];
let arguments = (ins
IREE_GlobalRefAttr:$global
);
let results = (outs
IREE_AnyGlobalPtr:$result
);
let assemblyFormat = [{
$global attr-dict `:` type($result)
}];
}
def IREE_GlobalLoadOp : IREE_Op<"global.load"> {
let summary = [{loads a value from a global variable}];
let description = [{
Returns a copy of the global value.
}];
let arguments = (ins
IREE_GlobalRefAttr:$global
);
let results = (outs
AnyType:$result
);
let assemblyFormat = [{
$global attr-dict `:` type($result)
}];
}
def IREE_GlobalLoadIndirectOp : IREE_Op<"global.load.indirect"> {
let summary = [{loads a value from a global variable}];
let description = [{
Returns a copy of the global value.
}];
let arguments = (ins
IREE_AnyGlobalPtr:$global
);
let results = (outs
AnyType:$result
);
let assemblyFormat = [{
$global attr-dict `:` type($global) `->` type($result)
}];
}
def IREE_GlobalStoreOp : IREE_Op<"global.store"> {
let summary = [{stores a value into a global variable}];
let description = [{
Stores a copy of the value into a global.
}];
let arguments = (ins
AnyType:$value,
IREE_GlobalRefAttr:$global
);
let assemblyFormat = [{
$value `,` $global attr-dict `:` type($value)
}];
}
def IREE_GlobalStoreIndirectOp : IREE_Op<"global.store.indirect"> {
let summary = [{stores a value into a global variable}];
let description = [{
Stores a copy of the value into a global.
}];
let arguments = (ins
AnyType:$value,
IREE_AnyGlobalPtr:$global
);
let assemblyFormat = [{
$value `,` $global attr-dict `:` type($value) `->` type($global)
}];
}
//===----------------------------------------------------------------------===//
// Buffer Views
//===----------------------------------------------------------------------===//
def IREE_BufferViewRankOp : IREE_PureOp<"buffer_view.rank"> {
let summary = [{buffer view rank query}];
let description = [{
Returns the rank of the buffer view.
}];
let arguments = (ins
IREE_BufferViewType:$buffer_view
);
let results = (outs
IREE_Dim:$result
);
let assemblyFormat = [{
$buffer_view attr-dict `:` type($result)
}];
}
def IREE_BufferViewDimOp : IREE_PureOp<"buffer_view.dim"> {
let summary = [{buffer view dimension value query}];
let description = [{
Returns the value of the given dimension.
}];
let arguments = (ins
IREE_BufferViewType:$buffer_view,
IndexAttr:$index
);
let results = (outs
IREE_Dim:$result
);
let assemblyFormat = [{
$buffer_view `,` $index attr-dict `:` type($result)
}];
}
//===----------------------------------------------------------------------===//
// Mutable Lists
//===----------------------------------------------------------------------===//
def IREE_ListCreateOp : IREE_PureOp<
"list.create", [MemoryEffects<[MemAlloc]>]> {
let summary = [{creates a new empty list}];
let description = [{
Creates a new empty list with an optional initial capacity.
}];
let arguments = (ins
Optional<Index>:$initial_capacity
);
let results = (outs
IREE_AnyList:$result
);
let assemblyFormat = "($initial_capacity^)? attr-dict `:` type($result)";
}
def IREE_ListSizeOp : IREE_Op<"list.size", [MemoryEffects<[MemRead]>]> {
let summary = [{the size of the list in elements}];
let description = [{
Returns the current size of the list in elements.
}];
let arguments = (ins
IREE_AnyList:$list
);
let results = (outs
Index:$result
);
let assemblyFormat = "operands attr-dict `:` type($list)";
}
def IREE_ListResizeOp : IREE_Op<"list.resize", [MemoryEffects<[MemWrite]>]> {
let summary = [{resizes the list to a new count in elements}];
let description = [{
Resizes the list to contain `new_size` elements. This will either truncate
the list if the existing size is greater than `new_size` or extend the list
with the default list value of the element type.
}];
let arguments = (ins
IREE_AnyList:$list,
Index:$new_size
);
let assemblyFormat = "operands attr-dict `:` type($list)";
}
def IREE_ListGetOp : IREE_Op<"list.get", [MemoryEffects<[MemRead]>]> {
let summary = [{element accessor}];
let description = [{
Returns the value of the element at the given index. Note that the value
may be null if the element is null or the type does not match.
}];
let arguments = (ins
IREE_AnyList:$list,
Index:$index
);
let results = (outs
AnyType:$result
);
let assemblyFormat = "$list `[` $index `]` attr-dict `:` type($list) `->` type($result)";
}
def IREE_ListSetOp : IREE_Op<"list.set", [MemoryEffects<[MemWrite]>]> {
let summary = [{element mutator}];
let description = [{
Sets the element at the given index to the new value.
}];
let arguments = (ins
IREE_AnyList:$list,
Index:$index,
AnyType:$value
);
let assemblyFormat = "$list `[` $index `]` `,` $value attr-dict `:` type($list) `,` type($value)";
}
//===----------------------------------------------------------------------===//
// Tensor ops
//===----------------------------------------------------------------------===//
def IREE_TensorReshapeOp : IREE_PureOp<"tensor.reshape", [
AllElementTypesMatch<["source", "result"]>,
AttrSizedOperandSegments,
]> {
let summary = [{reshapes a tensor}];
let description = [{
Reshapes a tensor to a new shape without modifying the contents.
}];
let arguments = (ins
IREE_Tensor:$source,
IREE_ShapeDynamicDims:$source_dims,
IREE_ShapeDynamicDims:$result_dims
);
let results = (outs
IREE_Tensor:$result
);
let assemblyFormat = [{
$source `:`
type($source) (`{` $source_dims^ `}`)? `->`
type($result) (`{` $result_dims^ `}`)?
attr-dict-with-keyword
}];
}
def IREE_TensorLoadOp : IREE_PureOp<"tensor.load", [
TypesMatchWith<"value type matches element type of target operand",
"source", "result",
"$_self.cast<ShapedType>().getElementType()">,
AttrSizedOperandSegments,
]> {
let summary = [{loads a value from a tensor element}];
let description = [{
Returns the element at the given location from within the tensor.
}];
let arguments = (ins
IREE_Tensor:$source,
IREE_ShapeDynamicDims:$source_dims,
Variadic<IREE_Dim>:$indices
);
let results = (outs
AnyTypeOf<[IREE_PrimitiveType, AnyVector]>:$result
);
let assemblyFormat = [{
$source (`[` $indices^ `]`)? `:`
type($source) (`{` $source_dims^ `}`)?
attr-dict-with-keyword
}];
}
def IREE_TensorStoreOp : IREE_PureOp<"tensor.store", [
AllTypesMatch<["target", "result"]>,
TypesMatchWith<"value type matches element type of target operand",
"target", "value",
"$_self.cast<ShapedType>().getElementType()">,
AttrSizedOperandSegments,
]> {
let summary = [{stores a value into a tensor element}];
let description = [{
Returns a tensor with the element at the given index set to the given value.
}];
let arguments = (ins
AnyTypeOf<[IREE_PrimitiveType, AnyVector]>:$value,
IREE_Tensor:$target,
IREE_ShapeDynamicDims:$target_dims,
Variadic<IREE_Dim>:$indices
);
let results = (outs
IREE_Tensor:$result
);
let assemblyFormat = [{
$value `,` $target (`[` $indices^ `]`)? `:`
type($target) (`{` $target_dims^ `}`)?
attr-dict-with-keyword
}];
}
def IREE_TensorSplatOp : IREE_PureOp<"tensor.splat", [
TypesMatchWith<"value type matches element type of result",
"result", "value",
"$_self.cast<ShapedType>().getElementType()">,
]> {
let summary = [{splats a value into a shaped tensor}];
let description = [{
Returns a tensor initialized to the given primitive value.
}];
let arguments = (ins
IREE_PrimitiveType:$value,
IREE_ShapeDynamicDims:$result_dims
);
let results = (outs
IREE_Tensor:$result
);
let assemblyFormat = [{
$value `:` type($result) (`{` $result_dims^ `}`)?
attr-dict-with-keyword
}];
}
def IREE_TensorCloneOp : IREE_PureOp<"tensor.clone", [
AllTypesMatch<["operand", "result"]>,
]> {
let summary = [{performs a full tensor clone operation}];
let description = [{
Clones the input tensor into an identical output tensor.
}];
let arguments = (ins
IREE_Tensor:$operand,
IREE_ShapeDynamicDims:$operand_dims
);
let results = (outs
IREE_Tensor:$result
);
let assemblyFormat = [{
$operand `:` type($result) (`{` $operand_dims^ `}`)?
attr-dict-with-keyword
}];
}
def IREE_TensorSliceOp : IREE_PureOp<"tensor.slice", [
AllRanksMatch<["source", "result"]>,
AllElementTypesMatch<["source", "result"]>,
AttrSizedOperandSegments,
]> {
let summary = [{slices out a subregion of a tensor}];
let description = [{
Clones a subregion of a tensor.
}];
let arguments = (ins
IREE_Tensor:$source,
IREE_ShapeDynamicDims:$source_dims,
Variadic<IREE_Dim>:$start_indices,
Variadic<IREE_Dim>:$lengths,
IREE_ShapeDynamicDims:$result_dims
);
let results = (outs
IREE_Tensor:$result
);
let assemblyFormat = [{
$source `[` $start_indices `for` $lengths `]` `:`
type($source) (`{` $source_dims^ `}`)? `->`
type($result) (`{` $result_dims^ `}`)?
attr-dict-with-keyword
}];
}
def IREE_TensorUpdateOp : IREE_PureOp<"tensor.update", [
AllRanksMatch<["update", "target", "result"]>,
AllTypesMatch<["target", "result"]>,
AllElementTypesMatch<["update", "target", "result"]>,
AttrSizedOperandSegments,
]> {
let summary = [{updates a tensor with the contents of another tensor}];
let description = [{
Updates the target tensor with the contents of the update tensor at the
given offset indices.
}];
let arguments = (ins
IREE_Tensor:$target,
IREE_ShapeDynamicDims:$target_dims,
Variadic<IREE_Dim>:$start_indices,
IREE_Tensor:$update,
IREE_ShapeDynamicDims:$update_dims
);
let results = (outs
IREE_Tensor:$result
);
let assemblyFormat = [{
$update `,` $target `[` $start_indices `]` `:`
type($update) (`{` $update_dims^ `}`)? `->`
type($result) (`{` $target_dims^ `}`)?
attr-dict-with-keyword
}];
let builders = [
OpBuilder<(ins
"Value":$target,
"ValueRange":$start_indices,
"Value":$update)>,
];
}
def IREE_TensorTraceOp : IREE_Op<"tensor.trace", []> {
let summary = [{trace value(s) operation}];
let description = [{
Traces out to a runtime trace sink (console, log file, etc) the given
tensors and titles them with the given key. The key is informational only
and useful for titling/marking specific sets of tensors for easier
searching.
}];
let arguments = (ins
StrAttr:$key,
Variadic<IREE_Tensor>:$operands
);
let assemblyFormat = "$key attr-dict ($operands^ `:` type($operands))?";
}
#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_OPS_TD

View File

@ -1,7 +0,0 @@
add_mlir_public_c_api_library(IREEDialectsCAPI
Dialects.cpp
LINK_LIBS PUBLIC
IREEDialectsIREEDialect
)
iree_dialects_target_includes(IREEDialectsCAPI)

View File

@ -1,12 +0,0 @@
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree-dialects-c/Dialects.h"
#include "iree-dialects/Dialect/IREE/IREEDialect.h"
#include "mlir/CAPI/Registration.h"
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(IREE, iree, mlir::iree::IREEDialect)

View File

@ -1,2 +0,0 @@
add_subdirectory(CAPI)
add_subdirectory(Dialect)

View File

@ -1 +0,0 @@
add_subdirectory(IREE)

View File

@ -1,16 +0,0 @@
add_mlir_library(IREEDialectsIREEDialect
IREEDialect.cpp
IREEOps.cpp
ADDITIONAL_HEADER_DIRS
${IREE_DIALECTS_SOURCE_DIR}/include
DEPENDS
MLIRIREEOpsIncGen
LINK_LIBS PUBLIC
MLIRIR
MLIRSideEffectInterfaces
)
iree_dialects_target_includes(IREEDialectsIREEDialect)

View File

@ -1,43 +0,0 @@
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree-dialects/Dialect/IREE/IREEDialect.h"
#include "iree-dialects/Dialect/IREE/IREEOps.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Support/LLVM.h"
using namespace mlir;
using namespace mlir::iree;
#include "iree-dialects/Dialect/IREE/IREEOpsDialect.cpp.inc"
#define GET_TYPEDEF_CLASSES
#include "iree-dialects/Dialect/IREE/IREEOpsTypes.cpp.inc"
void IREEDialect::initialize() {
addTypes<
#define GET_TYPEDEF_LIST
#include "iree-dialects/Dialect/IREE/IREEOpsTypes.cpp.inc"
>();
addOperations<
#define GET_OP_LIST
#include "iree-dialects/Dialect/IREE/IREEOps.cpp.inc"
>();
}
Type IREEDialect::parseType(DialectAsmParser &parser) const {
StringRef typeTag;
Type genType;
if (succeeded(parser.parseKeyword(&typeTag)))
generatedTypeParser(getContext(), parser, typeTag, genType);
return genType;
}
void IREEDialect::printType(Type type, DialectAsmPrinter &printer) const {
(void)generatedTypePrinter(type, printer);
}

View File

@ -1,94 +0,0 @@
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree-dialects/Dialect/IREE/IREEOps.h"
#include "iree-dialects/Dialect/IREE/IREEDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeUtilities.h"
using namespace mlir;
using namespace mlir::iree;
//===----------------------------------------------------------------------===//
// custom<SymbolVisibility>($sym_visibility)
//===----------------------------------------------------------------------===//
// some.op custom<SymbolVisibility>($sym_visibility) $sym_name
// ->
// some.op @foo
// some.op private @foo
static ParseResult parseSymbolVisibility(OpAsmParser &parser,
StringAttr &symVisibilityAttr) {
StringRef symVisibility;
parser.parseOptionalKeyword(&symVisibility, {"public", "private", "nested"});
if (!symVisibility.empty()) {
symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility);
}
return success();
}
static void printSymbolVisibility(OpAsmPrinter &p, Operation *op,
StringAttr symVisibilityAttr) {
if (!symVisibilityAttr) {
p << "public";
} else {
p << symVisibilityAttr.getValue();
}
}
//===----------------------------------------------------------------------===//
// custom<TypeOrAttr>($type, $attr)
//===----------------------------------------------------------------------===//
// some.op custom<TypeOrAttr>($type, $attr)
// ->
// some.op : i32
// some.op = 42 : i32
// some.op : i32 = 42 : index
static ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
Attribute &attr) {
if (succeeded(parser.parseOptionalEqual())) {
if (failed(parser.parseAttribute(attr))) {
return parser.emitError(parser.getCurrentLocation())
<< "expected attribute";
}
typeAttr = TypeAttr::get(attr.getType());
return success();
}
Type type;
if (failed(parser.parseColonType(type))) {
return parser.emitError(parser.getCurrentLocation()) << "expected type";
}
typeAttr = TypeAttr::get(type);
if (succeeded(parser.parseOptionalEqual())) {
if (failed(parser.parseAttribute(attr))) {
return parser.emitError(parser.getCurrentLocation())
<< "expected attribute";
}
}
return success();
}
static void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
Attribute attr) {
if (!attr || attr.getType() != type.getValue()) {
p << " : ";
p.printAttribute(type);
}
if (attr) {
p << " = ";
p.printAttribute(attr);
}
}
#define GET_OP_CLASSES
#include "iree-dialects/Dialect/IREE/IREEOps.cpp.inc"

View File

@ -1,65 +0,0 @@
include(AddMLIRPython)
################################################################################
# Sources
################################################################################
declare_mlir_python_sources(IREEDialectsPythonSources)
declare_mlir_python_sources(IREEDialectsPythonExtensions)
declare_mlir_python_sources(IREEDialectsPythonSources.Dialects
ADD_TO_PARENT IREEDialectsPythonSources
)
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT IREEDialectsPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/IreeBinding.td
SOURCES dialects/iree.py
DIALECT_NAME iree
)
################################################################################
# Extensions
################################################################################
declare_mlir_python_extension(IREEDialectsPythonExtensions.Main
MODULE_NAME _ireeDialects
ADD_TO_PARENT IREEDialectsPythonExtensions
SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/IREEDialectsModule.cpp
EMBED_CAPI_LINK_LIBS
IREEDialectsCAPI
PRIVATE_LINK_LIBS
LLVMSupport
)
################################################################################
# Generate packages and shared library
# Downstreams typically will not use these, but they are useful for local
# testing.
################################################################################
set(_source_components
# TODO: Core is now implicitly building/registering all dialects, increasing
# build burden by ~5x. Make it stop.
MLIRPythonSources.Core
IREEDialectsPythonSources
IREEDialectsPythonExtensions
)
add_mlir_python_common_capi_library(IREEDialectsAggregateCAPI
INSTALL_COMPONENT IREEDialectsPythonModules
INSTALL_DESTINATION python_packages/iree_dialects/mlir/_mlir_libs
OUTPUT_DIRECTORY "${IREE_DIALECTS_BINARY_DIR}/python_packages/iree_dialects/mlir/_mlir_libs"
RELATIVE_INSTALL_ROOT "../../../.."
DECLARED_SOURCES ${_source_components}
)
add_mlir_python_modules(IREEDialectsPythonModules
ROOT_PREFIX "${IREE_DIALECTS_BINARY_DIR}/python_packages/iree_dialects/mlir"
INSTALL_PREFIX "python_packages/iree_dialects/mlir"
DECLARED_SOURCES ${_source_components}
COMMON_CAPI_LINK_LIBS
IREEDialectsAggregateCAPI
)

View File

@ -1,27 +0,0 @@
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree-dialects-c/Dialects.h"
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Registration.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
namespace py = pybind11;
PYBIND11_MODULE(_ireeDialects, m) {
m.doc() = "iree-dialects main python extension";
m.def(
"register_iree_dialect",
[](MlirContext context, bool load) {
MlirDialectHandle handle = mlirGetDialectHandle__iree__();
mlirDialectHandleRegisterDialect(handle, context);
if (load) {
mlirDialectHandleLoadDialect(handle, context);
}
},
py::arg("context"), py::arg("load") = true);
}

View File

@ -1,13 +0,0 @@
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#ifndef PYTHON_BINDINGS_IREE_OPS
#define PYTHON_BINDINGS_IREE_OPS
include "mlir/Bindings/Python/Attributes.td"
include "iree-dialects/Dialect/IREE/IREEOps.td"
#endif // PYTHON_BINDINGS_IREE_OPS

View File

@ -1,8 +0,0 @@
# Copyright 2021 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._iree_ops_gen import *
from .._mlir_libs._ireeDialects import register_iree_dialect

View File

@ -1,28 +0,0 @@
llvm_canonicalize_cmake_booleans(
MLIR_ENABLE_BINDINGS_PYTHON
)
configure_lit_site_cfg(
${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py
MAIN_CONFIG
${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py
)
set(IREE_DIALECTS_TEST_DEPENDS
FileCheck count not
)
if(MLIR_ENABLE_BINDINGS_PYTHON)
list(APPEND IREE_DIALECTS_TEST_DEPENDS
IREEDialectsPythonModules
)
endif()
add_lit_testsuite(check-iree-dialects "Running the iree-dialects regression tests"
${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${IREE_DIALECTS_TEST_DEPENDS}
)
set_target_properties(check-iree-dialects PROPERTIES FOLDER "Tests")
add_lit_testsuites(IREE_DIALECTS ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${IREE_DIALECTS_TEST_DEPENDS})

View File

@ -1,4 +0,0 @@
# RUN: %PYTHON %s
# This test does nothing. It is just here so that if python bindings tests
# are excluded, the test suite is not empty.

View File

@ -1,74 +0,0 @@
# -*- Python -*-
# Copyright 2021 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import os
import platform
import re
import subprocess
import tempfile
import lit.formats
import lit.util
from lit.llvm import llvm_config
from lit.llvm.subst import ToolSubst
from lit.llvm.subst import FindTool
# Configuration file for the 'lit' test runner.
# name: The name of this test suite.
config.name = 'IREE_DIALECTS'
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
# suffixes: A list of file extensions to treat as test files.
config.suffixes = ['.mlir', '.py']
# test_source_root: The root path where tests are located.
config.test_source_root = os.path.dirname(__file__)
# test_exec_root: The root path where tests should be run.
config.test_exec_root = os.path.join(config.iree_dialects_obj_root, 'test')
config.substitutions.append(('%PATH%', config.environment['PATH']))
config.substitutions.append(('%shlibext', config.llvm_shlib_ext))
llvm_config.with_system_environment(['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP'])
#llvm_config.use_default_substitutions()
# excludes: A list of directories to exclude from the testsuite. The 'Inputs'
# subdirectories contain auxiliary inputs for various tests in their parent
# directories.
config.excludes = [
'Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt',
'lit.cfg.py', 'lit.site.cfg.py'
]
# test_source_root: The root path where tests are located.
config.test_source_root = os.path.dirname(__file__)
# test_exec_root: The root path where tests should be run.
config.test_exec_root = os.path.join(config.iree_dialects_obj_root, 'test')
config.standalone_tools_dir = os.path.join(config.iree_dialects_obj_root, 'bin')
# Tweak the PATH to include the tools dir.
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
tool_dirs = [config.llvm_tools_dir]
tools = [
ToolSubst('%PYTHON', config.python_executable, unresolved='ignore'),
]
llvm_config.add_tool_substitutions(tools, tool_dirs)
if config.enable_bindings_python:
llvm_config.with_environment('PYTHONPATH', [
os.path.join(config.iree_dialects_obj_root, 'python_packages',
'iree_dialects'),
],
append_path=True)

View File

@ -1,21 +0,0 @@
@LIT_SITE_CFG_IN_HEADER@
import sys
config.enable_bindings_python = @MLIR_ENABLE_BINDINGS_PYTHON@
config.iree_dialects_obj_root = "@IREE_DIALECTS_BINARY_DIR@"
config.llvm_src_root = "@LLVM_SOURCE_DIR@"
config.llvm_obj_root = "@LLVM_BINARY_DIR@"
config.llvm_tools_dir = "@LLVM_TOOLS_DIR@"
config.llvm_lib_dir = "@LLVM_LIBS_DIR@"
config.llvm_shlib_dir = "@SHLIBDIR@"
config.llvm_shlib_ext = "@SHLIBEXT@"
config.llvm_exe_ext = "@EXEEXT@"
config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@"
config.python_executable = sys.executable
import lit.llvm
lit.llvm.initialize(lit_config, config)
# Let the main config do the real work.
lit_config.load_config(config, "@IREE_DIALECTS_SOURCE_DIR@/test/lit.cfg.py")

View File

@ -1,2 +0,0 @@
if not config.enable_bindings_python:
config.unsupported = True

View File

@ -1,7 +0,0 @@
# RUN: %PYTHON %s
import mlir.ir
from mlir.dialects import iree
with mlir.ir.Context() as ctx:
iree.register_iree_dialect(ctx)

View File

@ -1,6 +1 @@
add_subdirectory(Common)
# We include this unconditionally, because we use the `iree-dialects` public
# op surface area to target IREE. No "external" C++ dependency is needed
# because we copy those files into npcomp and build them into our compiler
add_subdirectory(IREE)

View File

@ -20,8 +20,6 @@ void registerCommonBackendPasses();
std::unique_ptr<OperationPass<ModuleOp>> createVerifyBackendContractPass();
std::unique_ptr<OperationPass<FuncOp>> createDeleteDeadIREEListsPass();
} // namespace CommonBackend
} // namespace NPCOMP
} // namespace mlir

View File

@ -16,17 +16,4 @@ def VerifyBackendContract : Pass<"npcomp-verify-backend-contract", "ModuleOp"> {
let constructor = "mlir::NPCOMP::CommonBackend::createVerifyBackendContractPass()";
}
def DeleteDeadIREELists : Pass<"delete-dead-iree-lists", "FuncOp"> {
let summary = "Deletes dead `!iree.list` values";
let description = [{
Some backends cannot handle `!iree.list` values which might be incidentally
created during type conversion. This pass deletes them so those backends
can still run programs that don't use lists in an essential way.
For example, list arguments to convolution ops (such as strides) turn into
dead lists with our current lowerings.
}];
let constructor = "mlir::NPCOMP::CommonBackend::createDeleteDeadIREEListsPass()";
}
#endif // NPCOMP_BACKEND_COMMON_PASSES

View File

@ -1,5 +0,0 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls)
add_public_tablegen_target(NPCOMPIREEBackendPassIncGen)
add_mlir_doc(Passes IREEBackendPasses ./ -gen-pass-doc)

View File

@ -1,29 +0,0 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_BACKEND_IREE_PASSES_H
#define NPCOMP_BACKEND_IREE_PASSES_H
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
namespace mlir {
namespace NPCOMP {
namespace IREEBackend {
/// Registers all IREEBackend passes.
void registerIREEBackendPasses();
/// Create a pipeline that runs all passes needed to lower the npcomp backend
/// contract to IREE's frontend contract.
void createNpcompBackendToIreeFrontendPipeline(OpPassManager &pm);
} // namespace IREEBackend
} // namespace NPCOMP
} // namespace mlir
#endif // NPCOMP_BACKEND_IREE_PASSES_H

View File

@ -1,14 +0,0 @@
//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_BACKEND_IREE_PASSES
#define NPCOMP_BACKEND_IREE_PASSES
include "mlir/Pass/PassBase.td"
#endif // NPCOMP_BACKEND_IREE_PASSES

View File

@ -1,8 +0,0 @@
# IREE Backend
Passes/utilities for preparing input to IREE.
For now, this directory doesn't have a C++-level dependency on IREE, since
it only performs a trivial transformation. Eventually, if lowering
nontrivial constructs to IREE (such as a list type to `!iree.list`),
we will need to take that dependency, and it will be isolated to this directory.

View File

@ -104,12 +104,4 @@ def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "FuncOp"> {
let constructor = "mlir::NPCOMP::createConvertTorchToLinalgPass()";
}
def ConvertTorchToIREE : Pass<"convert-torch-to-iree", "FuncOp"> {
let summary = "Convert recognized Torch ops to IREE ops";
let description = [{
TODO
}];
let constructor = "mlir::NPCOMP::createConvertTorchToIREEPass()";
}
#endif // NPCOMP_CONVERSION_PASSES

View File

@ -1,21 +0,0 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_CONVERSION_TORCHTOIREE_TORCHTOIREE_H
#define NPCOMP_CONVERSION_TORCHTOIREE_TORCHTOIREE_H
#include "mlir/Pass/Pass.h"
#include <memory>
namespace mlir {
namespace NPCOMP {
std::unique_ptr<OperationPass<FuncOp>> createConvertTorchToIREEPass();
}
} // namespace mlir
#endif // NPCOMP_CONVERSION_TORCHTOIREE_TORCHTOIREE_H

View File

@ -21,8 +21,8 @@ def TorchConversion_Dialect : Dialect {
This mainly consists of converting ops and types from `torch` dialect
to the mix of dialects of the npcomp backend contract, such as tensor
ops being converted linalg-on-tensors, lists being converted to IREE lists,
and !torch.float being converted to `f64`.
ops being converted linalg-on-tensors and !torch.float being converted to
`f64`.
}];
}

View File

@ -9,7 +9,6 @@
#ifndef NPCOMP_DIALECT_TORCHCONVERSION_IR_TORCHOPS_H
#define NPCOMP_DIALECT_TORCHCONVERSION_IR_TORCHOPS_H
#include "iree-dialects/Dialect/IREE/IREEDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"

View File

@ -15,7 +15,6 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "npcomp/Dialect/TorchConversion/IR/TorchConversionBase.td"
include "torch-mlir/Dialect/Torch/IR/TorchTypes.td"
include "iree-dialects/Dialect/IREE/IREEDialect.td"
class TorchConversion_Op<string mnemonic, list<OpTrait> traits = []>
: Op<TorchConversion_Dialect, mnemonic, traits> {
@ -171,37 +170,4 @@ def TorchConversion_FromF64Op : TorchConversion_Op<"from_f64", [
}];
}
// TODO: Verify the element types match.
def TorchConversion_ToIREEListOp : TorchConversion_Op<"to_iree_list", [
]> {
let summary = "Convert a `!torch.list` to a `!iree.list`";
let description = [{
}];
let arguments = (ins
Torch_ListType:$operand
);
let results = (outs
IREE_ListType:$result
);
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `->` type($result)
}];
}
def TorchConversion_FromIREEListOp : TorchConversion_Op<"from_iree_list", [
]> {
let summary = "Convert a `!iree.list` to a `!torch.list`";
let description = [{
}];
let arguments = (ins
IREE_ListType:$operand
);
let results = (outs
Torch_ListType:$result
);
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `->` type($result)
}];
}
#endif // TORCHCONVERSION_OPS

View File

@ -32,9 +32,6 @@ std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();
std::unique_ptr<OperationPass<FuncOp>>
createFinalizingBackendTypeConversionPass();
std::unique_ptr<OperationPass<ModuleOp>> createAnnotateABIPass();
} // namespace TorchConversion
/// Registers all Torch transformation passes.

View File

@ -52,26 +52,4 @@ def FinalizingBackendTypeConversion
}];
}
def AnnotateABI : Pass<"torch-annotate-abi", "ModuleOp"> {
let summary = "Annotate `torch` types before lowering to backend types";
let description = [{
Populates `iree.abi` metadata to allow runtime reflection of
arguments and results.
See IREE's `docs/developers/design_docs/function_abi.md` for information
about this annotation format.
This information must be annotated before we lower types to the backend
contract, since that lowering is not generally reversible to recover the
correct Python signature.
TODO: Reconsider the passes in the Torch lowering pipeline in light of this.
We want to provide a faithful ABI up to the user, so the None handling (and
unimplemented tuple handling) in AdjustCallingConventions and the ClassType
handling in GlobalizeObjectGraph will need to be considered.
}];
let constructor = "mlir::NPCOMP::TorchConversion::createAnnotateABIPass()";
}
#endif // NPCOMP_TORCHCONVERSION_PASSES

View File

@ -1,6 +1 @@
add_subdirectory(Common)
# We include this unconditionally, because we use the `iree-dialects` public
# op surface area to target IREE. No "external" C++ dependency is needed
# because we copy those files into npcomp and build them into our compiler
add_subdirectory(IREE)

View File

@ -1,5 +1,4 @@
add_npcomp_library(NPCOMPCommonBackend
DeleteDeadIREELists.cpp
VerifyBackendContract.cpp
Passes.cpp
@ -18,7 +17,6 @@ add_npcomp_library(NPCOMPCommonBackend
MLIRTensor
MLIRStandard
MLIRMath
IREEDialectsIREEDialect
)
mlir_check_all_link_libraries(NPCOMPCommonBackend)

View File

@ -1,54 +0,0 @@
//===- DeleteDeadIREELists.cpp -----------------------------------*- C++-*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "iree-dialects/Dialect/IREE/IREEOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "npcomp/Backend/Common/Passes.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::CommonBackend;
namespace {
class DeleteDeadIREEListsPass
: public DeleteDeadIREEListsBase<DeleteDeadIREEListsPass> {
void runOnOperation() override {
SmallVector<Operation *> toErase;
// Delete lists that are only set (but not read from).
// This is created by our lowering for torch.prim.ListConstruct.
// Until IREE can run such ops e2e (or delete them itself), we need to
// do this cleanup.
// TODO: Add support to IREE to run these ops E2E.
getOperation().walk([&](iree::ListCreateOp op) {
SmallVector<Operation *> deadOps;
deadOps.push_back(op);
for (auto &use : op.getResult().getUses()) {
if (isa<iree::ListSetOp, iree::ListResizeOp>(use.getOwner())) {
deadOps.push_back(use.getOwner());
} else {
// We can't analyze the list op if it is used by something else.
return;
}
}
llvm::append_range(toErase, deadOps);
});
for (auto *op : toErase) {
op->dropAllDefinedValueUses();
op->erase();
}
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::NPCOMP::CommonBackend::createDeleteDeadIREEListsPass() {
return std::make_unique<DeleteDeadIREEListsPass>();
}

View File

@ -7,7 +7,6 @@
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "iree-dialects/Dialect/IREE/IREEOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@ -33,7 +32,6 @@ class VerifyBackendContractPass
return type;
return nullptr;
});
converter.addConversion([](iree::ListType type) { return type; });
TypeConverter scalarConverter;
for (TypeConverter *c : {&converter, &scalarConverter}) {
c->addConversion([](FloatType type) { return type; });
@ -61,7 +59,6 @@ class VerifyBackendContractPass
// Tensor operations should go through linalg and the tensor dialect.
target.addDynamicallyLegalDialect<linalg::LinalgDialect>(opHasLegalTypes);
target.addDynamicallyLegalDialect<tensor::TensorDialect>(opHasLegalTypes);
target.addDynamicallyLegalDialect<iree::IREEDialect>(opHasLegalTypes);
// AssertOp is used to terminate the program for error guards.
target.addLegalOp<AssertOp>();

View File

@ -1,17 +0,0 @@
add_npcomp_library(NPCOMPIREEBackend
Passes.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SRC_DIR}/include/npcomp/Backend/IREE
DEPENDS
NPCOMPIREEBackendPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
)
mlir_check_all_link_libraries(NPCOMPIREEBackend)

View File

@ -1,25 +0,0 @@
//===- PassDetail.h - Pass class details ------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef BACKEND_IREE_PASSDETAIL_H
#define BACKEND_IREE_PASSDETAIL_H
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace NPCOMP {
namespace IREEBackend {
#define GEN_PASS_CLASSES
#include "npcomp/Backend/IREE/Passes.h.inc"
} // namespace IREEBackend
} // namespace NPCOMP
} // end namespace mlir
#endif // BACKEND_IREE_PASSDETAIL_H

View File

@ -1,34 +0,0 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "npcomp/Backend/IREE/Passes.h"
#include "mlir/IR/BuiltinOps.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::IREEBackend;
namespace {
#define GEN_PASS_REGISTRATION
#include "npcomp/Backend/IREE/Passes.h.inc"
} // end namespace
void mlir::NPCOMP::IREEBackend::createNpcompBackendToIreeFrontendPipeline(
OpPassManager &pm) {}
void mlir::NPCOMP::IREEBackend::registerIREEBackendPasses() {
::registerPasses();
mlir::PassPipelineRegistration<>(
"npcomp-backend-to-iree-frontend-pipeline",
"Pipeline lowering the npcomp backend contract IR to IREE's frontend "
"contract.",
mlir::NPCOMP::IREEBackend::createNpcompBackendToIreeFrontendPipeline);
}

View File

@ -26,11 +26,9 @@ add_npcomp_library(NPCOMPInitAll
PUBLIC
# Local depends
NPCOMPCommonBackend
NPCOMPIREEBackend
TorchMLIRTorchDialect
NPCOMPTorchConversionDialect
NPCOMPConversionPasses
IREEDialectsIREEDialect
# TODO: We shouldn't need npcomp_conversion_libs here, but we have
# some dialect transform libraries accumulating into that property.

View File

@ -1,4 +1,3 @@
add_subdirectory(TorchToIREE)
add_subdirectory(TorchToLinalg)
add_subdirectory(TorchToSCF)
add_subdirectory(TorchToStd)

View File

@ -8,7 +8,6 @@
#include "npcomp/Conversion/Passes.h"
#include "npcomp/Conversion/TorchToIREE/TorchToIREE.h"
#include "npcomp/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "npcomp/Conversion/TorchToSCF/TorchToSCF.h"
#include "npcomp/Conversion/TorchToStd/TorchToStd.h"

View File

@ -1,19 +0,0 @@
add_npcomp_conversion_library(NPCOMPTorchToIREE
TorchToIREE.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/npcomp/Conversion/TorchToIREE
DEPENDS
NPCOMPConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
TorchMLIRTorchDialect
MLIRStandard
IREEDialectsIREEDialect
)

View File

@ -1,90 +0,0 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp/Conversion/TorchToIREE/TorchToIREE.h"
#include "../PassDetail.h"
#include "iree-dialects/Dialect/IREE/IREEOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "npcomp/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::torch::Torch;
//===----------------------------------------------------------------------===//
// The patterns
//===----------------------------------------------------------------------===//
namespace {
class ConvertPrimListConstructOp
: public OpConversionPattern<PrimListConstructOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(PrimListConstructOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto type = getTypeConverter()->convertType(op.getType());
auto size =
rewriter.create<ConstantIndexOp>(op.getLoc(), op->getNumOperands());
auto ireeList =
rewriter.replaceOpWithNewOp<iree::ListCreateOp>(op, type, size);
rewriter.create<iree::ListResizeOp>(op.getLoc(), ireeList, size);
for (int i = 0, e = operands.size(); i != e; ++i) {
auto index = rewriter.create<ConstantIndexOp>(op.getLoc(), i);
rewriter.create<iree::ListSetOp>(op.getLoc(), ireeList, index,
operands[i]);
}
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// The pass
//===----------------------------------------------------------------------===//
namespace {
class ConvertTorchToIREE : public ConvertTorchToIREEBase<ConvertTorchToIREE> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<StandardOpsDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
}
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<iree::IREEDialect>();
target.addLegalDialect<StandardOpsDialect>();
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter);
RewritePatternSet patterns(context);
patterns.add<ConvertPrimListConstructOp>(typeConverter, context);
target.addIllegalOp<PrimListConstructOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::NPCOMP::createConvertTorchToIREEPass() {
return std::make_unique<ConvertTorchToIREE>();
}

View File

@ -1,132 +0,0 @@
//===- AnnotateABI.cpp -------------------------------------------*- C++-*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "iree-dialects/Dialect/IREE/IREEOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "npcomp/Dialect/TorchConversion/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "llvm/Support/JSON.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::TorchConversion;
using namespace mlir::torch;
namespace json = llvm::json;
static json::Value
convertTypeToIREEABIJSON(Type type,
llvm::function_ref<InFlightDiagnostic()> emitError) {
if (auto tensorType = type.dyn_cast<Torch::BaseTensorType>()) {
// TODO: Support unranked and unknown dtype when we actually have examples
// that need it.
if (tensorType.hasSizes() && tensorType.hasDtype()) {
json::Array typeRecord{"ndarray"};
typeRecord.push_back(
convertTypeToIREEABIJSON(tensorType.getDtype(), emitError));
typeRecord.push_back(json::Value(tensorType.getSizes().size()));
for (auto size : tensorType.getSizes()) {
if (size == Torch::kUnknownSize)
typeRecord.push_back(json::Value(nullptr));
else
typeRecord.push_back(json::Value(size));
}
return typeRecord;
}
} else if (auto boolType = type.dyn_cast<Torch::BoolType>()) {
return json::Value("i1");
} else if (auto intType = type.dyn_cast<Torch::IntType>()) {
return json::Value("i64");
} else if (auto floatType = type.dyn_cast<Torch::FloatType>()) {
return json::Value("f64");
} else if (auto listType = type.dyn_cast<Torch::ListType>()) {
return json::Array{
json::Value("py_uniform_list"),
convertTypeToIREEABIJSON(listType.getContainedType(), emitError)};
} else if (auto dictType = type.dyn_cast<Torch::DictType>()) {
return json::Array{
json::Value("py_uniform_dict"),
convertTypeToIREEABIJSON(dictType.getKeyType(), emitError),
convertTypeToIREEABIJSON(dictType.getValueType(), emitError)};
} else if (auto tupleType = type.dyn_cast<Torch::TupleType>()) {
auto typeRecord = json::Array{"pytuple"};
for (auto containedType : tupleType.getContainedTypes())
typeRecord.push_back(convertTypeToIREEABIJSON(containedType, emitError));
return typeRecord;
} else if (auto strType = type.dyn_cast<Torch::StringType>()) {
return json::Value("pystr");
} else if (auto integerType = type.dyn_cast<mlir::IntegerType>()) {
// Only used in recursive calls for tensor dtypes.
return json::Value(("i" + Twine(integerType.getWidth())).str());
} else if (auto floatType = type.dyn_cast<mlir::FloatType>()) {
// Only used in recursive calls for tensor dtypes.
if (floatType.isa<BFloat16Type>())
return json::Value("bf16");
return json::Value(("f" + Twine(floatType.getWidth())).str());
}
emitError() << "unimplemented: ABI annotation for type " << type;
return json::Value("error: unimplemented type");
}
namespace {
class AnnotateABIPass : public AnnotateABIBase<AnnotateABIPass> {
void runOnOperation() override {
auto module = getOperation();
bool hadError = false;
module.walk([&](FuncOp func) {
if (func.getVisibility() != SymbolTable::Visibility::Public)
return;
func.getArgumentTypes();
json::Array abiArgs;
json::Array abiResults;
for (auto type : llvm::enumerate(func.getArgumentTypes())) {
auto emitError = [&]() {
hadError = true;
return func.emitError()
<< "at function argument " << type.index() << ": ";
};
abiArgs.push_back(convertTypeToIREEABIJSON(type.value(), emitError));
}
for (auto type : llvm::enumerate(func.getCallableResults())) {
auto emitError = [&]() {
hadError = true;
return func.emitError()
<< "at function result " << type.index() << ": ";
};
abiResults.push_back(convertTypeToIREEABIJSON(type.value(), emitError));
}
if (hadError)
return;
json::Object abiDict;
abiDict["v"] = json::Value(1);
abiDict["a"] = json::Value(std::move(abiArgs));
abiDict["r"] = json::Value(std::move(abiResults));
std::string buf;
llvm::raw_string_ostream os(buf);
os << json::Value(std::move(abiDict));
func->setAttr("iree.abi", Builder(func).getStringAttr(os.str()));
});
if (hadError)
return signalPassFailure();
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::NPCOMP::TorchConversion::createAnnotateABIPass() {
return std::make_unique<AnnotateABIPass>();
}

View File

@ -8,7 +8,6 @@
#include "PassDetail.h"
#include "iree-dialects/Dialect/IREE/IREEDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
#include "mlir/IR/BlockAndValueMapping.h"
@ -27,7 +26,6 @@ using namespace mlir::NPCOMP::TorchConversion;
void mlir::NPCOMP::TorchConversion::getBackendTypeConversionDependentDialects(
DialectRegistry &registry) {
registry.insert<TorchConversionDialect>();
registry.insert<iree::IREEDialect>();
}
//===----------------------------------------------------------------------===//
@ -136,39 +134,12 @@ static void setupTorchFloatToF64Conversion(ConversionTarget &target,
typeConverter.addArgumentMaterialization(sourceMaterialization);
}
static void setupTorchListToIREEListConversion(ConversionTarget &target,
TypeConverter &typeConverter) {
target.addLegalOp<TorchConversion::ToIREEListOp,
TorchConversion::FromIREEListOp>();
typeConverter.addConversion([&](Torch::ListType type) -> Optional<Type> {
return iree::ListType::get(
type.getContext(), typeConverter.convertType(type.getContainedType()));
});
typeConverter.addTargetMaterialization(
[](OpBuilder &builder, iree::ListType type, ValueRange inputs,
Location loc) -> Optional<Value> {
assert(inputs.size() == 1);
assert(inputs[0].getType().isa<Torch::ListType>());
return builder.create<ToIREEListOp>(loc, type, inputs[0]).getResult();
});
auto sourceMaterialization = [](OpBuilder &builder, Torch::ListType type,
ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1);
assert(inputs[0].getType().isa<iree::ListType>());
return builder.create<FromIREEListOp>(loc, type, inputs[0]);
};
typeConverter.addSourceMaterialization(sourceMaterialization);
typeConverter.addArgumentMaterialization(sourceMaterialization);
}
void mlir::NPCOMP::TorchConversion::setupBackendTypeConversion(
ConversionTarget &target, TypeConverter &typeConverter) {
setupValueTensorToBuiltinTensorConversion(target, typeConverter);
setupTorchBoolToI1Conversion(target, typeConverter);
setupTorchIntToI64Conversion(target, typeConverter);
setupTorchFloatToF64Conversion(target, typeConverter);
// TODO: Remove list support entirely.
// setupTorchListToIREEListConversion(target, typeConverter);
}
//===----------------------------------------------------------------------===//
@ -283,8 +254,8 @@ struct FinalizingBackendTypeConversionPass
// Mark materializations as illegal in this pass (since we are finalizing)
// and add patterns that eliminate them.
setupFinalization<ToBuiltinTensorOp, FromBuiltinTensorOp, FromI1Op, ToI1Op,
FromI64Op, ToI64Op, FromF64Op, ToF64Op, FromIREEListOp,
ToIREEListOp>(target, patterns, typeConverter);
FromI64Op, ToI64Op, FromF64Op, ToF64Op>(target, patterns,
typeConverter);
// If all result types are legal, and all block arguments are legal, then
// all types in the program are legal.

View File

@ -1,5 +1,4 @@
add_npcomp_conversion_library(NPCOMPTorchConversionPasses
AnnotateABI.cpp
BackendTypeConversion.cpp
Passes.cpp
VerifyInvariantsBeforeBackendLowering.cpp
@ -20,7 +19,6 @@ add_npcomp_conversion_library(NPCOMPTorchConversionPasses
NPCOMPTorchConversionDialect
TorchMLIRTorchDialect
TorchMLIRTorchPasses
NPCOMPTorchToIREE
NPCOMPTorchToLinalg
NPCOMPTorchToStd
NPCOMPTorchToSCF

View File

@ -13,7 +13,6 @@
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "npcomp/Backend/Common/Passes.h"
#include "npcomp/Conversion/TorchToIREE/TorchToIREE.h"
#include "npcomp/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "npcomp/Conversion/TorchToSCF/TorchToSCF.h"
#include "npcomp/Conversion/TorchToStd/TorchToStd.h"
@ -47,10 +46,6 @@ void mlir::NPCOMP::TorchConversion::createTorchScriptToNpcompBackendPipeline(
// contract.
Torch::createTorchScriptToTorchBackendPipeline(pm, options);
// Annotate the ABI of the original Torch functions before we lower them and
// lose information.
pm.addPass(TorchConversion::createAnnotateABIPass());
// Check some invariants to catch errors in a clear way.
pm.addPass(
TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass());
@ -62,17 +57,6 @@ void mlir::NPCOMP::TorchConversion::createTorchScriptToNpcompBackendPipeline(
pm.addNestedPass<FuncOp>(createConvertTorchToLinalgPass());
pm.addNestedPass<FuncOp>(createConvertTorchToStdPass());
pm.addNestedPass<FuncOp>(createConvertTorchToSCFPass());
// Lists and other concepts that don't exist in upstream go through the IREE
// dialect, which we treat as an reasonably well designed interim placeholder
// for the set of ops that we think makes sense in the npcomp backend
// contract. We expect to co-evolve this dialect with npcomp needs, as a lot
// of what we are doing here in npcomp is breaking new ground w.r.t.
// expressiveness and program generality for tensor compilers.
//
// We lower lists last because the lowered form is much harder to reason about
// than the original form.
// TODO: Remove list support entirely.
// pm.addNestedPass<FuncOp>(createConvertTorchToIREEPass());
pm.addNestedPass<FuncOp>(createStdExpandOpsPass());
if (options.optimize) {

View File

@ -8,24 +8,18 @@
#include "npcomp/InitAll.h"
#include "iree-dialects/Dialect/IREE/IREEDialect.h"
#include "mlir/IR/Dialect.h"
#include "npcomp/Backend/Common/Passes.h"
#include "npcomp/Backend/IREE/Passes.h"
#include "npcomp/Conversion/Passes.h"
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "npcomp/Dialect/TorchConversion/Transforms/Passes.h"
void mlir::NPCOMP::registerAllDialects(mlir::DialectRegistry &registry) {
// clang-format off
registry.insert<mlir::NPCOMP::TorchConversion::TorchConversionDialect,
iree::IREEDialect>();
// clang-format on
registry.insert<mlir::NPCOMP::TorchConversion::TorchConversionDialect>();
}
void mlir::NPCOMP::registerAllPasses() {
mlir::NPCOMP::registerConversionPasses();
mlir::NPCOMP::registerTorchConversionPasses();
mlir::NPCOMP::IREEBackend::registerIREEBackendPasses();
mlir::NPCOMP::CommonBackend::registerCommonBackendPasses();
}

View File

@ -1,7 +0,0 @@
def is_iree_enabled():
try:
import iree.runtime
import iree.compiler
except:
return False
return True

View File

@ -1,104 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import os
import torch
import numpy as np
from npcomp.ir import *
from npcomp.passmanager import *
from npcomp.compiler.utils import logging
import iree.runtime as ireert
import iree.compiler as ireec
from .abc import NpcompBackend
__all__ = [
"IreeNpcompBackend",
]
class IreeModuleInvoker:
"""Wrapper around a native IREE module for calling functions."""
def __init__(self, iree_module):
super().__init__()
self._iree_module = iree_module
def __getattr__(self, function_name):
return self.__getitem__(function_name)
def __getitem__(self, function_name):
def invoke(*args):
results = self._iree_module[function_name](*args)
return results
invoke.__isnpcomp__ = True
return invoke
class TorchIreeModuleInvoker(IreeModuleInvoker):
"""Allows torch.Tensor inputs to be passed to module invocations."""
def __getitem__(self, function_name: str):
numpy_invoke = super().__getitem__(function_name)
def invoke(*args):
args = tuple(
arg.numpy() if isinstance(arg, torch.Tensor) else arg for arg in args)
return numpy_invoke(*args)
return invoke
class IreeNpcompBackend(NpcompBackend):
"""Main entry-point for the backend."""
def __init__(self):
super().__init__()
self._debug = logging.debug_enabled()
def compile(self, imported_module: Module):
"""Compiles an imported module, with a flat list of functions.
The module is expected to conform to the npcomp backend contract.
See the VerifyBackendContract pass for more details.
Args:
imported_module: The MLIR module consisting of funcs in the torch
dialect.
Returns:
An opaque, backend specific module object that can be passed to load.
The object may actually be something more specific to the backend (i.e.
for IREE, it is a serialized VM flatbuffer) but the contract is that
it is operated on by methods on this class.
"""
with imported_module.context as context:
if self._debug:
logging.debug("IR passed to IREE compiler backend:\n{}",
imported_module)
pipeline_str = "npcomp-backend-to-iree-frontend-pipeline"
if self._debug:
logging.debug("Running Prepare For IREE pipeline '{}'", pipeline_str)
pm = PassManager.parse(pipeline_str)
pm.run(imported_module)
if self._debug:
logging.debug(
"IREE Input IR (this is what IREE's compiler will see):\n{}",
imported_module)
# Backend.
binary = ireec.compile_str(str(imported_module),
target_backends=["dylib-llvm-aot"])
return binary
def load(self, iree_module) -> TorchIreeModuleInvoker:
"""Loads a compiled artifact into the runtime."""
vm_module = ireert.VmModule.from_flatbuffer(iree_module)
iree_config = ireert.Config(driver_name="dylib")
ctx = ireert.SystemContext(config=iree_config)
ctx.add_vm_module(vm_module)
return TorchIreeModuleInvoker(ctx.modules.module)

View File

@ -1,7 +1,7 @@
// RUN: npcomp-opt -npcomp-verify-backend-contract -split-input-file -verify-diagnostics -allow-unregistered-dialect %s | FileCheck %s
// CHECK: func @mm
func @mm(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> attributes {iree.module.export} {
func @mm(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = constant 0 : index
%c1 = constant 1 : index
%cst = constant 0.000000e+00 : f32

View File

@ -1,21 +0,0 @@
// RUN: npcomp-opt <%s -convert-torch-to-iree -split-input-file -verify-diagnostics | FileCheck %s
// XFAIL: *
// CHECK-LABEL: func @forward(
// CHECK-SAME: %[[ARG_TORCH:.*]]: !torch.float) -> !torch.list<!torch.float> {
// CHECK: %[[ARG:.*]] = torch_c.to_f64 %[[ARG_TORCH]]
// CHECK: %[[ALSO_ARG:.*]] = torch_c.to_f64 %[[ARG_TORCH]]
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[LIST:.*]] = iree.list.create %[[C2]] : !iree.list<f64>
// CHECK: iree.list.resize %[[LIST]], %[[C2]] : !iree.list<f64>
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: iree.list.set %[[LIST]][%[[C0]]], %[[ARG]] : !iree.list<f64>, f64
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: iree.list.set %[[LIST]][%[[C1]]], %[[ALSO_ARG]] : !iree.list<f64>, f64
// CHECK: %[[LIST_TORCH:.*]] = torch_c.from_iree_list %[[LIST]] : !iree.list<f64> -> !torch.list<!torch.float>
// CHECK: return %[[LIST_TORCH]] : !torch.list<!torch.float>
builtin.func @forward(%arg0: !torch.float) -> !torch.list<!torch.float> {
%0 = torch.prim.ListConstruct %arg0, %arg0 : (!torch.float, !torch.float) -> !torch.list<!torch.float>
return %0 : !torch.list<!torch.float>
}

View File

@ -1,48 +0,0 @@
// RUN: npcomp-opt -split-input-file -verify-diagnostics %s -torch-annotate-abi
// -----
// CHECK-LABEL: builtin.func @basic_arg_and_ret(
// CHECK-SAME: attributes {iree.abi = "{\22a\22:[\22f64\22,\22i64\22,\22i1\22],\22r\22:[\22f64\22,\22i64\22,\22i1\22],\22v\22:1}"} {
builtin.func @basic_arg_and_ret(%arg0: !torch.float, %arg1: !torch.int, %arg2: !torch.bool) -> (!torch.float, !torch.int, !torch.bool) {
return %arg0, %arg1, %arg2 : !torch.float, !torch.int, !torch.bool
}
// -----
// CHECK-LABEL: builtin.func @list(
// CHECK-SAME: attributes {iree.abi = "{\22a\22:{{\[\[}}\22py_uniform_list\22,\22f64\22]],\22r\22:[],\22v\22:1}"} {
builtin.func @list(%arg0: !torch.list<!torch.float>) {
return
}
// -----
// CHECK-LABEL: builtin.func @tuple(
// CHECK-SAME: attributes {iree.abi = "{\22a\22:{{\[\[}}\22pytuple\22,\22f64\22,\22i64\22]],\22r\22:[],\22v\22:1}"} {
builtin.func @tuple(%arg0: !torch.tuple<!torch.float, !torch.int>) {
return
}
// -----
// CHECK-LABEL: builtin.func @dict(
// CHECK-SAME: attributes {iree.abi = "{\22a\22:{{\[\[}}\22py_uniform_dict\22,\22pystr\22,\22f64\22]],\22r\22:[],\22v\22:1}"} {
builtin.func @dict(%arg0: !torch.dict<!torch.str, !torch.float>) {
return
}
// -----
// CHECK-LABEL: builtin.func @tensor(
// CHECK-SAME: attributes {iree.abi = "{\22a\22:{{\[\[}}\22ndarray\22,\22f32\22,2,null,3]],\22r\22:[],\22v\22:1}"} {
builtin.func @tensor(%arg0: !torch.tensor<[?,3],f32>) {
return
}
// -----
// expected-error @+1 {{at function argument 0: unimplemented: ABI annotation for type '!torch.any'}}
builtin.func @unsupported(%arg0: !torch.any) {
return
}