mirror of https://github.com/llvm/torch-mlir
Add TMTensor dialect to torch-mlir
This is intended to explore support for non-structured ops that can't be modeled by Linalg dialect. `tm_tensor.scan` and `tm_tensor.scatter` are added as the first such ops. The dialect should aim to be upstreamed in the future.pull/603/head
parent
cd21dda867
commit
869daf3c22
|
@ -53,8 +53,9 @@ jobs:
|
|||
-DPython3_EXECUTABLE=$(which python) \
|
||||
-DLLVM_ENABLE_ASSERTIONS=ON \
|
||||
-DLLVM_ENABLE_PROJECTS=mlir \
|
||||
-DLLVM_EXTERNAL_PROJECTS=torch-mlir \
|
||||
-DLLVM_EXTERNAL_PROJECTS="torch-mlir;torch-mlir-dialects" \
|
||||
-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$GITHUB_WORKSPACE" \
|
||||
-DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR="${GITHUB_WORKSPACE}/external/llvm-external-projects/torch-mlir-dialects" \
|
||||
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
|
||||
-DLLVM_TARGETS_TO_BUILD=host
|
||||
ninja check-torch-mlir-all
|
||||
|
|
|
@ -25,6 +25,22 @@ project(torch-mlir LANGUAGES CXX C)
|
|||
set(CMAKE_C_STANDARD 11)
|
||||
set(CMAKE_CXX_STANDARD 14)
|
||||
|
||||
macro(torch_mlir_add_llvm_external_project name identifier location)
|
||||
message(STATUS "Adding LLVM external project ${name} (${identifier}) -> ${location}")
|
||||
if(NOT EXISTS "${location}/CMakeLists.txt")
|
||||
message(FATAL_ERROR "External project location ${location} is not valid")
|
||||
endif()
|
||||
list(APPEND LLVM_EXTERNAL_PROJECTS ${name})
|
||||
list(REMOVE_DUPLICATES LLVM_EXTERNAL_PROJECTS)
|
||||
set(LLVM_EXTERNAL_${identifier}_SOURCE_DIR ${location} CACHE STRING "" FORCE)
|
||||
set(LLVM_EXTERNAL_PROJECTS ${LLVM_EXTERNAL_PROJECTS} CACHE STRING "" FORCE)
|
||||
endmacro()
|
||||
|
||||
torch_mlir_add_llvm_external_project(
|
||||
torch-mlir-dialects
|
||||
TORCH_MLIR_DIALECTS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/external/llvm-external-projects/torch-mlir-dialects)
|
||||
|
||||
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
||||
# Out-of-tree build
|
||||
|
||||
|
@ -129,6 +145,7 @@ add_subdirectory(tools)
|
|||
add_custom_target(check-torch-mlir-all)
|
||||
add_dependencies(check-torch-mlir-all
|
||||
check-torch-mlir
|
||||
check-torch-mlir-dialects
|
||||
)
|
||||
|
||||
if(MLIR_ENABLE_BINDINGS_PYTHON)
|
||||
|
|
|
@ -62,8 +62,9 @@ cmake -GNinja -Bbuild \
|
|||
-DCMAKE_CXX_COMPILER=clang++ \
|
||||
-DPython3_FIND_VIRTUALENV=ONLY \
|
||||
-DLLVM_ENABLE_PROJECTS=mlir \
|
||||
-DLLVM_EXTERNAL_PROJECTS=torch-mlir \
|
||||
-DLLVM_EXTERNAL_PROJECTS=torch-mlir;torch-mlir-dialects \
|
||||
-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR=`pwd` \
|
||||
-DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR=`pwd`/external/llvm-external-projects/torch-mlir-dialects \
|
||||
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
|
||||
-DLLVM_TARGETS_TO_BUILD=host \
|
||||
external/llvm-project/llvm
|
||||
|
|
|
@ -19,8 +19,9 @@ cmake -GNinja -B"$build_dir" "$llvm_project_dir/llvm" \
|
|||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
|
||||
-DLLVM_ENABLE_PROJECTS=mlir \
|
||||
-DLLVM_EXTERNAL_PROJECTS=torch-mlir \
|
||||
-DLLVM_EXTERNAL_PROJECTS=torch-mlir;torch-mlir-dialects \
|
||||
-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$project_dir" \
|
||||
-DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR=${project_dir}/external/llvm-external-projects/torch-mlir-dialects \
|
||||
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
|
||||
-DLLVM_ENABLE_ASSERTIONS=ON \
|
||||
-DLLVM_TARGETS_TO_BUILD=host
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
||||
message(FATAL_ERROR
|
||||
"This project is intended to be built as part of LLVM via "
|
||||
"-DLLVM_EXTERNAL_PROJECTS=torch-mlir-dialects "
|
||||
"-DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR=${CMAKE_CURRENT_SOURCE_DIR}")
|
||||
endif()
|
||||
|
||||
option(MLIR_ENABLE_BINDINGS_PYTHON "Enables MLIR Python Bindings" OFF)
|
||||
|
||||
set(TORCH_MLIR_DIALECTS_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
|
||||
set(TORCH_MLIR_DIALECTS_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")
|
||||
message(STATUS "Building torch-mlir-dialects project at ${TORCH_MLIR_DIALECTS_SOURCE_DIR} (into ${TORCH_MLIR_DIALECTS_BINARY_DIR})")
|
||||
|
||||
# TODO: Fix this upstream so that global include directories are not needed.
|
||||
set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir)
|
||||
set(MLIR_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include)
|
||||
set(MLIR_GENERATED_INCLUDE_DIR ${LLVM_BINARY_DIR}/tools/mlir/include)
|
||||
|
||||
# TODO: Needed for tablegen. Remove.
|
||||
include_directories(SYSTEM ${MLIR_INCLUDE_DIR})
|
||||
include_directories(SYSTEM ${MLIR_GENERATED_INCLUDE_DIR})
|
||||
include_directories(SYSTEM ${TORCH_MLIR_DIALECTS_SOURCE_DIR}/include)
|
||||
|
||||
function(torch_mlir_dialects_target_includes target)
|
||||
set(_dirs
|
||||
$<BUILD_INTERFACE:${MLIR_INCLUDE_DIR}>
|
||||
$<BUILD_INTERFACE:${MLIR_GENERATED_INCLUDE_DIR}>
|
||||
$<BUILD_INTERFACE:${TORCH_MLIR_DIALECTS_SOURCE_DIR}/include>
|
||||
$<BUILD_INTERFACE:${TORCH_MLIR_DIALECTS_BINARY_DIR}/include>
|
||||
)
|
||||
# In LLVM parlance, the actual target may just be an interface and may not
|
||||
# be responsible for actually compiling anything. The corresponding obj.
|
||||
# target, when present, is just used for compilation and does not
|
||||
# contribute to the interface properties.
|
||||
# TODO: Normalize this upstream.
|
||||
target_include_directories(${target} PUBLIC ${_dirs})
|
||||
if(TARGET obj.${target})
|
||||
target_include_directories(obj.${target} PRIVATE ${_dirs})
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
# Configure CMake and tablegen.
|
||||
list(APPEND CMAKE_MODULE_PATH ${MLIR_MAIN_SRC_DIR}/cmake/modules)
|
||||
list(APPEND CMAKE_MODULE_PATH ${LLVM_MAIN_SRC_DIR}/cmake)
|
||||
set(MLIR_TABLEGEN_EXE mlir-tblgen)
|
||||
|
||||
include(TableGen)
|
||||
include(AddLLVM)
|
||||
include(AddMLIR)
|
||||
|
||||
add_subdirectory(include)
|
||||
add_subdirectory(lib)
|
||||
add_subdirectory(tools)
|
||||
add_subdirectory(test)
|
|
@ -0,0 +1 @@
|
|||
add_subdirectory(torch-mlir-dialects)
|
|
@ -0,0 +1 @@
|
|||
add_subdirectory(Dialect)
|
|
@ -0,0 +1 @@
|
|||
add_subdirectory(TMTensor)
|
|
@ -0,0 +1,2 @@
|
|||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
|
@ -0,0 +1,33 @@
|
|||
function(_add_interfaces)
|
||||
set(LLVM_TARGET_DEFINITIONS TMTensorInterfaces.td)
|
||||
mlir_tablegen(TMTensorOpInterfaces.h.inc -gen-op-interface-decls)
|
||||
mlir_tablegen(TMTensorOpInterfaces.cpp.inc -gen-op-interface-defs)
|
||||
mlir_tablegen(TMTensorTypeInterfaces.h.inc -gen-type-interface-decls)
|
||||
mlir_tablegen(TMTensorTypeInterfaces.cpp.inc -gen-type-interface-defs)
|
||||
add_public_tablegen_target(TorchMLIRTMTensorInterfacesIncGen)
|
||||
add_dependencies(TorchMLIRTMTensorOpsIncGen TorchMLIRTMTensorInterfacesIncGen)
|
||||
endfunction()
|
||||
|
||||
function(_add_scalar_loop_op_interface)
|
||||
set(LLVM_TARGET_DEFINITIONS ScalarLoopOpInterface.td)
|
||||
mlir_tablegen(ScalarLoopOpInterface.h.inc -gen-op-interface-decls)
|
||||
mlir_tablegen(ScalarLoopOpInterface.cpp.inc -gen-op-interface-defs)
|
||||
add_public_tablegen_target(TorchMLIRTMTensorScalarLoopOpInterfaceIncGen)
|
||||
add_dependencies(TorchMLIRTMTensorOpsIncGen TorchMLIRTMTensorScalarLoopOpInterfaceIncGen)
|
||||
endfunction()
|
||||
|
||||
function(_add_dialect)
|
||||
set(LLVM_TARGET_DEFINITIONS TMTensorOps.td)
|
||||
mlir_tablegen(TMTensorOps.h.inc -gen-op-decls)
|
||||
mlir_tablegen(TMTensorOps.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(TMTensorTypes.h.inc -gen-typedef-decls)
|
||||
mlir_tablegen(TMTensorTypes.cpp.inc -gen-typedef-defs)
|
||||
mlir_tablegen(TMTensorDialect.h.inc -gen-dialect-decls -dialect=tm_tensor)
|
||||
mlir_tablegen(TMTensorDialect.cpp.inc -gen-dialect-defs -dialect=tm_tensor)
|
||||
add_public_tablegen_target(TorchMLIRTMTensorOpsIncGen)
|
||||
add_dependencies(mlir-headers TorchMLIRTMTensorOpsIncGen)
|
||||
endfunction()
|
||||
|
||||
_add_dialect()
|
||||
_add_interfaces()
|
||||
_add_scalar_loop_op_interface()
|
|
@ -0,0 +1,29 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_SCALARLOOPOPINTERFACE_H_
|
||||
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_SCALARLOOPOPINTERFACE_H_
|
||||
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/Interfaces/ViewLikeInterface.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
/// Include the ODS generated interface header files.
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
namespace TMTensor {} // namespace TMTensor
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_SCALARLOOPOPINTERFACE_H_
|
|
@ -0,0 +1,79 @@
|
|||
//===-------------------------------------------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCH_MLIR_DIALECT_TMTENSOR_SCALARLOOPOPINTERFACE
|
||||
#define TORCH_MLIR_DIALECT_TMTENSOR_SCALARLOOPOPINTERFACE
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def ScalarLoopOpInterface : OpInterface<"ScalarLoopOpInterface"> {
|
||||
let description = [{
|
||||
Interface for allowing operations to expose information needed to
|
||||
lower it to for loops
|
||||
}];
|
||||
let cppNamespace = "::mlir::torch::TMTensor";
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Returns the destination operands. For op with `memref`
|
||||
operands, this is the result buffers. For op with `tensor`
|
||||
operands, this is the operands that contain the initial
|
||||
value for the result. These are "tied" to the result
|
||||
buffers. For example, for a `LinalgOp`/`TMTensor` ops, it
|
||||
is the `outs` parameters. For `tensor.insert_slice`, it is
|
||||
the `dest` parameter.
|
||||
}],
|
||||
/*retType=*/"SmallVector<Value>",
|
||||
/*methodName=*/"getDestinationOperands",
|
||||
/*args=*/(ins "OpBuilder &":$b),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/"return ValueRange{};"
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Returns a list of `StringRef`s that describe the number of
|
||||
loops and the iterator types of the operation. The list is
|
||||
expected to use
|
||||
`getParallelIteratorTypeName()`/`getReductionIteratorTypeName()`
|
||||
from MLIR Structured Op Utils.
|
||||
}],
|
||||
/*retType=*/"SmallVector<StringRef>",
|
||||
/*methodName=*/"getLoopIteratorTypes"
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Returns a list of ranges that describe the loop bounds and
|
||||
step for the loops of the operation.
|
||||
}],
|
||||
/*retTy=*/"SmallVector<Range>",
|
||||
/*methodName=*/"getIterationDomain",
|
||||
/*args=*/(ins "OpBuilder &":$b)
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Generates the loop body implementation. Assume that all the parallel
|
||||
loops and reduction loops are created and the insertion point of the
|
||||
build is set to the innermost of the loop. This method implements the
|
||||
loop body IRs.
|
||||
}],
|
||||
/*retType=*/"LogicalResult",
|
||||
/*methodName=*/"generateScalarImplementation",
|
||||
/*args=*/(ins
|
||||
"OpBuilder &":$b,
|
||||
"Location ":$loc,
|
||||
"ValueRange ":$ivs),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return failure();
|
||||
}]
|
||||
>
|
||||
];
|
||||
}
|
||||
|
||||
#endif // TORCH_MLIR_DIALECT_TMTENSOR_SCALARLOOPOPINTERFACES
|
|
@ -0,0 +1,59 @@
|
|||
//===-------------------------------------------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCH_MLIR_DIALECT_TMTENSOR_BASE
|
||||
#define TORCH_MLIR_DIALECT_TMTENSOR_BASE
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect definition
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TMTensor_Dialect : Dialect {
|
||||
let name = "tm_tensor";
|
||||
let cppNamespace = "::mlir::torch::TMTensor";
|
||||
let description = [{
|
||||
The tm_tensor (tm = torch-mlir) dialect is a temporary staging ground in
|
||||
the torch-mlir project for a set of widely-accepted tensor compute
|
||||
operations that are not well-served by existing representations in MLIR
|
||||
upstream. These ops are currently heavily inspired by the linalg_ext
|
||||
dialect (which itself is heavily inspired by the structured ops of the
|
||||
linalg dialect). But while linalg_ext is meant to power specific codegen
|
||||
transformations, the tm_tensor dialect is a much more pure "interface
|
||||
dialect" agnostic to any particular set of transformations applied to
|
||||
the operations. We simply require a way to name the specified operations
|
||||
for interchange between projects, without taking strong opinions on the
|
||||
mechanics of transformations.
|
||||
|
||||
The dialect does include interfaces to generate scalar reference code for
|
||||
the operations, which simultaneously provides a precise definition of their
|
||||
semantics, and aids in producing executable reference implementations of
|
||||
the operations.
|
||||
|
||||
The goal of this dialect is to eventually either be upstreamed or to be
|
||||
subsumed by functionality included by upstream MLIR. It should also be kept
|
||||
consistent with the linalg_ext dialect unless there is a good reason not
|
||||
to.
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class RankedTensorOrMemRefOf<list<Type> allowedTypes> :
|
||||
ShapedContainerType<allowedTypes,
|
||||
Or<[IsMemRefTypePred, And<[IsTensorTypePred, HasRankPred]>]>,
|
||||
"ranked tensor or memref", "::mlir::ShapedType">;
|
||||
|
||||
def AnyRankedTensorOrMemRefType : RankedTensorOrMemRefOf<[AnyType]>;
|
||||
|
||||
#endif // TORCH_MLIR_DIALECT_TMTENSOR_BASE
|
|
@ -0,0 +1,20 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORDIALECT_H_
|
||||
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORDIALECT_H_
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
// clang-format off: must be included after all LLVM/MLIR headers
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h.inc" // IWYU pragma: keep
|
||||
// clang-format on
|
||||
|
||||
#endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORDIALECT_H_
|
|
@ -0,0 +1,42 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_
|
||||
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_
|
||||
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
namespace TMTensor {
|
||||
class TMTensorOp;
|
||||
|
||||
/// OpOperand vector that implicitly converts to a Value vector.
|
||||
struct OpOperandVector : public SmallVector<OpOperand *> {
|
||||
operator SmallVector<Value>();
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
LogicalResult verifyTMTensorOpInterface(Operation *op);
|
||||
}
|
||||
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h.inc" // IWYU pragma: export
|
||||
|
||||
/// Include the generated interface declarations.
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOpInterfaces.h.inc" // IWYU pragma: export
|
||||
|
||||
} // namespace TMTensor
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_
|
|
@ -0,0 +1,493 @@
|
|||
//===-------------------------------------------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCH_MLIR_DIALECT_TMTENSOR_INTERFACES
|
||||
#define TORCH_MLIR_DIALECT_TMTENSOR_INTERFACES
|
||||
|
||||
include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorBase.td"
|
||||
|
||||
// The interface is a subset of LinalgStructuredInterface.
|
||||
def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
||||
let methods = [
|
||||
//===------------------------------------------------------------------===//
|
||||
// Num input/output arguments handling.
|
||||
//===------------------------------------------------------------------===//
|
||||
// `inputs` must be defined by each op that wants to implement the
|
||||
// LinalgStructuredInterface.
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the input shape operands.
|
||||
}],
|
||||
/*retTy=*/"ValueRange",
|
||||
/*methodName=*/"inputs",
|
||||
/*args=*/(ins)
|
||||
>,
|
||||
// These special methods rely on `inputs` and `outputs` being defined by
|
||||
// each op that wants to implement the LinalgStructuredInterface.
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the number of inputs.
|
||||
}],
|
||||
/*retTy=*/"int64_t",
|
||||
/*methodName=*/"getNumInputs",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return $_op.inputs().size();
|
||||
}]
|
||||
>,
|
||||
// `outputs` must be defined by each op that wants to implement the
|
||||
// LinalgStructuredInterface.
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the output shape operands.
|
||||
}],
|
||||
/*retTy=*/"ValueRange",
|
||||
/*methodName=*/"outputs",
|
||||
/*args=*/(ins)
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the number of outputs.
|
||||
}],
|
||||
/*retTy=*/"int64_t",
|
||||
/*methodName=*/"getNumOutputs",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return $_op.outputs().size();
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the number of inputs and outputs.
|
||||
}],
|
||||
/*retTy=*/"int64_t",
|
||||
/*methodName=*/"getNumInputsAndOutputs",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return getNumInputs() + getNumOutputs();
|
||||
}]
|
||||
>,
|
||||
//===------------------------------------------------------------------===//
|
||||
// Input operands handling.
|
||||
//===------------------------------------------------------------------===//
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the input operands.
|
||||
}],
|
||||
/*retTy=*/"OpOperandVector",
|
||||
/*methodName=*/"getInputOperands",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
int64_t numInputs = getNumInputs();
|
||||
OpOperandVector result;
|
||||
result.reserve(numInputs);
|
||||
llvm::transform(
|
||||
this->getOperation()->getOpOperands().take_front(numInputs),
|
||||
std::back_inserter(result),
|
||||
[](OpOperand &opOperand) { return &opOperand; });
|
||||
return result;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the `i`-th input operand.
|
||||
}],
|
||||
/*retTy=*/"OpOperand*",
|
||||
/*methodName=*/"getInputOperand",
|
||||
/*args=*/(ins "int64_t":$i),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
assert(i >= 0 && i < getNumInputs());
|
||||
return &this->getOperation()->getOpOperand(i);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the subset of input operands that are of buffer type.
|
||||
}],
|
||||
/*retTy=*/"OpOperandVector",
|
||||
/*methodName=*/"getInputBufferOperands",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
OpOperandVector result;
|
||||
result.reserve(getNumInputs());
|
||||
llvm::copy_if(getInputOperands(),
|
||||
std::back_inserter(result),
|
||||
[](OpOperand *opOperand) {
|
||||
return opOperand->get().getType().template isa<MemRefType>();
|
||||
});
|
||||
return result;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the subset of input operands that are of tensor type.
|
||||
}],
|
||||
/*retTy=*/"OpOperandVector",
|
||||
/*methodName=*/"getInputTensorOperands",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
OpOperandVector result;
|
||||
result.reserve(getNumInputs());
|
||||
llvm::copy_if(getInputOperands(),
|
||||
std::back_inserter(result),
|
||||
[](OpOperand *opOperand) {
|
||||
return opOperand->get().getType().template isa<RankedTensorType>();
|
||||
});
|
||||
return result;
|
||||
}]
|
||||
>,
|
||||
//===------------------------------------------------------------------===//
|
||||
// Output operands handling.
|
||||
//===------------------------------------------------------------------===//
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the output operands.
|
||||
}],
|
||||
/*retTy=*/"OpOperandVector",
|
||||
/*methodName=*/"getOutputOperands",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
int64_t numOutputs = getNumOutputs();
|
||||
OpOperandVector result;
|
||||
result.reserve(numOutputs);
|
||||
llvm::transform(
|
||||
this->getOperation()->getOpOperands()
|
||||
.drop_front(getNumInputs())
|
||||
.take_front(numOutputs),
|
||||
std::back_inserter(result),
|
||||
[](OpOperand &opOperand) { return &opOperand; });
|
||||
return result;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the `i`-th output operand.
|
||||
}],
|
||||
/*retTy=*/"OpOperand*",
|
||||
/*methodName=*/"getOutputOperand",
|
||||
/*args=*/(ins "int64_t":$i),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
assert(i >= 0 && i < getNumOutputs());
|
||||
return &this->getOperation()->getOpOperand(getNumInputs() + i);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the subset of output operands that are of buffer type.
|
||||
}],
|
||||
/*retTy=*/"OpOperandVector",
|
||||
/*methodName=*/"getOutputBufferOperands",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
OpOperandVector result;
|
||||
result.reserve(getNumOutputs());
|
||||
llvm::copy_if(getOutputOperands(),
|
||||
std::back_inserter(result),
|
||||
[](OpOperand *opOperand) {
|
||||
return opOperand->get().getType().template isa<MemRefType>();
|
||||
});
|
||||
return result;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the subset of output operands that are of tensor type.
|
||||
}],
|
||||
/*retTy=*/"OpOperandVector",
|
||||
/*methodName=*/"getOutputTensorOperands",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
OpOperandVector result;
|
||||
result.reserve(getNumOutputs());
|
||||
llvm::copy_if(getOutputOperands(),
|
||||
std::back_inserter(result),
|
||||
[](OpOperand *opOperand) {
|
||||
return opOperand->get().getType().template isa<RankedTensorType>();
|
||||
});
|
||||
return result;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the types of the subset of output operands that are of buffer type.
|
||||
}],
|
||||
/*retTy=*/"SmallVector<MemRefType>",
|
||||
/*methodName=*/"getOutputBufferTypes",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
SmallVector<MemRefType> result;
|
||||
result.reserve(getNumOutputs());
|
||||
llvm::transform(getOutputBufferOperands(),
|
||||
std::back_inserter(result),
|
||||
[](OpOperand *opOperands) {
|
||||
return opOperands->get().getType().cast<MemRefType>();
|
||||
});
|
||||
return result;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the types of the subset of output operands that are of tensor type.
|
||||
}],
|
||||
/*retTy=*/"SmallVector<RankedTensorType>",
|
||||
/*methodName=*/"getOutputTensorTypes",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
SmallVector<RankedTensorType> result;
|
||||
result.reserve(getNumOutputs());
|
||||
llvm::transform(getOutputTensorOperands(),
|
||||
std::back_inserter(result),
|
||||
[](OpOperand *opOperands) {
|
||||
return opOperands->get().getType().cast<RankedTensorType>();
|
||||
});
|
||||
return result;
|
||||
}]
|
||||
>,
|
||||
//===------------------------------------------------------------------===//
|
||||
// Input and Output arguments handling.
|
||||
//===------------------------------------------------------------------===//
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the range over input and output operands.
|
||||
}],
|
||||
/*retTy=*/"OpOperandVector",
|
||||
/*methodName=*/"getInputAndOutputOperands",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
int64_t numInputsAndOutputs = getNumInputsAndOutputs();
|
||||
OpOperandVector result;
|
||||
result.reserve(numInputsAndOutputs);
|
||||
llvm::transform(
|
||||
this->getOperation()->getOpOperands()
|
||||
.take_front(numInputsAndOutputs),
|
||||
std::back_inserter(result),
|
||||
[](OpOperand &opOperand) { return &opOperand; });
|
||||
return result;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return true if the payload uses the value loaded from `opOperand`. This
|
||||
is useful to avoid loading from "write-only" memory that may be
|
||||
uninitialized, as well as properly cloning "read-write" operands.
|
||||
}],
|
||||
/*retTy=*/"bool",
|
||||
/*methodName=*/"payloadUsesValueFromOperand",
|
||||
/*args=*/(ins "OpOperand *":$opOperand),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
unsigned bbArgNumber = opOperand->getOperandNumber();
|
||||
// Safeguard against the named linalg ops that are manually defined and
|
||||
// that only support buffer semantics: we should not be there.
|
||||
// Such ops have an empty regionBuilder and are not constructed with a
|
||||
// region for now. In the future they are slated to disappear.
|
||||
assert(this->getOperation()->getNumRegions() == 1 && "unexpected "
|
||||
"missing region (calling `payloadUsesValueFromOperand` on "
|
||||
"manually defined named Linalg op?)");
|
||||
Block &block = this->getOperation()->getRegion(0).front();
|
||||
// Init tensors have uses.
|
||||
return !block.getArgument(bbArgNumber).use_empty();
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return true if `opOperand` is an input tensor.
|
||||
}],
|
||||
/*retTy=*/"bool",
|
||||
/*methodName=*/"isInputTensor",
|
||||
/*args=*/(ins "OpOperand *":$opOperand),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
if (!opOperand->get().getType().template isa<RankedTensorType>())
|
||||
return false;
|
||||
if (opOperand->getOperandNumber() < $_op.getNumInputs())
|
||||
return true;
|
||||
return false;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return true if `opOperand` is an output tensor.
|
||||
}],
|
||||
/*retTy=*/"bool",
|
||||
/*methodName=*/"isOutputTensor",
|
||||
/*args=*/(ins "OpOperand *":$opOperand),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
if (!opOperand->get().getType().template isa<RankedTensorType>())
|
||||
return false;
|
||||
if (opOperand->getOperandNumber() >= $_op.getNumInputs())
|
||||
return true;
|
||||
return false;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return true if `opOperand` is an init tensor. This is true when it is
|
||||
an output tensor operand whose value is used in the payload region.
|
||||
}],
|
||||
/*retTy=*/"bool",
|
||||
/*methodName=*/"isInitTensor",
|
||||
/*args=*/(ins "OpOperand *":$opOperand),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
if (!$_op.isOutputTensor(opOperand))
|
||||
return false;
|
||||
return payloadUsesValueFromOperand(opOperand);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the `opOperand` rank or zero for scalars.
|
||||
}],
|
||||
/*retTy=*/"int64_t",
|
||||
/*methodName=*/"getRank",
|
||||
/*args=*/(ins "OpOperand*":$opOperand),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
assert(opOperand->getOwner() == this->getOperation());
|
||||
if (auto shapedType =
|
||||
opOperand->get().getType().template dyn_cast<ShapedType>())
|
||||
return shapedType.getRank();
|
||||
return 0;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the `opOperand` shape or an empty vector for scalars.
|
||||
}],
|
||||
/*retTy=*/"ArrayRef<int64_t>",
|
||||
/*methodName=*/"getShape",
|
||||
/*args=*/(ins "OpOperand*":$opOperand),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
assert(opOperand->getOwner() == this->getOperation());
|
||||
if (auto shapedType =
|
||||
opOperand->get().getType().template dyn_cast<ShapedType>())
|
||||
return shapedType.getShape();
|
||||
return {};
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return true if the `opOperand` is a scalar value.
|
||||
}],
|
||||
/*retTy=*/"bool",
|
||||
/*methodName=*/"isScalar",
|
||||
/*args=*/(ins "OpOperand*":$opOperand),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
assert(opOperand->getOwner() == this->getOperation());
|
||||
return !opOperand->get().getType().template isa<ShapedType>();
|
||||
}]
|
||||
>,
|
||||
//===------------------------------------------------------------------===//
|
||||
// Other interface methods.
|
||||
//===------------------------------------------------------------------===//
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return whether the op has only MemRef input and outputs.
|
||||
}],
|
||||
/*retTy=*/"bool",
|
||||
/*methodName=*/"hasBufferSemantics",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return this->getOperation()->getNumResults() == 0 &&
|
||||
llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) {
|
||||
return isScalar(opOperand) ||
|
||||
opOperand->get().getType().template isa<MemRefType>();
|
||||
}) &&
|
||||
llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) {
|
||||
return opOperand->get().getType().template isa<MemRefType>();
|
||||
});
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return whether the op has only RankedTensor input and outputs.
|
||||
}],
|
||||
/*retTy=*/"bool",
|
||||
/*methodName=*/"hasTensorSemantics",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return
|
||||
llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) {
|
||||
return isScalar(opOperand) ||
|
||||
opOperand->get().getType().template isa<RankedTensorType>();
|
||||
}) &&
|
||||
llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) {
|
||||
return opOperand->get().getType().template isa<RankedTensorType>();
|
||||
});
|
||||
}]
|
||||
>,
|
||||
//===------------------------------------------------------------------===//
|
||||
// Other static interface methods.
|
||||
//===------------------------------------------------------------------===//
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Clone the current operation with the given location and operands. This
|
||||
is used to abstract away the optional underlying region creation. This
|
||||
does not change the balance between input, output_buffer and
|
||||
init_tensors operands.
|
||||
}],
|
||||
/*retTy=*/"Operation *",
|
||||
/*methodName=*/"clone",
|
||||
(ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
|
||||
"ValueRange":$operands),
|
||||
[{
|
||||
BlockAndValueMapping bvm;
|
||||
OperationState state(
|
||||
loc, ConcreteOp::getOperationName(), operands, resultTypes,
|
||||
$_op->getAttrs());
|
||||
for (Region &r : $_op->getRegions())
|
||||
r.cloneInto(state.addRegion(), bvm);
|
||||
return b.createOperation(state);
|
||||
}]
|
||||
>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
//========================================================================//
|
||||
// Helper functions to mutate the `operand_segment_sizes` attribute.
|
||||
// These are useful when cloning and changing operand types.
|
||||
//========================================================================//
|
||||
void setNumInputs(unsigned num) { setOperandSegmentAt(0, num); }
|
||||
void setNumOutputBuffers(unsigned num) { setOperandSegmentAt(1, num); }
|
||||
|
||||
private:
|
||||
void setOperandSegmentAt(unsigned idx, unsigned val) {
|
||||
auto attr = (*this)->getAttr("operand_segment_sizes")
|
||||
.cast<DenseIntElementsAttr>();
|
||||
unsigned i = 0;
|
||||
auto newAttr = attr.mapValues(IntegerType::get(getContext(), 32),
|
||||
[&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; });
|
||||
getOperation()->setAttr("operand_segment_sizes", newAttr);
|
||||
}
|
||||
}];
|
||||
|
||||
let verify = [{ return detail::verifyTMTensorOpInterface($_op); }];
|
||||
}
|
||||
|
||||
#endif // TORCH_MLIR_DIALECT_TMTENSOR_INTERFACES
|
|
@ -0,0 +1,23 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_SCALARLOOPOPINTERFACE_H_
|
||||
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_SCALARLOOPOPINTERFACE_H_
|
||||
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/Interfaces/ViewLikeInterface.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
/// Include the ODS generated interface header files.
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.h.inc"
|
||||
|
||||
#endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_SCALARLOOPOPINTERFACE_H_
|
|
@ -0,0 +1,41 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSOROPS_H_
|
||||
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSOROPS_H_
|
||||
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
namespace TMTensor {
|
||||
|
||||
/// Returns a `memref.dim` or `tensor.dim` operation to get the shape of `v` at
|
||||
/// `dim`.
|
||||
Value getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim);
|
||||
|
||||
/// Returns a `memref.dim` or `tensor.dim` operation to get the shape of `v` at
|
||||
/// `dim`. If the shape is constant, returns the shape as an `IntegerAttr`.
|
||||
OpFoldResult getDim(OpBuilder &builder, Location loc, Value v, int64_t dim);
|
||||
|
||||
} // namespace TMTensor
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h.inc" // IWYU pragma: export
|
||||
|
||||
#endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSOROPS_H_
|
|
@ -0,0 +1,205 @@
|
|||
//===-------------------------------------------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCH_MLIR_DIALECT_TMTENSOR_OPS
|
||||
#define TORCH_MLIR_DIALECT_TMTENSOR_OPS
|
||||
|
||||
include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorBase.td"
|
||||
include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.td"
|
||||
include "torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Base class.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class TMTensor_PureOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<TMTensor_Dialect, mnemonic, traits> {
|
||||
}
|
||||
|
||||
class TMTensor_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
TMTensor_PureOp<mnemonic, !listconcat(traits,
|
||||
[AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
TMTensorInterface,
|
||||
SingleBlockImplicitTerminator<"::mlir::torch::TMTensor::YieldOp">
|
||||
])> {
|
||||
let verifier = [{ return verify$cppClass(*this); }];
|
||||
let printer = [{ return print$cppClass(p, *this); }];
|
||||
let parser = [{ return parse$cppClass(parser, result); }];
|
||||
code extraTMTensorOpClassDeclaration = [{
|
||||
SmallVector<Value> getDestinationOperands(OpBuilder &b) {
|
||||
SmallVector<Value> dest(outputs().begin(), outputs().end());
|
||||
return dest;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Non-structured ops
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TMTensor_ScanOp : TMTensor_Op<"scan"
|
||||
,[DeclareOpInterfaceMethods<ScalarLoopOpInterface,
|
||||
["generateScalarImplementation"]>]> {
|
||||
let summary = "Scan operator";
|
||||
let description = [{
|
||||
Computes the inclusive/exclusive scan along a given dimension.
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<AnyShaped>:$inputs,
|
||||
Variadic<AnyShaped>:$outputs,
|
||||
I64Attr:$dimension,
|
||||
BoolAttr:$inclusive
|
||||
);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
|
||||
CArg<"int64_t", "0">:$dimension, CArg<"bool", "true">:$inclusive)>
|
||||
];
|
||||
|
||||
let results = (outs Variadic<AnyRankedTensor>:$results);
|
||||
let regions = (region AnyRegion:$region);
|
||||
let hasFolder = 1;
|
||||
let assemblyFormat = [{
|
||||
`dimension` `(` $dimension `)`
|
||||
`inclusive` `(` $inclusive `)`
|
||||
attr-dict
|
||||
`ins` `(` $inputs `:` type($inputs) `)`
|
||||
`outs` `(` $outputs `:` type($outputs) `)`
|
||||
$region (`->` type($results)^)?
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = extraTMTensorOpClassDeclaration # [{
|
||||
Value input() {
|
||||
return getInputOperand(0)->get();
|
||||
}
|
||||
Value accumulator() {
|
||||
return getOutputOperand(1)->get();
|
||||
}
|
||||
Value output() {
|
||||
return getOutputOperand(0)->get();
|
||||
}
|
||||
ShapedType getOperandType() {
|
||||
return input().getType().cast<ShapedType>();
|
||||
}
|
||||
int64_t getOperandRank() {
|
||||
return getOperandType().getRank();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def TMTensor_ScatterOp : TMTensor_Op<"scatter",
|
||||
[DeclareOpInterfaceMethods<ScalarLoopOpInterface,
|
||||
["generateScalarImplementation"]>]> {
|
||||
let summary = "Scatter operator";
|
||||
let description = [{
|
||||
Based on XLA operation semantics, takes two `inputs` (`update` and
|
||||
`indices`) and `outputs` value (`original`). The operation updates
|
||||
the value at the slices specified by `indices` by combining the
|
||||
current value with the value in `updates` using the computation
|
||||
specified in `region`. The `region` specifies a binary operation
|
||||
of signature (T, T) -> T, where `T` is the element-type of
|
||||
`updates` (and `original`). The first argument correspond the
|
||||
value to be updated (i.e. from `updates`), and the second the
|
||||
current value (i.e. value from `original`).
|
||||
|
||||
The `indices` is a 2D tensor/memref type. The first dim is the number of
|
||||
updates, and the second dim is index depth. The index depth should always be
|
||||
static.
|
||||
|
||||
The first dim of `updates` and `indices` is identical, since they represent
|
||||
the number of updates.
|
||||
|
||||
The rank of the `original`/`result` is `index_depth + rank(%updates) - 1`.
|
||||
The first `index_depth` indices are derived from `indices` and the shape of
|
||||
update value must match the rest shape of `original`.
|
||||
|
||||
The shapes definition follows tensorflow operations execept that it force
|
||||
batch dims to be 1D. See more information in
|
||||
https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update
|
||||
}];
|
||||
let arguments = (ins
|
||||
Variadic<AnyRankedTensorOrMemRefType>:$inputs,
|
||||
Variadic<AnyRankedTensorOrMemRefType>:$outputs
|
||||
);
|
||||
let results = (outs Variadic<AnyRankedTensor>:$results);
|
||||
let regions = (region AnyRegion:$region);
|
||||
let assemblyFormat = [{
|
||||
attr-dict (`ins` `(` $inputs^ `:` type($inputs) `)`)?
|
||||
`outs` `(` $outputs `:` type($outputs) `)`
|
||||
$region (`->` type($results)^)?
|
||||
}];
|
||||
let extraClassDeclaration = extraTMTensorOpClassDeclaration # [{
|
||||
|
||||
int64_t getIndexDepth() {
|
||||
return getInputOperand(1)
|
||||
->get()
|
||||
.getType()
|
||||
.cast<ShapedType>()
|
||||
.getShape()
|
||||
.back();
|
||||
}
|
||||
|
||||
Value updates() {
|
||||
return getInputOperand(0)->get();
|
||||
}
|
||||
|
||||
ShapedType getUpdateType() {
|
||||
return updates().getType().cast<ShapedType>();
|
||||
}
|
||||
|
||||
Value indices() {
|
||||
return getInputOperand(1)->get();
|
||||
}
|
||||
|
||||
ShapedType getIndicesType() {
|
||||
return indices().getType().cast<ShapedType>();
|
||||
}
|
||||
|
||||
Value original() {
|
||||
return getOutputOperand(0)->get();
|
||||
}
|
||||
|
||||
ShapedType getOriginalType() {
|
||||
return original().getType().cast<ShapedType>();
|
||||
}
|
||||
|
||||
int64_t getUpdateSliceRank() {
|
||||
return updates().getType().cast<ShapedType>().getRank() - 1;
|
||||
}
|
||||
|
||||
bool isScalarUpdate() {
|
||||
return getUpdateSliceRank() == 0;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pure ops
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TMTensor_YieldOp : TMTensor_PureOp<"yield", [NoSideEffect, ReturnLike, Terminator]> {
|
||||
let summary = "TMTensor yield op";
|
||||
let description = [{
|
||||
`tm_tensor.yield` is a special terminator operation for blocks inside
|
||||
regions in `tm_tensor` ops.
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<AnyType>:$operands);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins), [{ /* nothing to do */ }]>,
|
||||
];
|
||||
|
||||
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
|
||||
}
|
||||
|
||||
#endif // TORCH_MLIR_DIALECT_TMTENSOR_OPS
|
|
@ -0,0 +1,5 @@
|
|||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||
mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header)
|
||||
mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl)
|
||||
add_public_tablegen_target(TorchMLIRTMTensorTransformsPassesIncGen)
|
|
@ -0,0 +1,26 @@
|
|||
//===- PassDetail.h - TMTensor Pass class details -------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASS_DETAIL_H_
|
||||
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASS_DETAIL_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
namespace TMTensor {
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h.inc" // IWYU pragma: keep
|
||||
|
||||
} // namespace TMTensor
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASS_DETAIL_H_
|
|
@ -0,0 +1,27 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASSES_H_
|
||||
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASSES_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
namespace TMTensor {
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createTMTensorToLoopsPass();
|
||||
|
||||
void registerPasses();
|
||||
|
||||
} // namespace TMTensor
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASSES_H_
|
|
@ -0,0 +1,21 @@
|
|||
//===-------------------------------------------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCH_MLIR_DIALECT_TMTENSOR_PASSES
|
||||
#define TORCH_MLIR_DIALECT_TMTENSOR_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def TMTensorToLoops :
|
||||
Pass<"torch-mlir-tm-tensor-to-loops", "FuncOp"> {
|
||||
let summary = "Convert TMTensor ops to loops and Linalg ops.";
|
||||
let constructor = "mlir::torch::TMTensor::createTMTensorToLoopsPass()";
|
||||
}
|
||||
|
||||
#endif // TORCH_MLIR_DIALECT_TMTENSOR_PASSES
|
|
@ -0,0 +1 @@
|
|||
add_subdirectory(Dialect)
|
|
@ -0,0 +1 @@
|
|||
add_subdirectory(TMTensor)
|
|
@ -0,0 +1,2 @@
|
|||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
|
@ -0,0 +1,29 @@
|
|||
add_mlir_library(TorchMLIRTMTensorDialect
|
||||
TMTensorDialect.cpp
|
||||
TMTensorInterfaces.cpp
|
||||
TMTensorOps.cpp
|
||||
ScalarLoopOpInterface.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${TORCH_MLIR_DIALECTS_SOURCE_DIR}/include
|
||||
|
||||
DEPENDS
|
||||
TorchMLIRTMTensorOpsIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRAffine
|
||||
MLIRDialectUtils
|
||||
MLIRIR
|
||||
MLIRLinalg
|
||||
MLIRMath
|
||||
MLIRMemRef
|
||||
MLIRPass
|
||||
MLIRSideEffectInterfaces
|
||||
MLIRSupport
|
||||
MLIRSCF
|
||||
MLIRStandard
|
||||
MLIRTensor
|
||||
MLIRViewLikeInterface
|
||||
)
|
||||
|
||||
torch_mlir_dialects_target_includes(TorchMLIRTMTensorDialect)
|
|
@ -0,0 +1,24 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.h"
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#define DEBUG_TYPE "torch-mlir-tiled-op-interface"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch::TMTensor;
|
||||
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.cpp.inc"
|
|
@ -0,0 +1,30 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
|
||||
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch::TMTensor;
|
||||
|
||||
void TMTensorDialect::initialize() {
|
||||
#define GET_OP_LIST
|
||||
addOperations<
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.cpp.inc"
|
||||
>();
|
||||
}
|
||||
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.cpp.inc"
|
|
@ -0,0 +1,54 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::TMTensor;
|
||||
|
||||
OpOperandVector::operator SmallVector<Value>() {
|
||||
SmallVector<Value> result;
|
||||
result.reserve(this->size());
|
||||
llvm::transform(*this, std::back_inserter(result),
|
||||
[](OpOperand *opOperand) { return opOperand->get(); });
|
||||
return result;
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
mlir::torch::TMTensor::detail::verifyTMTensorOpInterface(Operation *op) {
|
||||
TMTensorOp mtTensorOp = cast<TMTensorOp>(op);
|
||||
if (op->getNumResults()) {
|
||||
if (!mtTensorOp.hasTensorSemantics()) {
|
||||
return mtTensorOp.emitOpError(
|
||||
"expected inputs and outputs to be RankedTensorType or scalar");
|
||||
}
|
||||
|
||||
if (op->getNumResults() != mtTensorOp.outputs().size()) {
|
||||
return mtTensorOp.emitOpError(
|
||||
"expected number of outputs to be same as the number of results");
|
||||
}
|
||||
for (auto en : llvm::enumerate(op->getResultTypes())) {
|
||||
Type outputType = mtTensorOp.outputs()[en.index()].getType();
|
||||
if (en.value() != outputType) {
|
||||
return mtTensorOp.emitOpError("expected type of `outs` operand #")
|
||||
<< en.index() << " " << outputType
|
||||
<< " to be same as result type " << en.value();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (!mtTensorOp.hasBufferSemantics()) {
|
||||
return mtTensorOp.emitOpError(
|
||||
"expected inputs and outputs to be MemRefType or scalar");
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOpInterfaces.cpp.inc" // IWYU pragma: export
|
|
@ -0,0 +1,483 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/StandardOps/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "llvm/Support/SMLoc.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::TMTensor;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utils.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void getEffectsImpl(
|
||||
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
||||
&effects,
|
||||
ValueRange results, ValueRange inputBuffers, ValueRange outputBuffers) {
|
||||
for (Value value : results) {
|
||||
effects.emplace_back(MemoryEffects::Allocate::get(), value,
|
||||
SideEffects::DefaultResource::get());
|
||||
}
|
||||
for (Value value : inputBuffers) {
|
||||
effects.emplace_back(MemoryEffects::Read::get(), value,
|
||||
SideEffects::DefaultResource::get());
|
||||
}
|
||||
for (Value value : outputBuffers) {
|
||||
effects.emplace_back(MemoryEffects::Read::get(), value,
|
||||
SideEffects::DefaultResource::get());
|
||||
effects.emplace_back(MemoryEffects::Write::get(), value,
|
||||
SideEffects::DefaultResource::get());
|
||||
}
|
||||
}
|
||||
|
||||
Value TMTensor::getDimValue(OpBuilder &builder, Location loc, Value v,
|
||||
int64_t dim) {
|
||||
return TypeSwitch<Type, Value>(v.getType())
|
||||
.Case<RankedTensorType>([&](RankedTensorType t) -> Value {
|
||||
return builder.create<tensor::DimOp>(loc, v, dim);
|
||||
})
|
||||
.Case<MemRefType>([&](MemRefType t) -> Value {
|
||||
return builder.create<memref::DimOp>(loc, v, dim);
|
||||
})
|
||||
.Default([&](Type t) { return Value(); });
|
||||
}
|
||||
|
||||
OpFoldResult TMTensor::getDim(OpBuilder &builder, Location loc, Value v,
|
||||
int64_t dim) {
|
||||
auto t = v.getType().cast<ShapedType>();
|
||||
if (t.isDynamicDim(dim)) {
|
||||
return getDimValue(builder, loc, v, dim);
|
||||
}
|
||||
return builder.getI64IntegerAttr(t.getDimSize(dim));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ScanOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verifyScanOp(ScanOp op) {
|
||||
if (op.getNumInputs() != 1) {
|
||||
return op.emitOpError("expected one input operands");
|
||||
}
|
||||
if (op.getNumOutputs() != 2) {
|
||||
return op.emitOpError("expected two output operands");
|
||||
}
|
||||
if (!op.input().getType().isa<ShapedType>()) {
|
||||
return op.emitOpError("expected first input element type to be shaped");
|
||||
}
|
||||
auto accumulatorType = op.accumulator().getType().cast<ShapedType>();
|
||||
auto inputType = op.input().getType().cast<ShapedType>();
|
||||
auto outputType = op.output().getType().cast<ShapedType>();
|
||||
ArrayRef<int64_t> inputShapes = inputType.getShape();
|
||||
ArrayRef<int64_t> outputShapes = outputType.getShape();
|
||||
if (accumulatorType.getElementType() != inputType.getElementType()) {
|
||||
return op.emitOpError(
|
||||
"expected input/accumulator element types to be identical");
|
||||
}
|
||||
ArrayRef<int64_t> accumulatorShape = accumulatorType.getShape();
|
||||
int64_t accumulatorRank = accumulatorType.getRank();
|
||||
if (accumulatorRank != inputType.getRank() - 1) {
|
||||
return op.emitOpError(
|
||||
"expected accumulator rank to be equal to input rank - 1");
|
||||
}
|
||||
SmallVector<int64_t> expectedAccumulatorShape;
|
||||
for (size_t i = 0; i < (size_t)inputType.getRank(); i++) {
|
||||
if (i != op.dimension())
|
||||
expectedAccumulatorShape.push_back(inputShapes[i]);
|
||||
}
|
||||
if (llvm::any_of(llvm::zip(expectedAccumulatorShape, accumulatorShape),
|
||||
[](std::tuple<int64_t, int64_t> s) {
|
||||
return std::get<0>(s) != ShapedType::kDynamicSize &&
|
||||
std::get<1>(s) != ShapedType::kDynamicSize &&
|
||||
std::get<0>(s) != std::get<1>(s);
|
||||
})) {
|
||||
return op.emitOpError("incompatible input/accumulator shapes");
|
||||
}
|
||||
if (inputType.getElementType() != outputType.getElementType()) {
|
||||
return op.emitOpError(
|
||||
"expected input/output element types to be identical");
|
||||
}
|
||||
if (inputShapes.size() != outputShapes.size()) {
|
||||
return op.emitOpError("expected input/output to have identical ranks");
|
||||
}
|
||||
if (llvm::any_of(llvm::zip(inputShapes, outputShapes),
|
||||
[](std::tuple<int64_t, int64_t> s) {
|
||||
return std::get<0>(s) != ShapedType::kDynamicSize &&
|
||||
std::get<1>(s) != ShapedType::kDynamicSize &&
|
||||
std::get<0>(s) != std::get<1>(s);
|
||||
})) {
|
||||
return op.emitOpError("incompatible input/output shapes");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<Range> ScanOp::getIterationDomain(OpBuilder &builder) {
|
||||
int64_t operandRank = getOperandRank();
|
||||
SmallVector<Range> loopBounds(operandRank);
|
||||
Location loc = getLoc();
|
||||
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
|
||||
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
|
||||
Value source = input();
|
||||
for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
|
||||
loopBounds[dim].offset = zero;
|
||||
loopBounds[dim].size = getDimValue(builder, loc, source, dim);
|
||||
loopBounds[dim].stride = one;
|
||||
}
|
||||
return loopBounds;
|
||||
}
|
||||
|
||||
SmallVector<StringRef> ScanOp::getLoopIteratorTypes() {
|
||||
SmallVector<StringRef> iteratorTypes(getOperandRank(),
|
||||
getParallelIteratorTypeName());
|
||||
iteratorTypes[dimension()] = getReductionIteratorTypeName();
|
||||
return iteratorTypes;
|
||||
}
|
||||
|
||||
// Generates naive scalar implementation of scan for a given operator f.
|
||||
// For inclusive,
|
||||
// output[0] = input[0]
|
||||
// output[i] = f(output[i-1], input[i])
|
||||
//
|
||||
// For exclusive,
|
||||
// output[0] = 0
|
||||
// output[i] = f(output[i-1], input[i-1])
|
||||
|
||||
LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc,
|
||||
ValueRange ivs) {
|
||||
SmallVector<Value> indices, scanBlkArgs;
|
||||
indices.append(ivs.begin(), ivs.end());
|
||||
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
||||
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
|
||||
uint64_t scanDim = dimension();
|
||||
Value cond = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
|
||||
indices[scanDim], zero);
|
||||
bool isInclusive = inclusive();
|
||||
SmallVector<Value> accIndices;
|
||||
for (size_t i = 0; i < indices.size(); i++) {
|
||||
if (i != scanDim)
|
||||
accIndices.push_back(indices[i]);
|
||||
}
|
||||
|
||||
auto scfIf = b.create<scf::IfOp>(
|
||||
loc, TypeRange{}, cond,
|
||||
[&](OpBuilder &b, Location loc) {
|
||||
if (isInclusive) {
|
||||
auto value = b.create<memref::LoadOp>(loc, input(), indices);
|
||||
b.create<memref::StoreOp>(loc, value, output(), indices);
|
||||
} else {
|
||||
auto value = b.create<memref::LoadOp>(loc, accumulator(), accIndices);
|
||||
b.create<memref::StoreOp>(loc, value, output(), indices);
|
||||
}
|
||||
b.create<scf::YieldOp>(loc);
|
||||
},
|
||||
[&](OpBuilder &b, Location loc) {
|
||||
SmallVector<Value> indices(ivs.begin(), ivs.end());
|
||||
Value iv = indices[scanDim];
|
||||
Value ivMinusOne = b.create<arith::SubIOp>(loc, iv, one);
|
||||
indices[scanDim] = ivMinusOne;
|
||||
scanBlkArgs.push_back(b.create<memref::LoadOp>(loc, output(), indices));
|
||||
Value i0;
|
||||
if (!isInclusive)
|
||||
i0 = b.create<memref::LoadOp>(loc, input(), indices);
|
||||
indices[scanDim] = iv;
|
||||
if (isInclusive)
|
||||
i0 = b.create<memref::LoadOp>(loc, input(), indices);
|
||||
scanBlkArgs.push_back(i0);
|
||||
});
|
||||
|
||||
auto &srcBlock = region().front();
|
||||
Region ®ion = scfIf.getElseRegion();
|
||||
BlockAndValueMapping bvm;
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(b);
|
||||
auto &block = region.front();
|
||||
b.setInsertionPointToEnd(&block);
|
||||
for (auto it : llvm::zip(srcBlock.getArguments(), scanBlkArgs)) {
|
||||
bvm.map(std::get<0>(it), std::get<1>(it));
|
||||
}
|
||||
for (auto &blockOp : srcBlock.without_terminator()) {
|
||||
b.clone(blockOp, bvm);
|
||||
}
|
||||
b.create<memref::StoreOp>(
|
||||
loc, bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)),
|
||||
output(), indices);
|
||||
b.create<memref::StoreOp>(
|
||||
loc, bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)),
|
||||
accumulator(), accIndices);
|
||||
b.create<scf::YieldOp>(loc);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult foldMemRefCast(Operation *op) {
|
||||
bool folded = false;
|
||||
for (OpOperand &operand : op->getOpOperands()) {
|
||||
auto castOp = operand.get().getDefiningOp<memref::CastOp>();
|
||||
if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
|
||||
operand.set(castOp.getOperand());
|
||||
folded = true;
|
||||
}
|
||||
}
|
||||
return success(folded);
|
||||
}
|
||||
|
||||
LogicalResult ScanOp::fold(ArrayRef<Attribute>,
|
||||
SmallVectorImpl<OpFoldResult> &) {
|
||||
return foldMemRefCast(*this);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ScatterOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
static LogicalResult verifyScatterOp(ScatterOp op) {
|
||||
if (op.inputs().size() != 2) {
|
||||
return op.emitOpError("expected two input operands");
|
||||
}
|
||||
if (op.outputs().size() != 1) {
|
||||
return op.emitOpError("expected one output operand");
|
||||
}
|
||||
auto checkDimensionsMatch = [&](ShapedType t1, ShapedType t2, unsigned dim) {
|
||||
return t1.getShape()[dim] == t2.getShape()[dim];
|
||||
};
|
||||
|
||||
auto indicesType = op.getIndicesType();
|
||||
if (indicesType.getRank() != 2 ||
|
||||
!indicesType.getElementType().isInteger(32)) {
|
||||
return op.emitOpError(
|
||||
"expected indices to be of rank 2 of i32 element type");
|
||||
}
|
||||
auto indexDepth = op.getIndexDepth();
|
||||
if (indexDepth == ShapedType::kDynamicSize) {
|
||||
return op.emitOpError("expected index depth is static");
|
||||
}
|
||||
|
||||
// The first dimension of the indices should match the first dimension of the
|
||||
// output. They indicate to the number of updates.
|
||||
auto updateType = op.getUpdateType();
|
||||
if (updateType.getRank() < 1) {
|
||||
return op.emitOpError("expected update value to be at least rank 1");
|
||||
}
|
||||
if (!checkDimensionsMatch(indicesType, updateType, 0)) {
|
||||
return op.emitOpError(
|
||||
"mismatch in shape of indices and update value at dim#0");
|
||||
}
|
||||
auto originalType = op.getOriginalType();
|
||||
// indexDepth + update dims should match to original dims. The first dim of
|
||||
// update is the number of updates.
|
||||
if (originalType.getRank() != indexDepth + updateType.getRank() - 1) {
|
||||
return op.emitOpError(
|
||||
"mismatch in rank of update value, index depth and original value");
|
||||
}
|
||||
for (auto dim : llvm::seq<unsigned>(indexDepth, originalType.getRank())) {
|
||||
// Offset one because the first dim is the number of updates.
|
||||
if (updateType.getDimSize(1 + dim - indexDepth) !=
|
||||
originalType.getDimSize(dim)) {
|
||||
return op.emitOpError("mismatch in shape of update value dim#")
|
||||
<< (1 + dim - indexDepth) << " and original value at dim#" << dim;
|
||||
}
|
||||
}
|
||||
Region ®ion = op.region();
|
||||
Block *body = ®ion.front();
|
||||
if (body->getNumArguments() != 2) {
|
||||
return op.emitOpError("expected region to have two arguments");
|
||||
}
|
||||
Type arg0Type = body->getArgument(0).getType();
|
||||
Type arg1Type = body->getArgument(1).getType();
|
||||
if (!arg0Type.isIntOrFloat() || !arg1Type.isIntOrFloat()) {
|
||||
return op.emitOpError(
|
||||
"expected region to have scalar argument of integer or float types");
|
||||
}
|
||||
if (arg0Type != updateType.getElementType()) {
|
||||
return op.emitOpError("mismatch in argument 0 of region ")
|
||||
<< arg0Type << " and element type of update value "
|
||||
<< updateType.getElementType();
|
||||
}
|
||||
if (arg1Type != originalType.getElementType()) {
|
||||
return op.emitOpError("mismatch in argument 1 of region ")
|
||||
<< arg1Type << " and element type of original value "
|
||||
<< originalType.getElementType();
|
||||
}
|
||||
if (arg0Type != arg1Type) {
|
||||
return op.emitOpError("mismatch in region argument types ")
|
||||
<< arg0Type << " and " << arg1Type;
|
||||
}
|
||||
auto yieldOp = cast<TMTensor::YieldOp>(body->getTerminator());
|
||||
if (yieldOp->getNumOperands() != 1) {
|
||||
return yieldOp.emitOpError("expected region to yield a single value");
|
||||
}
|
||||
auto yieldedType = yieldOp->getOperand(0).getType();
|
||||
if (yieldedType != arg0Type) {
|
||||
return yieldOp.emitOpError("mismatch in type of yielded value ")
|
||||
<< yieldedType << " and argument of the region " << arg0Type;
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<StringRef> ScatterOp::getLoopIteratorTypes() {
|
||||
SmallVector<StringRef> iteratorTypes(getUpdateType().getRank(),
|
||||
getParallelIteratorTypeName());
|
||||
return iteratorTypes;
|
||||
}
|
||||
|
||||
SmallVector<Range> ScatterOp::getIterationDomain(OpBuilder &builder) {
|
||||
Location loc = getLoc();
|
||||
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
|
||||
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
|
||||
SmallVector<Range> ranges;
|
||||
for (auto dim : llvm::seq<int64_t>(0, getUpdateType().getRank())) {
|
||||
Value ub = getDimValue(builder, loc, updates(), dim);
|
||||
ranges.emplace_back(Range{zero, ub, one});
|
||||
}
|
||||
return ranges;
|
||||
}
|
||||
|
||||
LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b,
|
||||
Location loc,
|
||||
ValueRange ivs) {
|
||||
auto indexDepth = getIndexDepth();
|
||||
Value update = b.create<memref::LoadOp>(loc, updates(), ivs);
|
||||
SmallVector<Value> starts;
|
||||
SmallVector<Value> loadIndices;
|
||||
loadIndices.push_back(ivs.front());
|
||||
loadIndices.push_back(Value());
|
||||
for (auto i : llvm::seq<unsigned>(0, indexDepth)) {
|
||||
loadIndices.back() = b.create<arith::ConstantIndexOp>(loc, i);
|
||||
Value idx = b.create<memref::LoadOp>(loc, indices(), loadIndices);
|
||||
starts.push_back(b.create<arith::IndexCastOp>(loc, b.getIndexType(), idx));
|
||||
}
|
||||
starts.append(std::next(ivs.begin()), ivs.end());
|
||||
Value init = b.create<memref::LoadOp>(loc, original(), starts);
|
||||
|
||||
BlockAndValueMapping bvm;
|
||||
Block &block = region().front();
|
||||
bvm.map(block.getArgument(0), update);
|
||||
bvm.map(block.getArgument(1), init);
|
||||
for (auto &blockOp : block.without_terminator()) {
|
||||
b.clone(blockOp, bvm);
|
||||
}
|
||||
// The last op is linalg_ext.yield op. Store the operand to
|
||||
// destination.
|
||||
b.create<memref::StoreOp>(
|
||||
loc, bvm.lookupOrDefault(block.getTerminator()->getOperand(0)),
|
||||
original(), starts);
|
||||
return success();
|
||||
}
|
||||
|
||||
#define DEFINE_OP_GET_EFFECTS(OP_NAME) \
|
||||
void OP_NAME::getEffects( \
|
||||
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> \
|
||||
&effects) { \
|
||||
SmallVector<Value> inputBuffers = getInputBufferOperands(); \
|
||||
SmallVector<Value> outputBuffers = getOutputBufferOperands(); \
|
||||
getEffectsImpl(effects, getOperation()->getResults(), inputBuffers, \
|
||||
outputBuffers); \
|
||||
}
|
||||
|
||||
DEFINE_OP_GET_EFFECTS(ScanOp)
|
||||
DEFINE_OP_GET_EFFECTS(ScatterOp)
|
||||
|
||||
namespace {
|
||||
/// This is derived from mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp without any
|
||||
/// changes.
|
||||
struct FoldTensorCastOp : public OpInterfaceRewritePattern<TMTensorOp> {
|
||||
using OpInterfaceRewritePattern<TMTensorOp>::OpInterfaceRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(TMTensorOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// If no operand comes from a tensor::CastOp and can be folded then fail.
|
||||
bool hasTensorCastOperand =
|
||||
llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
|
||||
if (opOperand->get().isa<BlockArgument>())
|
||||
return false;
|
||||
auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>();
|
||||
return castOp && canFoldIntoConsumerOp(castOp);
|
||||
});
|
||||
if (!hasTensorCastOperand)
|
||||
return failure();
|
||||
|
||||
SmallVector<Type, 4> newResultTypes;
|
||||
newResultTypes.reserve(op->getNumResults());
|
||||
SmallVector<Value, 4> newOperands;
|
||||
newOperands.reserve(op->getNumOperands());
|
||||
// Inputs may fold.
|
||||
for (OpOperand *opOperand : op.getInputOperands()) {
|
||||
auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
|
||||
newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp)
|
||||
? tensorCastOp.source()
|
||||
: opOperand->get());
|
||||
}
|
||||
// Init tensors may fold, in which case the resultType must also change.
|
||||
for (OpOperand *opOperand : op.getOutputOperands()) {
|
||||
auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
|
||||
bool fold = canFoldIntoConsumerOp(tensorCastOp);
|
||||
newOperands.push_back(fold ? tensorCastOp.getOperand()
|
||||
: opOperand->get());
|
||||
newResultTypes.push_back(newOperands.back().getType());
|
||||
}
|
||||
// Clone op.
|
||||
Operation *newOp =
|
||||
op.clone(rewriter, op->getLoc(), newResultTypes, newOperands);
|
||||
SmallVector<Value, 4> replacements;
|
||||
replacements.reserve(newOp->getNumResults());
|
||||
for (auto result : llvm::zip(op->getResults(), newOp->getResults())) {
|
||||
Value oldResult = std::get<0>(result);
|
||||
Value newResult = std::get<1>(result);
|
||||
if (newResult.getType() != oldResult.getType()) {
|
||||
replacements.push_back(rewriter.create<tensor::CastOp>(
|
||||
op->getLoc(), oldResult.getType(), newResult));
|
||||
} else {
|
||||
replacements.push_back(newResult);
|
||||
}
|
||||
}
|
||||
rewriter.replaceOp(op, replacements);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TMTensorDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void TMTensorDialect::getCanonicalizationPatterns(
|
||||
RewritePatternSet &results) const {
|
||||
results.add<FoldTensorCastOp>(getContext());
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.cpp.inc"
|
|
@ -0,0 +1,22 @@
|
|||
add_mlir_library(TorchMLIRTMTensorPasses
|
||||
ConvertToLoops.cpp
|
||||
Passes.cpp
|
||||
|
||||
DEPENDS
|
||||
TorchMLIRTMTensorTransformsPassesIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
TorchMLIRTMTensorDialect
|
||||
MLIRAffine
|
||||
MLIRIR
|
||||
MLIRLinalg
|
||||
MLIRLinalgTransforms
|
||||
MLIRMath
|
||||
MLIRMemRef
|
||||
MLIRPass
|
||||
MLIRSCF
|
||||
MLIRStandard
|
||||
MLIRSupport
|
||||
MLIRTensor
|
||||
MLIRTransforms
|
||||
)
|
|
@ -0,0 +1,117 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/PassDetail.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch::TMTensor;
|
||||
|
||||
/// Recursive method that lowers one dimension of the `ScalarLoopOpInterface` to
|
||||
/// scalar loops at a time.
|
||||
static LogicalResult lowerToLoopsImpl(OpBuilder &builder,
|
||||
ScalarLoopOpInterface scalarLoopOp,
|
||||
ArrayRef<Range> loopRanges,
|
||||
unsigned loopDepth,
|
||||
SmallVectorImpl<Value> &ivs) {
|
||||
Location loc = scalarLoopOp.getLoc();
|
||||
if (loopDepth == loopRanges.size()) {
|
||||
return scalarLoopOp.generateScalarImplementation(builder, loc, ivs);
|
||||
}
|
||||
LogicalResult status = success();
|
||||
builder.create<scf::ForOp>(
|
||||
loc, loopRanges[loopDepth].offset, loopRanges[loopDepth].size,
|
||||
loopRanges[loopDepth].stride, ValueRange{},
|
||||
[&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
|
||||
ivs.push_back(iv);
|
||||
status =
|
||||
lowerToLoopsImpl(b, scalarLoopOp, loopRanges, loopDepth + 1, ivs);
|
||||
b.create<scf::YieldOp>(loc);
|
||||
});
|
||||
return status;
|
||||
}
|
||||
|
||||
/// Main entry point for lowering `ScalarLoopOpInterface` op to loops.
|
||||
static LogicalResult lowerToLoops(OpBuilder &builder,
|
||||
ScalarLoopOpInterface scalarLoopOp) {
|
||||
SmallVector<Range> loopBounds = scalarLoopOp.getIterationDomain(builder);
|
||||
SmallVector<Value> ivs;
|
||||
return lowerToLoopsImpl(builder, scalarLoopOp, loopBounds, 0, ivs);
|
||||
}
|
||||
|
||||
/// Pattern rewriter hook to lower a `ScalarLoopOpInterface` to loops.
|
||||
namespace {
|
||||
struct ScalarLoopOpInterfaceLowerToLoopsPattern : public RewritePattern {
|
||||
ScalarLoopOpInterfaceLowerToLoopsPattern(MLIRContext *context,
|
||||
PatternBenefit benefit = 1)
|
||||
: RewritePattern(MatchAnyOpTypeTag(), benefit, context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto scalarLoopOp = dyn_cast<ScalarLoopOpInterface>(op);
|
||||
if (!scalarLoopOp) {
|
||||
return failure();
|
||||
}
|
||||
if (llvm::any_of(scalarLoopOp->getResults(),
|
||||
[&](Value v) { return v.getType().isa<ShapedType>(); })) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
scalarLoopOp, "lower to loops needs to have tensor semantics");
|
||||
}
|
||||
if (failed(lowerToLoops(rewriter, scalarLoopOp))) {
|
||||
return failure();
|
||||
}
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pass
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
struct TMTensorToLoopsPass : public TMTensorToLoopsBase<TMTensorToLoopsPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<linalg::LinalgDialect, StandardOpsDialect,
|
||||
mlir::arith::ArithmeticDialect, math::MathDialect,
|
||||
memref::MemRefDialect, scf::SCFDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.insert<ScalarLoopOpInterfaceLowerToLoopsPattern>(context);
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation(),
|
||||
std::move(patterns)))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
torch::TMTensor::createTMTensorToLoopsPass() {
|
||||
return std::make_unique<TMTensorToLoopsPass>();
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h"
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
namespace TMTensor {
|
||||
|
||||
namespace detail {
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h.inc" // IWYU pragma: export
|
||||
} // namespace detail
|
||||
|
||||
} // namespace TMTensor
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
||||
void torch::TMTensor::registerPasses() {
|
||||
torch::TMTensor::detail::registerPasses();
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
configure_lit_site_cfg(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py
|
||||
MAIN_CONFIG
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py
|
||||
)
|
||||
|
||||
set(TORCH_MLIR_DIALECTS_TEST_DEPENDS
|
||||
FileCheck count not
|
||||
torch-mlir-dialects-opt
|
||||
)
|
||||
|
||||
add_lit_testsuite(check-torch-mlir-dialects "Running the torch-mlir-dialects regression tests"
|
||||
${CMAKE_CURRENT_BINARY_DIR}
|
||||
DEPENDS ${TORCH_MLIR_DIALECTS_TEST_DEPENDS}
|
||||
)
|
||||
set_target_properties(check-torch-mlir-dialects PROPERTIES FOLDER "Tests")
|
||||
|
||||
add_lit_testsuites(TORCH_MLIR_DIALECTS ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TORCH_MLIR_DIALECTS_TEST_DEPENDS})
|
|
@ -0,0 +1,69 @@
|
|||
# -*- Python -*-
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
import lit.formats
|
||||
import lit.util
|
||||
|
||||
from lit.llvm import llvm_config
|
||||
from lit.llvm.subst import ToolSubst
|
||||
from lit.llvm.subst import FindTool
|
||||
|
||||
# Configuration file for the 'lit' test runner.
|
||||
|
||||
# name: The name of this test suite.
|
||||
config.name = 'TORCH_MLIR_DIALECTS'
|
||||
|
||||
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
|
||||
|
||||
# suffixes: A list of file extensions to treat as test files.
|
||||
config.suffixes = ['.mlir', '.py']
|
||||
|
||||
# test_source_root: The root path where tests are located.
|
||||
config.test_source_root = os.path.dirname(__file__)
|
||||
|
||||
# test_exec_root: The root path where tests should be run.
|
||||
config.test_exec_root = os.path.join(config.torch_mlir_dialects_obj_root, 'test')
|
||||
|
||||
config.substitutions.append(('%PATH%', config.environment['PATH']))
|
||||
config.substitutions.append(('%shlibext', config.llvm_shlib_ext))
|
||||
config.substitutions.append(
|
||||
('%resources_dir', os.path.join(config.torch_mlir_dialects_obj_root,
|
||||
'resources')))
|
||||
|
||||
llvm_config.with_system_environment(['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP'])
|
||||
|
||||
#llvm_config.use_default_substitutions()
|
||||
|
||||
# excludes: A list of directories to exclude from the testsuite. The 'Inputs'
|
||||
# subdirectories contain auxiliary inputs for various tests in their parent
|
||||
# directories.
|
||||
config.excludes = [
|
||||
'Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt',
|
||||
'lit.cfg.py', 'lit.site.cfg.py'
|
||||
]
|
||||
|
||||
# test_source_root: The root path where tests are located.
|
||||
config.test_source_root = os.path.dirname(__file__)
|
||||
|
||||
# test_exec_root: The root path where tests should be run.
|
||||
config.test_exec_root = os.path.join(config.torch_mlir_dialects_obj_root, 'test')
|
||||
config.standalone_tools_dir = os.path.join(config.torch_mlir_dialects_obj_root, 'bin')
|
||||
|
||||
# Tweak the PATH to include the tools dir.
|
||||
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
|
||||
|
||||
tool_dirs = [config.llvm_tools_dir]
|
||||
tools = [
|
||||
ToolSubst('%PYTHON', config.python_executable, unresolved='ignore'),
|
||||
]
|
||||
|
||||
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
|
@ -0,0 +1,26 @@
|
|||
# -*- Python -*-
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
@LIT_SITE_CFG_IN_HEADER@
|
||||
|
||||
import sys
|
||||
|
||||
config.torch_mlir_dialects_obj_root = "@TORCH_MLIR_DIALECTS_BINARY_DIR@"
|
||||
config.llvm_src_root = "@LLVM_SOURCE_DIR@"
|
||||
config.llvm_obj_root = "@LLVM_BINARY_DIR@"
|
||||
config.llvm_tools_dir = "@LLVM_TOOLS_DIR@"
|
||||
config.llvm_lib_dir = "@LLVM_LIBS_DIR@"
|
||||
config.llvm_shlib_dir = "@SHLIBDIR@"
|
||||
config.llvm_shlib_ext = "@SHLIBEXT@"
|
||||
config.llvm_exe_ext = "@EXEEXT@"
|
||||
config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@"
|
||||
config.python_executable = sys.executable
|
||||
|
||||
import lit.llvm
|
||||
lit.llvm.initialize(lit_config, config)
|
||||
|
||||
# Let the main config do the real work.
|
||||
lit_config.load_config(config, "@TORCH_MLIR_DIALECTS_SOURCE_DIR@/test/lit.cfg.py")
|
|
@ -0,0 +1,24 @@
|
|||
// RUN: torch-mlir-dialects-opt -canonicalize -split-input-file %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @tensor.cast(
|
||||
func @tensor.cast(%arg0: tensor<128xi32>) -> tensor<128xi32> {
|
||||
%init = linalg.init_tensor [128] : tensor<128xi32>
|
||||
%c0 = linalg.init_tensor [] : tensor<i32>
|
||||
|
||||
%casted_arg0 = tensor.cast %arg0 : tensor<128xi32> to tensor<?xi32>
|
||||
%casted_init = tensor.cast %init : tensor<128xi32> to tensor<?xi32>
|
||||
// CHECK: tm_tensor.scan
|
||||
// CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<128xi32>)
|
||||
// CHECK-SAME: outs(%{{[a-zA-Z0-9]*}}, %{{[a-zA-Z0-9]*}} : tensor<128xi32>, tensor<i32>)
|
||||
%0, %1 = tm_tensor.scan dimension(0) inclusive(true)
|
||||
ins(%casted_arg0 : tensor<?xi32>)
|
||||
outs(%casted_init, %c0: tensor<?xi32>, tensor<i32>) {
|
||||
^bb0(%barg0 : i32, %barg1 : i32, %barg2 : i32):
|
||||
%sum = arith.addi %barg0, %barg1 : i32
|
||||
tm_tensor.yield %sum : i32
|
||||
} -> tensor<?xi32>, tensor<i32>
|
||||
|
||||
%2 = tensor.cast %0: tensor<?xi32> to tensor<128xi32>
|
||||
|
||||
return %2: tensor<128xi32>
|
||||
}
|
|
@ -0,0 +1,332 @@
|
|||
// RUN: torch-mlir-dialects-opt -split-input-file -torch-mlir-tm-tensor-to-loops %s | FileCheck %s
|
||||
|
||||
func @scan_1d_inclusive(%0: memref<128xi32>, %1: memref<128xi32>) {
|
||||
%c0 = memref.alloc() : memref<i32>
|
||||
tm_tensor.scan dimension(0) inclusive(true)
|
||||
ins(%0 : memref<128xi32>) outs(%1, %c0 : memref<128xi32>, memref<i32>) {
|
||||
^bb0(%arg0 : i32, %arg1 : i32):
|
||||
%sum = arith.addi %arg0, %arg1 : i32
|
||||
tm_tensor.yield %sum : i32
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @scan_1d_inclusive
|
||||
// CHECK-SAME: %[[BUFI:[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[BUFO:[a-zA-Z0-9]+]]
|
||||
// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[ACC:.+]] = memref.alloc() : memref<i32>
|
||||
// CHECK: scf.for %[[ARG1:.+]] = %[[C0]] to %[[C128]] step %[[C1]]
|
||||
// CHECK: %[[COND:.+]] = arith.cmpi eq, %[[ARG1]], %[[C0]] : index
|
||||
// CHECK: scf.if %[[COND]] {
|
||||
// CHECK: %[[V1:.+]] = memref.load %[[BUFI]][%[[ARG1]]]
|
||||
// CHECK: memref.store %[[V1]], %[[BUFO]][%[[ARG1]]]
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[T1:.+]] = arith.subi %[[ARG1]], %[[C1]] : index
|
||||
// CHECK: %[[V2:.+]] = memref.load %[[BUFO]][%[[T1]]]
|
||||
// CHECK: %[[V3:.+]] = memref.load %[[BUFI]][%[[ARG1]]]
|
||||
// CHECK: %[[V4:.+]] = arith.addi %[[V2]], %[[V3]] : i32
|
||||
// CHECK: memref.store %[[V4]], %[[BUFO]][%[[ARG1]]]
|
||||
// CHECK: memref.store %[[V4]], %[[ACC]][]
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
|
||||
func @scan_1d_exclusive(%0: memref<128xi32>, %1: memref<128xi32>) {
|
||||
%c0 = memref.alloc() : memref<i32>
|
||||
tm_tensor.scan dimension(0) inclusive(false)
|
||||
ins(%0 : memref<128xi32>) outs(%1, %c0 : memref<128xi32>, memref<i32>) {
|
||||
^bb0(%arg0 : i32, %arg1 : i32):
|
||||
%sum = arith.addi %arg0, %arg1 : i32
|
||||
tm_tensor.yield %sum : i32
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @scan_1d_exclusive
|
||||
// CHECK-SAME: %[[BUFI:[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[BUFO:[a-zA-Z0-9]+]]
|
||||
// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[ACC:.+]] = memref.alloc() : memref<i32>
|
||||
// CHECK: scf.for %[[ARG1:.+]] = %[[C0]] to %[[C128]] step %[[C1]]
|
||||
// CHECK: %[[COND:.+]] = arith.cmpi eq, %[[ARG1]], %[[C0]] : index
|
||||
// CHECK: scf.if %[[COND]] {
|
||||
// CHECK: %[[V0:.+]] = memref.load %[[ACC]][] : memref<i32>
|
||||
// CHECK: memref.store %[[V0]], %[[BUFO]][%[[ARG1]]]
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[T1:.+]] = arith.subi %[[ARG1]], %[[C1]] : index
|
||||
// CHECK: %[[V2:.+]] = memref.load %[[BUFO]][%[[T1]]]
|
||||
// CHECK: %[[V3:.+]] = memref.load %[[BUFI]][%[[T1]]]
|
||||
// CHECK: %[[V4:.+]] = arith.addi %[[V2]], %[[V3]] : i32
|
||||
// CHECK: memref.store %[[V4]], %[[BUFO]][%[[ARG1]]]
|
||||
// CHECK: memref.store %[[V4]], %[[ACC]][]
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
|
||||
func @scan_2d(%0: memref<16x32xi32>, %1: memref<16x32xi32>) {
|
||||
%t0 = memref.alloc() : memref<32xi32>
|
||||
tm_tensor.scan dimension(0) inclusive(true)
|
||||
ins(%0 : memref<16x32xi32>) outs(%1, %t0 : memref<16x32xi32>, memref<32xi32>) {
|
||||
^bb0(%arg0 : i32, %arg1 : i32):
|
||||
%sum = arith.addi %arg0, %arg1 : i32
|
||||
tm_tensor.yield %sum : i32
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @scan_2d
|
||||
// CHECK-SAME: %[[BUFI:[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[BUFO:[a-zA-Z0-9]+]]
|
||||
// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
|
||||
// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[ACC:.+]] = memref.alloc() : memref<32xi32>
|
||||
// CHECK: scf.for %[[ARG1:.+]] = %[[C0]] to %[[C16]] step %[[C1]]
|
||||
// CHECK: scf.for %[[ARG2:.+]] = %[[C0]] to %[[C32]] step %[[C1]]
|
||||
// CHECK: %[[COND:.+]] = arith.cmpi eq, %[[ARG1]], %[[C0]] : index
|
||||
// CHECK: scf.if %[[COND]] {
|
||||
// CHECK: %[[V1:.+]] = memref.load %[[BUFI]][%[[ARG1]], %[[ARG2]]]
|
||||
// CHECK: memref.store %[[V1]], %[[BUFO]][%[[ARG1]], %[[ARG2]]]
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[T1:.+]] = arith.subi %[[ARG1]], %[[C1]] : index
|
||||
// CHECK: %[[V2:.+]] = memref.load %[[BUFO]][%[[T1]], %[[ARG2]]]
|
||||
// CHECK: %[[V3:.+]] = memref.load %[[BUFI]][%[[ARG1]], %[[ARG2]]]
|
||||
// CHECK: %[[V4:.+]] = arith.addi %[[V2]], %[[V3]] : i32
|
||||
// CHECK: memref.store %[[V4]], %[[BUFO]][%[[ARG1]], %[[ARG2]]]
|
||||
// CHECK: memref.store %[[V4]], %[[ACC]][%[[ARG2]]]
|
||||
// CHECK: }
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func @scatter_update_scalar_1D(
|
||||
%original: memref<8xi32>, %indices: memref<3x1xi32>,
|
||||
%updates: memref<3xi32>) {
|
||||
tm_tensor.scatter
|
||||
ins(%updates, %indices : memref<3xi32>, memref<3x1xi32>)
|
||||
outs(%original : memref<8xi32>) {
|
||||
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||
tm_tensor.yield %arg0 : i32
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @scatter_update_scalar_1D
|
||||
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
|
||||
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C3]] step %[[C1]] {
|
||||
// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<3xi32>
|
||||
// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<3x1xi32>
|
||||
// CHECK: %[[IDX:.+]] = arith.index_cast %[[T2]] : i32 to index
|
||||
// CHECK: memref.store %[[T1]], %[[ORIGINAL]][%[[IDX]]]
|
||||
|
||||
// -----
|
||||
|
||||
func @scatter_add_scalar_2D(
|
||||
%original: memref<4x3xi32>, %indices: memref<3x2xi32>,
|
||||
%updates: memref<3xi32>) {
|
||||
tm_tensor.scatter
|
||||
ins(%updates, %indices : memref<3xi32>, memref<3x2xi32>)
|
||||
outs(%original : memref<4x3xi32>) {
|
||||
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||
%0 = arith.addi %arg1, %arg0 : i32
|
||||
tm_tensor.yield %0 : i32
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @scatter_add_scalar_2D
|
||||
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
|
||||
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C3]] step %[[C1]] {
|
||||
// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<3xi32>
|
||||
// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<3x2xi32>
|
||||
// CHECK: %[[IDX1:.+]] = arith.index_cast %[[T2]] : i32 to index
|
||||
// CHECK: %[[T3:.+]] = memref.load %[[INDICES]][%[[I]], %[[C1]]] : memref<3x2xi32>
|
||||
// CHECK: %[[IDX2:.+]] = arith.index_cast %[[T3]] : i32 to index
|
||||
// CHECK: %[[ORI:.+]] = memref.load %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]] : memref<4x3xi32>
|
||||
// CHECK: %[[ADD:.+]] = arith.addi %[[ORI]], %[[T1]] : i32
|
||||
// CHECK: memref.store %[[ADD]], %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]]
|
||||
|
||||
// -----
|
||||
|
||||
func @scatter_update_slice_2D(
|
||||
%original: memref<4x3xi32>, %indices: memref<2x1xi32>,
|
||||
%updates: memref<2x3xi32>) {
|
||||
tm_tensor.scatter
|
||||
ins(%updates, %indices : memref<2x3xi32>, memref<2x1xi32>)
|
||||
outs(%original : memref<4x3xi32>) {
|
||||
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||
tm_tensor.yield %arg0 : i32
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK: func @scatter_update_slice_2D
|
||||
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
|
||||
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
|
||||
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C2]] step %[[C1]] {
|
||||
// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[C3]] step %[[C1]] {
|
||||
// CHECK: %[[UPDATE:.+]] = memref.load %[[UPDATES]][%[[I]], %[[J]]]
|
||||
// CHECK: %[[INDEX:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]]
|
||||
// CHECK: %[[LOC:.+]] = arith.index_cast %[[INDEX]] : i32 to index
|
||||
// CHECK: memref.store %[[UPDATE]], %[[ORIGINAL]][%[[LOC]], %[[J]]]
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
|
||||
func @scatter_add_scalar_1D(
|
||||
%original: memref<8xi32>, %indices: memref<3x1xi32>,
|
||||
%updates: memref<3xi32>) {
|
||||
tm_tensor.scatter
|
||||
ins(%updates, %indices : memref<3xi32>, memref<3x1xi32>)
|
||||
outs(%original : memref<8xi32>) {
|
||||
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||
%0 = arith.addi %arg1, %arg0 : i32
|
||||
tm_tensor.yield %0 : i32
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @scatter_add_scalar_1D
|
||||
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
|
||||
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C3]] step %[[C1]] {
|
||||
// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<3xi32>
|
||||
// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<3x1xi32>
|
||||
// CHECK: %[[IDX:.+]] = arith.index_cast %[[T2]] : i32 to index
|
||||
// CHECK: %[[ORI:.+]] = memref.load %[[ORIGINAL]][%[[IDX]]] : memref<8xi32>
|
||||
// CHECK: %[[ADD:.+]] = arith.addi %[[ORI]], %[[T1]] : i32
|
||||
// CHECK: memref.store %[[ADD]], %[[ORIGINAL]][%[[IDX]]]
|
||||
|
||||
// -----
|
||||
|
||||
func @scatter_add_slice_2D(
|
||||
%original: memref<4x3xi32>, %indices: memref<2x1xi32>,
|
||||
%updates: memref<2x3xi32>) {
|
||||
tm_tensor.scatter
|
||||
ins(%updates, %indices : memref<2x3xi32>, memref<2x1xi32>)
|
||||
outs(%original : memref<4x3xi32>) {
|
||||
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||
%0 = arith.addi %arg1, %arg0 : i32
|
||||
tm_tensor.yield %0 : i32
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK: func @scatter_add_slice_2D
|
||||
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
|
||||
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C2]] step %[[C1]] {
|
||||
// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[C3]] step %[[C1]] {
|
||||
// CHECK: %[[UPDATEVAL:.+]] = memref.load %[[UPDATES]][%[[I]], %[[J]]]
|
||||
// CHECK: %[[INDEXVAL:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]]
|
||||
// CHECK: %[[INDEX:.+]] = arith.index_cast %[[INDEXVAL]] : i32 to index
|
||||
// CHECK: %[[ORIGINALVAL:.+]] = memref.load %[[ORIGINAL]][%[[INDEX]], %[[J]]]
|
||||
// CHECK: %[[STOREVAL:.+]] = arith.addi %[[ORIGINALVAL]], %[[UPDATEVAL]]
|
||||
// CHECK: memref.store %[[STOREVAL]], %[[ORIGINAL]][%[[INDEX]], %[[J]]]
|
||||
|
||||
// -----
|
||||
|
||||
func @scatter_update_scalar_dynamic_1D(
|
||||
%original: memref<?xi32>, %indices: memref<?x1xi32>,
|
||||
%updates: memref<?xi32>) {
|
||||
tm_tensor.scatter
|
||||
ins(%updates, %indices : memref<?xi32>, memref<?x1xi32>)
|
||||
outs(%original : memref<?xi32>) {
|
||||
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||
tm_tensor.yield %arg0 : i32
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @scatter_update_scalar_dynamic_1D
|
||||
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[UB:.+]] = memref.dim %[[UPDATES]], %[[C0]] : memref<?xi32>
|
||||
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[UB]] step %[[C1]] {
|
||||
// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<?xi32>
|
||||
// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<?x1xi32>
|
||||
// CHECK: %[[IDX:.+]] = arith.index_cast %[[T2]] : i32 to index
|
||||
// CHECK: memref.store %[[T1]], %[[ORIGINAL]][%[[IDX]]]
|
||||
|
||||
// -----
|
||||
|
||||
func @scatter_add_scalar_dynamic_2D(
|
||||
%original: memref<?x?xi32>, %indices: memref<?x2xi32>,
|
||||
%updates: memref<?xi32>) {
|
||||
tm_tensor.scatter
|
||||
ins(%updates, %indices : memref<?xi32>, memref<?x2xi32>)
|
||||
outs(%original : memref<?x?xi32>) {
|
||||
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||
%0 = arith.addi %arg1, %arg0 : i32
|
||||
tm_tensor.yield %0 : i32
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @scatter_add_scalar_dynamic_2D
|
||||
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[UB:.+]] = memref.dim %[[UPDATES]], %[[C0]] : memref<?xi32>
|
||||
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[UB]] step %[[C1]] {
|
||||
// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<?xi32>
|
||||
// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<?x2xi32>
|
||||
// CHECK: %[[IDX1:.+]] = arith.index_cast %[[T2]] : i32 to index
|
||||
// CHECK: %[[T3:.+]] = memref.load %[[INDICES]][%[[I]], %[[C1]]] : memref<?x2xi32>
|
||||
// CHECK: %[[IDX2:.+]] = arith.index_cast %[[T3]] : i32 to index
|
||||
// CHECK: %[[ORI:.+]] = memref.load %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]] : memref<?x?xi32>
|
||||
// CHECK: %[[ADD:.+]] = arith.addi %[[ORI]], %[[T1]] : i32
|
||||
// CHECK: memref.store %[[ADD]], %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]]
|
||||
|
||||
// -----
|
||||
|
||||
func @scatter_update_slice_dynamic_2D(
|
||||
%original: memref<?x?xi32>, %indices: memref<?x1xi32>,
|
||||
%updates: memref<?x?xi32>) {
|
||||
tm_tensor.scatter
|
||||
ins(%updates, %indices : memref<?x?xi32>, memref<?x1xi32>)
|
||||
outs(%original : memref<?x?xi32>) {
|
||||
^bb0(%arg0: i32, %arg1: i32): // no predecessors
|
||||
tm_tensor.yield %arg0 : i32
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK: func @scatter_update_slice_dynamic_2D
|
||||
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
|
||||
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[UB1:.+]] = memref.dim %[[UPDATES]], %[[C0]] : memref<?x?xi32>
|
||||
// CHECK-DAG: %[[UB2:.+]] = memref.dim %[[UPDATES]], %[[C1]] : memref<?x?xi32>
|
||||
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[UB1]] step %[[C1]] {
|
||||
// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[UB2]] step %[[C1]] {
|
||||
// CHECK: %[[UPDATEVAL:.+]] = memref.load %[[UPDATES]][%[[I]], %[[J]]]
|
||||
// CHECK: %[[INDEXVAL:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]]
|
||||
// CHECK: %[[INDEX:.+]] = arith.index_cast %[[INDEXVAL]] : i32 to index
|
||||
// CHECK: memref.store %[[UPDATEVAL]], %[[ORIGINAL]][%[[INDEX]], %[[J]]]
|
|
@ -0,0 +1 @@
|
|||
add_subdirectory(torch-mlir-dialects-opt)
|
|
@ -0,0 +1,22 @@
|
|||
set(LIBS
|
||||
MLIRArithmetic
|
||||
MLIRDialect
|
||||
MLIRLinalg
|
||||
MLIRMemRef
|
||||
MLIROptLib
|
||||
MLIRSCF
|
||||
MLIRSCFTransforms
|
||||
MLIRStandard
|
||||
MLIRTensor
|
||||
MLIRTransforms
|
||||
TorchMLIRTMTensorDialect
|
||||
TorchMLIRTMTensorPasses
|
||||
)
|
||||
|
||||
add_llvm_tool(torch-mlir-dialects-opt
|
||||
torch-mlir-dialects-opt.cpp
|
||||
|
||||
DEPENDS
|
||||
${LIBS}
|
||||
)
|
||||
target_link_libraries(torch-mlir-dialects-opt PRIVATE ${LIBS})
|
|
@ -0,0 +1,49 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/Passes.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/AsmState.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Support/MlirOptMain.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
registerAsmPrinterCLOptions();
|
||||
registerMLIRContextCLOptions();
|
||||
|
||||
registerTransformsPasses();
|
||||
registerSCFPasses();
|
||||
|
||||
// Local dialects.
|
||||
mlir::torch::TMTensor::registerPasses();
|
||||
|
||||
DialectRegistry registry;
|
||||
registry.insert<
|
||||
// Local dialects
|
||||
mlir::torch::TMTensor::TMTensorDialect,
|
||||
// Upstream dialects
|
||||
mlir::arith::ArithmeticDialect, mlir::linalg::LinalgDialect,
|
||||
mlir::memref::MemRefDialect, mlir::StandardOpsDialect,
|
||||
mlir::scf::SCFDialect, mlir::tensor::TensorDialect>();
|
||||
|
||||
return mlir::asMainReturnCode(
|
||||
mlir::MlirOptMain(argc, argv, "MLIR modular optimizer driver\n", registry,
|
||||
/*preloadDialectsInContext=*/false));
|
||||
}
|
|
@ -19,6 +19,9 @@ add_mlir_library(TorchMLIRInitAll
|
|||
TorchMLIRTorchDialect
|
||||
TorchMLIRTorchConversionPasses
|
||||
|
||||
TorchMLIRTMTensorPasses
|
||||
TorchMLIRTMTensorDialect
|
||||
|
||||
TorchMLIRConversionPasses
|
||||
TorchMLIRRefBackend
|
||||
)
|
||||
|
|
|
@ -10,6 +10,8 @@
|
|||
#include "torch-mlir/InitAll.h"
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h"
|
||||
#include "torch-mlir/Conversion/Passes.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
|
@ -20,6 +22,7 @@
|
|||
void mlir::torch::registerAllDialects(mlir::DialectRegistry ®istry) {
|
||||
registry.insert<mlir::torch::Torch::TorchDialect>();
|
||||
registry.insert<mlir::torch::TorchConversion::TorchConversionDialect>();
|
||||
registry.insert<mlir::torch::TMTensor::TMTensorDialect>();
|
||||
}
|
||||
|
||||
void mlir::torch::registerAllPasses() {
|
||||
|
@ -28,4 +31,5 @@ void mlir::torch::registerAllPasses() {
|
|||
|
||||
mlir::torch::registerConversionPasses();
|
||||
mlir::torch::RefBackend::registerRefBackendPasses();
|
||||
mlir::torch::TMTensor::registerPasses();
|
||||
}
|
||||
|
|
3
setup.py
3
setup.py
|
@ -62,8 +62,9 @@ class CMakeBuild(build_py):
|
|||
f"-DLLVM_TARGETS_TO_BUILD=host",
|
||||
f"-DMLIR_ENABLE_BINDINGS_PYTHON=ON",
|
||||
f"-DLLVM_ENABLE_PROJECTS=mlir",
|
||||
f"-DLLVM_EXTERNAL_PROJECTS=torch-mlir",
|
||||
f"-DLLVM_EXTERNAL_PROJECTS=torch-mlir;torch-mlir-dialects",
|
||||
f"-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR={src_dir}",
|
||||
f"-DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR={src_dir}/external/llvm-external-projects/torch-mlir-dialects",
|
||||
# Optimization options for building wheels.
|
||||
f"-DCMAKE_VISIBILITY_INLINES_HIDDEN=ON",
|
||||
f"-DCMAKE_C_VISIBILITY_PRESET=hidden",
|
||||
|
|
Loading…
Reference in New Issue