Add TMTensor dialect to torch-mlir

This is intended to explore support for non-structured ops that can't
be modeled by Linalg dialect. `tm_tensor.scan` and `tm_tensor.scatter`
are added as the first such ops. The dialect should aim to be
upstreamed in the future.
pull/603/head
Yi Zhang 2022-02-02 18:01:38 -05:00
parent cd21dda867
commit 869daf3c22
45 changed files with 2532 additions and 4 deletions

View File

@ -53,8 +53,9 @@ jobs:
-DPython3_EXECUTABLE=$(which python) \ -DPython3_EXECUTABLE=$(which python) \
-DLLVM_ENABLE_ASSERTIONS=ON \ -DLLVM_ENABLE_ASSERTIONS=ON \
-DLLVM_ENABLE_PROJECTS=mlir \ -DLLVM_ENABLE_PROJECTS=mlir \
-DLLVM_EXTERNAL_PROJECTS=torch-mlir \ -DLLVM_EXTERNAL_PROJECTS="torch-mlir;torch-mlir-dialects" \
-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$GITHUB_WORKSPACE" \ -DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$GITHUB_WORKSPACE" \
-DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR="${GITHUB_WORKSPACE}/external/llvm-external-projects/torch-mlir-dialects" \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DLLVM_TARGETS_TO_BUILD=host -DLLVM_TARGETS_TO_BUILD=host
ninja check-torch-mlir-all ninja check-torch-mlir-all

View File

@ -25,6 +25,22 @@ project(torch-mlir LANGUAGES CXX C)
set(CMAKE_C_STANDARD 11) set(CMAKE_C_STANDARD 11)
set(CMAKE_CXX_STANDARD 14) set(CMAKE_CXX_STANDARD 14)
macro(torch_mlir_add_llvm_external_project name identifier location)
message(STATUS "Adding LLVM external project ${name} (${identifier}) -> ${location}")
if(NOT EXISTS "${location}/CMakeLists.txt")
message(FATAL_ERROR "External project location ${location} is not valid")
endif()
list(APPEND LLVM_EXTERNAL_PROJECTS ${name})
list(REMOVE_DUPLICATES LLVM_EXTERNAL_PROJECTS)
set(LLVM_EXTERNAL_${identifier}_SOURCE_DIR ${location} CACHE STRING "" FORCE)
set(LLVM_EXTERNAL_PROJECTS ${LLVM_EXTERNAL_PROJECTS} CACHE STRING "" FORCE)
endmacro()
torch_mlir_add_llvm_external_project(
torch-mlir-dialects
TORCH_MLIR_DIALECTS
${CMAKE_CURRENT_SOURCE_DIR}/external/llvm-external-projects/torch-mlir-dialects)
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
# Out-of-tree build # Out-of-tree build
@ -129,6 +145,7 @@ add_subdirectory(tools)
add_custom_target(check-torch-mlir-all) add_custom_target(check-torch-mlir-all)
add_dependencies(check-torch-mlir-all add_dependencies(check-torch-mlir-all
check-torch-mlir check-torch-mlir
check-torch-mlir-dialects
) )
if(MLIR_ENABLE_BINDINGS_PYTHON) if(MLIR_ENABLE_BINDINGS_PYTHON)

View File

@ -62,8 +62,9 @@ cmake -GNinja -Bbuild \
-DCMAKE_CXX_COMPILER=clang++ \ -DCMAKE_CXX_COMPILER=clang++ \
-DPython3_FIND_VIRTUALENV=ONLY \ -DPython3_FIND_VIRTUALENV=ONLY \
-DLLVM_ENABLE_PROJECTS=mlir \ -DLLVM_ENABLE_PROJECTS=mlir \
-DLLVM_EXTERNAL_PROJECTS=torch-mlir \ -DLLVM_EXTERNAL_PROJECTS=torch-mlir;torch-mlir-dialects \
-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR=`pwd` \ -DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR=`pwd` \
-DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR=`pwd`/external/llvm-external-projects/torch-mlir-dialects \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DLLVM_TARGETS_TO_BUILD=host \ -DLLVM_TARGETS_TO_BUILD=host \
external/llvm-project/llvm external/llvm-project/llvm

View File

@ -19,8 +19,9 @@ cmake -GNinja -B"$build_dir" "$llvm_project_dir/llvm" \
-DCMAKE_BUILD_TYPE=Release \ -DCMAKE_BUILD_TYPE=Release \
-DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
-DLLVM_ENABLE_PROJECTS=mlir \ -DLLVM_ENABLE_PROJECTS=mlir \
-DLLVM_EXTERNAL_PROJECTS=torch-mlir \ -DLLVM_EXTERNAL_PROJECTS=torch-mlir;torch-mlir-dialects \
-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$project_dir" \ -DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$project_dir" \
-DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR=${project_dir}/external/llvm-external-projects/torch-mlir-dialects \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DLLVM_ENABLE_ASSERTIONS=ON \ -DLLVM_ENABLE_ASSERTIONS=ON \
-DLLVM_TARGETS_TO_BUILD=host -DLLVM_TARGETS_TO_BUILD=host

View File

