Remove acap_dispatch.

This is old code that barely worked, and this approach just won't scale.
TorchFX seems like the tracing-based solution going forward.
pull/309/head
Sean Silva 2021-09-17 03:51:24 +00:00
parent 5f3b1ce0b8
commit 900f0e04aa
24 changed files with 3 additions and 1427 deletions

View File

@ -13,10 +13,7 @@ include_directories(BEFORE
link_directories("${TORCH_INSTALL_PREFIX}/lib")
add_library(TorchMLIRTorchPlugin SHARED
builder/acap_dispatch.cpp
builder/class_annotator.cpp
builder/debug.cpp
builder/func_builder.cpp
builder/function_importer.cpp
builder/module_builder.cpp
builder/node_importer.cpp

View File

@ -1,595 +0,0 @@
//===- acap_dispatch.cpp --------------------------------------------------===//
//
// This file is licensed under a pytorch-style license
// See LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#include "acap_dispatch.h"
#include "debug.h"
#include "mlir_utils.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "torch-mlir-c/TorchTypes.h"
#include <ATen/core/function_schema.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/stack.h>
#include <c10/core/DispatchKey.h>
#include <torch/library.h>
using namespace torch_mlir;
namespace py = pybind11;
using c10::FunctionSchema;
using c10::IValue;
using c10::OperatorHandle;
using c10::Stack;
// TODO: Private use dispatch keys are not made for real uses. Allocate a proper
// dispatch key in upstream PyTorch (DispatchKey.h) prior to maturity. Note
// that the TORCH_LIBRARY_* macros expand this by name and other APIs use its
// enum value, so we define both. We can get rid of both once we have our
// own key.
// TODO: Ask the PT devs why conv is special and only shows up if dispatching
// through the autograd keys.
#define ACAP_DISPATCH_KEY PrivateUse2
#define ACAP_GRAD_DISPATCH_KEY AutogradPrivateUse2
static c10::DispatchKey kAcapDispatchKey = c10::DispatchKey::ACAP_DISPATCH_KEY;
static c10::DispatchKey kAcapGradDispatchKey =
c10::DispatchKey::ACAP_GRAD_DISPATCH_KEY;
AcapController::TracedSchemaOpBuilder::TracedSchemaOpBuilder(
AcapController &parent, MlirContext context, MlirLocation loc,
const c10::OperatorHandle &opHandle)
: parent(parent), loc(loc), opHandle(opHandle) {}
void AcapController::TracedSchemaOpBuilder::addOperand(const IValue &value) {
MlirValue mlirValue = parent.mapIValueToMlirValue(loc, value);
if (mlirValueIsNull(mlirValue)) {
std::stringstream out;
const std::string &kernelName = opHandle.operator_name().name;
out << "Unsupported capture value passed to kernel '" << kernelName << "' ("
<< value.tagKind() << "): " << value;
throw std::invalid_argument(out.str());
}
operands.push_back(mlirValue);
}
void AcapController::TracedSchemaOpBuilder::addResult(const IValue &value) {
MlirType resultType = parent.mapIValueToMlirType(loc, value);
if (mlirTypeIsNull(resultType)) {
std::stringstream out;
const std::string &kernelName = opHandle.operator_name().name;
out << "Unsupported capture value returned from kernel '" << kernelName
<< "' (" << value.tagKind() << "): " << value;
throw std::invalid_argument(out.str());
}
if (value.isTensor()) {
resultIndexToTensorMap.emplace_back(resultCount++, value.toTensor());
}
resultTypes.push_back(resultType);
}
MlirOperation AcapController::TracedSchemaOpBuilder::create() {
MlirOperation op =
createOperationFromSchema(parent.funcBuilder->getEntryBlock(), loc,
opHandle.schema(), resultTypes, operands);
// Map result tensors.
for (auto &it : resultIndexToTensorMap) {
MlirValue result = mlirOperationGetResult(op, it.first);
parent.funcBuilder->mapTensor(it.second, result);
}
return op;
}
std::list<AcapController::Activation> &
AcapController::getThreadLocalActiveStack() {
static thread_local std::list<Activation> threadLocalActiveStack;
return threadLocalActiveStack;
}
py::object AcapController::contextEnter() {
auto &stack = getThreadLocalActiveStack();
stack.emplace_front(shared_from_this());
Activation &current = stack.front();
c10::DispatchKeySet keySet{kAcapDispatchKey, kAcapGradDispatchKey};
current.includeGuard =
std::make_unique<c10::impl::IncludeDispatchKeyGuard>(keySet);
return py::cast(this);
}
void AcapController::contextExit(py::object exc_type, py::object exc_val,
py::object exc_tb) {
auto &stack = getThreadLocalActiveStack();
if (stack.empty() || stack.front().controller.get() != this) {
throw std::runtime_error("Mismatched context manager __exit__");
}
stack.pop_front();
if (!hasReturned) {
returns({});
}
}
void AcapController::returns(std::vector<at::Tensor> tensors) {
verifyHasNotReturned();
std::vector<MlirType> returnsTypes;
std::vector<MlirValue> returnsValues;
for (auto &tensor : tensors) {
MlirValue v = funcBuilder->lookupTensor(tensor);
if (mlirValueIsNull(v)) {
debugTrace(
"Return of imported-constant tensor (intentional memorization?)");
v = importTensorByValue(tensor);
}
returnsTypes.push_back(mlirValueGetType(v));
returnsValues.push_back(v);
}
MlirLocation loc = getCurrentLocation();
OperationStateHolder s("std.return", loc);
mlirOperationStateAddOperands(s, returnsValues.size(), returnsValues.data());
funcBuilder->getEntryBlockBuilder().insertBeforeTerminator(
s.createOperation());
funcBuilder->rewriteFuncReturnTypes(returnsTypes);
hasReturned = true;
}
std::shared_ptr<AcapController>
AcapController::getCurrentThreadAcapController() {
auto &stack = getThreadLocalActiveStack();
if (stack.empty())
return nullptr;
return stack.front().controller;
}
void AcapController::verifyHasNotReturned() {
if (hasReturned) {
throw std::runtime_error(
"Function has already returned. Cannot trace more operations.");
}
}
/* static */
void AcapController::fallbackKernel(const OperatorHandle &opHandle,
Stack *stack) {
auto redispatchCallback = [&]() {
// Exclude recursive dispatch to this kernel.
c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey);
// Passthrough.
auto &dispatcher = c10::Dispatcher::singleton();
dispatcher.callBoxed(opHandle, stack);
};
auto current = getCurrentThreadAcapController();
if (!current) {
redispatchCallback();
return;
}
current->fallbackKernelImpl(opHandle, stack, redispatchCallback);
}
at::Tensor AcapController::convolutionKernel(
const at::Tensor &input, const at::Tensor &weight,
const c10::optional<at::Tensor> &bias, const at::IntArrayRef stride,
const at::IntArrayRef padding, const at::IntArrayRef dilation,
const bool transposed, const at::IntArrayRef output_padding,
const int64_t groups) {
static c10::OperatorName opName{"aten::convolution", ""};
auto &dispatcher = c10::Dispatcher::singleton();
auto opHandle = dispatcher.findOp(opName);
assert(opHandle && "could not find convolution op");
if (isDebugTraceEnabled()) {
std::stringstream s;
s << "Convolution (unboxed) dispatch: " << opHandle->schema();
debugTrace(s.str());
}
auto opTyped = opHandle->typed<at::Tensor(
const at::Tensor &, const at::Tensor &, const c10::optional<at::Tensor> &,
const at::IntArrayRef, const at::IntArrayRef, const at::IntArrayRef,
const bool, const at::IntArrayRef, const int64_t)>();
// Exclude recursive calls: convolution is completely emitted by this
// kernel.
c10::DispatchKeySet keySet{kAcapDispatchKey, kAcapGradDispatchKey};
c10::impl::ExcludeDispatchKeyGuard exclusion(keySet);
auto current = getCurrentThreadAcapController();
if (!current) {
return opTyped.redispatch(c10::DispatchKeySet({c10::DispatchKey::AutogradOther}), input,
weight, bias, stride, padding, dilation,
transposed, output_padding, groups);
}
MlirContext context = current->funcBuilder->getContext();
MlirLocation loc = current->getCurrentLocation();
std::string kernelName{"aten::convolution"};
TracedSchemaOpBuilder opBuilder{*current, context, loc, *opHandle};
opBuilder.addOperand(IValue(input));
opBuilder.addOperand(IValue(weight));
// This is really sad: instead of storing a none in the optional, it stores
// an undefined tensor, which cannot convert to an IValue :(
// TODO: File PyTorch bug. Perhaps this is why they don't support boxing
// for it.
IValue biasIValue;
if (bias && bias->defined()) {
biasIValue = IValue(bias);
} else {
biasIValue = IValue(c10::optional<at::Tensor>());
}
opBuilder.addOperand(biasIValue);
opBuilder.addOperand(IValue(stride));
opBuilder.addOperand(IValue(padding));
opBuilder.addOperand(IValue(dilation));
opBuilder.addOperand(IValue(transposed));
opBuilder.addOperand(IValue(output_padding));
opBuilder.addOperand(IValue(groups));
auto result = opTyped.redispatch(
c10::DispatchKeySet({c10::DispatchKey::AutogradOther}), input, weight, bias, stride, padding,
dilation, transposed, output_padding, groups);
opBuilder.addResult(result);
opBuilder.create();
return result;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor>
AcapController::mklConvolutionBackward(
const at::Tensor &input, const at::Tensor &grad_output,
const at::Tensor &weight, const at::IntArrayRef padding,
const at::IntArrayRef stride, const at::IntArrayRef dilation,
const int64_t groups, std::array<bool, 3> output_mask) {
static c10::OperatorName opName{"aten::mkldnn_convolution_backward", ""};
auto &dispatcher = c10::Dispatcher::singleton();
auto opHandle = dispatcher.findOp(opName);
assert(opHandle && "could not find mkldnn_convolution_backward op");
if (isDebugTraceEnabled()) {
std::stringstream s;
s << "mkldnn_convolution_backward dispatch: " << opHandle->schema();
debugTrace(s.str());
}
auto opTyped = opHandle->typed<std::tuple<at::Tensor, at::Tensor, at::Tensor>(
const at::Tensor &input, const at::Tensor &grad_output,
const at::Tensor &weight, const at::IntArrayRef padding,
const at::IntArrayRef stride, const at::IntArrayRef dilation,
const int64_t groups, std::array<bool, 3> output_mask)>();
// Exclude recursive calls: convolution is completely emitted by this
// kernel.
c10::DispatchKeySet keySet{kAcapDispatchKey, kAcapGradDispatchKey};
c10::impl::ExcludeDispatchKeyGuard exclusion(keySet);
auto current = getCurrentThreadAcapController();
if (!current) {
return opTyped.redispatch(c10::DispatchKeySet({c10::DispatchKey::AutogradOther}), input,
grad_output, weight, padding, stride,
dilation, groups, output_mask);
}
// Emit the call as if to aten::convolution_overridable, the generic, full
// parameterized versions that backends are supposed to implement.
// Requires some parameter swizzling.
// It has the signature:
// convolution_backward_overrideable(Tensor grad_output, Tensor input,
// Tensor weight, int[] stride, int[] padding, int[] dilation,
// bool transposed, int[] output_padding, int groups,
// bool[3] output_mask) ->
// (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)
MlirContext context = current->funcBuilder->getContext();
MlirLocation loc = current->getCurrentLocation();
std::string kernelName{"aten::convolution_backward"};
static c10::OperatorName emitOpName{"aten::convolution_backward_overrideable",
""};
auto emitOpHandle = dispatcher.findOp(emitOpName);
assert(emitOpHandle && "could not find convolution_backward_overrideable op");
TracedSchemaOpBuilder opBuilder{*current, context, loc, *emitOpHandle};
opBuilder.addOperand(IValue(grad_output));
opBuilder.addOperand(IValue(input));
opBuilder.addOperand(IValue(weight));
opBuilder.addOperand(IValue(stride));
opBuilder.addOperand(IValue(padding));
opBuilder.addOperand(IValue(dilation));
opBuilder.addOperand(IValue(false));
std::vector<int64_t> output_padding(padding.size()); // Not provided.
opBuilder.addOperand(IValue(at::IntArrayRef(output_padding)));
opBuilder.addOperand(IValue(groups));
opBuilder.addOperand(IValue(output_mask));
auto results = opTyped.redispatch(
c10::DispatchKeySet({c10::DispatchKey::AutogradCPU}), input, grad_output, weight, padding,
stride, dilation, groups, output_mask);
opBuilder.addResult(std::get<0>(results));
opBuilder.addResult(std::get<1>(results));
opBuilder.addResult(std::get<2>(results));
opBuilder.create();
return results;
}
at::Tensor &AcapController::copyUnderKernel(at::Tensor &self,
const at::Tensor &src,
bool non_blocking) {
static c10::OperatorName opName{"aten::copy_", ""};
auto &dispatcher = c10::Dispatcher::singleton();
auto opHandle = dispatcher.findOp(opName);
assert(opHandle && "could not find copy_ op");
if (isDebugTraceEnabled()) {
std::stringstream s;
s << "copy_ dispatch: " << opHandle->schema();
debugTrace(s.str());
}
auto opTyped = opHandle->typed<at::Tensor &(
at::Tensor & self, const at::Tensor &src, bool non_blocking)>();
// Exclude recursive calls.
c10::DispatchKeySet keySet{kAcapDispatchKey, kAcapGradDispatchKey};
c10::impl::ExcludeDispatchKeyGuard exclusion(keySet);
auto current = getCurrentThreadAcapController();
if (!current) {
return opTyped.redispatch(c10::DispatchKeySet({c10::DispatchKey::AutogradOther}), self,
src, non_blocking);
}
MlirContext context = current->funcBuilder->getContext();
MlirLocation loc = current->getCurrentLocation();
TracedSchemaOpBuilder opBuilder{*current, context, loc, *opHandle};
opBuilder.addOperand(IValue(self));
opBuilder.addOperand(IValue(src));
auto &result = opTyped.redispatch(c10::DispatchKeySet({c10::DispatchKey::CPU}), self, src,
non_blocking);
opBuilder.addResult(result);
opBuilder.create();
return result;
}
at::Tensor AcapController::arangeBackendSelectKernel(
const at::Scalar &end, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
c10::optional<bool> pin_memory) {
static c10::OperatorName opName{"aten::arange", ""};
auto &dispatcher = c10::Dispatcher::singleton();
auto opHandle = dispatcher.findOp(opName);
assert(opHandle && "could not find arange op");
// Exclude recursive calls.
c10::DispatchKeySet keySet{kAcapDispatchKey, kAcapGradDispatchKey};
c10::impl::ExcludeDispatchKeyGuard exclusion(keySet);
// Dispatching in this fashion replicates the exact way that PyTorch
// built-in handlers dispatch to BackendSelect kernels.
auto targetDk = c10::computeDispatchKey(dtype, layout, device);
auto opTyped = opHandle->typed<at::Tensor(
const at::Scalar &end, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
c10::optional<bool> pin_memory)>();
return opTyped.redispatch(c10::DispatchKeySet({targetDk}), end, dtype, layout, device,
pin_memory);
}
MlirLocation AcapController::getCurrentLocation() {
return mlirLocationUnknownGet(funcBuilder->getContext());
}
void AcapController::fallbackKernelImpl(
const OperatorHandle &opHandle, Stack *stack,
std::function<void()> redispatchCallback) {
verifyHasNotReturned();
if (isDebugTraceEnabled()) {
std::stringstream s;
s << "Fallback (boxed) dispatch: " << opHandle.schema()
<< " (stack size=" << stack->size() << ")";
debugTrace(s.str());
}
// Exclude recursive dispatch to this kernel.
c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey);
const FunctionSchema &schema = opHandle.schema();
// Check for unsupported.
if (schema.is_vararg() || schema.is_varret()) {
throw std::invalid_argument(
"Cannot capture ops with variable arguments or returns");
}
MlirContext context = funcBuilder->getContext();
MlirLocation loc = getCurrentLocation();
auto kernelName = schema.name();
TracedSchemaOpBuilder opBuilder{*this, context, loc, opHandle};
// Map arguments to operands.
// This must be accumulated into the OperationState prior to re-dispatch
// since the stack is modified at that point.
size_t argCount = schema.arguments().size();
assert(stack->size() >= argCount && "stack too short");
for (auto argIt = stack->end() - argCount; argIt != stack->end(); ++argIt) {
opBuilder.addOperand(*argIt);
}
// Invoke the original kernel.
redispatchCallback();
// Map returns to results.
size_t returnCount = schema.returns().size();
assert(stack->size() >= returnCount && "stack too short");
for (auto returnIt = stack->end() - returnCount; returnIt != stack->end();
++returnIt) {
opBuilder.addResult(*returnIt);
}
opBuilder.create();
}
MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc,
const IValue &ival) {
if (ival.isScalar()) {
return funcBuilder->getScalarConstant(loc, ival.toScalar());
}
if (ival.isTensor()) {
auto tensor = ival.toTensor();
if (!tensor.defined()) {
// Optional tensors ("Tensor?" type) are represented as Tensor ivals
// that are undefined.
return funcBuilder->getNoneConstant(loc);
}
// Is it an already mapped tensor?
MlirValue mappedValue = funcBuilder->lookupTensor(ival.toTensor());
if (!mlirValueIsNull(mappedValue)) {
return mappedValue;
}
mappedValue = importTensorByValue(ival.toTensor());
assert(mappedValue.ptr);
return mappedValue;
}
if (ival.isBool()) {
// TODO: Switch to the numpy.bool type as that is a closer domain match.
return funcBuilder->getBoolConstant(loc, ival.toBool());
}
if (ival.isList()) {
auto list = ival.toList();
std::vector<MlirValue> elements;
for (IValue element : list) {
elements.push_back(mapIValueToMlirValue(loc, element));
}
return funcBuilder->buildList(loc,
typeMapper.mapFromTorchType(loc, list.elementType()), elements);
}
if (ival.isNone()) {
return funcBuilder->getNoneConstant(loc);
}
if (ival.isDevice()) {
// TODO: Do we need to model/preserve device? Currently, just None'ing
// it out.
return funcBuilder->getNoneConstant(loc);
}
return {nullptr};
// TODO: Implement mappings for the whole set (relevant to this use case):
// _(Tensor)
// _(Double)
// _(Int)
// _(Tuple)
// _(String)
// _(Blob)
// _(GenericList)
// _(GenericDict)
// _(Future)
// _(Device)
// _(Object)
// _(PyObject)
// _(Uninitialized)
// _(Capsule)
// _(RRef)
// _(Generator)
}
MlirType AcapController::mapIValueToMlirType(MlirLocation loc,
const IValue &ival) {
if (ival.isScalar()) {
return typeMapper.mapFromTorchScalarType(ival.toScalar().type());
}
if (ival.isTensor()) {
return typeMapper.forwardTensorToType(ival.toTensor());
}
if (ival.isBool()) {
// TODO: Switch to the numpy.bool type as that is a closer domain match.
return mlirIntegerTypeGet(funcBuilder->getContext(), 1);
}
if (ival.isList()) {
return torchMlirTorchListTypeGet(
typeMapper.mapFromTorchType(loc, ival.toList().elementType()));
}
if (ival.isNone()) {
return torchMlirTorchNoneTypeGet(funcBuilder->getContext());
}
if (ival.isDevice()) {
return torchMlirTorchNoneTypeGet(funcBuilder->getContext());
}
return {nullptr};
}
MlirValue AcapController::importTensorByValue(at::Tensor tensor) {
auto loc = getCurrentLocation();
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
MlirOperation tensorOp = createMlirOperationAtEnd(
funcBuilder->getEntryBlock(), "torch.tensor.literal", loc,
torchMlirTorchNonValueTensorTypeGetFromShaped(
mlirAttributeGetType(denseElements)),
toMlirNamedAttribute("value", denseElements));
MlirValue tensorValue = mlirOperationGetResult(tensorOp, 0);
funcBuilder->mapTensor(tensor, tensorValue);
return tensorValue;
}
TORCH_LIBRARY_IMPL(aten, BackendSelect, m) {
// PyTorch logs a warning when kernels are overriden, which is unavoidable
// for factory-function BackendSelect kernels (there is not yet a "safe"
// override mechanism). So, just silence it. Any of them here are coded
// to be a superset of the default functionality, and there are only a few.
auto orig_log_level = FLAGS_caffe2_log_level;
FLAGS_caffe2_log_level = c10::GLOG_ERROR;
// Disable capture of arange: causes it to memorize the resulting tensor.
m.impl("arange", &AcapController::arangeBackendSelectKernel);
// Restore log level.
FLAGS_caffe2_log_level = orig_log_level;
}
TORCH_LIBRARY_IMPL(_, ACAP_DISPATCH_KEY, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<
&AcapController::fallbackKernel>());
}
TORCH_LIBRARY_IMPL(aten, ACAP_DISPATCH_KEY, m) {
m.impl("copy_", &AcapController::copyUnderKernel);
}
TORCH_LIBRARY_IMPL(aten, ACAP_GRAD_DISPATCH_KEY, m) {
// The at::convolution op is special in several ways. First, it presently
// does not support boxing, so all of the usual fanciness does not apply
// and it cannot be intercepted by generic fallthroughs, which is what
// would usually allow us to avoid intercepting it at the gradient phase.
// Second, the default implementation (see
// aten/src/ATen/native/Convolution.cpp) is very switchy based on hard-coded
// assumptions about device type. If we do nothing here, we will at best
// intercept an mkldnn_convolution, cudnn_convolution, etc on the backend
// dispatch keys. Non standard backends that don't have these switches
// just route to aten::convolution_overrideable (see the else in
// aten::convolution) as a convenience, but that is mostly a pass-through
// (except for 3d convolutions which contain a trailing squeeze that needs
// special casing). Therefore, we just intercept the aten::convolution op,
// record it specially, and then mask ourselves off and ask the CPU backend
// to invoke it. Not awesome.
// Presumably this is on someone's list to adapt to the dispatch machinery
// in a more appropriate way, but as the core of what the framework is,
// perhaps people are reticent to touch it. Maybe someday, this can go away.
m.impl("convolution", &AcapController::convolutionKernel);
// Sadly, there is no easy intercept point for the backwards convolution
// kernel which allows for chaining to an existing backend. And convolution
// is exceptionally special cased in this way, moreso than other ops.
// The "solution" is to intercept the backend specific backward convolution
// ops, emit it with the signature of the more generic
// "convolution_backward_overrideable" op, which is available for generic
// backends, and twiddle the parameters needed to get back to that form.
// For MKL, which is effectively the CPU implementation, this just means that
// some parameters are swapped and the full generality is not supported.
// The "right" answer at some point is probably just to implement a
// convolution kernel that fully does what is needed and delegates to an
// appropriate implementation behind the scenes.
m.impl("mkldnn_convolution_backward", AcapController::mklConvolutionBackward);
}

