More progress on PyTorch acap device capture.

* Now gets far enough to capture batch_norm.
* Has some issues still with in-place ops.
* Can materialize constants.
* Includes an upgrade to PyTorch nightly, which has important bug fixes for fallback and boxed kernel dispatch.
* Fixes #78, #79, #80.
* Will do more testing in a follow-up once further bugs are fixed that facilitate getting at the other features.
pull/84/head
Stella Laurenzo 2020-10-15 18:28:30 -07:00
parent 06a8ba6900
commit 9e52f6235b
15 changed files with 168 additions and 39 deletions

View File

@ -124,7 +124,7 @@ Create docker image (or follow your own preferences):
* Mount the `/build` directory (in the container) appropriately for your case. * Mount the `/build` directory (in the container) appropriately for your case.
```shell ```shell
docker build docker/pytorch-1.6 --tag local/npcomp:build-pytorch-1.6 docker build docker/pytorch-nightly --tag local/npcomp:build-pytorch-nightly
docker volume create npcomp-build docker volume create npcomp-build
``` ```
@ -134,7 +134,7 @@ Shell into docker image:
docker run \ docker run \
--mount type=bind,source=$HOME/src/mlir-npcomp,target=/src/mlir-npcomp \ --mount type=bind,source=$HOME/src/mlir-npcomp,target=/src/mlir-npcomp \
--mount source=npcomp-build,target=/build \ --mount source=npcomp-build,target=/build \
--rm -it local/npcomp:build-pytorch-1.6 /bin/bash --rm -it local/npcomp:build-pytorch-nightly /bin/bash
``` ```
Build/test npcomp (from within docker image): Build/test npcomp (from within docker image):

View File