@ -0,0 +1,54 @@
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
message(FATAL_ERROR
"This project is intended to be built as part of LLVM via "
"-DLLVM_EXTERNAL_PROJECTS=torch-mlir-dialects "
"-DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR=${CMAKE_CURRENT_SOURCE_DIR}")
endif()
option(MLIR_ENABLE_BINDINGS_PYTHON "Enables MLIR Python Bindings" OFF)
set(TORCH_MLIR_DIALECTS_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
set(TORCH_MLIR_DIALECTS_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")
message(STATUS "Building torch-mlir-dialects project at ${TORCH_MLIR_DIALECTS_SOURCE_DIR} (into ${TORCH_MLIR_DIALECTS_BINARY_DIR})")
# TODO: Fix this upstream so that global include directories are not needed.
set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir)
set(MLIR_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include)
set(MLIR_GENERATED_INCLUDE_DIR ${LLVM_BINARY_DIR}/tools/mlir/include)
# TODO: Needed for tablegen. Remove.
include_directories(SYSTEM ${MLIR_INCLUDE_DIR})
include_directories(SYSTEM ${MLIR_GENERATED_INCLUDE_DIR})
include_directories(SYSTEM ${TORCH_MLIR_DIALECTS_SOURCE_DIR}/include)
function(torch_mlir_dialects_target_includes target)
set(_dirs
$<BUILD_INTERFACE:${MLIR_INCLUDE_DIR}>
$<BUILD_INTERFACE:${MLIR_GENERATED_INCLUDE_DIR}>
$<BUILD_INTERFACE:${TORCH_MLIR_DIALECTS_SOURCE_DIR}/include>
$<BUILD_INTERFACE:${TORCH_MLIR_DIALECTS_BINARY_DIR}/include>
)
# In LLVM parlance, the actual target may just be an interface and may not
# be responsible for actually compiling anything. The corresponding obj.
# target, when present, is just used for compilation and does not
# contribute to the interface properties.
# TODO: Normalize this upstream.
target_include_directories(${target} PUBLIC ${_dirs})
if(TARGET obj.${target})
target_include_directories(obj.${target} PRIVATE ${_dirs})
endif()
endfunction()
# Configure CMake and tablegen.
list(APPEND CMAKE_MODULE_PATH ${MLIR_MAIN_SRC_DIR}/cmake/modules)
list(APPEND CMAKE_MODULE_PATH ${LLVM_MAIN_SRC_DIR}/cmake)
set(MLIR_TABLEGEN_EXE mlir-tblgen)
include(TableGen)
include(AddLLVM)
include(AddMLIR)
add_subdirectory(include)
add_subdirectory(lib)
add_subdirectory(tools)
add_subdirectory(test)

View File

@ -0,0 +1 @@
add_subdirectory(torch-mlir-dialects)

View File

@ -0,0 +1 @@
add_subdirectory(Dialect)

View File

@ -0,0 +1 @@
add_subdirectory(TMTensor)

View File

@ -0,0 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@ -0,0 +1,33 @@
function(_add_interfaces)
set(LLVM_TARGET_DEFINITIONS TMTensorInterfaces.td)
mlir_tablegen(TMTensorOpInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(TMTensorOpInterfaces.cpp.inc -gen-op-interface-defs)
mlir_tablegen(TMTensorTypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(TMTensorTypeInterfaces.cpp.inc -gen-type-interface-defs)
add_public_tablegen_target(TorchMLIRTMTensorInterfacesIncGen)
add_dependencies(TorchMLIRTMTensorOpsIncGen TorchMLIRTMTensorInterfacesIncGen)
endfunction()
function(_add_scalar_loop_op_interface)
set(LLVM_TARGET_DEFINITIONS ScalarLoopOpInterface.td)
mlir_tablegen(ScalarLoopOpInterface.h.inc -gen-op-interface-decls)
mlir_tablegen(ScalarLoopOpInterface.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(TorchMLIRTMTensorScalarLoopOpInterfaceIncGen)
add_dependencies(TorchMLIRTMTensorOpsIncGen TorchMLIRTMTensorScalarLoopOpInterfaceIncGen)
endfunction()
function(_add_dialect)
set(LLVM_TARGET_DEFINITIONS TMTensorOps.td)
mlir_tablegen(TMTensorOps.h.inc -gen-op-decls)
mlir_tablegen(TMTensorOps.cpp.inc -gen-op-defs)
mlir_tablegen(TMTensorTypes.h.inc -gen-typedef-decls)
mlir_tablegen(TMTensorTypes.cpp.inc -gen-typedef-defs)
mlir_tablegen(TMTensorDialect.h.inc -gen-dialect-decls -dialect=tm_tensor)
mlir_tablegen(TMTensorDialect.cpp.inc -gen-dialect-defs -dialect=tm_tensor)
add_public_tablegen_target(TorchMLIRTMTensorOpsIncGen)
add_dependencies(mlir-headers TorchMLIRTMTensorOpsIncGen)
endfunction()
_add_dialect()
_add_interfaces()
_add_scalar_loop_op_interface()

View File

@ -0,0 +1,29 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_SCALARLOOPOPINTERFACE_H_
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_SCALARLOOPOPINTERFACE_H_
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Support/LLVM.h"
/// Include the ODS generated interface header files.
#include "torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.h.inc"
namespace mlir {
namespace torch {
namespace TMTensor {} // namespace TMTensor
} // namespace torch
} // namespace mlir
#endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_SCALARLOOPOPINTERFACE_H_

View File

@ -0,0 +1,79 @@
//===-------------------------------------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCH_MLIR_DIALECT_TMTENSOR_SCALARLOOPOPINTERFACE
#define TORCH_MLIR_DIALECT_TMTENSOR_SCALARLOOPOPINTERFACE
include "mlir/IR/OpBase.td"
def ScalarLoopOpInterface : OpInterface<"ScalarLoopOpInterface"> {
let description = [{
Interface for allowing operations to expose information needed to
lower it to for loops
}];
let cppNamespace = "::mlir::torch::TMTensor";
let methods = [
InterfaceMethod<
/*desc=*/[{
Returns the destination operands. For op with `memref`
operands, this is the result buffers. For op with `tensor`
operands, this is the operands that contain the initial
value for the result. These are "tied" to the result
buffers. For example, for a `LinalgOp`/`TMTensor` ops, it
is the `outs` parameters. For `tensor.insert_slice`, it is
the `dest` parameter.
}],
/*retType=*/"SmallVector<Value>",
/*methodName=*/"getDestinationOperands",
/*args=*/(ins "OpBuilder &":$b),
/*methodBody=*/"",
/*defaultImplementation=*/"return ValueRange{};"
>,
InterfaceMethod<
/*desc=*/[{
Returns a list of `StringRef`s that describe the number of
loops and the iterator types of the operation. The list is
expected to use
`getParallelIteratorTypeName()`/`getReductionIteratorTypeName()`
from MLIR Structured Op Utils.
}],
/*retType=*/"SmallVector<StringRef>",
/*methodName=*/"getLoopIteratorTypes"
>,
InterfaceMethod<
/*desc=*/[{
Returns a list of ranges that describe the loop bounds and
step for the loops of the operation.
}],
/*retTy=*/"SmallVector<Range>",
/*methodName=*/"getIterationDomain",
/*args=*/(ins "OpBuilder &":$b)
>,
InterfaceMethod<
/*desc=*/[{
Generates the loop body implementation. Assume that all the parallel
loops and reduction loops are created and the insertion point of the
build is set to the innermost of the loop. This method implements the
loop body IRs.
}],
/*retType=*/"LogicalResult",
/*methodName=*/"generateScalarImplementation",
/*args=*/(ins
"OpBuilder &":$b,
"Location ":$loc,
"ValueRange ":$ivs),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
}]
>
];
}
#endif // TORCH_MLIR_DIALECT_TMTENSOR_SCALARLOOPOPINTERFACES

View File

@ -0,0 +1,59 @@
//===-------------------------------------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCH_MLIR_DIALECT_TMTENSOR_BASE
#define TORCH_MLIR_DIALECT_TMTENSOR_BASE
include "mlir/IR/OpBase.td"
//===----------------------------------------------------------------------===//
// Dialect definition
//===----------------------------------------------------------------------===//
def TMTensor_Dialect : Dialect {
let name = "tm_tensor";
let cppNamespace = "::mlir::torch::TMTensor";
let description = [{
The tm_tensor (tm = torch-mlir) dialect is a temporary staging ground in
the torch-mlir project for a set of widely-accepted tensor compute
operations that are not well-served by existing representations in MLIR
upstream. These ops are currently heavily inspired by the linalg_ext
dialect (which itself is heavily inspired by the structured ops of the
linalg dialect). But while linalg_ext is meant to power specific codegen
transformations, the tm_tensor dialect is a much more pure "interface
dialect" agnostic to any particular set of transformations applied to
the operations. We simply require a way to name the specified operations
for interchange between projects, without taking strong opinions on the
mechanics of transformations.
The dialect does include interfaces to generate scalar reference code for
the operations, which simultaneously provides a precise definition of their
semantics, and aids in producing executable reference implementations of
the operations.
The goal of this dialect is to eventually either be upstreamed or to be
subsumed by functionality included by upstream MLIR. It should also be kept
consistent with the linalg_ext dialect unless there is a good reason not
to.
}];
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
// Type definitions
//===----------------------------------------------------------------------===//
class RankedTensorOrMemRefOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes,
Or<[IsMemRefTypePred, And<[IsTensorTypePred, HasRankPred]>]>,
"ranked tensor or memref", "::mlir::ShapedType">;
def AnyRankedTensorOrMemRefType : RankedTensorOrMemRefOf<[AnyType]>;
#endif // TORCH_MLIR_DIALECT_TMTENSOR_BASE

View File

@ -0,0 +1,20 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORDIALECT_H_
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORDIALECT_H_
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
// clang-format off: must be included after all LLVM/MLIR headers
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h.inc" // IWYU pragma: keep
// clang-format on
#endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORDIALECT_H_

View File

@ -0,0 +1,42 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LLVM.h"
namespace mlir {
namespace torch {
namespace TMTensor {
class TMTensorOp;
/// OpOperand vector that implicitly converts to a Value vector.
struct OpOperandVector : public SmallVector<OpOperand *> {
operator SmallVector<Value>();
};
namespace detail {
LogicalResult verifyTMTensorOpInterface(Operation *op);
}
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h.inc" // IWYU pragma: export
/// Include the generated interface declarations.
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOpInterfaces.h.inc" // IWYU pragma: export
} // namespace TMTensor
} // namespace torch
} // namespace mlir
#endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_

View File

@ -0,0 +1,493 @@
//===-------------------------------------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCH_MLIR_DIALECT_TMTENSOR_INTERFACES
#define TORCH_MLIR_DIALECT_TMTENSOR_INTERFACES
include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorBase.td"
// The interface is a subset of LinalgStructuredInterface.
def TMTensorInterface : OpInterface<"TMTensorOp"> {
let methods = [
//===------------------------------------------------------------------===//
// Num input/output arguments handling.
//===------------------------------------------------------------------===//
// `inputs` must be defined by each op that wants to implement the
// LinalgStructuredInterface.
InterfaceMethod<
/*desc=*/[{
Return the input shape operands.
}],
/*retTy=*/"ValueRange",
/*methodName=*/"inputs",
/*args=*/(ins)
>,
// These special methods rely on `inputs` and `outputs` being defined by
// each op that wants to implement the LinalgStructuredInterface.
InterfaceMethod<
/*desc=*/[{
Return the number of inputs.
}],
/*retTy=*/"int64_t",
/*methodName=*/"getNumInputs",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return $_op.inputs().size();
}]
>,
// `outputs` must be defined by each op that wants to implement the
// LinalgStructuredInterface.
InterfaceMethod<
/*desc=*/[{
Return the output shape operands.
}],
/*retTy=*/"ValueRange",
/*methodName=*/"outputs",
/*args=*/(ins)
>,
InterfaceMethod<
/*desc=*/[{
Return the number of outputs.
}],
/*retTy=*/"int64_t",
/*methodName=*/"getNumOutputs",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return $_op.outputs().size();
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the number of inputs and outputs.
}],
/*retTy=*/"int64_t",
/*methodName=*/"getNumInputsAndOutputs",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return getNumInputs() + getNumOutputs();
}]
>,
//===------------------------------------------------------------------===//
// Input operands handling.
//===------------------------------------------------------------------===//
InterfaceMethod<
/*desc=*/[{
Return the input operands.
}],
/*retTy=*/"OpOperandVector",
/*methodName=*/"getInputOperands",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
int64_t numInputs = getNumInputs();
OpOperandVector result;
result.reserve(numInputs);
llvm::transform(
this->getOperation()->getOpOperands().take_front(numInputs),
std::back_inserter(result),
[](OpOperand &opOperand) { return &opOperand; });
return result;
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the `i`-th input operand.
}],
/*retTy=*/"OpOperand*",
/*methodName=*/"getInputOperand",
/*args=*/(ins "int64_t":$i),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(i >= 0 && i < getNumInputs());
return &this->getOperation()->getOpOperand(i);
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the subset of input operands that are of buffer type.
}],
/*retTy=*/"OpOperandVector",
/*methodName=*/"getInputBufferOperands",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
OpOperandVector result;
result.reserve(getNumInputs());
llvm::copy_if(getInputOperands(),
std::back_inserter(result),
[](OpOperand *opOperand) {
return opOperand->get().getType().template isa<MemRefType>();
});
return result;
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the subset of input operands that are of tensor type.
}],
/*retTy=*/"OpOperandVector",
/*methodName=*/"getInputTensorOperands",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
OpOperandVector result;
result.reserve(getNumInputs());
llvm::copy_if(getInputOperands(),
std::back_inserter(result),
[](OpOperand *opOperand) {
return opOperand->get().getType().template isa<RankedTensorType>();
});
return result;
}]
>,
//===------------------------------------------------------------------===//
// Output operands handling.
//===------------------------------------------------------------------===//
InterfaceMethod<
/*desc=*/[{
Return the output operands.
}],
/*retTy=*/"OpOperandVector",
/*methodName=*/"getOutputOperands",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
int64_t numOutputs = getNumOutputs();
OpOperandVector result;
result.reserve(numOutputs);
llvm::transform(
this->getOperation()->getOpOperands()
.drop_front(getNumInputs())
.take_front(numOutputs),
std::back_inserter(result),
[](OpOperand &opOperand) { return &opOperand; });
return result;
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the `i`-th output operand.
}],
/*retTy=*/"OpOperand*",
/*methodName=*/"getOutputOperand",
/*args=*/(ins "int64_t":$i),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(i >= 0 && i < getNumOutputs());
return &this->getOperation()->getOpOperand(getNumInputs() + i);
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the subset of output operands that are of buffer type.
}],
/*retTy=*/"OpOperandVector",
/*methodName=*/"getOutputBufferOperands",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
OpOperandVector result;
result.reserve(getNumOutputs());
llvm::copy_if(getOutputOperands(),
std::back_inserter(result),
[](OpOperand *opOperand) {
return opOperand->get().getType().template isa<MemRefType>();
});
return result;
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the subset of output operands that are of tensor type.
}],
/*retTy=*/"OpOperandVector",
/*methodName=*/"getOutputTensorOperands",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
OpOperandVector result;
result.reserve(getNumOutputs());
llvm::copy_if(getOutputOperands(),
std::back_inserter(result),
[](OpOperand *opOperand) {
return opOperand->get().getType().template isa<RankedTensorType>();
});
return result;
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the types of the subset of output operands that are of buffer type.
}],
/*retTy=*/"SmallVector<MemRefType>",
/*methodName=*/"getOutputBufferTypes",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
SmallVector<MemRefType> result;
result.reserve(getNumOutputs());
llvm::transform(getOutputBufferOperands(),
std::back_inserter(result),
[](OpOperand *opOperands) {
return opOperands->get().getType().cast<MemRefType>();
});
return result;
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the types of the subset of output operands that are of tensor type.
}],
/*retTy=*/"SmallVector<RankedTensorType>",
/*methodName=*/"getOutputTensorTypes",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
SmallVector<RankedTensorType> result;
result.reserve(getNumOutputs());
llvm::transform(getOutputTensorOperands(),
std::back_inserter(result),
[](OpOperand *opOperands) {
return opOperands->get().getType().cast<RankedTensorType>();
});
return result;
}]
>,
//===------------------------------------------------------------------===//
// Input and Output arguments handling.
//===------------------------------------------------------------------===//
InterfaceMethod<
/*desc=*/[{
Return the range over input and output operands.
}],
/*retTy=*/"OpOperandVector",
/*methodName=*/"getInputAndOutputOperands",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
int64_t numInputsAndOutputs = getNumInputsAndOutputs();
OpOperandVector result;
result.reserve(numInputsAndOutputs);
llvm::transform(
this->getOperation()->getOpOperands()
.take_front(numInputsAndOutputs),
std::back_inserter(result),
[](OpOperand &opOperand) { return &opOperand; });
return result;
}]
>,
InterfaceMethod<
/*desc=*/[{
Return true if the payload uses the value loaded from `opOperand`. This
is useful to avoid loading from "write-only" memory that may be
uninitialized, as well as properly cloning "read-write" operands.
}],
/*retTy=*/"bool",
/*methodName=*/"payloadUsesValueFromOperand",
/*args=*/(ins "OpOperand *":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
unsigned bbArgNumber = opOperand->getOperandNumber();
// Safeguard against the named linalg ops that are manually defined and
// that only support buffer semantics: we should not be there.
// Such ops have an empty regionBuilder and are not constructed with a
// region for now. In the future they are slated to disappear.
assert(this->getOperation()->getNumRegions() == 1 && "unexpected "
"missing region (calling `payloadUsesValueFromOperand` on "
"manually defined named Linalg op?)");
Block &block = this->getOperation()->getRegion(0).front();
// Init tensors have uses.
return !block.getArgument(bbArgNumber).use_empty();
}]
>,
InterfaceMethod<
/*desc=*/[{
Return true if `opOperand` is an input tensor.
}],
/*retTy=*/"bool",
/*methodName=*/"isInputTensor",
/*args=*/(ins "OpOperand *":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
if (!opOperand->get().getType().template isa<RankedTensorType>())
return false;
if (opOperand->getOperandNumber() < $_op.getNumInputs())
return true;
return false;
}]
>,
InterfaceMethod<
/*desc=*/[{
Return true if `opOperand` is an output tensor.
}],
/*retTy=*/"bool",
/*methodName=*/"isOutputTensor",
/*args=*/(ins "OpOperand *":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
if (!opOperand->get().getType().template isa<RankedTensorType>())
return false;
if (opOperand->getOperandNumber() >= $_op.getNumInputs())
return true;
return false;
}]
>,
InterfaceMethod<
/*desc=*/[{
Return true if `opOperand` is an init tensor. This is true when it is
an output tensor operand whose value is used in the payload region.
}],
/*retTy=*/"bool",
/*methodName=*/"isInitTensor",
/*args=*/(ins "OpOperand *":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
if (!$_op.isOutputTensor(opOperand))
return false;
return payloadUsesValueFromOperand(opOperand);
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the `opOperand` rank or zero for scalars.
}],
/*retTy=*/"int64_t",
/*methodName=*/"getRank",
/*args=*/(ins "OpOperand*":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == this->getOperation());
if (auto shapedType =
opOperand->get().getType().template dyn_cast<ShapedType>())
return shapedType.getRank();
return 0;
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the `opOperand` shape or an empty vector for scalars.
}],
/*retTy=*/"ArrayRef<int64_t>",
/*methodName=*/"getShape",
/*args=*/(ins "OpOperand*":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == this->getOperation());
if (auto shapedType =
opOperand->get().getType().template dyn_cast<ShapedType>())
return shapedType.getShape();
return {};
}]
>,
InterfaceMethod<
/*desc=*/[{
Return true if the `opOperand` is a scalar value.
}],
/*retTy=*/"bool",
/*methodName=*/"isScalar",
/*args=*/(ins "OpOperand*":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == this->getOperation());
return !opOperand->get().getType().template isa<ShapedType>();
}]
>,
//===------------------------------------------------------------------===//
// Other interface methods.
//===------------------------------------------------------------------===//
InterfaceMethod<
/*desc=*/[{
Return whether the op has only MemRef input and outputs.
}],
/*retTy=*/"bool",
/*methodName=*/"hasBufferSemantics",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return this->getOperation()->getNumResults() == 0 &&
llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) {
return isScalar(opOperand) ||
opOperand->get().getType().template isa<MemRefType>();
}) &&
llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) {
return opOperand->get().getType().template isa<MemRefType>();
});
}]
>,
InterfaceMethod<
/*desc=*/[{
Return whether the op has only RankedTensor input and outputs.
}],
/*retTy=*/"bool",
/*methodName=*/"hasTensorSemantics",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return
llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) {
return isScalar(opOperand) ||
opOperand->get().getType().template isa<RankedTensorType>();
}) &&
llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) {
return opOperand->get().getType().template isa<RankedTensorType>();
});
}]
>,
//===------------------------------------------------------------------===//
// Other static interface methods.
//===------------------------------------------------------------------===//
InterfaceMethod<
/*desc=*/[{
Clone the current operation with the given location and operands. This
is used to abstract away the optional underlying region creation. This
does not change the balance between input, output_buffer and
init_tensors operands.
}],
/*retTy=*/"Operation *",
/*methodName=*/"clone",
(ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
"ValueRange":$operands),
[{
BlockAndValueMapping bvm;
OperationState state(
loc, ConcreteOp::getOperationName(), operands, resultTypes,
$_op->getAttrs());
for (Region &r : $_op->getRegions())
r.cloneInto(state.addRegion(), bvm);
return b.createOperation(state);
}]
>
];
let extraClassDeclaration = [{
//========================================================================//
// Helper functions to mutate the `operand_segment_sizes` attribute.
// These are useful when cloning and changing operand types.
//========================================================================//
void setNumInputs(unsigned num) { setOperandSegmentAt(0, num); }
void setNumOutputBuffers(unsigned num) { setOperandSegmentAt(1, num); }
private:
void setOperandSegmentAt(unsigned idx, unsigned val) {
auto attr = (*this)->getAttr("operand_segment_sizes")
.cast<DenseIntElementsAttr>();
unsigned i = 0;
auto newAttr = attr.mapValues(IntegerType::get(getContext(), 32),
[&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; });
getOperation()->setAttr("operand_segment_sizes", newAttr);
}
}];
let verify = [{ return detail::verifyTMTensorOpInterface($_op); }];
}
#endif // TORCH_MLIR_DIALECT_TMTENSOR_INTERFACES