View File

@ -1,131 +0,0 @@
//===- acap_dispatch.h ------------------------------------------*- C++ -*-===//
//
// This file is licensed under a pytorch-style license
// See LICENSE for license information.
//
//===----------------------------------------------------------------------===//
// "ATen Capture" dispatcher: Defines facility for capturing programs by
// registering dispatch keys to intercept op execution.
// References:
// http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIRPLUGIN_CSRC_BUILDER_ACAP_DISPATCH_H
#define TORCHMLIRPLUGIN_CSRC_BUILDER_ACAP_DISPATCH_H
#include <list>
#include <memory>
#include "../pybind.h"
#include "func_builder.h"
#include "mlir-c/IR.h"
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/core/ivalue.h>
#include <c10/core/impl/LocalDispatchKeySet.h>
namespace torch_mlir {
/// Main entry point for managing device capture.
class AcapController : public std::enable_shared_from_this<AcapController> {
public:
AcapController(TypeMapper &typeMapper,
std::unique_ptr<FuncBuilder> funcBuilder)
: typeMapper(typeMapper), funcBuilder(std::move(funcBuilder)) {}
// Enter and exit the context manager.
pybind11::object contextEnter();
void contextExit(pybind11::object exc_type, pybind11::object exc_val,
pybind11::object exc_tb);
// Terminates capture and returns tensors from the function.
void returns(std::vector<at::Tensor> tensors);
// Returns the current AcapController (if it has been activated on this
// thread. Returns nullptr if none (not active on the current thread).
static std::shared_ptr<AcapController> getCurrentThreadAcapController();
// The fallback boxed kernel that we route captured dispatches through.
static void fallbackKernel(const c10::OperatorHandle &opHandle,
c10::Stack *stack);
// Kernel implementation for the boxing-incompatible convolution kernel.
static at::Tensor
convolutionKernel(const at::Tensor &input, const at::Tensor &weight,
const c10::optional<at::Tensor> &bias,
const at::IntArrayRef stride, const at::IntArrayRef padding,
const at::IntArrayRef dilation, const bool transposed,
const at::IntArrayRef output_padding, const int64_t groups);
// Kernel implementation for the boxing-incompatible convolution kernel.
static std::tuple<at::Tensor, at::Tensor, at::Tensor> mklConvolutionBackward(
const at::Tensor &input, const at::Tensor &grad_output,
const at::Tensor &weight, const at::IntArrayRef padding,
const at::IntArrayRef stride, const at::IntArrayRef dilation,
const int64_t groups, std::array<bool, 3> output_mask);
// Implementation for the aten::copy_ kernel.
static at::Tensor &copyUnderKernel(at::Tensor &self, const at::Tensor &src,
bool non_blocking);
// Backend select kernel for arange factory function.
static at::Tensor arangeBackendSelectKernel(
const at::Scalar &end, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
c10::optional<bool> pin_memory);
private:
/// Builds an MLIR operation for a Torch operator step by step.
class TracedSchemaOpBuilder {
public:
TracedSchemaOpBuilder(AcapController &parent, MlirContext context,
MlirLocation loc,
const c10::OperatorHandle &opHandle);
void addOperand(const c10::IValue &value);
void addResult(const c10::IValue &result);
MlirOperation create();
private:
AcapController &parent;
MlirLocation loc;
const c10::OperatorHandle &opHandle;
std::vector<MlirValue> operands;
std::vector<MlirType> resultTypes;
int resultCount = 0;
std::vector<std::pair<size_t, at::Tensor>> resultIndexToTensorMap;
};
MlirLocation getCurrentLocation();
void redispatch(const c10::OperatorHandle &opHandle, c10::Stack *stack);
void fallbackKernelImpl(const c10::OperatorHandle &opHandle,
c10::Stack *stack,
std::function<void()> redispatchCallback);
MlirValue mapIValueToMlirValue(MlirLocation loc, const c10::IValue &ival);
MlirType mapIValueToMlirType(MlirLocation loc, const c10::IValue &ival);
/// Imports a tensor by value (as a constant), remembering the association.
MlirValue importTensorByValue(at::Tensor tensor);
void verifyHasNotReturned();
struct Activation {
Activation(std::shared_ptr<AcapController> controller)
: controller(std::move(controller)) {}
std::shared_ptr<AcapController> controller;
// The RAII dispatch key guard is not movable, so heap allocate it. This is
// a bit outside of its intended design, but since this is thread local as
// well, it should be fine.
std::unique_ptr<c10::impl::IncludeDispatchKeyGuard> includeGuard;
std::unique_ptr<c10::impl::ExcludeDispatchKeyGuard> excludeGuard;
};
// Gets the thread local stack of active acap controllers.
static std::list<Activation> &getThreadLocalActiveStack();
TypeMapper &typeMapper;
std::unique_ptr<FuncBuilder> funcBuilder;
bool hasReturned = false;
};
} // namespace torch_mlir
#endif // TORCHMLIRPLUGIN_CSRC_C10_DISPATCH_ACAP_DISPATCH_H