@ -4,25 +4,25 @@
# source $WHERE_YOU_CHECKED_OUT_NPCOMP/build_tools/docker_shell_funcs.sh # source $WHERE_YOU_CHECKED_OUT_NPCOMP/build_tools/docker_shell_funcs.sh
# ``` # ```
td="$(realpath $(dirname "${BASH_SOURCE[0]}")/..)" __npcomp_dir="$(realpath $(dirname "${BASH_SOURCE[0]}")/..)"
# Build the docker images for npcomp: # Build the docker images for npcomp:
# npcomp:build-pytorch-1.6 # npcomp:build-pytorch-nightly
# me/npcomp:build-pytorch-1.6 (additional dev packages and current user) # me/npcomp:build-pytorch-nightly (additional dev packages and current user)
function npcomp_docker_build() { function npcomp_docker_build() {
if ! [ -f "docker/pytorch-1.6/Dockerfile" ]; then if ! [ -f "docker/pytorch-nightly/Dockerfile" ]; then
echo "Please run out of mlir-npcomp/ source directory..." echo "Please run out of mlir-npcomp/ source directory..."
return 1 return 1
fi fi
echo "Building out of $(pwd)..." echo "Building out of $(pwd)..."
docker build docker/pytorch-1.6 --tag npcomp:build-pytorch-1.6 docker build docker/pytorch-nightly --tag npcomp:build-pytorch-nightly
npcomp_docker_build_for_me npcomp:build-pytorch-1.6 npcomp_docker_build_for_me npcomp:build-pytorch-nightly
} }
# Start a container named "npcomp" in the background with the current-user # Start a container named "npcomp" in the background with the current-user
# dev image built above. # dev image built above.
function npcomp_docker_start() { function npcomp_docker_start() {
local host_src_dir="${1-$td}" local host_src_dir="${1-$__npcomp_dir}"
if ! [ -d "$host_src_dir" ]; then if ! [ -d "$host_src_dir" ]; then
echo "mlir-npcomp source directory not found:" echo "mlir-npcomp source directory not found:"
echo "Pass path to host source directory as argument (default=$host_src_dir)." echo "Pass path to host source directory as argument (default=$host_src_dir)."
@ -32,7 +32,7 @@ function npcomp_docker_start() {
docker run -d --rm --name "npcomp" \ docker run -d --rm --name "npcomp" \
--mount source=npcomp-build,target=/build \ --mount source=npcomp-build,target=/build \
--mount type=bind,source=$host_src_dir,target=/src/mlir-npcomp \ --mount type=bind,source=$host_src_dir,target=/src/mlir-npcomp \
me/npcomp:build-pytorch-1.6 tail -f /dev/null me/npcomp:build-pytorch-nightly tail -f /dev/null
} }
# Stop the container named "npcomp". # Stop the container named "npcomp".

View File

@ -14,7 +14,8 @@ RUN ln -s /usr/bin/llvm-symbolizer-10 /usr/bin/llvm-symbolizer
# Install PyTorch # Install PyTorch
# Installs under: /usr/local/lib/python3.8/dist-packages/torch # Installs under: /usr/local/lib/python3.8/dist-packages/torch
RUN pip3 install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html RUN pip3 install numpy
RUN pip3 install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
RUN ln -s /usr/local/lib/python3.8/dist-packages/torch /pytorch RUN ln -s /usr/local/lib/python3.8/dist-packages/torch /pytorch
# Build configuration # Build configuration

@ -1 +1 @@
Subproject commit ee491ac91e123b90eeec3cce7e494936ea8cb85d Subproject commit 6771b98c4e4d5c0bd0a78a876bd212a76ec80a24

View File

@ -8,6 +8,8 @@
#include "acap_dispatch.h" #include "acap_dispatch.h"
#include "mlir-c/StandardAttributes.h" #include "mlir-c/StandardAttributes.h"
#include "mlir-c/StandardTypes.h"
#include "npcomp-c/Types.h"
#include "npcomp/Python/PybindUtils.h" #include "npcomp/Python/PybindUtils.h"
#include <ATen/core/function_schema.h> #include <ATen/core/function_schema.h>
@ -81,8 +83,7 @@ void AcapController::returns(std::vector<at::Tensor> tensors) {
returnsValues.push_back(v); returnsValues.push_back(v);
} }
// TODO: Get location from traceback. MlirLocation loc = getCurrentLocation();
MlirLocation loc = mlirLocationUnknownGet(funcBuilder->getContext());
OperationStateHolder s("std.return", loc); OperationStateHolder s("std.return", loc);
mlirOperationStateAddOperands(&s.state, returnsValues.size(), mlirOperationStateAddOperands(&s.state, returnsValues.size(),
returnsValues.data()); returnsValues.data());
@ -123,6 +124,10 @@ void AcapController::fallbackKernel(const OperatorHandle &opHandle,
current->fallbackKernelImpl(opHandle, stack); current->fallbackKernelImpl(opHandle, stack);
} }
MlirLocation AcapController::getCurrentLocation() {
return mlirLocationUnknownGet(funcBuilder->getContext());
}
void AcapController::redispatch(const c10::OperatorHandle &opHandle, void AcapController::redispatch(const c10::OperatorHandle &opHandle,
c10::Stack *stack) { c10::Stack *stack) {
// Exclude recursive dispatch to this kernel. // Exclude recursive dispatch to this kernel.
@ -168,8 +173,8 @@ void AcapController::fallbackKernelImpl(const OperatorHandle &opHandle,
MlirValue mlirValue = mapIValueToMlirValue(loc, *argIt); MlirValue mlirValue = mapIValueToMlirValue(loc, *argIt);
if (mlirValueIsNull(mlirValue)) { if (mlirValueIsNull(mlirValue)) {
std::stringstream out; std::stringstream out;
out << "Unsupported capture value passed to kernel (" << argIt->tagKind() out << "Unsupported capture value returned from kernel '" << kernelName
<< "): " << *argIt; << "' (" << argIt->tagKind() << "): " << *argIt;
throw std::invalid_argument(out.str()); throw std::invalid_argument(out.str());
} }
operands.push_back(mlirValue); operands.push_back(mlirValue);
@ -191,8 +196,8 @@ void AcapController::fallbackKernelImpl(const OperatorHandle &opHandle,
MlirType resultType = mapIValueToMlirType(loc, *returnIt); MlirType resultType = mapIValueToMlirType(loc, *returnIt);
if (mlirTypeIsNull(resultType)) { if (mlirTypeIsNull(resultType)) {
std::stringstream out; std::stringstream out;
out << "Unsupported capture value returned from kernel (" out << "Unsupported capture value returned from kernel '" << kernelName
<< returnIt->tagKind() << "): " << *returnIt; << "' (" << returnIt->tagKind() << "): " << *returnIt;
throw std::invalid_argument(out.str()); throw std::invalid_argument(out.str());
} }
resultTypes.push_back(resultType); resultTypes.push_back(resultType);
@ -227,13 +232,17 @@ MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc,
if (ival.isTensor()) { if (ival.isTensor()) {
// Is it an already mapped tensor? // Is it an already mapped tensor?
MlirValue mappedValue = funcBuilder->lookupTensor(ival.toTensor()); MlirValue mappedValue = funcBuilder->lookupTensor(ival.toTensor());
// TODO: Add mlirValueIsNull() if (!mlirValueIsNull(mappedValue)) {
if (mappedValue.ptr) {
return mappedValue; return mappedValue;
} }
throw std::invalid_argument( mappedValue = importTensorByValue(ival.toTensor());
"TODO: implement tensor import for non-arg tensors"); assert(mappedValue.ptr);
return mappedValue;
}
if (ival.isBool()) {
// TODO: Switch to the numpy.bool type as that is a closer domain match.
return funcBuilder->getBoolConstant(loc, ival.toBool());
} }
return {nullptr}; return {nullptr};
// TODO: Implement mappings for the whole set (relevant to this use case): // TODO: Implement mappings for the whole set (relevant to this use case):
@ -241,7 +250,6 @@ MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc,
// _(Tensor) // _(Tensor)
// _(Double) // _(Double)
// _(Int) // _(Int)
// _(Bool)
// _(Tuple) // _(Tuple)
// _(String) // _(String)
// _(Blob) // _(Blob)
@ -265,9 +273,86 @@ MlirType AcapController::mapIValueToMlirType(MlirLocation loc,
if (ival.isTensor()) { if (ival.isTensor()) {
return typeMapper.forwardTensorToType(ival.toTensor()); return typeMapper.forwardTensorToType(ival.toTensor());
} }
if (ival.isBool()) {
// TODO: Switch to the numpy.bool type as that is a closer domain match.
return mlirIntegerTypeGet(funcBuilder->getContext(), 1);
}
return {nullptr}; return {nullptr};
} }
MlirValue AcapController::importTensorByValue(at::Tensor tensor) {
using at::ScalarType;
auto throwUnsupportedTensorError = [&]() {
std::stringstream msg;
msg << "Unsupported import tensor type: " << tensor;
throw std::invalid_argument(msg.str());
};
// Get a C-contiguous form as we can bulk-load that into a DenseElementsAttr.
if (!tensor.is_contiguous())
tensor = tensor.contiguous();
// The flat number of bytes throws an exception for tensors that are not
// dense and accessible as such.
at::checkLayout(at::CheckedFrom("accessing contiguous"), tensor,
c10::Layout::Strided);
// Construct the ShapedType.
auto loc = getCurrentLocation();
MlirType elementType = typeMapper.mapScalarType(tensor.scalar_type());
llvm::SmallVector<int64_t, 4> shape(tensor.sizes().begin(),
tensor.sizes().end());
MlirType shapedType = mlirRankedTensorTypeGetChecked(
shape.size(), shape.data(), elementType, loc);
if (mlirTypeIsNull(shapedType)) {
throwUnsupportedTensorError();
}
// Import DenseElementsAttr data.
// TODO: Support bool tensors.
// TODO: More import formats in C-API.
MlirAttribute valueAttribute;
auto numElements = tensor.numel();
auto tensorData = tensor.data_ptr();
switch (tensor.scalar_type()) {
case ScalarType::Int:
valueAttribute = mlirDenseElementsAttrInt32Get(
shapedType, numElements, static_cast<const int32_t *>(tensorData));
break;
case ScalarType::Long:
valueAttribute = mlirDenseElementsAttrInt64Get(
shapedType, numElements, static_cast<const int64_t *>(tensorData));
break;
case ScalarType::Float:
valueAttribute = mlirDenseElementsAttrFloatGet(
shapedType, numElements, static_cast<const float *>(tensorData));
break;
case ScalarType::Double:
valueAttribute = mlirDenseElementsAttrDoubleGet(
shapedType, numElements, static_cast<const double *>(tensorData));
break;
default:
throwUnsupportedTensorError();
}
MlirValue constTensorValue =
funcBuilder->getGeneralConstant(loc, valueAttribute);
// Create an array from the tensor constant via the
// numpy.create_array_from_tensor op.
MlirType constArrayType = npcompNdArrayTypeGetFromShaped(shapedType);
MlirOperationState state =
mlirOperationStateGet("numpy.create_array_from_tensor", loc);
mlirOperationStateAddOperands(&state, 1, &constTensorValue);
mlirOperationStateAddResults(&state, 1, &constArrayType);
MlirOperation constArrayOp = mlirOperationCreate(&state);
funcBuilder->getEntryBlockBuilder().insertBeforeTerminator(constArrayOp);
MlirValue constArrayValue = mlirOperationGetResult(constArrayOp, 0);
funcBuilder->mapTensor(tensor, constArrayValue);
return constArrayValue;
}
TORCH_LIBRARY_IMPL(_, ACAP_DISPATCH_KEY, m) { TORCH_LIBRARY_IMPL(_, ACAP_DISPATCH_KEY, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction< m.fallback(torch::CppFunction::makeFromBoxedFunction<
&AcapController::fallbackKernel>()); &AcapController::fallbackKernel>());

View File

@ -56,11 +56,14 @@ public:
c10::Stack *stack); c10::Stack *stack);
private: private:
MlirLocation getCurrentLocation();
void redispatch(const c10::OperatorHandle &opHandle, c10::Stack *stack); void redispatch(const c10::OperatorHandle &opHandle, c10::Stack *stack);
void fallbackKernelImpl(const c10::OperatorHandle &opHandle, void fallbackKernelImpl(const c10::OperatorHandle &opHandle,
c10::Stack *stack); c10::Stack *stack);
MlirValue mapIValueToMlirValue(MlirLocation loc, c10::IValue &ival); MlirValue mapIValueToMlirValue(MlirLocation loc, c10::IValue &ival);
MlirType mapIValueToMlirType(MlirLocation loc, c10::IValue &ival); MlirType mapIValueToMlirType(MlirLocation loc, c10::IValue &ival);
/// Imports a tensor by value (as a constant), remembering the association.
MlirValue importTensorByValue(at::Tensor tensor);
void verifyHasNotReturned(); void verifyHasNotReturned();
struct Activation { struct Activation {
Activation(std::shared_ptr<AcapController> controller) Activation(std::shared_ptr<AcapController> controller)

View File

@ -149,18 +149,31 @@ MlirValue FuncBuilder::getScalarConstant(MlirLocation loc, at::Scalar s) {
// TODO: Switch to a basicpy.constant that works properly with signed // TODO: Switch to a basicpy.constant that works properly with signed
// integers and then switch this to a signed integer. // integers and then switch this to a signed integer.
MlirType t = mlirIntegerTypeGet(context, 64); MlirType t = mlirIntegerTypeGet(context, 64);
MlirOperation op = MlirAttribute value = mlirIntegerAttrGet(t, s.to<int64_t>());
createStandardConstant(loc, t, mlirIntegerAttrGet(t, s.to<int64_t>())); return getGeneralConstant(loc, value);
return insertConstantOp(op);
} }
if (s.isFloatingPoint()) { if (s.isFloatingPoint()) {
MlirType t = mlirF64TypeGet(context); MlirType t = mlirF64TypeGet(context);
MlirOperation op = createStandardConstant( MlirAttribute value = mlirFloatAttrDoubleGet(context, t, s.to<double>());
loc, t, mlirFloatAttrDoubleGet(context, t, s.to<double>())); return getGeneralConstant(loc, value);
return insertConstantOp(op); }
if (s.isBoolean()) {
return getBoolConstant(loc, s.to<bool>());
} }
// TODO: s.isBoolean()
// TODO: s.isComplex() // TODO: s.isComplex()
throw std::invalid_argument("TODO: Scalar of unknown kind"); throw std::invalid_argument("TODO: Scalar of unknown kind");
} }
MlirValue FuncBuilder::getBoolConstant(MlirLocation loc, bool v) {
MlirAttribute value = mlirBoolAttrGet(context, v);
return getGeneralConstant(loc, value);
}
MlirValue FuncBuilder::getGeneralConstant(MlirLocation loc,
MlirAttribute value) {
MlirType valueType = mlirAttributeGetType(value);
MlirOperation constOp = createStandardConstant(loc, valueType, value);
MlirValue constValue = insertConstantOp(constOp);
return constValue;
}