View File

@ -0,0 +1,23 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_SCALARLOOPOPINTERFACE_H_
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_SCALARLOOPOPINTERFACE_H_
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Support/LLVM.h"
/// Include the ODS generated interface header files.
#include "torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.h.inc"
#endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_SCALARLOOPOPINTERFACE_H_

View File

@ -0,0 +1,41 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSOROPS_H_
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSOROPS_H_
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h"
namespace mlir {
namespace torch {
namespace TMTensor {
/// Returns a `memref.dim` or `tensor.dim` operation to get the shape of `v` at
/// `dim`.
Value getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim);
/// Returns a `memref.dim` or `tensor.dim` operation to get the shape of `v` at
/// `dim`. If the shape is constant, returns the shape as an `IntegerAttr`.
OpFoldResult getDim(OpBuilder &builder, Location loc, Value v, int64_t dim);
} // namespace TMTensor
} // namespace torch
} // namespace mlir
#define GET_OP_CLASSES
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h.inc" // IWYU pragma: export
#endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSOROPS_H_

View File

@ -0,0 +1,205 @@
//===-------------------------------------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCH_MLIR_DIALECT_TMTENSOR_OPS
#define TORCH_MLIR_DIALECT_TMTENSOR_OPS
include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorBase.td"
include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.td"
include "torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
//===----------------------------------------------------------------------===//
// Base class.
//===----------------------------------------------------------------------===//
class TMTensor_PureOp<string mnemonic, list<OpTrait> traits = []> :
Op<TMTensor_Dialect, mnemonic, traits> {
}
class TMTensor_Op<string mnemonic, list<OpTrait> traits = []> :
TMTensor_PureOp<mnemonic, !listconcat(traits,
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TMTensorInterface,
SingleBlockImplicitTerminator<"::mlir::torch::TMTensor::YieldOp">
])> {
let verifier = [{ return verify$cppClass(*this); }];
let printer = [{ return print$cppClass(p, *this); }];
let parser = [{ return parse$cppClass(parser, result); }];
code extraTMTensorOpClassDeclaration = [{
SmallVector<Value> getDestinationOperands(OpBuilder &b) {
SmallVector<Value> dest(outputs().begin(), outputs().end());
return dest;
}
}];
}
//===----------------------------------------------------------------------===//
// Non-structured ops
//===----------------------------------------------------------------------===//
def TMTensor_ScanOp : TMTensor_Op<"scan"
,[DeclareOpInterfaceMethods<ScalarLoopOpInterface,
["generateScalarImplementation"]>]> {
let summary = "Scan operator";
let description = [{
Computes the inclusive/exclusive scan along a given dimension.
}];
let arguments = (ins Variadic<AnyShaped>:$inputs,
Variadic<AnyShaped>:$outputs,
I64Attr:$dimension,
BoolAttr:$inclusive
);
let builders = [
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
CArg<"int64_t", "0">:$dimension, CArg<"bool", "true">:$inclusive)>
];
let results = (outs Variadic<AnyRankedTensor>:$results);
let regions = (region AnyRegion:$region);
let hasFolder = 1;
let assemblyFormat = [{
`dimension` `(` $dimension `)`
`inclusive` `(` $inclusive `)`
attr-dict
`ins` `(` $inputs `:` type($inputs) `)`
`outs` `(` $outputs `:` type($outputs) `)`
$region (`->` type($results)^)?
}];
let extraClassDeclaration = extraTMTensorOpClassDeclaration # [{
Value input() {
return getInputOperand(0)->get();
}
Value accumulator() {
return getOutputOperand(1)->get();
}
Value output() {
return getOutputOperand(0)->get();
}
ShapedType getOperandType() {
return input().getType().cast<ShapedType>();
}
int64_t getOperandRank() {
return getOperandType().getRank();
}
}];
}
def TMTensor_ScatterOp : TMTensor_Op<"scatter",
[DeclareOpInterfaceMethods<ScalarLoopOpInterface,
["generateScalarImplementation"]>]> {
let summary = "Scatter operator";
let description = [{
Based on XLA operation semantics, takes two `inputs` (`update` and
`indices`) and `outputs` value (`original`). The operation updates
the value at the slices specified by `indices` by combining the
current value with the value in `updates` using the computation
specified in `region`. The `region` specifies a binary operation
of signature (T, T) -> T, where `T` is the element-type of
`updates` (and `original`). The first argument correspond the
value to be updated (i.e. from `updates`), and the second the
current value (i.e. value from `original`).
The `indices` is a 2D tensor/memref type. The first dim is the number of
updates, and the second dim is index depth. The index depth should always be
static.
The first dim of `updates` and `indices` is identical, since they represent
the number of updates.
The rank of the `original`/`result` is `index_depth + rank(%updates) - 1`.
The first `index_depth` indices are derived from `indices` and the shape of
update value must match the rest shape of `original`.
The shapes definition follows tensorflow operations execept that it force
batch dims to be 1D. See more information in
https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update
}];
let arguments = (ins
Variadic<AnyRankedTensorOrMemRefType>:$inputs,
Variadic<AnyRankedTensorOrMemRefType>:$outputs
);
let results = (outs Variadic<AnyRankedTensor>:$results);
let regions = (region AnyRegion:$region);
let assemblyFormat = [{
attr-dict (`ins` `(` $inputs^ `:` type($inputs) `)`)?
`outs` `(` $outputs `:` type($outputs) `)`
$region (`->` type($results)^)?
}];
let extraClassDeclaration = extraTMTensorOpClassDeclaration # [{
int64_t getIndexDepth() {
return getInputOperand(1)
->get()
.getType()
.cast<ShapedType>()
.getShape()
.back();
}
Value updates() {
return getInputOperand(0)->get();
}
ShapedType getUpdateType() {
return updates().getType().cast<ShapedType>();
}
Value indices() {
return getInputOperand(1)->get();
}
ShapedType getIndicesType() {
return indices().getType().cast<ShapedType>();
}
Value original() {
return getOutputOperand(0)->get();
}
ShapedType getOriginalType() {
return original().getType().cast<ShapedType>();
}
int64_t getUpdateSliceRank() {
return updates().getType().cast<ShapedType>().getRank() - 1;
}
bool isScalarUpdate() {
return getUpdateSliceRank() == 0;
}
}];
}
//===----------------------------------------------------------------------===//
// Pure ops
//===----------------------------------------------------------------------===//
def TMTensor_YieldOp : TMTensor_PureOp<"yield", [NoSideEffect, ReturnLike, Terminator]> {
let summary = "TMTensor yield op";
let description = [{
`tm_tensor.yield` is a special terminator operation for blocks inside
regions in `tm_tensor` ops.
}];
let arguments = (ins Variadic<AnyType>:$operands);
let builders = [
OpBuilder<(ins), [{ /* nothing to do */ }]>,
];
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
}
#endif // TORCH_MLIR_DIALECT_TMTENSOR_OPS