View File

@ -1,29 +0,0 @@
//===- debug.cpp ------------------------------------------------*- C++ -*-===//
//
// This file is licensed under a pytorch-style license
// See LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#include "debug.h"
#include <iostream>
namespace torch_mlir {
static bool debugTraceToStderrEnabled = false;
/// Whether debug tracing is enabled and calls to debugTrace() are more than
/// a no-op.
bool isDebugTraceEnabled() { return debugTraceToStderrEnabled; }
/// Writes a message to the debug trace log.
void debugTrace(const std::string &message) {
if (debugTraceToStderrEnabled)
std::cerr << "TORCH_MLIR TRACE: " << message << "\n" << std::flush;
}
/// Enables writing debug traces to stderr.
void enableDebugTraceToStderr() { debugTraceToStderrEnabled = true; }
} // namespace torch_mlir

View File

@ -1,27 +0,0 @@
//===- debug.h --------------------------------------------------*- C++ -*-===//
//
// This file is licensed under a pytorch-style license
// See LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIRPLUGIN_CSRC_BUILDER_DEBUG_H
#define TORCHMLIRPLUGIN_CSRC_BUILDER_DEBUG_H
#include <string>
namespace torch_mlir {
/// Whether debug tracing is enabled and calls to debugTrace() are more than
/// a no-op.
bool isDebugTraceEnabled();
/// Writes a message to the debug trace log.
void debugTrace(const std::string &message);
/// Enables writing debug traces to stderr.
void enableDebugTraceToStderr();
} // namespace torch_mlir
#endif // TORCHMLIRPLUGIN_CSRC_BUILDER_DEBUG_H

