mirror of https://github.com/llvm/torch-mlir
Support optional args/returns and other odds and ends.
* None's out Device? args. * Emits bool tensors if needed. * Adds some stderr tracing to better see what is going on. * Test case that exercises NLLLoss. * This test case emits something for backward calculations but there are some issues still to be worked out, so that part is left out of the test case. * Progress on #97pull/101/head
parent
a3f4db9fe8
commit
8d98dd4551
|
@ -9,6 +9,7 @@ include_directories(
|
||||||
link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
||||||
add_library(npcomp_torch_c10_dispatch_bindings
|
add_library(npcomp_torch_c10_dispatch_bindings
|
||||||
acap_dispatch.cpp
|
acap_dispatch.cpp
|
||||||
|
debug.cpp
|
||||||
func_builder.cpp
|
func_builder.cpp
|
||||||
module_builder.cpp
|
module_builder.cpp
|
||||||
python_bindings.cpp
|
python_bindings.cpp
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "acap_dispatch.h"
|
#include "acap_dispatch.h"
|
||||||
|
#include "debug.h"
|
||||||
|
|
||||||
#include "mlir-c/StandardAttributes.h"
|
#include "mlir-c/StandardAttributes.h"
|
||||||
#include "mlir-c/StandardTypes.h"
|
#include "mlir-c/StandardTypes.h"
|
||||||
|
@ -100,8 +101,8 @@ void AcapController::KernelCallBuilder::addOperand(const IValue &value) {
|
||||||
if (mlirValueIsNull(mlirValue)) {
|
if (mlirValueIsNull(mlirValue)) {
|
||||||
std::stringstream out;
|
std::stringstream out;
|
||||||
const std::string &kernelName = opHandle.operator_name().name;
|
const std::string &kernelName = opHandle.operator_name().name;
|
||||||
out << "Unsupported capture value returned from kernel '" << kernelName
|
out << "Unsupported capture value passed to kernel '" << kernelName << "' ("
|
||||||
<< "' (" << value.tagKind() << "): " << value;
|
<< value.tagKind() << "): " << value;
|
||||||
throw std::invalid_argument(out.str());
|
throw std::invalid_argument(out.str());
|
||||||
}
|
}
|
||||||
mlirOperationStateAddOperands(state, 1, &mlirValue);
|
mlirOperationStateAddOperands(state, 1, &mlirValue);
|
||||||
|
@ -132,11 +133,6 @@ MlirOperation AcapController::KernelCallBuilder::create() {
|
||||||
MlirValue result = mlirOperationGetResult(op, it.first);
|
MlirValue result = mlirOperationGetResult(op, it.first);
|
||||||
parent.funcBuilder->mapTensor(it.second, result);
|
parent.funcBuilder->mapTensor(it.second, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add to debug log.
|
|
||||||
std::stringstream sout;
|
|
||||||
sout << "CAPTURE: " << opHandle.operator_name().name << "\n";
|
|
||||||
parent.captureLog.push_back(sout.str());
|
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -178,11 +174,13 @@ void AcapController::returns(std::vector<at::Tensor> tensors) {
|
||||||
for (auto &tensor : tensors) {
|
for (auto &tensor : tensors) {
|
||||||
MlirValue v = funcBuilder->lookupTensor(tensor);
|
MlirValue v = funcBuilder->lookupTensor(tensor);
|
||||||
if (mlirValueIsNull(v)) {
|
if (mlirValueIsNull(v)) {
|
||||||
|
// Finalize this instance so that everything that goes into printing an
|
||||||
|
// error message does not capture.
|
||||||
|
hasReturned = true;
|
||||||
// Exclude recursive dispatch in order to print tensor.
|
// Exclude recursive dispatch in order to print tensor.
|
||||||
c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey);
|
c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey);
|
||||||
std::stringstream msg;
|
std::stringstream msg;
|
||||||
msg << "Cannot return a tensor that is not from the capture context: "
|
msg << "Cannot return a tensor that is not from the capture context";
|
||||||
<< tensor;
|
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -199,12 +197,6 @@ void AcapController::returns(std::vector<at::Tensor> tensors) {
|
||||||
hasReturned = true;
|
hasReturned = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> AcapController::getDebugLog() {
|
|
||||||
std::vector<std::string> copy;
|
|
||||||
captureLog.swap(copy);
|
|
||||||
return copy;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<AcapController>
|
std::shared_ptr<AcapController>
|
||||||
AcapController::getCurrentThreadAcapController() {
|
AcapController::getCurrentThreadAcapController() {
|
||||||
auto &stack = getThreadLocalActiveStack();
|
auto &stack = getThreadLocalActiveStack();
|
||||||
|
@ -241,6 +233,12 @@ at::Tensor AcapController::convolutionKernel(
|
||||||
auto &dispatcher = c10::Dispatcher::singleton();
|
auto &dispatcher = c10::Dispatcher::singleton();
|
||||||
auto opHandle = dispatcher.findOp(opName);
|
auto opHandle = dispatcher.findOp(opName);
|
||||||
assert(opHandle && "could not find convolution op");
|
assert(opHandle && "could not find convolution op");
|
||||||
|
if (isDebugTraceEnabled()) {
|
||||||
|
std::stringstream s;
|
||||||
|
s << "Convolution (unboxed) dispatch: " << opHandle->schema();
|
||||||
|
debugTrace(s.str());
|
||||||
|
}
|
||||||
|
|
||||||
auto opTyped = opHandle->typed<at::Tensor(
|
auto opTyped = opHandle->typed<at::Tensor(
|
||||||
const at::Tensor &, const at::Tensor &, const c10::optional<at::Tensor> &,
|
const at::Tensor &, const at::Tensor &, const c10::optional<at::Tensor> &,
|
||||||
const at::IntArrayRef, const at::IntArrayRef, const at::IntArrayRef,
|
const at::IntArrayRef, const at::IntArrayRef, const at::IntArrayRef,
|
||||||
|
@ -307,6 +305,12 @@ void AcapController::redispatch(const c10::OperatorHandle &opHandle,
|
||||||
void AcapController::fallbackKernelImpl(const OperatorHandle &opHandle,
|
void AcapController::fallbackKernelImpl(const OperatorHandle &opHandle,
|
||||||
Stack *stack) {
|
Stack *stack) {
|
||||||
verifyHasNotReturned();
|
verifyHasNotReturned();
|
||||||
|
if (isDebugTraceEnabled()) {
|
||||||
|
std::stringstream s;
|
||||||
|
s << "Fallback (boxed) dispatch: " << opHandle.schema();
|
||||||
|
debugTrace(s.str());
|
||||||
|
}
|
||||||
|
|
||||||
// Exclude recursive dispatch to this kernel.
|
// Exclude recursive dispatch to this kernel.
|
||||||
c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey);
|
c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey);
|
||||||
|
|
||||||
|
@ -352,6 +356,13 @@ MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc,
|
||||||
return funcBuilder->getScalarConstant(loc, ival.toScalar());
|
return funcBuilder->getScalarConstant(loc, ival.toScalar());
|
||||||
}
|
}
|
||||||
if (ival.isTensor()) {
|
if (ival.isTensor()) {
|
||||||
|
auto tensor = ival.toTensor();
|
||||||
|
if (!tensor.defined()) {
|
||||||
|
// Optional tensors ("Tensor?" type) are represented as Tensor ivals
|
||||||
|
// that are undefined.
|
||||||
|
return funcBuilder->getNoneConstant(loc);
|
||||||
|
}
|
||||||
|
|
||||||
// Is it an already mapped tensor?
|
// Is it an already mapped tensor?
|
||||||
MlirValue mappedValue = funcBuilder->lookupTensor(ival.toTensor());
|
MlirValue mappedValue = funcBuilder->lookupTensor(ival.toTensor());
|
||||||
if (!mlirValueIsNull(mappedValue)) {
|
if (!mlirValueIsNull(mappedValue)) {
|
||||||
|
@ -377,6 +388,11 @@ MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc,
|
||||||
if (ival.isNone()) {
|
if (ival.isNone()) {
|
||||||
return funcBuilder->getNoneConstant(loc);
|
return funcBuilder->getNoneConstant(loc);
|
||||||
}
|
}
|
||||||
|
if (ival.isDevice()) {
|
||||||
|
// TODO: Do we need to model/preserve device? Currently, just None'ing
|
||||||
|
// it out.
|
||||||
|
return funcBuilder->getNoneConstant(loc);
|
||||||
|
}
|
||||||
return {nullptr};
|
return {nullptr};
|
||||||
// TODO: Implement mappings for the whole set (relevant to this use case):
|
// TODO: Implement mappings for the whole set (relevant to this use case):
|
||||||
// _(Tensor)
|
// _(Tensor)
|
||||||
|
@ -415,6 +431,9 @@ MlirType AcapController::mapIValueToMlirType(MlirLocation loc,
|
||||||
if (ival.isNone()) {
|
if (ival.isNone()) {
|
||||||
return npcompNoneTypeGet(funcBuilder->getContext());
|
return npcompNoneTypeGet(funcBuilder->getContext());
|
||||||
}
|
}
|
||||||
|
if (ival.isDevice()) {
|
||||||
|
return npcompNoneTypeGet(funcBuilder->getContext());
|
||||||
|
}
|
||||||
return {nullptr};
|
return {nullptr};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -438,7 +457,15 @@ MlirValue AcapController::importTensorByValue(at::Tensor tensor) {
|
||||||
|
|
||||||
// Construct the ShapedType.
|
// Construct the ShapedType.
|
||||||
auto loc = getCurrentLocation();
|
auto loc = getCurrentLocation();
|
||||||
MlirType elementType = typeMapper.mapScalarType(tensor.scalar_type());
|
MlirType elementType;
|
||||||
|
if (tensor.scalar_type() == ScalarType::Bool) {
|
||||||
|
// Bool is a special case. When used as an element type, it must be i1.
|
||||||
|
// The generalized (non-Tensor) conversion, assumes that Bool is the
|
||||||
|
// Basicpy bool type.
|
||||||
|
elementType = mlirIntegerTypeGet(funcBuilder->getContext(), 1);
|
||||||
|
} else {
|
||||||
|
elementType = typeMapper.mapScalarType(tensor.scalar_type());
|
||||||
|
}
|
||||||
llvm::SmallVector<int64_t, 4> shape(tensor.sizes().begin(),
|
llvm::SmallVector<int64_t, 4> shape(tensor.sizes().begin(),
|
||||||
tensor.sizes().end());
|
tensor.sizes().end());
|
||||||
MlirType shapedType = mlirRankedTensorTypeGetChecked(
|
MlirType shapedType = mlirRankedTensorTypeGetChecked(
|
||||||
|
@ -470,6 +497,13 @@ MlirValue AcapController::importTensorByValue(at::Tensor tensor) {
|
||||||
valueAttribute = mlirDenseElementsAttrDoubleGet(
|
valueAttribute = mlirDenseElementsAttrDoubleGet(
|
||||||
shapedType, numElements, static_cast<const double *>(tensorData));
|
shapedType, numElements, static_cast<const double *>(tensorData));
|
||||||
break;
|
break;
|
||||||
|
case ScalarType::Bool:
|
||||||
|
// TODO: Add a test specifically for bool and ensure consistency between
|
||||||
|
// storage format and load format
|
||||||
|
// (https://github.com/llvm/mlir-npcomp/issues/100).
|
||||||
|
valueAttribute = mlirDenseElementsAttrBoolGet(
|
||||||
|
shapedType, numElements, static_cast<const int *>(tensorData));
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
throwUnsupportedTensorError();
|
throwUnsupportedTensorError();
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,9 +44,6 @@ public:
|
||||||
// Terminates capture and returns tensors from the function.
|
// Terminates capture and returns tensors from the function.
|
||||||
void returns(std::vector<at::Tensor> tensors);
|
void returns(std::vector<at::Tensor> tensors);
|
||||||
|
|
||||||
// Gets and clears the current debug log.
|
|
||||||
std::vector<std::string> getDebugLog();
|
|
||||||
|
|
||||||
// Returns the current AcapController (if it has been activated on this
|
// Returns the current AcapController (if it has been activated on this
|
||||||
// thread. Returns nullptr if none (not active on the current thread).
|
// thread. Returns nullptr if none (not active on the current thread).
|
||||||
static std::shared_ptr<AcapController> getCurrentThreadAcapController();
|
static std::shared_ptr<AcapController> getCurrentThreadAcapController();
|
||||||
|
@ -108,7 +105,6 @@ private:
|
||||||
|
|
||||||
TypeMapper &typeMapper;
|
TypeMapper &typeMapper;
|
||||||
std::unique_ptr<FuncBuilder> funcBuilder;
|
std::unique_ptr<FuncBuilder> funcBuilder;
|
||||||
std::vector<std::string> captureLog;
|
|
||||||
bool hasReturned = false;
|
bool hasReturned = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
//===- debug.cpp ------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// This file is licensed under a pytorch-style license
|
||||||
|
// See frontends/pytorch/LICENSE for license information.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "debug.h"
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
namespace torch_mlir {
|
||||||
|
|
||||||
|
static bool debugTraceToStderrEnabled = false;
|
||||||
|
|
||||||
|
/// Whether debug tracing is enabled and calls to debugTrace() are more than
|
||||||
|
/// a no-op.
|
||||||
|
bool isDebugTraceEnabled() { return debugTraceToStderrEnabled; }
|
||||||
|
|
||||||
|
/// Writes a message to the debug trace log.
|
||||||
|
void debugTrace(const std::string &message) {
|
||||||
|
if (debugTraceToStderrEnabled)
|
||||||
|
std::cerr << "TORCH_MLIR TRACE: " << message << "\n" << std::flush;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Enables writing debug traces to stderr.
|
||||||
|
void enableDebugTraceToStderr() { debugTraceToStderrEnabled = true; }
|
||||||
|
|
||||||
|
} // namespace torch_mlir
|
|
@ -0,0 +1,22 @@
|
||||||
|
//===- debug.h --------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// This file is licensed under a pytorch-style license
|
||||||
|
// See frontends/pytorch/LICENSE for license information.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace torch_mlir {
|
||||||
|
|
||||||
|
/// Whether debug tracing is enabled and calls to debugTrace() are more than
|
||||||
|
/// a no-op.
|
||||||
|
bool isDebugTraceEnabled();
|
||||||
|
|
||||||
|
/// Writes a message to the debug trace log.
|
||||||
|
void debugTrace(const std::string &message);
|
||||||
|
|
||||||
|
/// Enables writing debug traces to stderr.
|
||||||
|
void enableDebugTraceToStderr();
|
||||||
|
|
||||||
|
} // namespace torch_mlir
|
|
@ -6,6 +6,7 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "../pybind.h"
|
#include "../pybind.h"
|
||||||
|
#include "debug.h"
|
||||||
|
|
||||||
#include <ATen/core/dispatch/Dispatcher.h>
|
#include <ATen/core/dispatch/Dispatcher.h>
|
||||||
|
|
||||||
|
@ -126,12 +127,13 @@ py::list GetRegisteredOps() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void InitModuleBindings(py::module &m) {
|
void InitModuleBindings(py::module &m) {
|
||||||
|
m.def("debug_trace_to_stderr", &enableDebugTraceToStderr);
|
||||||
|
|
||||||
py::class_<AcapController, std::shared_ptr<AcapController>>(m,
|
py::class_<AcapController, std::shared_ptr<AcapController>>(m,
|
||||||
"AcapController")
|
"AcapController")
|
||||||
.def("__enter__", &AcapController::contextEnter)
|
.def("__enter__", &AcapController::contextEnter)
|
||||||
.def("__exit__", &AcapController::contextExit)
|
.def("__exit__", &AcapController::contextExit)
|
||||||
.def("returns", &AcapController::returns)
|
.def("returns", &AcapController::returns);
|
||||||
.def("get_debug_log", &AcapController::getDebugLog);
|
|
||||||
m.def("get_registered_ops", &GetRegisteredOps, kGetRegisteredOpsDocstring);
|
m.def("get_registered_ops", &GetRegisteredOps, kGetRegisteredOpsDocstring);
|
||||||
|
|
||||||
ModuleBuilder::bind(m);
|
ModuleBuilder::bind(m);
|
||||||
|
|
|
@ -6,9 +6,10 @@
|
||||||
# and binds names locally. It exists to allow for customization of behavior
|
# and binds names locally. It exists to allow for customization of behavior
|
||||||
# prior to loading shared objects.
|
# prior to loading shared objects.
|
||||||
|
|
||||||
from _torch_mlir import ModuleBuilder
|
from _torch_mlir import *
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"debug_trace_to_stderr",
|
||||||
"ModuleBuilder",
|
"ModuleBuilder",
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
# -*- Python -*-
|
||||||
|
# This file is licensed under a pytorch-style license
|
||||||
|
# See frontends/pytorch/LICENSE for license information.
|
||||||
|
|
||||||
|
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.autograd import Variable
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch_mlir
|
||||||
|
|
||||||
|
torch_mlir.debug_trace_to_stderr()
|
||||||
|
|
||||||
|
N = 3
|
||||||
|
Cin = 16
|
||||||
|
Cout = 4
|
||||||
|
w = 10
|
||||||
|
h = 10
|
||||||
|
|
||||||
|
class Net(nn.Module):
|
||||||
|
def __init__(self, Cin, Cout):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(Cin, Cout, (3,3))
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
output = F.log_softmax(x, dim=1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
model = Net(Cin, Cout)
|
||||||
|
|
||||||
|
inputs = torch.ones((N,Cin,h,w))
|
||||||
|
# Note that the NLLLoss kernel accepts an optional parameter, which is what
|
||||||
|
# this test is trying to verify.
|
||||||
|
loss = torch.nn.NLLLoss()
|
||||||
|
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, Cout)
|
||||||
|
|
||||||
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
with mb.capture_function("resa", [inputs, target]) as f:
|
||||||
|
result = loss(model(inputs), target)
|
||||||
|
f.returns([result])
|
||||||
|
|
||||||
|
# CHECK: "aten::convolution"
|
||||||
|
# CHECK: "aten::_log_softmax"
|
||||||
|
# CHECK: "aten::nll_loss2d_forward"
|
||||||
|
mb.module.operation.print(large_elements_limit=2)
|
Loading…
Reference in New Issue