mirror of https://github.com/llvm/torch-mlir
Add TMTensor dialect to torch-mlir
This is intended to explore support for non-structured ops that can't be modeled by Linalg dialect. `tm_tensor.scan` and `tm_tensor.scatter` are added as the first such ops. The dialect should aim to be upstreamed in the future.pull/603/head
parent
cd21dda867
commit
869daf3c22
|
@ -53,8 +53,9 @@ jobs:
|
||||||
-DPython3_EXECUTABLE=$(which python) \
|
-DPython3_EXECUTABLE=$(which python) \
|
||||||
-DLLVM_ENABLE_ASSERTIONS=ON \
|
-DLLVM_ENABLE_ASSERTIONS=ON \
|
||||||
-DLLVM_ENABLE_PROJECTS=mlir \
|
-DLLVM_ENABLE_PROJECTS=mlir \
|
||||||
-DLLVM_EXTERNAL_PROJECTS=torch-mlir \
|
-DLLVM_EXTERNAL_PROJECTS="torch-mlir;torch-mlir-dialects" \
|
||||||
-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$GITHUB_WORKSPACE" \
|
-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$GITHUB_WORKSPACE" \
|
||||||
|
-DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR="${GITHUB_WORKSPACE}/external/llvm-external-projects/torch-mlir-dialects" \
|
||||||
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
|
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
|
||||||
-DLLVM_TARGETS_TO_BUILD=host
|
-DLLVM_TARGETS_TO_BUILD=host
|
||||||
ninja check-torch-mlir-all
|
ninja check-torch-mlir-all
|
||||||
|
|
|
@ -25,6 +25,22 @@ project(torch-mlir LANGUAGES CXX C)
|
||||||
set(CMAKE_C_STANDARD 11)
|
set(CMAKE_C_STANDARD 11)
|
||||||
set(CMAKE_CXX_STANDARD 14)
|
set(CMAKE_CXX_STANDARD 14)
|
||||||
|
|
||||||
|
macro(torch_mlir_add_llvm_external_project name identifier location)
|
||||||
|
message(STATUS "Adding LLVM external project ${name} (${identifier}) -> ${location}")
|
||||||
|
if(NOT EXISTS "${location}/CMakeLists.txt")
|
||||||
|
message(FATAL_ERROR "External project location ${location} is not valid")
|
||||||
|
endif()
|
||||||
|
list(APPEND LLVM_EXTERNAL_PROJECTS ${name})
|
||||||
|
list(REMOVE_DUPLICATES LLVM_EXTERNAL_PROJECTS)
|
||||||
|
set(LLVM_EXTERNAL_${identifier}_SOURCE_DIR ${location} CACHE STRING "" FORCE)
|
||||||
|
set(LLVM_EXTERNAL_PROJECTS ${LLVM_EXTERNAL_PROJECTS} CACHE STRING "" FORCE)
|
||||||
|
endmacro()
|
||||||
|
|
||||||
|
torch_mlir_add_llvm_external_project(
|
||||||
|
torch-mlir-dialects
|
||||||
|
TORCH_MLIR_DIALECTS
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/external/llvm-external-projects/torch-mlir-dialects)
|
||||||
|
|
||||||
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
||||||
# Out-of-tree build
|
# Out-of-tree build
|
||||||
|
|
||||||
|
@ -129,6 +145,7 @@ add_subdirectory(tools)
|
||||||
add_custom_target(check-torch-mlir-all)
|
add_custom_target(check-torch-mlir-all)
|
||||||
add_dependencies(check-torch-mlir-all
|
add_dependencies(check-torch-mlir-all
|
||||||
check-torch-mlir
|
check-torch-mlir
|
||||||
|
check-torch-mlir-dialects
|
||||||
)
|
)
|
||||||
|
|
||||||
if(MLIR_ENABLE_BINDINGS_PYTHON)
|
if(MLIR_ENABLE_BINDINGS_PYTHON)
|
||||||
|
|
|
@ -62,8 +62,9 @@ cmake -GNinja -Bbuild \
|
||||||
-DCMAKE_CXX_COMPILER=clang++ \
|
-DCMAKE_CXX_COMPILER=clang++ \
|
||||||
-DPython3_FIND_VIRTUALENV=ONLY \
|
-DPython3_FIND_VIRTUALENV=ONLY \
|
||||||
-DLLVM_ENABLE_PROJECTS=mlir \
|
-DLLVM_ENABLE_PROJECTS=mlir \
|
||||||
-DLLVM_EXTERNAL_PROJECTS=torch-mlir \
|
-DLLVM_EXTERNAL_PROJECTS=torch-mlir;torch-mlir-dialects \
|
||||||
-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR=`pwd` \
|
-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR=`pwd` \
|
||||||
|
-DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR=`pwd`/external/llvm-external-projects/torch-mlir-dialects \
|
||||||
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
|
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
|
||||||
-DLLVM_TARGETS_TO_BUILD=host \
|
-DLLVM_TARGETS_TO_BUILD=host \
|
||||||
external/llvm-project/llvm
|
external/llvm-project/llvm
|
||||||
|
|
|
@ -19,8 +19,9 @@ cmake -GNinja -B"$build_dir" "$llvm_project_dir/llvm" \
|
||||||
-DCMAKE_BUILD_TYPE=Release \
|
-DCMAKE_BUILD_TYPE=Release \
|
||||||
-DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
|
-DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
|
||||||
-DLLVM_ENABLE_PROJECTS=mlir \
|
-DLLVM_ENABLE_PROJECTS=mlir \
|
||||||
-DLLVM_EXTERNAL_PROJECTS=torch-mlir \
|
-DLLVM_EXTERNAL_PROJECTS=torch-mlir;torch-mlir-dialects \
|
||||||
-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$project_dir" \
|
-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$project_dir" \
|
||||||
|
-DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR=${project_dir}/external/llvm-external-projects/torch-mlir-dialects \
|
||||||
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
|
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
|
||||||
-DLLVM_ENABLE_ASSERTIONS=ON \
|
-DLLVM_ENABLE_ASSERTIONS=ON \
|
||||||
-DLLVM_TARGETS_TO_BUILD=host
|
-DLLVM_TARGETS_TO_BUILD=host
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
||||||
|
message(FATAL_ERROR
|
||||||
|
"This project is intended to be built as part of LLVM via "
|
||||||
|
"-DLLVM_EXTERNAL_PROJECTS=torch-mlir-dialects "
|
||||||
|
"-DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR=${CMAKE_CURRENT_SOURCE_DIR}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
option(MLIR_ENABLE_BINDINGS_PYTHON "Enables MLIR Python Bindings" OFF)
|
||||||
|
|
||||||
|
set(TORCH_MLIR_DIALECTS_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
|
||||||
|
set(TORCH_MLIR_DIALECTS_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")
|
||||||
|
message(STATUS "Building torch-mlir-dialects project at ${TORCH_MLIR_DIALECTS_SOURCE_DIR} (into ${TORCH_MLIR_DIALECTS_BINARY_DIR})")
|
||||||
|
|
||||||
|
# TODO: Fix this upstream so that global include directories are not needed.
|
||||||
|
set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir)
|
||||||
|
set(MLIR_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include)
|
||||||
|
set(MLIR_GENERATED_INCLUDE_DIR ${LLVM_BINARY_DIR}/tools/mlir/include)
|
||||||
|
|
||||||
|
# TODO: Needed for tablegen. Remove.
|
||||||
|
include_directories(SYSTEM ${MLIR_INCLUDE_DIR})
|
||||||
|
include_directories(SYSTEM ${MLIR_GENERATED_INCLUDE_DIR})
|
||||||
|
include_directories(SYSTEM ${TORCH_MLIR_DIALECTS_SOURCE_DIR}/include)
|
||||||
|
|
||||||
|
function(torch_mlir_dialects_target_includes target)
|
||||||
|
set(_dirs
|
||||||
|
$<BUILD_INTERFACE:${MLIR_INCLUDE_DIR}>
|
||||||
|
$<BUILD_INTERFACE:${MLIR_GENERATED_INCLUDE_DIR}>
|
||||||
|
$<BUILD_INTERFACE:${TORCH_MLIR_DIALECTS_SOURCE_DIR}/include>
|
||||||
|
$<BUILD_INTERFACE:${TORCH_MLIR_DIALECTS_BINARY_DIR}/include>
|
||||||
|
)
|
||||||
|
# In LLVM parlance, the actual target may just be an interface and may not
|
||||||
|
# be responsible for actually compiling anything. The corresponding obj.
|
||||||
|
# target, when present, is just used for compilation and does not
|
||||||
|
# contribute to the interface properties.
|
||||||
|
# TODO: Normalize this upstream.
|
||||||
|
target_include_directories(${target} PUBLIC ${_dirs})
|
||||||
|
if(TARGET obj.${target})
|
||||||
|
target_include_directories(obj.${target} PRIVATE ${_dirs})
|
||||||
|
endif()
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
# Configure CMake and tablegen.
|
||||||
|
list(APPEND CMAKE_MODULE_PATH ${MLIR_MAIN_SRC_DIR}/cmake/modules)
|
||||||
|
list(APPEND CMAKE_MODULE_PATH ${LLVM_MAIN_SRC_DIR}/cmake)
|
||||||
|
set(MLIR_TABLEGEN_EXE mlir-tblgen)
|
||||||
|
|
||||||
|
include(TableGen)
|
||||||
|
include(AddLLVM)
|
||||||
|
include(AddMLIR)
|
||||||
|
|
||||||
|
add_subdirectory(include)
|
||||||
|
add_subdirectory(lib)
|
||||||
|
add_subdirectory(tools)
|
||||||
|
add_subdirectory(test)
|
|
@ -0,0 +1 @@
|
||||||
|
add_subdirectory(torch-mlir-dialects)
|
|
@ -0,0 +1 @@
|
||||||
|
add_subdirectory(Dialect)
|
|
@ -0,0 +1 @@
|
||||||
|
add_subdirectory(TMTensor)
|
|
@ -0,0 +1,2 @@
|
||||||
|
add_subdirectory(IR)
|
||||||
|
add_subdirectory(Transforms)
|
|
@ -0,0 +1,33 @@
|
||||||
|
function(_add_interfaces)
|
||||||
|
set(LLVM_TARGET_DEFINITIONS TMTensorInterfaces.td)
|
||||||
|
mlir_tablegen(TMTensorOpInterfaces.h.inc -gen-op-interface-decls)
|
||||||
|
mlir_tablegen(TMTensorOpInterfaces.cpp.inc -gen-op-interface-defs)
|
||||||
|
mlir_tablegen(TMTensorTypeInterfaces.h.inc -gen-type-interface-decls)
|
||||||
|
mlir_tablegen(TMTensorTypeInterfaces.cpp.inc -gen-type-interface-defs)
|
||||||
|
add_public_tablegen_target(TorchMLIRTMTensorInterfacesIncGen)
|
||||||
|
add_dependencies(TorchMLIRTMTensorOpsIncGen TorchMLIRTMTensorInterfacesIncGen)
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
function(_add_scalar_loop_op_interface)
|
||||||
|
set(LLVM_TARGET_DEFINITIONS ScalarLoopOpInterface.td)
|
||||||
|
mlir_tablegen(ScalarLoopOpInterface.h.inc -gen-op-interface-decls)
|
||||||
|
mlir_tablegen(ScalarLoopOpInterface.cpp.inc -gen-op-interface-defs)
|
||||||
|
add_public_tablegen_target(TorchMLIRTMTensorScalarLoopOpInterfaceIncGen)
|
||||||
|
add_dependencies(TorchMLIRTMTensorOpsIncGen TorchMLIRTMTensorScalarLoopOpInterfaceIncGen)
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
function(_add_dialect)
|
||||||
|
set(LLVM_TARGET_DEFINITIONS TMTensorOps.td)
|
||||||
|
mlir_tablegen(TMTensorOps.h.inc -gen-op-decls)
|
||||||
|
mlir_tablegen(TMTensorOps.cpp.inc -gen-op-defs)
|
||||||
|
mlir_tablegen(TMTensorTypes.h.inc -gen-typedef-decls)
|
||||||
|
mlir_tablegen(TMTensorTypes.cpp.inc -gen-typedef-defs)
|
||||||
|
mlir_tablegen(TMTensorDialect.h.inc -gen-dialect-decls -dialect=tm_tensor)
|
||||||
|
mlir_tablegen(TMTensorDialect.cpp.inc -gen-dialect-defs -dialect=tm_tensor)
|
||||||
|
add_public_tablegen_target(TorchMLIRTMTensorOpsIncGen)
|
||||||
|
add_dependencies(mlir-headers TorchMLIRTMTensorOpsIncGen)
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
_add_dialect()
|
||||||
|
_add_interfaces()
|
||||||
|
_add_scalar_loop_op_interface()
|
|
@ -0,0 +1,29 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_SCALARLOOPOPINTERFACE_H_
|
||||||
|
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_SCALARLOOPOPINTERFACE_H_
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/Interfaces/ViewLikeInterface.h"
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
|
||||||
|
/// Include the ODS generated interface header files.
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.h.inc"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace torch {
|
||||||
|
namespace TMTensor {} // namespace TMTensor
|
||||||
|
} // namespace torch
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_SCALARLOOPOPINTERFACE_H_
|
|
@ -0,0 +1,79 @@
|
||||||
|
//===-------------------------------------------------------*- tablegen -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCH_MLIR_DIALECT_TMTENSOR_SCALARLOOPOPINTERFACE
|
||||||
|
#define TORCH_MLIR_DIALECT_TMTENSOR_SCALARLOOPOPINTERFACE
|
||||||
|
|
||||||
|
include "mlir/IR/OpBase.td"
|
||||||
|
|
||||||
|
def ScalarLoopOpInterface : OpInterface<"ScalarLoopOpInterface"> {
|
||||||
|
let description = [{
|
||||||
|
Interface for allowing operations to expose information needed to
|
||||||
|
lower it to for loops
|
||||||
|
}];
|
||||||
|
let cppNamespace = "::mlir::torch::TMTensor";
|
||||||
|
let methods = [
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Returns the destination operands. For op with `memref`
|
||||||
|
operands, this is the result buffers. For op with `tensor`
|
||||||
|
operands, this is the operands that contain the initial
|
||||||
|
value for the result. These are "tied" to the result
|
||||||
|
buffers. For example, for a `LinalgOp`/`TMTensor` ops, it
|
||||||
|
is the `outs` parameters. For `tensor.insert_slice`, it is
|
||||||
|
the `dest` parameter.
|
||||||
|
}],
|
||||||
|
/*retType=*/"SmallVector<Value>",
|
||||||
|
/*methodName=*/"getDestinationOperands",
|
||||||
|
/*args=*/(ins "OpBuilder &":$b),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/"return ValueRange{};"
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Returns a list of `StringRef`s that describe the number of
|
||||||
|
loops and the iterator types of the operation. The list is
|
||||||
|
expected to use
|
||||||
|
`getParallelIteratorTypeName()`/`getReductionIteratorTypeName()`
|
||||||
|
from MLIR Structured Op Utils.
|
||||||
|
}],
|
||||||
|
/*retType=*/"SmallVector<StringRef>",
|
||||||
|
/*methodName=*/"getLoopIteratorTypes"
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Returns a list of ranges that describe the loop bounds and
|
||||||
|
step for the loops of the operation.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"SmallVector<Range>",
|
||||||
|
/*methodName=*/"getIterationDomain",
|
||||||
|
/*args=*/(ins "OpBuilder &":$b)
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Generates the loop body implementation. Assume that all the parallel
|
||||||
|
loops and reduction loops are created and the insertion point of the
|
||||||
|
build is set to the innermost of the loop. This method implements the
|
||||||
|
loop body IRs.
|
||||||
|
}],
|
||||||
|
/*retType=*/"LogicalResult",
|
||||||
|
/*methodName=*/"generateScalarImplementation",
|
||||||
|
/*args=*/(ins
|
||||||
|
"OpBuilder &":$b,
|
||||||
|
"Location ":$loc,
|
||||||
|
"ValueRange ":$ivs),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
return failure();
|
||||||
|
}]
|
||||||
|
>
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // TORCH_MLIR_DIALECT_TMTENSOR_SCALARLOOPOPINTERFACES
|
|
@ -0,0 +1,59 @@
|
||||||
|
//===-------------------------------------------------------*- tablegen -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCH_MLIR_DIALECT_TMTENSOR_BASE
|
||||||
|
#define TORCH_MLIR_DIALECT_TMTENSOR_BASE
|
||||||
|
|
||||||
|
include "mlir/IR/OpBase.td"
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Dialect definition
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def TMTensor_Dialect : Dialect {
|
||||||
|
let name = "tm_tensor";
|
||||||
|
let cppNamespace = "::mlir::torch::TMTensor";
|
||||||
|
let description = [{
|
||||||
|
The tm_tensor (tm = torch-mlir) dialect is a temporary staging ground in
|
||||||
|
the torch-mlir project for a set of widely-accepted tensor compute
|
||||||
|
operations that are not well-served by existing representations in MLIR
|
||||||
|
upstream. These ops are currently heavily inspired by the linalg_ext
|
||||||
|
dialect (which itself is heavily inspired by the structured ops of the
|
||||||
|
linalg dialect). But while linalg_ext is meant to power specific codegen
|
||||||
|
transformations, the tm_tensor dialect is a much more pure "interface
|
||||||
|
dialect" agnostic to any particular set of transformations applied to
|
||||||
|
the operations. We simply require a way to name the specified operations
|
||||||
|
for interchange between projects, without taking strong opinions on the
|
||||||
|
mechanics of transformations.
|
||||||
|
|
||||||
|
The dialect does include interfaces to generate scalar reference code for
|
||||||
|
the operations, which simultaneously provides a precise definition of their
|
||||||
|
semantics, and aids in producing executable reference implementations of
|
||||||
|
the operations.
|
||||||
|
|
||||||
|
The goal of this dialect is to eventually either be upstreamed or to be
|
||||||
|
subsumed by functionality included by upstream MLIR. It should also be kept
|
||||||
|
consistent with the linalg_ext dialect unless there is a good reason not
|
||||||
|
to.
|
||||||
|
}];
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Type definitions
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
class RankedTensorOrMemRefOf<list<Type> allowedTypes> :
|
||||||
|
ShapedContainerType<allowedTypes,
|
||||||
|
Or<[IsMemRefTypePred, And<[IsTensorTypePred, HasRankPred]>]>,
|
||||||
|
"ranked tensor or memref", "::mlir::ShapedType">;
|
||||||
|
|
||||||
|
def AnyRankedTensorOrMemRefType : RankedTensorOrMemRefOf<[AnyType]>;
|
||||||
|
|
||||||
|
#endif // TORCH_MLIR_DIALECT_TMTENSOR_BASE
|
|
@ -0,0 +1,20 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORDIALECT_H_
|
||||||
|
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORDIALECT_H_
|
||||||
|
|
||||||
|
#include "mlir/IR/Dialect.h"
|
||||||
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
|
||||||
|
// clang-format off: must be included after all LLVM/MLIR headers
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h.inc" // IWYU pragma: keep
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
#endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORDIALECT_H_
|
|
@ -0,0 +1,42 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_
|
||||||
|
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_
|
||||||
|
|
||||||
|
#include "mlir/IR/BlockAndValueMapping.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace torch {
|
||||||
|
namespace TMTensor {
|
||||||
|
class TMTensorOp;
|
||||||
|
|
||||||
|
/// OpOperand vector that implicitly converts to a Value vector.
|
||||||
|
struct OpOperandVector : public SmallVector<OpOperand *> {
|
||||||
|
operator SmallVector<Value>();
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
LogicalResult verifyTMTensorOpInterface(Operation *op);
|
||||||
|
}
|
||||||
|
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h.inc" // IWYU pragma: export
|
||||||
|
|
||||||
|
/// Include the generated interface declarations.
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOpInterfaces.h.inc" // IWYU pragma: export
|
||||||
|
|
||||||
|
} // namespace TMTensor
|
||||||
|
} // namespace torch
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_
|
|
@ -0,0 +1,493 @@
|
||||||
|
//===-------------------------------------------------------*- tablegen -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCH_MLIR_DIALECT_TMTENSOR_INTERFACES
|
||||||
|
#define TORCH_MLIR_DIALECT_TMTENSOR_INTERFACES
|
||||||
|
|
||||||
|
include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorBase.td"
|
||||||
|
|
||||||
|
// The interface is a subset of LinalgStructuredInterface.
|
||||||
|
def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
||||||
|
let methods = [
|
||||||
|
//===------------------------------------------------------------------===//
|
||||||
|
// Num input/output arguments handling.
|
||||||
|
//===------------------------------------------------------------------===//
|
||||||
|
// `inputs` must be defined by each op that wants to implement the
|
||||||
|
// LinalgStructuredInterface.
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Return the input shape operands.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"ValueRange",
|
||||||
|
/*methodName=*/"inputs",
|
||||||
|
/*args=*/(ins)
|
||||||
|
>,
|
||||||
|
// These special methods rely on `inputs` and `outputs` being defined by
|
||||||
|
// each op that wants to implement the LinalgStructuredInterface.
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Return the number of inputs.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"int64_t",
|
||||||
|
/*methodName=*/"getNumInputs",
|
||||||
|
/*args=*/(ins),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
return $_op.inputs().size();
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
// `outputs` must be defined by each op that wants to implement the
|
||||||
|
// LinalgStructuredInterface.
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Return the output shape operands.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"ValueRange",
|
||||||
|
/*methodName=*/"outputs",
|
||||||
|
/*args=*/(ins)
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Return the number of outputs.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"int64_t",
|
||||||
|
/*methodName=*/"getNumOutputs",
|
||||||
|
/*args=*/(ins),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
return $_op.outputs().size();
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Return the number of inputs and outputs.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"int64_t",
|
||||||
|
/*methodName=*/"getNumInputsAndOutputs",
|
||||||
|
/*args=*/(ins),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
return getNumInputs() + getNumOutputs();
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
//===------------------------------------------------------------------===//
|
||||||
|
// Input operands handling.
|
||||||
|
//===------------------------------------------------------------------===//
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Return the input operands.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"OpOperandVector",
|
||||||
|
/*methodName=*/"getInputOperands",
|
||||||
|
/*args=*/(ins),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
int64_t numInputs = getNumInputs();
|
||||||
|
OpOperandVector result;
|
||||||
|
result.reserve(numInputs);
|
||||||
|
llvm::transform(
|
||||||
|
this->getOperation()->getOpOperands().take_front(numInputs),
|
||||||
|
std::back_inserter(result),
|
||||||
|
[](OpOperand &opOperand) { return &opOperand; });
|
||||||
|
return result;
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Return the `i`-th input operand.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"OpOperand*",
|
||||||
|
/*methodName=*/"getInputOperand",
|
||||||
|
/*args=*/(ins "int64_t":$i),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
assert(i >= 0 && i < getNumInputs());
|
||||||
|
return &this->getOperation()->getOpOperand(i);
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Return the subset of input operands that are of buffer type.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"OpOperandVector",
|
||||||
|
/*methodName=*/"getInputBufferOperands",
|
||||||
|
/*args=*/(ins),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
OpOperandVector result;
|
||||||
|
result.reserve(getNumInputs());
|
||||||
|
llvm::copy_if(getInputOperands(),
|
||||||
|
std::back_inserter(result),
|
||||||
|
[](OpOperand *opOperand) {
|
||||||
|
return opOperand->get().getType().template isa<MemRefType>();
|
||||||
|
});
|
||||||
|
return result;
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Return the subset of input operands that are of tensor type.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"OpOperandVector",
|
||||||
|
/*methodName=*/"getInputTensorOperands",
|
||||||
|
/*args=*/(ins),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
OpOperandVector result;
|
||||||
|
result.reserve(getNumInputs());
|
||||||
|
llvm::copy_if(getInputOperands(),
|
||||||
|
std::back_inserter(result),
|
||||||
|
[](OpOperand *opOperand) {
|
||||||
|
return opOperand->get().getType().template isa<RankedTensorType>();
|
||||||
|
});
|
||||||
|
return result;
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
//===------------------------------------------------------------------===//
|
||||||
|
// Output operands handling.
|
||||||
|
//===------------------------------------------------------------------===//
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Return the output operands.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"OpOperandVector",
|
||||||
|
/*methodName=*/"getOutputOperands",
|
||||||
|
/*args=*/(ins),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
int64_t numOutputs = getNumOutputs();
|
||||||
|
OpOperandVector result;
|
||||||
|
result.reserve(numOutputs);
|
||||||
|
llvm::transform(
|
||||||
|
this->getOperation()->getOpOperands()
|
||||||
|
.drop_front(getNumInputs())
|
||||||
|
.take_front(numOutputs),
|
||||||
|
std::back_inserter(result),
|
||||||
|
[](OpOperand &opOperand) { return &opOperand; });
|
||||||
|
return result;
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Return the `i`-th output operand.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"OpOperand*",
|
||||||
|
/*methodName=*/"getOutputOperand",
|
||||||
|
/*args=*/(ins "int64_t":$i),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
assert(i >= 0 && i < getNumOutputs());
|
||||||
|
return &this->getOperation()->getOpOperand(getNumInputs() + i);
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Return the subset of output operands that are of buffer type.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"OpOperandVector",
|
||||||
|
/*methodName=*/"getOutputBufferOperands",
|
||||||
|
/*args=*/(ins),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
OpOperandVector result;
|
||||||
|
result.reserve(getNumOutputs());
|
||||||
|
llvm::copy_if(getOutputOperands(),
|
||||||
|
std::back_inserter(result),
|
||||||
|
[](OpOperand *opOperand) {
|
||||||
|
return opOperand->get().getType().template isa<MemRefType>();
|
||||||
|
});
|
||||||
|
return result;
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Return the subset of output operands that are of tensor type.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"OpOperandVector",
|
||||||
|
/*methodName=*/"getOutputTensorOperands",
|
||||||
|
/*args=*/(ins),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
OpOperandVector result;
|
||||||
|
result.reserve(getNumOutputs());
|
||||||
|
llvm::copy_if(getOutputOperands(),
|
||||||
|
std::back_inserter(result),
|
||||||
|
[](OpOperand *opOperand) {
|
||||||
|
return opOperand->get().getType().template isa<RankedTensorType>();
|
||||||
|
});
|
||||||
|
return result;
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Return the types of the subset of output operands that are of buffer type.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"SmallVector<MemRefType>",
|
||||||
|
/*methodName=*/"getOutputBufferTypes",
|
||||||
|
/*args=*/(ins),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
SmallVector<MemRefType> result;
|
||||||
|
result.reserve(getNumOutputs());
|
||||||
|
llvm::transform(getOutputBufferOperands(),
|
||||||
|
std::back_inserter(result),
|
||||||
|
[](OpOperand *opOperands) {
|
||||||
|
return opOperands->get().getType().cast<MemRefType>();
|
||||||
|
});
|
||||||
|
return result;
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Return the types of the subset of output operands that are of tensor type.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"SmallVector<RankedTensorType>",
|
||||||
|
/*methodName=*/"getOutputTensorTypes",
|
||||||
|
/*args=*/(ins),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
SmallVector<RankedTensorType> result;
|
||||||
|
result.reserve(getNumOutputs());
|
||||||
|
llvm::transform(getOutputTensorOperands(),
|
||||||
|
std::back_inserter(result),
|
||||||
|
[](OpOperand *opOperands) {
|
||||||
|
return opOperands->get().getType().cast<RankedTensorType>();
|
||||||
|
});
|
||||||
|
return result;
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
//===------------------------------------------------------------------===//
|
||||||
|
// Input and Output arguments handling.
|
||||||
|
//===------------------------------------------------------------------===//
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Return the range over input and output operands.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"OpOperandVector",
|
||||||
|
/*methodName=*/"getInputAndOutputOperands",
|
||||||
|
/*args=*/(ins),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
int64_t numInputsAndOutputs = getNumInputsAndOutputs();
|
||||||
|
OpOperandVector result;
|
||||||
|
result.reserve(numInputsAndOutputs);
|
||||||
|
llvm::transform(
|
||||||
|
this->getOperation()->getOpOperands()
|
||||||
|
.take_front(numInputsAndOutputs),
|
||||||
|
std::back_inserter(result),
|
||||||
|
[](OpOperand &opOperand) { return &opOperand; });
|
||||||
|
return result;
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Return true if the payload uses the value loaded from `opOperand`. This
|
||||||
|
is useful to avoid loading from "write-only" memory that may be
|
||||||
|
uninitialized, as well as properly cloning "read-write" operands.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"bool",
|
||||||
|
/*methodName=*/"payloadUsesValueFromOperand",
|
||||||
|
/*args=*/(ins "OpOperand *":$opOperand),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
unsigned bbArgNumber = opOperand->getOperandNumber();
|
||||||
|
// Safeguard against the named linalg ops that are manually defined and
|
||||||
|
// that only support buffer semantics: we should not be there.
|
||||||
|
// Such ops have an empty regionBuilder and are not constructed with a
|
||||||
|
// region for now. In the future they are slated to disappear.
|
||||||
|
assert(this->getOperation()->getNumRegions() == 1 && "unexpected "
|
||||||
|
"missing region (calling `payloadUsesValueFromOperand` on "
|
||||||
|
"manually defined named Linalg op?)");
|
||||||
|
Block &block = this->getOperation()->getRegion(0).front();
|
||||||
|
// Init tensors have uses.
|
||||||
|
return !block.getArgument(bbArgNumber).use_empty();
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Return true if `opOperand` is an input tensor.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"bool",
|
||||||
|
/*methodName=*/"isInputTensor",
|
||||||
|
/*args=*/(ins "OpOperand *":$opOperand),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
if (!opOperand->get().getType().template isa<RankedTensorType>())
|
||||||
|
return false;
|
||||||
|
if (opOperand->getOperandNumber() < $_op.getNumInputs())
|
||||||
|
return true;
|
||||||
|
return false;
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Return true if `opOperand` is an output tensor.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"bool",
|
||||||
|
/*methodName=*/"isOutputTensor",
|
||||||
|
/*args=*/(ins "OpOperand *":$opOperand),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
if (!opOperand->get().getType().template isa<RankedTensorType>())
|
||||||
|
return false;
|
||||||
|
if (opOperand->getOperandNumber() >= $_op.getNumInputs())
|
||||||
|
return true;
|
||||||
|
return false;
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Return true if `opOperand` is an init tensor. This is true when it is
|
||||||
|
an output tensor operand whose value is used in the payload region.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"bool",
|
||||||
|
/*methodName=*/"isInitTensor",
|
||||||
|
/*args=*/(ins "OpOperand *":$opOperand),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
if (!$_op.isOutputTensor(opOperand))
|
||||||
|
return false;
|
||||||
|
return payloadUsesValueFromOperand(opOperand);
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Return the `opOperand` rank or zero for scalars.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"int64_t",
|
||||||
|
/*methodName=*/"getRank",
|
||||||
|
/*args=*/(ins "OpOperand*":$opOperand),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
assert(opOperand->getOwner() == this->getOperation());
|
||||||
|
if (auto shapedType =
|
||||||
|
opOperand->get().getType().template dyn_cast<ShapedType>())
|
||||||
|
return shapedType.getRank();
|
||||||
|
return 0;
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Return the `opOperand` shape or an empty vector for scalars.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"ArrayRef<int64_t>",
|
||||||
|
/*methodName=*/"getShape",
|
||||||
|
/*args=*/(ins "OpOperand*":$opOperand),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
assert(opOperand->getOwner() == this->getOperation());
|
||||||
|
if (auto shapedType =
|
||||||
|
opOperand->get().getType().template dyn_cast<ShapedType>())
|
||||||
|
return shapedType.getShape();
|
||||||
|
return {};
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Return true if the `opOperand` is a scalar value.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"bool",
|
||||||
|
/*methodName=*/"isScalar",
|
||||||
|
/*args=*/(ins "OpOperand*":$opOperand),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
assert(opOperand->getOwner() == this->getOperation());
|
||||||
|
return !opOperand->get().getType().template isa<ShapedType>();
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
//===------------------------------------------------------------------===//
|
||||||
|
// Other interface methods.
|
||||||
|
//===------------------------------------------------------------------===//
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Return whether the op has only MemRef input and outputs.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"bool",
|
||||||
|
/*methodName=*/"hasBufferSemantics",
|
||||||
|
/*args=*/(ins),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
return this->getOperation()->getNumResults() == 0 &&
|
||||||
|
llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) {
|
||||||
|
return isScalar(opOperand) ||
|
||||||
|
opOperand->get().getType().template isa<MemRefType>();
|
||||||
|
}) &&
|
||||||
|
llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) {
|
||||||
|
return opOperand->get().getType().template isa<MemRefType>();
|
||||||
|
});
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Return whether the op has only RankedTensor input and outputs.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"bool",
|
||||||
|
/*methodName=*/"hasTensorSemantics",
|
||||||
|
/*args=*/(ins),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
return
|
||||||
|
llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) {
|
||||||
|
return isScalar(opOperand) ||
|
||||||
|
opOperand->get().getType().template isa<RankedTensorType>();
|
||||||
|
}) &&
|
||||||
|
llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) {
|
||||||
|
return opOperand->get().getType().template isa<RankedTensorType>();
|
||||||
|
});
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
//===------------------------------------------------------------------===//
|
||||||
|
// Other static interface methods.
|
||||||
|
//===------------------------------------------------------------------===//
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
Clone the current operation with the given location and operands. This
|
||||||
|
is used to abstract away the optional underlying region creation. This
|
||||||
|
does not change the balance between input, output_buffer and
|
||||||
|
init_tensors operands.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"Operation *",
|
||||||
|
/*methodName=*/"clone",
|
||||||
|
(ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
|
||||||
|
"ValueRange":$operands),
|
||||||
|
[{
|
||||||
|
BlockAndValueMapping bvm;
|
||||||
|
OperationState state(
|
||||||
|
loc, ConcreteOp::getOperationName(), operands, resultTypes,
|
||||||
|
$_op->getAttrs());
|
||||||
|
for (Region &r : $_op->getRegions())
|
||||||
|
r.cloneInto(state.addRegion(), bvm);
|
||||||
|
return b.createOperation(state);
|
||||||
|
}]
|
||||||
|
>
|
||||||
|
];
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
//========================================================================//
|
||||||
|
// Helper functions to mutate the `operand_segment_sizes` attribute.
|
||||||
|
// These are useful when cloning and changing operand types.
|
||||||
|
//========================================================================//
|
||||||
|
void setNumInputs(unsigned num) { setOperandSegmentAt(0, num); }
|
||||||
|
void setNumOutputBuffers(unsigned num) { setOperandSegmentAt(1, num); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
void setOperandSegmentAt(unsigned idx, unsigned val) {
|
||||||
|
auto attr = (*this)->getAttr("operand_segment_sizes")
|
||||||
|
.cast<DenseIntElementsAttr>();
|
||||||
|
unsigned i = 0;
|
||||||
|
auto newAttr = attr.mapValues(IntegerType::get(getContext(), 32),
|
||||||
|
[&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; });
|
||||||
|
getOperation()->setAttr("operand_segment_sizes", newAttr);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let verify = [{ return detail::verifyTMTensorOpInterface($_op); }];
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // TORCH_MLIR_DIALECT_TMTENSOR_INTERFACES
|
|
@ -0,0 +1,23 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_SCALARLOOPOPINTERFACE_H_
|
||||||
|
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_SCALARLOOPOPINTERFACE_H_
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/Interfaces/ViewLikeInterface.h"
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
|
||||||
|
/// Include the ODS generated interface header files.
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.h.inc"
|
||||||
|
|
||||||
|
#endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_SCALARLOOPOPINTERFACE_H_
|
|
@ -0,0 +1,41 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSOROPS_H_
|
||||||
|
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSOROPS_H_
|
||||||
|
|
||||||
|
#include "mlir/IR/Attributes.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/Dialect.h"
|
||||||
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||||
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.h"
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace torch {
|
||||||
|
namespace TMTensor {
|
||||||
|
|
||||||
|
/// Returns a `memref.dim` or `tensor.dim` operation to get the shape of `v` at
|
||||||
|
/// `dim`.
|
||||||
|
Value getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim);
|
||||||
|
|
||||||
|
/// Returns a `memref.dim` or `tensor.dim` operation to get the shape of `v` at
|
||||||
|
/// `dim`. If the shape is constant, returns the shape as an `IntegerAttr`.
|
||||||
|
OpFoldResult getDim(OpBuilder &builder, Location loc, Value v, int64_t dim);
|
||||||
|
|
||||||
|
} // namespace TMTensor
|
||||||
|
} // namespace torch
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#define GET_OP_CLASSES
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h.inc" // IWYU pragma: export
|
||||||
|
|
||||||
|
#endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSOROPS_H_
|
|
@ -0,0 +1,205 @@
|
||||||
|
//===-------------------------------------------------------*- tablegen -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCH_MLIR_DIALECT_TMTENSOR_OPS
|
||||||
|
#define TORCH_MLIR_DIALECT_TMTENSOR_OPS
|
||||||
|
|
||||||
|
include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorBase.td"
|
||||||
|
include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.td"
|
||||||
|
include "torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.td"
|
||||||
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
|
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Base class.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
class TMTensor_PureOp<string mnemonic, list<OpTrait> traits = []> :
|
||||||
|
Op<TMTensor_Dialect, mnemonic, traits> {
|
||||||
|
}
|
||||||
|
|
||||||
|
class TMTensor_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||||
|
TMTensor_PureOp<mnemonic, !listconcat(traits,
|
||||||
|
[AttrSizedOperandSegments,
|
||||||
|
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||||
|
TMTensorInterface,
|
||||||
|
SingleBlockImplicitTerminator<"::mlir::torch::TMTensor::YieldOp">
|
||||||
|
])> {
|
||||||
|
let verifier = [{ return verify$cppClass(*this); }];
|
||||||
|
let printer = [{ return print$cppClass(p, *this); }];
|
||||||
|
let parser = [{ return parse$cppClass(parser, result); }];
|
||||||
|
code extraTMTensorOpClassDeclaration = [{
|
||||||
|
SmallVector<Value> getDestinationOperands(OpBuilder &b) {
|
||||||
|
SmallVector<Value> dest(outputs().begin(), outputs().end());
|
||||||
|
return dest;
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Non-structured ops
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def TMTensor_ScanOp : TMTensor_Op<"scan"
|
||||||
|
,[DeclareOpInterfaceMethods<ScalarLoopOpInterface,
|
||||||
|
["generateScalarImplementation"]>]> {
|
||||||
|
let summary = "Scan operator";
|
||||||
|
let description = [{
|
||||||
|
Computes the inclusive/exclusive scan along a given dimension.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins Variadic<AnyShaped>:$inputs,
|
||||||
|
Variadic<AnyShaped>:$outputs,
|
||||||
|
I64Attr:$dimension,
|
||||||
|
BoolAttr:$inclusive
|
||||||
|
);
|
||||||
|
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
|
||||||
|
CArg<"int64_t", "0">:$dimension, CArg<"bool", "true">:$inclusive)>
|
||||||
|
];
|
||||||
|
|
||||||
|
let results = (outs Variadic<AnyRankedTensor>:$results);
|
||||||
|
let regions = (region AnyRegion:$region);
|
||||||
|
let hasFolder = 1;
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`dimension` `(` $dimension `)`
|
||||||
|
`inclusive` `(` $inclusive `)`
|
||||||
|
attr-dict
|
||||||
|
`ins` `(` $inputs `:` type($inputs) `)`
|
||||||
|
`outs` `(` $outputs `:` type($outputs) `)`
|
||||||
|
$region (`->` type($results)^)?
|
||||||
|
}];
|
||||||
|
|
||||||
|
let extraClassDeclaration = extraTMTensorOpClassDeclaration # [{
|
||||||
|
Value input() {
|
||||||
|
return getInputOperand(0)->get();
|
||||||
|
}
|
||||||
|
Value accumulator() {
|
||||||
|
return getOutputOperand(1)->get();
|
||||||
|
}
|
||||||
|
Value output() {
|
||||||
|
return getOutputOperand(0)->get();
|
||||||
|
}
|
||||||
|
ShapedType getOperandType() {
|
||||||
|
return input().getType().cast<ShapedType>();
|
||||||
|
}
|
||||||
|
int64_t getOperandRank() {
|
||||||
|
return getOperandType().getRank();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def TMTensor_ScatterOp : TMTensor_Op<"scatter",
|
||||||
|
[DeclareOpInterfaceMethods<ScalarLoopOpInterface,
|
||||||
|
["generateScalarImplementation"]>]> {
|
||||||
|
let summary = "Scatter operator";
|
||||||
|
let description = [{
|
||||||
|
Based on XLA operation semantics, takes two `inputs` (`update` and
|
||||||
|
`indices`) and `outputs` value (`original`). The operation updates
|
||||||
|
the value at the slices specified by `indices` by combining the
|
||||||
|
current value with the value in `updates` using the computation
|
||||||
|
specified in `region`. The `region` specifies a binary operation
|
||||||
|
of signature (T, T) -> T, where `T` is the element-type of
|
||||||
|
`updates` (and `original`). The first argument correspond the
|
||||||
|
value to be updated (i.e. from `updates`), and the second the
|
||||||
|
current value (i.e. value from `original`).
|
||||||
|
|
||||||
|
The `indices` is a 2D tensor/memref type. The first dim is the number of
|
||||||
|
updates, and the second dim is index depth. The index depth should always be
|
||||||
|
static.
|
||||||
|
|
||||||
|
The first dim of `updates` and `indices` is identical, since they represent
|
||||||
|
the number of updates.
|
||||||
|
|
||||||
|
The rank of the `original`/`result` is `index_depth + rank(%updates) - 1`.
|
||||||
|
The first `index_depth` indices are derived from `indices` and the shape of
|
||||||
|
update value must match the rest shape of `original`.
|
||||||
|
|
||||||
|
The shapes definition follows tensorflow operations execept that it force
|
||||||
|
batch dims to be 1D. See more information in
|
||||||
|
https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update
|
||||||
|
}];
|
||||||
|
let arguments = (ins
|
||||||
|
Variadic<AnyRankedTensorOrMemRefType>:$inputs,
|
||||||
|
Variadic<AnyRankedTensorOrMemRefType>:$outputs
|
||||||
|
);
|
||||||
|
let results = (outs Variadic<AnyRankedTensor>:$results);
|
||||||
|
let regions = (region AnyRegion:$region);
|
||||||
|
let assemblyFormat = [{
|
||||||
|
attr-dict (`ins` `(` $inputs^ `:` type($inputs) `)`)?
|
||||||
|
`outs` `(` $outputs `:` type($outputs) `)`
|
||||||
|
$region (`->` type($results)^)?
|
||||||
|
}];
|
||||||
|
let extraClassDeclaration = extraTMTensorOpClassDeclaration # [{
|
||||||
|
|
||||||
|
int64_t getIndexDepth() {
|
||||||
|
return getInputOperand(1)
|
||||||
|
->get()
|
||||||
|
.getType()
|
||||||
|
.cast<ShapedType>()
|
||||||
|
.getShape()
|
||||||
|
.back();
|
||||||
|
}
|
||||||
|
|
||||||
|
Value updates() {
|
||||||
|
return getInputOperand(0)->get();
|
||||||
|
}
|
||||||
|
|
||||||
|
ShapedType getUpdateType() {
|
||||||
|
return updates().getType().cast<ShapedType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
Value indices() {
|
||||||
|
return getInputOperand(1)->get();
|
||||||
|
}
|
||||||
|
|
||||||
|
ShapedType getIndicesType() {
|
||||||
|
return indices().getType().cast<ShapedType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
Value original() {
|
||||||
|
return getOutputOperand(0)->get();
|
||||||
|
}
|
||||||
|
|
||||||
|
ShapedType getOriginalType() {
|
||||||
|
return original().getType().cast<ShapedType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t getUpdateSliceRank() {
|
||||||
|
return updates().getType().cast<ShapedType>().getRank() - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isScalarUpdate() {
|
||||||
|
return getUpdateSliceRank() == 0;
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pure ops
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def TMTensor_YieldOp : TMTensor_PureOp<"yield", [NoSideEffect, ReturnLike, Terminator]> {
|
||||||
|
let summary = "TMTensor yield op";
|
||||||
|
let description = [{
|
||||||
|
`tm_tensor.yield` is a special terminator operation for blocks inside
|
||||||
|
regions in `tm_tensor` ops.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins Variadic<AnyType>:$operands);
|
||||||
|
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<(ins), [{ /* nothing to do */ }]>,
|
||||||
|
];
|
||||||
|
|
||||||
|
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // TORCH_MLIR_DIALECT_TMTENSOR_OPS
|
|
@ -0,0 +1,5 @@
|
||||||
|
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||||
|
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||||
|
mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header)
|
||||||
|
mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl)
|
||||||
|
add_public_tablegen_target(TorchMLIRTMTensorTransformsPassesIncGen)
|
|
@ -0,0 +1,26 @@
|
||||||
|
//===- PassDetail.h - TMTensor Pass class details -------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASS_DETAIL_H_
|
||||||
|
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASS_DETAIL_H_
|
||||||
|
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace torch {
|
||||||
|
namespace TMTensor {
|
||||||
|
|
||||||
|
#define GEN_PASS_CLASSES
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h.inc" // IWYU pragma: keep
|
||||||
|
|
||||||
|
} // namespace TMTensor
|
||||||
|
} // namespace torch
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASS_DETAIL_H_
|
|
@ -0,0 +1,27 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASSES_H_
|
||||||
|
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASSES_H_
|
||||||
|
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace torch {
|
||||||
|
namespace TMTensor {
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createTMTensorToLoopsPass();
|
||||||
|
|
||||||
|
void registerPasses();
|
||||||
|
|
||||||
|
} // namespace TMTensor
|
||||||
|
} // namespace torch
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASSES_H_
|
|
@ -0,0 +1,21 @@
|
||||||
|
//===-------------------------------------------------------*- tablegen -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCH_MLIR_DIALECT_TMTENSOR_PASSES
|
||||||
|
#define TORCH_MLIR_DIALECT_TMTENSOR_PASSES
|
||||||
|
|
||||||
|
include "mlir/Pass/PassBase.td"
|
||||||
|
|
||||||
|
def TMTensorToLoops :
|
||||||
|
Pass<"torch-mlir-tm-tensor-to-loops", "FuncOp"> {
|
||||||
|
let summary = "Convert TMTensor ops to loops and Linalg ops.";
|
||||||
|
let constructor = "mlir::torch::TMTensor::createTMTensorToLoopsPass()";
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // TORCH_MLIR_DIALECT_TMTENSOR_PASSES
|
|
@ -0,0 +1 @@
|
||||||
|
add_subdirectory(Dialect)
|
|
@ -0,0 +1 @@
|
||||||
|
add_subdirectory(TMTensor)
|
|
@ -0,0 +1,2 @@
|
||||||
|
add_subdirectory(IR)
|
||||||
|
add_subdirectory(Transforms)
|
|
@ -0,0 +1,29 @@
|
||||||
|
add_mlir_library(TorchMLIRTMTensorDialect
|
||||||
|
TMTensorDialect.cpp
|
||||||
|
TMTensorInterfaces.cpp
|
||||||
|
TMTensorOps.cpp
|
||||||
|
ScalarLoopOpInterface.cpp
|
||||||
|
|
||||||
|
ADDITIONAL_HEADER_DIRS
|
||||||
|
${TORCH_MLIR_DIALECTS_SOURCE_DIR}/include
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
TorchMLIRTMTensorOpsIncGen
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRAffine
|
||||||
|
MLIRDialectUtils
|
||||||
|
MLIRIR
|
||||||
|
MLIRLinalg
|
||||||
|
MLIRMath
|
||||||
|
MLIRMemRef
|
||||||
|
MLIRPass
|
||||||
|
MLIRSideEffectInterfaces
|
||||||
|
MLIRSupport
|
||||||
|
MLIRSCF
|
||||||
|
MLIRStandard
|
||||||
|
MLIRTensor
|
||||||
|
MLIRViewLikeInterface
|
||||||
|
)
|
||||||
|
|
||||||
|
torch_mlir_dialects_target_includes(TorchMLIRTMTensorDialect)
|
|
@ -0,0 +1,24 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.h"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||||
|
#include "llvm/Support/Debug.h"
|
||||||
|
|
||||||
|
#define DEBUG_TYPE "torch-mlir-tiled-op-interface"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::torch::TMTensor;
|
||||||
|
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.cpp.inc"
|
|
@ -0,0 +1,30 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
|
||||||
|
|
||||||
|
#include "mlir/IR/Attributes.h"
|
||||||
|
#include "mlir/IR/DialectImplementation.h"
|
||||||
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
#include "mlir/IR/OpImplementation.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include "llvm/Support/SourceMgr.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::torch::TMTensor;
|
||||||
|
|
||||||
|
void TMTensorDialect::initialize() {
|
||||||
|
#define GET_OP_LIST
|
||||||
|
addOperations<
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.cpp.inc"
|
||||||
|
>();
|
||||||
|
}
|
||||||
|
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.cpp.inc"
|
|
@ -0,0 +1,54 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::torch;
|
||||||
|
using namespace mlir::torch::TMTensor;
|
||||||
|
|
||||||
|
OpOperandVector::operator SmallVector<Value>() {
|
||||||
|
SmallVector<Value> result;
|
||||||
|
result.reserve(this->size());
|
||||||
|
llvm::transform(*this, std::back_inserter(result),
|
||||||
|
[](OpOperand *opOperand) { return opOperand->get(); });
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
mlir::torch::TMTensor::detail::verifyTMTensorOpInterface(Operation *op) {
|
||||||
|
TMTensorOp mtTensorOp = cast<TMTensorOp>(op);
|
||||||
|
if (op->getNumResults()) {
|
||||||
|
if (!mtTensorOp.hasTensorSemantics()) {
|
||||||
|
return mtTensorOp.emitOpError(
|
||||||
|
"expected inputs and outputs to be RankedTensorType or scalar");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (op->getNumResults() != mtTensorOp.outputs().size()) {
|
||||||
|
return mtTensorOp.emitOpError(
|
||||||
|
"expected number of outputs to be same as the number of results");
|
||||||
|
}
|
||||||
|
for (auto en : llvm::enumerate(op->getResultTypes())) {
|
||||||
|
Type outputType = mtTensorOp.outputs()[en.index()].getType();
|
||||||
|
if (en.value() != outputType) {
|
||||||
|
return mtTensorOp.emitOpError("expected type of `outs` operand #")
|
||||||
|
<< en.index() << " " << outputType
|
||||||
|
<< " to be same as result type " << en.value();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (!mtTensorOp.hasBufferSemantics()) {
|
||||||
|
return mtTensorOp.emitOpError(
|
||||||
|
"expected inputs and outputs to be MemRefType or scalar");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOpInterfaces.cpp.inc" // IWYU pragma: export
|
|
@ -0,0 +1,483 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
|
#include "mlir/Dialect/Math/IR/Math.h"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/Utils/Utils.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||||
|
#include "mlir/IR/Attributes.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/Diagnostics.h"
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
|
#include "mlir/IR/OpImplementation.h"
|
||||||
|
#include "mlir/IR/OperationSupport.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/IR/TypeUtilities.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/SmallSet.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include "llvm/ADT/TypeSwitch.h"
|
||||||
|
#include "llvm/Support/SMLoc.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::torch;
|
||||||
|
using namespace mlir::torch::TMTensor;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Utils.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static void getEffectsImpl(
|
||||||
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
||||||
|
&effects,
|
||||||
|
ValueRange results, ValueRange inputBuffers, ValueRange outputBuffers) {
|
||||||
|
for (Value value : results) {
|
||||||
|
effects.emplace_back(MemoryEffects::Allocate::get(), value,
|
||||||
|
SideEffects::DefaultResource::get());
|
||||||
|
}
|
||||||
|
for (Value value : inputBuffers) {
|
||||||
|
effects.emplace_back(MemoryEffects::Read::get(), value,
|
||||||
|
SideEffects::DefaultResource::get());
|
||||||
|
}
|
||||||
|
for (Value value : outputBuffers) {
|
||||||
|
effects.emplace_back(MemoryEffects::Read::get(), value,
|
||||||
|
SideEffects::DefaultResource::get());
|
||||||
|
effects.emplace_back(MemoryEffects::Write::get(), value,
|
||||||
|
SideEffects::DefaultResource::get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Value TMTensor::getDimValue(OpBuilder &builder, Location loc, Value v,
|
||||||
|
int64_t dim) {
|
||||||
|
return TypeSwitch<Type, Value>(v.getType())
|
||||||
|
.Case<RankedTensorType>([&](RankedTensorType t) -> Value {
|
||||||
|
return builder.create<tensor::DimOp>(loc, v, dim);
|
||||||
|
})
|
||||||
|
.Case<MemRefType>([&](MemRefType t) -> Value {
|
||||||
|
return builder.create<memref::DimOp>(loc, v, dim);
|
||||||
|
})
|
||||||
|
.Default([&](Type t) { return Value(); });
|
||||||
|
}
|
||||||
|
|
||||||
|
OpFoldResult TMTensor::getDim(OpBuilder &builder, Location loc, Value v,
|
||||||
|
int64_t dim) {
|
||||||
|
auto t = v.getType().cast<ShapedType>();
|
||||||
|
if (t.isDynamicDim(dim)) {
|
||||||
|
return getDimValue(builder, loc, v, dim);
|
||||||
|
}
|
||||||
|
return builder.getI64IntegerAttr(t.getDimSize(dim));
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ScanOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static LogicalResult verifyScanOp(ScanOp op) {
|
||||||
|
if (op.getNumInputs() != 1) {
|
||||||
|
return op.emitOpError("expected one input operands");
|
||||||
|
}
|
||||||
|
if (op.getNumOutputs() != 2) {
|
||||||
|
return op.emitOpError("expected two output operands");
|
||||||
|
}
|
||||||
|
if (!op.input().getType().isa<ShapedType>()) {
|
||||||
|
return op.emitOpError("expected first input element type to be shaped");
|
||||||
|
}
|
||||||
|
auto accumulatorType = op.accumulator().getType().cast<ShapedType>();
|
||||||
|
auto inputType = op.input().getType().cast<ShapedType>();
|
||||||
|
auto outputType = op.output().getType().cast<ShapedType>();
|
||||||
|
ArrayRef<int64_t> inputShapes = inputType.getShape();
|
||||||
|
ArrayRef<int64_t> outputShapes = outputType.getShape();
|
||||||
|
if (accumulatorType.getElementType() != inputType.getElementType()) {
|
||||||
|
return op.emitOpError(
|
||||||
|
"expected input/accumulator element types to be identical");
|
||||||
|
}
|
||||||
|
ArrayRef<int64_t> accumulatorShape = accumulatorType.getShape();
|
||||||
|
int64_t accumulatorRank = accumulatorType.getRank();
|
||||||
|
if (accumulatorRank != inputType.getRank() - 1) {
|
||||||
|
return op.emitOpError(
|
||||||
|
"expected accumulator rank to be equal to input rank - 1");
|
||||||
|
}
|
||||||
|
SmallVector<int64_t> expectedAccumulatorShape;
|
||||||
|
for (size_t i = 0; i < (size_t)inputType.getRank(); i++) {
|
||||||
|
if (i != op.dimension())
|
||||||
|
expectedAccumulatorShape.push_back(inputShapes[i]);
|
||||||
|
}
|
||||||
|
if (llvm::any_of(llvm::zip(expectedAccumulatorShape, accumulatorShape),
|
||||||
|
[](std::tuple<int64_t, int64_t> s) {
|
||||||
|
return std::get<0>(s) != ShapedType::kDynamicSize &&
|
||||||
|
std::get<1>(s) != ShapedType::kDynamicSize &&
|
||||||
|
std::get<0>(s) != std::get<1>(s);
|
||||||
|
})) {
|
||||||
|
return op.emitOpError("incompatible input/accumulator shapes");
|
||||||
|
}
|
||||||
|
if (inputType.getElementType() != outputType.getElementType()) {
|
||||||
|
return op.emitOpError(
|
||||||
|
"expected input/output element types to be identical");
|
||||||
|
}
|
||||||
|
if (inputShapes.size() != outputShapes.size()) {
|
||||||
|
return op.emitOpError("expected input/output to have identical ranks");
|
||||||
|
}
|
||||||
|
if (llvm::any_of(llvm::zip(inputShapes, outputShapes),
|
||||||
|
[](std::tuple<int64_t, int64_t> s) {
|
||||||
|
return std::get<0>(s) != ShapedType::kDynamicSize &&
|
||||||
|
std::get<1>(s) != ShapedType::kDynamicSize &&
|
||||||
|
std::get<0>(s) != std::get<1>(s);
|
||||||
|
})) {
|
||||||
|
return op.emitOpError("incompatible input/output shapes");
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Range> ScanOp::getIterationDomain(OpBuilder &builder) {
|
||||||
|
int64_t operandRank = getOperandRank();
|
||||||
|
SmallVector<Range> loopBounds(operandRank);
|
||||||
|
Location loc = getLoc();
|
||||||
|
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
|
||||||
|
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
|
||||||
|
Value source = input();
|
||||||
|
for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
|
||||||
|
loopBounds[dim].offset = zero;
|
||||||
|
loopBounds[dim].size = getDimValue(builder, loc, source, dim);
|
||||||
|
loopBounds[dim].stride = one;
|
||||||
|
}
|
||||||
|
return loopBounds;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<StringRef> ScanOp::getLoopIteratorTypes() {
|
||||||
|
SmallVector<StringRef> iteratorTypes(getOperandRank(),
|
||||||
|
getParallelIteratorTypeName());
|
||||||
|
iteratorTypes[dimension()] = getReductionIteratorTypeName();
|
||||||
|
return iteratorTypes;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generates naive scalar implementation of scan for a given operator f.
|
||||||
|
// For inclusive,
|
||||||
|
// output[0] = input[0]
|
||||||
|
// output[i] = f(output[i-1], input[i])
|
||||||
|
//
|
||||||
|
// For exclusive,
|
||||||
|
// output[0] = 0
|
||||||
|
// output[i] = f(output[i-1], input[i-1])
|
||||||
|
|
||||||
|
LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc,
|
||||||
|
ValueRange ivs) {
|
||||||
|
SmallVector<Value> indices, scanBlkArgs;
|
||||||
|
indices.append(ivs.begin(), ivs.end());
|
||||||
|
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
||||||
|
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
|
||||||
|
uint64_t scanDim = dimension();
|
||||||
|
Value cond = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
|
||||||
|
indices[scanDim], zero);
|
||||||
|
bool isInclusive = inclusive();
|
||||||
|
SmallVector<Value> accIndices;
|
||||||
|
for (size_t i = 0; i < indices.size(); i++) {
|
||||||
|
if (i != scanDim)
|
||||||
|
accIndices.push_back(indices[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto scfIf = b.create<scf::IfOp>(
|
||||||
|
loc, TypeRange{}, cond,
|
||||||
|
[&](OpBuilder &b, Location loc) {
|
||||||
|
if (isInclusive) {
|
||||||
|
auto value = b.create<memref::LoadOp>(loc, input(), indices);
|
||||||
|
b.create<memref::StoreOp>(loc, value, output(), indices);
|
||||||
|
} else {
|
||||||
|
auto value = b.create<memref::LoadOp>(loc, accumulator(), accIndices);
|
||||||
|
b.create<memref::StoreOp>(loc, value, output(), indices);
|
||||||
|
}
|
||||||
|
b.create<scf::YieldOp>(loc);
|
||||||
|
},
|
||||||
|
[&](OpBuilder &b, Location loc) {
|
||||||
|
SmallVector<Value> indices(ivs.begin(), ivs.end());
|
||||||
|
Value iv = indices[scanDim];
|
||||||
|
Value ivMinusOne = b.create<arith::SubIOp>(loc, iv, one);
|
||||||
|
indices[scanDim] = ivMinusOne;
|
||||||
|
scanBlkArgs.push_back(b.create<memref::LoadOp>(loc, output(), indices));
|
||||||
|
Value i0;
|
||||||
|
if (!isInclusive)
|
||||||
|
i0 = b.create<memref::LoadOp>(loc, input(), indices);
|
||||||
|
indices[scanDim] = iv;
|
||||||
|
if (isInclusive)
|
||||||
|
i0 = b.create<memref::LoadOp>(loc, input(), indices);
|
||||||
|
scanBlkArgs.push_back(i0);
|
||||||
|
});
|
||||||
|
|
||||||
|
auto &srcBlock = region().front();
|
||||||
|
Region ®ion = scfIf.getElseRegion();
|
||||||
|
BlockAndValueMapping bvm;
|
||||||
|
{
|
||||||
|
OpBuilder::InsertionGuard guard(b);
|
||||||
|
auto &block = region.front();
|
||||||
|
b.setInsertionPointToEnd(&block);
|
||||||
|
for (auto it : llvm::zip(srcBlock.getArguments(), scanBlkArgs)) {
|
||||||
|
bvm.map(std::get<0>(it), std::get<1>(it));
|
||||||
|
}
|
||||||
|
for (auto &blockOp : srcBlock.without_terminator()) {
|
||||||
|
b.clone(blockOp, bvm);
|
||||||
|
}
|
||||||
|
b.create<memref::StoreOp>(
|
||||||
|
loc, bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)),
|
||||||
|
output(), indices);
|
||||||
|
b.create<memref::StoreOp>(
|
||||||
|
loc, bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)),
|
||||||
|
accumulator(), accIndices);
|
||||||
|
b.create<scf::YieldOp>(loc);
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult foldMemRefCast(Operation *op) {
|
||||||
|
bool folded = false;
|
||||||
|
for (OpOperand &operand : op->getOpOperands()) {
|
||||||
|
auto castOp = operand.get().getDefiningOp<memref::CastOp>();
|
||||||
|
if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
|
||||||
|
operand.set(castOp.getOperand());
|
||||||
|
folded = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return success(folded);
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult ScanOp::fold(ArrayRef<Attribute>,
|
||||||
|
SmallVectorImpl<OpFoldResult> &) {
|
||||||
|
return foldMemRefCast(*this);
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ScatterOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
static LogicalResult verifyScatterOp(ScatterOp op) {
|
||||||
|
if (op.inputs().size() != 2) {
|
||||||
|
return op.emitOpError("expected two input operands");
|
||||||
|
}
|
||||||
|
if (op.outputs().size() != 1) {
|
||||||
|
return op.emitOpError("expected one output operand");
|
||||||
|
}
|
||||||
|
auto checkDimensionsMatch = [&](ShapedType t1, ShapedType t2, unsigned dim) {
|
||||||
|
return t1.getShape()[dim] == t2.getShape()[dim];
|
||||||
|
};
|
||||||
|
|
||||||
|
auto indicesType = op.getIndicesType();
|
||||||
|
if (indicesType.getRank() != 2 ||
|
||||||
|
!indicesType.getElementType().isInteger(32)) {
|
||||||
|
return op.emitOpError(
|
||||||
|
"expected indices to be of rank 2 of i32 element type");
|
||||||
|
}
|
||||||
|
auto indexDepth = op.getIndexDepth();
|
||||||
|
if (indexDepth == ShapedType::kDynamicSize) {
|
||||||
|
return op.emitOpError("expected index depth is static");
|
||||||
|
}
|
||||||
|
|
||||||
|
// The first dimension of the indices should match the first dimension of the
|
||||||
|
// output. They indicate to the number of updates.
|
||||||
|
auto updateType = op.getUpdateType();
|
||||||
|
if (updateType.getRank() < 1) {
|
||||||
|
return op.emitOpError("expected update value to be at least rank 1");
|
||||||
|
}
|
||||||
|
if (!checkDimensionsMatch(indicesType, updateType, 0)) {
|
||||||
|
return op.emitOpError(
|
||||||
|
"mismatch in shape of indices and update value at dim#0");
|
||||||
|
}
|
||||||
|
auto originalType = op.getOriginalType();
|
||||||
|
// indexDepth + update dims should match to original dims. The first dim of
|
||||||
|
// update is the number of updates.
|
||||||
|
if (originalType.getRank() != indexDepth + updateType.getRank() - 1) {
|
||||||
|
return op.emitOpError(
|
||||||
|
"mismatch in rank of update value, index depth and original value");
|
||||||
|
}
|
||||||
|
for (auto dim : llvm::seq<unsigned>(indexDepth, originalType.getRank())) {
|
||||||
|
// Offset one because the first dim is the number of updates.
|
||||||
|
if (updateType.getDimSize(1 + dim - indexDepth) !=
|
||||||
|
originalType.getDimSize(dim)) {
|
||||||
|
return op.emitOpError("mismatch in shape of update value dim#")
|
||||||
|
<< (1 + dim - indexDepth) << " and original value at dim#" << dim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Region ®ion = op.region();
|
||||||
|
Block *body = ®ion.front();
|
||||||
|
if (body->getNumArguments() != 2) {
|
||||||
|
return op.emitOpError("expected region to have two arguments");
|
||||||
|
}
|
||||||
|
Type arg0Type = body->getArgument(0).getType();
|
||||||
|
Type arg1Type = body->getArgument(1).getType();
|
||||||
|
if (!arg0Type.isIntOrFloat() || !arg1Type.isIntOrFloat()) {
|
||||||
|
return op.emitOpError(
|
||||||
|
"expected region to have scalar argument of integer or float types");
|
||||||
|
}
|
||||||
|
if (arg0Type != updateType.getElementType()) {
|
||||||
|
return op.emitOpError("mismatch in argument 0 of region ")
|
||||||
|
<< arg0Type << " and element type of update value "
|
||||||
|
<< updateType.getElementType();
|
||||||
|
}
|
||||||
|
if (arg1Type != originalType.getElementType()) {
|
||||||
|
return op.emitOpError("mismatch in argument 1 of region ")
|
||||||
|
<< arg1Type << " and element type of original value "
|
||||||
|
<< originalType.getElementType();
|
||||||
|
}
|
||||||
|
if (arg0Type != arg1Type) {
|
||||||
|
return op.emitOpError("mismatch in region argument types ")
|
||||||
|
<< arg0Type << " and " << arg1Type;
|
||||||
|
}
|
||||||
|
auto yieldOp = cast<TMTensor::YieldOp>(body->getTerminator());
|
||||||
|
if (yieldOp->getNumOperands() != 1) {
|
||||||
|
return yieldOp.emitOpError("expected region to yield a single value");
|
||||||
|
}
|
||||||
|
auto yieldedType = yieldOp->getOperand(0).getType();
|
||||||
|
if (yieldedType != arg0Type) {
|
||||||
|
return yieldOp.emitOpError("mismatch in type of yielded value ")
|
||||||
|
<< yieldedType << " and argument of the region " << arg0Type;
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<StringRef> ScatterOp::getLoopIteratorTypes() {
|
||||||
|
SmallVector<StringRef> iteratorTypes(getUpdateType().getRank(),
|
||||||
|
getParallelIteratorTypeName());
|
||||||
|
return iteratorTypes;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Range> ScatterOp::getIterationDomain(OpBuilder &builder) {
|
||||||
|
Location loc = getLoc();
|
||||||
|
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
|
||||||
|
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
|
||||||
|
SmallVector<Range> ranges;
|
||||||
|
for (auto dim : llvm::seq<int64_t>(0, getUpdateType().getRank())) {
|
||||||
|
Value ub = getDimValue(builder, loc, updates(), dim);
|
||||||
|
ranges.emplace_back(Range{zero, ub, one});
|
||||||
|
}
|
||||||
|
return ranges;
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b,
|
||||||
|
Location loc,
|
||||||
|
ValueRange ivs) {
|
||||||
|
auto indexDepth = getIndexDepth();
|
||||||
|
Value update = b.create<memref::LoadOp>(loc, updates(), ivs);
|
||||||
|
SmallVector<Value> starts;
|
||||||
|
SmallVector<Value> loadIndices;
|
||||||
|
loadIndices.push_back(ivs.front());
|
||||||
|
loadIndices.push_back(Value());
|
||||||
|
for (auto i : llvm::seq<unsigned>(0, indexDepth)) {
|
||||||
|
loadIndices.back() = b.create<arith::ConstantIndexOp>(loc, i);
|
||||||
|
Value idx = b.create<memref::LoadOp>(loc, indices(), loadIndices);
|
||||||
|
starts.push_back(b.create<arith::IndexCastOp>(loc, b.getIndexType(), idx));
|
||||||
|
}
|
||||||
|
starts.append(std::next(ivs.begin()), ivs.end());
|
||||||
|
Value init = b.create<memref::LoadOp>(loc, original(), starts);
|
||||||
|
|
||||||
|
BlockAndValueMapping bvm;
|
||||||
|
Block &block = region().front();
|
||||||
|
bvm.map(block.getArgument(0), update);
|
||||||
|
bvm.map(block.getArgument(1), init);
|
||||||
|
for (auto &blockOp : block.without_terminator()) {
|
||||||
|
b.clone(blockOp, bvm);
|
||||||
|
}
|
||||||
|
// The last op is linalg_ext.yield op. Store the operand to
|
||||||
|
// destination.
|
||||||
|
b.create<memref::StoreOp>(
|
||||||
|
loc, bvm.lookupOrDefault(block.getTerminator()->getOperand(0)),
|
||||||
|
original(), starts);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
#define DEFINE_OP_GET_EFFECTS(OP_NAME) \
|
||||||
|
void OP_NAME::getEffects( \
|
||||||
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> \
|
||||||
|
&effects) { \
|
||||||
|
SmallVector<Value> inputBuffers = getInputBufferOperands(); \
|
||||||
|
SmallVector<Value> outputBuffers = getOutputBufferOperands(); \
|
||||||
|
getEffectsImpl(effects, getOperation()->getResults(), inputBuffers, \
|
||||||
|
outputBuffers); \
|
||||||
|
}
|
||||||
|
|
||||||
|
DEFINE_OP_GET_EFFECTS(ScanOp)
|
||||||
|
DEFINE_OP_GET_EFFECTS(ScatterOp)
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
/// This is derived from mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp without any
|
||||||
|
/// changes.
|
||||||
|
struct FoldTensorCastOp : public OpInterfaceRewritePattern<TMTensorOp> {
|
||||||
|
using OpInterfaceRewritePattern<TMTensorOp>::OpInterfaceRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(TMTensorOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
// If no operand comes from a tensor::CastOp and can be folded then fail.
|
||||||
|
bool hasTensorCastOperand =
|
||||||
|
llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
|
||||||
|
if (opOperand->get().isa<BlockArgument>())
|
||||||
|
return false;
|
||||||
|
auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>();
|
||||||
|
return castOp && canFoldIntoConsumerOp(castOp);
|
||||||
|
});
|
||||||
|
if (!hasTensorCastOperand)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<Type, 4> newResultTypes;
|
||||||
|
newResultTypes.reserve(op->getNumResults());
|
||||||
|
SmallVector<Value, 4> newOperands;
|
||||||
|
newOperands.reserve(op->getNumOperands());
|
||||||
|
// Inputs may fold.
|
||||||
|
for (OpOperand *opOperand : op.getInputOperands()) {
|
||||||
|
auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
|
||||||
|
newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp)
|
||||||
|
? tensorCastOp.source()
|
||||||
|
: opOperand->get());
|
||||||
|
}
|
||||||
|
// Init tensors may fold, in which case the resultType must also change.
|
||||||
|
for (OpOperand *opOperand : op.getOutputOperands()) {
|
||||||
|
auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
|
||||||
|
bool fold = canFoldIntoConsumerOp(tensorCastOp);
|
||||||
|
newOperands.push_back(fold ? tensorCastOp.getOperand()
|
||||||
|
: opOperand->get());
|
||||||
|
newResultTypes.push_back(newOperands.back().getType());
|
||||||
|
}
|
||||||
|
// Clone op.
|
||||||
|
Operation *newOp =
|
||||||
|
op.clone(rewriter, op->getLoc(), newResultTypes, newOperands);
|
||||||
|
SmallVector<Value, 4> replacements;
|
||||||
|
replacements.reserve(newOp->getNumResults());
|
||||||
|
for (auto result : llvm::zip(op->getResults(), newOp->getResults())) {
|
||||||
|
Value oldResult = std::get<0>(result);
|
||||||
|
Value newResult = std::get<1>(result);
|
||||||
|
if (newResult.getType() != oldResult.getType()) {
|
||||||
|
replacements.push_back(rewriter.create<tensor::CastOp>(
|
||||||
|
op->getLoc(), oldResult.getType(), newResult));
|
||||||
|
} else {
|
||||||
|
replacements.push_back(newResult);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op, replacements);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TMTensorDialect
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void TMTensorDialect::getCanonicalizationPatterns(
|
||||||
|
RewritePatternSet &results) const {
|
||||||
|
results.add<FoldTensorCastOp>(getContext());
|
||||||
|
}
|
||||||
|
|
||||||
|
#define GET_OP_CLASSES
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.cpp.inc"
|
|
@ -0,0 +1,22 @@
|
||||||
|
add_mlir_library(TorchMLIRTMTensorPasses
|
||||||
|
ConvertToLoops.cpp
|
||||||
|
Passes.cpp
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
TorchMLIRTMTensorTransformsPassesIncGen
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
TorchMLIRTMTensorDialect
|
||||||
|
MLIRAffine
|
||||||
|
MLIRIR
|
||||||
|
MLIRLinalg
|
||||||
|
MLIRLinalgTransforms
|
||||||
|
MLIRMath
|
||||||
|
MLIRMemRef
|
||||||
|
MLIRPass
|
||||||
|
MLIRSCF
|
||||||
|
MLIRStandard
|
||||||
|
MLIRSupport
|
||||||
|
MLIRTensor
|
||||||
|
MLIRTransforms
|
||||||
|
)
|
|
@ -0,0 +1,117 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
|
#include "mlir/Dialect/Math/IR/Math.h"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/PassDetail.h"
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h"
|
||||||
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::torch::TMTensor;
|
||||||
|
|
||||||
|
/// Recursive method that lowers one dimension of the `ScalarLoopOpInterface` to
|
||||||
|
/// scalar loops at a time.
|
||||||
|
static LogicalResult lowerToLoopsImpl(OpBuilder &builder,
|
||||||
|
ScalarLoopOpInterface scalarLoopOp,
|
||||||
|
ArrayRef<Range> loopRanges,
|
||||||
|
unsigned loopDepth,
|
||||||
|
SmallVectorImpl<Value> &ivs) {
|
||||||
|
Location loc = scalarLoopOp.getLoc();
|
||||||
|
if (loopDepth == loopRanges.size()) {
|
||||||
|
return scalarLoopOp.generateScalarImplementation(builder, loc, ivs);
|
||||||
|
}
|
||||||
|
LogicalResult status = success();
|
||||||
|
builder.create<scf::ForOp>(
|
||||||
|
loc, loopRanges[loopDepth].offset, loopRanges[loopDepth].size,
|
||||||
|
loopRanges[loopDepth].stride, ValueRange{},
|
||||||
|
[&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
|
||||||
|
ivs.push_back(iv);
|
||||||
|
status =
|
||||||
|
lowerToLoopsImpl(b, scalarLoopOp, loopRanges, loopDepth + 1, ivs);
|
||||||
|
b.create<scf::YieldOp>(loc);
|
||||||
|
});
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Main entry point for lowering `ScalarLoopOpInterface` op to loops.
|
||||||
|
static LogicalResult lowerToLoops(OpBuilder &builder,
|
||||||
|
ScalarLoopOpInterface scalarLoopOp) {
|
||||||
|
SmallVector<Range> loopBounds = scalarLoopOp.getIterationDomain(builder);
|
||||||
|
SmallVector<Value> ivs;
|
||||||
|
return lowerToLoopsImpl(builder, scalarLoopOp, loopBounds, 0, ivs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Pattern rewriter hook to lower a `ScalarLoopOpInterface` to loops.
|
||||||
|
namespace {
|
||||||
|
struct ScalarLoopOpInterfaceLowerToLoopsPattern : public RewritePattern {
|
||||||
|
ScalarLoopOpInterfaceLowerToLoopsPattern(MLIRContext *context,
|
||||||
|
PatternBenefit benefit = 1)
|
||||||
|
: RewritePattern(MatchAnyOpTypeTag(), benefit, context) {}
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(Operation *op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto scalarLoopOp = dyn_cast<ScalarLoopOpInterface>(op);
|
||||||
|
if (!scalarLoopOp) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (llvm::any_of(scalarLoopOp->getResults(),
|
||||||
|
[&](Value v) { return v.getType().isa<ShapedType>(); })) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
scalarLoopOp, "lower to loops needs to have tensor semantics");
|
||||||
|
}
|
||||||
|
if (failed(lowerToLoops(rewriter, scalarLoopOp))) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pass
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
struct TMTensorToLoopsPass : public TMTensorToLoopsBase<TMTensorToLoopsPass> {
|
||||||
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<linalg::LinalgDialect, StandardOpsDialect,
|
||||||
|
mlir::arith::ArithmeticDialect, math::MathDialect,
|
||||||
|
memref::MemRefDialect, scf::SCFDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
MLIRContext *context = &getContext();
|
||||||
|
|
||||||
|
RewritePatternSet patterns(context);
|
||||||
|
patterns.insert<ScalarLoopOpInterfaceLowerToLoopsPattern>(context);
|
||||||
|
if (failed(applyPatternsAndFoldGreedily(getOperation(),
|
||||||
|
std::move(patterns)))) {
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
|
torch::TMTensor::createTMTensorToLoopsPass() {
|
||||||
|
return std::make_unique<TMTensorToLoopsPass>();
|
||||||
|
}
|
|
@ -0,0 +1,33 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h"
|
||||||
|
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Pass/PassRegistry.h"
|
||||||
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace torch {
|
||||||
|
namespace TMTensor {
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
#define GEN_PASS_REGISTRATION
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h.inc" // IWYU pragma: export
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
} // namespace TMTensor
|
||||||
|
} // namespace torch
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
void torch::TMTensor::registerPasses() {
|
||||||
|
torch::TMTensor::detail::registerPasses();
|
||||||
|
}
|
|
@ -0,0 +1,19 @@
|
||||||
|
configure_lit_site_cfg(
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py
|
||||||
|
MAIN_CONFIG
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py
|
||||||
|
)
|
||||||
|
|
||||||
|
set(TORCH_MLIR_DIALECTS_TEST_DEPENDS
|
||||||
|
FileCheck count not
|
||||||
|
torch-mlir-dialects-opt
|
||||||
|
)
|
||||||
|
|
||||||
|
add_lit_testsuite(check-torch-mlir-dialects "Running the torch-mlir-dialects regression tests"
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR}
|
||||||
|
DEPENDS ${TORCH_MLIR_DIALECTS_TEST_DEPENDS}
|
||||||
|
)
|
||||||
|
set_target_properties(check-torch-mlir-dialects PROPERTIES FOLDER "Tests")
|
||||||
|
|
||||||
|
add_lit_testsuites(TORCH_MLIR_DIALECTS ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TORCH_MLIR_DIALECTS_TEST_DEPENDS})
|
|
@ -0,0 +1,69 @@
|
||||||
|
# -*- Python -*-
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import lit.formats
|
||||||
|
import lit.util
|
||||||
|
|
||||||
|
from lit.llvm import llvm_config
|
||||||
|
from lit.llvm.subst import ToolSubst
|
||||||
|
from lit.llvm.subst import FindTool
|
||||||
|
|
||||||
|
# Configuration file for the 'lit' test runner.
|
||||||
|
|
||||||
|
# name: The name of this test suite.
|
||||||
|
config.name = 'TORCH_MLIR_DIALECTS'
|
||||||
|
|
||||||
|
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
|
||||||
|
|
||||||
|
# suffixes: A list of file extensions to treat as test files.
|
||||||
|
config.suffixes = ['.mlir', '.py']
|
||||||
|
|
||||||
|
# test_source_root: The root path where tests are located.
|
||||||
|
config.test_source_root = os.path.dirname(__file__)
|
||||||
|
|
||||||
|
# test_exec_root: The root path where tests should be run.
|
||||||
|
config.test_exec_root = os.path.join(config.torch_mlir_dialects_obj_root, 'test')
|
||||||
|
|
||||||
|
config.substitutions.append(('%PATH%', config.environment['PATH']))
|
||||||
|
config.substitutions.append(('%shlibext', config.llvm_shlib_ext))
|
||||||
|
config.substitutions.append(
|
||||||
|
('%resources_dir', os.path.join(config.torch_mlir_dialects_obj_root,
|
||||||
|
'resources')))
|
||||||
|
|
||||||
|
llvm_config.with_system_environment(['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP'])
|
||||||
|
|
||||||
|
#llvm_config.use_default_substitutions()
|
||||||
|
|
||||||
|
# excludes: A list of directories to exclude from the testsuite. The 'Inputs'
|
||||||
|
# subdirectories contain auxiliary inputs for various tests in their parent
|
||||||
|
# directories.
|
||||||
|
config.excludes = [
|
||||||
|
'Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt',
|
||||||
|
'lit.cfg.py', 'lit.site.cfg.py'
|
||||||
|
]
|
||||||
|
|
||||||
|
# test_source_root: The root path where tests are located.
|
||||||
|
config.test_source_root = os.path.dirname(__file__)
|
||||||
|
|
||||||
|
# test_exec_root: The root path where tests should be run.
|
||||||
|
config.test_exec_root = os.path.join(config.torch_mlir_dialects_obj_root, 'test')
|
||||||
|
config.standalone_tools_dir = os.path.join(config.torch_mlir_dialects_obj_root, 'bin')
|
||||||
|
|
||||||
|
# Tweak the PATH to include the tools dir.
|
||||||
|
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
|
||||||
|
|
||||||
|
tool_dirs = [config.llvm_tools_dir]
|
||||||
|
tools = [
|
||||||
|
ToolSubst('%PYTHON', config.python_executable, unresolved='ignore'),
|
||||||
|
]
|
||||||
|
|
||||||
|
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
|
@ -0,0 +1,26 @@
|
||||||
|
# -*- Python -*-
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
@LIT_SITE_CFG_IN_HEADER@
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
config.torch_mlir_dialects_obj_root = "@TORCH_MLIR_DIALECTS_BINARY_DIR@"
|
||||||
|
config.llvm_src_root = "@LLVM_SOURCE_DIR@"
|
||||||
|
config.llvm_obj_root = "@LLVM_BINARY_DIR@"
|
||||||
|
config.llvm_tools_dir = "@LLVM_TOOLS_DIR@"
|
||||||
|
config.llvm_lib_dir = "@LLVM_LIBS_DIR@"
|
||||||
|
config.llvm_shlib_dir = "@SHLIBDIR@"
|
||||||
|
config.llvm_shlib_ext = "@SHLIBEXT@"
|
||||||
|
config.llvm_exe_ext = "@EXEEXT@"
|
||||||
|
config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@"
|
||||||
|
config.python_executable = sys.executable
|
||||||
|
|
||||||
|
import lit.llvm
|
||||||
|
lit.llvm.initialize(lit_config, config)
|
||||||
|
|
||||||
|
# Let the main config do the real work.
|
||||||
|
lit_config.load_config(config, "@TORCH_MLIR_DIALECTS_SOURCE_DIR@/test/lit.cfg.py")
|
|
@ -0,0 +1,24 @@
|
||||||
|
// RUN: torch-mlir-dialects-opt -canonicalize -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @tensor.cast(
|
||||||
|
func @tensor.cast(%arg0: tensor<128xi32>) -> tensor<128xi32> {
|
||||||
|
%init = linalg.init_tensor [128] : tensor<128xi32>
|
||||||
|
%c0 = linalg.init_tensor [] : tensor<i32>
|
||||||
|
|
||||||
|
%casted_arg0 = tensor.cast %arg0 : tensor<128xi32> to tensor<?xi32>
|
||||||
|
%casted_init = tensor.cast %init : tensor<128xi32> to tensor<?xi32>
|
||||||
|
// CHECK: tm_tensor.scan
|
||||||
|
// CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<128xi32>)
|
||||||
|
// CHECK-SAME: outs(%{{[a-zA-Z0-9]*}}, %{{[a-zA-Z0-9]*}} : tensor<128xi32>, tensor<i32>)
|
||||||
|
%0, %1 = tm_tensor.scan dimension(0) inclusive(true)
|
||||||
|
ins(%casted_arg0 : tensor<?xi32>)
|
||||||
|
outs(%casted_init, %c0: tensor<?xi32>, tensor<i32>) {
|
||||||
|
^bb0(%barg0 : i32, %barg1 : i32, %barg2 : i32):
|
||||||
|
%sum = arith.addi %barg0, %barg1 : i32
|
||||||
|
tm_tensor.yield %sum : i32
|
||||||
|
} -> tensor<?xi32>, tensor<i32>
|
||||||
|
|
||||||
|
%2 = tensor.cast %0: tensor<?xi32> to tensor<128xi32>
|
||||||
|
|
||||||
|
return %2: tensor<128xi32>
|
||||||
|
}
|
|
@ -0,0 +1,332 @@
|
||||||
|
// RUN: torch-mlir-dialects-opt -split-input-file -torch-mlir-tm-tensor-to-loops %s | FileCheck %s
|
||||||
|
|
||||||
|
func @scan_1d_inclusive(%0: memref<128xi32>, %1: memref<128xi32>) {
|
||||||
|
%c0 = memref.alloc() : memref<i32>
|
||||||
|
tm_tensor.scan dimension(0) inclusive(true)
|
||||||
|
ins(%0 : memref<128xi32>) outs(%1, %c0 : memref<128xi32>, memref<i32>) {
|
||||||
|
^bb0(%arg0 : i32, %arg1 : i32):
|
||||||
|
%sum = arith.addi %arg0, %arg1 : i32
|
||||||
|
tm_tensor.yield %sum : i32
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @scan_1d_inclusive
|
||||||
|
// CHECK-SAME: %[[BUFI:[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-SAME: %[[BUFO:[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
|
||||||
|
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||||
|
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||||
|
// CHECK-DAG: %[[ACC:.+]] = memref.alloc() : memref<i32>
|
||||||
|
// CHECK: scf.for %[[ARG1:.+]] = %[[C0]] to %[[C128]] step %[[C1]]
|
||||||
|
// CHECK: %[[COND:.+]] = arith.cmpi eq, %[[ARG1]], %[[C0]] : index
|
||||||
|
// CHECK: scf.if %[[COND]] {
|
||||||
|
// CHECK: %[[V1:.+]] = memref.load %[[BUFI]][%[[ARG1]]]
|
||||||
|
// CHECK: memref.store %[[V1]], %[[BUFO]][%[[ARG1]]]
|
||||||
|
// CHECK: } else {
|
||||||
|
// CHECK: %[[T1:.+]] = arith.subi %[[ARG1]], %[[C1]] : index
|
||||||
|
// CHECK: %[[V2:.+]] = memref.load %[[BUFO]][%[[T1]]]
|
||||||
|
// CHECK: %[[V3:.+]] = memref.load %[[BUFI]][%[[ARG1]]]
|
||||||
|
// CHECK: %[[V4:.+]] = arith.addi %[[V2]], %[[V3]] : i32
|
||||||
|
// CHECK: memref.store %[[V4]], %[[BUFO]][%[[ARG1]]]
|
||||||
|
// CHECK: memref.store %[[V4]], %[[ACC]][]
|
||||||
|
// CHECK: }
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @scan_1d_exclusive(%0: memref<128xi32>, %1: memref<128xi32>) {
|
||||||
|
%c0 = memref.alloc() : memref<i32>
|
||||||
|
tm_tensor.scan dimension(0) inclusive(false)
|
||||||
|
ins(%0 : memref<128xi32>) outs(%1, %c0 : memref<128xi32>, memref<i32>) {
|
||||||
|
^bb0(%arg0 : i32, %arg1 : i32):
|
||||||
|
%sum = arith.addi %arg0, %arg1 : i32
|
||||||
|
tm_tensor.yield %sum : i32
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @scan_1d_exclusive
|
||||||
|
// CHECK-SAME: %[[BUFI:[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-SAME: %[[BUFO:[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
|
||||||
|
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||||
|
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||||
|
// CHECK-DAG: %[[ACC:.+]] = memref.alloc() : memref<i32>
|
||||||
|
// CHECK: scf.for %[[ARG1:.+]] = %[[C0]] to %[[C128]] step %[[C1]]
|
||||||
|
// CHECK: %[[COND:.+]] = arith.cmpi eq, %[[ARG1]], %[[C0]] : index
|
||||||
|
// CHECK: scf.if %[[COND]] {
|
||||||
|
// CHECK: %[[V0:.+]] = memref.load %[[ACC]][] : memref<i32>
|
||||||
|
// CHECK: memref.store %[[V0]], %[[BUFO]][%[[ARG1]]]
|
||||||
|
// CHECK: } else {
|
||||||
|
// CHECK: %[[T1:.+]] = arith.subi %[[ARG1]], %[[C1]] : index
|
||||||
|
// CHECK: %[[V2:.+]] = memref.load %[[BUFO]][%[[T1]]]
|
||||||
|
// CHECK: %[[V3:.+]] = memref.load %[[BUFI]][%[[T1]]]
|
||||||
|
// CHECK: %[[V4:.+]] = arith.addi %[[V2]], %[[V3]] : i32
|
||||||
|
// CHECK: memref.store %[[V4]], %[[BUFO]][%[[ARG1]]]
|
||||||
|
// CHECK: memref.store %[[V4]], %[[ACC]][]
|
||||||
|
// CHECK: }
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @scan_2d(%0: memref<16x32xi32>, %1: memref<16x32xi32>) {
|
||||||
|
%t0 = memref.alloc() : memref<32xi32>
|
||||||
|
tm_tensor.scan dimension(0) inclusive(true)
|
||||||
|
ins(%0 : memref<16x32xi32>) outs(%1, %t0 : memref<16x32xi32>, memref<32xi32>) {
|
||||||
|
^bb0(%arg0 : i32, %arg1 : i32):
|
||||||
|
%sum = arith.addi %arg0, %arg1 : i32
|
||||||
|
tm_tensor.yield %sum : i32
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @scan_2d
|
||||||
|
// CHECK-SAME: %[[BUFI:[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-SAME: %[[BUFO:[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
|
||||||
|
// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
|
||||||
|
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||||
|
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||||
|
// CHECK-DAG: %[[ACC:.+]] = memref.alloc() : memref<32xi32>
|
||||||
|
// CHECK: scf.for %[[ARG1:.+]] = %[[C0]] to %[[C16]] step %[[C1]]
|
||||||
|
// CHECK: scf.for %[[ARG2:.+]] = %[[C0]] to %[[C32]] step %[[C1]]
|
||||||
|
// CHECK: %[[COND:.+]] = arith.cmpi eq, %[[ARG1]], %[[C0]] : index
|
||||||
|
// CHECK: scf.if %[[COND]] {
|
||||||
|
// CHECK: %[[V1:.+]] = memref.load %[[BUFI]][%[[ARG1]], %[[ARG2]]]
|
||||||
|
// CHECK: memref.store %[[V1]], %[[BUFO]][%[[ARG1]], %[[ARG2]]]
|
||||||
|
// CHECK: } else {
|
||||||
|
// CHECK: %[[T1:.+]] = arith.subi %[[ARG1]], %[[C1]] : index
|
||||||
|
// CHECK: %[[V2:.+]] = memref.load %[[BUFO]][%[[T1]], %[[ARG2]]]
|
||||||
|
// CHECK: %[[V3:.+]] = memref.load %[[BUFI]][%[[ARG1]], %[[ARG2]]]
|
||||||
|
// CHECK: %[[V4:.+]] = arith.addi %[[V2]], %[[V3]] : i32
|
||||||
|
// CHECK: memref.store %[[V4]], %[[BUFO]][%[[ARG1]], %[[ARG2]]]
|
||||||
|
// CHECK: memref.store %[[V4]], %[[ACC]][%[[ARG2]]]
|
||||||
|
// CHECK: }
|
||||||
|
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @scatter_update_scalar_1D(
|
||||||
|
%original: memref<8xi32>, %indices: memref<3x1xi32>,
|
||||||
|
%updates: memref<3xi32>) {
|
||||||
|
tm_tensor.scatter
|
||||||
|
ins(%updates, %indices : memref<3xi32>, memref<3x1xi32>)
|
||||||
|
outs(%original : memref<8xi32>) {
|
||||||
|
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||||
|
tm_tensor.yield %arg0 : i32
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @scatter_update_scalar_1D
|
||||||
|
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||||
|
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||||
|
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
|
||||||
|
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C3]] step %[[C1]] {
|
||||||
|
// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<3xi32>
|
||||||
|
// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<3x1xi32>
|
||||||
|
// CHECK: %[[IDX:.+]] = arith.index_cast %[[T2]] : i32 to index
|
||||||
|
// CHECK: memref.store %[[T1]], %[[ORIGINAL]][%[[IDX]]]
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @scatter_add_scalar_2D(
|
||||||
|
%original: memref<4x3xi32>, %indices: memref<3x2xi32>,
|
||||||
|
%updates: memref<3xi32>) {
|
||||||
|
tm_tensor.scatter
|
||||||
|
ins(%updates, %indices : memref<3xi32>, memref<3x2xi32>)
|
||||||
|
outs(%original : memref<4x3xi32>) {
|
||||||
|
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||||
|
%0 = arith.addi %arg1, %arg0 : i32
|
||||||
|
tm_tensor.yield %0 : i32
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @scatter_add_scalar_2D
|
||||||
|
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||||
|
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||||
|
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
|
||||||
|
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C3]] step %[[C1]] {
|
||||||
|
// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<3xi32>
|
||||||
|
// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<3x2xi32>
|
||||||
|
// CHECK: %[[IDX1:.+]] = arith.index_cast %[[T2]] : i32 to index
|
||||||
|
// CHECK: %[[T3:.+]] = memref.load %[[INDICES]][%[[I]], %[[C1]]] : memref<3x2xi32>
|
||||||
|
// CHECK: %[[IDX2:.+]] = arith.index_cast %[[T3]] : i32 to index
|
||||||
|
// CHECK: %[[ORI:.+]] = memref.load %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]] : memref<4x3xi32>
|
||||||
|
// CHECK: %[[ADD:.+]] = arith.addi %[[ORI]], %[[T1]] : i32
|
||||||
|
// CHECK: memref.store %[[ADD]], %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]]
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @scatter_update_slice_2D(
|
||||||
|
%original: memref<4x3xi32>, %indices: memref<2x1xi32>,
|
||||||
|
%updates: memref<2x3xi32>) {
|
||||||
|
tm_tensor.scatter
|
||||||
|
ins(%updates, %indices : memref<2x3xi32>, memref<2x1xi32>)
|
||||||
|
outs(%original : memref<4x3xi32>) {
|
||||||
|
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||||
|
tm_tensor.yield %arg0 : i32
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: func @scatter_update_slice_2D
|
||||||
|
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||||
|
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||||
|
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
|
||||||
|
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
|
||||||
|
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C2]] step %[[C1]] {
|
||||||
|
// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[C3]] step %[[C1]] {
|
||||||
|
// CHECK: %[[UPDATE:.+]] = memref.load %[[UPDATES]][%[[I]], %[[J]]]
|
||||||
|
// CHECK: %[[INDEX:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]]
|
||||||
|
// CHECK: %[[LOC:.+]] = arith.index_cast %[[INDEX]] : i32 to index
|
||||||
|
// CHECK: memref.store %[[UPDATE]], %[[ORIGINAL]][%[[LOC]], %[[J]]]
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: }
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @scatter_add_scalar_1D(
|
||||||
|
%original: memref<8xi32>, %indices: memref<3x1xi32>,
|
||||||
|
%updates: memref<3xi32>) {
|
||||||
|
tm_tensor.scatter
|
||||||
|
ins(%updates, %indices : memref<3xi32>, memref<3x1xi32>)
|
||||||
|
outs(%original : memref<8xi32>) {
|
||||||
|
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||||
|
%0 = arith.addi %arg1, %arg0 : i32
|
||||||
|
tm_tensor.yield %0 : i32
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @scatter_add_scalar_1D
|
||||||
|
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||||
|
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||||
|
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
|
||||||
|
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C3]] step %[[C1]] {
|
||||||
|
// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<3xi32>
|
||||||
|
// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<3x1xi32>
|
||||||
|
// CHECK: %[[IDX:.+]] = arith.index_cast %[[T2]] : i32 to index
|
||||||
|
// CHECK: %[[ORI:.+]] = memref.load %[[ORIGINAL]][%[[IDX]]] : memref<8xi32>
|
||||||
|
// CHECK: %[[ADD:.+]] = arith.addi %[[ORI]], %[[T1]] : i32
|
||||||
|
// CHECK: memref.store %[[ADD]], %[[ORIGINAL]][%[[IDX]]]
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @scatter_add_slice_2D(
|
||||||
|
%original: memref<4x3xi32>, %indices: memref<2x1xi32>,
|
||||||
|
%updates: memref<2x3xi32>) {
|
||||||
|
tm_tensor.scatter
|
||||||
|
ins(%updates, %indices : memref<2x3xi32>, memref<2x1xi32>)
|
||||||
|
outs(%original : memref<4x3xi32>) {
|
||||||
|
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||||
|
%0 = arith.addi %arg1, %arg0 : i32
|
||||||
|
tm_tensor.yield %0 : i32
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: func @scatter_add_slice_2D
|
||||||
|
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||||
|
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||||
|
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
|
||||||
|
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C2]] step %[[C1]] {
|
||||||
|
// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[C3]] step %[[C1]] {
|
||||||
|
// CHECK: %[[UPDATEVAL:.+]] = memref.load %[[UPDATES]][%[[I]], %[[J]]]
|
||||||
|
// CHECK: %[[INDEXVAL:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]]
|
||||||
|
// CHECK: %[[INDEX:.+]] = arith.index_cast %[[INDEXVAL]] : i32 to index
|
||||||
|
// CHECK: %[[ORIGINALVAL:.+]] = memref.load %[[ORIGINAL]][%[[INDEX]], %[[J]]]
|
||||||
|
// CHECK: %[[STOREVAL:.+]] = arith.addi %[[ORIGINALVAL]], %[[UPDATEVAL]]
|
||||||
|
// CHECK: memref.store %[[STOREVAL]], %[[ORIGINAL]][%[[INDEX]], %[[J]]]
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @scatter_update_scalar_dynamic_1D(
|
||||||
|
%original: memref<?xi32>, %indices: memref<?x1xi32>,
|
||||||
|
%updates: memref<?xi32>) {
|
||||||
|
tm_tensor.scatter
|
||||||
|
ins(%updates, %indices : memref<?xi32>, memref<?x1xi32>)
|
||||||
|
outs(%original : memref<?xi32>) {
|
||||||
|
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||||
|
tm_tensor.yield %arg0 : i32
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @scatter_update_scalar_dynamic_1D
|
||||||
|
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||||
|
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||||
|
// CHECK-DAG: %[[UB:.+]] = memref.dim %[[UPDATES]], %[[C0]] : memref<?xi32>
|
||||||
|
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[UB]] step %[[C1]] {
|
||||||
|
// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<?xi32>
|
||||||
|
// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<?x1xi32>
|
||||||
|
// CHECK: %[[IDX:.+]] = arith.index_cast %[[T2]] : i32 to index
|
||||||
|
// CHECK: memref.store %[[T1]], %[[ORIGINAL]][%[[IDX]]]
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @scatter_add_scalar_dynamic_2D(
|
||||||
|
%original: memref<?x?xi32>, %indices: memref<?x2xi32>,
|
||||||
|
%updates: memref<?xi32>) {
|
||||||
|
tm_tensor.scatter
|
||||||
|
ins(%updates, %indices : memref<?xi32>, memref<?x2xi32>)
|
||||||
|
outs(%original : memref<?x?xi32>) {
|
||||||
|
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||||
|
%0 = arith.addi %arg1, %arg0 : i32
|
||||||
|
tm_tensor.yield %0 : i32
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @scatter_add_scalar_dynamic_2D
|
||||||
|
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||||
|
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||||
|
// CHECK-DAG: %[[UB:.+]] = memref.dim %[[UPDATES]], %[[C0]] : memref<?xi32>
|
||||||
|
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[UB]] step %[[C1]] {
|
||||||
|
// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<?xi32>
|
||||||
|
// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<?x2xi32>
|
||||||
|
// CHECK: %[[IDX1:.+]] = arith.index_cast %[[T2]] : i32 to index
|
||||||
|
// CHECK: %[[T3:.+]] = memref.load %[[INDICES]][%[[I]], %[[C1]]] : memref<?x2xi32>
|
||||||
|
// CHECK: %[[IDX2:.+]] = arith.index_cast %[[T3]] : i32 to index
|
||||||
|
// CHECK: %[[ORI:.+]] = memref.load %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]] : memref<?x?xi32>
|
||||||
|
// CHECK: %[[ADD:.+]] = arith.addi %[[ORI]], %[[T1]] : i32
|
||||||
|
// CHECK: memref.store %[[ADD]], %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]]
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @scatter_update_slice_dynamic_2D(
|
||||||
|
%original: memref<?x?xi32>, %indices: memref<?x1xi32>,
|
||||||
|
%updates: memref<?x?xi32>) {
|
||||||
|
tm_tensor.scatter
|
||||||
|
ins(%updates, %indices : memref<?x?xi32>, memref<?x1xi32>)
|
||||||
|
outs(%original : memref<?x?xi32>) {
|
||||||
|
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||||
|
tm_tensor.yield %arg0 : i32
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: func @scatter_update_slice_dynamic_2D
|
||||||
|
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
|
||||||
|
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||||
|
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||||
|
// CHECK-DAG: %[[UB1:.+]] = memref.dim %[[UPDATES]], %[[C0]] : memref<?x?xi32>
|
||||||
|
// CHECK-DAG: %[[UB2:.+]] = memref.dim %[[UPDATES]], %[[C1]] : memref<?x?xi32>
|
||||||
|
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[UB1]] step %[[C1]] {
|
||||||
|
// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[UB2]] step %[[C1]] {
|
||||||
|
// CHECK: %[[UPDATEVAL:.+]] = memref.load %[[UPDATES]][%[[I]], %[[J]]]
|
||||||
|
// CHECK: %[[INDEXVAL:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]]
|
||||||
|
// CHECK: %[[INDEX:.+]] = arith.index_cast %[[INDEXVAL]] : i32 to index
|
||||||
|
// CHECK: memref.store %[[UPDATEVAL]], %[[ORIGINAL]][%[[INDEX]], %[[J]]]
|
|
@ -0,0 +1 @@
|
||||||
|
add_subdirectory(torch-mlir-dialects-opt)
|
|
@ -0,0 +1,22 @@
|
||||||
|
set(LIBS
|
||||||
|
MLIRArithmetic
|
||||||
|
MLIRDialect
|
||||||
|
MLIRLinalg
|
||||||
|
MLIRMemRef
|
||||||
|
MLIROptLib
|
||||||
|
MLIRSCF
|
||||||
|
MLIRSCFTransforms
|
||||||
|
MLIRStandard
|
||||||
|
MLIRTensor
|
||||||
|
MLIRTransforms
|
||||||
|
TorchMLIRTMTensorDialect
|
||||||
|
TorchMLIRTMTensorPasses
|
||||||
|
)
|
||||||
|
|
||||||
|
add_llvm_tool(torch-mlir-dialects-opt
|
||||||
|
torch-mlir-dialects-opt.cpp
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
${LIBS}
|
||||||
|
)
|
||||||
|
target_link_libraries(torch-mlir-dialects-opt PRIVATE ${LIBS})
|
|
@ -0,0 +1,49 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/Dialect/SCF/Passes.h"
|
||||||
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/AsmState.h"
|
||||||
|
#include "mlir/IR/Dialect.h"
|
||||||
|
#include "mlir/Support/MlirOptMain.h"
|
||||||
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.h"
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
registerAsmPrinterCLOptions();
|
||||||
|
registerMLIRContextCLOptions();
|
||||||
|
|
||||||
|
registerTransformsPasses();
|
||||||
|
registerSCFPasses();
|
||||||
|
|
||||||
|
// Local dialects.
|
||||||
|
mlir::torch::TMTensor::registerPasses();
|
||||||
|
|
||||||
|
DialectRegistry registry;
|
||||||
|
registry.insert<
|
||||||
|
// Local dialects
|
||||||
|
mlir::torch::TMTensor::TMTensorDialect,
|
||||||
|
// Upstream dialects
|
||||||
|
mlir::arith::ArithmeticDialect, mlir::linalg::LinalgDialect,
|
||||||
|
mlir::memref::MemRefDialect, mlir::StandardOpsDialect,
|
||||||
|
mlir::scf::SCFDialect, mlir::tensor::TensorDialect>();
|
||||||
|
|
||||||
|
return mlir::asMainReturnCode(
|
||||||
|
mlir::MlirOptMain(argc, argv, "MLIR modular optimizer driver\n", registry,
|
||||||
|
/*preloadDialectsInContext=*/false));
|
||||||
|
}
|
|
@ -19,6 +19,9 @@ add_mlir_library(TorchMLIRInitAll
|
||||||
TorchMLIRTorchDialect
|
TorchMLIRTorchDialect
|
||||||
TorchMLIRTorchConversionPasses
|
TorchMLIRTorchConversionPasses
|
||||||
|
|
||||||
|
TorchMLIRTMTensorPasses
|
||||||
|
TorchMLIRTMTensorDialect
|
||||||
|
|
||||||
TorchMLIRConversionPasses
|
TorchMLIRConversionPasses
|
||||||
TorchMLIRRefBackend
|
TorchMLIRRefBackend
|
||||||
)
|
)
|
||||||
|
|
|
@ -10,6 +10,8 @@
|
||||||
#include "torch-mlir/InitAll.h"
|
#include "torch-mlir/InitAll.h"
|
||||||
|
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
||||||
|
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h"
|
||||||
#include "torch-mlir/Conversion/Passes.h"
|
#include "torch-mlir/Conversion/Passes.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
@ -20,6 +22,7 @@
|
||||||
void mlir::torch::registerAllDialects(mlir::DialectRegistry ®istry) {
|
void mlir::torch::registerAllDialects(mlir::DialectRegistry ®istry) {
|
||||||
registry.insert<mlir::torch::Torch::TorchDialect>();
|
registry.insert<mlir::torch::Torch::TorchDialect>();
|
||||||
registry.insert<mlir::torch::TorchConversion::TorchConversionDialect>();
|
registry.insert<mlir::torch::TorchConversion::TorchConversionDialect>();
|
||||||
|
registry.insert<mlir::torch::TMTensor::TMTensorDialect>();
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::torch::registerAllPasses() {
|
void mlir::torch::registerAllPasses() {
|
||||||
|
@ -28,4 +31,5 @@ void mlir::torch::registerAllPasses() {
|
||||||
|
|
||||||
mlir::torch::registerConversionPasses();
|
mlir::torch::registerConversionPasses();
|
||||||
mlir::torch::RefBackend::registerRefBackendPasses();
|
mlir::torch::RefBackend::registerRefBackendPasses();
|
||||||
|
mlir::torch::TMTensor::registerPasses();
|
||||||
}
|
}
|
||||||
|
|
3
setup.py
3
setup.py
|
@ -62,8 +62,9 @@ class CMakeBuild(build_py):
|
||||||
f"-DLLVM_TARGETS_TO_BUILD=host",
|
f"-DLLVM_TARGETS_TO_BUILD=host",
|
||||||
f"-DMLIR_ENABLE_BINDINGS_PYTHON=ON",
|
f"-DMLIR_ENABLE_BINDINGS_PYTHON=ON",
|
||||||
f"-DLLVM_ENABLE_PROJECTS=mlir",
|
f"-DLLVM_ENABLE_PROJECTS=mlir",
|
||||||
f"-DLLVM_EXTERNAL_PROJECTS=torch-mlir",
|
f"-DLLVM_EXTERNAL_PROJECTS=torch-mlir;torch-mlir-dialects",
|
||||||
f"-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR={src_dir}",
|
f"-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR={src_dir}",
|
||||||
|
f"-DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR={src_dir}/external/llvm-external-projects/torch-mlir-dialects",
|
||||||
# Optimization options for building wheels.
|
# Optimization options for building wheels.
|
||||||
f"-DCMAKE_VISIBILITY_INLINES_HIDDEN=ON",
|
f"-DCMAKE_VISIBILITY_INLINES_HIDDEN=ON",
|
||||||
f"-DCMAKE_C_VISIBILITY_PRESET=hidden",
|
f"-DCMAKE_C_VISIBILITY_PRESET=hidden",
|
||||||
|
|
Loading…
Reference in New Issue