View File

@ -1,143 +0,0 @@
//===- func_builder.cpp ---------------------------------------------------===//
//
// This file is licensed under a pytorch-style license
// See LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#include "func_builder.h"
#include "op_builder.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
#include "torch-mlir-c/TorchTypes.h"
using namespace torch_mlir;
std::unique_ptr<FuncBuilder>
FuncBuilder::createFunction(FuncBuilder::Inserter &inserter,
MlirLocation location, const std::string &name,
std::vector<MlirType> &inputTypes) {
auto context = mlirLocationGetContext(location);
// TODO: Create a dedicated API upstream for creating/manipulating func ops.
// (this is fragile and reveals details that are not guaranteed).
std::vector<MlirNamedAttribute> funcAttrs;
funcAttrs.push_back(toMlirNamedAttribute(
"type", mlirTypeAttrGet(mlirFunctionTypeGet(
context, inputTypes.size(), inputTypes.data(),
/*numResults=*/0, /*results=*/nullptr))));
funcAttrs.push_back(toMlirNamedAttribute(
"sym_name", mlirStringAttrGet(
context, mlirStringRefCreate(name.data(), name.size()))));
MlirOperationState state =
mlirOperationStateGet(toMlirStringRef("builtin.func"), location);
mlirOperationStateAddAttributes(&state, funcAttrs.size(), funcAttrs.data());
{
// Don't access these once ownership transferred.
MlirRegion newBodyRegion = mlirRegionCreate();
MlirBlock newEntryBlock =
mlirBlockCreate(inputTypes.size(), inputTypes.data());
mlirRegionInsertOwnedBlockAfter(newBodyRegion, {nullptr}, newEntryBlock);
mlirOperationStateAddOwnedRegions(&state, 1, &newBodyRegion);
}
// Need to re-lookup the region/block because we relinquished ownership above.
MlirOperation funcOp = mlirOperationCreate(&state);
MlirRegion bodyRegion = mlirOperationGetRegion(funcOp, 0);
MlirBlock entryBlock = mlirRegionGetFirstBlock(bodyRegion);
inserter(funcOp);
return std::unique_ptr<FuncBuilder>(new FuncBuilder(
context, funcOp, BlockBuilder(entryBlock, /*returnOp=*/{nullptr}, true)));
}
void FuncBuilder::rewriteFuncReturnTypes(std::vector<MlirType> &resultTypes) {
// Get inputs from current function type.
MlirAttribute funcTypeAttr =
mlirOperationGetAttributeByName(funcOp, toMlirStringRef("type"));
assert(!mlirAttributeIsNull(funcTypeAttr) &&
"function missing 'type' attribute");
assert(mlirAttributeIsAType(funcTypeAttr) &&
"function type is not a TypeAttr");
MlirType funcType = mlirTypeAttrGetValue(funcTypeAttr);
std::vector<MlirType> inputTypes;
for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(funcType); i < e; ++i) {
inputTypes.push_back(mlirFunctionTypeGetInput(funcType, i));
}
// Make new function type.
MlirType newFuncType =
mlirFunctionTypeGet(context, inputTypes.size(), inputTypes.data(),
resultTypes.size(), resultTypes.data());
MlirAttribute newFuncTypeAttr = mlirTypeAttrGet(newFuncType);
mlirOperationSetAttributeByName(funcOp, toMlirStringRef("type"),
newFuncTypeAttr);
(void)newFuncTypeAttr;
}
MlirValue FuncBuilder::insertConstantOp(MlirOperation op) {
mlirBlockInsertOwnedOperationAfter(entryBlock.getBlock(), prevConstantOp, op);
prevConstantOp = op;
return mlirOperationGetResult(op, 0);
}
MlirValue FuncBuilder::lookupTensor(at::Tensor tensor) {
for (auto it = tensorValueMap.rbegin(), e = tensorValueMap.rend(); it != e;
++it) {
if (it->first.is_same(tensor))
return it->second;
}
return {nullptr};
}
MlirValue FuncBuilder::getScalarConstant(MlirLocation loc, at::Scalar s) {
// Note that interpreter "scalars" match the Python semantics and are
// represented as one of double or int64_t, with a special tag for whether
// it should be interpreted as a bool.
if (s.isIntegral(/*includeBool=*/false)) {
MlirType t = torchMlirTorchIntTypeGet(context);
MlirAttribute value =
mlirIntegerAttrGet(mlirIntegerTypeGet(context, 64), s.to<int64_t>());
MlirOperation op = createMlirOperation(
"torch.constant.int", loc, t, toMlirNamedAttribute("value", value));
insertConstantOp(op);
return mlirOperationGetResult(op, 0);
}
if (s.isFloatingPoint()) {
MlirType t = torchMlirTorchFloatTypeGet(context);
MlirAttribute value = mlirFloatAttrDoubleGet(
context, mlirF64TypeGet(context), s.to<double>());
MlirOperation op = createMlirOperation(
"torch.constant.float", loc, t, toMlirNamedAttribute("value", value));
insertConstantOp(op);
return mlirOperationGetResult(op, 0);
}
if (s.isBoolean()) {
return getBoolConstant(loc, s.to<bool>());
}
// TODO: s.isComplex()
throw std::invalid_argument("TODO: Scalar of unknown kind");
}
MlirValue FuncBuilder::getBoolConstant(MlirLocation loc, bool v) {
return insertConstantOp(OpBuilder(context).createBoolConstant(loc, v));
}
MlirValue FuncBuilder::getNoneConstant(MlirLocation loc) {
return insertConstantOp(OpBuilder(context).createNoneConstant(loc));
}
MlirValue FuncBuilder::buildList(MlirLocation loc, MlirType elementType,
std::vector<MlirValue> &elements) {
MlirType resultType = torchMlirTorchListTypeGet(elementType);
OperationStateHolder state{"torch.prim.ListConstruct", loc};
mlirOperationStateAddResults(state, 1, &resultType);
mlirOperationStateAddOperands(state, elements.size(), elements.data());
MlirOperation op = state.createOperation();
entryBlock.insertBeforeTerminator(op);
return mlirOperationGetResult(op, 0);
}

