mirror of https://github.com/llvm/torch-mlir
More progress on PyTorch acap device capture.
* Now gets far enough to capture batch_norm. * Has some issues still with in-place ops. * Can materialize constants. * Includes an upgrade to PyTorch nightly, which has important bug fixes for fallback and boxed kernel dispatch. * Fixes #78, #79, #80. * Will do more testing in a follow-up once further bugs are fixed that facilitate getting at the other features.pull/84/head
parent
06a8ba6900
commit
9e52f6235b
|
@ -124,7 +124,7 @@ Create docker image (or follow your own preferences):
|
||||||
* Mount the `/build` directory (in the container) appropriately for your case.
|
* Mount the `/build` directory (in the container) appropriately for your case.
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
docker build docker/pytorch-1.6 --tag local/npcomp:build-pytorch-1.6
|
docker build docker/pytorch-nightly --tag local/npcomp:build-pytorch-nightly
|
||||||
docker volume create npcomp-build
|
docker volume create npcomp-build
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -134,7 +134,7 @@ Shell into docker image:
|
||||||
docker run \
|
docker run \
|
||||||
--mount type=bind,source=$HOME/src/mlir-npcomp,target=/src/mlir-npcomp \
|
--mount type=bind,source=$HOME/src/mlir-npcomp,target=/src/mlir-npcomp \
|
||||||
--mount source=npcomp-build,target=/build \
|
--mount source=npcomp-build,target=/build \
|
||||||
--rm -it local/npcomp:build-pytorch-1.6 /bin/bash
|
--rm -it local/npcomp:build-pytorch-nightly /bin/bash
|
||||||
```
|
```
|
||||||
|
|
||||||
Build/test npcomp (from within docker image):
|
Build/test npcomp (from within docker image):
|
||||||
|
|
|
@ -4,25 +4,25 @@
|
||||||
# source $WHERE_YOU_CHECKED_OUT_NPCOMP/build_tools/docker_shell_funcs.sh
|
# source $WHERE_YOU_CHECKED_OUT_NPCOMP/build_tools/docker_shell_funcs.sh
|
||||||
# ```
|
# ```
|
||||||
|
|
||||||
td="$(realpath $(dirname "${BASH_SOURCE[0]}")/..)"
|
__npcomp_dir="$(realpath $(dirname "${BASH_SOURCE[0]}")/..)"
|
||||||
|
|
||||||
# Build the docker images for npcomp:
|
# Build the docker images for npcomp:
|
||||||
# npcomp:build-pytorch-1.6
|
# npcomp:build-pytorch-nightly
|
||||||
# me/npcomp:build-pytorch-1.6 (additional dev packages and current user)
|
# me/npcomp:build-pytorch-nightly (additional dev packages and current user)
|
||||||
function npcomp_docker_build() {
|
function npcomp_docker_build() {
|
||||||
if ! [ -f "docker/pytorch-1.6/Dockerfile" ]; then
|
if ! [ -f "docker/pytorch-nightly/Dockerfile" ]; then
|
||||||
echo "Please run out of mlir-npcomp/ source directory..."
|
echo "Please run out of mlir-npcomp/ source directory..."
|
||||||
return 1
|
return 1
|
||||||
fi
|
fi
|
||||||
echo "Building out of $(pwd)..."
|
echo "Building out of $(pwd)..."
|
||||||
docker build docker/pytorch-1.6 --tag npcomp:build-pytorch-1.6
|
docker build docker/pytorch-nightly --tag npcomp:build-pytorch-nightly
|
||||||
npcomp_docker_build_for_me npcomp:build-pytorch-1.6
|
npcomp_docker_build_for_me npcomp:build-pytorch-nightly
|
||||||
}
|
}
|
||||||
|
|
||||||
# Start a container named "npcomp" in the background with the current-user
|
# Start a container named "npcomp" in the background with the current-user
|
||||||
# dev image built above.
|
# dev image built above.
|
||||||
function npcomp_docker_start() {
|
function npcomp_docker_start() {
|
||||||
local host_src_dir="${1-$td}"
|
local host_src_dir="${1-$__npcomp_dir}"
|
||||||
if ! [ -d "$host_src_dir" ]; then
|
if ! [ -d "$host_src_dir" ]; then
|
||||||
echo "mlir-npcomp source directory not found:"
|
echo "mlir-npcomp source directory not found:"
|
||||||
echo "Pass path to host source directory as argument (default=$host_src_dir)."
|
echo "Pass path to host source directory as argument (default=$host_src_dir)."
|
||||||
|
@ -32,7 +32,7 @@ function npcomp_docker_start() {
|
||||||
docker run -d --rm --name "npcomp" \
|
docker run -d --rm --name "npcomp" \
|
||||||
--mount source=npcomp-build,target=/build \
|
--mount source=npcomp-build,target=/build \
|
||||||
--mount type=bind,source=$host_src_dir,target=/src/mlir-npcomp \
|
--mount type=bind,source=$host_src_dir,target=/src/mlir-npcomp \
|
||||||
me/npcomp:build-pytorch-1.6 tail -f /dev/null
|
me/npcomp:build-pytorch-nightly tail -f /dev/null
|
||||||
}
|
}
|
||||||
|
|
||||||
# Stop the container named "npcomp".
|
# Stop the container named "npcomp".
|
||||||
|
|
|
@ -14,7 +14,8 @@ RUN ln -s /usr/bin/llvm-symbolizer-10 /usr/bin/llvm-symbolizer
|
||||||
|
|
||||||
# Install PyTorch
|
# Install PyTorch
|
||||||
# Installs under: /usr/local/lib/python3.8/dist-packages/torch
|
# Installs under: /usr/local/lib/python3.8/dist-packages/torch
|
||||||
RUN pip3 install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
RUN pip3 install numpy
|
||||||
|
RUN pip3 install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
||||||
RUN ln -s /usr/local/lib/python3.8/dist-packages/torch /pytorch
|
RUN ln -s /usr/local/lib/python3.8/dist-packages/torch /pytorch
|
||||||
|
|
||||||
# Build configuration
|
# Build configuration
|
|
@ -1 +1 @@
|
||||||
Subproject commit ee491ac91e123b90eeec3cce7e494936ea8cb85d
|
Subproject commit 6771b98c4e4d5c0bd0a78a876bd212a76ec80a24
|
|
@ -8,6 +8,8 @@
|
||||||
#include "acap_dispatch.h"
|
#include "acap_dispatch.h"
|
||||||
|
|
||||||
#include "mlir-c/StandardAttributes.h"
|
#include "mlir-c/StandardAttributes.h"
|
||||||
|
#include "mlir-c/StandardTypes.h"
|
||||||
|
#include "npcomp-c/Types.h"
|
||||||
#include "npcomp/Python/PybindUtils.h"
|
#include "npcomp/Python/PybindUtils.h"
|
||||||
|
|
||||||
#include <ATen/core/function_schema.h>
|
#include <ATen/core/function_schema.h>
|
||||||
|
@ -81,8 +83,7 @@ void AcapController::returns(std::vector<at::Tensor> tensors) {
|
||||||
returnsValues.push_back(v);
|
returnsValues.push_back(v);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Get location from traceback.
|
MlirLocation loc = getCurrentLocation();
|
||||||
MlirLocation loc = mlirLocationUnknownGet(funcBuilder->getContext());
|
|
||||||
OperationStateHolder s("std.return", loc);
|
OperationStateHolder s("std.return", loc);
|
||||||
mlirOperationStateAddOperands(&s.state, returnsValues.size(),
|
mlirOperationStateAddOperands(&s.state, returnsValues.size(),
|
||||||
returnsValues.data());
|
returnsValues.data());
|
||||||
|
@ -123,6 +124,10 @@ void AcapController::fallbackKernel(const OperatorHandle &opHandle,
|
||||||
current->fallbackKernelImpl(opHandle, stack);
|
current->fallbackKernelImpl(opHandle, stack);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MlirLocation AcapController::getCurrentLocation() {
|
||||||
|
return mlirLocationUnknownGet(funcBuilder->getContext());
|
||||||
|
}
|
||||||
|
|
||||||
void AcapController::redispatch(const c10::OperatorHandle &opHandle,
|
void AcapController::redispatch(const c10::OperatorHandle &opHandle,
|
||||||
c10::Stack *stack) {
|
c10::Stack *stack) {
|
||||||
// Exclude recursive dispatch to this kernel.
|
// Exclude recursive dispatch to this kernel.
|
||||||
|
@ -168,8 +173,8 @@ void AcapController::fallbackKernelImpl(const OperatorHandle &opHandle,
|
||||||
MlirValue mlirValue = mapIValueToMlirValue(loc, *argIt);
|
MlirValue mlirValue = mapIValueToMlirValue(loc, *argIt);
|
||||||
if (mlirValueIsNull(mlirValue)) {
|
if (mlirValueIsNull(mlirValue)) {
|
||||||
std::stringstream out;
|
std::stringstream out;
|
||||||
out << "Unsupported capture value passed to kernel (" << argIt->tagKind()
|
out << "Unsupported capture value returned from kernel '" << kernelName
|
||||||
<< "): " << *argIt;
|
<< "' (" << argIt->tagKind() << "): " << *argIt;
|
||||||
throw std::invalid_argument(out.str());
|
throw std::invalid_argument(out.str());
|
||||||
}
|
}
|
||||||
operands.push_back(mlirValue);
|
operands.push_back(mlirValue);
|
||||||
|
@ -191,8 +196,8 @@ void AcapController::fallbackKernelImpl(const OperatorHandle &opHandle,
|
||||||
MlirType resultType = mapIValueToMlirType(loc, *returnIt);
|
MlirType resultType = mapIValueToMlirType(loc, *returnIt);
|
||||||
if (mlirTypeIsNull(resultType)) {
|
if (mlirTypeIsNull(resultType)) {
|
||||||
std::stringstream out;
|
std::stringstream out;
|
||||||
out << "Unsupported capture value returned from kernel ("
|
out << "Unsupported capture value returned from kernel '" << kernelName
|
||||||
<< returnIt->tagKind() << "): " << *returnIt;
|
<< "' (" << returnIt->tagKind() << "): " << *returnIt;
|
||||||
throw std::invalid_argument(out.str());
|
throw std::invalid_argument(out.str());
|
||||||
}
|
}
|
||||||
resultTypes.push_back(resultType);
|
resultTypes.push_back(resultType);
|
||||||
|
@ -227,13 +232,17 @@ MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc,
|
||||||
if (ival.isTensor()) {
|
if (ival.isTensor()) {
|
||||||
// Is it an already mapped tensor?
|
// Is it an already mapped tensor?
|
||||||
MlirValue mappedValue = funcBuilder->lookupTensor(ival.toTensor());
|
MlirValue mappedValue = funcBuilder->lookupTensor(ival.toTensor());
|
||||||
// TODO: Add mlirValueIsNull()
|
if (!mlirValueIsNull(mappedValue)) {
|
||||||
if (mappedValue.ptr) {
|
|
||||||
return mappedValue;
|
return mappedValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
throw std::invalid_argument(
|
mappedValue = importTensorByValue(ival.toTensor());
|
||||||
"TODO: implement tensor import for non-arg tensors");
|
assert(mappedValue.ptr);
|
||||||
|
return mappedValue;
|
||||||
|
}
|
||||||
|
if (ival.isBool()) {
|
||||||
|
// TODO: Switch to the numpy.bool type as that is a closer domain match.
|
||||||
|
return funcBuilder->getBoolConstant(loc, ival.toBool());
|
||||||
}
|
}
|
||||||
return {nullptr};
|
return {nullptr};
|
||||||
// TODO: Implement mappings for the whole set (relevant to this use case):
|
// TODO: Implement mappings for the whole set (relevant to this use case):
|
||||||
|
@ -241,7 +250,6 @@ MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc,
|
||||||
// _(Tensor)
|
// _(Tensor)
|
||||||
// _(Double)
|
// _(Double)
|
||||||
// _(Int)
|
// _(Int)
|
||||||
// _(Bool)
|
|
||||||
// _(Tuple)
|
// _(Tuple)
|
||||||
// _(String)
|
// _(String)
|
||||||
// _(Blob)
|
// _(Blob)
|
||||||
|
@ -265,9 +273,86 @@ MlirType AcapController::mapIValueToMlirType(MlirLocation loc,
|
||||||
if (ival.isTensor()) {
|
if (ival.isTensor()) {
|
||||||
return typeMapper.forwardTensorToType(ival.toTensor());
|
return typeMapper.forwardTensorToType(ival.toTensor());
|
||||||
}
|
}
|
||||||
|
if (ival.isBool()) {
|
||||||
|
// TODO: Switch to the numpy.bool type as that is a closer domain match.
|
||||||
|
return mlirIntegerTypeGet(funcBuilder->getContext(), 1);
|
||||||
|
}
|
||||||
return {nullptr};
|
return {nullptr};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MlirValue AcapController::importTensorByValue(at::Tensor tensor) {
|
||||||
|
using at::ScalarType;
|
||||||
|
|
||||||
|
auto throwUnsupportedTensorError = [&]() {
|
||||||
|
std::stringstream msg;
|
||||||
|
msg << "Unsupported import tensor type: " << tensor;
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
};
|
||||||
|
|
||||||
|
// Get a C-contiguous form as we can bulk-load that into a DenseElementsAttr.
|
||||||
|
if (!tensor.is_contiguous())
|
||||||
|
tensor = tensor.contiguous();
|
||||||
|
|
||||||
|
// The flat number of bytes throws an exception for tensors that are not
|
||||||
|
// dense and accessible as such.
|
||||||
|
at::checkLayout(at::CheckedFrom("accessing contiguous"), tensor,
|
||||||
|
c10::Layout::Strided);
|
||||||
|
|
||||||
|
// Construct the ShapedType.
|
||||||
|
auto loc = getCurrentLocation();
|
||||||
|
MlirType elementType = typeMapper.mapScalarType(tensor.scalar_type());
|
||||||
|
llvm::SmallVector<int64_t, 4> shape(tensor.sizes().begin(),
|
||||||
|
tensor.sizes().end());
|
||||||
|
MlirType shapedType = mlirRankedTensorTypeGetChecked(
|
||||||
|
shape.size(), shape.data(), elementType, loc);
|
||||||
|
if (mlirTypeIsNull(shapedType)) {
|
||||||
|
throwUnsupportedTensorError();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Import DenseElementsAttr data.
|
||||||
|
// TODO: Support bool tensors.
|
||||||
|
// TODO: More import formats in C-API.
|
||||||
|
MlirAttribute valueAttribute;
|
||||||
|
auto numElements = tensor.numel();
|
||||||
|
auto tensorData = tensor.data_ptr();
|
||||||
|
switch (tensor.scalar_type()) {
|
||||||
|
case ScalarType::Int:
|
||||||
|
valueAttribute = mlirDenseElementsAttrInt32Get(
|
||||||
|
shapedType, numElements, static_cast<const int32_t *>(tensorData));
|
||||||
|
break;
|
||||||
|
case ScalarType::Long:
|
||||||
|
valueAttribute = mlirDenseElementsAttrInt64Get(
|
||||||
|
shapedType, numElements, static_cast<const int64_t *>(tensorData));
|
||||||
|
break;
|
||||||
|
case ScalarType::Float:
|
||||||
|
valueAttribute = mlirDenseElementsAttrFloatGet(
|
||||||
|
shapedType, numElements, static_cast<const float *>(tensorData));
|
||||||
|
break;
|
||||||
|
case ScalarType::Double:
|
||||||
|
valueAttribute = mlirDenseElementsAttrDoubleGet(
|
||||||
|
shapedType, numElements, static_cast<const double *>(tensorData));
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throwUnsupportedTensorError();
|
||||||
|
}
|
||||||
|
MlirValue constTensorValue =
|
||||||
|
funcBuilder->getGeneralConstant(loc, valueAttribute);
|
||||||
|
|
||||||
|
// Create an array from the tensor constant via the
|
||||||
|
// numpy.create_array_from_tensor op.
|
||||||
|
MlirType constArrayType = npcompNdArrayTypeGetFromShaped(shapedType);
|
||||||
|
MlirOperationState state =
|
||||||
|
mlirOperationStateGet("numpy.create_array_from_tensor", loc);
|
||||||
|
mlirOperationStateAddOperands(&state, 1, &constTensorValue);
|
||||||
|
mlirOperationStateAddResults(&state, 1, &constArrayType);
|
||||||
|
MlirOperation constArrayOp = mlirOperationCreate(&state);
|
||||||
|
|
||||||
|
funcBuilder->getEntryBlockBuilder().insertBeforeTerminator(constArrayOp);
|
||||||
|
MlirValue constArrayValue = mlirOperationGetResult(constArrayOp, 0);
|
||||||
|
funcBuilder->mapTensor(tensor, constArrayValue);
|
||||||
|
return constArrayValue;
|
||||||
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL(_, ACAP_DISPATCH_KEY, m) {
|
TORCH_LIBRARY_IMPL(_, ACAP_DISPATCH_KEY, m) {
|
||||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<
|
m.fallback(torch::CppFunction::makeFromBoxedFunction<
|
||||||
&AcapController::fallbackKernel>());
|
&AcapController::fallbackKernel>());
|
||||||
|
|
|
@ -56,11 +56,14 @@ public:
|
||||||
c10::Stack *stack);
|
c10::Stack *stack);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
MlirLocation getCurrentLocation();
|
||||||
void redispatch(const c10::OperatorHandle &opHandle, c10::Stack *stack);
|
void redispatch(const c10::OperatorHandle &opHandle, c10::Stack *stack);
|
||||||
void fallbackKernelImpl(const c10::OperatorHandle &opHandle,
|
void fallbackKernelImpl(const c10::OperatorHandle &opHandle,
|
||||||
c10::Stack *stack);
|
c10::Stack *stack);
|
||||||
MlirValue mapIValueToMlirValue(MlirLocation loc, c10::IValue &ival);
|
MlirValue mapIValueToMlirValue(MlirLocation loc, c10::IValue &ival);
|
||||||
MlirType mapIValueToMlirType(MlirLocation loc, c10::IValue &ival);
|
MlirType mapIValueToMlirType(MlirLocation loc, c10::IValue &ival);
|
||||||
|
/// Imports a tensor by value (as a constant), remembering the association.
|
||||||
|
MlirValue importTensorByValue(at::Tensor tensor);
|
||||||
void verifyHasNotReturned();
|
void verifyHasNotReturned();
|
||||||
struct Activation {
|
struct Activation {
|
||||||
Activation(std::shared_ptr<AcapController> controller)
|
Activation(std::shared_ptr<AcapController> controller)
|
||||||
|
|
|
@ -149,18 +149,31 @@ MlirValue FuncBuilder::getScalarConstant(MlirLocation loc, at::Scalar s) {
|
||||||
// TODO: Switch to a basicpy.constant that works properly with signed
|
// TODO: Switch to a basicpy.constant that works properly with signed
|
||||||
// integers and then switch this to a signed integer.
|
// integers and then switch this to a signed integer.
|
||||||
MlirType t = mlirIntegerTypeGet(context, 64);
|
MlirType t = mlirIntegerTypeGet(context, 64);
|
||||||
MlirOperation op =
|
MlirAttribute value = mlirIntegerAttrGet(t, s.to<int64_t>());
|
||||||
createStandardConstant(loc, t, mlirIntegerAttrGet(t, s.to<int64_t>()));
|
return getGeneralConstant(loc, value);
|
||||||
return insertConstantOp(op);
|
|
||||||
}
|
}
|
||||||
if (s.isFloatingPoint()) {
|
if (s.isFloatingPoint()) {
|
||||||
MlirType t = mlirF64TypeGet(context);
|
MlirType t = mlirF64TypeGet(context);
|
||||||
MlirOperation op = createStandardConstant(
|
MlirAttribute value = mlirFloatAttrDoubleGet(context, t, s.to<double>());
|
||||||
loc, t, mlirFloatAttrDoubleGet(context, t, s.to<double>()));
|
return getGeneralConstant(loc, value);
|
||||||
return insertConstantOp(op);
|
}
|
||||||
|
if (s.isBoolean()) {
|
||||||
|
return getBoolConstant(loc, s.to<bool>());
|
||||||
}
|
}
|
||||||
// TODO: s.isBoolean()
|
|
||||||
// TODO: s.isComplex()
|
// TODO: s.isComplex()
|
||||||
|
|
||||||
throw std::invalid_argument("TODO: Scalar of unknown kind");
|
throw std::invalid_argument("TODO: Scalar of unknown kind");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MlirValue FuncBuilder::getBoolConstant(MlirLocation loc, bool v) {
|
||||||
|
MlirAttribute value = mlirBoolAttrGet(context, v);
|
||||||
|
return getGeneralConstant(loc, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirValue FuncBuilder::getGeneralConstant(MlirLocation loc,
|
||||||
|
MlirAttribute value) {
|
||||||
|
MlirType valueType = mlirAttributeGetType(value);
|
||||||
|
MlirOperation constOp = createStandardConstant(loc, valueType, value);
|
||||||
|
MlirValue constValue = insertConstantOp(constOp);
|
||||||
|
return constValue;
|
||||||
|
}
|
||||||
|
|
|
@ -120,6 +120,13 @@ public:
|
||||||
/// Gets a scalar constant value.
|
/// Gets a scalar constant value.
|
||||||
MlirValue getScalarConstant(MlirLocation loc, at::Scalar s);
|
MlirValue getScalarConstant(MlirLocation loc, at::Scalar s);
|
||||||
|
|
||||||
|
/// Gets a bool constant value.
|
||||||
|
MlirValue getBoolConstant(MlirLocation loc, bool v);
|
||||||
|
|
||||||
|
/// Gets a general constant value representing the given value
|
||||||
|
/// attribute.
|
||||||
|
MlirValue getGeneralConstant(MlirLocation loc, MlirAttribute value);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
FuncBuilder(MlirContext context, MlirOperation funcOp,
|
FuncBuilder(MlirContext context, MlirOperation funcOp,
|
||||||
BlockBuilder entryBlock)
|
BlockBuilder entryBlock)
|
||||||
|
|
|
@ -11,6 +11,7 @@ configure_lit_site_cfg(
|
||||||
|
|
||||||
set(TEST_DEPENDS
|
set(TEST_DEPENDS
|
||||||
FileCheck count not
|
FileCheck count not
|
||||||
|
npcomp-opt
|
||||||
NPCOMPTorchMLIRExt
|
NPCOMPTorchMLIRExt
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -5,23 +5,21 @@
|
||||||
import torch
|
import torch
|
||||||
import torch_mlir
|
import torch_mlir
|
||||||
|
|
||||||
# See bug references below and remove XFAIL when resolved.
|
|
||||||
# XFAIL: *
|
|
||||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||||
|
|
||||||
mb = torch_mlir.ModuleBuilder()
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
|
||||||
# TODO: Both of these fail with the "unsupported from an unboxed API yet" error.
|
|
||||||
# The corresponding ops need to be manually coded. Then these can be moved into
|
|
||||||
# the capture. https://github.com/llvm/mlir-npcomp/issues/78
|
|
||||||
# TODO: These also create constant tensors (needs implementation of import of
|
|
||||||
# DenseElements constants). https://github.com/llvm/mlir-npcomp/issues/79
|
|
||||||
model = torch.nn.BatchNorm2d(123)
|
|
||||||
ones = torch.ones(42,123,4,5)
|
ones = torch.ones(42,123,4,5)
|
||||||
|
|
||||||
with mb.capture_function("bn2d", []) as f:
|
with mb.capture_function("bn2d", [ones]) as f:
|
||||||
|
model = torch.nn.BatchNorm2d(123)
|
||||||
result = model(ones)
|
result = model(ones)
|
||||||
f.returns([result])
|
f.returns([result])
|
||||||
|
|
||||||
|
# TODO: This test exercises promotion of const to arrays, inplace zero_ and
|
||||||
|
# add, all of which should be checked individually because they have specific
|
||||||
|
# behavior.
|
||||||
# CHECK-LABEL: @bn2d
|
# CHECK-LABEL: @bn2d
|
||||||
|
# CHECK: %[[RESULT:.*]]:3 = torch.kernel_call "aten::native_batch_norm" %arg0
|
||||||
|
# CHECK: return %[[RESULT]]#0 : !numpy.ndarray<[42,123,4,5]:f32>
|
||||||
print(mb.module)
|
print(mb.module)
|
||||||
|
|
|
@ -49,6 +49,9 @@ int npcompTypeIsANdArray(MlirType t);
|
||||||
MlirType npcompNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape,
|
MlirType npcompNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape,
|
||||||
MlirType elementType);
|
MlirType elementType);
|
||||||
|
|
||||||
|
/// Helper that gets an equivalent NdArrayType from a ShapedType.
|
||||||
|
MlirType npcompNdArrayTypeGetFromShaped(MlirType shapedType);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -40,6 +40,9 @@ public:
|
||||||
static NdArrayType get(Type dtype,
|
static NdArrayType get(Type dtype,
|
||||||
llvm::Optional<ArrayRef<int64_t>> shape = llvm::None);
|
llvm::Optional<ArrayRef<int64_t>> shape = llvm::None);
|
||||||
|
|
||||||
|
/// Helper that gets an equivalent NdArrayType from a ShapedType.
|
||||||
|
static NdArrayType getFromShapedType(ShapedType shapedType);
|
||||||
|
|
||||||
/// Returns whether the dtype is a concrete type (versus
|
/// Returns whether the dtype is a concrete type (versus
|
||||||
/// !basicpy.UnknownType).
|
/// !basicpy.UnknownType).
|
||||||
bool hasKnownDtype();
|
bool hasKnownDtype();
|
||||||
|
|
|
@ -86,6 +86,7 @@ def AnyTorchTensorType : AnyTypeOf<[
|
||||||
|
|
||||||
def AnyScalar : AnyTypeOf<[
|
def AnyScalar : AnyTypeOf<[
|
||||||
AnySignedInteger,
|
AnySignedInteger,
|
||||||
|
AnyFloat,
|
||||||
Basicpy_BoolType,
|
Basicpy_BoolType,
|
||||||
Basicpy_StrType,
|
Basicpy_StrType,
|
||||||
Basicpy_NoneType,
|
Basicpy_NoneType,
|
||||||
|
|
|
@ -9,9 +9,11 @@
|
||||||
#include "npcomp-c/Types.h"
|
#include "npcomp-c/Types.h"
|
||||||
|
|
||||||
#include "mlir/CAPI/IR.h"
|
#include "mlir/CAPI/IR.h"
|
||||||
|
#include "mlir/IR/StandardTypes.h"
|
||||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP::Basicpy;
|
using namespace mlir::NPCOMP::Basicpy;
|
||||||
using namespace mlir::NPCOMP::Numpy;
|
using namespace mlir::NPCOMP::Numpy;
|
||||||
|
|
||||||
|
@ -46,3 +48,8 @@ MlirType npcompNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape,
|
||||||
llvm::ArrayRef<int64_t> shapeArray(shape, rank);
|
llvm::ArrayRef<int64_t> shapeArray(shape, rank);
|
||||||
return wrap(NdArrayType::get(unwrap(elementType), shapeArray));
|
return wrap(NdArrayType::get(unwrap(elementType), shapeArray));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MlirType npcompNdArrayTypeGetFromShaped(MlirType shapedType) {
|
||||||
|
return wrap(
|
||||||
|
NdArrayType::getFromShapedType(unwrap(shapedType).cast<ShapedType>()));
|
||||||
|
}
|
||||||
|
|
|
@ -195,6 +195,13 @@ NdArrayType NdArrayType::get(Type dtype,
|
||||||
return Base::get(dtype.getContext(), dtype, shape);
|
return Base::get(dtype.getContext(), dtype, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NdArrayType NdArrayType::getFromShapedType(ShapedType shapedType) {
|
||||||
|
llvm::Optional<ArrayRef<int64_t>> shape;
|
||||||
|
if (shapedType.hasRank())
|
||||||
|
shape = shapedType.getShape();
|
||||||
|
return get(shapedType.getElementType(), shape);
|
||||||
|
}
|
||||||
|
|
||||||
bool NdArrayType::hasKnownDtype() {
|
bool NdArrayType::hasKnownDtype() {
|
||||||
return getDtype() != Basicpy::UnknownType::get(getContext());
|
return getDtype() != Basicpy::UnknownType::get(getContext());
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue