mirror of https://github.com/llvm/torch-mlir
Support optional args/returns and other odds and ends.
* None's out Device? args. * Emits bool tensors if needed. * Adds some stderr tracing to better see what is going on. * Test case that exercises NLLLoss. * This test case emits something for backward calculations but there are some issues still to be worked out, so that part is left out of the test case. * Progress on #97pull/101/head
parent
a3f4db9fe8
commit
8d98dd4551
|
@ -9,6 +9,7 @@ include_directories(
|
|||
link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
||||
add_library(npcomp_torch_c10_dispatch_bindings
|
||||
acap_dispatch.cpp
|
||||
debug.cpp
|
||||
func_builder.cpp
|
||||
module_builder.cpp
|
||||
python_bindings.cpp
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "acap_dispatch.h"
|
||||
#include "debug.h"
|
||||
|
||||
#include "mlir-c/StandardAttributes.h"
|
||||
#include "mlir-c/StandardTypes.h"
|
||||
|
@ -100,8 +101,8 @@ void AcapController::KernelCallBuilder::addOperand(const IValue &value) {
|
|||
if (mlirValueIsNull(mlirValue)) {
|
||||
std::stringstream out;
|
||||
const std::string &kernelName = opHandle.operator_name().name;
|
||||
out << "Unsupported capture value returned from kernel '" << kernelName
|
||||
<< "' (" << value.tagKind() << "): " << value;
|
||||
out << "Unsupported capture value passed to kernel '" << kernelName << "' ("
|
||||
<< value.tagKind() << "): " << value;
|
||||
throw std::invalid_argument(out.str());
|
||||
}
|
||||
mlirOperationStateAddOperands(state, 1, &mlirValue);
|
||||
|
@ -132,11 +133,6 @@ MlirOperation AcapController::KernelCallBuilder::create() {
|
|||
MlirValue result = mlirOperationGetResult(op, it.first);
|
||||
parent.funcBuilder->mapTensor(it.second, result);
|
||||
}
|
||||
|
||||
// Add to debug log.
|
||||
std::stringstream sout;
|
||||
sout << "CAPTURE: " << opHandle.operator_name().name << "\n";
|
||||
parent.captureLog.push_back(sout.str());
|
||||
return op;
|
||||
}
|
||||
|
||||
|
@ -178,11 +174,13 @@ void AcapController::returns(std::vector<at::Tensor> tensors) {
|
|||
for (auto &tensor : tensors) {
|
||||
MlirValue v = funcBuilder->lookupTensor(tensor);
|
||||
if (mlirValueIsNull(v)) {
|
||||
// Finalize this instance so that everything that goes into printing an
|
||||
// error message does not capture.
|
||||
hasReturned = true;
|
||||
// Exclude recursive dispatch in order to print tensor.
|
||||
c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey);
|
||||
std::stringstream msg;
|
||||
msg << "Cannot return a tensor that is not from the capture context: "
|
||||
<< tensor;
|
||||
msg << "Cannot return a tensor that is not from the capture context";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
|
@ -199,12 +197,6 @@ void AcapController::returns(std::vector<at::Tensor> tensors) {
|
|||
hasReturned = true;
|
||||
}
|
||||
|
||||
std::vector<std::string> AcapController::getDebugLog() {
|
||||
std::vector<std::string> copy;
|
||||
captureLog.swap(copy);
|
||||
return copy;
|
||||
}
|
||||
|
||||
std::shared_ptr<AcapController>
|
||||
AcapController::getCurrentThreadAcapController() {
|
||||
auto &stack = getThreadLocalActiveStack();
|
||||
|
@ -241,6 +233,12 @@ at::Tensor AcapController::convolutionKernel(
|
|||
auto &dispatcher = c10::Dispatcher::singleton();
|
||||
auto opHandle = dispatcher.findOp(opName);
|
||||
assert(opHandle && "could not find convolution op");
|
||||
if (isDebugTraceEnabled()) {
|
||||
std::stringstream s;
|
||||
s << "Convolution (unboxed) dispatch: " << opHandle->schema();
|
||||
debugTrace(s.str());
|
||||
}
|
||||
|
||||
auto opTyped = opHandle->typed<at::Tensor(
|
||||
const at::Tensor &, const at::Tensor &, const c10::optional<at::Tensor> &,
|
||||
const at::IntArrayRef, const at::IntArrayRef, const at::IntArrayRef,
|
||||
|
@ -307,6 +305,12 @@ void AcapController::redispatch(const c10::OperatorHandle &opHandle,
|
|||
void AcapController::fallbackKernelImpl(const OperatorHandle &opHandle,
|
||||
Stack *stack) {
|
||||
verifyHasNotReturned();
|
||||
if (isDebugTraceEnabled()) {
|
||||
std::stringstream s;
|
||||
s << "Fallback (boxed) dispatch: " << opHandle.schema();
|
||||
debugTrace(s.str());
|
||||
}
|
||||
|
||||
// Exclude recursive dispatch to this kernel.
|
||||
c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey);
|
||||
|
||||
|
@ -352,6 +356,13 @@ MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc,
|
|||
return funcBuilder->getScalarConstant(loc, ival.toScalar());
|
||||
}
|
||||
if (ival.isTensor()) {
|
||||
auto tensor = ival.toTensor();
|
||||
if (!tensor.defined()) {
|
||||
// Optional tensors ("Tensor?" type) are represented as Tensor ivals
|
||||
// that are undefined.
|
||||
return funcBuilder->getNoneConstant(loc);
|
||||
}
|
||||
|
||||
// Is it an already mapped tensor?
|
||||
MlirValue mappedValue = funcBuilder->lookupTensor(ival.toTensor());
|
||||
if (!mlirValueIsNull(mappedValue)) {
|
||||
|
@ -377,6 +388,11 @@ MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc,
|
|||
if (ival.isNone()) {
|
||||
return funcBuilder->getNoneConstant(loc);
|
||||
}
|
||||
if (ival.isDevice()) {
|
||||
// TODO: Do we need to model/preserve device? Currently, just None'ing
|
||||
// it out.
|
||||
return funcBuilder->getNoneConstant(loc);
|
||||
}
|
||||
return {nullptr};
|
||||
// TODO: Implement mappings for the whole set (relevant to this use case):
|
||||
// _(Tensor)
|
||||
|
@ -415,6 +431,9 @@ MlirType AcapController::mapIValueToMlirType(MlirLocation loc,
|
|||
if (ival.isNone()) {
|
||||
return npcompNoneTypeGet(funcBuilder->getContext());
|
||||
}
|
||||
if (ival.isDevice()) {
|
||||
return npcompNoneTypeGet(funcBuilder->getContext());
|
||||
}
|
||||
return {nullptr};
|
||||
}
|
||||
|
||||
|
@ -438,7 +457,15 @@ MlirValue AcapController::importTensorByValue(at::Tensor tensor) {
|
|||
|
||||
// Construct the ShapedType.
|
||||
auto loc = getCurrentLocation();
|
||||
MlirType elementType = typeMapper.mapScalarType(tensor.scalar_type());
|
||||
MlirType elementType;
|
||||
if (tensor.scalar_type() == ScalarType::Bool) {
|
||||
// Bool is a special case. When used as an element type, it must be i1.
|
||||
// The generalized (non-Tensor) conversion, assumes that Bool is the
|
||||
// Basicpy bool type.
|
||||
elementType = mlirIntegerTypeGet(funcBuilder->getContext(), 1);
|
||||
} else {
|
||||
elementType = typeMapper.mapScalarType(tensor.scalar_type());
|
||||
}
|
||||
llvm::SmallVector<int64_t, 4> shape(tensor.sizes().begin(),
|
||||
tensor.sizes().end());
|
||||
MlirType shapedType = mlirRankedTensorTypeGetChecked(
|
||||
|
@ -470,6 +497,13 @@ MlirValue AcapController::importTensorByValue(at::Tensor tensor) {
|
|||
valueAttribute = mlirDenseElementsAttrDoubleGet(
|
||||
shapedType, numElements, static_cast<const double *>(tensorData));
|
||||
break;
|
||||
case ScalarType::Bool:
|
||||
// TODO: Add a test specifically for bool and ensure consistency between
|
||||
// storage format and load format
|
||||
// (https://github.com/llvm/mlir-npcomp/issues/100).
|
||||
valueAttribute = mlirDenseElementsAttrBoolGet(
|
||||
shapedType, numElements, static_cast<const int *>(tensorData));
|
||||
break;
|
||||
default:
|
||||
throwUnsupportedTensorError();
|
||||
}
|
||||
|
|
|
@ -44,9 +44,6 @@ public:
|
|||
// Terminates capture and returns tensors from the function.
|
||||
void returns(std::vector<at::Tensor> tensors);
|
||||
|
||||
// Gets and clears the current debug log.
|
||||
std::vector<std::string> getDebugLog();
|
||||
|
||||
// Returns the current AcapController (if it has been activated on this
|
||||
// thread. Returns nullptr if none (not active on the current thread).
|
||||
static std::shared_ptr<AcapController> getCurrentThreadAcapController();
|
||||
|
@ -108,7 +105,6 @@ private:
|
|||
|
||||
TypeMapper &typeMapper;
|
||||
std::unique_ptr<FuncBuilder> funcBuilder;
|
||||
std::vector<std::string> captureLog;
|
||||
bool hasReturned = false;
|
||||
};
|
||||
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
//===- debug.cpp ------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under a pytorch-style license
|
||||
// See frontends/pytorch/LICENSE for license information.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "debug.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace torch_mlir {
|
||||
|
||||
static bool debugTraceToStderrEnabled = false;
|
||||
|
||||
/// Whether debug tracing is enabled and calls to debugTrace() are more than
|
||||
/// a no-op.
|
||||
bool isDebugTraceEnabled() { return debugTraceToStderrEnabled; }
|
||||
|
||||
/// Writes a message to the debug trace log.
|
||||
void debugTrace(const std::string &message) {
|
||||
if (debugTraceToStderrEnabled)
|
||||
std::cerr << "TORCH_MLIR TRACE: " << message << "\n" << std::flush;
|
||||
}
|
||||
|
||||
/// Enables writing debug traces to stderr.
|
||||
void enableDebugTraceToStderr() { debugTraceToStderrEnabled = true; }
|
||||
|
||||
} // namespace torch_mlir
|
|
@ -0,0 +1,22 @@
|
|||
//===- debug.h --------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under a pytorch-style license
|
||||
// See frontends/pytorch/LICENSE for license information.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace torch_mlir {
|
||||
|
||||
/// Whether debug tracing is enabled and calls to debugTrace() are more than
|
||||
/// a no-op.
|
||||
bool isDebugTraceEnabled();
|
||||
|
||||
/// Writes a message to the debug trace log.
|
||||
void debugTrace(const std::string &message);
|
||||
|
||||
/// Enables writing debug traces to stderr.
|
||||
void enableDebugTraceToStderr();
|
||||
|
||||
} // namespace torch_mlir
|
|
@ -6,6 +6,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "../pybind.h"
|
||||
#include "debug.h"
|
||||
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
|
||||
|
@ -126,12 +127,13 @@ py::list GetRegisteredOps() {
|
|||
}
|
||||
|
||||
void InitModuleBindings(py::module &m) {
|
||||
m.def("debug_trace_to_stderr", &enableDebugTraceToStderr);
|
||||
|
||||
py::class_<AcapController, std::shared_ptr<AcapController>>(m,
|
||||
"AcapController")
|
||||
.def("__enter__", &AcapController::contextEnter)
|
||||
.def("__exit__", &AcapController::contextExit)
|
||||
.def("returns", &AcapController::returns)
|
||||
.def("get_debug_log", &AcapController::getDebugLog);
|
||||
.def("returns", &AcapController::returns);
|
||||
m.def("get_registered_ops", &GetRegisteredOps, kGetRegisteredOpsDocstring);
|
||||
|
||||
ModuleBuilder::bind(m);
|
||||
|
|
|
@ -6,9 +6,10 @@
|
|||
# and binds names locally. It exists to allow for customization of behavior
|
||||
# prior to loading shared objects.
|
||||
|
||||
from _torch_mlir import ModuleBuilder
|
||||
from _torch_mlir import *
|
||||
|
||||
|
||||
__all__ = [
|
||||
"debug_trace_to_stderr",
|
||||
"ModuleBuilder",
|
||||
]
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch_mlir
|
||||
|
||||
torch_mlir.debug_trace_to_stderr()
|
||||
|
||||
N = 3
|
||||
Cin = 16
|
||||
Cout = 4
|
||||
w = 10
|
||||
h = 10
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self, Cin, Cout):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(Cin, Cout, (3,3))
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
output = F.log_softmax(x, dim=1)
|
||||
return output
|
||||
|
||||
model = Net(Cin, Cout)
|
||||
|
||||
inputs = torch.ones((N,Cin,h,w))
|
||||
# Note that the NLLLoss kernel accepts an optional parameter, which is what
|
||||
# this test is trying to verify.
|
||||
loss = torch.nn.NLLLoss()
|
||||
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, Cout)
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
with mb.capture_function("resa", [inputs, target]) as f:
|
||||
result = loss(model(inputs), target)
|
||||
f.returns([result])
|
||||
|
||||
# CHECK: "aten::convolution"
|
||||
# CHECK: "aten::_log_softmax"
|
||||
# CHECK: "aten::nll_loss2d_forward"
|
||||
mb.module.operation.print(large_elements_limit=2)
|
Loading…
Reference in New Issue