View File

@ -1,153 +0,0 @@
//===- func_builder.h -------------------------------------------*- C++ -*-===//
//
// This file is licensed under a pytorch-style license
// See LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIRPLUGIN_CSRC_BUILDER_FUNC_BUILDER_H
#define TORCHMLIRPLUGIN_CSRC_BUILDER_FUNC_BUILDER_H
#include "mlir_utils.h"
#include "torch_to_mlir_utils.h"
#include "mlir-c/IR.h"
#include <ATen/Tensor.h>
#include <ATen/core/function_schema.h>
namespace torch_mlir {
/// Wraps an MlirOperationState, deallocating it on destruction unless if
/// it has been consumed.
class OperationStateHolder {
public:
OperationStateHolder(const char *name, MlirLocation loc)
: state(mlirOperationStateGet(toMlirStringRef(name), loc)) {}
OperationStateHolder(const OperationStateHolder &) = delete;
OperationStateHolder(OperationStateHolder &&other) = delete;
~OperationStateHolder() {
if (owned) {
// Destroying is done by creating and then destroying the operation.
mlirOperationDestroy(createOperation());
}
}
operator MlirOperationState *() { return &state; }
MlirOperation createOperation() {
assert(owned && "cannot createOperation on unowned state");
owned = false;
return mlirOperationCreate(&state);
}
private:
MlirOperationState state;
bool owned = true;
};
/// Wraps an MlirBlock under construction, primarily tracking the terminator
/// and supporting manipulation of it. The terminator may be null if it has
/// not yet been constructed.
class BlockBuilder {
public:
BlockBuilder(MlirBlock block, MlirOperation terminator, bool isReturn)
: block(block), terminator(terminator), isReturn(isReturn) {}
MlirBlock getBlock() { return block; }
MlirOperation getTerminator() { return terminator; }
bool getIsReturnTerminator() { return isReturn; }
/// Inserts an owned operation before the terminator.
void insertBeforeTerminator(MlirOperation op) {
mlirBlockInsertOwnedOperationBefore(block, terminator, op);
}
private:
MlirBlock block;
MlirOperation terminator;
bool isReturn;
};
/// Wraps a 'func' MlirOperation and provides facilities for constructing
/// IR from some stream of Torch operations.
class FuncBuilder {
public:
/// Callback for inserting a function.
using Inserter = std::function<void(MlirOperation funcOp)>;
/// Creates a new func op with the given characteristics. The created
/// operation is not attached. The caller must either destroy it or add it
/// to a parent.
static std::unique_ptr<FuncBuilder>
createFunction(Inserter &inserter, MlirLocation location,
const std::string &name, std::vector<MlirType> &inputTypes);
MlirContext getContext() { return context; }
MlirOperation getFuncOp() { return funcOp; }
/// Gets the function's entry block.
MlirBlock getEntryBlock() { return entryBlock.getBlock(); }
BlockBuilder &getEntryBlockBuilder() { return entryBlock; }
/// Rewrites the function's signature to return the given types. It is
/// assumed that a compatible terminator has been added.
void rewriteFuncReturnTypes(std::vector<MlirType> &resultTypes);
/// Maps a live Tensor to an MlirValue.
void mapTensor(at::Tensor tensor, MlirValue value) {
tensorValueMap.push_back(std::make_pair(tensor, value));
}
/// Looks up a current mapping of tensor to an MlirValue, returning a null
/// value if not found.
MlirValue lookupTensor(at::Tensor tensor);
/// Gets a scalar constant value.
MlirValue getScalarConstant(MlirLocation loc, at::Scalar s);
/// Gets a bool constant value.
MlirValue getBoolConstant(MlirLocation loc, bool v);
/// Gets a None constant value.
MlirValue getNoneConstant(MlirLocation loc);
/// Builds a list with the given elements
MlirValue buildList(MlirLocation loc, MlirType elementType,
std::vector<MlirValue> &elements);
private:
FuncBuilder(MlirContext context, MlirOperation funcOp,
BlockBuilder entryBlock)
: context(context), funcOp(funcOp), entryBlock(std::move(entryBlock)) {
(void)this->context;
}
/// Inserts a constant op into the function, returning its first result.
/// The function maintains a constant area at the top where these are
/// inserted.
MlirValue insertConstantOp(MlirOperation op);
MlirContext context;
/// The func op under construction.
MlirOperation funcOp;
/// Block builder for the entry block.
BlockBuilder entryBlock;
/// Previously inserted constant op or null.
MlirOperation prevConstantOp = {nullptr};
/// Maps tensors to MlirValue. Unfortunately, this needs to be a linear scan
/// because the impl pointer for the Tensor is not accessible. To make this
/// slightly better, we add to the back and lookup in reverse under the idea
/// that tensors may be mapped and accessed in proximity.
/// TODO: Tensors referenced via an IValue support hash code lookup and
/// identity checks. Switch to this instead of a linear scan.
std::vector<std::pair<at::Tensor, MlirValue>> tensorValueMap;
};
} // namespace torch_mlir
#endif // TORCHMLIRPLUGIN_CSRC_BUILDER_FUNC_BUILDER_H