View File

@ -120,6 +120,13 @@ public:
/// Gets a scalar constant value. /// Gets a scalar constant value.
MlirValue getScalarConstant(MlirLocation loc, at::Scalar s); MlirValue getScalarConstant(MlirLocation loc, at::Scalar s);
/// Gets a bool constant value.
MlirValue getBoolConstant(MlirLocation loc, bool v);
/// Gets a general constant value representing the given value
/// attribute.
MlirValue getGeneralConstant(MlirLocation loc, MlirAttribute value);
private: private:
FuncBuilder(MlirContext context, MlirOperation funcOp, FuncBuilder(MlirContext context, MlirOperation funcOp,
BlockBuilder entryBlock) BlockBuilder entryBlock)

View File

@ -11,6 +11,7 @@ configure_lit_site_cfg(
set(TEST_DEPENDS set(TEST_DEPENDS
FileCheck count not FileCheck count not
npcomp-opt
NPCOMPTorchMLIRExt NPCOMPTorchMLIRExt
) )

View File

@ -5,23 +5,21 @@
import torch import torch
import torch_mlir import torch_mlir
# See bug references below and remove XFAIL when resolved.
# XFAIL: *
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s # RUN: %PYTHON %s | npcomp-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder() mb = torch_mlir.ModuleBuilder()
# TODO: Both of these fail with the "unsupported from an unboxed API yet" error.
# The corresponding ops need to be manually coded. Then these can be moved into
# the capture. https://github.com/llvm/mlir-npcomp/issues/78
# TODO: These also create constant tensors (needs implementation of import of
# DenseElements constants). https://github.com/llvm/mlir-npcomp/issues/79
model = torch.nn.BatchNorm2d(123)
ones = torch.ones(42,123,4,5) ones = torch.ones(42,123,4,5)
with mb.capture_function("bn2d", []) as f: with mb.capture_function("bn2d", [ones]) as f:
model = torch.nn.BatchNorm2d(123)
result = model(ones) result = model(ones)
f.returns([result]) f.returns([result])
# TODO: This test exercises promotion of const to arrays, inplace zero_ and
# add, all of which should be checked individually because they have specific
# behavior.
# CHECK-LABEL: @bn2d # CHECK-LABEL: @bn2d
# CHECK: %[[RESULT:.*]]:3 = torch.kernel_call "aten::native_batch_norm" %arg0
# CHECK: return %[[RESULT]]#0 : !numpy.ndarray<[42,123,4,5]:f32>
print(mb.module) print(mb.module)

View File

@ -49,6 +49,9 @@ int npcompTypeIsANdArray(MlirType t);
MlirType npcompNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape, MlirType npcompNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape,
MlirType elementType); MlirType elementType);
/// Helper that gets an equivalent NdArrayType from a ShapedType.
MlirType npcompNdArrayTypeGetFromShaped(MlirType shapedType);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -40,6 +40,9 @@ public:
static NdArrayType get(Type dtype, static NdArrayType get(Type dtype,
llvm::Optional<ArrayRef<int64_t>> shape = llvm::None); llvm::Optional<ArrayRef<int64_t>> shape = llvm::None);
/// Helper that gets an equivalent NdArrayType from a ShapedType.
static NdArrayType getFromShapedType(ShapedType shapedType);
/// Returns whether the dtype is a concrete type (versus /// Returns whether the dtype is a concrete type (versus
/// !basicpy.UnknownType). /// !basicpy.UnknownType).
bool hasKnownDtype(); bool hasKnownDtype();

