Support optional args/returns and other odds and ends.

* None's out Device? args.
* Emits bool tensors if needed.
* Adds some stderr tracing to better see what is going on.
* Test case that exercises NLLLoss.
* This test case emits something for backward calculations but there are some issues still to be worked out, so that part is left out of the test case.
* Progress on #97
pull/101/head
Stella Laurenzo 2020-10-29 17:41:15 -07:00
parent a3f4db9fe8
commit 8d98dd4551
8 changed files with 154 additions and 23 deletions

View File

@ -9,6 +9,7 @@ include_directories(
link_directories("${TORCH_INSTALL_PREFIX}/lib")
add_library(npcomp_torch_c10_dispatch_bindings
acap_dispatch.cpp
debug.cpp
func_builder.cpp
module_builder.cpp
python_bindings.cpp

View File

@ -6,6 +6,7 @@
//===----------------------------------------------------------------------===//
#include "acap_dispatch.h"
#include "debug.h"
#include "mlir-c/StandardAttributes.h"
#include "mlir-c/StandardTypes.h"
@ -100,8 +101,8 @@ void AcapController::KernelCallBuilder::addOperand(const IValue &value) {
if (mlirValueIsNull(mlirValue)) {
std::stringstream out;
const std::string &kernelName = opHandle.operator_name().name;
out << "Unsupported capture value returned from kernel '" << kernelName
<< "' (" << value.tagKind() << "): " << value;
out << "Unsupported capture value passed to kernel '" << kernelName << "' ("
<< value.tagKind() << "): " << value;
throw std::invalid_argument(out.str());
}
mlirOperationStateAddOperands(state, 1, &mlirValue);
@ -132,11 +133,6 @@ MlirOperation AcapController::KernelCallBuilder::create() {
MlirValue result = mlirOperationGetResult(op, it.first);
parent.funcBuilder->mapTensor(it.second, result);
}
// Add to debug log.
std::stringstream sout;
sout << "CAPTURE: " << opHandle.operator_name().name << "\n";
parent.captureLog.push_back(sout.str());
return op;
}
@ -178,11 +174,13 @@ void AcapController::returns(std::vector<at::Tensor> tensors) {
for (auto &tensor : tensors) {
MlirValue v = funcBuilder->lookupTensor(tensor);
if (mlirValueIsNull(v)) {
// Finalize this instance so that everything that goes into printing an
// error message does not capture.
hasReturned = true;
// Exclude recursive dispatch in order to print tensor.
c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey);
std::stringstream msg;
msg << "Cannot return a tensor that is not from the capture context: "
<< tensor;
msg << "Cannot return a tensor that is not from the capture context";
throw std::invalid_argument(msg.str());
}
@ -199,12 +197,6 @@ void AcapController::returns(std::vector<at::Tensor> tensors) {
hasReturned = true;
}
std::vector<std::string> AcapController::getDebugLog() {
std::vector<std::string> copy;
captureLog.swap(copy);
return copy;
}
std::shared_ptr<AcapController>
AcapController::getCurrentThreadAcapController() {
auto &stack = getThreadLocalActiveStack();
@ -241,6 +233,12 @@ at::Tensor AcapController::convolutionKernel(
auto &dispatcher = c10::Dispatcher::singleton();
auto opHandle = dispatcher.findOp(opName);
assert(opHandle && "could not find convolution op");
if (isDebugTraceEnabled()) {
std::stringstream s;
s << "Convolution (unboxed) dispatch: " << opHandle->schema();
debugTrace(s.str());
}
auto opTyped = opHandle->typed<at::Tensor(
const at::Tensor &, const at::Tensor &, const c10::optional<at::Tensor> &,
const at::IntArrayRef, const at::IntArrayRef, const at::IntArrayRef,
@ -307,6 +305,12 @@ void AcapController::redispatch(const c10::OperatorHandle &opHandle,
void AcapController::fallbackKernelImpl(const OperatorHandle &opHandle,
Stack *stack) {
verifyHasNotReturned();
if (isDebugTraceEnabled()) {
std::stringstream s;
s << "Fallback (boxed) dispatch: " << opHandle.schema();
debugTrace(s.str());
}
// Exclude recursive dispatch to this kernel.
c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey);
@ -352,6 +356,13 @@ MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc,
return funcBuilder->getScalarConstant(loc, ival.toScalar());
}
if (ival.isTensor()) {
auto tensor = ival.toTensor();
if (!tensor.defined()) {
// Optional tensors ("Tensor?" type) are represented as Tensor ivals
// that are undefined.
return funcBuilder->getNoneConstant(loc);
}
// Is it an already mapped tensor?
MlirValue mappedValue = funcBuilder->lookupTensor(ival.toTensor());
if (!mlirValueIsNull(mappedValue)) {
@ -377,6 +388,11 @@ MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc,
if (ival.isNone()) {
return funcBuilder->getNoneConstant(loc);
}
if (ival.isDevice()) {
// TODO: Do we need to model/preserve device? Currently, just None'ing
// it out.
return funcBuilder->getNoneConstant(loc);
}
return {nullptr};
// TODO: Implement mappings for the whole set (relevant to this use case):
// _(Tensor)
@ -415,6 +431,9 @@ MlirType AcapController::mapIValueToMlirType(MlirLocation loc,
if (ival.isNone()) {
return npcompNoneTypeGet(funcBuilder->getContext());
}
if (ival.isDevice()) {
return npcompNoneTypeGet(funcBuilder->getContext());
}
return {nullptr};
}
@ -438,7 +457,15 @@ MlirValue AcapController::importTensorByValue(at::Tensor tensor) {
// Construct the ShapedType.
auto loc = getCurrentLocation();
MlirType elementType = typeMapper.mapScalarType(tensor.scalar_type());
MlirType elementType;
if (tensor.scalar_type() == ScalarType::Bool) {
// Bool is a special case. When used as an element type, it must be i1.
// The generalized (non-Tensor) conversion, assumes that Bool is the
// Basicpy bool type.
elementType = mlirIntegerTypeGet(funcBuilder->getContext(), 1);
} else {
elementType = typeMapper.mapScalarType(tensor.scalar_type());
}
llvm::SmallVector<int64_t, 4> shape(tensor.sizes().begin(),
tensor.sizes().end());
MlirType shapedType = mlirRankedTensorTypeGetChecked(
@ -470,6 +497,13 @@ MlirValue AcapController::importTensorByValue(at::Tensor tensor) {
valueAttribute = mlirDenseElementsAttrDoubleGet(
shapedType, numElements, static_cast<const double *>(tensorData));
break;
case ScalarType::Bool:
// TODO: Add a test specifically for bool and ensure consistency between
// storage format and load format
// (https://github.com/llvm/mlir-npcomp/issues/100).
valueAttribute = mlirDenseElementsAttrBoolGet(
shapedType, numElements, static_cast<const int *>(tensorData));
break;
default:
throwUnsupportedTensorError();
}

View File

@ -44,9 +44,6 @@ public:
// Terminates capture and returns tensors from the function.
void returns(std::vector<at::Tensor> tensors);
// Gets and clears the current debug log.
std::vector<std::string> getDebugLog();
// Returns the current AcapController (if it has been activated on this
// thread. Returns nullptr if none (not active on the current thread).
static std::shared_ptr<AcapController> getCurrentThreadAcapController();
@ -108,7 +105,6 @@ private:
TypeMapper &typeMapper;
std::unique_ptr<FuncBuilder> funcBuilder;
std::vector<std::string> captureLog;
bool hasReturned = false;
};

View File

@ -0,0 +1,29 @@
//===- debug.cpp ------------------------------------------------*- C++ -*-===//
//
// This file is licensed under a pytorch-style license
// See frontends/pytorch/LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#include "debug.h"
#include <iostream>
namespace torch_mlir {
static bool debugTraceToStderrEnabled = false;
/// Whether debug tracing is enabled and calls to debugTrace() are more than
/// a no-op.
bool isDebugTraceEnabled() { return debugTraceToStderrEnabled; }
/// Writes a message to the debug trace log.
void debugTrace(const std::string &message) {
if (debugTraceToStderrEnabled)
std::cerr << "TORCH_MLIR TRACE: " << message << "\n" << std::flush;
}
/// Enables writing debug traces to stderr.
void enableDebugTraceToStderr() { debugTraceToStderrEnabled = true; }
} // namespace torch_mlir

View File

@ -0,0 +1,22 @@
//===- debug.h --------------------------------------------------*- C++ -*-===//
//
// This file is licensed under a pytorch-style license
// See frontends/pytorch/LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#include <string>
namespace torch_mlir {
/// Whether debug tracing is enabled and calls to debugTrace() are more than
/// a no-op.
bool isDebugTraceEnabled();
/// Writes a message to the debug trace log.
void debugTrace(const std::string &message);
/// Enables writing debug traces to stderr.
void enableDebugTraceToStderr();
} // namespace torch_mlir

View File

@ -6,6 +6,7 @@
//===----------------------------------------------------------------------===//
#include "../pybind.h"
#include "debug.h"
#include <ATen/core/dispatch/Dispatcher.h>
@ -126,12 +127,13 @@ py::list GetRegisteredOps() {
}
void InitModuleBindings(py::module &m) {
m.def("debug_trace_to_stderr", &enableDebugTraceToStderr);
py::class_<AcapController, std::shared_ptr<AcapController>>(m,
"AcapController")
.def("__enter__", &AcapController::contextEnter)
.def("__exit__", &AcapController::contextExit)
.def("returns", &AcapController::returns)
.def("get_debug_log", &AcapController::getDebugLog);
.def("returns", &AcapController::returns);
m.def("get_registered_ops", &GetRegisteredOps, kGetRegisteredOpsDocstring);
ModuleBuilder::bind(m);

View File

@ -6,9 +6,10 @@
# and binds names locally. It exists to allow for customization of behavior
# prior to loading shared objects.
from _torch_mlir import ModuleBuilder
from _torch_mlir import *
__all__ = [
"debug_trace_to_stderr",
"ModuleBuilder",
]

View File

@ -0,0 +1,46 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch_mlir
torch_mlir.debug_trace_to_stderr()
N = 3
Cin = 16
Cout = 4
w = 10
h = 10
class Net(nn.Module):
def __init__(self, Cin, Cout):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(Cin, Cout, (3,3))
def forward(self, x):
x = self.conv1(x)
output = F.log_softmax(x, dim=1)
return output
model = Net(Cin, Cout)
inputs = torch.ones((N,Cin,h,w))
# Note that the NLLLoss kernel accepts an optional parameter, which is what
# this test is trying to verify.
loss = torch.nn.NLLLoss()
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, Cout)
mb = torch_mlir.ModuleBuilder()
with mb.capture_function("resa", [inputs, target]) as f:
result = loss(model(inputs), target)
f.returns([result])
# CHECK: "aten::convolution"
# CHECK: "aten::_log_softmax"
# CHECK: "aten::nll_loss2d_forward"
mb.module.operation.print(large_elements_limit=2)