View File

@ -11,7 +11,6 @@
#include <memory>
#include "../pybind.h"
#include "func_builder.h"
#include "node_importer.h"
#include "mlir-c/IR.h"

View File

@ -8,6 +8,7 @@
#include "ivalue_importer.h"
#include "class_annotator.h"
#include "function_importer.h"
#include "torch_to_mlir_utils.h"
#include <unordered_map>

View File

@ -11,7 +11,6 @@
#include <memory>
#include "../pybind.h"
#include "func_builder.h"
#include "class_annotator.h"
#include "mlir-c/IR.h"

View File

@ -110,7 +110,7 @@ ModuleBuilder::ModuleBuilder(pybind11::object contextObj)
context(castPythonObjectToMlirContext(this->contextObj)),
module(createEmptyModule(this->context)),
moduleObj(castMlirModuleToPythonObject(module)),
unknownLoc(mlirLocationUnknownGet(context)), typeMapper(this->context) {
unknownLoc(mlirLocationUnknownGet(context)) {
// TODO: Rework this once dialect registration C-APIs are in place.
// https://reviews.llvm.org/D88162
mlirRegisterAllDialects(context);
@ -122,30 +122,6 @@ ModuleBuilder::ModuleBuilder(pybind11::object contextObj)
terminator = mlirBlockGetFirstOperation(getBodyBlock());
}
std::shared_ptr<AcapController>
ModuleBuilder::startCaptureFunction(std::string &name,
std::vector<at::Tensor> args) {
// TODO: Verify that arguments do not alias each other.
std::vector<MlirType> inputTypes;
for (auto &arg : args) {
inputTypes.push_back(typeMapper.forwardTensorToType(arg));
}
// TODO: Extract a traceback and use in place of unknownLoc.
auto inserter = createInserter();
auto funcBuilder =
FuncBuilder::createFunction(inserter, unknownLoc, name, inputTypes);
// Map block arguments.
MlirBlock entryBlock = funcBuilder->getEntryBlock();
assert(mlirBlockGetNumArguments(entryBlock) ==
static_cast<intptr_t>(args.size()) &&
"entry block incorrect arg arity");
for (size_t i = 0; i < args.size(); ++i) {
funcBuilder->mapTensor(args[i], mlirBlockGetArgument(entryBlock, i));
}
return std::make_shared<AcapController>(typeMapper, std::move(funcBuilder));
}
torch::jit::StrongFunctionPtr
ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) {
MlirBlock block = getBodyBlock();
@ -166,14 +142,6 @@ void ModuleBuilder::importModule(torch::jit::Module jitModule,
mlirModuleGetContext(module), *classAnnotator);
}
FuncBuilder::Inserter ModuleBuilder::createInserter() {
MlirBlock block = getBodyBlock();
MlirOperation terminator = this->terminator;
return [=](MlirOperation op) {
mlirBlockInsertOwnedOperationBefore(block, terminator, op);
};
}
MlirBlock ModuleBuilder::getBodyBlock() {
MlirOperation moduleOp = mlirModuleGetOperation(module);
return mlirRegionGetFirstBlock(mlirOperationGetRegion(moduleOp, 0));
@ -184,8 +152,6 @@ void ModuleBuilder::bind(py::module &m) {
.def(py::init<py::object>(), py::arg("context") = py::none())
.def_property_readonly("context", &ModuleBuilder::getContextObj)
.def_property_readonly("module", &ModuleBuilder::getModuleObj)
.def("capture_function", &ModuleBuilder::startCaptureFunction,
py::keep_alive<0, 1>())
.def("import_function", &ModuleBuilder::importFunction)
.def("import_module", &ModuleBuilder::importModule, py::arg("module"),
py::arg("classAnnotator") = py::none());

View File

@ -10,7 +10,6 @@
#include "../pybind.h"
#include "acap_dispatch.h"
#include "class_annotator.h"
#include "mlir-c/IR.h"
@ -34,10 +33,6 @@ public:
pybind11::object getContextObj() { return contextObj; }
pybind11::object getModuleObj() { return moduleObj; }
// Starts a device-capture based function.
std::shared_ptr<AcapController>
startCaptureFunction(std::string &name, std::vector<at::Tensor> args);
// Imports a traced function. Note that the python type
// torch.jit.ScriptFunction is the C++ type torch::jit::StrongFunctionPtr.
// Just a bit of naming cruft.
@ -52,7 +47,6 @@ public:
py::object maybeClassAnnotator);
private:
FuncBuilder::Inserter createInserter();
MlirBlock getBodyBlock();
// Capture references to the python-owned context and module. Ownership
@ -64,8 +58,6 @@ private:
pybind11::object moduleObj;
MlirOperation terminator;
MlirLocation unknownLoc;
TypeMapper typeMapper;
};
} // namespace torch_mlir