View File

@ -0,0 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls)
mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header)
mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl)
add_public_tablegen_target(TorchMLIRTMTensorTransformsPassesIncGen)

View File

@ -0,0 +1,26 @@
//===- PassDetail.h - TMTensor Pass class details -------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASS_DETAIL_H_
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASS_DETAIL_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace torch {
namespace TMTensor {
#define GEN_PASS_CLASSES
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h.inc" // IWYU pragma: keep
} // namespace TMTensor
} // namespace torch
} // namespace mlir
#endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASS_DETAIL_H_

View File

@ -0,0 +1,27 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASSES_H_
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASSES_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace torch {
namespace TMTensor {
std::unique_ptr<OperationPass<FuncOp>> createTMTensorToLoopsPass();
void registerPasses();
} // namespace TMTensor
} // namespace torch
} // namespace mlir
#endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASSES_H_

View File

@ -0,0 +1,21 @@
//===-------------------------------------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCH_MLIR_DIALECT_TMTENSOR_PASSES
#define TORCH_MLIR_DIALECT_TMTENSOR_PASSES
include "mlir/Pass/PassBase.td"
def TMTensorToLoops :
Pass<"torch-mlir-tm-tensor-to-loops", "FuncOp"> {
let summary = "Convert TMTensor ops to loops and Linalg ops.";
let constructor = "mlir::torch::TMTensor::createTMTensorToLoopsPass()";
}
#endif // TORCH_MLIR_DIALECT_TMTENSOR_PASSES

View File

@ -0,0 +1 @@
add_subdirectory(Dialect)

View File

@ -0,0 +1 @@
add_subdirectory(TMTensor)

View File

@ -0,0 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@ -0,0 +1,29 @@
add_mlir_library(TorchMLIRTMTensorDialect
TMTensorDialect.cpp
TMTensorInterfaces.cpp
TMTensorOps.cpp
ScalarLoopOpInterface.cpp
ADDITIONAL_HEADER_DIRS
${TORCH_MLIR_DIALECTS_SOURCE_DIR}/include
DEPENDS
TorchMLIRTMTensorOpsIncGen
LINK_LIBS PUBLIC
MLIRAffine
MLIRDialectUtils
MLIRIR
MLIRLinalg
MLIRMath
MLIRMemRef
MLIRPass
MLIRSideEffectInterfaces
MLIRSupport
MLIRSCF
MLIRStandard
MLIRTensor
MLIRViewLikeInterface
)
torch_mlir_dialects_target_includes(TorchMLIRTMTensorDialect)

View File

@ -0,0 +1,24 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "torch-mlir-tiled-op-interface"
using namespace mlir;
using namespace mlir::torch::TMTensor;
#include "torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.cpp.inc"

View File

@ -0,0 +1,30 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/SourceMgr.h"
using namespace mlir;
using namespace mlir::torch::TMTensor;
void TMTensorDialect::initialize() {
#define GET_OP_LIST
addOperations<
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.cpp.inc"
>();
}
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.cpp.inc"

View File

@ -0,0 +1,54 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::TMTensor;
OpOperandVector::operator SmallVector<Value>() {
SmallVector<Value> result;
result.reserve(this->size());
llvm::transform(*this, std::back_inserter(result),
[](OpOperand *opOperand) { return opOperand->get(); });
return result;
}
LogicalResult
mlir::torch::TMTensor::detail::verifyTMTensorOpInterface(Operation *op) {
TMTensorOp mtTensorOp = cast<TMTensorOp>(op);
if (op->getNumResults()) {
if (!mtTensorOp.hasTensorSemantics()) {
return mtTensorOp.emitOpError(
"expected inputs and outputs to be RankedTensorType or scalar");
}
if (op->getNumResults() != mtTensorOp.outputs().size()) {
return mtTensorOp.emitOpError(
"expected number of outputs to be same as the number of results");
}
for (auto en : llvm::enumerate(op->getResultTypes())) {
Type outputType = mtTensorOp.outputs()[en.index()].getType();
if (en.value() != outputType) {
return mtTensorOp.emitOpError("expected type of `outs` operand #")
<< en.index() << " " << outputType
<< " to be same as result type " << en.value();
}
}
} else {
if (!mtTensorOp.hasBufferSemantics()) {
return mtTensorOp.emitOpError(
"expected inputs and outputs to be MemRefType or scalar");
}
}
return success();
}
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOpInterfaces.cpp.inc" // IWYU pragma: export

View File