View File

@ -86,6 +86,7 @@ def AnyTorchTensorType : AnyTypeOf<[
def AnyScalar : AnyTypeOf<[ def AnyScalar : AnyTypeOf<[
AnySignedInteger, AnySignedInteger,
AnyFloat,
Basicpy_BoolType, Basicpy_BoolType,
Basicpy_StrType, Basicpy_StrType,
Basicpy_NoneType, Basicpy_NoneType,

View File

@ -9,9 +9,11 @@
#include "npcomp-c/Types.h" #include "npcomp-c/Types.h"
#include "mlir/CAPI/IR.h" #include "mlir/CAPI/IR.h"
#include "mlir/IR/StandardTypes.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h" #include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h" #include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
using namespace mlir;
using namespace mlir::NPCOMP::Basicpy; using namespace mlir::NPCOMP::Basicpy;
using namespace mlir::NPCOMP::Numpy; using namespace mlir::NPCOMP::Numpy;
@ -46,3 +48,8 @@ MlirType npcompNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape,
llvm::ArrayRef<int64_t> shapeArray(shape, rank); llvm::ArrayRef<int64_t> shapeArray(shape, rank);
return wrap(NdArrayType::get(unwrap(elementType), shapeArray)); return wrap(NdArrayType::get(unwrap(elementType), shapeArray));
} }
MlirType npcompNdArrayTypeGetFromShaped(MlirType shapedType) {
return wrap(
NdArrayType::getFromShapedType(unwrap(shapedType).cast<ShapedType>()));
}

View File

@ -195,6 +195,13 @@ NdArrayType NdArrayType::get(Type dtype,
return Base::get(dtype.getContext(), dtype, shape); return Base::get(dtype.getContext(), dtype, shape);
} }
NdArrayType NdArrayType::getFromShapedType(ShapedType shapedType) {
llvm::Optional<ArrayRef<int64_t>> shape;
if (shapedType.hasRank())
shape = shapedType.getShape();
return get(shapedType.getElementType(), shape);
}
bool NdArrayType::hasKnownDtype() { bool NdArrayType::hasKnownDtype() {
return getDtype() != Basicpy::UnknownType::get(getContext()); return getDtype() != Basicpy::UnknownType::get(getContext());
} }