View File

@ -11,7 +11,6 @@
#include <memory>
#include "../pybind.h"
#include "func_builder.h"
#include "mlir-c/IR.h"

View File

@ -6,12 +6,10 @@
//===----------------------------------------------------------------------===//
#include "../pybind.h"
#include "debug.h"
#include <ATen/core/dispatch/Dispatcher.h>
#include "../init_python_bindings.h"
#include "acap_dispatch.h"
#include "module_builder.h"
#include "class_annotator.h"
@ -125,16 +123,7 @@ py::list GetRegisteredOps() {
} // namespace
void torch_mlir::InitBuilderBindings(py::module &m) {
m.def("debug_trace_to_stderr", &enableDebugTraceToStderr);
py::class_<AcapController, std::shared_ptr<AcapController>>(m,
"AcapController")
.def("__enter__", &AcapController::contextEnter)
.def("__exit__", &AcapController::contextExit)
.def("returns", &AcapController::returns);
m.def("get_registered_ops", &GetRegisteredOps, kGetRegisteredOpsDocstring);
ModuleBuilder::bind(m);
initClassAnnotatorBindings(m);
}

View File

@ -5,6 +5,7 @@
//
//===----------------------------------------------------------------------===//
#include "torch_to_mlir_utils.h"
#include "function_importer.h"
#include "ivalue_importer.h"

View File

@ -1,18 +0,0 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See LICENSE for license information.
import torch
import torch_mlir
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
torch_mlir.debug_trace_to_stderr()
mb = torch_mlir.ModuleBuilder()
with mb.capture_function("arange_test", []) as f:
x = torch.arange(10)
f.returns([x])
# CHECK: %[[T:.*]] = torch.tensor.literal(dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xsi64>) : !torch.tensor<[10],si64>
# CHECK: return %[[T]]
mb.module.operation.print()

View File

@ -1,52 +0,0 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See LICENSE for license information.
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
# XFAIL: *
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch_mlir
torch_mlir.debug_trace_to_stderr()
N = 3
Cin = 16
Cout = 4
w = 10
h = 10
class Net(nn.Module):
def __init__(self, Cin, Cout):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(Cin, Cout, (3,3))
def forward(self, x):
x = self.conv1(x)
output = F.log_softmax(x, dim=1)
return output
model = Net(Cin, Cout)
inputs = torch.ones((N,Cin,h,w))
loss = torch.nn.NLLLoss()
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, Cout)
mb = torch_mlir.ModuleBuilder()
with mb.capture_function("resa", [inputs, target]) as f:
result = loss(model(inputs), target)
result.backward()
f.returns([result] + [p.grad for p in model.parameters()])
# CHECK: torch.operator "aten.convolution"
# CHECK: torch.operator "aten._log_softmax"
# CHECK: %[[FWD:.*]]:2 = torch.operator "aten.nll_loss2d_forward"
# CHECK: torch.operator "aten.nll_loss2d_backward"
# CHECK: torch.operator "aten._log_softmax_backward_data"
# CHECK: %[[BWD_CONV:.*]]:3 = torch.operator "aten.convolution_backward_overrideable"
# CHECK: %[[BWD_CONV_WEIGHTS:.*]] = aten.copy_{{.*}}%[[BWD_CONV]]#1
# CHECK: %[[BWD_CONV_BIAS:.*]] = aten.copy_{{.*}}%[[BWD_CONV]]#2
# CHECK: return %[[FWD]]#0, %[[BWD_CONV_WEIGHTS]], %[[BWD_CONV_BIAS]]
mb.module.operation.print(large_elements_limit=2)

View File

@ -1,33 +0,0 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See LICENSE for license information.
import torch
import torch_mlir
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
t0 = torch.randn((1,2,3,4))
t1 = torch.randn((1,2,3,4))
t2 = torch.randn((1,2,3,4))
mb = torch_mlir.ModuleBuilder()
with mb.capture_function("add3", [t0, t1, t2]) as f:
t3 = t0 + t1 + t2
f.returns([t3])
# NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
# CHECK-LABEL: func @add3(
# CHECK-SAME: %[[VAL_0:.*]]: !torch.tensor<[1,2,3,4],f32>, %[[VAL_1:.*]]: !torch.tensor<[1,2,3,4],f32>,
# CHECK-SAME: %[[VAL_2:.*]]: !torch.tensor<[1,2,3,4],f32>) -> !torch.tensor<[1,2,3,4],f32> {
# CHECK: %[[VAL_3:.*]] = torch.constant.int 1
# CHECK: %[[VAL_4:.*]] = torch.constant.int 1
# CHECK: %[[VAL_5:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor<1x2x3x4xf32>) : !torch.tensor<[1,2,3,4],f32>
# CHECK: %[[VAL_6:.*]] = torch.operator "aten.add.out"(%[[VAL_0]], %[[VAL_1]], %[[VAL_3]], %[[VAL_5]]) : (!torch.tensor<[1,2,3,4],f32>, !torch.tensor<[1,2,3,4],f32>, !torch.int, !torch.tensor<[1,2,3,4],f32>) -> !torch.tensor<[1,2,3,4],f32>
# CHECK: %[[VAL_7:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor<1x2x3x4xf32>) : !torch.tensor<[1,2,3,4],f32>
# CHECK: %[[VAL_8:.*]] = torch.operator "aten.add.out"(%[[VAL_6]], %[[VAL_2]], %[[VAL_4]], %[[VAL_7]]) : (!torch.tensor<[1,2,3,4],f32>, !torch.tensor<[1,2,3,4],f32>, !torch.int, !torch.tensor<[1,2,3,4],f32>) -> !torch.tensor<[1,2,3,4],f32>
# CHECK: return %[[VAL_8]] : !torch.tensor<[1,2,3,4],f32>
# CHECK: }
print(mb.module)