@ -0,0 +1,483 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/SMLoc.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::TMTensor;
//===----------------------------------------------------------------------===//
// Utils.
//===----------------------------------------------------------------------===//
static void getEffectsImpl(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects,
ValueRange results, ValueRange inputBuffers, ValueRange outputBuffers) {
for (Value value : results) {
effects.emplace_back(MemoryEffects::Allocate::get(), value,
SideEffects::DefaultResource::get());
}
for (Value value : inputBuffers) {
effects.emplace_back(MemoryEffects::Read::get(), value,
SideEffects::DefaultResource::get());
}
for (Value value : outputBuffers) {
effects.emplace_back(MemoryEffects::Read::get(), value,
SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Write::get(), value,
SideEffects::DefaultResource::get());
}
}
Value TMTensor::getDimValue(OpBuilder &builder, Location loc, Value v,
int64_t dim) {
return TypeSwitch<Type, Value>(v.getType())
.Case<RankedTensorType>([&](RankedTensorType t) -> Value {
return builder.create<tensor::DimOp>(loc, v, dim);
})
.Case<MemRefType>([&](MemRefType t) -> Value {
return builder.create<memref::DimOp>(loc, v, dim);
})
.Default([&](Type t) { return Value(); });
}
OpFoldResult TMTensor::getDim(OpBuilder &builder, Location loc, Value v,
int64_t dim) {
auto t = v.getType().cast<ShapedType>();
if (t.isDynamicDim(dim)) {
return getDimValue(builder, loc, v, dim);
}
return builder.getI64IntegerAttr(t.getDimSize(dim));
}
//===----------------------------------------------------------------------===//
// ScanOp
//===----------------------------------------------------------------------===//
static LogicalResult verifyScanOp(ScanOp op) {
if (op.getNumInputs() != 1) {
return op.emitOpError("expected one input operands");
}
if (op.getNumOutputs() != 2) {
return op.emitOpError("expected two output operands");
}
if (!op.input().getType().isa<ShapedType>()) {
return op.emitOpError("expected first input element type to be shaped");
}
auto accumulatorType = op.accumulator().getType().cast<ShapedType>();
auto inputType = op.input().getType().cast<ShapedType>();
auto outputType = op.output().getType().cast<ShapedType>();
ArrayRef<int64_t> inputShapes = inputType.getShape();
ArrayRef<int64_t> outputShapes = outputType.getShape();
if (accumulatorType.getElementType() != inputType.getElementType()) {
return op.emitOpError(
"expected input/accumulator element types to be identical");
}
ArrayRef<int64_t> accumulatorShape = accumulatorType.getShape();
int64_t accumulatorRank = accumulatorType.getRank();
if (accumulatorRank != inputType.getRank() - 1) {
return op.emitOpError(
"expected accumulator rank to be equal to input rank - 1");
}
SmallVector<int64_t> expectedAccumulatorShape;
for (size_t i = 0; i < (size_t)inputType.getRank(); i++) {
if (i != op.dimension())
expectedAccumulatorShape.push_back(inputShapes[i]);
}
if (llvm::any_of(llvm::zip(expectedAccumulatorShape, accumulatorShape),
[](std::tuple<int64_t, int64_t> s) {
return std::get<0>(s) != ShapedType::kDynamicSize &&
std::get<1>(s) != ShapedType::kDynamicSize &&
std::get<0>(s) != std::get<1>(s);
})) {
return op.emitOpError("incompatible input/accumulator shapes");
}
if (inputType.getElementType() != outputType.getElementType()) {
return op.emitOpError(
"expected input/output element types to be identical");
}
if (inputShapes.size() != outputShapes.size()) {
return op.emitOpError("expected input/output to have identical ranks");
}
if (llvm::any_of(llvm::zip(inputShapes, outputShapes),
[](std::tuple<int64_t, int64_t> s) {
return std::get<0>(s) != ShapedType::kDynamicSize &&
std::get<1>(s) != ShapedType::kDynamicSize &&
std::get<0>(s) != std::get<1>(s);
})) {
return op.emitOpError("incompatible input/output shapes");
}
return success();
}
SmallVector<Range> ScanOp::getIterationDomain(OpBuilder &builder) {
int64_t operandRank = getOperandRank();
SmallVector<Range> loopBounds(operandRank);
Location loc = getLoc();
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
Value source = input();
for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
loopBounds[dim].offset = zero;
loopBounds[dim].size = getDimValue(builder, loc, source, dim);
loopBounds[dim].stride = one;
}
return loopBounds;
}
SmallVector<StringRef> ScanOp::getLoopIteratorTypes() {
SmallVector<StringRef> iteratorTypes(getOperandRank(),
getParallelIteratorTypeName());
iteratorTypes[dimension()] = getReductionIteratorTypeName();
return iteratorTypes;
}
// Generates naive scalar implementation of scan for a given operator f.
// For inclusive,
// output[0] = input[0]
// output[i] = f(output[i-1], input[i])
//
// For exclusive,
// output[0] = 0
// output[i] = f(output[i-1], input[i-1])
LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc,
ValueRange ivs) {
SmallVector<Value> indices, scanBlkArgs;
indices.append(ivs.begin(), ivs.end());
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
uint64_t scanDim = dimension();
Value cond = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
indices[scanDim], zero);
bool isInclusive = inclusive();
SmallVector<Value> accIndices;
for (size_t i = 0; i < indices.size(); i++) {
if (i != scanDim)
accIndices.push_back(indices[i]);
}
auto scfIf = b.create<scf::IfOp>(
loc, TypeRange{}, cond,
[&](OpBuilder &b, Location loc) {
if (isInclusive) {
auto value = b.create<memref::LoadOp>(loc, input(), indices);
b.create<memref::StoreOp>(loc, value, output(), indices);
} else {
auto value = b.create<memref::LoadOp>(loc, accumulator(), accIndices);
b.create<memref::StoreOp>(loc, value, output(), indices);
}
b.create<scf::YieldOp>(loc);
},
[&](OpBuilder &b, Location loc) {
SmallVector<Value> indices(ivs.begin(), ivs.end());
Value iv = indices[scanDim];
Value ivMinusOne = b.create<arith::SubIOp>(loc, iv, one);
indices[scanDim] = ivMinusOne;
scanBlkArgs.push_back(b.create<memref::LoadOp>(loc, output(), indices));
Value i0;
if (!isInclusive)
i0 = b.create<memref::LoadOp>(loc, input(), indices);
indices[scanDim] = iv;
if (isInclusive)
i0 = b.create<memref::LoadOp>(loc, input(), indices);
scanBlkArgs.push_back(i0);
});
auto &srcBlock = region().front();
Region &region = scfIf.getElseRegion();
BlockAndValueMapping bvm;
{
OpBuilder::InsertionGuard guard(b);
auto &block = region.front();
b.setInsertionPointToEnd(&block);
for (auto it : llvm::zip(srcBlock.getArguments(), scanBlkArgs)) {
bvm.map(std::get<0>(it), std::get<1>(it));
}
for (auto &blockOp : srcBlock.without_terminator()) {
b.clone(blockOp, bvm);
}
b.create<memref::StoreOp>(
loc, bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)),
output(), indices);
b.create<memref::StoreOp>(
loc, bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)),
accumulator(), accIndices);
b.create<scf::YieldOp>(loc);
}
return success();
}
static LogicalResult foldMemRefCast(Operation *op) {
bool folded = false;
for (OpOperand &operand : op->getOpOperands()) {
auto castOp = operand.get().getDefiningOp<memref::CastOp>();
if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
operand.set(castOp.getOperand());
folded = true;
}
}
return success(folded);
}
LogicalResult ScanOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}
//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
static LogicalResult verifyScatterOp(ScatterOp op) {
if (op.inputs().size() != 2) {
return op.emitOpError("expected two input operands");
}
if (op.outputs().size() != 1) {
return op.emitOpError("expected one output operand");
}
auto checkDimensionsMatch = [&](ShapedType t1, ShapedType t2, unsigned dim) {
return t1.getShape()[dim] == t2.getShape()[dim];
};
auto indicesType = op.getIndicesType();
if (indicesType.getRank() != 2 ||
!indicesType.getElementType().isInteger(32)) {
return op.emitOpError(
"expected indices to be of rank 2 of i32 element type");
}
auto indexDepth = op.getIndexDepth();
if (indexDepth == ShapedType::kDynamicSize) {
return op.emitOpError("expected index depth is static");
}
// The first dimension of the indices should match the first dimension of the
// output. They indicate to the number of updates.
auto updateType = op.getUpdateType();
if (updateType.getRank() < 1) {
return op.emitOpError("expected update value to be at least rank 1");
}
if (!checkDimensionsMatch(indicesType, updateType, 0)) {
return op.emitOpError(
"mismatch in shape of indices and update value at dim#0");
}
auto originalType = op.getOriginalType();
// indexDepth + update dims should match to original dims. The first dim of
// update is the number of updates.
if (originalType.getRank() != indexDepth + updateType.getRank() - 1) {
return op.emitOpError(
"mismatch in rank of update value, index depth and original value");
}
for (auto dim : llvm::seq<unsigned>(indexDepth, originalType.getRank())) {
// Offset one because the first dim is the number of updates.
if (updateType.getDimSize(1 + dim - indexDepth) !=
originalType.getDimSize(dim)) {
return op.emitOpError("mismatch in shape of update value dim#")
<< (1 + dim - indexDepth) << " and original value at dim#" << dim;
}
}
Region &region = op.region();
Block *body = &region.front();
if (body->getNumArguments() != 2) {
return op.emitOpError("expected region to have two arguments");
}
Type arg0Type = body->getArgument(0).getType();
Type arg1Type = body->getArgument(1).getType();
if (!arg0Type.isIntOrFloat() || !arg1Type.isIntOrFloat()) {
return op.emitOpError(
"expected region to have scalar argument of integer or float types");
}
if (arg0Type != updateType.getElementType()) {
return op.emitOpError("mismatch in argument 0 of region ")
<< arg0Type << " and element type of update value "
<< updateType.getElementType();
}
if (arg1Type != originalType.getElementType()) {
return op.emitOpError("mismatch in argument 1 of region ")
<< arg1Type << " and element type of original value "
<< originalType.getElementType();
}
if (arg0Type != arg1Type) {
return op.emitOpError("mismatch in region argument types ")
<< arg0Type << " and " << arg1Type;
}
auto yieldOp = cast<TMTensor::YieldOp>(body->getTerminator());
if (yieldOp->getNumOperands() != 1) {
return yieldOp.emitOpError("expected region to yield a single value");
}
auto yieldedType = yieldOp->getOperand(0).getType();
if (yieldedType != arg0Type) {
return yieldOp.emitOpError("mismatch in type of yielded value ")
<< yieldedType << " and argument of the region " << arg0Type;
}
return success();
}
SmallVector<StringRef> ScatterOp::getLoopIteratorTypes() {
SmallVector<StringRef> iteratorTypes(getUpdateType().getRank(),
getParallelIteratorTypeName());
return iteratorTypes;
}
SmallVector<Range> ScatterOp::getIterationDomain(OpBuilder &builder) {
Location loc = getLoc();
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
SmallVector<Range> ranges;
for (auto dim : llvm::seq<int64_t>(0, getUpdateType().getRank())) {
Value ub = getDimValue(builder, loc, updates(), dim);
ranges.emplace_back(Range{zero, ub, one});
}
return ranges;
}
LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b,
Location loc,
ValueRange ivs) {
auto indexDepth = getIndexDepth();
Value update = b.create<memref::LoadOp>(loc, updates(), ivs);
SmallVector<Value> starts;
SmallVector<Value> loadIndices;
loadIndices.push_back(ivs.front());
loadIndices.push_back(Value());
for (auto i : llvm::seq<unsigned>(0, indexDepth)) {
loadIndices.back() = b.create<arith::ConstantIndexOp>(loc, i);
Value idx = b.create<memref::LoadOp>(loc, indices(), loadIndices);
starts.push_back(b.create<arith::IndexCastOp>(loc, b.getIndexType(), idx));
}
starts.append(std::next(ivs.begin()), ivs.end());
Value init = b.create<memref::LoadOp>(loc, original(), starts);
BlockAndValueMapping bvm;
Block &block = region().front();
bvm.map(block.getArgument(0), update);
bvm.map(block.getArgument(1), init);
for (auto &blockOp : block.without_terminator()) {
b.clone(blockOp, bvm);
}
// The last op is linalg_ext.yield op. Store the operand to
// destination.
b.create<memref::StoreOp>(
loc, bvm.lookupOrDefault(block.getTerminator()->getOperand(0)),
original(), starts);
return success();
}
#define DEFINE_OP_GET_EFFECTS(OP_NAME) \
void OP_NAME::getEffects( \
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> \
&effects) { \
SmallVector<Value> inputBuffers = getInputBufferOperands(); \
SmallVector<Value> outputBuffers = getOutputBufferOperands(); \
getEffectsImpl(effects, getOperation()->getResults(), inputBuffers, \
outputBuffers); \
}
DEFINE_OP_GET_EFFECTS(ScanOp)
DEFINE_OP_GET_EFFECTS(ScatterOp)
namespace {
/// This is derived from mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp without any
/// changes.
struct FoldTensorCastOp : public OpInterfaceRewritePattern<TMTensorOp> {
using OpInterfaceRewritePattern<TMTensorOp>::OpInterfaceRewritePattern;
LogicalResult matchAndRewrite(TMTensorOp op,
PatternRewriter &rewriter) const override {
// If no operand comes from a tensor::CastOp and can be folded then fail.
bool hasTensorCastOperand =
llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
if (opOperand->get().isa<BlockArgument>())
return false;
auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>();
return castOp && canFoldIntoConsumerOp(castOp);
});
if (!hasTensorCastOperand)
return failure();
SmallVector<Type, 4> newResultTypes;
newResultTypes.reserve(op->getNumResults());
SmallVector<Value, 4> newOperands;
newOperands.reserve(op->getNumOperands());
// Inputs may fold.
for (OpOperand *opOperand : op.getInputOperands()) {
auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp)
? tensorCastOp.source()
: opOperand->get());
}
// Init tensors may fold, in which case the resultType must also change.
for (OpOperand *opOperand : op.getOutputOperands()) {
auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
bool fold = canFoldIntoConsumerOp(tensorCastOp);
newOperands.push_back(fold ? tensorCastOp.getOperand()
: opOperand->get());
newResultTypes.push_back(newOperands.back().getType());
}
// Clone op.
Operation *newOp =
op.clone(rewriter, op->getLoc(), newResultTypes, newOperands);
SmallVector<Value, 4> replacements;
replacements.reserve(newOp->getNumResults());
for (auto result : llvm::zip(op->getResults(), newOp->getResults())) {
Value oldResult = std::get<0>(result);
Value newResult = std::get<1>(result);
if (newResult.getType() != oldResult.getType()) {
replacements.push_back(rewriter.create<tensor::CastOp>(
op->getLoc(), oldResult.getType(), newResult));
} else {
replacements.push_back(newResult);
}
}
rewriter.replaceOp(op, replacements);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// TMTensorDialect
//===----------------------------------------------------------------------===//
void TMTensorDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
results.add<FoldTensorCastOp>(getContext());
}
#define GET_OP_CLASSES
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.cpp.inc"