View File

@ -1,25 +0,0 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See LICENSE for license information.
import torch
import torch_mlir
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder()
ones = torch.ones(42,123,4,5)
with mb.capture_function("bn2d", [ones]) as f:
model = torch.nn.BatchNorm2d(123)
result = model(ones)
f.returns([result])
# TODO: This test exercises promotion of const to arrays, inplace zero_ and
# add, all of which should be checked individually because they have specific
# behavior.
# CHECK-LABEL: @bn2d
# CHECK: %[[RESULT:.*]]:3 = torch.operator "aten.native_batch_norm"(%arg0
# CHECK: return %[[RESULT]]#0 : !torch.tensor<[42,123,4,5],f32>
print(mb.module)

View File

@ -1,47 +0,0 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See LICENSE for license information.
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_mlir
torch_mlir.debug_trace_to_stderr()
N = 3
Cin = 16
Cout = 4
w = 10
h = 10
class Net(nn.Module):
def __init__(self, Cin, Cout):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(Cin, Cout, (3,3))
def forward(self, x):
x0 = self.conv1(x)
x1 = self.conv1(x)
z = torch.cat([x0, x1])
output = F.log_softmax(z, dim=1)
return output
model = Net(Cin, Cout)
inputs = torch.ones((N,Cin,h,w))
loss = torch.nn.NLLLoss()
target = torch.empty(2*N, 8, 8, dtype=torch.long).random_(0, Cout)
mb = torch_mlir.ModuleBuilder()
with mb.capture_function("conv_cat", [inputs, target]) as f:
result = loss(model(inputs), target)
f.returns([result])
# CHECK: "aten.convolution"
# CHECK: "aten.convolution"
# CHECK: torch.prim.ListConstruct
# CHECK: "aten._cat"
# CHECK: "aten._log_softmax.out"
# CHECK: "aten.nll_loss2d_forward"
mb.module.operation.print(large_elements_limit=2)

View File

@ -1,56 +0,0 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See LICENSE for license information.
import torch
import torch_mlir
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder()
N = 3
Cin = 16
Cout = 4
w = 10
h = 10
model = torch.nn.Conv2d(Cin, Cout, (3,3))
ref_model = torch.nn.Conv2d(Cin, Cout, (3,3))
ref_model.weight.data = model.weight.clone()
ref_model.bias.data = model.bias.clone()
softmax = torch.nn.LogSoftmax(dim=1)
loss = torch.nn.NLLLoss()
tensor = torch.randn(N, Cin, h, w)
with mb.capture_function("conv2d_fwd", [tensor]) as f:
result = model(tensor)
f.returns([result])
# NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
# CHECK-LABEL: func @conv2d_fwd(
# CHECK-SAME: %[[VAL_0:.*]]: !torch.tensor<[3,16,10,10],f32>) -> !torch.tensor<[3,4,8,8],f32> {
# CHECK: %[[VAL_1:.*]] = torch.constant.int 1
# CHECK: %[[VAL_2:.*]] = torch.constant.int 1
# CHECK: %[[VAL_3:.*]] = torch.constant.int 0
# CHECK: %[[VAL_4:.*]] = torch.constant.int 0
# CHECK: %[[VAL_5:.*]] = torch.constant.int 1
# CHECK: %[[VAL_6:.*]] = torch.constant.int 1
# CHECK: %[[VAL_7:.*]] = torch.constant.bool false
# CHECK: %[[VAL_8:.*]] = torch.constant.int 0
# CHECK: %[[VAL_9:.*]] = torch.constant.int 0
# CHECK: %[[VAL_10:.*]] = torch.constant.int 1
# CHECK: %[[VAL_11:.*]] = torch.tensor.literal(opaque<"_", "0xDEADBEEF"> : tensor<4x16x3x3xf32>) : !torch.tensor<[4,16,3,3],f32>
# CHECK: %[[VAL_12:.*]] = torch.tensor.literal(opaque<"_", "0xDEADBEEF"> : tensor<4xf32>) : !torch.tensor<[4],f32>
# CHECK: %[[VAL_13:.*]] = torch.prim.ListConstruct %[[VAL_1]], %[[VAL_2]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
# CHECK: %[[VAL_14:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
# CHECK: %[[VAL_15:.*]] = torch.prim.ListConstruct %[[VAL_5]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
# CHECK: %[[VAL_16:.*]] = torch.prim.ListConstruct %[[VAL_8]], %[[VAL_9]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
# CHECK: %[[VAL_17:.*]] = torch.operator "aten.convolution"(%[[VAL_0]], %[[VAL_11]], %[[VAL_12]], %[[VAL_13]], %[[VAL_14]], %[[VAL_15]], %[[VAL_7]], %[[VAL_16]], %[[VAL_10]]) : (!torch.tensor<[3,16,10,10],f32>, !torch.tensor<[4,16,3,3],f32>, !torch.tensor<[4],f32>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.bool, !torch.list<!torch.int>, !torch.int) -> !torch.tensor<[3,4,8,8],f32>
# CHECK: return %[[VAL_17]] : !torch.tensor<[3,4,8,8],f32>
# CHECK: }
mb.module.operation.print(large_elements_limit=2)

View File

@ -1,28 +0,0 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See LICENSE for license information.
import torch
import torch_mlir
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder()
t0 = torch.randn(4)
t1 = torch.randn(4)
t2 = torch.randn(4)
with mb.capture_function("multi_output", [t0, t1, t2]) as f:
t4 = t0 + t1 + t2
t5 = t4 + t1
t6 = t5 + t4
f.returns([t4, t5, t6])
# CHECK-LABEL: func @multi_output
# CHECK: %[[ADD0:.*]] = torch.operator "aten.add.out"(%arg0
# CHECK: %[[ADD1:.*]] = torch.operator "aten.add.out"(%[[ADD0]]
# CHECK: %[[ADD2:.*]] = torch.operator "aten.add.out"(%[[ADD1]]
# CHECK: %[[ADD3:.*]] = torch.operator "aten.add.out"(%[[ADD2]]
# CHECK: return %[[ADD1]], %[[ADD2]], %[[ADD3]]
print(mb.module)

View File

@ -1,30 +0,0 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See LICENSE for license information.
# RUN: %PYTHON %s | torch-mlir-opt -aten-recognize-kernels -numpy-public-functions-to-tensor -canonicalize | FileCheck %s
# TODO: Re-enable after adding support for 4-operand aten::add in `aten-recognize-kernels`.
# XFAIL: *
# TODO: This test should go away or become part of an e2e test suite. It is
# preserved right now as a stop-gap.
import torch
import torch_mlir
t0 = torch.randn((1,4))
t1 = torch.randn((4,1))
mb = torch_mlir.ModuleBuilder()
with mb.capture_function("foobar", [t0, t1]) as f:
result = t0 + t1
f.returns([result])
# CHECK-LABEL: func @foobar(
# CHECK-SAME: %[[VAL_0:.*]]: tensor<1x4xf32>,
# CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1xf32>) -> tensor<4x4xf32> {
# CHECK: %[[VAL_2:.*]] = torch.constant.int 1
# CHECK: %[[VAL_3:.*]] = "aten.add"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<1x4xf32>, tensor<4x1xf32>, !torch.int) -> tensor<4x4xf32>
# CHECK: return %[[VAL_3]] : tensor<4x4xf32>
# CHECK: }
mb.module.operation.print(large_elements_limit=2)