View File

@ -0,0 +1,22 @@
add_mlir_library(TorchMLIRTMTensorPasses
ConvertToLoops.cpp
Passes.cpp
DEPENDS
TorchMLIRTMTensorTransformsPassesIncGen
LINK_LIBS PUBLIC
TorchMLIRTMTensorDialect
MLIRAffine
MLIRIR
MLIRLinalg
MLIRLinalgTransforms
MLIRMath
MLIRMemRef
MLIRPass
MLIRSCF
MLIRStandard
MLIRSupport
MLIRTensor
MLIRTransforms
)

View File

@ -0,0 +1,117 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/PassDetail.h"
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
using namespace mlir;
using namespace mlir::torch::TMTensor;
/// Recursive method that lowers one dimension of the `ScalarLoopOpInterface` to
/// scalar loops at a time.
static LogicalResult lowerToLoopsImpl(OpBuilder &builder,
ScalarLoopOpInterface scalarLoopOp,
ArrayRef<Range> loopRanges,
unsigned loopDepth,
SmallVectorImpl<Value> &ivs) {
Location loc = scalarLoopOp.getLoc();
if (loopDepth == loopRanges.size()) {
return scalarLoopOp.generateScalarImplementation(builder, loc, ivs);
}
LogicalResult status = success();
builder.create<scf::ForOp>(
loc, loopRanges[loopDepth].offset, loopRanges[loopDepth].size,
loopRanges[loopDepth].stride, ValueRange{},
[&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
ivs.push_back(iv);
status =
lowerToLoopsImpl(b, scalarLoopOp, loopRanges, loopDepth + 1, ivs);
b.create<scf::YieldOp>(loc);
});
return status;
}
/// Main entry point for lowering `ScalarLoopOpInterface` op to loops.
static LogicalResult lowerToLoops(OpBuilder &builder,
ScalarLoopOpInterface scalarLoopOp) {
SmallVector<Range> loopBounds = scalarLoopOp.getIterationDomain(builder);
SmallVector<Value> ivs;
return lowerToLoopsImpl(builder, scalarLoopOp, loopBounds, 0, ivs);
}
/// Pattern rewriter hook to lower a `ScalarLoopOpInterface` to loops.
namespace {
struct ScalarLoopOpInterfaceLowerToLoopsPattern : public RewritePattern {
ScalarLoopOpInterfaceLowerToLoopsPattern(MLIRContext *context,
PatternBenefit benefit = 1)
: RewritePattern(MatchAnyOpTypeTag(), benefit, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto scalarLoopOp = dyn_cast<ScalarLoopOpInterface>(op);
if (!scalarLoopOp) {
return failure();
}
if (llvm::any_of(scalarLoopOp->getResults(),
[&](Value v) { return v.getType().isa<ShapedType>(); })) {
return rewriter.notifyMatchFailure(
scalarLoopOp, "lower to loops needs to have tensor semantics");
}
if (failed(lowerToLoops(rewriter, scalarLoopOp))) {
return failure();
}
rewriter.eraseOp(op);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Pass
//===----------------------------------------------------------------------===//
namespace {
struct TMTensorToLoopsPass : public TMTensorToLoopsBase<TMTensorToLoopsPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect, StandardOpsDialect,
mlir::arith::ArithmeticDialect, math::MathDialect,
memref::MemRefDialect, scf::SCFDialect>();
}
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.insert<ScalarLoopOpInterfaceLowerToLoopsPattern>(context);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
torch::TMTensor::createTMTensorToLoopsPass() {
return std::make_unique<TMTensorToLoopsPass>();
}

View File

@ -0,0 +1,33 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/Passes.h"
using namespace mlir;
namespace mlir {
namespace torch {
namespace TMTensor {
namespace detail {
#define GEN_PASS_REGISTRATION
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h.inc" // IWYU pragma: export
} // namespace detail
} // namespace TMTensor
} // namespace torch
} // namespace mlir
void torch::TMTensor::registerPasses() {
torch::TMTensor::detail::registerPasses();
}

View File

@ -0,0 +1,19 @@
configure_lit_site_cfg(
${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py
MAIN_CONFIG
${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py
)
set(TORCH_MLIR_DIALECTS_TEST_DEPENDS
FileCheck count not
torch-mlir-dialects-opt
)
add_lit_testsuite(check-torch-mlir-dialects "Running the torch-mlir-dialects regression tests"
${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${TORCH_MLIR_DIALECTS_TEST_DEPENDS}
)
set_target_properties(check-torch-mlir-dialects PROPERTIES FOLDER "Tests")
add_lit_testsuites(TORCH_MLIR_DIALECTS ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TORCH_MLIR_DIALECTS_TEST_DEPENDS})

View File

@ -0,0 +1,69 @@
# -*- Python -*-
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import os
import platform
import re
import subprocess
import tempfile
import lit.formats
import lit.util
from lit.llvm import llvm_config
from lit.llvm.subst import ToolSubst
from lit.llvm.subst import FindTool
# Configuration file for the 'lit' test runner.
# name: The name of this test suite.
config.name = 'TORCH_MLIR_DIALECTS'
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
# suffixes: A list of file extensions to treat as test files.
config.suffixes = ['.mlir', '.py']
# test_source_root: The root path where tests are located.
config.test_source_root = os.path.dirname(__file__)
# test_exec_root: The root path where tests should be run.
config.test_exec_root = os.path.join(config.torch_mlir_dialects_obj_root, 'test')
config.substitutions.append(('%PATH%', config.environment['PATH']))
config.substitutions.append(('%shlibext', config.llvm_shlib_ext))
config.substitutions.append(
('%resources_dir', os.path.join(config.torch_mlir_dialects_obj_root,
'resources')))
llvm_config.with_system_environment(['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP'])
#llvm_config.use_default_substitutions()
# excludes: A list of directories to exclude from the testsuite. The 'Inputs'
# subdirectories contain auxiliary inputs for various tests in their parent
# directories.
config.excludes = [
'Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt',
'lit.cfg.py', 'lit.site.cfg.py'
]
# test_source_root: The root path where tests are located.
config.test_source_root = os.path.dirname(__file__)
# test_exec_root: The root path where tests should be run.
config.test_exec_root = os.path.join(config.torch_mlir_dialects_obj_root, 'test')
config.standalone_tools_dir = os.path.join(config.torch_mlir_dialects_obj_root, 'bin')
# Tweak the PATH to include the tools dir.
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
tool_dirs = [config.llvm_tools_dir]
tools = [
ToolSubst('%PYTHON', config.python_executable, unresolved='ignore'),
]
llvm_config.add_tool_substitutions(tools, tool_dirs)

View File

@ -0,0 +1,26 @@
# -*- Python -*-
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
@LIT_SITE_CFG_IN_HEADER@
import sys
config.torch_mlir_dialects_obj_root = "@TORCH_MLIR_DIALECTS_BINARY_DIR@"
config.llvm_src_root = "@LLVM_SOURCE_DIR@"
config.llvm_obj_root = "@LLVM_BINARY_DIR@"
config.llvm_tools_dir = "@LLVM_TOOLS_DIR@"
config.llvm_lib_dir = "@LLVM_LIBS_DIR@"
config.llvm_shlib_dir = "@SHLIBDIR@"
config.llvm_shlib_ext = "@SHLIBEXT@"
config.llvm_exe_ext = "@EXEEXT@"
config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@"
config.python_executable = sys.executable
import lit.llvm
lit.llvm.initialize(lit_config, config)
# Let the main config do the real work.
lit_config.load_config(config, "@TORCH_MLIR_DIALECTS_SOURCE_DIR@/test/lit.cfg.py")

View File

@ -0,0 +1,24 @@
// RUN: torch-mlir-dialects-opt -canonicalize -split-input-file %s | FileCheck %s
// CHECK-LABEL: func @tensor.cast(
func @tensor.cast(%arg0: tensor<128xi32>) -> tensor<128xi32> {
%init = linalg.init_tensor [128] : tensor<128xi32>
%c0 = linalg.init_tensor [] : tensor<i32>
%casted_arg0 = tensor.cast %arg0 : tensor<128xi32> to tensor<?xi32>
%casted_init = tensor.cast %init : tensor<128xi32> to tensor<?xi32>
// CHECK: tm_tensor.scan
// CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<128xi32>)
// CHECK-SAME: outs(%{{[a-zA-Z0-9]*}}, %{{[a-zA-Z0-9]*}} : tensor<128xi32>, tensor<i32>)
%0, %1 = tm_tensor.scan dimension(0) inclusive(true)
ins(%casted_arg0 : tensor<?xi32>)
outs(%casted_init, %c0: tensor<?xi32>, tensor<i32>) {
^bb0(%barg0 : i32, %barg1 : i32, %barg2 : i32):
%sum = arith.addi %barg0, %barg1 : i32
tm_tensor.yield %sum : i32
} -> tensor<?xi32>, tensor<i32>
%2 = tensor.cast %0: tensor<?xi32> to tensor<128xi32>
return %2: tensor<128xi32>
}

View File

@ -0,0 +1,332 @@
// RUN: torch-mlir-dialects-opt -split-input-file -torch-mlir-tm-tensor-to-loops %s | FileCheck %s
func @scan_1d_inclusive(%0: memref<128xi32>, %1: memref<128xi32>) {
%c0 = memref.alloc() : memref<i32>
tm_tensor.scan dimension(0) inclusive(true)
ins(%0 : memref<128xi32>) outs(%1, %c0 : memref<128xi32>, memref<i32>) {
^bb0(%arg0 : i32, %arg1 : i32):
%sum = arith.addi %arg0, %arg1 : i32
tm_tensor.yield %sum : i32
}
return
}
// CHECK-LABEL: func @scan_1d_inclusive
// CHECK-SAME: %[[BUFI:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[BUFO:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[ACC:.+]] = memref.alloc() : memref<i32>
// CHECK: scf.for %[[ARG1:.+]] = %[[C0]] to %[[C128]] step %[[C1]]
// CHECK: %[[COND:.+]] = arith.cmpi eq, %[[ARG1]], %[[C0]] : index
// CHECK: scf.if %[[COND]] {
// CHECK: %[[V1:.+]] = memref.load %[[BUFI]][%[[ARG1]]]
// CHECK: memref.store %[[V1]], %[[BUFO]][%[[ARG1]]]
// CHECK: } else {
// CHECK: %[[T1:.+]] = arith.subi %[[ARG1]], %[[C1]] : index
// CHECK: %[[V2:.+]] = memref.load %[[BUFO]][%[[T1]]]
// CHECK: %[[V3:.+]] = memref.load %[[BUFI]][%[[ARG1]]]
// CHECK: %[[V4:.+]] = arith.addi %[[V2]], %[[V3]] : i32
// CHECK: memref.store %[[V4]], %[[BUFO]][%[[ARG1]]]
// CHECK: memref.store %[[V4]], %[[ACC]][]
// CHECK: }
// -----
func @scan_1d_exclusive(%0: memref<128xi32>, %1: memref<128xi32>) {
%c0 = memref.alloc() : memref<i32>
tm_tensor.scan dimension(0) inclusive(false)
ins(%0 : memref<128xi32>) outs(%1, %c0 : memref<128xi32>, memref<i32>) {
^bb0(%arg0 : i32, %arg1 : i32):
%sum = arith.addi %arg0, %arg1 : i32
tm_tensor.yield %sum : i32
}
return
}
// CHECK-LABEL: func @scan_1d_exclusive
// CHECK-SAME: %[[BUFI:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[BUFO:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[ACC:.+]] = memref.alloc() : memref<i32>
// CHECK: scf.for %[[ARG1:.+]] = %[[C0]] to %[[C128]] step %[[C1]]
// CHECK: %[[COND:.+]] = arith.cmpi eq, %[[ARG1]], %[[C0]] : index
// CHECK: scf.if %[[COND]] {
// CHECK: %[[V0:.+]] = memref.load %[[ACC]][] : memref<i32>
// CHECK: memref.store %[[V0]], %[[BUFO]][%[[ARG1]]]
// CHECK: } else {
// CHECK: %[[T1:.+]] = arith.subi %[[ARG1]], %[[C1]] : index
// CHECK: %[[V2:.+]] = memref.load %[[BUFO]][%[[T1]]]
// CHECK: %[[V3:.+]] = memref.load %[[BUFI]][%[[T1]]]
// CHECK: %[[V4:.+]] = arith.addi %[[V2]], %[[V3]] : i32
// CHECK: memref.store %[[V4]], %[[BUFO]][%[[ARG1]]]
// CHECK: memref.store %[[V4]], %[[ACC]][]
// CHECK: }
// -----
func @scan_2d(%0: memref<16x32xi32>, %1: memref<16x32xi32>) {
%t0 = memref.alloc() : memref<32xi32>
tm_tensor.scan dimension(0) inclusive(true)
ins(%0 : memref<16x32xi32>) outs(%1, %t0 : memref<16x32xi32>, memref<32xi32>) {
^bb0(%arg0 : i32, %arg1 : i32):
%sum = arith.addi %arg0, %arg1 : i32
tm_tensor.yield %sum : i32
}
return
}
// CHECK-LABEL: func @scan_2d
// CHECK-SAME: %[[BUFI:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[BUFO:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[ACC:.+]] = memref.alloc() : memref<32xi32>
// CHECK: scf.for %[[ARG1:.+]] = %[[C0]] to %[[C16]] step %[[C1]]
// CHECK: scf.for %[[ARG2:.+]] = %[[C0]] to %[[C32]] step %[[C1]]
// CHECK: %[[COND:.+]] = arith.cmpi eq, %[[ARG1]], %[[C0]] : index
// CHECK: scf.if %[[COND]] {
// CHECK: %[[V1:.+]] = memref.load %[[BUFI]][%[[ARG1]], %[[ARG2]]]
// CHECK: memref.store %[[V1]], %[[BUFO]][%[[ARG1]], %[[ARG2]]]
// CHECK: } else {
// CHECK: %[[T1:.+]] = arith.subi %[[ARG1]], %[[C1]] : index
// CHECK: %[[V2:.+]] = memref.load %[[BUFO]][%[[T1]], %[[ARG2]]]
// CHECK: %[[V3:.+]] = memref.load %[[BUFI]][%[[ARG1]], %[[ARG2]]]
// CHECK: %[[V4:.+]] = arith.addi %[[V2]], %[[V3]] : i32
// CHECK: memref.store %[[V4]], %[[BUFO]][%[[ARG1]], %[[ARG2]]]
// CHECK: memref.store %[[V4]], %[[ACC]][%[[ARG2]]]
// CHECK: }
// -----
func @scatter_update_scalar_1D(
%original: memref<8xi32>, %indices: memref<3x1xi32>,
%updates: memref<3xi32>) {
tm_tensor.scatter
ins(%updates, %indices : memref<3xi32>, memref<3x1xi32>)
outs(%original : memref<8xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
tm_tensor.yield %arg0 : i32
}
return
}
// CHECK-LABEL: func @scatter_update_scalar_1D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C3]] step %[[C1]] {
// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<3xi32>
// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<3x1xi32>
// CHECK: %[[IDX:.+]] = arith.index_cast %[[T2]] : i32 to index
// CHECK: memref.store %[[T1]], %[[ORIGINAL]][%[[IDX]]]
// -----
func @scatter_add_scalar_2D(
%original: memref<4x3xi32>, %indices: memref<3x2xi32>,
%updates: memref<3xi32>) {
tm_tensor.scatter
ins(%updates, %indices : memref<3xi32>, memref<3x2xi32>)
outs(%original : memref<4x3xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
%0 = arith.addi %arg1, %arg0 : i32
tm_tensor.yield %0 : i32
}
return
}
// CHECK-LABEL: func @scatter_add_scalar_2D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C3]] step %[[C1]] {
// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<3xi32>
// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<3x2xi32>
// CHECK: %[[IDX1:.+]] = arith.index_cast %[[T2]] : i32 to index
// CHECK: %[[T3:.+]] = memref.load %[[INDICES]][%[[I]], %[[C1]]] : memref<3x2xi32>
// CHECK: %[[IDX2:.+]] = arith.index_cast %[[T3]] : i32 to index
// CHECK: %[[ORI:.+]] = memref.load %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]] : memref<4x3xi32>
// CHECK: %[[ADD:.+]] = arith.addi %[[ORI]], %[[T1]] : i32
// CHECK: memref.store %[[ADD]], %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]]
// -----
func @scatter_update_slice_2D(
%original: memref<4x3xi32>, %indices: memref<2x1xi32>,
%updates: memref<2x3xi32>) {
tm_tensor.scatter
ins(%updates, %indices : memref<2x3xi32>, memref<2x1xi32>)
outs(%original : memref<4x3xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
tm_tensor.yield %arg0 : i32
}
return
}
// CHECK: func @scatter_update_slice_2D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C2]] step %[[C1]] {
// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[C3]] step %[[C1]] {
// CHECK: %[[UPDATE:.+]] = memref.load %[[UPDATES]][%[[I]], %[[J]]]
// CHECK: %[[INDEX:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]]
// CHECK: %[[LOC:.+]] = arith.index_cast %[[INDEX]] : i32 to index
// CHECK: memref.store %[[UPDATE]], %[[ORIGINAL]][%[[LOC]], %[[J]]]
// CHECK: }
// CHECK: }
// -----
func @scatter_add_scalar_1D(
%original: memref<8xi32>, %indices: memref<3x1xi32>,
%updates: memref<3xi32>) {
tm_tensor.scatter
ins(%updates, %indices : memref<3xi32>, memref<3x1xi32>)
outs(%original : memref<8xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
%0 = arith.addi %arg1, %arg0 : i32
tm_tensor.yield %0 : i32
}
return
}
// CHECK-LABEL: func @scatter_add_scalar_1D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C3]] step %[[C1]] {
// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<3xi32>
// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<3x1xi32>
// CHECK: %[[IDX:.+]] = arith.index_cast %[[T2]] : i32 to index
// CHECK: %[[ORI:.+]] = memref.load %[[ORIGINAL]][%[[IDX]]] : memref<8xi32>
// CHECK: %[[ADD:.+]] = arith.addi %[[ORI]], %[[T1]] : i32
// CHECK: memref.store %[[ADD]], %[[ORIGINAL]][%[[IDX]]]
// -----
func @scatter_add_slice_2D(
%original: memref<4x3xi32>, %indices: memref<2x1xi32>,
%updates: memref<2x3xi32>) {
tm_tensor.scatter
ins(%updates, %indices : memref<2x3xi32>, memref<2x1xi32>)
outs(%original : memref<4x3xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
%0 = arith.addi %arg1, %arg0 : i32
tm_tensor.yield %0 : i32
}
return
}
// CHECK: func @scatter_add_slice_2D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C2]] step %[[C1]] {
// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[C3]] step %[[C1]] {
// CHECK: %[[UPDATEVAL:.+]] = memref.load %[[UPDATES]][%[[I]], %[[J]]]
// CHECK: %[[INDEXVAL:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]]
// CHECK: %[[INDEX:.+]] = arith.index_cast %[[INDEXVAL]] : i32 to index
// CHECK: %[[ORIGINALVAL:.+]] = memref.load %[[ORIGINAL]][%[[INDEX]], %[[J]]]
// CHECK: %[[STOREVAL:.+]] = arith.addi %[[ORIGINALVAL]], %[[UPDATEVAL]]
// CHECK: memref.store %[[STOREVAL]], %[[ORIGINAL]][%[[INDEX]], %[[J]]]
// -----
func @scatter_update_scalar_dynamic_1D(
%original: memref<?xi32>, %indices: memref<?x1xi32>,
%updates: memref<?xi32>) {
tm_tensor.scatter
ins(%updates, %indices : memref<?xi32>, memref<?x1xi32>)
outs(%original : memref<?xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
tm_tensor.yield %arg0 : i32
}
return
}
// CHECK-LABEL: func @scatter_update_scalar_dynamic_1D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[UB:.+]] = memref.dim %[[UPDATES]], %[[C0]] : memref<?xi32>
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[UB]] step %[[C1]] {
// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<?xi32>
// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<?x1xi32>
// CHECK: %[[IDX:.+]] = arith.index_cast %[[T2]] : i32 to index
// CHECK: memref.store %[[T1]], %[[ORIGINAL]][%[[IDX]]]
// -----
func @scatter_add_scalar_dynamic_2D(
%original: memref<?x?xi32>, %indices: memref<?x2xi32>,
%updates: memref<?xi32>) {
tm_tensor.scatter
ins(%updates, %indices : memref<?xi32>, memref<?x2xi32>)
outs(%original : memref<?x?xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
%0 = arith.addi %arg1, %arg0 : i32
tm_tensor.yield %0 : i32
}
return
}
// CHECK-LABEL: func @scatter_add_scalar_dynamic_2D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[UB:.+]] = memref.dim %[[UPDATES]], %[[C0]] : memref<?xi32>
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[UB]] step %[[C1]] {
// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<?xi32>
// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<?x2xi32>
// CHECK: %[[IDX1:.+]] = arith.index_cast %[[T2]] : i32 to index
// CHECK: %[[T3:.+]] = memref.load %[[INDICES]][%[[I]], %[[C1]]] : memref<?x2xi32>
// CHECK: %[[IDX2:.+]] = arith.index_cast %[[T3]] : i32 to index
// CHECK: %[[ORI:.+]] = memref.load %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]] : memref<?x?xi32>
// CHECK: %[[ADD:.+]] = arith.addi %[[ORI]], %[[T1]] : i32
// CHECK: memref.store %[[ADD]], %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]]
// -----
func @scatter_update_slice_dynamic_2D(
%original: memref<?x?xi32>, %indices: memref<?x1xi32>,
%updates: memref<?x?xi32>) {
tm_tensor.scatter
ins(%updates, %indices : memref<?x?xi32>, memref<?x1xi32>)
outs(%original : memref<?x?xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
tm_tensor.yield %arg0 : i32
}
return
}
// CHECK: func @scatter_update_slice_dynamic_2D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[UB1:.+]] = memref.dim %[[UPDATES]], %[[C0]] : memref<?x?xi32>
// CHECK-DAG: %[[UB2:.+]] = memref.dim %[[UPDATES]], %[[C1]] : memref<?x?xi32>
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[UB1]] step %[[C1]] {
// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[UB2]] step %[[C1]] {
// CHECK: %[[UPDATEVAL:.+]] = memref.load %[[UPDATES]][%[[I]], %[[J]]]
// CHECK: %[[INDEXVAL:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]]
// CHECK: %[[INDEX:.+]] = arith.index_cast %[[INDEXVAL]] : i32 to index
// CHECK: memref.store %[[UPDATEVAL]], %[[ORIGINAL]][%[[INDEX]], %[[J]]]

View File

@ -0,0 +1 @@
add_subdirectory(torch-mlir-dialects-opt)

View File

@ -0,0 +1,22 @@
set(LIBS
MLIRArithmetic
MLIRDialect
MLIRLinalg
MLIRMemRef
MLIROptLib
MLIRSCF
MLIRSCFTransforms
MLIRStandard
MLIRTensor
MLIRTransforms
TorchMLIRTMTensorDialect
TorchMLIRTMTensorPasses
)
add_llvm_tool(torch-mlir-dialects-opt
torch-mlir-dialects-opt.cpp
DEPENDS
${LIBS}
)
target_link_libraries(torch-mlir-dialects-opt PRIVATE ${LIBS})

View File

@ -0,0 +1,49 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Passes.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Support/MlirOptMain.h"
#include "mlir/Transforms/Passes.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h"
using namespace mlir;
int main(int argc, char **argv) {
registerAsmPrinterCLOptions();
registerMLIRContextCLOptions();
registerTransformsPasses();
registerSCFPasses();
// Local dialects.
mlir::torch::TMTensor::registerPasses();
DialectRegistry registry;
registry.insert<
// Local dialects
mlir::torch::TMTensor::TMTensorDialect,
// Upstream dialects
mlir::arith::ArithmeticDialect, mlir::linalg::LinalgDialect,
mlir::memref::MemRefDialect, mlir::StandardOpsDialect,
mlir::scf::SCFDialect, mlir::tensor::TensorDialect>();
return mlir::asMainReturnCode(
mlir::MlirOptMain(argc, argv, "MLIR modular optimizer driver\n", registry,
/*preloadDialectsInContext=*/false));
}

View File

@ -19,6 +19,9 @@ add_mlir_library(TorchMLIRInitAll
TorchMLIRTorchDialect TorchMLIRTorchDialect
TorchMLIRTorchConversionPasses TorchMLIRTorchConversionPasses
TorchMLIRTMTensorPasses
TorchMLIRTMTensorDialect
TorchMLIRConversionPasses TorchMLIRConversionPasses
TorchMLIRRefBackend TorchMLIRRefBackend
) )

View File

@ -10,6 +10,8 @@
#include "torch-mlir/InitAll.h" #include "torch-mlir/InitAll.h"
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h"
#include "torch-mlir/Conversion/Passes.h" #include "torch-mlir/Conversion/Passes.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
@ -20,6 +22,7 @@
void mlir::torch::registerAllDialects(mlir::DialectRegistry &registry) { void mlir::torch::registerAllDialects(mlir::DialectRegistry &registry) {
registry.insert<mlir::torch::Torch::TorchDialect>(); registry.insert<mlir::torch::Torch::TorchDialect>();
registry.insert<mlir::torch::TorchConversion::TorchConversionDialect>(); registry.insert<mlir::torch::TorchConversion::TorchConversionDialect>();
registry.insert<mlir::torch::TMTensor::TMTensorDialect>();
} }
void mlir::torch::registerAllPasses() { void mlir::torch::registerAllPasses() {
@ -28,4 +31,5 @@ void mlir::torch::registerAllPasses() {
mlir::torch::registerConversionPasses(); mlir::torch::registerConversionPasses();
mlir::torch::RefBackend::registerRefBackendPasses(); mlir::torch::RefBackend::registerRefBackendPasses();
mlir::torch::TMTensor::registerPasses();
} }

View File

@ -62,8 +62,9 @@ class CMakeBuild(build_py):
f"-DLLVM_TARGETS_TO_BUILD=host", f"-DLLVM_TARGETS_TO_BUILD=host",
f"-DMLIR_ENABLE_BINDINGS_PYTHON=ON", f"-DMLIR_ENABLE_BINDINGS_PYTHON=ON",
f"-DLLVM_ENABLE_PROJECTS=mlir", f"-DLLVM_ENABLE_PROJECTS=mlir",
f"-DLLVM_EXTERNAL_PROJECTS=torch-mlir", f"-DLLVM_EXTERNAL_PROJECTS=torch-mlir;torch-mlir-dialects",
f"-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR={src_dir}", f"-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR={src_dir}",
f"-DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR={src_dir}/external/llvm-external-projects/torch-mlir-dialects",
# Optimization options for building wheels. # Optimization options for building wheels.
f"-DCMAKE_VISIBILITY_INLINES_HIDDEN=ON", f"-DCMAKE_VISIBILITY_INLINES_HIDDEN=ON",
f"-DCMAKE_C_VISIBILITY_PRESET=hidden", f"-DCMAKE_C_VISIBILITY_PRESET=hidden",